12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152 |
- from collections import defaultdict
- from typing import Any, Callable
- class LambdaDefaultDict(defaultdict):
- """A defaultdict that creates default values based on the associated key.
- Note that the standard defaultdict can only produce default values (via its factory)
- that are independent of the key under which they are stored.
- As opposed to that, the lambda functions used as factories for this
- `LambdaDefaultDict` class do accept a single argument: The missing key.
- If a missing key is accessed by the user, the provided lambda function is called
- with this missing key as its argument. The returned value is stored in the
- dictionary under that key and returned.
- Example:
- In this example, if you try to access a key that doesn't exist, it will call
- the lambda function, passing it the missing key. The function will return a
- string, which will be stored in the dictionary under that key.
- .. testcode::
- from ray.rllib.utils.lambda_defaultdict import LambdaDefaultDict
- default_dict = LambdaDefaultDict(lambda missing_key: f"Value for {missing_key}")
- print(default_dict["a"])
- .. testoutput::
- Value for a
- """ # noqa: E501
- def __init__(self, default_factory: Callable[[str], Any], *args, **kwargs):
- """Initializes a LambdaDefaultDict instance.
- Args:
- default_factory: The default factory callable, taking a string (key)
- and returning the default value to use for that key.
- """
- if not callable(default_factory):
- raise TypeError("First argument must be a Callable!")
- # We will handle the factory in __missing__ method.
- super().__init__(None, *args, **kwargs)
- self.default_factory = default_factory
- def __missing__(self, key):
- # Call default factory with the key as argument.
- self[key] = value = self.default_factory(key)
- return value
|