1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A Context that captures profile and performs profiling/dumping.
16"""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import contextlib
22import os
23import random
24import sys
25import threading
26
27from tensorflow.core.protobuf import config_pb2
28from tensorflow.python.client import session
29from tensorflow.python.framework import errors
30from tensorflow.python.framework import ops
31from tensorflow.python.platform import gfile
32from tensorflow.python.profiler import model_analyzer
33from tensorflow.python.util import _pywrap_tfprof as print_mdl
34from tensorflow.python.util import compat
35
36WARMUP_STEPS = 10
37MAX_TRACED_STEPS = 100
38
39
40def _profiled_init(self, target='', graph=None, config=None):
41  """Overwrites the session.__init__."""
42  self._profiler_init_internal(target, graph, config)  # pylint: disable=protected-access
43
44
45def _profiled_run(self,
46                  fetches,
47                  feed_dict=None,
48                  options=None,
49                  run_metadata=None):
50  """Overwrites the session.run()."""
51  # pylint: disable=protected-access
52  # Count the session steps.
53  with self.profile_context._new_step() as state:
54    step, locked = state
55    # Fast path if no need for profiling.
56    if locked and not self.profile_context._is_fast_path(step):
57      # Maybe trace this step.
58      if self.profile_context._should_trace(step, self.graph, fetches):
59        if self.profile_context._debug:
60          sys.stderr.write('debug: tracing step: %d\n' % step)
61        # Enable tracing, perform auto profiling or auto dump.
62        if not run_metadata:
63          run_metadata = config_pb2.RunMetadata()
64
65        if not options:
66          options = config_pb2.RunOptions(
67              trace_level=config_pb2.RunOptions.FULL_TRACE)
68          old_trace_level = options.trace_level
69        else:
70          old_trace_level = options.trace_level
71          options.trace_level = config_pb2.RunOptions.FULL_TRACE
72
73        ret = self._profiler_run_internal(
74            fetches, feed_dict, options, run_metadata)
75        if self.profile_context._debug:
76          self.profile_context._dump_file(run_metadata, 'run_meta_%d' % step)
77
78        self.profile_context.profiler._graph = self.graph
79        self.profile_context.profiler.add_step(step, run_metadata)
80        options.trace_level = old_trace_level
81      else:
82        ret = self._profiler_run_internal(fetches, feed_dict, options)
83
84      # Maybe dump profile.
85      self.profile_context._maybe_dump(step)
86
87      # Maybe profile:
88      to_profiles = self.profile_context._profile_candidates()
89      for to_prof in to_profiles:
90        cmd, opts, _ = to_prof
91        saved_views = self.profile_context._views.setdefault(cmd, {})
92        if self.profile_context._debug:
93          sys.stderr.write('debug: profiling %s step: %d\n' % (cmd, step))
94        if cmd == 'graph':
95          saved_views[step] = self.profile_context.profiler.profile_graph(opts)
96        elif cmd == 'scope':
97          saved_views[step] = self.profile_context.profiler.profile_name_scope(
98              opts)
99        elif cmd == 'op':
100          saved_views[step] = self.profile_context.profiler.profile_operations(
101              opts)
102        elif cmd == 'code':
103          saved_views[step] = self.profile_context.profiler.profile_python(opts)
104        else:
105          raise ValueError('Unknown cmd: %s\n' % cmd)
106      return ret
107  # Fast no lock path.
108  return self._profiler_run_internal(
109      fetches, feed_dict, options, run_metadata)
110  # pylint: enable=protected-access
111
112
113class ProfileContext(object):
114  """A Context that captures RunMetadata and performs profiling.
115
116  ```python
117    # Trace steps 100~200, profile at [150, 200] and dump profile at 200.
118    with profile_context.ProfileContext('/tmp/train_dir',
119                                        trace_steps=range(100, 200, 3),
120                                        dump_steps=[200]) as pctx:
121      opts = tf.profiler.ProfileOptionBuilder.time_and_memory()
122      pctx.add_auto_profiling('op', opts, [150, 200])
123      train_loop().
124
125    # Tracing only.
126    with profile_context.tfprof.ProfileContext('/tmp/train_dir') as pctx:
127      # Run train/eval loop for at least few hundred steps. Profiles will be
128      # dumped to train_dir. Use web UI or command line to do profiling.
129      train_loop().
130
131    # When session object is available, do explicit trace, profile and dump.
132    with profile_context.ProfileContext('/tmp/train_dir',
133                                        trace_steps=[],
134                                        dump_steps=[]) as pctx:
135      opts = tf.profiler.ProfileOptionBuilder.time_and_memory()
136      pctx.trace_next_step()
137      _ = session.run(train_op)
138      pctx.profiler.profile_operations(options=opts)
139  ```
140
141  Args:
142    profile_dir: Directory to store profiles.
143    trace_steps: A list of session run steps to trace. If None, use
144        pre-defined steps.
145    dump_steps: A list of steps to dump the profile to `profile_dir`. If None,
146        use pre-defined steps.
147    enabled: If false, everything is disabled with minimal overhead. It allows
148        user to only enable profiling when needed.
149    debug: If true, also dumps the raw trace RunMetadata text file to
150        profile_dir. And print debugging message. Useful for bug report.
151  """
152
153  def __init__(self,
154               profile_dir,
155               trace_steps=None,
156               dump_steps=None,
157               enabled=True,
158               debug=False):
159    self._enabled = enabled
160    if not self._enabled:
161      return
162
163    self._debug = debug
164    if not profile_dir:
165      raise ValueError('Must have a directory for profile.\n')
166    self._profiler_dir = profile_dir
167
168    if trace_steps is None:
169      self._trace_steps = set()
170      self._auto_tracing = True
171    else:
172      if len(trace_steps) > MAX_TRACED_STEPS:
173        raise ValueError('Only support tracing up to 100 steps.\n')
174      self._trace_steps = set(trace_steps[:])
175      self._auto_tracing = False
176
177    if dump_steps is None:
178      self._dump_steps = set([MAX_TRACED_STEPS])
179    else:
180      self._dump_steps = set(dump_steps[:])
181
182    self._rng = random.Random(111)
183    self._fetched = set()
184    self._slow_path_steps = self._dump_steps | self._trace_steps
185    self._trace_next_step = False
186    self._dump_next_step = False
187    self._step = 0
188    self._traced_steps = 0
189    self._auto_profiles = []
190    self._profiler = None
191    self._views = {}
192    self._lock = threading.Lock()
193
194  def get_profiles(self, cmd):
195    """Returns profiling results for each step at which `cmd` was run.
196
197    Args:
198      cmd: string, profiling command used in an `add_auto_profiling` call.
199
200    Returns:
201      dict[int: (MultiGraphNodeProto | GraphNodeProto)]. Keys are steps at which
202      the profiling command was run. Values are the outputs of profiling.
203      For "code" and "op" commands this will be a `MultiGraphNodeProto`, for
204      "scope" and "graph" commands this will be a `GraphNodeProto.
205
206    Raises:
207      ValueError: if `cmd` was never run (either because no session.run call was
208      made or because there was no `add_auto_profiling` call with the specified
209      `cmd`.
210    """
211    if cmd not in self._views:
212      raise ValueError('No autoprofiler for command: {}, was run'.format(cmd))
213    return self._views[cmd]
214
215  def add_auto_profiling(self, cmd, options, profile_steps):
216    """Traces and profiles at some session run steps.
217
218    Args:
219      cmd: The profiling commands. (i.e. scope, op, python, graph)
220      options: The profiling options.
221      profile_steps: A list/set of integers. The profiling command and options
222          will be run automatically at these integer steps. Each step is
223          a session.run.
224    """
225    if not self._enabled:
226      return
227    self._auto_profiles.append((cmd, options, profile_steps[:]))
228    self._slow_path_steps |= set(profile_steps)
229    self._trace_steps |= set(profile_steps)
230
231  @property
232  def profiler(self):
233    """Returns the current profiler object."""
234    if not self._enabled:
235      return None
236    if not self._profiler:
237      self._profiler = model_analyzer.Profiler(ops.get_default_graph())
238    return self._profiler
239
240  def trace_next_step(self):
241    """Enables tracing and adds traces to profiler at next step."""
242    if not self._enabled:
243      return
244    self._trace_next_step = True
245    self._slow_path_steps.add(self._step)
246
247  def dump_next_step(self):
248    """Enable tracing and dump profiles at next step."""
249    if not self._enabled:
250      return
251    self._dump_next_step = True
252    self._slow_path_steps.add(self._step)
253
254  def _is_fast_path(self, step):
255    if step in self._slow_path_steps:
256      return False
257    # When user doesn't set the tracing steps explicitly, auto decide it.
258    if (self._auto_tracing and step > WARMUP_STEPS and
259        self._traced_steps <= MAX_TRACED_STEPS):
260      return False
261    return True
262
263  def _should_trace(self, step, graph, fetches):
264    """Whether should do tracing at current step."""
265    if self._traced_steps > MAX_TRACED_STEPS:
266      return False
267    # Check user-set tracing steps.
268    if step in self._trace_steps or self._trace_next_step:
269      self._traced_steps += 1
270      return True
271
272    # If no user-set tracing steps set and passes warm up steps, auto trace.
273    if self._auto_tracing and step > WARMUP_STEPS:
274      # If the fetches have not been seen before, trace it.
275      with graph.as_default():
276        fetch_names = [f.name for f in
277                       session._FetchMapper.for_fetch(fetches).unique_fetches()]  # pylint: disable=protected-access
278      fetch_name = '-'.join(sorted(fetch_names))
279      if self._debug:
280        sys.stderr.write('debug: trace fetches: %s\n' % fetch_name)
281      if fetch_name not in self._fetched:
282        self._fetched.add(fetch_name)
283        self._traced_steps += 1
284        return True
285      # If the trace coverage is low, does some random tracing.
286      if (self.profiler._coverage < 0.5 and step < MAX_TRACED_STEPS and  # pylint: disable=protected-access
287          self._rng.randint(0, 10) < 2):
288        self._traced_steps += 1
289        return True
290    return False
291
292  def _maybe_dump(self, step):
293    """Maybe dump the profile file."""
294    if not (step in self._dump_steps or self._dump_next_step):
295      return
296    if self._debug:
297      sys.stderr.write('debug: dumping file at step: %d\n' % step)
298    if not gfile.Exists(self._profiler_dir):
299      gfile.MakeDirs(self._profiler_dir)
300
301    filename = os.path.join(compat.as_bytes(self._profiler_dir),
302                            compat.as_bytes('profile_%d' % step))
303    self.profiler._write_profile(filename)  # pylint: disable=protected-access
304
305  def _dump_file(self, pb, basename):
306    if not gfile.Exists(self._profiler_dir):
307      gfile.MakeDirs(self._profiler_dir)
308    with gfile.Open(os.path.join(self._profiler_dir, basename), 'w') as f:
309      f.write('%s' % pb)
310
311  @contextlib.contextmanager
312  def _new_step(self):
313    acquired = self._lock.acquire(False)
314    yield (self._step, acquired)
315    self._step += 1
316    self._trace_next_step = False
317    self._dump_next_step = False
318    if acquired:
319      self._lock.release()
320
321  def _profile_candidates(self):
322    to_profile = []
323    for auto_prof in self._auto_profiles:
324      _, _, prof_steps = auto_prof
325      if self._step in prof_steps:
326        to_profile.append(auto_prof)
327    return to_profile
328
329  def __enter__(self):
330    if self._enabled:
331      self.old_run = getattr(session.BaseSession, 'run', None)
332      self.old_init = getattr(session.BaseSession, '__init__', None)
333      if not self.old_run:
334        raise errors.InternalError(None, None, 'BaseSession misses run method.')
335      elif not self.old_init:
336        raise errors.InternalError(None, None,
337                                   'BaseSession misses __init__ method.')
338      elif getattr(session.BaseSession, '_profiler_run_internal', None):
339        raise errors.InternalError(None, None,
340                                   'Already in context or context not cleaned.')
341      elif getattr(session.BaseSession, '_profiler_init_internal', None):
342        raise errors.InternalError(None, None,
343                                   'Already in context or context not cleaned.')
344      else:
345        setattr(session.BaseSession, 'run', _profiled_run)
346        setattr(session.BaseSession, '__init__', _profiled_init)
347        setattr(session.BaseSession, '_profiler_run_internal', self.old_run)
348        setattr(session.BaseSession, '_profiler_init_internal', self.old_init)
349        setattr(session.BaseSession, 'profile_context', self)
350        return self
351    else:
352      return self
353
354  def __exit__(self, exec_type, exec_value, exec_tb):
355    if not self._enabled:
356      return
357    print_mdl.DeleteProfiler()
358    setattr(session.BaseSession, 'run', self.old_run)
359    setattr(session.BaseSession, '__init__', self.old_init)
360    setattr(session.BaseSession, '_profiler_run_internal', None)
361    setattr(session.BaseSession, '_profiler_init_internal', None)
362    setattr(session.BaseSession, 'profile_context', None)
363