1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""State management for eager execution."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import contextlib
23import copy
24import os
25import random
26import threading
27
28from absl import logging
29import numpy as np
30import six
31
32from tensorflow.core.framework import function_pb2
33from tensorflow.core.protobuf import config_pb2
34from tensorflow.core.protobuf import rewriter_config_pb2
35from tensorflow.python import pywrap_tfe
36from tensorflow.python import tf2
37from tensorflow.python.client import pywrap_tf_session
38from tensorflow.python.eager import executor
39from tensorflow.python.eager import monitoring
40from tensorflow.python.framework import c_api_util
41from tensorflow.python.framework import device as pydev
42from tensorflow.python.framework import tfrt_utils
43from tensorflow.python.util import compat
44from tensorflow.python.util import is_in_graph_mode
45from tensorflow.python.util import tf_contextlib
46from tensorflow.python.util.deprecation import deprecated
47from tensorflow.python.util.tf_export import tf_export
48
49GRAPH_MODE = 0
50EAGER_MODE = 1
51
52default_execution_mode = EAGER_MODE if tf2.enabled() else GRAPH_MODE
53
54# Cache from (old_device_name, partial_new_device_name) -> (new_device_name,
55# new_device_spec).
56# Note that we do not protect this with a lock and instead rely on python's GIL
57# and the idempotent nature of writes to provide thread safety.
58_device_parsing_cache = {}
59_starting_device_spec = pydev.DeviceSpec.from_string("")
60
61_MAXINT32 = 2**31 - 1
62
63DEVICE_PLACEMENT_EXPLICIT = pywrap_tfe.TFE_DEVICE_PLACEMENT_EXPLICIT
64DEVICE_PLACEMENT_WARN = pywrap_tfe.TFE_DEVICE_PLACEMENT_WARN
65DEVICE_PLACEMENT_SILENT = pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT
66DEVICE_PLACEMENT_SILENT_FOR_INT32 = (
67    pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32)
68
69SYNC = 0
70ASYNC = 1
71
72_KEEP_ALIVE_SECS = 600
73
74_python_eager_context_create_counter = monitoring.Counter(
75    "/tensorflow/api/python/eager_context_create_counter",
76    "Counter for number of eager contexts created in Python.")
77
78# Re-exporting through context.
79is_tfrt_enabled = tfrt_utils.enabled
80
81# Expose it as internally public APIs for Keras use cases in b/171080602.
82tf_export("__internal__.is_tfrt_enabled", v1=[])(is_tfrt_enabled)
83
84
85class _EagerTensorCache(object):
86  """Simple cache which evicts items based on length in a FIFO manner."""
87
88  __slots__ = ["_data", "_max_items", "_max_tensor_size"]
89
90  def __init__(self, max_items=256, max_tensor_size=10000):
91    self._data = collections.OrderedDict()
92    self._max_items = max_items
93    self._max_tensor_size = max_tensor_size
94
95  def put(self, key, value):
96    if value._num_elements() > self._max_tensor_size:  # pylint: disable=protected-access
97      return
98
99    self._data[key] = value
100
101    if len(self._data) > self._max_items:
102      self._data.popitem(last=False)
103
104  def get(self, key):
105    return self._data.get(key, None)
106
107  def flush(self):
108    self._data.clear()
109
110
111class FunctionCallOptions(object):
112  """Options applied at call sites of eager functions.
113
114  Eager functions are functions decorated with tf.contrib.eager.defun.
115  """
116
117  __slots__ = ["_config_proto_serialized", "_executor_type"]
118
119  def __init__(self, executor_type=None, config_proto=None):
120    """Constructor.
121
122    Args:
123      executor_type: (optional) name of the executor to be used to execute the
124        eager function. If None or an empty string, the default Tensorflow
125        executor will be used.
126      config_proto: (optional) a `config_pb2.ConfigProto` proto or
127        a serialized string of that proto.
128        The config used by Grappler when optimizing the function graph.
129        Each concrete function is optimized the first time is called. Changing
130        config_proto after the first call has no effect.
131        If config_proto is None, an empty RewriterConfig will be used.
132    """
133    self.config_proto_serialized = config_proto
134    self.executor_type = executor_type
135
136  @property
137  def executor_type(self):
138    return self._executor_type
139
140  @executor_type.setter
141  def executor_type(self, executor_type):
142    self._executor_type = executor_type
143
144  @property
145  def config_proto_serialized(self):
146    return self._config_proto_serialized
147
148  @config_proto_serialized.setter
149  def config_proto_serialized(self, config):
150    if isinstance(config, config_pb2.ConfigProto):
151      self._config_proto_serialized = config.SerializeToString(
152          deterministic=True)
153    elif isinstance(config, str):
154      self._config_proto_serialized = config
155    elif config is None:
156      self._config_proto_serialized = (
157          config_pb2.ConfigProto().SerializeToString())
158    else:
159      raise ValueError("the rewriter config must be either a "
160                       "config_pb2.ConfigProto, or a serialized string of that "
161                       "proto or None. got: {}".format(type(config)))
162
163
164# Map from context_id (an int) to _TensorCaches.
165# Dicts are thread safe in CPython.
166# TODO(iga): Remove this once TensorCaches are moved to C++.
167_tensor_caches_map = {}
168
169
170class _TensorCaches(threading.local):
171  """Thread local tensor caches."""
172
173  __slots__ = ["_ones_rank_cache", "_zeros_cache"]
174
175  def __init__(self):
176    super(_TensorCaches, self).__init__()
177    self._ones_rank_cache = None
178    self._zeros_cache = None
179
180  @property
181  def ones_rank_cache(self):
182    if not self._ones_rank_cache:
183      self._ones_rank_cache = _EagerTensorCache()
184    return self._ones_rank_cache
185
186  @property
187  def zeros_cache(self):
188    if not self._zeros_cache:
189      self._zeros_cache = _EagerTensorCache()
190    return self._zeros_cache
191
192
193ContextSwitch = collections.namedtuple(
194    "ContextSwitch", ["is_building_function", "enter_context_fn",
195                      "device_stack"])
196
197
198# `_ContextSwitchStack` is a `threading.local` to match the semantics of
199# ``DefaultGraphStack`, which is also a `threading.local`.
200class _ContextSwitchStack(threading.local):
201  """A thread-local stack of context switches."""
202
203  def __init__(self, eager):
204    super(_ContextSwitchStack, self).__init__()
205    self.stack = []
206    if eager:
207      # Initialize the stack with a pointer to enter the eager context; this
208      # ensures that the fact that eager execution was enabled is propagated
209      # across threads, since (1) `enable_eager_execution` modifies a
210      # process-level flag (`default_execution_mode`) and (2) `__init__` is
211      # called each time a threading.local object is used in a separate thread.
212      self.push(is_building_function=False, enter_context_fn=eager_mode,
213                device_stack=None)
214
215  def push(self, is_building_function, enter_context_fn, device_stack):
216    """Push metadata about a context switch onto the stack.
217
218    A context switch can take any one of the two forms: installing a graph as
219    the default graph, or entering the eager context. For each context switch,
220    we record whether or not the entered context is building a function.
221
222    Args:
223      is_building_function: (bool.) Whether the context is building a function.
224      enter_context_fn: (function.) A callable that executes the context switch.
225        For example, `graph.as_default` or `eager_mode`.
226      device_stack: If applicable, the device function stack for this
227        graph. When breaking out of graphs in init_scope, the innermost nonempty
228        device stack is used. Eager contexts put `None` here and the value is
229        never used.
230    """
231
232    self.stack.append(
233        ContextSwitch(is_building_function, enter_context_fn, device_stack))
234
235  def pop(self):
236    """Pop the stack."""
237
238    self.stack.pop()
239
240
241@tf_export("config.LogicalDevice")
242class LogicalDevice(
243    collections.namedtuple("LogicalDevice", ["name", "device_type"])):
244  """Abstraction for a logical device initialized by the runtime.
245
246  A `tf.config.LogicalDevice` corresponds to an initialized logical device on a
247  `tf.config.PhysicalDevice` or a remote device visible to the cluster. Tensors
248  and operations can be placed on a specific logical device by calling
249  `tf.device` with a specified `tf.config.LogicalDevice`.
250
251  Fields:
252    name: The fully qualified name of the device. Can be used for Op or function
253      placement.
254    device_type: String declaring the type of device such as "CPU" or "GPU".
255  """
256  pass
257
258
259@tf_export("config.LogicalDeviceConfiguration",
260           "config.experimental.VirtualDeviceConfiguration")
261class LogicalDeviceConfiguration(
262    collections.namedtuple("LogicalDeviceConfiguration",
263                           ["memory_limit", "experimental_priority"])):
264  """Configuration class for a logical devices.
265
266  The class specifies the parameters to configure a `tf.config.PhysicalDevice`
267  as it is initialized to a `tf.config.LogicalDevice` during runtime
268  initialization. Not all fields are valid for all device types.
269
270  See `tf.config.get_logical_device_configuration` and
271  `tf.config.set_logical_device_configuration` for usage examples.
272
273  Fields:
274    memory_limit: (optional) Maximum memory (in MB) to allocate on the virtual
275      device. Currently only supported for GPUs.
276    experimental_priority: (optional) Priority to assign to a virtual device.
277      Lower values have higher priorities and 0 is the default.
278      Within a physical GPU, the GPU scheduler will prioritize ops on virtual
279      devices with higher priority. Currently only supported for Nvidia GPUs.
280  """
281
282  def __new__(cls, memory_limit=None, experimental_priority=None):
283    return super(LogicalDeviceConfiguration,
284                 cls).__new__(cls, memory_limit, experimental_priority)
285
286
287@tf_export("config.PhysicalDevice")
288class PhysicalDevice(
289    collections.namedtuple("PhysicalDevice", ["name", "device_type"])):
290  """Abstraction for a locally visible physical device.
291
292  TensorFlow can utilize various devices such as the CPU or multiple GPUs
293  for computation. Before initializing a local device for use, the user can
294  customize certain properties of the device such as it's visibility or memory
295  configuration.
296
297  Once a visible `tf.config.PhysicalDevice` is initialized one or more
298  `tf.config.LogicalDevice` objects are created. Use
299  `tf.config.set_visible_devices` to configure the visibility of a physical
300  device and `tf.config.set_logical_device_configuration` to configure multiple
301  `tf.config.LogicalDevice` objects for a `tf.config.PhysicalDevice`. This is
302  useful when separation between models is needed or to simulate a multi-device
303  environment.
304
305  Fields:
306    name: Unique identifier for device.
307    device_type: String declaring the type of device such as "CPU" or "GPU".
308  """
309  pass
310
311
312class _AtomicCounter(object):
313  """A simple atomic counter."""
314
315  __slots__ = ["_value", "_lock"]
316
317  def __init__(self):
318    self._value = 0
319    self._lock = threading.Lock()
320
321  def increment_and_get(self):
322    with self._lock:
323      self._value += 1
324      return self._value
325
326
327_context_id_counter = _AtomicCounter()
328
329
330class _TensorCacheDeleter(object):
331  """Deletes tensor caches for a given context."""
332
333  __slots__ = ["_context_id"]
334
335  def __init__(self, context_id):
336    self._context_id = context_id
337
338  def __del__(self):
339    if _tensor_caches_map is None:
340      return
341    if self._context_id in _tensor_caches_map:
342      del _tensor_caches_map[self._context_id]
343
344
345# TODO(agarwal): rename to EagerContext / EagerRuntime ?
346# TODO(agarwal): consider keeping the corresponding Graph here.
347class Context(object):
348  """Environment in which eager operations execute."""
349
350  # TODO(agarwal): create and link in some documentation for `execution_mode`.
351  # pylint: disable=redefined-outer-name
352  def __init__(self,
353               config=None,
354               device_policy=None,
355               execution_mode=None,
356               server_def=None):
357    """Creates a new Context.
358
359    Args:
360      config: (Optional.) A `ConfigProto` protocol buffer with configuration
361        options for the Context. Note that a lot of these options may be
362        currently unimplemented or irrelevant when eager execution is enabled.
363      device_policy: (Optional.) What policy to use when trying to run an
364        operation on a device with inputs which are not on that device.
365        When set to None, an appropriate value will be picked automatically.
366        The value picked may change between TensorFlow releases.
367
368        Defaults to DEVICE_PLACEMENT_SILENT.
369        Valid values:
370        - DEVICE_PLACEMENT_EXPLICIT: raises an error if the placement is
371          not correct.
372        - DEVICE_PLACEMENT_WARN: copies the tensors which are not on the
373          right device but raises a warning.
374        - DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might
375          hide performance problems.
376        - DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies int32 tensors,
377          raising errors on the other ones.
378      execution_mode: (Optional.) Policy controlling how operations dispatched
379        are actually executed. When set to None, an appropriate value will be
380        picked automatically. The value picked may change between TensorFlow
381        releases.
382        Valid values:
383        - SYNC: executes each operation synchronously.
384        - ASYNC: executes each operation asynchronously. These
385          operations may return "non-ready" handles.
386      server_def: (Optional.) A tensorflow::ServerDef proto.
387        Enables execution on remote devices. GrpcServers need to be started by
388        creating an identical server_def to this, and setting the appropriate
389        task_indexes, so that the servers can communicate. It will then be
390        possible to execute operations on remote devices.
391
392    Raises:
393     ValueError: If execution_mode is not valid.
394    """
395    # This _id is used only to index the tensor caches.
396    # TODO(iga): Remove this when tensor caches are moved to C++.
397    self._id = _context_id_counter.increment_and_get()
398    self._tensor_cache_deleter = _TensorCacheDeleter(self._id)
399    _tensor_caches_map[self._id] = _TensorCaches()
400
401    self._config = config
402    self._thread_local_data = pywrap_tfe.EagerContextThreadLocalData(
403        self,
404        is_eager=lambda: default_execution_mode == EAGER_MODE,
405        device_spec=_starting_device_spec)
406    self._context_switches = _ContextSwitchStack(self.executing_eagerly())
407    self._context_handle = None
408    self._context_devices = None
409    self._seed = None
410    self._initialize_lock = threading.Lock()
411    self._initialized = False
412    if device_policy is None:
413      device_policy = DEVICE_PLACEMENT_SILENT
414    self._device_policy = device_policy
415    self._mirroring_policy = None
416    if execution_mode not in (None, SYNC, ASYNC):
417      raise ValueError(
418          "execution_mode should be None/SYNC/ASYNC. Got %s" % execution_mode)
419    if execution_mode is None:
420      execution_mode = SYNC
421    self._default_is_async = execution_mode == ASYNC
422    self._use_tfrt = is_tfrt_enabled()
423    self._server_def = server_def
424    self._collective_ops_server_def = None
425    self._collective_leader = None
426    self._collective_scoped_allocator_enabled_ops = None
427    self._collective_use_nccl_communication = None
428    self._collective_device_filters = None
429
430    self._device_lock = threading.Lock()
431    self._physical_devices = None
432    self._physical_device_to_index = None
433    self._visible_device_list = []
434    self._memory_growth_map = None
435    self._virtual_device_map = {}
436
437    # Values set after construction
438    self._optimizer_jit = None
439    self._intra_op_parallelism_threads = None
440    self._inter_op_parallelism_threads = None
441    self._soft_device_placement = None
442    self._log_device_placement = None
443    self._enable_mlir_graph_optimization = None
444    self._optimizer_experimental_options = {}
445
446    _python_eager_context_create_counter.get_cell().increase_by(1)
447  # pylint: enable=redefined-outer-name
448
449  def _set_global_seed(self, seed):
450    """Set a global eager mode seed for random ops."""
451    self._seed = seed
452    # `random.Random(seed)` needs `seed` to be hashable, while values of type
453    # e.g. `np.int64` or `np.ndarray` are not. We use `int(...)` to convert them
454    # to int.
455    try:
456      hash(seed)
457    except TypeError:
458      seed = int(np.array(seed))
459    self._rng = random.Random(seed)
460    # Also clear the kernel cache, to reset any existing seeds
461    if self._context_handle is not None:
462      pywrap_tfe.TFE_ContextClearCaches(self._context_handle)
463
464  def _internal_operation_seed(self):
465    """Returns a fake operation seed.
466
467      In eager mode, user shouldn't set or depend on operation seed.
468      Here, we generate a random seed based on global seed to make
469      operation's randomness different and depend on the global seed.
470
471    Returns:
472      A fake operation seed based on global seed.
473    """
474    return self._rng.randint(0, _MAXINT32)
475
476  def _initialize_logical_devices(self):
477    """Helper to initialize devices."""
478    # Store list of devices
479    logical_devices = []
480    context_devices = []
481    device_list = pywrap_tfe.TFE_ContextListDevices(self._context_handle)
482    try:
483      self._num_gpus = 0
484      for i in range(pywrap_tfe.TF_DeviceListCount(device_list)):
485        dev_name = pywrap_tfe.TF_DeviceListName(device_list, i)
486        context_devices.append(pydev.canonical_name(dev_name))
487        spec = pydev.DeviceSpec.from_string(dev_name)
488        # If the job is localhost, we assume that the cluster has not yet been
489        # configured and thus clear the job, replica & task.
490        if spec.job == "localhost":
491          spec = spec.replace(job=None, replica=None, task=None)
492        logical_devices.append(
493            LogicalDevice(name=spec.to_string(), device_type=spec.device_type))
494        dev_type = pywrap_tfe.TF_DeviceListType(device_list, i)
495        if dev_type == "GPU":
496          self._num_gpus += 1
497
498    finally:
499      self._logical_devices = logical_devices
500      self._context_devices = context_devices
501      pywrap_tfe.TF_DeleteDeviceList(device_list)
502
503  def ensure_initialized(self):
504    """Initialize handle and devices if not already done so."""
505    if self._initialized:
506      return
507    with self._initialize_lock:
508      if self._initialized:
509        return
510      assert self._context_devices is None
511      opts = pywrap_tfe.TFE_NewContextOptions()
512      try:
513        config_str = self.config.SerializeToString()
514        pywrap_tfe.TFE_ContextOptionsSetConfig(opts, config_str)
515        if self._device_policy is not None:
516          pywrap_tfe.TFE_ContextOptionsSetDevicePlacementPolicy(
517              opts, self._device_policy)
518        if self._mirroring_policy is not None:
519          pywrap_tfe.TFE_ContextOptionsSetMirroringPolicy(
520              opts, self._mirroring_policy)
521        if self._default_is_async == ASYNC:
522          pywrap_tfe.TFE_ContextOptionsSetAsync(opts, True)
523        if self._use_tfrt is not None:
524          pywrap_tfe.TFE_ContextOptionsSetTfrt(opts, self._use_tfrt)
525        context_handle = pywrap_tfe.TFE_NewContext(opts)
526      finally:
527        pywrap_tfe.TFE_DeleteContextOptions(opts)
528      assert not (self._server_def and self._collective_ops_server_def), (
529          "Cannot enable remote execution as well as collective ops at the "
530          "moment. If this is important to you, please file an issue.")
531      if self._server_def is not None:
532        server_def_str = self._server_def.SerializeToString()
533        pywrap_tfe.TFE_ContextSetServerDef(context_handle, _KEEP_ALIVE_SECS,
534                                           server_def_str)
535      elif self._collective_ops_server_def is not None:
536        server_def_str = self._collective_ops_server_def.SerializeToString()
537        pywrap_tfe.TFE_EnableCollectiveOps(context_handle, server_def_str)
538
539      self._context_handle = context_handle
540      self._initialize_logical_devices()
541      self._initialized = True
542
543  def _clear_caches(self):
544    self.ones_rank_cache().flush()
545    self.zeros_cache().flush()
546    pywrap_tfe.TFE_ClearScalarCache()
547
548  def get_server_def(self):
549    return self._server_def
550
551  def set_server_def(self, server_def, keep_alive_secs=_KEEP_ALIVE_SECS):
552    """Allow setting a server_def on the context.
553
554    When a server def is replaced, it effectively clears a bunch of caches
555    within the context. If you attempt to use a tensor object that was pointing
556    to a tensor on the remote device, it will raise an error.
557
558    Args:
559      server_def: A tensorflow::ServerDef proto.
560        Enables execution on remote devices.
561      keep_alive_secs: Num. seconds after which the remote end will hang up.
562        As long as the client is still alive, the server state for the context
563        will be kept alive. If the client is killed (or there is some failure),
564        the server will clean up its context keep_alive_secs after the final RPC
565        it receives.
566
567    Raises:
568      ValueError: if server_def is None.
569    """
570    if not server_def:
571      raise ValueError("server_def is None.")
572
573    self._server_def = server_def
574
575    if self._context_handle:
576      server_def_str = server_def.SerializeToString()
577      pywrap_tfe.TFE_ContextSetServerDef(self._context_handle, keep_alive_secs,
578                                         server_def_str)
579      self._initialize_logical_devices()
580
581    # Clear all the caches in case there are remote tensors in them.
582    self._clear_caches()
583
584  def update_server_def(self, server_def, keep_alive_secs=_KEEP_ALIVE_SECS):
585    """Update a server_def on the context.
586
587    Args:
588      server_def: A tensorflow::ServerDef proto. Enables execution on remote
589        devices.
590      keep_alive_secs: Num. seconds after which the remote end will hang up. As
591        long as the client is still alive, the server state for the context will
592        be kept alive. If the client is killed (or there is some failure), the
593        server will clean up its context keep_alive_secs after the final RPC it
594        receives.
595
596    Raises:
597      ValueError: if server_def is None.
598    """
599    if not server_def:
600      raise ValueError("server_def is None.")
601
602    self._server_def = server_def
603
604    if self._context_handle:
605      server_def_str = server_def.SerializeToString()
606      pywrap_tfe.TFE_ContextUpdateServerDef(self._context_handle,
607                                            keep_alive_secs, server_def_str)
608      self._initialize_logical_devices()
609
610    self._clear_caches()
611
612  def check_alive(self, worker_name):
613    """Checks whether a remote worker is alive or not.
614
615    Args:
616      worker_name: a string representing the remote worker. It must be a fully
617      specified name like "/job:worker/replica:0/task:0".
618
619    Returns:
620      a boolean indicating whether the remote worker is alive or not.
621
622    Raises:
623      ValueError: if context is not initialized.
624    """
625    # TODO(yuefengz): support checking multiple workers.
626    if self._context_handle:
627      return pywrap_tfe.TFE_ContextCheckAlive(self._context_handle, worker_name)
628    else:
629      raise ValueError("Context is not initialized.")
630
631  def sync_executors(self):
632    """Sync both local executors and the ones on remote workers.
633
634    In async execution mode, local function calls can return before the
635    corresponding remote op/function execution requests are completed. Calling
636    this method creates a synchronization barrier for remote executors. It only
637    returns when all remote pending nodes are finished, potentially with errors
638    if any remote executors are in error state.
639
640    Raises:
641      ValueError: if context is not initialized.
642    """
643    if self._context_handle:
644      pywrap_tfe.TFE_ContextSyncExecutors(self._context_handle)
645    else:
646      raise ValueError("Context is not initialized.")
647
648  def clear_executor_errors(self):
649    """Clear errors in both local executors and remote workers.
650
651    After receiving errors from remote workers, additional requests on the fly
652    could further taint the status on the remote workers due to the async nature
653    of remote execution. Calling this method block on waiting for all pending
654    nodes in remote executors to finish and clear their error statuses.
655
656    Raises:
657      ValueError: if context is not initialized.
658    """
659    if self._context_handle:
660      pywrap_tfe.TFE_ContextClearExecutors(self._context_handle)
661    else:
662      raise ValueError("Context is not initialized.")
663
664  def clear_kernel_cache(self):
665    """Clear kernel cache and reset all stateful kernels.
666
667    Raises:
668      ValueError: if context is not initialized.
669    """
670    if self._context_handle is not None:
671      pywrap_tfe.TFE_ContextClearCaches(self._context_handle)
672    else:
673      raise ValueError("Context is not initialized.")
674
675  def enable_collective_ops(self, server_def):
676    """Enable distributed collective ops with an appropriate server_def.
677
678    Args:
679      server_def: A tensorflow::ServerDef proto. Enables execution on remote
680        devices.
681
682    Raises:
683      ValueError: if server_def is None.
684      RuntimeError: if this method is not called at program startup.
685    """
686    if not server_def:
687      raise ValueError("server_def is None.")
688
689    self._collective_ops_server_def = server_def
690
691    # TODO(b/129298253): Allow creating datasets/tensors before enabling
692    # collective ops.
693    if self._context_handle is not None:
694      logging.warning("Enabling collective ops after program startup may cause "
695                      "error when accessing previously created tensors.")
696      with self._initialize_lock:
697        assert self._initialized
698        server_def_str = self._collective_ops_server_def.SerializeToString()
699        pywrap_tfe.TFE_EnableCollectiveOps(self._context_handle, server_def_str)
700        self._initialize_logical_devices()
701        self._clear_caches()
702
703  def configure_collective_ops(
704      self,
705      collective_leader="",
706      scoped_allocator_enabled_ops=("CollectiveReduce",),
707      use_nccl_communication=False,
708      device_filters=None):
709    """Configure collective ops.
710
711      Collective group leader is necessary for collective ops to run, other
712      configurations are mainly for the purpose of performance.
713
714    Args:
715      collective_leader: a device string for collective leader, e.g.
716        "/job:worker/replica:0/task:0"; empty string means local execution of
717          collective ops.
718      scoped_allocator_enabled_ops: a tuple or a list of op names for scoped
719        allocator to run with.
720      use_nccl_communication: whether to use nccl communication for collective
721        ops.
722      device_filters: a tuple or a list of device strings. If set, corresponding
723        task can only see the devices filtered by these device filters.
724
725    Raises:
726      RuntimeError: if this method is not called at program startup.
727    """
728    if self._collective_leader is not None:
729      if (self._collective_leader != collective_leader or
730          self._collective_scoped_allocator_enabled_ops !=
731          scoped_allocator_enabled_ops or
732          self._collective_use_nccl_communication != use_nccl_communication or
733          self._collective_device_filters != device_filters):
734        raise ValueError("Collective ops are already configured.")
735      else:
736        return
737
738    if self._context_handle is not None:
739      raise RuntimeError("Collective ops must be configured at program startup")
740
741    self._collective_leader = collective_leader
742    self._collective_scoped_allocator_enabled_ops = scoped_allocator_enabled_ops
743    self._collective_use_nccl_communication = use_nccl_communication
744    self._collective_device_filters = device_filters
745
746  def abort_collective_ops(self, code, message):
747    """Abort the collective ops.
748
749    This is intended to be used when a peer failure is detected, which allows
750    the user to handle the case instead of hanging. This aborts all on-going
751    collectives. After all subsequent collectives error immediately, and you
752    need to reset_context() to use collectives again.
753
754    Args:
755      code: a `tf.errors` error code.
756      message: a string. The error message.
757    """
758    self.ensure_initialized()
759    pywrap_tfe.TFE_AbortCollectiveOps(self._handle, code, message)
760
761  def check_collective_ops_peer_health(self, task, timeout_in_ms):
762    """Check collective peer health.
763
764    This probes each task to see if they're still alive. Note that restarted
765    tasks are considered a different one, and they're considered not healthy.
766
767    This should only be used in multi client multi worker training.
768
769    Args:
770      task: a task string, must be in the format of /job:xxx/replica:0/task:N.
771      timeout_in_ms: an integer, the timeout. If zero, there's no timeout.
772
773    Raises:
774      tf.errors.UnavailableError: when a peer is down.
775      tf.errors.FailedPreconditionError: when a peer is a different one from the
776        one this task has talked to, e.g. the peer has restarted.
777      tf.errors.InvalidArgumentError: when the task string is invalid.
778    """
779    self.ensure_initialized()
780    pywrap_tfe.TFE_CollectiveOpsCheckPeerHealth(self._handle, task,
781                                                timeout_in_ms)
782
783  @property
784  def _handle(self):
785    if self._context_handle is None:
786      raise AssertionError("Context must be initialized first.")
787
788    return self._context_handle
789
790  @property
791  def _devices(self):
792    if self._context_devices is None:
793      raise AssertionError("Context must be initialized first.")
794
795    return self._context_devices
796
797  def __str__(self):
798    if self._context_handle is None:
799      return "Eager TensorFlow Context. Devices currently uninitialized."
800    else:
801      devices = self._devices
802      lines = ["Eager TensorFlow Context with %d devices" % (len(devices))]
803      for i, d in enumerate(devices):
804        lines.append("   Device %d: %s" % (i, d))
805      return "\n".join(lines)
806
807  @tf_contextlib.contextmanager
808  def _mode(self, mode):
809    """A context manager to allow setting the mode to EAGER/GRAPH."""
810    ctx = self._thread_local_data
811    old_is_eager = ctx.is_eager
812    ctx.is_eager = mode == EAGER_MODE
813    if mode == EAGER_MODE:
814      # Entering graph mode does not provide us with sufficient information to
815      # record a context switch; graph-based context switches are only logged
816      # when a graph is registered as the default graph.
817      self.context_switches.push(False, eager_mode, None)
818    try:
819      yield
820    finally:
821      ctx.is_eager = old_is_eager
822      if mode == EAGER_MODE:
823        self.context_switches.pop()
824
825  def executing_eagerly(self):
826    """Returns True if current thread has eager executing enabled."""
827    return self._thread_local_data.is_eager
828
829  def ones_rank_cache(self):
830    """Per-device cache for scalars."""
831    return _tensor_caches_map[self._id].ones_rank_cache
832
833  def zeros_cache(self):
834    """Per-device cache for scalars."""
835    return _tensor_caches_map[self._id].zeros_cache
836
837  @property
838  def scope_name(self):
839    """Returns scope name for the current thread."""
840    return self._thread_local_data.scope_name
841
842  @scope_name.setter
843  def scope_name(self, s):
844    """Sets scope name for the current thread."""
845    self._thread_local_data.scope_name = s
846
847  @property
848  def device_name(self):
849    """Returns the device name for the current thread."""
850    return self._thread_local_data.device_name
851
852  @property
853  def device_spec(self):
854    """Returns the device spec for the current thread."""
855    return self._thread_local_data.device_spec
856
857  def _set_device(self, device_name, device_spec):
858    self._thread_local_data.device_name = device_name
859    self._thread_local_data.device_spec = device_spec
860
861  def device(self, name):
862    """Context-manager to force placement of operations and Tensors on a device.
863
864    Args:
865      name: Name of the device or None to get default placement.
866
867    Returns:
868      Context manager that forces device placement.
869
870    Raises:
871      ValueError: If name is not a string or is an invalid device name.
872      RuntimeError: If device scopes are not properly nested.
873    """
874    if isinstance(name, LogicalDevice):
875      name = name.name
876    elif pydev.is_device_spec(name):
877      name = name.to_string()
878    return _EagerDeviceContext(self, name)
879
880  def devices(self):
881    """List of the names of devices available to execute operations."""
882    return self._devices
883
884  def host_address_space(self):
885    self.ensure_initialized()
886    with c_api_util.tf_buffer() as buffer_:
887      pywrap_tfe.TFE_HostAddressSpace(self._context_handle, buffer_)
888      address_space = pywrap_tf_session.TF_GetBuffer(buffer_).decode("utf-8")
889    return address_space
890
891  # TODO(fishx): remove this property.
892  @property
893  def execution_mode(self):
894    """Gets execution mode for current thread."""
895    return ASYNC if self.is_async() else SYNC
896
897  @execution_mode.setter
898  def execution_mode(self, mode):
899    """Sets execution mode for current thread."""
900    if mode not in (None, SYNC, ASYNC):
901      raise ValueError(
902          "Execution mode should be None/SYNC/ASYNC. Got %s" % mode)
903
904    if mode is None:
905      mode = SYNC
906
907    enable_async = (mode == ASYNC)
908    if self.is_async() != enable_async:
909      # Only set the execution mode if the context has already been initialized
910      if self._context_handle is not None:
911        self.executor.wait()
912        executor_new = executor.new_executor(enable_async)
913        self._thread_local_data.executor = executor_new
914        pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle,
915                                                   executor_new.handle())
916      else:
917        self._default_is_async = enable_async
918
919  def is_async(self):
920    if self._context_handle is not None:
921      return self.executor.is_async()
922    else:
923      return self._default_is_async
924
925  @property
926  def executor(self):
927    self.ensure_initialized()
928    return executor.Executor(
929        pywrap_tfe.TFE_ContextGetExecutorForThread(self._context_handle))
930
931  @executor.setter
932  def executor(self, e):
933    self.ensure_initialized()
934    pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle, e.handle())
935
936  @property
937  def config(self):
938    """Return the ConfigProto with all runtime deltas applied."""
939    # Ensure physical devices have been discovered and config has been imported
940    self._initialize_physical_devices()
941
942    config = config_pb2.ConfigProto()
943    if self._config is not None:
944      config.CopyFrom(self._config)
945
946    if self._optimizer_jit is not None:
947      config.graph_options.optimizer_options.global_jit_level = (
948          config_pb2.OptimizerOptions.ON_1
949          if self._optimizer_jit else config_pb2.OptimizerOptions.OFF)
950    if self._intra_op_parallelism_threads is not None:
951      config.intra_op_parallelism_threads = self._intra_op_parallelism_threads
952    if self._inter_op_parallelism_threads is not None:
953      config.inter_op_parallelism_threads = self._inter_op_parallelism_threads
954
955    if self._soft_device_placement is not None:
956      config.allow_soft_placement = self._soft_device_placement
957    else:
958      config.allow_soft_placement = self.executing_eagerly()
959
960    if self._log_device_placement is not None:
961      config.log_device_placement = self._log_device_placement
962
963    is_mlir_bridge_enabled = pywrap_tfe.TF_IsMlirBridgeEnabled()
964    config.experimental.mlir_bridge_rollout = is_mlir_bridge_enabled
965    if (is_mlir_bridge_enabled ==
966        config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_ENABLED):
967      config.experimental.enable_mlir_bridge = True
968
969    if self._enable_mlir_graph_optimization is not None:
970      config.experimental.enable_mlir_graph_optimization = (
971          self._enable_mlir_graph_optimization)
972
973    def rewriter_toggle(option):
974      toggle = self._optimizer_experimental_options.get(option, None)
975      if toggle is None:
976        return
977
978      setattr(config.graph_options.rewrite_options,
979              option,
980              (rewriter_config_pb2.RewriterConfig.ON
981               if toggle else rewriter_config_pb2.RewriterConfig.OFF))
982
983    def rewriter_bool(option):
984      toggle = self._optimizer_experimental_options.get(option, None)
985      if toggle is None:
986        return
987
988      setattr(config.graph_options.rewrite_options,
989              option,
990              toggle)
991
992    rewriter_toggle("layout_optimizer")
993    rewriter_toggle("constant_folding")
994    rewriter_toggle("shape_optimization")
995    rewriter_toggle("remapping")
996    rewriter_toggle("arithmetic_optimization")
997    rewriter_toggle("dependency_optimization")
998    rewriter_toggle("loop_optimization")
999    rewriter_toggle("function_optimization")
1000    rewriter_toggle("debug_stripper")
1001    rewriter_bool("disable_model_pruning")
1002    rewriter_toggle("scoped_allocator_optimization")
1003    rewriter_toggle("pin_to_host_optimization")
1004    rewriter_toggle("implementation_selector")
1005    rewriter_toggle("auto_mixed_precision")
1006    rewriter_bool("disable_meta_optimizer")
1007    nodes = self._optimizer_experimental_options.get("min_graph_nodes", None)
1008    if nodes is not None:
1009      config.graph_options.rewrite_options.min_graph_nodes = nodes
1010
1011    # Compute device counts
1012    config.device_count["CPU"] = 0
1013    config.device_count["GPU"] = 0
1014    for dev in self._physical_devices:
1015      if dev not in self._visible_device_list:
1016        continue
1017
1018      virtual_devices = self._virtual_device_map.get(dev)
1019      if virtual_devices is None:
1020        config.device_count[dev.device_type] += 1
1021      else:
1022        config.device_count[dev.device_type] += len(virtual_devices)
1023
1024    # Configure gpu_options
1025    gpu_options = self._compute_gpu_options()
1026    config.gpu_options.MergeFrom(gpu_options)
1027
1028    # Configure collective ops
1029    if self._collective_leader:
1030      config.experimental.collective_group_leader = self._collective_leader
1031    if self._collective_scoped_allocator_enabled_ops:
1032      rewrite_options = config.graph_options.rewrite_options
1033      rewrite_options.scoped_allocator_optimization = (
1034          rewriter_config_pb2.RewriterConfig.ON)
1035      del rewrite_options.scoped_allocator_opts.enable_op[:]
1036      for op in self._collective_scoped_allocator_enabled_ops:
1037        rewrite_options.scoped_allocator_opts.enable_op.append(op)
1038    if self._collective_use_nccl_communication:
1039      config.experimental.collective_nccl = True
1040    if self._collective_device_filters:
1041      del config.device_filters[:]
1042      for f in self._collective_device_filters:
1043        config.device_filters.append(f)
1044
1045    return config
1046
1047  def _compute_gpu_options(self):
1048    """Build the GPUOptions proto."""
1049    visible_device_list = []
1050    virtual_devices = []
1051    gpu_index = -1
1052    memory_growths = set()
1053    for dev in self.list_physical_devices("GPU"):
1054      gpu_index += 1
1055
1056      if dev not in self._visible_device_list:
1057        continue
1058
1059      growth = self._memory_growth_map[dev]
1060      memory_growths.add(growth)
1061      visible_device_list.append(str(gpu_index))
1062
1063      if self._virtual_device_map:
1064        vdevs = self._virtual_device_map.get(dev, [])
1065        device_limits = []
1066        priority = []
1067        for virt_dev in vdevs:
1068          device_limits.append(virt_dev.memory_limit)
1069          if virt_dev.experimental_priority is not None:
1070            priority.append(virt_dev.experimental_priority)
1071        # If priority is specified, it must be specified for all virtual
1072        # devices.
1073        if priority and len(device_limits) != len(priority):
1074          raise ValueError("priority must be specified for all virtual devices")
1075
1076        virtual_devices.append(
1077            config_pb2.GPUOptions.Experimental.VirtualDevices(
1078                memory_limit_mb=device_limits, priority=priority))
1079
1080    # Only compute growth if virtual devices have not been configured and we
1081    # have GPUs
1082    if not virtual_devices and memory_growths:
1083      if len(memory_growths) > 1:
1084        raise ValueError("Memory growth cannot differ between GPU devices")
1085      allow_growth = memory_growths.pop()
1086    else:
1087      allow_growth = None
1088
1089    return config_pb2.GPUOptions(
1090        allow_growth=allow_growth,
1091        visible_device_list=",".join(visible_device_list),
1092        experimental=config_pb2.GPUOptions.Experimental(
1093            virtual_devices=virtual_devices))
1094
1095  @property
1096  def function_call_options(self):
1097    """Returns function call options for current thread.
1098
1099    Note that the returned object is still referenced by the eager context.
1100
1101    Returns: the FunctionCallOptions for current thread.
1102    """
1103    if self._thread_local_data.function_call_options is None:
1104      config = self.config
1105
1106      # Default to soft placement for functions unless specified
1107      if self._soft_device_placement is None:
1108        config.allow_soft_placement = True
1109      self._thread_local_data.function_call_options = FunctionCallOptions(
1110          config_proto=config)
1111
1112    return self._thread_local_data.function_call_options
1113
1114  @function_call_options.setter
1115  def function_call_options(self, options):
1116    """Returns function call options for current thread."""
1117    self._thread_local_data.function_call_options = options
1118
1119  def num_gpus(self):
1120    """The number of GPUs available to execute operations."""
1121    self.ensure_initialized()
1122    return self._num_gpus
1123
1124  def add_function(self, fn):
1125    """Add a function definition to the context.
1126
1127    Once added, the function (identified by its name) can be executed like any
1128    other operation.
1129
1130    Args:
1131      fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper).
1132    """
1133    self.ensure_initialized()
1134    pywrap_tfe.TFE_ContextAddFunction(self._handle, fn)
1135
1136  def add_function_def(self, fdef):
1137    """Add a function definition to the context.
1138
1139    Once added, the function (identified by its name) can be executed like any
1140    other operation.
1141
1142    Args:
1143      fdef: A FunctionDef protocol buffer message.
1144    """
1145    self.ensure_initialized()
1146    fdef_string = fdef.SerializeToString()
1147    pywrap_tfe.TFE_ContextAddFunctionDef(self._handle, fdef_string,
1148                                         len(fdef_string))
1149
1150  def get_function_def(self, name):
1151    """Get a function definition from the context.
1152
1153    Args:
1154      name: function signature name.
1155
1156    Returns:
1157      The requested FunctionDef.
1158
1159    Raises:
1160      tf.errors.NotFoundError: if name is not the name of a registered function.
1161    """
1162    with c_api_util.tf_buffer() as buffer_:
1163      pywrap_tfe.TFE_ContextGetFunctionDef(self._handle, name, buffer_)
1164      proto_data = pywrap_tf_session.TF_GetBuffer(buffer_)
1165    function_def = function_pb2.FunctionDef()
1166    function_def.ParseFromString(proto_data)
1167
1168    return function_def
1169
1170  def register_custom_device(self, device_capsule, device_name,
1171                             device_info_capsule):
1172    """Calls TFE_RegisterCustomDevice. See the non-member function."""
1173    self.ensure_initialized()
1174    pywrap_tfe.TFE_Py_RegisterCustomDevice(self._handle, device_capsule,
1175                                           device_name, device_info_capsule)
1176
1177  def pack_eager_tensors(self, tensors):
1178    """Pack multiple `EagerTensor`s of the same dtype and shape.
1179
1180    Args:
1181      tensors: a list of EagerTensors to pack.
1182
1183    Returns:
1184      A packed EagerTensor.
1185    """
1186    self.ensure_initialized()
1187    return pywrap_tfe.TFE_Py_PackEagerTensors(self._handle, tensors)
1188
1189  def list_function_names(self):
1190    """Get a list of names of registered functions.
1191
1192    Returns:
1193      A set of names of all registered functions for the context.
1194    """
1195    self.ensure_initialized()
1196    return set(pywrap_tfe.TFE_ContextListFunctionNames(self._handle))
1197
1198  def remove_function(self, name):
1199    """Remove a function from the context.
1200
1201    Once removed, the function cannot be executed anymore.
1202
1203    Args:
1204      name: function signature name.
1205    """
1206    self.ensure_initialized()
1207    pywrap_tfe.TFE_ContextRemoveFunction(self._handle, name)
1208
1209  def has_function(self, name):
1210    """Check if a function `name` is registered."""
1211    self.ensure_initialized()
1212    return bool(pywrap_tfe.TFE_ContextHasFunction(self._handle, name))
1213
1214  def add_op_callback(self, callback):
1215    """Add a post-op callback to the context.
1216
1217    A post-op callback is invoked immediately after an eager operation or
1218    function has finished execution or after a op has been added to a graph,
1219    providing access to the op's type, name input and output tensors. Multiple
1220    op callbacks can be added, in which case the callbacks will be invoked in
1221    the order in which they are added.
1222
1223    Args:
1224      callback: a callable of the signature
1225        `f(op_type, inputs, attrs, outputs, op_name=None, graph=None)`.
1226        See doc strings in `op_callbacks.py` for details on the function
1227        signature and its semantics.
1228    """
1229    if callback not in self._thread_local_data.op_callbacks:
1230      self._thread_local_data.op_callbacks.append(callback)
1231
1232  def remove_op_callback(self, callback):
1233    """Remove an already-registered op callback.
1234
1235    Args:
1236      callback: The op callback to be removed.
1237
1238    Raises:
1239      KeyError: If `callback` is not already registered.
1240    """
1241    if callback not in self._thread_local_data.op_callbacks:
1242      raise KeyError(
1243          "The specified op callback has not been registered, "
1244          "and hence cannot be removed.")
1245    del self._thread_local_data.op_callbacks[
1246        self._thread_local_data.op_callbacks.index(callback)]
1247
1248  @property
1249  def op_callbacks(self):
1250    return self._thread_local_data.op_callbacks
1251
1252  @property
1253  def invoking_op_callbacks(self):
1254    return self._thread_local_data.invoking_op_callbacks
1255
1256  @invoking_op_callbacks.setter
1257  def invoking_op_callbacks(self, value):
1258    self._thread_local_data.invoking_op_callbacks = value
1259
1260  def _initialize_physical_devices(self, reinitialize=False):
1261    """Gets local devices visible to the system.
1262
1263    Args:
1264      reinitialize: If True, reinitializes self._physical_devices  so that
1265        dynamic registered devices will also be visible to the python front-end.
1266    """
1267    # We lazy initialize self._physical_devices since we do not want to do this
1268    # the constructor since the backend may not be initialized yet.
1269    with self._device_lock:
1270      if not reinitialize and self._physical_devices is not None:
1271        return
1272
1273      devs = pywrap_tfe.TF_ListPhysicalDevices()
1274      self._physical_devices = [
1275          PhysicalDevice(name=d.decode(),
1276                         device_type=d.decode().split(":")[1]) for d in devs]
1277      self._physical_device_to_index = {
1278          p: i for i, p in enumerate(self._physical_devices)
1279      }
1280
1281      self._visible_device_list = list(self._physical_devices)
1282      self._memory_growth_map = {
1283          d: None for d in self._physical_devices if d.device_type == "GPU"
1284      }
1285
1286    # Import device settings that may have been passed into the constructor
1287    self._import_config()
1288
1289  def reinitialize_physical_devices(self):
1290    """Gets local devices visible to the system."""
1291    # Reinitialize the physical device list after registering
1292    # the pluggable device.
1293    self._initialize_physical_devices(True)
1294
1295  def list_physical_devices(self, device_type=None):
1296    """List local devices visible to the system.
1297
1298    This API allows a client to query the devices before they have been
1299    initialized by the eager runtime. Additionally a user can filter by device
1300    type, to get only CPUs or GPUs.
1301
1302    Args:
1303      device_type: Optional device type to limit results to
1304
1305    Returns:
1306      List of PhysicalDevice objects.
1307    """
1308    self._initialize_physical_devices()
1309
1310    if device_type is None:
1311      return list(self._physical_devices)
1312
1313    return [d for d in self._physical_devices if d.device_type == device_type]
1314
1315  def get_device_details(self, device):  # pylint: disable=redefined-outer-name
1316    """Returns details about a physical devices.
1317
1318    Args:
1319      device: A `tf.config.PhysicalDevice` returned by
1320        `tf.config.list_physical_devices` or `tf.config.get_visible_devices`.
1321
1322    Returns:
1323      A dict with string keys.
1324    """
1325    if not isinstance(device, PhysicalDevice):
1326      raise ValueError("device must be a tf.config.PhysicalDevice, but got: "
1327                       "%s" % (device,))
1328    if (self._physical_device_to_index is None or
1329        device not in self._physical_device_to_index):
1330      raise ValueError("The PhysicalDevice must be one obtained from "
1331                       "calling `tf.config.list_physical_devices`, but got: "
1332                       "%s" % (device,))
1333    index = self._physical_device_to_index[device]
1334    details = pywrap_tfe.TF_GetDeviceDetails(index)
1335
1336    # Change compute_capability from a string to a tuple
1337    if "compute_capability" in details:
1338      try:
1339        major, minor = details["compute_capability"].split(".")
1340        details["compute_capability"] = (int(major), int(minor))
1341      except ValueError:
1342        raise RuntimeError("Device returned compute capability an in invalid "
1343                           "format: %s" % details["compute_capability"])
1344    return details
1345
1346  def _import_config(self):
1347    """Import config if passed in during construction.
1348
1349    If Context was created with a ConfigProto such as when calling
1350    tf.compat.v1.enable_eager_execution(), then we need to pull out the
1351    various pieces we might be replacing and import then into our internal
1352    class representation.
1353    """
1354    if self._config is None:
1355      return
1356
1357    num_cpus = self._config.device_count.get("CPU", 1)
1358    if num_cpus != 1:
1359      cpus = [d for d in self._physical_devices if d.device_type == "CPU"]
1360      if num_cpus == 0:
1361        self.set_visible_devices([], "CPU")
1362      elif num_cpus > 1:
1363        self.set_logical_device_configuration(
1364            cpus[0], [LogicalDeviceConfiguration() for _ in range(num_cpus)])
1365
1366    # Parse GPU options
1367    gpus = [d for d in self._physical_devices if d.device_type == "GPU"]
1368
1369    # If there are no GPUs detected, simply ignore all the GPU options passed in
1370    # rather than doing any validation checks.
1371    if not gpus:
1372      return
1373
1374    gpu_count = self._config.device_count.get("GPU", None)
1375
1376    visible_gpus = []
1377    # TODO(gjn): Handle importing existing virtual GPU configuration
1378    visible_indices = self._config.gpu_options.visible_device_list
1379    if visible_indices:
1380      for index in visible_indices.split(","):
1381        if int(index) >= len(gpus):
1382          raise ValueError("Invalid visible device index: %s" % index)
1383        visible_gpus.append(gpus[int(index)])
1384    else:
1385      visible_gpus = gpus
1386
1387    if gpu_count is not None:
1388      visible_gpus = visible_gpus[:gpu_count]
1389
1390    self.set_visible_devices(visible_gpus, "GPU")
1391
1392  def list_logical_devices(self, device_type=None):
1393    """Return logical devices."""
1394    self.ensure_initialized()
1395    if device_type is None:
1396      return list(self._logical_devices)
1397
1398    return [d for d in self._logical_devices if d.device_type == device_type]
1399
1400  def get_visible_devices(self, device_type=None):
1401    """Get the list of visible devices."""
1402    self._initialize_physical_devices()
1403
1404    if device_type is None:
1405      return list(self._visible_device_list)
1406
1407    return [
1408        d for d in self._visible_device_list if d.device_type == device_type
1409    ]
1410
1411  def set_visible_devices(self, devices, device_type=None):
1412    """Set the list of visible devices."""
1413    self._initialize_physical_devices()
1414
1415    if not isinstance(devices, list):
1416      devices = [devices]
1417
1418    for d in devices:
1419      if d not in self._physical_devices:
1420        raise ValueError("Unrecognized device: %s" % repr(d))
1421      if device_type is not None and d.device_type != device_type:
1422        raise ValueError("Unrecognized device: %s" % repr(d))
1423
1424    visible_device_list = []
1425    if device_type is not None:
1426      visible_device_list = [
1427          d for d in self._visible_device_list if d.device_type != device_type
1428      ]
1429
1430    visible_device_list += devices
1431
1432    if self._visible_device_list == visible_device_list:
1433      return
1434
1435    if self._context_handle is not None:
1436      raise RuntimeError(
1437          "Visible devices cannot be modified after being initialized")
1438
1439    self._visible_device_list = visible_device_list
1440
1441  def get_memory_info(self, dev):
1442    """Returns a dict of memory info for the device."""
1443    self._initialize_physical_devices()
1444    self.ensure_initialized()
1445    return pywrap_tfe.TFE_GetMemoryInfo(self._context_handle, dev)
1446
1447  # TODO(reedwm): Remove this function
1448  def get_total_memory_usage(self, dev):
1449    """Returns total memory usage in bytes for the current device."""
1450    return self.get_memory_info(dev)["current"]
1451
1452  def get_memory_growth(self, dev):
1453    """Get if memory growth is enabled for a PhysicalDevice."""
1454    self._initialize_physical_devices()
1455
1456    if dev not in self._physical_devices:
1457      raise ValueError("Unrecognized device: %s" % repr(dev))
1458
1459    return self._memory_growth_map[dev]
1460
1461  def set_memory_growth(self, dev, enable):
1462    """Set if memory growth should be enabled for a PhysicalDevice."""
1463    self._initialize_physical_devices()
1464
1465    if dev not in self._physical_devices:
1466      raise ValueError("Unrecognized device: %s" % repr(dev))
1467
1468    if dev in self._virtual_device_map:
1469      raise ValueError(
1470          "Cannot set memory growth on device when virtual devices configured")
1471
1472    if dev.device_type != "GPU":
1473      raise ValueError("Cannot set memory growth on non-GPU devices")
1474
1475    if self._memory_growth_map.get(dev) == enable:
1476      return
1477
1478    if self._context_handle is not None:
1479      raise RuntimeError(
1480          "Physical devices cannot be modified after being initialized")
1481
1482    self._memory_growth_map[dev] = enable
1483
1484  def get_logical_device_configuration(self, dev):
1485    """Get the virtual device configuration for a PhysicalDevice."""
1486    self._initialize_physical_devices()
1487
1488    if dev not in self._physical_devices:
1489      raise ValueError("Unrecognized device: %s" % repr(dev))
1490
1491    return self._virtual_device_map.get(dev)
1492
1493  def set_logical_device_configuration(self, dev, virtual_devices):
1494    """Set the virtual device configuration for a PhysicalDevice."""
1495    self._initialize_physical_devices()
1496
1497    if dev not in self._physical_devices:
1498      raise ValueError("Unrecognized device: %s" % repr(dev))
1499
1500    if dev.device_type == "CPU":
1501      for vdev in virtual_devices:
1502        if vdev.memory_limit is not None:
1503          raise ValueError("Setting memory limit on CPU virtual devices is "
1504                           "currently not supported")
1505        if vdev.experimental_priority is not None:
1506          raise ValueError("Setting experimental_priority on CPU virtual "
1507                           " devices is currently not supported")
1508    elif dev.device_type == "GPU":
1509      for vdev in virtual_devices:
1510        if vdev.memory_limit is None:
1511          raise ValueError(
1512              "Setting memory limit is required for GPU virtual devices")
1513    else:
1514      raise ValueError("Virtual devices are not supported for %s" %
1515                       dev.device_type)
1516
1517    if self._virtual_device_map.get(dev) == virtual_devices:
1518      return
1519
1520    if self._context_handle is not None:
1521      raise RuntimeError(
1522          "Virtual devices cannot be modified after being initialized")
1523
1524    self._virtual_device_map[dev] = virtual_devices
1525
1526  def get_compiler_ir(self, device_name, function_name, args, stage="hlo"):
1527    return pywrap_tfe.TF_GetCompilerIr(self._context_handle, function_name,
1528                                       stage, device_name, args)
1529
1530  @deprecated(
1531      None, "XLA:CPU and XLA:GPU devices are deprecated", warn_once=True)
1532  def enable_xla_devices(self):
1533    """Enables XLA:CPU and XLA:GPU devices registration."""
1534    pywrap_tfe.TF_EnableXlaDevices()
1535
1536  @property
1537  def enable_mlir_bridge(self):
1538    return pywrap_tfe.TF_IsMlirBridgeEnabled()
1539
1540  @property
1541  def enable_mlir_graph_optimization(self):
1542    return self._enable_mlir_graph_optimization
1543
1544  @enable_mlir_bridge.setter
1545  def enable_mlir_bridge(self, enabled):
1546    pywrap_tfe.TF_EnableMlirBridge(enabled)
1547    self._thread_local_data.function_call_options = None
1548
1549  @enable_mlir_graph_optimization.setter
1550  def enable_mlir_graph_optimization(self, enabled):
1551    self._enable_mlir_graph_optimization = enabled
1552    self._thread_local_data.function_call_options = None
1553
1554  @property
1555  def optimizer_jit(self):
1556    level = self.config.graph_options.optimizer_options.global_jit_level
1557    return (level == config_pb2.OptimizerOptions.ON_1 or
1558            level == config_pb2.OptimizerOptions.ON_2)
1559
1560  @optimizer_jit.setter
1561  def optimizer_jit(self, enabled):
1562    self._optimizer_jit = enabled
1563
1564    self._thread_local_data.function_call_options = None
1565
1566  def get_optimizer_experimental_options(self):
1567    """Get experimental options for the optimizer.
1568
1569    Returns:
1570      Dictionary of current option values
1571    """
1572    rewrite_options = self.config.graph_options.rewrite_options
1573    options = {}
1574
1575    def rewriter_toggle(option):
1576      attr = getattr(rewrite_options, option)
1577      if attr != 0:
1578        options[option] = (attr == rewriter_config_pb2.RewriterConfig.ON)
1579
1580    def rewriter_bool(option):
1581      options[option] = getattr(rewrite_options, option)
1582
1583    rewriter_toggle("layout_optimizer")
1584    rewriter_toggle("constant_folding")
1585    rewriter_toggle("shape_optimization")
1586    rewriter_toggle("remapping")
1587    rewriter_toggle("arithmetic_optimization")
1588    rewriter_toggle("dependency_optimization")
1589    rewriter_toggle("loop_optimization")
1590    rewriter_toggle("function_optimization")
1591    rewriter_toggle("debug_stripper")
1592    rewriter_bool("disable_model_pruning")
1593    rewriter_toggle("scoped_allocator_optimization")
1594    rewriter_toggle("pin_to_host_optimization")
1595    rewriter_toggle("implementation_selector")
1596    rewriter_toggle("auto_mixed_precision")
1597    rewriter_bool("disable_meta_optimizer")
1598
1599    if rewrite_options.min_graph_nodes != 0:
1600      options["min_graph_nodes"] = rewrite_options.min_graph_nodes
1601
1602    return options
1603
1604  def set_optimizer_experimental_options(self, options):
1605    """Set experimental options for the optimizer.
1606
1607    Args:
1608      options: Dictionary of options to modify
1609    """
1610    self._optimizer_experimental_options.update(options)
1611
1612    self._thread_local_data.function_call_options = None
1613
1614  @property
1615  def intra_op_parallelism_threads(self):
1616    return self.config.intra_op_parallelism_threads
1617
1618  @intra_op_parallelism_threads.setter
1619  def intra_op_parallelism_threads(self, num_threads):
1620    if self._intra_op_parallelism_threads == num_threads:
1621      return
1622
1623    if self._context_handle is not None:
1624      raise RuntimeError(
1625          "Intra op parallelism cannot be modified after initialization.")
1626
1627    self._intra_op_parallelism_threads = num_threads
1628
1629  @property
1630  def inter_op_parallelism_threads(self):
1631    return self.config.inter_op_parallelism_threads
1632
1633  @inter_op_parallelism_threads.setter
1634  def inter_op_parallelism_threads(self, num_threads):
1635    if self._inter_op_parallelism_threads == num_threads:
1636      return
1637
1638    if self._context_handle is not None:
1639      raise RuntimeError(
1640          "Inter op parallelism cannot be modified after initialization.")
1641
1642    self._inter_op_parallelism_threads = num_threads
1643
1644  @property
1645  def soft_device_placement(self):
1646    return self.config.allow_soft_placement
1647
1648  @soft_device_placement.setter
1649  def soft_device_placement(self, enable):
1650    if self._context_handle is not None:
1651      pywrap_tfe.TFE_ContextSetSoftDevicePlacement(self._handle, enable)
1652
1653    self._soft_device_placement = enable
1654    self._thread_local_data.function_call_options = None
1655
1656  @property
1657  def log_device_placement(self):
1658    return self.config.log_device_placement
1659
1660  @log_device_placement.setter
1661  def log_device_placement(self, enable):
1662    if self._context_handle is not None:
1663      pywrap_tfe.TFE_ContextSetLogDevicePlacement(self._handle, enable)
1664
1665    self._log_device_placement = enable
1666    self._thread_local_data.function_call_options = None
1667
1668  @property
1669  def device_policy(self):
1670    # Only get the policy from the context if it has already been initialized
1671    if self._context_handle is not None:
1672      return pywrap_tfe.TFE_ContextGetDevicePlacementPolicy(self._handle)
1673
1674    return self._device_policy
1675
1676  @device_policy.setter
1677  def device_policy(self, policy):
1678    if policy is None:
1679      policy = DEVICE_PLACEMENT_SILENT
1680
1681    if self._device_policy != policy:
1682      self._device_policy = policy
1683
1684      # Only set the policy if the context has already been initialized
1685      if self._context_handle is not None:
1686        pywrap_tfe.TFE_ContextSetThreadLocalDevicePlacementPolicy(
1687            self._handle, self._device_policy)
1688
1689  @property
1690  def use_tfrt(self):
1691    return self._use_tfrt
1692
1693  @use_tfrt.setter
1694  def use_tfrt(self, tfrt):
1695    """Sets whether to use TFRT."""
1696    if not isinstance(tfrt, bool):
1697      raise ValueError("Expecting a boolean but got %s" % type(tfrt))
1698
1699    if self._use_tfrt != tfrt:
1700      if self._initialized:
1701        raise ValueError("use_tfrt should be set before being initialized.")
1702      self._use_tfrt = tfrt
1703
1704  def enable_run_metadata(self):
1705    """Enables tracing of op execution via RunMetadata.
1706
1707    To retrieve the accumulated metadata call context.export_run_metadata()
1708    and to stop tracing call context.disable_run_metadata().
1709    """
1710    self.ensure_initialized()
1711    pywrap_tfe.TFE_ContextEnableRunMetadata(self._handle)
1712
1713  def disable_run_metadata(self):
1714    """Disables tracing of op execution via RunMetadata."""
1715    if not self._context_handle:
1716      return
1717    pywrap_tfe.TFE_ContextDisableRunMetadata(self._context_handle)
1718
1719  def enable_graph_collection(self):
1720    """Enables graph collection of executed functions.
1721
1722    To retrieve the accumulated graphs call context.export_run_metadata()
1723    and to stop collecting graphs call context.disable_graph_collection().
1724    """
1725    self.ensure_initialized()
1726    pywrap_tfe.TFE_ContextEnableGraphCollection(self._handle)
1727
1728  def disable_graph_collection(self):
1729    """Disables graph collection of executed functions."""
1730    if not self._context_handle:
1731      return
1732    pywrap_tfe.TFE_ContextDisableGraphCollection(self._context_handle)
1733
1734  def export_run_metadata(self):
1735    """Returns a RunMetadata proto with accumulated information.
1736
1737    The returned protocol buffer contains information since the most recent call
1738    to either enable_run_metadata or export_run_metadata.
1739
1740    Returns:
1741      A RunMetadata protocol buffer. Or None if not enabled.
1742    """
1743    if not self._context_handle:
1744      return None
1745    with c_api_util.tf_buffer() as buffer_:
1746      pywrap_tfe.TFE_ContextExportRunMetadata(self._context_handle, buffer_)
1747      proto_data = pywrap_tf_session.TF_GetBuffer(buffer_)
1748    run_metadata = config_pb2.RunMetadata()
1749    run_metadata.ParseFromString(compat.as_bytes(proto_data))
1750    return run_metadata
1751
1752  @property
1753  def context_switches(self):
1754    """Returns a stack of context switches."""
1755    return self._context_switches
1756
1757
1758class _EagerDeviceContext(object):
1759  """Context-manager forcing placement of ops and Tensors on a device."""
1760
1761  __slots__ = ["_device_name", "_ctx", "_stack"]
1762
1763  def __init__(self, ctx, device_name):
1764    self._device_name = device_name
1765    self._ctx = ctx
1766    self._stack = []
1767
1768  def __enter__(self):
1769    ctx = self._ctx
1770    old_device_name = ctx.device_name
1771    old_device_spec = ctx.device_spec
1772    new_device_name = self._device_name
1773    cache_key = (old_device_name, new_device_name)
1774    try:
1775      new_device_name, new_device_spec = _device_parsing_cache[cache_key]
1776    except TypeError:
1777      # Error while trying to compute the cache key.
1778      raise ValueError("Expecting a string device name. Got %s(%s)" %
1779                       (type(new_device_name), new_device_name))
1780    except KeyError:
1781      # Handle a cache miss.
1782      if new_device_name is not None:
1783        if not isinstance(new_device_name, six.string_types):
1784          raise ValueError("Expecting a string device name. Got %s(%s)" %
1785                           (type(new_device_name), new_device_name))
1786        device_spec = pydev.DeviceSpec.from_string(new_device_name)
1787        if old_device_name:
1788          new_device_spec = copy.copy(old_device_spec)
1789        else:
1790          ctx.ensure_initialized()
1791          new_device_spec = pydev.DeviceSpec.from_string(
1792              ctx._context_devices[0])  # pylint: disable=protected-access
1793        new_device_spec = new_device_spec.make_merged_spec(device_spec)
1794      else:
1795        new_device_spec = pydev.DeviceSpec.from_string("")
1796      new_device_name = new_device_spec.to_string()
1797      _device_parsing_cache[cache_key] = (new_device_name, new_device_spec)
1798
1799    ctx._set_device(new_device_name, new_device_spec)  # pylint: disable=protected-access
1800    self._stack.append((old_device_name, old_device_spec, new_device_spec))
1801
1802  def __exit__(self, *ex_info):
1803    ctx = self._ctx
1804    old_device_name, old_device_spec, new_device_spec = self._stack[-1]
1805    if ctx.device_spec is not new_device_spec:
1806      raise RuntimeError(
1807          "Exiting device scope without proper scope nesting")
1808    del self._stack[-1]
1809    ctx._set_device(old_device_name, old_device_spec)  # pylint: disable=protected-access
1810
1811
1812# Do not set directly. Use _set_context.
1813_context = None
1814_context_lock = threading.Lock()
1815
1816
1817def _set_context_locked(ctx):
1818  global _context
1819  pywrap_tfe.TFE_Py_SetEagerContext(ctx)
1820  _context = ctx
1821
1822
1823def _set_context(ctx):
1824  with _context_lock:
1825    _set_context_locked(ctx)
1826
1827
1828def _create_context():
1829  with _context_lock:
1830    if _context is None:
1831      ctx = Context()
1832      _set_context_locked(ctx)
1833
1834
1835def _reset_context():
1836  """Clears and re-initializes the singleton context.
1837
1838  Should only be used for testing.
1839  """
1840  global _context
1841  global _device_parsing_cache
1842  with _context_lock:
1843    if _context is not None:
1844      _context._clear_caches()
1845      _context = None
1846  _create_context()
1847  _device_parsing_cache = {}
1848  pywrap_tfe.TFE_ClearScalarCache()
1849
1850
1851def context():
1852  """Returns a singleton context object."""
1853  if _context is None:
1854    _create_context()
1855  return _context
1856
1857
1858def context_safe():
1859  """Returns current context (or None if one hasn't been initialized)."""
1860  return _context
1861
1862
1863def ensure_initialized():
1864  """Initialize the context."""
1865  context().ensure_initialized()
1866
1867
1868def set_global_seed(seed):
1869  """Sets the eager mode seed."""
1870  context()._set_global_seed(seed)  # pylint: disable=protected-access
1871
1872
1873def global_seed():
1874  """Returns the eager mode seed."""
1875  return context()._seed  # pylint: disable=protected-access
1876
1877
1878def internal_operation_seed():
1879  """Returns the operation seed generated based on global seed."""
1880  return context()._internal_operation_seed()  # pylint: disable=protected-access
1881
1882
1883@tf_export("executing_eagerly", v1=[])
1884def executing_eagerly():
1885  """Checks whether the current thread has eager execution enabled.
1886
1887  Eager execution is enabled by default and this API returns `True`
1888  in most of cases. However, this API might return `False` in the following use
1889  cases.
1890
1891  *  Executing inside `tf.function`, unless under `tf.init_scope` or
1892     `tf.config.run_functions_eagerly(True)` is previously called.
1893  *  Executing inside a transformation function for `tf.dataset`.
1894  *  `tf.compat.v1.disable_eager_execution()` is called.
1895
1896  General case:
1897
1898  >>> print(tf.executing_eagerly())
1899  True
1900
1901  Inside `tf.function`:
1902
1903  >>> @tf.function
1904  ... def fn():
1905  ...   with tf.init_scope():
1906  ...     print(tf.executing_eagerly())
1907  ...   print(tf.executing_eagerly())
1908  >>> fn()
1909  True
1910  False
1911
1912  Inside `tf.function` after `tf.config.run_functions_eagerly(True)` is called:
1913
1914  >>> tf.config.run_functions_eagerly(True)
1915  >>> @tf.function
1916  ... def fn():
1917  ...   with tf.init_scope():
1918  ...     print(tf.executing_eagerly())
1919  ...   print(tf.executing_eagerly())
1920  >>> fn()
1921  True
1922  True
1923  >>> tf.config.run_functions_eagerly(False)
1924
1925  Inside a transformation function for `tf.dataset`:
1926
1927  >>> def data_fn(x):
1928  ...   print(tf.executing_eagerly())
1929  ...   return x
1930  >>> dataset = tf.data.Dataset.range(100)
1931  >>> dataset = dataset.map(data_fn)
1932  False
1933
1934  Returns:
1935    `True` if the current thread has eager execution enabled.
1936  """
1937  ctx = context_safe()
1938  if ctx is None:
1939    return default_execution_mode == EAGER_MODE
1940
1941  return ctx.executing_eagerly()
1942
1943
1944@tf_export(v1=["executing_eagerly"])
1945def executing_eagerly_v1():
1946  """Checks whether the current thread has eager execution enabled.
1947
1948  Eager execution is typically enabled via
1949  `tf.compat.v1.enable_eager_execution`, but may also be enabled within the
1950  context of a Python function via tf.contrib.eager.py_func.
1951
1952  When eager execution is enabled, returns `True` in most cases. However,
1953  this API might return `False` in the following use cases.
1954
1955  *  Executing inside `tf.function`, unless under `tf.init_scope` or
1956     `tf.config.run_functions_eagerly(True)` is previously called.
1957  *  Executing inside a transformation function for `tf.dataset`.
1958  *  `tf.compat.v1.disable_eager_execution()` is called.
1959
1960  >>> tf.compat.v1.enable_eager_execution()
1961
1962  General case:
1963
1964  >>> print(tf.executing_eagerly())
1965  True
1966
1967  Inside `tf.function`:
1968
1969  >>> @tf.function
1970  ... def fn():
1971  ...   with tf.init_scope():
1972  ...     print(tf.executing_eagerly())
1973  ...   print(tf.executing_eagerly())
1974  >>> fn()
1975  True
1976  False
1977
1978  Inside `tf.function`
1979  after  `tf.config.run_functions_eagerly(True)` is called:
1980
1981  >>> tf.config.run_functions_eagerly(True)
1982  >>> @tf.function
1983  ... def fn():
1984  ...   with tf.init_scope():
1985  ...     print(tf.executing_eagerly())
1986  ...   print(tf.executing_eagerly())
1987  >>> fn()
1988  True
1989  True
1990  >>> tf.config.run_functions_eagerly(False)
1991
1992  Inside a transformation function for `tf.dataset`:
1993
1994  >>> def data_fn(x):
1995  ...   print(tf.executing_eagerly())
1996  ...   return x
1997  >>> dataset = tf.data.Dataset.range(100)
1998  >>> dataset = dataset.map(data_fn)
1999  False
2000
2001  Returns:
2002    `True` if the current thread has eager execution enabled.
2003  """
2004  return executing_eagerly()
2005
2006
2007def in_eager_mode():
2008  """Use executing_eagerly() instead. This function will be removed."""
2009  return executing_eagerly()
2010
2011
2012def shared_name(name=None):
2013  """Returns the anonymous shared name GUID if no shared name is specified.
2014
2015  In eager mode we need to use a unique shared name to avoid spurious sharing
2016  issues. The runtime generates a unique name on our behalf when the reserved
2017  GUID is used as a shared name.
2018
2019  Args:
2020    name: Optional shared name
2021
2022  Returns:
2023    Eager compatible shared name.
2024  """
2025  if name or not executing_eagerly():
2026    return name
2027
2028  # Ensure a unique name when eager execution is enabled to avoid spurious
2029  # sharing issues.
2030  return "cd2c89b7-88b7-44c8-ad83-06c2a9158347"
2031
2032
2033def graph_mode():
2034  """Context-manager to disable eager execution for the current thread."""
2035  return context()._mode(GRAPH_MODE)  # pylint: disable=protected-access
2036
2037
2038# Used by b/167638505 for keras backend API and Lambda layer.
2039@tf_export("__internal__.eager_context.eager_mode", v1=[])
2040def eager_mode():
2041  """Context-manager to enable eager execution for the current thread."""
2042  return context()._mode(EAGER_MODE)  # pylint: disable=protected-access
2043
2044
2045def scope_name():
2046  """Name of the current scope."""
2047  return context().scope_name
2048
2049
2050def device(name):
2051  """Context-manager to force placement of operations and Tensors on a device.
2052
2053  Example:
2054  ```python
2055  with tf.device('gpu:0'):
2056    with tf.device('cpu:0'):
2057      shape = tf.constant([], dtype=tf.int32)
2058    x = tf.random.truncated_normal(shape, tf.float32)
2059  ```
2060  will ensure that the `shape` Tensor is on CPU but the `truncated_normal`
2061  operation runs on GPU 0.
2062
2063  Args:
2064    name: Name of the device (see context().devices()), or None to
2065      perform automatic placement.
2066
2067  Returns:
2068    Context manager for setting the device.
2069  """
2070  ensure_initialized()
2071  return context().device(name)
2072
2073
2074# Expose some properties of Context as internally public APIs (b/160348781).
2075@tf_export("__internal__.eager_context.get_config", v1=[])
2076def get_config():
2077  """Get the ConfigProto of Context.
2078
2079  Returns:
2080    The ConfigProto of Context.
2081  """
2082  return context().config
2083
2084
2085@tf_export("__internal__.eager_context.get_device_name", v1=[])
2086def get_device_name():
2087  """Get the device name for the current thread.
2088
2089  Returns:
2090    The device name for the current thread.
2091  """
2092  return context().device_name
2093
2094
2095@tf_export("__internal__.eager_context.set_soft_device_placement", v1=[])
2096def set_soft_device_placement(enabled):
2097  """Set if soft device placements should be allowed.
2098
2099  Args:
2100    enabled: Whether to enable soft device placement.
2101  """
2102  context().soft_device_placement = enabled
2103
2104
2105@tf_export("__internal__.eager_context.get_executor", v1=[])
2106def get_executor():
2107  """Get the Executor of the current thread.
2108
2109  Returns:
2110    The Executor of the current thread.
2111  """
2112  return context().executor
2113
2114
2115@tf_export("debugging.get_log_device_placement")
2116def get_log_device_placement():
2117  """Get if device placements are logged.
2118
2119  Returns:
2120    If device placements are logged.
2121  """
2122  return context().log_device_placement
2123
2124
2125@tf_export("debugging.set_log_device_placement")
2126def set_log_device_placement(enabled):
2127  """Set if device placements should be logged.
2128
2129  Args:
2130    enabled: Whether to enabled device placement logging.
2131  """
2132  context().log_device_placement = enabled
2133
2134
2135@tf_contextlib.contextmanager
2136def device_policy(policy):
2137  """Context manager for setting device placement policy for current thread."""
2138  ctx = context()
2139  old_policy = ctx.device_policy
2140  try:
2141    ctx.device_policy = policy
2142    yield
2143  finally:
2144    ctx.device_policy = old_policy
2145
2146
2147def set_execution_mode(mode):
2148  """Sets execution mode for the current thread."""
2149  context().execution_mode = mode
2150
2151
2152# TODO(fishx): remove this method.
2153@tf_contextlib.contextmanager
2154def execution_mode(mode):
2155  """Context manager for setting execution mode for current thread."""
2156  if mode is None:
2157    yield
2158  else:
2159    ctx = context()
2160    executor_new = executor.new_executor(mode == ASYNC)
2161    executor_old = ctx.executor
2162    try:
2163      executor_old.wait()
2164      ctx.executor = executor_new
2165      yield
2166    finally:
2167      ctx.executor = executor_old
2168      executor_new.wait()
2169
2170
2171@tf_contextlib.contextmanager
2172def executor_scope(e):
2173  """Context manager for changing executor for current thread.
2174
2175  Args:
2176    e: A Executor to execute eager ops under this scope. Setting it to None will
2177      switch back to use the default executor for the context.
2178
2179  Yields:
2180    Context manager for setting the executor for current thread.
2181  """
2182  ctx = context()
2183  executor_old = ctx.executor
2184  try:
2185    ctx.executor = e
2186    yield
2187  finally:
2188    ctx.executor = executor_old
2189
2190
2191@tf_export("experimental.function_executor_type")
2192@tf_contextlib.contextmanager
2193def function_executor_type(executor_type):
2194  """Context manager for setting the executor of eager defined functions.
2195
2196  Eager defined functions are functions decorated by tf.contrib.eager.defun.
2197
2198  Args:
2199    executor_type: a string for the name of the executor to be used to execute
2200      functions defined by tf.contrib.eager.defun.
2201
2202  Yields:
2203    Context manager for setting the executor of eager defined functions.
2204  """
2205  current_options = context().function_call_options
2206  old_options = copy.copy(current_options)
2207  try:
2208    current_options.executor_type = executor_type
2209    yield
2210  finally:
2211    context().function_call_options = old_options
2212
2213
2214def is_async():
2215  """Returns true if current thread is in async mode."""
2216  return context().is_async()
2217
2218
2219def num_gpus():
2220  """Get the number of available GPU devices.
2221
2222  Returns:
2223    The number of available GPU devices.
2224  """
2225  return context().num_gpus()
2226
2227
2228def enable_run_metadata():
2229  """Enables tracing of op execution via RunMetadata.
2230
2231  To retrieve the accumulated metadata call context.export_run_metadata()
2232  and to stop tracing call context.disable_run_metadata().
2233  """
2234  context().enable_run_metadata()
2235
2236
2237def disable_run_metadata():
2238  """Disables tracing of op execution via RunMetadata."""
2239  context().disable_run_metadata()
2240
2241
2242def enable_graph_collection():
2243  """Enables graph collection of executed functions.
2244
2245  To retrieve the accumulated graphs call context.export_run_metadata()
2246  and to stop collecting graphs call context.disable_graph_collection().
2247  """
2248  context().enable_graph_collection()
2249
2250
2251def disable_graph_collection():
2252  """Disables graph collection of executed functions."""
2253  context().disable_graph_collection()
2254
2255
2256def export_run_metadata():
2257  """Returns a RunMetadata proto with accumulated information.
2258
2259  The returned protocol buffer contains information since the most recent call
2260  to either enable_run_metadata or export_run_metadata.
2261
2262  Returns:
2263    A RunMetadata protocol buffer.
2264  """
2265  return context().export_run_metadata()
2266
2267
2268@contextlib.contextmanager
2269def collect_graphs(optimized=True):
2270  """Collects a flat list of pre- or post-optimization graphs.
2271
2272  The collected graphs include device placements, which can be useful for
2273  testing.
2274
2275  Usage:
2276
2277  ```
2278  @def_function.function
2279  def f(x):
2280    return x + constant_op.constant(1.)
2281
2282  with context.collect_graphs() as graphs:
2283    with ops.device("CPU:0"):
2284      f(constant_op.constant(1.))
2285
2286  graph, = graphs  # `graph` contains a single GraphDef for inspection
2287  ```
2288
2289  Args:
2290    optimized: whether to collect optimized graphs or non-optimized graphs
2291  Yields:
2292    A list of GraphDefs, populated when the context manager exits.
2293  """
2294  ctx = context()
2295  ctx.enable_graph_collection()
2296  try:
2297    graphs = []
2298    yield graphs
2299    metadata = ctx.export_run_metadata()
2300  finally:
2301    ctx.disable_graph_collection()
2302  for graph in metadata.function_graphs:
2303    if optimized:
2304      graphs.append(graph.post_optimization_graph)
2305    else:
2306      graphs.append(graph.pre_optimization_graph)
2307
2308
2309def get_server_def():
2310  return context().get_server_def()
2311
2312
2313def set_server_def(server_def):
2314  context().set_server_def(server_def)
2315
2316
2317def update_server_def(server_def):
2318  context().update_server_def(server_def)
2319
2320
2321def check_alive(worker_name):
2322  return context().check_alive(worker_name)
2323
2324
2325@tf_export("experimental.async_scope")
2326@tf_contextlib.contextmanager
2327def async_scope():
2328  """Context manager for grouping async operations.
2329
2330  Ops/function calls inside the scope can return before finishing the actual
2331  execution. When exiting the async scope, a synchronization barrier will be
2332  automatically added to ensure the completion of all async op and function
2333  execution, potentially raising exceptions if async execution results in
2334  an error state.
2335
2336  Users may write the following code to asynchronously invoke `train_step_fn`
2337  and log the `loss` metric for every `num_steps` steps in a training loop.
2338  `train_step_fn` internally consumes data using `iterator.get_next()`, and may
2339  throw OutOfRangeError when running out of data. In the case:
2340
2341  ```
2342  try:
2343    with tf.experimental.async_scope():
2344      for _ in range(num_steps):
2345        # Step function updates the metric `loss` internally
2346        train_step_fn()
2347  except tf.errors.OutOfRangeError:
2348    tf.experimental.async_clear_error()
2349  logging.info('loss = %s', loss.numpy())
2350  ```
2351
2352  Yields:
2353    Context manager for grouping async operations.
2354  """
2355  # TODO(haoyuzhang): replace env var once we have a config method to turn on
2356  # and off async streaming RPC
2357  remote_async_env_var = "TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE"
2358  old_policy = os.environ.get(remote_async_env_var)
2359  try:
2360    os.environ[remote_async_env_var] = str(True)
2361    yield
2362    # Note: sync local and remote executors iff the async block does not raise
2363    # an exception. Triggering sync after an exception may lead to derived
2364    # runtime errors and unexpected exception types.
2365    context().sync_executors()
2366  finally:
2367    if old_policy is None:
2368      del os.environ[remote_async_env_var]
2369    else:
2370      os.environ[remote_async_env_var] = old_policy
2371
2372
2373def async_wait():
2374  """Sync all async operations and raise any errors during execution.
2375
2376  In async execution mode, an op/function call can return before finishing the
2377  actual execution. Calling this method creates a synchronization barrier for
2378  all async op and function execution. It only returns when all pending nodes
2379  are finished, potentially raising exceptions if async execution results in
2380  an error state.
2381  """
2382  context().sync_executors()
2383
2384
2385@tf_export("experimental.async_clear_error")
2386def async_clear_error():
2387  """Clear pending operations and error statuses in async execution.
2388
2389  In async execution mode, an error in op/function execution can lead to errors
2390  in subsequent ops/functions that are scheduled but not yet executed. Calling
2391  this method clears all pending operations and reset the async execution state.
2392
2393  Example:
2394
2395  ```
2396  while True:
2397    try:
2398      # Step function updates the metric `loss` internally
2399      train_step_fn()
2400    except tf.errors.OutOfRangeError:
2401      tf.experimental.async_clear_error()
2402      break
2403  logging.info('loss = %s', loss.numpy())
2404  ```
2405  """
2406  context().clear_executor_errors()
2407
2408
2409def add_function(fdef):
2410  """Add a function definition to the context."""
2411  context().add_function(fdef)
2412
2413
2414def remove_function(name):
2415  """Remove a function from the context."""
2416  context().remove_function(name)
2417
2418
2419def get_function_def(name):
2420  return context().get_function_def(name)
2421
2422
2423def register_custom_device(device_capsule, device_name, device_info_capsule):
2424  """Calls TFE_RegisterCustomDevice to register a custom device with Python.
2425
2426  Enables using C extensions specifying a custom device from Python. See the
2427  experimental eager C API in tensorflow/c/eager/c_api_experimental.h for
2428  details.
2429
2430  Note that custom devices are not currently supported inside `tf.function`s.
2431
2432  Args:
2433    device_capsule: A PyCapsule with the name set to 'TFE_CustomDevice'
2434      containing a pointer to a TFE_CustomDevice struct. The capsule retains
2435      ownership of the memory.
2436    device_name: A string indicating the name to register the custom device
2437      under, e.g. '/job:localhost/replica:0/task:0/device:CUSTOM:0'. It may
2438      subsequently be passed to `with tf.device(...):`.
2439    device_info_capsule: A PyCapsule with the name set to
2440      'TFE_CustomDevice_DeviceInfo' containing a pointer to a device-specific
2441      struct with the initial state of the custom device (the void* device_info
2442      argument to TFE_RegisterCustomDevice). This method takes ownership of the
2443      memory and clears the capsule destructor.
2444  """
2445  context().register_custom_device(device_capsule, device_name,
2446                                   device_info_capsule)
2447
2448
2449# Not every user creates a Context via context.context()
2450# (for example, enable_eager_execution in python/framework/ops.py),
2451# but they do all import this file.  Note that IS_IN_GRAPH_MODE and
2452# in_graph_mode are both parameterless functions.
2453def _tmp_in_graph_mode():
2454  if context_safe() is None:
2455    # Context not yet initialized. Assume graph mode following the
2456    # default implementation in `is_in_graph_mode`.
2457    return True
2458  return not executing_eagerly()
2459
2460
2461is_in_graph_mode.IS_IN_GRAPH_MODE = _tmp_in_graph_mode
2462