tf_run_builder.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import logging
  2. import os
  3. import time
  4. from ray.util.debug import log_once
  5. from ray.rllib.utils.framework import try_import_tf
  6. tf1, tf, tfv = try_import_tf()
  7. logger = logging.getLogger(__name__)
  8. class _TFRunBuilder:
  9. """Used to incrementally build up a TensorFlow run.
  10. This is particularly useful for batching ops from multiple different
  11. policies in the multi-agent setting.
  12. """
  13. def __init__(self, session, debug_name):
  14. self.session = session
  15. self.debug_name = debug_name
  16. self.feed_dict = {}
  17. self.fetches = []
  18. self._executed = None
  19. def add_feed_dict(self, feed_dict):
  20. assert not self._executed
  21. for k in feed_dict:
  22. if k in self.feed_dict:
  23. raise ValueError("Key added twice: {}".format(k))
  24. self.feed_dict.update(feed_dict)
  25. def add_fetches(self, fetches):
  26. assert not self._executed
  27. base_index = len(self.fetches)
  28. self.fetches.extend(fetches)
  29. return list(range(base_index, len(self.fetches)))
  30. def get(self, to_fetch):
  31. if self._executed is None:
  32. try:
  33. self._executed = _run_timeline(
  34. self.session,
  35. self.fetches,
  36. self.debug_name,
  37. self.feed_dict,
  38. os.environ.get("TF_TIMELINE_DIR"),
  39. )
  40. except Exception as e:
  41. logger.exception(
  42. "Error fetching: {}, feed_dict={}".format(
  43. self.fetches, self.feed_dict
  44. )
  45. )
  46. raise e
  47. if isinstance(to_fetch, int):
  48. return self._executed[to_fetch]
  49. elif isinstance(to_fetch, list):
  50. return [self.get(x) for x in to_fetch]
  51. elif isinstance(to_fetch, tuple):
  52. return tuple(self.get(x) for x in to_fetch)
  53. else:
  54. raise ValueError("Unsupported fetch type: {}".format(to_fetch))
  55. _count = 0
  56. def _run_timeline(sess, ops, debug_name, feed_dict=None, timeline_dir=None):
  57. if feed_dict is None:
  58. feed_dict = {}
  59. if timeline_dir:
  60. from tensorflow.python.client import timeline
  61. try:
  62. run_options = tf1.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
  63. except AttributeError:
  64. run_options = None
  65. # In local mode, tf1.RunOptions is not available, see #26511
  66. if log_once("tf1.RunOptions_not_available"):
  67. logger.exception(
  68. "Can not access tf.RunOptions.FULL_TRACE. This may be because "
  69. "you have used `ray.init(local_mode=True)`. RLlib will use "
  70. "timeline without `options=tf.RunOptions.FULL_TRACE`."
  71. )
  72. run_metadata = tf1.RunMetadata()
  73. start = time.time()
  74. fetches = sess.run(
  75. ops, options=run_options, run_metadata=run_metadata, feed_dict=feed_dict
  76. )
  77. trace = timeline.Timeline(step_stats=run_metadata.step_stats)
  78. global _count
  79. outf = os.path.join(
  80. timeline_dir,
  81. "timeline-{}-{}-{}.json".format(debug_name, os.getpid(), _count % 10),
  82. )
  83. _count += 1
  84. trace_file = open(outf, "w")
  85. logger.info(
  86. "Wrote tf timeline ({} s) to {}".format(
  87. time.time() - start, os.path.abspath(outf)
  88. )
  89. )
  90. trace_file.write(trace.generate_chrome_trace_format())
  91. else:
  92. if log_once("tf_timeline"):
  93. logger.info(
  94. "Executing TF run without tracing. To dump TF timeline traces "
  95. "to disk, set the TF_TIMELINE_DIR environment variable."
  96. )
  97. fetches = sess.run(ops, feed_dict=feed_dict)
  98. return fetches