bc.py 1.0 KB

123456789101112131415161718192021222324252627282930313233343536
  1. """Behavioral Cloning (derived from MARWIL).
  2. Simply uses the MARWIL agent with beta force-set to 0.0.
  3. """
  4. from ray.rllib.agents.marwil.marwil import MARWILTrainer, \
  5. DEFAULT_CONFIG as MARWIL_CONFIG
  6. from ray.rllib.utils.typing import TrainerConfigDict
  7. # yapf: disable
  8. # __sphinx_doc_begin__
  9. BC_DEFAULT_CONFIG = MARWILTrainer.merge_trainer_configs(
  10. MARWIL_CONFIG, {
  11. # No need to calculate advantages (or do anything else with the
  12. # rewards).
  13. "beta": 0.0,
  14. # Advantages (calculated during postprocessing) not important for
  15. # behavioral cloning.
  16. "postprocess_inputs": False,
  17. # No reward estimation.
  18. "input_evaluation": [],
  19. })
  20. # __sphinx_doc_end__
  21. # yapf: enable
  22. def validate_config(config: TrainerConfigDict) -> None:
  23. if config["beta"] != 0.0:
  24. raise ValueError(
  25. "For behavioral cloning, `beta` parameter must be 0.0!")
  26. BCTrainer = MARWILTrainer.with_updates(
  27. name="BC",
  28. default_config=BC_DEFAULT_CONFIG,
  29. validate_config=validate_config,
  30. )