test_nested_dict.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. import unittest
  2. from ray.rllib.utils.test_utils import check
  3. from ray.rllib.utils.nested_dict import NestedDict
  4. class TestNestedDict(unittest.TestCase):
  5. def test_basics(self):
  6. foo_dict = NestedDict()
  7. # test __setitem__
  8. def set_invalid_item_1():
  9. foo_dict[()] = 1
  10. def set_invalid_item_2():
  11. foo_dict[""] = 1
  12. self.assertRaises(IndexError, set_invalid_item_1)
  13. self.assertRaises(IndexError, set_invalid_item_2)
  14. desired_dict = {
  15. "aa": 100,
  16. "b": {"c": 200, "d": 300},
  17. "c": {"e": {"f": 400}},
  18. "d": {"g": {"h": {"i": 500}}},
  19. # An empty dict that has no leafs and thus should be ignored when
  20. # counting or iterating
  21. "j": {"k": {}},
  22. "l": {},
  23. }
  24. desired_keys = [
  25. ("aa",),
  26. ("b", "c"),
  27. ("b", "d"),
  28. ("c", "e", "f"),
  29. ("d", "g", "h", "i"),
  30. ("j", "k"),
  31. ("l",),
  32. ]
  33. # We have 5 leafs that are not empty and two empty leafs
  34. desired_values = [100, 200, 300, 400, 500, NestedDict(), NestedDict()]
  35. foo_dict["aa"] = 100
  36. foo_dict["b", "c"] = 200
  37. foo_dict[("b", "d")] = 300
  38. foo_dict["c", "e"] = {"f": 400}
  39. # test __len__
  40. # We have not yet included d, j and l in foo_dict
  41. self.assertEqual(len(foo_dict), len(desired_keys) - 3)
  42. # test __iter__
  43. self.assertEqual(list(iter(foo_dict)), desired_keys[:-3])
  44. # this call will use __len__ and __iter__
  45. foo_dict["d"] = {"g": NestedDict([("h", NestedDict({"i": 500}))])}
  46. foo_dict["j"] = {"k": {}}
  47. foo_dict["l"] = {}
  48. # test asdict
  49. check(foo_dict.asdict(), desired_dict)
  50. # test __len__ again
  51. # We have included d, j and l in foo_dict, but j and l don't contribute to
  52. # the length because they are empty sub-roots of the tree structure with no
  53. # leafs.
  54. self.assertEqual(len(foo_dict), len(desired_keys) - 2)
  55. # test __iter__ again
  56. self.assertEqual(list(iter(foo_dict)), desired_keys)
  57. # test __contains__
  58. self.assertTrue("aa" in foo_dict)
  59. self.assertTrue(("b", "c") in foo_dict)
  60. self.assertTrue(("b", "c") in foo_dict)
  61. self.assertTrue(("b", "d") in foo_dict)
  62. self.assertTrue(("d", "g", ("h", "i")) in foo_dict)
  63. self.assertFalse("f" in foo_dict)
  64. self.assertFalse(("b", "e") in foo_dict)
  65. # test get()
  66. self.assertEqual(foo_dict.get("aa"), 100)
  67. self.assertEqual(foo_dict.get("b").asdict(), {"c": 200, "d": 300})
  68. self.assertEqual(foo_dict.get(("b", "d")), 300)
  69. self.assertRaises(KeyError, lambda: foo_dict.get("e"))
  70. self.assertEqual(foo_dict.get("e", default=400), 400)
  71. # test __getitem__
  72. self.assertEqual(foo_dict["aa"], 100)
  73. self.assertEqual(foo_dict["b", "c"], 200)
  74. self.assertEqual(foo_dict["c", "e", "f"], 400)
  75. self.assertEqual(foo_dict["d", "g", "h", "i"], 500)
  76. self.assertEqual(foo_dict["b"], NestedDict({"c": 200, "d": 300}))
  77. # test __str__
  78. self.assertEqual(str(foo_dict), str(desired_dict))
  79. # test keys()
  80. self.assertEqual(list(foo_dict.keys()), desired_keys)
  81. # test values()
  82. self.assertEqual(list(foo_dict.values()), desired_values)
  83. # test items()
  84. self.assertEqual(
  85. list(foo_dict.items()), list(zip(desired_keys, desired_values))
  86. )
  87. # test shallow_keys()
  88. self.assertEqual(list(foo_dict.shallow_keys()), ["aa", "b", "c", "d", "j", "l"])
  89. # test copy()
  90. foo_dict_copy = foo_dict.copy()
  91. self.assertEqual(foo_dict_copy.asdict(), foo_dict.asdict())
  92. self.assertIsNot(foo_dict_copy, foo_dict)
  93. # test __delitem__
  94. del foo_dict["d", "g", "h", "i"]
  95. del desired_dict["d"]["g"]
  96. self.assertNotEqual(foo_dict.asdict(), desired_dict)
  97. del desired_dict["d"]
  98. self.assertEqual(foo_dict.asdict(), desired_dict)
  99. def test_filter(self):
  100. dict1 = NestedDict(
  101. [
  102. (("foo", "a"), 10),
  103. (("foo", "b"), 11),
  104. (("bar", "c"), 11),
  105. (("bar", "a"), 110),
  106. ]
  107. )
  108. dict2 = NestedDict([("foo", NestedDict(dict(a=33)))])
  109. dict3 = NestedDict(
  110. [("foo", NestedDict(dict(a=None))), ("bar", NestedDict(dict(d=None)))]
  111. )
  112. dict4 = NestedDict(
  113. [("foo", NestedDict(dict(a=None))), ("bar", NestedDict(dict(c=None)))]
  114. )
  115. self.assertEqual(dict1.filter(dict2).asdict(), {"foo": {"a": 10}})
  116. self.assertEqual(
  117. dict1.filter(dict4).asdict(), {"bar": {"c": 11}, "foo": {"a": 10}}
  118. )
  119. self.assertRaises(KeyError, lambda: dict1.filter(dict3).asdict())
  120. self.assertEqual(
  121. dict1.filter(dict3, ignore_missing=True).asdict(), {"foo": {"a": 10}}
  122. )
  123. def test_init(self):
  124. # test init with list
  125. foo_dict = NestedDict([(("a", "b"), 1), (("a", "c"), 2)])
  126. self.assertEqual(foo_dict.asdict(), {"a": {"b": 1, "c": 2}})
  127. # test init with dict
  128. foo_dict = NestedDict({"a": {"b": 1, "c": 2}})
  129. self.assertEqual(foo_dict.asdict(), {"a": {"b": 1, "c": 2}})
  130. # test init with NestedDict
  131. foo_dict = NestedDict(NestedDict({"a": {"b": 1, "c": 2}}))
  132. self.assertEqual(foo_dict.asdict(), {"a": {"b": 1, "c": 2}})
  133. # test init empty element
  134. foo_dict = NestedDict({"a": {}})
  135. self.assertEqual(foo_dict.asdict(), {"a": {}})
  136. # test init with nested empty element
  137. foo_dict = NestedDict({"a": {"b": {}, "c": 2}})
  138. self.assertEqual(foo_dict.asdict(), {"a": {"b": {}, "c": 2}})
  139. # test init with empty dict
  140. self.assertEqual(NestedDict().asdict(), {})
  141. if __name__ == "__main__":
  142. import pytest
  143. import sys
  144. sys.exit(pytest.main(["-v", __file__]))