dynamic_tf_policy.py 56 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362
  1. from collections import namedtuple, OrderedDict
  2. import gymnasium as gym
  3. import logging
  4. import re
  5. import tree # pip install dm_tree
  6. from typing import Callable, Dict, List, Optional, Tuple, Type, Union
  7. from ray.util.debug import log_once
  8. from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
  9. from ray.rllib.models.modelv2 import ModelV2
  10. from ray.rllib.policy.policy import Policy
  11. from ray.rllib.policy.sample_batch import SampleBatch
  12. from ray.rllib.policy.tf_policy import TFPolicy
  13. from ray.rllib.policy.view_requirement import ViewRequirement
  14. from ray.rllib.models.catalog import ModelCatalog
  15. from ray.rllib.utils import force_list
  16. from ray.rllib.utils.annotations import override, DeveloperAPI
  17. from ray.rllib.utils.debug import summarize
  18. from ray.rllib.utils.deprecation import (
  19. deprecation_warning,
  20. DEPRECATED_VALUE,
  21. )
  22. from ray.rllib.utils.framework import try_import_tf
  23. from ray.rllib.utils.metrics import (
  24. DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY,
  25. NUM_GRAD_UPDATES_LIFETIME,
  26. )
  27. from ray.rllib.utils.spaces.space_utils import get_dummy_batch_for_space
  28. from ray.rllib.utils.tf_utils import get_placeholder
  29. from ray.rllib.utils.typing import (
  30. LocalOptimizer,
  31. ModelGradients,
  32. TensorType,
  33. AlgorithmConfigDict,
  34. )
  35. tf1, tf, tfv = try_import_tf()
  36. logger = logging.getLogger(__name__)
  37. # Variable scope in which created variables will be placed under.
  38. TOWER_SCOPE_NAME = "tower"
  39. @DeveloperAPI
  40. class DynamicTFPolicy(TFPolicy):
  41. """A TFPolicy that auto-defines placeholders dynamically at runtime.
  42. Do not sub-class this class directly (neither should you sub-class
  43. TFPolicy), but rather use rllib.policy.tf_policy_template.build_tf_policy
  44. to generate your custom tf (graph-mode or eager) Policy classes.
  45. """
  46. @DeveloperAPI
  47. def __init__(
  48. self,
  49. obs_space: gym.spaces.Space,
  50. action_space: gym.spaces.Space,
  51. config: AlgorithmConfigDict,
  52. loss_fn: Callable[
  53. [Policy, ModelV2, Type[TFActionDistribution], SampleBatch], TensorType
  54. ],
  55. *,
  56. stats_fn: Optional[
  57. Callable[[Policy, SampleBatch], Dict[str, TensorType]]
  58. ] = None,
  59. grad_stats_fn: Optional[
  60. Callable[[Policy, SampleBatch, ModelGradients], Dict[str, TensorType]]
  61. ] = None,
  62. before_loss_init: Optional[
  63. Callable[
  64. [Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], None
  65. ]
  66. ] = None,
  67. make_model: Optional[
  68. Callable[
  69. [Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict],
  70. ModelV2,
  71. ]
  72. ] = None,
  73. action_sampler_fn: Optional[
  74. Callable[
  75. [TensorType, List[TensorType]],
  76. Union[
  77. Tuple[TensorType, TensorType],
  78. Tuple[TensorType, TensorType, TensorType, List[TensorType]],
  79. ],
  80. ]
  81. ] = None,
  82. action_distribution_fn: Optional[
  83. Callable[
  84. [Policy, ModelV2, TensorType, TensorType, TensorType],
  85. Tuple[TensorType, type, List[TensorType]],
  86. ]
  87. ] = None,
  88. existing_inputs: Optional[Dict[str, "tf1.placeholder"]] = None,
  89. existing_model: Optional[ModelV2] = None,
  90. get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None,
  91. obs_include_prev_action_reward=DEPRECATED_VALUE,
  92. ):
  93. """Initializes a DynamicTFPolicy instance.
  94. Initialization of this class occurs in two phases and defines the
  95. static graph.
  96. Phase 1: The model is created and model variables are initialized.
  97. Phase 2: A fake batch of data is created, sent to the trajectory
  98. postprocessor, and then used to create placeholders for the loss
  99. function. The loss and stats functions are initialized with these
  100. placeholders.
  101. Args:
  102. observation_space: Observation space of the policy.
  103. action_space: Action space of the policy.
  104. config: Policy-specific configuration data.
  105. loss_fn: Function that returns a loss tensor for the policy graph.
  106. stats_fn: Optional callable that - given the policy and batch
  107. input tensors - returns a dict mapping str to TF ops.
  108. These ops are fetched from the graph after loss calculations
  109. and the resulting values can be found in the results dict
  110. returned by e.g. `Algorithm.train()` or in tensorboard (if TB
  111. logging is enabled).
  112. grad_stats_fn: Optional callable that - given the policy, batch
  113. input tensors, and calculated loss gradient tensors - returns
  114. a dict mapping str to TF ops. These ops are fetched from the
  115. graph after loss and gradient calculations and the resulting
  116. values can be found in the results dict returned by e.g.
  117. `Algorithm.train()` or in tensorboard (if TB logging is
  118. enabled).
  119. before_loss_init: Optional function to run prior to
  120. loss init that takes the same arguments as __init__.
  121. make_model: Optional function that returns a ModelV2 object
  122. given policy, obs_space, action_space, and policy config.
  123. All policy variables should be created in this function. If not
  124. specified, a default model will be created.
  125. action_sampler_fn: A callable returning either a sampled action and
  126. its log-likelihood or a sampled action, its log-likelihood,
  127. action distribution inputs and updated state given Policy,
  128. ModelV2, observation inputs, explore, and is_training.
  129. Provide `action_sampler_fn` if you would like to have full
  130. control over the action computation step, including the
  131. model forward pass, possible sampling from a distribution,
  132. and exploration logic.
  133. Note: If `action_sampler_fn` is given, `action_distribution_fn`
  134. must be None. If both `action_sampler_fn` and
  135. `action_distribution_fn` are None, RLlib will simply pass
  136. inputs through `self.model` to get distribution inputs, create
  137. the distribution object, sample from it, and apply some
  138. exploration logic to the results.
  139. The callable takes as inputs: Policy, ModelV2, obs_batch,
  140. state_batches (optional), seq_lens (optional),
  141. prev_actions_batch (optional), prev_rewards_batch (optional),
  142. explore, and is_training.
  143. action_distribution_fn: A callable returning distribution inputs
  144. (parameters), a dist-class to generate an action distribution
  145. object from, and internal-state outputs (or an empty list if
  146. not applicable).
  147. Provide `action_distribution_fn` if you would like to only
  148. customize the model forward pass call. The resulting
  149. distribution parameters are then used by RLlib to create a
  150. distribution object, sample from it, and execute any
  151. exploration logic.
  152. Note: If `action_distribution_fn` is given, `action_sampler_fn`
  153. must be None. If both `action_sampler_fn` and
  154. `action_distribution_fn` are None, RLlib will simply pass
  155. inputs through `self.model` to get distribution inputs, create
  156. the distribution object, sample from it, and apply some
  157. exploration logic to the results.
  158. The callable takes as inputs: Policy, ModelV2, input_dict,
  159. explore, timestep, is_training.
  160. existing_inputs: When copying a policy, this specifies an existing
  161. dict of placeholders to use instead of defining new ones.
  162. existing_model: When copying a policy, this specifies an existing
  163. model to clone and share weights with.
  164. get_batch_divisibility_req: Optional callable that returns the
  165. divisibility requirement for sample batches. If None, will
  166. assume a value of 1.
  167. """
  168. if obs_include_prev_action_reward != DEPRECATED_VALUE:
  169. deprecation_warning(old="obs_include_prev_action_reward", error=True)
  170. self.observation_space = obs_space
  171. self.action_space = action_space
  172. self.config = config
  173. self.framework = "tf"
  174. self._loss_fn = loss_fn
  175. self._stats_fn = stats_fn
  176. self._grad_stats_fn = grad_stats_fn
  177. self._seq_lens = None
  178. self._is_tower = existing_inputs is not None
  179. dist_class = None
  180. if action_sampler_fn or action_distribution_fn:
  181. if not make_model:
  182. raise ValueError(
  183. "`make_model` is required if `action_sampler_fn` OR "
  184. "`action_distribution_fn` is given"
  185. )
  186. else:
  187. dist_class, logit_dim = ModelCatalog.get_action_dist(
  188. action_space, self.config["model"]
  189. )
  190. # Setup self.model.
  191. if existing_model:
  192. if isinstance(existing_model, list):
  193. self.model = existing_model[0]
  194. # TODO: (sven) hack, but works for `target_[q_]?model`.
  195. for i in range(1, len(existing_model)):
  196. setattr(self, existing_model[i][0], existing_model[i][1])
  197. elif make_model:
  198. self.model = make_model(self, obs_space, action_space, config)
  199. else:
  200. self.model = ModelCatalog.get_model_v2(
  201. obs_space=obs_space,
  202. action_space=action_space,
  203. num_outputs=logit_dim,
  204. model_config=self.config["model"],
  205. framework="tf",
  206. )
  207. # Auto-update model's inference view requirements, if recurrent.
  208. self._update_model_view_requirements_from_init_state()
  209. # Input placeholders already given -> Use these.
  210. if existing_inputs:
  211. self._state_inputs = [
  212. v for k, v in existing_inputs.items() if k.startswith("state_in_")
  213. ]
  214. # Placeholder for RNN time-chunk valid lengths.
  215. if self._state_inputs:
  216. self._seq_lens = existing_inputs[SampleBatch.SEQ_LENS]
  217. # Create new input placeholders.
  218. else:
  219. self._state_inputs = [
  220. get_placeholder(
  221. space=vr.space,
  222. time_axis=not isinstance(vr.shift, int),
  223. name=k,
  224. )
  225. for k, vr in self.model.view_requirements.items()
  226. if k.startswith("state_in_")
  227. ]
  228. # Placeholder for RNN time-chunk valid lengths.
  229. if self._state_inputs:
  230. self._seq_lens = tf1.placeholder(
  231. dtype=tf.int32, shape=[None], name="seq_lens"
  232. )
  233. # Use default settings.
  234. # Add NEXT_OBS, STATE_IN_0.., and others.
  235. self.view_requirements = self._get_default_view_requirements()
  236. # Combine view_requirements for Model and Policy.
  237. self.view_requirements.update(self.model.view_requirements)
  238. # Disable env-info placeholder.
  239. if SampleBatch.INFOS in self.view_requirements:
  240. self.view_requirements[SampleBatch.INFOS].used_for_training = False
  241. # Setup standard placeholders.
  242. if self._is_tower:
  243. timestep = existing_inputs["timestep"]
  244. explore = False
  245. self._input_dict, self._dummy_batch = self._get_input_dict_and_dummy_batch(
  246. self.view_requirements, existing_inputs
  247. )
  248. else:
  249. if not self.config.get("_disable_action_flattening"):
  250. action_ph = ModelCatalog.get_action_placeholder(action_space)
  251. prev_action_ph = {}
  252. if SampleBatch.PREV_ACTIONS not in self.view_requirements:
  253. prev_action_ph = {
  254. SampleBatch.PREV_ACTIONS: ModelCatalog.get_action_placeholder(
  255. action_space, "prev_action"
  256. )
  257. }
  258. (
  259. self._input_dict,
  260. self._dummy_batch,
  261. ) = self._get_input_dict_and_dummy_batch(
  262. self.view_requirements,
  263. dict({SampleBatch.ACTIONS: action_ph}, **prev_action_ph),
  264. )
  265. else:
  266. (
  267. self._input_dict,
  268. self._dummy_batch,
  269. ) = self._get_input_dict_and_dummy_batch(self.view_requirements, {})
  270. # Placeholder for (sampling steps) timestep (int).
  271. timestep = tf1.placeholder_with_default(
  272. tf.zeros((), dtype=tf.int64), (), name="timestep"
  273. )
  274. # Placeholder for `is_exploring` flag.
  275. explore = tf1.placeholder_with_default(True, (), name="is_exploring")
  276. # Placeholder for `is_training` flag.
  277. self._input_dict.set_training(self._get_is_training_placeholder())
  278. # Multi-GPU towers do not need any action computing/exploration
  279. # graphs.
  280. sampled_action = None
  281. sampled_action_logp = None
  282. dist_inputs = None
  283. extra_action_fetches = {}
  284. self._state_out = None
  285. if not self._is_tower:
  286. # Create the Exploration object to use for this Policy.
  287. self.exploration = self._create_exploration()
  288. # Fully customized action generation (e.g., custom policy).
  289. if action_sampler_fn:
  290. action_sampler_outputs = action_sampler_fn(
  291. self,
  292. self.model,
  293. obs_batch=self._input_dict[SampleBatch.CUR_OBS],
  294. state_batches=self._state_inputs,
  295. seq_lens=self._seq_lens,
  296. prev_action_batch=self._input_dict.get(SampleBatch.PREV_ACTIONS),
  297. prev_reward_batch=self._input_dict.get(SampleBatch.PREV_REWARDS),
  298. explore=explore,
  299. is_training=self._input_dict.is_training,
  300. )
  301. if len(action_sampler_outputs) == 4:
  302. (
  303. sampled_action,
  304. sampled_action_logp,
  305. dist_inputs,
  306. self._state_out,
  307. ) = action_sampler_outputs
  308. else:
  309. dist_inputs = None
  310. self._state_out = []
  311. sampled_action, sampled_action_logp = action_sampler_outputs
  312. # Distribution generation is customized, e.g., DQN, DDPG.
  313. else:
  314. if action_distribution_fn:
  315. # Try new action_distribution_fn signature, supporting
  316. # state_batches and seq_lens.
  317. in_dict = self._input_dict
  318. try:
  319. (
  320. dist_inputs,
  321. dist_class,
  322. self._state_out,
  323. ) = action_distribution_fn(
  324. self,
  325. self.model,
  326. input_dict=in_dict,
  327. state_batches=self._state_inputs,
  328. seq_lens=self._seq_lens,
  329. explore=explore,
  330. timestep=timestep,
  331. is_training=in_dict.is_training,
  332. )
  333. # Trying the old way (to stay backward compatible).
  334. # TODO: Remove in future.
  335. except TypeError as e:
  336. if (
  337. "positional argument" in e.args[0]
  338. or "unexpected keyword argument" in e.args[0]
  339. ):
  340. (
  341. dist_inputs,
  342. dist_class,
  343. self._state_out,
  344. ) = action_distribution_fn(
  345. self,
  346. self.model,
  347. obs_batch=in_dict[SampleBatch.CUR_OBS],
  348. state_batches=self._state_inputs,
  349. seq_lens=self._seq_lens,
  350. prev_action_batch=in_dict.get(SampleBatch.PREV_ACTIONS),
  351. prev_reward_batch=in_dict.get(SampleBatch.PREV_REWARDS),
  352. explore=explore,
  353. is_training=in_dict.is_training,
  354. )
  355. else:
  356. raise e
  357. # Default distribution generation behavior:
  358. # Pass through model. E.g., PG, PPO.
  359. else:
  360. if isinstance(self.model, tf.keras.Model):
  361. dist_inputs, self._state_out, extra_action_fetches = self.model(
  362. self._input_dict
  363. )
  364. else:
  365. dist_inputs, self._state_out = self.model(self._input_dict)
  366. action_dist = dist_class(dist_inputs, self.model)
  367. # Using exploration to get final action (e.g. via sampling).
  368. (
  369. sampled_action,
  370. sampled_action_logp,
  371. ) = self.exploration.get_exploration_action(
  372. action_distribution=action_dist, timestep=timestep, explore=explore
  373. )
  374. if dist_inputs is not None:
  375. extra_action_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
  376. if sampled_action_logp is not None:
  377. extra_action_fetches[SampleBatch.ACTION_LOGP] = sampled_action_logp
  378. extra_action_fetches[SampleBatch.ACTION_PROB] = tf.exp(
  379. tf.cast(sampled_action_logp, tf.float32)
  380. )
  381. # Phase 1 init.
  382. sess = tf1.get_default_session() or tf1.Session(
  383. config=tf1.ConfigProto(**self.config["tf_session_args"])
  384. )
  385. batch_divisibility_req = (
  386. get_batch_divisibility_req(self)
  387. if callable(get_batch_divisibility_req)
  388. else (get_batch_divisibility_req or 1)
  389. )
  390. prev_action_input = (
  391. self._input_dict[SampleBatch.PREV_ACTIONS]
  392. if SampleBatch.PREV_ACTIONS in self._input_dict.accessed_keys
  393. else None
  394. )
  395. prev_reward_input = (
  396. self._input_dict[SampleBatch.PREV_REWARDS]
  397. if SampleBatch.PREV_REWARDS in self._input_dict.accessed_keys
  398. else None
  399. )
  400. super().__init__(
  401. observation_space=obs_space,
  402. action_space=action_space,
  403. config=config,
  404. sess=sess,
  405. obs_input=self._input_dict[SampleBatch.OBS],
  406. action_input=self._input_dict[SampleBatch.ACTIONS],
  407. sampled_action=sampled_action,
  408. sampled_action_logp=sampled_action_logp,
  409. dist_inputs=dist_inputs,
  410. dist_class=dist_class,
  411. loss=None, # dynamically initialized on run
  412. loss_inputs=[],
  413. model=self.model,
  414. state_inputs=self._state_inputs,
  415. state_outputs=self._state_out,
  416. prev_action_input=prev_action_input,
  417. prev_reward_input=prev_reward_input,
  418. seq_lens=self._seq_lens,
  419. max_seq_len=config["model"]["max_seq_len"],
  420. batch_divisibility_req=batch_divisibility_req,
  421. explore=explore,
  422. timestep=timestep,
  423. )
  424. # Phase 2 init.
  425. if before_loss_init is not None:
  426. before_loss_init(self, obs_space, action_space, config)
  427. if hasattr(self, "_extra_action_fetches"):
  428. self._extra_action_fetches.update(extra_action_fetches)
  429. else:
  430. self._extra_action_fetches = extra_action_fetches
  431. # Loss initialization and model/postprocessing test calls.
  432. if not self._is_tower:
  433. self._initialize_loss_from_dummy_batch(auto_remove_unneeded_view_reqs=True)
  434. # Create MultiGPUTowerStacks, if we have at least one actual
  435. # GPU or >1 CPUs (fake GPUs).
  436. if len(self.devices) > 1 or any("gpu" in d for d in self.devices):
  437. # Per-GPU graph copies created here must share vars with the
  438. # policy. Therefore, `reuse` is set to tf1.AUTO_REUSE because
  439. # Adam nodes are created after all of the device copies are
  440. # created.
  441. with tf1.variable_scope("", reuse=tf1.AUTO_REUSE):
  442. self.multi_gpu_tower_stacks = [
  443. TFMultiGPUTowerStack(policy=self)
  444. for i in range(self.config.get("num_multi_gpu_tower_stacks", 1))
  445. ]
  446. # Initialize again after loss and tower init.
  447. self.get_session().run(tf1.global_variables_initializer())
  448. @override(TFPolicy)
  449. @DeveloperAPI
  450. def copy(self, existing_inputs: List[Tuple[str, "tf1.placeholder"]]) -> TFPolicy:
  451. """Creates a copy of self using existing input placeholders."""
  452. flat_loss_inputs = tree.flatten(self._loss_input_dict)
  453. flat_loss_inputs_no_rnn = tree.flatten(self._loss_input_dict_no_rnn)
  454. # Note that there might be RNN state inputs at the end of the list
  455. if len(flat_loss_inputs) != len(existing_inputs):
  456. raise ValueError(
  457. "Tensor list mismatch",
  458. self._loss_input_dict,
  459. self._state_inputs,
  460. existing_inputs,
  461. )
  462. for i, v in enumerate(flat_loss_inputs_no_rnn):
  463. if v.shape.as_list() != existing_inputs[i].shape.as_list():
  464. raise ValueError(
  465. "Tensor shape mismatch", i, v.shape, existing_inputs[i].shape
  466. )
  467. # By convention, the loss inputs are followed by state inputs and then
  468. # the seq len tensor.
  469. rnn_inputs = []
  470. for i in range(len(self._state_inputs)):
  471. rnn_inputs.append(
  472. (
  473. "state_in_{}".format(i),
  474. existing_inputs[len(flat_loss_inputs_no_rnn) + i],
  475. )
  476. )
  477. if rnn_inputs:
  478. rnn_inputs.append((SampleBatch.SEQ_LENS, existing_inputs[-1]))
  479. existing_inputs_unflattened = tree.unflatten_as(
  480. self._loss_input_dict_no_rnn,
  481. existing_inputs[: len(flat_loss_inputs_no_rnn)],
  482. )
  483. input_dict = OrderedDict(
  484. [("is_exploring", self._is_exploring), ("timestep", self._timestep)]
  485. + [
  486. (k, existing_inputs_unflattened[k])
  487. for i, k in enumerate(self._loss_input_dict_no_rnn.keys())
  488. ]
  489. + rnn_inputs
  490. )
  491. instance = self.__class__(
  492. self.observation_space,
  493. self.action_space,
  494. self.config,
  495. existing_inputs=input_dict,
  496. existing_model=[
  497. self.model,
  498. # Deprecated: Target models should all reside under
  499. # `policy.target_model` now.
  500. ("target_q_model", getattr(self, "target_q_model", None)),
  501. ("target_model", getattr(self, "target_model", None)),
  502. ],
  503. )
  504. instance._loss_input_dict = input_dict
  505. losses = instance._do_loss_init(SampleBatch(input_dict))
  506. loss_inputs = [
  507. (k, existing_inputs_unflattened[k])
  508. for i, k in enumerate(self._loss_input_dict_no_rnn.keys())
  509. ]
  510. TFPolicy._initialize_loss(instance, losses, loss_inputs)
  511. if instance._grad_stats_fn:
  512. instance._stats_fetches.update(
  513. instance._grad_stats_fn(instance, input_dict, instance._grads)
  514. )
  515. return instance
  516. @override(Policy)
  517. @DeveloperAPI
  518. def get_initial_state(self) -> List[TensorType]:
  519. if self.model:
  520. return self.model.get_initial_state()
  521. else:
  522. return []
  523. @override(Policy)
  524. @DeveloperAPI
  525. def load_batch_into_buffer(
  526. self,
  527. batch: SampleBatch,
  528. buffer_index: int = 0,
  529. ) -> int:
  530. # Set the is_training flag of the batch.
  531. batch.set_training(True)
  532. # Shortcut for 1 CPU only: Store batch in
  533. # `self._loaded_single_cpu_batch`.
  534. if len(self.devices) == 1 and self.devices[0] == "/cpu:0":
  535. assert buffer_index == 0
  536. self._loaded_single_cpu_batch = batch
  537. return len(batch)
  538. input_dict = self._get_loss_inputs_dict(batch, shuffle=False)
  539. data_keys = tree.flatten(self._loss_input_dict_no_rnn)
  540. if self._state_inputs:
  541. state_keys = self._state_inputs + [self._seq_lens]
  542. else:
  543. state_keys = []
  544. inputs = [input_dict[k] for k in data_keys]
  545. state_inputs = [input_dict[k] for k in state_keys]
  546. return self.multi_gpu_tower_stacks[buffer_index].load_data(
  547. sess=self.get_session(),
  548. inputs=inputs,
  549. state_inputs=state_inputs,
  550. num_grad_updates=batch.num_grad_updates,
  551. )
  552. @override(Policy)
  553. @DeveloperAPI
  554. def get_num_samples_loaded_into_buffer(self, buffer_index: int = 0) -> int:
  555. # Shortcut for 1 CPU only: Batch should already be stored in
  556. # `self._loaded_single_cpu_batch`.
  557. if len(self.devices) == 1 and self.devices[0] == "/cpu:0":
  558. assert buffer_index == 0
  559. return (
  560. len(self._loaded_single_cpu_batch)
  561. if self._loaded_single_cpu_batch is not None
  562. else 0
  563. )
  564. return self.multi_gpu_tower_stacks[buffer_index].num_tuples_loaded
  565. @override(Policy)
  566. @DeveloperAPI
  567. def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0):
  568. # Shortcut for 1 CPU only: Batch should already be stored in
  569. # `self._loaded_single_cpu_batch`.
  570. if len(self.devices) == 1 and self.devices[0] == "/cpu:0":
  571. assert buffer_index == 0
  572. if self._loaded_single_cpu_batch is None:
  573. raise ValueError(
  574. "Must call Policy.load_batch_into_buffer() before "
  575. "Policy.learn_on_loaded_batch()!"
  576. )
  577. # Get the correct slice of the already loaded batch to use,
  578. # based on offset and batch size.
  579. batch_size = self.config.get(
  580. "sgd_minibatch_size", self.config["train_batch_size"]
  581. )
  582. if batch_size >= len(self._loaded_single_cpu_batch):
  583. sliced_batch = self._loaded_single_cpu_batch
  584. else:
  585. sliced_batch = self._loaded_single_cpu_batch.slice(
  586. start=offset, end=offset + batch_size
  587. )
  588. return self.learn_on_batch(sliced_batch)
  589. tower_stack = self.multi_gpu_tower_stacks[buffer_index]
  590. results = tower_stack.optimize(self.get_session(), offset)
  591. self.num_grad_updates += 1
  592. results.update(
  593. {
  594. NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
  595. # -1, b/c we have to measure this diff before we do the update above.
  596. DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
  597. self.num_grad_updates - 1 - (tower_stack.num_grad_updates or 0)
  598. ),
  599. }
  600. )
  601. return results
  602. def _get_input_dict_and_dummy_batch(self, view_requirements, existing_inputs):
  603. """Creates input_dict and dummy_batch for loss initialization.
  604. Used for managing the Policy's input placeholders and for loss
  605. initialization.
  606. Input_dict: Str -> tf.placeholders, dummy_batch: str -> np.arrays.
  607. Args:
  608. view_requirements: The view requirements dict.
  609. existing_inputs (Dict[str, tf.placeholder]): A dict of already
  610. existing placeholders.
  611. Returns:
  612. Tuple[Dict[str, tf.placeholder], Dict[str, np.ndarray]]: The
  613. input_dict/dummy_batch tuple.
  614. """
  615. input_dict = {}
  616. for view_col, view_req in view_requirements.items():
  617. # Point state_in to the already existing self._state_inputs.
  618. mo = re.match(r"state_in_(\d+)", view_col)
  619. if mo is not None:
  620. input_dict[view_col] = self._state_inputs[int(mo.group(1))]
  621. # State-outs (no placeholders needed).
  622. elif view_col.startswith("state_out_"):
  623. continue
  624. # Skip action dist inputs placeholder (do later).
  625. elif view_col == SampleBatch.ACTION_DIST_INPUTS:
  626. continue
  627. # This is a tower: Input placeholders already exist.
  628. elif view_col in existing_inputs:
  629. input_dict[view_col] = existing_inputs[view_col]
  630. # All others.
  631. else:
  632. time_axis = not isinstance(view_req.shift, int)
  633. if view_req.used_for_training:
  634. # Create a +time-axis placeholder if the shift is not an
  635. # int (range or list of ints).
  636. # Do not flatten actions if action flattening disabled.
  637. if self.config.get("_disable_action_flattening") and view_col in [
  638. SampleBatch.ACTIONS,
  639. SampleBatch.PREV_ACTIONS,
  640. ]:
  641. flatten = False
  642. # Do not flatten observations if no preprocessor API used.
  643. elif (
  644. view_col in [SampleBatch.OBS, SampleBatch.NEXT_OBS]
  645. and self.config["_disable_preprocessor_api"]
  646. ):
  647. flatten = False
  648. # Flatten everything else.
  649. else:
  650. flatten = True
  651. input_dict[view_col] = get_placeholder(
  652. space=view_req.space,
  653. name=view_col,
  654. time_axis=time_axis,
  655. flatten=flatten,
  656. )
  657. dummy_batch = self._get_dummy_batch_from_view_requirements(batch_size=32)
  658. return SampleBatch(input_dict, seq_lens=self._seq_lens), dummy_batch
  659. @override(Policy)
  660. def _initialize_loss_from_dummy_batch(
  661. self, auto_remove_unneeded_view_reqs: bool = True, stats_fn=None
  662. ) -> None:
  663. # Create the optimizer/exploration optimizer here. Some initialization
  664. # steps (e.g. exploration postprocessing) may need this.
  665. if not self._optimizers:
  666. self._optimizers = force_list(self.optimizer())
  667. # Backward compatibility.
  668. self._optimizer = self._optimizers[0]
  669. # Test calls depend on variable init, so initialize model first.
  670. self.get_session().run(tf1.global_variables_initializer())
  671. # Fields that have not been accessed are not needed for action
  672. # computations -> Tag them as `used_for_compute_actions=False`.
  673. for key, view_req in self.view_requirements.items():
  674. if (
  675. not key.startswith("state_in_")
  676. and key not in self._input_dict.accessed_keys
  677. ):
  678. view_req.used_for_compute_actions = False
  679. for key, value in self._extra_action_fetches.items():
  680. self._dummy_batch[key] = get_dummy_batch_for_space(
  681. gym.spaces.Box(
  682. -1.0, 1.0, shape=value.shape.as_list()[1:], dtype=value.dtype.name
  683. ),
  684. batch_size=len(self._dummy_batch),
  685. )
  686. self._input_dict[key] = get_placeholder(value=value, name=key)
  687. if key not in self.view_requirements:
  688. logger.info("Adding extra-action-fetch `{}` to view-reqs.".format(key))
  689. self.view_requirements[key] = ViewRequirement(
  690. space=gym.spaces.Box(
  691. -1.0,
  692. 1.0,
  693. shape=value.shape.as_list()[1:],
  694. dtype=value.dtype.name,
  695. ),
  696. used_for_compute_actions=False,
  697. )
  698. dummy_batch = self._dummy_batch
  699. logger.info("Testing `postprocess_trajectory` w/ dummy batch.")
  700. self.exploration.postprocess_trajectory(self, dummy_batch, self.get_session())
  701. _ = self.postprocess_trajectory(dummy_batch)
  702. # Add new columns automatically to (loss) input_dict.
  703. for key in dummy_batch.added_keys:
  704. if key not in self._input_dict:
  705. self._input_dict[key] = get_placeholder(
  706. value=dummy_batch[key], name=key
  707. )
  708. if key not in self.view_requirements:
  709. self.view_requirements[key] = ViewRequirement(
  710. space=gym.spaces.Box(
  711. -1.0,
  712. 1.0,
  713. shape=dummy_batch[key].shape[1:],
  714. dtype=dummy_batch[key].dtype,
  715. ),
  716. used_for_compute_actions=False,
  717. )
  718. train_batch = SampleBatch(
  719. dict(self._input_dict, **self._loss_input_dict),
  720. _is_training=True,
  721. )
  722. if self._state_inputs:
  723. train_batch[SampleBatch.SEQ_LENS] = self._seq_lens
  724. self._loss_input_dict.update(
  725. {SampleBatch.SEQ_LENS: train_batch[SampleBatch.SEQ_LENS]}
  726. )
  727. self._loss_input_dict.update({k: v for k, v in train_batch.items()})
  728. if log_once("loss_init"):
  729. logger.debug(
  730. "Initializing loss function with dummy input:\n\n{}\n".format(
  731. summarize(train_batch)
  732. )
  733. )
  734. losses = self._do_loss_init(train_batch)
  735. all_accessed_keys = (
  736. train_batch.accessed_keys
  737. | dummy_batch.accessed_keys
  738. | dummy_batch.added_keys
  739. | set(self.model.view_requirements.keys())
  740. )
  741. TFPolicy._initialize_loss(
  742. self,
  743. losses,
  744. [(k, v) for k, v in train_batch.items() if k in all_accessed_keys]
  745. + (
  746. [(SampleBatch.SEQ_LENS, train_batch[SampleBatch.SEQ_LENS])]
  747. if SampleBatch.SEQ_LENS in train_batch
  748. else []
  749. ),
  750. )
  751. if "is_training" in self._loss_input_dict:
  752. del self._loss_input_dict["is_training"]
  753. # Call the grads stats fn.
  754. # TODO: (sven) rename to simply stats_fn to match eager and torch.
  755. if self._grad_stats_fn:
  756. self._stats_fetches.update(
  757. self._grad_stats_fn(self, train_batch, self._grads)
  758. )
  759. # Add new columns automatically to view-reqs.
  760. if auto_remove_unneeded_view_reqs:
  761. # Add those needed for postprocessing and training.
  762. all_accessed_keys = train_batch.accessed_keys | dummy_batch.accessed_keys
  763. # Tag those only needed for post-processing (with some exceptions).
  764. for key in dummy_batch.accessed_keys:
  765. if (
  766. key not in train_batch.accessed_keys
  767. and key not in self.model.view_requirements
  768. and key
  769. not in [
  770. SampleBatch.EPS_ID,
  771. SampleBatch.AGENT_INDEX,
  772. SampleBatch.UNROLL_ID,
  773. SampleBatch.TERMINATEDS,
  774. SampleBatch.TRUNCATEDS,
  775. SampleBatch.REWARDS,
  776. SampleBatch.INFOS,
  777. SampleBatch.T,
  778. SampleBatch.OBS_EMBEDS,
  779. ]
  780. ):
  781. if key in self.view_requirements:
  782. self.view_requirements[key].used_for_training = False
  783. if key in self._loss_input_dict:
  784. del self._loss_input_dict[key]
  785. # Remove those not needed at all (leave those that are needed
  786. # by Sampler to properly execute sample collection).
  787. # Also always leave TERMINATEDS, TRUNCATEDS, REWARDS, and INFOS,
  788. # no matter what.
  789. for key in list(self.view_requirements.keys()):
  790. if (
  791. key not in all_accessed_keys
  792. and key
  793. not in [
  794. SampleBatch.EPS_ID,
  795. SampleBatch.AGENT_INDEX,
  796. SampleBatch.UNROLL_ID,
  797. SampleBatch.TERMINATEDS,
  798. SampleBatch.TRUNCATEDS,
  799. SampleBatch.REWARDS,
  800. SampleBatch.INFOS,
  801. SampleBatch.T,
  802. ]
  803. and key not in self.model.view_requirements
  804. ):
  805. # If user deleted this key manually in postprocessing
  806. # fn, warn about it and do not remove from
  807. # view-requirements.
  808. if key in dummy_batch.deleted_keys:
  809. logger.warning(
  810. "SampleBatch key '{}' was deleted manually in "
  811. "postprocessing function! RLlib will "
  812. "automatically remove non-used items from the "
  813. "data stream. Remove the `del` from your "
  814. "postprocessing function.".format(key)
  815. )
  816. # If we are not writing output to disk, safe to erase
  817. # this key to save space in the sample batch.
  818. elif self.config["output"] is None:
  819. del self.view_requirements[key]
  820. if key in self._loss_input_dict:
  821. del self._loss_input_dict[key]
  822. # Add those data_cols (again) that are missing and have
  823. # dependencies by view_cols.
  824. for key in list(self.view_requirements.keys()):
  825. vr = self.view_requirements[key]
  826. if (
  827. vr.data_col is not None
  828. and vr.data_col not in self.view_requirements
  829. ):
  830. used_for_training = vr.data_col in train_batch.accessed_keys
  831. self.view_requirements[vr.data_col] = ViewRequirement(
  832. space=vr.space, used_for_training=used_for_training
  833. )
  834. self._loss_input_dict_no_rnn = {
  835. k: v
  836. for k, v in self._loss_input_dict.items()
  837. if (v not in self._state_inputs and v != self._seq_lens)
  838. }
  839. def _do_loss_init(self, train_batch: SampleBatch):
  840. losses = self._loss_fn(self, self.model, self.dist_class, train_batch)
  841. losses = force_list(losses)
  842. if self._stats_fn:
  843. self._stats_fetches.update(self._stats_fn(self, train_batch))
  844. # Override the update ops to be those of the model.
  845. self._update_ops = []
  846. if not isinstance(self.model, tf.keras.Model):
  847. self._update_ops = self.model.update_ops()
  848. return losses
  849. @DeveloperAPI
  850. class TFMultiGPUTowerStack:
  851. """Optimizer that runs in parallel across multiple local devices.
  852. TFMultiGPUTowerStack automatically splits up and loads training data
  853. onto specified local devices (e.g. GPUs) with `load_data()`. During a call
  854. to `optimize()`, the devices compute gradients over slices of the data in
  855. parallel. The gradients are then averaged and applied to the shared
  856. weights.
  857. The data loaded is pinned in device memory until the next call to
  858. `load_data`, so you can make multiple passes (possibly in randomized order)
  859. over the same data once loaded.
  860. This is similar to tf1.train.SyncReplicasOptimizer, but works within a
  861. single TensorFlow graph, i.e. implements in-graph replicated training:
  862. https://www.tensorflow.org/api_docs/python/tf/train/SyncReplicasOptimizer
  863. """
  864. def __init__(
  865. self,
  866. # Deprecated.
  867. optimizer=None,
  868. devices=None,
  869. input_placeholders=None,
  870. rnn_inputs=None,
  871. max_per_device_batch_size=None,
  872. build_graph=None,
  873. grad_norm_clipping=None,
  874. # Use only `policy` argument from here on.
  875. policy: TFPolicy = None,
  876. ):
  877. """Initializes a TFMultiGPUTowerStack instance.
  878. Args:
  879. policy: The TFPolicy object that this tower stack
  880. belongs to.
  881. """
  882. # Obsoleted usage, use only `policy` arg from here on.
  883. if policy is None:
  884. deprecation_warning(
  885. old="TFMultiGPUTowerStack(...)",
  886. new="TFMultiGPUTowerStack(policy=[Policy])",
  887. error=True,
  888. )
  889. self.policy = None
  890. self.optimizers = optimizer
  891. self.devices = devices
  892. self.max_per_device_batch_size = max_per_device_batch_size
  893. self.policy_copy = build_graph
  894. else:
  895. self.policy: TFPolicy = policy
  896. self.optimizers: List[LocalOptimizer] = self.policy._optimizers
  897. self.devices = self.policy.devices
  898. self.max_per_device_batch_size = (
  899. max_per_device_batch_size
  900. or policy.config.get(
  901. "sgd_minibatch_size", policy.config.get("train_batch_size", 999999)
  902. )
  903. ) // len(self.devices)
  904. input_placeholders = tree.flatten(self.policy._loss_input_dict_no_rnn)
  905. rnn_inputs = []
  906. if self.policy._state_inputs:
  907. rnn_inputs = self.policy._state_inputs + [self.policy._seq_lens]
  908. grad_norm_clipping = self.policy.config.get("grad_clip")
  909. self.policy_copy = self.policy.copy
  910. assert len(self.devices) > 1 or "gpu" in self.devices[0]
  911. self.loss_inputs = input_placeholders + rnn_inputs
  912. shared_ops = tf1.get_collection(
  913. tf1.GraphKeys.UPDATE_OPS, scope=tf1.get_variable_scope().name
  914. )
  915. # Then setup the per-device loss graphs that use the shared weights
  916. self._batch_index = tf1.placeholder(tf.int32, name="batch_index")
  917. # Dynamic batch size, which may be shrunk if there isn't enough data
  918. self._per_device_batch_size = tf1.placeholder(
  919. tf.int32, name="per_device_batch_size"
  920. )
  921. self._loaded_per_device_batch_size = max_per_device_batch_size
  922. # When loading RNN input, we dynamically determine the max seq len
  923. self._max_seq_len = tf1.placeholder(tf.int32, name="max_seq_len")
  924. self._loaded_max_seq_len = 1
  925. device_placeholders = [[] for _ in range(len(self.devices))]
  926. for t in tree.flatten(self.loss_inputs):
  927. # Split on the CPU in case the data doesn't fit in GPU memory.
  928. with tf.device("/cpu:0"):
  929. splits = tf.split(t, len(self.devices))
  930. for i, d in enumerate(self.devices):
  931. device_placeholders[i].append(splits[i])
  932. self._towers = []
  933. for tower_i, (device, placeholders) in enumerate(
  934. zip(self.devices, device_placeholders)
  935. ):
  936. self._towers.append(
  937. self._setup_device(
  938. tower_i, device, placeholders, len(tree.flatten(input_placeholders))
  939. )
  940. )
  941. if self.policy.config["_tf_policy_handles_more_than_one_loss"]:
  942. avgs = []
  943. for i, optim in enumerate(self.optimizers):
  944. avg = _average_gradients([t.grads[i] for t in self._towers])
  945. if grad_norm_clipping:
  946. clipped = []
  947. for grad, _ in avg:
  948. clipped.append(grad)
  949. clipped, _ = tf.clip_by_global_norm(clipped, grad_norm_clipping)
  950. for i, (grad, var) in enumerate(avg):
  951. avg[i] = (clipped[i], var)
  952. avgs.append(avg)
  953. # Gather update ops for any batch norm layers.
  954. # TODO(ekl) here we
  955. # will use all the ops found which won't work for DQN / DDPG, but
  956. # those aren't supported with multi-gpu right now anyways.
  957. self._update_ops = tf1.get_collection(
  958. tf1.GraphKeys.UPDATE_OPS, scope=tf1.get_variable_scope().name
  959. )
  960. for op in shared_ops:
  961. self._update_ops.remove(op) # only care about tower update ops
  962. if self._update_ops:
  963. logger.debug(
  964. "Update ops to run on apply gradient: {}".format(self._update_ops)
  965. )
  966. with tf1.control_dependencies(self._update_ops):
  967. self._train_op = tf.group(
  968. [o.apply_gradients(a) for o, a in zip(self.optimizers, avgs)]
  969. )
  970. else:
  971. avg = _average_gradients([t.grads for t in self._towers])
  972. if grad_norm_clipping:
  973. clipped = []
  974. for grad, _ in avg:
  975. clipped.append(grad)
  976. clipped, _ = tf.clip_by_global_norm(clipped, grad_norm_clipping)
  977. for i, (grad, var) in enumerate(avg):
  978. avg[i] = (clipped[i], var)
  979. # Gather update ops for any batch norm layers.
  980. # TODO(ekl) here we
  981. # will use all the ops found which won't work for DQN / DDPG, but
  982. # those aren't supported with multi-gpu right now anyways.
  983. self._update_ops = tf1.get_collection(
  984. tf1.GraphKeys.UPDATE_OPS, scope=tf1.get_variable_scope().name
  985. )
  986. for op in shared_ops:
  987. self._update_ops.remove(op) # only care about tower update ops
  988. if self._update_ops:
  989. logger.debug(
  990. "Update ops to run on apply gradient: {}".format(self._update_ops)
  991. )
  992. with tf1.control_dependencies(self._update_ops):
  993. self._train_op = self.optimizers[0].apply_gradients(avg)
  994. # The lifetime number of gradient updates that the policy having sent
  995. # some data (SampleBatchType) into this tower stack's GPU buffer(s) has already
  996. # undergone.
  997. self.num_grad_updates = 0
  998. def load_data(self, sess, inputs, state_inputs, num_grad_updates=None):
  999. """Bulk loads the specified inputs into device memory.
  1000. The shape of the inputs must conform to the shapes of the input
  1001. placeholders this optimizer was constructed with.
  1002. The data is split equally across all the devices. If the data is not
  1003. evenly divisible by the batch size, excess data will be discarded.
  1004. Args:
  1005. sess: TensorFlow session.
  1006. inputs: List of arrays matching the input placeholders, of shape
  1007. [BATCH_SIZE, ...].
  1008. state_inputs: List of RNN input arrays. These arrays have size
  1009. [BATCH_SIZE / MAX_SEQ_LEN, ...].
  1010. num_grad_updates: The lifetime number of gradient updates that the
  1011. policy having collected the data has already undergone.
  1012. Returns:
  1013. The number of tuples loaded per device.
  1014. """
  1015. self.num_grad_updates = num_grad_updates
  1016. if log_once("load_data"):
  1017. logger.info(
  1018. "Training on concatenated sample batches:\n\n{}\n".format(
  1019. summarize(
  1020. {
  1021. "placeholders": self.loss_inputs,
  1022. "inputs": inputs,
  1023. "state_inputs": state_inputs,
  1024. }
  1025. )
  1026. )
  1027. )
  1028. feed_dict = {}
  1029. assert len(self.loss_inputs) == len(inputs + state_inputs), (
  1030. self.loss_inputs,
  1031. inputs,
  1032. state_inputs,
  1033. )
  1034. # Let's suppose we have the following input data, and 2 devices:
  1035. # 1 2 3 4 5 6 7 <- state inputs shape
  1036. # A A A B B B C C C D D D E E E F F F G G G <- inputs shape
  1037. # The data is truncated and split across devices as follows:
  1038. # |---| seq len = 3
  1039. # |---------------------------------| seq batch size = 6 seqs
  1040. # |----------------| per device batch size = 9 tuples
  1041. if len(state_inputs) > 0:
  1042. smallest_array = state_inputs[0]
  1043. seq_len = len(inputs[0]) // len(state_inputs[0])
  1044. self._loaded_max_seq_len = seq_len
  1045. else:
  1046. smallest_array = inputs[0]
  1047. self._loaded_max_seq_len = 1
  1048. sequences_per_minibatch = (
  1049. self.max_per_device_batch_size
  1050. // self._loaded_max_seq_len
  1051. * len(self.devices)
  1052. )
  1053. if sequences_per_minibatch < 1:
  1054. logger.warning(
  1055. (
  1056. "Target minibatch size is {}, however the rollout sequence "
  1057. "length is {}, hence the minibatch size will be raised to "
  1058. "{}."
  1059. ).format(
  1060. self.max_per_device_batch_size,
  1061. self._loaded_max_seq_len,
  1062. self._loaded_max_seq_len * len(self.devices),
  1063. )
  1064. )
  1065. sequences_per_minibatch = 1
  1066. if len(smallest_array) < sequences_per_minibatch:
  1067. # Dynamically shrink the batch size if insufficient data
  1068. sequences_per_minibatch = _make_divisible_by(
  1069. len(smallest_array), len(self.devices)
  1070. )
  1071. if log_once("data_slicing"):
  1072. logger.info(
  1073. (
  1074. "Divided {} rollout sequences, each of length {}, among "
  1075. "{} devices."
  1076. ).format(
  1077. len(smallest_array), self._loaded_max_seq_len, len(self.devices)
  1078. )
  1079. )
  1080. if sequences_per_minibatch < len(self.devices):
  1081. raise ValueError(
  1082. "Must load at least 1 tuple sequence per device. Try "
  1083. "increasing `sgd_minibatch_size` or reducing `max_seq_len` "
  1084. "to ensure that at least one sequence fits per device."
  1085. )
  1086. self._loaded_per_device_batch_size = (
  1087. sequences_per_minibatch // len(self.devices) * self._loaded_max_seq_len
  1088. )
  1089. if len(state_inputs) > 0:
  1090. # First truncate the RNN state arrays to the sequences_per_minib.
  1091. state_inputs = [
  1092. _make_divisible_by(arr, sequences_per_minibatch) for arr in state_inputs
  1093. ]
  1094. # Then truncate the data inputs to match
  1095. inputs = [arr[: len(state_inputs[0]) * seq_len] for arr in inputs]
  1096. assert len(state_inputs[0]) * seq_len == len(inputs[0]), (
  1097. len(state_inputs[0]),
  1098. sequences_per_minibatch,
  1099. seq_len,
  1100. len(inputs[0]),
  1101. )
  1102. for ph, arr in zip(self.loss_inputs, inputs + state_inputs):
  1103. feed_dict[ph] = arr
  1104. truncated_len = len(inputs[0])
  1105. else:
  1106. truncated_len = 0
  1107. for ph, arr in zip(self.loss_inputs, inputs):
  1108. truncated_arr = _make_divisible_by(arr, sequences_per_minibatch)
  1109. feed_dict[ph] = truncated_arr
  1110. if truncated_len == 0:
  1111. truncated_len = len(truncated_arr)
  1112. sess.run([t.init_op for t in self._towers], feed_dict=feed_dict)
  1113. self.num_tuples_loaded = truncated_len
  1114. samples_per_device = truncated_len // len(self.devices)
  1115. assert samples_per_device > 0, "No data loaded?"
  1116. assert samples_per_device % self._loaded_per_device_batch_size == 0
  1117. # Return loaded samples per-device.
  1118. return samples_per_device
  1119. def optimize(self, sess, batch_index):
  1120. """Run a single step of SGD.
  1121. Runs a SGD step over a slice of the preloaded batch with size given by
  1122. self._loaded_per_device_batch_size and offset given by the batch_index
  1123. argument.
  1124. Updates shared model weights based on the averaged per-device
  1125. gradients.
  1126. Args:
  1127. sess: TensorFlow session.
  1128. batch_index: Offset into the preloaded data. This value must be
  1129. between `0` and `tuples_per_device`. The amount of data to
  1130. process is at most `max_per_device_batch_size`.
  1131. Returns:
  1132. The outputs of extra_ops evaluated over the batch.
  1133. """
  1134. feed_dict = {
  1135. self._batch_index: batch_index,
  1136. self._per_device_batch_size: self._loaded_per_device_batch_size,
  1137. self._max_seq_len: self._loaded_max_seq_len,
  1138. }
  1139. for tower in self._towers:
  1140. feed_dict.update(tower.loss_graph.extra_compute_grad_feed_dict())
  1141. fetches = {"train": self._train_op}
  1142. for tower_num, tower in enumerate(self._towers):
  1143. tower_fetch = tower.loss_graph._get_grad_and_stats_fetches()
  1144. fetches["tower_{}".format(tower_num)] = tower_fetch
  1145. return sess.run(fetches, feed_dict=feed_dict)
  1146. def get_device_losses(self):
  1147. return [t.loss_graph for t in self._towers]
  1148. def _setup_device(self, tower_i, device, device_input_placeholders, num_data_in):
  1149. assert num_data_in <= len(device_input_placeholders)
  1150. with tf.device(device):
  1151. with tf1.name_scope(TOWER_SCOPE_NAME + f"_{tower_i}"):
  1152. device_input_batches = []
  1153. device_input_slices = []
  1154. for i, ph in enumerate(device_input_placeholders):
  1155. current_batch = tf1.Variable(
  1156. ph, trainable=False, validate_shape=False, collections=[]
  1157. )
  1158. device_input_batches.append(current_batch)
  1159. if i < num_data_in:
  1160. scale = self._max_seq_len
  1161. granularity = self._max_seq_len
  1162. else:
  1163. scale = self._max_seq_len
  1164. granularity = 1
  1165. current_slice = tf.slice(
  1166. current_batch,
  1167. (
  1168. [self._batch_index // scale * granularity]
  1169. + [0] * len(ph.shape[1:])
  1170. ),
  1171. (
  1172. [self._per_device_batch_size // scale * granularity]
  1173. + [-1] * len(ph.shape[1:])
  1174. ),
  1175. )
  1176. current_slice.set_shape(ph.shape)
  1177. device_input_slices.append(current_slice)
  1178. graph_obj = self.policy_copy(device_input_slices)
  1179. device_grads = graph_obj.gradients(self.optimizers, graph_obj._losses)
  1180. return _Tower(
  1181. tf.group(*[batch.initializer for batch in device_input_batches]),
  1182. device_grads,
  1183. graph_obj,
  1184. )
  1185. # Each tower is a copy of the loss graph pinned to a specific device.
  1186. _Tower = namedtuple("Tower", ["init_op", "grads", "loss_graph"])
  1187. def _make_divisible_by(a, n):
  1188. if type(a) is int:
  1189. return a - a % n
  1190. return a[0 : a.shape[0] - a.shape[0] % n]
  1191. def _average_gradients(tower_grads):
  1192. """Averages gradients across towers.
  1193. Calculate the average gradient for each shared variable across all towers.
  1194. Note that this function provides a synchronization point across all towers.
  1195. Args:
  1196. tower_grads: List of lists of (gradient, variable) tuples. The outer
  1197. list is over individual gradients. The inner list is over the
  1198. gradient calculation for each tower.
  1199. Returns:
  1200. List of pairs of (gradient, variable) where the gradient has been
  1201. averaged across all towers.
  1202. TODO(ekl): We could use NCCL if this becomes a bottleneck.
  1203. """
  1204. average_grads = []
  1205. for grad_and_vars in zip(*tower_grads):
  1206. # Note that each grad_and_vars looks like the following:
  1207. # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
  1208. grads = []
  1209. for g, _ in grad_and_vars:
  1210. if g is not None:
  1211. # Add 0 dimension to the gradients to represent the tower.
  1212. expanded_g = tf.expand_dims(g, 0)
  1213. # Append on a 'tower' dimension which we will average over
  1214. # below.
  1215. grads.append(expanded_g)
  1216. if not grads:
  1217. continue
  1218. # Average over the 'tower' dimension.
  1219. grad = tf.concat(axis=0, values=grads)
  1220. grad = tf.reduce_mean(grad, 0)
  1221. # Keep in mind that the Variables are redundant because they are shared
  1222. # across towers. So .. we will just return the first tower's pointer to
  1223. # the Variable.
  1224. v = grad_and_vars[0][1]
  1225. grad_and_var = (grad, v)
  1226. average_grads.append(grad_and_var)
  1227. return average_grads