train_ops.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. import logging
  2. import numpy as np
  3. import math
  4. from typing import Dict, List, Tuple, Any
  5. import ray
  6. from ray.rllib.evaluation.worker_set import WorkerSet
  7. from ray.rllib.execution.common import \
  8. AGENT_STEPS_TRAINED_COUNTER, APPLY_GRADS_TIMER, COMPUTE_GRADS_TIMER, \
  9. LAST_TARGET_UPDATE_TS, LEARN_ON_BATCH_TIMER, \
  10. LOAD_BATCH_TIMER, NUM_TARGET_UPDATES, STEPS_SAMPLED_COUNTER, \
  11. STEPS_TRAINED_COUNTER, STEPS_TRAINED_THIS_ITER_COUNTER, \
  12. WORKER_UPDATE_TIMER, _check_sample_batch_type, \
  13. _get_global_vars, _get_shared_metrics
  14. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, MultiAgentBatch
  15. from ray.rllib.utils.annotations import ExperimentalAPI
  16. from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
  17. from ray.rllib.utils.framework import try_import_tf
  18. from ray.rllib.utils.metrics import NUM_ENV_STEPS_TRAINED, \
  19. NUM_AGENT_STEPS_TRAINED
  20. from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder, \
  21. LEARNER_INFO
  22. from ray.rllib.utils.sgd import do_minibatch_sgd
  23. from ray.rllib.utils.typing import PolicyID, SampleBatchType, ModelGradients
  24. tf1, tf, tfv = try_import_tf()
  25. logger = logging.getLogger(__name__)
  26. @ExperimentalAPI
  27. def train_one_step(trainer, train_batch) -> Dict:
  28. config = trainer.config
  29. workers = trainer.workers
  30. local_worker = workers.local_worker()
  31. policies = local_worker.policies_to_train
  32. num_sgd_iter = config.get("num_sgd_iter", 1)
  33. sgd_minibatch_size = config.get("sgd_minibatch_size", 0)
  34. learn_timer = trainer._timers[LEARN_ON_BATCH_TIMER]
  35. with learn_timer:
  36. # Subsample minibatches (size=`sgd_minibatch_size`) from the
  37. # train batch and loop through train batch `num_sgd_iter` times.
  38. if num_sgd_iter > 1 or sgd_minibatch_size > 0:
  39. info = do_minibatch_sgd(
  40. train_batch,
  41. {pid: local_worker.get_policy(pid)
  42. for pid in policies}, local_worker, num_sgd_iter,
  43. sgd_minibatch_size, [])
  44. # Single update step using train batch.
  45. else:
  46. info = local_worker.learn_on_batch(train_batch)
  47. learn_timer.push_units_processed(train_batch.count)
  48. trainer._counters[NUM_ENV_STEPS_TRAINED] += train_batch.count
  49. trainer._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()
  50. return info
  51. @ExperimentalAPI
  52. def multi_gpu_train_one_step(trainer, train_batch) -> Dict:
  53. config = trainer.config
  54. workers = trainer.workers
  55. local_worker = workers.local_worker()
  56. policies = local_worker.policies_to_train
  57. num_sgd_iter = config.get("sgd_num_iter", 1)
  58. sgd_minibatch_size = config.get("sgd_minibatch_size",
  59. config["train_batch_size"])
  60. # Determine the number of devices (GPUs or 1 CPU) we use.
  61. num_devices = int(math.ceil(config["num_gpus"] or 1))
  62. # Make sure total batch size is dividable by the number of devices.
  63. # Batch size per tower.
  64. per_device_batch_size = sgd_minibatch_size // num_devices
  65. # Total batch size.
  66. batch_size = per_device_batch_size * num_devices
  67. assert batch_size % num_devices == 0
  68. assert batch_size >= num_devices, "Batch size too small!"
  69. # Handle everything as if multi-agent.
  70. train_batch = train_batch.as_multi_agent()
  71. # Load data into GPUs.
  72. load_timer = trainer._timers[LOAD_BATCH_TIMER]
  73. with load_timer:
  74. num_loaded_samples = {}
  75. for policy_id, batch in train_batch.policy_batches.items():
  76. # Not a policy-to-train.
  77. if policy_id not in policies:
  78. continue
  79. # Decompress SampleBatch, in case some columns are compressed.
  80. batch.decompress_if_needed()
  81. # Load the entire train batch into the Policy's only buffer
  82. # (idx=0). Policies only have >1 buffers, if we are training
  83. # asynchronously.
  84. num_loaded_samples[policy_id] = local_worker.policy_map[
  85. policy_id].load_batch_into_buffer(
  86. batch, buffer_index=0)
  87. # Execute minibatch SGD on loaded data.
  88. learn_timer = trainer._timers[LEARN_ON_BATCH_TIMER]
  89. with learn_timer:
  90. # Use LearnerInfoBuilder as a unified way to build the final
  91. # results dict from `learn_on_loaded_batch` call(s).
  92. # This makes sure results dicts always have the same structure
  93. # no matter the setup (multi-GPU, multi-agent, minibatch SGD,
  94. # tf vs torch).
  95. learner_info_builder = LearnerInfoBuilder(num_devices=num_devices)
  96. for policy_id, samples_per_device in num_loaded_samples.items():
  97. policy = local_worker.policy_map[policy_id]
  98. num_batches = max(
  99. 1,
  100. int(samples_per_device) // int(per_device_batch_size))
  101. logger.debug("== sgd epochs for {} ==".format(policy_id))
  102. for _ in range(num_sgd_iter):
  103. permutation = np.random.permutation(num_batches)
  104. for batch_index in range(num_batches):
  105. # Learn on the pre-loaded data in the buffer.
  106. # Note: For minibatch SGD, the data is an offset into
  107. # the pre-loaded entire train batch.
  108. results = policy.learn_on_loaded_batch(
  109. permutation[batch_index] * per_device_batch_size,
  110. buffer_index=0)
  111. learner_info_builder.add_learn_on_batch_results(
  112. results, policy_id)
  113. # Tower reduce and finalize results.
  114. learner_info = learner_info_builder.finalize()
  115. load_timer.push_units_processed(train_batch.count)
  116. learn_timer.push_units_processed(train_batch.count)
  117. trainer._counters[NUM_ENV_STEPS_TRAINED] += train_batch.count
  118. trainer._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()
  119. # Update weights - after learning on the local worker - on all remote
  120. # workers.
  121. if workers.remote_workers():
  122. with trainer._timers[WORKER_UPDATE_TIMER]:
  123. weights = ray.put(workers.local_worker().get_weights(policies))
  124. for e in workers.remote_workers():
  125. e.set_weights.remote(weights)
  126. return learner_info
  127. class TrainOneStep:
  128. """Callable that improves the policy and updates workers.
  129. This should be used with the .for_each() operator. A tuple of the input
  130. and learner stats will be returned.
  131. Examples:
  132. >>> rollouts = ParallelRollouts(...)
  133. >>> train_op = rollouts.for_each(TrainOneStep(workers))
  134. >>> print(next(train_op)) # This trains the policy on one batch.
  135. SampleBatch(...), {"learner_stats": ...}
  136. Updates the STEPS_TRAINED_COUNTER counter and LEARNER_INFO field in the
  137. local iterator context.
  138. """
  139. def __init__(self,
  140. workers: WorkerSet,
  141. policies: List[PolicyID] = frozenset([]),
  142. num_sgd_iter: int = 1,
  143. sgd_minibatch_size: int = 0):
  144. self.workers = workers
  145. self.local_worker = workers.local_worker()
  146. self.policies = policies
  147. self.num_sgd_iter = num_sgd_iter
  148. self.sgd_minibatch_size = sgd_minibatch_size
  149. def __call__(self,
  150. batch: SampleBatchType) -> (SampleBatchType, List[dict]):
  151. _check_sample_batch_type(batch)
  152. metrics = _get_shared_metrics()
  153. learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
  154. with learn_timer:
  155. # Subsample minibatches (size=`sgd_minibatch_size`) from the
  156. # train batch and loop through train batch `num_sgd_iter` times.
  157. if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0:
  158. lw = self.workers.local_worker()
  159. learner_info = do_minibatch_sgd(
  160. batch, {
  161. pid: lw.get_policy(pid)
  162. for pid in self.policies
  163. or self.local_worker.policies_to_train
  164. }, lw, self.num_sgd_iter, self.sgd_minibatch_size, [])
  165. # Single update step using train batch.
  166. else:
  167. learner_info = \
  168. self.workers.local_worker().learn_on_batch(batch)
  169. metrics.info[LEARNER_INFO] = learner_info
  170. learn_timer.push_units_processed(batch.count)
  171. metrics.counters[STEPS_TRAINED_COUNTER] += batch.count
  172. metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = batch.count
  173. if isinstance(batch, MultiAgentBatch):
  174. metrics.counters[
  175. AGENT_STEPS_TRAINED_COUNTER] += batch.agent_steps()
  176. # Update weights - after learning on the local worker - on all remote
  177. # workers.
  178. if self.workers.remote_workers():
  179. with metrics.timers[WORKER_UPDATE_TIMER]:
  180. weights = ray.put(self.workers.local_worker().get_weights(
  181. self.policies or self.local_worker.policies_to_train))
  182. for e in self.workers.remote_workers():
  183. e.set_weights.remote(weights, _get_global_vars())
  184. # Also update global vars of the local worker.
  185. self.workers.local_worker().set_global_vars(_get_global_vars())
  186. return batch, learner_info
  187. class MultiGPUTrainOneStep:
  188. """Multi-GPU version of TrainOneStep.
  189. This should be used with the .for_each() operator. A tuple of the input
  190. and learner stats will be returned.
  191. Examples:
  192. >>> rollouts = ParallelRollouts(...)
  193. >>> train_op = rollouts.for_each(MultiGPUTrainOneStep(workers, ...))
  194. >>> print(next(train_op)) # This trains the policy on one batch.
  195. SampleBatch(...), {"learner_stats": ...}
  196. Updates the STEPS_TRAINED_COUNTER counter and LEARNER_INFO field in the
  197. local iterator context.
  198. """
  199. def __init__(
  200. self,
  201. *,
  202. workers: WorkerSet,
  203. sgd_minibatch_size: int,
  204. num_sgd_iter: int,
  205. num_gpus: int,
  206. _fake_gpus: bool = False,
  207. # Deprecated args.
  208. shuffle_sequences=DEPRECATED_VALUE,
  209. framework=DEPRECATED_VALUE):
  210. if framework != DEPRECATED_VALUE or \
  211. shuffle_sequences != DEPRECATED_VALUE:
  212. deprecation_warning(
  213. old="MultiGPUTrainOneStep(framework=..., "
  214. "shuffle_sequences=...)",
  215. error=False)
  216. self.workers = workers
  217. self.local_worker = workers.local_worker()
  218. self.num_sgd_iter = num_sgd_iter
  219. self.sgd_minibatch_size = sgd_minibatch_size
  220. self.shuffle_sequences = shuffle_sequences
  221. # Collect actual GPU devices to use.
  222. if not num_gpus:
  223. _fake_gpus = True
  224. num_gpus = 1
  225. type_ = "cpu" if _fake_gpus else "gpu"
  226. self.devices = [
  227. "/{}:{}".format(type_, 0 if _fake_gpus else i)
  228. for i in range(int(math.ceil(num_gpus)))
  229. ]
  230. # Make sure total batch size is dividable by the number of devices.
  231. # Batch size per tower.
  232. self.per_device_batch_size = sgd_minibatch_size // len(self.devices)
  233. # Total batch size.
  234. self.batch_size = self.per_device_batch_size * len(self.devices)
  235. assert self.batch_size % len(self.devices) == 0
  236. assert self.batch_size >= len(self.devices), "Batch size too small!"
  237. def __call__(self,
  238. samples: SampleBatchType) -> (SampleBatchType, List[dict]):
  239. _check_sample_batch_type(samples)
  240. # Handle everything as if multi agent.
  241. samples = samples.as_multi_agent()
  242. metrics = _get_shared_metrics()
  243. load_timer = metrics.timers[LOAD_BATCH_TIMER]
  244. learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
  245. # Load data into GPUs.
  246. with load_timer:
  247. num_loaded_samples = {}
  248. for policy_id, batch in samples.policy_batches.items():
  249. # Not a policy-to-train.
  250. if policy_id not in self.local_worker.policies_to_train:
  251. continue
  252. # Decompress SampleBatch, in case some columns are compressed.
  253. batch.decompress_if_needed()
  254. # Load the entire train batch into the Policy's only buffer
  255. # (idx=0). Policies only have >1 buffers, if we are training
  256. # asynchronously.
  257. num_loaded_samples[policy_id] = self.local_worker.policy_map[
  258. policy_id].load_batch_into_buffer(
  259. batch, buffer_index=0)
  260. # Execute minibatch SGD on loaded data.
  261. with learn_timer:
  262. # Use LearnerInfoBuilder as a unified way to build the final
  263. # results dict from `learn_on_loaded_batch` call(s).
  264. # This makes sure results dicts always have the same structure
  265. # no matter the setup (multi-GPU, multi-agent, minibatch SGD,
  266. # tf vs torch).
  267. learner_info_builder = LearnerInfoBuilder(
  268. num_devices=len(self.devices))
  269. for policy_id, samples_per_device in num_loaded_samples.items():
  270. policy = self.local_worker.policy_map[policy_id]
  271. num_batches = max(
  272. 1,
  273. int(samples_per_device) // int(self.per_device_batch_size))
  274. logger.debug("== sgd epochs for {} ==".format(policy_id))
  275. for _ in range(self.num_sgd_iter):
  276. permutation = np.random.permutation(num_batches)
  277. for batch_index in range(num_batches):
  278. # Learn on the pre-loaded data in the buffer.
  279. # Note: For minibatch SGD, the data is an offset into
  280. # the pre-loaded entire train batch.
  281. results = policy.learn_on_loaded_batch(
  282. permutation[batch_index] *
  283. self.per_device_batch_size,
  284. buffer_index=0)
  285. learner_info_builder.add_learn_on_batch_results(
  286. results, policy_id)
  287. # Tower reduce and finalize results.
  288. learner_info = learner_info_builder.finalize()
  289. load_timer.push_units_processed(samples.count)
  290. learn_timer.push_units_processed(samples.count)
  291. metrics.counters[STEPS_TRAINED_COUNTER] += samples.count
  292. metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = samples.count
  293. metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += samples.agent_steps()
  294. metrics.info[LEARNER_INFO] = learner_info
  295. if self.workers.remote_workers():
  296. with metrics.timers[WORKER_UPDATE_TIMER]:
  297. weights = ray.put(self.workers.local_worker().get_weights(
  298. self.local_worker.policies_to_train))
  299. for e in self.workers.remote_workers():
  300. e.set_weights.remote(weights, _get_global_vars())
  301. # Also update global vars of the local worker.
  302. self.workers.local_worker().set_global_vars(_get_global_vars())
  303. return samples, learner_info
  304. # Backward compatibility.
  305. TrainTFMultiGPU = MultiGPUTrainOneStep
  306. class ComputeGradients:
  307. """Callable that computes gradients with respect to the policy loss.
  308. This should be used with the .for_each() operator.
  309. Examples:
  310. >>> grads_op = rollouts.for_each(ComputeGradients(workers))
  311. >>> print(next(grads_op))
  312. {"var_0": ..., ...}, 50 # grads, batch count
  313. Updates the LEARNER_INFO info field in the local iterator context.
  314. """
  315. def __init__(self, workers: WorkerSet):
  316. self.workers = workers
  317. def __call__(self, samples: SampleBatchType) -> Tuple[ModelGradients, int]:
  318. _check_sample_batch_type(samples)
  319. metrics = _get_shared_metrics()
  320. with metrics.timers[COMPUTE_GRADS_TIMER]:
  321. grad, info = self.workers.local_worker().compute_gradients(samples)
  322. # RolloutWorker.compute_gradients returns pure single agent stats
  323. # in a non-multi agent setup.
  324. if isinstance(samples, MultiAgentBatch):
  325. metrics.info[LEARNER_INFO] = info
  326. else:
  327. metrics.info[LEARNER_INFO] = {DEFAULT_POLICY_ID: info}
  328. return grad, samples.count
  329. class ApplyGradients:
  330. """Callable that applies gradients and updates workers.
  331. This should be used with the .for_each() operator.
  332. Examples:
  333. >>> apply_op = grads_op.for_each(ApplyGradients(workers))
  334. >>> print(next(apply_op))
  335. None
  336. Updates the STEPS_TRAINED_COUNTER counter in the local iterator context.
  337. """
  338. def __init__(self,
  339. workers,
  340. policies: List[PolicyID] = frozenset([]),
  341. update_all=True):
  342. """Creates an ApplyGradients instance.
  343. Args:
  344. workers (WorkerSet): workers to apply gradients to.
  345. update_all (bool): If true, updates all workers. Otherwise, only
  346. update the worker that produced the sample batch we are
  347. currently processing (i.e., A3C style).
  348. """
  349. self.workers = workers
  350. self.local_worker = workers.local_worker()
  351. self.policies = policies
  352. self.update_all = update_all
  353. def __call__(self, item: Tuple[ModelGradients, int]) -> None:
  354. if not isinstance(item, tuple) or len(item) != 2:
  355. raise ValueError(
  356. "Input must be a tuple of (grad_dict, count), got {}".format(
  357. item))
  358. gradients, count = item
  359. metrics = _get_shared_metrics()
  360. metrics.counters[STEPS_TRAINED_COUNTER] += count
  361. metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = count
  362. apply_timer = metrics.timers[APPLY_GRADS_TIMER]
  363. with apply_timer:
  364. self.workers.local_worker().apply_gradients(gradients)
  365. apply_timer.push_units_processed(count)
  366. # Also update global vars of the local worker.
  367. self.workers.local_worker().set_global_vars(_get_global_vars())
  368. if self.update_all:
  369. if self.workers.remote_workers():
  370. with metrics.timers[WORKER_UPDATE_TIMER]:
  371. weights = ray.put(self.workers.local_worker().get_weights(
  372. self.policies or self.local_worker.policies_to_train))
  373. for e in self.workers.remote_workers():
  374. e.set_weights.remote(weights, _get_global_vars())
  375. else:
  376. if metrics.current_actor is None:
  377. raise ValueError(
  378. "Could not find actor to update. When "
  379. "update_all=False, `current_actor` must be set "
  380. "in the iterator context.")
  381. with metrics.timers[WORKER_UPDATE_TIMER]:
  382. weights = self.workers.local_worker().get_weights(
  383. self.policies or self.local_worker.policies_to_train)
  384. metrics.current_actor.set_weights.remote(
  385. weights, _get_global_vars())
  386. class AverageGradients:
  387. """Callable that averages the gradients in a batch.
  388. This should be used with the .for_each() operator after a set of gradients
  389. have been batched with .batch().
  390. Examples:
  391. >>> batched_grads = grads_op.batch(32)
  392. >>> avg_grads = batched_grads.for_each(AverageGradients())
  393. >>> print(next(avg_grads))
  394. {"var_0": ..., ...}, 1600 # averaged grads, summed batch count
  395. """
  396. def __call__(self, gradients: List[Tuple[ModelGradients, int]]
  397. ) -> Tuple[ModelGradients, int]:
  398. acc = None
  399. sum_count = 0
  400. for grad, count in gradients:
  401. if acc is None:
  402. acc = grad
  403. else:
  404. acc = [a + b for a, b in zip(acc, grad)]
  405. sum_count += count
  406. logger.info("Computing average of {} microbatch gradients "
  407. "({} samples total)".format(len(gradients), sum_count))
  408. return acc, sum_count
  409. class UpdateTargetNetwork:
  410. """Periodically call policy.update_target() on all trainable policies.
  411. This should be used with the .for_each() operator after training step
  412. has been taken.
  413. Examples:
  414. >>> train_op = ParallelRollouts(...).for_each(TrainOneStep(...))
  415. >>> update_op = train_op.for_each(
  416. ... UpdateTargetIfNeeded(workers, target_update_freq=500))
  417. >>> print(next(update_op))
  418. None
  419. Updates the LAST_TARGET_UPDATE_TS and NUM_TARGET_UPDATES counters in the
  420. local iterator context. The value of the last update counter is used to
  421. track when we should update the target next.
  422. """
  423. def __init__(self,
  424. workers: WorkerSet,
  425. target_update_freq: int,
  426. by_steps_trained: bool = False,
  427. policies: List[PolicyID] = frozenset([])):
  428. self.workers = workers
  429. self.local_worker = workers.local_worker()
  430. self.target_update_freq = target_update_freq
  431. self.policies = policies
  432. if by_steps_trained:
  433. self.metric = STEPS_TRAINED_COUNTER
  434. else:
  435. self.metric = STEPS_SAMPLED_COUNTER
  436. def __call__(self, _: Any) -> None:
  437. metrics = _get_shared_metrics()
  438. cur_ts = metrics.counters[self.metric]
  439. last_update = metrics.counters[LAST_TARGET_UPDATE_TS]
  440. if cur_ts - last_update > self.target_update_freq:
  441. to_update = self.policies or self.local_worker.policies_to_train
  442. self.workers.local_worker().foreach_trainable_policy(
  443. lambda p, p_id: p_id in to_update and p.update_target())
  444. metrics.counters[NUM_TARGET_UPDATES] += 1
  445. metrics.counters[LAST_TARGET_UPDATE_TS] = cur_ts