1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Variable class."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import enum  # pylint: disable=g-bad-import-order
21import functools
22import os
23import six
24
25from tensorflow.core.framework import attr_value_pb2
26from tensorflow.core.framework import variable_pb2
27from tensorflow.python.eager import context
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import tensor_shape
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import control_flow_ops
33from tensorflow.python.ops import gen_array_ops
34from tensorflow.python.ops import gen_state_ops
35from tensorflow.python.ops import math_ops
36from tensorflow.python.ops import state_ops
37from tensorflow.python.platform import tf_logging as logging
38from tensorflow.python.training.tracking import base as trackable
39from tensorflow.python.util import compat
40from tensorflow.python.util import tf_should_use
41from tensorflow.python.util.deprecation import deprecated
42from tensorflow.python.util.tf_export import tf_export
43
44
45def default_variable_creator(_, **kwds):
46  del kwds
47  raise NotImplementedError("variable_scope needs to be imported")
48
49
50def default_variable_creator_v2(_, **kwds):
51  del kwds
52  raise NotImplementedError("variable_scope needs to be imported")
53
54
55def _make_getter(captured_getter, captured_previous):
56  """To avoid capturing loop variables."""
57  def getter(**kwargs):
58    return captured_getter(captured_previous, **kwargs)
59  return getter
60
61
62@tf_export("VariableSynchronization")
63class VariableSynchronization(enum.Enum):
64  """Indicates when a distributed variable will be synced.
65
66  * `AUTO`: Indicates that the synchronization will be determined by the current
67    `DistributionStrategy` (eg. With `MirroredStrategy` this would be
68    `ON_WRITE`).
69  * `NONE`: Indicates that there will only be one copy of the variable, so
70    there is no need to sync.
71  * `ON_WRITE`: Indicates that the variable will be updated across devices
72    every time it is written.
73  * `ON_READ`: Indicates that the variable will be aggregated across devices
74    when it is read (eg. when checkpointing or when evaluating an op that uses
75    the variable).
76  """
77  AUTO = 0
78  NONE = 1
79  ON_WRITE = 2
80  ON_READ = 3
81
82
83@tf_export("VariableAggregation", v1=[])
84class VariableAggregationV2(enum.Enum):
85  """Indicates how a distributed variable will be aggregated.
86
87  `tf.contrib.distribute.DistributionStrategy` distributes a model by making
88  multiple copies (called "replicas") acting data-parallel on different elements
89  of the input batch. When performing some variable-update operation, say
90  `var.assign_add(x)`, in a model, we need to resolve how to combine the
91  different values for `x` computed in the different replicas.
92
93  * `NONE`: This is the default, giving an error if you use a
94    variable-update operation with multiple replicas.
95  * `SUM`: Add the updates across replicas.
96  * `MEAN`: Take the arithmetic mean ("average") of the updates across replicas.
97  * `ONLY_FIRST_REPLICA`: This is for when every replica is performing the same
98    update, but we only want to perform the update once. Used, e.g., for the
99    global step counter.
100  """
101  NONE = 0
102  SUM = 1
103  MEAN = 2
104  ONLY_FIRST_REPLICA = 3
105
106  def __hash__(self):
107    return hash(self.value)
108
109  def __eq__(self, other):
110    if self is other:
111      return True
112    elif isinstance(other, VariableAggregation):
113      return int(self.value) == int(other.value)
114    else:
115      return False
116
117
118@tf_export(v1=["VariableAggregation"])
119class VariableAggregation(enum.Enum):
120  NONE = 0
121  SUM = 1
122  MEAN = 2
123  ONLY_FIRST_REPLICA = 3
124  ONLY_FIRST_TOWER = 3  # DEPRECATED
125
126  def __hash__(self):
127    return hash(self.value)
128
129
130VariableAggregation.__doc__ = (
131    VariableAggregationV2.__doc__ +
132    "* `ONLY_FIRST_TOWER`: Deprecated alias for `ONLY_FIRST_REPLICA`.\n  ")
133
134
135class VariableMetaclass(type):
136  """Metaclass to allow construction of tf.Variable to be overridden."""
137
138  def _variable_v1_call(cls,
139                        initial_value=None,
140                        trainable=None,
141                        collections=None,
142                        validate_shape=True,
143                        caching_device=None,
144                        name=None,
145                        variable_def=None,
146                        dtype=None,
147                        expected_shape=None,
148                        import_scope=None,
149                        constraint=None,
150                        use_resource=None,
151                        synchronization=VariableSynchronization.AUTO,
152                        aggregation=VariableAggregation.NONE):
153    """Call on Variable class. Useful to force the signature."""
154    previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
155    for _, getter in ops.get_default_graph()._variable_creator_stack:  # pylint: disable=protected-access
156      previous_getter = _make_getter(getter, previous_getter)
157
158    # Reset `aggregation` that is explicitly set as `None` to the enum NONE.
159    if aggregation is None:
160      aggregation = VariableAggregation.NONE
161    return previous_getter(
162        initial_value=initial_value,
163        trainable=trainable,
164        collections=collections,
165        validate_shape=validate_shape,
166        caching_device=caching_device,
167        name=name,
168        variable_def=variable_def,
169        dtype=dtype,
170        expected_shape=expected_shape,
171        import_scope=import_scope,
172        constraint=constraint,
173        use_resource=use_resource,
174        synchronization=synchronization,
175        aggregation=aggregation)
176
177  def _variable_v2_call(cls,
178                        initial_value=None,
179                        trainable=None,
180                        validate_shape=True,
181                        caching_device=None,
182                        name=None,
183                        variable_def=None,
184                        dtype=None,
185                        import_scope=None,
186                        constraint=None,
187                        synchronization=VariableSynchronization.AUTO,
188                        aggregation=VariableAggregation.NONE):
189    """Call on Variable class. Useful to force the signature."""
190    previous_getter = lambda **kws: default_variable_creator_v2(None, **kws)
191    for _, getter in ops.get_default_graph()._variable_creator_stack:  # pylint: disable=protected-access
192      previous_getter = _make_getter(getter, previous_getter)
193
194    # Reset `aggregation` that is explicitly set as `None` to the enum NONE.
195    if aggregation is None:
196      aggregation = VariableAggregation.NONE
197    return previous_getter(
198        initial_value=initial_value,
199        trainable=trainable,
200        validate_shape=validate_shape,
201        caching_device=caching_device,
202        name=name,
203        variable_def=variable_def,
204        dtype=dtype,
205        import_scope=import_scope,
206        constraint=constraint,
207        synchronization=synchronization,
208        aggregation=aggregation)
209
210  def __call__(cls, *args, **kwargs):
211    if cls is VariableV1:
212      return cls._variable_v1_call(*args, **kwargs)
213    elif cls is Variable:
214      return cls._variable_v2_call(*args, **kwargs)
215    else:
216      return super(VariableMetaclass, cls).__call__(*args, **kwargs)
217
218
219@tf_export("Variable", v1=[])
220class Variable(six.with_metaclass(VariableMetaclass,
221                                  trackable.Trackable)):
222  """See the [Variables Guide](https://tensorflow.org/guide/variables).
223
224  A variable maintains state in the graph across calls to `run()`. You add a
225  variable to the graph by constructing an instance of the class `Variable`.
226
227  The `Variable()` constructor requires an initial value for the variable,
228  which can be a `Tensor` of any type and shape. The initial value defines the
229  type and shape of the variable. After construction, the type and shape of
230  the variable are fixed. The value can be changed using one of the assign
231  methods.
232
233  If you want to change the shape of a variable later you have to use an
234  `assign` Op with `validate_shape=False`.
235
236  Just like any `Tensor`, variables created with `Variable()` can be used as
237  inputs for other Ops in the graph. Additionally, all the operators
238  overloaded for the `Tensor` class are carried over to variables, so you can
239  also add nodes to the graph by just doing arithmetic on variables.
240
241  ```python
242  import tensorflow as tf
243
244  # Create a variable.
245  w = tf.Variable(<initial-value>, name=<optional-name>)
246
247  # Use the variable in the graph like any Tensor.
248  y = tf.matmul(w, ...another variable or tensor...)
249
250  # The overloaded operators are available too.
251  z = tf.sigmoid(w + y)
252
253  # Assign a new value to the variable with `assign()` or a related method.
254  w.assign(w + 1.0)
255  w.assign_add(1.0)
256  ```
257
258  When you launch the graph, variables have to be explicitly initialized before
259  you can run Ops that use their value. You can initialize a variable by
260  running its *initializer op*, restoring the variable from a save file, or
261  simply running an `assign` Op that assigns a value to the variable. In fact,
262  the variable *initializer op* is just an `assign` Op that assigns the
263  variable's initial value to the variable itself.
264
265  ```python
266  # Launch the graph in a session.
267  with tf.Session() as sess:
268      # Run the variable initializer.
269      sess.run(w.initializer)
270      # ...you now can run ops that use the value of 'w'...
271  ```
272
273  The most common initialization pattern is to use the convenience function
274  `global_variables_initializer()` to add an Op to the graph that initializes
275  all the variables. You then run that Op after launching the graph.
276
277  ```python
278  # Add an Op to initialize global variables.
279  init_op = tf.global_variables_initializer()
280
281  # Launch the graph in a session.
282  with tf.Session() as sess:
283      # Run the Op that initializes global variables.
284      sess.run(init_op)
285      # ...you can now run any Op that uses variable values...
286  ```
287
288  If you need to create a variable with an initial value dependent on another
289  variable, use the other variable's `initialized_value()`. This ensures that
290  variables are initialized in the right order.
291
292  All variables are automatically collected in the graph where they are
293  created. By default, the constructor adds the new variable to the graph
294  collection `GraphKeys.GLOBAL_VARIABLES`. The convenience function
295  `global_variables()` returns the contents of that collection.
296
297  When building a machine learning model it is often convenient to distinguish
298  between variables holding the trainable model parameters and other variables
299  such as a `global step` variable used to count training steps. To make this
300  easier, the variable constructor supports a `trainable=<bool>` parameter. If
301  `True`, the new variable is also added to the graph collection
302  `GraphKeys.TRAINABLE_VARIABLES`. The convenience function
303  `trainable_variables()` returns the contents of this collection. The
304  various `Optimizer` classes use this collection as the default list of
305  variables to optimize.
306
307  WARNING: tf.Variable objects by default have a non-intuitive memory model. A
308  Variable is represented internally as a mutable Tensor which can
309  non-deterministically alias other Tensors in a graph. The set of operations
310  which consume a Variable and can lead to aliasing is undetermined and can
311  change across TensorFlow versions. Avoid writing code which relies on the
312  value of a Variable either changing or not changing as other operations
313  happen. For example, using Variable objects or simple functions thereof as
314  predicates in a `tf.cond` is dangerous and error-prone:
315
316  ```
317  v = tf.Variable(True)
318  tf.cond(v, lambda: v.assign(False), my_false_fn)  # Note: this is broken.
319  ```
320
321  Here, adding `use_resource=True` when constructing the variable will
322  fix any nondeterminism issues:
323
324  ```
325  v = tf.Variable(True, use_resource=True)
326  tf.cond(v, lambda: v.assign(False), my_false_fn)
327  ```
328
329  To use the replacement for variables which does
330  not have these issues:
331
332  * Add `use_resource=True` when constructing `tf.Variable`;
333  * Call `tf.get_variable_scope().set_use_resource(True)` inside a
334    `tf.variable_scope` before the `tf.get_variable()` call.
335  """
336
337  def __init__(self,
338               initial_value=None,
339               trainable=True,
340               validate_shape=True,
341               caching_device=None,
342               name=None,
343               variable_def=None,
344               dtype=None,
345               import_scope=None,
346               constraint=None,
347               synchronization=VariableSynchronization.AUTO,
348               aggregation=VariableAggregation.NONE):
349    """Creates a new variable with value `initial_value`.
350
351    The new variable is added to the graph collections listed in `collections`,
352    which defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
353
354    If `trainable` is `True` the variable is also added to the graph collection
355    `GraphKeys.TRAINABLE_VARIABLES`.
356
357    This constructor creates both a `variable` Op and an `assign` Op to set the
358    variable to its initial value.
359
360    Args:
361      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
362        which is the initial value for the Variable. The initial value must have
363        a shape specified unless `validate_shape` is set to False. Can also be a
364        callable with no argument that returns the initial value when called. In
365        that case, `dtype` must be specified. (Note that initializer functions
366        from init_ops.py must first be bound to a shape before being used here.)
367      trainable: If `True`, the default, GradientTapes automatically watch uses
368        of this variable.
369      validate_shape: If `False`, allows the variable to be initialized with a
370        value of unknown shape. If `True`, the default, the shape of
371        `initial_value` must be known.
372      caching_device: Optional device string describing where the Variable
373        should be cached for reading.  Defaults to the Variable's device.
374        If not `None`, caches on another device.  Typical use is to cache
375        on the device where the Ops using the Variable reside, to deduplicate
376        copying through `Switch` and other conditional statements.
377      name: Optional name for the variable. Defaults to `'Variable'` and gets
378        uniquified automatically.
379      variable_def: `VariableDef` protocol buffer. If not `None`, recreates
380        the Variable object with its contents, referencing the variable's nodes
381        in the graph, which must already exist. The graph is not changed.
382        `variable_def` and the other arguments are mutually exclusive.
383      dtype: If set, initial_value will be converted to the given type.
384        If `None`, either the datatype will be kept (if `initial_value` is
385        a Tensor), or `convert_to_tensor` will decide.
386      import_scope: Optional `string`. Name scope to add to the
387        `Variable.` Only used when initializing from protocol buffer.
388      constraint: An optional projection function to be applied to the variable
389        after being updated by an `Optimizer` (e.g. used to implement norm
390        constraints or value constraints for layer weights). The function must
391        take as input the unprojected Tensor representing the value of the
392        variable and return the Tensor for the projected value
393        (which must have the same shape). Constraints are not safe to
394        use when doing asynchronous distributed training.
395      synchronization: Indicates when a distributed a variable will be
396        aggregated. Accepted values are constants defined in the class
397        `tf.VariableSynchronization`. By default the synchronization is set to
398        `AUTO` and the current `DistributionStrategy` chooses
399        when to synchronize. If `synchronization` is set to `ON_READ`,
400        `trainable` must not be set to `True`.
401      aggregation: Indicates how a distributed variable will be aggregated.
402        Accepted values are constants defined in the class
403        `tf.VariableAggregation`.
404
405    Raises:
406      ValueError: If both `variable_def` and initial_value are specified.
407      ValueError: If the initial value is not specified, or does not have a
408        shape and `validate_shape` is `True`.
409      RuntimeError: If eager execution is enabled.
410    """
411    raise NotImplementedError
412
413  def __repr__(self):
414    raise NotImplementedError
415
416  def value(self):
417    """Returns the last snapshot of this variable.
418
419    You usually do not need to call this method as all ops that need the value
420    of the variable call it automatically through a `convert_to_tensor()` call.
421
422    Returns a `Tensor` which holds the value of the variable.  You can not
423    assign a new value to this tensor as it is not a reference to the variable.
424
425    To avoid copies, if the consumer of the returned value is on the same device
426    as the variable, this actually returns the live value of the variable, not
427    a copy.  Updates to the variable are seen by the consumer.  If the consumer
428    is on a different device it will get a copy of the variable.
429
430    Returns:
431      A `Tensor` containing the value of the variable.
432    """
433    raise NotImplementedError
434
435  def read_value(self):
436    """Returns the value of this variable, read in the current context.
437
438    Can be different from value() if it's on another device, with control
439    dependencies, etc.
440
441    Returns:
442      A `Tensor` containing the value of the variable.
443    """
444    raise NotImplementedError
445
446  def set_shape(self, shape):
447    """Overrides the shape for this variable.
448
449    Args:
450      shape: the `TensorShape` representing the overridden shape.
451    """
452    raise NotImplementedError
453
454  @property
455  def trainable(self):
456    raise NotImplementedError
457
458  def eval(self, session=None):
459    """In a session, computes and returns the value of this variable.
460
461    This is not a graph construction method, it does not add ops to the graph.
462
463    This convenience method requires a session where the graph
464    containing this variable has been launched. If no session is
465    passed, the default session is used.  See `tf.Session` for more
466    information on launching a graph and on sessions.
467
468    ```python
469    v = tf.Variable([1, 2])
470    init = tf.global_variables_initializer()
471
472    with tf.Session() as sess:
473        sess.run(init)
474        # Usage passing the session explicitly.
475        print(v.eval(sess))
476        # Usage with the default session.  The 'with' block
477        # above makes 'sess' the default session.
478        print(v.eval())
479    ```
480
481    Args:
482      session: The session to use to evaluate this variable. If
483        none, the default session is used.
484
485    Returns:
486      A numpy `ndarray` with a copy of the value of this variable.
487    """
488    raise NotImplementedError
489
490  @deprecated(
491      None,
492      "Use Variable.read_value. Variables in 2.X are initialized "
493      "automatically both in eager and graph (inside tf.defun) contexts.")
494  def initialized_value(self):
495    """Returns the value of the initialized variable.
496
497    You should use this instead of the variable itself to initialize another
498    variable with a value that depends on the value of this variable.
499
500    ```python
501    # Initialize 'v' with a random tensor.
502    v = tf.Variable(tf.truncated_normal([10, 40]))
503    # Use `initialized_value` to guarantee that `v` has been
504    # initialized before its value is used to initialize `w`.
505    # The random values are picked only once.
506    w = tf.Variable(v.initialized_value() * 2.0)
507    ```
508
509    Returns:
510      A `Tensor` holding the value of this variable after its initializer
511      has run.
512    """
513    with ops.init_scope():
514      return control_flow_ops.cond(is_variable_initialized(self),
515                                   self.read_value,
516                                   lambda: self.initial_value)
517
518  @property
519  def initial_value(self):
520    """Returns the Tensor used as the initial value for the variable.
521
522    Note that this is different from `initialized_value()` which runs
523    the op that initializes the variable before returning its value.
524    This method returns the tensor that is used by the op that initializes
525    the variable.
526
527    Returns:
528      A `Tensor`.
529    """
530    raise NotImplementedError
531
532  @property
533  def constraint(self):
534    """Returns the constraint function associated with this variable.
535
536    Returns:
537      The constraint function that was passed to the variable constructor.
538      Can be `None` if no constraint was passed.
539    """
540    raise NotImplementedError
541
542  def assign(self, value, use_locking=False, name=None, read_value=True):
543    """Assigns a new value to the variable.
544
545    This is essentially a shortcut for `assign(self, value)`.
546
547    Args:
548      value: A `Tensor`. The new value for this variable.
549      use_locking: If `True`, use locking during the assignment.
550      name: The name of the operation to be created
551      read_value: if True, will return something which evaluates to the
552        new value of the variable; if False will return the assign op.
553
554    Returns:
555      A `Tensor` that will hold the new value of this variable after
556      the assignment has completed.
557    """
558    raise NotImplementedError
559
560  def assign_add(self, delta, use_locking=False, name=None, read_value=True):
561    """Adds a value to this variable.
562
563     This is essentially a shortcut for `assign_add(self, delta)`.
564
565    Args:
566      delta: A `Tensor`. The value to add to this variable.
567      use_locking: If `True`, use locking during the operation.
568      name: The name of the operation to be created
569      read_value: if True, will return something which evaluates to the
570        new value of the variable; if False will return the assign op.
571
572    Returns:
573      A `Tensor` that will hold the new value of this variable after
574      the addition has completed.
575    """
576    raise NotImplementedError
577
578  def assign_sub(self, delta, use_locking=False, name=None, read_value=True):
579    """Subtracts a value from this variable.
580
581    This is essentially a shortcut for `assign_sub(self, delta)`.
582
583    Args:
584      delta: A `Tensor`. The value to subtract from this variable.
585      use_locking: If `True`, use locking during the operation.
586      name: The name of the operation to be created
587      read_value: if True, will return something which evaluates to the
588        new value of the variable; if False will return the assign op.
589
590    Returns:
591      A `Tensor` that will hold the new value of this variable after
592      the subtraction has completed.
593    """
594    raise NotImplementedError
595
596  def scatter_sub(self, sparse_delta, use_locking=False, name=None):
597    """Subtracts `IndexedSlices` from this variable.
598
599    Args:
600      sparse_delta: `IndexedSlices` to be subtracted from this variable.
601      use_locking: If `True`, use locking during the operation.
602      name: the name of the operation.
603
604    Returns:
605      A `Tensor` that will hold the new value of this variable after
606      the scattered subtraction has completed.
607
608    Raises:
609      ValueError: if `sparse_delta` is not an `IndexedSlices`.
610    """
611    raise NotImplementedError
612
613  def scatter_add(self, sparse_delta, use_locking=False, name=None):
614    """Adds `IndexedSlices` to this variable.
615
616    Args:
617      sparse_delta: `IndexedSlices` to be assigned to this variable.
618      use_locking: If `True`, use locking during the operation.
619      name: the name of the operation.
620
621    Returns:
622      A `Tensor` that will hold the new value of this variable after
623      the scattered subtraction has completed.
624
625    Raises:
626      ValueError: if `sparse_delta` is not an `IndexedSlices`.
627    """
628    raise NotImplementedError
629
630  def scatter_update(self, sparse_delta, use_locking=False, name=None):
631    """Assigns `IndexedSlices` to this variable.
632
633    Args:
634      sparse_delta: `IndexedSlices` to be assigned to this variable.
635      use_locking: If `True`, use locking during the operation.
636      name: the name of the operation.
637
638    Returns:
639      A `Tensor` that will hold the new value of this variable after
640      the scattered subtraction has completed.
641
642    Raises:
643      ValueError: if `sparse_delta` is not an `IndexedSlices`.
644    """
645    raise NotImplementedError
646
647  def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
648    """Assigns `IndexedSlices` to this variable batch-wise.
649
650    Analogous to `batch_gather`. This assumes that this variable and the
651    sparse_delta IndexedSlices have a series of leading dimensions that are the
652    same for all of them, and the updates are performed on the last dimension of
653    indices. In other words, the dimensions should be the following:
654
655    `num_prefix_dims = sparse_delta.indices.ndims - 1`
656    `batch_dim = num_prefix_dims + 1`
657    `sparse_delta.updates.shape = sparse_delta.indices.shape + var.shape[
658         batch_dim:]`
659
660    where
661
662    `sparse_delta.updates.shape[:num_prefix_dims]`
663    `== sparse_delta.indices.shape[:num_prefix_dims]`
664    `== var.shape[:num_prefix_dims]`
665
666    And the operation performed can be expressed as:
667
668    `var[i_1, ..., i_n,
669         sparse_delta.indices[i_1, ..., i_n, j]] = sparse_delta.updates[
670            i_1, ..., i_n, j]`
671
672    When sparse_delta.indices is a 1D tensor, this operation is equivalent to
673    `scatter_update`.
674
675    To avoid this operation one can looping over the first `ndims` of the
676    variable and using `scatter_update` on the subtensors that result of slicing
677    the first dimension. This is a valid option for `ndims = 1`, but less
678    efficient than this implementation.
679
680    Args:
681      sparse_delta: `IndexedSlices` to be assigned to this variable.
682      use_locking: If `True`, use locking during the operation.
683      name: the name of the operation.
684
685    Returns:
686      A `Tensor` that will hold the new value of this variable after
687      the scattered subtraction has completed.
688
689    Raises:
690      ValueError: if `sparse_delta` is not an `IndexedSlices`.
691    """
692    raise NotImplementedError
693
694  def scatter_nd_sub(self, indices, updates, name=None):
695    """Applies sparse subtraction to individual values or slices in a Variable.
696
697    Assuming the variable has rank `P` and `indices` is a `Tensor` of rank `Q`.
698
699    `indices` must be integer tensor, containing indices into self.
700    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
701
702    The innermost dimension of `indices` (with length `K`) corresponds to
703    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
704    dimension of self.
705
706    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
707
708    ```
709    [d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]].
710    ```
711
712    For example, say we want to add 4 scattered elements to a rank-1 tensor to
713    8 elements. In Python, that update would look like this:
714
715    ```python
716        v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
717        indices = tf.constant([[4], [3], [1] ,[7]])
718        updates = tf.constant([9, 10, 11, 12])
719        op = v.scatter_nd_sub(indices, updates)
720        with tf.Session() as sess:
721          print sess.run(op)
722    ```
723
724    The resulting update to v would look like this:
725
726        [1, -9, 3, -6, -6, 6, 7, -4]
727
728    See `tf.scatter_nd` for more details about how to make updates to
729    slices.
730
731    Args:
732      indices: The indices to be used in the operation.
733      updates: The values to be used in the operation.
734      name: the name of the operation.
735
736    Returns:
737      A `Tensor` that will hold the new value of this variable after
738      the scattered subtraction has completed.
739
740    Raises:
741      ValueError: if `sparse_delta` is not an `IndexedSlices`.
742    """
743    raise NotImplementedError
744
745  def scatter_nd_add(self, indices, updates, name=None):
746    """Applies sparse addition to individual values or slices in a Variable.
747
748    The Variable has rank `P` and `indices` is a `Tensor` of rank `Q`.
749
750    `indices` must be integer tensor, containing indices into self.
751    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
752
753    The innermost dimension of `indices` (with length `K`) corresponds to
754    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
755    dimension of self.
756
757    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
758
759    ```
760    [d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]].
761    ```
762
763    For example, say we want to add 4 scattered elements to a rank-1 tensor to
764    8 elements. In Python, that update would look like this:
765
766    ```python
767        v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
768        indices = tf.constant([[4], [3], [1] ,[7]])
769        updates = tf.constant([9, 10, 11, 12])
770        add = v.scatter_nd_add(indices, updates)
771        with tf.Session() as sess:
772          print sess.run(add)
773    ```
774
775    The resulting update to v would look like this:
776
777        [1, 13, 3, 14, 14, 6, 7, 20]
778
779    See `tf.scatter_nd` for more details about how to make updates to
780    slices.
781
782    Args:
783      indices: The indices to be used in the operation.
784      updates: The values to be used in the operation.
785      name: the name of the operation.
786
787    Returns:
788      A `Tensor` that will hold the new value of this variable after
789      the scattered subtraction has completed.
790
791    Raises:
792      ValueError: if `sparse_delta` is not an `IndexedSlices`.
793    """
794    raise NotImplementedError
795
796  def scatter_nd_update(self, indices, updates, name=None):
797    """Applies sparse assignment to individual values or slices in a Variable.
798
799    The Variable has rank `P` and `indices` is a `Tensor` of rank `Q`.
800
801    `indices` must be integer tensor, containing indices into self.
802    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
803
804    The innermost dimension of `indices` (with length `K`) corresponds to
805    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
806    dimension of self.
807
808    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
809
810    ```
811    [d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]].
812    ```
813
814    For example, say we want to add 4 scattered elements to a rank-1 tensor to
815    8 elements. In Python, that update would look like this:
816
817    ```python
818        v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
819        indices = tf.constant([[4], [3], [1] ,[7]])
820        updates = tf.constant([9, 10, 11, 12])
821        op = v.scatter_nd_assign(indices, updates)
822        with tf.Session() as sess:
823          print sess.run(op)
824    ```
825
826    The resulting update to v would look like this:
827
828        [1, 11, 3, 10, 9, 6, 7, 12]
829
830    See `tf.scatter_nd` for more details about how to make updates to
831    slices.
832
833    Args:
834      indices: The indices to be used in the operation.
835      updates: The values to be used in the operation.
836      name: the name of the operation.
837
838    Returns:
839      A `Tensor` that will hold the new value of this variable after
840      the scattered subtraction has completed.
841
842    Raises:
843      ValueError: if `sparse_delta` is not an `IndexedSlices`.
844    """
845    raise NotImplementedError
846
847  @deprecated(None, "Prefer Dataset.range instead.")
848  def count_up_to(self, limit):
849    """Increments this variable until it reaches `limit`.
850
851    When that Op is run it tries to increment the variable by `1`. If
852    incrementing the variable would bring it above `limit` then the Op raises
853    the exception `OutOfRangeError`.
854
855    If no error is raised, the Op outputs the value of the variable before
856    the increment.
857
858    This is essentially a shortcut for `count_up_to(self, limit)`.
859
860    Args:
861      limit: value at which incrementing the variable raises an error.
862
863    Returns:
864      A `Tensor` that will hold the variable value before the increment. If no
865      other Op modifies this variable, the values produced will all be
866      distinct.
867    """
868    raise NotImplementedError
869
870  @deprecated(
871      None,
872      "Prefer Variable.assign which has equivalent behavior in 2.X.")
873  def load(self, value, session=None):
874    """Load new value into this variable.
875
876    Writes new value to variable's memory. Doesn't add ops to the graph.
877
878    This convenience method requires a session where the graph
879    containing this variable has been launched. If no session is
880    passed, the default session is used.  See `tf.Session` for more
881    information on launching a graph and on sessions.
882
883    ```python
884    v = tf.Variable([1, 2])
885    init = tf.global_variables_initializer()
886
887    with tf.Session() as sess:
888        sess.run(init)
889        # Usage passing the session explicitly.
890        v.load([2, 3], sess)
891        print(v.eval(sess)) # prints [2 3]
892        # Usage with the default session.  The 'with' block
893        # above makes 'sess' the default session.
894        v.load([3, 4], sess)
895        print(v.eval()) # prints [3 4]
896    ```
897
898    Args:
899        value: New variable value
900        session: The session to use to evaluate this variable. If
901          none, the default session is used.
902
903    Raises:
904        ValueError: Session is not passed and no default session
905    """
906    if context.executing_eagerly():
907      self.assign(value)
908    else:
909      session = session or ops.get_default_session()
910      if session is None:
911        raise ValueError(
912            "Either session argument should be provided or default session "
913            "should be established")
914      session.run(self.initializer, {self.initializer.inputs[1]: value})
915
916  # Conversion to tensor.
917  @staticmethod
918  def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False):  # pylint: disable=invalid-name
919    """Utility function for converting a Variable to a Tensor."""
920    _ = name
921    if dtype and not dtype.is_compatible_with(v.dtype):
922      raise ValueError(
923          "Incompatible type conversion requested to type '%s' for variable "
924          "of type '%s'" % (dtype.name, v.dtype.name))
925    if as_ref:
926      return v._ref()  # pylint: disable=protected-access
927    else:
928      return v.value()
929
930  @classmethod
931  def _OverloadAllOperators(cls):  # pylint: disable=invalid-name
932    """Register overloads for all operators."""
933    for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
934      cls._OverloadOperator(operator)
935    # For slicing, bind getitem differently than a tensor (use SliceHelperVar
936    # instead)
937    # pylint: disable=protected-access
938    setattr(cls, "__getitem__", array_ops._SliceHelperVar)
939
940  @classmethod
941  def _OverloadOperator(cls, operator):  # pylint: disable=invalid-name
942    """Defer an operator overload to `ops.Tensor`.
943
944    We pull the operator out of ops.Tensor dynamically to avoid ordering issues.
945
946    Args:
947      operator: string. The operator name.
948    """
949    tensor_oper = getattr(ops.Tensor, operator)
950
951    def _run_op(a, *args, **kwargs):
952      # pylint: disable=protected-access
953      return tensor_oper(a.value(), *args, **kwargs)
954
955    functools.update_wrapper(_run_op, tensor_oper)
956    setattr(cls, operator, _run_op)
957
958  def __iter__(self):
959    """Dummy method to prevent iteration. Do not call.
960
961    NOTE(mrry): If we register __getitem__ as an overloaded operator,
962    Python will valiantly attempt to iterate over the variable's Tensor from 0
963    to infinity.  Declaring this method prevents this unintended behavior.
964
965    Raises:
966      TypeError: when invoked.
967    """
968    raise TypeError("'Variable' object is not iterable.")
969
970  # NOTE(mrry): This enables the Variable's overloaded "right" binary
971  # operators to run when the left operand is an ndarray, because it
972  # accords the Variable class higher priority than an ndarray, or a
973  # numpy matrix.
974  # TODO(mrry): Convert this to using numpy's __numpy_ufunc__
975  # mechanism, which allows more control over how Variables interact
976  # with ndarrays.
977  __array_priority__ = 100
978
979  @property
980  def name(self):
981    """The name of this variable."""
982    raise NotImplementedError
983
984  @property
985  def _shared_name(self):
986    """The shared name of the variable.
987
988      Unlike name(), shared_name doesn't have ":0" suffix. It is user-specified
989      name with name scope prefix.
990
991    Returns:
992      variable name.
993    """
994    return self.name[:self.name.index(":")]
995
996  @property
997  def initializer(self):
998    """The initializer operation for this variable."""
999    raise NotImplementedError
1000
1001  @property
1002  def device(self):
1003    """The device of this variable."""
1004    raise NotImplementedError
1005
1006  @property
1007  def dtype(self):
1008    """The `DType` of this variable."""
1009    raise NotImplementedError
1010
1011  @property
1012  def op(self):
1013    """The `Operation` of this variable."""
1014    raise NotImplementedError
1015
1016  @property
1017  def graph(self):
1018    """The `Graph` of this variable."""
1019    raise NotImplementedError
1020
1021  @property
1022  def shape(self):
1023    """The `TensorShape` of this variable.
1024
1025    Returns:
1026      A `TensorShape`.
1027    """
1028    raise NotImplementedError
1029
1030  def get_shape(self):
1031    """Alias of `Variable.shape`."""
1032    return self.shape
1033
1034  def _gather_saveables_for_checkpoint(self):
1035    """For implementing `Trackable`. This object is saveable on its own."""
1036    return {trackable.VARIABLE_VALUE_KEY: self}
1037
1038  def to_proto(self, export_scope=None):
1039    """Converts a `Variable` to a `VariableDef` protocol buffer.
1040
1041    Args:
1042      export_scope: Optional `string`. Name scope to remove.
1043
1044    Returns:
1045      A `VariableDef` protocol buffer, or `None` if the `Variable` is not
1046      in the specified name scope.
1047    """
1048    raise NotImplementedError
1049
1050  @staticmethod
1051  def from_proto(variable_def, import_scope=None):
1052    """Returns a `Variable` object created from `variable_def`."""
1053    return RefVariable(variable_def=variable_def,
1054                       import_scope=import_scope)
1055
1056  def _set_save_slice_info(self, save_slice_info):
1057    """Sets the slice info for this `Variable`.
1058
1059    Args:
1060      save_slice_info: A `Variable.SaveSliceInfo` object.
1061    """
1062    self._save_slice_info = save_slice_info
1063
1064  def _get_save_slice_info(self):
1065    return self._save_slice_info
1066
1067  class SaveSliceInfo(object):
1068    """Information on how to save this Variable as a slice.
1069
1070    Provides internal support for saving variables as slices of a larger
1071    variable.  This API is not public and is subject to change.
1072
1073    Available properties:
1074
1075    * full_name
1076    * full_shape
1077    * var_offset
1078    * var_shape
1079    """
1080
1081    def __init__(self,
1082                 full_name=None,
1083                 full_shape=None,
1084                 var_offset=None,
1085                 var_shape=None,
1086                 save_slice_info_def=None,
1087                 import_scope=None):
1088      """Create a `SaveSliceInfo`.
1089
1090      Args:
1091        full_name: Name of the full variable of which this `Variable` is a
1092            slice.
1093        full_shape: Shape of the full variable, as a list of int.
1094        var_offset: Offset of this `Variable` into the full variable, as a
1095            list of int.
1096        var_shape: Shape of this `Variable`, as a list of int.
1097        save_slice_info_def: `SaveSliceInfoDef` protocol buffer. If not `None`,
1098          recreates the SaveSliceInfo object its contents.
1099          `save_slice_info_def` and other arguments are mutually
1100          exclusive.
1101        import_scope: Optional `string`. Name scope to add. Only used
1102          when initializing from protocol buffer.
1103      """
1104      if save_slice_info_def:
1105        assert isinstance(save_slice_info_def, variable_pb2.SaveSliceInfoDef)
1106        self.full_name = ops.prepend_name_scope(
1107            save_slice_info_def.full_name, import_scope=import_scope)
1108        self.full_shape = [i for i in save_slice_info_def.full_shape]
1109        self.var_offset = [i for i in save_slice_info_def.var_offset]
1110        self.var_shape = [i for i in save_slice_info_def.var_shape]
1111      else:
1112        self.full_name = full_name
1113        self.full_shape = full_shape
1114        self.var_offset = var_offset
1115        self.var_shape = var_shape
1116
1117    @property
1118    def spec(self):
1119      """Computes the spec string used for saving."""
1120      full_shape_str = " ".join(["%d" % d for d in self.full_shape]) + " "
1121      sl_spec = ":".join([
1122          "%d,%d" % (o, s) for o, s in zip(self.var_offset, self.var_shape)
1123      ])
1124      return full_shape_str + sl_spec
1125
1126    def to_proto(self, export_scope=None):
1127      """Returns a SaveSliceInfoDef() proto.
1128
1129      Args:
1130        export_scope: Optional `string`. Name scope to remove.
1131
1132      Returns:
1133        A `SaveSliceInfoDef` protocol buffer, or None if the `Variable` is not
1134        in the specified name scope.
1135      """
1136      if (export_scope is None or
1137          self.full_name.startswith(export_scope)):
1138        save_slice_info_def = variable_pb2.SaveSliceInfoDef()
1139        save_slice_info_def.full_name = ops.strip_name_scope(
1140            self.full_name, export_scope)
1141        for i in self.full_shape:
1142          save_slice_info_def.full_shape.append(i)
1143        for i in self.var_offset:
1144          save_slice_info_def.var_offset.append(i)
1145        for i in self.var_shape:
1146          save_slice_info_def.var_shape.append(i)
1147        return save_slice_info_def
1148      else:
1149        return None
1150
1151
1152Variable._OverloadAllOperators()  # pylint: disable=protected-access
1153
1154
1155@tf_export(v1=["Variable"])
1156class VariableV1(Variable):
1157  """See the [Variables Guide](https://tensorflow.org/guide/variables).
1158
1159  A variable maintains state in the graph across calls to `run()`. You add a
1160  variable to the graph by constructing an instance of the class `Variable`.
1161
1162  The `Variable()` constructor requires an initial value for the variable,
1163  which can be a `Tensor` of any type and shape. The initial value defines the
1164  type and shape of the variable. After construction, the type and shape of
1165  the variable are fixed. The value can be changed using one of the assign
1166  methods.
1167
1168  If you want to change the shape of a variable later you have to use an
1169  `assign` Op with `validate_shape=False`.
1170
1171  Just like any `Tensor`, variables created with `Variable()` can be used as
1172  inputs for other Ops in the graph. Additionally, all the operators
1173  overloaded for the `Tensor` class are carried over to variables, so you can
1174  also add nodes to the graph by just doing arithmetic on variables.
1175
1176  ```python
1177  import tensorflow as tf
1178
1179  # Create a variable.
1180  w = tf.Variable(<initial-value>, name=<optional-name>)
1181
1182  # Use the variable in the graph like any Tensor.
1183  y = tf.matmul(w, ...another variable or tensor...)
1184
1185  # The overloaded operators are available too.
1186  z = tf.sigmoid(w + y)
1187
1188  # Assign a new value to the variable with `assign()` or a related method.
1189  w.assign(w + 1.0)
1190  w.assign_add(1.0)
1191  ```
1192
1193  When you launch the graph, variables have to be explicitly initialized before
1194  you can run Ops that use their value. You can initialize a variable by
1195  running its *initializer op*, restoring the variable from a save file, or
1196  simply running an `assign` Op that assigns a value to the variable. In fact,
1197  the variable *initializer op* is just an `assign` Op that assigns the
1198  variable's initial value to the variable itself.
1199
1200  ```python
1201  # Launch the graph in a session.
1202  with tf.Session() as sess:
1203      # Run the variable initializer.
1204      sess.run(w.initializer)
1205      # ...you now can run ops that use the value of 'w'...
1206  ```
1207
1208  The most common initialization pattern is to use the convenience function
1209  `global_variables_initializer()` to add an Op to the graph that initializes
1210  all the variables. You then run that Op after launching the graph.
1211
1212  ```python
1213  # Add an Op to initialize global variables.
1214  init_op = tf.global_variables_initializer()
1215
1216  # Launch the graph in a session.
1217  with tf.Session() as sess:
1218      # Run the Op that initializes global variables.
1219      sess.run(init_op)
1220      # ...you can now run any Op that uses variable values...
1221  ```
1222
1223  If you need to create a variable with an initial value dependent on another
1224  variable, use the other variable's `initialized_value()`. This ensures that
1225  variables are initialized in the right order.
1226
1227  All variables are automatically collected in the graph where they are
1228  created. By default, the constructor adds the new variable to the graph
1229  collection `GraphKeys.GLOBAL_VARIABLES`. The convenience function
1230  `global_variables()` returns the contents of that collection.
1231
1232  When building a machine learning model it is often convenient to distinguish
1233  between variables holding the trainable model parameters and other variables
1234  such as a `global step` variable used to count training steps. To make this
1235  easier, the variable constructor supports a `trainable=<bool>` parameter. If
1236  `True`, the new variable is also added to the graph collection
1237  `GraphKeys.TRAINABLE_VARIABLES`. The convenience function
1238  `trainable_variables()` returns the contents of this collection. The
1239  various `Optimizer` classes use this collection as the default list of
1240  variables to optimize.
1241
1242  WARNING: tf.Variable objects by default have a non-intuitive memory model. A
1243  Variable is represented internally as a mutable Tensor which can
1244  non-deterministically alias other Tensors in a graph. The set of operations
1245  which consume a Variable and can lead to aliasing is undetermined and can
1246  change across TensorFlow versions. Avoid writing code which relies on the
1247  value of a Variable either changing or not changing as other operations
1248  happen. For example, using Variable objects or simple functions thereof as
1249  predicates in a `tf.cond` is dangerous and error-prone:
1250
1251  ```
1252  v = tf.Variable(True)
1253  tf.cond(v, lambda: v.assign(False), my_false_fn)  # Note: this is broken.
1254  ```
1255
1256  Here, adding `use_resource=True` when constructing the variable will
1257  fix any nondeterminism issues:
1258  ```
1259  v = tf.Variable(True, use_resource=True)
1260  tf.cond(v, lambda: v.assign(False), my_false_fn)
1261  ```
1262
1263  To use the replacement for variables which does
1264  not have these issues:
1265
1266  * Add `use_resource=True` when constructing `tf.Variable`;
1267  * Call `tf.get_variable_scope().set_use_resource(True)` inside a
1268    `tf.variable_scope` before the `tf.get_variable()` call.
1269  """
1270
1271  def __init__(self,  # pylint: disable=super-init-not-called
1272               initial_value=None,
1273               trainable=True,
1274               collections=None,
1275               validate_shape=True,
1276               caching_device=None,
1277               name=None,
1278               variable_def=None,
1279               dtype=None,
1280               expected_shape=None,
1281               import_scope=None,
1282               constraint=None,
1283               use_resource=None,
1284               synchronization=VariableSynchronization.AUTO,
1285               aggregation=VariableAggregation.NONE):
1286    """Creates a new variable with value `initial_value`.
1287
1288    The new variable is added to the graph collections listed in `collections`,
1289    which defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
1290
1291    If `trainable` is `True` the variable is also added to the graph collection
1292    `GraphKeys.TRAINABLE_VARIABLES`.
1293
1294    This constructor creates both a `variable` Op and an `assign` Op to set the
1295    variable to its initial value.
1296
1297    Args:
1298      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
1299        which is the initial value for the Variable. The initial value must have
1300        a shape specified unless `validate_shape` is set to False. Can also be a
1301        callable with no argument that returns the initial value when called. In
1302        that case, `dtype` must be specified. (Note that initializer functions
1303        from init_ops.py must first be bound to a shape before being used here.)
1304      trainable: If `True`, the default, also adds the variable to the graph
1305        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
1306        the default list of variables to use by the `Optimizer` classes.
1307      collections: List of graph collections keys. The new variable is added to
1308        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
1309      validate_shape: If `False`, allows the variable to be initialized with a
1310        value of unknown shape. If `True`, the default, the shape of
1311        `initial_value` must be known.
1312      caching_device: Optional device string describing where the Variable
1313        should be cached for reading.  Defaults to the Variable's device.
1314        If not `None`, caches on another device.  Typical use is to cache
1315        on the device where the Ops using the Variable reside, to deduplicate
1316        copying through `Switch` and other conditional statements.
1317      name: Optional name for the variable. Defaults to `'Variable'` and gets
1318        uniquified automatically.
1319      variable_def: `VariableDef` protocol buffer. If not `None`, recreates
1320        the Variable object with its contents, referencing the variable's nodes
1321        in the graph, which must already exist. The graph is not changed.
1322        `variable_def` and the other arguments are mutually exclusive.
1323      dtype: If set, initial_value will be converted to the given type.
1324        If `None`, either the datatype will be kept (if `initial_value` is
1325        a Tensor), or `convert_to_tensor` will decide.
1326      expected_shape: A TensorShape. If set, initial_value is expected
1327        to have this shape.
1328      import_scope: Optional `string`. Name scope to add to the
1329        `Variable.` Only used when initializing from protocol buffer.
1330      constraint: An optional projection function to be applied to the variable
1331        after being updated by an `Optimizer` (e.g. used to implement norm
1332        constraints or value constraints for layer weights). The function must
1333        take as input the unprojected Tensor representing the value of the
1334        variable and return the Tensor for the projected value
1335        (which must have the same shape). Constraints are not safe to
1336        use when doing asynchronous distributed training.
1337      use_resource: whether to use resource variables.
1338      synchronization: unused
1339      aggregation: unused
1340
1341    Raises:
1342      ValueError: If both `variable_def` and initial_value are specified.
1343      ValueError: If the initial value is not specified, or does not have a
1344        shape and `validate_shape` is `True`.
1345      RuntimeError: If eager execution is enabled.
1346    """
1347
1348  SaveSliceInfo = Variable.SaveSliceInfo
1349
1350
1351# TODO(apassos): do not repeat all comments here
1352class RefVariable(VariableV1):
1353  """Ref-based implementation of variables."""
1354
1355  def __init__(self,  # pylint: disable=super-init-not-called
1356               initial_value=None,
1357               trainable=True,
1358               collections=None,
1359               validate_shape=True,
1360               caching_device=None,
1361               name=None,
1362               variable_def=None,
1363               dtype=None,
1364               expected_shape=None,
1365               import_scope=None,
1366               constraint=None):
1367    """Creates a new variable with value `initial_value`.
1368
1369    The new variable is added to the graph collections listed in `collections`,
1370    which defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
1371
1372    If `trainable` is `True` the variable is also added to the graph collection
1373    `GraphKeys.TRAINABLE_VARIABLES`.
1374
1375    This constructor creates both a `variable` Op and an `assign` Op to set the
1376    variable to its initial value.
1377
1378    Args:
1379      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
1380        which is the initial value for the Variable. The initial value must have
1381        a shape specified unless `validate_shape` is set to False. Can also be a
1382        callable with no argument that returns the initial value when called. In
1383        that case, `dtype` must be specified. (Note that initializer functions
1384        from init_ops.py must first be bound to a shape before being used here.)
1385      trainable: If `True`, the default, also adds the variable to the graph
1386        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
1387        the default list of variables to use by the `Optimizer` classes.
1388      collections: List of graph collections keys. The new variable is added to
1389        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
1390      validate_shape: If `False`, allows the variable to be initialized with a
1391        value of unknown shape. If `True`, the default, the shape of
1392        `initial_value` must be known.
1393      caching_device: Optional device string describing where the Variable
1394        should be cached for reading.  Defaults to the Variable's device.
1395        If not `None`, caches on another device.  Typical use is to cache
1396        on the device where the Ops using the Variable reside, to deduplicate
1397        copying through `Switch` and other conditional statements.
1398      name: Optional name for the variable. Defaults to `'Variable'` and gets
1399        uniquified automatically.
1400      variable_def: `VariableDef` protocol buffer. If not `None`, recreates
1401        the Variable object with its contents, referencing the variable's nodes
1402        in the graph, which must already exist. The graph is not changed.
1403        `variable_def` and the other arguments are mutually exclusive.
1404      dtype: If set, initial_value will be converted to the given type.
1405        If `None`, either the datatype will be kept (if `initial_value` is
1406        a Tensor), or `convert_to_tensor` will decide.
1407      expected_shape: A TensorShape. If set, initial_value is expected
1408        to have this shape.
1409      import_scope: Optional `string`. Name scope to add to the
1410        `Variable.` Only used when initializing from protocol buffer.
1411      constraint: An optional projection function to be applied to the variable
1412        after being updated by an `Optimizer` (e.g. used to implement norm
1413        constraints or value constraints for layer weights). The function must
1414        take as input the unprojected Tensor representing the value of the
1415        variable and return the Tensor for the projected value
1416        (which must have the same shape). Constraints are not safe to
1417        use when doing asynchronous distributed training.
1418
1419    Raises:
1420      ValueError: If both `variable_def` and initial_value are specified.
1421      ValueError: If the initial value is not specified, or does not have a
1422        shape and `validate_shape` is `True`.
1423      RuntimeError: If eager execution is enabled.
1424    """
1425    self._in_graph_mode = True
1426    if variable_def:
1427      # If variable_def is provided, recreates the variable from its fields.
1428      if initial_value:
1429        raise ValueError("variable_def and initial_value are mutually "
1430                         "exclusive.")
1431      self._init_from_proto(variable_def, import_scope=import_scope)
1432    else:
1433      # Create from initial_value.
1434      self._init_from_args(
1435          initial_value=initial_value,
1436          trainable=trainable,
1437          collections=collections,
1438          validate_shape=validate_shape,
1439          caching_device=caching_device,
1440          name=name,
1441          dtype=dtype,
1442          expected_shape=expected_shape,
1443          constraint=constraint)
1444
1445  def __repr__(self):
1446    if context.executing_eagerly() and not self._in_graph_mode:
1447      return "<tf.Variable '%s' shape=%s dtype=%s, numpy=%s>" % (
1448          self.name, self.get_shape(), self.dtype.name,
1449          ops.numpy_text(self.read_value(), is_repr=True))
1450    else:
1451      return "<tf.Variable '%s' shape=%s dtype=%s>" % (
1452          self.name, self.get_shape(), self.dtype.name)
1453
1454  def _init_from_args(self,
1455                      initial_value=None,
1456                      trainable=True,
1457                      collections=None,
1458                      validate_shape=True,
1459                      caching_device=None,
1460                      name=None,
1461                      dtype=None,
1462                      expected_shape=None,
1463                      constraint=None):
1464    """Creates a new variable from arguments.
1465
1466    Args:
1467      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
1468        which is the initial value for the Variable. The initial value must have
1469        a shape specified unless `validate_shape` is set to False. Can also be a
1470        callable with no argument that returns the initial value when called.
1471        (Note that initializer functions from init_ops.py must first be bound
1472         to a shape before being used here.)
1473      trainable: If `True`, the default, also adds the variable to the graph
1474        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
1475        the default list of variables to use by the `Optimizer` classes.
1476      collections: List of graph collections keys. The new variable is added to
1477        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
1478      validate_shape: If `False`, allows the variable to be initialized with a
1479        value of unknown shape. If `True`, the default, the shape of
1480        `initial_value` must be known.
1481      caching_device: Optional device string or function describing where the
1482        Variable should be cached for reading.  Defaults to the Variable's
1483        device.  If not `None`, caches on another device.  Typical use is to
1484        cache on the device where the Ops using the Variable reside, to
1485        deduplicate copying through `Switch` and other conditional statements.
1486      name: Optional name for the variable. Defaults to `'Variable'` and gets
1487        uniquified automatically.
1488      dtype: If set, initial_value will be converted to the given type.
1489        If None, either the datatype will be kept (if initial_value is
1490       a Tensor) or float32 will be used (if it is a Python object convertible
1491       to a Tensor).
1492      expected_shape: Deprecated. Ignored.
1493      constraint: An optional projection function to be applied to the variable
1494        after being updated by an `Optimizer` (e.g. used to implement norm
1495        constraints or value constraints for layer weights). The function must
1496        take as input the unprojected Tensor representing the value of the
1497        variable and return the Tensor for the projected value
1498        (which must have the same shape). Constraints are not safe to
1499        use when doing asynchronous distributed training.
1500
1501    Raises:
1502      ValueError: If the initial value is not specified, or does not have a
1503        shape and `validate_shape` is `True`.
1504      RuntimeError: If lifted into the eager context.
1505    """
1506    _ = expected_shape
1507    if initial_value is None:
1508      raise ValueError("initial_value must be specified.")
1509    init_from_fn = callable(initial_value)
1510
1511    if collections is None:
1512      collections = [ops.GraphKeys.GLOBAL_VARIABLES]
1513    if not isinstance(collections, (list, tuple, set)):
1514      raise ValueError(
1515          "collections argument to Variable constructor must be a list, tuple, "
1516          "or set. Got %s of type %s" % (collections, type(collections)))
1517    if constraint is not None and not callable(constraint):
1518      raise ValueError("The `constraint` argument must be a callable.")
1519
1520    # Store the graph key so optimizers know how to only retrieve variables from
1521    # this graph.
1522    self._graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
1523    if isinstance(initial_value, trackable.CheckpointInitialValue):
1524      self._maybe_initialize_trackable()
1525      self._update_uid = initial_value.checkpoint_position.restore_uid
1526      initial_value = initial_value.wrapped_value
1527
1528    self._trainable = trainable
1529    if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
1530      collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
1531    with ops.init_scope():
1532      # Ensure that we weren't lifted into the eager context.
1533      if context.executing_eagerly():
1534        raise RuntimeError(
1535            "RefVariable not supported when eager execution is enabled. ")
1536      with ops.name_scope(name, "Variable", [] if init_from_fn else
1537                          [initial_value]) as name:
1538
1539        if init_from_fn:
1540          # Use attr_scope and device(None) to simulate the behavior of
1541          # colocate_with when the variable we want to colocate with doesn't
1542          # yet exist.
1543          true_name = ops._name_from_scope_name(name)  # pylint: disable=protected-access
1544          attr = attr_value_pb2.AttrValue(
1545              list=attr_value_pb2.AttrValue.ListValue(
1546                  s=[compat.as_bytes("loc:@%s" % true_name)]))
1547          # pylint: disable=protected-access
1548          with ops.get_default_graph()._attr_scope({"_class": attr}):
1549            with ops.name_scope("Initializer"), ops.device(None):
1550              self._initial_value = ops.convert_to_tensor(
1551                  initial_value(), name="initial_value", dtype=dtype)
1552              shape = (self._initial_value.get_shape()
1553                       if validate_shape else tensor_shape.unknown_shape())
1554            self._variable = state_ops.variable_op_v2(
1555                shape,
1556                self._initial_value.dtype.base_dtype,
1557                name=name)
1558          # pylint: enable=protected-access
1559
1560        # Or get the initial value from a Tensor or Python object.
1561        else:
1562          self._initial_value = ops.convert_to_tensor(
1563              initial_value, name="initial_value", dtype=dtype)
1564          # pylint: disable=protected-access
1565          if self._initial_value.op._get_control_flow_context() is not None:
1566            raise ValueError(
1567                "Initializer for variable %s is from inside a control-flow "
1568                "construct, such as a loop or conditional. When creating a "
1569                "variable inside a loop or conditional, use a lambda as the "
1570                "initializer." % name)
1571          # pylint: enable=protected-access
1572          shape = (self._initial_value.get_shape()
1573                   if validate_shape else tensor_shape.unknown_shape())
1574          # In this case, the variable op can't be created until after the
1575          # initial_value has been converted to a Tensor with a known type.
1576          self._variable = state_ops.variable_op_v2(
1577              shape,
1578              self._initial_value.dtype.base_dtype,
1579              name=name)
1580
1581        # Manually overrides the variable's shape with the initial value's.
1582        if validate_shape:
1583          initial_value_shape = self._initial_value.get_shape()
1584          if not initial_value_shape.is_fully_defined():
1585            raise ValueError("initial_value must have a shape specified: %s" %
1586                             self._initial_value)
1587
1588        # If 'initial_value' makes use of other variables, make sure we don't
1589        # have an issue if these other variables aren't initialized first by
1590        # using their initialized_value() method.
1591        self._initializer_op = state_ops.assign(
1592            self._variable,
1593            _try_guard_against_uninitialized_dependencies(
1594                name,
1595                self._initial_value),
1596            validate_shape=validate_shape).op
1597
1598        # TODO(vrv): Change this class to not take caching_device, but
1599        # to take the op to colocate the snapshot with, so we can use
1600        # colocation rather than devices.
1601        if caching_device is not None:
1602          with ops.device(caching_device):
1603            self._snapshot = array_ops.identity(self._variable, name="read")
1604        else:
1605          with ops.colocate_with(self._variable.op):
1606            self._snapshot = array_ops.identity(self._variable, name="read")
1607      ops.add_to_collections(collections, self)
1608
1609    self._caching_device = caching_device
1610    self._save_slice_info = None
1611    self._constraint = constraint
1612
1613  def _init_from_proto(self, variable_def, import_scope=None):
1614    """Recreates the Variable object from a `VariableDef` protocol buffer.
1615
1616    Args:
1617      variable_def: `VariableDef` protocol buffer, describing a variable
1618          whose nodes already exists in the graph.
1619      import_scope: Optional `string`. Name scope to add.
1620    """
1621    assert isinstance(variable_def, variable_pb2.VariableDef)
1622    # Create from variable_def.
1623    g = ops.get_default_graph()
1624    self._variable = g.as_graph_element(
1625        ops.prepend_name_scope(variable_def.variable_name,
1626                               import_scope=import_scope))
1627    self._initializer_op = g.as_graph_element(
1628        ops.prepend_name_scope(variable_def.initializer_name,
1629                               import_scope=import_scope))
1630    # Tests whether initial_value_name exists first for backwards compatibility.
1631    if (hasattr(variable_def, "initial_value_name") and
1632        variable_def.initial_value_name):
1633      self._initial_value = g.as_graph_element(
1634          ops.prepend_name_scope(variable_def.initial_value_name,
1635                                 import_scope=import_scope))
1636    else:
1637      self._initial_value = None
1638    self._trainable = getattr(variable_def, "trainable", True)
1639    self._snapshot = g.as_graph_element(
1640        ops.prepend_name_scope(variable_def.snapshot_name,
1641                               import_scope=import_scope))
1642    if variable_def.HasField("save_slice_info_def"):
1643      self._save_slice_info = Variable.SaveSliceInfo(
1644          save_slice_info_def=variable_def.save_slice_info_def,
1645          import_scope=import_scope)
1646    else:
1647      self._save_slice_info = None
1648    self._caching_device = None
1649    self._constraint = None
1650
1651  def _as_graph_element(self):
1652    """Conversion function for Graph.as_graph_element()."""
1653    return self._variable
1654
1655  def value(self):
1656    """Returns the last snapshot of this variable.
1657
1658    You usually do not need to call this method as all ops that need the value
1659    of the variable call it automatically through a `convert_to_tensor()` call.
1660
1661    Returns a `Tensor` which holds the value of the variable.  You can not
1662    assign a new value to this tensor as it is not a reference to the variable.
1663
1664    To avoid copies, if the consumer of the returned value is on the same device
1665    as the variable, this actually returns the live value of the variable, not
1666    a copy.  Updates to the variable are seen by the consumer.  If the consumer
1667    is on a different device it will get a copy of the variable.
1668
1669    Returns:
1670      A `Tensor` containing the value of the variable.
1671    """
1672    return self._snapshot
1673
1674  def read_value(self):
1675    """Returns the value of this variable, read in the current context.
1676
1677    Can be different from value() if it's on another device, with control
1678    dependencies, etc.
1679
1680    Returns:
1681      A `Tensor` containing the value of the variable.
1682    """
1683    return array_ops.identity(self._variable, name="read")
1684
1685  def _ref(self):
1686    """Returns a reference to this variable.
1687
1688    You usually do not need to call this method as all ops that need a reference
1689    to the variable call it automatically.
1690
1691    Returns is a `Tensor` which holds a reference to the variable.  You can
1692    assign a new value to the variable by passing the tensor to an assign op.
1693    See `tf.Variable.value` if you want to get the value of the
1694    variable.
1695
1696    Returns:
1697      A `Tensor` that is a reference to the variable.
1698    """
1699    return self._variable
1700
1701  def set_shape(self, shape):
1702    """Overrides the shape for this variable.
1703
1704    Args:
1705      shape: the `TensorShape` representing the overridden shape.
1706    """
1707    self._ref().set_shape(shape)
1708    self.value().set_shape(shape)
1709
1710  @property
1711  def trainable(self):
1712    return self._trainable
1713
1714  def eval(self, session=None):
1715    """In a session, computes and returns the value of this variable.
1716
1717    This is not a graph construction method, it does not add ops to the graph.
1718
1719    This convenience method requires a session where the graph
1720    containing this variable has been launched. If no session is
1721    passed, the default session is used.  See `tf.Session` for more
1722    information on launching a graph and on sessions.
1723
1724    ```python
1725    v = tf.Variable([1, 2])
1726    init = tf.global_variables_initializer()
1727
1728    with tf.Session() as sess:
1729        sess.run(init)
1730        # Usage passing the session explicitly.
1731        print(v.eval(sess))
1732        # Usage with the default session.  The 'with' block
1733        # above makes 'sess' the default session.
1734        print(v.eval())
1735    ```
1736
1737    Args:
1738      session: The session to use to evaluate this variable. If
1739        none, the default session is used.
1740
1741    Returns:
1742      A numpy `ndarray` with a copy of the value of this variable.
1743    """
1744    return self._variable.eval(session=session)
1745
1746  @property
1747  def initial_value(self):
1748    """Returns the Tensor used as the initial value for the variable.
1749
1750    Note that this is different from `initialized_value()` which runs
1751    the op that initializes the variable before returning its value.
1752    This method returns the tensor that is used by the op that initializes
1753    the variable.
1754
1755    Returns:
1756      A `Tensor`.
1757    """
1758    return self._initial_value
1759
1760  @property
1761  def constraint(self):
1762    """Returns the constraint function associated with this variable.
1763
1764    Returns:
1765      The constraint function that was passed to the variable constructor.
1766      Can be `None` if no constraint was passed.
1767    """
1768    return self._constraint
1769
1770  def assign(self, value, use_locking=False, name=None, read_value=True):
1771    """Assigns a new value to the variable.
1772
1773    This is essentially a shortcut for `assign(self, value)`.
1774
1775    Args:
1776      value: A `Tensor`. The new value for this variable.
1777      use_locking: If `True`, use locking during the assignment.
1778      name: The name of the operation to be created
1779      read_value: if True, will return something which evaluates to the
1780        new value of the variable; if False will return the assign op.
1781
1782    Returns:
1783      A `Tensor` that will hold the new value of this variable after
1784      the assignment has completed.
1785    """
1786    assign = state_ops.assign(self._variable, value, use_locking=use_locking,
1787                              name=name)
1788    if read_value:
1789      return assign
1790    return assign.op
1791
1792  def assign_add(self, delta, use_locking=False, name=None, read_value=True):
1793    """Adds a value to this variable.
1794
1795     This is essentially a shortcut for `assign_add(self, delta)`.
1796
1797    Args:
1798      delta: A `Tensor`. The value to add to this variable.
1799      use_locking: If `True`, use locking during the operation.
1800      name: The name of the operation to be created
1801      read_value: if True, will return something which evaluates to the
1802        new value of the variable; if False will return the assign op.
1803
1804    Returns:
1805      A `Tensor` that will hold the new value of this variable after
1806      the addition has completed.
1807    """
1808    assign = state_ops.assign_add(
1809        self._variable, delta, use_locking=use_locking, name=name)
1810    if read_value:
1811      return assign
1812    return assign.op
1813
1814  def assign_sub(self, delta, use_locking=False, name=None, read_value=True):
1815    """Subtracts a value from this variable.
1816
1817    This is essentially a shortcut for `assign_sub(self, delta)`.
1818
1819    Args:
1820      delta: A `Tensor`. The value to subtract from this variable.
1821      use_locking: If `True`, use locking during the operation.
1822      name: The name of the operation to be created
1823      read_value: if True, will return something which evaluates to the
1824        new value of the variable; if False will return the assign op.
1825
1826    Returns:
1827      A `Tensor` that will hold the new value of this variable after
1828      the subtraction has completed.
1829    """
1830    assign = state_ops.assign_sub(
1831        self._variable, delta, use_locking=use_locking, name=name)
1832    if read_value:
1833      return assign
1834    return assign.op
1835
1836  def scatter_sub(self, sparse_delta, use_locking=False, name=None):
1837    """Subtracts `IndexedSlices` from this variable.
1838
1839    Args:
1840      sparse_delta: `IndexedSlices` to be subtracted from this variable.
1841      use_locking: If `True`, use locking during the operation.
1842      name: the name of the operation.
1843
1844    Returns:
1845      A `Tensor` that will hold the new value of this variable after
1846      the scattered subtraction has completed.
1847
1848    Raises:
1849      ValueError: if `sparse_delta` is not an `IndexedSlices`.
1850    """
1851    if not isinstance(sparse_delta, ops.IndexedSlices):
1852      raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
1853    return gen_state_ops.scatter_sub(
1854        self._variable,
1855        sparse_delta.indices,
1856        sparse_delta.values,
1857        use_locking=use_locking,
1858        name=name)
1859
1860  def scatter_add(self, sparse_delta, use_locking=False, name=None):
1861    """Adds `IndexedSlices` from this variable.
1862
1863    Args:
1864      sparse_delta: `IndexedSlices` to be added to this variable.
1865      use_locking: If `True`, use locking during the operation.
1866      name: the name of the operation.
1867
1868    Returns:
1869      A `Tensor` that will hold the new value of this variable after
1870      the scattered subtraction has completed.
1871
1872    Raises:
1873      ValueError: if `sparse_delta` is not an `IndexedSlices`.
1874    """
1875    if not isinstance(sparse_delta, ops.IndexedSlices):
1876      raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
1877    return gen_state_ops.scatter_add(
1878        self._variable,
1879        sparse_delta.indices,
1880        sparse_delta.values,
1881        use_locking=use_locking,
1882        name=name)
1883
1884  def scatter_update(self, sparse_delta, use_locking=False, name=None):
1885    """Assigns `IndexedSlices` to this variable.
1886
1887    Args:
1888      sparse_delta: `IndexedSlices` to be assigned to this variable.
1889      use_locking: If `True`, use locking during the operation.
1890      name: the name of the operation.
1891
1892    Returns:
1893      A `Tensor` that will hold the new value of this variable after
1894      the scattered subtraction has completed.
1895
1896    Raises:
1897      ValueError: if `sparse_delta` is not an `IndexedSlices`.
1898    """
1899    if not isinstance(sparse_delta, ops.IndexedSlices):
1900      raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
1901    return gen_state_ops.scatter_update(
1902        self._variable,
1903        sparse_delta.indices,
1904        sparse_delta.values,
1905        use_locking=use_locking,
1906        name=name)
1907
1908  def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
1909    """Assigns `IndexedSlices` to this variable batch-wise.
1910
1911    Analogous to `batch_gather`. This assumes that this variable and the
1912    sparse_delta IndexedSlices have a series of leading dimensions that are the
1913    same for all of them, and the updates are performed on the last dimension of
1914    indices. In other words, the dimensions should be the following:
1915
1916    `num_prefix_dims = sparse_delta.indices.ndims - 1`
1917    `batch_dim = num_prefix_dims + 1`
1918    `sparse_delta.updates.shape = sparse_delta.indices.shape + var.shape[
1919         batch_dim:]`
1920
1921    where
1922
1923    `sparse_delta.updates.shape[:num_prefix_dims]`
1924    `== sparse_delta.indices.shape[:num_prefix_dims]`
1925    `== var.shape[:num_prefix_dims]`
1926
1927    And the operation performed can be expressed as:
1928
1929    `var[i_1, ..., i_n,
1930         sparse_delta.indices[i_1, ..., i_n, j]] = sparse_delta.updates[
1931            i_1, ..., i_n, j]`
1932
1933    When sparse_delta.indices is a 1D tensor, this operation is equivalent to
1934    `scatter_update`.
1935
1936    To avoid this operation one can looping over the first `ndims` of the
1937    variable and using `scatter_update` on the subtensors that result of slicing
1938    the first dimension. This is a valid option for `ndims = 1`, but less
1939    efficient than this implementation.
1940
1941    Args:
1942      sparse_delta: `IndexedSlices` to be assigned to this variable.
1943      use_locking: If `True`, use locking during the operation.
1944      name: the name of the operation.
1945
1946    Returns:
1947      A `Tensor` that will hold the new value of this variable after
1948      the scattered subtraction has completed.
1949
1950    Raises:
1951      ValueError: if `sparse_delta` is not an `IndexedSlices`.
1952    """
1953    return state_ops.batch_scatter_update(
1954        self, sparse_delta.indices, sparse_delta.values,
1955        use_locking=use_locking, name=name)
1956
1957  def scatter_nd_sub(self, indices, updates, name=None):
1958    """Applies sparse subtraction to individual values or slices in a Variable.
1959
1960    `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
1961
1962    `indices` must be integer tensor, containing indices into `ref`.
1963    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
1964
1965    The innermost dimension of `indices` (with length `K`) corresponds to
1966    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
1967    dimension of `ref`.
1968
1969    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
1970
1971    ```
1972    [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
1973    ```
1974
1975    For example, say we want to add 4 scattered elements to a rank-1 tensor to
1976    8 elements. In Python, that update would look like this:
1977
1978    ```python
1979        ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
1980        indices = tf.constant([[4], [3], [1] ,[7]])
1981        updates = tf.constant([9, 10, 11, 12])
1982        op = ref.scatter_nd_sub(indices, updates)
1983        with tf.Session() as sess:
1984          print sess.run(op)
1985    ```
1986
1987    The resulting update to ref would look like this:
1988
1989        [1, -9, 3, -6, -6, 6, 7, -4]
1990
1991    See `tf.scatter_nd` for more details about how to make updates to
1992    slices.
1993
1994    Args:
1995      indices: The indices to be used in the operation.
1996      updates: The values to be used in the operation.
1997      name: the name of the operation.
1998
1999    Returns:
2000      A `Tensor` that will hold the new value of this variable after
2001      the scattered subtraction has completed.
2002
2003    Raises:
2004      ValueError: if `sparse_delta` is not an `IndexedSlices`.
2005    """
2006    return gen_state_ops.scatter_nd_sub(
2007        self._variable, indices, updates, use_locking=True, name=name)
2008
2009  def scatter_nd_add(self, indices, updates, name=None):
2010    """Applies sparse addition to individual values or slices in a Variable.
2011
2012    `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
2013
2014    `indices` must be integer tensor, containing indices into `ref`.
2015    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
2016
2017    The innermost dimension of `indices` (with length `K`) corresponds to
2018    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
2019    dimension of `ref`.
2020
2021    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
2022
2023    ```
2024    [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
2025    ```
2026
2027    For example, say we want to add 4 scattered elements to a rank-1 tensor to
2028    8 elements. In Python, that update would look like this:
2029
2030    ```python
2031        ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
2032        indices = tf.constant([[4], [3], [1] ,[7]])
2033        updates = tf.constant([9, 10, 11, 12])
2034        add = ref.scatter_nd_add(indices, updates)
2035        with tf.Session() as sess:
2036          print sess.run(add)
2037    ```
2038
2039    The resulting update to ref would look like this:
2040
2041        [1, 13, 3, 14, 14, 6, 7, 20]
2042
2043    See `tf.scatter_nd` for more details about how to make updates to
2044    slices.
2045
2046    Args:
2047      indices: The indices to be used in the operation.
2048      updates: The values to be used in the operation.
2049      name: the name of the operation.
2050
2051    Returns:
2052      A `Tensor` that will hold the new value of this variable after
2053      the scattered subtraction has completed.
2054
2055    Raises:
2056      ValueError: if `sparse_delta` is not an `IndexedSlices`.
2057    """
2058    return gen_state_ops.scatter_nd_add(
2059        self._variable, indices, updates, use_locking=True, name=name)
2060
2061  def scatter_nd_update(self, indices, updates, name=None):
2062    """Applies sparse assignment to individual values or slices in a Variable.
2063
2064    `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
2065
2066    `indices` must be integer tensor, containing indices into `ref`.
2067    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
2068
2069    The innermost dimension of `indices` (with length `K`) corresponds to
2070    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
2071    dimension of `ref`.
2072
2073    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
2074
2075    ```
2076    [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
2077    ```
2078
2079    For example, say we want to add 4 scattered elements to a rank-1 tensor to
2080    8 elements. In Python, that update would look like this:
2081
2082    ```python
2083        ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
2084        indices = tf.constant([[4], [3], [1] ,[7]])
2085        updates = tf.constant([9, 10, 11, 12])
2086        op = ref.scatter_nd_update(indices, updates)
2087        with tf.Session() as sess:
2088          print sess.run(op)
2089    ```
2090
2091    The resulting update to ref would look like this:
2092
2093        [1, 11, 3, 10, 9, 6, 7, 12]
2094
2095    See `tf.scatter_nd` for more details about how to make updates to
2096    slices.
2097
2098    Args:
2099      indices: The indices to be used in the operation.
2100      updates: The values to be used in the operation.
2101      name: the name of the operation.
2102
2103    Returns:
2104      A `Tensor` that will hold the new value of this variable after
2105      the scattered subtraction has completed.
2106
2107    Raises:
2108      ValueError: if `sparse_delta` is not an `IndexedSlices`.
2109    """
2110    return gen_state_ops.scatter_nd_update(
2111        self._variable, indices, updates, use_locking=True, name=name)
2112
2113  def _strided_slice_assign(self,
2114                            begin,
2115                            end,
2116                            strides,
2117                            value,
2118                            name,
2119                            begin_mask,
2120                            end_mask,
2121                            ellipsis_mask,
2122                            new_axis_mask,
2123                            shrink_axis_mask):
2124    return gen_array_ops.strided_slice_assign(ref=self._ref(),
2125                                              begin=begin,
2126                                              end=end,
2127                                              strides=strides,
2128                                              value=value,
2129                                              name=name,
2130                                              begin_mask=begin_mask,
2131                                              end_mask=end_mask,
2132                                              ellipsis_mask=ellipsis_mask,
2133                                              new_axis_mask=new_axis_mask,
2134                                              shrink_axis_mask=shrink_axis_mask)
2135
2136  @deprecated(None, "Prefer Dataset.range instead.")
2137  def count_up_to(self, limit):
2138    """Increments this variable until it reaches `limit`.
2139
2140    When that Op is run it tries to increment the variable by `1`. If
2141    incrementing the variable would bring it above `limit` then the Op raises
2142    the exception `OutOfRangeError`.
2143
2144    If no error is raised, the Op outputs the value of the variable before
2145    the increment.
2146
2147    This is essentially a shortcut for `count_up_to(self, limit)`.
2148
2149    Args:
2150      limit: value at which incrementing the variable raises an error.
2151
2152    Returns:
2153      A `Tensor` that will hold the variable value before the increment. If no
2154      other Op modifies this variable, the values produced will all be
2155      distinct.
2156    """
2157    return state_ops.count_up_to(self._variable, limit=limit)
2158
2159  # Conversion to tensor.
2160  @staticmethod
2161  def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False):  # pylint: disable=invalid-name
2162    """Utility function for converting a Variable to a Tensor."""
2163    _ = name
2164    if dtype and not dtype.is_compatible_with(v.dtype):
2165      raise ValueError(
2166          "Incompatible type conversion requested to type '%s' for variable "
2167          "of type '%s'" % (dtype.name, v.dtype.name))
2168    if as_ref:
2169      return v._ref()  # pylint: disable=protected-access
2170    else:
2171      return v.value()
2172
2173  # NOTE(mrry): This enables the Variable's overloaded "right" binary
2174  # operators to run when the left operand is an ndarray, because it
2175  # accords the Variable class higher priority than an ndarray, or a
2176  # numpy matrix.
2177  # TODO(mrry): Convert this to using numpy's __numpy_ufunc__
2178  # mechanism, which allows more control over how Variables interact
2179  # with ndarrays.
2180  __array_priority__ = 100
2181
2182  @property
2183  def name(self):
2184    """The name of this variable."""
2185    return self._variable.name
2186
2187  @property
2188  def initializer(self):
2189    """The initializer operation for this variable."""
2190    return self._initializer_op
2191
2192  @property
2193  def device(self):
2194    """The device of this variable."""
2195    return self._variable.device
2196
2197  @property
2198  def dtype(self):
2199    """The `DType` of this variable."""
2200    return self._variable.dtype
2201
2202  @property
2203  def op(self):
2204    """The `Operation` of this variable."""
2205    return self._variable.op
2206
2207  @property
2208  def graph(self):
2209    """The `Graph` of this variable."""
2210    return self._variable.graph
2211
2212  @property
2213  def _distribute_strategy(self):
2214    """The `tf.distribute.Strategy` that this variable was created under."""
2215    return None   # Ref variables are never created inside a strategy.
2216
2217  @property
2218  def shape(self):
2219    """The `TensorShape` of this variable.
2220
2221    Returns:
2222      A `TensorShape`.
2223    """
2224    return self._variable.get_shape()
2225
2226  def to_proto(self, export_scope=None):
2227    """Converts a `Variable` to a `VariableDef` protocol buffer.
2228
2229    Args:
2230      export_scope: Optional `string`. Name scope to remove.
2231
2232    Returns:
2233      A `VariableDef` protocol buffer, or `None` if the `Variable` is not
2234      in the specified name scope.
2235    """
2236    if (export_scope is None or
2237        self._variable.name.startswith(export_scope)):
2238      var_def = variable_pb2.VariableDef()
2239      var_def.variable_name = ops.strip_name_scope(
2240          self._variable.name, export_scope)
2241      if self._initial_value is not None:
2242        # For backwards compatibility.
2243        var_def.initial_value_name = ops.strip_name_scope(
2244            self._initial_value.name, export_scope)
2245      var_def.trainable = self.trainable
2246      var_def.initializer_name = ops.strip_name_scope(
2247          self.initializer.name, export_scope)
2248      var_def.snapshot_name = ops.strip_name_scope(
2249          self._snapshot.name, export_scope)
2250      if self._save_slice_info:
2251        var_def.save_slice_info_def.MergeFrom(self._save_slice_info.to_proto(
2252            export_scope=export_scope))
2253      return var_def
2254    else:
2255      return None
2256
2257  def __iadd__(self, other):
2258    logging.log_first_n(
2259        logging.WARN,
2260        "Variable += will be deprecated. Use variable.assign_add"
2261        " if you want assignment to the variable value or 'x = x + y'"
2262        " if you want a new python Tensor object.", 1)
2263    return self + other
2264
2265  def __isub__(self, other):
2266    logging.log_first_n(
2267        logging.WARN,
2268        "Variable -= will be deprecated. Use variable.assign_sub"
2269        " if you want assignment to the variable value or 'x = x - y'"
2270        " if you want a new python Tensor object.", 1)
2271    return self - other
2272
2273  def __imul__(self, other):
2274    logging.log_first_n(
2275        logging.WARN,
2276        "Variable *= will be deprecated. Use `var.assign(var * other)`"
2277        " if you want assignment to the variable value or `x = x * y`"
2278        " if you want a new python Tensor object.", 1)
2279    return self * other
2280
2281  def __idiv__(self, other):
2282    logging.log_first_n(
2283        logging.WARN,
2284        "Variable /= will be deprecated. Use `var.assign(var / other)`"
2285        " if you want assignment to the variable value or `x = x / y`"
2286        " if you want a new python Tensor object.", 1)
2287    return self / other
2288
2289  def __itruediv__(self, other):
2290    logging.log_first_n(
2291        logging.WARN,
2292        "Variable /= will be deprecated. Use `var.assign(var / other)`"
2293        " if you want assignment to the variable value or `x = x / y`"
2294        " if you want a new python Tensor object.", 1)
2295    return self / other
2296
2297  def __irealdiv__(self, other):
2298    logging.log_first_n(
2299        logging.WARN,
2300        "Variable /= will be deprecated. Use `var.assign(var / other)`"
2301        " if you want assignment to the variable value or `x = x / y`"
2302        " if you want a new python Tensor object.", 1)
2303    return self / other
2304
2305  def __ipow__(self, other):
2306    logging.log_first_n(
2307        logging.WARN,
2308        "Variable **= will be deprecated. Use `var.assign(var ** other)`"
2309        " if you want assignment to the variable value or `x = x ** y`"
2310        " if you want a new python Tensor object.", 1)
2311    return self ** other
2312
2313
2314def _try_guard_against_uninitialized_dependencies(name, initial_value):
2315  """Attempt to guard against dependencies on uninitialized variables.
2316
2317  Replace references to variables in `initial_value` with references to the
2318  variable's initialized values. The initialized values are essentially
2319  conditional TensorFlow graphs that return a variable's value if it is
2320  initialized or its `initial_value` if it hasn't been initialized. This
2321  replacement is done on a best effort basis:
2322
2323  - If the `initial_value` graph contains cycles, we don't do any
2324    replacements for that graph.
2325  - If the variables that `initial_value` depends on are not present in the
2326    `GLOBAL_VARIABLES` or `LOCAL_VARIABLES` we don't replace them.
2327
2328  In these cases, it is up to the caller to ensure that the `initial_value`
2329  graph uses initialized variables or that they guard access to variables
2330  using their `initialized_value` method.
2331
2332  Args:
2333    name: Variable name.
2334    initial_value: `Tensor`. The initial value.
2335  Returns:
2336    A `Tensor` suitable to initialize a variable.
2337  Raises:
2338    TypeError: If `initial_value` is not a `Tensor`.
2339  """
2340  if not isinstance(initial_value, ops.Tensor):
2341    raise TypeError("initial_value needs to be a Tensor: %s" % initial_value)
2342
2343  # Don't modify initial_value if it contains any cyclic dependencies.
2344  if _has_cycle(initial_value.op, path=set()):
2345    return initial_value
2346  return _safe_initial_value_from_tensor(name, initial_value, op_cache={})
2347
2348
2349def _has_cycle(op, path):
2350  """Detect cycles in the dependencies of `initial_value`."""
2351  if op.name in path:
2352    return True
2353  path.add(op.name)
2354  for op_input in op.inputs:
2355    if _has_cycle(op_input.op, path):
2356      return True
2357  for op_control_input in op.control_inputs:
2358    if _has_cycle(op_control_input, path):
2359      return True
2360  path.remove(op.name)
2361  return False
2362
2363
2364def _safe_initial_value_from_tensor(name, tensor, op_cache):
2365  """Replace dependencies on variables with their initialized values.
2366
2367  Args:
2368    name: Variable name.
2369    tensor: A `Tensor`. The tensor to replace.
2370    op_cache: A dict mapping operation names to `Operation`s. Used to memoize
2371      the results so as to avoid creating redundant operations.
2372  Returns:
2373    A `Tensor` compatible with `tensor`. Any inputs that lead to variable
2374    values will be replaced with a corresponding graph that uses the
2375    variable's initialized values. This is done on a best-effort basis. If no
2376    modifications need to be made then `tensor` will be returned unchanged.
2377  """
2378  op = tensor.op
2379  new_op = op_cache.get(op.name)
2380  if new_op is None:
2381    new_op = _safe_initial_value_from_op(name, op, op_cache)
2382    op_cache[op.name] = new_op
2383  return new_op.outputs[tensor.value_index]
2384
2385
2386def _safe_initial_value_from_op(name, op, op_cache):
2387  """Replace dependencies on variables with their initialized values.
2388
2389  Args:
2390    name: Variable name.
2391    op: An `Operation`. The operation to replace.
2392    op_cache: A dict mapping operation names to `Operation`s. Used to memoize
2393      the results so as to avoid creating redundant operations.
2394  Returns:
2395    An `Operation` compatible with `op`. Any inputs that lead to variable
2396    values will be replaced with a corresponding graph that uses the
2397    variable's initialized values. This is done on a best-effort basis. If no
2398    modifications need to be made then `op` will be returned unchanged.
2399  """
2400  op_type = op.node_def.op
2401  if op_type in ("IsVariableInitialized", "VarIsInitializedOp",
2402                 "ReadVariableOp"):
2403    return op
2404
2405  # Attempt to find the initialized_value of any variable reference / handles.
2406  # TODO(b/70206927): Fix handling of ResourceVariables.
2407  if op_type in ("Variable", "VariableV2", "VarHandleOp"):
2408    initialized_value = _find_initialized_value_for_variable(op)
2409    return op if initialized_value is None else initialized_value.op
2410
2411  # Recursively build initializer expressions for inputs.
2412  modified = False
2413  new_op_inputs = []
2414  for op_input in op.inputs:
2415    new_op_input = _safe_initial_value_from_tensor(name, op_input, op_cache)
2416    new_op_inputs.append(new_op_input)
2417    modified = modified or (new_op_input != op_input)
2418
2419  # If at least one input was modified, replace the op.
2420  if modified:
2421    new_op_type = op_type
2422    if new_op_type == "RefSwitch":
2423      new_op_type = "Switch"
2424    new_op_name = op.node_def.name + "_" + name
2425    new_op_name = new_op_name.replace(":", "_")
2426    return op.graph.create_op(
2427        new_op_type, new_op_inputs,
2428        op._output_types,  # pylint: disable=protected-access
2429        name=new_op_name, attrs=op.node_def.attr)
2430
2431  return op
2432
2433
2434def _find_initialized_value_for_variable(variable_op):
2435  """Find the initialized value for a variable op.
2436
2437  To do so, lookup the variable op in the variables collection.
2438
2439  Args:
2440    variable_op: A variable `Operation`.
2441  Returns:
2442    A `Tensor` representing the initialized value for the variable or `None`
2443    if the initialized value could not be found.
2444  """
2445  try:
2446    var_names = [variable_op.node_def.name, variable_op.node_def.name + ":0"]
2447    for collection_name in (ops.GraphKeys.GLOBAL_VARIABLES,
2448                            ops.GraphKeys.LOCAL_VARIABLES):
2449      for var in variable_op.graph.get_collection(collection_name):
2450        if var.name in var_names:
2451          return var.initialized_value()
2452  except AttributeError:
2453    # Return None when an incomplete user-defined variable type was put in
2454    # the collection.
2455    return None
2456  return None
2457
2458
2459class PartitionedVariable(object):
2460  """A container for partitioned `Variable` objects.
2461
2462  @compatibility(eager) `tf.PartitionedVariable` is not compatible with
2463  eager execution.  Use `tf.Variable` instead which is compatible
2464  with both eager execution and graph construction.  See [the
2465  TensorFlow Eager Execution
2466  guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/g3doc/guide.md#variables-and-optimizers)
2467  for details on how variables work in eager execution.
2468  @end_compatibility
2469  """
2470
2471  def __init__(self, name, shape, dtype, variable_list, partitions):
2472    """Creates a new partitioned variable wrapper.
2473
2474    Variables passed via the variable_list must contain a save_slice_info
2475    field.  Concatenation and iteration is in lexicographic order according
2476    to the var_offset property of the save_slice_info.
2477
2478    Args:
2479      name: String. Overall name of the variables.
2480      shape: List of integers.  Overall shape of the variables.
2481      dtype: Type of the variables.
2482      variable_list: List of `Variable` that comprise this partitioned variable.
2483      partitions: List of integers.  Number of partitions for each dimension.
2484
2485    Raises:
2486      TypeError: If `variable_list` is not a list of `Variable` objects, or
2487        `partitions` is not a list.
2488      ValueError: If `variable_list` is empty, or the `Variable` shape
2489        information does not match `shape`, or `partitions` has invalid values.
2490    """
2491    if not isinstance(variable_list, (list, tuple)):
2492      raise TypeError(
2493          "variable_list is not a list or tuple: %s" % variable_list)
2494    if not isinstance(partitions, (list, tuple)):
2495      raise TypeError("partitions is not a list or tuple: %s" % partitions)
2496    if not all(p >= 1 for p in partitions):
2497      raise ValueError("partition values must be positive: %s" % partitions)
2498    if not variable_list:
2499      raise ValueError("variable_list may not be empty")
2500    # pylint: disable=protected-access
2501    for v in variable_list:
2502      # Sort the variable_list lexicographically according to var offset value.
2503      if not all(v._get_save_slice_info() is not None for v in variable_list):
2504        raise ValueError(
2505            "All variables must have a save_slice_info available: %s"
2506            % [v.name for v in variable_list])
2507      if len(shape) != len(partitions):
2508        raise ValueError("len(shape) != len(partitions): %s vs. %s"
2509                         % (shape, partitions))
2510      if v._get_save_slice_info().full_shape != shape:
2511        raise ValueError(
2512            "All variables' full shapes must match shape: %s; "
2513            "but full shapes were: %s"
2514            % (shape, str([v._get_save_slice_info().full_shape])))
2515    self._variable_list = sorted(
2516        variable_list, key=lambda v: v._get_save_slice_info().var_offset)
2517    # pylint: enable=protected-access
2518
2519    self._name = name
2520    self._shape = shape
2521    self._dtype = dtype
2522    self._partitions = partitions
2523    self._as_tensor = None
2524
2525  def __iter__(self):
2526    """Return an iterable for accessing the underlying partition Variables."""
2527    return iter(self._variable_list)
2528
2529  def __len__(self):
2530    num_partition_axes = len(self._partition_axes())
2531    if num_partition_axes > 1:
2532      raise ValueError("Cannot get a length for %d > 1 partition axes"
2533                       % num_partition_axes)
2534    return len(self._variable_list)
2535
2536  def _partition_axes(self):
2537    if all(p == 1 for p in self._partitions):
2538      return [0]
2539    else:
2540      return [i for i, p in enumerate(self._partitions) if p > 1]
2541
2542  def _concat(self):
2543    """Returns the overall concatenated value as a `Tensor`.
2544
2545    This is different from using the partitioned variable directly as a tensor
2546    (through tensor conversion and `as_tensor`) in that it creates a new set of
2547    operations that keeps the control dependencies from its scope.
2548
2549    Returns:
2550      `Tensor` containing the concatenated value.
2551    """
2552    if len(self._variable_list) == 1:
2553      with ops.name_scope(None):
2554        return array_ops.identity(self._variable_list[0], name=self._name)
2555
2556    partition_axes = self._partition_axes()
2557
2558    if len(partition_axes) > 1:
2559      raise NotImplementedError(
2560          "Cannot concatenate along more than one dimension: %s.  "
2561          "Multi-axis partition concat is not supported" % str(partition_axes))
2562    partition_ix = partition_axes[0]
2563
2564    with ops.name_scope(self._name + "/ConcatPartitions/"):
2565      concatenated = array_ops.concat(self._variable_list, partition_ix)
2566
2567    with ops.name_scope(None):
2568      return array_ops.identity(concatenated, name=self._name)
2569
2570  def as_tensor(self):
2571    """Returns the overall concatenated value as a `Tensor`.
2572
2573    The returned tensor will not inherit the control dependencies from the scope
2574    where the value is used, which is similar to getting the value of
2575    `Variable`.
2576
2577    Returns:
2578      `Tensor` containing the concatenated value.
2579    """
2580    with ops.control_dependencies(None):
2581      return self._concat()
2582
2583  @staticmethod
2584  def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False):
2585    # pylint: disable=invalid-name
2586    _ = name
2587    if dtype is not None and not dtype.is_compatible_with(v.dtype):
2588      raise ValueError(
2589          "Incompatible type conversion requested to type '%s' for variable "
2590          "of type '%s'" % (dtype.name, v.dtype.name))
2591    if as_ref:
2592      raise NotImplementedError(
2593          "PartitionedVariable doesn't support being used as a reference.")
2594    else:
2595      return v.as_tensor()
2596
2597  @property
2598  def name(self):
2599    return self._name
2600
2601  @property
2602  def dtype(self):
2603    return self._dtype
2604
2605  @property
2606  def shape(self):
2607    return self.get_shape()
2608
2609  @property
2610  def _distribute_strategy(self):
2611    """The `tf.distribute.Strategy` that this variable was created under."""
2612    # NOTE(yuefengz): Today, no partitioned variables in a distribute strategy.
2613    return None
2614
2615  def get_shape(self):
2616    return self._shape
2617
2618  def _get_variable_list(self):
2619    return self._variable_list
2620
2621  def _get_partitions(self):
2622    return self._partitions
2623
2624  def _apply_assign_fn(self, assign_fn, value):
2625    partition_axes = self._partition_axes()
2626    if len(partition_axes) > 1:
2627      raise NotImplementedError(
2628          "Cannot do assign action along more than one dimension: %s.  "
2629          "Multi-axis partition assign action is not supported " %
2630          str(partition_axes))
2631    if isinstance(value, list):
2632      assert len(value) == len(self._variable_list)
2633      value_list = value
2634    elif isinstance(value, PartitionedVariable):
2635      value_list = [var_part for var_part in value]
2636    else:
2637      partition_ix = partition_axes[0]
2638      size_splits_list = [
2639          tensor_shape.dimension_value(var.shape[partition_ix])
2640          for var in self._variable_list
2641      ]
2642      value_list = array_ops.split(value, size_splits_list, axis=partition_ix)
2643
2644    op_list = [
2645        assign_fn(var, value_list[idx])
2646        for idx, var in enumerate(self._variable_list)
2647    ]
2648    return op_list
2649
2650  def assign(self, value, use_locking=False, name=None, read_value=True):
2651    assign_fn = lambda var, r_value: var.assign(
2652        r_value, use_locking=use_locking,
2653        name=name, read_value=read_value)
2654    assign_list = self._apply_assign_fn(assign_fn, value)
2655    if read_value:
2656      return assign_list
2657    return [assign.op for assign in assign_list]
2658
2659  def assign_add(self, value, use_locking=False, name=None, read_value=True):
2660    assign_fn = lambda var, r_value: var.assign_add(
2661        r_value, use_locking=use_locking,
2662        name=name, read_value=read_value)
2663    assign_list = self._apply_assign_fn(assign_fn, value)
2664    if read_value:
2665      return assign_list
2666    return [assign.op for assign in assign_list]
2667
2668  def assign_sub(self, value, use_locking=False, name=None, read_value=True):
2669    assign_fn = lambda var, r_value: var.assign_sub(
2670        r_value, use_locking=use_locking,
2671        name=name, read_value=read_value)
2672    assign_list = self._apply_assign_fn(assign_fn, value)
2673    if read_value:
2674      return assign_list
2675    return [assign.op for assign in assign_list]
2676
2677
2678# Register a conversion function which reads the value of the variable,
2679# allowing instances of the class to be used as tensors.
2680ops.register_tensor_conversion_function(
2681    RefVariable,
2682    RefVariable._TensorConversionFunction)  # pylint: disable=protected-access
2683ops.register_dense_tensor_like_type(RefVariable)
2684
2685
2686@tf_export(v1=["global_variables"])
2687def global_variables(scope=None):
2688  """Returns global variables.
2689
2690  Global variables are variables that are shared across machines in a
2691  distributed environment. The `Variable()` constructor or `get_variable()`
2692  automatically adds new variables to the graph collection
2693  `GraphKeys.GLOBAL_VARIABLES`.
2694  This convenience function returns the contents of that collection.
2695
2696  An alternative to global variables are local variables. See
2697  `tf.local_variables`
2698
2699  Args:
2700    scope: (Optional.) A string. If supplied, the resulting list is filtered
2701      to include only items whose `name` attribute matches `scope` using
2702      `re.match`. Items without a `name` attribute are never returned if a
2703      scope is supplied. The choice of `re.match` means that a `scope` without
2704      special tokens filters by prefix.
2705
2706  Returns:
2707    A list of `Variable` objects.
2708  """
2709  return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope)
2710
2711
2712@tf_export(v1=["all_variables"])
2713@deprecated("2017-03-02", "Please use tf.global_variables instead.")
2714def all_variables():
2715  """See `tf.global_variables`."""
2716  return global_variables()
2717
2718
2719def _all_saveable_objects(scope=None):
2720  """Returns all variables and `SaveableObject`s that must be checkpointed.
2721
2722  Args:
2723    scope: (Optional.) A string. If supplied, the resulting list is filtered
2724      to include only items whose `name` attribute matches `scope` using
2725      `re.match`. Items without a `name` attribute are never returned if a
2726      scope is supplied. The choice of `re.match` means that a `scope` without
2727      special tokens filters by prefix.
2728
2729  Returns:
2730    A list of `Variable` and `SaveableObject` to be checkpointed
2731  """
2732  # TODO(andreasst): make this function public once things are settled.
2733  return (ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) +
2734          ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope))
2735
2736
2737@tf_export(v1=["local_variables"])
2738def local_variables(scope=None):
2739  """Returns local variables.
2740
2741  Local variables - per process variables, usually not saved/restored to
2742  checkpoint and used for temporary or intermediate values.
2743  For example, they can be used as counters for metrics computation or
2744  number of epochs this machine has read data.
2745  The `tf.contrib.framework.local_variable()` function automatically adds the
2746  new variable to `GraphKeys.LOCAL_VARIABLES`.
2747  This convenience function returns the contents of that collection.
2748
2749  An alternative to local variables are global variables. See
2750  `tf.global_variables`
2751
2752  Args:
2753    scope: (Optional.) A string. If supplied, the resulting list is filtered
2754      to include only items whose `name` attribute matches `scope` using
2755      `re.match`. Items without a `name` attribute are never returned if a
2756      scope is supplied. The choice of `re.match` means that a `scope` without
2757      special tokens filters by prefix.
2758
2759  Returns:
2760    A list of local `Variable` objects.
2761  """
2762  return ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES, scope)
2763
2764
2765@tf_export(v1=["model_variables"])
2766def model_variables(scope=None):
2767  """Returns all variables in the MODEL_VARIABLES collection.
2768
2769  Args:
2770    scope: (Optional.) A string. If supplied, the resulting list is filtered
2771      to include only items whose `name` attribute matches `scope` using
2772      `re.match`. Items without a `name` attribute are never returned if a
2773      scope is supplied. The choice of `re.match` means that a `scope` without
2774      special tokens filters by prefix.
2775
2776  Returns:
2777    A list of local Variable objects.
2778  """
2779  return ops.get_collection(ops.GraphKeys.MODEL_VARIABLES, scope)
2780
2781
2782@tf_export(v1=["trainable_variables"])
2783def trainable_variables(scope=None):
2784  """Returns all variables created with `trainable=True`.
2785
2786  When passed `trainable=True`, the `Variable()` constructor automatically
2787  adds new variables to the graph collection
2788  `GraphKeys.TRAINABLE_VARIABLES`. This convenience function returns the
2789  contents of that collection.
2790
2791  Args:
2792    scope: (Optional.) A string. If supplied, the resulting list is filtered
2793      to include only items whose `name` attribute matches `scope` using
2794      `re.match`. Items without a `name` attribute are never returned if a
2795      scope is supplied. The choice of `re.match` means that a `scope` without
2796      special tokens filters by prefix.
2797
2798  Returns:
2799    A list of Variable objects.
2800  """
2801  return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES, scope)
2802
2803
2804@tf_export(v1=["moving_average_variables"])
2805def moving_average_variables(scope=None):
2806  """Returns all variables that maintain their moving averages.
2807
2808  If an `ExponentialMovingAverage` object is created and the `apply()`
2809  method is called on a list of variables, these variables will
2810  be added to the `GraphKeys.MOVING_AVERAGE_VARIABLES` collection.
2811  This convenience function returns the contents of that collection.
2812
2813  Args:
2814    scope: (Optional.) A string. If supplied, the resulting list is filtered
2815      to include only items whose `name` attribute matches `scope` using
2816      `re.match`. Items without a `name` attribute are never returned if a
2817      scope is supplied. The choice of `re.match` means that a `scope` without
2818      special tokens filters by prefix.
2819
2820  Returns:
2821    A list of Variable objects.
2822  """
2823  return ops.get_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, scope)
2824
2825
2826@tf_export(v1=["initializers.variables", "variables_initializer"])
2827def variables_initializer(var_list, name="init"):
2828  """Returns an Op that initializes a list of variables.
2829
2830  After you launch the graph in a session, you can run the returned Op to
2831  initialize all the variables in `var_list`. This Op runs all the
2832  initializers of the variables in `var_list` in parallel.
2833
2834  Calling `initialize_variables()` is equivalent to passing the list of
2835  initializers to `Group()`.
2836
2837  If `var_list` is empty, however, the function still returns an Op that can
2838  be run. That Op just has no effect.
2839
2840  Args:
2841    var_list: List of `Variable` objects to initialize.
2842    name: Optional name for the returned operation.
2843
2844  Returns:
2845    An Op that run the initializers of all the specified variables.
2846  """
2847  if var_list and not context.executing_eagerly():
2848    return control_flow_ops.group(*[v.initializer for v in var_list], name=name)
2849  return control_flow_ops.no_op(name=name)
2850
2851
2852@tf_export(v1=["initialize_variables"])
2853@tf_should_use.should_use_result
2854@deprecated("2017-03-02", "Use `tf.variables_initializer` instead.")
2855def initialize_variables(var_list, name="init"):
2856  """See `tf.variables_initializer`."""
2857  return variables_initializer(var_list, name=name)
2858
2859
2860@tf_export(v1=["initializers.global_variables", "global_variables_initializer"])
2861def global_variables_initializer():
2862  """Returns an Op that initializes global variables.
2863
2864  This is just a shortcut for `variables_initializer(global_variables())`
2865
2866  Returns:
2867    An Op that initializes global variables in the graph.
2868  """
2869  if context.executing_eagerly():
2870    return control_flow_ops.no_op(name="global_variables_initializer")
2871  return variables_initializer(global_variables())
2872
2873
2874@tf_export(v1=["initialize_all_variables"])
2875@tf_should_use.should_use_result
2876@deprecated("2017-03-02", "Use `tf.global_variables_initializer` instead.")
2877def initialize_all_variables():
2878  """See `tf.global_variables_initializer`."""
2879  return global_variables_initializer()
2880
2881
2882@tf_export(v1=["initializers.local_variables", "local_variables_initializer"])
2883def local_variables_initializer():
2884  """Returns an Op that initializes all local variables.
2885
2886  This is just a shortcut for `variables_initializer(local_variables())`
2887
2888  Returns:
2889    An Op that initializes all local variables in the graph.
2890  """
2891  if context.executing_eagerly():
2892    return control_flow_ops.no_op(name="local_variables_initializer")
2893  return variables_initializer(local_variables())
2894
2895
2896@tf_export(v1=["initialize_local_variables"])
2897@tf_should_use.should_use_result
2898@deprecated("2017-03-02", "Use `tf.local_variables_initializer` instead.")
2899def initialize_local_variables():
2900  """See `tf.local_variables_initializer`."""
2901  return local_variables_initializer()
2902
2903
2904@tf_export(v1=["is_variable_initialized"])
2905@tf_should_use.should_use_result
2906def is_variable_initialized(variable):
2907  """Tests if a variable has been initialized.
2908
2909  Args:
2910    variable: A `Variable`.
2911
2912  Returns:
2913    Returns a scalar boolean Tensor, `True` if the variable has been
2914    initialized, `False` otherwise.
2915  """
2916  return state_ops.is_variable_initialized(variable)
2917
2918
2919@tf_export(v1=["assert_variables_initialized"])
2920@tf_should_use.should_use_result
2921def assert_variables_initialized(var_list=None):
2922  """Returns an Op to check if variables are initialized.
2923
2924  NOTE: This function is obsolete and will be removed in 6 months.  Please
2925  change your implementation to use `report_uninitialized_variables()`.
2926
2927  When run, the returned Op will raise the exception `FailedPreconditionError`
2928  if any of the variables has not yet been initialized.
2929
2930  Note: This function is implemented by trying to fetch the values of the
2931  variables. If one of the variables is not initialized a message may be
2932  logged by the C++ runtime. This is expected.
2933
2934  Args:
2935    var_list: List of `Variable` objects to check. Defaults to the
2936      value of `global_variables().`
2937
2938  Returns:
2939    An Op, or None if there are no variables.
2940  """
2941  if var_list is None:
2942    var_list = global_variables() + local_variables()
2943  # Backwards compatibility for old-style variables. TODO(touts): remove.
2944  if not var_list:
2945    var_list = []
2946    for op in ops.get_default_graph().get_operations():
2947      if op.type in ["Variable", "VariableV2", "AutoReloadVariable"]:
2948        var_list.append(op.outputs[0])
2949  if not var_list:
2950    return None
2951  else:
2952    ranks = []
2953    for var in var_list:
2954      with ops.colocate_with(var.op):
2955        ranks.append(array_ops.rank_internal(var, optimize=False))
2956    if len(ranks) == 1:
2957      return ranks[0]
2958    else:
2959      return array_ops.stack(ranks)
2960
2961
2962@tf_export(v1=["report_uninitialized_variables"])
2963@tf_should_use.should_use_result
2964def report_uninitialized_variables(var_list=None,
2965                                   name="report_uninitialized_variables"):
2966  """Adds ops to list the names of uninitialized variables.
2967
2968  When run, it returns a 1-D tensor containing the names of uninitialized
2969  variables if there are any, or an empty array if there are none.
2970
2971  Args:
2972    var_list: List of `Variable` objects to check. Defaults to the
2973      value of `global_variables() + local_variables()`
2974    name: Optional name of the `Operation`.
2975
2976  Returns:
2977    A 1-D tensor containing names of the uninitialized variables, or an empty
2978    1-D tensor if there are no variables or no uninitialized variables.
2979  """
2980  if var_list is None:
2981    var_list = global_variables() + local_variables()
2982    # Backwards compatibility for old-style variables. TODO(touts): remove.
2983    if not var_list:
2984      var_list = []
2985      for op in ops.get_default_graph().get_operations():
2986        if op.type in ["Variable", "VariableV2", "AutoReloadVariable"]:
2987          var_list.append(op.outputs[0])
2988  with ops.name_scope(name):
2989    # Run all operations on CPU
2990    if var_list:
2991      init_vars = [state_ops.is_variable_initialized(v) for v in var_list]
2992    local_device = os.environ.get(
2993        "TF_DEVICE_FOR_UNINITIALIZED_VARIABLE_REPORTING", "/cpu:0")
2994    with ops.device(local_device):
2995      if not var_list:
2996        # Return an empty tensor so we only need to check for returned tensor
2997        # size being 0 as an indication of model ready.
2998        return array_ops.constant([], dtype=dtypes.string)
2999      else:
3000        # Get a 1-D boolean tensor listing whether each variable is initialized.
3001        variables_mask = math_ops.logical_not(array_ops.stack(init_vars))
3002        # Get a 1-D string tensor containing all the variable names.
3003        variable_names_tensor = array_ops.constant(
3004            [s.op.name for s in var_list])
3005        # Return a 1-D tensor containing all the names of
3006        # uninitialized variables.
3007        return array_ops.boolean_mask(variable_names_tensor, variables_mask)
3008
3009
3010ops.register_tensor_conversion_function(
3011    PartitionedVariable,
3012    PartitionedVariable._TensorConversionFunction)  # pylint: disable=protected-access
3013