123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156 |
- from typing import Type
- from ray.rllib.agents.trainer import with_common_config
- from ray.rllib.agents.dqn.simple_q import SimpleQTrainer
- from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy
- from ray.rllib.evaluation.worker_set import WorkerSet
- from ray.rllib.execution.concurrency_ops import Concurrently
- from ray.rllib.execution.metric_ops import StandardMetricsReporting
- from ray.rllib.execution.replay_ops import SimpleReplayBuffer, Replay, \
- StoreToReplayBuffer
- from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
- from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork
- from ray.rllib.policy.policy import Policy
- from ray.rllib.utils.annotations import override
- from ray.rllib.utils.typing import TrainerConfigDict
- from ray.util.iter import LocalIterator
- # yapf: disable
- # __sphinx_doc_begin__
- DEFAULT_CONFIG = with_common_config({
- # === QMix ===
- # Mixing network. Either "qmix", "vdn", or None
- "mixer": "qmix",
- # Size of the mixing network embedding
- "mixing_embed_dim": 32,
- # Whether to use Double_Q learning
- "double_q": True,
- # Optimize over complete episodes by default.
- "batch_mode": "complete_episodes",
- # === Exploration Settings ===
- "exploration_config": {
- # The Exploration class to use.
- "type": "EpsilonGreedy",
- # Config for the Exploration class' constructor:
- "initial_epsilon": 1.0,
- "final_epsilon": 0.01,
- # Timesteps over which to anneal epsilon.
- "epsilon_timesteps": 40000,
- # For soft_q, use:
- # "exploration_config" = {
- # "type": "SoftQ"
- # "temperature": [float, e.g. 1.0]
- # }
- },
- # === Evaluation ===
- # Evaluate with epsilon=0 every `evaluation_interval` training iterations.
- # The evaluation stats will be reported under the "evaluation" metric key.
- # Note that evaluation is currently not parallelized, and that for Ape-X
- # metrics are already only reported for the lowest epsilon workers.
- "evaluation_interval": None,
- # Number of episodes to run per evaluation period.
- "evaluation_duration": 10,
- # Switch to greedy actions in evaluation workers.
- "evaluation_config": {
- "explore": False,
- },
- # Number of env steps to optimize for before returning
- "timesteps_per_iteration": 1000,
- # Update the target network every `target_network_update_freq` steps.
- "target_network_update_freq": 500,
- # === Replay buffer ===
- # Size of the replay buffer in batches (not timesteps!).
- "buffer_size": 1000,
- # === Optimization ===
- # Learning rate for RMSProp optimizer
- "lr": 0.0005,
- # RMSProp alpha
- "optim_alpha": 0.99,
- # RMSProp epsilon
- "optim_eps": 0.00001,
- # If not None, clip gradients during optimization at this value
- "grad_norm_clipping": 10,
- # How many steps of the model to sample before learning starts.
- "learning_starts": 1000,
- # Update the replay buffer with this many samples at once. Note that
- # this setting applies per-worker if num_workers > 1.
- "rollout_fragment_length": 4,
- # Size of a batched sampled from replay buffer for training. Note that
- # if async_updates is set, then each worker returns gradients for a
- # batch of this size.
- "train_batch_size": 32,
- # === Parallelism ===
- # Number of workers for collecting samples with. This only makes sense
- # to increase if your environment is particularly slow to sample, or if
- # you"re using the Async or Ape-X optimizers.
- "num_workers": 0,
- # Whether to compute priorities on workers.
- "worker_side_prioritization": False,
- # Prevent reporting frequency from going lower than this time span.
- "min_time_s_per_reporting": 1,
- # === Model ===
- "model": {
- "lstm_cell_size": 64,
- "max_seq_len": 999999,
- },
- # Only torch supported so far.
- "framework": "torch",
- })
- # __sphinx_doc_end__
- # yapf: enable
- class QMixTrainer(SimpleQTrainer):
- @classmethod
- @override(SimpleQTrainer)
- def get_default_config(cls) -> TrainerConfigDict:
- return DEFAULT_CONFIG
- @override(SimpleQTrainer)
- def validate_config(self, config: TrainerConfigDict) -> None:
- # Call super's validation method.
- super().validate_config(config)
- if config["framework"] != "torch":
- raise ValueError(
- "Only `framework=torch` supported so far for QMixTrainer!")
- @override(SimpleQTrainer)
- def get_default_policy_class(self,
- config: TrainerConfigDict) -> Type[Policy]:
- return QMixTorchPolicy
- @staticmethod
- @override(SimpleQTrainer)
- def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
- **kwargs) -> LocalIterator[dict]:
- assert len(kwargs) == 0, (
- "QMIX execution_plan does NOT take any additional parameters")
- rollouts = ParallelRollouts(workers, mode="bulk_sync")
- replay_buffer = SimpleReplayBuffer(config["buffer_size"])
- store_op = rollouts \
- .for_each(StoreToReplayBuffer(local_buffer=replay_buffer))
- train_op = Replay(local_buffer=replay_buffer) \
- .combine(
- ConcatBatches(
- min_batch_size=config["train_batch_size"],
- count_steps_by=config["multiagent"]["count_steps_by"]
- )) \
- .for_each(TrainOneStep(workers)) \
- .for_each(UpdateTargetNetwork(
- workers, config["target_network_update_freq"]))
- merged_op = Concurrently(
- [store_op, train_op], mode="round_robin", output_indexes=[1])
- return StandardMetricsReporting(merged_op, workers, config)
|