random_parametric_agent.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. from abc import ABC
  2. import ray
  3. import numpy as np
  4. from ray.rllib import Policy
  5. from ray.rllib.algorithms.algorithm import Algorithm
  6. from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
  7. from ray.rllib.examples.env.parametric_actions_cartpole import ParametricActionsCartPole
  8. from ray.rllib.models.modelv2 import restore_original_dimensions
  9. from ray.rllib.utils import override
  10. from ray.rllib.utils.typing import ResultDict
  11. from ray.tune.registry import register_env
  12. class RandomParametricPolicy(Policy, ABC):
  13. """
  14. Just pick a random legal action
  15. The outputted state of the environment needs to be a dictionary with an
  16. 'action_mask' key containing the legal actions for the agent.
  17. """
  18. def __init__(self, *args, **kwargs):
  19. super().__init__(*args, **kwargs)
  20. self.exploration = self._create_exploration()
  21. @override(Policy)
  22. def compute_actions(
  23. self,
  24. obs_batch,
  25. state_batches=None,
  26. prev_action_batch=None,
  27. prev_reward_batch=None,
  28. info_batch=None,
  29. episodes=None,
  30. **kwargs
  31. ):
  32. obs_batch = restore_original_dimensions(
  33. np.array(obs_batch, dtype=np.float32), self.observation_space, tensorlib=np
  34. )
  35. def pick_legal_action(legal_action):
  36. return np.random.choice(
  37. len(legal_action), 1, p=(legal_action / legal_action.sum())
  38. )[0]
  39. return [pick_legal_action(x) for x in obs_batch["action_mask"]], [], {}
  40. def learn_on_batch(self, samples):
  41. pass
  42. def get_weights(self):
  43. pass
  44. def set_weights(self, weights):
  45. pass
  46. class RandomParametricAlgorithm(Algorithm):
  47. """Algo with Policy and config defined above and overriding `training_step`.
  48. Overrides the `training_step` method, which only runs a (dummy)
  49. rollout and performs no learning.
  50. """
  51. @classmethod
  52. def get_default_policy_class(cls, config):
  53. return RandomParametricPolicy
  54. @override(Algorithm)
  55. def training_step(self) -> ResultDict:
  56. # Perform rollouts (only for collecting metrics later).
  57. synchronous_parallel_sample(worker_set=self.workers)
  58. # Return (empty) training metrics.
  59. return {}
  60. def main():
  61. register_env("pa_cartpole", lambda _: ParametricActionsCartPole(10))
  62. algo = RandomParametricAlgorithm(env="pa_cartpole")
  63. result = algo.train()
  64. assert result["episode_reward_mean"] > 10, result
  65. print("Test: OK")
  66. if __name__ == "__main__":
  67. ray.init()
  68. main()