filter_manager.py 1.1 KB

12345678910111213141516171819202122232425262728293031
  1. import ray
  2. from ray.rllib.utils.annotations import DeveloperAPI
  3. @DeveloperAPI
  4. class FilterManager:
  5. """Manages filters and coordination across remote evaluators that expose
  6. `get_filters` and `sync_filters`.
  7. """
  8. @staticmethod
  9. @DeveloperAPI
  10. def synchronize(local_filters, remotes, update_remote=True):
  11. """Aggregates all filters from remote evaluators.
  12. Local copy is updated and then broadcasted to all remote evaluators.
  13. Args:
  14. local_filters (dict): Filters to be synchronized.
  15. remotes (list): Remote evaluators with filters.
  16. update_remote (bool): Whether to push updates to remote filters.
  17. """
  18. remote_filters = ray.get(
  19. [r.get_filters.remote(flush_after=True) for r in remotes])
  20. for rf in remote_filters:
  21. for k in local_filters:
  22. local_filters[k].apply_changes(rf[k], with_buffer=False)
  23. if update_remote:
  24. copies = {k: v.as_serializable() for k, v in local_filters.items()}
  25. remote_copy = ray.put(copies)
  26. [r.sync_filters.remote(remote_copy) for r in remotes]