json_writer.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. from datetime import datetime
  2. import json
  3. import logging
  4. import numpy as np
  5. import os
  6. from six.moves.urllib.parse import urlparse
  7. import time
  8. try:
  9. from smart_open import smart_open
  10. except ImportError:
  11. smart_open = None
  12. from ray.rllib.policy.sample_batch import MultiAgentBatch
  13. from ray.rllib.offline.io_context import IOContext
  14. from ray.rllib.offline.output_writer import OutputWriter
  15. from ray.rllib.utils.annotations import override, PublicAPI
  16. from ray.rllib.utils.compression import pack, compression_supported
  17. from ray.rllib.utils.typing import FileType, SampleBatchType
  18. from ray.util.ml_utils.json import SafeFallbackEncoder
  19. from typing import Any, List
  20. logger = logging.getLogger(__name__)
  21. WINDOWS_DRIVES = [chr(i) for i in range(ord("c"), ord("z") + 1)]
  22. @PublicAPI
  23. class JsonWriter(OutputWriter):
  24. """Writer object that saves experiences in JSON file chunks."""
  25. @PublicAPI
  26. def __init__(self,
  27. path: str,
  28. ioctx: IOContext = None,
  29. max_file_size: int = 64 * 1024 * 1024,
  30. compress_columns: List[str] = frozenset(["obs", "new_obs"])):
  31. """Initializes a JsonWriter instance.
  32. Args:
  33. path: a path/URI of the output directory to save files in.
  34. ioctx: current IO context object.
  35. max_file_size: max size of single files before rolling over.
  36. compress_columns: list of sample batch columns to compress.
  37. """
  38. self.ioctx = ioctx or IOContext()
  39. self.max_file_size = max_file_size
  40. self.compress_columns = compress_columns
  41. if urlparse(path).scheme not in [""] + WINDOWS_DRIVES:
  42. self.path_is_uri = True
  43. else:
  44. path = os.path.abspath(os.path.expanduser(path))
  45. # Try to create local dirs if they don't exist
  46. try:
  47. os.makedirs(path)
  48. except OSError:
  49. pass # already exists
  50. assert os.path.exists(path), "Failed to create {}".format(path)
  51. self.path_is_uri = False
  52. self.path = path
  53. self.file_index = 0
  54. self.bytes_written = 0
  55. self.cur_file = None
  56. @override(OutputWriter)
  57. def write(self, sample_batch: SampleBatchType):
  58. start = time.time()
  59. data = _to_json(sample_batch, self.compress_columns)
  60. f = self._get_file()
  61. f.write(data)
  62. f.write("\n")
  63. if hasattr(f, "flush"): # legacy smart_open impls
  64. f.flush()
  65. self.bytes_written += len(data)
  66. logger.debug("Wrote {} bytes to {} in {}s".format(
  67. len(data), f,
  68. time.time() - start))
  69. def _get_file(self) -> FileType:
  70. if not self.cur_file or self.bytes_written >= self.max_file_size:
  71. if self.cur_file:
  72. self.cur_file.close()
  73. timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
  74. path = os.path.join(
  75. self.path, "output-{}_worker-{}_{}.json".format(
  76. timestr, self.ioctx.worker_index, self.file_index))
  77. if self.path_is_uri:
  78. if smart_open is None:
  79. raise ValueError(
  80. "You must install the `smart_open` module to write "
  81. "to URIs like {}".format(path))
  82. self.cur_file = smart_open(path, "w")
  83. else:
  84. self.cur_file = open(path, "w")
  85. self.file_index += 1
  86. self.bytes_written = 0
  87. logger.info("Writing to new output file {}".format(self.cur_file))
  88. return self.cur_file
  89. def _to_jsonable(v, compress: bool) -> Any:
  90. if compress and compression_supported():
  91. return str(pack(v))
  92. elif isinstance(v, np.ndarray):
  93. return v.tolist()
  94. return v
  95. def _to_json(batch: SampleBatchType, compress_columns: List[str]) -> str:
  96. out = {}
  97. if isinstance(batch, MultiAgentBatch):
  98. out["type"] = "MultiAgentBatch"
  99. out["count"] = batch.count
  100. policy_batches = {}
  101. for policy_id, sub_batch in batch.policy_batches.items():
  102. policy_batches[policy_id] = {}
  103. for k, v in sub_batch.items():
  104. policy_batches[policy_id][k] = _to_jsonable(
  105. v, compress=k in compress_columns)
  106. out["policy_batches"] = policy_batches
  107. else:
  108. out["type"] = "SampleBatch"
  109. for k, v in batch.items():
  110. out[k] = _to_jsonable(v, compress=k in compress_columns)
  111. return json.dumps(out, cls=SafeFallbackEncoder)