random_policy.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import random
  2. import numpy as np
  3. from gym.spaces import Box
  4. from ray.rllib.policy.policy import Policy
  5. from ray.rllib.utils.annotations import override
  6. from ray.rllib.utils.typing import ModelWeights
  7. class RandomPolicy(Policy):
  8. """Hand-coded policy that returns random actions."""
  9. def __init__(self, *args, **kwargs):
  10. super().__init__(*args, **kwargs)
  11. # Whether for compute_actions, the bounds given in action_space
  12. # should be ignored (default: False). This is to test action-clipping
  13. # and any Env's reaction to bounds breaches.
  14. if self.config.get("ignore_action_bounds", False) and \
  15. isinstance(self.action_space, Box):
  16. self.action_space_for_sampling = Box(
  17. -float("inf"),
  18. float("inf"),
  19. shape=self.action_space.shape,
  20. dtype=self.action_space.dtype)
  21. else:
  22. self.action_space_for_sampling = self.action_space
  23. @override(Policy)
  24. def compute_actions(self,
  25. obs_batch,
  26. state_batches=None,
  27. prev_action_batch=None,
  28. prev_reward_batch=None,
  29. **kwargs):
  30. # Alternatively, a numpy array would work here as well.
  31. # e.g.: np.array([random.choice([0, 1])] * len(obs_batch))
  32. return [self.action_space_for_sampling.sample() for _ in obs_batch], \
  33. [], {}
  34. @override(Policy)
  35. def learn_on_batch(self, samples):
  36. """No learning."""
  37. return {}
  38. @override(Policy)
  39. def compute_log_likelihoods(self,
  40. actions,
  41. obs_batch,
  42. state_batches=None,
  43. prev_action_batch=None,
  44. prev_reward_batch=None):
  45. return np.array([random.random()] * len(obs_batch))
  46. @override(Policy)
  47. def get_weights(self) -> ModelWeights:
  48. """No weights to save."""
  49. return {}
  50. @override(Policy)
  51. def set_weights(self, weights: ModelWeights) -> None:
  52. """No weights to set."""
  53. pass