nested_dict.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. """Custom NestedDict datatype."""
  2. from collections import abc
  3. import itertools
  4. from typing import (
  5. AbstractSet,
  6. Any,
  7. Dict,
  8. Generic,
  9. Iterable,
  10. Iterator,
  11. Mapping,
  12. MutableMapping,
  13. Optional,
  14. Sequence,
  15. Tuple,
  16. TypeVar,
  17. Union,
  18. )
  19. from ray.rllib.utils.annotations import ExperimentalAPI
  20. SeqStrType = Union[str, Sequence[str]]
  21. T = TypeVar("T")
  22. _NestedDictType = Dict[str, Any]
  23. _NestedMappingType = Mapping[SeqStrType, Any]
  24. NestedDictInputType = Union[
  25. Iterable[Tuple[SeqStrType, T]], _NestedMappingType, "NestedDict[T]"
  26. ]
  27. def _flatten_index(index: SeqStrType) -> Sequence[str]:
  28. if isinstance(index, str):
  29. return (index,)
  30. else:
  31. return tuple(itertools.chain.from_iterable([_flatten_index(y) for y in index]))
  32. @ExperimentalAPI
  33. class StrKey(str):
  34. """A string that can be compared to a string or sequence of strings representing a
  35. SeqStrType. This is needed for the tree functions to work.
  36. """
  37. def __lt__(self, other: SeqStrType):
  38. if isinstance(other, str):
  39. return str(self) < other
  40. else:
  41. return (self,) < tuple(other)
  42. def __gt__(self, other: SeqStrType):
  43. if isinstance(other, str):
  44. return str(self) > other
  45. else:
  46. return (self,) > tuple(other)
  47. @ExperimentalAPI
  48. class NestedDict(Generic[T], MutableMapping[str, Union[T, "NestedDict"]]):
  49. """A dict with special properties to support partial indexing.
  50. The main properties of NestedDict are::
  51. * The NestedDict gives access to nested elements as a sequence of
  52. strings.
  53. * These NestedDicts can also be used to filter a superset into a subset of
  54. nested elements with the filter function.
  55. * This can be instantiated with any mapping of strings, or an iterable of
  56. key value tuples where the values can themselves be recursively the values
  57. that a NestedDict can take.
  58. * The length of a NestedDict is the number of leaves in the tree, excluding
  59. empty leafs.
  60. * Iterating over a NestedDict yields the leaves of the tree, including empty
  61. leafs.
  62. Args:
  63. x: a representation of a NestedDict: it can be an iterable of `SeqStrType`
  64. to values. e.g. `[(("a", "b") , 1), ("b", 2)]` or a mapping of flattened
  65. keys to values. e.g. `{("a", "b"): 1, ("b",): 2}` or any nested mapping,
  66. e.g. `{"a": {"b": 1}, "b": {}}`.
  67. Example:
  68. Basic usage:
  69. >>> foo_dict = NestedDict()
  70. >>> # Setting elements, possibly nested:
  71. >>> foo_dict['a'] = 100 # foo_dict = {'a': 100}
  72. >>> foo_dict['b', 'c'] = 200 # foo_dict = {'a': 100, 'b': {'c': 200}}
  73. >>> foo_dict['b', 'd'] = 300 # foo_dict = {'a': 100,
  74. >>> # 'b': {'c': 200, 'd': 300}}
  75. >>> foo_dict['b', 'e'] = {} # foo_dict = {'a': 100,
  76. >>> # 'b': {'c': 200, 'd': 300}}
  77. >>> # Getting elements, possibly nested:
  78. >>> print(foo_dict['b', 'c']) # 200
  79. >>> print(foo_dict['b']) # {'c': 200, 'd': 300}
  80. >>> print(foo_dict.get('b')) # {'c': 200, 'd': 300}
  81. >>> print(foo_dict) # {'a': 100, 'b': {'c': 200, 'd': 300}}
  82. >>> # Converting to a dict:
  83. >>> foo_dict.asdict() # {'a': 100, 'b': {'c': 200, 'd': 300}}
  84. >>> # len function:
  85. >>> print(len(foo_dict)) # 3
  86. >>> # Iterating:
  87. >>> foo_dict.keys() # dict_keys(['a', ('b', 'c'), ('b', 'd')])
  88. >>> foo_dict.items() # dict_items([('a', 100), (('b', 'c'), 200), (('b', 'd'), 300)])
  89. >>> foo_dict.shallow_keys() # dict_keys(['a', 'b'])
  90. Filter:
  91. >>> dict1 = NestedDict([
  92. (('foo', 'a'), 10), (('foo', 'b'), 11),
  93. (('bar', 'c'), 11), (('bar', 'a'), 110)])
  94. >>> dict2 = NestedDict([('foo', NestedDict(dict(a=11)))])
  95. >>> dict3 = NestedDict([('foo', NestedDict(dict(a=100))),
  96. ('bar', NestedDict(dict(d=11)))])
  97. >>> dict4 = NestedDict([('foo', NestedDict(dict(a=100))),
  98. ('bar', NestedDict(dict(c=11)))])
  99. >>> dict1.filter(dict2).asdict() # {'foo': {'a': 10}}
  100. >>> dict1.filter(dict4).asdict() # {'bar': {'c': 11}, 'foo': {'a': 10}}
  101. >>> dict1.filter(dict3).asdict() # KeyError - ('bar', 'd') not in dict1
  102. """ # noqa: E501
  103. def __init__(
  104. self,
  105. x: Optional[NestedDictInputType] = None,
  106. ):
  107. # shallow dict
  108. self._data = dict() # type: Dict[str, Union[T, NestedDict[T]]]
  109. x = x if x is not None else {}
  110. if isinstance(x, NestedDict):
  111. self._data = x._data
  112. elif isinstance(x, abc.Mapping):
  113. for k in x:
  114. self[k] = x[k]
  115. elif isinstance(x, abc.Iterable):
  116. for k, v in x:
  117. self[k] = v
  118. else:
  119. raise ValueError(f"Input must be a Mapping or Iterable, got {type(x)}.")
  120. def __contains__(self, k: SeqStrType) -> bool:
  121. """Returns true if the key is in the NestedDict."""
  122. k = _flatten_index(k)
  123. data_ptr = self._data # type: Dict[str, Any]
  124. for key in k:
  125. # this is to avoid the recursion on __contains__
  126. if isinstance(data_ptr, NestedDict):
  127. data_ptr = data_ptr._data
  128. if not isinstance(data_ptr, Mapping) or key not in data_ptr:
  129. return False
  130. data_ptr = data_ptr[key]
  131. return True
  132. def get(
  133. self, k: SeqStrType, default: Optional[T] = None
  134. ) -> Union[T, "NestedDict[T]"]:
  135. """Returns `self[k]`, with partial indexing allowed.
  136. If `k` is not in the `NestedDict`, returns default. If default is `None`,
  137. and `k` is not in the `NestedDict`, a `KeyError` is raised.
  138. Args:
  139. k: The key to get. This can be a string or a sequence of strings.
  140. default: The default value to return if `k` is not in the `NestedDict`. If
  141. default is `None`, and `k` is not in the `NestedDict`, a `KeyError` is
  142. raised.
  143. Returns:
  144. The value of `self[k]`.
  145. Raises:
  146. KeyError: if `k` is not in the `NestedDict` and default is None.
  147. """
  148. k = _flatten_index(k)
  149. if k not in self:
  150. if default is not None:
  151. return default
  152. else:
  153. raise KeyError(k)
  154. data_ptr = self._data
  155. for key in k:
  156. # This is to avoid the recursion on __getitem__
  157. if isinstance(data_ptr, NestedDict):
  158. data_ptr = data_ptr._data
  159. data_ptr = data_ptr[key]
  160. return data_ptr
  161. def __getitem__(self, k: SeqStrType) -> T:
  162. output = self.get(k)
  163. return output
  164. def __setitem__(self, k: SeqStrType, v: Union[T, _NestedMappingType]) -> None:
  165. """Sets item at `k` to `v`.
  166. This is a zero-copy operation. The pointer to value if preserved in the
  167. internal data structure.
  168. """
  169. if not k:
  170. raise IndexError(
  171. f"Key for {self.__class__.__name__} cannot be empty. Got {k}."
  172. )
  173. k = _flatten_index(k)
  174. v = self.__class__(v) if isinstance(v, Mapping) else v
  175. data_ptr = self._data
  176. for k_indx, key in enumerate(k):
  177. # this is done to avoid recursion over __setitem__
  178. if isinstance(data_ptr, NestedDict):
  179. data_ptr = data_ptr._data
  180. if k_indx == len(k) - 1:
  181. data_ptr[key] = v
  182. elif key not in data_ptr:
  183. data_ptr[key] = self.__class__()
  184. data_ptr = data_ptr[key]
  185. def __iter__(self) -> Iterator[SeqStrType]:
  186. """Iterate over NestedDict, returning tuples of paths.
  187. Every iteration yields a tuple of strings, with each element of
  188. such a tuple representing a branch in the NestedDict. Each yielded tuple
  189. represents the path to a leaf. This includes leafs that are empty dicts.
  190. For example, if the NestedDict is: {'a': {'b': 1, 'c': {}}}, then this
  191. iterator will yield: ('a', 'b'), ('a', 'c').
  192. """
  193. data_ptr = self._data
  194. # do a DFS to get all the keys
  195. stack = [((StrKey(k),), v) for k, v in data_ptr.items()]
  196. while stack:
  197. k, v = stack.pop(0)
  198. if isinstance(v, NestedDict):
  199. if len(v._data) == 0:
  200. yield tuple(k)
  201. else:
  202. stack = [
  203. (k + (StrKey(k2),), v) for k2, v in v._data.items()
  204. ] + stack
  205. else:
  206. yield tuple(k)
  207. def __delitem__(self, k: SeqStrType) -> None:
  208. """Deletes item at `k`."""
  209. ks, ns = [], []
  210. data_ptr = self._data
  211. for k in _flatten_index(k):
  212. if isinstance(data_ptr, NestedDict):
  213. data_ptr = data_ptr._data
  214. if k not in data_ptr:
  215. raise KeyError(str(ks + [k]))
  216. ks.append(k)
  217. ns.append(data_ptr)
  218. data_ptr = data_ptr[k]
  219. del ns[-1][ks[-1]]
  220. for i in reversed(range(len(ks) - 1)):
  221. if not ns[i + 1]:
  222. del ns[i][ks[i]]
  223. def __len__(self) -> int:
  224. """Returns the length of the NestedDict.
  225. The length is defined as the number of leaf nodes in the `NestedDict` that
  226. are not of type Mapping. For example, if the `NestedDict` is: {'a': {'b': 1,
  227. 'c': {}}}, then the length is 1.
  228. """
  229. # do a DFS to count the number of leaf nodes
  230. count = 0
  231. stack = [self._data]
  232. while stack:
  233. node = stack.pop()
  234. if isinstance(node, NestedDict):
  235. node = node._data
  236. if isinstance(node, Mapping):
  237. stack.extend(node.values())
  238. else:
  239. count += 1
  240. return count
  241. def __str__(self) -> str:
  242. return str(self.asdict())
  243. def __repr__(self) -> str:
  244. return f"NestedDict({repr(self._data)})"
  245. def filter(
  246. self,
  247. other: Union[Sequence[SeqStrType], "NestedDict"],
  248. ignore_missing: bool = False,
  249. ) -> "NestedDict[T]":
  250. """Returns a NestedDict with only entries present in `other`.
  251. The values in the `other` NestedDict are ignored. Only the keys are used.
  252. Args:
  253. other: a NestedDict or a sequence of keys to filter by.
  254. ignore_missing: if True, ignore missing keys in `other`.
  255. Returns:
  256. A NestedDict with only keys present in `other`.
  257. """
  258. output = self.__class__()
  259. if isinstance(other, Sequence):
  260. keys = other
  261. else:
  262. keys = other.keys()
  263. for k in keys:
  264. if k not in self:
  265. if not ignore_missing:
  266. raise KeyError(k)
  267. else:
  268. output[k] = self.get(k)
  269. return output
  270. def asdict(self) -> _NestedDictType:
  271. """Returns a dictionary representation of the NestedDict."""
  272. output = dict()
  273. for k, v in self._data.items():
  274. if isinstance(v, NestedDict):
  275. output[k] = v.asdict()
  276. else:
  277. output[k] = v
  278. return output
  279. def copy(self) -> "NestedDict[T]":
  280. """Returns a shallow copy of the NestedDict."""
  281. return NestedDict(self.items())
  282. def __copy__(self) -> "NestedDict[T]":
  283. return self.copy()
  284. def shallow_keys(self) -> AbstractSet[str]:
  285. """Returns a set of the keys at the top level of the NestedDict."""
  286. return self._data.keys()