1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.ops.tf.matmul.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import test_util 26from tensorflow.python.ops import gradient_checker 27from tensorflow.python.ops import math_ops 28from tensorflow.python.platform import test 29 30 31def RandMatrix(rows, cols, tr, round_bfloat=False): 32 if tr: 33 rows, cols = cols, rows 34 rand_func = np.random.randint if round_bfloat else np.random.uniform 35 return (np.clip( 36 rand_func( 37 low=-256.0, high=256.0, size=rows * cols), -64, 38 64) / 128.0).reshape([rows, cols]).astype(np.float32) 39 40 41class SparseMatMulTest(test.TestCase): 42 43 def _testCpuMatmul(self, 44 x, 45 y, 46 tr_a=False, 47 tr_b=False, 48 sp_a=True, 49 sp_b=False, 50 x_dtype=dtypes.float32, 51 y_dtype=dtypes.float32): 52 with self.cached_session(use_gpu=False): 53 tf_x = math_ops.cast(x, x_dtype) 54 tf_y = math_ops.cast(y, y_dtype) 55 tf_ans = math_ops.matmul( 56 tf_x, 57 tf_y, 58 transpose_a=tr_a, 59 transpose_b=tr_b, 60 a_is_sparse=sp_a, 61 b_is_sparse=sp_b) 62 out = self.evaluate(tf_ans) 63 np_x = math_ops.cast(tf_x, dtypes.float32).eval() 64 np_y = math_ops.cast(tf_y, dtypes.float32).eval() 65 66 if tr_a: 67 np_x = np.transpose(np_x) 68 if tr_b: 69 np_y = np.transpose(np_y) 70 71 np_ans = np.matrix(np_x) * np.matrix(np_y) 72 self.assertShapeEqual(np_ans, tf_ans) 73 self.assertAllCloseAccordingToType(np_ans, out, rtol=1e-4, atol=1e-4) 74 75 @test_util.run_deprecated_v1 76 def testBasic(self): 77 x = np.arange(0., 4.).reshape([4, 1]).astype(np.float32) 78 y = np.arange(-1., 1.).reshape([1, 2]).astype(np.float32) 79 for x_dtype in (dtypes.float32, dtypes.bfloat16): 80 for y_dtype in (dtypes.float32, dtypes.bfloat16): 81 self._testCpuMatmul(x, y, x_dtype=x_dtype, y_dtype=y_dtype) 82 83 @test_util.run_deprecated_v1 84 def testZeroDim(self): 85 x = np.ones((4, 0)).astype(np.float32) 86 y = np.ones((0, 3)).astype(np.float32) 87 for x_dtype in (dtypes.float32, dtypes.bfloat16): 88 for y_dtype in (dtypes.float32, dtypes.bfloat16): 89 self._testCpuMatmul(x, y, x_dtype=x_dtype, y_dtype=y_dtype) 90 91 @test_util.run_deprecated_v1 92 def testEmpty(self): 93 x = np.ones((0, 0)).astype(np.float32) 94 y = np.ones((0, 0)).astype(np.float32) 95 for x_dtype in (dtypes.float32, dtypes.bfloat16): 96 for y_dtype in (dtypes.float32, dtypes.bfloat16): 97 self._testCpuMatmul(x, y, x_dtype=x_dtype, y_dtype=y_dtype) 98 99 # Tests setting one dimension to be a high value. 100 @test_util.run_deprecated_v1 101 def testLarge(self): 102 r1 = np.random.randint(6000, 20000) 103 r2 = np.random.randint(1, 10) 104 r3 = np.random.randint(1, 10) 105 for m, k, n in [(r1, r2, r3), (r2, r1, r3), (r2, r3, r1)]: 106 for x_dtype in (dtypes.float32, dtypes.bfloat16): 107 for y_dtype in (dtypes.float32, dtypes.bfloat16): 108 x = RandMatrix(m, k, False) 109 y = RandMatrix(k, n, False) 110 self._testCpuMatmul(x, y, x_dtype=x_dtype, y_dtype=y_dtype) 111 112 # Tests random sized matrices. 113 @test_util.run_deprecated_v1 114 def testRandom(self): 115 for tr_a in [True, False]: 116 for tr_b in [True, False]: 117 for sp_a in [True, False]: 118 for sp_b in [True, False]: 119 for x_dtype in (dtypes.float32, dtypes.bfloat16): 120 for y_dtype in (dtypes.float32, dtypes.bfloat16): 121 n, k, m = np.random.randint(1, 100, size=3) 122 x = RandMatrix(n, k, tr_a) 123 y = RandMatrix(k, m, tr_b) 124 self._testCpuMatmul( 125 x, 126 y, 127 tr_a, 128 tr_b, 129 sp_a, 130 sp_b, 131 x_dtype=x_dtype, 132 y_dtype=y_dtype) 133 134 135class MatMulGradientTest(test.TestCase): 136 137 def _testGradients(self, tr_a, tr_b, sp_a, sp_b, a_dtype, b_dtype, delta, 138 name): 139 with self.cached_session(): 140 a = constant_op.constant( 141 RandMatrix( 142 3, 2, tr_a, round_bfloat=True), dtype=dtypes.float32) 143 b = constant_op.constant( 144 RandMatrix( 145 2, 4, tr_b, round_bfloat=True), dtype=dtypes.float32) 146 tf_a = math_ops.cast(a, a_dtype) if a_dtype != dtypes.float32 else a 147 tf_b = math_ops.cast(b, b_dtype) if b_dtype != dtypes.float32 else b 148 149 m = math_ops.matmul( 150 tf_a, 151 tf_b, 152 name=name, 153 transpose_a=tr_a, 154 transpose_b=tr_b, 155 a_is_sparse=sp_a, 156 b_is_sparse=sp_b) 157 err = (gradient_checker.compute_gradient_error( 158 a, [2, 3] if tr_a else [3, 2], 159 m, [3, 4], 160 x_init_value=a.eval(), 161 delta=delta) + gradient_checker.compute_gradient_error( 162 b, [4, 2] if tr_b else [2, 4], 163 m, [3, 4], 164 x_init_value=b.eval(), 165 delta=delta)) 166 self.assertLessEqual(err, delta / 2.) 167 168 @test_util.run_deprecated_v1 169 def testGradientInput(self): 170 for tr_a in [True, False]: 171 for tr_b in [True, False]: 172 for sp_a in [True, False]: 173 for sp_b in [True, False]: 174 for a_dtype in (dtypes.float32, dtypes.bfloat16): 175 for b_dtype in (dtypes.float32, dtypes.bfloat16): 176 # Note: bfloat16 only has 7 mantissa bits, versus float32 with 177 # 10. Hence, we shift by 2 bits to pass the test. 178 if a_dtype == dtypes.bfloat16 and b_dtype == dtypes.bfloat16: 179 delta = 1 / 16. 180 else: 181 delta = 1 / 64. 182 name = "sparse_matmul_%s_%s_%s_%s" % (tr_a, tr_b, sp_a, sp_b) 183 self._testGradients(tr_a, tr_b, sp_a, sp_b, a_dtype, b_dtype, 184 delta, name) 185 186 187if __name__ == "__main__": 188 test.main() 189