advanced_api.py 908 B

123456789101112131415161718192021222324252627282930313233343536373839
  1. # flake8: noqa
  2. # __rllib-adv_api_counter_begin__
  3. import ray
  4. @ray.remote
  5. class Counter:
  6. def __init__(self):
  7. self.count = 0
  8. def inc(self, n):
  9. self.count += n
  10. def get(self):
  11. return self.count
  12. # on the driver
  13. counter = Counter.options(name="global_counter").remote()
  14. print(ray.get(counter.get.remote())) # get the latest count
  15. # in your envs
  16. counter = ray.get_actor("global_counter")
  17. counter.inc.remote(1) # async call to increment the global count
  18. # __rllib-adv_api_counter_end__
  19. # __rllib-adv_api_explore_begin__
  20. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  21. config = AlgorithmConfig().exploration(
  22. exploration_config={
  23. # Special `type` key provides class information
  24. "type": "StochasticSampling",
  25. # Add any needed constructor args here.
  26. "constructor_arg": "value",
  27. }
  28. )
  29. # __rllib-adv_api_explore_end__