from datetime import datetime import json import logging import numpy as np import os from six.moves.urllib.parse import urlparse import time try: from smart_open import smart_open except ImportError: smart_open = None from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.offline.io_context import IOContext from ray.rllib.offline.output_writer import OutputWriter from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.compression import pack, compression_supported from ray.rllib.utils.typing import FileType, SampleBatchType from ray.util.ml_utils.json import SafeFallbackEncoder from typing import Any, List logger = logging.getLogger(__name__) WINDOWS_DRIVES = [chr(i) for i in range(ord("c"), ord("z") + 1)] @PublicAPI class JsonWriter(OutputWriter): """Writer object that saves experiences in JSON file chunks.""" @PublicAPI def __init__(self, path: str, ioctx: IOContext = None, max_file_size: int = 64 * 1024 * 1024, compress_columns: List[str] = frozenset(["obs", "new_obs"])): """Initializes a JsonWriter instance. Args: path: a path/URI of the output directory to save files in. ioctx: current IO context object. max_file_size: max size of single files before rolling over. compress_columns: list of sample batch columns to compress. """ self.ioctx = ioctx or IOContext() self.max_file_size = max_file_size self.compress_columns = compress_columns if urlparse(path).scheme not in [""] + WINDOWS_DRIVES: self.path_is_uri = True else: path = os.path.abspath(os.path.expanduser(path)) # Try to create local dirs if they don't exist try: os.makedirs(path) except OSError: pass # already exists assert os.path.exists(path), "Failed to create {}".format(path) self.path_is_uri = False self.path = path self.file_index = 0 self.bytes_written = 0 self.cur_file = None @override(OutputWriter) def write(self, sample_batch: SampleBatchType): start = time.time() data = _to_json(sample_batch, self.compress_columns) f = self._get_file() f.write(data) f.write("\n") if hasattr(f, "flush"): # legacy smart_open impls f.flush() self.bytes_written += len(data) logger.debug("Wrote {} bytes to {} in {}s".format( len(data), f, time.time() - start)) def _get_file(self) -> FileType: if not self.cur_file or self.bytes_written >= self.max_file_size: if self.cur_file: self.cur_file.close() timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S") path = os.path.join( self.path, "output-{}_worker-{}_{}.json".format( timestr, self.ioctx.worker_index, self.file_index)) if self.path_is_uri: if smart_open is None: raise ValueError( "You must install the `smart_open` module to write " "to URIs like {}".format(path)) self.cur_file = smart_open(path, "w") else: self.cur_file = open(path, "w") self.file_index += 1 self.bytes_written = 0 logger.info("Writing to new output file {}".format(self.cur_file)) return self.cur_file def _to_jsonable(v, compress: bool) -> Any: if compress and compression_supported(): return str(pack(v)) elif isinstance(v, np.ndarray): return v.tolist() return v def _to_json(batch: SampleBatchType, compress_columns: List[str]) -> str: out = {} if isinstance(batch, MultiAgentBatch): out["type"] = "MultiAgentBatch" out["count"] = batch.count policy_batches = {} for policy_id, sub_batch in batch.policy_batches.items(): policy_batches[policy_id] = {} for k, v in sub_batch.items(): policy_batches[policy_id][k] = _to_jsonable( v, compress=k in compress_columns) out["policy_batches"] = policy_batches else: out["type"] = "SampleBatch" for k, v in batch.items(): out[k] = _to_jsonable(v, compress=k in compress_columns) return json.dumps(out, cls=SafeFallbackEncoder)