1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes and functions used to construct graphs."""
16# pylint: disable=g-bad-name
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import copy
23import linecache
24import os
25import re
26import sys
27import threading
28
29import numpy as np
30import six
31from six.moves import xrange  # pylint: disable=redefined-builtin
32
33from tensorflow.core.framework import attr_value_pb2
34from tensorflow.core.framework import function_pb2
35from tensorflow.core.framework import graph_pb2
36from tensorflow.core.framework import node_def_pb2
37from tensorflow.core.framework import op_def_pb2
38from tensorflow.core.framework import versions_pb2
39from tensorflow.core.protobuf import config_pb2
40from tensorflow.python import pywrap_tensorflow as c_api
41from tensorflow.python.eager import context
42from tensorflow.python.eager import core
43from tensorflow.python.eager import tape
44from tensorflow.python.framework import c_api_util
45from tensorflow.python.framework import device as pydev
46from tensorflow.python.framework import dtypes
47from tensorflow.python.framework import errors
48from tensorflow.python.framework import op_def_registry
49from tensorflow.python.framework import registry
50from tensorflow.python.framework import tensor_shape
51from tensorflow.python.framework import versions
52from tensorflow.python.ops import control_flow_util
53from tensorflow.python.platform import app
54from tensorflow.python.platform import tf_logging as logging
55from tensorflow.python.util import compat
56from tensorflow.python.util import decorator_utils
57from tensorflow.python.util import tf_contextlib
58from tensorflow.python.util.tf_export import tf_export
59
60
61# Temporary global switch determining if we should enable the work-in-progress
62# calls to the C API. Currently disabled by default but can be manually enabled
63# in code or via the environment variable. This will be removed once all
64# functionality is supported and there's no performance penalty with it enabled.
65_USE_C_API = os.getenv("TF_C_API_GRAPH_CONSTRUCTION", "0") is not "0"
66
67
68def tensor_id(tensor):
69  """Returns a unique identifier for this Tensor."""
70  return tensor._id  # pylint: disable=protected-access
71
72
73class _NullContextmanager(object):
74
75  def __enter__(self):
76    pass
77
78  def __exit__(self, type_arg, value_arg, traceback_arg):
79    return False  # False values do not suppress exceptions
80
81
82def _override_helper(clazz_object, operator, func):
83  """Overrides (string) operator on Tensors to call func.
84
85  Args:
86    clazz_object: the class to override for; either Tensor or SparseTensor.
87    operator: the string name of the operator to override.
88    func: the function that replaces the overridden operator.
89
90  Raises:
91    ValueError: If operator has already been overwritten,
92      or if operator is not allowed to be overwritten.
93  """
94  existing = getattr(clazz_object, operator, None)
95  if existing is not None:
96    # Check to see if this is a default method-wrapper or slot wrapper which
97    # will be true for the comparison operators.
98    if not isinstance(existing, type(object.__lt__)):
99      raise ValueError("operator %s cannot be overwritten again on class %s." %
100                       (operator, clazz_object))
101  if operator not in Tensor.OVERLOADABLE_OPERATORS:
102    raise ValueError("Overriding %s is disallowed" % operator)
103  setattr(clazz_object, operator, func)
104
105
106def _as_graph_element(obj):
107  """Convert `obj` to a graph element if possible, otherwise return `None`.
108
109  Args:
110    obj: Object to convert.
111
112  Returns:
113    The result of `obj._as_graph_element()` if that method is available;
114        otherwise `None`.
115  """
116  conv_fn = getattr(obj, "_as_graph_element", None)
117  if conv_fn and callable(conv_fn):
118    return conv_fn()
119  return None
120
121
122_TENSOR_LIKE_TYPES = tuple()
123
124
125def is_dense_tensor_like(t):
126  """EXPERIMENTAL: Returns true if `t` implements the tensor interface.
127
128  See `register_dense_tensor_like_type()` for the current definition of a
129  "tensor-like type".
130
131  Args:
132    t: An object.
133
134  Returns:
135    True iff `t` is an instance of one of the registered "tensor-like" types.
136  """
137  return isinstance(t, _TENSOR_LIKE_TYPES)
138
139
140def register_dense_tensor_like_type(tensor_type):
141  """EXPERIMENTAL: Registers `tensor_type` as implementing the tensor interface.
142
143  A "tensor-like type" can represent a single dense tensor, and implements
144  the `name` and `dtype` properties.
145
146  Args:
147    tensor_type: A type implementing the tensor interface.
148
149  Raises:
150    TypeError: If `tensor_type` does not implement the tensor interface.
151  """
152  try:
153    if not isinstance(tensor_type.name, property):
154      raise TypeError("Type %s does not define a `name` property" %
155                      tensor_type.__name__)
156  except AttributeError:
157    raise TypeError("Type %s does not define a `name` property" %
158                    tensor_type.__name__)
159  try:
160    if not isinstance(tensor_type.dtype, property):
161      raise TypeError("Type %s does not define a `dtype` property" %
162                      tensor_type.__name__)
163  except AttributeError:
164    raise TypeError("Type %s does not define a `dtype` property" %
165                    tensor_type.__name__)
166  # We expect this list to be small, so choose quadratic complexity
167  # for registration, so that we have a tuple that can be used for
168  # more efficient `isinstance` checks later.
169  global _TENSOR_LIKE_TYPES
170  _TENSOR_LIKE_TYPES = tuple(list(_TENSOR_LIKE_TYPES) + [tensor_type])
171
172
173def uid():
174  """A unique (within this program execution) integer."""
175  return c_api.TFE_Py_UID()
176
177
178def numpy_text(tensor, is_repr=False):
179  """Human readable representation of a tensor's numpy value."""
180  if tensor.dtype.is_numpy_compatible:
181    text = repr(tensor.numpy()) if is_repr else str(tensor.numpy())
182  else:
183    text = "<unprintable>"
184  if "\n" in text:
185    text = "\n" + text
186  return text
187
188
189# NOTE(ebrevdo): Do not subclass this.  If you do, I will break you on purpose.
190class _TensorLike(object):
191  """Internal cls for grouping Tensor, SparseTensor, ..., for is_instance."""
192  pass
193
194
195@tf_export("Tensor")
196class Tensor(_TensorLike):
197  """Represents one of the outputs of an `Operation`.
198
199  A `Tensor` is a symbolic handle to one of the outputs of an
200  `Operation`. It does not hold the values of that operation's output,
201  but instead provides a means of computing those values in a
202  TensorFlow @{tf.Session}.
203
204  This class has two primary purposes:
205
206  1. A `Tensor` can be passed as an input to another `Operation`.
207     This builds a dataflow connection between operations, which
208     enables TensorFlow to execute an entire `Graph` that represents a
209     large, multi-step computation.
210
211  2. After the graph has been launched in a session, the value of the
212     `Tensor` can be computed by passing it to
213     @{tf.Session.run}.
214     `t.eval()` is a shortcut for calling
215     `tf.get_default_session().run(t)`.
216
217  In the following example, `c`, `d`, and `e` are symbolic `Tensor`
218  objects, whereas `result` is a numpy array that stores a concrete
219  value:
220
221  ```python
222  # Build a dataflow graph.
223  c = tf.constant([[1.0, 2.0], [3.0, 4.0]])
224  d = tf.constant([[1.0, 1.0], [0.0, 1.0]])
225  e = tf.matmul(c, d)
226
227  # Construct a `Session` to execute the graph.
228  sess = tf.Session()
229
230  # Execute the graph and store the value that `e` represents in `result`.
231  result = sess.run(e)
232  ```
233  """
234
235  # List of Python operators that we allow to override.
236  OVERLOADABLE_OPERATORS = {
237      # Binary.
238      "__add__",
239      "__radd__",
240      "__sub__",
241      "__rsub__",
242      "__mul__",
243      "__rmul__",
244      "__div__",
245      "__rdiv__",
246      "__truediv__",
247      "__rtruediv__",
248      "__floordiv__",
249      "__rfloordiv__",
250      "__mod__",
251      "__rmod__",
252      "__lt__",
253      "__le__",
254      "__gt__",
255      "__ge__",
256      "__and__",
257      "__rand__",
258      "__or__",
259      "__ror__",
260      "__xor__",
261      "__rxor__",
262      "__getitem__",
263      "__pow__",
264      "__rpow__",
265      # Unary.
266      "__invert__",
267      "__neg__",
268      "__abs__",
269      "__matmul__",
270      "__rmatmul__"
271  }
272
273  def __init__(self, op, value_index, dtype):
274    """Creates a new `Tensor`.
275
276    Args:
277      op: An `Operation`. `Operation` that computes this tensor.
278      value_index: An `int`. Index of the operation's endpoint that produces
279        this tensor.
280      dtype: A `DType`. Type of elements stored in this tensor.
281
282    Raises:
283      TypeError: If the op is not an `Operation`.
284    """
285    if not isinstance(op, Operation):
286      raise TypeError("op needs to be an Operation: %s" % op)
287    self._op = op
288    self._value_index = value_index
289    self._dtype = dtypes.as_dtype(dtype)
290    self._shape_val = tensor_shape.unknown_shape()
291    # List of operations that use this Tensor as input.  We maintain this list
292    # to easily navigate a computation graph.
293    self._consumers = []
294
295    # Attributes used for C++ shape inference. Not inspected, only forwarded.
296    # If set, will be a HandleData object from cpp_shape_inference.proto.
297    self._handle_data = None
298    self._id = uid()
299
300  @property
301  def op(self):
302    """The `Operation` that produces this tensor as an output."""
303    return self._op
304
305  @property
306  def dtype(self):
307    """The `DType` of elements in this tensor."""
308    return self._dtype
309
310  @property
311  def graph(self):
312    """The `Graph` that contains this tensor."""
313    return self._op.graph
314
315  @property
316  def name(self):
317    """The string name of this tensor."""
318    if not self._op.name:
319      raise ValueError("Operation was not named: %s" % self._op)
320    return "%s:%d" % (self._op.name, self._value_index)
321
322  @property
323  def device(self):
324    """The name of the device on which this tensor will be produced, or None."""
325    return self._op.device
326
327  @property
328  def shape(self):
329    """Returns the `TensorShape` that represents the shape of this tensor.
330
331    The shape is computed using shape inference functions that are
332    registered in the Op for each `Operation`.  See
333    @{tf.TensorShape}
334    for more details of what a shape represents.
335
336    The inferred shape of a tensor is used to provide shape
337    information without having to launch the graph in a session. This
338    can be used for debugging, and providing early error messages. For
339    example:
340
341    ```python
342    c = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
343
344    print(c.shape)
345    ==> TensorShape([Dimension(2), Dimension(3)])
346
347    d = tf.constant([[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]])
348
349    print(d.shape)
350    ==> TensorShape([Dimension(4), Dimension(2)])
351
352    # Raises a ValueError, because `c` and `d` do not have compatible
353    # inner dimensions.
354    e = tf.matmul(c, d)
355
356    f = tf.matmul(c, d, transpose_a=True, transpose_b=True)
357
358    print(f.shape)
359    ==> TensorShape([Dimension(3), Dimension(4)])
360    ```
361
362    In some cases, the inferred shape may have unknown dimensions. If
363    the caller has additional information about the values of these
364    dimensions, `Tensor.set_shape()` can be used to augment the
365    inferred shape.
366
367    Returns:
368      A `TensorShape` representing the shape of this tensor.
369
370    """
371    if _USE_C_API:
372      graph = self._op._graph._c_graph  # pylint: disable=protected-access
373      with errors.raise_exception_on_not_ok_status() as status:
374        num_dims = c_api.TF_GraphGetTensorNumDims(graph, self._as_tf_output(),
375                                                  status)
376      if num_dims == -1:
377        dim_list = None
378      else:
379        with errors.raise_exception_on_not_ok_status() as status:
380          dim_list = c_api.TF_GraphGetTensorShape_wrapper(
381              graph, self._as_tf_output(), num_dims, status)
382        dim_list = [None if i == -1 else i for i in dim_list]
383      return tensor_shape.TensorShape(dim_list)
384    return self._shape_val
385
386  @property
387  def _shape(self):
388    logging.warning("Tensor._shape is private, use Tensor.shape "
389                    "instead. Tensor._shape will eventually be removed.")
390    return self.shape
391
392  @_shape.setter
393  def _shape(self, value):
394    raise ValueError(
395        "Tensor._shape cannot be assigned, use Tensor.set_shape instead.")
396
397  def __iter__(self):
398    if context.in_graph_mode():
399      raise TypeError(
400          "`Tensor` objects are not iterable when eager execution is not "
401          "enabled. To iterate over this tensor use `tf.map_fn`.")
402    shape = self._shape_tuple()
403    if shape is None:
404      raise TypeError("Cannot iterate over a tensor with unknown shape.")
405    if not shape:
406      raise TypeError("Cannot iterate over a scalar tensor.")
407    if shape[0] is None:
408      raise TypeError(
409          "Cannot iterate over a tensor with unknown first dimension.")
410    for i in xrange(shape[0]):
411      yield self[i]
412
413  def _shape_as_list(self):
414    if self.shape.ndims is not None:
415      return [dim.value for dim in self.shape.dims]
416    else:
417      return None
418
419  def _shape_tuple(self):
420    shape = self._shape_as_list()
421    if shape is None:
422      return None
423    return tuple(shape)
424
425  def _rank(self):
426    """Integer rank of this Tensor, if known, else None.
427
428    Returns:
429      Integer rank or None
430    """
431    return self.shape.ndims
432
433  def get_shape(self):
434    """Alias of Tensor.shape."""
435    return self.shape
436
437  def set_shape(self, shape):
438    """Updates the shape of this tensor.
439
440    This method can be called multiple times, and will merge the given
441    `shape` with the current shape of this tensor. It can be used to
442    provide additional information about the shape of this tensor that
443    cannot be inferred from the graph alone. For example, this can be used
444    to provide additional information about the shapes of images:
445
446    ```python
447    _, image_data = tf.TFRecordReader(...).read(...)
448    image = tf.image.decode_png(image_data, channels=3)
449
450    # The height and width dimensions of `image` are data dependent, and
451    # cannot be computed without executing the op.
452    print(image.shape)
453    ==> TensorShape([Dimension(None), Dimension(None), Dimension(3)])
454
455    # We know that each image in this dataset is 28 x 28 pixels.
456    image.set_shape([28, 28, 3])
457    print(image.shape)
458    ==> TensorShape([Dimension(28), Dimension(28), Dimension(3)])
459    ```
460
461    Args:
462      shape: A `TensorShape` representing the shape of this tensor, a
463      `TensorShapeProto`, a list, a tuple, or None.
464
465    Raises:
466      ValueError: If `shape` is not compatible with the current shape of
467        this tensor.
468    """
469    if not _USE_C_API:
470      self._shape_val = self._shape_val.merge_with(shape)
471      return
472    if not isinstance(shape, tensor_shape.TensorShape):
473      shape = tensor_shape.TensorShape(shape)
474    dim_list = []
475    if shape.dims is None:
476      unknown_shape = True
477    else:
478      unknown_shape = False
479      for dim in shape.dims:
480        if dim.value is None:
481          dim_list.append(-1)
482        else:
483          dim_list.append(dim.value)
484    try:
485      with errors.raise_exception_on_not_ok_status() as status:
486        c_api.TF_GraphSetTensorShape_wrapper(
487            self._op._graph._c_graph,  # pylint: disable=protected-access
488            self._as_tf_output(),
489            dim_list,
490            unknown_shape,
491            status)
492    except errors.InvalidArgumentError as e:
493      # Convert to ValueError for backwards compatibility.
494      raise ValueError(str(e))
495
496  @property
497  def value_index(self):
498    """The index of this tensor in the outputs of its `Operation`."""
499    return self._value_index
500
501  def consumers(self):
502    """Returns a list of `Operation`s that consume this tensor.
503
504    Returns:
505      A list of `Operation`s.
506    """
507    if self._op._c_op:  # pylint: disable=protected-access
508      consumer_names = c_api.TF_OperationOutputConsumers_wrapper(
509          self._as_tf_output())
510      # pylint: disable=protected-access
511      return [
512          self.graph._get_operation_by_name_unsafe(name)
513          for name in consumer_names
514      ]
515      # pylint: enable=protected-access
516    else:
517      return self._consumers
518
519  def _add_consumer(self, consumer):
520    """Add a consumer to this tensor.
521
522    Args:
523      consumer: an Operation.
524
525    Raises:
526      TypeError: if the consumer is not an Operation.
527    """
528    # pylint: disable=protected-access
529    assert not self._op._c_op, "Tensor._add_consumer doesn't work with C API"
530    # pylint: enable=protected-access
531    if not isinstance(consumer, Operation):
532      raise TypeError("Consumer must be an Operation: %s" % consumer)
533    self._consumers.append(consumer)
534
535  def _as_node_def_input(self):
536    """Return a value to use for the NodeDef "input" attribute.
537
538    The returned string can be used in a NodeDef "input" attribute
539    to indicate that the NodeDef uses this Tensor as input.
540
541    Raises:
542      ValueError: if this Tensor's Operation does not have a name.
543
544    Returns:
545      a string.
546    """
547    if not self._op.name:
548      raise ValueError("Operation was not named: %s" % self._op)
549    if self._value_index == 0:
550      return self._op.name
551    else:
552      return "%s:%d" % (self._op.name, self._value_index)
553
554  def _as_tf_output(self):
555    # pylint: disable=protected-access
556    assert self.op._c_op
557    return c_api_util.tf_output(self.op._c_op, self.value_index)
558    # pylint: enable=protected-access
559
560  def __str__(self):
561    return "Tensor(\"%s\"%s%s%s)" % (
562        self.name, (", shape=%s" % self.get_shape())
563        if self.get_shape().ndims is not None else "",
564        (", dtype=%s" % self._dtype.name)
565        if self._dtype else "", (", device=%s" % self.device)
566        if self.device else "")
567
568  def __repr__(self):
569    return "<tf.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.get_shape(),
570                                                   self._dtype.name)
571
572  def __hash__(self):
573    # Necessary to support Python's collection membership operators
574    return id(self)
575
576  def __eq__(self, other):
577    # Necessary to support Python's collection membership operators
578    return id(self) == id(other)
579
580  # NOTE(mrry): This enables the Tensor's overloaded "right" binary
581  # operators to run when the left operand is an ndarray, because it
582  # accords the Tensor class higher priority than an ndarray, or a
583  # numpy matrix.
584  # TODO(mrry): Convert this to using numpy's __numpy_ufunc__
585  # mechanism, which allows more control over how Tensors interact
586  # with ndarrays.
587  __array_priority__ = 100
588
589  @staticmethod
590  def _override_operator(operator, func):
591    _override_helper(Tensor, operator, func)
592
593  def __bool__(self):
594    """Dummy method to prevent a tensor from being used as a Python `bool`.
595
596    This overload raises a `TypeError` when the user inadvertently
597    treats a `Tensor` as a boolean (e.g. in an `if` statement). For
598    example:
599
600    ```python
601    if tf.constant(True):  # Will raise.
602      # ...
603
604    if tf.constant(5) < tf.constant(7):  # Will raise.
605      # ...
606    ```
607
608    This disallows ambiguities between testing the Python value vs testing the
609    dynamic condition of the `Tensor`.
610
611    Raises:
612      `TypeError`.
613    """
614    raise TypeError("Using a `tf.Tensor` as a Python `bool` is not allowed. "
615                    "Use `if t is not None:` instead of `if t:` to test if a "
616                    "tensor is defined, and use TensorFlow ops such as "
617                    "tf.cond to execute subgraphs conditioned on the value of "
618                    "a tensor.")
619
620  def __nonzero__(self):
621    """Dummy method to prevent a tensor from being used as a Python `bool`.
622
623    This is the Python 2.x counterpart to `__bool__()` above.
624
625    Raises:
626      `TypeError`.
627    """
628    raise TypeError("Using a `tf.Tensor` as a Python `bool` is not allowed. "
629                    "Use `if t is not None:` instead of `if t:` to test if a "
630                    "tensor is defined, and use TensorFlow ops such as "
631                    "tf.cond to execute subgraphs conditioned on the value of "
632                    "a tensor.")
633
634  def eval(self, feed_dict=None, session=None):
635    """Evaluates this tensor in a `Session`.
636
637    Calling this method will execute all preceding operations that
638    produce the inputs needed for the operation that produces this
639    tensor.
640
641    *N.B.* Before invoking `Tensor.eval()`, its graph must have been
642    launched in a session, and either a default session must be
643    available, or `session` must be specified explicitly.
644
645    Args:
646      feed_dict: A dictionary that maps `Tensor` objects to feed values.
647        See @{tf.Session.run} for a
648        description of the valid feed values.
649      session: (Optional.) The `Session` to be used to evaluate this tensor. If
650        none, the default session will be used.
651
652    Returns:
653      A numpy array corresponding to the value of this tensor.
654
655    """
656    return _eval_using_default_session(self, feed_dict, self.graph, session)
657
658
659# TODO(agarwal): consider getting rid of this.
660class _EagerTensorBase(Tensor):
661  """Base class for EagerTensor."""
662
663  @property
664  def dtype(self):
665    # Note: using the intern table directly here as this is
666    # performance-sensitive in some models.
667    return dtypes._INTERN_TABLE[self._datatype_enum()]  # pylint: disable=protected-access
668
669  def numpy(self):
670    """Returns a numpy array or a scalar with the same contents as the Tensor.
671
672    TODO(ashankar,agarwal): Perhaps this should NOT reference the underlying
673    buffer but instead always explicitly copy? Note that currently it may or may
674    not copy based on whether the numpy data is properly aligned or not.
675
676    Returns:
677      A numpy array or a scalar. Numpy array may share memory with the
678      Tensor object. Any changes to one may be reflected in the other. A scalar
679      value is returned when self has rank 0.
680
681    Raises:
682      ValueError: if the type of this Tensor is not representable in numpy.
683    """
684    if self.dtype == dtypes.resource:
685      raise ValueError("Resource handles are not convertible to numpy.")
686    return self.cpu()._numpy()  # pylint: disable=protected-access
687
688  # __int__ and  __float__ may copy the tensor to CPU and
689  # only work for scalars; values are cast as per numpy.
690  def __int__(self):
691    return int(self.numpy())
692
693  def __float__(self):
694    return float(self.numpy())
695
696  def __array__(self, dtype=None):
697    return np.array(self.numpy(), dtype=dtype)
698
699  def __format__(self, format_spec):
700    return self.numpy().__format__(format_spec)
701
702  def _numpy(self):
703    raise NotImplementedError()
704
705  def __copy__(self):
706    # Eager Tensors are immutable so it's safe to return themselves as a copy.
707    return self
708
709  def __deepcopy__(self, memo):
710    # Eager Tensors are immutable so it's safe to return themselves as a copy.
711    del memo
712    return self
713
714  def _datatype_enum(self):
715    raise NotImplementedError()
716
717  def _shape_tuple(self):
718    """The shape of this Tensor, as a tuple.
719
720    This is more performant than tuple(shape().as_list()) as it avoids
721    two list and one object creation. Marked private for now as from an API
722    perspective, it would be better to have a single performant way of
723    getting a shape rather than exposing shape() and shape_tuple()
724    (and heaven forbid, shape_list() etc. as well!). Punting on that for now,
725    but ideally one would work things out and remove the need for this method.
726
727    Returns:
728      tuple with the shape.
729    """
730    raise NotImplementedError()
731
732  def _rank(self):
733    """Integer rank of this Tensor.
734
735    Unlike regular Tensors, the rank is always known for EagerTensors.
736
737    This is more performant than len(self._shape_tuple())
738
739    Returns:
740      Integer rank
741    """
742    raise NotImplementedError()
743
744  def _copy_to_device(self, context, device):  # pylint: disable=redefined-outer-name
745    raise NotImplementedError()
746
747  def __str__(self):
748    return "tf.Tensor(%s, shape=%s, dtype=%s)" % (numpy_text(self),
749                                                  self.shape,
750                                                  self.dtype.name)
751
752  def __repr__(self):
753    return "<tf.Tensor: id=%s, shape=%s, dtype=%s, numpy=%s>" % (
754        self._id, self.shape, self.dtype.name, numpy_text(self, is_repr=True))
755
756  @staticmethod
757  def _override_operator(name, func):
758    setattr(_EagerTensorBase, name, func)
759
760  def _copy(self, ctx=None, device_name=None):
761    """Copies tensor to dest device."""
762    # pylint: disable=protected-access
763    # Creates a new tensor on the dest device.
764    if ctx is None:
765      ctx = context.context()
766    if device_name is None:
767      device_name = ctx.device_name
768    # pylint: disable=protected-access
769    try:
770      new_tensor = self._copy_to_device(context=ctx._handle, device=device_name)
771    except core._NotOkStatusException as e:
772      six.raise_from(core._status_to_exception(e.code, e.message), None)
773
774    # Record the copy on tape and define backprop copy as well.
775    if not context.in_graph_mode():
776      self_device = self.device
777      def grad_fun(dresult):
778        return [dresult._copy(device_name=self_device)]
779      tape.record_operation("_copy", [new_tensor], [self], grad_fun)
780    return new_tensor
781    # pylint: enable=protected-access
782
783  @property
784  def shape(self):
785    return tensor_shape.TensorShape(self._shape_tuple())
786
787  def get_shape(self):
788    """Alias of Tensor.shape."""
789    return self.shape
790
791  def _shape_as_list(self):
792    """The shape of the tensor as a list."""
793    return list(self._shape_tuple())
794
795  @property
796  def ndim(self):
797    """Returns the number of Tensor dimensions."""
798    return self.shape.ndims
799
800  def cpu(self):
801    """A copy of this Tensor with contents backed by host memory."""
802    return self._copy(context.context(), "CPU:0")
803
804  def gpu(self, gpu_index=0):
805    """A copy of this Tensor with contents backed by memory on the GPU.
806
807    Arguments:
808      gpu_index: Identifies which GPU to place the contents on the returned
809        Tensor in.
810
811    Returns:
812      A GPU-memory backed Tensor object initialized with the same contents
813      as this Tensor.
814    """
815    return self._copy(context.context(), "GPU:" + str(gpu_index))
816
817  def __bool__(self):
818    if self._shape_tuple() != ():  # pylint: disable=g-explicit-bool-comparison
819      raise ValueError(
820          "Non-scalar tensor %s cannot be converted to boolean." % repr(self))
821    if self.dtype != dtypes.bool:
822      raise ValueError(
823          "Non-boolean tensor %s cannot be converted to boolean." % repr(self))
824    return bool(self.cpu().numpy())
825
826  def __nonzero__(self):
827    return self.__bool__()
828
829  def set_shape(self, shape):
830    if not self.shape.is_compatible_with(shape):
831      raise ValueError(
832          "EagerTensor's shape %s is not compatible with supplied shape %s" %
833          (self.shape, shape))
834
835  # Methods not supported / implemented for Eager Tensors.
836  @property
837  def op(self):
838    raise AttributeError("op not supported for Eager Tensors.")
839
840  @property
841  def graph(self):
842    raise AttributeError("graph not supported for Eager Tensors.")
843
844  @property
845  def name(self):
846    raise AttributeError("name not supported for Eager Tensors.")
847
848  @property
849  def value_index(self):
850    raise AttributeError("value_index not supported for Eager Tensors.")
851
852  def consumers(self):
853    raise NotImplementedError("consumers not supported for Eager Tensors.")
854
855  def _add_consumer(self, consumer):
856    raise NotImplementedError("_add_consumer not supported for Eager Tensors.")
857
858  def _as_node_def_input(self):
859    raise NotImplementedError(
860        "_as_node_def_input not supported for Eager Tensors.")
861
862  def _as_tf_output(self):
863    raise NotImplementedError("_as_tf_output not supported for Eager Tensors.")
864
865  def eval(self, feed_dict=None, session=None):
866    raise NotImplementedError("eval not supported for Eager Tensors.")
867
868
869# This call creates an EagerTensor class, as a subclass of _EagerTensorBase, and
870# registers it with the current module.
871EagerTensor = c_api.TFE_Py_InitEagerTensor(_EagerTensorBase)
872
873
874def _TensorTensorConversionFunction(t, dtype=None, name=None, as_ref=False):
875  _ = name, as_ref
876  if dtype and not dtype.is_compatible_with(t.dtype):
877    raise ValueError(
878        "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
879        (dtype.name, t.dtype.name, str(t)))
880  return t
881
882
883_tensor_conversion_func_registry = {
884    0: [(Tensor, _TensorTensorConversionFunction)]
885}
886_tensor_conversion_func_cache = {}
887_tensor_conversion_func_lock = threading.Lock()
888register_dense_tensor_like_type(Tensor)
889
890
891@tf_export("convert_to_tensor")
892def convert_to_tensor(value, dtype=None, name=None, preferred_dtype=None):
893  """Converts the given `value` to a `Tensor`.
894
895  This function converts Python objects of various types to `Tensor`
896  objects. It accepts `Tensor` objects, numpy arrays, Python lists,
897  and Python scalars. For example:
898
899  ```python
900  import numpy as np
901
902  def my_func(arg):
903    arg = tf.convert_to_tensor(arg, dtype=tf.float32)
904    return tf.matmul(arg, arg) + arg
905
906  # The following calls are equivalent.
907  value_1 = my_func(tf.constant([[1.0, 2.0], [3.0, 4.0]]))
908  value_2 = my_func([[1.0, 2.0], [3.0, 4.0]])
909  value_3 = my_func(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32))
910  ```
911
912  This function can be useful when composing a new operation in Python
913  (such as `my_func` in the example above). All standard Python op
914  constructors apply this function to each of their Tensor-valued
915  inputs, which allows those ops to accept numpy arrays, Python lists,
916  and scalars in addition to `Tensor` objects.
917
918  Note: This function diverges from default Numpy behavior for `float` and
919    `string` types when `None` is present in a Python list or scalar. Rather
920    than silently converting `None` values, an error will be thrown.
921
922  Args:
923    value: An object whose type has a registered `Tensor` conversion function.
924    dtype: Optional element type for the returned tensor. If missing, the
925      type is inferred from the type of `value`.
926    name: Optional name to use if a new `Tensor` is created.
927    preferred_dtype: Optional element type for the returned tensor,
928      used when dtype is None. In some cases, a caller may not have a
929      dtype in mind when converting to a tensor, so preferred_dtype
930      can be used as a soft preference.  If the conversion to
931      `preferred_dtype` is not possible, this argument has no effect.
932
933  Returns:
934    An `Output` based on `value`.
935
936  Raises:
937    TypeError: If no conversion function is registered for `value`.
938    RuntimeError: If a registered conversion function returns an invalid value.
939
940  """
941  return internal_convert_to_tensor(
942      value=value,
943      dtype=dtype,
944      name=name,
945      preferred_dtype=preferred_dtype,
946      as_ref=False)
947
948
949def _error_prefix(name):
950  return "" if name is None else "%s: " % name
951
952
953def internal_convert_to_tensor(value,
954                               dtype=None,
955                               name=None,
956                               as_ref=False,
957                               preferred_dtype=None,
958                               ctx=None):
959  """Converts the given `value` to an `Tensor`.
960
961  This function converts Python objects of various types to `Tensor`
962  objects. It accepts `Tensor` objects, numpy arrays, Python lists,
963  and Python scalars. For example:
964
965  This function can be useful when composing a new operation in Python
966  All standard Python op constructors apply this function to each of their
967  Tensor-valued inputs, which allows those ops to accept numpy arrays, Python
968  lists, and scalars in addition to `Tensor` objects.
969
970  Args:
971    value: An object whose type has a registered `Tensor` conversion function.
972    dtype: Optional element type for the returned tensor. If missing, the
973      type is inferred from the type of `value`.
974    name: Optional name to use if a new `Tensor` is created.
975    as_ref: True if we want the mutable view of Variables, if applicable.
976    preferred_dtype: Optional element type for the returned tensor,
977      used when dtype is None. In some cases, a caller may not have a
978      dtype in mind when converting to a tensor, so preferred_dtype
979      can be used as a soft preference.  If the conversion to
980      `preferred_dtype` is not possible, this argument has no effect.
981    ctx: Optional: The value of context.context().
982
983  Returns:
984    A `Tensor` based on `value`.
985
986  Raises:
987    TypeError: If no conversion function is registered for `value`.
988    RuntimeError: If a registered conversion function returns an invalid value.
989
990  """
991  if ctx is None: ctx = context.context()
992  if ctx.in_eager_mode():
993    # Fast path for EagerTensors that don't need any conversion.
994    if isinstance(value, EagerTensor):
995      # Note that we don't check that value's dtype matches the dtype
996      # argument.  We expect that the C runtime will do that checking
997      # when we execute the kernel.
998      return value
999
1000  if dtype is not None:
1001    dtype = dtypes.as_dtype(dtype)
1002  unwrapped_type = type(value)
1003  conversion_func_list = _tensor_conversion_func_cache.get(unwrapped_type, None)
1004  if conversion_func_list is None:
1005    with _tensor_conversion_func_lock:
1006      conversion_func_list = []
1007      for _, funcs_at_priority in sorted(
1008          _tensor_conversion_func_registry.items()):
1009        for base_type, conversion_func in funcs_at_priority:
1010          if isinstance(value, base_type):
1011            conversion_func_list.append((base_type, conversion_func))
1012      _tensor_conversion_func_cache[unwrapped_type] = conversion_func_list
1013
1014  for base_type, conversion_func in conversion_func_list:
1015    # If dtype is None but preferred_dtype is not None, we try to
1016    # cast to preferred_dtype first.
1017    ret = None
1018    if dtype is None and preferred_dtype is not None:
1019      try:
1020        ret = conversion_func(
1021            value, dtype=preferred_dtype, name=name, as_ref=as_ref)
1022      except (TypeError, ValueError, errors.UnimplementedError,
1023              errors.InvalidArgumentError):
1024        # Could not coerce the conversion to use the preferred dtype.
1025        ret = None
1026
1027      if ret is not None and ret is not NotImplemented:
1028        if (ret.dtype.base_dtype !=
1029            dtypes.as_dtype(preferred_dtype).base_dtype):
1030          raise TypeError("convert_to_tensor did not convert to "
1031                          "the preferred dtype: %s vs %s " %
1032                          (ret.dtype.base_dtype,
1033                           dtypes.as_dtype(preferred_dtype).base_dtype))
1034
1035    if ret is None:
1036      ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
1037
1038    if ret is NotImplemented:
1039      continue
1040
1041    if not isinstance(ret, Tensor):
1042      raise RuntimeError(
1043          "%sConversion function %r for type %s returned non-Tensor: %r" %
1044          (_error_prefix(name), conversion_func, base_type, ret))
1045    if dtype and not dtype.is_compatible_with(ret.dtype):
1046      raise RuntimeError(
1047          "%sConversion function %r for type %s returned incompatible "
1048          "dtype: requested = %s, actual = %s" %
1049          (_error_prefix(name), conversion_func, base_type, dtype.name,
1050           ret.dtype.name))
1051    return ret
1052  raise TypeError("%sCannot convert %r with type %s to Tensor: "
1053                  "no conversion function registered." %
1054                  (_error_prefix(name), value, unwrapped_type))
1055
1056
1057def internal_convert_n_to_tensor(values,
1058                                 dtype=None,
1059                                 name=None,
1060                                 as_ref=False,
1061                                 preferred_dtype=None,
1062                                 ctx=None):
1063  """Converts `values` to a list of `Tensor` objects.
1064
1065  Args:
1066    values: A list of objects that can be consumed by `tf.convert_to_tensor()`.
1067    dtype: (Optional.) The required `DType` of the returned `Tensor` objects.
1068    name: (Optional.) A name prefix to used when a new `Tensor` is
1069      created, in which case element `i` will be given the name `name
1070      + '_' + i`.
1071    as_ref: True if the caller wants the results as ref tensors.
1072    preferred_dtype: Optional element type for the returned tensors,
1073      used when dtype is None. In some cases, a caller may not have a
1074      dtype in mind when converting to a tensor, so preferred_dtype
1075      can be used as a soft preference.  If the conversion to
1076      `preferred_dtype` is not possible, this argument has no effect.
1077    ctx: The value of context.context().
1078
1079  Returns:
1080    A list of `Tensor` and/or `IndexedSlices` objects.
1081
1082  Raises:
1083    TypeError: If no conversion function is registered for an element in
1084      `values`.
1085    RuntimeError: If a registered conversion function returns an invalid
1086      value.
1087  """
1088  if not isinstance(values, collections.Sequence):
1089    raise TypeError("values must be a list.")
1090  ret = []
1091  if ctx is None: ctx = context.context()
1092  for i, value in enumerate(values):
1093    n = None if name is None else "%s_%d" % (name, i)
1094    ret.append(
1095        internal_convert_to_tensor(
1096            value,
1097            dtype=dtype,
1098            name=n,
1099            as_ref=as_ref,
1100            preferred_dtype=preferred_dtype,
1101            ctx=ctx))
1102  return ret
1103
1104
1105def convert_n_to_tensor(values, dtype=None, name=None, preferred_dtype=None):
1106  """Converts `values` to a list of `Tensor` objects.
1107
1108  Args:
1109    values: A list of objects that can be consumed by `tf.convert_to_tensor()`.
1110    dtype: (Optional.) The required `DType` of the returned `Tensor` objects.
1111    name: (Optional.) A name prefix to used when a new `Tensor` is
1112      created, in which case element `i` will be given the name `name
1113      + '_' + i`.
1114    preferred_dtype: Optional element type for the returned tensors,
1115      used when dtype is None. In some cases, a caller may not have a
1116      dtype in mind when converting to a tensor, so preferred_dtype
1117      can be used as a soft preference.  If the conversion to
1118      `preferred_dtype` is not possible, this argument has no effect.
1119
1120  Returns:
1121    A list of `Tensor` and/or `IndexedSlices` objects.
1122
1123  Raises:
1124    TypeError: If no conversion function is registered for an element in
1125      `values`.
1126    RuntimeError: If a registered conversion function returns an invalid
1127      value.
1128  """
1129  return internal_convert_n_to_tensor(
1130      values=values,
1131      dtype=dtype,
1132      name=name,
1133      preferred_dtype=preferred_dtype,
1134      as_ref=False)
1135
1136
1137@tf_export("convert_to_tensor_or_indexed_slices")
1138def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None):
1139  """Converts the given object to a `Tensor` or an `IndexedSlices`.
1140
1141  If `value` is an `IndexedSlices` or `SparseTensor` it is returned
1142  unmodified. Otherwise, it is converted to a `Tensor` using
1143  `convert_to_tensor()`.
1144
1145  Args:
1146    value: An `IndexedSlices`, `SparseTensor`, or an object that can be consumed
1147      by `convert_to_tensor()`.
1148    dtype: (Optional.) The required `DType` of the returned `Tensor` or
1149      `IndexedSlices`.
1150    name: (Optional.) A name to use if a new `Tensor` is created.
1151
1152  Returns:
1153    An `Tensor`, `IndexedSlices`, or `SparseTensor` based on `value`.
1154
1155  Raises:
1156    ValueError: If `dtype` does not match the element type of `value`.
1157  """
1158  return internal_convert_to_tensor_or_indexed_slices(
1159      value=value, dtype=dtype, name=name, as_ref=False)
1160
1161
1162def internal_convert_to_tensor_or_indexed_slices(value,
1163                                                 dtype=None,
1164                                                 name=None,
1165                                                 as_ref=False):
1166  """Converts the given object to an `Tensor` or an `IndexedSlices`.
1167
1168  If `value` is an `IndexedSlices` or `SparseTensor` it is returned
1169  unmodified. Otherwise, it is converted to a `Tensor` using
1170  `convert_to_tensor()`.
1171
1172  Args:
1173    value: An `IndexedSlices`, `SparseTensor`, or an object that can be consumed
1174      by `convert_to_tensor()`.
1175    dtype: (Optional.) The required `DType` of the returned `Tensor` or
1176      `IndexedSlices`.
1177    name: (Optional.) A name to use if a new `Tensor` is created.
1178    as_ref: True if the caller wants the results as ref tensors.
1179
1180  Returns:
1181    An `Tensor`, `IndexedSlices`, or `SparseTensor` based on `value`.
1182
1183  Raises:
1184    ValueError: If `dtype` does not match the element type of `value`.
1185  """
1186  if isinstance(value, _TensorLike):
1187    if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value.dtype):
1188      raise ValueError(
1189          "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
1190          (dtypes.as_dtype(dtype).name, value.dtype.name, str(value)))
1191    return value
1192  else:
1193    return internal_convert_to_tensor(
1194        value, dtype=dtype, name=name, as_ref=as_ref)
1195
1196
1197def internal_convert_n_to_tensor_or_indexed_slices(values,
1198                                                   dtype=None,
1199                                                   name=None,
1200                                                   as_ref=False):
1201  """Converts `values` to a list of `Tensor` or `IndexedSlices` objects.
1202
1203  Any `IndexedSlices` or `SparseTensor` objects in `values` are returned
1204  unmodified.
1205
1206  Args:
1207    values: A list of `None`, `IndexedSlices`, `SparseTensor`, or objects that
1208      can be consumed by `convert_to_tensor()`.
1209    dtype: (Optional.) The required `DType` of the returned `Tensor`
1210      `IndexedSlices`.
1211    name: (Optional.) A name prefix to used when a new `Tensor` is
1212      created, in which case element `i` will be given the name `name
1213      + '_' + i`.
1214    as_ref: True if the caller wants the results as ref tensors.
1215
1216  Returns:
1217    A list of `Tensor`, `IndexedSlices`, and/or `SparseTensor` objects.
1218
1219  Raises:
1220    TypeError: If no conversion function is registered for an element in
1221      `values`.
1222    RuntimeError: If a registered conversion function returns an invalid
1223      value.
1224  """
1225  if not isinstance(values, collections.Sequence):
1226    raise TypeError("values must be a list.")
1227  ret = []
1228  for i, value in enumerate(values):
1229    if value is None:
1230      ret.append(value)
1231    else:
1232      n = None if name is None else "%s_%d" % (name, i)
1233      ret.append(
1234          internal_convert_to_tensor_or_indexed_slices(
1235              value, dtype=dtype, name=n, as_ref=as_ref))
1236  return ret
1237
1238
1239def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None):
1240  """Converts `values` to a list of `Output` or `IndexedSlices` objects.
1241
1242  Any `IndexedSlices` or `SparseTensor` objects in `values` are returned
1243  unmodified.
1244
1245  Args:
1246    values: A list of `None`, `IndexedSlices`, `SparseTensor`, or objects that
1247      can be consumed by `convert_to_tensor()`.
1248    dtype: (Optional.) The required `DType` of the returned `Tensor`
1249      `IndexedSlices`.
1250    name: (Optional.) A name prefix to used when a new `Tensor` is
1251      created, in which case element `i` will be given the name `name
1252      + '_' + i`.
1253
1254  Returns:
1255    A list of `Tensor`, `IndexedSlices`, and/or `SparseTensor` objects.
1256
1257  Raises:
1258    TypeError: If no conversion function is registered for an element in
1259      `values`.
1260    RuntimeError: If a registered conversion function returns an invalid
1261      value.
1262  """
1263  return internal_convert_n_to_tensor_or_indexed_slices(
1264      values=values, dtype=dtype, name=name, as_ref=False)
1265
1266
1267# TODO(josh11b): Add ctx argument to conversion_func() signature.
1268@tf_export("register_tensor_conversion_function")
1269def register_tensor_conversion_function(base_type,
1270                                        conversion_func,
1271                                        priority=100):
1272  """Registers a function for converting objects of `base_type` to `Tensor`.
1273
1274  The conversion function must have the following signature:
1275
1276  ```python
1277      def conversion_func(value, dtype=None, name=None, as_ref=False):
1278        # ...
1279  ```
1280
1281  It must return a `Tensor` with the given `dtype` if specified. If the
1282  conversion function creates a new `Tensor`, it should use the given
1283  `name` if specified. All exceptions will be propagated to the caller.
1284
1285  The conversion function may return `NotImplemented` for some
1286  inputs. In this case, the conversion process will continue to try
1287  subsequent conversion functions.
1288
1289  If `as_ref` is true, the function must return a `Tensor` reference,
1290  such as a `Variable`.
1291
1292  NOTE: The conversion functions will execute in order of priority,
1293  followed by order of registration. To ensure that a conversion function
1294  `F` runs before another conversion function `G`, ensure that `F` is
1295  registered with a smaller priority than `G`.
1296
1297  Args:
1298    base_type: The base type or tuple of base types for all objects that
1299      `conversion_func` accepts.
1300    conversion_func: A function that converts instances of `base_type` to
1301      `Tensor`.
1302    priority: Optional integer that indicates the priority for applying this
1303      conversion function. Conversion functions with smaller priority values
1304      run earlier than conversion functions with larger priority values.
1305      Defaults to 100.
1306
1307  Raises:
1308    TypeError: If the arguments do not have the appropriate type.
1309
1310  """
1311  global _tensor_conversion_func_cache
1312  with _tensor_conversion_func_lock:
1313    if not (isinstance(base_type, type) or
1314            (isinstance(base_type, tuple) and
1315             all(isinstance(x, type) for x in base_type))):
1316      raise TypeError("base_type must be a type or a tuple of types.")
1317    if not callable(conversion_func):
1318      raise TypeError("conversion_func must be callable.")
1319
1320    try:
1321      funcs_at_priority = _tensor_conversion_func_registry[priority]
1322    except KeyError:
1323      funcs_at_priority = []
1324      _tensor_conversion_func_registry[priority] = funcs_at_priority
1325    funcs_at_priority.append((base_type, conversion_func))
1326    _tensor_conversion_func_cache = {}
1327
1328
1329@tf_export("IndexedSlices")
1330class IndexedSlices(_TensorLike):
1331  """A sparse representation of a set of tensor slices at given indices.
1332
1333  This class is a simple wrapper for a pair of `Tensor` objects:
1334
1335  * `values`: A `Tensor` of any dtype with shape `[D0, D1, ..., Dn]`.
1336  * `indices`: A 1-D integer `Tensor` with shape `[D0]`.
1337
1338  An `IndexedSlices` is typically used to represent a subset of a larger
1339  tensor `dense` of shape `[LARGE0, D1, .. , DN]` where `LARGE0 >> D0`.
1340  The values in `indices` are the indices in the first dimension of
1341  the slices that have been extracted from the larger tensor.
1342
1343  The dense tensor `dense` represented by an `IndexedSlices` `slices` has
1344
1345  ```python
1346  dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...]
1347  ```
1348
1349  The `IndexedSlices` class is used principally in the definition of
1350  gradients for operations that have sparse gradients
1351  (e.g. @{tf.gather}).
1352
1353  Contrast this representation with
1354  @{tf.SparseTensor},
1355  which uses multi-dimensional indices and scalar values.
1356  """
1357
1358  def __init__(self, values, indices, dense_shape=None):
1359    """Creates an `IndexedSlices`."""
1360    _get_graph_from_inputs([values, indices, dense_shape])
1361    self._values = values
1362    self._indices = indices
1363    self._dense_shape = dense_shape
1364
1365  @property
1366  def values(self):
1367    """A `Tensor` containing the values of the slices."""
1368    return self._values
1369
1370  @property
1371  def indices(self):
1372    """A 1-D `Tensor` containing the indices of the slices."""
1373    return self._indices
1374
1375  @property
1376  def dense_shape(self):
1377    """A 1-D `Tensor` containing the shape of the corresponding dense tensor."""
1378    return self._dense_shape
1379
1380  @property
1381  def name(self):
1382    """The name of this `IndexedSlices`."""
1383    return self.values.name
1384
1385  @property
1386  def device(self):
1387    """The name of the device on which `values` will be produced, or `None`."""
1388    return self.values.device
1389
1390  @property
1391  def op(self):
1392    """The `Operation` that produces `values` as an output."""
1393    return self.values.op
1394
1395  @property
1396  def dtype(self):
1397    """The `DType` of elements in this tensor."""
1398    return self.values.dtype
1399
1400  @property
1401  def graph(self):
1402    """The `Graph` that contains the values, indices, and shape tensors."""
1403    return self._values.graph
1404
1405  def __str__(self):
1406    return "IndexedSlices(indices=%s, values=%s%s)" % (
1407        self._indices, self._values, (", dense_shape=%s" % self._dense_shape)
1408        if self._dense_shape is not None else "")
1409
1410  def __neg__(self):
1411    return IndexedSlices(-self.values, self.indices, self.dense_shape)
1412
1413
1414IndexedSlicesValue = collections.namedtuple(
1415    "IndexedSlicesValue", ["values", "indices", "dense_shape"])
1416
1417
1418def _device_string(dev_spec):
1419  if isinstance(dev_spec, pydev.DeviceSpec):
1420    return dev_spec.to_string()
1421  else:
1422    return dev_spec
1423
1424
1425def _NodeDef(op_type, name, device=None, attrs=None):  # pylint: disable=redefined-outer-name
1426  """Create a NodeDef proto.
1427
1428  Args:
1429    op_type: Value for the "op" attribute of the NodeDef proto.
1430    name: Value for the "name" attribute of the NodeDef proto.
1431    device: string, device, or function from NodeDef to string.
1432      Value for the "device" attribute of the NodeDef proto.
1433    attrs: Optional dictionary where the key is the attribute name (a string)
1434      and the value is the respective "attr" attribute of the NodeDef proto (an
1435      AttrValue).
1436
1437  Returns:
1438    A node_def_pb2.NodeDef protocol buffer.
1439  """
1440  node_def = node_def_pb2.NodeDef()
1441  node_def.op = compat.as_bytes(op_type)
1442  node_def.name = compat.as_bytes(name)
1443  if attrs is not None:
1444    for k, v in six.iteritems(attrs):
1445      node_def.attr[k].CopyFrom(v)
1446  if device is not None:
1447    if callable(device):
1448      node_def.device = device(node_def)
1449    else:
1450      node_def.device = _device_string(device)
1451  return node_def
1452
1453
1454# Copied from core/framework/node_def_util.cc
1455# TODO(mrry,josh11b): Consolidate this validation in C++ code.
1456_VALID_OP_NAME_REGEX = re.compile("^[A-Za-z0-9.][A-Za-z0-9_.\\-/]*$")
1457_VALID_SCOPE_NAME_REGEX = re.compile("^[A-Za-z0-9_.\\-/]*$")
1458
1459
1460def _create_c_op(graph, node_def, inputs, control_inputs):
1461  """Creates a TF_Operation.
1462
1463  Args:
1464    graph: a `Graph`.
1465    node_def: `node_def_pb2.NodeDef` for the operation to create.
1466    inputs: A list of `Tensor`s (corresponding to scalar inputs) and lists of
1467      `Tensor`s (corresponding to sequence inputs, e.g. "int64 * N",
1468      "list(int64)"). The length of the list should be equal to the number of
1469      inputs specified by this operation's op def.
1470    control_inputs: A list of `Operation`s to set as control dependencies.
1471
1472  Returns:
1473    A wrapped TF_Operation*.
1474  """
1475  # pylint: disable=protected-access
1476  op_desc = c_api.TF_NewOperation(graph._c_graph,
1477                                  compat.as_str(node_def.op),
1478                                  compat.as_str(node_def.name))
1479  # Add inputs
1480  for op_input in inputs:
1481    if isinstance(op_input, (list, tuple)):
1482      c_api.TF_AddInputList(op_desc, [t._as_tf_output() for t in op_input])
1483    else:
1484      c_api.TF_AddInput(op_desc, op_input._as_tf_output())
1485
1486  # Add control inputs
1487  for control_input in control_inputs:
1488    c_api.TF_AddControlInput(op_desc, control_input._c_op)
1489  # pylint: enable=protected-access
1490
1491  # Add attrs
1492  for name, attr_value in node_def.attr.items():
1493    serialized = attr_value.SerializeToString()
1494    # TODO(skyewm): this creates and deletes a new TF_Status for every attr.
1495    # It might be worth creating a convenient way to re-use the same status.
1496    with errors.raise_exception_on_not_ok_status() as status:
1497      c_api.TF_SetAttrValueProto(op_desc,
1498                                 compat.as_str(name), serialized, status)
1499
1500  try:
1501    with errors.raise_exception_on_not_ok_status() as status:
1502      c_op = c_api.TF_FinishOperation(op_desc, status)
1503  except errors.InvalidArgumentError as e:
1504    # Convert to ValueError for backwards compatibility.
1505    raise ValueError(str(e))
1506
1507  return c_op
1508
1509
1510@tf_export("Operation")
1511class Operation(object):
1512  """Represents a graph node that performs computation on tensors.
1513
1514  An `Operation` is a node in a TensorFlow `Graph` that takes zero or
1515  more `Tensor` objects as input, and produces zero or more `Tensor`
1516  objects as output. Objects of type `Operation` are created by
1517  calling a Python op constructor (such as
1518  @{tf.matmul})
1519  or @{tf.Graph.create_op}.
1520
1521  For example `c = tf.matmul(a, b)` creates an `Operation` of type
1522  "MatMul" that takes tensors `a` and `b` as input, and produces `c`
1523  as output.
1524
1525  After the graph has been launched in a session, an `Operation` can
1526  be executed by passing it to
1527  @{tf.Session.run}.
1528  `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`.
1529  """
1530
1531  def __init__(self,
1532               node_def,
1533               g,
1534               inputs=None,
1535               output_types=None,
1536               control_inputs=None,
1537               input_types=None,
1538               original_op=None,
1539               op_def=None):
1540    r"""Creates an `Operation`.
1541
1542    NOTE: This constructor validates the name of the `Operation` (passed
1543    as `node_def.name`). Valid `Operation` names match the following
1544    regular expression:
1545
1546        [A-Za-z0-9.][A-Za-z0-9_.\\-/]*
1547
1548    Args:
1549      node_def: `node_def_pb2.NodeDef`.  `NodeDef` for the `Operation`.
1550        Used for attributes of `node_def_pb2.NodeDef`, typically `name`,
1551        `op`, and `device`.  The `input` attribute is irrelevant here
1552        as it will be computed when generating the model.
1553      g: `Graph`. The parent graph.
1554      inputs: list of `Tensor` objects. The inputs to this `Operation`.
1555      output_types: list of `DType` objects.  List of the types of the
1556        `Tensors` computed by this operation.  The length of this list indicates
1557        the number of output endpoints of the `Operation`.
1558      control_inputs: list of operations or tensors from which to have a
1559        control dependency.
1560      input_types: List of `DType` objects representing the
1561        types of the tensors accepted by the `Operation`.  By default
1562        uses `[x.dtype.base_dtype for x in inputs]`.  Operations that expect
1563        reference-typed inputs must specify these explicitly.
1564      original_op: Optional. Used to associate the new `Operation` with an
1565        existing `Operation` (for example, a replica with the op that was
1566        replicated).
1567      op_def: Optional. The `op_def_pb2.OpDef` proto that describes the
1568        op type that this `Operation` represents.
1569
1570    Raises:
1571      TypeError: if control inputs are not Operations or Tensors,
1572        or if `node_def` is not a `NodeDef`,
1573        or if `g` is not a `Graph`,
1574        or if `inputs` are not tensors,
1575        or if `inputs` and `input_types` are incompatible.
1576      ValueError: if the `node_def` name is not valid.
1577    """
1578    # For internal use only: `node_def` can be set to a TF_Operation to create
1579    # an Operation for that op. This is useful for creating Operations for ops
1580    # indirectly created by C API methods, e.g. the ops created by
1581    # TF_ImportGraphDef. When `node_def` is a TF_Operation, all optional fields
1582    # should be None.
1583
1584    if isinstance(node_def, node_def_pb2.NodeDef):
1585      if node_def.ByteSize() >= (1 << 31) or node_def.ByteSize() < 0:
1586        raise ValueError(
1587            "Cannot create a tensor proto whose content is larger than 2GB.")
1588      if not _VALID_OP_NAME_REGEX.match(node_def.name):
1589        raise ValueError("'%s' is not a valid node name" % node_def.name)
1590      c_op = None
1591    elif type(node_def).__name__ == "SwigPyObject":
1592      assert inputs is None
1593      assert output_types is None
1594      assert control_inputs is None
1595      assert input_types is None
1596      assert original_op is None
1597      assert op_def is None
1598      c_op = node_def
1599    else:
1600      raise TypeError("node_def needs to be a NodeDef: %s" % node_def)
1601
1602    if not isinstance(g, Graph):
1603      raise TypeError("g needs to be a Graph: %s" % g)
1604    self._graph = g
1605
1606    if inputs is None:
1607      inputs = []
1608    elif not isinstance(inputs, list):
1609      raise TypeError("inputs needs to be a list of Tensors: %s" % inputs)
1610    for a in inputs:
1611      if not isinstance(a, Tensor):
1612        raise TypeError("input needs to be a Tensor: %s" % a)
1613    if input_types is None:
1614      input_types = [i.dtype.base_dtype for i in inputs]
1615    else:
1616      if not all(
1617          x.is_compatible_with(i.dtype)
1618          for i, x in zip(inputs, input_types)):
1619        raise TypeError("In op '%s', input types (%s) are not compatible "
1620                        "with expected types (%s)" %
1621                        (node_def.name, [i.dtype for i in inputs],
1622                         input_types))
1623
1624    # Build the list of control inputs.
1625    control_input_ops = []
1626    if control_inputs:
1627      for c in control_inputs:
1628        control_op = None
1629        if isinstance(c, Operation):
1630          control_op = c
1631        elif isinstance(c, (Tensor, IndexedSlices)):
1632          control_op = c.op
1633        else:
1634          raise TypeError("Control input must be an Operation, "
1635                          "a Tensor, or IndexedSlices: %s" % c)
1636        control_input_ops.append(control_op)
1637
1638    # Don't set private fields with C API enabled to catch users who need to
1639    # switch to public API.
1640    # TODO(skyewm): delete these fields once we remove _USE_C_API
1641    if not self._graph._c_graph:
1642      self._inputs_val = list(inputs)  # Defensive copy.
1643      self._input_types_val = input_types
1644      self._control_inputs_val = control_input_ops
1645      self._node_def_val = copy.deepcopy(node_def)
1646      self._op_def_val = op_def
1647
1648    self._id_value = self._graph._next_id()  # pylint: disable=protected-access
1649    self._original_op = original_op
1650    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access
1651    self._control_flow_context = self.graph._get_control_flow_context()  # pylint: disable=protected-access
1652
1653    # Initialize self._c_op.
1654    if c_op:
1655      # TODO(skyewm): remove this assert when we remove USE_C_API
1656      assert self._graph._c_graph  # pylint: disable=protected-access
1657      self._c_op = c_op
1658    elif self._graph._c_graph:  # pylint: disable=protected-access
1659      if op_def is None:
1660        op_def = self._graph._get_op_def(node_def.op)
1661      # TODO(skyewm): op_def_library.apply_op() flattens the incoming inputs.
1662      # Refactor so we don't have to do this here.
1663      grouped_inputs = self._reconstruct_sequence_inputs(
1664          op_def, inputs, node_def.attr)
1665      self._c_op = _create_c_op(self._graph, node_def, grouped_inputs,
1666                                control_input_ops)
1667    else:
1668      self._c_op = None
1669
1670    # Mark that we consume the inputs. This is unnecessary and unsupported with
1671    # the C API enabled, since the C API tracks the tensor consumers instead.
1672    if not self._c_op:
1673      for input_tensor in self._inputs_val:
1674        input_tensor._add_consumer(self)  # pylint: disable=protected-access
1675
1676    # Initialize self._outputs.
1677    if self._c_op:
1678      num_outputs = c_api.TF_OperationNumOutputs(self._c_op)
1679      output_types = [
1680          c_api.TF_OperationOutputType(c_api_util.tf_output(self._c_op, i))
1681          for i in range(num_outputs)]
1682      assert output_types is not None
1683    elif output_types is None:
1684      output_types = []
1685    self._output_types_val = output_types
1686    self._outputs = [
1687        Tensor(self, i, output_type)
1688        for i, output_type in enumerate(output_types)
1689    ]
1690
1691    if not c_op:
1692      self._control_flow_post_processing()
1693
1694  def _control_flow_post_processing(self):
1695    """Add this op to its control flow context.
1696
1697    This may add new ops and change this op's inputs. self.inputs must be
1698    available before calling this method.
1699    """
1700    for input_tensor in self.inputs:
1701      control_flow_util.CheckInputFromValidContext(self, input_tensor.op)
1702    if self._control_flow_context is not None:
1703      self._control_flow_context.AddOp(self)
1704    self._recompute_node_def()
1705
1706  def _reconstruct_sequence_inputs(self, op_def, inputs, attrs):
1707    """Regroups a flat list of input tensors into scalar and sequence inputs.
1708
1709    Args:
1710      op_def: The `op_def_pb2.OpDef` (for knowing the input types)
1711      inputs: a list of input `Tensor`s to the op.
1712      attrs: mapping from attr name to `attr_value_pb2.AttrValue` (these define
1713        how long each sequence is)
1714
1715    Returns:
1716      A list of `Tensor`s (corresponding to scalar inputs) and lists of
1717      `Tensor`s (corresponding to sequence inputs).
1718    """
1719    grouped_inputs = []
1720    i = 0
1721    for input_arg in op_def.input_arg:
1722      if input_arg.number_attr:
1723        input_len = attrs[input_arg.number_attr].i
1724        is_sequence = True
1725      elif input_arg.type_list_attr:
1726        input_len = len(attrs[input_arg.type_list_attr].list.type)
1727        is_sequence = True
1728      else:
1729        input_len = 1
1730        is_sequence = False
1731
1732      if is_sequence:
1733        grouped_inputs.append(inputs[i:i + input_len])
1734      else:
1735        grouped_inputs.append(inputs[i])
1736      i += input_len
1737
1738    assert i == len(inputs)
1739    return grouped_inputs
1740
1741  def colocation_groups(self):
1742    """Returns the list of colocation groups of the op."""
1743    default_colocation_group = [
1744        compat.as_bytes("loc:@%s" % self.name)
1745    ]
1746    try:
1747      class_attr = self.get_attr("_class")
1748    except ValueError:
1749      # This op has no explicit colocation group, so it is itself its
1750      # own root of a colocation group.
1751      return default_colocation_group
1752
1753    attr_groups = [
1754        class_name for class_name in class_attr
1755        if class_name.startswith(b"loc:@")
1756    ]
1757
1758    # If there are no colocation groups in the explicit _class field,
1759    # return the default colocation group.
1760    return attr_groups if attr_groups else default_colocation_group
1761
1762  def values(self):
1763    """DEPRECATED: Use outputs."""
1764    return tuple(self.outputs)
1765
1766  def _get_control_flow_context(self):
1767    """Returns the control flow context of this op.
1768
1769    Returns:
1770      A context object.
1771    """
1772    return self._control_flow_context
1773
1774  def _set_control_flow_context(self, ctx):
1775    """Sets the current control flow context of this op.
1776
1777    Args:
1778      ctx: a context object.
1779    """
1780    self._control_flow_context = ctx
1781
1782  @property
1783  def name(self):
1784    """The full name of this operation."""
1785    if self._c_op:
1786      return c_api.TF_OperationName(self._c_op)
1787    else:
1788      return self._node_def_val.name
1789
1790  @property
1791  def _id(self):
1792    """The unique integer id of this operation."""
1793    return self._id_value
1794
1795  @property
1796  def device(self):
1797    """The name of the device to which this op has been assigned, if any.
1798
1799    Returns:
1800      The string name of the device to which this op has been
1801      assigned, or an empty string if it has not been assigned to a
1802      device.
1803    """
1804    if self._c_op:
1805      return c_api.TF_OperationDevice(self._c_op)
1806    else:
1807      return self._node_def_val.device
1808
1809  @property
1810  def _output_types(self):
1811    """List this operation's output types.
1812
1813    Returns:
1814      List of the types of the Tensors computed by this operation.
1815      Each element in the list is an integer whose value is one of
1816      the TF_DataType enums defined in c_api.h
1817      The length of this list indicates the number of output endpoints
1818      of the operation.
1819    """
1820    if self._c_op:
1821      num_outputs = c_api.TF_OperationNumOutputs(self._c_op)
1822      output_types = [
1823          c_api.TF_OperationOutputType(self._tf_output(i))
1824          for i in xrange(num_outputs)
1825      ]
1826      # TODO(iga): Remove this assert after converting to C API by default.
1827      # Just being a bit paranoid here.
1828      assert self._output_types_val == output_types
1829      # In all the tests we have output_types that are passed into
1830      # Operation.__init__ are a list of ints (which is illegal according
1831      # to the docstring), but input_types are instances of DType.
1832      # This extra assert is to catch if we ever use DType for output_types.
1833      if output_types:
1834        assert isinstance(output_types[0], int)
1835      return output_types
1836    else:
1837      return self._output_types_val
1838
1839  def _tf_output(self, output_idx):
1840    """Create and return a new TF_Output for output_idx'th output of this op."""
1841    assert self._c_op
1842    tf_output = c_api.TF_Output()
1843    tf_output.oper = self._c_op
1844    tf_output.index = output_idx
1845    return tf_output
1846
1847  def _tf_input(self, input_idx):
1848    """Create and return a new TF_Input for input_idx'th input of this op."""
1849    assert self._c_op
1850    tf_input = c_api.TF_Input()
1851    tf_input.oper = self._c_op
1852    tf_input.index = input_idx
1853    return tf_input
1854
1855  def _set_device(self, device):  # pylint: disable=redefined-outer-name
1856    """Set the device of this operation.
1857
1858    Args:
1859      device: string or device..  The device to set.
1860    """
1861    if self._c_op:
1862      c_api.SetRequestedDevice(
1863          self._graph._c_graph,  # pylint: disable=protected-access
1864          self._c_op,  # pylint: disable=protected-access
1865          compat.as_str(_device_string(device)))
1866    else:
1867      self._node_def_val.device = _device_string(device)
1868
1869  def _add_input(self, tensor, dtype=None):
1870    """Add a new input to this operation.
1871
1872    Args:
1873      tensor: the Tensor to add as an input.
1874      dtype: tf.DType: type of the input; defaults to
1875        the tensor's dtype.
1876
1877    Raises:
1878      TypeError: if tensor is not a Tensor,
1879        or if input tensor type is not convertible to dtype.
1880      ValueError: if the Tensor is from a different graph.
1881    """
1882    assert not self._c_op, (
1883        "Operation._add_input doesn't work with C API")
1884    if not isinstance(tensor, Tensor):
1885      raise TypeError("tensor must be a Tensor: %s" % tensor)
1886    _assert_same_graph(self, tensor)
1887    if dtype is None:
1888      dtype = tensor.dtype
1889    else:
1890      dtype = dtypes.as_dtype(dtype)
1891      if not dtype.is_compatible_with(tensor.dtype):
1892        raise TypeError(
1893            "Cannot convert a tensor of type %s to an input of type %s" %
1894            (tensor.dtype.name, dtype.name))
1895    self._inputs_val.append(tensor)
1896    self._input_types_val.append(dtype)
1897    tensor._add_consumer(self)  # pylint: disable=protected-access
1898    self._recompute_node_def()
1899
1900  def _update_input(self, index, tensor):
1901    """Update the input to this operation at the given index.
1902
1903    NOTE: This is for TF internal use only. Please don't use it.
1904
1905    Args:
1906      index: the index of the input to update.
1907      tensor: the Tensor to be used as the input at the given index.
1908
1909    Raises:
1910      TypeError: if tensor is not a Tensor,
1911        or if input tensor type is not convertible to dtype.
1912      ValueError: if the Tensor is from a different graph.
1913    """
1914    if not isinstance(tensor, Tensor):
1915      raise TypeError("tensor must be a Tensor: %s" % tensor)
1916    _assert_same_graph(self, tensor)
1917    if self._c_op:
1918      with errors.raise_exception_on_not_ok_status() as status:
1919        c_api.UpdateEdge(
1920            self._graph._c_graph,  # pylint: disable=protected-access
1921            tensor._as_tf_output(),  # pylint: disable=protected-access
1922            self._tf_input(index),
1923            status)
1924    else:
1925      self._inputs_val[index].consumers().remove(self)
1926      self._inputs_val[index] = tensor
1927      self._input_types_val[index] = tensor.dtype
1928      tensor._add_consumer(self)  # pylint: disable=protected-access
1929      self._recompute_node_def()
1930
1931  def _add_control_inputs(self, ops):
1932    """Add a list of new control inputs to this operation.
1933
1934    Args:
1935      ops: the list of Operations to add as control input.
1936
1937    Raises:
1938      TypeError: if ops is not a list of Operations.
1939      ValueError: if any op in ops is from a different graph.
1940    """
1941    if self._c_op:
1942      for op in ops:
1943        if not isinstance(op, Operation):
1944          raise TypeError("op must be an Operation: %s" % op)
1945        c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op)  # pylint: disable=protected-access
1946    else:
1947      if ops:
1948        for op in ops:
1949          if not isinstance(op, Operation):
1950            raise TypeError("op must be an Operation: %s" % op)
1951          _assert_same_graph(self, op)
1952          self._control_inputs_val.append(op)
1953        self._recompute_node_def()
1954
1955  def _add_control_input(self, op):
1956    """Add a new control input to this operation.
1957
1958    Args:
1959      op: the Operation to add as control input.
1960
1961    Raises:
1962      TypeError: if op is not an Operation.
1963      ValueError: if op is from a different graph.
1964    """
1965    if self._c_op:
1966      if not isinstance(op, Operation):
1967        raise TypeError("op must be an Operation: %s" % op)
1968      c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op)  # pylint: disable=protected-access
1969    else:
1970      self._add_control_inputs([op])
1971
1972  def _remove_all_control_inputs(self):
1973    """Removes any control inputs to this operation."""
1974    if self._c_op:
1975      c_api.RemoveAllControlInputs(self._graph._c_graph, self._c_op)  # pylint: disable=protected-access
1976    else:
1977      del self.control_inputs[:]
1978
1979  # Methods below are used when building the NodeDef and Graph proto.
1980  def _recompute_node_def(self):
1981    # TODO(skyewm): remove this function when we switch to C API
1982    if self._c_op: return
1983
1984    del self._node_def_val.input[:]
1985    # pylint: disable=protected-access
1986    self._node_def_val.input.extend(
1987        [t._as_node_def_input() for t in self._inputs_val])
1988    # pylint: enable=protected-access
1989    if self._control_inputs_val:
1990      self._node_def_val.input.extend(
1991          ["^%s" % op.name for op in self._control_inputs_val])
1992
1993  def __str__(self):
1994    return str(self.node_def)
1995
1996  def __repr__(self):
1997    return "<tf.Operation '%s' type=%s>" % (self.name, self.type)
1998
1999  @property
2000  def outputs(self):
2001    """The list of `Tensor` objects representing the outputs of this op."""
2002    return self._outputs
2003
2004# pylint: disable=protected-access
2005
2006  class _InputList(object):
2007    """Immutable input list wrapper."""
2008
2009    def __init__(self, inputs):
2010      self._inputs = inputs
2011
2012    def __iter__(self):
2013      return iter(self._inputs)
2014
2015    def __len__(self):
2016      return len(self._inputs)
2017
2018    def __bool__(self):
2019      return bool(self._inputs)
2020
2021    # Python 3 wants __bool__, Python 2.7 wants __nonzero__
2022    __nonzero__ = __bool__
2023
2024    def __getitem__(self, i):
2025      return self._inputs[i]
2026
2027# pylint: enable=protected-access
2028
2029  @property
2030  def inputs(self):
2031    """The list of `Tensor` objects representing the data inputs of this op."""
2032    if self._c_op:
2033      tf_outputs = c_api.GetOperationInputs(self._c_op)
2034      # pylint: disable=protected-access
2035      retval = [
2036          self.graph._get_tensor_by_tf_output(tf_output)
2037          for tf_output in tf_outputs
2038      ]
2039      # pylint: enable=protected-access
2040      return Operation._InputList(retval)
2041    return Operation._InputList(self._inputs_val)
2042
2043  @property
2044  def _inputs(self):
2045    logging.warning("Operation._inputs is private, use Operation.inputs "
2046                    "instead. Operation._inputs will eventually be removed.")
2047    return self.inputs
2048
2049  @_inputs.setter
2050  def _inputs(self, value):
2051    raise ValueError("Cannot assign _inputs")
2052
2053  @property
2054  def _input_types(self):
2055    if self._c_op:
2056      num_inputs = c_api.TF_OperationNumInputs(self._c_op)
2057      input_types = [
2058          dtypes.as_dtype(c_api.TF_OperationInputType(self._tf_input(i)))
2059          for i in xrange(num_inputs)
2060      ]
2061      return input_types
2062    else:
2063      return self._input_types_val
2064
2065  @_input_types.setter
2066  def _input_types(self, value):
2067    raise ValueError("Cannot assign _input_types")
2068
2069  @property
2070  def control_inputs(self):
2071    """The `Operation` objects on which this op has a control dependency.
2072
2073    Before this op is executed, TensorFlow will ensure that the
2074    operations in `self.control_inputs` have finished executing. This
2075    mechanism can be used to run ops sequentially for performance
2076    reasons, or to ensure that the side effects of an op are observed
2077    in the correct order.
2078
2079    Returns:
2080      A list of `Operation` objects.
2081
2082    """
2083    if self._c_op:
2084      control_c_ops = c_api.TF_OperationGetControlInputs_wrapper(self._c_op)
2085      # pylint: disable=protected-access
2086      return [
2087          self.graph._get_operation_by_name_unsafe(
2088              c_api.TF_OperationName(c_op)) for c_op in control_c_ops
2089      ]
2090      # pylint: enable=protected-access
2091    else:
2092      return self._control_inputs_val
2093
2094  @property
2095  def _control_inputs(self):
2096    logging.warning("Operation._control_inputs is private, use "
2097                    "Operation.control_inputs instead. "
2098                    "Operation._control_inputs will eventually be removed.")
2099    return self.control_inputs
2100
2101  @_control_inputs.setter
2102  def _control_inputs(self, value):
2103    logging.warning("Operation._control_inputs is private, use "
2104                    "Operation.control_inputs instead. "
2105                    "Operation._control_inputs will eventually be removed.")
2106    # Copy value because it may be self._control_inputs_val (in particular if
2107    # this is called from self._control_inputs += ...), and we don't want to
2108    # clear value below.
2109    value = copy.copy(value)
2110    self._remove_all_control_inputs()
2111    self._add_control_inputs(value)
2112
2113  @property
2114  def type(self):
2115    """The type of the op (e.g. `"MatMul"`)."""
2116    if self._c_op:
2117      op_type = c_api.TF_OperationOpType(self._c_op)
2118      return op_type
2119    else:
2120      return self._node_def_val.op
2121
2122  @property
2123  def graph(self):
2124    """The `Graph` that contains this operation."""
2125    return self._graph
2126
2127  @property
2128  def node_def(self):
2129    # pylint: disable=line-too-long
2130    """Returns the `NodeDef` representation of this operation.
2131
2132    Returns:
2133      A
2134      [`NodeDef`](https://www.tensorflow.org/code/tensorflow/core/framework/node_def.proto)
2135      protocol buffer.
2136    """
2137    # pylint: enable=line-too-long
2138    if self._c_op:
2139      with c_api_util.tf_buffer() as buf:
2140        with errors.raise_exception_on_not_ok_status() as status:
2141          c_api.TF_OperationToNodeDef(self._c_op, buf, status)
2142        data = c_api.TF_GetBuffer(buf)
2143      node_def = node_def_pb2.NodeDef()
2144      node_def.ParseFromString(compat.as_bytes(data))
2145      return node_def
2146    else:
2147      return self._node_def_val
2148
2149  @property
2150  def _node_def(self):
2151    logging.warning("Operation._node_def is private, use Operation.node_def "
2152                    "instead. Operation._node_def will eventually be removed.")
2153    return self.node_def
2154
2155  @property
2156  def op_def(self):
2157    # pylint: disable=line-too-long
2158    """Returns the `OpDef` proto that represents the type of this op.
2159
2160    Returns:
2161      An
2162      [`OpDef`](https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto)
2163      protocol buffer.
2164    """
2165    # pylint: enable=line-too-long
2166    if self._c_op:
2167      return self._graph._get_op_def(self.type)
2168    else:
2169      return self._op_def_val
2170
2171  @property
2172  def _op_def(self):
2173    logging.warning("Operation._op_def is private, use Operation.op_def "
2174                    "instead. Operation._op_def will eventually be removed.")
2175    return self.op_def
2176
2177  @property
2178  def traceback(self):
2179    """Returns the call stack from when this operation was constructed."""
2180    return self._graph._convert_stack(self._traceback)  # pylint: disable=protected-access
2181
2182  @property
2183  def traceback_with_start_lines(self):
2184    """Same as traceback but includes start line of function definition.
2185
2186    Returns:
2187      A list of 5-tuples (filename, lineno, name, code, func_start_lineno).
2188    """
2189    return self._graph._convert_stack(  # pylint: disable=protected-access
2190        self._traceback,
2191        include_func_start_lineno=True)
2192
2193  def _set_attr(self, attr_name, attr_value):
2194    """Private method used to set an attribute in the node_def."""
2195    if self._c_op:
2196      buf = c_api.TF_NewBufferFromString(
2197          compat.as_bytes(attr_value.SerializeToString()))
2198      try:
2199        with errors.raise_exception_on_not_ok_status() as status:
2200          # pylint: disable=protected-access
2201          c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf,
2202                        status)
2203          # pylint: enable=protected-access
2204      finally:
2205        c_api.TF_DeleteBuffer(buf)
2206    else:
2207      self._node_def_val.attr[attr_name].CopyFrom(attr_value)
2208
2209  def get_attr(self, name):
2210    """Returns the value of the attr of this op with the given `name`.
2211
2212    Args:
2213      name: The name of the attr to fetch.
2214
2215    Returns:
2216      The value of the attr, as a Python object.
2217
2218    Raises:
2219      ValueError: If this op does not have an attr with the given `name`.
2220    """
2221    fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"]
2222    if self._c_op:
2223      try:
2224        with c_api_util.tf_buffer() as buf:
2225          with errors.raise_exception_on_not_ok_status() as status:
2226            c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf, status)
2227          data = c_api.TF_GetBuffer(buf)
2228      except errors.InvalidArgumentError as e:
2229        # Convert to ValueError for backwards compatibility.
2230        raise ValueError(str(e))
2231      x = attr_value_pb2.AttrValue()
2232      x.ParseFromString(data)
2233    else:
2234      if name not in self._node_def_val.attr:
2235        raise ValueError(
2236            "No attr named '" + name + "' in " + str(self._node_def_val))
2237      x = self._node_def_val.attr[name]
2238
2239    # Treat an empty oneof value as an empty list.
2240    if not x.WhichOneof("value"):
2241      return []
2242    if x.HasField("list"):
2243      for f in fields:
2244        if getattr(x.list, f):
2245          if f == "type":
2246            return [dtypes.as_dtype(x) for x in list(getattr(x.list, f))]
2247          else:
2248            return list(getattr(x.list, f))
2249      return []
2250    else:
2251      for f in fields:
2252        if x.HasField(f):
2253          if f == "type":
2254            return dtypes.as_dtype(getattr(x, f))
2255          else:
2256            return getattr(x, f)
2257      assert False, "Unsupported field type in " + str(x)
2258
2259  def run(self, feed_dict=None, session=None):
2260    """Runs this operation in a `Session`.
2261
2262    Calling this method will execute all preceding operations that
2263    produce the inputs needed for this operation.
2264
2265    *N.B.* Before invoking `Operation.run()`, its graph must have been
2266    launched in a session, and either a default session must be
2267    available, or `session` must be specified explicitly.
2268
2269    Args:
2270      feed_dict: A dictionary that maps `Tensor` objects to feed values.
2271        See @{tf.Session.run}
2272        for a description of the valid feed values.
2273      session: (Optional.) The `Session` to be used to run to this operation. If
2274        none, the default session will be used.
2275    """
2276    _run_using_default_session(self, feed_dict, self.graph, session)
2277
2278_gradient_registry = registry.Registry("gradient")
2279
2280
2281@tf_export("RegisterGradient")
2282class RegisterGradient(object):
2283  """A decorator for registering the gradient function for an op type.
2284
2285  This decorator is only used when defining a new op type. For an op
2286  with `m` inputs and `n` outputs, the gradient function is a function
2287  that takes the original `Operation` and `n` `Tensor` objects
2288  (representing the gradients with respect to each output of the op),
2289  and returns `m` `Tensor` objects (representing the partial gradients
2290  with respect to each input of the op).
2291
2292  For example, assuming that operations of type `"Sub"` take two
2293  inputs `x` and `y`, and return a single output `x - y`, the
2294  following gradient function would be registered:
2295
2296  ```python
2297  @tf.RegisterGradient("Sub")
2298  def _sub_grad(unused_op, grad):
2299    return grad, tf.negative(grad)
2300  ```
2301
2302  The decorator argument `op_type` is the string type of an
2303  operation. This corresponds to the `OpDef.name` field for the proto
2304  that defines the operation.
2305  """
2306
2307  def __init__(self, op_type):
2308    """Creates a new decorator with `op_type` as the Operation type.
2309
2310    Args:
2311      op_type: The string type of an operation. This corresponds to the
2312        `OpDef.name` field for the proto that defines the operation.
2313    """
2314    if not isinstance(op_type, six.string_types):
2315      raise TypeError("op_type must be a string")
2316    self._op_type = op_type
2317
2318  def __call__(self, f):
2319    """Registers the function `f` as gradient function for `op_type`."""
2320    _gradient_registry.register(f, self._op_type)
2321    return f
2322
2323
2324@tf_export("NoGradient", "NotDifferentiable")
2325def NotDifferentiable(op_type):
2326  """Specifies that ops of type `op_type` is not differentiable.
2327
2328  This function should *not* be used for operations that have a
2329  well-defined gradient that is not yet implemented.
2330
2331  This function is only used when defining a new op type. It may be
2332  used for ops such as `tf.size()` that are not differentiable.  For
2333  example:
2334
2335  ```python
2336  tf.NotDifferentiable("Size")
2337  ```
2338
2339  The gradient computed for 'op_type' will then propagate zeros.
2340
2341  For ops that have a well-defined gradient but are not yet implemented,
2342  no declaration should be made, and an error *must* be thrown if
2343  an attempt to request its gradient is made.
2344
2345  Args:
2346    op_type: The string type of an operation. This corresponds to the
2347      `OpDef.name` field for the proto that defines the operation.
2348
2349  Raises:
2350    TypeError: If `op_type` is not a string.
2351
2352  """
2353  if not isinstance(op_type, six.string_types):
2354    raise TypeError("op_type must be a string")
2355  _gradient_registry.register(None, op_type)
2356
2357
2358# Alias for the old name, will be eventually removed.
2359NoGradient = NotDifferentiable
2360
2361
2362def get_gradient_function(op):
2363  """Returns the function that computes gradients for "op"."""
2364  if not op.inputs:
2365    return None
2366  try:
2367    op_type = op.get_attr("_gradient_op_type")
2368  except ValueError:
2369    op_type = op.type
2370  return _gradient_registry.lookup(op_type)
2371
2372
2373_shape_registry = registry.Registry("shape functions")
2374_default_shape_function_registry = registry.Registry("default shape functions")
2375
2376# These are set to common_shapes.call_cpp_shape_fn by op generated code
2377# (generated by python_op_gen.cc).
2378# It is set outside ops.py to avoid a circular dependency.
2379_call_cpp_shape_fn = None
2380_call_cpp_shape_fn_and_require_op = None
2381
2382
2383def _set_call_cpp_shape_fn(call_cpp_shape_fn):
2384  """Sets default shape fns from passed common_shapes.call_cpp_shape_fn."""
2385  global _call_cpp_shape_fn, _call_cpp_shape_fn_and_require_op
2386  if _call_cpp_shape_fn:
2387    return  # already registered
2388
2389  def call_without_requiring(op):
2390    return call_cpp_shape_fn(op, require_shape_fn=False)
2391
2392  _call_cpp_shape_fn = call_without_requiring
2393
2394  def call_with_requiring(op):
2395    return call_cpp_shape_fn(op, require_shape_fn=True)
2396
2397  _call_cpp_shape_fn_and_require_op = call_with_requiring
2398
2399
2400class RegisterShape(object):
2401  """No longer used.  Was: A decorator for registering a shape function.
2402
2403  Shape functions must now be registered via the SetShapeFn on the
2404  original Op specification in C++.
2405
2406  """
2407
2408  def __init__(self, op_type):
2409    """Saves the `op_type` as the `Operation` type."""
2410    if not isinstance(op_type, six.string_types):
2411      raise TypeError("op_type must be a string")
2412    self._op_type = op_type
2413
2414  def __call__(self, f):
2415    """Registers "f" as the shape function for "op_type"."""
2416    if f is None:
2417      assert _call_cpp_shape_fn
2418
2419      # None is a special "weak" value that provides a default shape function,
2420      # and can be overridden by a non-None registration.
2421      try:
2422        _default_shape_function_registry.register(_call_cpp_shape_fn,
2423                                                  self._op_type)
2424      except KeyError:
2425        # Ignore duplicate registrations of the weak value. This can
2426        # occur if the op library input to wrapper generation
2427        # inadvertently links in one or more of the standard op
2428        # libraries.
2429        pass
2430    else:
2431      _shape_registry.register(f, self._op_type)
2432    return f
2433
2434
2435def _set_shapes_for_outputs_c_api(op):
2436  """set_shapes_for_outputs implementation when C API is enabled."""
2437  # The C API computes the shapes when the TF_Operation is created. Fetch the
2438  # output shapes from the C object.
2439  for output in op.outputs:
2440    with errors.raise_exception_on_not_ok_status() as status:
2441      # pylint: disable=protected-access
2442      shape_vector, unknown_shape = c_api.TF_GraphGetTensorShapeHelper(
2443          op._graph._c_graph, output._as_tf_output(), status)
2444      # pylint: enable=protected-access
2445    if unknown_shape:
2446      output.set_shape(tensor_shape.unknown_shape())
2447    elif not shape_vector:
2448      output.set_shape(tensor_shape.scalar())
2449    else:
2450      shape_vector = [None if d == -1 else d for d in shape_vector]
2451      output.set_shape(tensor_shape.TensorShape(shape_vector))
2452
2453
2454# TODO(skyewm): remove this when _USE_C_API flag is removed.
2455def _set_shapes_for_outputs(op):
2456  """set_shapes_for_outputs implementation when C API is disabled."""
2457  try:
2458    shape_func = _shape_registry.lookup(op.type)
2459  except LookupError:
2460    try:
2461      shape_func = _default_shape_function_registry.lookup(op.type)
2462    except LookupError:
2463      shape_func = _call_cpp_shape_fn_and_require_op
2464
2465  shapes = shape_func(op)
2466  if shapes is None:
2467    raise RuntimeError(
2468        "Shape function for op %s did not return any shapes" % op)
2469  elif isinstance(shapes, dict):
2470    # Returned by call_cpp_shape_fn
2471    shapes_dict = shapes
2472    shapes = shapes_dict["shapes"]
2473    handle_datas = shapes_dict["handle_data"]
2474    for output, handle_data in zip(op.outputs, handle_datas):
2475      # pylint: disable=protected-access
2476      output._handle_data = handle_data
2477      # pylint: enable=protected-access
2478
2479  if len(op.outputs) != len(shapes):
2480    raise RuntimeError(
2481        "Shape function for op %s returned %d shapes but expected %d %s %s" %
2482        (op, len(shapes), len(op.outputs), shape_func.__name__, str(shapes)))
2483  for output, s in zip(op.outputs, shapes):
2484    output.set_shape(s)
2485
2486
2487def set_shapes_for_outputs(op):
2488  """Set the shapes for op's outputs."""
2489  if op._c_op:  # pylint: disable=protected-access
2490    return _set_shapes_for_outputs_c_api(op)
2491  else:
2492    return _set_shapes_for_outputs(op)
2493
2494
2495class OpStats(object):
2496  """A holder for statistics about an operator.
2497
2498  This class holds information about the resource requirements for an op,
2499  including the size of its weight parameters on-disk and how many FLOPS it
2500  requires to execute forward inference.
2501
2502  If you define a new operation, you can create a function that will return a
2503  set of information about its usage of the CPU and disk space when serialized.
2504  The function itself takes a Graph object that's been set up so you can call
2505  methods like get_tensor_by_name to help calculate the results, and a NodeDef
2506  argument.
2507
2508  """
2509
2510  def __init__(self, statistic_type, value=None):
2511    """Sets up the initial placeholders for the statistics."""
2512    self.statistic_type = statistic_type
2513    self.value = value
2514
2515  @property
2516  def statistic_type(self):
2517    return self._statistic_type
2518
2519  @statistic_type.setter
2520  def statistic_type(self, statistic_type):
2521    self._statistic_type = statistic_type
2522
2523  @property
2524  def value(self):
2525    return self._value
2526
2527  @value.setter
2528  def value(self, value):
2529    self._value = value
2530
2531  def __iadd__(self, other):
2532    if other.statistic_type != self.statistic_type:
2533      raise ValueError("Can't add an OpStat of type %s to one of %s." %
2534                       (self.statistic_type, other.statistic_type))
2535    if self.value is None:
2536      self.value = other.value
2537    elif other.value is not None:
2538      self._value += other.value
2539    return self
2540
2541
2542_stats_registry = registry.Registry("statistical functions")
2543
2544
2545class RegisterStatistics(object):
2546  """A decorator for registering the statistics function for an op type.
2547
2548  This decorator can be defined for an op type so that it gives a
2549  report on the resources used by an instance of an operator, in the
2550  form of an OpStats object.
2551
2552  Well-known types of statistics include these so far:
2553
2554  - flops: When running a graph, the bulk of the computation happens doing
2555    numerical calculations like matrix multiplications. This type allows a node
2556    to return how many floating-point operations it takes to complete. The
2557    total number of FLOPs for a graph is a good guide to its expected latency.
2558
2559  You can add your own statistics just by picking a new type string, registering
2560  functions for the ops you care about, and then calling get_stats_for_node_def.
2561
2562  If a statistic for an op is registered multiple times, a KeyError will be
2563  raised.
2564
2565  Since the statistics is counted on a per-op basis. It is not suitable for
2566  model parameters (capacity), which is expected to be counted only once, even
2567  if it is shared by multiple ops. (e.g. RNN)
2568
2569  For example, you can define a new metric called doohickey for a Foo operation
2570  by placing this in your code:
2571
2572  ```python
2573  @ops.RegisterStatistics("Foo", "doohickey")
2574  def _calc_foo_bojangles(unused_graph, unused_node_def):
2575    return ops.OpStats("doohickey", 20)
2576  ```
2577
2578  Then in client code you can retrieve the value by making this call:
2579
2580  ```python
2581  doohickey = ops.get_stats_for_node_def(graph, node_def, "doohickey")
2582  ```
2583
2584  If the NodeDef is for an op with a registered doohickey function, you'll get
2585  back the calculated amount in doohickey.value, or None if it's not defined.
2586
2587  """
2588
2589  def __init__(self, op_type, statistic_type):
2590    """Saves the `op_type` as the `Operation` type."""
2591    if not isinstance(op_type, six.string_types):
2592      raise TypeError("op_type must be a string.")
2593    if "," in op_type:
2594      raise TypeError("op_type must not contain a comma.")
2595    self._op_type = op_type
2596    if not isinstance(statistic_type, six.string_types):
2597      raise TypeError("statistic_type must be a string.")
2598    if "," in statistic_type:
2599      raise TypeError("statistic_type must not contain a comma.")
2600    self._statistic_type = statistic_type
2601
2602  def __call__(self, f):
2603    """Registers "f" as the statistics function for "op_type"."""
2604    _stats_registry.register(f, self._op_type + "," + self._statistic_type)
2605    return f
2606
2607
2608def get_stats_for_node_def(graph, node, statistic_type):
2609  """Looks up the node's statistics function in the registry and calls it.
2610
2611  This function takes a Graph object and a NodeDef from a GraphDef, and if
2612  there's an associated statistics method, calls it and returns a result. If no
2613  function has been registered for the particular node type, it returns an empty
2614  statistics object.
2615
2616  Args:
2617    graph: A Graph object that's been set up with the node's graph.
2618    node: A NodeDef describing the operator.
2619    statistic_type: A string identifying the statistic we're interested in.
2620  Returns:
2621    An OpStats object containing information about resource usage.
2622  """
2623
2624  try:
2625    stats_func = _stats_registry.lookup(node.op + "," + statistic_type)
2626    result = stats_func(graph, node)
2627  except LookupError:
2628    result = OpStats(statistic_type)
2629  return result
2630
2631
2632def _name_from_scope_name(name):
2633  """Returns the name of an op given the name of its scope.
2634
2635  Args:
2636    name: the name of the scope.
2637
2638  Returns:
2639    the name of the op (equal to scope name minus any trailing slash).
2640  """
2641  return name[:-1] if (name and name[-1] == "/") else name
2642
2643
2644@tf_export("Graph")
2645class Graph(object):
2646  """A TensorFlow computation, represented as a dataflow graph.
2647
2648  A `Graph` contains a set of
2649  @{tf.Operation} objects,
2650  which represent units of computation; and
2651  @{tf.Tensor} objects, which represent
2652  the units of data that flow between operations.
2653
2654  A default `Graph` is always registered, and accessible by calling
2655  @{tf.get_default_graph}.
2656  To add an operation to the default graph, simply call one of the functions
2657  that defines a new `Operation`:
2658
2659  ```python
2660  c = tf.constant(4.0)
2661  assert c.graph is tf.get_default_graph()
2662  ```
2663
2664  Another typical usage involves the
2665  @{tf.Graph.as_default}
2666  context manager, which overrides the current default graph for the
2667  lifetime of the context:
2668
2669  ```python
2670  g = tf.Graph()
2671  with g.as_default():
2672    # Define operations and tensors in `g`.
2673    c = tf.constant(30.0)
2674    assert c.graph is g
2675  ```
2676
2677  Important note: This class *is not* thread-safe for graph construction. All
2678  operations should be created from a single thread, or external
2679  synchronization must be provided. Unless otherwise specified, all methods
2680  are not thread-safe.
2681
2682  A `Graph` instance supports an arbitrary number of "collections"
2683  that are identified by name. For convenience when building a large
2684  graph, collections can store groups of related objects: for
2685  example, the `tf.Variable` uses a collection (named
2686  @{tf.GraphKeys.GLOBAL_VARIABLES}) for
2687  all variables that are created during the construction of a graph. The caller
2688  may define additional collections by specifying a new name.
2689  """
2690
2691  def __init__(self):
2692    """Creates a new, empty Graph."""
2693    # Protects the core state that may be accessed by multiple readers.
2694    # Only state that can be returned via public accessors (`as_graph_def()`,
2695    # `get_operations()`, `as_graph_element()`, `get_collection()`, and
2696    # `get_collection_ref()`) is by the lock. Thread-safety is provided on a
2697    # best-effort basis to support buggy programs, and is not guaranteed by the
2698    # public `tf.Graph` API.
2699    # NOTE(mrry): This does not protect the various stacks. A warning will
2700    # be reported if these are used from multiple threads
2701    self._lock = threading.Lock()
2702    self._nodes_by_id = dict()  # GUARDED_BY(self._lock)
2703    self._next_id_counter = 0  # GUARDED_BY(self._lock)
2704    self._nodes_by_name = dict()  # GUARDED_BY(self._lock)
2705    self._version = 0  # GUARDED_BY(self._lock)
2706    # Current name stack: uniquified names
2707    self._name_stack = ""
2708    # Maps a name used in the graph to the next id to use for that name.
2709    self._names_in_use = {}
2710    # Functions that will be applied to choose a device if none is specified.
2711    self._device_function_stack = []
2712    # Default original_op applied to new ops.
2713    self._default_original_op = None
2714    # Current control flow context. It could be either CondContext or
2715    # WhileContext defined in ops/control_flow_ops.py
2716    self._control_flow_context = None
2717    # A new node will depend of the union of all of the nodes in the stack.
2718    self._control_dependencies_stack = []
2719    # Arbitrary collections of objects.
2720    self._collections = {}
2721    # The graph-level random seed
2722    self._seed = None
2723    # A dictionary of attributes that should be applied to all ops.
2724    self._attr_scope_map = {}
2725    # A map from op type to the kernel label that should be used.
2726    self._op_to_kernel_label_map = {}
2727    # A map from op type to an alternative op type that should be used when
2728    # computing gradients.
2729    self._gradient_override_map = {}
2730    # True if the graph is considered "finalized".  In that case no
2731    # new operations can be added.
2732    self._finalized = False
2733    # Functions defined in the graph
2734    self._functions = collections.OrderedDict()
2735    # Default GraphDef versions
2736    self._graph_def_versions = versions_pb2.VersionDef(
2737        producer=versions.GRAPH_DEF_VERSION,
2738        min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER)
2739    self._building_function = False
2740    # Stack of colocate_with ops
2741    self._colocation_stack = []
2742    # Set of tensors that are dangerous to feed!
2743    self._unfeedable_tensors = set()
2744    # Set of operations that are dangerous to fetch!
2745    self._unfetchable_ops = set()
2746    # A map of tensor handle placeholder to tensor dtype.
2747    self._handle_feeders = {}
2748    # A map from tensor handle to its read op.
2749    self._handle_readers = {}
2750    # A map from tensor handle to its move op.
2751    self._handle_movers = {}
2752    # A map from tensor handle to its delete op.
2753    self._handle_deleters = {}
2754    # Allow optimizers and other objects to pseudo-uniquely key graphs (this key
2755    # will be shared when defining function graphs, for example, so optimizers
2756    # being called inside function definitions behave as if they were seeing the
2757    # actual outside graph).
2758    self._graph_key = "grap-key-%d/" % (uid(),)
2759    self._container = ""
2760    self._registered_ops = op_def_registry.get_registered_ops()
2761
2762    # TODO(skyewm): fold as much of the above as possible into the C
2763    # implementation
2764    if _USE_C_API or self._use_c_api_hack():
2765      self._scoped_c_graph = c_api_util.ScopedTFGraph()
2766    else:
2767      self._scoped_c_graph = None
2768    self._variable_creator_stack = []
2769
2770  # TODO(apassos) remove once the C API is used by default.
2771  def _use_c_api_hack(self):
2772    """Temporary hack; can be overridden to force C API usage."""
2773    return False
2774
2775  def _convert_stack(self, stack, include_func_start_lineno=False):
2776    """Converts a stack extracted using _extract_stack() to a traceback stack.
2777
2778    Args:
2779      stack: A list of n 5-tuples,
2780        (filename, lineno, name, frame_globals, func_start_lineno).
2781      include_func_start_lineno: True if function start line number should be
2782        included as the 5th entry in return tuples.
2783
2784    Returns:
2785      A list of n 4-tuples or 5-tuples
2786      (filename, lineno, name, code, [optional: func_start_lineno]), where the
2787      code tuple element is calculated from the corresponding elements of the
2788      input tuple.
2789    """
2790    ret = []
2791    for (filename, lineno, name, frame_globals, func_start_lineno,
2792         unused_frame_info) in stack:
2793      linecache.checkcache(filename)
2794      line = linecache.getline(filename, lineno, frame_globals)
2795      if line:
2796        line = line.strip()
2797      else:
2798        line = None
2799      if include_func_start_lineno:
2800        ret.append((filename, lineno, name, line, func_start_lineno))
2801      else:
2802        ret.append((filename, lineno, name, line))
2803    return ret
2804
2805  # Note: this method is private because the API of tf.Graph() is public and
2806  # frozen, and this functionality is still not ready for public visibility.
2807  @tf_contextlib.contextmanager
2808  def _variable_creator_scope(self, creator):
2809    old = list(self._variable_creator_stack)
2810    self._variable_creator_stack.append(creator)
2811    try:
2812      yield
2813    finally:
2814      self._variable_creator_stack = old
2815
2816  # Note: this method is private because the API of tf.Graph() is public and
2817  # frozen, and this functionality is still not ready for public visibility.
2818  def _get_variable_creator_stack(self):
2819    return list(self._variable_creator_stack)
2820
2821  def _extract_stack(self):
2822    """A lightweight, extensible re-implementation of traceback.extract_stack.
2823
2824    NOTE(mrry): traceback.extract_stack eagerly retrieves the line of code for
2825      each stack frame using linecache, which results in an abundance of stat()
2826      calls. This implementation does not retrieve the code, and any consumer
2827      should apply _convert_stack to the result to obtain a traceback that can
2828      be formatted etc. using traceback methods.
2829
2830    Derived classes can implement _extract_frame_info() to add extra information
2831    to the traceback.
2832
2833    Returns:
2834      A list of 6-tuples
2835      (filename, lineno, name, frame_globals, func_start_lineno, custom_info)
2836      corresponding to the call stack of the current thread.
2837    """
2838    try:
2839      raise ZeroDivisionError
2840    except ZeroDivisionError:
2841      f = sys.exc_info()[2].tb_frame.f_back
2842    ret = []
2843    while f is not None:
2844      lineno = f.f_lineno
2845      co = f.f_code
2846      filename = co.co_filename
2847      name = co.co_name
2848      frame_globals = f.f_globals
2849      func_start_lineno = co.co_firstlineno
2850      frame_info = self._extract_frame_info(f)
2851      ret.append((filename, lineno, name, frame_globals, func_start_lineno,
2852                  frame_info))
2853      f = f.f_back
2854    ret.reverse()
2855    return ret
2856
2857  def _extract_frame_info(self, frame):  # pylint: disable=unused-argument
2858    """Extracts custom information from a frame in an op traceback."""
2859    return None
2860
2861  def _check_not_finalized(self):
2862    """Check if the graph is finalized.
2863
2864    Raises:
2865      RuntimeError: If the graph finalized.
2866    """
2867    if self._finalized:
2868      raise RuntimeError("Graph is finalized and cannot be modified.")
2869
2870  def _add_op(self, op):
2871    """Adds 'op' to the graph.
2872
2873    Args:
2874      op: the Operator or Tensor to add.
2875
2876    Raises:
2877      TypeError: if op is not an Operation or Tensor.
2878      ValueError: if the op.name or op._id are already used.
2879    """
2880    self._check_not_finalized()
2881    if not isinstance(op, (Tensor, Operation)):
2882      raise TypeError("op must be a Tensor or Operation: %s" % op)
2883    with self._lock:
2884      # pylint: disable=protected-access
2885      if op._id in self._nodes_by_id:
2886        raise ValueError("cannot add an op with id %d as it already "
2887                         "exists in the graph" % op._id)
2888      if op.name in self._nodes_by_name:
2889        raise ValueError("cannot add op with name %s as that name "
2890                         "is already used" % op.name)
2891      self._nodes_by_id[op._id] = op
2892      self._nodes_by_name[op.name] = op
2893      self._version = max(self._version, op._id)
2894      # pylint: enable=protected-access
2895
2896  @property
2897  def _c_graph(self):
2898    if self._scoped_c_graph:
2899      return self._scoped_c_graph.graph
2900    return None
2901
2902  @property
2903  def version(self):
2904    """Returns a version number that increases as ops are added to the graph.
2905
2906    Note that this is unrelated to the
2907    @{tf.Graph.graph_def_versions}.
2908
2909    Returns:
2910       An integer version that increases as ops are added to the graph.
2911    """
2912    if self._finalized:
2913      return self._version
2914
2915    with self._lock:
2916      return self._version
2917
2918  @property
2919  def graph_def_versions(self):
2920    # pylint: disable=line-too-long
2921    """The GraphDef version information of this graph.
2922
2923    For details on the meaning of each version, see
2924    [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto).
2925
2926    Returns:
2927      A `VersionDef`.
2928    """
2929    # pylint: enable=line-too-long
2930    if self._c_graph:
2931      with c_api_util.tf_buffer() as buf:
2932        with errors.raise_exception_on_not_ok_status() as status:
2933          c_api.TF_GraphVersions(self._c_graph, buf, status)
2934        data = c_api.TF_GetBuffer(buf)
2935      version_def = versions_pb2.VersionDef()
2936      version_def.ParseFromString(compat.as_bytes(data))
2937      return version_def
2938    else:
2939      return self._graph_def_versions
2940
2941  @property
2942  def seed(self):
2943    """The graph-level random seed of this graph."""
2944    return self._seed
2945
2946  @seed.setter
2947  def seed(self, seed):
2948    self._seed = seed
2949
2950  @property
2951  def finalized(self):
2952    """True if this graph has been finalized."""
2953    return self._finalized
2954
2955  def finalize(self):
2956    """Finalizes this graph, making it read-only.
2957
2958    After calling `g.finalize()`, no new operations can be added to
2959    `g`.  This method is used to ensure that no operations are added
2960    to a graph when it is shared between multiple threads, for example
2961    when using a @{tf.train.QueueRunner}.
2962    """
2963    self._finalized = True
2964
2965  def _unsafe_unfinalize(self):
2966    """Opposite of `finalize`. Internal interface.
2967
2968    NOTE: Unfinalizing a graph could have negative impact on performance,
2969    especially in a multi-threaded environment.  Unfinalizing a graph
2970    when it is in use by a Session may lead to undefined behavior. Ensure
2971    that all sessions using a graph are closed before calling this method.
2972    """
2973    self._finalized = False
2974
2975  def _get_control_flow_context(self):
2976    """Returns the current control flow context.
2977
2978    Returns:
2979      A context object.
2980    """
2981    return self._control_flow_context
2982
2983  def _set_control_flow_context(self, ctx):
2984    """Sets the current control flow context.
2985
2986    Args:
2987      ctx: a context object.
2988    """
2989    self._control_flow_context = ctx
2990
2991  def _copy_functions_to_graph_def(self, graph_def, starting_bytesize):
2992    """If this graph contains functions, copy them to `graph_def`."""
2993    bytesize = starting_bytesize
2994    for f in self._functions.values():
2995      bytesize += f.definition.ByteSize()
2996      if bytesize >= (1 << 31) or bytesize < 0:
2997        raise ValueError("GraphDef cannot be larger than 2GB.")
2998      graph_def.library.function.extend([f.definition])
2999      if f.grad_func_name:
3000        grad_def = function_pb2.GradientDef()
3001        grad_def.function_name = f.name
3002        grad_def.gradient_func = f.grad_func_name
3003        graph_def.library.gradient.extend([grad_def])
3004
3005  def _as_graph_def(self, from_version=None, add_shapes=False):
3006    # pylint: disable=line-too-long
3007    """Returns a serialized `GraphDef` representation of this graph.
3008
3009    The serialized `GraphDef` can be imported into another `Graph`
3010    (using @{tf.import_graph_def}) or used with the
3011    [C++ Session API](../../../../api_docs/cc/index.md).
3012
3013    This method is thread-safe.
3014
3015    Args:
3016      from_version: Optional.  If this is set, returns a `GraphDef`
3017        containing only the nodes that were added to this graph since
3018        its `version` property had the given value.
3019      add_shapes: If true, adds an "_output_shapes" list attr to each
3020        node with the inferred shapes of each of its outputs.
3021
3022    Returns:
3023      A tuple containing a
3024      [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
3025      protocol buffer, and the version of the graph to which that
3026      `GraphDef` corresponds.
3027
3028    Raises:
3029      ValueError: If the `graph_def` would be too large.
3030
3031    """
3032    # pylint: enable=line-too-long
3033    if _USE_C_API:
3034      with self._lock:
3035        with c_api_util.tf_buffer() as buf:
3036          with errors.raise_exception_on_not_ok_status() as status:
3037            c_api.TF_GraphToGraphDef(self._c_graph, buf, status)
3038          data = c_api.TF_GetBuffer(buf)
3039        graph = graph_pb2.GraphDef()
3040        graph.ParseFromString(compat.as_bytes(data))
3041        # Strip the experimental library field iff it's empty.
3042        if not graph.library.function:
3043          graph.ClearField("library")
3044
3045        if add_shapes:
3046          for node in graph.node:
3047            op = self._nodes_by_name[node.name]
3048            if op.outputs:
3049              node.attr["_output_shapes"].list.shape.extend(
3050                  [output.get_shape().as_proto() for output in op.outputs])
3051    else:
3052      with self._lock:
3053        graph = graph_pb2.GraphDef()
3054        graph.versions.CopyFrom(self._graph_def_versions)
3055        bytesize = 0
3056        for op_id in sorted(self._nodes_by_id):
3057          op = self._nodes_by_id[op_id]
3058          if from_version is None or op_id > from_version:
3059            graph.node.extend([op.node_def])
3060            if op.outputs and add_shapes:
3061              assert "_output_shapes" not in graph.node[-1].attr
3062              graph.node[-1].attr["_output_shapes"].list.shape.extend(
3063                  [output.get_shape().as_proto() for output in op.outputs])
3064            bytesize += op.node_def.ByteSize()
3065            if bytesize >= (1 << 31) or bytesize < 0:
3066              raise ValueError("GraphDef cannot be larger than 2GB.")
3067        self._copy_functions_to_graph_def(graph, bytesize)
3068    return graph, self._version
3069
3070  def as_graph_def(self, from_version=None, add_shapes=False):
3071    # pylint: disable=line-too-long
3072    """Returns a serialized `GraphDef` representation of this graph.
3073
3074    The serialized `GraphDef` can be imported into another `Graph`
3075    (using @{tf.import_graph_def}) or used with the
3076    [C++ Session API](../../api_docs/cc/index.md).
3077
3078    This method is thread-safe.
3079
3080    Args:
3081      from_version: Optional.  If this is set, returns a `GraphDef`
3082        containing only the nodes that were added to this graph since
3083        its `version` property had the given value.
3084      add_shapes: If true, adds an "_output_shapes" list attr to each
3085        node with the inferred shapes of each of its outputs.
3086
3087    Returns:
3088      A
3089      [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
3090      protocol buffer.
3091
3092    Raises:
3093      ValueError: If the `graph_def` would be too large.
3094    """
3095    # pylint: enable=line-too-long
3096    result, _ = self._as_graph_def(from_version, add_shapes)
3097    return result
3098
3099  def _is_function(self, name):
3100    """Tests whether 'name' is registered in this graph's function library.
3101
3102    Args:
3103      name: string op name.
3104    Returns:
3105      bool indicating whether or not 'name' is registered in function library.
3106    """
3107    return name in self._functions
3108
3109  def _get_function(self, name):
3110    """Returns the function definition for 'name'.
3111
3112    Args:
3113      name: string function name.
3114    Returns:
3115      The function def proto.
3116    """
3117    return self._functions.get(name, None)
3118
3119  def _add_function(self, function):
3120    """Adds a function to the graph.
3121
3122    After the function has been added, you can call to the function by
3123    passing the function name in place of an op name to
3124    `Graph.create_op()`.
3125
3126    Args:
3127      function: A `_DefinedFunction` object.
3128
3129
3130    Raises:
3131      ValueError: if another function is defined with the same name.
3132    """
3133    name = function.name
3134    # Sanity checks on gradient definition.
3135    if (function.grad_func_name is not None) and (function.python_grad_func is
3136                                                  not None):
3137      raise ValueError("Gradient defined twice for function %s" % name)
3138
3139    # Add function to graph
3140    # pylint: disable=protected-access
3141    if self._c_graph:
3142      # Handle functions created without using the C API. TODO(apassos,skyewm)
3143      # remove this when all functions are generated using the C API by default
3144      # as this will be unnecessary.
3145      if not function._c_func:
3146        with errors.raise_exception_on_not_ok_status() as status:
3147          serialized = function.definition.SerializeToString()
3148          function._c_func = c_api.TF_FunctionImportFunctionDef(
3149              serialized, status)
3150      with errors.raise_exception_on_not_ok_status() as status:
3151        gradient = function._grad_func._c_func if function._grad_func else None
3152        c_api.TF_GraphCopyFunction(self._c_graph, function._c_func, gradient,
3153                                   status)
3154    else:
3155      # If there is already a function with the same name, raise an error
3156      # if bodies are different. Else, do nothing. The C API version above
3157      # has the same behavior.
3158      previous = self._functions.get(name, None)
3159      if previous:
3160        # This check is not ideal as we can have a hash collision with only
3161        # 32 bits in the hash, but the non C API mode is being deprecated.
3162        # Don't bother changing it now.
3163        if previous._hash_str == function._hash_str:
3164          return
3165        else:
3166          raise ValueError("Cannot add function (%s, hash %s) to graph (%s). "
3167                           "Another function (%s, hash %s) is already defined "
3168                           "with that name (%s)" % (
3169                               function, function._hash_str, self,
3170                               previous, previous._hash_str, name))
3171    # pylint: enable=protected-access
3172
3173    self._functions[name] = function
3174
3175    # Need a new-enough consumer to support the functions we add to the graph.
3176    if self._graph_def_versions.min_consumer < 12:
3177      self._graph_def_versions.min_consumer = 12
3178
3179  @property
3180  def building_function(self):
3181    """Returns True iff this graph represents a function."""
3182    return self._building_function
3183
3184  # Helper functions to create operations.
3185  def create_op(
3186      self,
3187      op_type,
3188      inputs,
3189      dtypes,  # pylint: disable=redefined-outer-name
3190      input_types=None,
3191      name=None,
3192      attrs=None,
3193      op_def=None,
3194      compute_shapes=True,
3195      compute_device=True):
3196    """Creates an `Operation` in this graph.
3197
3198    This is a low-level interface for creating an `Operation`. Most
3199    programs will not call this method directly, and instead use the
3200    Python op constructors, such as `tf.constant()`, which add ops to
3201    the default graph.
3202
3203    Args:
3204      op_type: The `Operation` type to create. This corresponds to the
3205        `OpDef.name` field for the proto that defines the operation.
3206      inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
3207      dtypes: A list of `DType` objects that will be the types of the tensors
3208        that the operation produces.
3209      input_types: (Optional.) A list of `DType`s that will be the types of
3210        the tensors that the operation consumes. By default, uses the base
3211        `DType` of each input in `inputs`. Operations that expect
3212        reference-typed inputs must specify `input_types` explicitly.
3213      name: (Optional.) A string name for the operation. If not specified, a
3214        name is generated based on `op_type`.
3215      attrs: (Optional.) A dictionary where the key is the attribute name (a
3216        string) and the value is the respective `attr` attribute of the
3217        `NodeDef` proto that will represent the operation (an `AttrValue`
3218        proto).
3219      op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
3220        the operation will have.
3221      compute_shapes: (Optional.) If True, shape inference will be performed
3222        to compute the shapes of the outputs.
3223      compute_device: (Optional.) If True, device functions will be executed
3224        to compute the device property of the Operation.
3225
3226    Raises:
3227      TypeError: if any of the inputs is not a `Tensor`.
3228      ValueError: if colocation conflicts with existing device assignment.
3229
3230    Returns:
3231      An `Operation` object.
3232
3233    """
3234    self._check_not_finalized()
3235    for idx, a in enumerate(inputs):
3236      if not isinstance(a, Tensor):
3237        raise TypeError("Input #%d is not a tensor: %s" % (idx, a))
3238    if name is None:
3239      name = op_type
3240    # If a names ends with a '/' it is a "name scope" and we use it as-is,
3241    # after removing the trailing '/'.
3242    if name and name[-1] == "/":
3243      name = _name_from_scope_name(name)
3244    else:
3245      name = self.unique_name(name)
3246
3247    node_def = _NodeDef(op_type, name, device=None, attrs=attrs)
3248
3249    input_ops = set([t.op for t in inputs])
3250    control_inputs = self._control_dependencies_for_inputs(input_ops)
3251    ret = Operation(
3252        node_def,
3253        self,
3254        inputs=inputs,
3255        output_types=dtypes,
3256        control_inputs=control_inputs,
3257        input_types=input_types,
3258        original_op=self._default_original_op,
3259        op_def=op_def)
3260    self._create_op_helper(ret, compute_shapes=compute_shapes,
3261                           compute_device=compute_device)
3262    return ret
3263
3264  def _create_op_from_tf_operation(self, c_op, compute_device=True):
3265    """Creates an `Operation` in this graph from the supplied TF_Operation.
3266
3267    This method is like create_op() except the new Operation is constructed
3268    using `c_op`. The returned Operation will have `c_op` as its _c_op
3269    field. This is used to create Operation objects around TF_Operations created
3270    indirectly by the C API (e.g. by TF_ImportGraphDef, TF_FinishWhile).
3271
3272    This function does not call Operation._control_flow_post_processing or
3273    Graph._control_dependencies_for_inputs (since the inputs may not be
3274    available yet). The caller is responsible for calling these methods.
3275
3276    Args:
3277      c_op: a wrapped TF_Operation
3278      compute_device: (Optional.) If True, device functions will be executed
3279        to compute the device property of the Operation.
3280
3281    Returns:
3282      An `Operation` object.
3283    """
3284    self._check_not_finalized()
3285    ret = Operation(c_op, self)
3286    assert ret.name not in self._names_in_use
3287    self._names_in_use[ret.name] = 1
3288    self._create_op_helper(ret, compute_device=compute_device)
3289    return ret
3290
3291  def _create_op_helper(self, op, compute_shapes=True, compute_device=True):
3292    """Common logic for creating an op in this graph."""
3293    # TODO(vrv): Instead of eagerly filling in shape property for every op, only
3294    # populate the shape when requested.
3295    #
3296    # TODO(skyewm): unlike in the original Python implementation, the C API
3297    # always computes shape information (even for function calls, which the
3298    # original Python shape inference code doesn't handle). Deprecate the
3299    # compute_shapes argument.
3300    if op._c_op or compute_shapes:  # pylint: disable=protected-access
3301      set_shapes_for_outputs(op)
3302    # TODO(b/XXXX): move to Operation.__init__ once _USE_C_API flag is removed.
3303    self._add_op(op)
3304
3305    # Apply any additional attributes requested. Do not overwrite any existing
3306    # attributes.
3307    for key, value in self._attr_scope_map.items():
3308      try:
3309        op.get_attr(key)
3310      except ValueError:
3311        if callable(value):
3312          value = value(op.node_def)
3313          if not isinstance(value, (type(None), attr_value_pb2.AttrValue)):
3314            raise TypeError(
3315                "Callable for scope map key '%s' must return either None or "
3316                "an AttrValue protocol buffer; but it returned: %s" % (key,
3317                                                                       value))
3318        if value:
3319          op._set_attr(key, value)  # pylint: disable=protected-access
3320
3321    # Apply a kernel label if one has been specified for this op type.
3322    try:
3323      kernel_label = self._op_to_kernel_label_map[op.type]
3324      op._set_attr("_kernel",  # pylint: disable=protected-access
3325                   attr_value_pb2.AttrValue(s=compat.as_bytes(kernel_label)))
3326    except KeyError:
3327      pass
3328
3329    # Apply the overriding op type for gradients if one has been specified for
3330    # this op type.
3331    try:
3332      mapped_op_type = self._gradient_override_map[op.type]
3333      op._set_attr("_gradient_op_type",  # pylint: disable=protected-access
3334                   attr_value_pb2.AttrValue(s=compat.as_bytes(mapped_op_type)))
3335    except KeyError:
3336      pass
3337
3338    self._record_op_seen_by_control_dependencies(op)
3339
3340    if compute_device:
3341      self._apply_device_functions(op)
3342
3343    if self._colocation_stack:
3344      all_colocation_groups = []
3345      for colocation_op in self._colocation_stack:
3346        all_colocation_groups.extend(colocation_op.colocation_groups())
3347        if colocation_op.device:
3348          # Make this device match the device of the colocated op, to provide
3349          # consistency between the device and the colocation property.
3350          if (op.device and pydev.canonical_name(op.device) !=
3351              pydev.canonical_name(colocation_op.device)):
3352            logging.warning("Tried to colocate %s with an op %s that had "
3353                            "a different device: %s vs %s. "
3354                            "Ignoring colocation property.", op.name,
3355                            colocation_op.name, op.device,
3356                            colocation_op.device)
3357          else:
3358            op._set_device(colocation_op.device)  # pylint: disable=protected-access
3359
3360      all_colocation_groups = sorted(set(all_colocation_groups))
3361      # pylint: disable=protected-access
3362      op._set_attr("_class", attr_value_pb2.AttrValue(
3363          list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups)))
3364      # pylint: enable=protected-access
3365
3366    # Sets "container" attribute if
3367    # (1) self._container is not None
3368    # (2) "is_stateful" is set in OpDef
3369    # (3) "container" attribute is in OpDef
3370    # (4) "container" attribute is None
3371    # TODO(skyewm): remove op.op_def check when _USE_C_API is removed.
3372    if self._container and op.op_def and op.op_def.is_stateful:
3373      try:
3374        container_attr = op.get_attr("container")
3375      except ValueError:
3376        # "container" attribute is not in OpDef
3377        pass
3378      else:
3379        if not container_attr:
3380          op._set_attr("container", attr_value_pb2.AttrValue(  # pylint: disable=protected-access
3381              s=compat.as_bytes(self._container)))
3382
3383  def _add_new_tf_operations(self, compute_devices=True):
3384    """Creates `Operations` in this graph for any new TF_Operations.
3385
3386    This is useful for when TF_Operations are indirectly created by the C API
3387    outside of the Operation constructor (e.g. by TF_ImportGraphDef,
3388    TF_FinishWhile). This ensures there are corresponding Operations for all
3389    TF_Operations in the underlying TF_Graph.
3390
3391    Args:
3392      compute_devices: (Optional.) If True, device functions will be executed
3393        to compute the device properties of each new Operation.
3394
3395    Returns:
3396      A list of the new `Operation` objects.
3397    """
3398    # Create all Operation objects before accessing their inputs since an op may
3399    # be created before its inputs.
3400    new_ops = [
3401        self._create_op_from_tf_operation(c_op, compute_device=compute_devices)
3402        for c_op in c_api_util.new_tf_operations(self)
3403    ]
3404
3405    for op in new_ops:
3406      new_control_inputs = self._control_dependencies_for_inputs(op.inputs)
3407      # pylint: disable=protected-access
3408      op._add_control_inputs(new_control_inputs)
3409      op._control_flow_post_processing()
3410      # pylint: enable=protected-access
3411
3412    return new_ops
3413
3414  def as_graph_element(self, obj, allow_tensor=True, allow_operation=True):
3415    """Returns the object referred to by `obj`, as an `Operation` or `Tensor`.
3416
3417    This function validates that `obj` represents an element of this
3418    graph, and gives an informative error message if it is not.
3419
3420    This function is the canonical way to get/validate an object of
3421    one of the allowed types from an external argument reference in the
3422    Session API.
3423
3424    This method may be called concurrently from multiple threads.
3425
3426    Args:
3427      obj: A `Tensor`, an `Operation`, or the name of a tensor or operation.
3428        Can also be any object with an `_as_graph_element()` method that returns
3429        a value of one of these types.
3430      allow_tensor: If true, `obj` may refer to a `Tensor`.
3431      allow_operation: If true, `obj` may refer to an `Operation`.
3432
3433    Returns:
3434      The `Tensor` or `Operation` in the Graph corresponding to `obj`.
3435
3436    Raises:
3437      TypeError: If `obj` is not a type we support attempting to convert
3438        to types.
3439      ValueError: If `obj` is of an appropriate type but invalid. For
3440        example, an invalid string.
3441      KeyError: If `obj` is not an object in the graph.
3442    """
3443    if self._finalized:
3444      return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
3445
3446    with self._lock:
3447      return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
3448
3449  def _as_graph_element_locked(self, obj, allow_tensor, allow_operation):
3450    """See `Graph.as_graph_element()` for details."""
3451    # The vast majority of this function is figuring
3452    # out what an API user might be doing wrong, so
3453    # that we can give helpful error messages.
3454    #
3455    # Ideally, it would be nice to split it up, but we
3456    # need context to generate nice error messages.
3457
3458    if allow_tensor and allow_operation:
3459      types_str = "Tensor or Operation"
3460    elif allow_tensor:
3461      types_str = "Tensor"
3462    elif allow_operation:
3463      types_str = "Operation"
3464    else:
3465      raise ValueError("allow_tensor and allow_operation can't both be False.")
3466
3467    temp_obj = _as_graph_element(obj)
3468    if temp_obj is not None:
3469      obj = temp_obj
3470
3471    # If obj appears to be a name...
3472    if isinstance(obj, compat.bytes_or_text_types):
3473      name = compat.as_str(obj)
3474
3475      if ":" in name and allow_tensor:
3476        # Looks like a Tensor name and can be a Tensor.
3477        try:
3478          op_name, out_n = name.split(":")
3479          out_n = int(out_n)
3480        except:
3481          raise ValueError("The name %s looks a like a Tensor name, but is "
3482                           "not a valid one. Tensor names must be of the "
3483                           "form \"<op_name>:<output_index>\"." % repr(name))
3484        if op_name in self._nodes_by_name:
3485          op = self._nodes_by_name[op_name]
3486        else:
3487          raise KeyError("The name %s refers to a Tensor which does not "
3488                         "exist. The operation, %s, does not exist in the "
3489                         "graph." % (repr(name), repr(op_name)))
3490        try:
3491          return op.outputs[out_n]
3492        except:
3493          raise KeyError("The name %s refers to a Tensor which does not "
3494                         "exist. The operation, %s, exists but only has "
3495                         "%s outputs." % (repr(name), repr(op_name),
3496                                          len(op.outputs)))
3497
3498      elif ":" in name and not allow_tensor:
3499        # Looks like a Tensor name but can't be a Tensor.
3500        raise ValueError("Name %s appears to refer to a Tensor, not a %s." %
3501                         (repr(name), types_str))
3502
3503      elif ":" not in name and allow_operation:
3504        # Looks like an Operation name and can be an Operation.
3505        if name not in self._nodes_by_name:
3506          raise KeyError("The name %s refers to an Operation not in the "
3507                         "graph." % repr(name))
3508        return self._nodes_by_name[name]
3509
3510      elif ":" not in name and not allow_operation:
3511        # Looks like an Operation name but can't be an Operation.
3512        if name in self._nodes_by_name:
3513          # Yep, it's an Operation name
3514          err_msg = ("The name %s refers to an Operation, not a %s." %
3515                     (repr(name), types_str))
3516        else:
3517          err_msg = ("The name %s looks like an (invalid) Operation name, "
3518                     "not a %s." % (repr(name), types_str))
3519        err_msg += (" Tensor names must be of the form "
3520                    "\"<op_name>:<output_index>\".")
3521        raise ValueError(err_msg)
3522
3523    elif isinstance(obj, Tensor) and allow_tensor:
3524      # Actually obj is just the object it's referring to.
3525      if obj.graph is not self:
3526        raise ValueError("Tensor %s is not an element of this graph." % obj)
3527      return obj
3528    elif isinstance(obj, Operation) and allow_operation:
3529      # Actually obj is just the object it's referring to.
3530      if obj.graph is not self:
3531        raise ValueError("Operation %s is not an element of this graph." % obj)
3532      return obj
3533    else:
3534      # We give up!
3535      raise TypeError("Can not convert a %s into a %s." % (type(obj).__name__,
3536                                                           types_str))
3537
3538  def get_operations(self):
3539    """Return the list of operations in the graph.
3540
3541    You can modify the operations in place, but modifications
3542    to the list such as inserts/delete have no effect on the
3543    list of operations known to the graph.
3544
3545    This method may be called concurrently from multiple threads.
3546
3547    Returns:
3548      A list of Operations.
3549    """
3550    if self._finalized:
3551      return list(self._nodes_by_id.values())
3552
3553    with self._lock:
3554      return list(self._nodes_by_id.values())
3555
3556  def get_operation_by_name(self, name):
3557    """Returns the `Operation` with the given `name`.
3558
3559    This method may be called concurrently from multiple threads.
3560
3561    Args:
3562      name: The name of the `Operation` to return.
3563
3564    Returns:
3565      The `Operation` with the given `name`.
3566
3567    Raises:
3568      TypeError: If `name` is not a string.
3569      KeyError: If `name` does not correspond to an operation in this graph.
3570    """
3571
3572    if not isinstance(name, six.string_types):
3573      raise TypeError("Operation names are strings (or similar), not %s." %
3574                      type(name).__name__)
3575    return self.as_graph_element(name, allow_tensor=False, allow_operation=True)
3576
3577  def _get_operation_by_name_unsafe(self, name):
3578    """Returns the `Operation` with the given `name`.
3579
3580    This is a internal unsafe version of get_operation_by_name. It skips many
3581    checks and does not have user friedly error messages but runs considerably
3582    faster. This method may be called concurrently from multiple threads.
3583
3584    Args:
3585      name: The name of the `Operation` to return.
3586
3587    Returns:
3588      The `Operation` with the given `name`.
3589
3590    Raises:
3591      KeyError: If `name` does not correspond to an operation in this graph.
3592    """
3593
3594    if self._finalized:
3595      return self._nodes_by_name[name]
3596
3597    with self._lock:
3598      return self._nodes_by_name[name]
3599
3600  def _get_operation_by_tf_operation(self, tf_oper):
3601    op_name = c_api.TF_OperationName(tf_oper)
3602    return self._get_operation_by_name_unsafe(op_name)
3603
3604  def get_tensor_by_name(self, name):
3605    """Returns the `Tensor` with the given `name`.
3606
3607    This method may be called concurrently from multiple threads.
3608
3609    Args:
3610      name: The name of the `Tensor` to return.
3611
3612    Returns:
3613      The `Tensor` with the given `name`.
3614
3615    Raises:
3616      TypeError: If `name` is not a string.
3617      KeyError: If `name` does not correspond to a tensor in this graph.
3618    """
3619    # Names should be strings.
3620    if not isinstance(name, six.string_types):
3621      raise TypeError("Tensor names are strings (or similar), not %s." %
3622                      type(name).__name__)
3623    return self.as_graph_element(name, allow_tensor=True, allow_operation=False)
3624
3625  def _get_tensor_by_tf_output(self, tf_output):
3626    """Returns the `Tensor` representing `tf_output`.
3627
3628    Note that there is only one such `Tensor`, i.e. multiple calls to this
3629    function with the same TF_Output value will always return the same `Tensor`
3630    object.
3631
3632    Args:
3633      tf_output: A wrapped `TF_Output` (the C API equivalent of `Tensor`).
3634
3635    Returns:
3636      The `Tensor` that represents `tf_output`.
3637    """
3638    op = self._get_operation_by_tf_operation(tf_output.oper)
3639    return op.outputs[tf_output.index]
3640
3641  def _next_id(self):
3642    """Id for next Operation instance. Also increments the internal id."""
3643    self._check_not_finalized()
3644    with self._lock:
3645      self._next_id_counter += 1
3646      return self._next_id_counter
3647
3648  @property
3649  def _last_id(self):
3650    return self._next_id_counter
3651
3652  def _get_op_def(self, type):  # pylint: disable=redefined-builtin
3653    """Returns the `OpDef` proto for `type`. `type` is a string."""
3654    if self._c_graph:
3655      with c_api_util.tf_buffer() as buf:
3656        with errors.raise_exception_on_not_ok_status() as status:
3657          # pylint: disable=protected-access
3658          c_api.TF_GraphGetOpDef(self._c_graph,
3659                                 compat.as_bytes(type), buf, status)
3660          # pylint: enable=protected-access
3661        data = c_api.TF_GetBuffer(buf)
3662      op_def = op_def_pb2.OpDef()
3663      op_def.ParseFromString(compat.as_bytes(data))
3664      return op_def
3665    else:
3666      return self._registered_ops[type]
3667
3668  def as_default(self):
3669    """Returns a context manager that makes this `Graph` the default graph.
3670
3671    This method should be used if you want to create multiple graphs
3672    in the same process. For convenience, a global default graph is
3673    provided, and all ops will be added to this graph if you do not
3674    create a new graph explicitly. Use this method with the `with` keyword
3675    to specify that ops created within the scope of a block should be
3676    added to this graph.
3677
3678    The default graph is a property of the current thread. If you
3679    create a new thread, and wish to use the default graph in that
3680    thread, you must explicitly add a `with g.as_default():` in that
3681    thread's function.
3682
3683    The following code examples are equivalent:
3684
3685    ```python
3686    # 1. Using Graph.as_default():
3687    g = tf.Graph()
3688    with g.as_default():
3689      c = tf.constant(5.0)
3690      assert c.graph is g
3691
3692    # 2. Constructing and making default:
3693    with tf.Graph().as_default() as g:
3694      c = tf.constant(5.0)
3695      assert c.graph is g
3696    ```
3697
3698    Returns:
3699      A context manager for using this graph as the default graph.
3700    """
3701    return _default_graph_stack.get_controller(self)
3702
3703  @property
3704  def collections(self):
3705    """Returns the names of the collections known to this graph."""
3706    return list(self._collections)
3707
3708  def add_to_collection(self, name, value):
3709    """Stores `value` in the collection with the given `name`.
3710
3711    Note that collections are not sets, so it is possible to add a value to
3712    a collection several times.
3713
3714    Args:
3715      name: The key for the collection. The `GraphKeys` class
3716        contains many standard names for collections.
3717      value: The value to add to the collection.
3718    """  # pylint: disable=g-doc-exception
3719    _assert_collection_is_ok(name)
3720    self._check_not_finalized()
3721    with self._lock:
3722      if name not in self._collections:
3723        self._collections[name] = [value]
3724      else:
3725        self._collections[name].append(value)
3726
3727  def add_to_collections(self, names, value):
3728    """Stores `value` in the collections given by `names`.
3729
3730    Note that collections are not sets, so it is possible to add a value to
3731    a collection several times. This function makes sure that duplicates in
3732    `names` are ignored, but it will not check for pre-existing membership of
3733    `value` in any of the collections in `names`.
3734
3735    `names` can be any iterable, but if `names` is a string, it is treated as a
3736    single collection name.
3737
3738    Args:
3739      names: The keys for the collections to add to. The `GraphKeys` class
3740        contains many standard names for collections.
3741      value: The value to add to the collections.
3742    """
3743    # Make sure names are unique, but treat strings as a single collection name
3744    names = (names,) if isinstance(names, six.string_types) else set(names)
3745    for name in names:
3746      self.add_to_collection(name, value)
3747
3748  def get_collection_ref(self, name):
3749    """Returns a list of values in the collection with the given `name`.
3750
3751    If the collection exists, this returns the list itself, which can
3752    be modified in place to change the collection.  If the collection does
3753    not exist, it is created as an empty list and the list is returned.
3754
3755    This is different from `get_collection()` which always returns a copy of
3756    the collection list if it exists and never creates an empty collection.
3757
3758    Args:
3759      name: The key for the collection. For example, the `GraphKeys` class
3760        contains many standard names for collections.
3761
3762    Returns:
3763      The list of values in the collection with the given `name`, or an empty
3764      list if no value has been added to that collection.
3765    """  # pylint: disable=g-doc-exception
3766    _assert_collection_is_ok(name)
3767    with self._lock:
3768      coll_list = self._collections.get(name, None)
3769      if coll_list is None:
3770        coll_list = []
3771        self._collections[name] = coll_list
3772      return coll_list
3773
3774  def get_collection(self, name, scope=None):
3775    """Returns a list of values in the collection with the given `name`.
3776
3777    This is different from `get_collection_ref()` which always returns the
3778    actual collection list if it exists in that it returns a new list each time
3779    it is called.
3780
3781    Args:
3782      name: The key for the collection. For example, the `GraphKeys` class
3783        contains many standard names for collections.
3784      scope: (Optional.) A string. If supplied, the resulting list is filtered
3785        to include only items whose `name` attribute matches `scope` using
3786        `re.match`. Items without a `name` attribute are never returned if a
3787        scope is supplied. The choice of `re.match` means that a `scope` without
3788        special tokens filters by prefix.
3789
3790    Returns:
3791      The list of values in the collection with the given `name`, or
3792      an empty list if no value has been added to that collection. The
3793      list contains the values in the order under which they were
3794      collected.
3795    """  # pylint: disable=g-doc-exception
3796    _assert_collection_is_ok(name)
3797    with self._lock:
3798      collection = self._collections.get(name, None)
3799      if collection is None:
3800        return []
3801      if scope is None:
3802        return list(collection)
3803      else:
3804        c = []
3805        regex = re.compile(scope)
3806        for item in collection:
3807          if hasattr(item, "name") and regex.match(item.name):
3808            c.append(item)
3809        return c
3810
3811  def get_all_collection_keys(self):
3812    """Returns a list of collections used in this graph."""
3813    with self._lock:
3814      return [x for x in self._collections if isinstance(x, six.string_types)]
3815
3816  def clear_collection(self, name):
3817    """Clears all values in a collection.
3818
3819    Args:
3820      name: The key for the collection. The `GraphKeys` class contains many
3821        standard names for collections.
3822    """
3823    self._check_not_finalized()
3824    with self._lock:
3825      if name in self._collections:
3826        del self._collections[name]
3827
3828  @tf_contextlib.contextmanager
3829  def _original_op(self, op):
3830    """Python 'with' handler to help annotate ops with their originator.
3831
3832    An op may have an 'original_op' property that indicates the op on which
3833    it was based. For example a replica op is based on the op that was
3834    replicated and a gradient op is based on the op that was differentiated.
3835
3836    All ops created in the scope of this 'with' handler will have
3837    the given 'op' as their original op.
3838
3839    Args:
3840      op: The Operation that all ops created in this scope will have as their
3841        original op.
3842
3843    Yields:
3844      Nothing.
3845    """
3846    old_original_op = self._default_original_op
3847    try:
3848      self._default_original_op = op
3849      yield
3850    finally:
3851      self._default_original_op = old_original_op
3852
3853  # pylint: disable=g-doc-return-or-yield,line-too-long
3854  @tf_contextlib.contextmanager
3855  def name_scope(self, name):
3856    r"""Returns a context manager that creates hierarchical names for operations.
3857
3858    A graph maintains a stack of name scopes. A `with name_scope(...):`
3859    statement pushes a new name onto the stack for the lifetime of the context.
3860
3861    The `name` argument will be interpreted as follows:
3862
3863    * A string (not ending with '/') will create a new name scope, in which
3864      `name` is appended to the prefix of all operations created in the
3865      context. If `name` has been used before, it will be made unique by
3866      calling `self.unique_name(name)`.
3867    * A scope previously captured from a `with g.name_scope(...) as
3868      scope:` statement will be treated as an "absolute" name scope, which
3869      makes it possible to re-enter existing scopes.
3870    * A value of `None` or the empty string will reset the current name scope
3871      to the top-level (empty) name scope.
3872
3873    For example:
3874
3875    ```python
3876    with tf.Graph().as_default() as g:
3877      c = tf.constant(5.0, name="c")
3878      assert c.op.name == "c"
3879      c_1 = tf.constant(6.0, name="c")
3880      assert c_1.op.name == "c_1"
3881
3882      # Creates a scope called "nested"
3883      with g.name_scope("nested") as scope:
3884        nested_c = tf.constant(10.0, name="c")
3885        assert nested_c.op.name == "nested/c"
3886
3887        # Creates a nested scope called "inner".
3888        with g.name_scope("inner"):
3889          nested_inner_c = tf.constant(20.0, name="c")
3890          assert nested_inner_c.op.name == "nested/inner/c"
3891
3892        # Create a nested scope called "inner_1".
3893        with g.name_scope("inner"):
3894          nested_inner_1_c = tf.constant(30.0, name="c")
3895          assert nested_inner_1_c.op.name == "nested/inner_1/c"
3896
3897          # Treats `scope` as an absolute name scope, and
3898          # switches to the "nested/" scope.
3899          with g.name_scope(scope):
3900            nested_d = tf.constant(40.0, name="d")
3901            assert nested_d.op.name == "nested/d"
3902
3903            with g.name_scope(""):
3904              e = tf.constant(50.0, name="e")
3905              assert e.op.name == "e"
3906    ```
3907
3908    The name of the scope itself can be captured by `with
3909    g.name_scope(...) as scope:`, which stores the name of the scope
3910    in the variable `scope`. This value can be used to name an
3911    operation that represents the overall result of executing the ops
3912    in a scope. For example:
3913
3914    ```python
3915    inputs = tf.constant(...)
3916    with g.name_scope('my_layer') as scope:
3917      weights = tf.Variable(..., name="weights")
3918      biases = tf.Variable(..., name="biases")
3919      affine = tf.matmul(inputs, weights) + biases
3920      output = tf.nn.relu(affine, name=scope)
3921    ```
3922
3923    NOTE: This constructor validates the given `name`. Valid scope
3924    names match one of the following regular expressions:
3925
3926        [A-Za-z0-9.][A-Za-z0-9_.\\-/]* (for scopes at the root)
3927        [A-Za-z0-9_.\\-/]* (for other scopes)
3928
3929    Args:
3930      name: A name for the scope.
3931
3932    Returns:
3933      A context manager that installs `name` as a new name scope.
3934
3935    Raises:
3936      ValueError: If `name` is not a valid scope name, according to the rules
3937        above.
3938    """
3939    if name:
3940      if isinstance(name, compat.bytes_or_text_types):
3941        name = compat.as_str(name)
3942
3943      if self._name_stack:
3944        # Scopes created in a nested scope may have initial characters
3945        # that are illegal as the initial character of an op name
3946        # (viz. '-', '\', '/', and '_').
3947        if not _VALID_SCOPE_NAME_REGEX.match(name):
3948          raise ValueError("'%s' is not a valid scope name" % name)
3949      else:
3950        # Scopes created in the root must match the more restrictive
3951        # op name regex, which constrains the initial character.
3952        if not _VALID_OP_NAME_REGEX.match(name):
3953          raise ValueError("'%s' is not a valid scope name" % name)
3954    try:
3955      old_stack = self._name_stack
3956      if not name:  # Both for name=None and name="" we re-set to empty scope.
3957        new_stack = None
3958      elif name[-1] == "/":
3959        new_stack = _name_from_scope_name(name)
3960      else:
3961        new_stack = self.unique_name(name)
3962      self._name_stack = new_stack
3963      yield "" if new_stack is None else new_stack + "/"
3964    finally:
3965      self._name_stack = old_stack
3966
3967  # pylint: enable=g-doc-return-or-yield,line-too-long
3968
3969  def unique_name(self, name, mark_as_used=True):
3970    """Return a unique operation name for `name`.
3971
3972    Note: You rarely need to call `unique_name()` directly.  Most of
3973    the time you just need to create `with g.name_scope()` blocks to
3974    generate structured names.
3975
3976    `unique_name` is used to generate structured names, separated by
3977    `"/"`, to help identify operations when debugging a graph.
3978    Operation names are displayed in error messages reported by the
3979    TensorFlow runtime, and in various visualization tools such as
3980    TensorBoard.
3981
3982    If `mark_as_used` is set to `True`, which is the default, a new
3983    unique name is created and marked as in use. If it's set to `False`,
3984    the unique name is returned without actually being marked as used.
3985    This is useful when the caller simply wants to know what the name
3986    to be created will be.
3987
3988    Args:
3989      name: The name for an operation.
3990      mark_as_used: Whether to mark this name as being used.
3991
3992    Returns:
3993      A string to be passed to `create_op()` that will be used
3994      to name the operation being created.
3995    """
3996    if self._name_stack:
3997      name = self._name_stack + "/" + name
3998    i = self._names_in_use.get(name, 0)
3999    # Increment the number for "name".
4000    if mark_as_used:
4001      self._names_in_use[name] = i + 1
4002    if i > 0:
4003      base_name = name
4004      # Make sure the composed name is not already used.
4005      while name in self._names_in_use:
4006        name = "%s_%d" % (base_name, i)
4007        i += 1
4008      # Mark the composed name as used in case someone wants
4009      # to call unique_name("name_1").
4010      if mark_as_used:
4011        self._names_in_use[name] = 1
4012    return name
4013
4014  def get_name_scope(self):
4015    """Returns the current name scope.
4016
4017    For example:
4018
4019    ```python
4020    with tf.name_scope('scope1'):
4021      with tf.name_scope('scope2'):
4022        print(tf.get_default_graph().get_name_scope())
4023    ```
4024    would print the string `scope1/scope2`.
4025
4026    Returns:
4027      A string representing the current name scope.
4028    """
4029    return self._name_stack
4030
4031  @tf_contextlib.contextmanager
4032  def colocate_with(self, op, ignore_existing=False):
4033    """Returns a context manager that specifies an op to colocate with.
4034
4035    Note: this function is not for public use, only for internal libraries.
4036
4037    For example:
4038
4039    ```python
4040    a = tf.Variable([1.0])
4041    with g.colocate_with(a):
4042      b = tf.constant(1.0)
4043      c = tf.add(a, b)
4044    ```
4045
4046    `b` and `c` will always be colocated with `a`, no matter where `a`
4047    is eventually placed.
4048
4049    **NOTE** Using a colocation scope resets any existing device constraints.
4050
4051    If `op` is `None` then `ignore_existing` must be `True` and the new
4052    scope resets all colocation and device constraints.
4053
4054    Args:
4055      op: The op to colocate all created ops with, or `None`.
4056      ignore_existing: If true, only applies colocation of this op within
4057        the context, rather than applying all colocation properties
4058        on the stack.  If `op` is `None`, this value must be `True`.
4059
4060    Raises:
4061      ValueError: if op is None but ignore_existing is False.
4062
4063    Yields:
4064      A context manager that specifies the op with which to colocate
4065      newly created ops.
4066
4067    """
4068    if op is None and not ignore_existing:
4069      raise ValueError("Trying to reset colocation (op is None) but "
4070                       "ignore_existing is not True")
4071
4072    if op is not None and not isinstance(op, Operation):
4073      # We always want to colocate with the reference op.
4074      op = internal_convert_to_tensor_or_indexed_slices(op, as_ref=True).op
4075
4076    # By default, colocate_with resets the device function stack,
4077    # since colocate_with is typically used in specific internal
4078    # library functions where colocation is intended to be "stronger"
4079    # than device functions.
4080    #
4081    # In the future, a caller may specify that device_functions win
4082    # over colocation, in which case we can add support.
4083    device_fn_tmp = self._device_function_stack
4084    self._device_function_stack = []
4085
4086    if ignore_existing:
4087      current_stack = self._colocation_stack
4088      self._colocation_stack = []
4089
4090    if op is not None:
4091      self._colocation_stack.append(op)
4092
4093    try:
4094      yield
4095    finally:
4096      # Restore device function stack
4097      self._device_function_stack = device_fn_tmp
4098      if op is not None:
4099        self._colocation_stack.pop()
4100
4101      # Reset the colocation stack if requested.
4102      if ignore_existing:
4103        self._colocation_stack = current_stack
4104
4105  @tf_contextlib.contextmanager
4106  def device(self, device_name_or_function):
4107    # pylint: disable=line-too-long
4108    """Returns a context manager that specifies the default device to use.
4109
4110    The `device_name_or_function` argument may either be a device name
4111    string, a device function, or None:
4112
4113    * If it is a device name string, all operations constructed in
4114      this context will be assigned to the device with that name, unless
4115      overridden by a nested `device()` context.
4116    * If it is a function, it will be treated as a function from
4117      Operation objects to device name strings, and invoked each time
4118      a new Operation is created. The Operation will be assigned to
4119      the device with the returned name.
4120    * If it is None, all `device()` invocations from the enclosing context
4121      will be ignored.
4122
4123    For information about the valid syntax of device name strings, see
4124    the documentation in
4125    [`DeviceNameUtils`](https://www.tensorflow.org/code/tensorflow/core/util/device_name_utils.h).
4126
4127    For example:
4128
4129    ```python
4130    with g.device('/device:GPU:0'):
4131      # All operations constructed in this context will be placed
4132      # on GPU 0.
4133      with g.device(None):
4134        # All operations constructed in this context will have no
4135        # assigned device.
4136
4137    # Defines a function from `Operation` to device string.
4138    def matmul_on_gpu(n):
4139      if n.type == "MatMul":
4140        return "/device:GPU:0"
4141      else:
4142        return "/cpu:0"
4143
4144    with g.device(matmul_on_gpu):
4145      # All operations of type "MatMul" constructed in this context
4146      # will be placed on GPU 0; all other operations will be placed
4147      # on CPU 0.
4148    ```
4149
4150    **N.B.** The device scope may be overridden by op wrappers or
4151    other library code. For example, a variable assignment op
4152    `v.assign()` must be colocated with the `tf.Variable` `v`, and
4153    incompatible device scopes will be ignored.
4154
4155    Args:
4156      device_name_or_function: The device name or function to use in
4157        the context.
4158
4159    Yields:
4160      A context manager that specifies the default device to use for newly
4161      created ops.
4162
4163    """
4164    # pylint: enable=line-too-long
4165    if (device_name_or_function is not None and
4166        not callable(device_name_or_function)):
4167      device_function = pydev.merge_device(device_name_or_function)
4168    else:
4169      device_function = device_name_or_function
4170
4171    try:
4172      self._device_function_stack.append(device_function)
4173      yield
4174    finally:
4175      self._device_function_stack.pop()
4176
4177  def _apply_device_functions(self, op):
4178    """Applies the current device function stack to the given operation."""
4179    # Apply any device functions in reverse order, so that the most recently
4180    # pushed function has the first chance to apply a device to the op.
4181    # We apply here because the result can depend on the Operation's
4182    # signature, which is computed in the Operation constructor.
4183    for device_function in reversed(self._device_function_stack):
4184      if device_function is None:
4185        break
4186      op._set_device(device_function(op))  # pylint: disable=protected-access
4187
4188  # pylint: disable=g-doc-return-or-yield
4189  @tf_contextlib.contextmanager
4190  def container(self, container_name):
4191    """Returns a context manager that specifies the resource container to use.
4192
4193    Stateful operations, such as variables and queues, can maintain their
4194    states on devices so that they can be shared by multiple processes.
4195    A resource container is a string name under which these stateful
4196    operations are tracked. These resources can be released or cleared
4197    with `tf.Session.reset()`.
4198
4199    For example:
4200
4201    ```python
4202    with g.container('experiment0'):
4203      # All stateful Operations constructed in this context will be placed
4204      # in resource container "experiment0".
4205      v1 = tf.Variable([1.0])
4206      v2 = tf.Variable([2.0])
4207      with g.container("experiment1"):
4208        # All stateful Operations constructed in this context will be
4209        # placed in resource container "experiment1".
4210        v3 = tf.Variable([3.0])
4211        q1 = tf.FIFOQueue(10, tf.float32)
4212      # All stateful Operations constructed in this context will be
4213      # be created in the "experiment0".
4214      v4 = tf.Variable([4.0])
4215      q1 = tf.FIFOQueue(20, tf.float32)
4216      with g.container(""):
4217        # All stateful Operations constructed in this context will be
4218        # be placed in the default resource container.
4219        v5 = tf.Variable([5.0])
4220        q3 = tf.FIFOQueue(30, tf.float32)
4221
4222    # Resets container "experiment0", after which the state of v1, v2, v4, q1
4223    # will become undefined (such as uninitialized).
4224    tf.Session.reset(target, ["experiment0"])
4225    ```
4226
4227    Args:
4228      container_name: container name string.
4229
4230    Returns:
4231      A context manager for defining resource containers for stateful ops,
4232        yields the container name.
4233    """
4234    original_container = self._container
4235    try:
4236      self._container = container_name
4237      yield self._container
4238    finally:
4239      self._container = original_container
4240
4241  # pylint: enable=g-doc-return-or-yield
4242
4243  class _ControlDependenciesController(object):
4244    """Context manager for `control_dependencies()`."""
4245
4246    def __init__(self, graph, control_inputs):
4247      """Create a new `_ControlDependenciesController`.
4248
4249      A `_ControlDependenciesController` is the context manager for
4250      `with tf.control_dependencies()` blocks.  These normally nest,
4251      as described in the documentation for `control_dependencies()`.
4252
4253      The `control_inputs` argument list control dependencies that must be
4254      added to the current set of control dependencies.  Because of
4255      uniquification the set can be empty even if the caller passed a list of
4256      ops.  The special value `None` indicates that we want to start a new
4257      empty set of control dependencies instead of extending the current set.
4258
4259      In that case we also clear the current control flow context, which is an
4260      additional mechanism to add control dependencies.
4261
4262      Args:
4263        graph: The graph that this controller is managing.
4264        control_inputs: List of ops to use as control inputs in addition
4265          to the current control dependencies.  None to indicate that
4266          the dependencies should be cleared.
4267      """
4268      self._graph = graph
4269      if control_inputs is None:
4270        self._control_inputs_val = []
4271        self._new_stack = True
4272      else:
4273        self._control_inputs_val = control_inputs
4274        self._new_stack = False
4275      self._seen_nodes = set()
4276      self._old_stack = None
4277      self._old_control_flow_context = None
4278
4279# pylint: disable=protected-access
4280
4281    def __enter__(self):
4282      if self._new_stack:
4283        # Clear the control_dependencies graph.
4284        self._old_stack = self._graph._control_dependencies_stack
4285        self._graph._control_dependencies_stack = []
4286        # Clear the control_flow_context too.
4287        self._old_control_flow_context = self._graph._get_control_flow_context()
4288        self._graph._set_control_flow_context(None)
4289      self._graph._push_control_dependencies_controller(self)
4290
4291    def __exit__(self, unused_type, unused_value, unused_traceback):
4292      self._graph._pop_control_dependencies_controller(self)
4293      if self._new_stack:
4294        self._graph._control_dependencies_stack = self._old_stack
4295        self._graph._set_control_flow_context(self._old_control_flow_context)
4296
4297# pylint: enable=protected-access
4298
4299    @property
4300    def control_inputs(self):
4301      return self._control_inputs_val
4302
4303    def add_op(self, op):
4304      self._seen_nodes.add(op)
4305
4306    def op_in_group(self, op):
4307      return op in self._seen_nodes
4308
4309  def _push_control_dependencies_controller(self, controller):
4310    self._control_dependencies_stack.append(controller)
4311
4312  def _pop_control_dependencies_controller(self, controller):
4313    assert self._control_dependencies_stack[-1] is controller
4314    self._control_dependencies_stack.pop()
4315
4316  def _current_control_dependencies(self):
4317    ret = set()
4318    for controller in self._control_dependencies_stack:
4319      for op in controller.control_inputs:
4320        ret.add(op)
4321    return ret
4322
4323  def _control_dependencies_for_inputs(self, input_ops):
4324    """For an op that takes `input_ops` as inputs, compute control inputs.
4325
4326    The returned control dependencies should yield an execution that
4327    is equivalent to adding all control inputs in
4328    self._control_dependencies_stack to a newly created op. However,
4329    this function attempts to prune the returned control dependencies
4330    by observing that nodes created within the same `with
4331    control_dependencies(...):` block may have data dependencies that make
4332    the explicit approach redundant.
4333
4334    Args:
4335      input_ops: The data input ops for an op to be created.
4336
4337    Returns:
4338      A list of control inputs for the op to be created.
4339    """
4340    ret = []
4341    for controller in self._control_dependencies_stack:
4342      # If any of the input_ops already depends on the inputs from controller,
4343      # we say that the new op is dominated (by that input), and we therefore
4344      # do not need to add control dependencies for this controller's inputs.
4345      dominated = False
4346      for op in input_ops:
4347        if controller.op_in_group(op):
4348          dominated = True
4349          break
4350      if not dominated:
4351        # Don't add a control input if we already have a data dependency on i.
4352        # NOTE(mrry): We do not currently track transitive data dependencies,
4353        #   so we may add redundant control inputs.
4354        ret.extend([c for c in controller.control_inputs if c not in input_ops])
4355    return ret
4356
4357  def _record_op_seen_by_control_dependencies(self, op):
4358    """Record that the given op depends on all registered control dependencies.
4359
4360    Args:
4361      op: An Operation.
4362    """
4363    for controller in self._control_dependencies_stack:
4364      controller.add_op(op)
4365
4366  def control_dependencies(self, control_inputs):
4367    """Returns a context manager that specifies control dependencies.
4368
4369    Use with the `with` keyword to specify that all operations constructed
4370    within the context should have control dependencies on
4371    `control_inputs`. For example:
4372
4373    ```python
4374    with g.control_dependencies([a, b, c]):
4375      # `d` and `e` will only run after `a`, `b`, and `c` have executed.
4376      d = ...
4377      e = ...
4378    ```
4379
4380    Multiple calls to `control_dependencies()` can be nested, and in
4381    that case a new `Operation` will have control dependencies on the union
4382    of `control_inputs` from all active contexts.
4383
4384    ```python
4385    with g.control_dependencies([a, b]):
4386      # Ops constructed here run after `a` and `b`.
4387      with g.control_dependencies([c, d]):
4388        # Ops constructed here run after `a`, `b`, `c`, and `d`.
4389    ```
4390
4391    You can pass None to clear the control dependencies:
4392
4393    ```python
4394    with g.control_dependencies([a, b]):
4395      # Ops constructed here run after `a` and `b`.
4396      with g.control_dependencies(None):
4397        # Ops constructed here run normally, not waiting for either `a` or `b`.
4398        with g.control_dependencies([c, d]):
4399          # Ops constructed here run after `c` and `d`, also not waiting
4400          # for either `a` or `b`.
4401    ```
4402
4403    *N.B.* The control dependencies context applies *only* to ops that
4404    are constructed within the context. Merely using an op or tensor
4405    in the context does not add a control dependency. The following
4406    example illustrates this point:
4407
4408    ```python
4409    # WRONG
4410    def my_func(pred, tensor):
4411      t = tf.matmul(tensor, tensor)
4412      with tf.control_dependencies([pred]):
4413        # The matmul op is created outside the context, so no control
4414        # dependency will be added.
4415        return t
4416
4417    # RIGHT
4418    def my_func(pred, tensor):
4419      with tf.control_dependencies([pred]):
4420        # The matmul op is created in the context, so a control dependency
4421        # will be added.
4422        return tf.matmul(tensor, tensor)
4423    ```
4424
4425    Args:
4426      control_inputs: A list of `Operation` or `Tensor` objects which
4427        must be executed or computed before running the operations
4428        defined in the context.  Can also be `None` to clear the control
4429        dependencies.
4430
4431    Returns:
4432     A context manager that specifies control dependencies for all
4433     operations constructed within the context.
4434
4435    Raises:
4436      TypeError: If `control_inputs` is not a list of `Operation` or
4437        `Tensor` objects.
4438    """
4439    if control_inputs is None:
4440      return self._ControlDependenciesController(self, None)
4441    # First convert the inputs to ops, and deduplicate them.
4442    # NOTE(mrry): Other than deduplication, we do not currently track direct
4443    #   or indirect dependencies between control_inputs, which may result in
4444    #   redundant control inputs.
4445    control_ops = []
4446    current = self._current_control_dependencies()
4447    for c in control_inputs:
4448      if isinstance(c, IndexedSlices):
4449        c = c.op
4450      c = self.as_graph_element(c)
4451      if isinstance(c, Tensor):
4452        c = c.op
4453      elif not isinstance(c, Operation):
4454        raise TypeError("Control input must be Operation or Tensor: %s" % c)
4455      if c not in current:
4456        control_ops.append(c)
4457        current.add(c)
4458    return self._ControlDependenciesController(self, control_ops)
4459
4460  # pylint: disable=g-doc-return-or-yield
4461  @tf_contextlib.contextmanager
4462  def _attr_scope(self, attr_map):
4463    """EXPERIMENTAL: A context manager for setting attributes on operators.
4464
4465    This context manager can be used to add additional
4466    attributes to operators within the scope of the context.
4467
4468    For example:
4469
4470       with ops.Graph().as_default() as g:
4471         f_1 = Foo()  # No extra attributes
4472         with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=False)}):
4473           f_2 = Foo()  # Additional attribute _a=False
4474           with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=True)}):
4475             f_3 = Foo()  # Additional attribute _a=False
4476             with g._attr_scope({"_a": None}):
4477               f_4 = Foo()  # No additional attributes.
4478
4479    Args:
4480      attr_map: A dictionary mapping attr name strings to
4481        AttrValue protocol buffers or None.
4482
4483    Returns:
4484      A context manager that sets the kernel label to be used for one or more
4485      ops created in that context.
4486
4487    Raises:
4488      TypeError: If attr_map is not a dictionary mapping
4489        strings to AttrValue protobufs.
4490    """
4491    if not isinstance(attr_map, dict):
4492      raise TypeError("attr_map must be a dictionary mapping "
4493                      "strings to AttrValue protocol buffers")
4494    # The saved_attrs dictionary stores any currently-set labels that
4495    # will be overridden by this context manager.
4496    saved_attrs = {}
4497    # Install the given attribute
4498    for name, attr in attr_map.items():
4499      if not (isinstance(name, six.string_types) and
4500              (isinstance(attr, (type(None), attr_value_pb2.AttrValue)) or
4501               callable(attr))):
4502        raise TypeError("attr_map must be a dictionary mapping "
4503                        "strings to AttrValue protocol buffers or "
4504                        "callables that emit AttrValue protocol buffers")
4505      try:
4506        saved_attrs[name] = self._attr_scope_map[name]
4507      except KeyError:
4508        pass
4509      if attr is None:
4510        del self._attr_scope_map[name]
4511      else:
4512        self._attr_scope_map[name] = attr
4513    try:
4514      yield  # The code within the context runs here.
4515    finally:
4516      # Remove the attributes set for this context, and restore any saved
4517      # attributes.
4518      for name, attr in attr_map.items():
4519        try:
4520          self._attr_scope_map[name] = saved_attrs[name]
4521        except KeyError:
4522          del self._attr_scope_map[name]
4523
4524  # pylint: enable=g-doc-return-or-yield
4525
4526  # pylint: disable=g-doc-return-or-yield
4527  @tf_contextlib.contextmanager
4528  def _kernel_label_map(self, op_to_kernel_label_map):
4529    """EXPERIMENTAL: A context manager for setting kernel labels.
4530
4531    This context manager can be used to select particular
4532    implementations of kernels within the scope of the context.
4533
4534    For example:
4535
4536        with ops.Graph().as_default() as g:
4537          f_1 = Foo()  # Uses the default registered kernel for the Foo op.
4538          with g.kernel_label_map({"Foo": "v_2"}):
4539            f_2 = Foo()  # Uses the registered kernel with label "v_2"
4540                         # for the Foo op.
4541            with g.kernel_label_map({"Foo": "v_3"}):
4542              f_3 = Foo()  # Uses the registered kernel with label "v_3"
4543                           # for the Foo op.
4544              with g.kernel_label_map({"Foo": ""}):
4545                f_4 = Foo()  # Uses the default registered kernel
4546                             # for the Foo op.
4547
4548    Args:
4549      op_to_kernel_label_map: A dictionary mapping op type strings to
4550        kernel label strings.
4551
4552    Returns:
4553      A context manager that sets the kernel label to be used for one or more
4554      ops created in that context.
4555
4556    Raises:
4557      TypeError: If op_to_kernel_label_map is not a dictionary mapping
4558        strings to strings.
4559    """
4560    if not isinstance(op_to_kernel_label_map, dict):
4561      raise TypeError("op_to_kernel_label_map must be a dictionary mapping "
4562                      "strings to strings")
4563    # The saved_labels dictionary stores any currently-set labels that
4564    # will be overridden by this context manager.
4565    saved_labels = {}
4566    # Install the given label
4567    for op_type, label in op_to_kernel_label_map.items():
4568      if not (isinstance(op_type, six.string_types) and
4569              isinstance(label, six.string_types)):
4570        raise TypeError("op_to_kernel_label_map must be a dictionary mapping "
4571                        "strings to strings")
4572      try:
4573        saved_labels[op_type] = self._op_to_kernel_label_map[op_type]
4574      except KeyError:
4575        pass
4576      self._op_to_kernel_label_map[op_type] = label
4577    try:
4578      yield  # The code within the context runs here.
4579    finally:
4580      # Remove the labels set for this context, and restore any saved labels.
4581      for op_type, label in op_to_kernel_label_map.items():
4582        try:
4583          self._op_to_kernel_label_map[op_type] = saved_labels[op_type]
4584        except KeyError:
4585          del self._op_to_kernel_label_map[op_type]
4586
4587  # pylint: enable=g-doc-return-or-yield
4588
4589  # pylint: disable=g-doc-return-or-yield
4590  @tf_contextlib.contextmanager
4591  def gradient_override_map(self, op_type_map):
4592    """EXPERIMENTAL: A context manager for overriding gradient functions.
4593
4594    This context manager can be used to override the gradient function
4595    that will be used for ops within the scope of the context.
4596
4597    For example:
4598
4599    ```python
4600    @tf.RegisterGradient("CustomSquare")
4601    def _custom_square_grad(op, grad):
4602      # ...
4603
4604    with tf.Graph().as_default() as g:
4605      c = tf.constant(5.0)
4606      s_1 = tf.square(c)  # Uses the default gradient for tf.square.
4607      with g.gradient_override_map({"Square": "CustomSquare"}):
4608        s_2 = tf.square(s_2)  # Uses _custom_square_grad to compute the
4609                              # gradient of s_2.
4610    ```
4611
4612    Args:
4613      op_type_map: A dictionary mapping op type strings to alternative op
4614        type strings.
4615
4616    Returns:
4617      A context manager that sets the alternative op type to be used for one
4618      or more ops created in that context.
4619
4620    Raises:
4621      TypeError: If `op_type_map` is not a dictionary mapping strings to
4622        strings.
4623    """
4624    if not isinstance(op_type_map, dict):
4625      raise TypeError("op_type_map must be a dictionary mapping "
4626                      "strings to strings")
4627    # The saved_mappings dictionary stores any currently-set mappings that
4628    # will be overridden by this context manager.
4629    saved_mappings = {}
4630    # Install the given label
4631    for op_type, mapped_op_type in op_type_map.items():
4632      if not (isinstance(op_type, six.string_types) and
4633              isinstance(mapped_op_type, six.string_types)):
4634        raise TypeError("op_type_map must be a dictionary mapping "
4635                        "strings to strings")
4636      try:
4637        saved_mappings[op_type] = self._gradient_override_map[op_type]
4638      except KeyError:
4639        pass
4640      self._gradient_override_map[op_type] = mapped_op_type
4641    try:
4642      yield  # The code within the context runs here.
4643    finally:
4644      # Remove the labels set for this context, and restore any saved labels.
4645      for op_type, mapped_op_type in op_type_map.items():
4646        try:
4647          self._gradient_override_map[op_type] = saved_mappings[op_type]
4648        except KeyError:
4649          del self._gradient_override_map[op_type]
4650
4651  # pylint: enable=g-doc-return-or-yield
4652
4653  def prevent_feeding(self, tensor):
4654    """Marks the given `tensor` as unfeedable in this graph."""
4655    self._unfeedable_tensors.add(tensor)
4656
4657  def is_feedable(self, tensor):
4658    """Returns `True` if and only if `tensor` is feedable."""
4659    return tensor not in self._unfeedable_tensors
4660
4661  def prevent_fetching(self, op):
4662    """Marks the given `op` as unfetchable in this graph."""
4663    self._unfetchable_ops.add(op)
4664
4665  def is_fetchable(self, tensor_or_op):
4666    """Returns `True` if and only if `tensor_or_op` is fetchable."""
4667    if isinstance(tensor_or_op, Tensor):
4668      return tensor_or_op.op not in self._unfetchable_ops
4669    else:
4670      return tensor_or_op not in self._unfetchable_ops
4671
4672
4673# TODO(agarwal): currently device directives in an outer eager scope will not
4674# apply to inner graph mode code. Fix that.
4675
4676
4677@tf_export("device")
4678def device(device_name_or_function):
4679  """Wrapper for `Graph.device()` using the default graph.
4680
4681  See
4682  @{tf.Graph.device}
4683  for more details.
4684
4685  Args:
4686    device_name_or_function: The device name or function to use in
4687      the context.
4688
4689  Returns:
4690    A context manager that specifies the default device to use for newly
4691    created ops.
4692
4693  Raises:
4694    RuntimeError: If eager execution is enabled and a function is passed in.
4695  """
4696  if context.in_graph_mode():
4697    return get_default_graph().device(device_name_or_function)
4698  else:
4699    # TODO(agarwal): support device functions in EAGER mode.
4700    if callable(device_name_or_function):
4701      raise RuntimeError(
4702          "tf.device does not support functions when eager execution "
4703          "is enabled.")
4704    return context.device(device_name_or_function)
4705
4706
4707@tf_export("container")
4708def container(container_name):
4709  """Wrapper for `Graph.container()` using the default graph.
4710
4711  Args:
4712    container_name: The container string to use in the context.
4713
4714  Returns:
4715    A context manager that specifies the default container to use for newly
4716    created stateful ops.
4717  """
4718  return get_default_graph().container(container_name)
4719
4720
4721@tf_export("colocate_with")
4722def colocate_with(op, ignore_existing=False):
4723  if context.in_graph_mode():
4724    return get_default_graph().colocate_with(op, ignore_existing)
4725  else:
4726    if op is not None:
4727      return device(op.device)
4728    else:
4729      return _NullContextmanager()
4730
4731
4732@tf_export("control_dependencies")
4733def control_dependencies(control_inputs):
4734  """Wrapper for `Graph.control_dependencies()` using the default graph.
4735
4736  See @{tf.Graph.control_dependencies}
4737  for more details.
4738
4739  Args:
4740    control_inputs: A list of `Operation` or `Tensor` objects which
4741      must be executed or computed before running the operations
4742      defined in the context.  Can also be `None` to clear the control
4743      dependencies.
4744
4745  Returns:
4746   A context manager that specifies control dependencies for all
4747   operations constructed within the context.
4748  """
4749  if context.in_graph_mode():
4750    return get_default_graph().control_dependencies(control_inputs)
4751  else:
4752    return _NullContextmanager()
4753
4754
4755class _DefaultStack(threading.local):
4756  """A thread-local stack of objects for providing implicit defaults."""
4757
4758  def __init__(self):
4759    super(_DefaultStack, self).__init__()
4760    self._enforce_nesting = True
4761    self.stack = []
4762
4763  def get_default(self):
4764    return self.stack[-1] if len(self.stack) >= 1 else None
4765
4766  def reset(self):
4767    self.stack = []
4768
4769  def is_cleared(self):
4770    return not self.stack
4771
4772  @property
4773  def enforce_nesting(self):
4774    return self._enforce_nesting
4775
4776  @enforce_nesting.setter
4777  def enforce_nesting(self, value):
4778    self._enforce_nesting = value
4779
4780  @tf_contextlib.contextmanager
4781  def get_controller(self, default):
4782    """A context manager for manipulating a default stack."""
4783    try:
4784      self.stack.append(default)
4785      yield default
4786    finally:
4787      # stack may be empty if reset() was called
4788      if self.stack:
4789        if self._enforce_nesting:
4790          if self.stack[-1] is not default:
4791            raise AssertionError(
4792                "Nesting violated for default stack of %s objects" %
4793                type(default))
4794          self.stack.pop()
4795        else:
4796          self.stack.remove(default)
4797
4798
4799_default_session_stack = _DefaultStack()  # pylint: disable=protected-access
4800
4801
4802def default_session(session):
4803  """Python "with" handler for defining a default session.
4804
4805  This function provides a means of registering a session for handling
4806  Tensor.eval() and Operation.run() calls. It is primarily intended for use
4807  by session.Session, but can be used with any object that implements
4808  the Session.run() interface.
4809
4810  Use with the "with" keyword to specify that Tensor.eval() and Operation.run()
4811  invocations within the scope of a block should be executed by a particular
4812  session.
4813
4814  The default session applies to the current thread only, so it is always
4815  possible to inspect the call stack and determine the scope of a default
4816  session. If you create a new thread, and wish to use the default session
4817  in that thread, you must explicitly add a "with ops.default_session(sess):"
4818  block in that thread's function.
4819
4820  Example:
4821    The following code examples are equivalent:
4822
4823    # 1. Using the Session object directly:
4824    sess = ...
4825    c = tf.constant(5.0)
4826    sess.run(c)
4827
4828    # 2. Using default_session():
4829    sess = ...
4830    with ops.default_session(sess):
4831      c = tf.constant(5.0)
4832      result = c.eval()
4833
4834    # 3. Overriding default_session():
4835    sess = ...
4836    with ops.default_session(sess):
4837      c = tf.constant(5.0)
4838      with ops.default_session(...):
4839        c.eval(session=sess)
4840
4841  Args:
4842    session: The session to be installed as the default session.
4843
4844  Returns:
4845    A context manager for the default session.
4846  """
4847  return _default_session_stack.get_controller(session)
4848
4849
4850@tf_export("get_default_session")
4851def get_default_session():
4852  """Returns the default session for the current thread.
4853
4854  The returned `Session` will be the innermost session on which a
4855  `Session` or `Session.as_default()` context has been entered.
4856
4857  NOTE: The default session is a property of the current thread. If you
4858  create a new thread, and wish to use the default session in that
4859  thread, you must explicitly add a `with sess.as_default():` in that
4860  thread's function.
4861
4862  Returns:
4863    The default `Session` being used in the current thread.
4864  """
4865  return _default_session_stack.get_default()
4866
4867
4868def _eval_using_default_session(tensors, feed_dict, graph, session=None):
4869  """Uses the default session to evaluate one or more tensors.
4870
4871  Args:
4872    tensors: A single Tensor, or a list of Tensor objects.
4873    feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists,
4874      numpy ndarrays, TensorProtos, or strings.
4875    graph: The graph in which the tensors are defined.
4876    session: (Optional) A different session to use to evaluate "tensors".
4877
4878  Returns:
4879    Either a single numpy ndarray if "tensors" is a single tensor; or a list
4880    of numpy ndarrays that each correspond to the respective element in
4881    "tensors".
4882
4883  Raises:
4884    ValueError: If no default session is available; the default session
4885      does not have "graph" as its graph; or if "session" is specified,
4886      and it does not have "graph" as its graph.
4887  """
4888  if session is None:
4889    session = get_default_session()
4890    if session is None:
4891      raise ValueError("Cannot evaluate tensor using `eval()`: No default "
4892                       "session is registered. Use `with "
4893                       "sess.as_default()` or pass an explicit session to "
4894                       "`eval(session=sess)`")
4895    if session.graph is not graph:
4896      raise ValueError("Cannot use the default session to evaluate tensor: "
4897                       "the tensor's graph is different from the session's "
4898                       "graph. Pass an explicit session to "
4899                       "`eval(session=sess)`.")
4900  else:
4901    if session.graph is not graph:
4902      raise ValueError("Cannot use the given session to evaluate tensor: "
4903                       "the tensor's graph is different from the session's "
4904                       "graph.")
4905  return session.run(tensors, feed_dict)
4906
4907
4908def _run_using_default_session(operation, feed_dict, graph, session=None):
4909  """Uses the default session to run "operation".
4910
4911  Args:
4912    operation: The Operation to be run.
4913    feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists,
4914      numpy ndarrays, TensorProtos, or strings.
4915    graph: The graph in which "operation" is defined.
4916    session: (Optional) A different session to use to run "operation".
4917
4918  Raises:
4919    ValueError: If no default session is available; the default session
4920      does not have "graph" as its graph; or if "session" is specified,
4921      and it does not have "graph" as its graph.
4922  """
4923  if session is None:
4924    session = get_default_session()
4925    if session is None:
4926      raise ValueError("Cannot execute operation using `run()`: No default "
4927                       "session is registered. Use `with "
4928                       "sess.as_default():` or pass an explicit session to "
4929                       "`run(session=sess)`")
4930    if session.graph is not graph:
4931      raise ValueError("Cannot use the default session to execute operation: "
4932                       "the operation's graph is different from the "
4933                       "session's graph. Pass an explicit session to "
4934                       "run(session=sess).")
4935  else:
4936    if session.graph is not graph:
4937      raise ValueError("Cannot use the given session to execute operation: "
4938                       "the operation's graph is different from the session's "
4939                       "graph.")
4940  session.run(operation, feed_dict)
4941
4942
4943class _DefaultGraphStack(_DefaultStack):  # pylint: disable=protected-access
4944  """A thread-local stack of objects for providing an implicit default graph."""
4945
4946  def __init__(self):
4947    super(_DefaultGraphStack, self).__init__()
4948    self._global_default_graph = None
4949
4950  def get_default(self):
4951    """Override that returns a global default if the stack is empty."""
4952    ret = super(_DefaultGraphStack, self).get_default()
4953    if ret is None:
4954      ret = self._GetGlobalDefaultGraph()
4955    return ret
4956
4957  def _GetGlobalDefaultGraph(self):
4958    if self._global_default_graph is None:
4959      # TODO(mrry): Perhaps log that the default graph is being used, or set
4960      #   provide some other feedback to prevent confusion when a mixture of
4961      #   the global default graph and an explicit graph are combined in the
4962      #   same process.
4963      self._global_default_graph = Graph()
4964    return self._global_default_graph
4965
4966  def reset(self):
4967    super(_DefaultGraphStack, self).reset()
4968    self._global_default_graph = None
4969
4970  @tf_contextlib.contextmanager
4971  def get_controller(self, default):
4972    try:
4973      context.context_stack.push(default.building_function, default.as_default)
4974      with super(_DefaultGraphStack, self).get_controller(default) as g:
4975        yield g
4976    finally:
4977      context.context_stack.pop()
4978
4979
4980_default_graph_stack = _DefaultGraphStack()
4981
4982
4983# pylint: disable=g-doc-return-or-yield,line-too-long
4984@tf_contextlib.contextmanager
4985def init_scope():
4986  """A context manager that lifts ops out of control-flow scopes and function-building graphs.
4987
4988  There is often a need to lift variable initialization ops out of control-flow
4989  scopes, function-building graphs, and gradient tapes. Entering an
4990  `init_scope` is a mechanism for satisfying these desiderata. In particular,
4991  entering an `init_scope` has three effects:
4992
4993    (1) All control dependencies are cleared the moment the scope is entered;
4994        this is equivalent to entering the context manager returned from
4995        `control_dependencies(None)`, which has the side-effect of exiting
4996        control-flow scopes like `tf.cond` and `tf.while_loop`.
4997
4998    (2) All operations that are created while the scope is active are lifted
4999        into the lowest context on the `context_stack` that is not building a
5000        graph function. Here, a context is defined as either a graph or an eager
5001        context. Every context switch, i.e., every installation of a graph as
5002        the default graph and every switch into eager mode, is logged in a
5003        thread-local stack called the `context_stack`; the log entry for a
5004        context switch is popped from the stack when the context is exited.
5005        Entering an `init_scope` is equivalent to crawling up the
5006        `context_stack`, finding the first context that is not building a graph
5007        function, and entering it. A caveat is that if graph mode is enabled
5008        but the default graph stack is empty, then entering an `init_scope`
5009        will simply install a fresh graph as the default one.
5010
5011    (3) The gradient tape is paused while the scope is active.
5012  """
5013  # pylint: enable=g-doc-return-or-yield,line-too-long
5014
5015  in_graph_mode = context.in_graph_mode()
5016  # Retrieve the active name scope: entering an `init_scope` preserves
5017  # the name scope of the current context.
5018  if in_graph_mode:
5019    default_graph = get_default_graph()
5020    scope = default_graph.get_name_scope()
5021  else:
5022    scope = context.context().scope_name
5023  if scope and scope[-1] != '/':
5024    # Names that end with trailing slashes are treated by `name_scope` as
5025    # absolute.
5026    scope = scope + '/'
5027
5028  outer_context = None
5029  if in_graph_mode and not _default_graph_stack.stack:
5030    outer_context = default_graph.as_default
5031  else:
5032    for stack_entry in reversed(context.context_stack.stack):
5033      if not stack_entry.is_building_function:
5034        outer_context = stack_entry.enter_context_fn
5035        break
5036
5037  if outer_context is None:
5038    raise AssertionError("All graphs are building functions, and no "
5039                         "eager context was previously active.")
5040
5041  try:
5042    with outer_context(), name_scope(scope), control_dependencies(
5043        None), tape.stop_recording():
5044      yield
5045  finally:
5046    pass
5047
5048
5049def enable_eager_execution(config=None, device_policy=None):
5050  """Enables, for the rest of the lifetime of this program, eager execution.
5051
5052  If not called immediately on startup risks creating breakage and bugs.
5053
5054  Example:
5055  ```python
5056  tfe.enable_eager_execution()
5057
5058  # After eager execution is enabled, operations are executed as they are
5059  # defined and `Tensor`s hold concrete values, which can be accessed as
5060  # `numpy.ndarray`s through the `numpy()` method.
5061  assert tf.multiply(6, 7).numpy() == 42
5062  ```
5063
5064  Args:
5065    config: (Optional.) A `ConfigProto` protocol buffer with configuration
5066     options for the Context. Note that a lot of these options may be
5067     currently unimplemented or irrelevant when eager execution is enabled.
5068    device_policy: (Optional.) What policy to use when trying to run an
5069     operation on a device with inputs which are not on that device.
5070     Valid values:
5071       tfe.DEVICE_PLACEMENT_EXPLICIT: raises an error if the placement is not
5072         correct.
5073       tfe.DEVICE_PLACEMENT_WARN: copies the tensors which are not on the
5074         right device but raises a warning.
5075       tfe.DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might
5076         hide performance problems.
5077       tfe.DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies int32 tensors,
5078         raising errors on the other ones.
5079
5080  Raises:
5081    ValueError: If trying to create a context after using graph operations
5082     or if trying to create a context with nontrivial options which differ
5083     from those of the existing context.
5084  """
5085  if config is not None and not isinstance(config, config_pb2.ConfigProto):
5086    raise TypeError(
5087        "config must be a tf.ConfigProto, but got %s" % type(config))
5088  if device_policy not in (None, context.DEVICE_PLACEMENT_EXPLICIT,
5089                           context.DEVICE_PLACEMENT_WARN,
5090                           context.DEVICE_PLACEMENT_SILENT,
5091                           context.DEVICE_PLACEMENT_SILENT_FOR_INT32):
5092    raise ValueError(
5093        "device_policy must be one of None, tfe.DEVICE_PLACEMENT_*"
5094    )
5095  # pylint: disable=protected-access
5096  if context._default_mode == context.GRAPH_MODE:
5097    graph_mode_has_been_used = (
5098        _default_session_stack.stack or
5099        _default_graph_stack._global_default_graph is not None)
5100    if graph_mode_has_been_used:
5101      raise ValueError(
5102          "tfe.enable_eager_execution has to be called at program startup.")
5103  context._default_mode = context.EAGER_MODE
5104  if context._context is None:
5105    context._context = context.Context(config=config,
5106                                       device_policy=device_policy)
5107    if context.context_stack.stack:
5108      raise AssertionError("Invariant violated: The context stack must "
5109                           "be empty when eager execution is enabled.")
5110    # Log that eager execution has been enabled by pushing an entry onto the
5111    # context stack; this entry won't ever be popped, as it's impossible to
5112    # disable eager execution
5113    context.context_stack.push(False, context.eager_mode)
5114  elif ((config is not None and config is not context._context._config)
5115        or (device_policy is not None
5116            and device_policy is not context._context._device_policy)):
5117    raise ValueError("Trying to change the options of an active eager"
5118                     " execution. Context config: %s, specified config:"
5119                     " %s. Context device policy: %s; specified device"
5120                     " policy: %s." % (config, context._context._config,
5121                                       device_policy,
5122                                       context._context._device_policy))
5123  else:
5124    raise ValueError(
5125        "tfe.enable_eager_execution has to be called at program startup.")
5126
5127
5128def eager_run(main=None, argv=None):
5129  """Runs the program with an optional main function and argv list.
5130
5131  The program will run with eager execution enabled.
5132
5133  Example:
5134  ```python
5135  import tensorflow as tf
5136  # Import subject to future changes:
5137  from tensorflow.contrib.eager.python import tfe
5138
5139  def main(_):
5140    u = tf.constant(6.0)
5141    v = tf.constant(7.0)
5142    print(u * v)
5143
5144  if __name__ == "__main__":
5145    tfe.run()
5146  ```
5147
5148  Args:
5149    main: the main function to run.
5150    argv: the arguments to pass to it.
5151  """
5152  enable_eager_execution()
5153  app.run(main, argv)
5154
5155
5156@tf_export("reset_default_graph")
5157def reset_default_graph():
5158  """Clears the default graph stack and resets the global default graph.
5159
5160  NOTE: The default graph is a property of the current thread. This
5161  function applies only to the current thread.  Calling this function while
5162  a `tf.Session` or `tf.InteractiveSession` is active will result in undefined
5163  behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects
5164  after calling this function will result in undefined behavior.
5165  Raises:
5166    AssertionError: If this function is called within a nested graph.
5167  """
5168  if not _default_graph_stack.is_cleared():
5169    raise AssertionError("Do not use tf.reset_default_graph() to clear "
5170                         "nested graphs. If you need a cleared graph, "
5171                         "exit the nesting and create a new graph.")
5172  _default_graph_stack.reset()
5173
5174
5175@tf_export("get_default_graph")
5176def get_default_graph():
5177  """Returns the default graph for the current thread.
5178
5179  The returned graph will be the innermost graph on which a
5180  `Graph.as_default()` context has been entered, or a global default
5181  graph if none has been explicitly created.
5182
5183  NOTE: The default graph is a property of the current thread. If you
5184  create a new thread, and wish to use the default graph in that
5185  thread, you must explicitly add a `with g.as_default():` in that
5186  thread's function.
5187
5188  Returns:
5189    The default `Graph` being used in the current thread.
5190  """
5191  return _default_graph_stack.get_default()
5192
5193
5194def get_name_scope():
5195  """Returns the current name scope in the default_graph.
5196
5197  For example:
5198
5199  ```python
5200  with tf.name_scope('scope1'):
5201    with tf.name_scope('scope2'):
5202      print(tf.get_name_scope())
5203  ```
5204  would print the string `scope1/scope2`.
5205
5206  Returns:
5207    A string representing the current name scope.
5208  """
5209  return get_default_graph().get_name_scope()
5210
5211
5212def _assert_same_graph(original_item, item):
5213  """Fail if the 2 items are from different graphs.
5214
5215  Args:
5216    original_item: Original item to check against.
5217    item: Item to check.
5218
5219  Raises:
5220    ValueError: if graphs do not match.
5221  """
5222  if original_item.graph is not item.graph:
5223    raise ValueError("%s must be from the same graph as %s." % (item,
5224                                                                original_item))
5225
5226
5227def _get_graph_from_inputs(op_input_list, graph=None):
5228  """Returns the appropriate graph to use for the given inputs.
5229
5230  This library method provides a consistent algorithm for choosing the graph
5231  in which an Operation should be constructed:
5232
5233  1. If the default graph is being used to construct a function, we
5234     use the default graph.
5235  2. If the "graph" is specified explicitly, we validate that all of the inputs
5236     in "op_input_list" are compatible with that graph.
5237  3. Otherwise, we attempt to select a graph from the first Operation-
5238     or Tensor-valued input in "op_input_list", and validate that all other
5239     such inputs are in the same graph.
5240  4. If the graph was not specified and it could not be inferred from
5241     "op_input_list", we attempt to use the default graph.
5242
5243  Args:
5244    op_input_list: A list of inputs to an operation, which may include `Tensor`,
5245      `Operation`, and other objects that may be converted to a graph element.
5246    graph: (Optional) The explicit graph to use.
5247
5248  Raises:
5249    TypeError: If op_input_list is not a list or tuple, or if graph is not a
5250      Graph.
5251    ValueError: If a graph is explicitly passed and not all inputs are from it,
5252      or if the inputs are from multiple graphs, or we could not find a graph
5253      and there was no default graph.
5254
5255  Returns:
5256    The appropriate graph to use for the given inputs.
5257
5258  """
5259  if get_default_graph().building_function:
5260    return get_default_graph()
5261
5262  op_input_list = tuple(op_input_list)  # Handle generators correctly
5263  if graph and not isinstance(graph, Graph):
5264    raise TypeError("Input graph needs to be a Graph: %s" % graph)
5265
5266  # 1. We validate that all of the inputs are from the same graph. This is
5267  #    either the supplied graph parameter, or the first one selected from one
5268  #    the graph-element-valued inputs. In the latter case, we hold onto
5269  #    that input in original_graph_element so we can provide a more
5270  #    informative error if a mismatch is found.
5271  original_graph_element = None
5272  for op_input in op_input_list:
5273    # Determine if this is a valid graph_element.
5274    # TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this
5275    # up.
5276    graph_element = None
5277    if (isinstance(op_input, (Operation, _TensorLike)) and
5278        ((not isinstance(op_input, Tensor)) or type(op_input) == Tensor)):  # pylint: disable=unidiomatic-typecheck
5279      graph_element = op_input
5280    else:
5281      graph_element = _as_graph_element(op_input)
5282
5283    if graph_element is not None:
5284      if not graph:
5285        original_graph_element = graph_element
5286        graph = graph_element.graph
5287      elif original_graph_element is not None:
5288        _assert_same_graph(original_graph_element, graph_element)
5289      elif graph_element.graph is not graph:
5290        raise ValueError("%s is not from the passed-in graph." % graph_element)
5291
5292  # 2. If all else fails, we use the default graph, which is always there.
5293  return graph or get_default_graph()
5294
5295
5296@tf_export("GraphKeys")
5297class GraphKeys(object):
5298  """Standard names to use for graph collections.
5299
5300  The standard library uses various well-known names to collect and
5301  retrieve values associated with a graph. For example, the
5302  `tf.Optimizer` subclasses default to optimizing the variables
5303  collected under `tf.GraphKeys.TRAINABLE_VARIABLES` if none is
5304  specified, but it is also possible to pass an explicit list of
5305  variables.
5306
5307  The following standard keys are defined:
5308
5309  * `GLOBAL_VARIABLES`: the default collection of `Variable` objects, shared
5310    across distributed environment (model variables are subset of these). See
5311    @{tf.global_variables}
5312    for more details.
5313    Commonly, all `TRAINABLE_VARIABLES` variables will be in `MODEL_VARIABLES`,
5314    and all `MODEL_VARIABLES` variables will be in `GLOBAL_VARIABLES`.
5315  * `LOCAL_VARIABLES`: the subset of `Variable` objects that are local to each
5316    machine. Usually used for temporarily variables, like counters.
5317    Note: use `tf.contrib.framework.local_variable` to add to this collection.
5318  * `MODEL_VARIABLES`: the subset of `Variable` objects that are used in the
5319    model for inference (feed forward). Note: use
5320    `tf.contrib.framework.model_variable` to add to this collection.
5321  * `TRAINABLE_VARIABLES`: the subset of `Variable` objects that will
5322    be trained by an optimizer. See
5323    @{tf.trainable_variables}
5324    for more details.
5325  * `SUMMARIES`: the summary `Tensor` objects that have been created in the
5326    graph. See
5327    @{tf.summary.merge_all}
5328    for more details.
5329  * `QUEUE_RUNNERS`: the `QueueRunner` objects that are used to
5330    produce input for a computation. See
5331    @{tf.train.start_queue_runners}
5332    for more details.
5333  * `MOVING_AVERAGE_VARIABLES`: the subset of `Variable` objects that will also
5334    keep moving averages.  See
5335    @{tf.moving_average_variables}
5336    for more details.
5337  * `REGULARIZATION_LOSSES`: regularization losses collected during graph
5338    construction.
5339
5340  The following standard keys are _defined_, but their collections are **not**
5341  automatically populated as many of the others are:
5342
5343  * `WEIGHTS`
5344  * `BIASES`
5345  * `ACTIVATIONS`
5346  """
5347
5348  # Key to collect Variable objects that are global (shared across machines).
5349  # Default collection for all variables, except local ones.
5350  GLOBAL_VARIABLES = "variables"
5351  # Key to collect local variables that are local to the machine and are not
5352  # saved/restored.
5353  LOCAL_VARIABLES = "local_variables"
5354  # Key to collect local variables which are used to accumulate interal state
5355  # to be used in tf.metrics.*.
5356  METRIC_VARIABLES = "metric_variables"
5357  # Key to collect model variables defined by layers.
5358  MODEL_VARIABLES = "model_variables"
5359  # Key to collect Variable objects that will be trained by the
5360  # optimizers.
5361  TRAINABLE_VARIABLES = "trainable_variables"
5362  # Key to collect summaries.
5363  SUMMARIES = "summaries"
5364  # Key to collect QueueRunners.
5365  QUEUE_RUNNERS = "queue_runners"
5366  # Key to collect table initializers.
5367  TABLE_INITIALIZERS = "table_initializer"
5368  # Key to collect asset filepaths. An asset represents an external resource
5369  # like a vocabulary file.
5370  ASSET_FILEPATHS = "asset_filepaths"
5371  # Key to collect Variable objects that keep moving averages.
5372  MOVING_AVERAGE_VARIABLES = "moving_average_variables"
5373  # Key to collect regularization losses at graph construction.
5374  REGULARIZATION_LOSSES = "regularization_losses"
5375  # Key to collect concatenated sharded variables.
5376  CONCATENATED_VARIABLES = "concatenated_variables"
5377  # Key to collect savers.
5378  SAVERS = "savers"
5379  # Key to collect weights
5380  WEIGHTS = "weights"
5381  # Key to collect biases
5382  BIASES = "biases"
5383  # Key to collect activations
5384  ACTIVATIONS = "activations"
5385  # Key to collect update_ops
5386  UPDATE_OPS = "update_ops"
5387  # Key to collect losses
5388  LOSSES = "losses"
5389  # Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing.
5390  SAVEABLE_OBJECTS = "saveable_objects"
5391  # Key to collect all shared resources used by the graph which need to be
5392  # initialized once per cluster.
5393  RESOURCES = "resources"
5394  # Key to collect all shared resources used in this graph which need to be
5395  # initialized once per session.
5396  LOCAL_RESOURCES = "local_resources"
5397  # Trainable resource-style variables.
5398  TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables"
5399
5400  # Key to indicate various ops.
5401  INIT_OP = "init_op"
5402  LOCAL_INIT_OP = "local_init_op"
5403  READY_OP = "ready_op"
5404  READY_FOR_LOCAL_INIT_OP = "ready_for_local_init_op"
5405  SUMMARY_OP = "summary_op"
5406  GLOBAL_STEP = "global_step"
5407
5408  # Used to count the number of evaluations performed during a single evaluation
5409  # run.
5410  EVAL_STEP = "eval_step"
5411  TRAIN_OP = "train_op"
5412
5413  # Key for control flow context.
5414  COND_CONTEXT = "cond_context"
5415  WHILE_CONTEXT = "while_context"
5416
5417  # Used to store v2 summary names.
5418  _SUMMARY_COLLECTION = "_SUMMARY_V2"
5419
5420  # List of all collections that keep track of variables.
5421  _VARIABLE_COLLECTIONS = [
5422      GLOBAL_VARIABLES,
5423      LOCAL_VARIABLES,
5424      METRIC_VARIABLES,
5425      MODEL_VARIABLES,
5426      TRAINABLE_VARIABLES,
5427      MOVING_AVERAGE_VARIABLES,
5428      CONCATENATED_VARIABLES,
5429      TRAINABLE_RESOURCE_VARIABLES,
5430  ]
5431
5432  # Key for streaming model ports.
5433  # NOTE(yuanbyu): internal and experimental.
5434  _STREAMING_MODEL_PORTS = "streaming_model_ports"
5435
5436  @decorator_utils.classproperty
5437  def VARIABLES(cls):  # pylint: disable=no-self-argument
5438    logging.log_first_n(logging.WARN,
5439                        "VARIABLES collection name is deprecated, please use "
5440                        "GLOBAL_VARIABLES instead; VARIABLES will be removed "
5441                        "after 2017-03-02.", 1)
5442    return cls.GLOBAL_VARIABLES
5443
5444
5445@tf_export("add_to_collection")
5446def add_to_collection(name, value):
5447  """Wrapper for `Graph.add_to_collection()` using the default graph.
5448
5449  See @{tf.Graph.add_to_collection}
5450  for more details.
5451
5452  Args:
5453    name: The key for the collection. For example, the `GraphKeys` class
5454      contains many standard names for collections.
5455    value: The value to add to the collection.
5456
5457  @compatibility(eager)
5458  Collections are not supported when eager execution is enabled.
5459  @end_compatibility
5460  """
5461  get_default_graph().add_to_collection(name, value)
5462
5463
5464def add_to_collections(names, value):
5465  """Wrapper for `Graph.add_to_collections()` using the default graph.
5466
5467  See @{tf.Graph.add_to_collections}
5468  for more details.
5469
5470  Args:
5471    names: The key for the collections. The `GraphKeys` class
5472      contains many standard names for collections.
5473    value: The value to add to the collections.
5474
5475  @compatibility(eager)
5476  Collections are not supported when eager execution is enabled.
5477  @end_compatibility
5478  """
5479  get_default_graph().add_to_collections(names, value)
5480
5481
5482@tf_export("get_collection_ref")
5483def get_collection_ref(key):
5484  """Wrapper for `Graph.get_collection_ref()` using the default graph.
5485
5486  See @{tf.Graph.get_collection_ref}
5487  for more details.
5488
5489  Args:
5490    key: The key for the collection. For example, the `GraphKeys` class
5491      contains many standard names for collections.
5492
5493  Returns:
5494    The list of values in the collection with the given `name`, or an empty
5495    list if no value has been added to that collection.  Note that this returns
5496    the collection list itself, which can be modified in place to change the
5497    collection.
5498
5499  @compatibility(eager)
5500  Collections are not supported when eager execution is enabled.
5501  @end_compatibility
5502  """
5503  return get_default_graph().get_collection_ref(key)
5504
5505
5506@tf_export("get_collection")
5507def get_collection(key, scope=None):
5508  """Wrapper for `Graph.get_collection()` using the default graph.
5509
5510  See @{tf.Graph.get_collection}
5511  for more details.
5512
5513  Args:
5514    key: The key for the collection. For example, the `GraphKeys` class
5515      contains many standard names for collections.
5516    scope: (Optional.) If supplied, the resulting list is filtered to include
5517      only items whose `name` attribute matches using `re.match`. Items
5518      without a `name` attribute are never returned if a scope is supplied and
5519      the choice or `re.match` means that a `scope` without special tokens
5520      filters by prefix.
5521
5522  Returns:
5523    The list of values in the collection with the given `name`, or
5524    an empty list if no value has been added to that collection. The
5525    list contains the values in the order under which they were
5526    collected.
5527
5528  @compatibility(eager)
5529  Collections are not supported when eager execution is enabled.
5530  @end_compatibility
5531  """
5532  return get_default_graph().get_collection(key, scope)
5533
5534
5535def get_all_collection_keys():
5536  """Returns a list of collections used in the default graph."""
5537  return get_default_graph().get_all_collection_keys()
5538
5539
5540name_scope_cache = {}
5541
5542
5543# Named like a function for backwards compatibility with the
5544# @tf_contextlib.contextmanager version, which was switched to a class to avoid
5545# some object creation overhead.
5546@tf_export("name_scope", "keras.backend.name_scope")
5547class name_scope(object):  # pylint: disable=invalid-name
5548  """A context manager for use when defining a Python op.
5549
5550  This context manager validates that the given `values` are from the
5551  same graph, makes that graph the default graph, and pushes a
5552  name scope in that graph (see
5553  @{tf.Graph.name_scope}
5554  for more details on that).
5555
5556  For example, to define a new Python op called `my_op`:
5557
5558  ```python
5559  def my_op(a, b, c, name=None):
5560    with tf.name_scope(name, "MyOp", [a, b, c]) as scope:
5561      a = tf.convert_to_tensor(a, name="a")
5562      b = tf.convert_to_tensor(b, name="b")
5563      c = tf.convert_to_tensor(c, name="c")
5564      # Define some computation that uses `a`, `b`, and `c`.
5565      return foo_op(..., name=scope)
5566  ```
5567  """
5568
5569  @property
5570  def name(self):
5571    return self._name
5572
5573  def __init__(self, name, default_name=None, values=None):
5574    """Initialize the context manager.
5575
5576    Args:
5577      name: The name argument that is passed to the op function.
5578      default_name: The default name to use if the `name` argument is `None`.
5579      values: The list of `Tensor` arguments that are passed to the op function.
5580    """
5581    self._name = default_name if name is None else name
5582    self._default_name = default_name
5583    self._values = values
5584    self._ctx = context.context()
5585    self._in_eager_mode = self._ctx.in_eager_mode()
5586
5587  def __enter__(self):
5588    """Start the scope block.
5589
5590    Returns:
5591      The scope name.
5592
5593    Raises:
5594      ValueError: if neither `name` nor `default_name` is provided
5595        but `values` are.
5596    """
5597    if self._in_eager_mode:
5598      self._old_name = self._ctx.scope_name
5599      if not self._name:
5600        scope_name = ""
5601      else:
5602        cache_key = self._name, self._old_name, self._default_name
5603        if cache_key in name_scope_cache:
5604          self._ctx.scope_name = name_scope_cache[cache_key]
5605          return self._ctx.scope_name
5606        elif self._name[-1] == "/":
5607          # A trailing slash breaks out of nested name scopes, indicating a
5608          # fully specified scope name, for compatibility with Graph.name_scope.
5609          scope_name = self._name
5610        else:
5611          name_with_trailing_slash = self._name + "/"
5612          scope_name = (
5613              self._old_name + name_with_trailing_slash
5614              if self._old_name else name_with_trailing_slash)
5615        name_scope_cache[cache_key] = scope_name
5616      self._ctx.scope_name = scope_name
5617      return scope_name
5618    else:
5619      if self._name is None and self._values is not None:
5620        # We only raise an error if values is not None (provided) because
5621        # currently tf.name_scope(None) (values=None then) is sometimes used as
5622        # an idiom to reset to top scope.
5623        raise ValueError(
5624            "At least one of name (%s) and default_name (%s) must be provided."
5625            % (self._name, self._default_name))
5626      if self._values is None:
5627        self._values = []
5628      g = _get_graph_from_inputs(self._values)
5629      self._g_manager = g.as_default()
5630      self._g_manager.__enter__()
5631      try:
5632        self._name_scope = g.name_scope(self._name)
5633        return self._name_scope.__enter__()
5634      except:
5635        self._g_manager.__exit__(*sys.exc_info())
5636        raise
5637
5638  def __exit__(self, type_arg, value_arg, traceback_arg):
5639    if self._in_eager_mode:
5640      self._ctx.scope_name = self._old_name
5641    else:
5642      self._name_scope.__exit__(type_arg, value_arg, traceback_arg)
5643      self._g_manager.__exit__(type_arg, value_arg, traceback_arg)
5644    return False  # False values do not suppress exceptions
5645
5646
5647def strip_name_scope(name, export_scope):
5648  """Removes name scope from a name.
5649
5650  Args:
5651    name: A `string` name.
5652    export_scope: Optional `string`. Name scope to remove.
5653
5654  Returns:
5655    Name with name scope removed, or the original name if export_scope
5656    is None.
5657  """
5658  if export_scope:
5659    try:
5660      # Strips export_scope/, export_scope///,
5661      # ^export_scope/, loc:@export_scope/.
5662      str_to_replace = r"([\^]|loc:@|^)" + export_scope + r"[\/]+(.*)"
5663      return re.sub(str_to_replace, r"\1\2", compat.as_str(name), count=1)
5664    except TypeError as e:
5665      # If the name is not of a type we can process, simply return it.
5666      logging.warning(e)
5667      return name
5668  else:
5669    return name
5670
5671
5672def prepend_name_scope(name, import_scope):
5673  """Prepends name scope to a name.
5674
5675  Args:
5676    name: A `string` name.
5677    import_scope: Optional `string`. Name scope to add.
5678
5679  Returns:
5680    Name with name scope added, or the original name if import_scope
5681    is None.
5682  """
5683  if import_scope:
5684    try:
5685      str_to_replace = r"([\^]|loc:@|^)(.*)"
5686      return re.sub(str_to_replace, r"\1" + import_scope + r"/\2",
5687                    compat.as_str(name))
5688    except TypeError as e:
5689      # If the name is not of a type we can process, simply return it.
5690      logging.warning(e)
5691      return name
5692  else:
5693    return name
5694
5695
5696# pylint: disable=g-doc-return-or-yield
5697# pylint: disable=not-context-manager
5698@tf_export("op_scope")
5699@tf_contextlib.contextmanager
5700def op_scope(values, name, default_name=None):
5701  """DEPRECATED. Same as name_scope above, just different argument order."""
5702  logging.warn("tf.op_scope(values, name, default_name) is deprecated,"
5703               " use tf.name_scope(name, default_name, values)")
5704  with name_scope(name, default_name=default_name, values=values) as scope:
5705    yield scope
5706
5707
5708_proto_function_registry = registry.Registry("proto functions")
5709
5710
5711def register_proto_function(collection_name,
5712                            proto_type=None,
5713                            to_proto=None,
5714                            from_proto=None):
5715  """Registers `to_proto` and `from_proto` functions for collection_name.
5716
5717  `to_proto` function converts a Python object to the corresponding protocol
5718  buffer, and returns the protocol buffer.
5719
5720  `from_proto` function converts protocol buffer into a Python object, and
5721  returns the object..
5722
5723  Args:
5724    collection_name: Name of the collection.
5725    proto_type: Protobuf type, such as `saver_pb2.SaverDef`,
5726      `variable_pb2.VariableDef`, `queue_runner_pb2.QueueRunnerDef`..
5727    to_proto: Function that implements Python object to protobuf conversion.
5728    from_proto: Function that implements protobuf to Python object conversion.
5729  """
5730  if to_proto and not callable(to_proto):
5731    raise TypeError("to_proto must be callable.")
5732  if from_proto and not callable(from_proto):
5733    raise TypeError("from_proto must be callable.")
5734
5735  _proto_function_registry.register((proto_type, to_proto, from_proto),
5736                                    collection_name)
5737
5738
5739def get_collection_proto_type(collection_name):
5740  """Returns the proto_type for collection_name."""
5741  try:
5742    return _proto_function_registry.lookup(collection_name)[0]
5743  except LookupError:
5744    return None
5745
5746
5747def get_to_proto_function(collection_name):
5748  """Returns the to_proto function for collection_name."""
5749  try:
5750    return _proto_function_registry.lookup(collection_name)[1]
5751  except LookupError:
5752    return None
5753
5754
5755def get_from_proto_function(collection_name):
5756  """Returns the from_proto function for collection_name."""
5757  try:
5758    return _proto_function_registry.lookup(collection_name)[2]
5759  except LookupError:
5760    return None
5761
5762
5763def _assert_collection_is_ok(collection_name):
5764  if context.in_eager_mode():
5765    if collection_name in GraphKeys._VARIABLE_COLLECTIONS:  # pylint: disable=protected-access
5766      raise ValueError("When Eager Execution is enabled, variable "
5767                       "collections are not supported.")
5768
5769
5770def _operation_conversion_error(op, dtype=None, name=None, as_ref=False):
5771  """Produce a nice error if someone converts an Operation to a Tensor."""
5772  raise TypeError(("Can't convert Operation '%s' to Tensor "
5773                   "(target dtype=%r, name=%r, as_ref=%r)") % (op.name, dtype,
5774                                                               name, as_ref))
5775
5776
5777register_tensor_conversion_function(Operation, _operation_conversion_error)
5778