compression.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. from ray.rllib.utils.annotations import DeveloperAPI
  2. import logging
  3. import time
  4. import base64
  5. import numpy as np
  6. from ray import cloudpickle as pickle
  7. from six import string_types
  8. logger = logging.getLogger(__name__)
  9. try:
  10. import lz4.frame
  11. LZ4_ENABLED = True
  12. except ImportError:
  13. logger.warning("lz4 not available, disabling sample compression. "
  14. "This will significantly impact RLlib performance. "
  15. "To install lz4, run `pip install lz4`.")
  16. LZ4_ENABLED = False
  17. @DeveloperAPI
  18. def compression_supported():
  19. return LZ4_ENABLED
  20. @DeveloperAPI
  21. def pack(data):
  22. if LZ4_ENABLED:
  23. data = pickle.dumps(data)
  24. data = lz4.frame.compress(data)
  25. # TODO(ekl) we shouldn't need to base64 encode this data, but this
  26. # seems to not survive a transfer through the object store if we don't.
  27. data = base64.b64encode(data).decode("ascii")
  28. return data
  29. @DeveloperAPI
  30. def pack_if_needed(data):
  31. if isinstance(data, np.ndarray):
  32. data = pack(data)
  33. return data
  34. @DeveloperAPI
  35. def unpack(data):
  36. if LZ4_ENABLED:
  37. data = base64.b64decode(data)
  38. data = lz4.frame.decompress(data)
  39. data = pickle.loads(data)
  40. return data
  41. @DeveloperAPI
  42. def unpack_if_needed(data):
  43. if is_compressed(data):
  44. data = unpack(data)
  45. return data
  46. @DeveloperAPI
  47. def is_compressed(data):
  48. return isinstance(data, bytes) or isinstance(data, string_types)
  49. # Intel(R) Core(TM) i7-4600U CPU @ 2.10GHz
  50. # Compression speed: 753.664 MB/s
  51. # Compression ratio: 87.4839812046
  52. # Decompression speed: 910.9504 MB/s
  53. if __name__ == "__main__":
  54. size = 32 * 80 * 80 * 4
  55. data = np.ones(size).reshape((32, 80, 80, 4))
  56. count = 0
  57. start = time.time()
  58. while time.time() - start < 1:
  59. pack(data)
  60. count += 1
  61. compressed = pack(data)
  62. print("Compression speed: {} MB/s".format(count * size * 4 / 1e6))
  63. print("Compression ratio: {}".format(round(size * 4 / len(compressed), 2)))
  64. count = 0
  65. start = time.time()
  66. while time.time() - start < 1:
  67. unpack(compressed)
  68. count += 1
  69. print("Decompression speed: {} MB/s".format(count * size * 4 / 1e6))