1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Experimental API for TensorFlow's "Eager" mode of execution."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import contextlib
23import copy
24import random
25import threading
26
27from tensorflow.core.protobuf import config_pb2
28from tensorflow.python import pywrap_tensorflow
29from tensorflow.python.framework import c_api_util
30from tensorflow.python.framework import device as pydev
31from tensorflow.python.framework import errors
32from tensorflow.python.util import compat
33from tensorflow.python.util import is_in_graph_mode
34from tensorflow.python.util import tf_contextlib
35
36GRAPH_MODE = 0
37EAGER_MODE = 1
38
39# Default execution mode.
40_default_mode = GRAPH_MODE
41
42# Cache from (old_device_name, partial_new_device_name) -> (new_device_name,
43# new_device_spec).
44# Note that we do not protect this with a lock and instead rely on python's GIL
45# and the idempotent nature of writes to provide thread safety.
46_device_parsing_cache = {}
47
48_MAXINT32 = 2**31 - 1
49
50DEVICE_PLACEMENT_EXPLICIT = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_EXPLICIT
51DEVICE_PLACEMENT_WARN = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_WARN
52DEVICE_PLACEMENT_SILENT = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT
53DEVICE_PLACEMENT_SILENT_FOR_INT32 = (
54    pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32)
55
56
57# TODO(agarwal): better name ?
58class _EagerContext(threading.local):
59  """Thread local eager context."""
60
61  def __init__(self):
62    super(_EagerContext, self).__init__()
63    self.device_spec = pydev.DeviceSpec.from_string(
64        "/job:localhost/replica:0/task:0/device:CPU:0")
65    self.device_name = self.device_spec.to_string()
66    self.mode = _default_mode
67    self.scope_name = ""
68    self.recording_summaries = False
69    self.summary_writer_resource = None
70    self.scalar_cache = {}
71
72
73ContextStackEntry = collections.namedtuple(
74    "ContextStackEntry", ["is_building_function", "enter_context_fn"])
75
76
77class ContextStack(threading.local):
78  """A thread-local stack of context switches."""
79
80  def __init__(self):
81    super(ContextStack, self).__init__()
82    self.stack = []
83
84  def push(self, is_building_function, enter_context_fn):
85    """Push metadata about a context switch onto the stack.
86
87    A context switch can take one of two forms: installing a graph as the
88    default graph, or entering the eager context.
89
90    Args:
91      is_building_function: (bool.) Whether the context is building a function.
92      enter_context_fn: (function.) A callable that executes the context switch.
93        For example, `graph.as_default` or `eager_mode`.
94    """
95
96    self.stack.append(
97        ContextStackEntry(is_building_function, enter_context_fn))
98
99  def pop(self):
100    """Pop the stack."""
101
102    self.stack.pop()
103
104
105context_stack = ContextStack()
106
107
108# TODO(agarwal): rename to EagerContext / EagerRuntime ?
109# TODO(agarwal): consider keeping the corresponding Graph here.
110class Context(object):
111  """Environment in which eager operations execute."""
112
113  def __init__(self, config=None, device_policy=None):
114    """Creates a new Context.
115
116    Args:
117      config: (Optional.) A `ConfigProto` protocol buffer with configuration
118       options for the Context. Note that a lot of these options may be
119       currently unimplemented or irrelevant when eager execution is enabled.
120      device_policy: (Optional.) What policy to use when trying to run an
121       operation on a device with inputs which are not on that device.
122       Valid values:
123         tfe.DEVICE_PLACEMENT_EXPLICIT: raises an error if the placement is not
124           correct.
125         tfe.DEVICE_PLACEMENT_WARN: copies the tensors which are not on the
126           right device but raises a warning.
127         tfe.DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might
128           hide performance problems.
129         tfe.DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies int32 tensors,
130           raising errors on the other ones.
131    """
132    self._eager_context = _EagerContext()
133    self._context_handle = None
134    self._context_devices = None
135    self._post_execution_callbacks = []
136    self._config = config
137    self._seed = None
138    self._initialize_lock = threading.Lock()
139    self._device_policy = device_policy
140
141  def _set_global_seed(self, seed):
142    """Set a global eager mode seed for random ops."""
143    self._seed = seed
144    self._rng = random.Random(self._seed)
145    # Also clear the kernel cache, to reset any existing seeds
146    if self._context_handle is not None:
147      pywrap_tensorflow.TFE_ContextClearCaches(self._context_handle)
148
149  def _internal_operation_seed(self):
150    """Returns a fake operation seed.
151
152      In eager mode, user shouldn't set or depend on operation seed.
153      Here, we generate a random seed based on global seed to make
154      operation's randomness different and depend on the global seed.
155
156    Returns:
157      A fake operation seed based on global seed.
158    """
159    return self._rng.randint(0, _MAXINT32)
160
161  def _initialize_handle_and_devices(self):
162    """Initialize handle and devices."""
163    with self._initialize_lock:
164      if self._context_handle is not None:
165        return
166      assert self._context_devices is None
167      opts = pywrap_tensorflow.TFE_NewContextOptions()
168      try:
169        with errors.raise_exception_on_not_ok_status() as status:
170          if self._config is not None:
171            config_str = self._config.SerializeToString()
172            pywrap_tensorflow.TFE_ContextOptionsSetConfig(
173                opts, config_str, len(config_str), status)
174          if self._device_policy is not None:
175            pywrap_tensorflow.TFE_ContextOptionsSetDevicePlacementPolicy(
176                opts, self._device_policy)
177          self._context_handle = pywrap_tensorflow.TFE_NewContext(opts, status)
178      finally:
179        pywrap_tensorflow.TFE_DeleteContextOptions(opts)
180      # Store list of devices
181      self._context_devices = []
182      with errors.raise_exception_on_not_ok_status() as status:
183        device_list = pywrap_tensorflow.TFE_ContextListDevices(
184            self._context_handle, status)
185      try:
186        self._num_gpus = 0
187        for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)):
188          with errors.raise_exception_on_not_ok_status() as status:
189            dev_name = pywrap_tensorflow.TF_DeviceListName(
190                device_list, i, status)
191          self._context_devices.append(pydev.canonical_name(dev_name))
192          with errors.raise_exception_on_not_ok_status() as status:
193            dev_type = pywrap_tensorflow.TF_DeviceListType(
194                device_list, i, status)
195          if dev_type == "GPU":
196            self._num_gpus += 1
197
198      finally:
199        pywrap_tensorflow.TF_DeleteDeviceList(device_list)
200
201  @property
202  def _handle(self):
203    ctx = self._context_handle
204    if ctx is None:
205      self._initialize_handle_and_devices()
206      return self._context_handle
207    else:
208      return ctx
209
210  @property
211  def _devices(self):
212    devices = self._context_devices
213    if devices is None:
214      self._initialize_handle_and_devices()
215      return self._context_devices
216    else:
217      return devices
218
219  def __str__(self):
220    if self._context_handle is None:
221      return "Eager TensorFlow Context. Devices currently uninitialized."
222    else:
223      devices = self._devices
224      lines = ["Eager TensorFlow Context with %d devices" % (len(devices))]
225      for i, d in enumerate(devices):
226        lines.append("   Device %d: %s" % (i, d))
227      return "\n".join(lines)
228
229  @tf_contextlib.contextmanager
230  def _mode(self, mode):
231    ctx = self._eager_context
232    old_mode = ctx.mode
233    ctx.mode = mode
234    if mode == EAGER_MODE:
235      context_stack.push(False, eager_mode)
236    try:
237      yield
238    finally:
239      ctx.mode = old_mode
240      if mode == EAGER_MODE:
241        context_stack.pop()
242
243  def in_graph_mode(self):
244    """Returns True if current thread is in GRAPH mode."""
245    return self._eager_context.mode == GRAPH_MODE
246
247  def in_eager_mode(self):
248    """Returns True if current thread is in EAGER mode."""
249    return self._eager_context.mode == EAGER_MODE
250
251  def scalar_cache(self):
252    """Per-device cache for scalars."""
253    return self._eager_context.scalar_cache
254
255  @property
256  def scope_name(self):
257    """Returns scope name for the current thread."""
258    return self._eager_context.scope_name
259
260  @scope_name.setter
261  def scope_name(self, s):
262    """Sets scope name for the current thread."""
263    self._eager_context.scope_name = s
264
265  @property
266  def summary_writer_resource(self):
267    """Returns summary writer resource."""
268    return self._eager_context.summary_writer_resource
269
270  @summary_writer_resource.setter
271  def summary_writer_resource(self, resource):
272    """Sets summary writer resource."""
273    self._eager_context.summary_writer_resource = resource
274
275  @property
276  def device_name(self):
277    """Returns the device name for the current thread."""
278    return self._eager_context.device_name
279
280  @property
281  def device_spec(self):
282    """Returns the device spec for the current thread."""
283    return self._eager_context.device_spec
284
285  @tf_contextlib.contextmanager
286  def device(self, name):
287    """Context-manager to force placement of operations and Tensors on a device.
288
289    Args:
290      name: Name of the device or None to get default placement.
291
292    Yields:
293      Nothing.
294
295    Raises:
296      ValueError: If name is not a string or is an invalid device name.
297    """
298    eager_context = self._eager_context
299    old_device_name = eager_context.device_name
300    old_device_spec = eager_context.device_spec
301    cache_key = (old_device_name, name)
302    try:
303      new_device_name, new_device_spec = _device_parsing_cache[cache_key]
304    except TypeError:
305      # Error while trying to compute the cache key.
306      raise ValueError("Expecting a string device name. Got %s(%s)" %
307                       (type(name), name))
308    except KeyError:
309      # Handle a cache miss.
310      if name is not None:
311        if not isinstance(name, str):
312          raise ValueError("Expecting a string device name. Got %s(%s)" %
313                           (type(name), name))
314        device_spec = pydev.DeviceSpec.from_string(name)
315        if old_device_name:
316          new_device_spec = copy.copy(old_device_spec)
317        else:
318          new_device_spec = pydev.DeviceSpec.from_string(
319              "/job:localhost/replica:0/task:0/device:CPU:0")
320        new_device_spec.merge_from(device_spec)
321      else:
322        new_device_spec = pydev.DeviceSpec.from_string("")
323      new_device_name = new_device_spec.to_string()
324      _device_parsing_cache[cache_key] = (new_device_name, new_device_spec)
325
326    try:
327      eager_context.device_name = new_device_name
328      eager_context.device_spec = new_device_spec
329      yield
330    finally:
331      eager_context.device_name = old_device_name
332      eager_context.device_spec = old_device_spec
333
334  def devices(self):
335    """List of the names of devices available to execute operations."""
336    return self._devices
337
338  def num_gpus(self):
339    """The number of GPUs available to execute operations."""
340    self._initialize_handle_and_devices()
341    return self._num_gpus
342
343  def add_function(self, fn):
344    """Add a function definition to the context.
345
346    Once added, the function (identified by its name) can be executed like any
347    other operation.
348
349    Args:
350      fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper).
351    """
352    with errors.raise_exception_on_not_ok_status() as status:
353      pywrap_tensorflow.TFE_ContextAddFunction(
354          self._handle,  # pylint: disable=protected-access
355          fn,
356          status)
357
358  def add_function_def(self, fdef):
359    """Add a function definition to the context.
360
361    Once added, the function (identified by its name) can be executed like any
362    other operation.
363
364    Args:
365      fdef: A FunctionDef protocol buffer message.
366    """
367    fdef_string = fdef.SerializeToString()
368    with errors.raise_exception_on_not_ok_status() as status:
369      pywrap_tensorflow.TFE_ContextAddFunctionDef(
370          self._handle,  # pylint: disable=protected-access
371          fdef_string,
372          len(fdef_string),
373          status)
374
375  def add_post_execution_callback(self, callback):
376    """Add a post-execution callback to the context.
377
378    A post-execution callback is invoked immediately after an eager operation or
379    function has finished execution, providing access to the op's type, name
380    input and output tensors. Multiple execution callbacks can be added, in
381    which case the callbacks will be invoked in the order in which they are
382    added.
383
384    Args:
385      callback: a callable of the signature
386      `f(op_type, op_name, attrs, inputs, outputs)`.
387      `op_type` is the type of the operation that was just executed (e.g.,
388        `MatMul`).
389      `op_name` is the name of the operation that has was just executed. This
390        name is set by the client who created the operation and can be `None` if
391        it is unset.
392      `attrs` contains the attributes of the operation as a `tuple` of
393        alternating attribute names and attribute values.
394      `inputs` is the `list` of input `Tensor`(s) to the op.
395      `outputs` is the `list` of output `Tensor`(s) from the op.
396       Return value(s) from the callback are ignored.
397    """
398    # TODO(cais): (b/64674139) Allow access to function-internal operations.
399    self._post_execution_callbacks.append(callback)
400
401  def clear_post_execution_callbacks(self):
402    """Clear all post-execution callbacks added to the context."""
403    del self._post_execution_callbacks[:]
404
405  @property
406  def post_execution_callbacks(self):
407    """Get the list of post-execution callbacks added to the context."""
408    return self._post_execution_callbacks
409
410  def enable_run_metadata(self):
411    """Enables tracing of op execution via RunMetadata.
412
413    To retrieve the accumulated metadata call context.export_run_metadata()
414    and to stop tracing call context.disable_run_metadata().
415    """
416    if not self._context_handle:
417      self._initialize_handle_and_devices()
418    pywrap_tensorflow.TFE_ContextEnableRunMetadata(self._context_handle)
419
420  @tf_contextlib.contextmanager
421  def device_policy(self, policy):
422    if not self._context_handle:
423      self._initialize_handle_and_devices()
424    old = pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(
425        self._context_handle)
426    pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy(
427        self._handle, policy)
428    try:
429      yield
430    finally:
431      pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy(
432          self._handle, old)
433
434  def disable_run_metadata(self):
435    """Disables tracing of op execution via RunMetadata."""
436    if not self._context_handle:
437      return
438    pywrap_tensorflow.TFE_ContextDisableRunMetadata(self._context_handle)
439
440  def export_run_metadata(self):
441    """Returns a RunMetadata proto with accumulated information.
442
443    The returned protocol buffer contains information since the most recent call
444    to either enable_run_metadata or export_run_metadata.
445
446    Returns:
447      A RunMetadata protocol buffer. Or None if not enabled.
448    """
449    if not self._context_handle:
450      return None
451    with c_api_util.tf_buffer() as buffer_:
452      with errors.raise_exception_on_not_ok_status() as status:
453        pywrap_tensorflow.TFE_ContextExportRunMetadata(
454            self._context_handle, buffer_, status)
455      proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
456    run_metadata = config_pb2.RunMetadata()
457    run_metadata.ParseFromString(compat.as_bytes(proto_data))
458    return run_metadata
459
460_context = None
461_context_lock = threading.Lock()
462
463
464def _initialize_context():
465  global _context
466  with _context_lock:
467    if _context is None:
468      _context = Context()
469
470
471def context():
472  """Returns a singleton context object."""
473  if _context is None:
474    _initialize_context()
475  return _context
476
477
478# TODO(agarwal): remove this.
479def get_default_context():
480  """Same as context."""
481  if _context is None:
482    _initialize_context()
483  return _context
484
485
486def set_global_seed(seed):
487  """Sets the eager mode seed."""
488  context()._set_global_seed(seed)  # pylint: disable=protected-access
489
490
491def global_seed():
492  """Returns the eager mode seed."""
493  return context()._seed  # pylint: disable=protected-access
494
495
496def internal_operation_seed():
497  """Returns the operation seed generated based on global seed."""
498  return context()._internal_operation_seed()  # pylint: disable=protected-access
499
500
501def in_graph_mode():
502  """Returns True if current thread is in GRAPH mode for default context."""
503  return context().in_graph_mode()
504
505
506def in_eager_mode():
507  """Returns True if current thread is in EAGER mode for default context."""
508  return context().in_eager_mode()
509
510
511def graph_mode():
512  """Context-manager to enable GRAPH mode for current thread."""
513  return context()._mode(GRAPH_MODE)  # pylint: disable=protected-access
514
515
516def eager_mode():
517  """Context-manager to enable EAGER mode for current thread."""
518  return context()._mode(EAGER_MODE)  # pylint: disable=protected-access
519
520
521# TODO(agarwal): get rid of this and use ops.name_scope instead.
522@contextlib.contextmanager
523def namescope(name):
524  """ContextManager for creating hierarchical name scopes."""
525  ctx = context()
526  old_name = ctx.scope_name
527  ctx.scope_name = "%s/%s" % (old_name, name) if old_name else name
528  try:
529    yield
530  finally:
531    ctx.scope_name = old_name
532
533
534def scope_name():
535  """Name of the current scope."""
536  return context().scope_name
537
538
539def device(name):
540  """Context-manager to force placement of operations and Tensors on a device.
541
542  Example:
543  ```python
544  with tfe.device('gpu:0'):
545    with tfe.device('cpu:0'):
546      shape = tf.constant([], dtype=tf.int32)
547    x = tf.truncated_normal(shape, tf.float32)
548  ```
549  will ensure that the `shape` Tensor is on CPU but the `truncated_normal`
550  operation runs on GPU 0.
551
552  Args:
553    name: Name of the device (see context().devices()), or None to
554      perform automatic placement.
555
556  Returns:
557    Context manager for setting the device.
558  """
559  return context().device(name)
560
561
562def list_devices():
563  """List the names of the available devices.
564
565  Returns:
566    Names of the available devices, as a `list`.
567  """
568  return context().devices()
569
570
571def num_gpus():
572  """Get the number of available GPU devices.
573
574  Returns:
575    The number of available GPU devices.
576  """
577  return context().num_gpus()
578
579
580def enable_run_metadata():
581  """Enables tracing of op execution via RunMetadata.
582
583  To retrieve the accumulated metadata call context.export_run_metadata()
584  and to stop tracing call context.disable_run_metadata().
585  """
586  context().enable_run_metadata()
587
588
589def disable_run_metadata():
590  """Disables tracing of op execution via RunMetadata."""
591  context().disable_run_metadata()
592
593
594def export_run_metadata():
595  """Returns a RunMetadata proto with accumulated information.
596
597  The returned protocol buffer contains information since the most recent call
598  to either enable_run_metadata or export_run_metadata.
599
600  Returns:
601    A RunMetadata protocol buffer.
602  """
603  return context().export_run_metadata()
604
605
606# Not every user creates a Context via context.context()
607# (for example, enable_eager_execution in python/framework/ops.py),
608# but they do all import this file.  Note that IS_IN_GRAPH_MODE and
609# in_graph_mode are both parameterless functions.
610is_in_graph_mode.IS_IN_GRAPH_MODE = in_graph_mode
611