1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ops to use variables as resources."""
16
17# pylint: disable=g-bad-name
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import contextlib
23import functools
24
25from tensorflow.core.framework import attr_value_pb2
26from tensorflow.core.framework import variable_pb2
27from tensorflow.python import pywrap_tensorflow
28from tensorflow.python.eager import context
29from tensorflow.python.eager import tape
30from tensorflow.python.framework import constant_op
31from tensorflow.python.framework import cpp_shape_inference_pb2
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import tensor_shape
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import gen_array_ops
37from tensorflow.python.ops import gen_resource_variable_ops
38from tensorflow.python.ops import gen_state_ops
39from tensorflow.python.ops import math_ops
40from tensorflow.python.ops import state_ops
41from tensorflow.python.ops import variables
42# go/tf-wildcard-import
43# pylint: disable=wildcard-import
44from tensorflow.python.ops.gen_resource_variable_ops import *
45# pylint: enable=wildcard-import
46from tensorflow.python.training.tracking import base as trackable
47from tensorflow.python.util import compat
48from tensorflow.python.util.deprecation import deprecated
49
50
51def get_resource_handle_data(graph_op):
52  assert type(graph_op) == ops.Tensor  # pylint: disable=unidiomatic-typecheck
53
54  handle_data = pywrap_tensorflow.GetHandleShapeAndType(
55      graph_op.graph._c_graph, graph_op._as_tf_output())  # pylint: disable=protected-access
56
57  return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
58      compat.as_bytes(handle_data))
59
60
61def get_eager_safe_handle_data(handle):
62  """Get the data handle from the Tensor `handle`."""
63  assert isinstance(handle, ops.Tensor)
64
65  if isinstance(handle, ops.EagerTensor):
66    return handle._handle_data  # pylint: disable=protected-access
67  else:
68    return get_resource_handle_data(handle)
69
70
71def _set_handle_shapes_and_types(tensor, handle_data, graph_mode):
72  """Sets the shape inference result HandleData on tensor.
73
74  Args:
75    tensor: A `Tensor` or `EagerTensor`.
76    handle_data: A `CppShapeInferenceResult.HandleData`.
77    graph_mode: A python bool.
78  """
79  tensor._handle_data = handle_data  # pylint: disable=protected-access
80  if not graph_mode:
81    return
82
83  # Not an EagerTensor, so a graph tensor.
84  shapes, types = zip(*[(pair.shape, pair.dtype)
85                        for pair in handle_data.shape_and_type])
86  ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
87  shapes = [[d.size for d in s.dim]
88            if not s.unknown_rank else None for s in shapes]
89  pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
90      tensor._op._graph._c_graph,  # pylint: disable=protected-access
91      tensor._as_tf_output(),  # pylint: disable=protected-access
92      shapes, ranks, types)
93
94
95def _combine_handle_data(handle, initial_value):
96  """Concats HandleData from tensors `handle` and `initial_value`.
97
98  Args:
99    handle: A `Tensor` of dtype `resource`.
100    initial_value: A `Tensor`.
101
102  Returns:
103    A `CppShapeInferenceResult.HandleData`.  If `initial_value` has dtype
104    `variant`, the `HandleData` contains the concatenation of the shape_and_type
105    from both `handle` and `initial_value`.
106
107  Raises:
108    RuntimeError: If handle, which was returned by VarHandleOp, either has
109      no handle data, or its len(handle_data.shape_and_type) != 1.
110  """
111  assert handle.dtype == dtypes.resource
112
113  variable_handle_data = get_eager_safe_handle_data(handle)
114
115  if initial_value.dtype != dtypes.variant:
116    return variable_handle_data
117
118  extra_handle_data = get_eager_safe_handle_data(initial_value)
119  if extra_handle_data is not None and extra_handle_data.is_set:
120    if (variable_handle_data is None
121        or not variable_handle_data.is_set
122        or len(variable_handle_data.shape_and_type) != 1):
123      raise RuntimeError(
124          "Expected VarHandleOp to return a length==1 shape_and_type, "
125          "but saw: '%s'" % (variable_handle_data,))
126    variable_handle_data.shape_and_type.extend(
127        extra_handle_data.shape_and_type)
128  return variable_handle_data
129
130
131def eager_safe_variable_handle(initial_value, shared_name, name, graph_mode):
132  """Creates a variable handle with information to do shape inference.
133
134  The shape and dtype are read from `initial_value` and stored in the returned
135  resource tensor's handle data.
136
137  If `initial_value.dtype == tf.variant`, we additionally extract the handle
138  data (if any) from `initial_value` and append it to the `handle_data`.
139  In this case, the returned tensor's handle data is in the form
140
141  ```
142  is_set: true
143  shape_and_type {
144    shape {
145      // initial_value.shape
146    }
147    dtype: DT_VARIANT
148  }
149  shape_and_type {
150    // handle_data(initial_value).shape_and_type[0]
151  }
152  shape_and_type {
153    // handle_data(initial_value).shape_and_type[1]
154  }
155  ...
156  ```
157
158  Ops that read from this tensor, such as `ReadVariableOp` and
159  `AssignVariableOp`, know that `handle_data(handle).shape_and_type[1:]`
160  correspond to the handle data of the variant(s) stored in the Variable.
161
162  Args:
163    initial_value: A `Tensor`.
164    shared_name: A string.
165    name: A string.
166    graph_mode: A python bool.
167
168  Returns:
169    The handle, a `Tensor` of type `resource`.
170  """
171  shape = initial_value.get_shape()
172  dtype = initial_value.dtype.base_dtype
173  container = ops.get_default_graph()._container  # pylint: disable=protected-access
174  if container is None:
175    container = ""
176  handle = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
177                                                   shared_name=shared_name,
178                                                   name=name,
179                                                   container=container)
180
181  if graph_mode:
182    full_handle_data = _combine_handle_data(handle, initial_value)
183    _set_handle_shapes_and_types(handle, full_handle_data, graph_mode)
184    return handle
185  else:
186    # We do not want two distinct ResourceVariable objects for the same
187    # underlying resource in the runtime.
188    # When in eager mode, explicitly ensure so here. When in graph mode, it's
189    # ensured by always generating different variable names.
190    exists = gen_resource_variable_ops.var_is_initialized_op(handle)
191    if exists:
192      raise ValueError("variable object with name '%s' already created. Use "
193                       "get_variable() if reuse is desired." %
194                       shared_name)
195    with context.graph_mode(), ops.Graph().as_default() as graph:
196      h = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
197                                                  shared_name=shared_name,
198                                                  name=name,
199                                                  container=container)
200
201      # Tensor._handle_data contains information for the shape-inference code to
202      # know the shape and dtype of the variable pointed to by a handle. Since
203      # shape inference doesn't run in eager mode we copy this data here for
204      # when the handle is captured by an eager mode function.
205      # pylint: disable=protected-access
206      full_handle_data = _combine_handle_data(h, initial_value)
207      _set_handle_shapes_and_types(handle, full_handle_data, graph_mode)
208      # pylint: enable=protected-access
209    # Clean up op->graph->op reference cycles.
210    ops.dismantle_graph(graph)
211    return handle
212
213
214@contextlib.contextmanager
215def _handle_graph(handle):
216  # Note: might have an eager tensor but not be executing eagerly when building
217  # functions.
218  if (context.executing_eagerly() or isinstance(handle, ops.EagerTensor)
219      or ops.has_default_graph()):
220    yield
221  else:
222    with handle.graph.as_default():
223      yield
224
225
226class EagerResourceDeleter(object):
227  """An object which cleans up a resource handle.
228
229  An alternative to defining a __del__ method on an object. The intended use is
230  that ResourceVariables or other objects with resource handles will maintain a
231  single reference to this object. When the parent object is collected, this
232  object will be too. Even if the parent object is part of a reference cycle,
233  the cycle will be collectable.
234  """
235
236  def __init__(self, handle, handle_device):
237    if not isinstance(handle, ops.Tensor):
238      raise ValueError(
239          ("Passed handle=%s to EagerResourceDeleter. Was expecting a handle "
240           "Tensor." % (handle,)))
241    self._handle = handle
242    self._handle_device = handle_device
243
244  def __del__(self):
245    # Resources follow object-identity when executing eagerly, so it is safe to
246    # delete the resource we have a handle to.
247    try:
248      # This resource was created in eager mode. However, this destructor may be
249      # running in graph mode (especially during unit tests). To clean up
250      # successfully, we switch back into eager mode temporarily.
251      with context.eager_mode():
252        with ops.device(self._handle_device):
253          gen_resource_variable_ops.destroy_resource_op(
254              self._handle, ignore_lookup_error=True)
255    except TypeError:
256      # Suppress some exceptions, mainly for the case when we're running on
257      # module deletion. Things that can go wrong include the context module
258      # already being unloaded, self._handle._handle_data no longer being
259      # valid, and so on. Printing warnings in these cases is silly
260      # (exceptions raised from __del__ are printed as warnings to stderr).
261      pass  # 'NoneType' object is not callable when the handle has been
262      # partially unloaded.
263    except AttributeError:
264      pass  # 'NoneType' object has no attribute 'eager_mode' when context has
265      # been unloaded. Will catch other module unloads as well.
266
267
268def shape_safe_assign_variable_handle(handle, shape, value, name=None):
269  """Helper that checks shape compatibility and assigns variable."""
270  with _handle_graph(handle):
271    value_tensor = ops.convert_to_tensor(value)
272  shape.assert_is_compatible_with(value_tensor.shape)
273  return gen_resource_variable_ops.assign_variable_op(handle,
274                                                      value_tensor,
275                                                      name=name)
276
277
278def _maybe_set_handle_data(dtype, handle, tensor):
279  if dtype == dtypes.variant:
280    # For DT_VARIANT types, the handle's shape_and_type[1:] stores the
281    # variant's handle data.  Extract it.
282    handle_data = get_eager_safe_handle_data(handle)
283    if handle_data.is_set and len(handle_data.shape_and_type) > 1:
284      tensor._handle_data = (  # pylint: disable=protected-access
285          cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData(
286              is_set=True,
287              shape_and_type=handle_data.shape_and_type[1:]))
288
289
290class ResourceVariable(variables.VariableV1):
291  """Variable based on resource handles.
292
293  See the [Variables How To](https://tensorflow.org/guide/variables)
294  for a high level overview.
295
296  A `ResourceVariable` allows you to maintain state across subsequent calls to
297  session.run.
298
299  The `ResourceVariable` constructor requires an initial value for the variable,
300  which can be a `Tensor` of any type and shape. The initial value defines the
301  type and shape of the variable. After construction, the type and shape of
302  the variable are fixed. The value can be changed using one of the assign
303  methods.
304
305  Just like any `Tensor`, variables created with
306  `tf.Variable(use_resource=True)` can be used as inputs for other Ops in the
307  graph. Additionally, all the operators overloaded for the `Tensor` class are
308  carried over to variables, so you can also add nodes to the graph by just
309  doing arithmetic on variables.
310
311  Unlike ref-based variable, a ResourceVariable has well-defined semantics. Each
312  usage of a ResourceVariable in a TensorFlow graph adds a read_value operation
313  to the graph. The Tensors returned by a read_value operation are guaranteed to
314  see all modifications to the value of the variable which happen in any
315  operation on which the read_value depends on (either directly, indirectly, or
316  via a control dependency) and guaranteed to not see any modification to the
317  value of the variable from operations that depend on the read_value operation.
318  Updates from operations that have no dependency relationship to the read_value
319  operation might or might not be visible to read_value.
320
321  For example, if there is more than one assignment to a ResourceVariable in
322  a single session.run call there is a well-defined value for each operation
323  which uses the variable's value if the assignments and the read are connected
324  by edges in the graph. Consider the following example, in which two writes
325  can cause tf.Variable and tf.ResourceVariable to behave differently:
326
327  ```python
328  a = tf.Variable(1.0, use_resource=True)
329  a.initializer.run()
330
331  assign = a.assign(2.0)
332  with tf.control_dependencies([assign]):
333    b = a.read_value()
334  with tf.control_dependencies([b]):
335    other_assign = a.assign(3.0)
336  with tf.control_dependencies([other_assign]):
337    # Will print 2.0 because the value was read before other_assign ran. If
338    # `a` was a tf.Variable instead, 2.0 or 3.0 could be printed.
339    tf.Print(b, [b]).eval()
340  ```
341  """
342
343  def __init__(self,
344               initial_value=None,
345               trainable=True,
346               collections=None,
347               validate_shape=True,  # pylint: disable=unused-argument
348               caching_device=None,
349               name=None,
350               dtype=None,
351               variable_def=None,
352               import_scope=None,
353               constraint=None,
354               distribute_strategy=None):
355    """Creates a variable.
356
357    Args:
358      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
359        which is the initial value for the Variable. Can also be a
360        callable with no argument that returns the initial value when called.
361        (Note that initializer functions from init_ops.py must first be bound
362         to a shape before being used here.)
363      trainable: If `True`, the default, also adds the variable to the graph
364        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
365        the default list of variables to use by the `Optimizer` classes.
366      collections: List of graph collections keys. The new variable is added to
367        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
368      validate_shape: Ignored. Provided for compatibility with tf.Variable.
369      caching_device: Optional device string or function describing where the
370        Variable should be cached for reading.  Defaults to the Variable's
371        device.  If not `None`, caches on another device.  Typical use is to
372        cache on the device where the Ops using the Variable reside, to
373        deduplicate copying through `Switch` and other conditional statements.
374      name: Optional name for the variable. Defaults to `'Variable'` and gets
375        uniquified automatically.
376      dtype: If set, initial_value will be converted to the given type.
377        If None, either the datatype will be kept (if initial_value is
378        a Tensor) or float32 will be used (if it is a Python object convertible
379        to a Tensor).
380      variable_def: `VariableDef` protocol buffer. If not None, recreates the
381        `ResourceVariable` object with its contents. `variable_def` and other
382        arguments (except for import_scope) are mutually exclusive.
383      import_scope: Optional `string`. Name scope to add to the
384        ResourceVariable. Only used when `variable_def` is provided.
385      constraint: An optional projection function to be applied to the variable
386        after being updated by an `Optimizer` (e.g. used to implement norm
387        constraints or value constraints for layer weights). The function must
388        take as input the unprojected Tensor representing the value of the
389        variable and return the Tensor for the projected value
390        (which must have the same shape). Constraints are not safe to
391        use when doing asynchronous distributed training.
392      distribute_strategy: The tf.distribute.Strategy this variable is being
393        created inside of.
394
395    Raises:
396      ValueError: If the initial value is not specified, or does not have a
397        shape and `validate_shape` is `True`.
398
399    @compatibility(eager)
400    When Eager Execution is enabled, the default for the `collections` argument
401    is `None`, which signifies that this `Variable` will not be added to any
402    collections.
403    @end_compatibility
404    """
405    self._distribute_strategy = distribute_strategy
406    if variable_def:
407      if initial_value is not None:
408        raise ValueError("variable_def and initial_value are mutually "
409                         "exclusive.")
410      if context.executing_eagerly():
411        raise ValueError("Creating ResourceVariable from variable_def is "
412                         "not supported when eager execution is enabled.")
413      self._init_from_proto(variable_def, import_scope=import_scope)
414    else:
415      self._init_from_args(
416          initial_value=initial_value,
417          trainable=trainable,
418          collections=collections,
419          caching_device=caching_device,
420          name=name,
421          dtype=dtype,
422          constraint=constraint)
423
424  def __repr__(self):
425    if context.executing_eagerly() and not self._in_graph_mode:
426      return "<tf.Variable '%s' shape=%s dtype=%s, numpy=%s>" % (
427          self.name, self.get_shape(), self.dtype.name,
428          ops.numpy_text(self.read_value(), is_repr=True))
429    else:
430      return "<tf.Variable '%s' shape=%s dtype=%s>" % (
431          self.name, self.get_shape(), self.dtype.name)
432
433  def _init_from_args(self,
434                      initial_value=None,
435                      trainable=True,
436                      collections=None,
437                      caching_device=None,
438                      name=None,
439                      dtype=None,
440                      constraint=None):
441    """Creates a variable.
442
443    Args:
444      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
445        which is the initial value for the Variable. The initial value must have
446        a shape specified unless `validate_shape` is set to False. Can also be a
447        callable with no argument that returns the initial value when called.
448        (Note that initializer functions from init_ops.py must first be bound
449         to a shape before being used here.)
450      trainable: If `True`, the default, also adds the variable to the graph
451        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
452        the default list of variables to use by the `Optimizer` classes.
453      collections: List of graph collections keys. The new variable is added to
454        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
455      validate_shape: Ignored. Provided for compatibility with tf.Variable.
456      caching_device: Optional device string or function describing where the
457        Variable should be cached for reading.  Defaults to the Variable's
458        device.  If not `None`, caches on another device.  Typical use is to
459        cache on the device where the Ops using the Variable reside, to
460        deduplicate copying through `Switch` and other conditional statements.
461      name: Optional name for the variable. Defaults to `'Variable'` and gets
462        uniquified automatically.
463      dtype: If set, initial_value will be converted to the given type.
464        If None, either the datatype will be kept (if initial_value is
465       a Tensor) or float32 will be used (if it is a Python object convertible
466       to a Tensor).
467      constraint: An optional projection function to be applied to the variable
468        after being updated by an `Optimizer` (e.g. used to implement norm
469        constraints or value constraints for layer weights). The function must
470        take as input the unprojected Tensor representing the value of the
471        variable and return the Tensor for the projected value
472        (which must have the same shape). Constraints are not safe to
473        use when doing asynchronous distributed training.
474
475    Raises:
476      ValueError: If the initial value is not specified, or does not have a
477        shape and `validate_shape` is `True`.
478
479    @compatibility(eager)
480    When Eager Execution is enabled, variables are never added to collections.
481    It is not implicitly added to the `GLOBAL_VARIABLES` or
482    `TRAINABLE_VARIABLES` collections, and the `collections` argument is
483    ignored.
484    @end_compatibility
485    """
486    if initial_value is None:
487      raise ValueError("initial_value must be specified.")
488    init_from_fn = callable(initial_value)
489
490    if isinstance(initial_value, ops.Tensor) and hasattr(
491        initial_value, "graph") and initial_value.graph.building_function:
492      raise ValueError("Tensor-typed variable initializers must either be "
493                       "wrapped in an init_scope or callable "
494                       "(e.g., `tf.Variable(lambda : "
495                       "tf.truncated_normal([10, 40]))`) when building "
496                       "functions. Please file a feature request if this "
497                       "restriction inconveniences you.")
498
499    if collections is None:
500      collections = [ops.GraphKeys.GLOBAL_VARIABLES]
501    if not isinstance(collections, (list, tuple, set)):
502      raise ValueError(
503          "collections argument to Variable constructor must be a list, tuple, "
504          "or set. Got %s of type %s" % (collections, type(collections)))
505    if constraint is not None and not callable(constraint):
506      raise ValueError("The `constraint` argument must be a callable.")
507
508    if isinstance(initial_value, trackable.CheckpointInitialValue):
509      self._maybe_initialize_trackable()
510      self._update_uid = initial_value.checkpoint_position.restore_uid
511      initial_value = initial_value.wrapped_value
512
513    self._trainable = trainable
514    if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
515      collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
516    self._save_slice_info = None
517    # Store the graph key so optimizers know how to only retrieve variables from
518    # this graph.
519    self._graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
520    with ops.init_scope():
521      self._in_graph_mode = not context.executing_eagerly()
522      with ops.name_scope(name, "Variable", []
523                          if init_from_fn else [initial_value]) as name:
524        # pylint: disable=protected-access
525        handle_name = ops._name_from_scope_name(name)
526        if self._in_graph_mode:
527          shared_name = handle_name
528          unique_id = shared_name
529        else:
530          # When in eager mode use a uid for the shared_name, to prevent
531          # accidental sharing.
532          unique_id = "%s_%d" % (handle_name, ops.uid())
533          shared_name = context.shared_name()
534        # Use attr_scope and device(None) to simulate the behavior of
535        # colocate_with when the variable we want to colocate with doesn't
536        # yet exist.
537        device_context_manager = (
538            ops.device if self._in_graph_mode else ops.NullContextmanager)
539        attr = attr_value_pb2.AttrValue(
540            list=attr_value_pb2.AttrValue.ListValue(
541                s=[compat.as_bytes("loc:@%s" % handle_name)]))
542        with ops.get_default_graph()._attr_scope({"_class": attr}):
543          with ops.name_scope("Initializer"), device_context_manager(None):
544            initial_value = ops.convert_to_tensor(
545                initial_value() if init_from_fn else initial_value,
546                name="initial_value", dtype=dtype)
547          self._handle = eager_safe_variable_handle(
548              initial_value=initial_value,
549              shared_name=shared_name,
550              name=name,
551              graph_mode=self._in_graph_mode)
552        self._shape = initial_value.shape
553        # pylint: disable=protected-access
554        if (self._in_graph_mode and initial_value is not None and
555            initial_value.op._get_control_flow_context() is not None):
556          raise ValueError(
557              "Initializer for variable %s is from inside a control-flow "
558              "construct, such as a loop or conditional. When creating a "
559              "variable inside a loop or conditional, use a lambda as the "
560              "initializer." % name)
561        # pylint: enable=protected-access
562        self._unique_id = unique_id
563        self._initial_value = initial_value if self._in_graph_mode else None
564        self._handle_name = handle_name + ":0"
565        self._dtype = initial_value.dtype.base_dtype
566        self._constraint = constraint
567
568        if self._in_graph_mode:
569          with ops.name_scope("IsInitialized"):
570            self._is_initialized_op = (
571                gen_resource_variable_ops.var_is_initialized_op(self._handle))
572          if initial_value is not None:
573            with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
574              # pylint: disable=protected-access
575              self._initializer_op = (
576                  gen_resource_variable_ops.assign_variable_op(
577                      self._handle,
578                      variables._try_guard_against_uninitialized_dependencies(
579                          name,
580                          initial_value),
581                      name=n))
582              # pylint: enable=protected-access
583          with ops.name_scope("Read"), ops.colocate_with(self._handle):
584            # Manually assign reads to the handle's device to avoid log
585            # messages.
586            with ops.device(self._handle.device):
587              value = self._read_variable_op()
588            self._graph_element = value
589            if caching_device is not None:
590              # Variables may be created in a tf.device() or ops.colocate_with()
591              # context. At the same time, users would expect caching device to
592              # be independent of this context, and/or would not expect the
593              # current device context to be merged with the caching device
594              # spec.  Therefore we reset the colocation stack before creating
595              # the cached value. Note that resetting the colocation stack will
596              # also reset the device stack.
597              with ops.colocate_with(None, ignore_existing=True):
598                with ops.device(caching_device):
599                  self._cached_value = array_ops.identity(value)
600            else:
601              self._cached_value = None
602        else:
603          gen_resource_variable_ops.assign_variable_op(self._handle,
604                                                       initial_value)
605          self._is_initialized_op = None
606          self._initializer_op = None
607          self._graph_element = None
608          if caching_device:
609            with ops.device(caching_device):
610              self._cached_value = self._read_variable_op()
611          else:
612            self._cached_value = None
613        if not context.executing_eagerly():
614          # Eager variables are only added to collections if they are part of an
615          # eager variable store (otherwise in an interactive session they would
616          # hog memory and cause OOM). This is done in ops/variable_scope.py.
617          ops.add_to_collections(collections, self)
618        elif ops.GraphKeys.GLOBAL_STEP in collections:
619          ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self)
620
621    if not self._in_graph_mode:
622      # After the handle has been created, set up a way to clean it up when
623      # executing eagerly. We'll hold the only reference to the deleter, so that
624      # when this object is garbage collected the deleter will be too. This
625      # means ResourceVariables can be part of reference cycles without those
626      # cycles being uncollectable, and means that no __del__ will be defined at
627      # all in graph mode.
628      self._handle_deleter = EagerResourceDeleter(
629          handle=self._handle, handle_device=self._handle.device)
630
631  def _init_from_proto(self, variable_def, import_scope=None):
632    """Initializes from `VariableDef` proto."""
633    # Note that init_from_proto is currently not supported in Eager mode.
634    assert not context.executing_eagerly()
635    self._in_graph_mode = True
636    assert isinstance(variable_def, variable_pb2.VariableDef)
637    if not variable_def.is_resource:
638      raise ValueError("Trying to restore Variable as ResourceVariable.")
639
640    # Create from variable_def.
641    g = ops.get_default_graph()
642    self._handle = g.as_graph_element(
643        ops.prepend_name_scope(
644            variable_def.variable_name, import_scope=import_scope))
645    self._shape = tensor_shape.TensorShape(
646        self._handle.op.get_attr("shape"))
647    self._handle_name = self._handle.name
648    self._unique_id = self._handle_name
649    self._initializer_op = g.as_graph_element(
650        ops.prepend_name_scope(
651            variable_def.initializer_name, import_scope=import_scope))
652    # Check whether initial_value_name exists for backwards compatibility.
653    if (hasattr(variable_def, "initial_value_name") and
654        variable_def.initial_value_name):
655      self._initial_value = g.as_graph_element(
656          ops.prepend_name_scope(variable_def.initial_value_name,
657                                 import_scope=import_scope))
658    else:
659      self._initial_value = None
660    self._trainable = getattr(variable_def, "trainable", True)
661    if variable_def.snapshot_name:
662      snapshot = g.as_graph_element(
663          ops.prepend_name_scope(
664              variable_def.snapshot_name, import_scope=import_scope))
665      if snapshot.op.type != "ReadVariableOp":
666        self._cached_value = snapshot
667      else:
668        self._cached_value = None
669      while snapshot.op.type != "ReadVariableOp":
670        snapshot = snapshot.op.inputs[0]
671      self._graph_element = snapshot
672    else:
673      self._cached_value = None
674      # Legacy case for protos without the snapshot name; assume it's the
675      # following.
676      self._graph_element = g.get_tensor_by_name(
677          self._handle.op.name + "/Read/ReadVariableOp:0")
678    if variable_def.HasField("save_slice_info_def"):
679      self._save_slice_info = variables.Variable.SaveSliceInfo(
680          save_slice_info_def=variable_def.save_slice_info_def,
681          import_scope=import_scope)
682    else:
683      self._save_slice_info = None
684    self._caching_device = None
685    self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype"))
686    self._constraint = None
687
688  @contextlib.contextmanager
689  def _assign_dependencies(self):
690    """Makes assignments depend on the cached value, if any.
691
692    This prevents undefined behavior with reads not ordered wrt writes.
693
694    Yields:
695      None.
696    """
697    if self._cached_value is not None:
698      with ops.control_dependencies([self._cached_value]):
699        yield
700    else:
701      yield
702
703  def __nonzero__(self):
704    return self.__bool__()
705
706  def __bool__(self):
707    return bool(self.read_value())
708
709  def __copy__(self):
710    return self
711
712  def __deepcopy__(self, memo):
713    if not context.executing_eagerly():
714      raise NotImplementedError(
715          "__deepcopy__() is only available when eager execution is enabled.")
716    copied_variable = ResourceVariable(
717        initial_value=self.read_value(),
718        trainable=self._trainable,
719        constraint=self._constraint,
720        dtype=self._dtype,
721        name=self._shared_name + "_copy",
722        distribute_strategy=self._distribute_strategy)
723    memo[self._unique_id] = copied_variable
724    return copied_variable
725
726  @property
727  def dtype(self):
728    """The dtype of this variable."""
729    return self._dtype
730
731  @property
732  def device(self):
733    """The device this variable is on."""
734    return self._handle.device
735
736  @property
737  def graph(self):
738    """The `Graph` of this variable."""
739    return self._handle.graph
740
741  @property
742  def name(self):
743    """The name of the handle for this variable."""
744    return self._handle_name
745
746  @property
747  def shape(self):
748    """The shape of this variable."""
749    return self._shape
750
751  def _shape_as_list(self):
752    if self.shape.ndims is None:
753      return None
754    return [dim.value for dim in self.shape.dims]
755
756  def _shape_tuple(self):
757    shape = self._shape_as_list()
758    if shape is None:
759      return None
760    return tuple(shape)
761
762  @property
763  def create(self):
764    """The op responsible for initializing this variable."""
765    if not self._in_graph_mode:
766      raise RuntimeError("Calling create is not supported when eager execution"
767                         " is enabled.")
768    return self._initializer_op
769
770  @property
771  def handle(self):
772    """The handle by which this variable can be accessed."""
773    return self._handle
774
775  def value(self):
776    """A cached operation which reads the value of this variable."""
777    if self._cached_value is not None:
778      return self._cached_value
779    with ops.colocate_with(None, ignore_existing=True):
780      with ops.device(self._handle.device):
781        return self._read_variable_op()
782
783  def _as_graph_element(self):
784    """Conversion function for Graph.as_graph_element()."""
785    return self._graph_element
786
787  @property
788  def initializer(self):
789    """The op responsible for initializing this variable."""
790    return self._initializer_op
791
792  @property
793  def initial_value(self):
794    """Returns the Tensor used as the initial value for the variable."""
795    if context.executing_eagerly():
796      raise RuntimeError("initial_value not supported in EAGER mode.")
797    return self._initial_value
798
799  @property
800  def constraint(self):
801    """Returns the constraint function associated with this variable.
802
803    Returns:
804      The constraint function that was passed to the variable constructor.
805      Can be `None` if no constraint was passed.
806    """
807    return self._constraint
808
809  @property
810  def op(self):
811    """The op for this variable."""
812    return self._handle.op
813
814  @property
815  def trainable(self):
816    return self._trainable
817
818  def eval(self, session=None):
819    """Evaluates and returns the value of this variable."""
820    if context.executing_eagerly():
821      raise RuntimeError("Trying to eval in EAGER mode")
822    return self._graph_element.eval(session=session)
823
824  def numpy(self):
825    if context.executing_eagerly():
826      return self.read_value().numpy()
827    raise NotImplementedError(
828        "numpy() is only available when eager execution is enabled.")
829
830  @deprecated(None, "Prefer Dataset.range instead.")
831  def count_up_to(self, limit):
832    """Increments this variable until it reaches `limit`.
833
834    When that Op is run it tries to increment the variable by `1`. If
835    incrementing the variable would bring it above `limit` then the Op raises
836    the exception `OutOfRangeError`.
837
838    If no error is raised, the Op outputs the value of the variable before
839    the increment.
840
841    This is essentially a shortcut for `count_up_to(self, limit)`.
842
843    Args:
844      limit: value at which incrementing the variable raises an error.
845
846    Returns:
847      A `Tensor` that will hold the variable value before the increment. If no
848      other Op modifies this variable, the values produced will all be
849      distinct.
850    """
851    return gen_state_ops.resource_count_up_to(self.handle, limit=limit,
852                                              T=self.dtype)
853
854  def _read_variable_op(self):
855    if self.trainable:
856      tape.variable_accessed(self)
857    result = gen_resource_variable_ops.read_variable_op(self._handle,
858                                                        self._dtype)
859    _maybe_set_handle_data(self._dtype, self._handle, result)
860
861    if not context.executing_eagerly():
862      # Note that if a control flow context is active the input of the read op
863      # might not actually be the handle. This line bypasses it.
864      tape.record_operation(
865          "ReadVariableOp", [result], [self._handle], lambda x: [x])
866    return result
867
868  def read_value(self):
869    """Constructs an op which reads the value of this variable.
870
871    Should be used when there are multiple reads, or when it is desirable to
872    read the value only after some condition is true.
873
874    Returns:
875     the read operation.
876    """
877    with ops.name_scope("Read"):
878      # Ensure we read the variable in the same device as the handle.
879      with ops.device(self._handle.device):
880        value = self._read_variable_op()
881    # Return an identity so it can get placed on whatever device the context
882    # specifies instead of the device where the variable is.
883    return array_ops.identity(value)
884
885  def sparse_read(self, indices, name=None):
886    """Reads the value of this variable sparsely, using `gather`."""
887    with ops.name_scope("Gather" if name is None else name) as name:
888      if self.trainable:
889        tape.variable_accessed(self)
890      value = gen_resource_variable_ops.resource_gather(
891          self._handle, indices, dtype=self._dtype, name=name)
892
893      if self._dtype == dtypes.variant:
894        # For DT_VARIANT types, the handle's shape_and_type[1:] stores the
895        # variant's handle data.  Extract it.
896        handle_data = get_eager_safe_handle_data(self._handle)
897        if handle_data.is_set and len(handle_data.shape_and_type) > 1:
898          value._handle_data = (  # pylint: disable=protected-access
899              cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData(
900                  is_set=True,
901                  shape_and_type=handle_data.shape_and_type[1:]))
902
903    return array_ops.identity(value)
904
905  def to_proto(self, export_scope=None):
906    """Converts a `ResourceVariable` to a `VariableDef` protocol buffer.
907
908    Args:
909      export_scope: Optional `string`. Name scope to remove.
910
911    Raises:
912      RuntimeError: If run in EAGER mode.
913
914    Returns:
915      A `VariableDef` protocol buffer, or `None` if the `Variable` is not
916      in the specified name scope.
917    """
918    if context.executing_eagerly():
919      raise RuntimeError("to_proto not supported in EAGER mode.")
920    if export_scope is None or self.handle.name.startswith(export_scope):
921      var_def = variable_pb2.VariableDef()
922      var_def.variable_name = ops.strip_name_scope(self.handle.name,
923                                                   export_scope)
924      if self._initial_value is not None:
925        # This is inside an if-statement for backwards compatibility, since
926        # self._initial_value might be None for variables constructed from old
927        # protos.
928        var_def.initial_value_name = ops.strip_name_scope(
929            self._initial_value.name, export_scope)
930      var_def.initializer_name = ops.strip_name_scope(self.initializer.name,
931                                                      export_scope)
932      if self._cached_value is not None:
933        var_def.snapshot_name = ops.strip_name_scope(self._cached_value.name,
934                                                     export_scope)
935      else:
936        # Store the graph_element here
937        var_def.snapshot_name = ops.strip_name_scope(self._graph_element.name,
938                                                     export_scope)
939      var_def.is_resource = True
940      var_def.trainable = self.trainable
941      if self._save_slice_info:
942        var_def.save_slice_info_def.MergeFrom(
943            self._save_slice_info.to_proto(export_scope=export_scope))
944      return var_def
945    else:
946      return None
947
948  @staticmethod
949  def from_proto(variable_def, import_scope=None):
950    if context.executing_eagerly():
951      raise RuntimeError("from_proto not supported in EAGER mode.")
952    return ResourceVariable(
953        variable_def=variable_def, import_scope=import_scope)
954
955  def set_shape(self, shape):
956    """Unsupported."""
957    raise NotImplementedError("ResourceVariable does not implement set_shape()")
958
959  __array_priority__ = 100
960
961  def is_initialized(self, name=None):
962    """Checks whether a resource variable has been initialized.
963
964    Outputs boolean scalar indicating whether the tensor has been initialized.
965
966    Args:
967      name: A name for the operation (optional).
968
969    Returns:
970      A `Tensor` of type `bool`.
971    """
972    return gen_resource_variable_ops.var_is_initialized_op(self.handle, name)
973
974  def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
975    """Subtracts a value from this variable.
976
977    Args:
978      delta: A `Tensor`. The value to subtract from this variable.
979      use_locking: If `True`, use locking during the operation.
980      name: The name to use for the operation.
981      read_value: A `bool`. Whether to read and return the new value of the
982          variable or not.
983
984    Returns:
985      If `read_value` is `True`, this method will return the new value of the
986      variable after the assignment has completed. Otherwise, when in graph mode
987      it will return the `Operation` that does the assignment, and when in eager
988      mode it will return `None`.
989    """
990    # TODO(apassos): this here and below is not atomic. Consider making it
991    # atomic if there's a way to do so without a performance cost for those who
992    # don't need it.
993    with _handle_graph(self.handle), self._assign_dependencies():
994      assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(
995          self.handle, ops.convert_to_tensor(delta, dtype=self.dtype),
996          name=name)
997    if read_value:
998      return self._lazy_read(assign_sub_op)
999    return assign_sub_op
1000
1001  def assign_add(self, delta, use_locking=None, name=None, read_value=True):
1002    """Adds a value to this variable.
1003
1004    Args:
1005      delta: A `Tensor`. The value to add to this variable.
1006      use_locking: If `True`, use locking during the operation.
1007      name: The name to use for the operation.
1008      read_value: A `bool`. Whether to read and return the new value of the
1009          variable or not.
1010
1011    Returns:
1012      If `read_value` is `True`, this method will return the new value of the
1013      variable after the assignment has completed. Otherwise, when in graph mode
1014      it will return the `Operation` that does the assignment, and when in eager
1015      mode it will return `None`.
1016    """
1017    with _handle_graph(self.handle), self._assign_dependencies():
1018      assign_add_op = gen_resource_variable_ops.assign_add_variable_op(
1019          self.handle, ops.convert_to_tensor(delta, dtype=self.dtype),
1020          name=name)
1021    if read_value:
1022      return self._lazy_read(assign_add_op)
1023    return assign_add_op
1024
1025  def _lazy_read(self, op):
1026    if self.trainable:
1027      tape.variable_accessed(self)
1028    return _UnreadVariable(
1029        handle=self._handle, dtype=self.dtype, shape=self._shape,
1030        in_graph_mode=self._in_graph_mode,
1031        deleter=self._handle_deleter if not self._in_graph_mode else None,
1032        parent_op=op, unique_id=self._unique_id)
1033
1034  def assign(self, value, use_locking=None, name=None, read_value=True):
1035    """Assigns a new value to this variable.
1036
1037    Args:
1038      value: A `Tensor`. The new value for this variable.
1039      use_locking: If `True`, use locking during the assignment.
1040      name: The name to use for the assignment.
1041      read_value: A `bool`. Whether to read and return the new value of the
1042          variable or not.
1043
1044    Returns:
1045      If `read_value` is `True`, this method will return the new value of the
1046      variable after the assignment has completed. Otherwise, when in graph mode
1047      it will return the `Operation` that does the assignment, and when in eager
1048      mode it will return `None`.
1049    """
1050    # Note: not depending on the cached value here since this can used to
1051    # initialize the variable.
1052    with _handle_graph(self.handle):
1053      value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
1054      self._shape.assert_is_compatible_with(value_tensor.shape)
1055      assign_op = gen_resource_variable_ops.assign_variable_op(
1056          self.handle, value_tensor, name=name)
1057      if read_value:
1058        return self._lazy_read(assign_op)
1059    return assign_op
1060
1061  def __reduce__(self):
1062    # The implementation mirrors that of __deepcopy__.
1063    return functools.partial(
1064        ResourceVariable,
1065        initial_value=self.numpy(),
1066        trainable=self.trainable,
1067        name=self._shared_name,
1068        dtype=self.dtype,
1069        constraint=self.constraint,
1070        distribute_strategy=self._distribute_strategy), ()
1071
1072  def scatter_sub(self, sparse_delta, use_locking=False, name=None):
1073    """Subtracts `IndexedSlices` from this variable.
1074
1075    Args:
1076      sparse_delta: `IndexedSlices` to be subtracted from this variable.
1077      use_locking: If `True`, use locking during the operation.
1078      name: the name of the operation.
1079
1080    Returns:
1081      A `Tensor` that will hold the new value of this variable after
1082      the scattered subtraction has completed.
1083
1084    Raises:
1085      ValueError: if `sparse_delta` is not an `IndexedSlices`.
1086    """
1087    if not isinstance(sparse_delta, ops.IndexedSlices):
1088      raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
1089    return self._lazy_read(gen_resource_variable_ops.resource_scatter_sub(
1090        self.handle, sparse_delta.indices,
1091        ops.convert_to_tensor(sparse_delta.values, self.dtype), name=name))
1092
1093  def scatter_add(self, sparse_delta, use_locking=False, name=None):
1094    """Adds `IndexedSlices` from this variable.
1095
1096    Args:
1097      sparse_delta: `IndexedSlices` to be added to this variable.
1098      use_locking: If `True`, use locking during the operation.
1099      name: the name of the operation.
1100
1101    Returns:
1102      A `Tensor` that will hold the new value of this variable after
1103      the scattered subtraction has completed.
1104
1105    Raises:
1106      ValueError: if `sparse_delta` is not an `IndexedSlices`.
1107    """
1108    if not isinstance(sparse_delta, ops.IndexedSlices):
1109      raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
1110    return self._lazy_read(gen_resource_variable_ops.resource_scatter_add(
1111        self.handle, sparse_delta.indices,
1112        ops.convert_to_tensor(sparse_delta.values, self.dtype), name=name))
1113
1114  def scatter_update(self, sparse_delta, use_locking=False, name=None):
1115    """Assigns `IndexedSlices` to this variable.
1116
1117    Args:
1118      sparse_delta: `IndexedSlices` to be assigned to this variable.
1119      use_locking: If `True`, use locking during the operation.
1120      name: the name of the operation.
1121
1122    Returns:
1123      A `Tensor` that will hold the new value of this variable after
1124      the scattered subtraction has completed.
1125
1126    Raises:
1127      ValueError: if `sparse_delta` is not an `IndexedSlices`.
1128    """
1129    if not isinstance(sparse_delta, ops.IndexedSlices):
1130      raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
1131    return self._lazy_read(gen_resource_variable_ops.resource_scatter_update(
1132        self.handle, sparse_delta.indices,
1133        ops.convert_to_tensor(sparse_delta.values, self.dtype), name=name))
1134
1135  def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
1136    """Assigns `IndexedSlices` to this variable batch-wise.
1137
1138    Analogous to `batch_gather`. This assumes that this variable and the
1139    sparse_delta IndexedSlices have a series of leading dimensions that are the
1140    same for all of them, and the updates are performed on the last dimension of
1141    indices. In other words, the dimensions should be the following:
1142
1143    `num_prefix_dims = sparse_delta.indices.ndims - 1`
1144    `batch_dim = num_prefix_dims + 1`
1145    `sparse_delta.updates.shape = sparse_delta.indices.shape + var.shape[
1146         batch_dim:]`
1147
1148    where
1149
1150    `sparse_delta.updates.shape[:num_prefix_dims]`
1151    `== sparse_delta.indices.shape[:num_prefix_dims]`
1152    `== var.shape[:num_prefix_dims]`
1153
1154    And the operation performed can be expressed as:
1155
1156    `var[i_1, ..., i_n,
1157         sparse_delta.indices[i_1, ..., i_n, j]] = sparse_delta.updates[
1158            i_1, ..., i_n, j]`
1159
1160    When sparse_delta.indices is a 1D tensor, this operation is equivalent to
1161    `scatter_update`.
1162
1163    To avoid this operation one can looping over the first `ndims` of the
1164    variable and using `scatter_update` on the subtensors that result of slicing
1165    the first dimension. This is a valid option for `ndims = 1`, but less
1166    efficient than this implementation.
1167
1168    Args:
1169      sparse_delta: `IndexedSlices` to be assigned to this variable.
1170      use_locking: If `True`, use locking during the operation.
1171      name: the name of the operation.
1172
1173    Returns:
1174      A `Tensor` that will hold the new value of this variable after
1175      the scattered subtraction has completed.
1176
1177    Raises:
1178      ValueError: if `sparse_delta` is not an `IndexedSlices`.
1179    """
1180    return self._lazy_read(state_ops.batch_scatter_update(
1181        self, sparse_delta.indices, sparse_delta.values,
1182        use_locking=use_locking, name=name))
1183
1184  def scatter_nd_sub(self, indices, updates, name=None):
1185    """Applies sparse subtraction to individual values or slices in a Variable.
1186
1187    `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
1188
1189    `indices` must be integer tensor, containing indices into `ref`.
1190    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
1191
1192    The innermost dimension of `indices` (with length `K`) corresponds to
1193    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
1194    dimension of `ref`.
1195
1196    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
1197
1198    ```
1199    [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
1200    ```
1201
1202    For example, say we want to add 4 scattered elements to a rank-1 tensor to
1203    8 elements. In Python, that update would look like this:
1204
1205    ```python
1206        ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
1207        indices = tf.constant([[4], [3], [1] ,[7]])
1208        updates = tf.constant([9, 10, 11, 12])
1209        op = ref.scatter_nd_sub(indices, updates)
1210        with tf.Session() as sess:
1211          print sess.run(op)
1212    ```
1213
1214    The resulting update to ref would look like this:
1215
1216        [1, -9, 3, -6, -6, 6, 7, -4]
1217
1218    See `tf.scatter_nd` for more details about how to make updates to
1219    slices.
1220
1221    Args:
1222      indices: The indices to be used in the operation.
1223      updates: The values to be used in the operation.
1224      name: the name of the operation.
1225
1226    Returns:
1227      A `Tensor` that will hold the new value of this variable after
1228      the scattered subtraction has completed.
1229
1230    Raises:
1231      ValueError: if `sparse_delta` is not an `IndexedSlices`.
1232    """
1233    return self._lazy_read(gen_state_ops.resource_scatter_nd_sub(
1234        self.handle, indices, ops.convert_to_tensor(updates, self.dtype),
1235        name=name))
1236
1237  def scatter_nd_add(self, indices, updates, name=None):
1238    """Applies sparse addition to individual values or slices in a Variable.
1239
1240    `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
1241
1242    `indices` must be integer tensor, containing indices into `ref`.
1243    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
1244
1245    The innermost dimension of `indices` (with length `K`) corresponds to
1246    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
1247    dimension of `ref`.
1248
1249    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
1250
1251    ```
1252    [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
1253    ```
1254
1255    For example, say we want to add 4 scattered elements to a rank-1 tensor to
1256    8 elements. In Python, that update would look like this:
1257
1258    ```python
1259        ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
1260        indices = tf.constant([[4], [3], [1] ,[7]])
1261        updates = tf.constant([9, 10, 11, 12])
1262        add = ref.scatter_nd_add(indices, updates)
1263        with tf.Session() as sess:
1264          print sess.run(add)
1265    ```
1266
1267    The resulting update to ref would look like this:
1268
1269        [1, 13, 3, 14, 14, 6, 7, 20]
1270
1271    See `tf.scatter_nd` for more details about how to make updates to
1272    slices.
1273
1274    Args:
1275      indices: The indices to be used in the operation.
1276      updates: The values to be used in the operation.
1277      name: the name of the operation.
1278
1279    Returns:
1280      A `Tensor` that will hold the new value of this variable after
1281      the scattered subtraction has completed.
1282
1283    Raises:
1284      ValueError: if `sparse_delta` is not an `IndexedSlices`.
1285    """
1286    return self._lazy_read(gen_state_ops.resource_scatter_nd_add(
1287        self.handle, indices, ops.convert_to_tensor(updates, self.dtype),
1288        name=name))
1289
1290  def scatter_nd_update(self, indices, updates, name=None):
1291    """Applies sparse assignment to individual values or slices in a Variable.
1292
1293    `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
1294
1295    `indices` must be integer tensor, containing indices into `ref`.
1296    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
1297
1298    The innermost dimension of `indices` (with length `K`) corresponds to
1299    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
1300    dimension of `ref`.
1301
1302    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
1303
1304    ```
1305    [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
1306    ```
1307
1308    For example, say we want to add 4 scattered elements to a rank-1 tensor to
1309    8 elements. In Python, that update would look like this:
1310
1311    ```python
1312        ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
1313        indices = tf.constant([[4], [3], [1] ,[7]])
1314        updates = tf.constant([9, 10, 11, 12])
1315        op = ref.scatter_nd_update(indices, updates)
1316        with tf.Session() as sess:
1317          print sess.run(op)
1318    ```
1319
1320    The resulting update to ref would look like this:
1321
1322        [1, 11, 3, 10, 9, 6, 7, 12]
1323
1324    See `tf.scatter_nd` for more details about how to make updates to
1325    slices.
1326
1327    Args:
1328      indices: The indices to be used in the operation.
1329      updates: The values to be used in the operation.
1330      name: the name of the operation.
1331
1332    Returns:
1333      A `Tensor` that will hold the new value of this variable after
1334      the scattered subtraction has completed.
1335
1336    Raises:
1337      ValueError: if `sparse_delta` is not an `IndexedSlices`.
1338    """
1339    return self._lazy_read(gen_state_ops.resource_scatter_nd_update(
1340        self.handle, indices, ops.convert_to_tensor(updates, self.dtype),
1341        name=name))
1342
1343  def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask,
1344                            end_mask, ellipsis_mask, new_axis_mask,
1345                            shrink_axis_mask):
1346    with _handle_graph(self.handle), self._assign_dependencies():
1347      return self._lazy_read(
1348          gen_array_ops.resource_strided_slice_assign(
1349              ref=self.handle,
1350              begin=begin,
1351              end=end,
1352              strides=strides,
1353              value=ops.convert_to_tensor(value, dtype=self.dtype),
1354              name=name,
1355              begin_mask=begin_mask,
1356              end_mask=end_mask,
1357              ellipsis_mask=ellipsis_mask,
1358              new_axis_mask=new_axis_mask,
1359              shrink_axis_mask=shrink_axis_mask))
1360
1361  def __int__(self):
1362    if self.dtype != dtypes.int32 and self.dtype != dtypes.int64:
1363      raise TypeError("Non-integer variable can't be converted to integer.")
1364    return int(self.value().numpy())
1365
1366  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
1367    del name
1368    if dtype is not None and not dtype.is_compatible_with(self.dtype):
1369      raise ValueError(
1370          "Incompatible type conversion requested to type {!r} for variable "
1371          "of type {!r}".format(dtype.name, self.dtype.name))
1372    if as_ref:
1373      return self.read_value().op.inputs[0]
1374    else:
1375      return self.value()
1376
1377  def __iadd__(self, unused_other):
1378    raise RuntimeError("Variable += value not supported. Use "
1379                       "variable.assign_add(value) to modify the variable "
1380                       "value and variable = variable + value to get a new "
1381                       "Tensor object.")
1382
1383  def __isub__(self, unused_other):
1384    raise RuntimeError("Variable -= value not supported. Use "
1385                       "variable.assign_sub(value) to modify the variable "
1386                       "value and variable = variable - value to get a new "
1387                       "Tensor object.")
1388
1389  def __imul__(self, unused_other):
1390    raise RuntimeError("Variable *= value not supported. Use "
1391                       "`var.assign(var * value)` to modify the variable or "
1392                       "`var = var * value` to get a new Tensor object.")
1393
1394  def __idiv__(self, unused_other):
1395    raise RuntimeError("Variable /= value not supported. Use "
1396                       "`var.assign(var / value)` to modify the variable or "
1397                       "`var = var / value` to get a new Tensor object.")
1398
1399  def __itruediv__(self, unused_other):
1400    raise RuntimeError("Variable /= value not supported. Use "
1401                       "`var.assign(var / value)` to modify the variable or "
1402                       "`var = var / value` to get a new Tensor object.")
1403
1404  def __irealdiv__(self, unused_other):
1405    raise RuntimeError("Variable /= value not supported. Use "
1406                       "`var.assign(var / value)` to modify the variable or "
1407                       "`var = var / value` to get a new Tensor object.")
1408
1409  def __ipow__(self, unused_other):
1410    raise RuntimeError("Variable **= value not supported. Use "
1411                       "`var.assign(var ** value)` to modify the variable or "
1412                       "`var = var ** value` to get a new Tensor object.")
1413
1414
1415pywrap_tensorflow.TFE_Py_RegisterResourceVariableType(ResourceVariable)
1416math_ops._resource_variable_type = ResourceVariable  # pylint: disable=protected-access
1417
1418
1419def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False):
1420  return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref)  # pylint: disable=protected-access
1421
1422
1423# Register a conversion function which reads the value of the variable,
1424# allowing instances of the class to be used as tensors.
1425ops.register_tensor_conversion_function(ResourceVariable, _dense_var_to_tensor)
1426ops.register_dense_tensor_like_type(ResourceVariable)
1427
1428
1429class _UnreadVariable(ResourceVariable):
1430  """Represents a future for a read of a variable.
1431
1432  Pretends to be the tensor if anyone looks.
1433  """
1434
1435  def __init__(self, handle, dtype,  # pylint: disable=super-init-not-called
1436               shape, in_graph_mode, deleter, parent_op, unique_id):
1437    # We do not call super init on purpose.
1438    self._trainable = False
1439    self._save_slice_info = None
1440    self._graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
1441    self._in_graph_mode = in_graph_mode
1442    self._handle = handle
1443    self._shape = shape
1444    self._initial_value = None
1445    if isinstance(self._handle, ops.EagerTensor):
1446      self._handle_name = ""
1447    else:
1448      self._handle_name = self._handle.name
1449    self._unique_id = unique_id
1450    self._dtype = dtype
1451    self._constraint = None
1452    self._cached_value = None
1453    self._is_initialized_op = None
1454    self._initializer_op = None
1455    self._parent_op = parent_op
1456    if context.executing_eagerly():
1457      self._graph_element = None
1458    else:
1459      self._graph_element = self.read_value()
1460    self._handle_deleter = deleter
1461
1462  @property
1463  def name(self):
1464    if self._in_graph_mode:
1465      return self._parent_op.name
1466    else:
1467      return "UnreadVariable"
1468
1469  def value(self):
1470    return self._read_variable_op()
1471
1472  def read_value(self):
1473    return self._read_variable_op()
1474
1475  def _read_variable_op(self):
1476    with ops.control_dependencies([self._parent_op]):
1477      result = gen_resource_variable_ops.read_variable_op(self._handle,
1478                                                          self._dtype)
1479      _maybe_set_handle_data(self._dtype, self._handle, result)
1480      return result
1481
1482
1483  @property
1484  def op(self):
1485    """The op for this variable."""
1486    return self._parent_op
1487
1488
1489ops.register_dense_tensor_like_type(_UnreadVariable)
1490
1491
1492@ops.RegisterGradient("ReadVariableOp")
1493def _ReadGrad(_, grad):
1494  """Gradient for read op."""
1495  return grad
1496
1497
1498def variable_shape(handle, out_type=dtypes.int32):
1499  if getattr(
1500      handle, "_handle_data", None) is None or not handle._handle_data.is_set:
1501    return gen_resource_variable_ops.variable_shape(handle, out_type=out_type)
1502  shape_proto = handle._handle_data.shape_and_type[0].shape
1503  if shape_proto.unknown_rank or any(x.size == -1 for x in shape_proto.dim):
1504    return gen_resource_variable_ops.variable_shape(handle, out_type=out_type)
1505  return constant_op.constant([x.size for x in shape_proto.dim], dtype=out_type)
1506
1507
1508@ops.RegisterGradient("ResourceGather")
1509def _GatherGrad(op, grad):
1510  """Gradient for gather op."""
1511  # Build appropriately shaped IndexedSlices
1512  handle = op.inputs[0]
1513  indices = op.inputs[1]
1514  params_shape = variable_shape(handle)
1515  size = array_ops.expand_dims(array_ops.size(indices), 0)
1516  values_shape = array_ops.concat([size, params_shape[1:]], 0)
1517  values = array_ops.reshape(grad, values_shape)
1518  indices = array_ops.reshape(indices, size)
1519  return (ops.IndexedSlices(values, indices, params_shape), None)
1520
1521
1522def _to_proto_fn(v, export_scope=None):
1523  """Converts Variable and ResourceVariable to VariableDef for collections."""
1524  return v.to_proto(export_scope=export_scope)
1525
1526
1527def _from_proto_fn(v, import_scope=None):
1528  """Creates Variable or ResourceVariable from VariableDef as needed."""
1529  if v.is_resource:
1530    return ResourceVariable.from_proto(v, import_scope=import_scope)
1531  return variables.Variable.from_proto(v, import_scope=import_scope)
1532
1533
1534ops.register_proto_function(
1535    ops.GraphKeys.GLOBAL_VARIABLES,
1536    proto_type=variable_pb2.VariableDef,
1537    to_proto=_to_proto_fn,
1538    from_proto=_from_proto_fn)
1539ops.register_proto_function(
1540    ops.GraphKeys.TRAINABLE_VARIABLES,
1541    proto_type=variable_pb2.VariableDef,
1542    to_proto=_to_proto_fn,
1543    from_proto=_from_proto_fn)
1544ops.register_proto_function(
1545    ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
1546    proto_type=variable_pb2.VariableDef,
1547    to_proto=_to_proto_fn,
1548    from_proto=_from_proto_fn)
1549ops.register_proto_function(
1550    ops.GraphKeys.LOCAL_VARIABLES,
1551    proto_type=variable_pb2.VariableDef,
1552    to_proto=_to_proto_fn,
1553    from_proto=_from_proto_fn)
1554ops.register_proto_function(
1555    ops.GraphKeys.MODEL_VARIABLES,
1556    proto_type=variable_pb2.VariableDef,
1557    to_proto=_to_proto_fn,
1558    from_proto=_from_proto_fn)
1559ops.register_proto_function(
1560    ops.GraphKeys.GLOBAL_STEP,
1561    proto_type=variable_pb2.VariableDef,
1562    to_proto=_to_proto_fn,
1563    from_proto=_from_proto_fn)
1564
1565
1566def is_resource_variable(var):
1567  """"Returns True if `var` is to be considered a ResourceVariable."""
1568  return isinstance(var, ResourceVariable) or hasattr(
1569      var, "_should_act_as_resource_variable")
1570
1571
1572def copy_to_graph_uninitialized(var):
1573  """Copies an existing variable to a new graph, with no initializer."""
1574  # Like ResourceVariable.__deepcopy__, but does not set an initializer on the
1575  # new variable.
1576  # pylint: disable=protected-access
1577  new_variable = ResourceVariable(
1578      initial_value=array_ops.placeholder(
1579          shape=var.shape, dtype=var.dtype,
1580          name="unused_initial_variable_value"),
1581      trainable=var.trainable,
1582      constraint=var._constraint,
1583      dtype=var.dtype,
1584      name=var._shared_name)
1585  new_variable._maybe_initialize_trackable()
1586  # pylint: enable=protected-access
1587  return new_variable
1588
1589ops.NotDifferentiable("VarIsInitializedOp")
1590ops.NotDifferentiable("VariableShape")
1591