test_convtranspose2d_stack.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import gymnasium as gym
  2. import numpy as np
  3. import os
  4. from pathlib import Path
  5. import unittest
  6. from ray.rllib.models.preprocessors import GenericPixelPreprocessor
  7. from ray.rllib.models.torch.modules.convtranspose2d_stack import ConvTranspose2DStack
  8. from ray.rllib.utils.framework import try_import_torch, try_import_tf
  9. from ray.rllib.utils.images import imread
  10. torch, nn = try_import_torch()
  11. tf1, tf, tfv = try_import_tf()
  12. class TestConvTranspose2DStack(unittest.TestCase):
  13. """Tests our ConvTranspose2D Stack modules/layers."""
  14. def test_convtranspose2d_stack(self):
  15. """Tests, whether the conv2d stack can be trained to predict an image."""
  16. batch_size = 128
  17. input_size = 1
  18. module = ConvTranspose2DStack(input_size=input_size)
  19. preprocessor = GenericPixelPreprocessor(
  20. gym.spaces.Box(0, 255, (64, 64, 3), np.uint8), options={"dim": 64}
  21. )
  22. optim = torch.optim.Adam(module.parameters(), lr=0.0001)
  23. rllib_dir = Path(__file__).parent.parent.parent
  24. img_file = os.path.join(rllib_dir, "tests/data/images/obstacle_tower.png")
  25. img = imread(img_file)
  26. # Preprocess.
  27. img = preprocessor.transform(img)
  28. # Make channels first.
  29. img = np.transpose(img, (2, 0, 1))
  30. # Add batch rank and repeat.
  31. imgs = np.reshape(img, (1,) + img.shape)
  32. imgs = np.repeat(imgs, batch_size, axis=0)
  33. # Move to torch.
  34. imgs = torch.from_numpy(imgs)
  35. init_loss = loss = None
  36. for _ in range(10):
  37. # Random inputs.
  38. inputs = torch.from_numpy(
  39. np.random.normal(0.0, 1.0, (batch_size, input_size))
  40. ).float()
  41. distribution = module(inputs)
  42. # Construct a loss.
  43. loss = -torch.mean(distribution.log_prob(imgs))
  44. if init_loss is None:
  45. init_loss = loss
  46. print("loss={}".format(loss))
  47. # Minimize loss.
  48. loss.backward()
  49. optim.step()
  50. self.assertLess(loss, init_loss)
  51. if __name__ == "__main__":
  52. import pytest
  53. import sys
  54. sys.exit(pytest.main(["-v", __file__]))