1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ops.math_ops.matrix_inverse."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python import tf2
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import test_util
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import gradient_checker
28from tensorflow.python.ops import linalg_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops import random_ops
31from tensorflow.python.platform import test
32
33
34def _AddTest(test_class, op_name, testcase_name, fn):
35  test_name = "_".join(["test", op_name, testcase_name])
36  if hasattr(test_class, test_name):
37    raise RuntimeError("Test %s defined more than once" % test_name)
38  setattr(test_class, test_name, fn)
39
40
41class QrOpTest(test.TestCase):
42
43  @test_util.run_v1_only("b/120545219")
44  def testWrongDimensions(self):
45    # The input to qr should be a tensor of at least rank 2.
46    scalar = constant_op.constant(1.)
47    with self.assertRaisesRegexp(ValueError,
48                                 "Shape must be at least rank 2 but is rank 0"):
49      linalg_ops.qr(scalar)
50    vector = constant_op.constant([1., 2.])
51    with self.assertRaisesRegexp(ValueError,
52                                 "Shape must be at least rank 2 but is rank 1"):
53      linalg_ops.qr(vector)
54
55  @test_util.run_deprecated_v1
56  def testConcurrentExecutesWithoutError(self):
57    with self.session(use_gpu=True) as sess:
58      all_ops = []
59      for full_matrices_ in True, False:
60        for rows_ in 4, 5:
61          for cols_ in 4, 5:
62            matrix1 = random_ops.random_normal([rows_, cols_], seed=42)
63            matrix2 = random_ops.random_normal([rows_, cols_], seed=42)
64            q1, r1 = linalg_ops.qr(matrix1, full_matrices=full_matrices_)
65            q2, r2 = linalg_ops.qr(matrix2, full_matrices=full_matrices_)
66            all_ops += [q1, r1, q2, r2]
67      val = self.evaluate(all_ops)
68      for i in range(8):
69        q = 4 * i
70        self.assertAllClose(val[q], val[q + 2])  # q1 == q2
71        self.assertAllClose(val[q + 1], val[q + 3])  # r1 == r2
72
73
74def _GetQrOpTest(dtype_, shape_, full_matrices_, use_static_shape_):
75
76  is_complex = dtype_ in (np.complex64, np.complex128)
77  is_single = dtype_ in (np.float32, np.complex64)
78
79  def CompareOrthogonal(self, x, y, rank):
80    if is_single:
81      atol = 5e-4
82    else:
83      atol = 5e-14
84    # We only compare the first 'rank' orthogonal vectors since the
85    # remainder form an arbitrary orthonormal basis for the
86    # (row- or column-) null space, whose exact value depends on
87    # implementation details. Notice that since we check that the
88    # matrices of singular vectors are unitary elsewhere, we do
89    # implicitly test that the trailing vectors of x and y span the
90    # same space.
91    x = x[..., 0:rank]
92    y = y[..., 0:rank]
93    # Q is only unique up to sign (complex phase factor for complex matrices),
94    # so we normalize the sign first.
95    sum_of_ratios = np.sum(np.divide(y, x), -2, keepdims=True)
96    phases = np.divide(sum_of_ratios, np.abs(sum_of_ratios))
97    x *= phases
98    self.assertAllClose(x, y, atol=atol)
99
100  def CheckApproximation(self, a, q, r):
101    if is_single:
102      tol = 1e-5
103    else:
104      tol = 1e-14
105    # Tests that a ~= q*r.
106    a_recon = math_ops.matmul(q, r)
107    self.assertAllClose(a_recon, a, rtol=tol, atol=tol)
108
109  def CheckUnitary(self, x):
110    # Tests that x[...,:,:]^H * x[...,:,:] is close to the identity.
111    xx = math_ops.matmul(x, x, adjoint_a=True)
112    identity = array_ops.matrix_band_part(array_ops.ones_like(xx), 0, 0)
113    if is_single:
114      tol = 1e-5
115    else:
116      tol = 1e-14
117    self.assertAllClose(identity, xx, atol=tol)
118
119  @test_util.run_v1_only("b/120545219")
120  def Test(self):
121    np.random.seed(1)
122    x_np = np.random.uniform(
123        low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(dtype_)
124    if is_complex:
125      x_np += 1j * np.random.uniform(
126          low=-1.0, high=1.0,
127          size=np.prod(shape_)).reshape(shape_).astype(dtype_)
128
129    with self.session(use_gpu=True) as sess:
130      if use_static_shape_:
131        x_tf = constant_op.constant(x_np)
132      else:
133        x_tf = array_ops.placeholder(dtype_)
134      q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices_)
135
136      if use_static_shape_:
137        q_tf_val, r_tf_val = self.evaluate([q_tf, r_tf])
138      else:
139        q_tf_val, r_tf_val = sess.run([q_tf, r_tf], feed_dict={x_tf: x_np})
140
141      q_dims = q_tf_val.shape
142      np_q = np.ndarray(q_dims, dtype_)
143      np_q_reshape = np.reshape(np_q, (-1, q_dims[-2], q_dims[-1]))
144      new_first_dim = np_q_reshape.shape[0]
145
146      x_reshape = np.reshape(x_np, (-1, x_np.shape[-2], x_np.shape[-1]))
147      for i in range(new_first_dim):
148        if full_matrices_:
149          np_q_reshape[i, :, :], _ = np.linalg.qr(
150              x_reshape[i, :, :], mode="complete")
151        else:
152          np_q_reshape[i, :, :], _ = np.linalg.qr(
153              x_reshape[i, :, :], mode="reduced")
154      np_q = np.reshape(np_q_reshape, q_dims)
155      CompareOrthogonal(self, np_q, q_tf_val, min(shape_[-2:]))
156      CheckApproximation(self, x_np, q_tf_val, r_tf_val)
157      CheckUnitary(self, q_tf_val)
158
159  return Test
160
161
162class QrGradOpTest(test.TestCase):
163  pass
164
165
166def _GetQrGradOpTest(dtype_, shape_, full_matrices_):
167
168  @test_util.run_v1_only("b/120545219")
169  def Test(self):
170    np.random.seed(42)
171    a = np.random.uniform(low=-1.0, high=1.0, size=shape_).astype(dtype_)
172    if dtype_ in [np.complex64, np.complex128]:
173      a += 1j * np.random.uniform(
174          low=-1.0, high=1.0, size=shape_).astype(dtype_)
175    # Optimal stepsize for central difference is O(epsilon^{1/3}).
176    epsilon = np.finfo(dtype_).eps
177    delta = 0.1 * epsilon**(1.0 / 3.0)
178    if dtype_ in [np.float32, np.complex64]:
179      tol = 3e-2
180    else:
181      tol = 1e-6
182    with self.session(use_gpu=True):
183      tf_a = constant_op.constant(a)
184      tf_b = linalg_ops.qr(tf_a, full_matrices=full_matrices_)
185      for b in tf_b:
186        x_init = np.random.uniform(
187            low=-1.0, high=1.0, size=shape_).astype(dtype_)
188        if dtype_ in [np.complex64, np.complex128]:
189          x_init += 1j * np.random.uniform(
190              low=-1.0, high=1.0, size=shape_).astype(dtype_)
191        theoretical, numerical = gradient_checker.compute_gradient(
192            tf_a,
193            tf_a.get_shape().as_list(),
194            b,
195            b.get_shape().as_list(),
196            x_init_value=x_init,
197            delta=delta)
198        self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)
199
200  return Test
201
202
203if __name__ == "__main__":
204  for dtype in np.float32, np.float64, np.complex64, np.complex128:
205    for rows in 1, 2, 5, 10, 32, 100:
206      for cols in 1, 2, 5, 10, 32, 100:
207        for full_matrices in False, True:
208          for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
209            # TF2 does not support placeholders under eager so we skip it
210            for use_static_shape in set([True, tf2.enabled()]):
211              shape = batch_dims + (rows, cols)
212              name = "%s_%s_full_%s_static_%s" % (dtype.__name__,
213                                                  "_".join(map(str, shape)),
214                                                  full_matrices,
215                                                  use_static_shape)
216              _AddTest(QrOpTest, "Qr", name,
217                       _GetQrOpTest(dtype, shape, full_matrices,
218                                    use_static_shape))
219
220  # TODO(pfau): Get working with complex types.
221  # TODO(pfau): Get working with full_matrices when rows != cols
222  # TODO(pfau): Get working when rows < cols
223  # TODO(pfau): Get working with shapeholders (dynamic shapes)
224  for full_matrices in False, True:
225    for dtype in np.float32, np.float64:
226      for rows in 1, 2, 5, 10:
227        for cols in 1, 2, 5, 10:
228          if rows == cols or (not full_matrices and rows > cols):
229            for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
230              shape = batch_dims + (rows, cols)
231              name = "%s_%s_full_%s" % (dtype.__name__,
232                                        "_".join(map(str, shape)),
233                                        full_matrices)
234              _AddTest(QrGradOpTest, "QrGrad", name,
235                       _GetQrGradOpTest(dtype, shape, full_matrices))
236  test.main()
237