1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556 |
- import gym
- import unittest
- from ray.rllib.models.catalog import ModelCatalog, MODEL_DEFAULTS
- from ray.rllib.models.tf.visionnet import VisionNetwork
- from ray.rllib.models.torch.visionnet import VisionNetwork as TorchVision
- from ray.rllib.utils.framework import try_import_torch, try_import_tf
- from ray.rllib.utils.test_utils import framework_iterator
- torch, nn = try_import_torch()
- tf1, tf, tfv = try_import_tf()
- class TestConv2DDefaultStacks(unittest.TestCase):
- """Tests our ConvTranspose2D Stack modules/layers."""
- def test_conv2d_default_stacks(self):
- """Tests, whether conv2d defaults are available for img obs spaces.
- """
- action_space = gym.spaces.Discrete(2)
- shapes = [
- (480, 640, 3),
- (240, 320, 3),
- (96, 96, 3),
- (84, 84, 3),
- (42, 42, 3),
- (10, 10, 3),
- ]
- for shape in shapes:
- print(f"shape={shape}")
- obs_space = gym.spaces.Box(-1.0, 1.0, shape=shape)
- for fw in framework_iterator():
- model = ModelCatalog.get_model_v2(
- obs_space,
- action_space,
- 2,
- MODEL_DEFAULTS.copy(),
- framework=fw)
- self.assertTrue(
- isinstance(model, (VisionNetwork, TorchVision)))
- if fw == "torch":
- output, _ = model({
- "obs": torch.from_numpy(obs_space.sample()[None])
- })
- else:
- output, _ = model({"obs": obs_space.sample()[None]})
- # B x [action logits]
- self.assertTrue(output.shape == (1, 2))
- print("ok")
- if __name__ == "__main__":
- import pytest
- import sys
- sys.exit(pytest.main(["-v", __file__]))
|