from collections import deque import gym import os import pickle import threading from typing import Callable, Dict, Optional, Set, Type, TYPE_CHECKING from ray.rllib.policy.policy import PolicySpec from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.tf_utils import get_tf_eager_cls_if_necessary from ray.rllib.utils.threading import with_lock from ray.rllib.utils.typing import PartialTrainerConfigDict, \ PolicyID, TrainerConfigDict from ray.tune.utils.util import merge_dicts if TYPE_CHECKING: from ray.rllib.policy.policy import Policy tf1, tf, tfv = try_import_tf() class PolicyMap(dict): """Maps policy IDs to Policy objects. Thereby, keeps n policies in memory and - when capacity is reached - writes the least recently used to disk. This allows adding 100s of policies to a Trainer for league-based setups w/o running out of memory. """ def __init__( self, worker_index: int, num_workers: int, capacity: Optional[int] = None, path: Optional[str] = None, policy_config: Optional[TrainerConfigDict] = None, session_creator: Optional[Callable[[], "tf1.Session"]] = None, seed: Optional[int] = None, ): """Initializes a PolicyMap instance. Args: worker_index (int): The worker index of the RolloutWorker this map resides in. num_workers (int): The total number of remote workers in the WorkerSet to which this map's RolloutWorker belongs to. capacity (int): The maximum number of policies to hold in memory. The least used ones are written to disk/S3 and retrieved when needed. path (str): The path to store the policy pickle files to. Files will have the name: [policy_id].[worker idx].policy.pkl. policy_config (TrainerConfigDict): The Trainer's base config dict. session_creator (Optional[Callable[[], tf1.Session]): An optional tf1.Session creation callable. seed (int): An optional seed (used to seed tf policies). """ super().__init__() self.worker_index = worker_index self.num_workers = num_workers self.session_creator = session_creator self.seed = seed # The file extension for stashed policies (that are no longer available # in-memory but can be reinstated any time from storage). self.extension = f".{self.worker_index}.policy.pkl" # Dictionary of keys that may be looked up (cached or not). self.valid_keys: Set[str] = set() # The actual cache with the in-memory policy objects. self.cache: Dict[str, Policy] = {} # The doubly-linked list holding the currently in-memory objects. self.deque = deque(maxlen=capacity or 10) # The file path where to store overflowing policies. self.path = path or "." # The core config to use. Each single policy's config override is # added on top of this. self.policy_config: TrainerConfigDict = policy_config or {} # The orig classes/obs+act spaces, and config overrides of the # Policies. self.policy_specs: Dict[PolicyID, PolicySpec] = {} # Lock used for locking some methods on the object-level. # This prevents possible race conditions when accessing the map # and the underlying structures, like self.deque and others. self._lock = threading.RLock() def create_policy(self, policy_id: PolicyID, policy_cls: Type["Policy"], observation_space: gym.Space, action_space: gym.Space, config_override: PartialTrainerConfigDict, merged_config: TrainerConfigDict) -> None: """Creates a new policy and stores it to the cache. Args: policy_id (PolicyID): The policy ID. This is the key under which the created policy will be stored in this map. policy_cls (Type[Policy]): The (original) policy class to use. This may still be altered in case tf-eager (and tracing) is used. observation_space (gym.Space): The observation space of the policy. action_space (gym.Space): The action space of the policy. config_override (PartialTrainerConfigDict): The config override dict for this policy. This is the partial dict provided by the user. merged_config (TrainerConfigDict): The entire config (merged default config + `config_override`). """ framework = merged_config.get("framework", "tf") class_ = get_tf_eager_cls_if_necessary(policy_cls, merged_config) # Tf. if framework in ["tf2", "tf", "tfe"]: var_scope = policy_id + (("_wk" + str(self.worker_index)) if self.worker_index else "") # For tf static graph, build every policy in its own graph # and create a new session for it. if framework == "tf": with tf1.Graph().as_default(): if self.session_creator: sess = self.session_creator() else: sess = tf1.Session( config=tf1.ConfigProto( gpu_options=tf1.GPUOptions(allow_growth=True))) with sess.as_default(): # Set graph-level seed. if self.seed is not None: tf1.set_random_seed(self.seed) with tf1.variable_scope(var_scope): self[policy_id] = class_( observation_space, action_space, merged_config) # For tf-eager: no graph, no session. else: with tf1.variable_scope(var_scope): self[policy_id] = \ class_(observation_space, action_space, merged_config) # Non-tf: No graph, no session. else: class_ = policy_cls self[policy_id] = class_(observation_space, action_space, merged_config) # Store spec (class, obs-space, act-space, and config overrides) such # that the map will be able to reproduce on-the-fly added policies # from disk. self.policy_specs[policy_id] = PolicySpec( policy_class=policy_cls, observation_space=observation_space, action_space=action_space, config=config_override) @with_lock @override(dict) def __getitem__(self, item): # Never seen this key -> Error. if item not in self.valid_keys: raise KeyError(f"PolicyID '{item}' not found in this PolicyMap!") # Item already in cache -> Rearrange deque (least recently used) and # return. if item in self.cache: self.deque.remove(item) self.deque.append(item) # Item not currently in cache -> Get from disk and - if at capacity - # remove leftmost one. else: self._read_from_disk(policy_id=item) return self.cache[item] @with_lock @override(dict) def __setitem__(self, key, value): # Item already in cache -> Rearrange deque (least recently used). if key in self.cache: self.deque.remove(key) self.deque.append(key) self.cache[key] = value # Item not currently in cache -> store new value and - if at capacity - # remove leftmost one. else: # Cache at capacity -> Drop leftmost item. if len(self.deque) == self.deque.maxlen: self._stash_to_disk() self.deque.append(key) self.cache[key] = value self.valid_keys.add(key) @with_lock @override(dict) def __delitem__(self, key): # Make key invalid. self.valid_keys.remove(key) # Remove policy from memory if currently cached. if key in self.cache: policy = self.cache[key] self._close_session(policy) del self.cache[key] # Remove file associated with the policy, if it exists. filename = self.path + "/" + key + self.extension if os.path.isfile(filename): os.remove(filename) @override(dict) def __iter__(self): return iter(self.keys()) @override(dict) def items(self): """Iterates over all policies, even the stashed-to-disk ones.""" def gen(): for key in self.valid_keys: yield (key, self[key]) return gen() @override(dict) def keys(self): self._lock.acquire() ks = list(self.valid_keys) self._lock.release() def gen(): for key in ks: yield key return gen() @override(dict) def values(self): self._lock.acquire() vs = [self[k] for k in self.valid_keys] self._lock.release() def gen(): for value in vs: yield value return gen() @with_lock @override(dict) def update(self, __m, **kwargs): for k, v in __m.items(): self[k] = v for k, v in kwargs.items(): self[k] = v @with_lock @override(dict) def get(self, key): if key not in self.valid_keys: return None return self[key] @with_lock @override(dict) def __len__(self): """Returns number of all policies, including the stashed-to-disk ones. """ return len(self.valid_keys) @with_lock @override(dict) def __contains__(self, item): return item in self.valid_keys def _stash_to_disk(self): """Writes the least-recently used policy to disk and rearranges cache. Also closes the session - if applicable - of the stashed policy. """ # Get least recently used policy (all the way on the left in deque). delkey = self.deque.popleft() policy = self.cache[delkey] # Get its state for writing to disk. policy_state = policy.get_state() # Closes policy's tf session, if any. self._close_session(policy) # Remove from memory. This will clear the tf Graph as well. del self.cache[delkey] # Write state to disk. with open(self.path + "/" + delkey + self.extension, "wb") as f: pickle.dump(policy_state, file=f) def _read_from_disk(self, policy_id): """Reads a policy ID from disk and re-adds it to the cache. """ # Make sure this policy ID is not in the cache right now. assert policy_id not in self.cache # Read policy state from disk. with open(self.path + "/" + policy_id + self.extension, "rb") as f: policy_state = pickle.load(f) # Get class and config override. merged_conf = merge_dicts(self.policy_config, self.policy_specs[policy_id].config) # Create policy object (from its spec: cls, obs-space, act-space, # config). self.create_policy( policy_id, self.policy_specs[policy_id].policy_class, self.policy_specs[policy_id].observation_space, self.policy_specs[policy_id].action_space, self.policy_specs[policy_id].config, merged_conf, ) # Restore policy's state. policy = self[policy_id] policy.set_state(policy_state) def _close_session(self, policy): sess = policy.get_session() # Closes the tf session, if any. if sess is not None: sess.close()