123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354 |
- """Test utils in rllib/utils/space_utils.py."""
- import unittest
- import numpy as np
- from gym.spaces import Box, Discrete, MultiDiscrete, MultiBinary, Tuple, Dict
- from ray.rllib.utils.spaces.space_utils import convert_element_to_space_type
- class TestSpaceUtils(unittest.TestCase):
- def test_convert_element_to_space_type(self):
- """Test if space converter works for all elements/space permutations"""
- box_space = Box(low=-1, high=1, shape=(2, ))
- discrete_space = Discrete(2)
- multi_discrete_space = MultiDiscrete([2, 2])
- multi_binary_space = MultiBinary(2)
- tuple_space = Tuple((box_space, discrete_space))
- dict_space = Dict({
- "box": box_space,
- "discrete": discrete_space,
- "multi_discrete": multi_discrete_space,
- "multi_binary": multi_binary_space,
- "dict_space": Dict({
- "box2": box_space,
- "discrete2": discrete_space,
- }),
- "tuple_space": tuple_space
- })
- box_space_uncoverted = box_space.sample().astype(np.float64)
- multi_discrete_unconverted = multi_discrete_space.sample().astype(
- np.int32)
- multi_binary_unconverted = multi_binary_space.sample().astype(np.int32)
- tuple_unconverted = (box_space_uncoverted, float(0))
- modified_element = {
- "box": box_space_uncoverted,
- "discrete": float(0),
- "multi_discrete": multi_discrete_unconverted,
- "multi_binary": multi_binary_unconverted,
- "tuple_space": tuple_unconverted,
- "dict_space": {
- "box2": box_space_uncoverted,
- "discrete2": float(0),
- }
- }
- element_with_correct_types = convert_element_to_space_type(
- modified_element, dict_space.sample())
- assert dict_space.contains(element_with_correct_types)
- if __name__ == "__main__":
- import pytest
- import sys
- sys.exit(pytest.main(["-v", __file__]))
|