random_parametric_agent.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. from abc import ABC
  2. import ray
  3. import numpy as np
  4. from ray.rllib import Policy
  5. from ray.rllib.agents import with_common_config
  6. from ray.rllib.agents.trainer_template import build_trainer
  7. from ray.rllib.evaluation.worker_set import WorkerSet
  8. from ray.rllib.execution.metric_ops import StandardMetricsReporting
  9. from ray.rllib.execution.rollout_ops import ParallelRollouts, SelectExperiences
  10. from ray.rllib.examples.env.parametric_actions_cartpole import \
  11. ParametricActionsCartPole
  12. from ray.rllib.models.modelv2 import restore_original_dimensions
  13. from ray.rllib.utils import override
  14. from ray.rllib.utils.typing import TrainerConfigDict
  15. from ray.util.iter import LocalIterator
  16. from ray.tune.registry import register_env
  17. DEFAULT_CONFIG = with_common_config({})
  18. class RandomParametriclPolicy(Policy, ABC):
  19. """
  20. Just pick a random legal action
  21. The outputted state of the environment needs to be a dictionary with an
  22. 'action_mask' key containing the legal actions for the agent.
  23. """
  24. def __init__(self, *args, **kwargs):
  25. super().__init__(*args, **kwargs)
  26. self.exploration = self._create_exploration()
  27. @override(Policy)
  28. def compute_actions(self,
  29. obs_batch,
  30. state_batches=None,
  31. prev_action_batch=None,
  32. prev_reward_batch=None,
  33. info_batch=None,
  34. episodes=None,
  35. **kwargs):
  36. obs_batch = restore_original_dimensions(
  37. np.array(obs_batch, dtype=np.float32),
  38. self.observation_space,
  39. tensorlib=np)
  40. def pick_legal_action(legal_action):
  41. return np.random.choice(
  42. len(legal_action), 1, p=(legal_action / legal_action.sum()))[0]
  43. return [pick_legal_action(x) for x in obs_batch["action_mask"]], [], {}
  44. def learn_on_batch(self, samples):
  45. pass
  46. def get_weights(self):
  47. pass
  48. def set_weights(self, weights):
  49. pass
  50. def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
  51. **kwargs) -> LocalIterator[dict]:
  52. rollouts = ParallelRollouts(workers, mode="async")
  53. # Collect batches for the trainable policies.
  54. rollouts = rollouts.for_each(
  55. SelectExperiences(workers.trainable_policies()))
  56. # Return training metrics.
  57. return StandardMetricsReporting(rollouts, workers, config)
  58. RandomParametricTrainer = build_trainer(
  59. name="RandomParametric",
  60. default_config=DEFAULT_CONFIG,
  61. default_policy=RandomParametriclPolicy,
  62. execution_plan=execution_plan)
  63. def main():
  64. register_env("pa_cartpole", lambda _: ParametricActionsCartPole(10))
  65. trainer = RandomParametricTrainer(env="pa_cartpole")
  66. result = trainer.train()
  67. assert result["episode_reward_mean"] > 10, result
  68. print("Test: OK")
  69. if __name__ == "__main__":
  70. ray.init()
  71. main()