test_conv2d_default_stacks.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import gym
  2. import unittest
  3. from ray.rllib.models.catalog import ModelCatalog, MODEL_DEFAULTS
  4. from ray.rllib.models.tf.visionnet import VisionNetwork
  5. from ray.rllib.models.torch.visionnet import VisionNetwork as TorchVision
  6. from ray.rllib.utils.framework import try_import_torch, try_import_tf
  7. from ray.rllib.utils.test_utils import framework_iterator
  8. torch, nn = try_import_torch()
  9. tf1, tf, tfv = try_import_tf()
  10. class TestConv2DDefaultStacks(unittest.TestCase):
  11. """Tests our ConvTranspose2D Stack modules/layers."""
  12. def test_conv2d_default_stacks(self):
  13. """Tests, whether conv2d defaults are available for img obs spaces.
  14. """
  15. action_space = gym.spaces.Discrete(2)
  16. shapes = [
  17. (480, 640, 3),
  18. (240, 320, 3),
  19. (96, 96, 3),
  20. (84, 84, 3),
  21. (42, 42, 3),
  22. (10, 10, 3),
  23. ]
  24. for shape in shapes:
  25. print(f"shape={shape}")
  26. obs_space = gym.spaces.Box(-1.0, 1.0, shape=shape)
  27. for fw in framework_iterator():
  28. model = ModelCatalog.get_model_v2(
  29. obs_space,
  30. action_space,
  31. 2,
  32. MODEL_DEFAULTS.copy(),
  33. framework=fw)
  34. self.assertTrue(
  35. isinstance(model, (VisionNetwork, TorchVision)))
  36. if fw == "torch":
  37. output, _ = model({
  38. "obs": torch.from_numpy(obs_space.sample()[None])
  39. })
  40. else:
  41. output, _ = model({"obs": obs_space.sample()[None]})
  42. # B x [action logits]
  43. self.assertTrue(output.shape == (1, 2))
  44. print("ok")
  45. if __name__ == "__main__":
  46. import pytest
  47. import sys
  48. sys.exit(pytest.main(["-v", __file__]))