1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functional tests for scan ops."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import errors_impl
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import test_util
28from tensorflow.python.ops import gradient_checker
29from tensorflow.python.ops import math_ops
30from tensorflow.python.platform import test
31
32
33def numpy_reverse(x, axis):
34  length = len(x.shape)
35  if axis < 0:
36    axis = length + axis
37
38  ix = [
39      slice(None, None, -1) if i == axis else slice(None) for i in range(length)
40  ]
41  return x[ix]
42
43
44def handle_options(func, x, axis, exclusive, reverse):
45  """Adds tf options to numpy scan ops."""
46  length = len(x.shape)
47  if axis < 0:
48    axis = length + axis
49
50  if reverse:
51    x = numpy_reverse(x, axis)
52
53  if exclusive:
54    ix_head = [slice(0, 1) if i == axis else slice(None) for i in range(length)]
55    ix_init = [
56        slice(0, -1) if i == axis else slice(None) for i in range(length)
57    ]
58    if func == np.cumsum:
59      init = np.zeros_like(x[ix_head])
60    elif func == np.cumprod:
61      init = np.ones_like(x[ix_head])
62    else:
63      raise ValueError("Unknown scan function.")
64    x = np.concatenate([init, func(x[ix_init], axis)], axis=axis)
65  else:
66    x = func(x, axis=axis)
67
68  if reverse:
69    x = numpy_reverse(x, axis)
70  return x
71
72
73class CumsumTest(test.TestCase):
74
75  valid_dtypes = [
76      np.int32, np.int64, np.float16, np.float32, np.float64, np.complex64,
77      np.complex128
78  ]
79
80  def _compare(self, x, axis, exclusive, reverse):
81    np_out = handle_options(np.cumsum, x, axis, exclusive, reverse)
82    with self.cached_session(use_gpu=True):
83      tf_out = math_ops.cumsum(x, axis, exclusive, reverse).eval()
84
85    self.assertAllClose(np_out, tf_out)
86
87  def _compareAll(self, x, axis):
88    for exclusive in [True, False]:
89      for reverse in [True, False]:
90        self._compare(x, axis, exclusive, reverse)
91
92  @test_util.run_deprecated_v1
93  def testEmpty(self):
94    for dtype in self.valid_dtypes:
95      x = np.zeros([0]).astype(dtype)
96      for axis in (-1, 0):
97        self._compareAll(x, axis)
98
99  @test_util.run_deprecated_v1
100  def testAxisType(self):
101    for dtype in self.valid_dtypes:
102      x = np.arange(1, 6).reshape([5]).astype(dtype)
103      for axis_dtype in [dtypes.int64, dtypes.int32]:
104        with self.cached_session(use_gpu=True):
105          axis = constant_op.constant(0, axis_dtype)
106          tf_out = math_ops.cumsum(x, axis).eval()
107
108  @test_util.run_deprecated_v1
109  def test1D(self):
110    for dtype in self.valid_dtypes:
111      x = np.arange(1, 6).reshape([5]).astype(dtype)
112      for axis in (-1, 0):
113        self._compareAll(x, axis)
114
115  @test_util.run_deprecated_v1
116  def test2D(self):
117    for dtype in self.valid_dtypes:
118      x = np.arange(0, 10).reshape([2, 5]).astype(dtype)
119      for axis in (-2, -1, 0, 1):
120        self._compareAll(x, axis)
121
122  @test_util.run_deprecated_v1
123  def test3D(self):
124    for dtype in self.valid_dtypes:
125      x = np.arange(0, 20).reshape([2, 2, 5]).astype(dtype)
126      for axis in (-3, -2, -1, 0, 1, 2):
127        self._compareAll(x, axis)
128
129  @test_util.run_deprecated_v1
130  def test6D(self):
131    for dtype in self.valid_dtypes:
132      x = np.arange(1, 145).reshape([2, 2, 3, 3, 2, 2]).astype(dtype)
133      for axis in range(-6, 6, 3):
134        self._compareAll(x, axis)
135
136  @test_util.run_deprecated_v1
137  @test_util.disable_xla("b/123860949")  # The computation is constant folded
138  def testLarge(self):
139    for dtype in self.valid_dtypes:
140      x = np.ones([1000000], dtype=dtype) / 1024
141      self._compareAll(x, 0)
142
143  def testInvalidAxis(self):
144    x = np.arange(0, 10).reshape([2, 5]).astype(np.float32)
145    input_tensor = ops.convert_to_tensor(x)
146    with self.session(use_gpu=True):
147      with self.assertRaisesWithPredicateMatch(
148          errors_impl.InvalidArgumentError,
149          lambda e: "Expected scan axis in the range [-2, 2)" in str(e)):
150        math_ops.cumsum(input_tensor, -3).eval()
151      with self.assertRaisesWithPredicateMatch(
152          errors_impl.InvalidArgumentError,
153          lambda e: "Expected scan axis in the range [-2, 2)" in str(e)):
154        math_ops.cumsum(input_tensor, 2).eval()
155      with self.assertRaisesWithPredicateMatch(
156          errors_impl.InvalidArgumentError,
157          lambda e: "axis must be a scalar" in str(e)):
158        math_ops.cumsum(input_tensor, [0]).eval()
159
160  def _compareGradient(self, shape, axis, exclusive, reverse):
161    x = np.arange(0, 50).reshape(shape).astype(np.float64)
162    with self.cached_session(use_gpu=True):
163      t = ops.convert_to_tensor(x)
164      result = math_ops.cumsum(t, axis, exclusive, reverse)
165      jacob_t, jacob_n = gradient_checker.compute_gradient(
166          t, shape, result, shape, x_init_value=x, delta=1)
167    self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
168
169  @test_util.run_deprecated_v1
170  def testGradient(self):
171    for axis in (-1, 0):
172      self._compareGradient([50], axis, False, False)
173
174  @test_util.run_deprecated_v1
175  def testGradientReverse(self):
176    for axis in (-1, 0):
177      self._compareGradient([50], axis, False, True)
178
179  @test_util.run_deprecated_v1
180  def testGradientExclusive(self):
181    for axis in (-1, 0):
182      self._compareGradient([50], axis, True, False)
183
184  @test_util.run_deprecated_v1
185  def testGradientExclusiveReverse(self):
186    for axis in (-1, 0):
187      self._compareGradient([50], axis, True, True)
188
189  @test_util.run_deprecated_v1
190  def testGradient2D(self):
191    for axis in (-1, 0, 1):
192      for exclusive in [True, False]:
193        for reverse in [True, False]:
194          self._compareGradient([5, 10], axis, exclusive, reverse)
195
196
197class CumprodTest(test.TestCase):
198
199  valid_dtypes = [
200      np.int32, np.int64, np.float16, np.float32, np.float64, np.complex64,
201      np.complex128
202  ]
203
204  def _compare(self, x, axis, exclusive, reverse):
205    np_out = handle_options(np.cumprod, x, axis, exclusive, reverse)
206    with self.cached_session(use_gpu=True):
207      tf_out = math_ops.cumprod(x, axis, exclusive, reverse).eval()
208
209    self.assertAllClose(np_out, tf_out)
210
211  def _compareAll(self, x, axis):
212    for exclusive in [True, False]:
213      for reverse in [True, False]:
214        self._compare(x, axis, exclusive, reverse)
215
216  @test_util.run_deprecated_v1
217  def testEmpty(self):
218    for dtype in self.valid_dtypes:
219      x = np.zeros([0]).astype(dtype)
220      for axis in (-1, 0):
221        self._compareAll(x, axis)
222
223  @test_util.run_deprecated_v1
224  def testAxisType(self):
225    for dtype in self.valid_dtypes:
226      x = np.arange(1, 6).reshape([5]).astype(dtype)
227      for axis_dtype in [dtypes.int64, dtypes.int32]:
228        with self.cached_session(use_gpu=True):
229          axis = constant_op.constant(0, axis_dtype)
230          tf_out = math_ops.cumprod(x, axis).eval()
231
232  @test_util.run_deprecated_v1
233  def test1D(self):
234    for dtype in self.valid_dtypes:
235      x = np.arange(1, 6).reshape([5]).astype(dtype)
236      for axis in (-1, 0):
237        self._compareAll(x, axis)
238
239  @test_util.run_deprecated_v1
240  def test2D(self):
241    for dtype in self.valid_dtypes:
242      x = np.arange(1, 11).reshape([2, 5]).astype(dtype)
243      for axis in (-2, -1, 0, 1):
244        self._compareAll(x, axis)
245
246  @test_util.run_deprecated_v1
247  def test3D(self):
248    for dtype in self.valid_dtypes:
249      x = np.arange(1, 21).reshape([2, 2, 5]).astype(dtype)
250      for axis in (-3, -2, -1, 0, 1, 2):
251        self._compareAll(x, axis)
252
253  @test_util.run_deprecated_v1
254  def test6D(self):
255    for dtype in self.valid_dtypes:
256      x = np.arange(1, 145).reshape([2, 2, 3, 3, 2, 2]).astype(dtype)
257      for axis in range(-6, 6, 3):
258        self._compareAll(x, axis)
259
260  def testInvalidAxis(self):
261    x = np.arange(0, 10).reshape([2, 5]).astype(np.float32)
262    input_tensor = ops.convert_to_tensor(x)
263    with self.session(use_gpu=True):
264      with self.assertRaisesWithPredicateMatch(
265          errors_impl.InvalidArgumentError,
266          lambda e: "Expected scan axis in the range [-2, 2)" in str(e)):
267        math_ops.cumprod(input_tensor, -3).eval()
268      with self.assertRaisesWithPredicateMatch(
269          errors_impl.InvalidArgumentError,
270          lambda e: "Expected scan axis in the range [-2, 2)" in str(e)):
271        math_ops.cumprod(input_tensor, 2).eval()
272      with self.assertRaisesWithPredicateMatch(
273          errors_impl.InvalidArgumentError,
274          lambda e: "axis must be a scalar" in str(e)):
275        math_ops.cumprod(input_tensor, [0]).eval()
276
277  def _compareGradient(self, shape, axis, exclusive, reverse):
278    x = np.arange(1, 9).reshape(shape).astype(np.float64)
279    with self.cached_session(use_gpu=True):
280      t = ops.convert_to_tensor(x)
281      result = math_ops.cumprod(t, axis, exclusive, reverse)
282      jacob_t, jacob_n = gradient_checker.compute_gradient(
283          t, shape, result, shape, x_init_value=x, delta=1)
284    self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
285
286  @test_util.run_deprecated_v1
287  def testGradient(self):
288    for axis in (-1, 0):
289      self._compareGradient([8], axis, False, False)
290
291  @test_util.run_deprecated_v1
292  def testGradientReverse(self):
293    for axis in (-1, 0):
294      self._compareGradient([8], axis, False, True)
295
296  @test_util.run_deprecated_v1
297  def testGradientExclusive(self):
298    for axis in (-1, 0):
299      self._compareGradient([8], axis, True, False)
300
301  @test_util.run_deprecated_v1
302  def testGradientExclusiveReverse(self):
303    for axis in (-1, 0):
304      self._compareGradient([8], axis, True, True)
305
306  @test_util.run_deprecated_v1
307  def testGradient2D(self):
308    for axis in (-2, -1, 0, 1):
309      for exclusive in [True, False]:
310        for reverse in [True, False]:
311          self._compareGradient([2, 4], axis, exclusive, reverse)
312
313
314if __name__ == "__main__":
315  test.main()
316