1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Variables. 17 18See the [Variables](https://www.tensorflow.org/guide/variables) guide. 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import gen_math_ops 29from tensorflow.python.ops import gen_resource_variable_ops 30from tensorflow.python.ops import gen_state_ops 31# go/tf-wildcard-import 32# pylint: disable=wildcard-import 33from tensorflow.python.ops.gen_state_ops import * 34# pylint: enable=wildcard-import 35from tensorflow.python.util import deprecation 36from tensorflow.python.util.deprecation import deprecated 37from tensorflow.python.util.tf_export import tf_export 38 39 40# pylint: disable=protected-access,g-doc-return-or-yield,g-doc-args 41def variable_op(shape, dtype, name="Variable", set_shape=True, container="", 42 shared_name=""): 43 """Deprecated. Used variable_op_v2 instead.""" 44 if not set_shape: 45 shape = tensor_shape.unknown_shape() 46 ret = gen_state_ops.variable(shape=shape, dtype=dtype, name=name, 47 container=container, shared_name=shared_name) 48 # TODO(mrry): Move this to where it is used, so we can get rid of this op 49 # wrapper? 50 if set_shape: 51 ret.set_shape(shape) 52 return ret 53 54 55def variable_op_v2(shape, dtype, name="Variable", container="", shared_name=""): 56 """Create a variable Operation. 57 58 See also variables.Variable. 59 60 Args: 61 shape: The shape of the tensor managed by this variable 62 dtype: The underlying type of the tensor values. 63 name: optional name to use for the variable op. 64 container: An optional string. Defaults to "". 65 If non-empty, this variable is placed in the given container. 66 Otherwise, a default container is used. 67 shared_name: An optional string. Defaults to "". 68 If non-empty, this variable is named in the given bucket 69 with this shared_name. Otherwise, the node name is used instead. 70 71 Returns: 72 A variable tensor. 73 """ 74 return gen_state_ops.variable_v2( 75 shape=shape, 76 dtype=dtype, 77 name=name, 78 container=container, 79 shared_name=shared_name) 80 81 82def init_variable(v, init, name="init"): 83 """Initializes variable with "init". 84 85 This op does the following: 86 if init is a Tensor, v = init 87 if callable(init): v = init(VariableShape(v), v.dtype) 88 89 Args: 90 v: Variable to initialize 91 init: Tensor to assign to v, 92 Or an object convertible to Tensor e.g. nparray, 93 Or an Initializer that generates a tensor given the shape and type of v. 94 An "Initializer" is a callable that returns a tensor that "v" should be 95 set to. It will be called as init(shape, dtype). 96 name: Optional name for the op. 97 98 Returns: 99 The operation that initializes v. 100 """ 101 with ops.name_scope(None, v.op.name + "/", [v, init]): 102 with ops.name_scope(name) as scope: 103 with ops.colocate_with(v): 104 if callable(init): 105 assert v.get_shape().is_fully_defined(), "Variable shape unknown." 106 # TODO(mrry): Convert to v.shape when the property and 107 # accessor are reconciled (and all initializers support 108 # tf.TensorShape objects). 109 value = init(v.get_shape().as_list(), v.dtype.base_dtype) 110 value = ops.convert_to_tensor(value, name="value") 111 return gen_state_ops.assign(v, value, name=scope) 112 else: 113 init = ops.convert_to_tensor(init, name="init") 114 return gen_state_ops.assign(v, init, name=scope) 115 116 117def is_variable_initialized(ref, name=None): 118 """Checks whether a tensor has been initialized. 119 120 Outputs boolean scalar indicating whether the tensor has been initialized. 121 122 Args: 123 ref: A mutable `Tensor`. 124 Should be from a `Variable` node. May be uninitialized. 125 name: A name for the operation (optional). 126 127 Returns: 128 A `Tensor` of type `bool`. 129 """ 130 if ref.dtype._is_ref_dtype: 131 return gen_state_ops.is_variable_initialized(ref=ref, name=name) 132 # Handle resource variables. 133 return ref.is_initialized(name=name) 134 135 136@tf_export(v1=["assign_sub"]) 137def assign_sub(ref, value, use_locking=None, name=None): 138 """Update `ref` by subtracting `value` from it. 139 140 This operation outputs `ref` after the update is done. 141 This makes it easier to chain operations that need to use the reset value. 142 Unlike `tf.math.subtract`, this op does not broadcast. `ref` and `value` 143 must have the same shape. 144 145 Args: 146 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 147 `float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, 148 `complex64`, `complex128`, `qint8`, `quint8`, `qint32`, `half`. Should be 149 from a `Variable` node. 150 value: A `Tensor`. Must have the same shape and dtype as `ref`. The value to 151 be subtracted to the variable. 152 use_locking: An optional `bool`. Defaults to `False`. If True, the 153 subtraction will be protected by a lock; otherwise the behavior is 154 undefined, but may exhibit less contention. 155 name: A name for the operation (optional). 156 157 Returns: 158 Same as "ref". Returned as a convenience for operations that want 159 to use the new value after the variable has been updated. 160 """ 161 if ref.dtype._is_ref_dtype: 162 return gen_state_ops.assign_sub( 163 ref, value, use_locking=use_locking, name=name) 164 return ref.assign_sub(value) 165 166 167@tf_export(v1=["assign_add"]) 168def assign_add(ref, value, use_locking=None, name=None): 169 """Update `ref` by adding `value` to it. 170 171 This operation outputs "ref" after the update is done. 172 This makes it easier to chain operations that need to use the reset value. 173 Unlike `tf.math.add`, this op does not broadcast. `ref` and `value` must have 174 the same shape. 175 176 Args: 177 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 178 `float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, 179 `complex64`, `complex128`, `qint8`, `quint8`, `qint32`, `half`. Should be 180 from a `Variable` node. 181 value: A `Tensor`. Must have the same shape and dtype as `ref`. The value to 182 be added to the variable. 183 use_locking: An optional `bool`. Defaults to `False`. If True, the addition 184 will be protected by a lock; otherwise the behavior is undefined, but may 185 exhibit less contention. 186 name: A name for the operation (optional). 187 188 Returns: 189 Same as "ref". Returned as a convenience for operations that want 190 to use the new value after the variable has been updated. 191 """ 192 if ref.dtype._is_ref_dtype: 193 return gen_state_ops.assign_add( 194 ref, value, use_locking=use_locking, name=name) 195 return ref.assign_add(value) 196 197 198@tf_export(v1=["assign"]) 199def assign(ref, value, validate_shape=None, use_locking=None, name=None): 200 """Update `ref` by assigning `value` to it. 201 202 This operation outputs a Tensor that holds the new value of `ref` after 203 the value has been assigned. This makes it easier to chain operations that 204 need to use the reset value. 205 206 Args: 207 ref: A mutable `Tensor`. Should be from a `Variable` node. May be 208 uninitialized. 209 value: A `Tensor`. Must have the same shape and dtype as `ref`. The value to 210 be assigned to the variable. 211 validate_shape: An optional `bool`. Defaults to `True`. If true, the 212 operation will validate that the shape of 'value' matches the shape of the 213 Tensor being assigned to. If false, 'ref' will take on the shape of 214 'value'. 215 use_locking: An optional `bool`. Defaults to `True`. If True, the assignment 216 will be protected by a lock; otherwise the behavior is undefined, but may 217 exhibit less contention. 218 name: A name for the operation (optional). 219 220 Returns: 221 A `Tensor` that will hold the new value of `ref` after 222 the assignment has completed. 223 """ 224 if ref.dtype._is_ref_dtype: 225 return gen_state_ops.assign( 226 ref, value, use_locking=use_locking, name=name, 227 validate_shape=validate_shape) 228 return ref.assign(value, name=name) 229 230 231@tf_export(v1=["count_up_to"]) 232@deprecated(None, "Prefer Dataset.range instead.") 233def count_up_to(ref, limit, name=None): 234 r"""Increments 'ref' until it reaches 'limit'. 235 236 Args: 237 ref: A Variable. Must be one of the following types: `int32`, `int64`. 238 Should be from a scalar `Variable` node. 239 limit: An `int`. 240 If incrementing ref would bring it above limit, instead generates an 241 'OutOfRange' error. 242 name: A name for the operation (optional). 243 244 Returns: 245 A `Tensor`. Has the same type as `ref`. 246 A copy of the input before increment. If nothing else modifies the 247 input, the values produced will all be distinct. 248 """ 249 if ref.dtype._is_ref_dtype: 250 return gen_state_ops.count_up_to(ref, limit=limit, name=name) 251 return gen_state_ops.resource_count_up_to( 252 ref.handle, limit, T=ref.dtype, name=name) 253 254 255@tf_export(v1=["scatter_update"]) 256def scatter_update(ref, indices, updates, use_locking=True, name=None): 257 # pylint: disable=line-too-long 258 r"""Applies sparse updates to a variable reference. 259 260 This operation computes 261 262 ```python 263 # Scalar indices 264 ref[indices, ...] = updates[...] 265 266 # Vector indices (for each i) 267 ref[indices[i], ...] = updates[i, ...] 268 269 # High rank indices (for each i, ..., j) 270 ref[indices[i, ..., j], ...] = updates[i, ..., j, ...] 271 ``` 272 273 This operation outputs `ref` after the update is done. 274 This makes it easier to chain operations that need to use the reset value. 275 276 If values in `ref` is to be updated more than once, because there are 277 duplicate entries in `indices`, the order at which the updates happen 278 for each value is undefined. 279 280 Requires `updates.shape = indices.shape + ref.shape[1:]`. 281 282 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> 283 <img style="width:100%" src="https://www.tensorflow.org/images/ScatterUpdate.png" alt> 284 </div> 285 286 Args: 287 ref: A `Variable`. 288 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 289 A tensor of indices into the first dimension of `ref`. 290 updates: A `Tensor`. Must have the same type as `ref`. 291 A tensor of updated values to store in `ref`. 292 use_locking: An optional `bool`. Defaults to `True`. 293 If True, the assignment will be protected by a lock; 294 otherwise the behavior is undefined, but may exhibit less contention. 295 name: A name for the operation (optional). 296 297 Returns: 298 Same as `ref`. Returned as a convenience for operations that want 299 to use the updated values after the update is done. 300 """ 301 if ref.dtype._is_ref_dtype: 302 return gen_state_ops.scatter_update(ref, indices, updates, 303 use_locking=use_locking, name=name) 304 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_update( # pylint: disable=protected-access 305 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 306 name=name)) 307 308 309@tf_export(v1=["scatter_nd_update"]) 310def scatter_nd_update(ref, indices, updates, use_locking=True, name=None): 311 r"""Applies sparse `updates` to individual values or slices in a Variable. 312 313 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 314 315 `indices` must be integer tensor, containing indices into `ref`. 316 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 317 318 The innermost dimension of `indices` (with length `K`) corresponds to 319 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 320 dimension of `ref`. 321 322 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 323 324 ``` 325 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 326 ``` 327 328 For example, say we want to update 4 scattered elements to a rank-1 tensor to 329 8 elements. In Python, that update would look like this: 330 331 ```python 332 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 333 indices = tf.constant([[4], [3], [1] ,[7]]) 334 updates = tf.constant([9, 10, 11, 12]) 335 update = tf.compat.v1.scatter_nd_update(ref, indices, updates) 336 with tf.compat.v1.Session() as sess: 337 print sess.run(update) 338 ``` 339 340 The resulting update to ref would look like this: 341 342 [1, 11, 3, 10, 9, 6, 7, 12] 343 344 See `tf.scatter_nd` for more details about how to make updates to 345 slices. 346 347 Args: 348 ref: A Variable. 349 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 350 A tensor of indices into ref. 351 updates: A `Tensor`. Must have the same type as `ref`. 352 A Tensor. Must have the same type as ref. A tensor of updated 353 values to add to ref. 354 use_locking: An optional `bool`. Defaults to `True`. 355 An optional bool. Defaults to True. If True, the assignment will 356 be protected by a lock; otherwise the behavior is undefined, 357 but may exhibit less contention. 358 name: A name for the operation (optional). 359 360 Returns: 361 The value of the variable after the update. 362 """ 363 if ref.dtype._is_ref_dtype: 364 return gen_state_ops.scatter_nd_update( 365 ref, indices, updates, use_locking, name) 366 return ref._lazy_read(gen_state_ops.resource_scatter_nd_update( # pylint: disable=protected-access 367 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 368 name=name)) 369 370 371@tf_export(v1=["scatter_add"]) 372def scatter_add(ref, indices, updates, use_locking=False, name=None): 373 # pylint: disable=line-too-long 374 r"""Adds sparse updates to the variable referenced by `resource`. 375 376 This operation computes 377 378 ```python 379 # Scalar indices 380 ref[indices, ...] += updates[...] 381 382 # Vector indices (for each i) 383 ref[indices[i], ...] += updates[i, ...] 384 385 # High rank indices (for each i, ..., j) 386 ref[indices[i, ..., j], ...] += updates[i, ..., j, ...] 387 ``` 388 389 This operation outputs `ref` after the update is done. 390 This makes it easier to chain operations that need to use the updated value. 391 Duplicate entries are handled correctly: if multiple `indices` reference 392 the same location, their contributions add. 393 394 Requires `updates.shape = indices.shape + ref.shape[1:]`. 395 396 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> 397 <img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt> 398 </div> 399 400 Args: 401 ref: A `Variable`. 402 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 403 A tensor of indices into the first dimension of `ref`. 404 updates: A `Tensor`. Must have the same type as `ref`. 405 A tensor of updated values to store in `ref`. 406 use_locking: An optional `bool`. Defaults to `False`. 407 If True, the assignment will be protected by a lock; 408 otherwise the behavior is undefined, but may exhibit less contention. 409 name: A name for the operation (optional). 410 411 Returns: 412 Same as `ref`. Returned as a convenience for operations that want 413 to use the updated values after the update is done. 414 """ 415 if ref.dtype._is_ref_dtype: 416 return gen_state_ops.scatter_add(ref, indices, updates, 417 use_locking=use_locking, name=name) 418 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_add( # pylint: disable=protected-access 419 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 420 name=name)) 421 422 423@tf_export(v1=["scatter_nd_add"]) 424def scatter_nd_add(ref, indices, updates, use_locking=False, name=None): 425 r"""Applies sparse addition to individual values or slices in a Variable. 426 427 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 428 429 `indices` must be integer tensor, containing indices into `ref`. 430 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 431 432 The innermost dimension of `indices` (with length `K`) corresponds to 433 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 434 dimension of `ref`. 435 436 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 437 438 ``` 439 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]] 440 ``` 441 442 For example, say we want to add 4 scattered elements to a rank-1 tensor to 443 8 elements. In Python, that addition would look like this: 444 445 ```python 446 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 447 indices = tf.constant([[4], [3], [1], [7]]) 448 updates = tf.constant([9, 10, 11, 12]) 449 add = tf.compat.v1.scatter_nd_add(ref, indices, updates) 450 with tf.compat.v1.Session() as sess: 451 print sess.run(add) 452 ``` 453 454 The resulting update to ref would look like this: 455 456 [1, 13, 3, 14, 14, 6, 7, 20] 457 458 See `tf.scatter_nd` for more details about how to make updates to 459 slices. 460 461 Args: 462 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 463 `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, 464 `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, 465 `uint32`, `uint64`. A mutable Tensor. Should be from a Variable node. 466 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 467 A tensor of indices into ref. 468 updates: A `Tensor`. Must have the same type as `ref`. 469 A tensor of updated values to add to ref. 470 use_locking: An optional `bool`. Defaults to `False`. 471 If True, the assignment will be protected by a lock; 472 otherwise the behavior is undefined, but may exhibit less contention. 473 name: A name for the operation (optional). 474 475 Returns: 476 A mutable `Tensor`. Has the same type as `ref`. 477 """ 478 if ref.dtype._is_ref_dtype: 479 return gen_state_ops.scatter_nd_add( 480 ref, indices, updates, use_locking, name) 481 return ref._lazy_read(gen_state_ops.resource_scatter_nd_add( # pylint: disable=protected-access 482 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 483 name=name)) 484 485 486@tf_export(v1=["scatter_sub"]) 487def scatter_sub(ref, indices, updates, use_locking=False, name=None): 488 r"""Subtracts sparse updates to a variable reference. 489 490 ```python 491 # Scalar indices 492 ref[indices, ...] -= updates[...] 493 494 # Vector indices (for each i) 495 ref[indices[i], ...] -= updates[i, ...] 496 497 # High rank indices (for each i, ..., j) 498 ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...] 499 ``` 500 501 This operation outputs `ref` after the update is done. 502 This makes it easier to chain operations that need to use the reset value. 503 504 Duplicate entries are handled correctly: if multiple `indices` reference 505 the same location, their (negated) contributions add. 506 507 Requires `updates.shape = indices.shape + ref.shape[1:]` or 508 `updates.shape = []`. 509 510 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> 511 <img style="width:100%" 512 src="https://www.tensorflow.org/images/ScatterSub.png" alt> 513 </div> 514 515 Args: 516 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 517 `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, 518 `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, 519 `uint32`, `uint64`. Should be from a `Variable` node. 520 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 521 A tensor of indices into the first dimension of `ref`. 522 updates: A `Tensor`. Must have the same type as `ref`. 523 A tensor of updated values to subtract from `ref`. 524 use_locking: An optional `bool`. Defaults to `False`. 525 If True, the subtraction will be protected by a lock; 526 otherwise the behavior is undefined, but may exhibit less contention. 527 name: A name for the operation (optional). 528 529 Returns: 530 A mutable `Tensor`. Has the same type as `ref`. 531 """ 532 if ref.dtype._is_ref_dtype: 533 return gen_state_ops.scatter_sub(ref, indices, updates, 534 use_locking=use_locking, name=name) 535 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_sub( # pylint: disable=protected-access 536 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 537 name=name)) 538 539 540@tf_export(v1=["scatter_nd_sub"]) 541def scatter_nd_sub(ref, indices, updates, use_locking=False, name=None): 542 r"""Applies sparse subtraction to individual values or slices in a Variable. 543 544 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 545 546 `indices` must be integer tensor, containing indices into `ref`. 547 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 548 549 The innermost dimension of `indices` (with length `K`) corresponds to 550 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 551 dimension of `ref`. 552 553 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 554 555 ``` 556 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]] 557 ``` 558 559 For example, say we want to subtract 4 scattered elements from a rank-1 tensor 560 with 8 elements. In Python, that update would look like this: 561 562 ```python 563 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 564 indices = tf.constant([[4], [3], [1] ,[7]]) 565 updates = tf.constant([9, 10, 11, 12]) 566 op = tf.compat.v1.scatter_nd_sub(ref, indices, updates) 567 with tf.compat.v1.Session() as sess: 568 print sess.run(op) 569 ``` 570 571 The resulting update to ref would look like this: 572 573 [1, -9, 3, -6, -6, 6, 7, -4] 574 575 See `tf.scatter_nd` for more details about how to make updates to 576 slices. 577 578 Args: 579 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 580 `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, 581 `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, 582 `uint32`, `uint64`. A mutable Tensor. Should be from a Variable node. 583 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 584 A tensor of indices into ref. 585 updates: A `Tensor`. Must have the same type as `ref`. 586 A tensor of updated values to add to ref. 587 use_locking: An optional `bool`. Defaults to `False`. 588 An optional bool. Defaults to True. If True, the assignment will 589 be protected by a lock; otherwise the behavior is undefined, 590 but may exhibit less contention. 591 name: A name for the operation (optional). 592 593 Returns: 594 A mutable `Tensor`. Has the same type as `ref`. 595 """ 596 if ref.dtype._is_ref_dtype: 597 return gen_state_ops.scatter_nd_sub( 598 ref, indices, updates, use_locking, name) 599 return ref._lazy_read(gen_state_ops.resource_scatter_nd_sub( # pylint: disable=protected-access 600 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 601 name=name)) 602 603 604@tf_export(v1=["scatter_mul"]) 605def scatter_mul(ref, indices, updates, use_locking=False, name=None): 606 # pylint: disable=line-too-long 607 r"""Multiplies sparse updates into a variable reference. 608 609 This operation computes 610 611 ```python 612 # Scalar indices 613 ref[indices, ...] *= updates[...] 614 615 # Vector indices (for each i) 616 ref[indices[i], ...] *= updates[i, ...] 617 618 # High rank indices (for each i, ..., j) 619 ref[indices[i, ..., j], ...] *= updates[i, ..., j, ...] 620 ``` 621 622 This operation outputs `ref` after the update is done. 623 This makes it easier to chain operations that need to use the reset value. 624 625 Duplicate entries are handled correctly: if multiple `indices` reference 626 the same location, their contributions multiply. 627 628 Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = 629 []`. 630 631 Args: 632 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 633 `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, 634 `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, 635 `uint32`, `uint64`. Should be from a `Variable` node. 636 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A 637 tensor of indices into the first dimension of `ref`. 638 updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated 639 values to multiply to `ref`. 640 use_locking: An optional `bool`. Defaults to `False`. If True, the operation 641 will be protected by a lock; otherwise the behavior is undefined, but may 642 exhibit less contention. 643 name: A name for the operation (optional). 644 645 Returns: 646 A mutable `Tensor`. Has the same type as `ref`. 647 """ 648 if ref.dtype._is_ref_dtype: 649 return gen_state_ops.scatter_mul(ref, indices, updates, 650 use_locking=use_locking, name=name) 651 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_mul( # pylint: disable=protected-access 652 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 653 name=name)) 654 655 656@tf_export(v1=["scatter_div"]) 657def scatter_div(ref, indices, updates, use_locking=False, name=None): 658 # pylint: disable=line-too-long 659 r"""Divides a variable reference by sparse updates. 660 661 This operation computes 662 663 ```python 664 # Scalar indices 665 ref[indices, ...] /= updates[...] 666 667 # Vector indices (for each i) 668 ref[indices[i], ...] /= updates[i, ...] 669 670 # High rank indices (for each i, ..., j) 671 ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...] 672 ``` 673 674 This operation outputs `ref` after the update is done. 675 This makes it easier to chain operations that need to use the reset value. 676 677 Duplicate entries are handled correctly: if multiple `indices` reference 678 the same location, their contributions divide. 679 680 Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = 681 []`. 682 683 Args: 684 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 685 `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, 686 `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, 687 `uint32`, `uint64`. Should be from a `Variable` node. 688 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A 689 tensor of indices into the first dimension of `ref`. 690 updates: A `Tensor`. Must have the same type as `ref`. A tensor of values 691 that `ref` is divided by. 692 use_locking: An optional `bool`. Defaults to `False`. If True, the operation 693 will be protected by a lock; otherwise the behavior is undefined, but may 694 exhibit less contention. 695 name: A name for the operation (optional). 696 697 Returns: 698 A mutable `Tensor`. Has the same type as `ref`. 699 """ 700 if ref.dtype._is_ref_dtype: 701 return gen_state_ops.scatter_div(ref, indices, updates, 702 use_locking=use_locking, name=name) 703 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_div( # pylint: disable=protected-access 704 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 705 name=name)) 706 707 708@tf_export(v1=["scatter_max"]) 709def scatter_max(ref, indices, updates, use_locking=False, name=None): 710 # pylint: disable=line-too-long 711 r"""Reduces sparse updates into a variable reference using the `max` operation. 712 713 This operation computes 714 715 # Scalar indices 716 ref[indices, ...] = max(ref[indices, ...], updates[...]) 717 718 # Vector indices (for each i) 719 ref[indices[i], ...] = max(ref[indices[i], ...], updates[i, ...]) 720 721 # High rank indices (for each i, ..., j) 722 ref[indices[i, ..., j], ...] = max(ref[indices[i, ..., j], ...], 723 updates[i, ..., j, ...]) 724 725 This operation outputs `ref` after the update is done. 726 This makes it easier to chain operations that need to use the reset value. 727 728 Duplicate entries are handled correctly: if multiple `indices` reference 729 the same location, their contributions combine. 730 731 Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = 732 []`. 733 734 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> 735 <img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png" 736 alt> 737 </div> 738 739 Args: 740 ref: A mutable `Tensor`. Must be one of the following types: `half`, 741 `bfloat16`, `float32`, `float64`, `int32`, `int64`. Should be from a 742 `Variable` node. 743 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A 744 tensor of indices into the first dimension of `ref`. 745 updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated 746 values to reduce into `ref`. 747 use_locking: An optional `bool`. Defaults to `False`. If True, the update 748 will be protected by a lock; otherwise the behavior is undefined, but may 749 exhibit less contention. 750 name: A name for the operation (optional). 751 752 Returns: 753 A mutable `Tensor`. Has the same type as `ref`. 754 """ 755 if ref.dtype._is_ref_dtype: 756 return gen_state_ops.scatter_max(ref, indices, updates, 757 use_locking=use_locking, name=name) 758 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_max( # pylint: disable=protected-access 759 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 760 name=name)) 761 762 763@tf_export(v1=["scatter_min"]) 764def scatter_min(ref, indices, updates, use_locking=False, name=None): 765 # pylint: disable=line-too-long 766 r"""Reduces sparse updates into a variable reference using the `min` operation. 767 768 This operation computes 769 770 # Scalar indices 771 ref[indices, ...] = min(ref[indices, ...], updates[...]) 772 773 # Vector indices (for each i) 774 ref[indices[i], ...] = min(ref[indices[i], ...], updates[i, ...]) 775 776 # High rank indices (for each i, ..., j) 777 ref[indices[i, ..., j], ...] = min(ref[indices[i, ..., j], ...], 778 updates[i, ..., j, ...]) 779 780 This operation outputs `ref` after the update is done. 781 This makes it easier to chain operations that need to use the reset value. 782 783 Duplicate entries are handled correctly: if multiple `indices` reference 784 the same location, their contributions combine. 785 786 Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = 787 []`. 788 789 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> 790 <img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png" 791 alt> 792 </div> 793 794 Args: 795 ref: A mutable `Tensor`. Must be one of the following types: `half`, 796 `bfloat16`, `float32`, `float64`, `int32`, `int64`. Should be from a 797 `Variable` node. 798 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A 799 tensor of indices into the first dimension of `ref`. 800 updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated 801 values to reduce into `ref`. 802 use_locking: An optional `bool`. Defaults to `False`. If True, the update 803 will be protected by a lock; otherwise the behavior is undefined, but may 804 exhibit less contention. 805 name: A name for the operation (optional). 806 807 Returns: 808 A mutable `Tensor`. Has the same type as `ref`. 809 """ 810 if ref.dtype._is_ref_dtype: 811 return gen_state_ops.scatter_min(ref, indices, updates, 812 use_locking=use_locking, name=name) 813 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_min( # pylint: disable=protected-access 814 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 815 name=name)) 816 817 818@tf_export(v1=["batch_scatter_update"]) 819@deprecation.deprecated( 820 "2018-11-29", "Use the batch_scatter_update method of Variable instead.") 821def batch_scatter_update(ref, indices, updates, use_locking=True, name=None): 822 """Generalization of `tf.compat.v1.scatter_update` to axis different than 0. 823 824 Analogous to `batch_gather`. This assumes that `ref`, `indices` and `updates` 825 have a series of leading dimensions that are the same for all of them, and the 826 updates are performed on the last dimension of indices. In other words, the 827 dimensions should be the following: 828 829 `num_prefix_dims = indices.ndims - 1` 830 `batch_dim = num_prefix_dims + 1` 831 `updates.shape = indices.shape + var.shape[batch_dim:]` 832 833 where 834 835 `updates.shape[:num_prefix_dims]` 836 `== indices.shape[:num_prefix_dims]` 837 `== var.shape[:num_prefix_dims]` 838 839 And the operation performed can be expressed as: 840 841 `var[i_1, ..., i_n, indices[i_1, ..., i_n, j]] = updates[i_1, ..., i_n, j]` 842 843 When indices is a 1D tensor, this operation is equivalent to 844 `tf.compat.v1.scatter_update`. 845 846 To avoid this operation there would be 2 alternatives: 847 1) Reshaping the variable by merging the first `ndims` dimensions. However, 848 this is not possible because `tf.reshape` returns a Tensor, which we 849 cannot use `tf.compat.v1.scatter_update` on. 850 2) Looping over the first `ndims` of the variable and using 851 `tf.compat.v1.scatter_update` on the subtensors that result of slicing the 852 first 853 dimension. This is a valid option for `ndims = 1`, but less efficient than 854 this implementation. 855 856 See also `tf.compat.v1.scatter_update` and `tf.compat.v1.scatter_nd_update`. 857 858 Args: 859 ref: `Variable` to scatter onto. 860 indices: Tensor containing indices as described above. 861 updates: Tensor of updates to apply to `ref`. 862 use_locking: Boolean indicating whether to lock the writing operation. 863 name: Optional scope name string. 864 865 Returns: 866 Ref to `variable` after it has been modified. 867 868 Raises: 869 ValueError: If the initial `ndims` of `ref`, `indices`, and `updates` are 870 not the same. 871 """ 872 with ops.name_scope(name): 873 indices = ops.convert_to_tensor(indices, name="indices") 874 indices_shape = array_ops.shape(indices) 875 indices_dimensions = indices.get_shape().ndims 876 877 if indices_dimensions is None: 878 raise ValueError("batch_gather does not allow indices with unknown " 879 "shape.") 880 881 nd_indices = array_ops.expand_dims(indices, axis=-1) 882 nd_indices_list = [] 883 884 # Scatter ND requires indices to have an additional dimension, in which the 885 # coordinates of the updated things are specified. For this to be adapted to 886 # the scatter_update with several leading dimensions, we simply make use of 887 # a tf.range for all the leading dimensions followed by concat of all the 888 # coordinates we created with the original indices. 889 890 # For example if indices.shape = [2, 3, 4], we should generate the following 891 # indices for tf.compat.v1.scatter_nd_update: 892 # nd_indices[:, :, 0] = [[0, 0, 0], [1, 1, 1]] 893 # nd_indices[:, :, 1] = [[0, 1, 2], [0, 1, 2]] 894 # nd_indices[:, :, 2] = indices 895 for dimension in range(indices_dimensions - 1): 896 # In this loop we generate the following for the example (one for each 897 # iteration). 898 # nd_indices[:, :, 0] = [[0, 0, 0], [1, 1, 1]] 899 # nd_indices[:, :, 1] = [[0, 1, 2], [0, 1, 2]] 900 # This is done at every iteration with a tf.range over the size of the 901 # i-th dimension and using broadcasting over the desired shape. 902 dimension_size = indices_shape[dimension] 903 shape_to_broadcast = [1] * (indices_dimensions + 1) 904 shape_to_broadcast[dimension] = dimension_size 905 dimension_range = array_ops.reshape( 906 gen_math_ops._range(0, dimension_size, 1), shape_to_broadcast) 907 if dimension_range.dtype.base_dtype != nd_indices.dtype: 908 dimension_range = gen_math_ops.cast(dimension_range, nd_indices.dtype) 909 nd_indices_list.append( 910 dimension_range * array_ops.ones_like(nd_indices)) 911 # Add the original indices at the end, as described above, and concat. 912 nd_indices_list.append(nd_indices) 913 final_indices = array_ops.concat(nd_indices_list, axis=-1) 914 return scatter_nd_update( 915 ref, final_indices, updates, use_locking=use_locking) 916