dynamic_tf_policy.py 50 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136
  1. from collections import namedtuple, OrderedDict
  2. import gym
  3. import logging
  4. import re
  5. from typing import Callable, Dict, List, Optional, Tuple, Type
  6. from ray.util.debug import log_once
  7. from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
  8. from ray.rllib.models.modelv2 import ModelV2
  9. from ray.rllib.policy.policy import Policy
  10. from ray.rllib.policy.sample_batch import SampleBatch
  11. from ray.rllib.policy.tf_policy import TFPolicy
  12. from ray.rllib.policy.view_requirement import ViewRequirement
  13. from ray.rllib.models.catalog import ModelCatalog
  14. from ray.rllib.utils import force_list
  15. from ray.rllib.utils.annotations import override, DeveloperAPI
  16. from ray.rllib.utils.debug import summarize
  17. from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
  18. from ray.rllib.utils.framework import try_import_tf
  19. from ray.rllib.utils.tf_utils import get_placeholder
  20. from ray.rllib.utils.typing import LocalOptimizer, ModelGradients, \
  21. TensorType, TrainerConfigDict
  22. tf1, tf, tfv = try_import_tf()
  23. logger = logging.getLogger(__name__)
  24. # Variable scope in which created variables will be placed under.
  25. TOWER_SCOPE_NAME = "tower"
  26. @DeveloperAPI
  27. class DynamicTFPolicy(TFPolicy):
  28. """A TFPolicy that auto-defines placeholders dynamically at runtime.
  29. Do not sub-class this class directly (neither should you sub-class
  30. TFPolicy), but rather use rllib.policy.tf_policy_template.build_tf_policy
  31. to generate your custom tf (graph-mode or eager) Policy classes.
  32. """
  33. @DeveloperAPI
  34. def __init__(
  35. self,
  36. obs_space: gym.spaces.Space,
  37. action_space: gym.spaces.Space,
  38. config: TrainerConfigDict,
  39. loss_fn: Callable[[
  40. Policy, ModelV2, Type[TFActionDistribution], SampleBatch
  41. ], TensorType],
  42. *,
  43. stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[
  44. str, TensorType]]] = None,
  45. grad_stats_fn: Optional[Callable[[
  46. Policy, SampleBatch, ModelGradients
  47. ], Dict[str, TensorType]]] = None,
  48. before_loss_init: Optional[Callable[[
  49. Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict
  50. ], None]] = None,
  51. make_model: Optional[Callable[[
  52. Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict
  53. ], ModelV2]] = None,
  54. action_sampler_fn: Optional[Callable[[
  55. TensorType, List[TensorType]
  56. ], Tuple[TensorType, TensorType]]] = None,
  57. action_distribution_fn: Optional[Callable[[
  58. Policy, ModelV2, TensorType, TensorType, TensorType
  59. ], Tuple[TensorType, type, List[TensorType]]]] = None,
  60. existing_inputs: Optional[Dict[str, "tf1.placeholder"]] = None,
  61. existing_model: Optional[ModelV2] = None,
  62. get_batch_divisibility_req: Optional[Callable[[Policy],
  63. int]] = None,
  64. obs_include_prev_action_reward=DEPRECATED_VALUE):
  65. """Initializes a DynamicTFPolicy instance.
  66. Initialization of this class occurs in two phases and defines the
  67. static graph.
  68. Phase 1: The model is created and model variables are initialized.
  69. Phase 2: A fake batch of data is created, sent to the trajectory
  70. postprocessor, and then used to create placeholders for the loss
  71. function. The loss and stats functions are initialized with these
  72. placeholders.
  73. Args:
  74. observation_space: Observation space of the policy.
  75. action_space: Action space of the policy.
  76. config: Policy-specific configuration data.
  77. loss_fn: Function that returns a loss tensor for the policy graph.
  78. stats_fn: Optional callable that - given the policy and batch
  79. input tensors - returns a dict mapping str to TF ops.
  80. These ops are fetched from the graph after loss calculations
  81. and the resulting values can be found in the results dict
  82. returned by e.g. `Trainer.train()` or in tensorboard (if TB
  83. logging is enabled).
  84. grad_stats_fn: Optional callable that - given the policy, batch
  85. input tensors, and calculated loss gradient tensors - returns
  86. a dict mapping str to TF ops. These ops are fetched from the
  87. graph after loss and gradient calculations and the resulting
  88. values can be found in the results dict returned by e.g.
  89. `Trainer.train()` or in tensorboard (if TB logging is
  90. enabled).
  91. before_loss_init: Optional function to run prior to
  92. loss init that takes the same arguments as __init__.
  93. make_model: Optional function that returns a ModelV2 object
  94. given policy, obs_space, action_space, and policy config.
  95. All policy variables should be created in this function. If not
  96. specified, a default model will be created.
  97. action_sampler_fn: A callable returning a sampled action and its
  98. log-likelihood given Policy, ModelV2, observation inputs,
  99. explore, and is_training.
  100. Provide `action_sampler_fn` if you would like to have full
  101. control over the action computation step, including the
  102. model forward pass, possible sampling from a distribution,
  103. and exploration logic.
  104. Note: If `action_sampler_fn` is given, `action_distribution_fn`
  105. must be None. If both `action_sampler_fn` and
  106. `action_distribution_fn` are None, RLlib will simply pass
  107. inputs through `self.model` to get distribution inputs, create
  108. the distribution object, sample from it, and apply some
  109. exploration logic to the results.
  110. The callable takes as inputs: Policy, ModelV2, obs_batch,
  111. state_batches (optional), seq_lens (optional),
  112. prev_actions_batch (optional), prev_rewards_batch (optional),
  113. explore, and is_training.
  114. action_distribution_fn: A callable returning distribution inputs
  115. (parameters), a dist-class to generate an action distribution
  116. object from, and internal-state outputs (or an empty list if
  117. not applicable).
  118. Provide `action_distribution_fn` if you would like to only
  119. customize the model forward pass call. The resulting
  120. distribution parameters are then used by RLlib to create a
  121. distribution object, sample from it, and execute any
  122. exploration logic.
  123. Note: If `action_distribution_fn` is given, `action_sampler_fn`
  124. must be None. If both `action_sampler_fn` and
  125. `action_distribution_fn` are None, RLlib will simply pass
  126. inputs through `self.model` to get distribution inputs, create
  127. the distribution object, sample from it, and apply some
  128. exploration logic to the results.
  129. The callable takes as inputs: Policy, ModelV2, input_dict,
  130. explore, timestep, is_training.
  131. existing_inputs: When copying a policy, this specifies an existing
  132. dict of placeholders to use instead of defining new ones.
  133. existing_model: When copying a policy, this specifies an existing
  134. model to clone and share weights with.
  135. get_batch_divisibility_req: Optional callable that returns the
  136. divisibility requirement for sample batches. If None, will
  137. assume a value of 1.
  138. """
  139. if obs_include_prev_action_reward != DEPRECATED_VALUE:
  140. deprecation_warning(
  141. old="obs_include_prev_action_reward", error=False)
  142. self.observation_space = obs_space
  143. self.action_space = action_space
  144. self.config = config
  145. self.framework = "tf"
  146. self._loss_fn = loss_fn
  147. self._stats_fn = stats_fn
  148. self._grad_stats_fn = grad_stats_fn
  149. self._seq_lens = None
  150. self._is_tower = existing_inputs is not None
  151. dist_class = None
  152. if action_sampler_fn or action_distribution_fn:
  153. if not make_model:
  154. raise ValueError(
  155. "`make_model` is required if `action_sampler_fn` OR "
  156. "`action_distribution_fn` is given")
  157. else:
  158. dist_class, logit_dim = ModelCatalog.get_action_dist(
  159. action_space, self.config["model"])
  160. # Setup self.model.
  161. if existing_model:
  162. if isinstance(existing_model, list):
  163. self.model = existing_model[0]
  164. # TODO: (sven) hack, but works for `target_[q_]?model`.
  165. for i in range(1, len(existing_model)):
  166. setattr(self, existing_model[i][0], existing_model[i][1])
  167. elif make_model:
  168. self.model = make_model(self, obs_space, action_space, config)
  169. else:
  170. self.model = ModelCatalog.get_model_v2(
  171. obs_space=obs_space,
  172. action_space=action_space,
  173. num_outputs=logit_dim,
  174. model_config=self.config["model"],
  175. framework="tf")
  176. # Auto-update model's inference view requirements, if recurrent.
  177. self._update_model_view_requirements_from_init_state()
  178. # Input placeholders already given -> Use these.
  179. if existing_inputs:
  180. self._state_inputs = [
  181. v for k, v in existing_inputs.items()
  182. if k.startswith("state_in_")
  183. ]
  184. # Placeholder for RNN time-chunk valid lengths.
  185. if self._state_inputs:
  186. self._seq_lens = existing_inputs[SampleBatch.SEQ_LENS]
  187. # Create new input placeholders.
  188. else:
  189. self._state_inputs = [
  190. get_placeholder(
  191. space=vr.space,
  192. time_axis=not isinstance(vr.shift, int),
  193. ) for k, vr in self.model.view_requirements.items()
  194. if k.startswith("state_in_")
  195. ]
  196. # Placeholder for RNN time-chunk valid lengths.
  197. if self._state_inputs:
  198. self._seq_lens = tf1.placeholder(
  199. dtype=tf.int32, shape=[None], name="seq_lens")
  200. # Use default settings.
  201. # Add NEXT_OBS, STATE_IN_0.., and others.
  202. self.view_requirements = self._get_default_view_requirements()
  203. # Combine view_requirements for Model and Policy.
  204. self.view_requirements.update(self.model.view_requirements)
  205. # Disable env-info placeholder.
  206. if SampleBatch.INFOS in self.view_requirements:
  207. self.view_requirements[SampleBatch.INFOS].used_for_training = False
  208. # Setup standard placeholders.
  209. if self._is_tower:
  210. timestep = existing_inputs["timestep"]
  211. explore = False
  212. self._input_dict, self._dummy_batch = \
  213. self._get_input_dict_and_dummy_batch(
  214. self.view_requirements, existing_inputs)
  215. else:
  216. action_ph = ModelCatalog.get_action_placeholder(action_space)
  217. prev_action_ph = {}
  218. if SampleBatch.PREV_ACTIONS not in self.view_requirements:
  219. prev_action_ph = {
  220. SampleBatch.PREV_ACTIONS: ModelCatalog.
  221. get_action_placeholder(action_space, "prev_action")
  222. }
  223. self._input_dict, self._dummy_batch = \
  224. self._get_input_dict_and_dummy_batch(
  225. self.view_requirements,
  226. dict({SampleBatch.ACTIONS: action_ph},
  227. **prev_action_ph))
  228. # Placeholder for (sampling steps) timestep (int).
  229. timestep = tf1.placeholder_with_default(
  230. tf.zeros((), dtype=tf.int64), (), name="timestep")
  231. # Placeholder for `is_exploring` flag.
  232. explore = tf1.placeholder_with_default(
  233. True, (), name="is_exploring")
  234. # Placeholder for `is_training` flag.
  235. self._input_dict.set_training(self._get_is_training_placeholder())
  236. # Multi-GPU towers do not need any action computing/exploration
  237. # graphs.
  238. sampled_action = None
  239. sampled_action_logp = None
  240. dist_inputs = None
  241. self._state_out = None
  242. if not self._is_tower:
  243. # Create the Exploration object to use for this Policy.
  244. self.exploration = self._create_exploration()
  245. # Fully customized action generation (e.g., custom policy).
  246. if action_sampler_fn:
  247. sampled_action, sampled_action_logp = action_sampler_fn(
  248. self,
  249. self.model,
  250. obs_batch=self._input_dict[SampleBatch.CUR_OBS],
  251. state_batches=self._state_inputs,
  252. seq_lens=self._seq_lens,
  253. prev_action_batch=self._input_dict.get(
  254. SampleBatch.PREV_ACTIONS),
  255. prev_reward_batch=self._input_dict.get(
  256. SampleBatch.PREV_REWARDS),
  257. explore=explore,
  258. is_training=self._input_dict.is_training)
  259. # Distribution generation is customized, e.g., DQN, DDPG.
  260. else:
  261. if action_distribution_fn:
  262. # Try new action_distribution_fn signature, supporting
  263. # state_batches and seq_lens.
  264. in_dict = self._input_dict
  265. try:
  266. dist_inputs, dist_class, self._state_out = \
  267. action_distribution_fn(
  268. self,
  269. self.model,
  270. input_dict=in_dict,
  271. state_batches=self._state_inputs,
  272. seq_lens=self._seq_lens,
  273. explore=explore,
  274. timestep=timestep,
  275. is_training=in_dict.is_training)
  276. # Trying the old way (to stay backward compatible).
  277. # TODO: Remove in future.
  278. except TypeError as e:
  279. if "positional argument" in e.args[0] or \
  280. "unexpected keyword argument" in e.args[0]:
  281. dist_inputs, dist_class, self._state_out = \
  282. action_distribution_fn(
  283. self, self.model,
  284. obs_batch=in_dict[SampleBatch.CUR_OBS],
  285. state_batches=self._state_inputs,
  286. seq_lens=self._seq_lens,
  287. prev_action_batch=in_dict.get(
  288. SampleBatch.PREV_ACTIONS),
  289. prev_reward_batch=in_dict.get(
  290. SampleBatch.PREV_REWARDS),
  291. explore=explore,
  292. is_training=in_dict.is_training)
  293. else:
  294. raise e
  295. # Default distribution generation behavior:
  296. # Pass through model. E.g., PG, PPO.
  297. else:
  298. if isinstance(self.model, tf.keras.Model):
  299. dist_inputs, self._state_out, \
  300. self._extra_action_fetches = \
  301. self.model(self._input_dict)
  302. else:
  303. dist_inputs, self._state_out = self.model(
  304. self._input_dict, self._state_inputs,
  305. self._seq_lens)
  306. action_dist = dist_class(dist_inputs, self.model)
  307. # Using exploration to get final action (e.g. via sampling).
  308. sampled_action, sampled_action_logp = \
  309. self.exploration.get_exploration_action(
  310. action_distribution=action_dist,
  311. timestep=timestep,
  312. explore=explore)
  313. # Phase 1 init.
  314. sess = tf1.get_default_session() or tf1.Session(
  315. config=tf1.ConfigProto(**self.config["tf_session_args"]))
  316. batch_divisibility_req = get_batch_divisibility_req(self) if \
  317. callable(get_batch_divisibility_req) else \
  318. (get_batch_divisibility_req or 1)
  319. super().__init__(
  320. observation_space=obs_space,
  321. action_space=action_space,
  322. config=config,
  323. sess=sess,
  324. obs_input=self._input_dict[SampleBatch.OBS],
  325. action_input=self._input_dict[SampleBatch.ACTIONS],
  326. sampled_action=sampled_action,
  327. sampled_action_logp=sampled_action_logp,
  328. dist_inputs=dist_inputs,
  329. dist_class=dist_class,
  330. loss=None, # dynamically initialized on run
  331. loss_inputs=[],
  332. model=self.model,
  333. state_inputs=self._state_inputs,
  334. state_outputs=self._state_out,
  335. prev_action_input=self._input_dict.get(SampleBatch.PREV_ACTIONS),
  336. prev_reward_input=self._input_dict.get(SampleBatch.PREV_REWARDS),
  337. seq_lens=self._seq_lens,
  338. max_seq_len=config["model"]["max_seq_len"],
  339. batch_divisibility_req=batch_divisibility_req,
  340. explore=explore,
  341. timestep=timestep)
  342. # Phase 2 init.
  343. if before_loss_init is not None:
  344. before_loss_init(self, obs_space, action_space, config)
  345. # Loss initialization and model/postprocessing test calls.
  346. if not self._is_tower:
  347. self._initialize_loss_from_dummy_batch(
  348. auto_remove_unneeded_view_reqs=True)
  349. # Create MultiGPUTowerStacks, if we have at least one actual
  350. # GPU or >1 CPUs (fake GPUs).
  351. if len(self.devices) > 1 or any("gpu" in d for d in self.devices):
  352. # Per-GPU graph copies created here must share vars with the
  353. # policy. Therefore, `reuse` is set to tf1.AUTO_REUSE because
  354. # Adam nodes are created after all of the device copies are
  355. # created.
  356. with tf1.variable_scope("", reuse=tf1.AUTO_REUSE):
  357. self.multi_gpu_tower_stacks = [
  358. TFMultiGPUTowerStack(policy=self) for i in range(
  359. self.config.get("num_multi_gpu_tower_stacks", 1))
  360. ]
  361. # Initialize again after loss and tower init.
  362. self.get_session().run(tf1.global_variables_initializer())
  363. @override(TFPolicy)
  364. @DeveloperAPI
  365. def copy(self,
  366. existing_inputs: List[Tuple[str, "tf1.placeholder"]]) -> TFPolicy:
  367. """Creates a copy of self using existing input placeholders."""
  368. # Note that there might be RNN state inputs at the end of the list
  369. if len(self._loss_input_dict) != len(existing_inputs):
  370. raise ValueError("Tensor list mismatch", self._loss_input_dict,
  371. self._state_inputs, existing_inputs)
  372. for i, (k, v) in enumerate(self._loss_input_dict_no_rnn.items()):
  373. if v.shape.as_list() != existing_inputs[i].shape.as_list():
  374. raise ValueError("Tensor shape mismatch", i, k, v.shape,
  375. existing_inputs[i].shape)
  376. # By convention, the loss inputs are followed by state inputs and then
  377. # the seq len tensor.
  378. rnn_inputs = []
  379. for i in range(len(self._state_inputs)):
  380. rnn_inputs.append(
  381. ("state_in_{}".format(i),
  382. existing_inputs[len(self._loss_input_dict_no_rnn) + i]))
  383. if rnn_inputs:
  384. rnn_inputs.append((SampleBatch.SEQ_LENS, existing_inputs[-1]))
  385. input_dict = OrderedDict(
  386. [("is_exploring", self._is_exploring), ("timestep",
  387. self._timestep)] +
  388. [(k, existing_inputs[i])
  389. for i, k in enumerate(self._loss_input_dict_no_rnn.keys())] +
  390. rnn_inputs)
  391. instance = self.__class__(
  392. self.observation_space,
  393. self.action_space,
  394. self.config,
  395. existing_inputs=input_dict,
  396. existing_model=[
  397. self.model,
  398. # Deprecated: Target models should all reside under
  399. # `policy.target_model` now.
  400. ("target_q_model", getattr(self, "target_q_model", None)),
  401. ("target_model", getattr(self, "target_model", None)),
  402. ])
  403. instance._loss_input_dict = input_dict
  404. losses = instance._do_loss_init(SampleBatch(input_dict))
  405. loss_inputs = [
  406. (k, existing_inputs[i])
  407. for i, k in enumerate(self._loss_input_dict_no_rnn.keys())
  408. ]
  409. TFPolicy._initialize_loss(instance, losses, loss_inputs)
  410. if instance._grad_stats_fn:
  411. instance._stats_fetches.update(
  412. instance._grad_stats_fn(instance, input_dict, instance._grads))
  413. return instance
  414. @override(Policy)
  415. @DeveloperAPI
  416. def get_initial_state(self) -> List[TensorType]:
  417. if self.model:
  418. return self.model.get_initial_state()
  419. else:
  420. return []
  421. @override(Policy)
  422. @DeveloperAPI
  423. def load_batch_into_buffer(
  424. self,
  425. batch: SampleBatch,
  426. buffer_index: int = 0,
  427. ) -> int:
  428. # Set the is_training flag of the batch.
  429. batch.set_training(True)
  430. # Shortcut for 1 CPU only: Store batch in
  431. # `self._loaded_single_cpu_batch`.
  432. if len(self.devices) == 1 and self.devices[0] == "/cpu:0":
  433. assert buffer_index == 0
  434. self._loaded_single_cpu_batch = batch
  435. return len(batch)
  436. input_dict = self._get_loss_inputs_dict(batch, shuffle=False)
  437. data_keys = list(self._loss_input_dict_no_rnn.values())
  438. if self._state_inputs:
  439. state_keys = self._state_inputs + [self._seq_lens]
  440. else:
  441. state_keys = []
  442. inputs = [input_dict[k] for k in data_keys]
  443. state_inputs = [input_dict[k] for k in state_keys]
  444. return self.multi_gpu_tower_stacks[buffer_index].load_data(
  445. sess=self.get_session(),
  446. inputs=inputs,
  447. state_inputs=state_inputs,
  448. )
  449. @override(Policy)
  450. @DeveloperAPI
  451. def get_num_samples_loaded_into_buffer(self, buffer_index: int = 0) -> int:
  452. # Shortcut for 1 CPU only: Batch should already be stored in
  453. # `self._loaded_single_cpu_batch`.
  454. if len(self.devices) == 1 and self.devices[0] == "/cpu:0":
  455. assert buffer_index == 0
  456. return len(self._loaded_single_cpu_batch) if \
  457. self._loaded_single_cpu_batch is not None else 0
  458. return self.multi_gpu_tower_stacks[buffer_index].num_tuples_loaded
  459. @override(Policy)
  460. @DeveloperAPI
  461. def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0):
  462. # Shortcut for 1 CPU only: Batch should already be stored in
  463. # `self._loaded_single_cpu_batch`.
  464. if len(self.devices) == 1 and self.devices[0] == "/cpu:0":
  465. assert buffer_index == 0
  466. if self._loaded_single_cpu_batch is None:
  467. raise ValueError(
  468. "Must call Policy.load_batch_into_buffer() before "
  469. "Policy.learn_on_loaded_batch()!")
  470. # Get the correct slice of the already loaded batch to use,
  471. # based on offset and batch size.
  472. batch_size = self.config.get("sgd_minibatch_size",
  473. self.config["train_batch_size"])
  474. if batch_size >= len(self._loaded_single_cpu_batch):
  475. sliced_batch = self._loaded_single_cpu_batch
  476. else:
  477. sliced_batch = self._loaded_single_cpu_batch.slice(
  478. start=offset, end=offset + batch_size)
  479. return self.learn_on_batch(sliced_batch)
  480. return self.multi_gpu_tower_stacks[buffer_index].optimize(
  481. self.get_session(), offset)
  482. def _get_input_dict_and_dummy_batch(self, view_requirements,
  483. existing_inputs):
  484. """Creates input_dict and dummy_batch for loss initialization.
  485. Used for managing the Policy's input placeholders and for loss
  486. initialization.
  487. Input_dict: Str -> tf.placeholders, dummy_batch: str -> np.arrays.
  488. Args:
  489. view_requirements (ViewReqs): The view requirements dict.
  490. existing_inputs (Dict[str, tf.placeholder]): A dict of already
  491. existing placeholders.
  492. Returns:
  493. Tuple[Dict[str, tf.placeholder], Dict[str, np.ndarray]]: The
  494. input_dict/dummy_batch tuple.
  495. """
  496. input_dict = {}
  497. for view_col, view_req in view_requirements.items():
  498. # Point state_in to the already existing self._state_inputs.
  499. mo = re.match("state_in_(\d+)", view_col)
  500. if mo is not None:
  501. input_dict[view_col] = self._state_inputs[int(mo.group(1))]
  502. # State-outs (no placeholders needed).
  503. elif view_col.startswith("state_out_"):
  504. continue
  505. # Skip action dist inputs placeholder (do later).
  506. elif view_col == SampleBatch.ACTION_DIST_INPUTS:
  507. continue
  508. # This is a tower, input placeholders already exist.
  509. elif view_col in existing_inputs:
  510. input_dict[view_col] = existing_inputs[view_col]
  511. # All others.
  512. else:
  513. time_axis = not isinstance(view_req.shift, int)
  514. if view_req.used_for_training:
  515. # Create a +time-axis placeholder if the shift is not an
  516. # int (range or list of ints).
  517. flatten = view_col not in [
  518. SampleBatch.OBS, SampleBatch.NEXT_OBS] or \
  519. not self.config["_disable_preprocessor_api"]
  520. input_dict[view_col] = get_placeholder(
  521. space=view_req.space,
  522. name=view_col,
  523. time_axis=time_axis,
  524. flatten=flatten,
  525. )
  526. dummy_batch = self._get_dummy_batch_from_view_requirements(
  527. batch_size=32)
  528. return SampleBatch(input_dict, seq_lens=self._seq_lens), dummy_batch
  529. @override(Policy)
  530. def _initialize_loss_from_dummy_batch(
  531. self, auto_remove_unneeded_view_reqs: bool = True,
  532. stats_fn=None) -> None:
  533. # Create the optimizer/exploration optimizer here. Some initialization
  534. # steps (e.g. exploration postprocessing) may need this.
  535. if not self._optimizers:
  536. self._optimizers = force_list(self.optimizer())
  537. # Backward compatibility.
  538. self._optimizer = self._optimizers[0]
  539. # Test calls depend on variable init, so initialize model first.
  540. self.get_session().run(tf1.global_variables_initializer())
  541. logger.info("Testing `compute_actions` w/ dummy batch.")
  542. actions, state_outs, extra_fetches = \
  543. self.compute_actions_from_input_dict(
  544. self._dummy_batch, explore=False, timestep=0)
  545. for key, value in extra_fetches.items():
  546. self._dummy_batch[key] = value
  547. self._input_dict[key] = get_placeholder(value=value, name=key)
  548. if key not in self.view_requirements:
  549. logger.info("Adding extra-action-fetch `{}` to "
  550. "view-reqs.".format(key))
  551. self.view_requirements[key] = ViewRequirement(
  552. space=gym.spaces.Box(
  553. -1.0, 1.0, shape=value.shape[1:], dtype=value.dtype),
  554. used_for_compute_actions=False,
  555. )
  556. dummy_batch = self._dummy_batch
  557. logger.info("Testing `postprocess_trajectory` w/ dummy batch.")
  558. self.exploration.postprocess_trajectory(self, dummy_batch,
  559. self.get_session())
  560. _ = self.postprocess_trajectory(dummy_batch)
  561. # Add new columns automatically to (loss) input_dict.
  562. for key in dummy_batch.added_keys:
  563. if key not in self._input_dict:
  564. self._input_dict[key] = get_placeholder(
  565. value=dummy_batch[key], name=key)
  566. if key not in self.view_requirements:
  567. self.view_requirements[key] = ViewRequirement(
  568. space=gym.spaces.Box(
  569. -1.0,
  570. 1.0,
  571. shape=dummy_batch[key].shape[1:],
  572. dtype=dummy_batch[key].dtype),
  573. used_for_compute_actions=False,
  574. )
  575. train_batch = SampleBatch(
  576. dict(self._input_dict, **self._loss_input_dict),
  577. _is_training=True,
  578. )
  579. if self._state_inputs:
  580. train_batch[SampleBatch.SEQ_LENS] = self._seq_lens
  581. self._loss_input_dict.update({
  582. SampleBatch.SEQ_LENS: train_batch[SampleBatch.SEQ_LENS]
  583. })
  584. self._loss_input_dict.update({k: v for k, v in train_batch.items()})
  585. if log_once("loss_init"):
  586. logger.debug(
  587. "Initializing loss function with dummy input:\n\n{}\n".format(
  588. summarize(train_batch)))
  589. losses = self._do_loss_init(train_batch)
  590. all_accessed_keys = \
  591. train_batch.accessed_keys | dummy_batch.accessed_keys | \
  592. dummy_batch.added_keys | set(
  593. self.model.view_requirements.keys())
  594. TFPolicy._initialize_loss(self, losses, [
  595. (k, v) for k, v in train_batch.items() if k in all_accessed_keys
  596. ] + ([(SampleBatch.SEQ_LENS, train_batch[SampleBatch.SEQ_LENS])]
  597. if SampleBatch.SEQ_LENS in train_batch else []))
  598. if "is_training" in self._loss_input_dict:
  599. del self._loss_input_dict["is_training"]
  600. # Call the grads stats fn.
  601. # TODO: (sven) rename to simply stats_fn to match eager and torch.
  602. if self._grad_stats_fn:
  603. self._stats_fetches.update(
  604. self._grad_stats_fn(self, train_batch, self._grads))
  605. # Add new columns automatically to view-reqs.
  606. if auto_remove_unneeded_view_reqs:
  607. # Add those needed for postprocessing and training.
  608. all_accessed_keys = train_batch.accessed_keys | \
  609. dummy_batch.accessed_keys
  610. # Tag those only needed for post-processing (with some exceptions).
  611. for key in dummy_batch.accessed_keys:
  612. if key not in train_batch.accessed_keys and \
  613. key not in self.model.view_requirements and \
  614. key not in [
  615. SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
  616. SampleBatch.UNROLL_ID, SampleBatch.DONES,
  617. SampleBatch.REWARDS, SampleBatch.INFOS]:
  618. if key in self.view_requirements:
  619. self.view_requirements[key].used_for_training = False
  620. if key in self._loss_input_dict:
  621. del self._loss_input_dict[key]
  622. # Remove those not needed at all (leave those that are needed
  623. # by Sampler to properly execute sample collection).
  624. # Also always leave DONES, REWARDS, and INFOS, no matter what.
  625. for key in list(self.view_requirements.keys()):
  626. if key not in all_accessed_keys and key not in [
  627. SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
  628. SampleBatch.UNROLL_ID, SampleBatch.DONES,
  629. SampleBatch.REWARDS, SampleBatch.INFOS] and \
  630. key not in self.model.view_requirements:
  631. # If user deleted this key manually in postprocessing
  632. # fn, warn about it and do not remove from
  633. # view-requirements.
  634. if key in dummy_batch.deleted_keys:
  635. logger.warning(
  636. "SampleBatch key '{}' was deleted manually in "
  637. "postprocessing function! RLlib will "
  638. "automatically remove non-used items from the "
  639. "data stream. Remove the `del` from your "
  640. "postprocessing function.".format(key))
  641. else:
  642. del self.view_requirements[key]
  643. if key in self._loss_input_dict:
  644. del self._loss_input_dict[key]
  645. # Add those data_cols (again) that are missing and have
  646. # dependencies by view_cols.
  647. for key in list(self.view_requirements.keys()):
  648. vr = self.view_requirements[key]
  649. if (vr.data_col is not None
  650. and vr.data_col not in self.view_requirements):
  651. used_for_training = \
  652. vr.data_col in train_batch.accessed_keys
  653. self.view_requirements[vr.data_col] = ViewRequirement(
  654. space=vr.space, used_for_training=used_for_training)
  655. self._loss_input_dict_no_rnn = {
  656. k: v
  657. for k, v in self._loss_input_dict.items()
  658. if (v not in self._state_inputs and v != self._seq_lens)
  659. }
  660. def _do_loss_init(self, train_batch: SampleBatch):
  661. losses = self._loss_fn(self, self.model, self.dist_class, train_batch)
  662. losses = force_list(losses)
  663. if self._stats_fn:
  664. self._stats_fetches.update(self._stats_fn(self, train_batch))
  665. # Override the update ops to be those of the model.
  666. self._update_ops = []
  667. if not isinstance(self.model, tf.keras.Model):
  668. self._update_ops = self.model.update_ops()
  669. return losses
  670. class TFMultiGPUTowerStack:
  671. """Optimizer that runs in parallel across multiple local devices.
  672. TFMultiGPUTowerStack automatically splits up and loads training data
  673. onto specified local devices (e.g. GPUs) with `load_data()`. During a call
  674. to `optimize()`, the devices compute gradients over slices of the data in
  675. parallel. The gradients are then averaged and applied to the shared
  676. weights.
  677. The data loaded is pinned in device memory until the next call to
  678. `load_data`, so you can make multiple passes (possibly in randomized order)
  679. over the same data once loaded.
  680. This is similar to tf1.train.SyncReplicasOptimizer, but works within a
  681. single TensorFlow graph, i.e. implements in-graph replicated training:
  682. https://www.tensorflow.org/api_docs/python/tf/train/SyncReplicasOptimizer
  683. """
  684. def __init__(
  685. self,
  686. # Deprecated.
  687. optimizer=None,
  688. devices=None,
  689. input_placeholders=None,
  690. rnn_inputs=None,
  691. max_per_device_batch_size=None,
  692. build_graph=None,
  693. grad_norm_clipping=None,
  694. # Use only `policy` argument from here on.
  695. policy: TFPolicy = None,
  696. ):
  697. """Initializes a TFMultiGPUTowerStack instance.
  698. Args:
  699. policy (TFPolicy): The TFPolicy object that this tower stack
  700. belongs to.
  701. """
  702. # Obsoleted usage, use only `policy` arg from here on.
  703. if policy is None:
  704. deprecation_warning(
  705. old="TFMultiGPUTowerStack(...)",
  706. new="TFMultiGPUTowerStack(policy=[Policy])",
  707. error=False,
  708. )
  709. self.policy = None
  710. self.optimizers = optimizer
  711. self.devices = devices
  712. self.max_per_device_batch_size = max_per_device_batch_size
  713. self.policy_copy = build_graph
  714. else:
  715. self.policy: TFPolicy = policy
  716. self.optimizers: List[LocalOptimizer] = self.policy._optimizers
  717. self.devices = self.policy.devices
  718. self.max_per_device_batch_size = \
  719. (max_per_device_batch_size or
  720. policy.config.get("sgd_minibatch_size", policy.config.get(
  721. "train_batch_size", 999999))) // len(self.devices)
  722. input_placeholders = list(
  723. self.policy._loss_input_dict_no_rnn.values())
  724. rnn_inputs = []
  725. if self.policy._state_inputs:
  726. rnn_inputs = self.policy._state_inputs + [
  727. self.policy._seq_lens
  728. ]
  729. grad_norm_clipping = self.policy.config.get("grad_clip")
  730. self.policy_copy = self.policy.copy
  731. assert len(self.devices) > 1 or "gpu" in self.devices[0]
  732. self.loss_inputs = input_placeholders + rnn_inputs
  733. shared_ops = tf1.get_collection(
  734. tf1.GraphKeys.UPDATE_OPS, scope=tf1.get_variable_scope().name)
  735. # Then setup the per-device loss graphs that use the shared weights
  736. self._batch_index = tf1.placeholder(tf.int32, name="batch_index")
  737. # Dynamic batch size, which may be shrunk if there isn't enough data
  738. self._per_device_batch_size = tf1.placeholder(
  739. tf.int32, name="per_device_batch_size")
  740. self._loaded_per_device_batch_size = max_per_device_batch_size
  741. # When loading RNN input, we dynamically determine the max seq len
  742. self._max_seq_len = tf1.placeholder(tf.int32, name="max_seq_len")
  743. self._loaded_max_seq_len = 1
  744. # Split on the CPU in case the data doesn't fit in GPU memory.
  745. with tf.device("/cpu:0"):
  746. data_splits = zip(
  747. *[tf.split(ph, len(self.devices)) for ph in self.loss_inputs])
  748. self._towers = []
  749. for tower_i, (device, device_placeholders) in enumerate(
  750. zip(self.devices, data_splits)):
  751. self._towers.append(
  752. self._setup_device(tower_i, device, device_placeholders,
  753. len(input_placeholders)))
  754. if self.policy.config["_tf_policy_handles_more_than_one_loss"]:
  755. avgs = []
  756. for i, optim in enumerate(self.optimizers):
  757. avg = average_gradients([t.grads[i] for t in self._towers])
  758. if grad_norm_clipping:
  759. clipped = []
  760. for grad, _ in avg:
  761. clipped.append(grad)
  762. clipped, _ = tf.clip_by_global_norm(
  763. clipped, grad_norm_clipping)
  764. for i, (grad, var) in enumerate(avg):
  765. avg[i] = (clipped[i], var)
  766. avgs.append(avg)
  767. # Gather update ops for any batch norm layers.
  768. # TODO(ekl) here we
  769. # will use all the ops found which won't work for DQN / DDPG, but
  770. # those aren't supported with multi-gpu right now anyways.
  771. self._update_ops = tf1.get_collection(
  772. tf1.GraphKeys.UPDATE_OPS, scope=tf1.get_variable_scope().name)
  773. for op in shared_ops:
  774. self._update_ops.remove(op) # only care about tower update ops
  775. if self._update_ops:
  776. logger.debug("Update ops to run on apply gradient: {}".format(
  777. self._update_ops))
  778. with tf1.control_dependencies(self._update_ops):
  779. self._train_op = tf.group([
  780. o.apply_gradients(a)
  781. for o, a in zip(self.optimizers, avgs)
  782. ])
  783. else:
  784. avg = average_gradients([t.grads for t in self._towers])
  785. if grad_norm_clipping:
  786. clipped = []
  787. for grad, _ in avg:
  788. clipped.append(grad)
  789. clipped, _ = tf.clip_by_global_norm(clipped,
  790. grad_norm_clipping)
  791. for i, (grad, var) in enumerate(avg):
  792. avg[i] = (clipped[i], var)
  793. # Gather update ops for any batch norm layers.
  794. # TODO(ekl) here we
  795. # will use all the ops found which won't work for DQN / DDPG, but
  796. # those aren't supported with multi-gpu right now anyways.
  797. self._update_ops = tf1.get_collection(
  798. tf1.GraphKeys.UPDATE_OPS, scope=tf1.get_variable_scope().name)
  799. for op in shared_ops:
  800. self._update_ops.remove(op) # only care about tower update ops
  801. if self._update_ops:
  802. logger.debug("Update ops to run on apply gradient: {}".format(
  803. self._update_ops))
  804. with tf1.control_dependencies(self._update_ops):
  805. self._train_op = self.optimizers[0].apply_gradients(avg)
  806. def load_data(self, sess, inputs, state_inputs):
  807. """Bulk loads the specified inputs into device memory.
  808. The shape of the inputs must conform to the shapes of the input
  809. placeholders this optimizer was constructed with.
  810. The data is split equally across all the devices. If the data is not
  811. evenly divisible by the batch size, excess data will be discarded.
  812. Args:
  813. sess: TensorFlow session.
  814. inputs: List of arrays matching the input placeholders, of shape
  815. [BATCH_SIZE, ...].
  816. state_inputs: List of RNN input arrays. These arrays have size
  817. [BATCH_SIZE / MAX_SEQ_LEN, ...].
  818. Returns:
  819. The number of tuples loaded per device.
  820. """
  821. if log_once("load_data"):
  822. logger.info(
  823. "Training on concatenated sample batches:\n\n{}\n".format(
  824. summarize({
  825. "placeholders": self.loss_inputs,
  826. "inputs": inputs,
  827. "state_inputs": state_inputs
  828. })))
  829. feed_dict = {}
  830. assert len(self.loss_inputs) == len(inputs + state_inputs), \
  831. (self.loss_inputs, inputs, state_inputs)
  832. # Let's suppose we have the following input data, and 2 devices:
  833. # 1 2 3 4 5 6 7 <- state inputs shape
  834. # A A A B B B C C C D D D E E E F F F G G G <- inputs shape
  835. # The data is truncated and split across devices as follows:
  836. # |---| seq len = 3
  837. # |---------------------------------| seq batch size = 6 seqs
  838. # |----------------| per device batch size = 9 tuples
  839. if len(state_inputs) > 0:
  840. smallest_array = state_inputs[0]
  841. seq_len = len(inputs[0]) // len(state_inputs[0])
  842. self._loaded_max_seq_len = seq_len
  843. else:
  844. smallest_array = inputs[0]
  845. self._loaded_max_seq_len = 1
  846. sequences_per_minibatch = (
  847. self.max_per_device_batch_size // self._loaded_max_seq_len * len(
  848. self.devices))
  849. if sequences_per_minibatch < 1:
  850. logger.warning(
  851. ("Target minibatch size is {}, however the rollout sequence "
  852. "length is {}, hence the minibatch size will be raised to "
  853. "{}.").format(self.max_per_device_batch_size,
  854. self._loaded_max_seq_len,
  855. self._loaded_max_seq_len * len(self.devices)))
  856. sequences_per_minibatch = 1
  857. if len(smallest_array) < sequences_per_minibatch:
  858. # Dynamically shrink the batch size if insufficient data
  859. sequences_per_minibatch = make_divisible_by(
  860. len(smallest_array), len(self.devices))
  861. if log_once("data_slicing"):
  862. logger.info(
  863. ("Divided {} rollout sequences, each of length {}, among "
  864. "{} devices.").format(
  865. len(smallest_array), self._loaded_max_seq_len,
  866. len(self.devices)))
  867. if sequences_per_minibatch < len(self.devices):
  868. raise ValueError(
  869. "Must load at least 1 tuple sequence per device. Try "
  870. "increasing `sgd_minibatch_size` or reducing `max_seq_len` "
  871. "to ensure that at least one sequence fits per device.")
  872. self._loaded_per_device_batch_size = (sequences_per_minibatch // len(
  873. self.devices) * self._loaded_max_seq_len)
  874. if len(state_inputs) > 0:
  875. # First truncate the RNN state arrays to the sequences_per_minib.
  876. state_inputs = [
  877. make_divisible_by(arr, sequences_per_minibatch)
  878. for arr in state_inputs
  879. ]
  880. # Then truncate the data inputs to match
  881. inputs = [arr[:len(state_inputs[0]) * seq_len] for arr in inputs]
  882. assert len(state_inputs[0]) * seq_len == len(inputs[0]), \
  883. (len(state_inputs[0]), sequences_per_minibatch, seq_len,
  884. len(inputs[0]))
  885. for ph, arr in zip(self.loss_inputs, inputs + state_inputs):
  886. feed_dict[ph] = arr
  887. truncated_len = len(inputs[0])
  888. else:
  889. truncated_len = 0
  890. for ph, arr in zip(self.loss_inputs, inputs):
  891. truncated_arr = make_divisible_by(arr, sequences_per_minibatch)
  892. feed_dict[ph] = truncated_arr
  893. if truncated_len == 0:
  894. truncated_len = len(truncated_arr)
  895. sess.run([t.init_op for t in self._towers], feed_dict=feed_dict)
  896. self.num_tuples_loaded = truncated_len
  897. samples_per_device = truncated_len // len(self.devices)
  898. assert samples_per_device > 0, "No data loaded?"
  899. assert samples_per_device % self._loaded_per_device_batch_size == 0
  900. # Return loaded samples per-device.
  901. return samples_per_device
  902. def optimize(self, sess, batch_index):
  903. """Run a single step of SGD.
  904. Runs a SGD step over a slice of the preloaded batch with size given by
  905. self._loaded_per_device_batch_size and offset given by the batch_index
  906. argument.
  907. Updates shared model weights based on the averaged per-device
  908. gradients.
  909. Args:
  910. sess: TensorFlow session.
  911. batch_index: Offset into the preloaded data. This value must be
  912. between `0` and `tuples_per_device`. The amount of data to
  913. process is at most `max_per_device_batch_size`.
  914. Returns:
  915. The outputs of extra_ops evaluated over the batch.
  916. """
  917. feed_dict = {
  918. self._batch_index: batch_index,
  919. self._per_device_batch_size: self._loaded_per_device_batch_size,
  920. self._max_seq_len: self._loaded_max_seq_len,
  921. }
  922. for tower in self._towers:
  923. feed_dict.update(tower.loss_graph.extra_compute_grad_feed_dict())
  924. fetches = {"train": self._train_op}
  925. for tower_num, tower in enumerate(self._towers):
  926. tower_fetch = tower.loss_graph._get_grad_and_stats_fetches()
  927. fetches["tower_{}".format(tower_num)] = tower_fetch
  928. return sess.run(fetches, feed_dict=feed_dict)
  929. def get_device_losses(self):
  930. return [t.loss_graph for t in self._towers]
  931. def _setup_device(self, tower_i, device, device_input_placeholders,
  932. num_data_in):
  933. assert num_data_in <= len(device_input_placeholders)
  934. with tf.device(device):
  935. with tf1.name_scope(TOWER_SCOPE_NAME + f"_{tower_i}"):
  936. device_input_batches = []
  937. device_input_slices = []
  938. for i, ph in enumerate(device_input_placeholders):
  939. current_batch = tf1.Variable(
  940. ph,
  941. trainable=False,
  942. validate_shape=False,
  943. collections=[])
  944. device_input_batches.append(current_batch)
  945. if i < num_data_in:
  946. scale = self._max_seq_len
  947. granularity = self._max_seq_len
  948. else:
  949. scale = self._max_seq_len
  950. granularity = 1
  951. current_slice = tf.slice(
  952. current_batch,
  953. ([self._batch_index // scale * granularity] +
  954. [0] * len(ph.shape[1:])),
  955. ([self._per_device_batch_size // scale * granularity] +
  956. [-1] * len(ph.shape[1:])))
  957. current_slice.set_shape(ph.shape)
  958. device_input_slices.append(current_slice)
  959. graph_obj = self.policy_copy(device_input_slices)
  960. device_grads = graph_obj.gradients(self.optimizers,
  961. graph_obj._losses)
  962. return Tower(
  963. tf.group(
  964. *[batch.initializer for batch in device_input_batches]),
  965. device_grads, graph_obj)
  966. # Each tower is a copy of the loss graph pinned to a specific device.
  967. Tower = namedtuple("Tower", ["init_op", "grads", "loss_graph"])
  968. def make_divisible_by(a, n):
  969. if type(a) is int:
  970. return a - a % n
  971. return a[0:a.shape[0] - a.shape[0] % n]
  972. def average_gradients(tower_grads):
  973. """Averages gradients across towers.
  974. Calculate the average gradient for each shared variable across all towers.
  975. Note that this function provides a synchronization point across all towers.
  976. Args:
  977. tower_grads: List of lists of (gradient, variable) tuples. The outer
  978. list is over individual gradients. The inner list is over the
  979. gradient calculation for each tower.
  980. Returns:
  981. List of pairs of (gradient, variable) where the gradient has been
  982. averaged across all towers.
  983. TODO(ekl): We could use NCCL if this becomes a bottleneck.
  984. """
  985. average_grads = []
  986. for grad_and_vars in zip(*tower_grads):
  987. # Note that each grad_and_vars looks like the following:
  988. # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
  989. grads = []
  990. for g, _ in grad_and_vars:
  991. if g is not None:
  992. # Add 0 dimension to the gradients to represent the tower.
  993. expanded_g = tf.expand_dims(g, 0)
  994. # Append on a 'tower' dimension which we will average over
  995. # below.
  996. grads.append(expanded_g)
  997. if not grads:
  998. continue
  999. # Average over the 'tower' dimension.
  1000. grad = tf.concat(axis=0, values=grads)
  1001. grad = tf.reduce_mean(grad, 0)
  1002. # Keep in mind that the Variables are redundant because they are shared
  1003. # across towers. So .. we will just return the first tower's pointer to
  1004. # the Variable.
  1005. v = grad_and_vars[0][1]
  1006. grad_and_var = (grad, v)
  1007. average_grads.append(grad_and_var)
  1008. return average_grads