coin_game_vectorized_env.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. ##########
  2. # Contribution by the Center on Long-Term Risk:
  3. # https://github.com/longtermrisk/marltoolbox
  4. # Some parts are originally from:
  5. # https://github.com/julianstastny/openspiel-social-dilemmas/
  6. # blob/master/games/coin_game_gym.py
  7. ##########
  8. import copy
  9. from collections import Iterable
  10. import numpy as np
  11. from numba import jit, prange
  12. from numba.typed import List
  13. from ray.rllib.examples.env.coin_game_non_vectorized_env import CoinGame
  14. from ray.rllib.utils import override
  15. class VectorizedCoinGame(CoinGame):
  16. """
  17. Vectorized Coin Game environment.
  18. """
  19. def __init__(self, config=None):
  20. if config is None:
  21. config = {}
  22. super().__init__(config)
  23. self.batch_size = config.get("batch_size", 1)
  24. self.force_vectorized = config.get("force_vectorize", False)
  25. assert self.grid_size == 3, "hardcoded in the generate_state function"
  26. @override(CoinGame)
  27. def _randomize_color_and_player_positions(self):
  28. # Reset coin color and the players and coin positions
  29. self.red_coin = np.random.randint(2, size=self.batch_size)
  30. self.red_pos = np.random.randint(self.grid_size, size=(self.batch_size, 2))
  31. self.blue_pos = np.random.randint(self.grid_size, size=(self.batch_size, 2))
  32. self.coin_pos = np.zeros((self.batch_size, 2), dtype=np.int8)
  33. self._players_do_not_overlap_at_start()
  34. @override(CoinGame)
  35. def _players_do_not_overlap_at_start(self):
  36. for i in range(self.batch_size):
  37. while _same_pos(self.red_pos[i], self.blue_pos[i]):
  38. self.blue_pos[i] = np.random.randint(self.grid_size, size=2)
  39. @override(CoinGame)
  40. def _generate_coin(self):
  41. generate = np.ones(self.batch_size, dtype=bool)
  42. self.coin_pos = generate_coin(
  43. self.batch_size,
  44. generate,
  45. self.red_coin,
  46. self.red_pos,
  47. self.blue_pos,
  48. self.coin_pos,
  49. self.grid_size,
  50. )
  51. @override(CoinGame)
  52. def _generate_observation(self):
  53. obs = generate_observations_wt_numba_optimization(
  54. self.batch_size,
  55. self.red_pos,
  56. self.blue_pos,
  57. self.coin_pos,
  58. self.red_coin,
  59. self.grid_size,
  60. )
  61. obs = self._get_obs_invariant_to_the_player_trained(obs)
  62. obs, _ = self._optional_unvectorize(obs)
  63. return obs
  64. def _optional_unvectorize(self, obs, rewards=None):
  65. if self.batch_size == 1 and not self.force_vectorized:
  66. obs = [one_obs[0, ...] for one_obs in obs]
  67. if rewards is not None:
  68. rewards[0], rewards[1] = rewards[0][0], rewards[1][0]
  69. return obs, rewards
  70. @override(CoinGame)
  71. def step(self, actions: Iterable):
  72. actions = self._from_RLlib_API_to_list(actions)
  73. self.step_count_in_current_episode += 1
  74. (
  75. self.red_pos,
  76. self.blue_pos,
  77. rewards,
  78. self.coin_pos,
  79. observation,
  80. self.red_coin,
  81. red_pick_any,
  82. red_pick_red,
  83. blue_pick_any,
  84. blue_pick_blue,
  85. ) = vectorized_step_wt_numba_optimization(
  86. actions,
  87. self.batch_size,
  88. self.red_pos,
  89. self.blue_pos,
  90. self.coin_pos,
  91. self.red_coin,
  92. self.grid_size,
  93. self.asymmetric,
  94. self.max_steps,
  95. self.both_players_can_pick_the_same_coin,
  96. )
  97. if self.output_additional_info:
  98. self._accumulate_info(
  99. red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue
  100. )
  101. obs = self._get_obs_invariant_to_the_player_trained(observation)
  102. obs, rewards = self._optional_unvectorize(obs, rewards)
  103. return self._to_RLlib_API(obs, rewards)
  104. @override(CoinGame)
  105. def _get_episode_info(self):
  106. player_red_info, player_blue_info = {}, {}
  107. if len(self.red_pick) > 0:
  108. red_pick = sum(self.red_pick)
  109. player_red_info["pick_speed"] = red_pick / (
  110. len(self.red_pick) * self.batch_size
  111. )
  112. if red_pick > 0:
  113. player_red_info["pick_own_color"] = sum(self.red_pick_own) / red_pick
  114. if len(self.blue_pick) > 0:
  115. blue_pick = sum(self.blue_pick)
  116. player_blue_info["pick_speed"] = blue_pick / (
  117. len(self.blue_pick) * self.batch_size
  118. )
  119. if blue_pick > 0:
  120. player_blue_info["pick_own_color"] = sum(self.blue_pick_own) / blue_pick
  121. return player_red_info, player_blue_info
  122. @override(CoinGame)
  123. def _from_RLlib_API_to_list(self, actions):
  124. ac_red = actions[self.player_red_id]
  125. ac_blue = actions[self.player_blue_id]
  126. if not isinstance(ac_red, Iterable):
  127. assert not isinstance(ac_blue, Iterable)
  128. ac_red, ac_blue = [ac_red], [ac_blue]
  129. actions = [ac_red, ac_blue]
  130. actions = np.array(actions).T
  131. return actions
  132. def _save_env(self):
  133. env_save_state = {
  134. "red_pos": self.red_pos,
  135. "blue_pos": self.blue_pos,
  136. "coin_pos": self.coin_pos,
  137. "red_coin": self.red_coin,
  138. "grid_size": self.grid_size,
  139. "asymmetric": self.asymmetric,
  140. "batch_size": self.batch_size,
  141. "step_count_in_current_episode": self.step_count_in_current_episode,
  142. "max_steps": self.max_steps,
  143. "red_pick": self.red_pick,
  144. "red_pick_own": self.red_pick_own,
  145. "blue_pick": self.blue_pick,
  146. "blue_pick_own": self.blue_pick_own,
  147. "both_players_can_pick_the_same_coin": self.both_players_can_pick_the_same_coin, # noqa: E501
  148. }
  149. return copy.deepcopy(env_save_state)
  150. def _load_env(self, env_state):
  151. for k, v in env_state.items():
  152. self.__setattr__(k, v)
  153. class AsymVectorizedCoinGame(VectorizedCoinGame):
  154. NAME = "AsymCoinGame"
  155. def __init__(self, config=None):
  156. if config is None:
  157. config = {}
  158. if "asymmetric" in config:
  159. assert config["asymmetric"]
  160. else:
  161. config["asymmetric"] = True
  162. super().__init__(config)
  163. @jit(nopython=True)
  164. def move_players(batch_size, actions, red_pos, blue_pos, grid_size):
  165. moves = List(
  166. [
  167. np.array([0, 1]),
  168. np.array([0, -1]),
  169. np.array([1, 0]),
  170. np.array([-1, 0]),
  171. ]
  172. )
  173. for j in prange(batch_size):
  174. red_pos[j] = (red_pos[j] + moves[actions[j, 0]]) % grid_size
  175. blue_pos[j] = (blue_pos[j] + moves[actions[j, 1]]) % grid_size
  176. return red_pos, blue_pos
  177. @jit(nopython=True)
  178. def compute_reward(
  179. batch_size,
  180. red_pos,
  181. blue_pos,
  182. coin_pos,
  183. red_coin,
  184. asymmetric,
  185. both_players_can_pick_the_same_coin,
  186. ):
  187. reward_red = np.zeros(batch_size)
  188. reward_blue = np.zeros(batch_size)
  189. generate = np.zeros(batch_size, dtype=np.bool_)
  190. red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue = 0, 0, 0, 0
  191. for i in prange(batch_size):
  192. red_first_if_both = None
  193. if not both_players_can_pick_the_same_coin:
  194. if _same_pos(red_pos[i], coin_pos[i]) and _same_pos(
  195. blue_pos[i], coin_pos[i]
  196. ):
  197. red_first_if_both = bool(np.random.randint(0, 1))
  198. if red_coin[i]:
  199. if _same_pos(red_pos[i], coin_pos[i]) and (
  200. red_first_if_both is None or red_first_if_both
  201. ):
  202. generate[i] = True
  203. reward_red[i] += 1
  204. if asymmetric:
  205. reward_red[i] += 3
  206. red_pick_any += 1
  207. red_pick_red += 1
  208. if _same_pos(blue_pos[i], coin_pos[i]) and (
  209. red_first_if_both is None or not red_first_if_both
  210. ):
  211. generate[i] = True
  212. reward_red[i] += -2
  213. reward_blue[i] += 1
  214. blue_pick_any += 1
  215. else:
  216. if _same_pos(red_pos[i], coin_pos[i]) and (
  217. red_first_if_both is None or red_first_if_both
  218. ):
  219. generate[i] = True
  220. reward_red[i] += 1
  221. reward_blue[i] += -2
  222. if asymmetric:
  223. reward_red[i] += 3
  224. red_pick_any += 1
  225. if _same_pos(blue_pos[i], coin_pos[i]) and (
  226. red_first_if_both is None or not red_first_if_both
  227. ):
  228. generate[i] = True
  229. reward_blue[i] += 1
  230. blue_pick_any += 1
  231. blue_pick_blue += 1
  232. reward = [reward_red, reward_blue]
  233. return reward, generate, red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue
  234. @jit(nopython=True)
  235. def _same_pos(x, y):
  236. return (x == y).all()
  237. @jit(nopython=True)
  238. def _flatten_index(pos, grid_size):
  239. y_pos, x_pos = pos
  240. idx = grid_size * y_pos
  241. idx += x_pos
  242. return idx
  243. @jit(nopython=True)
  244. def _unflatten_index(pos, grid_size):
  245. x_idx = pos % grid_size
  246. y_idx = pos // grid_size
  247. return np.array([y_idx, x_idx])
  248. @jit(nopython=True)
  249. def generate_coin(
  250. batch_size, generate, red_coin, red_pos, blue_pos, coin_pos, grid_size
  251. ):
  252. red_coin[generate] = 1 - red_coin[generate]
  253. for i in prange(batch_size):
  254. if generate[i]:
  255. coin_pos[i] = place_coin(red_pos[i], blue_pos[i], grid_size)
  256. return coin_pos
  257. @jit(nopython=True)
  258. def place_coin(red_pos_i, blue_pos_i, grid_size):
  259. red_pos_flat = _flatten_index(red_pos_i, grid_size)
  260. blue_pos_flat = _flatten_index(blue_pos_i, grid_size)
  261. possible_coin_pos = np.array(
  262. [x for x in range(9) if ((x != blue_pos_flat) and (x != red_pos_flat))]
  263. )
  264. flat_coin_pos = np.random.choice(possible_coin_pos)
  265. return _unflatten_index(flat_coin_pos, grid_size)
  266. @jit(nopython=True)
  267. def generate_observations_wt_numba_optimization(
  268. batch_size, red_pos, blue_pos, coin_pos, red_coin, grid_size
  269. ):
  270. obs = np.zeros((batch_size, grid_size, grid_size, 4))
  271. for i in prange(batch_size):
  272. obs[i, red_pos[i][0], red_pos[i][1], 0] = 1
  273. obs[i, blue_pos[i][0], blue_pos[i][1], 1] = 1
  274. if red_coin[i]:
  275. obs[i, coin_pos[i][0], coin_pos[i][1], 2] = 1
  276. else:
  277. obs[i, coin_pos[i][0], coin_pos[i][1], 3] = 1
  278. return obs
  279. @jit(nopython=True)
  280. def vectorized_step_wt_numba_optimization(
  281. actions,
  282. batch_size,
  283. red_pos,
  284. blue_pos,
  285. coin_pos,
  286. red_coin,
  287. grid_size: int,
  288. asymmetric: bool,
  289. max_steps: int,
  290. both_players_can_pick_the_same_coin: bool,
  291. ):
  292. red_pos, blue_pos = move_players(batch_size, actions, red_pos, blue_pos, grid_size)
  293. (
  294. reward,
  295. generate,
  296. red_pick_any,
  297. red_pick_red,
  298. blue_pick_any,
  299. blue_pick_blue,
  300. ) = compute_reward(
  301. batch_size,
  302. red_pos,
  303. blue_pos,
  304. coin_pos,
  305. red_coin,
  306. asymmetric,
  307. both_players_can_pick_the_same_coin,
  308. )
  309. coin_pos = generate_coin(
  310. batch_size, generate, red_coin, red_pos, blue_pos, coin_pos, grid_size
  311. )
  312. obs = generate_observations_wt_numba_optimization(
  313. batch_size, red_pos, blue_pos, coin_pos, red_coin, grid_size
  314. )
  315. return (
  316. red_pos,
  317. blue_pos,
  318. reward,
  319. coin_pos,
  320. obs,
  321. red_coin,
  322. red_pick_any,
  323. red_pick_red,
  324. blue_pick_any,
  325. blue_pick_blue,
  326. )