1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Inplace operations.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import gen_array_ops
26from tensorflow.python.ops import math_ops
27
28
29def _inplace_helper(x, i, v, op):
30  """Applies an inplace op on (x, i, v).
31
32  op is one of gen_array_ops.alias_inplace_update,
33  gen_array_ops.alias_inplace_add, or gen_array_ops.alias_inplace_sub.
34
35  If i is None, x and v must be the same shape. Computes
36    x op v;
37  If i is a scalar, x has a rank 1 higher than v's. Computes
38    x[i, :] op v;
39  Otherwise, x and v must have the same rank. Computes
40    x[i, :] op v;
41
42  Args:
43    x: A Tensor.
44    i: None, a scalar or a vector.
45    v: A Tensor.
46    op: alias_inplace_update, alias_inplace_add, or alias_inplace_sub.
47
48  Returns:
49    Returns x.
50
51  """
52  x = ops.convert_to_tensor(x)
53  v = ops.convert_to_tensor(v, x.dtype)
54  if i is None:
55    # Full tensor.
56    return array_ops.reshape(
57        op(array_ops.reshape(x, [1, -1]), [0], array_ops.reshape(v, [1, -1])),
58        array_ops.shape(x))
59  i = math_ops.cast(i, dtypes.int32)
60  if i.get_shape().ndims == 0:
61    # Single 0-dim update.
62    return op(x, array_ops.reshape(i, [1]), array_ops.expand_dims(v, 0))
63  return op(x, i, v)
64
65
66def alias_inplace_update(x, i, v):
67  """Applies an inplace update on input x at index i with value v. Aliases x.
68
69  If i is None, x and v must be the same shape. Computes
70    x = v;
71  If i is a scalar, x has a rank 1 higher than v's. Computes
72    x[i, :] = v;
73  Otherwise, x and v must have the same rank. Computes
74    x[i, :] = v;
75
76  Args:
77    x: A Tensor.
78    i: None, a scalar or a vector.
79    v: A Tensor.
80
81  Returns:
82    Returns x.
83
84  """
85  return _inplace_helper(x, i, v, gen_array_ops.inplace_update)
86
87
88def alias_inplace_add(x, i, v):
89  """Applies an inplace add on input x at index i with value v. Aliases x.
90
91  If i is None, x and v must be the same shape. Computes
92    x += v;
93  If i is a scalar, x has a rank 1 higher than v's. Computes
94    x[i, :] += v;
95  Otherwise, x and v must have the same rank. Computes
96    x[i, :] += v;
97
98  Args:
99    x: A Tensor.
100    i: None, a scalar or a vector.
101    v: A Tensor.
102
103  Returns:
104    Returns x.
105
106  """
107  return _inplace_helper(x, i, v, gen_array_ops.inplace_add)
108
109
110def alias_inplace_sub(x, i, v):
111  """Applies an inplace sub on input x at index i with value v. Aliases x.
112
113  If i is None, x and v must be the same shape. Computes
114    x -= v;
115  If i is a scalar, x has a rank 1 higher than v's. Computes
116    x[i, :] -= v;
117  Otherwise, x and v must have the same rank. Computes
118    x[i, :] -= v;
119
120  Args:
121    x: A Tensor.
122    i: None, a scalar or a vector.
123    v: A Tensor.
124
125  Returns:
126    Returns x.
127
128  """
129  return _inplace_helper(x, i, v, gen_array_ops.inplace_sub)
130
131
132def empty_like(x, init=None):
133  """Returns a non-initialized tensor with the same shape and dtype as x.
134
135  Args:
136    x: A Tensor.
137    init: Initialize the returned tensor with the default value of
138      x.dtype(), if True. Otherwise, do not initialize. Defaults to
139      None.
140
141  Returns:
142    A tensor y, whose dtype and shape are the same as those of x.
143    y is guaranteed not to be an alias of x. Upon return, y may contain
144    arbitrary data.
145
146  """
147  x = ops.convert_to_tensor(x)
148  return gen_array_ops.empty(array_ops.shape(x), x.dtype, init=init)
149
150
151def inplace_update(x, i, v):
152  """Applies an inplace update on input x at index i with value v.
153
154  Note that this function is not actually inplace - it allocates
155  a copy of x.  The utility is not avoiding memory copies but rather
156  specifying a sparse update.
157
158  If i is None, x and v must be the same shape. Computes
159    y = x; y = v;
160  If i is a scalar, x has a rank 1 higher than v's. Computes
161    y = x; y[i, :] = v;
162  Otherwise, x and v must have the same rank. Computes
163    y = x; y[i, :] = v;
164
165  Args:
166    x: A Tensor.
167    i: None, a scalar or a vector.
168    v: A Tensor.
169
170  Returns:
171    Returns y, which is guaranteed not to be an alias of x.
172
173  """
174  return alias_inplace_update(gen_array_ops.deep_copy(x), i, v)
175
176
177def inplace_add(x, i, v):
178  """Applies an inplace add on input x at index i with value v.
179
180  Note that this function is not actually inplace - it allocates
181  a copy of x.  The utility is not avoiding memory copies but rather
182  specifying a sparse update.
183
184  If i is None, x and v must be the same shape. Computes
185    y = x; y += v;
186  If i is a scalar, x has a rank 1 higher than v's. Computes
187    y = x; y[i, :] += v;
188  Otherwise, x and v must have the same rank. Computes
189    y = x; y[i, :] += v;
190
191  Args:
192    x: A Tensor.
193    i: None, a scalar or a vector.
194    v: A Tensor.
195
196  Returns:
197    Returns y, which is guaranteed not to be an alias of x.
198
199  """
200  return alias_inplace_add(gen_array_ops.deep_copy(x), i, v)
201
202
203def inplace_sub(x, i, v):
204  """Applies an inplace sub on input x at index i with value v.
205
206  Note that this function is not actually inplace - it allocates
207  a copy of x.  The utility is not avoiding memory copies but rather
208  specifying a sparse update.
209
210  If i is None, x and v must be the same shape. Computes
211    y = x; y -= v;
212  If i is a scalar, x has a rank 1 higher than v's. Computes
213    y = x; y[i, :] -= v;
214  Otherwise, x and v must have the same rank. Computes
215    y = x; y[i, :] -= v;
216
217  Args:
218    x: A Tensor.
219    i: None, a scalar or a vector.
220    v: A Tensor.
221
222  Returns:
223    Returns y, which is guaranteed not to be an alias of x.
224
225  """
226  return alias_inplace_sub(gen_array_ops.deep_copy(x), i, v)
227
228empty = gen_array_ops.empty
229