custom_fast_model.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. """Example of using a custom image env and model.
  2. Both the model and env are trivial (and super-fast), so they are useful
  3. for running perf microbenchmarks.
  4. """
  5. import argparse
  6. import os
  7. import ray
  8. import ray.tune as tune
  9. from ray.tune import sample_from
  10. from ray.rllib.examples.env.fast_image_env import FastImageEnv
  11. from ray.rllib.examples.models.fast_model import FastModel, TorchFastModel
  12. from ray.rllib.models import ModelCatalog
  13. parser = argparse.ArgumentParser()
  14. parser.add_argument("--num-cpus", type=int, default=4)
  15. parser.add_argument(
  16. "--framework",
  17. choices=["tf", "tf2", "tfe", "torch"],
  18. default="tf",
  19. help="The DL framework specifier.")
  20. parser.add_argument("--stop-iters", type=int, default=200)
  21. parser.add_argument("--stop-timesteps", type=int, default=100000)
  22. if __name__ == "__main__":
  23. args = parser.parse_args()
  24. ray.init(num_cpus=args.num_cpus or None)
  25. ModelCatalog.register_custom_model(
  26. "fast_model", TorchFastModel
  27. if args.framework == "torch" else FastModel)
  28. config = {
  29. "env": FastImageEnv,
  30. "compress_observations": True,
  31. "model": {
  32. "custom_model": "fast_model"
  33. },
  34. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  35. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  36. "num_workers": 2,
  37. "num_envs_per_worker": 10,
  38. "num_multi_gpu_tower_stacks": 1,
  39. "num_aggregation_workers": 1,
  40. "broadcast_interval": 50,
  41. "rollout_fragment_length": 100,
  42. "train_batch_size": sample_from(
  43. lambda spec: 1000 * max(1, spec.config.num_gpus or 1)),
  44. "fake_sampler": True,
  45. "framework": args.framework,
  46. }
  47. stop = {
  48. "training_iteration": args.stop_iters,
  49. "timesteps_total": args.stop_timesteps,
  50. }
  51. tune.run("IMPALA", config=config, stop=stop, verbose=1)
  52. ray.shutdown()