dreamerv3_learner.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. """
  2. [1] Mastering Diverse Domains through World Models - 2023
  3. D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
  4. https://arxiv.org/pdf/2301.04104v1.pdf
  5. [2] Mastering Atari with Discrete World Models - 2021
  6. D. Hafner, T. Lillicrap, M. Norouzi, J. Ba
  7. https://arxiv.org/pdf/2010.02193.pdf
  8. """
  9. from dataclasses import dataclass
  10. from typing import Any, DefaultDict, Dict
  11. from ray.rllib.core.learner.learner import Learner, LearnerHyperparameters
  12. from ray.rllib.core.rl_module.rl_module import ModuleID
  13. from ray.rllib.policy.sample_batch import MultiAgentBatch
  14. from ray.rllib.utils.annotations import override
  15. from ray.rllib.utils.typing import TensorType
  16. @dataclass
  17. class DreamerV3LearnerHyperparameters(LearnerHyperparameters):
  18. """Hyperparameters for the DreamerV3Learner sub-classes (framework specific).
  19. These should never be set directly by the user. Instead, use the DreamerV3Config
  20. class to configure your algorithm.
  21. See `ray.rllib.algorithms.dreamerv3.dreamerv3::DreamerV3Config::training()` for
  22. more details on the individual properties.
  23. """
  24. model_size: str = None
  25. training_ratio: float = None
  26. batch_size_B: int = None
  27. batch_length_T: int = None
  28. horizon_H: int = None
  29. gamma: float = None
  30. gae_lambda: float = None
  31. entropy_scale: float = None
  32. return_normalization_decay: float = None
  33. world_model_lr: float = None
  34. actor_lr: float = None
  35. critic_lr: float = None
  36. train_critic: bool = None
  37. train_actor: bool = None
  38. use_curiosity: bool = None
  39. intrinsic_rewards_scale: float = None
  40. world_model_grad_clip_by_global_norm: float = None
  41. actor_grad_clip_by_global_norm: float = None
  42. critic_grad_clip_by_global_norm: float = None
  43. use_float16: bool = None
  44. # Reporting settings.
  45. report_individual_batch_item_stats: bool = None
  46. report_dream_data: bool = None
  47. report_images_and_videos: bool = None
  48. class DreamerV3Learner(Learner):
  49. """DreamerV3 specific Learner class.
  50. Only implements the `additional_update_for_module()` method to define the logic
  51. for updating the critic EMA-copy after each training step.
  52. """
  53. @override(Learner)
  54. def compile_results(
  55. self,
  56. *,
  57. batch: MultiAgentBatch,
  58. fwd_out: Dict[str, Any],
  59. loss_per_module: Dict[str, TensorType],
  60. metrics_per_module: DefaultDict[ModuleID, Dict[str, Any]],
  61. ) -> Dict[str, Any]:
  62. results = super().compile_results(
  63. batch=batch,
  64. fwd_out=fwd_out,
  65. loss_per_module=loss_per_module,
  66. metrics_per_module=metrics_per_module,
  67. )
  68. # Add the predicted obs distributions for possible (video) summarization.
  69. if self.hps.report_images_and_videos:
  70. for module_id, res in results.items():
  71. if module_id in fwd_out:
  72. res["WORLD_MODEL_fwd_out_obs_distribution_means_BxT"] = fwd_out[
  73. module_id
  74. ]["obs_distribution_means_BxT"]
  75. return results
  76. @override(Learner)
  77. def additional_update_for_module(
  78. self,
  79. *,
  80. module_id: ModuleID,
  81. hps: DreamerV3LearnerHyperparameters,
  82. timestep: int,
  83. ) -> Dict[str, Any]:
  84. """Updates the EMA weights of the critic network."""
  85. # Call the base class' method.
  86. results = super().additional_update_for_module(
  87. module_id=module_id, hps=hps, timestep=timestep
  88. )
  89. # Update EMA weights of the critic.
  90. self.module[module_id].critic.update_ema()
  91. return results