1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A Context that captures profile and performs profiling/dumping. 16""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import contextlib 22import os 23import random 24import sys 25import threading 26 27from tensorflow.core.protobuf import config_pb2 28from tensorflow.python.client import session 29from tensorflow.python.framework import errors 30from tensorflow.python.framework import ops 31from tensorflow.python.platform import gfile 32from tensorflow.python.profiler import model_analyzer 33from tensorflow.python.util import _pywrap_tfprof as print_mdl 34from tensorflow.python.util import compat 35 36WARMUP_STEPS = 10 37MAX_TRACED_STEPS = 100 38 39 40def _profiled_init(self, target='', graph=None, config=None): 41 """Overwrites the session.__init__.""" 42 self._profiler_init_internal(target, graph, config) # pylint: disable=protected-access 43 44 45def _profiled_run(self, 46 fetches, 47 feed_dict=None, 48 options=None, 49 run_metadata=None): 50 """Overwrites the session.run().""" 51 # pylint: disable=protected-access 52 # Count the session steps. 53 with self.profile_context._new_step() as state: 54 step, locked = state 55 # Fast path if no need for profiling. 56 if locked and not self.profile_context._is_fast_path(step): 57 # Maybe trace this step. 58 if self.profile_context._should_trace(step, self.graph, fetches): 59 if self.profile_context._debug: 60 sys.stderr.write('debug: tracing step: %d\n' % step) 61 # Enable tracing, perform auto profiling or auto dump. 62 if not run_metadata: 63 run_metadata = config_pb2.RunMetadata() 64 65 if not options: 66 options = config_pb2.RunOptions( 67 trace_level=config_pb2.RunOptions.FULL_TRACE) 68 old_trace_level = options.trace_level 69 else: 70 old_trace_level = options.trace_level 71 options.trace_level = config_pb2.RunOptions.FULL_TRACE 72 73 ret = self._profiler_run_internal( 74 fetches, feed_dict, options, run_metadata) 75 if self.profile_context._debug: 76 self.profile_context._dump_file(run_metadata, 'run_meta_%d' % step) 77 78 self.profile_context.profiler._graph = self.graph 79 self.profile_context.profiler.add_step(step, run_metadata) 80 options.trace_level = old_trace_level 81 else: 82 ret = self._profiler_run_internal(fetches, feed_dict, options) 83 84 # Maybe dump profile. 85 self.profile_context._maybe_dump(step) 86 87 # Maybe profile: 88 to_profiles = self.profile_context._profile_candidates() 89 for to_prof in to_profiles: 90 cmd, opts, _ = to_prof 91 saved_views = self.profile_context._views.setdefault(cmd, {}) 92 if self.profile_context._debug: 93 sys.stderr.write('debug: profiling %s step: %d\n' % (cmd, step)) 94 if cmd == 'graph': 95 saved_views[step] = self.profile_context.profiler.profile_graph(opts) 96 elif cmd == 'scope': 97 saved_views[step] = self.profile_context.profiler.profile_name_scope( 98 opts) 99 elif cmd == 'op': 100 saved_views[step] = self.profile_context.profiler.profile_operations( 101 opts) 102 elif cmd == 'code': 103 saved_views[step] = self.profile_context.profiler.profile_python(opts) 104 else: 105 raise ValueError('Unknown cmd: %s\n' % cmd) 106 return ret 107 # Fast no lock path. 108 return self._profiler_run_internal( 109 fetches, feed_dict, options, run_metadata) 110 # pylint: enable=protected-access 111 112 113class ProfileContext(object): 114 """A Context that captures RunMetadata and performs profiling. 115 116 ```python 117 # Trace steps 100~200, profile at [150, 200] and dump profile at 200. 118 with profile_context.ProfileContext('/tmp/train_dir', 119 trace_steps=range(100, 200, 3), 120 dump_steps=[200]) as pctx: 121 opts = tf.profiler.ProfileOptionBuilder.time_and_memory() 122 pctx.add_auto_profiling('op', opts, [150, 200]) 123 train_loop(). 124 125 # Tracing only. 126 with profile_context.tfprof.ProfileContext('/tmp/train_dir') as pctx: 127 # Run train/eval loop for at least few hundred steps. Profiles will be 128 # dumped to train_dir. Use web UI or command line to do profiling. 129 train_loop(). 130 131 # When session object is available, do explicit trace, profile and dump. 132 with profile_context.ProfileContext('/tmp/train_dir', 133 trace_steps=[], 134 dump_steps=[]) as pctx: 135 opts = tf.profiler.ProfileOptionBuilder.time_and_memory() 136 pctx.trace_next_step() 137 _ = session.run(train_op) 138 pctx.profiler.profile_operations(options=opts) 139 ``` 140 141 Args: 142 profile_dir: Directory to store profiles. 143 trace_steps: A list of session run steps to trace. If None, use 144 pre-defined steps. 145 dump_steps: A list of steps to dump the profile to `profile_dir`. If None, 146 use pre-defined steps. 147 enabled: If false, everything is disabled with minimal overhead. It allows 148 user to only enable profiling when needed. 149 debug: If true, also dumps the raw trace RunMetadata text file to 150 profile_dir. And print debugging message. Useful for bug report. 151 """ 152 153 def __init__(self, 154 profile_dir, 155 trace_steps=None, 156 dump_steps=None, 157 enabled=True, 158 debug=False): 159 self._enabled = enabled 160 if not self._enabled: 161 return 162 163 self._debug = debug 164 if not profile_dir: 165 raise ValueError('Must have a directory for profile.\n') 166 self._profiler_dir = profile_dir 167 168 if trace_steps is None: 169 self._trace_steps = set() 170 self._auto_tracing = True 171 else: 172 if len(trace_steps) > MAX_TRACED_STEPS: 173 raise ValueError('Only support tracing up to 100 steps.\n') 174 self._trace_steps = set(trace_steps[:]) 175 self._auto_tracing = False 176 177 if dump_steps is None: 178 self._dump_steps = set([MAX_TRACED_STEPS]) 179 else: 180 self._dump_steps = set(dump_steps[:]) 181 182 self._rng = random.Random(111) 183 self._fetched = set() 184 self._slow_path_steps = self._dump_steps | self._trace_steps 185 self._trace_next_step = False 186 self._dump_next_step = False 187 self._step = 0 188 self._traced_steps = 0 189 self._auto_profiles = [] 190 self._profiler = None 191 self._views = {} 192 self._lock = threading.Lock() 193 194 def get_profiles(self, cmd): 195 """Returns profiling results for each step at which `cmd` was run. 196 197 Args: 198 cmd: string, profiling command used in an `add_auto_profiling` call. 199 200 Returns: 201 dict[int: (MultiGraphNodeProto | GraphNodeProto)]. Keys are steps at which 202 the profiling command was run. Values are the outputs of profiling. 203 For "code" and "op" commands this will be a `MultiGraphNodeProto`, for 204 "scope" and "graph" commands this will be a `GraphNodeProto. 205 206 Raises: 207 ValueError: if `cmd` was never run (either because no session.run call was 208 made or because there was no `add_auto_profiling` call with the specified 209 `cmd`. 210 """ 211 if cmd not in self._views: 212 raise ValueError('No autoprofiler for command: {}, was run'.format(cmd)) 213 return self._views[cmd] 214 215 def add_auto_profiling(self, cmd, options, profile_steps): 216 """Traces and profiles at some session run steps. 217 218 Args: 219 cmd: The profiling commands. (i.e. scope, op, python, graph) 220 options: The profiling options. 221 profile_steps: A list/set of integers. The profiling command and options 222 will be run automatically at these integer steps. Each step is 223 a session.run. 224 """ 225 if not self._enabled: 226 return 227 self._auto_profiles.append((cmd, options, profile_steps[:])) 228 self._slow_path_steps |= set(profile_steps) 229 self._trace_steps |= set(profile_steps) 230 231 @property 232 def profiler(self): 233 """Returns the current profiler object.""" 234 if not self._enabled: 235 return None 236 if not self._profiler: 237 self._profiler = model_analyzer.Profiler(ops.get_default_graph()) 238 return self._profiler 239 240 def trace_next_step(self): 241 """Enables tracing and adds traces to profiler at next step.""" 242 if not self._enabled: 243 return 244 self._trace_next_step = True 245 self._slow_path_steps.add(self._step) 246 247 def dump_next_step(self): 248 """Enable tracing and dump profiles at next step.""" 249 if not self._enabled: 250 return 251 self._dump_next_step = True 252 self._slow_path_steps.add(self._step) 253 254 def _is_fast_path(self, step): 255 if step in self._slow_path_steps: 256 return False 257 # When user doesn't set the tracing steps explicitly, auto decide it. 258 if (self._auto_tracing and step > WARMUP_STEPS and 259 self._traced_steps <= MAX_TRACED_STEPS): 260 return False 261 return True 262 263 def _should_trace(self, step, graph, fetches): 264 """Whether should do tracing at current step.""" 265 if self._traced_steps > MAX_TRACED_STEPS: 266 return False 267 # Check user-set tracing steps. 268 if step in self._trace_steps or self._trace_next_step: 269 self._traced_steps += 1 270 return True 271 272 # If no user-set tracing steps set and passes warm up steps, auto trace. 273 if self._auto_tracing and step > WARMUP_STEPS: 274 # If the fetches have not been seen before, trace it. 275 with graph.as_default(): 276 fetch_names = [f.name for f in 277 session._FetchMapper.for_fetch(fetches).unique_fetches()] # pylint: disable=protected-access 278 fetch_name = '-'.join(sorted(fetch_names)) 279 if self._debug: 280 sys.stderr.write('debug: trace fetches: %s\n' % fetch_name) 281 if fetch_name not in self._fetched: 282 self._fetched.add(fetch_name) 283 self._traced_steps += 1 284 return True 285 # If the trace coverage is low, does some random tracing. 286 if (self.profiler._coverage < 0.5 and step < MAX_TRACED_STEPS and # pylint: disable=protected-access 287 self._rng.randint(0, 10) < 2): 288 self._traced_steps += 1 289 return True 290 return False 291 292 def _maybe_dump(self, step): 293 """Maybe dump the profile file.""" 294 if not (step in self._dump_steps or self._dump_next_step): 295 return 296 if self._debug: 297 sys.stderr.write('debug: dumping file at step: %d\n' % step) 298 if not gfile.Exists(self._profiler_dir): 299 gfile.MakeDirs(self._profiler_dir) 300 301 filename = os.path.join(compat.as_bytes(self._profiler_dir), 302 compat.as_bytes('profile_%d' % step)) 303 self.profiler._write_profile(filename) # pylint: disable=protected-access 304 305 def _dump_file(self, pb, basename): 306 if not gfile.Exists(self._profiler_dir): 307 gfile.MakeDirs(self._profiler_dir) 308 with gfile.Open(os.path.join(self._profiler_dir, basename), 'w') as f: 309 f.write('%s' % pb) 310 311 @contextlib.contextmanager 312 def _new_step(self): 313 acquired = self._lock.acquire(False) 314 yield (self._step, acquired) 315 self._step += 1 316 self._trace_next_step = False 317 self._dump_next_step = False 318 if acquired: 319 self._lock.release() 320 321 def _profile_candidates(self): 322 to_profile = [] 323 for auto_prof in self._auto_profiles: 324 _, _, prof_steps = auto_prof 325 if self._step in prof_steps: 326 to_profile.append(auto_prof) 327 return to_profile 328 329 def __enter__(self): 330 if self._enabled: 331 self.old_run = getattr(session.BaseSession, 'run', None) 332 self.old_init = getattr(session.BaseSession, '__init__', None) 333 if not self.old_run: 334 raise errors.InternalError(None, None, 'BaseSession misses run method.') 335 elif not self.old_init: 336 raise errors.InternalError(None, None, 337 'BaseSession misses __init__ method.') 338 elif getattr(session.BaseSession, '_profiler_run_internal', None): 339 raise errors.InternalError(None, None, 340 'Already in context or context not cleaned.') 341 elif getattr(session.BaseSession, '_profiler_init_internal', None): 342 raise errors.InternalError(None, None, 343 'Already in context or context not cleaned.') 344 else: 345 setattr(session.BaseSession, 'run', _profiled_run) 346 setattr(session.BaseSession, '__init__', _profiled_init) 347 setattr(session.BaseSession, '_profiler_run_internal', self.old_run) 348 setattr(session.BaseSession, '_profiler_init_internal', self.old_init) 349 setattr(session.BaseSession, 'profile_context', self) 350 return self 351 else: 352 return self 353 354 def __exit__(self, exec_type, exec_value, exec_tb): 355 if not self._enabled: 356 return 357 print_mdl.DeleteProfiler() 358 setattr(session.BaseSession, 'run', self.old_run) 359 setattr(session.BaseSession, '__init__', self.old_init) 360 setattr(session.BaseSession, '_profiler_run_internal', None) 361 setattr(session.BaseSession, '_profiler_init_internal', None) 362 setattr(session.BaseSession, 'profile_context', None) 363