123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336 |
- """Custom NestedDict datatype."""
- from collections import abc
- import itertools
- from typing import (
- AbstractSet,
- Any,
- Dict,
- Generic,
- Iterable,
- Iterator,
- Mapping,
- MutableMapping,
- Optional,
- Sequence,
- Tuple,
- TypeVar,
- Union,
- )
- from ray.rllib.utils.annotations import ExperimentalAPI
- SeqStrType = Union[str, Sequence[str]]
- T = TypeVar("T")
- _NestedDictType = Dict[str, Any]
- _NestedMappingType = Mapping[SeqStrType, Any]
- NestedDictInputType = Union[
- Iterable[Tuple[SeqStrType, T]], _NestedMappingType, "NestedDict[T]"
- ]
- def _flatten_index(index: SeqStrType) -> Sequence[str]:
- if isinstance(index, str):
- return (index,)
- else:
- return tuple(itertools.chain.from_iterable([_flatten_index(y) for y in index]))
- @ExperimentalAPI
- class StrKey(str):
- """A string that can be compared to a string or sequence of strings representing a
- SeqStrType. This is needed for the tree functions to work.
- """
- def __lt__(self, other: SeqStrType):
- if isinstance(other, str):
- return str(self) < other
- else:
- return (self,) < tuple(other)
- def __gt__(self, other: SeqStrType):
- if isinstance(other, str):
- return str(self) > other
- else:
- return (self,) > tuple(other)
- @ExperimentalAPI
- class NestedDict(Generic[T], MutableMapping[str, Union[T, "NestedDict"]]):
- """A dict with special properties to support partial indexing.
- The main properties of NestedDict are::
- * The NestedDict gives access to nested elements as a sequence of
- strings.
- * These NestedDicts can also be used to filter a superset into a subset of
- nested elements with the filter function.
- * This can be instantiated with any mapping of strings, or an iterable of
- key value tuples where the values can themselves be recursively the values
- that a NestedDict can take.
- * The length of a NestedDict is the number of leaves in the tree, excluding
- empty leafs.
- * Iterating over a NestedDict yields the leaves of the tree, including empty
- leafs.
- Args:
- x: a representation of a NestedDict: it can be an iterable of `SeqStrType`
- to values. e.g. `[(("a", "b") , 1), ("b", 2)]` or a mapping of flattened
- keys to values. e.g. `{("a", "b"): 1, ("b",): 2}` or any nested mapping,
- e.g. `{"a": {"b": 1}, "b": {}}`.
- Example:
- Basic usage:
- >>> foo_dict = NestedDict()
- >>> # Setting elements, possibly nested:
- >>> foo_dict['a'] = 100 # foo_dict = {'a': 100}
- >>> foo_dict['b', 'c'] = 200 # foo_dict = {'a': 100, 'b': {'c': 200}}
- >>> foo_dict['b', 'd'] = 300 # foo_dict = {'a': 100,
- >>> # 'b': {'c': 200, 'd': 300}}
- >>> foo_dict['b', 'e'] = {} # foo_dict = {'a': 100,
- >>> # 'b': {'c': 200, 'd': 300}}
- >>> # Getting elements, possibly nested:
- >>> print(foo_dict['b', 'c']) # 200
- >>> print(foo_dict['b']) # {'c': 200, 'd': 300}
- >>> print(foo_dict.get('b')) # {'c': 200, 'd': 300}
- >>> print(foo_dict) # {'a': 100, 'b': {'c': 200, 'd': 300}}
- >>> # Converting to a dict:
- >>> foo_dict.asdict() # {'a': 100, 'b': {'c': 200, 'd': 300}}
- >>> # len function:
- >>> print(len(foo_dict)) # 3
- >>> # Iterating:
- >>> foo_dict.keys() # dict_keys(['a', ('b', 'c'), ('b', 'd')])
- >>> foo_dict.items() # dict_items([('a', 100), (('b', 'c'), 200), (('b', 'd'), 300)])
- >>> foo_dict.shallow_keys() # dict_keys(['a', 'b'])
- Filter:
- >>> dict1 = NestedDict([
- (('foo', 'a'), 10), (('foo', 'b'), 11),
- (('bar', 'c'), 11), (('bar', 'a'), 110)])
- >>> dict2 = NestedDict([('foo', NestedDict(dict(a=11)))])
- >>> dict3 = NestedDict([('foo', NestedDict(dict(a=100))),
- ('bar', NestedDict(dict(d=11)))])
- >>> dict4 = NestedDict([('foo', NestedDict(dict(a=100))),
- ('bar', NestedDict(dict(c=11)))])
- >>> dict1.filter(dict2).asdict() # {'foo': {'a': 10}}
- >>> dict1.filter(dict4).asdict() # {'bar': {'c': 11}, 'foo': {'a': 10}}
- >>> dict1.filter(dict3).asdict() # KeyError - ('bar', 'd') not in dict1
- """ # noqa: E501
- def __init__(
- self,
- x: Optional[NestedDictInputType] = None,
- ):
- # shallow dict
- self._data = dict() # type: Dict[str, Union[T, NestedDict[T]]]
- x = x if x is not None else {}
- if isinstance(x, NestedDict):
- self._data = x._data
- elif isinstance(x, abc.Mapping):
- for k in x:
- self[k] = x[k]
- elif isinstance(x, abc.Iterable):
- for k, v in x:
- self[k] = v
- else:
- raise ValueError(f"Input must be a Mapping or Iterable, got {type(x)}.")
- def __contains__(self, k: SeqStrType) -> bool:
- """Returns true if the key is in the NestedDict."""
- k = _flatten_index(k)
- data_ptr = self._data # type: Dict[str, Any]
- for key in k:
- # this is to avoid the recursion on __contains__
- if isinstance(data_ptr, NestedDict):
- data_ptr = data_ptr._data
- if not isinstance(data_ptr, Mapping) or key not in data_ptr:
- return False
- data_ptr = data_ptr[key]
- return True
- def get(
- self, k: SeqStrType, default: Optional[T] = None
- ) -> Union[T, "NestedDict[T]"]:
- """Returns `self[k]`, with partial indexing allowed.
- If `k` is not in the `NestedDict`, returns default. If default is `None`,
- and `k` is not in the `NestedDict`, a `KeyError` is raised.
- Args:
- k: The key to get. This can be a string or a sequence of strings.
- default: The default value to return if `k` is not in the `NestedDict`. If
- default is `None`, and `k` is not in the `NestedDict`, a `KeyError` is
- raised.
- Returns:
- The value of `self[k]`.
- Raises:
- KeyError: if `k` is not in the `NestedDict` and default is None.
- """
- k = _flatten_index(k)
- if k not in self:
- if default is not None:
- return default
- else:
- raise KeyError(k)
- data_ptr = self._data
- for key in k:
- # This is to avoid the recursion on __getitem__
- if isinstance(data_ptr, NestedDict):
- data_ptr = data_ptr._data
- data_ptr = data_ptr[key]
- return data_ptr
- def __getitem__(self, k: SeqStrType) -> T:
- output = self.get(k)
- return output
- def __setitem__(self, k: SeqStrType, v: Union[T, _NestedMappingType]) -> None:
- """Sets item at `k` to `v`.
- This is a zero-copy operation. The pointer to value if preserved in the
- internal data structure.
- """
- if not k:
- raise IndexError(
- f"Key for {self.__class__.__name__} cannot be empty. Got {k}."
- )
- k = _flatten_index(k)
- v = self.__class__(v) if isinstance(v, Mapping) else v
- data_ptr = self._data
- for k_indx, key in enumerate(k):
- # this is done to avoid recursion over __setitem__
- if isinstance(data_ptr, NestedDict):
- data_ptr = data_ptr._data
- if k_indx == len(k) - 1:
- data_ptr[key] = v
- elif key not in data_ptr:
- data_ptr[key] = self.__class__()
- data_ptr = data_ptr[key]
- def __iter__(self) -> Iterator[SeqStrType]:
- """Iterate over NestedDict, returning tuples of paths.
- Every iteration yields a tuple of strings, with each element of
- such a tuple representing a branch in the NestedDict. Each yielded tuple
- represents the path to a leaf. This includes leafs that are empty dicts.
- For example, if the NestedDict is: {'a': {'b': 1, 'c': {}}}, then this
- iterator will yield: ('a', 'b'), ('a', 'c').
- """
- data_ptr = self._data
- # do a DFS to get all the keys
- stack = [((StrKey(k),), v) for k, v in data_ptr.items()]
- while stack:
- k, v = stack.pop(0)
- if isinstance(v, NestedDict):
- if len(v._data) == 0:
- yield tuple(k)
- else:
- stack = [
- (k + (StrKey(k2),), v) for k2, v in v._data.items()
- ] + stack
- else:
- yield tuple(k)
- def __delitem__(self, k: SeqStrType) -> None:
- """Deletes item at `k`."""
- ks, ns = [], []
- data_ptr = self._data
- for k in _flatten_index(k):
- if isinstance(data_ptr, NestedDict):
- data_ptr = data_ptr._data
- if k not in data_ptr:
- raise KeyError(str(ks + [k]))
- ks.append(k)
- ns.append(data_ptr)
- data_ptr = data_ptr[k]
- del ns[-1][ks[-1]]
- for i in reversed(range(len(ks) - 1)):
- if not ns[i + 1]:
- del ns[i][ks[i]]
- def __len__(self) -> int:
- """Returns the length of the NestedDict.
- The length is defined as the number of leaf nodes in the `NestedDict` that
- are not of type Mapping. For example, if the `NestedDict` is: {'a': {'b': 1,
- 'c': {}}}, then the length is 1.
- """
- # do a DFS to count the number of leaf nodes
- count = 0
- stack = [self._data]
- while stack:
- node = stack.pop()
- if isinstance(node, NestedDict):
- node = node._data
- if isinstance(node, Mapping):
- stack.extend(node.values())
- else:
- count += 1
- return count
- def __str__(self) -> str:
- return str(self.asdict())
- def __repr__(self) -> str:
- return f"NestedDict({repr(self._data)})"
- def filter(
- self,
- other: Union[Sequence[SeqStrType], "NestedDict"],
- ignore_missing: bool = False,
- ) -> "NestedDict[T]":
- """Returns a NestedDict with only entries present in `other`.
- The values in the `other` NestedDict are ignored. Only the keys are used.
- Args:
- other: a NestedDict or a sequence of keys to filter by.
- ignore_missing: if True, ignore missing keys in `other`.
- Returns:
- A NestedDict with only keys present in `other`.
- """
- output = self.__class__()
- if isinstance(other, Sequence):
- keys = other
- else:
- keys = other.keys()
- for k in keys:
- if k not in self:
- if not ignore_missing:
- raise KeyError(k)
- else:
- output[k] = self.get(k)
- return output
- def asdict(self) -> _NestedDictType:
- """Returns a dictionary representation of the NestedDict."""
- output = dict()
- for k, v in self._data.items():
- if isinstance(v, NestedDict):
- output[k] = v.asdict()
- else:
- output[k] = v
- return output
- def copy(self) -> "NestedDict[T]":
- """Returns a shallow copy of the NestedDict."""
- return NestedDict(self.items())
- def __copy__(self) -> "NestedDict[T]":
- return self.copy()
- def shallow_keys(self) -> AbstractSet[str]:
- """Returns a set of the keys at the top level of the NestedDict."""
- return self._data.keys()
|