1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Variables.
17
18See the [Variables](https://www.tensorflow.org/guide/variables) guide.
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import gen_math_ops
29from tensorflow.python.ops import gen_resource_variable_ops
30from tensorflow.python.ops import gen_state_ops
31# go/tf-wildcard-import
32# pylint: disable=wildcard-import
33from tensorflow.python.ops.gen_state_ops import *
34# pylint: enable=wildcard-import
35from tensorflow.python.util import deprecation
36from tensorflow.python.util.deprecation import deprecated
37from tensorflow.python.util.tf_export import tf_export
38
39
40# pylint: disable=protected-access,g-doc-return-or-yield,g-doc-args
41def variable_op(shape, dtype, name="Variable", set_shape=True, container="",
42                shared_name=""):
43  """Deprecated. Used variable_op_v2 instead."""
44  if not set_shape:
45    shape = tensor_shape.unknown_shape()
46  ret = gen_state_ops.variable(shape=shape, dtype=dtype, name=name,
47                               container=container, shared_name=shared_name)
48  # TODO(mrry): Move this to where it is used, so we can get rid of this op
49  #   wrapper?
50  if set_shape:
51    ret.set_shape(shape)
52  return ret
53
54
55def variable_op_v2(shape, dtype, name="Variable", container="", shared_name=""):
56  """Create a variable Operation.
57
58  See also variables.Variable.
59
60  Args:
61    shape: The shape of the tensor managed by this variable
62    dtype: The underlying type of the tensor values.
63    name: optional name to use for the variable op.
64    container: An optional string. Defaults to "".
65      If non-empty, this variable is placed in the given container.
66      Otherwise, a default container is used.
67    shared_name: An optional string. Defaults to "".
68      If non-empty, this variable is named in the given bucket
69      with this shared_name. Otherwise, the node name is used instead.
70
71  Returns:
72    A variable tensor.
73  """
74  return gen_state_ops.variable_v2(
75      shape=shape,
76      dtype=dtype,
77      name=name,
78      container=container,
79      shared_name=shared_name)
80
81
82def init_variable(v, init, name="init"):
83  """Initializes variable with "init".
84
85  This op does the following:
86  if init is a Tensor, v = init
87  if callable(init): v = init(VariableShape(v), v.dtype)
88
89  Args:
90    v: Variable to initialize
91    init: Tensor to assign to v,
92      Or an object convertible to Tensor e.g. nparray,
93      Or an Initializer that generates a tensor given the shape and type of v.
94      An "Initializer" is a callable that returns a tensor that "v" should be
95      set to. It will be called as init(shape, dtype).
96    name: Optional name for the op.
97
98  Returns:
99    The operation that initializes v.
100  """
101  with ops.name_scope(None, v.op.name + "/", [v, init]):
102    with ops.name_scope(name) as scope:
103      with ops.colocate_with(v):
104        if callable(init):
105          assert v.get_shape().is_fully_defined(), "Variable shape unknown."
106          # TODO(mrry): Convert to v.shape when the property and
107          # accessor are reconciled (and all initializers support
108          # tf.TensorShape objects).
109          value = init(v.get_shape().as_list(), v.dtype.base_dtype)
110          value = ops.convert_to_tensor(value, name="value")
111          return gen_state_ops.assign(v, value, name=scope)
112        else:
113          init = ops.convert_to_tensor(init, name="init")
114          return gen_state_ops.assign(v, init, name=scope)
115
116
117def is_variable_initialized(ref, name=None):
118  """Checks whether a tensor has been initialized.
119
120  Outputs boolean scalar indicating whether the tensor has been initialized.
121
122  Args:
123    ref: A mutable `Tensor`.
124      Should be from a `Variable` node. May be uninitialized.
125    name: A name for the operation (optional).
126
127  Returns:
128    A `Tensor` of type `bool`.
129  """
130  if ref.dtype._is_ref_dtype:
131    return gen_state_ops.is_variable_initialized(ref=ref, name=name)
132  # Handle resource variables.
133  return ref.is_initialized(name=name)
134
135
136@tf_export(v1=["assign_sub"])
137def assign_sub(ref, value, use_locking=None, name=None):
138  """Update `ref` by subtracting `value` from it.
139
140  This operation outputs `ref` after the update is done.
141  This makes it easier to chain operations that need to use the reset value.
142  Unlike `tf.math.subtract`, this op does not broadcast. `ref` and `value`
143  must have the same shape.
144
145  Args:
146    ref: A mutable `Tensor`. Must be one of the following types: `float32`,
147      `float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`,
148      `complex64`, `complex128`, `qint8`, `quint8`, `qint32`, `half`. Should be
149      from a `Variable` node.
150    value: A `Tensor`. Must have the same shape and dtype as `ref`. The value to
151      be subtracted to the variable.
152    use_locking: An optional `bool`. Defaults to `False`. If True, the
153      subtraction will be protected by a lock; otherwise the behavior is
154      undefined, but may exhibit less contention.
155    name: A name for the operation (optional).
156
157  Returns:
158    Same as "ref".  Returned as a convenience for operations that want
159    to use the new value after the variable has been updated.
160  """
161  if ref.dtype._is_ref_dtype:
162    return gen_state_ops.assign_sub(
163        ref, value, use_locking=use_locking, name=name)
164  return ref.assign_sub(value)
165
166
167@tf_export(v1=["assign_add"])
168def assign_add(ref, value, use_locking=None, name=None):
169  """Update `ref` by adding `value` to it.
170
171  This operation outputs "ref" after the update is done.
172  This makes it easier to chain operations that need to use the reset value.
173  Unlike `tf.math.add`, this op does not broadcast. `ref` and `value` must have
174  the same shape.
175
176  Args:
177    ref: A mutable `Tensor`. Must be one of the following types: `float32`,
178      `float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`,
179      `complex64`, `complex128`, `qint8`, `quint8`, `qint32`, `half`. Should be
180      from a `Variable` node.
181    value: A `Tensor`. Must have the same shape and dtype as `ref`. The value to
182      be added to the variable.
183    use_locking: An optional `bool`. Defaults to `False`. If True, the addition
184      will be protected by a lock; otherwise the behavior is undefined, but may
185      exhibit less contention.
186    name: A name for the operation (optional).
187
188  Returns:
189    Same as "ref".  Returned as a convenience for operations that want
190    to use the new value after the variable has been updated.
191  """
192  if ref.dtype._is_ref_dtype:
193    return gen_state_ops.assign_add(
194        ref, value, use_locking=use_locking, name=name)
195  return ref.assign_add(value)
196
197
198@tf_export(v1=["assign"])
199def assign(ref, value, validate_shape=None, use_locking=None, name=None):
200  """Update `ref` by assigning `value` to it.
201
202  This operation outputs a Tensor that holds the new value of `ref` after
203  the value has been assigned. This makes it easier to chain operations that
204  need to use the reset value.
205
206  Args:
207    ref: A mutable `Tensor`. Should be from a `Variable` node. May be
208      uninitialized.
209    value: A `Tensor`. Must have the same shape and dtype as `ref`. The value to
210      be assigned to the variable.
211    validate_shape: An optional `bool`. Defaults to `True`. If true, the
212      operation will validate that the shape of 'value' matches the shape of the
213      Tensor being assigned to.  If false, 'ref' will take on the shape of
214      'value'.
215    use_locking: An optional `bool`. Defaults to `True`. If True, the assignment
216      will be protected by a lock; otherwise the behavior is undefined, but may
217      exhibit less contention.
218    name: A name for the operation (optional).
219
220  Returns:
221    A `Tensor` that will hold the new value of `ref` after
222      the assignment has completed.
223  """
224  if ref.dtype._is_ref_dtype:
225    return gen_state_ops.assign(
226        ref, value, use_locking=use_locking, name=name,
227        validate_shape=validate_shape)
228  return ref.assign(value, name=name)
229
230
231@tf_export(v1=["count_up_to"])
232@deprecated(None, "Prefer Dataset.range instead.")
233def count_up_to(ref, limit, name=None):
234  r"""Increments 'ref' until it reaches 'limit'.
235
236  Args:
237    ref: A Variable. Must be one of the following types: `int32`, `int64`.
238      Should be from a scalar `Variable` node.
239    limit: An `int`.
240      If incrementing ref would bring it above limit, instead generates an
241      'OutOfRange' error.
242    name: A name for the operation (optional).
243
244  Returns:
245    A `Tensor`. Has the same type as `ref`.
246    A copy of the input before increment. If nothing else modifies the
247    input, the values produced will all be distinct.
248  """
249  if ref.dtype._is_ref_dtype:
250    return gen_state_ops.count_up_to(ref, limit=limit, name=name)
251  return gen_state_ops.resource_count_up_to(
252      ref.handle, limit, T=ref.dtype, name=name)
253
254
255@tf_export(v1=["scatter_update"])
256def scatter_update(ref, indices, updates, use_locking=True, name=None):
257  # pylint: disable=line-too-long
258  r"""Applies sparse updates to a variable reference.
259
260  This operation computes
261
262  ```python
263      # Scalar indices
264      ref[indices, ...] = updates[...]
265
266      # Vector indices (for each i)
267      ref[indices[i], ...] = updates[i, ...]
268
269      # High rank indices (for each i, ..., j)
270      ref[indices[i, ..., j], ...] = updates[i, ..., j, ...]
271  ```
272
273  This operation outputs `ref` after the update is done.
274  This makes it easier to chain operations that need to use the reset value.
275
276  If values in `ref` is to be updated more than once, because there are
277  duplicate entries in `indices`, the order at which the updates happen
278  for each value is undefined.
279
280  Requires `updates.shape = indices.shape + ref.shape[1:]`.
281
282  <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
283  <img style="width:100%" src="https://www.tensorflow.org/images/ScatterUpdate.png" alt>
284  </div>
285
286  Args:
287    ref: A `Variable`.
288    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
289      A tensor of indices into the first dimension of `ref`.
290    updates: A `Tensor`. Must have the same type as `ref`.
291      A tensor of updated values to store in `ref`.
292    use_locking: An optional `bool`. Defaults to `True`.
293      If True, the assignment will be protected by a lock;
294      otherwise the behavior is undefined, but may exhibit less contention.
295    name: A name for the operation (optional).
296
297  Returns:
298    Same as `ref`.  Returned as a convenience for operations that want
299    to use the updated values after the update is done.
300  """
301  if ref.dtype._is_ref_dtype:
302    return gen_state_ops.scatter_update(ref, indices, updates,
303                                        use_locking=use_locking, name=name)
304  return ref._lazy_read(gen_resource_variable_ops.resource_scatter_update(  # pylint: disable=protected-access
305      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
306      name=name))
307
308
309@tf_export(v1=["scatter_nd_update"])
310def scatter_nd_update(ref, indices, updates, use_locking=True, name=None):
311  r"""Applies sparse `updates` to individual values or slices in a Variable.
312
313  `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
314
315  `indices` must be integer tensor, containing indices into `ref`.
316  It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
317
318  The innermost dimension of `indices` (with length `K`) corresponds to
319  indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
320  dimension of `ref`.
321
322  `updates` is `Tensor` of rank `Q-1+P-K` with shape:
323
324  ```
325  [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
326  ```
327
328  For example, say we want to update 4 scattered elements to a rank-1 tensor to
329  8 elements. In Python, that update would look like this:
330
331  ```python
332      ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
333      indices = tf.constant([[4], [3], [1] ,[7]])
334      updates = tf.constant([9, 10, 11, 12])
335      update = tf.compat.v1.scatter_nd_update(ref, indices, updates)
336      with tf.compat.v1.Session() as sess:
337        print sess.run(update)
338  ```
339
340  The resulting update to ref would look like this:
341
342      [1, 11, 3, 10, 9, 6, 7, 12]
343
344  See `tf.scatter_nd` for more details about how to make updates to
345  slices.
346
347  Args:
348    ref: A Variable.
349    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
350      A tensor of indices into ref.
351    updates: A `Tensor`. Must have the same type as `ref`.
352      A Tensor. Must have the same type as ref. A tensor of updated
353      values to add to ref.
354    use_locking: An optional `bool`. Defaults to `True`.
355      An optional bool. Defaults to True. If True, the assignment will
356      be protected by a lock; otherwise the behavior is undefined,
357      but may exhibit less contention.
358    name: A name for the operation (optional).
359
360  Returns:
361    The value of the variable after the update.
362  """
363  if ref.dtype._is_ref_dtype:
364    return gen_state_ops.scatter_nd_update(
365        ref, indices, updates, use_locking, name)
366  return ref._lazy_read(gen_state_ops.resource_scatter_nd_update(  # pylint: disable=protected-access
367      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
368      name=name))
369
370
371@tf_export(v1=["scatter_add"])
372def scatter_add(ref, indices, updates, use_locking=False, name=None):
373  # pylint: disable=line-too-long
374  r"""Adds sparse updates to the variable referenced by `resource`.
375
376  This operation computes
377
378  ```python
379      # Scalar indices
380      ref[indices, ...] += updates[...]
381
382      # Vector indices (for each i)
383      ref[indices[i], ...] += updates[i, ...]
384
385      # High rank indices (for each i, ..., j)
386      ref[indices[i, ..., j], ...] += updates[i, ..., j, ...]
387  ```
388
389  This operation outputs `ref` after the update is done.
390  This makes it easier to chain operations that need to use the updated value.
391  Duplicate entries are handled correctly: if multiple `indices` reference
392  the same location, their contributions add.
393
394  Requires `updates.shape = indices.shape + ref.shape[1:]`.
395
396  <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
397  <img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt>
398  </div>
399
400  Args:
401    ref: A `Variable`.
402    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
403      A tensor of indices into the first dimension of `ref`.
404    updates: A `Tensor`. Must have the same type as `ref`.
405      A tensor of updated values to store in `ref`.
406    use_locking: An optional `bool`. Defaults to `False`.
407      If True, the assignment will be protected by a lock;
408      otherwise the behavior is undefined, but may exhibit less contention.
409    name: A name for the operation (optional).
410
411  Returns:
412    Same as `ref`.  Returned as a convenience for operations that want
413    to use the updated values after the update is done.
414  """
415  if ref.dtype._is_ref_dtype:
416    return gen_state_ops.scatter_add(ref, indices, updates,
417                                     use_locking=use_locking, name=name)
418  return ref._lazy_read(gen_resource_variable_ops.resource_scatter_add(  # pylint: disable=protected-access
419      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
420      name=name))
421
422
423@tf_export(v1=["scatter_nd_add"])
424def scatter_nd_add(ref, indices, updates, use_locking=False, name=None):
425  r"""Applies sparse addition to individual values or slices in a Variable.
426
427  `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
428
429  `indices` must be integer tensor, containing indices into `ref`.
430  It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
431
432  The innermost dimension of `indices` (with length `K`) corresponds to
433  indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
434  dimension of `ref`.
435
436  `updates` is `Tensor` of rank `Q-1+P-K` with shape:
437
438  ```
439  [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]
440  ```
441
442  For example, say we want to add 4 scattered elements to a rank-1 tensor to
443  8 elements. In Python, that addition would look like this:
444
445  ```python
446  ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
447  indices = tf.constant([[4], [3], [1], [7]])
448  updates = tf.constant([9, 10, 11, 12])
449  add = tf.compat.v1.scatter_nd_add(ref, indices, updates)
450  with tf.compat.v1.Session() as sess:
451    print sess.run(add)
452  ```
453
454  The resulting update to ref would look like this:
455
456      [1, 13, 3, 14, 14, 6, 7, 20]
457
458  See `tf.scatter_nd` for more details about how to make updates to
459  slices.
460
461  Args:
462    ref: A mutable `Tensor`. Must be one of the following types: `float32`,
463      `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
464      `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
465      `uint32`, `uint64`. A mutable Tensor. Should be from a Variable node.
466    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
467      A tensor of indices into ref.
468    updates: A `Tensor`. Must have the same type as `ref`.
469      A tensor of updated values to add to ref.
470    use_locking: An optional `bool`. Defaults to `False`.
471      If True, the assignment will be protected by a lock;
472      otherwise the behavior is undefined, but may exhibit less contention.
473    name: A name for the operation (optional).
474
475  Returns:
476    A mutable `Tensor`. Has the same type as `ref`.
477  """
478  if ref.dtype._is_ref_dtype:
479    return gen_state_ops.scatter_nd_add(
480        ref, indices, updates, use_locking, name)
481  return ref._lazy_read(gen_state_ops.resource_scatter_nd_add(  # pylint: disable=protected-access
482      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
483      name=name))
484
485
486@tf_export(v1=["scatter_sub"])
487def scatter_sub(ref, indices, updates, use_locking=False, name=None):
488  r"""Subtracts sparse updates to a variable reference.
489
490  ```python
491      # Scalar indices
492      ref[indices, ...] -= updates[...]
493
494      # Vector indices (for each i)
495      ref[indices[i], ...] -= updates[i, ...]
496
497      # High rank indices (for each i, ..., j)
498      ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...]
499  ```
500
501  This operation outputs `ref` after the update is done.
502  This makes it easier to chain operations that need to use the reset value.
503
504  Duplicate entries are handled correctly: if multiple `indices` reference
505  the same location, their (negated) contributions add.
506
507  Requires `updates.shape = indices.shape + ref.shape[1:]` or
508  `updates.shape = []`.
509
510  <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
511  <img style="width:100%"
512       src="https://www.tensorflow.org/images/ScatterSub.png" alt>
513  </div>
514
515  Args:
516    ref: A mutable `Tensor`. Must be one of the following types: `float32`,
517      `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
518      `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
519      `uint32`, `uint64`. Should be from a `Variable` node.
520    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
521      A tensor of indices into the first dimension of `ref`.
522    updates: A `Tensor`. Must have the same type as `ref`.
523      A tensor of updated values to subtract from `ref`.
524    use_locking: An optional `bool`. Defaults to `False`.
525      If True, the subtraction will be protected by a lock;
526      otherwise the behavior is undefined, but may exhibit less contention.
527    name: A name for the operation (optional).
528
529  Returns:
530    A mutable `Tensor`. Has the same type as `ref`.
531  """
532  if ref.dtype._is_ref_dtype:
533    return gen_state_ops.scatter_sub(ref, indices, updates,
534                                     use_locking=use_locking, name=name)
535  return ref._lazy_read(gen_resource_variable_ops.resource_scatter_sub(  # pylint: disable=protected-access
536      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
537      name=name))
538
539
540@tf_export(v1=["scatter_nd_sub"])
541def scatter_nd_sub(ref, indices, updates, use_locking=False, name=None):
542  r"""Applies sparse subtraction to individual values or slices in a Variable.
543
544  `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
545
546  `indices` must be integer tensor, containing indices into `ref`.
547  It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
548
549  The innermost dimension of `indices` (with length `K`) corresponds to
550  indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
551  dimension of `ref`.
552
553  `updates` is `Tensor` of rank `Q-1+P-K` with shape:
554
555  ```
556  [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]
557  ```
558
559  For example, say we want to subtract 4 scattered elements from a rank-1 tensor
560  with 8 elements. In Python, that update would look like this:
561
562  ```python
563  ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
564  indices = tf.constant([[4], [3], [1] ,[7]])
565  updates = tf.constant([9, 10, 11, 12])
566  op = tf.compat.v1.scatter_nd_sub(ref, indices, updates)
567  with tf.compat.v1.Session() as sess:
568    print sess.run(op)
569  ```
570
571  The resulting update to ref would look like this:
572
573      [1, -9, 3, -6, -6, 6, 7, -4]
574
575  See `tf.scatter_nd` for more details about how to make updates to
576  slices.
577
578  Args:
579    ref: A mutable `Tensor`. Must be one of the following types: `float32`,
580      `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
581      `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
582      `uint32`, `uint64`. A mutable Tensor. Should be from a Variable node.
583    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
584      A tensor of indices into ref.
585    updates: A `Tensor`. Must have the same type as `ref`.
586      A tensor of updated values to add to ref.
587    use_locking: An optional `bool`. Defaults to `False`.
588      An optional bool. Defaults to True. If True, the assignment will
589      be protected by a lock; otherwise the behavior is undefined,
590      but may exhibit less contention.
591    name: A name for the operation (optional).
592
593  Returns:
594    A mutable `Tensor`. Has the same type as `ref`.
595  """
596  if ref.dtype._is_ref_dtype:
597    return gen_state_ops.scatter_nd_sub(
598        ref, indices, updates, use_locking, name)
599  return ref._lazy_read(gen_state_ops.resource_scatter_nd_sub(  # pylint: disable=protected-access
600      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
601      name=name))
602
603
604@tf_export(v1=["scatter_mul"])
605def scatter_mul(ref, indices, updates, use_locking=False, name=None):
606  # pylint: disable=line-too-long
607  r"""Multiplies sparse updates into a variable reference.
608
609  This operation computes
610
611  ```python
612      # Scalar indices
613      ref[indices, ...] *= updates[...]
614
615      # Vector indices (for each i)
616      ref[indices[i], ...] *= updates[i, ...]
617
618      # High rank indices (for each i, ..., j)
619      ref[indices[i, ..., j], ...] *= updates[i, ..., j, ...]
620  ```
621
622  This operation outputs `ref` after the update is done.
623  This makes it easier to chain operations that need to use the reset value.
624
625  Duplicate entries are handled correctly: if multiple `indices` reference
626  the same location, their contributions multiply.
627
628  Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape =
629  []`.
630
631  Args:
632    ref: A mutable `Tensor`. Must be one of the following types: `float32`,
633      `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
634      `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
635      `uint32`, `uint64`. Should be from a `Variable` node.
636    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A
637      tensor of indices into the first dimension of `ref`.
638    updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated
639      values to multiply to `ref`.
640    use_locking: An optional `bool`. Defaults to `False`. If True, the operation
641      will be protected by a lock; otherwise the behavior is undefined, but may
642      exhibit less contention.
643    name: A name for the operation (optional).
644
645  Returns:
646    A mutable `Tensor`. Has the same type as `ref`.
647  """
648  if ref.dtype._is_ref_dtype:
649    return gen_state_ops.scatter_mul(ref, indices, updates,
650                                     use_locking=use_locking, name=name)
651  return ref._lazy_read(gen_resource_variable_ops.resource_scatter_mul(  # pylint: disable=protected-access
652      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
653      name=name))
654
655
656@tf_export(v1=["scatter_div"])
657def scatter_div(ref, indices, updates, use_locking=False, name=None):
658  # pylint: disable=line-too-long
659  r"""Divides a variable reference by sparse updates.
660
661  This operation computes
662
663  ```python
664      # Scalar indices
665      ref[indices, ...] /= updates[...]
666
667      # Vector indices (for each i)
668      ref[indices[i], ...] /= updates[i, ...]
669
670      # High rank indices (for each i, ..., j)
671      ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...]
672  ```
673
674  This operation outputs `ref` after the update is done.
675  This makes it easier to chain operations that need to use the reset value.
676
677  Duplicate entries are handled correctly: if multiple `indices` reference
678  the same location, their contributions divide.
679
680  Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape =
681  []`.
682
683  Args:
684    ref: A mutable `Tensor`. Must be one of the following types: `float32`,
685      `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
686      `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
687      `uint32`, `uint64`. Should be from a `Variable` node.
688    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A
689      tensor of indices into the first dimension of `ref`.
690    updates: A `Tensor`. Must have the same type as `ref`. A tensor of values
691      that `ref` is divided by.
692    use_locking: An optional `bool`. Defaults to `False`. If True, the operation
693      will be protected by a lock; otherwise the behavior is undefined, but may
694      exhibit less contention.
695    name: A name for the operation (optional).
696
697  Returns:
698    A mutable `Tensor`. Has the same type as `ref`.
699  """
700  if ref.dtype._is_ref_dtype:
701    return gen_state_ops.scatter_div(ref, indices, updates,
702                                     use_locking=use_locking, name=name)
703  return ref._lazy_read(gen_resource_variable_ops.resource_scatter_div(  # pylint: disable=protected-access
704      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
705      name=name))
706
707
708@tf_export(v1=["scatter_max"])
709def scatter_max(ref, indices, updates, use_locking=False, name=None):
710  # pylint: disable=line-too-long
711  r"""Reduces sparse updates into a variable reference using the `max` operation.
712
713  This operation computes
714
715      # Scalar indices
716      ref[indices, ...] = max(ref[indices, ...], updates[...])
717
718      # Vector indices (for each i)
719      ref[indices[i], ...] = max(ref[indices[i], ...], updates[i, ...])
720
721      # High rank indices (for each i, ..., j)
722      ref[indices[i, ..., j], ...] = max(ref[indices[i, ..., j], ...],
723      updates[i, ..., j, ...])
724
725  This operation outputs `ref` after the update is done.
726  This makes it easier to chain operations that need to use the reset value.
727
728  Duplicate entries are handled correctly: if multiple `indices` reference
729  the same location, their contributions combine.
730
731  Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape =
732  []`.
733
734  <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
735  <img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png"
736  alt>
737  </div>
738
739  Args:
740    ref: A mutable `Tensor`. Must be one of the following types: `half`,
741      `bfloat16`, `float32`, `float64`, `int32`, `int64`. Should be from a
742      `Variable` node.
743    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A
744      tensor of indices into the first dimension of `ref`.
745    updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated
746      values to reduce into `ref`.
747    use_locking: An optional `bool`. Defaults to `False`. If True, the update
748      will be protected by a lock; otherwise the behavior is undefined, but may
749      exhibit less contention.
750    name: A name for the operation (optional).
751
752  Returns:
753    A mutable `Tensor`. Has the same type as `ref`.
754  """
755  if ref.dtype._is_ref_dtype:
756    return gen_state_ops.scatter_max(ref, indices, updates,
757                                     use_locking=use_locking, name=name)
758  return ref._lazy_read(gen_resource_variable_ops.resource_scatter_max(  # pylint: disable=protected-access
759      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
760      name=name))
761
762
763@tf_export(v1=["scatter_min"])
764def scatter_min(ref, indices, updates, use_locking=False, name=None):
765  # pylint: disable=line-too-long
766  r"""Reduces sparse updates into a variable reference using the `min` operation.
767
768  This operation computes
769
770      # Scalar indices
771      ref[indices, ...] = min(ref[indices, ...], updates[...])
772
773      # Vector indices (for each i)
774      ref[indices[i], ...] = min(ref[indices[i], ...], updates[i, ...])
775
776      # High rank indices (for each i, ..., j)
777      ref[indices[i, ..., j], ...] = min(ref[indices[i, ..., j], ...],
778      updates[i, ..., j, ...])
779
780  This operation outputs `ref` after the update is done.
781  This makes it easier to chain operations that need to use the reset value.
782
783  Duplicate entries are handled correctly: if multiple `indices` reference
784  the same location, their contributions combine.
785
786  Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape =
787  []`.
788
789  <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
790  <img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png"
791  alt>
792  </div>
793
794  Args:
795    ref: A mutable `Tensor`. Must be one of the following types: `half`,
796      `bfloat16`, `float32`, `float64`, `int32`, `int64`. Should be from a
797      `Variable` node.
798    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A
799      tensor of indices into the first dimension of `ref`.
800    updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated
801      values to reduce into `ref`.
802    use_locking: An optional `bool`. Defaults to `False`. If True, the update
803      will be protected by a lock; otherwise the behavior is undefined, but may
804      exhibit less contention.
805    name: A name for the operation (optional).
806
807  Returns:
808    A mutable `Tensor`. Has the same type as `ref`.
809  """
810  if ref.dtype._is_ref_dtype:
811    return gen_state_ops.scatter_min(ref, indices, updates,
812                                     use_locking=use_locking, name=name)
813  return ref._lazy_read(gen_resource_variable_ops.resource_scatter_min(  # pylint: disable=protected-access
814      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
815      name=name))
816
817
818@tf_export(v1=["batch_scatter_update"])
819@deprecation.deprecated(
820    "2018-11-29", "Use the batch_scatter_update method of Variable instead.")
821def batch_scatter_update(ref, indices, updates, use_locking=True, name=None):
822  """Generalization of `tf.compat.v1.scatter_update` to axis different than 0.
823
824  Analogous to `batch_gather`. This assumes that `ref`, `indices` and `updates`
825  have a series of leading dimensions that are the same for all of them, and the
826  updates are performed on the last dimension of indices. In other words, the
827  dimensions should be the following:
828
829  `num_prefix_dims = indices.ndims - 1`
830  `batch_dim = num_prefix_dims + 1`
831  `updates.shape = indices.shape + var.shape[batch_dim:]`
832
833  where
834
835  `updates.shape[:num_prefix_dims]`
836  `== indices.shape[:num_prefix_dims]`
837  `== var.shape[:num_prefix_dims]`
838
839  And the operation performed can be expressed as:
840
841  `var[i_1, ..., i_n, indices[i_1, ..., i_n, j]] = updates[i_1, ..., i_n, j]`
842
843  When indices is a 1D tensor, this operation is equivalent to
844  `tf.compat.v1.scatter_update`.
845
846  To avoid this operation there would be 2 alternatives:
847  1) Reshaping the variable by merging the first `ndims` dimensions. However,
848     this is not possible because `tf.reshape` returns a Tensor, which we
849     cannot use `tf.compat.v1.scatter_update` on.
850  2) Looping over the first `ndims` of the variable and using
851     `tf.compat.v1.scatter_update` on the subtensors that result of slicing the
852     first
853     dimension. This is a valid option for `ndims = 1`, but less efficient than
854     this implementation.
855
856  See also `tf.compat.v1.scatter_update` and `tf.compat.v1.scatter_nd_update`.
857
858  Args:
859    ref: `Variable` to scatter onto.
860    indices: Tensor containing indices as described above.
861    updates: Tensor of updates to apply to `ref`.
862    use_locking: Boolean indicating whether to lock the writing operation.
863    name: Optional scope name string.
864
865  Returns:
866    Ref to `variable` after it has been modified.
867
868  Raises:
869    ValueError: If the initial `ndims` of `ref`, `indices`, and `updates` are
870        not the same.
871  """
872  with ops.name_scope(name):
873    indices = ops.convert_to_tensor(indices, name="indices")
874    indices_shape = array_ops.shape(indices)
875    indices_dimensions = indices.get_shape().ndims
876
877    if indices_dimensions is None:
878      raise ValueError("batch_gather does not allow indices with unknown "
879                       "shape.")
880
881    nd_indices = array_ops.expand_dims(indices, axis=-1)
882    nd_indices_list = []
883
884    # Scatter ND requires indices to have an additional dimension, in which the
885    # coordinates of the updated things are specified. For this to be adapted to
886    # the scatter_update with several leading dimensions, we simply make use of
887    # a tf.range for all the leading dimensions followed by concat of all the
888    # coordinates we created with the original indices.
889
890    # For example if indices.shape = [2, 3, 4], we should generate the following
891    # indices for tf.compat.v1.scatter_nd_update:
892    # nd_indices[:, :, 0] = [[0, 0, 0], [1, 1, 1]]
893    # nd_indices[:, :, 1] = [[0, 1, 2], [0, 1, 2]]
894    # nd_indices[:, :, 2] = indices
895    for dimension in range(indices_dimensions - 1):
896      # In this loop we generate the following for the example (one for each
897      # iteration).
898      # nd_indices[:, :, 0] = [[0, 0, 0], [1, 1, 1]]
899      # nd_indices[:, :, 1] = [[0, 1, 2], [0, 1, 2]]
900      # This is done at every iteration with a tf.range over the size of the
901      # i-th dimension and using broadcasting over the desired shape.
902      dimension_size = indices_shape[dimension]
903      shape_to_broadcast = [1] * (indices_dimensions + 1)
904      shape_to_broadcast[dimension] = dimension_size
905      dimension_range = array_ops.reshape(
906          gen_math_ops._range(0, dimension_size, 1), shape_to_broadcast)
907      if dimension_range.dtype.base_dtype != nd_indices.dtype:
908        dimension_range = gen_math_ops.cast(dimension_range, nd_indices.dtype)
909      nd_indices_list.append(
910          dimension_range * array_ops.ones_like(nd_indices))
911    # Add the original indices at the end, as described above, and concat.
912    nd_indices_list.append(nd_indices)
913    final_indices = array_ops.concat(nd_indices_list, axis=-1)
914    return scatter_nd_update(
915        ref, final_indices, updates, use_locking=use_locking)
916