1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Functional tests for scan ops.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import errors_impl 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import test_util 28from tensorflow.python.ops import gradient_checker 29from tensorflow.python.ops import math_ops 30from tensorflow.python.platform import test 31 32 33def numpy_reverse(x, axis): 34 length = len(x.shape) 35 if axis < 0: 36 axis = length + axis 37 38 ix = [ 39 slice(None, None, -1) if i == axis else slice(None) for i in range(length) 40 ] 41 return x[ix] 42 43 44def handle_options(func, x, axis, exclusive, reverse): 45 """Adds tf options to numpy scan ops.""" 46 length = len(x.shape) 47 if axis < 0: 48 axis = length + axis 49 50 if reverse: 51 x = numpy_reverse(x, axis) 52 53 if exclusive: 54 ix_head = [slice(0, 1) if i == axis else slice(None) for i in range(length)] 55 ix_init = [ 56 slice(0, -1) if i == axis else slice(None) for i in range(length) 57 ] 58 if func == np.cumsum: 59 init = np.zeros_like(x[ix_head]) 60 elif func == np.cumprod: 61 init = np.ones_like(x[ix_head]) 62 else: 63 raise ValueError("Unknown scan function.") 64 x = np.concatenate([init, func(x[ix_init], axis)], axis=axis) 65 else: 66 x = func(x, axis=axis) 67 68 if reverse: 69 x = numpy_reverse(x, axis) 70 return x 71 72 73class CumsumTest(test.TestCase): 74 75 valid_dtypes = [ 76 np.int32, np.int64, np.float16, np.float32, np.float64, np.complex64, 77 np.complex128 78 ] 79 80 def _compare(self, x, axis, exclusive, reverse): 81 np_out = handle_options(np.cumsum, x, axis, exclusive, reverse) 82 with self.cached_session(use_gpu=True): 83 tf_out = math_ops.cumsum(x, axis, exclusive, reverse).eval() 84 85 self.assertAllClose(np_out, tf_out) 86 87 def _compareAll(self, x, axis): 88 for exclusive in [True, False]: 89 for reverse in [True, False]: 90 self._compare(x, axis, exclusive, reverse) 91 92 @test_util.run_deprecated_v1 93 def testEmpty(self): 94 for dtype in self.valid_dtypes: 95 x = np.zeros([0]).astype(dtype) 96 for axis in (-1, 0): 97 self._compareAll(x, axis) 98 99 @test_util.run_deprecated_v1 100 def testAxisType(self): 101 for dtype in self.valid_dtypes: 102 x = np.arange(1, 6).reshape([5]).astype(dtype) 103 for axis_dtype in [dtypes.int64, dtypes.int32]: 104 with self.cached_session(use_gpu=True): 105 axis = constant_op.constant(0, axis_dtype) 106 tf_out = math_ops.cumsum(x, axis).eval() 107 108 @test_util.run_deprecated_v1 109 def test1D(self): 110 for dtype in self.valid_dtypes: 111 x = np.arange(1, 6).reshape([5]).astype(dtype) 112 for axis in (-1, 0): 113 self._compareAll(x, axis) 114 115 @test_util.run_deprecated_v1 116 def test2D(self): 117 for dtype in self.valid_dtypes: 118 x = np.arange(0, 10).reshape([2, 5]).astype(dtype) 119 for axis in (-2, -1, 0, 1): 120 self._compareAll(x, axis) 121 122 @test_util.run_deprecated_v1 123 def test3D(self): 124 for dtype in self.valid_dtypes: 125 x = np.arange(0, 20).reshape([2, 2, 5]).astype(dtype) 126 for axis in (-3, -2, -1, 0, 1, 2): 127 self._compareAll(x, axis) 128 129 @test_util.run_deprecated_v1 130 def test6D(self): 131 for dtype in self.valid_dtypes: 132 x = np.arange(1, 145).reshape([2, 2, 3, 3, 2, 2]).astype(dtype) 133 for axis in range(-6, 6, 3): 134 self._compareAll(x, axis) 135 136 @test_util.run_deprecated_v1 137 @test_util.disable_xla("b/123860949") # The computation is constant folded 138 def testLarge(self): 139 for dtype in self.valid_dtypes: 140 x = np.ones([1000000], dtype=dtype) / 1024 141 self._compareAll(x, 0) 142 143 def testInvalidAxis(self): 144 x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) 145 input_tensor = ops.convert_to_tensor(x) 146 with self.session(use_gpu=True): 147 with self.assertRaisesWithPredicateMatch( 148 errors_impl.InvalidArgumentError, 149 lambda e: "Expected scan axis in the range [-2, 2)" in str(e)): 150 math_ops.cumsum(input_tensor, -3).eval() 151 with self.assertRaisesWithPredicateMatch( 152 errors_impl.InvalidArgumentError, 153 lambda e: "Expected scan axis in the range [-2, 2)" in str(e)): 154 math_ops.cumsum(input_tensor, 2).eval() 155 with self.assertRaisesWithPredicateMatch( 156 errors_impl.InvalidArgumentError, 157 lambda e: "axis must be a scalar" in str(e)): 158 math_ops.cumsum(input_tensor, [0]).eval() 159 160 def _compareGradient(self, shape, axis, exclusive, reverse): 161 x = np.arange(0, 50).reshape(shape).astype(np.float64) 162 with self.cached_session(use_gpu=True): 163 t = ops.convert_to_tensor(x) 164 result = math_ops.cumsum(t, axis, exclusive, reverse) 165 jacob_t, jacob_n = gradient_checker.compute_gradient( 166 t, shape, result, shape, x_init_value=x, delta=1) 167 self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8) 168 169 @test_util.run_deprecated_v1 170 def testGradient(self): 171 for axis in (-1, 0): 172 self._compareGradient([50], axis, False, False) 173 174 @test_util.run_deprecated_v1 175 def testGradientReverse(self): 176 for axis in (-1, 0): 177 self._compareGradient([50], axis, False, True) 178 179 @test_util.run_deprecated_v1 180 def testGradientExclusive(self): 181 for axis in (-1, 0): 182 self._compareGradient([50], axis, True, False) 183 184 @test_util.run_deprecated_v1 185 def testGradientExclusiveReverse(self): 186 for axis in (-1, 0): 187 self._compareGradient([50], axis, True, True) 188 189 @test_util.run_deprecated_v1 190 def testGradient2D(self): 191 for axis in (-1, 0, 1): 192 for exclusive in [True, False]: 193 for reverse in [True, False]: 194 self._compareGradient([5, 10], axis, exclusive, reverse) 195 196 197class CumprodTest(test.TestCase): 198 199 valid_dtypes = [ 200 np.int32, np.int64, np.float16, np.float32, np.float64, np.complex64, 201 np.complex128 202 ] 203 204 def _compare(self, x, axis, exclusive, reverse): 205 np_out = handle_options(np.cumprod, x, axis, exclusive, reverse) 206 with self.cached_session(use_gpu=True): 207 tf_out = math_ops.cumprod(x, axis, exclusive, reverse).eval() 208 209 self.assertAllClose(np_out, tf_out) 210 211 def _compareAll(self, x, axis): 212 for exclusive in [True, False]: 213 for reverse in [True, False]: 214 self._compare(x, axis, exclusive, reverse) 215 216 @test_util.run_deprecated_v1 217 def testEmpty(self): 218 for dtype in self.valid_dtypes: 219 x = np.zeros([0]).astype(dtype) 220 for axis in (-1, 0): 221 self._compareAll(x, axis) 222 223 @test_util.run_deprecated_v1 224 def testAxisType(self): 225 for dtype in self.valid_dtypes: 226 x = np.arange(1, 6).reshape([5]).astype(dtype) 227 for axis_dtype in [dtypes.int64, dtypes.int32]: 228 with self.cached_session(use_gpu=True): 229 axis = constant_op.constant(0, axis_dtype) 230 tf_out = math_ops.cumprod(x, axis).eval() 231 232 @test_util.run_deprecated_v1 233 def test1D(self): 234 for dtype in self.valid_dtypes: 235 x = np.arange(1, 6).reshape([5]).astype(dtype) 236 for axis in (-1, 0): 237 self._compareAll(x, axis) 238 239 @test_util.run_deprecated_v1 240 def test2D(self): 241 for dtype in self.valid_dtypes: 242 x = np.arange(1, 11).reshape([2, 5]).astype(dtype) 243 for axis in (-2, -1, 0, 1): 244 self._compareAll(x, axis) 245 246 @test_util.run_deprecated_v1 247 def test3D(self): 248 for dtype in self.valid_dtypes: 249 x = np.arange(1, 21).reshape([2, 2, 5]).astype(dtype) 250 for axis in (-3, -2, -1, 0, 1, 2): 251 self._compareAll(x, axis) 252 253 @test_util.run_deprecated_v1 254 def test6D(self): 255 for dtype in self.valid_dtypes: 256 x = np.arange(1, 145).reshape([2, 2, 3, 3, 2, 2]).astype(dtype) 257 for axis in range(-6, 6, 3): 258 self._compareAll(x, axis) 259 260 def testInvalidAxis(self): 261 x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) 262 input_tensor = ops.convert_to_tensor(x) 263 with self.session(use_gpu=True): 264 with self.assertRaisesWithPredicateMatch( 265 errors_impl.InvalidArgumentError, 266 lambda e: "Expected scan axis in the range [-2, 2)" in str(e)): 267 math_ops.cumprod(input_tensor, -3).eval() 268 with self.assertRaisesWithPredicateMatch( 269 errors_impl.InvalidArgumentError, 270 lambda e: "Expected scan axis in the range [-2, 2)" in str(e)): 271 math_ops.cumprod(input_tensor, 2).eval() 272 with self.assertRaisesWithPredicateMatch( 273 errors_impl.InvalidArgumentError, 274 lambda e: "axis must be a scalar" in str(e)): 275 math_ops.cumprod(input_tensor, [0]).eval() 276 277 def _compareGradient(self, shape, axis, exclusive, reverse): 278 x = np.arange(1, 9).reshape(shape).astype(np.float64) 279 with self.cached_session(use_gpu=True): 280 t = ops.convert_to_tensor(x) 281 result = math_ops.cumprod(t, axis, exclusive, reverse) 282 jacob_t, jacob_n = gradient_checker.compute_gradient( 283 t, shape, result, shape, x_init_value=x, delta=1) 284 self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8) 285 286 @test_util.run_deprecated_v1 287 def testGradient(self): 288 for axis in (-1, 0): 289 self._compareGradient([8], axis, False, False) 290 291 @test_util.run_deprecated_v1 292 def testGradientReverse(self): 293 for axis in (-1, 0): 294 self._compareGradient([8], axis, False, True) 295 296 @test_util.run_deprecated_v1 297 def testGradientExclusive(self): 298 for axis in (-1, 0): 299 self._compareGradient([8], axis, True, False) 300 301 @test_util.run_deprecated_v1 302 def testGradientExclusiveReverse(self): 303 for axis in (-1, 0): 304 self._compareGradient([8], axis, True, True) 305 306 @test_util.run_deprecated_v1 307 def testGradient2D(self): 308 for axis in (-2, -1, 0, 1): 309 for exclusive in [True, False]: 310 for reverse in [True, False]: 311 self._compareGradient([2, 4], axis, exclusive, reverse) 312 313 314if __name__ == "__main__": 315 test.main() 316