123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- import logging
- import os
- import time
- from ray.util.debug import log_once
- from ray.rllib.utils.framework import try_import_tf
- tf1, tf, tfv = try_import_tf()
- logger = logging.getLogger(__name__)
- class _TFRunBuilder:
- """Used to incrementally build up a TensorFlow run.
- This is particularly useful for batching ops from multiple different
- policies in the multi-agent setting.
- """
- def __init__(self, session, debug_name):
- self.session = session
- self.debug_name = debug_name
- self.feed_dict = {}
- self.fetches = []
- self._executed = None
- def add_feed_dict(self, feed_dict):
- assert not self._executed
- for k in feed_dict:
- if k in self.feed_dict:
- raise ValueError("Key added twice: {}".format(k))
- self.feed_dict.update(feed_dict)
- def add_fetches(self, fetches):
- assert not self._executed
- base_index = len(self.fetches)
- self.fetches.extend(fetches)
- return list(range(base_index, len(self.fetches)))
- def get(self, to_fetch):
- if self._executed is None:
- try:
- self._executed = _run_timeline(
- self.session,
- self.fetches,
- self.debug_name,
- self.feed_dict,
- os.environ.get("TF_TIMELINE_DIR"),
- )
- except Exception as e:
- logger.exception(
- "Error fetching: {}, feed_dict={}".format(
- self.fetches, self.feed_dict
- )
- )
- raise e
- if isinstance(to_fetch, int):
- return self._executed[to_fetch]
- elif isinstance(to_fetch, list):
- return [self.get(x) for x in to_fetch]
- elif isinstance(to_fetch, tuple):
- return tuple(self.get(x) for x in to_fetch)
- else:
- raise ValueError("Unsupported fetch type: {}".format(to_fetch))
- _count = 0
- def _run_timeline(sess, ops, debug_name, feed_dict=None, timeline_dir=None):
- if feed_dict is None:
- feed_dict = {}
- if timeline_dir:
- from tensorflow.python.client import timeline
- try:
- run_options = tf1.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
- except AttributeError:
- run_options = None
- # In local mode, tf1.RunOptions is not available, see #26511
- if log_once("tf1.RunOptions_not_available"):
- logger.exception(
- "Can not access tf.RunOptions.FULL_TRACE. This may be because "
- "you have used `ray.init(local_mode=True)`. RLlib will use "
- "timeline without `options=tf.RunOptions.FULL_TRACE`."
- )
- run_metadata = tf1.RunMetadata()
- start = time.time()
- fetches = sess.run(
- ops, options=run_options, run_metadata=run_metadata, feed_dict=feed_dict
- )
- trace = timeline.Timeline(step_stats=run_metadata.step_stats)
- global _count
- outf = os.path.join(
- timeline_dir,
- "timeline-{}-{}-{}.json".format(debug_name, os.getpid(), _count % 10),
- )
- _count += 1
- trace_file = open(outf, "w")
- logger.info(
- "Wrote tf timeline ({} s) to {}".format(
- time.time() - start, os.path.abspath(outf)
- )
- )
- trace_file.write(trace.generate_chrome_trace_format())
- else:
- if log_once("tf_timeline"):
- logger.info(
- "Executing TF run without tracing. To dump TF timeline traces "
- "to disk, set the TF_TIMELINE_DIR environment variable."
- )
- fetches = sess.run(ops, feed_dict=feed_dict)
- return fetches
|