coin_game_vectorized_env.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. ##########
  2. # Contribution by the Center on Long-Term Risk:
  3. # https://github.com/longtermrisk/marltoolbox
  4. # Some parts are originally from:
  5. # https://github.com/julianstastny/openspiel-social-dilemmas/
  6. # blob/master/games/coin_game_gym.py
  7. ##########
  8. import copy
  9. from collections import Iterable
  10. import numpy as np
  11. from numba import jit, prange
  12. from numba.typed import List
  13. from ray.rllib.examples.env.coin_game_non_vectorized_env import CoinGame
  14. from ray.rllib.utils import override
  15. class VectorizedCoinGame(CoinGame):
  16. """
  17. Vectorized Coin Game environment.
  18. """
  19. def __init__(self, config=None):
  20. if config is None:
  21. config = {}
  22. super().__init__(config)
  23. self.batch_size = config.get("batch_size", 1)
  24. self.force_vectorized = config.get("force_vectorize", False)
  25. assert self.grid_size == 3, "hardcoded in the generate_state function"
  26. @override(CoinGame)
  27. def _randomize_color_and_player_positions(self):
  28. # Reset coin color and the players and coin positions
  29. self.red_coin = np.random.randint(2, size=self.batch_size)
  30. self.red_pos = np.random.randint(
  31. self.grid_size, size=(self.batch_size, 2))
  32. self.blue_pos = np.random.randint(
  33. self.grid_size, size=(self.batch_size, 2))
  34. self.coin_pos = np.zeros((self.batch_size, 2), dtype=np.int8)
  35. self._players_do_not_overlap_at_start()
  36. @override(CoinGame)
  37. def _players_do_not_overlap_at_start(self):
  38. for i in range(self.batch_size):
  39. while _same_pos(self.red_pos[i], self.blue_pos[i]):
  40. self.blue_pos[i] = np.random.randint(self.grid_size, size=2)
  41. @override(CoinGame)
  42. def _generate_coin(self):
  43. generate = np.ones(self.batch_size, dtype=bool)
  44. self.coin_pos = generate_coin(self.batch_size, generate, self.red_coin,
  45. self.red_pos, self.blue_pos,
  46. self.coin_pos, self.grid_size)
  47. @override(CoinGame)
  48. def _generate_observation(self):
  49. obs = generate_observations_wt_numba_optimization(
  50. self.batch_size, self.red_pos, self.blue_pos, self.coin_pos,
  51. self.red_coin, self.grid_size)
  52. obs = self._get_obs_invariant_to_the_player_trained(obs)
  53. obs, _ = self._optional_unvectorize(obs)
  54. return obs
  55. def _optional_unvectorize(self, obs, rewards=None):
  56. if self.batch_size == 1 and not self.force_vectorized:
  57. obs = [one_obs[0, ...] for one_obs in obs]
  58. if rewards is not None:
  59. rewards[0], rewards[1] = rewards[0][0], rewards[1][0]
  60. return obs, rewards
  61. @override(CoinGame)
  62. def step(self, actions: Iterable):
  63. actions = self._from_RLLib_API_to_list(actions)
  64. self.step_count_in_current_episode += 1
  65. (self.red_pos, self.blue_pos, rewards, self.coin_pos, observation,
  66. self.red_coin, red_pick_any, red_pick_red, blue_pick_any,
  67. blue_pick_blue) = vectorized_step_wt_numba_optimization(
  68. actions, self.batch_size, self.red_pos, self.blue_pos,
  69. self.coin_pos, self.red_coin, self.grid_size, self.asymmetric,
  70. self.max_steps, self.both_players_can_pick_the_same_coin)
  71. if self.output_additional_info:
  72. self._accumulate_info(red_pick_any, red_pick_red, blue_pick_any,
  73. blue_pick_blue)
  74. obs = \
  75. self._get_obs_invariant_to_the_player_trained(observation)
  76. obs, rewards = self._optional_unvectorize(obs, rewards)
  77. return self._to_RLLib_API(obs, rewards)
  78. @override(CoinGame)
  79. def _get_episode_info(self):
  80. player_red_info, player_blue_info = {}, {}
  81. if len(self.red_pick) > 0:
  82. red_pick = sum(self.red_pick)
  83. player_red_info["pick_speed"] = \
  84. red_pick / (len(self.red_pick) * self.batch_size)
  85. if red_pick > 0:
  86. player_red_info["pick_own_color"] = \
  87. sum(self.red_pick_own) / red_pick
  88. if len(self.blue_pick) > 0:
  89. blue_pick = sum(self.blue_pick)
  90. player_blue_info["pick_speed"] = \
  91. blue_pick / (len(self.blue_pick) * self.batch_size)
  92. if blue_pick > 0:
  93. player_blue_info["pick_own_color"] = \
  94. sum(self.blue_pick_own) / blue_pick
  95. return player_red_info, player_blue_info
  96. @override(CoinGame)
  97. def _from_RLLib_API_to_list(self, actions):
  98. ac_red = actions[self.player_red_id]
  99. ac_blue = actions[self.player_blue_id]
  100. if not isinstance(ac_red, Iterable):
  101. assert not isinstance(ac_blue, Iterable)
  102. ac_red, ac_blue = [ac_red], [ac_blue]
  103. actions = [ac_red, ac_blue]
  104. actions = np.array(actions).T
  105. return actions
  106. def _save_env(self):
  107. env_save_state = {
  108. "red_pos": self.red_pos,
  109. "blue_pos": self.blue_pos,
  110. "coin_pos": self.coin_pos,
  111. "red_coin": self.red_coin,
  112. "grid_size": self.grid_size,
  113. "asymmetric": self.asymmetric,
  114. "batch_size": self.batch_size,
  115. "step_count_in_current_episode": self.
  116. step_count_in_current_episode,
  117. "max_steps": self.max_steps,
  118. "red_pick": self.red_pick,
  119. "red_pick_own": self.red_pick_own,
  120. "blue_pick": self.blue_pick,
  121. "blue_pick_own": self.blue_pick_own,
  122. "both_players_can_pick_the_same_coin": self.
  123. both_players_can_pick_the_same_coin,
  124. }
  125. return copy.deepcopy(env_save_state)
  126. def _load_env(self, env_state):
  127. for k, v in env_state.items():
  128. self.__setattr__(k, v)
  129. class AsymVectorizedCoinGame(VectorizedCoinGame):
  130. NAME = "AsymCoinGame"
  131. def __init__(self, config=None):
  132. if config is None:
  133. config = {}
  134. if "asymmetric" in config:
  135. assert config["asymmetric"]
  136. else:
  137. config["asymmetric"] = True
  138. super().__init__(config)
  139. @jit(nopython=True)
  140. def move_players(batch_size, actions, red_pos, blue_pos, grid_size):
  141. moves = List([
  142. np.array([0, 1]),
  143. np.array([0, -1]),
  144. np.array([1, 0]),
  145. np.array([-1, 0]),
  146. ])
  147. for j in prange(batch_size):
  148. red_pos[j] = \
  149. (red_pos[j] + moves[actions[j, 0]]) % grid_size
  150. blue_pos[j] = \
  151. (blue_pos[j] + moves[actions[j, 1]]) % grid_size
  152. return red_pos, blue_pos
  153. @jit(nopython=True)
  154. def compute_reward(batch_size, red_pos, blue_pos, coin_pos, red_coin,
  155. asymmetric, both_players_can_pick_the_same_coin):
  156. reward_red = np.zeros(batch_size)
  157. reward_blue = np.zeros(batch_size)
  158. generate = np.zeros(batch_size, dtype=np.bool_)
  159. red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue = \
  160. 0, 0, 0, 0
  161. for i in prange(batch_size):
  162. red_first_if_both = None
  163. if not both_players_can_pick_the_same_coin:
  164. if (_same_pos(red_pos[i], coin_pos[i])
  165. and _same_pos(blue_pos[i], coin_pos[i])):
  166. red_first_if_both = bool(np.random.randint(0, 1))
  167. if red_coin[i]:
  168. if _same_pos(red_pos[i], coin_pos[i]) and \
  169. (red_first_if_both is None or red_first_if_both):
  170. generate[i] = True
  171. reward_red[i] += 1
  172. if asymmetric:
  173. reward_red[i] += 3
  174. red_pick_any += 1
  175. red_pick_red += 1
  176. if _same_pos(blue_pos[i], coin_pos[i]) and \
  177. (red_first_if_both is None or not red_first_if_both):
  178. generate[i] = True
  179. reward_red[i] += -2
  180. reward_blue[i] += 1
  181. blue_pick_any += 1
  182. else:
  183. if _same_pos(red_pos[i], coin_pos[i]) and \
  184. (red_first_if_both is None or red_first_if_both):
  185. generate[i] = True
  186. reward_red[i] += 1
  187. reward_blue[i] += -2
  188. if asymmetric:
  189. reward_red[i] += 3
  190. red_pick_any += 1
  191. if _same_pos(blue_pos[i], coin_pos[i]) and \
  192. (red_first_if_both is None or not red_first_if_both):
  193. generate[i] = True
  194. reward_blue[i] += 1
  195. blue_pick_any += 1
  196. blue_pick_blue += 1
  197. reward = [reward_red, reward_blue]
  198. return reward, generate, \
  199. red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue
  200. @jit(nopython=True)
  201. def _same_pos(x, y):
  202. return (x == y).all()
  203. @jit(nopython=True)
  204. def _flatten_index(pos, grid_size):
  205. y_pos, x_pos = pos
  206. idx = grid_size * y_pos
  207. idx += x_pos
  208. return idx
  209. @jit(nopython=True)
  210. def _unflatten_index(pos, grid_size):
  211. x_idx = pos % grid_size
  212. y_idx = pos // grid_size
  213. return np.array([y_idx, x_idx])
  214. @jit(nopython=True)
  215. def generate_coin(batch_size, generate, red_coin, red_pos, blue_pos, coin_pos,
  216. grid_size):
  217. red_coin[generate] = 1 - red_coin[generate]
  218. for i in prange(batch_size):
  219. if generate[i]:
  220. coin_pos[i] = place_coin(red_pos[i], blue_pos[i], grid_size)
  221. return coin_pos
  222. @jit(nopython=True)
  223. def place_coin(red_pos_i, blue_pos_i, grid_size):
  224. red_pos_flat = _flatten_index(red_pos_i, grid_size)
  225. blue_pos_flat = _flatten_index(blue_pos_i, grid_size)
  226. possible_coin_pos = np.array([
  227. x for x in range(9) if ((x != blue_pos_flat) and (x != red_pos_flat))
  228. ])
  229. flat_coin_pos = np.random.choice(possible_coin_pos)
  230. return _unflatten_index(flat_coin_pos, grid_size)
  231. @jit(nopython=True)
  232. def generate_observations_wt_numba_optimization(batch_size, red_pos, blue_pos,
  233. coin_pos, red_coin, grid_size):
  234. obs = np.zeros((batch_size, grid_size, grid_size, 4))
  235. for i in prange(batch_size):
  236. obs[i, red_pos[i][0], red_pos[i][1], 0] = 1
  237. obs[i, blue_pos[i][0], blue_pos[i][1], 1] = 1
  238. if red_coin[i]:
  239. obs[i, coin_pos[i][0], coin_pos[i][1], 2] = 1
  240. else:
  241. obs[i, coin_pos[i][0], coin_pos[i][1], 3] = 1
  242. return obs
  243. @jit(nopython=True)
  244. def vectorized_step_wt_numba_optimization(
  245. actions, batch_size, red_pos, blue_pos, coin_pos, red_coin,
  246. grid_size: int, asymmetric: bool, max_steps: int,
  247. both_players_can_pick_the_same_coin: bool):
  248. red_pos, blue_pos = move_players(batch_size, actions, red_pos, blue_pos,
  249. grid_size)
  250. reward, generate, red_pick_any, red_pick_red, \
  251. blue_pick_any, blue_pick_blue = compute_reward(
  252. batch_size, red_pos, blue_pos, coin_pos, red_coin,
  253. asymmetric, both_players_can_pick_the_same_coin)
  254. coin_pos = generate_coin(batch_size, generate, red_coin, red_pos, blue_pos,
  255. coin_pos, grid_size)
  256. obs = generate_observations_wt_numba_optimization(
  257. batch_size, red_pos, blue_pos, coin_pos, red_coin, grid_size)
  258. return red_pos, blue_pos, reward, coin_pos, obs, red_coin, red_pick_any, \
  259. red_pick_red, blue_pick_any, blue_pick_blue