12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485 |
- import gym
- import numpy as np
- import random
- from ray.rllib.policy.policy import Policy
- from ray.rllib.policy.view_requirement import ViewRequirement
- ROCK = 0
- PAPER = 1
- SCISSORS = 2
- class AlwaysSameHeuristic(Policy):
- """Pick a random move and stick with it for the entire episode."""
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.exploration = self._create_exploration()
- self.view_requirements.update({
- "state_in_0": ViewRequirement(
- "state_out_0",
- shift=-1,
- space=gym.spaces.Box(0, 100, shape=(), dtype=np.int32))
- })
- def get_initial_state(self):
- return [random.choice([ROCK, PAPER, SCISSORS])]
- def compute_actions(self,
- obs_batch,
- state_batches=None,
- prev_action_batch=None,
- prev_reward_batch=None,
- info_batch=None,
- episodes=None,
- **kwargs):
- return state_batches[0], state_batches, {}
- class BeatLastHeuristic(Policy):
- """Play the move that would beat the last move of the opponent."""
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.exploration = self._create_exploration()
- def compute_actions(self,
- obs_batch,
- state_batches=None,
- prev_action_batch=None,
- prev_reward_batch=None,
- info_batch=None,
- episodes=None,
- **kwargs):
- def successor(x):
- # Make this also work w/o one-hot preprocessing.
- if isinstance(self.observation_space, gym.spaces.Discrete):
- if x == ROCK:
- return PAPER
- elif x == PAPER:
- return SCISSORS
- elif x == SCISSORS:
- return ROCK
- else:
- return random.choice([ROCK, PAPER, SCISSORS])
- # One-hot (auto-preprocessed) inputs.
- else:
- if x[ROCK] == 1:
- return PAPER
- elif x[PAPER] == 1:
- return SCISSORS
- elif x[SCISSORS] == 1:
- return ROCK
- elif x[-1] == 1:
- return random.choice([ROCK, PAPER, SCISSORS])
- return [successor(x) for x in obs_batch], [], {}
- def learn_on_batch(self, samples):
- pass
- def get_weights(self):
- pass
- def set_weights(self, weights):
- pass
|