test_base_model.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. import unittest
  2. from ray.rllib.models.base_model import (
  3. UnrollOutputType,
  4. Model,
  5. RecurrentModel,
  6. ForwardOutputType,
  7. )
  8. import numpy as np
  9. from ray.rllib.models.temp_spec_classes import TensorDict, SpecDict
  10. from ray.rllib.utils.annotations import override
  11. from ray.rllib.utils.test_utils import check
  12. class NpRecurrentModelImpl(RecurrentModel):
  13. """A numpy recurrent model for checking:
  14. (1) initial states
  15. (2) that model in/out is as expected
  16. (3) unroll logic
  17. (4) spec checking"""
  18. def __init__(self, input_check=None, output_check=None):
  19. super().__init__()
  20. self.input_check = input_check
  21. self.output_check = output_check
  22. @property
  23. @override(RecurrentModel)
  24. def input_specs(self):
  25. return SpecDict({"in": "h"}, h=3)
  26. @property
  27. @override(RecurrentModel)
  28. def output_specs(self):
  29. return SpecDict({"out": "o"}, o=2)
  30. @property
  31. @override(RecurrentModel)
  32. def next_state_spec(self):
  33. return SpecDict({"out": "i"}, i=4)
  34. @property
  35. @override(RecurrentModel)
  36. def prev_state_spec(self):
  37. return SpecDict({"in": "o"}, o=1)
  38. @override(RecurrentModel)
  39. def _update_inputs_and_prev_state(self, inputs, states):
  40. if self.input_check:
  41. self.input_check(inputs, states)
  42. return inputs, states
  43. @override(RecurrentModel)
  44. def _update_outputs_and_next_state(self, outputs, states):
  45. if self.output_check:
  46. self.output_check(outputs, states)
  47. return outputs, states
  48. @override(RecurrentModel)
  49. def _initial_state(self):
  50. return TensorDict({"in": np.arange(1)})
  51. @override(RecurrentModel)
  52. def _unroll(self, inputs: TensorDict, prev_state: TensorDict) -> UnrollOutputType:
  53. # Ensure unroll is passed the input/state as expected
  54. # and does not mutate/permute it in any way
  55. check(inputs["in"], np.arange(3))
  56. check(prev_state["in"], np.arange(1))
  57. return TensorDict({"out": np.arange(2)}), TensorDict({"out": np.arange(4)})
  58. class NpModelImpl(Model):
  59. """Non-recurrent extension of NPRecurrentModelImpl
  60. For testing:
  61. (1) rollout and forward_ logic
  62. (2) spec checking
  63. """
  64. def __init__(self, input_check=None, output_check=None):
  65. super().__init__()
  66. self.input_check = input_check
  67. self.output_check = output_check
  68. @property
  69. @override(Model)
  70. def input_specs(self):
  71. return SpecDict({"in": "h"}, h=3)
  72. @property
  73. @override(Model)
  74. def output_specs(self):
  75. return SpecDict({"out": "o"}, o=2)
  76. @override(Model)
  77. def _update_inputs(self, inputs):
  78. if self.input_check:
  79. return self.input_check(inputs)
  80. return inputs
  81. @override(Model)
  82. def _update_outputs(self, outputs):
  83. if self.output_check:
  84. self.output_check(outputs)
  85. return outputs
  86. @override(Model)
  87. def _forward(self, inputs: TensorDict) -> ForwardOutputType:
  88. # Ensure _forward is passed the input from unroll as expected
  89. # and does not mutate/permute it in any way
  90. check(inputs["in"], np.arange(3))
  91. return TensorDict({"out": np.arange(2)})
  92. class TestRecurrentModel(unittest.TestCase):
  93. def test_initial_state(self):
  94. """Check that the _initial state is corrected called by initial_state
  95. and outputs correct values."""
  96. output = NpRecurrentModelImpl().initial_state()
  97. desired = TensorDict({"in": np.arange(1)})
  98. for k in output.flatten().keys() | desired.flatten().keys():
  99. check(output[k], desired[k])
  100. def test_unroll(self):
  101. """Test that _unroll is correctly called by unroll and outputs are the
  102. correct values"""
  103. out, out_state = NpRecurrentModelImpl().unroll(
  104. inputs=TensorDict({"in": np.arange(3)}),
  105. prev_state=TensorDict({"in": np.arange(1)}),
  106. )
  107. desired, desired_state = (
  108. TensorDict({"out": np.arange(2)}),
  109. TensorDict({"out": np.arange(4)}),
  110. )
  111. for k in out.flatten().keys() | desired.flatten().keys():
  112. check(out[k], desired[k])
  113. for k in out_state.flatten().keys() | desired_state.flatten().keys():
  114. check(out_state[k], desired_state[k])
  115. def test_unroll_filter(self):
  116. """Test that unroll correctly filters unused data"""
  117. def in_check(inputs, states):
  118. assert "bork" not in inputs.keys() and "borkbork" not in states.keys()
  119. return inputs, states
  120. m = NpRecurrentModelImpl(input_check=in_check)
  121. out, state = m.unroll(
  122. inputs=TensorDict({"in": np.arange(3), "bork": np.zeros(1)}),
  123. prev_state=TensorDict({"in": np.arange(1), "borkbork": np.zeros(1)}),
  124. )
  125. def test_hooks(self):
  126. """Test that _update_inputs_and_prev_state and _update_outputs_and_prev_state
  127. are called during unroll"""
  128. class MyException(Exception):
  129. pass
  130. def exc(a, b):
  131. raise MyException()
  132. with self.assertRaises(MyException):
  133. m = NpRecurrentModelImpl(input_check=exc)
  134. m.unroll(
  135. inputs=TensorDict({"in": np.arange(3)}),
  136. prev_state=TensorDict({"in": np.arange(1)}),
  137. )
  138. with self.assertRaises(MyException):
  139. m = NpRecurrentModelImpl(output_check=exc)
  140. m.unroll(
  141. inputs=TensorDict({"in": np.arange(3)}),
  142. prev_state=TensorDict({"in": np.arange(1)}),
  143. )
  144. class TestModel(unittest.TestCase):
  145. def test_unroll(self):
  146. """Test that unroll correctly calls _forward. The outputs
  147. should be as expected."""
  148. m = NpModelImpl()
  149. output, nullstate = m.unroll(
  150. inputs=TensorDict({"in": np.arange(3)}), prev_state=TensorDict()
  151. )
  152. self.assertEqual(
  153. nullstate,
  154. TensorDict(),
  155. )
  156. check(output["out"], np.arange(2))
  157. def test_hooks(self):
  158. """Test that unroll correctly calls the filter functions
  159. before _forward"""
  160. class MyException(Exception):
  161. pass
  162. def exc(a):
  163. raise MyException()
  164. with self.assertRaises(MyException):
  165. NpModelImpl(input_check=exc).unroll(
  166. inputs=TensorDict({"in": np.arange(3)}),
  167. prev_state=TensorDict(),
  168. )
  169. with self.assertRaises(MyException):
  170. NpModelImpl(output_check=exc).unroll(
  171. inputs=TensorDict({"in": np.arange(3)}),
  172. prev_state=TensorDict(),
  173. )
  174. if __name__ == "__main__":
  175. import pytest
  176. import sys
  177. sys.exit(pytest.main(["-v", __file__]))