impala_learner.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. from dataclasses import dataclass
  2. from typing import Any, Dict, List, Optional, Union
  3. import numpy as np
  4. import tree # pip install dm_tree
  5. from ray.rllib.core.learner.learner import (
  6. Learner,
  7. LearnerHyperparameters,
  8. )
  9. from ray.rllib.core.rl_module.rl_module import ModuleID
  10. from ray.rllib.utils.annotations import override
  11. from ray.rllib.utils.lambda_defaultdict import LambdaDefaultDict
  12. from ray.rllib.utils.metrics import (
  13. ALL_MODULES,
  14. NUM_AGENT_STEPS_TRAINED,
  15. NUM_ENV_STEPS_TRAINED,
  16. )
  17. from ray.rllib.utils.schedules.scheduler import Scheduler
  18. from ray.rllib.utils.typing import ResultDict
  19. LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY = "curr_entropy_coeff"
  20. @dataclass
  21. class ImpalaLearnerHyperparameters(LearnerHyperparameters):
  22. """LearnerHyperparameters for the ImpalaLearner sub-classes (framework specific).
  23. These should never be set directly by the user. Instead, use the IMPALAConfig
  24. class to configure your algorithm.
  25. See `ray.rllib.algorithms.impala.impala::IMPALAConfig::training()` for more details
  26. on the individual properties.
  27. Attributes:
  28. rollout_frag_or_episode_len: The length of a rollout fragment or episode.
  29. Used when making SampleBatches time major for computing loss.
  30. recurrent_seq_len: The length of a recurrent sequence. Used when making
  31. SampleBatches time major for computing loss.
  32. """
  33. rollout_frag_or_episode_len: int = None
  34. recurrent_seq_len: int = None
  35. discount_factor: float = None
  36. vtrace_clip_rho_threshold: float = None
  37. vtrace_clip_pg_rho_threshold: float = None
  38. vf_loss_coeff: float = None
  39. entropy_coeff: float = None
  40. entropy_coeff_schedule: Optional[List[List[Union[int, float]]]] = None
  41. class ImpalaLearner(Learner):
  42. @override(Learner)
  43. def build(self) -> None:
  44. super().build()
  45. # Dict mapping module IDs to the respective entropy Scheduler instance.
  46. self.entropy_coeff_schedulers_per_module: Dict[
  47. ModuleID, Scheduler
  48. ] = LambdaDefaultDict(
  49. lambda module_id: Scheduler(
  50. fixed_value_or_schedule=(
  51. self.hps.get_hps_for_module(module_id).entropy_coeff
  52. ),
  53. framework=self.framework,
  54. device=self._device,
  55. )
  56. )
  57. @override(Learner)
  58. def remove_module(self, module_id: str):
  59. super().remove_module(module_id)
  60. self.entropy_coeff_schedulers_per_module.pop(module_id)
  61. @override(Learner)
  62. def additional_update_for_module(
  63. self, *, module_id: ModuleID, hps: ImpalaLearnerHyperparameters, timestep: int
  64. ) -> Dict[str, Any]:
  65. results = super().additional_update_for_module(
  66. module_id=module_id, hps=hps, timestep=timestep
  67. )
  68. # Update entropy coefficient via our Scheduler.
  69. new_entropy_coeff = self.entropy_coeff_schedulers_per_module[module_id].update(
  70. timestep=timestep
  71. )
  72. results.update({LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY: new_entropy_coeff})
  73. return results
  74. def _reduce_impala_results(results: List[ResultDict]) -> ResultDict:
  75. """Reduce/Aggregate a list of results from Impala Learners.
  76. Average the values of the result dicts. Add keys for the number of agent and env
  77. steps trained (on all modules).
  78. Args:
  79. results: result dicts to reduce.
  80. Returns:
  81. A reduced result dict.
  82. """
  83. result = tree.map_structure(lambda *x: np.mean(x), *results)
  84. agent_steps_trained = sum(r[ALL_MODULES][NUM_AGENT_STEPS_TRAINED] for r in results)
  85. env_steps_trained = sum(r[ALL_MODULES][NUM_ENV_STEPS_TRAINED] for r in results)
  86. result[ALL_MODULES][NUM_AGENT_STEPS_TRAINED] = agent_steps_trained
  87. result[ALL_MODULES][NUM_ENV_STEPS_TRAINED] = env_steps_trained
  88. return result