lambda_defaultdict.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. from collections import defaultdict
  2. from typing import Any, Callable
  3. class LambdaDefaultDict(defaultdict):
  4. """A defaultdict that creates default values based on the associated key.
  5. Note that the standard defaultdict can only produce default values (via its factory)
  6. that are independent of the key under which they are stored.
  7. As opposed to that, the lambda functions used as factories for this
  8. `LambdaDefaultDict` class do accept a single argument: The missing key.
  9. If a missing key is accessed by the user, the provided lambda function is called
  10. with this missing key as its argument. The returned value is stored in the
  11. dictionary under that key and returned.
  12. Example:
  13. In this example, if you try to access a key that doesn't exist, it will call
  14. the lambda function, passing it the missing key. The function will return a
  15. string, which will be stored in the dictionary under that key.
  16. .. testcode::
  17. from ray.rllib.utils.lambda_defaultdict import LambdaDefaultDict
  18. default_dict = LambdaDefaultDict(lambda missing_key: f"Value for {missing_key}")
  19. print(default_dict["a"])
  20. .. testoutput::
  21. Value for a
  22. """ # noqa: E501
  23. def __init__(self, default_factory: Callable[[str], Any], *args, **kwargs):
  24. """Initializes a LambdaDefaultDict instance.
  25. Args:
  26. default_factory: The default factory callable, taking a string (key)
  27. and returning the default value to use for that key.
  28. """
  29. if not callable(default_factory):
  30. raise TypeError("First argument must be a Callable!")
  31. # We will handle the factory in __missing__ method.
  32. super().__init__(None, *args, **kwargs)
  33. self.default_factory = default_factory
  34. def __missing__(self, key):
  35. # Call default factory with the key as argument.
  36. self[key] = value = self.default_factory(key)
  37. return value