json_writer.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. from datetime import datetime
  2. import json
  3. import logging
  4. import numpy as np
  5. import os
  6. from urllib.parse import urlparse
  7. import time
  8. try:
  9. from smart_open import smart_open
  10. except ImportError:
  11. smart_open = None
  12. from ray.air._internal.json import SafeFallbackEncoder
  13. from ray.rllib.policy.sample_batch import MultiAgentBatch
  14. from ray.rllib.offline.io_context import IOContext
  15. from ray.rllib.offline.output_writer import OutputWriter
  16. from ray.rllib.utils.annotations import override, PublicAPI
  17. from ray.rllib.utils.compression import pack, compression_supported
  18. from ray.rllib.utils.typing import FileType, SampleBatchType
  19. from typing import Any, Dict, List
  20. logger = logging.getLogger(__name__)
  21. WINDOWS_DRIVES = [chr(i) for i in range(ord("c"), ord("z") + 1)]
  22. # TODO(jungong): use DatasetWriter to back JsonWriter, so we reduce codebase complexity
  23. # without losing existing functionality.
  24. @PublicAPI
  25. class JsonWriter(OutputWriter):
  26. """Writer object that saves experiences in JSON file chunks."""
  27. @PublicAPI
  28. def __init__(
  29. self,
  30. path: str,
  31. ioctx: IOContext = None,
  32. max_file_size: int = 64 * 1024 * 1024,
  33. compress_columns: List[str] = frozenset(["obs", "new_obs"]),
  34. ):
  35. """Initializes a JsonWriter instance.
  36. Args:
  37. path: a path/URI of the output directory to save files in.
  38. ioctx: current IO context object.
  39. max_file_size: max size of single files before rolling over.
  40. compress_columns: list of sample batch columns to compress.
  41. """
  42. logger.info(
  43. "You are using JSONWriter. It is recommended to use "
  44. + "DatasetWriter instead."
  45. )
  46. self.ioctx = ioctx or IOContext()
  47. self.max_file_size = max_file_size
  48. self.compress_columns = compress_columns
  49. if urlparse(path).scheme not in [""] + WINDOWS_DRIVES:
  50. self.path_is_uri = True
  51. else:
  52. path = os.path.abspath(os.path.expanduser(path))
  53. # Try to create local dirs if they don't exist
  54. os.makedirs(path, exist_ok=True)
  55. assert os.path.exists(path), "Failed to create {}".format(path)
  56. self.path_is_uri = False
  57. self.path = path
  58. self.file_index = 0
  59. self.bytes_written = 0
  60. self.cur_file = None
  61. @override(OutputWriter)
  62. def write(self, sample_batch: SampleBatchType):
  63. start = time.time()
  64. data = _to_json(sample_batch, self.compress_columns)
  65. f = self._get_file()
  66. f.write(data)
  67. f.write("\n")
  68. if hasattr(f, "flush"): # legacy smart_open impls
  69. f.flush()
  70. self.bytes_written += len(data)
  71. logger.debug(
  72. "Wrote {} bytes to {} in {}s".format(len(data), f, time.time() - start)
  73. )
  74. def _get_file(self) -> FileType:
  75. if not self.cur_file or self.bytes_written >= self.max_file_size:
  76. if self.cur_file:
  77. self.cur_file.close()
  78. timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
  79. path = os.path.join(
  80. self.path,
  81. "output-{}_worker-{}_{}.json".format(
  82. timestr, self.ioctx.worker_index, self.file_index
  83. ),
  84. )
  85. if self.path_is_uri:
  86. if smart_open is None:
  87. raise ValueError(
  88. "You must install the `smart_open` module to write "
  89. "to URIs like {}".format(path)
  90. )
  91. self.cur_file = smart_open(path, "w")
  92. else:
  93. self.cur_file = open(path, "w")
  94. self.file_index += 1
  95. self.bytes_written = 0
  96. logger.info("Writing to new output file {}".format(self.cur_file))
  97. return self.cur_file
  98. def _to_jsonable(v, compress: bool) -> Any:
  99. if compress and compression_supported():
  100. return str(pack(v))
  101. elif isinstance(v, np.ndarray):
  102. return v.tolist()
  103. return v
  104. def _to_json_dict(batch: SampleBatchType, compress_columns: List[str]) -> Dict:
  105. out = {}
  106. if isinstance(batch, MultiAgentBatch):
  107. out["type"] = "MultiAgentBatch"
  108. out["count"] = batch.count
  109. policy_batches = {}
  110. for policy_id, sub_batch in batch.policy_batches.items():
  111. policy_batches[policy_id] = {}
  112. for k, v in sub_batch.items():
  113. policy_batches[policy_id][k] = _to_jsonable(
  114. v, compress=k in compress_columns
  115. )
  116. out["policy_batches"] = policy_batches
  117. else:
  118. out["type"] = "SampleBatch"
  119. for k, v in batch.items():
  120. out[k] = _to_jsonable(v, compress=k in compress_columns)
  121. return out
  122. def _to_json(batch: SampleBatchType, compress_columns: List[str]) -> str:
  123. out = _to_json_dict(batch, compress_columns)
  124. return json.dumps(out, cls=SafeFallbackEncoder)