1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ops.tf.matmul."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import test_util
26from tensorflow.python.ops import gradient_checker
27from tensorflow.python.ops import math_ops
28from tensorflow.python.platform import test
29
30
31def RandMatrix(rows, cols, tr, round_bfloat=False):
32  if tr:
33    rows, cols = cols, rows
34  rand_func = np.random.randint if round_bfloat else np.random.uniform
35  return (np.clip(
36      rand_func(
37          low=-256.0, high=256.0, size=rows * cols), -64,
38      64) / 128.0).reshape([rows, cols]).astype(np.float32)
39
40
41class SparseMatMulTest(test.TestCase):
42
43  def _testCpuMatmul(self,
44                     x,
45                     y,
46                     tr_a=False,
47                     tr_b=False,
48                     sp_a=True,
49                     sp_b=False,
50                     x_dtype=dtypes.float32,
51                     y_dtype=dtypes.float32):
52    with self.cached_session(use_gpu=False):
53      tf_x = math_ops.cast(x, x_dtype)
54      tf_y = math_ops.cast(y, y_dtype)
55      tf_ans = math_ops.matmul(
56          tf_x,
57          tf_y,
58          transpose_a=tr_a,
59          transpose_b=tr_b,
60          a_is_sparse=sp_a,
61          b_is_sparse=sp_b)
62      out = self.evaluate(tf_ans)
63      np_x = math_ops.cast(tf_x, dtypes.float32).eval()
64      np_y = math_ops.cast(tf_y, dtypes.float32).eval()
65
66    if tr_a:
67      np_x = np.transpose(np_x)
68    if tr_b:
69      np_y = np.transpose(np_y)
70
71    np_ans = np.matrix(np_x) * np.matrix(np_y)
72    self.assertShapeEqual(np_ans, tf_ans)
73    self.assertAllCloseAccordingToType(np_ans, out, rtol=1e-4, atol=1e-4)
74
75  @test_util.run_deprecated_v1
76  def testBasic(self):
77    x = np.arange(0., 4.).reshape([4, 1]).astype(np.float32)
78    y = np.arange(-1., 1.).reshape([1, 2]).astype(np.float32)
79    for x_dtype in (dtypes.float32, dtypes.bfloat16):
80      for y_dtype in (dtypes.float32, dtypes.bfloat16):
81        self._testCpuMatmul(x, y, x_dtype=x_dtype, y_dtype=y_dtype)
82
83  @test_util.run_deprecated_v1
84  def testZeroDim(self):
85    x = np.ones((4, 0)).astype(np.float32)
86    y = np.ones((0, 3)).astype(np.float32)
87    for x_dtype in (dtypes.float32, dtypes.bfloat16):
88      for y_dtype in (dtypes.float32, dtypes.bfloat16):
89        self._testCpuMatmul(x, y, x_dtype=x_dtype, y_dtype=y_dtype)
90
91  @test_util.run_deprecated_v1
92  def testEmpty(self):
93    x = np.ones((0, 0)).astype(np.float32)
94    y = np.ones((0, 0)).astype(np.float32)
95    for x_dtype in (dtypes.float32, dtypes.bfloat16):
96      for y_dtype in (dtypes.float32, dtypes.bfloat16):
97        self._testCpuMatmul(x, y, x_dtype=x_dtype, y_dtype=y_dtype)
98
99  # Tests setting one dimension to be a high value.
100  @test_util.run_deprecated_v1
101  def testLarge(self):
102    r1 = np.random.randint(6000, 20000)
103    r2 = np.random.randint(1, 10)
104    r3 = np.random.randint(1, 10)
105    for m, k, n in [(r1, r2, r3), (r2, r1, r3), (r2, r3, r1)]:
106      for x_dtype in (dtypes.float32, dtypes.bfloat16):
107        for y_dtype in (dtypes.float32, dtypes.bfloat16):
108          x = RandMatrix(m, k, False)
109          y = RandMatrix(k, n, False)
110          self._testCpuMatmul(x, y, x_dtype=x_dtype, y_dtype=y_dtype)
111
112  # Tests random sized matrices.
113  @test_util.run_deprecated_v1
114  def testRandom(self):
115    for tr_a in [True, False]:
116      for tr_b in [True, False]:
117        for sp_a in [True, False]:
118          for sp_b in [True, False]:
119            for x_dtype in (dtypes.float32, dtypes.bfloat16):
120              for y_dtype in (dtypes.float32, dtypes.bfloat16):
121                n, k, m = np.random.randint(1, 100, size=3)
122                x = RandMatrix(n, k, tr_a)
123                y = RandMatrix(k, m, tr_b)
124                self._testCpuMatmul(
125                    x,
126                    y,
127                    tr_a,
128                    tr_b,
129                    sp_a,
130                    sp_b,
131                    x_dtype=x_dtype,
132                    y_dtype=y_dtype)
133
134
135class MatMulGradientTest(test.TestCase):
136
137  def _testGradients(self, tr_a, tr_b, sp_a, sp_b, a_dtype, b_dtype, delta,
138                     name):
139    with self.cached_session():
140      a = constant_op.constant(
141          RandMatrix(
142              3, 2, tr_a, round_bfloat=True), dtype=dtypes.float32)
143      b = constant_op.constant(
144          RandMatrix(
145              2, 4, tr_b, round_bfloat=True), dtype=dtypes.float32)
146      tf_a = math_ops.cast(a, a_dtype) if a_dtype != dtypes.float32 else a
147      tf_b = math_ops.cast(b, b_dtype) if b_dtype != dtypes.float32 else b
148
149      m = math_ops.matmul(
150          tf_a,
151          tf_b,
152          name=name,
153          transpose_a=tr_a,
154          transpose_b=tr_b,
155          a_is_sparse=sp_a,
156          b_is_sparse=sp_b)
157      err = (gradient_checker.compute_gradient_error(
158          a, [2, 3] if tr_a else [3, 2],
159          m, [3, 4],
160          x_init_value=a.eval(),
161          delta=delta) + gradient_checker.compute_gradient_error(
162              b, [4, 2] if tr_b else [2, 4],
163              m, [3, 4],
164              x_init_value=b.eval(),
165              delta=delta))
166    self.assertLessEqual(err, delta / 2.)
167
168  @test_util.run_deprecated_v1
169  def testGradientInput(self):
170    for tr_a in [True, False]:
171      for tr_b in [True, False]:
172        for sp_a in [True, False]:
173          for sp_b in [True, False]:
174            for a_dtype in (dtypes.float32, dtypes.bfloat16):
175              for b_dtype in (dtypes.float32, dtypes.bfloat16):
176                # Note: bfloat16 only has 7 mantissa bits, versus float32 with
177                # 10. Hence, we shift by 2 bits to pass the test.
178                if a_dtype == dtypes.bfloat16 and b_dtype == dtypes.bfloat16:
179                  delta = 1 / 16.
180                else:
181                  delta = 1 / 64.
182                name = "sparse_matmul_%s_%s_%s_%s" % (tr_a, tr_b, sp_a, sp_b)
183                self._testGradients(tr_a, tr_b, sp_a, sp_b, a_dtype, b_dtype,
184                                    delta, name)
185
186
187if __name__ == "__main__":
188  test.main()
189