rock_paper_scissors_dummies.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import gym
  2. import numpy as np
  3. import random
  4. from ray.rllib.policy.policy import Policy
  5. from ray.rllib.policy.view_requirement import ViewRequirement
  6. ROCK = 0
  7. PAPER = 1
  8. SCISSORS = 2
  9. class AlwaysSameHeuristic(Policy):
  10. """Pick a random move and stick with it for the entire episode."""
  11. def __init__(self, *args, **kwargs):
  12. super().__init__(*args, **kwargs)
  13. self.exploration = self._create_exploration()
  14. self.view_requirements.update({
  15. "state_in_0": ViewRequirement(
  16. "state_out_0",
  17. shift=-1,
  18. space=gym.spaces.Box(0, 100, shape=(), dtype=np.int32))
  19. })
  20. def get_initial_state(self):
  21. return [random.choice([ROCK, PAPER, SCISSORS])]
  22. def compute_actions(self,
  23. obs_batch,
  24. state_batches=None,
  25. prev_action_batch=None,
  26. prev_reward_batch=None,
  27. info_batch=None,
  28. episodes=None,
  29. **kwargs):
  30. return state_batches[0], state_batches, {}
  31. class BeatLastHeuristic(Policy):
  32. """Play the move that would beat the last move of the opponent."""
  33. def __init__(self, *args, **kwargs):
  34. super().__init__(*args, **kwargs)
  35. self.exploration = self._create_exploration()
  36. def compute_actions(self,
  37. obs_batch,
  38. state_batches=None,
  39. prev_action_batch=None,
  40. prev_reward_batch=None,
  41. info_batch=None,
  42. episodes=None,
  43. **kwargs):
  44. def successor(x):
  45. # Make this also work w/o one-hot preprocessing.
  46. if isinstance(self.observation_space, gym.spaces.Discrete):
  47. if x == ROCK:
  48. return PAPER
  49. elif x == PAPER:
  50. return SCISSORS
  51. elif x == SCISSORS:
  52. return ROCK
  53. else:
  54. return random.choice([ROCK, PAPER, SCISSORS])
  55. # One-hot (auto-preprocessed) inputs.
  56. else:
  57. if x[ROCK] == 1:
  58. return PAPER
  59. elif x[PAPER] == 1:
  60. return SCISSORS
  61. elif x[SCISSORS] == 1:
  62. return ROCK
  63. elif x[-1] == 1:
  64. return random.choice([ROCK, PAPER, SCISSORS])
  65. return [successor(x) for x in obs_batch], [], {}
  66. def learn_on_batch(self, samples):
  67. pass
  68. def get_weights(self):
  69. pass
  70. def set_weights(self, weights):
  71. pass