test_space_utils.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. """Test utils in rllib/utils/space_utils.py."""
  2. import unittest
  3. import numpy as np
  4. from gym.spaces import Box, Discrete, MultiDiscrete, MultiBinary, Tuple, Dict
  5. from ray.rllib.utils.spaces.space_utils import convert_element_to_space_type
  6. class TestSpaceUtils(unittest.TestCase):
  7. def test_convert_element_to_space_type(self):
  8. """Test if space converter works for all elements/space permutations"""
  9. box_space = Box(low=-1, high=1, shape=(2, ))
  10. discrete_space = Discrete(2)
  11. multi_discrete_space = MultiDiscrete([2, 2])
  12. multi_binary_space = MultiBinary(2)
  13. tuple_space = Tuple((box_space, discrete_space))
  14. dict_space = Dict({
  15. "box": box_space,
  16. "discrete": discrete_space,
  17. "multi_discrete": multi_discrete_space,
  18. "multi_binary": multi_binary_space,
  19. "dict_space": Dict({
  20. "box2": box_space,
  21. "discrete2": discrete_space,
  22. }),
  23. "tuple_space": tuple_space
  24. })
  25. box_space_uncoverted = box_space.sample().astype(np.float64)
  26. multi_discrete_unconverted = multi_discrete_space.sample().astype(
  27. np.int32)
  28. multi_binary_unconverted = multi_binary_space.sample().astype(np.int32)
  29. tuple_unconverted = (box_space_uncoverted, float(0))
  30. modified_element = {
  31. "box": box_space_uncoverted,
  32. "discrete": float(0),
  33. "multi_discrete": multi_discrete_unconverted,
  34. "multi_binary": multi_binary_unconverted,
  35. "tuple_space": tuple_unconverted,
  36. "dict_space": {
  37. "box2": box_space_uncoverted,
  38. "discrete2": float(0),
  39. }
  40. }
  41. element_with_correct_types = convert_element_to_space_type(
  42. modified_element, dict_space.sample())
  43. assert dict_space.contains(element_with_correct_types)
  44. if __name__ == "__main__":
  45. import pytest
  46. import sys
  47. sys.exit(pytest.main(["-v", __file__]))