123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419 |
- import base64
- import importlib
- import io
- import zlib
- from typing import Any, Dict, Optional, Sequence, Type, Union
- import numpy as np
- import ray
- from ray.rllib.utils.annotations import DeveloperAPI
- from ray.rllib.utils.gym import try_import_gymnasium_and_gym
- from ray.rllib.utils.error import NotSerializable
- from ray.rllib.utils.spaces.flexdict import FlexDict
- from ray.rllib.utils.spaces.repeated import Repeated
- from ray.rllib.utils.spaces.simplex import Simplex
- NOT_SERIALIZABLE = "__not_serializable__"
- gym, old_gym = try_import_gymnasium_and_gym()
- old_gym_text_class = None
- if old_gym:
- old_gym_text_class = getattr(old_gym.spaces, "Text", None)
- @DeveloperAPI
- def convert_numpy_to_python_primitives(obj: Any):
- """Convert an object that is a numpy type to a python type.
- If the object is not a numpy type, it is returned unchanged.
- Args:
- obj: The object to convert.
- """
- if isinstance(obj, np.integer):
- return int(obj)
- elif isinstance(obj, np.floating):
- return float(obj)
- elif isinstance(obj, np.bool_):
- return bool(obj)
- elif isinstance(obj, np.str_):
- return str(obj)
- elif isinstance(obj, np.ndarray):
- ret = obj.tolist()
- for i, v in enumerate(ret):
- ret[i] = convert_numpy_to_python_primitives(v)
- return ret
- else:
- return obj
- def _serialize_ndarray(array: np.ndarray) -> str:
- """Pack numpy ndarray into Base64 encoded strings for serialization.
- This function uses numpy.save() instead of pickling to ensure
- compatibility.
- Args:
- array: numpy ndarray.
- Returns:
- b64 escaped string.
- """
- buf = io.BytesIO()
- np.save(buf, array)
- return base64.b64encode(zlib.compress(buf.getvalue())).decode("ascii")
- def _deserialize_ndarray(b64_string: str) -> np.ndarray:
- """Unpack b64 escaped string into numpy ndarray.
- This function assumes the unescaped bytes are of npy format.
- Args:
- b64_string: Base64 escaped string.
- Returns:
- numpy ndarray.
- """
- return np.load(io.BytesIO(zlib.decompress(base64.b64decode(b64_string))))
- @DeveloperAPI
- def gym_space_to_dict(space: gym.spaces.Space) -> Dict:
- """Serialize a gym Space into a JSON-serializable dict.
- Args:
- space: gym.spaces.Space
- Returns:
- Serialized JSON string.
- """
- def _box(sp: gym.spaces.Box) -> Dict:
- return {
- "space": "box",
- "low": _serialize_ndarray(sp.low),
- "high": _serialize_ndarray(sp.high),
- "shape": sp._shape, # shape is a tuple.
- "dtype": sp.dtype.str,
- }
- def _discrete(sp: gym.spaces.Discrete) -> Dict:
- d = {
- "space": "discrete",
- "n": int(sp.n),
- }
- # Offset is a relatively new Discrete space feature.
- if hasattr(sp, "start"):
- d["start"] = int(sp.start)
- return d
- def _multi_binary(sp: gym.spaces.MultiBinary) -> Dict:
- return {
- "space": "multi-binary",
- "n": sp.n,
- }
- def _multi_discrete(sp: gym.spaces.MultiDiscrete) -> Dict:
- return {
- "space": "multi-discrete",
- "nvec": _serialize_ndarray(sp.nvec),
- "dtype": sp.dtype.str,
- }
- def _tuple(sp: gym.spaces.Tuple) -> Dict:
- return {
- "space": "tuple",
- "spaces": [gym_space_to_dict(sp) for sp in sp.spaces],
- }
- def _dict(sp: gym.spaces.Dict) -> Dict:
- return {
- "space": "dict",
- "spaces": {k: gym_space_to_dict(sp) for k, sp in sp.spaces.items()},
- }
- def _simplex(sp: Simplex) -> Dict:
- return {
- "space": "simplex",
- "shape": sp._shape, # shape is a tuple.
- "concentration": sp.concentration,
- "dtype": sp.dtype.str,
- }
- def _repeated(sp: Repeated) -> Dict:
- return {
- "space": "repeated",
- "child_space": gym_space_to_dict(sp.child_space),
- "max_len": sp.max_len,
- }
- def _flex_dict(sp: FlexDict) -> Dict:
- d = {
- "space": "flex_dict",
- }
- for k, s in sp.spaces:
- d[k] = gym_space_to_dict(s)
- return d
- def _text(sp: "gym.spaces.Text") -> Dict:
- # Note (Kourosh): This only works in gym >= 0.25.0
- charset = getattr(sp, "character_set", None)
- if charset is None:
- charset = getattr(sp, "charset", None)
- if charset is None:
- raise ValueError(
- "Text space must have a character_set or charset attribute"
- )
- return {
- "space": "text",
- "min_length": sp.min_length,
- "max_length": sp.max_length,
- "charset": charset,
- }
- if isinstance(space, gym.spaces.Box):
- return _box(space)
- elif isinstance(space, gym.spaces.Discrete):
- return _discrete(space)
- elif isinstance(space, gym.spaces.MultiBinary):
- return _multi_binary(space)
- elif isinstance(space, gym.spaces.MultiDiscrete):
- return _multi_discrete(space)
- elif isinstance(space, gym.spaces.Tuple):
- return _tuple(space)
- elif isinstance(space, gym.spaces.Dict):
- return _dict(space)
- elif isinstance(space, gym.spaces.Text):
- return _text(space)
- elif isinstance(space, Simplex):
- return _simplex(space)
- elif isinstance(space, Repeated):
- return _repeated(space)
- elif isinstance(space, FlexDict):
- return _flex_dict(space)
- # Old gym Spaces.
- elif old_gym and isinstance(space, old_gym.spaces.Box):
- return _box(space)
- elif old_gym and isinstance(space, old_gym.spaces.Discrete):
- return _discrete(space)
- elif old_gym and isinstance(space, old_gym.spaces.MultiDiscrete):
- return _multi_discrete(space)
- elif old_gym and isinstance(space, old_gym.spaces.Tuple):
- return _tuple(space)
- elif old_gym and isinstance(space, old_gym.spaces.Dict):
- return _dict(space)
- elif old_gym and old_gym_text_class and isinstance(space, old_gym_text_class):
- return _text(space)
- else:
- raise ValueError("Unknown space type for serialization, ", type(space))
- @DeveloperAPI
- def space_to_dict(space: gym.spaces.Space) -> Dict:
- d = {"space": gym_space_to_dict(space)}
- if "original_space" in space.__dict__:
- d["original_space"] = space_to_dict(space.original_space)
- return d
- @DeveloperAPI
- def gym_space_from_dict(d: Dict) -> gym.spaces.Space:
- """De-serialize a dict into gym Space.
- Args:
- str: serialized JSON str.
- Returns:
- De-serialized gym space.
- """
- def __common(d: Dict):
- """Common updates to the dict before we use it to construct spaces"""
- ret = d.copy()
- del ret["space"]
- if "dtype" in ret:
- ret["dtype"] = np.dtype(ret["dtype"])
- return ret
- def _box(d: Dict) -> gym.spaces.Box:
- ret = d.copy()
- ret.update(
- {
- "low": _deserialize_ndarray(d["low"]),
- "high": _deserialize_ndarray(d["high"]),
- }
- )
- return gym.spaces.Box(**__common(ret))
- def _discrete(d: Dict) -> gym.spaces.Discrete:
- return gym.spaces.Discrete(**__common(d))
- def _multi_binary(d: Dict) -> gym.spaces.MultiBinary:
- return gym.spaces.MultiBinary(**__common(d))
- def _multi_discrete(d: Dict) -> gym.spaces.MultiDiscrete:
- ret = d.copy()
- ret.update(
- {
- "nvec": _deserialize_ndarray(ret["nvec"]),
- }
- )
- return gym.spaces.MultiDiscrete(**__common(ret))
- def _tuple(d: Dict) -> gym.spaces.Discrete:
- spaces = [gym_space_from_dict(sp) for sp in d["spaces"]]
- return gym.spaces.Tuple(spaces=spaces)
- def _dict(d: Dict) -> gym.spaces.Discrete:
- spaces = {k: gym_space_from_dict(sp) for k, sp in d["spaces"].items()}
- return gym.spaces.Dict(spaces=spaces)
- def _simplex(d: Dict) -> Simplex:
- return Simplex(**__common(d))
- def _repeated(d: Dict) -> Repeated:
- child_space = gym_space_from_dict(d["child_space"])
- return Repeated(child_space=child_space, max_len=d["max_len"])
- def _flex_dict(d: Dict) -> FlexDict:
- spaces = {k: gym_space_from_dict(s) for k, s in d.items() if k != "space"}
- return FlexDict(spaces=spaces)
- def _text(d: Dict) -> "gym.spaces.Text":
- return gym.spaces.Text(**__common(d))
- space_map = {
- "box": _box,
- "discrete": _discrete,
- "multi-binary": _multi_binary,
- "multi-discrete": _multi_discrete,
- "tuple": _tuple,
- "dict": _dict,
- "simplex": _simplex,
- "repeated": _repeated,
- "flex_dict": _flex_dict,
- "text": _text,
- }
- space_type = d["space"]
- if space_type not in space_map:
- raise ValueError("Unknown space type for de-serialization, ", space_type)
- return space_map[space_type](d)
- @DeveloperAPI
- def space_from_dict(d: Dict) -> gym.spaces.Space:
- space = gym_space_from_dict(d["space"])
- if "original_space" in d:
- assert "space" in d["original_space"]
- if isinstance(d["original_space"]["space"], str):
- # For backward compatibility reasons, if d["original_space"]["space"]
- # is a string, this original space was serialized by gym_space_to_dict.
- space.original_space = gym_space_from_dict(d["original_space"])
- else:
- # Otherwise, this original space was serialized by space_to_dict.
- space.original_space = space_from_dict(d["original_space"])
- return space
- @DeveloperAPI
- def check_if_args_kwargs_serializable(args: Sequence[Any], kwargs: Dict[str, Any]):
- """Check if parameters to a function are serializable by ray.
- Args:
- args: arguments to be checked.
- kwargs: keyword arguments to be checked.
- Raises:
- NoteSerializable if either args are kwargs are not serializable
- by ray.
- """
- for arg in args:
- try:
- # if the object is truly serializable we should be able to
- # ray.put and ray.get it.
- ray.get(ray.put(arg))
- except TypeError as e:
- raise NotSerializable(
- "RLModule constructor arguments must be serializable. "
- f"Found non-serializable argument: {arg}.\n"
- f"Original serialization error: {e}"
- )
- for k, v in kwargs.items():
- try:
- # if the object is truly serializable we should be able to
- # ray.put and ray.get it.
- ray.get(ray.put(v))
- except TypeError as e:
- raise NotSerializable(
- "RLModule constructor arguments must be serializable. "
- f"Found non-serializable keyword argument: {k} = {v}.\n"
- f"Original serialization error: {e}"
- )
- @DeveloperAPI
- def serialize_type(type_: Union[Type, str]) -> str:
- """Converts a type into its full classpath ([module file] + "." + [class name]).
- Args:
- type_: The type to convert.
- Returns:
- The full classpath of the given type, e.g. "ray.rllib.algorithms.ppo.PPOConfig".
- """
- # TODO (avnishn): find a way to incorporate the tune registry here.
- # Already serialized.
- if isinstance(type_, str):
- return type_
- return type_.__module__ + "." + type_.__qualname__
- @DeveloperAPI
- def deserialize_type(
- module: Union[str, Type], error: bool = False
- ) -> Optional[Union[str, Type]]:
- """Resolves a class path to a class.
- If the given module is already a class, it is returned as is.
- If the given module is a string, it is imported and the class is returned.
- Args:
- module: The classpath (str) or type to resolve.
- error: Whether to throw a ValueError if `module` could not be resolved into
- a class. If False and `module` is not resolvable, returns None.
- Returns:
- The resolved class or `module` (if `error` is False and no resolution possible).
- Raises:
- ValueError: If `error` is True and `module` cannot be resolved.
- """
- # Already a class, return as-is.
- if isinstance(module, type):
- return module
- # A string.
- elif isinstance(module, str):
- # Try interpreting (as classpath) and importing the given module.
- try:
- module_path, class_name = module.rsplit(".", 1)
- module = importlib.import_module(module_path)
- return getattr(module, class_name)
- # Module not found OR not a module (but a registered string?).
- except (ModuleNotFoundError, ImportError, AttributeError, ValueError) as e:
- # Ignore if error=False.
- if error:
- raise ValueError(
- f"Could not deserialize the given classpath `module={module}` into "
- "a valid python class! Make sure you have all necessary pip "
- "packages installed and all custom modules are in your "
- "`PYTHONPATH` env variable."
- ) from e
- else:
- raise ValueError(f"`module` ({module} must be type or string (classpath)!")
- return module
|