1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/matrix.h"
17 
18 #include <limits>
19 #include <map>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/strings/string_view.h"
24 #include "tensorflow/compiler/xla/client/lib/constants.h"
25 #include "tensorflow/compiler/xla/client/lib/slicing.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/status.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
31 #include "tensorflow/compiler/xla/tests/test_macros.h"
32 #include "tensorflow/compiler/xla/types.h"
33 
34 namespace xla {
35 namespace {
36 
37 class MatrixTest : public ClientLibraryTestBase {
38  protected:
39   template <typename T>
40   void TestMatrixDiagonal();
41   template <typename T>
42   void TestMatrixDiagonal4D();
43   template <typename T>
44   void TestSetMatrixDiagonal();
45 
46   template <typename T>
k_and_expected() const47   std::map<int, Array2D<T>> k_and_expected() const {
48     return std::map<int, Array2D<T>>{
49         {0, {{0, 5, 10}, {12, 17, 22}}},
50         {1, {{1, 6, 11}, {13, 18, 23}}},
51         {2, {{2, 7}, {14, 19}}},
52         {3, {{3}, {15}}},
53         {4, {{}, {}}},
54         {-1, {{4, 9}, {16, 21}}},
55         {-2, {{8}, {20}}},
56         {-3, {{}, {}}},
57         {-4, {{}, {}}},
58     };
59   }
60 };
61 
XLA_TEST_F(MatrixTest,Triangle)62 XLA_TEST_F(MatrixTest, Triangle) {
63   XlaBuilder builder(TestName());
64   Array3D<int32> input(2, 3, 4);
65   input.FillIota(0);
66 
67   XlaOp a;
68   auto a_data = CreateR3Parameter<int32>(input, 0, "a", &builder, &a);
69   LowerTriangle(a);
70   Array3D<int32> expected({{{0, 0, 0, 0}, {4, 5, 0, 0}, {8, 9, 10, 0}},
71                            {{12, 0, 0, 0}, {16, 17, 0, 0}, {20, 21, 22, 0}}});
72 
73   ComputeAndCompareR3<int32>(&builder, expected, {a_data.get()});
74 }
75 
76 template <typename T>
TestMatrixDiagonal()77 void MatrixTest::TestMatrixDiagonal() {
78   XlaBuilder builder("SetMatrixDiagonal");
79   Array3D<T> input(2, 3, 4);
80   input.FillIota(0);
81   for (const auto& kv : k_and_expected<T>()) {
82     XlaOp a;
83     auto a_data = CreateR3Parameter<T>(input, 0, "a", &builder, &a);
84     GetMatrixDiagonal(a, kv.first);
85 
86     ComputeAndCompareR2<T>(&builder, kv.second, {a_data.get()});
87   }
88 }
89 
90 template <typename T>
TestSetMatrixDiagonal()91 void MatrixTest::TestSetMatrixDiagonal() {
92   XlaBuilder builder("GetMatrixDiagonal");
93   Array3D<T> input(2, 3, 4);
94   input.FillIota(0);
95   for (const auto& kv : k_and_expected<T>()) {
96     XlaOp a;
97     XlaOp b;
98     auto a_data = CreateR3Parameter<T>(input, 0, "a", &builder, &a);
99     auto new_diag =
100         CreateR2Parameter<T>(Array2D<T>{kv.second}, 1, "d", &builder, &b);
101 
102     GetMatrixDiagonal(SetMatrixDiagonal(a, b + ScalarLike(b, 1), kv.first),
103                       kv.first) -
104         ScalarLike(b, 1);
105 
106     ComputeAndCompareR2<T>(&builder, kv.second, {a_data.get(), new_diag.get()});
107   }
108 }
109 
XLA_TEST_F(MatrixTest,SetMatrixDiagonal_S32)110 XLA_TEST_F(MatrixTest, SetMatrixDiagonal_S32) {
111   TestSetMatrixDiagonal<int32>();
112 }
XLA_TEST_F(MatrixTest,SetMatrixDiagonal_S64)113 XLA_TEST_F(MatrixTest, SetMatrixDiagonal_S64) {
114   TestSetMatrixDiagonal<int64>();
115 }
XLA_TEST_F(MatrixTest,SetMatrixDiagonal_F32)116 XLA_TEST_F(MatrixTest, SetMatrixDiagonal_F32) {
117   TestSetMatrixDiagonal<float>();
118 }
119 
XLA_TEST_F(MatrixTest,GetMatrixDiagonal_S32)120 XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal<int32>(); }
121 
XLA_TEST_F(MatrixTest,GetMatrixDiagonal_S64)122 XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S64) { TestMatrixDiagonal<int64>(); }
123 
XLA_TEST_F(MatrixTest,GetMatrixDiagonal_F32)124 XLA_TEST_F(MatrixTest, GetMatrixDiagonal_F32) { TestMatrixDiagonal<float>(); }
125 
126 template <typename T>
TestMatrixDiagonal4D()127 void MatrixTest::TestMatrixDiagonal4D() {
128   XlaBuilder builder("GetMatrixDiagonal");
129   Array4D<T> input(2, 2, 4, 3);
130   input.FillIota(0);
131   std::map<int, Array3D<T>> k_and_expected = {
132       {0, {{{0, 4, 8}, {12, 16, 20}}, {{24, 28, 32}, {36, 40, 44}}}},
133       {1, {{{1, 5}, {13, 17}}, {{25, 29}, {37, 41}}}},
134       {2, {{{2}, {14}}, {{26}, {38}}}},
135       {3, {{{}, {}}, {{}, {}}}},
136       {4, {{{}, {}}, {{}, {}}}},
137       {-1, {{{3, 7, 11}, {15, 19, 23}}, {{27, 31, 35}, {39, 43, 47}}}},
138       {-2, {{{6, 10}, {18, 22}}, {{30, 34}, {42, 46}}}},
139       {-3, {{{9}, {21}}, {{33}, {45}}}},
140       {-4, {{{}, {}}, {{}, {}}}},
141   };
142   for (const auto& kv : k_and_expected) {
143     XlaOp a;
144     auto a_data = CreateR4Parameter<T>(input, 0, "a", &builder, &a);
145     GetMatrixDiagonal(a, kv.first);
146 
147     ComputeAndCompareR3<T>(&builder, kv.second, {a_data.get()});
148   }
149 }
150 
XLA_TEST_F(MatrixTest,GetMatrixDiagonal4D_S32)151 XLA_TEST_F(MatrixTest, GetMatrixDiagonal4D_S32) {
152   TestMatrixDiagonal4D<int32>();
153 }
154 
XLA_TEST_F(MatrixTest,GetMatrixDiagonal4D_S64)155 XLA_TEST_F(MatrixTest, GetMatrixDiagonal4D_S64) {
156   TestMatrixDiagonal4D<int64>();
157 }
158 
XLA_TEST_F(MatrixTest,GetMatrixDiagonal4D_F32)159 XLA_TEST_F(MatrixTest, GetMatrixDiagonal4D_F32) {
160   TestMatrixDiagonal4D<float>();
161 }
162 
BatchedAValsFull()163 Array3D<float> BatchedAValsFull() {
164   return {{
165               {2, 0, 1, 2},
166               {3, 6, 0, 1},
167               {4, 7, 9, 0},
168               {5, 8, 10, 11},
169           },
170           {
171               {16, 24, 8, 12},
172               {24, 61, 82, 48},
173               {8, 82, 456, 106},
174               {12, 48, 106, 62},
175           }};
176 }
177 
XLA_TEST_F(MatrixTest,RowBatchDot)178 XLA_TEST_F(MatrixTest, RowBatchDot) {
179   XlaBuilder builder(TestName());
180   int n = 4;
181 
182   XlaOp a, row, index;
183   auto a_data =
184       CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
185   auto row_data = CreateR3Parameter<float>({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1,
186                                            "row", &builder, &row);
187   // Select {{3, 6, 0, 1}, {24, 61,  82,  48}} out of BatchedAValsFull().
188   auto index_data = CreateR0Parameter<int>(1, 2, "index", &builder, &index);
189 
190   auto l_index = DynamicSliceInMinorDims(
191       a, {index, ConstantR0<int32>(&builder, 0)}, {1, n});
192   BatchDot(l_index, TransposeInMinorDims(row));
193 
194   ComputeAndCompareR3<float>(&builder, {{{33}}, {{292}}},
195                              {a_data.get(), row_data.get(), index_data.get()});
196 }
197 
XLA_TEST_F(MatrixTest,Einsum)198 XLA_TEST_F(MatrixTest, Einsum) {
199   XlaBuilder builder(TestName());
200 
201   int n = 4;
202 
203   XlaOp a, row, index;
204   auto a_data =
205       CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
206   auto row_data = CreateR3Parameter<float>({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1,
207                                            "row", &builder, &row);
208   // Select {{3, 6, 0, 1}, {24, 61,  82,  48}} out of BatchedAValsFull().
209   auto index_data = CreateR0Parameter<int>(1, 2, "index", &builder, &index);
210 
211   auto l_index = DynamicSliceInMinorDims(
212       a, {index, ConstantR0<int32>(&builder, 0)}, {1, n});
213   Einsum(l_index, row, "abc,adc->abd");
214 
215   ComputeAndCompareR3<float>(&builder, {{{33}}, {{292}}},
216                              {a_data.get(), row_data.get(), index_data.get()});
217 }
218 
XLA_TEST_F(MatrixTest,ParseEinsumString)219 XLA_TEST_F(MatrixTest, ParseEinsumString) {
220   auto to_vec = [](absl::string_view s) {
221     std::vector<int64> v;
222     v.reserve(s.size());
223     int e = -3;
224     for (auto c : s) {
225       v.push_back(c == '.' ? e++ : int64{c});
226     }
227     return v;
228   };
229 
230   auto to_string = [&](absl::string_view x, absl::string_view y,
231                        absl::string_view o) {
232     return absl::StrCat(x, ",", y, "->", o);
233   };
234 
235   std::vector<std::vector<string>> good_test_cases = {
236       {"ab", "bc", "ac"},
237       {"Bab", "Bbc", "Bac"},
238       {"ab", "cd", "dcba"},
239       {"abc", "abd", "cbd"},
240       {"...ab", "...bc", "...ac"},
241       {"a...bc", "...abd", "cbd..."},
242       {"...ab", "...bc", "ac"},
243       {"...b", "...bc", "...c"},
244       {"...abz", "...bc", "...ac"},
245       {"...ab", "...bcz", "...ac"},
246       {"abz", "bc", "ac"},
247       {"ab", "bcz", "ac"},
248 
249       {"a", "b", "c"},
250       {"...a", "...b", "...c"},
251       {"abb", "bcc", "ac"},
252       {"ab", "bc", "ad"},
253   };
254   for (auto test_case : good_test_cases) {
255     auto parse_result_or_status =
256         ParseEinsumString(to_string(test_case[0], test_case[1], test_case[2]),
257                           test_case[0].size(), test_case[1].size());
258     EXPECT_TRUE(parse_result_or_status.status().ok());
259     auto parse_result = parse_result_or_status.ValueOrDie();
260     for (int i = 0; i < 3; ++i) {
261       EXPECT_EQ(parse_result[i], to_vec(test_case[i]));
262     }
263   }
264 
265   std::vector<string> einsum_strings_that_fail_parsing = {
266       "", "a", "ab->ba", "ab,bc,cd->ad", "a...b...,bc->a...c",
267   };
268   for (auto test_case : einsum_strings_that_fail_parsing) {
269     auto parse_result_or_status = ParseEinsumString(test_case, 3, 3);
270     EXPECT_FALSE(parse_result_or_status.status().ok());
271   }
272 }
273 
XLA_TEST_F(MatrixTest,NormalizeEinsumString)274 XLA_TEST_F(MatrixTest, NormalizeEinsumString) {
275   EXPECT_EQ(NormalizeEinsumString("a,b->ab"), "");
276   EXPECT_EQ(NormalizeEinsumString("ba"), "ba->ab");
277   EXPECT_EQ(NormalizeEinsumString("ab,dc"), "ab,dc->abcd");
278   EXPECT_EQ(NormalizeEinsumString("a,b"), "a,b->ab");
279   EXPECT_EQ(NormalizeEinsumString("...ba,ca..."), "...ba,ca...->...bc");
280 }
281 
282 }  // namespace
283 }  // namespace xla
284