policy_map.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. from collections import deque
  2. import gym
  3. import os
  4. import pickle
  5. import threading
  6. from typing import Callable, Dict, Optional, Set, Type, TYPE_CHECKING
  7. from ray.rllib.policy.policy import PolicySpec
  8. from ray.rllib.utils.annotations import override
  9. from ray.rllib.utils.framework import try_import_tf
  10. from ray.rllib.utils.tf_utils import get_tf_eager_cls_if_necessary
  11. from ray.rllib.utils.threading import with_lock
  12. from ray.rllib.utils.typing import PartialTrainerConfigDict, \
  13. PolicyID, TrainerConfigDict
  14. from ray.tune.utils.util import merge_dicts
  15. if TYPE_CHECKING:
  16. from ray.rllib.policy.policy import Policy
  17. tf1, tf, tfv = try_import_tf()
  18. class PolicyMap(dict):
  19. """Maps policy IDs to Policy objects.
  20. Thereby, keeps n policies in memory and - when capacity is reached -
  21. writes the least recently used to disk. This allows adding 100s of
  22. policies to a Trainer for league-based setups w/o running out of memory.
  23. """
  24. def __init__(
  25. self,
  26. worker_index: int,
  27. num_workers: int,
  28. capacity: Optional[int] = None,
  29. path: Optional[str] = None,
  30. policy_config: Optional[TrainerConfigDict] = None,
  31. session_creator: Optional[Callable[[], "tf1.Session"]] = None,
  32. seed: Optional[int] = None,
  33. ):
  34. """Initializes a PolicyMap instance.
  35. Args:
  36. worker_index (int): The worker index of the RolloutWorker this map
  37. resides in.
  38. num_workers (int): The total number of remote workers in the
  39. WorkerSet to which this map's RolloutWorker belongs to.
  40. capacity (int): The maximum number of policies to hold in memory.
  41. The least used ones are written to disk/S3 and retrieved
  42. when needed.
  43. path (str): The path to store the policy pickle files to. Files
  44. will have the name: [policy_id].[worker idx].policy.pkl.
  45. policy_config (TrainerConfigDict): The Trainer's base config dict.
  46. session_creator (Optional[Callable[[], tf1.Session]): An optional
  47. tf1.Session creation callable.
  48. seed (int): An optional seed (used to seed tf policies).
  49. """
  50. super().__init__()
  51. self.worker_index = worker_index
  52. self.num_workers = num_workers
  53. self.session_creator = session_creator
  54. self.seed = seed
  55. # The file extension for stashed policies (that are no longer available
  56. # in-memory but can be reinstated any time from storage).
  57. self.extension = f".{self.worker_index}.policy.pkl"
  58. # Dictionary of keys that may be looked up (cached or not).
  59. self.valid_keys: Set[str] = set()
  60. # The actual cache with the in-memory policy objects.
  61. self.cache: Dict[str, Policy] = {}
  62. # The doubly-linked list holding the currently in-memory objects.
  63. self.deque = deque(maxlen=capacity or 10)
  64. # The file path where to store overflowing policies.
  65. self.path = path or "."
  66. # The core config to use. Each single policy's config override is
  67. # added on top of this.
  68. self.policy_config: TrainerConfigDict = policy_config or {}
  69. # The orig classes/obs+act spaces, and config overrides of the
  70. # Policies.
  71. self.policy_specs: Dict[PolicyID, PolicySpec] = {}
  72. # Lock used for locking some methods on the object-level.
  73. # This prevents possible race conditions when accessing the map
  74. # and the underlying structures, like self.deque and others.
  75. self._lock = threading.RLock()
  76. def create_policy(self, policy_id: PolicyID, policy_cls: Type["Policy"],
  77. observation_space: gym.Space, action_space: gym.Space,
  78. config_override: PartialTrainerConfigDict,
  79. merged_config: TrainerConfigDict) -> None:
  80. """Creates a new policy and stores it to the cache.
  81. Args:
  82. policy_id (PolicyID): The policy ID. This is the key under which
  83. the created policy will be stored in this map.
  84. policy_cls (Type[Policy]): The (original) policy class to use.
  85. This may still be altered in case tf-eager (and tracing)
  86. is used.
  87. observation_space (gym.Space): The observation space of the
  88. policy.
  89. action_space (gym.Space): The action space of the policy.
  90. config_override (PartialTrainerConfigDict): The config override
  91. dict for this policy. This is the partial dict provided by
  92. the user.
  93. merged_config (TrainerConfigDict): The entire config (merged
  94. default config + `config_override`).
  95. """
  96. framework = merged_config.get("framework", "tf")
  97. class_ = get_tf_eager_cls_if_necessary(policy_cls, merged_config)
  98. # Tf.
  99. if framework in ["tf2", "tf", "tfe"]:
  100. var_scope = policy_id + (("_wk" + str(self.worker_index))
  101. if self.worker_index else "")
  102. # For tf static graph, build every policy in its own graph
  103. # and create a new session for it.
  104. if framework == "tf":
  105. with tf1.Graph().as_default():
  106. if self.session_creator:
  107. sess = self.session_creator()
  108. else:
  109. sess = tf1.Session(
  110. config=tf1.ConfigProto(
  111. gpu_options=tf1.GPUOptions(allow_growth=True)))
  112. with sess.as_default():
  113. # Set graph-level seed.
  114. if self.seed is not None:
  115. tf1.set_random_seed(self.seed)
  116. with tf1.variable_scope(var_scope):
  117. self[policy_id] = class_(
  118. observation_space, action_space, merged_config)
  119. # For tf-eager: no graph, no session.
  120. else:
  121. with tf1.variable_scope(var_scope):
  122. self[policy_id] = \
  123. class_(observation_space, action_space, merged_config)
  124. # Non-tf: No graph, no session.
  125. else:
  126. class_ = policy_cls
  127. self[policy_id] = class_(observation_space, action_space,
  128. merged_config)
  129. # Store spec (class, obs-space, act-space, and config overrides) such
  130. # that the map will be able to reproduce on-the-fly added policies
  131. # from disk.
  132. self.policy_specs[policy_id] = PolicySpec(
  133. policy_class=policy_cls,
  134. observation_space=observation_space,
  135. action_space=action_space,
  136. config=config_override)
  137. @with_lock
  138. @override(dict)
  139. def __getitem__(self, item):
  140. # Never seen this key -> Error.
  141. if item not in self.valid_keys:
  142. raise KeyError(f"PolicyID '{item}' not found in this PolicyMap!")
  143. # Item already in cache -> Rearrange deque (least recently used) and
  144. # return.
  145. if item in self.cache:
  146. self.deque.remove(item)
  147. self.deque.append(item)
  148. # Item not currently in cache -> Get from disk and - if at capacity -
  149. # remove leftmost one.
  150. else:
  151. self._read_from_disk(policy_id=item)
  152. return self.cache[item]
  153. @with_lock
  154. @override(dict)
  155. def __setitem__(self, key, value):
  156. # Item already in cache -> Rearrange deque (least recently used).
  157. if key in self.cache:
  158. self.deque.remove(key)
  159. self.deque.append(key)
  160. self.cache[key] = value
  161. # Item not currently in cache -> store new value and - if at capacity -
  162. # remove leftmost one.
  163. else:
  164. # Cache at capacity -> Drop leftmost item.
  165. if len(self.deque) == self.deque.maxlen:
  166. self._stash_to_disk()
  167. self.deque.append(key)
  168. self.cache[key] = value
  169. self.valid_keys.add(key)
  170. @with_lock
  171. @override(dict)
  172. def __delitem__(self, key):
  173. # Make key invalid.
  174. self.valid_keys.remove(key)
  175. # Remove policy from memory if currently cached.
  176. if key in self.cache:
  177. policy = self.cache[key]
  178. self._close_session(policy)
  179. del self.cache[key]
  180. # Remove file associated with the policy, if it exists.
  181. filename = self.path + "/" + key + self.extension
  182. if os.path.isfile(filename):
  183. os.remove(filename)
  184. @override(dict)
  185. def __iter__(self):
  186. return iter(self.keys())
  187. @override(dict)
  188. def items(self):
  189. """Iterates over all policies, even the stashed-to-disk ones."""
  190. def gen():
  191. for key in self.valid_keys:
  192. yield (key, self[key])
  193. return gen()
  194. @override(dict)
  195. def keys(self):
  196. self._lock.acquire()
  197. ks = list(self.valid_keys)
  198. self._lock.release()
  199. def gen():
  200. for key in ks:
  201. yield key
  202. return gen()
  203. @override(dict)
  204. def values(self):
  205. self._lock.acquire()
  206. vs = [self[k] for k in self.valid_keys]
  207. self._lock.release()
  208. def gen():
  209. for value in vs:
  210. yield value
  211. return gen()
  212. @with_lock
  213. @override(dict)
  214. def update(self, __m, **kwargs):
  215. for k, v in __m.items():
  216. self[k] = v
  217. for k, v in kwargs.items():
  218. self[k] = v
  219. @with_lock
  220. @override(dict)
  221. def get(self, key):
  222. if key not in self.valid_keys:
  223. return None
  224. return self[key]
  225. @with_lock
  226. @override(dict)
  227. def __len__(self):
  228. """Returns number of all policies, including the stashed-to-disk ones.
  229. """
  230. return len(self.valid_keys)
  231. @with_lock
  232. @override(dict)
  233. def __contains__(self, item):
  234. return item in self.valid_keys
  235. def _stash_to_disk(self):
  236. """Writes the least-recently used policy to disk and rearranges cache.
  237. Also closes the session - if applicable - of the stashed policy.
  238. """
  239. # Get least recently used policy (all the way on the left in deque).
  240. delkey = self.deque.popleft()
  241. policy = self.cache[delkey]
  242. # Get its state for writing to disk.
  243. policy_state = policy.get_state()
  244. # Closes policy's tf session, if any.
  245. self._close_session(policy)
  246. # Remove from memory. This will clear the tf Graph as well.
  247. del self.cache[delkey]
  248. # Write state to disk.
  249. with open(self.path + "/" + delkey + self.extension, "wb") as f:
  250. pickle.dump(policy_state, file=f)
  251. def _read_from_disk(self, policy_id):
  252. """Reads a policy ID from disk and re-adds it to the cache.
  253. """
  254. # Make sure this policy ID is not in the cache right now.
  255. assert policy_id not in self.cache
  256. # Read policy state from disk.
  257. with open(self.path + "/" + policy_id + self.extension, "rb") as f:
  258. policy_state = pickle.load(f)
  259. # Get class and config override.
  260. merged_conf = merge_dicts(self.policy_config,
  261. self.policy_specs[policy_id].config)
  262. # Create policy object (from its spec: cls, obs-space, act-space,
  263. # config).
  264. self.create_policy(
  265. policy_id,
  266. self.policy_specs[policy_id].policy_class,
  267. self.policy_specs[policy_id].observation_space,
  268. self.policy_specs[policy_id].action_space,
  269. self.policy_specs[policy_id].config,
  270. merged_conf,
  271. )
  272. # Restore policy's state.
  273. policy = self[policy_id]
  274. policy.set_state(policy_state)
  275. def _close_session(self, policy):
  276. sess = policy.get_session()
  277. # Closes the tf session, if any.
  278. if sess is not None:
  279. sess.close()