1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.ops.math_ops.matrix_inverse.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python import tf2 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import test_util 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import gradient_checker 28from tensorflow.python.ops import linalg_ops 29from tensorflow.python.ops import math_ops 30from tensorflow.python.ops import random_ops 31from tensorflow.python.platform import test 32 33 34def _AddTest(test_class, op_name, testcase_name, fn): 35 test_name = "_".join(["test", op_name, testcase_name]) 36 if hasattr(test_class, test_name): 37 raise RuntimeError("Test %s defined more than once" % test_name) 38 setattr(test_class, test_name, fn) 39 40 41class QrOpTest(test.TestCase): 42 43 @test_util.run_v1_only("b/120545219") 44 def testWrongDimensions(self): 45 # The input to qr should be a tensor of at least rank 2. 46 scalar = constant_op.constant(1.) 47 with self.assertRaisesRegexp(ValueError, 48 "Shape must be at least rank 2 but is rank 0"): 49 linalg_ops.qr(scalar) 50 vector = constant_op.constant([1., 2.]) 51 with self.assertRaisesRegexp(ValueError, 52 "Shape must be at least rank 2 but is rank 1"): 53 linalg_ops.qr(vector) 54 55 @test_util.run_deprecated_v1 56 def testConcurrentExecutesWithoutError(self): 57 with self.session(use_gpu=True) as sess: 58 all_ops = [] 59 for full_matrices_ in True, False: 60 for rows_ in 4, 5: 61 for cols_ in 4, 5: 62 matrix1 = random_ops.random_normal([rows_, cols_], seed=42) 63 matrix2 = random_ops.random_normal([rows_, cols_], seed=42) 64 q1, r1 = linalg_ops.qr(matrix1, full_matrices=full_matrices_) 65 q2, r2 = linalg_ops.qr(matrix2, full_matrices=full_matrices_) 66 all_ops += [q1, r1, q2, r2] 67 val = self.evaluate(all_ops) 68 for i in range(8): 69 q = 4 * i 70 self.assertAllClose(val[q], val[q + 2]) # q1 == q2 71 self.assertAllClose(val[q + 1], val[q + 3]) # r1 == r2 72 73 74def _GetQrOpTest(dtype_, shape_, full_matrices_, use_static_shape_): 75 76 is_complex = dtype_ in (np.complex64, np.complex128) 77 is_single = dtype_ in (np.float32, np.complex64) 78 79 def CompareOrthogonal(self, x, y, rank): 80 if is_single: 81 atol = 5e-4 82 else: 83 atol = 5e-14 84 # We only compare the first 'rank' orthogonal vectors since the 85 # remainder form an arbitrary orthonormal basis for the 86 # (row- or column-) null space, whose exact value depends on 87 # implementation details. Notice that since we check that the 88 # matrices of singular vectors are unitary elsewhere, we do 89 # implicitly test that the trailing vectors of x and y span the 90 # same space. 91 x = x[..., 0:rank] 92 y = y[..., 0:rank] 93 # Q is only unique up to sign (complex phase factor for complex matrices), 94 # so we normalize the sign first. 95 sum_of_ratios = np.sum(np.divide(y, x), -2, keepdims=True) 96 phases = np.divide(sum_of_ratios, np.abs(sum_of_ratios)) 97 x *= phases 98 self.assertAllClose(x, y, atol=atol) 99 100 def CheckApproximation(self, a, q, r): 101 if is_single: 102 tol = 1e-5 103 else: 104 tol = 1e-14 105 # Tests that a ~= q*r. 106 a_recon = math_ops.matmul(q, r) 107 self.assertAllClose(a_recon, a, rtol=tol, atol=tol) 108 109 def CheckUnitary(self, x): 110 # Tests that x[...,:,:]^H * x[...,:,:] is close to the identity. 111 xx = math_ops.matmul(x, x, adjoint_a=True) 112 identity = array_ops.matrix_band_part(array_ops.ones_like(xx), 0, 0) 113 if is_single: 114 tol = 1e-5 115 else: 116 tol = 1e-14 117 self.assertAllClose(identity, xx, atol=tol) 118 119 @test_util.run_v1_only("b/120545219") 120 def Test(self): 121 np.random.seed(1) 122 x_np = np.random.uniform( 123 low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(dtype_) 124 if is_complex: 125 x_np += 1j * np.random.uniform( 126 low=-1.0, high=1.0, 127 size=np.prod(shape_)).reshape(shape_).astype(dtype_) 128 129 with self.session(use_gpu=True) as sess: 130 if use_static_shape_: 131 x_tf = constant_op.constant(x_np) 132 else: 133 x_tf = array_ops.placeholder(dtype_) 134 q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices_) 135 136 if use_static_shape_: 137 q_tf_val, r_tf_val = self.evaluate([q_tf, r_tf]) 138 else: 139 q_tf_val, r_tf_val = sess.run([q_tf, r_tf], feed_dict={x_tf: x_np}) 140 141 q_dims = q_tf_val.shape 142 np_q = np.ndarray(q_dims, dtype_) 143 np_q_reshape = np.reshape(np_q, (-1, q_dims[-2], q_dims[-1])) 144 new_first_dim = np_q_reshape.shape[0] 145 146 x_reshape = np.reshape(x_np, (-1, x_np.shape[-2], x_np.shape[-1])) 147 for i in range(new_first_dim): 148 if full_matrices_: 149 np_q_reshape[i, :, :], _ = np.linalg.qr( 150 x_reshape[i, :, :], mode="complete") 151 else: 152 np_q_reshape[i, :, :], _ = np.linalg.qr( 153 x_reshape[i, :, :], mode="reduced") 154 np_q = np.reshape(np_q_reshape, q_dims) 155 CompareOrthogonal(self, np_q, q_tf_val, min(shape_[-2:])) 156 CheckApproximation(self, x_np, q_tf_val, r_tf_val) 157 CheckUnitary(self, q_tf_val) 158 159 return Test 160 161 162class QrGradOpTest(test.TestCase): 163 pass 164 165 166def _GetQrGradOpTest(dtype_, shape_, full_matrices_): 167 168 @test_util.run_v1_only("b/120545219") 169 def Test(self): 170 np.random.seed(42) 171 a = np.random.uniform(low=-1.0, high=1.0, size=shape_).astype(dtype_) 172 if dtype_ in [np.complex64, np.complex128]: 173 a += 1j * np.random.uniform( 174 low=-1.0, high=1.0, size=shape_).astype(dtype_) 175 # Optimal stepsize for central difference is O(epsilon^{1/3}). 176 epsilon = np.finfo(dtype_).eps 177 delta = 0.1 * epsilon**(1.0 / 3.0) 178 if dtype_ in [np.float32, np.complex64]: 179 tol = 3e-2 180 else: 181 tol = 1e-6 182 with self.session(use_gpu=True): 183 tf_a = constant_op.constant(a) 184 tf_b = linalg_ops.qr(tf_a, full_matrices=full_matrices_) 185 for b in tf_b: 186 x_init = np.random.uniform( 187 low=-1.0, high=1.0, size=shape_).astype(dtype_) 188 if dtype_ in [np.complex64, np.complex128]: 189 x_init += 1j * np.random.uniform( 190 low=-1.0, high=1.0, size=shape_).astype(dtype_) 191 theoretical, numerical = gradient_checker.compute_gradient( 192 tf_a, 193 tf_a.get_shape().as_list(), 194 b, 195 b.get_shape().as_list(), 196 x_init_value=x_init, 197 delta=delta) 198 self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol) 199 200 return Test 201 202 203if __name__ == "__main__": 204 for dtype in np.float32, np.float64, np.complex64, np.complex128: 205 for rows in 1, 2, 5, 10, 32, 100: 206 for cols in 1, 2, 5, 10, 32, 100: 207 for full_matrices in False, True: 208 for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10): 209 # TF2 does not support placeholders under eager so we skip it 210 for use_static_shape in set([True, tf2.enabled()]): 211 shape = batch_dims + (rows, cols) 212 name = "%s_%s_full_%s_static_%s" % (dtype.__name__, 213 "_".join(map(str, shape)), 214 full_matrices, 215 use_static_shape) 216 _AddTest(QrOpTest, "Qr", name, 217 _GetQrOpTest(dtype, shape, full_matrices, 218 use_static_shape)) 219 220 # TODO(pfau): Get working with complex types. 221 # TODO(pfau): Get working with full_matrices when rows != cols 222 # TODO(pfau): Get working when rows < cols 223 # TODO(pfau): Get working with shapeholders (dynamic shapes) 224 for full_matrices in False, True: 225 for dtype in np.float32, np.float64: 226 for rows in 1, 2, 5, 10: 227 for cols in 1, 2, 5, 10: 228 if rows == cols or (not full_matrices and rows > cols): 229 for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10): 230 shape = batch_dims + (rows, cols) 231 name = "%s_%s_full_%s" % (dtype.__name__, 232 "_".join(map(str, shape)), 233 full_matrices) 234 _AddTest(QrGradOpTest, "QrGrad", name, 235 _GetQrGradOpTest(dtype, shape, full_matrices)) 236 test.main() 237