1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
17 #include "absl/strings/ascii.h"
18 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
19 #include "tensorflow/compiler/xla/tests/test_utils.h"
20
21 namespace xla {
22 namespace {
23 class IndexedArrayAnalysisTest : public HloTestBase {
24 protected:
AssertArrayForRootExpressionIs(const string & hlo_text,const string & root_expression)25 void AssertArrayForRootExpressionIs(const string& hlo_text,
26 const string& root_expression) {
27 AssertArrayForRootExpressionIsImpl(hlo_text, root_expression,
28 /*print_constants=*/false);
29 }
30
AssertArrayWithConstantsForRootExpressionIs(const string & hlo_text,const string & root_expression)31 void AssertArrayWithConstantsForRootExpressionIs(
32 const string& hlo_text, const string& root_expression) {
33 AssertArrayForRootExpressionIsImpl(hlo_text, root_expression,
34 /*print_constants=*/true);
35 }
36
37 private:
38 // Replaces seqences of whitespace with a single space. This makes the
39 // strings being matched against "whitespace insensitive" which lets us indent
40 // them for readability.
CanonicalizeWhitespace(const string & text)41 string CanonicalizeWhitespace(const string& text) {
42 string result;
43
44 for (char c : text) {
45 if (!absl::ascii_isspace(c)) {
46 result.push_back(c);
47 } else if (!result.empty() && result.back() != ' ') {
48 result.push_back(' ');
49 }
50 }
51
52 while (!result.empty() && result.back() == ' ') {
53 result.pop_back();
54 }
55
56 return result;
57 }
58
AssertArrayForRootExpressionIsImpl(const string & hlo_text,const string & root_expression,bool print_constants)59 void AssertArrayForRootExpressionIsImpl(const string& hlo_text,
60 const string& root_expression,
61 bool print_constants) {
62 IndexedArrayAnalysis indexed_tensor_analysis;
63 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
64 ParseAndReturnVerifiedModule(hlo_text));
65
66 TF_ASSERT_OK_AND_ASSIGN(IndexedArrayAnalysis::Array* const array_result,
67 indexed_tensor_analysis.GetArrayFor(
68 m->entry_computation()->root_instruction()));
69 string string_result = CanonicalizeWhitespace(
70 indexed_tensor_analysis.ToString(array_result, print_constants));
71 LOG(INFO) << string_result;
72 ASSERT_EQ(string_result, CanonicalizeWhitespace(root_expression));
73 }
74 };
75
TEST_F(IndexedArrayAnalysisTest,SimpleOneToOneGather)76 TEST_F(IndexedArrayAnalysisTest, SimpleOneToOneGather) {
77 string hlo_text = R"(
78 HloModule SimpleGather
79
80 ENTRY main {
81 operand = s32[3,3] parameter(0)
82 indices = s32[5] parameter(1)
83 ROOT gather = s32[5,3] gather(operand, indices),
84 offset_dims={1},
85 collapsed_slice_dims={0},
86 start_index_map={0},
87 index_vector_dim=1,
88 slice_sizes={1,3}
89 }
90 )";
91
92 AssertArrayForRootExpressionIs(hlo_text,
93 "(scalar-indexed %operand %indices 0->[0])");
94 }
95
TEST_F(IndexedArrayAnalysisTest,SimpleOneToOneConstantGather)96 TEST_F(IndexedArrayAnalysisTest, SimpleOneToOneConstantGather) {
97 string hlo_text = R"(
98 HloModule SimpleGather
99
100 ENTRY main {
101 operand = s32[3,3] constant({{1,2,3},{1,2,3},{1,2,3}})
102 indices = s32[5] parameter(0)
103 ROOT gather = s32[5,3] gather(operand, indices),
104 offset_dims={1},
105 collapsed_slice_dims={0},
106 start_index_map={0},
107 index_vector_dim=1,
108 slice_sizes={1,3}
109 }
110 )";
111
112 AssertArrayForRootExpressionIs(
113 hlo_text, "(scalar-indexed-const (constant s32[3,3]) %indices 0->[0])");
114 }
115
TEST_F(IndexedArrayAnalysisTest,GatherIsNotScalarIndexed0)116 TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed0) {
117 string hlo_text = R"(
118 HloModule SimpleGather
119
120 ENTRY main {
121 operand = s32[3,3] constant({{1,2,3},{1,2,3},{1,2,3}})
122 indices = s32[5,2] parameter(0)
123 ROOT gather = s32[5] gather(operand, indices),
124 offset_dims={},
125 collapsed_slice_dims={0,1},
126 start_index_map={0,1},
127 index_vector_dim=1,
128 slice_sizes={1,1}
129 }
130 )";
131
132 AssertArrayForRootExpressionIs(hlo_text, "%gather");
133 }
134
TEST_F(IndexedArrayAnalysisTest,GatherIsNotScalarIndexed1)135 TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed1) {
136 string hlo_text = R"(
137 HloModule SimpleGather
138
139 ENTRY main {
140 operand = s32[3,3,1] parameter(0)
141 indices = s32[5] parameter(1)
142 ROOT gather = s32[5,3] gather(operand, indices),
143 offset_dims={1},
144 collapsed_slice_dims={0,2},
145 start_index_map={0},
146 index_vector_dim=1,
147 slice_sizes={1,3,1}
148 }
149 )";
150
151 AssertArrayForRootExpressionIs(hlo_text, "%gather");
152 }
153
TEST_F(IndexedArrayAnalysisTest,GatherIsNotScalarIndexed2)154 TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed2) {
155 string hlo_text = R"(
156 HloModule SimpleGather
157
158 ENTRY main {
159 operand = s32[3,3,1] parameter(0)
160 indices = s32[5] parameter(1)
161 ROOT gather = s32[5,2,3] gather(operand, indices),
162 offset_dims={1,2},
163 collapsed_slice_dims={2},
164 start_index_map={0},
165 index_vector_dim=1,
166 slice_sizes={2,3,1}
167 }
168 )";
169
170 AssertArrayForRootExpressionIs(hlo_text, "%gather");
171 }
172
TEST_F(IndexedArrayAnalysisTest,GatherIsNotScalarIndexed3)173 TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed3) {
174 string hlo_text = R"(
175 HloModule SimpleGather
176
177 ENTRY main {
178 operand = s32[3,3] parameter(0)
179 indices = s32[5] parameter(1)
180 ROOT gather = s32[5,2] gather(operand, indices),
181 offset_dims={1},
182 collapsed_slice_dims={0},
183 start_index_map={0},
184 index_vector_dim=1,
185 slice_sizes={1,2}
186 }
187 )";
188
189 AssertArrayForRootExpressionIs(hlo_text, "%gather");
190 }
191
TEST_F(IndexedArrayAnalysisTest,GatherOfGather_OneToOne)192 TEST_F(IndexedArrayAnalysisTest, GatherOfGather_OneToOne) {
193 string hlo_text = R"(
194 HloModule SimpleGather
195
196 ENTRY main {
197 operand = s32[3,3] constant({{1,2,3},{1,2,3},{1,2,3}})
198 indices_a = s32[5] parameter(0)
199 indices_b = s32[2] parameter(1)
200 gather_a = s32[5,3] gather(operand, indices_a),
201 offset_dims={1},
202 collapsed_slice_dims={0},
203 start_index_map={0},
204 index_vector_dim=1,
205 slice_sizes={1,3}
206 ROOT gather_b = s32[2,3] gather(gather_a, indices_b),
207 offset_dims={1},
208 collapsed_slice_dims={0},
209 start_index_map={0},
210 index_vector_dim=1,
211 slice_sizes={1,3}
212 }
213 )";
214
215 AssertArrayForRootExpressionIs(
216 hlo_text,
217 "(scalar-indexed-const (constant s32[3,3]) (scalar-indexed %indices_a "
218 "%indices_b 0->[0]) 0->[0])");
219 }
220
TEST_F(IndexedArrayAnalysisTest,GatherOfGather_ManyToOneWithOneToOne)221 TEST_F(IndexedArrayAnalysisTest, GatherOfGather_ManyToOneWithOneToOne) {
222 string hlo_text = R"(
223 HloModule SimpleGather
224
225 ENTRY main {
226 operand = s32[3,2] parameter(0)
227 indices_a = s32[5,7] parameter(1)
228 indices_b = s32[2] parameter(2)
229 gather_a = s32[5,3,7] gather(operand, indices_a),
230 offset_dims={1},
231 collapsed_slice_dims={1},
232 start_index_map={1},
233 index_vector_dim=2,
234 slice_sizes={3,1}
235 ROOT gather_b = s32[5,3,2] gather(gather_a, indices_b),
236 offset_dims={0,1},
237 collapsed_slice_dims={2},
238 start_index_map={2},
239 index_vector_dim=1,
240 slice_sizes={5,3,1}
241 }
242 )";
243
244 AssertArrayForRootExpressionIs(hlo_text,
245 "(scalar-indexed %operand (scalar-indexed "
246 "%indices_a %indices_b 1->[1]) 1->[0,2])");
247 }
248
TEST_F(IndexedArrayAnalysisTest,GatherOfGather_OneToOneWithManyToOne)249 TEST_F(IndexedArrayAnalysisTest, GatherOfGather_OneToOneWithManyToOne) {
250 string hlo_text = R"(
251 HloModule SimpleGather
252
253 ENTRY main {
254 operand = s32[3,6] parameter(0)
255 indices_a = s32[2] parameter(1)
256 indices_b = s32[5,7] parameter(2)
257 gather_a = s32[2,6] gather(operand, indices_a),
258 offset_dims={1},
259 collapsed_slice_dims={0},
260 start_index_map={0},
261 index_vector_dim=1,
262 slice_sizes={1,6}
263 ROOT gather_b = s32[5,6,7] gather(gather_a, indices_b),
264 offset_dims={1},
265 collapsed_slice_dims={0},
266 start_index_map={0},
267 index_vector_dim=2,
268 slice_sizes={1,6}
269 }
270 )";
271
272 AssertArrayForRootExpressionIs(hlo_text,
273 "(scalar-indexed %operand (scalar-indexed "
274 "%indices_a %indices_b 0->[0,1]) 0->[0,2])");
275 }
276
TEST_F(IndexedArrayAnalysisTest,GatherOfGather_ManyToOneWithManyToOne)277 TEST_F(IndexedArrayAnalysisTest, GatherOfGather_ManyToOneWithManyToOne) {
278 string hlo_text = R"(
279 HloModule SimpleGather
280
281 ENTRY main {
282 operand = s32[3,2] parameter(0)
283 indices_a = s32[5,7] parameter(1)
284 indices_b = s32[4,8] parameter(2)
285 gather_a = s32[5,3,7] gather(operand, indices_a),
286 offset_dims={1},
287 collapsed_slice_dims={1},
288 start_index_map={1},
289 index_vector_dim=2,
290 slice_sizes={3,1}
291 ROOT gather_b = s32[4,5,3,8] gather(gather_a, indices_b),
292 offset_dims={1,2},
293 collapsed_slice_dims={2},
294 start_index_map={2},
295 index_vector_dim=2,
296 slice_sizes={5,3,1}
297 }
298 )";
299
300 AssertArrayForRootExpressionIs(
301 hlo_text,
302 "(scalar-indexed %operand (scalar-indexed %indices_a %indices_b "
303 "1->[0,2]) 1->[0,1,3])");
304 }
305
TEST_F(IndexedArrayAnalysisTest,ReshapeOfGather0)306 TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather0) {
307 string hlo_text = R"(
308 HloModule ReshapeOfGather
309
310 ENTRY main {
311 operand = s32[3,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4}})
312 indices = s32[5] parameter(0)
313 gather = s32[5,4] gather(operand, indices),
314 offset_dims={1},
315 collapsed_slice_dims={0},
316 start_index_map={0},
317 index_vector_dim=1,
318 slice_sizes={1,4}
319 ROOT reshape = s32[5,2,2] reshape(gather)
320 }
321 )";
322
323 AssertArrayForRootExpressionIs(
324 hlo_text, "(scalar-indexed-const (constant s32[3,2,2]) %indices 0->[0])");
325 }
326
TEST_F(IndexedArrayAnalysisTest,ReshapeOfGather1)327 TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather1) {
328 string hlo_text = R"(
329 HloModule ReshapeOfGather
330
331 ENTRY main {
332 operand = s32[3,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4}})
333 indices = s32[5,7] parameter(0)
334 gather = s32[5,4,7] gather(operand, indices),
335 offset_dims={1},
336 collapsed_slice_dims={0},
337 start_index_map={0},
338 index_vector_dim=2,
339 slice_sizes={1,4}
340 ROOT reshape = s32[5,2,2,7] reshape(gather)
341 }
342 )";
343
344 AssertArrayForRootExpressionIs(
345 hlo_text,
346 "(scalar-indexed-const (constant s32[3,2,2]) %indices 0->[0,3])");
347 }
348
TEST_F(IndexedArrayAnalysisTest,ReshapeOfGather2)349 TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather2) {
350 string hlo_text = R"(
351 HloModule ReshapeOfGather
352
353 ENTRY main {
354 operand = s32[3,2,6] constant({
355 {{1,2,3,4,5,6},{1,2,3,4,5,6}},
356 {{1,2,3,4,5,6},{1,2,3,4,5,6}},
357 {{1,2,3,4,5,6},{1,2,3,4,5,6}}})
358 indices = s32[5,7] parameter(0)
359 gather = s32[5,2,6,7] gather(operand, indices),
360 offset_dims={1,2},
361 collapsed_slice_dims={0},
362 start_index_map={0},
363 index_vector_dim=2,
364 slice_sizes={1,2,6}
365 ROOT reshape = s32[5,3,4,7] reshape(gather)
366 }
367 )";
368
369 AssertArrayForRootExpressionIs(
370 hlo_text,
371 "(scalar-indexed-const (constant s32[3,3,4]) %indices 0->[0,3])");
372 }
373
TEST_F(IndexedArrayAnalysisTest,ReshapeOfGather3)374 TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather3) {
375 string hlo_text = R"(
376 HloModule ReshapeOfGather
377
378 ENTRY main {
379 operand = s32[2,6] constant({
380 {1,2,3,4,5,6},{1,2,3,4,5,6}})
381 indices = s32[1] parameter(0)
382 gather = s32[1,6] gather(operand, indices),
383 offset_dims={1},
384 collapsed_slice_dims={0},
385 start_index_map={0},
386 index_vector_dim=1,
387 slice_sizes={1,6}
388 ROOT reshape = s32[1,1,6] reshape(gather)
389 }
390 )";
391
392 const char* expected_root_expression = R"(
393 (scalar-indexed-const
394 (constant s32[2,1,1,6])
395 (reshape %indices to s32[])
396 0->[])
397 )";
398
399 AssertArrayForRootExpressionIs(hlo_text, expected_root_expression);
400 }
401
TEST_F(IndexedArrayAnalysisTest,ReshapeOfGather4)402 TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather4) {
403 string hlo_text = R"(
404 HloModule ReshapeOfGather
405
406 ENTRY main {
407 operand = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 1, 2, 3 } })
408
409 i.0 = s64[1,3]{1,0} parameter(0)
410 g.0 = s32[1,3,3]{2,1,0} gather(operand, i.0), offset_dims={2},
411 collapsed_slice_dims={0}, start_index_map={0},
412 index_vector_dim=2, slice_sizes={1,3}
413
414 i.1 = s64[1] parameter(1)
415 g.1 = s32[1,1,3]{2,1,0} gather(g.0, i.1), offset_dims={0,2},
416 collapsed_slice_dims={1}, start_index_map={1},
417 index_vector_dim=1, slice_sizes={1,1,3}
418
419 ROOT reshape = s32[1,3]{1,0} reshape(g.1)
420 }
421 )";
422
423 const char* expected_root_expression = R"(
424 (scalar-indexed-const
425 (constant s32[2,1,3])
426 (reshape
427 (scalar-indexed %i.0 %i.1 1->[1])
428 to s64[])
429 0->[])
430 )";
431
432 AssertArrayForRootExpressionIs(hlo_text, expected_root_expression);
433 }
434
TEST_F(IndexedArrayAnalysisTest,ReshapeOfGather5)435 TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather5) {
436 string hlo_text = R"(
437 HloModule ReshapeOfGather
438
439 ENTRY main {
440 operand = s32[1,6] constant({{1,2,3,4,5,6}})
441 indices = s32[1] parameter(0)
442 gather = s32[1,6] gather(operand, indices),
443 offset_dims={1},
444 collapsed_slice_dims={0},
445 start_index_map={0},
446 index_vector_dim=1,
447 slice_sizes={1,6}
448 ROOT reshape = s32[1,1,6] reshape(gather)
449 }
450 )";
451
452 const char* expected_root_expression = R"(
453 (scalar-indexed-const
454 (constant s32[1,1,1,6])
455 (reshape %indices to s32[])
456 0->[])
457 )";
458
459 AssertArrayForRootExpressionIs(hlo_text, expected_root_expression);
460 }
461
TEST_F(IndexedArrayAnalysisTest,ReshapeOfGather6)462 TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather6) {
463 string hlo_text = R"(
464 HloModule ReshapeOfGather
465
466 ENTRY main {
467 operand = s32[1,2,6] constant({{
468 {1,2,3,4,5,6},{1,2,3,4,5,6}}})
469 indices = s32[1] parameter(0)
470 gather = s32[1,1,6] gather(operand, indices),
471 offset_dims={1,2},
472 collapsed_slice_dims={1},
473 start_index_map={1},
474 index_vector_dim=1,
475 slice_sizes={1,1,6}
476 ROOT reshape = s32[1,1,1,6] reshape(gather)
477 }
478 )";
479
480 const char* expected_root_expression = R"(
481 (scalar-indexed-const
482 (constant s32[2,1,1,1,6] s32[2,1,1,1,6] {
483 { /*i0=0*/ { /*i1=0*/ { /*i2=0*/ { 1, 2, 3, 4, 5, 6 } } } },
484 { /*i0=1*/ { /*i1=0*/ { /*i2=0*/ { 1, 2, 3, 4, 5, 6 } } } } })
485 (reshape %indices to s32[])
486 0->[])
487 )";
488
489 AssertArrayWithConstantsForRootExpressionIs(hlo_text,
490 expected_root_expression);
491 }
492
TEST_F(IndexedArrayAnalysisTest,ReshapeOfGather7)493 TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather7) {
494 string hlo_text = R"(
495 HloModule ReshapeOfGather
496
497 ENTRY main {
498 operand = s32[2,6] constant({
499 {1,2,3,4,5,6},{1,2,3,4,5,6}})
500 indices = s32[1,5] parameter(0)
501 gather = s32[1,5,6] gather(operand, indices),
502 offset_dims={2},
503 collapsed_slice_dims={0},
504 start_index_map={0},
505 index_vector_dim=2,
506 slice_sizes={1,6}
507 ROOT reshape = s32[1,1,5,6] reshape(gather)
508 }
509 )";
510
511 const char* expected_root_expression = R"(
512 (scalar-indexed-const
513 (constant s32[2,1,1,6] s32[2,1,1,6] {
514 { /*i0=0*/ { /*i1=0*/ { 1, 2, 3, 4, 5, 6 } } },
515 { /*i0=1*/ { /*i1=0*/ { 1, 2, 3, 4, 5, 6 } } } })
516 (reshape %indices to s32[5])
517 0->[2])
518 )";
519
520 AssertArrayWithConstantsForRootExpressionIs(hlo_text,
521 expected_root_expression);
522 }
523
TEST_F(IndexedArrayAnalysisTest,ReshapeOfGatherNoFold0)524 TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold0) {
525 string hlo_text = R"(
526 HloModule ReshapeOfGather
527
528 ENTRY main {
529 operand = s32[3,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4}})
530 indices = s32[5,6] parameter(0)
531 gather = s32[5,4,6] gather(operand, indices),
532 offset_dims={1},
533 collapsed_slice_dims={0},
534 start_index_map={0},
535 index_vector_dim=2,
536 slice_sizes={1,4}
537 ROOT reshape = s32[5,2,2,2,3] reshape(gather)
538 }
539 )";
540
541 const char* expected_root_expression = R"(
542 (reshape
543 (scalar-indexed-const
544 (constant s32[3,4])
545 %indices
546 0->[0,2])
547 to s32[5,2,2,2,3])
548 )";
549
550 AssertArrayForRootExpressionIs(hlo_text, expected_root_expression);
551 }
552
TEST_F(IndexedArrayAnalysisTest,ReshapeOfGatherNoFold1)553 TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold1) {
554 string hlo_text = R"(
555 HloModule ReshapeOfGather
556
557 ENTRY main {
558 operand = s32[3,5,2] constant({
559 {{1,2},{3,4},{5,6},{7,8},{9,10}},
560 {{1,2},{3,4},{5,6},{7,8},{9,10}},
561 {{1,2},{3,4},{5,6},{7,8},{9,10}}})
562 indices = s32[7] parameter(0)
563 gather = s32[3,2,7] gather(operand, indices),
564 offset_dims={0,1},
565 collapsed_slice_dims={1},
566 start_index_map={1},
567 index_vector_dim=1,
568 slice_sizes={3,1,2}
569 ROOT reshape = s32[6,7] reshape(gather)
570 }
571 )";
572
573 const char* expected_root_expression = R"(
574 (reshape
575 (scalar-indexed-const
576 (constant s32[3,5,2])
577 %indices
578 1->[2])
579 to s32[6,7])
580 )";
581
582 AssertArrayForRootExpressionIs(hlo_text, expected_root_expression);
583 }
584
TEST_F(IndexedArrayAnalysisTest,ReshapeOfGatherNoFold2)585 TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold2) {
586 string hlo_text = R"(
587 HloModule ReshapeOfGather
588
589 ENTRY main {
590 operand = s32[3,4,1] constant({
591 {{1},{2},{3},{4}},
592 {{1},{2},{3},{4}},
593 {{1},{2},{3},{4}}})
594 indices = s32[5,6] parameter(0)
595 gather = s32[5,4,6,1] gather(operand, indices),
596 offset_dims={1,3},
597 collapsed_slice_dims={0},
598 start_index_map={0},
599 index_vector_dim=2,
600 slice_sizes={1,4,1}
601 ROOT reshape = s32[5,2,2,2,3,1] reshape(gather)
602 }
603 )";
604
605 const char* expected_root_expression = R"(
606 (reshape
607 (scalar-indexed-const
608 (constant s32[3,4,1])
609 %indices
610 0->[0,2])
611 to s32[5,2,2,2,3,1])
612 )";
613
614 AssertArrayForRootExpressionIs(hlo_text, expected_root_expression);
615 }
616
TEST_F(IndexedArrayAnalysisTest,UnaryOpOfGather)617 TEST_F(IndexedArrayAnalysisTest, UnaryOpOfGather) {
618 string hlo_text = R"(
619 HloModule UnaryOpOfGather
620
621 ENTRY main {
622 operand = f32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}})
623 indices = s32[5] parameter(0)
624 gather = f32[5,4] gather(operand, indices),
625 offset_dims={1},
626 collapsed_slice_dims={0},
627 start_index_map={0},
628 index_vector_dim=1,
629 slice_sizes={1,4}
630 ROOT tanh = f32[5,4] tanh(gather)
631 }
632 )";
633
634 AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"(
635 (scalar-indexed-const (constant f32[3,4] f32[3,4] {
636 { 0.761594, 0.964028, 0.995055, 0.999329 },
637 { 0.761594, 0.995055, 0.964028, 0.999329 },
638 { 0.999329, 0.995055, 0.964028, 0.761594 }
639 }) %indices 0->[0]))");
640 }
641
TEST_F(IndexedArrayAnalysisTest,AddBroadcastedScalarWithGather)642 TEST_F(IndexedArrayAnalysisTest, AddBroadcastedScalarWithGather) {
643 string hlo_text = R"(
644 HloModule AddBroadcastedScalarWithGather
645
646 ENTRY main {
647 gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}})
648 constant = s32[] constant(5)
649 constant_broadcasted = s32[5,4] broadcast(constant), dimensions={}
650 indices = s32[5] parameter(0)
651 gather = s32[5,4] gather(gather_operand, indices),
652 offset_dims={1},
653 collapsed_slice_dims={0},
654 start_index_map={0},
655 index_vector_dim=1,
656 slice_sizes={1,4}
657 ROOT add = s32[5,4] add(gather, constant_broadcasted)
658 }
659 )";
660
661 AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"(
662 (scalar-indexed-const (constant s32[3,4] s32[3,4] {
663 { 6, 7, 8, 9 },
664 { 6, 8, 7, 9 },
665 { 9, 8, 7, 6 }
666 }) %indices 0->[0]))");
667 }
668
TEST_F(IndexedArrayAnalysisTest,SubtractBroadcastedScalarWithGather_GatherIsLhs)669 TEST_F(IndexedArrayAnalysisTest,
670 SubtractBroadcastedScalarWithGather_GatherIsLhs) {
671 string hlo_text = R"(
672 HloModule SubtractBroadcastedScalarWithGather
673
674 ENTRY main {
675 gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}})
676 constant = s32[] constant(5)
677 constant_broadcasted = s32[5,4] broadcast(constant), dimensions={}
678 indices = s32[5] parameter(0)
679 gather = s32[5,4] gather(gather_operand, indices),
680 offset_dims={1},
681 collapsed_slice_dims={0},
682 start_index_map={0},
683 index_vector_dim=1,
684 slice_sizes={1,4}
685 ROOT sub = s32[5,4] subtract(gather, constant_broadcasted)
686 }
687 )";
688
689 AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"(
690 (scalar-indexed-const (constant s32[3,4] s32[3,4] {
691 { -4, -3, -2, -1 },
692 { -4, -2, -3, -1 },
693 { -1, -2, -3, -4 }
694 }) %indices 0->[0]))");
695 }
696
TEST_F(IndexedArrayAnalysisTest,SubtractBroadcastedScalarWithGather_GatherIsRhs)697 TEST_F(IndexedArrayAnalysisTest,
698 SubtractBroadcastedScalarWithGather_GatherIsRhs) {
699 string hlo_text = R"(
700 HloModule SubtractBroadcastedScalarWithGather
701
702 ENTRY main {
703 gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}})
704 constant = s32[] constant(5)
705 constant_broadcasted = s32[5,4] broadcast(constant), dimensions={}
706 indices = s32[5] parameter(0)
707 gather = s32[5,4] gather(gather_operand, indices),
708 offset_dims={1},
709 collapsed_slice_dims={0},
710 start_index_map={0},
711 index_vector_dim=1,
712 slice_sizes={1,4}
713 ROOT sub = s32[5,4] subtract(constant_broadcasted, gather)
714 }
715 )";
716
717 AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"(
718 (scalar-indexed-const (constant s32[3,4] s32[3,4] {
719 { 4, 3, 2, 1 },
720 { 4, 2, 3, 1 },
721 { 1, 2, 3, 4 }
722 }) %indices 0->[0]))");
723 }
724
TEST_F(IndexedArrayAnalysisTest,AddBroadcastedVectorWithGather)725 TEST_F(IndexedArrayAnalysisTest, AddBroadcastedVectorWithGather) {
726 string hlo_text = R"(
727 HloModule AddBroadcastedVectorWithGather
728
729 ENTRY main {
730 gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}})
731 constant_vect = s32[4] constant({10,11,12,13})
732 constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={1}
733 indices = s32[5] parameter(0)
734 gather = s32[5,4] gather(gather_operand, indices),
735 offset_dims={1},
736 collapsed_slice_dims={0},
737 start_index_map={0},
738 index_vector_dim=1,
739 slice_sizes={1,4}
740 ROOT add = s32[5,4] add(gather, constant_broadcasted)
741 }
742 )";
743
744 AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"(
745 (scalar-indexed-const (constant s32[3,4] s32[3,4] {
746 { 11, 13, 15, 17 },
747 { 11, 14, 14, 17 },
748 { 14, 14, 14, 14 }
749 }) %indices 0->[0]))");
750 }
751
TEST_F(IndexedArrayAnalysisTest,AddBroadcastedVectorWithGather_Negative)752 TEST_F(IndexedArrayAnalysisTest, AddBroadcastedVectorWithGather_Negative) {
753 string hlo_text = R"(
754 HloModule AddBroadcastedVectorWithGather
755
756 ENTRY main {
757 gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}})
758 constant_vect = s32[5] constant({10,11,12,13,14})
759 constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={0}
760 indices = s32[5] parameter(0)
761 gather = s32[5,4] gather(gather_operand, indices),
762 offset_dims={1},
763 collapsed_slice_dims={0},
764 start_index_map={0},
765 index_vector_dim=1,
766 slice_sizes={1,4}
767 ROOT add = s32[5,4] add(gather, constant_broadcasted)
768 }
769 )";
770
771 AssertArrayForRootExpressionIs(hlo_text, "%add");
772 }
773
TEST_F(IndexedArrayAnalysisTest,RegularUnaryOp)774 TEST_F(IndexedArrayAnalysisTest, RegularUnaryOp) {
775 string hlo_text = R"(
776 HloModule RegularUnaryOp
777
778 ENTRY main {
779 input = f32[100] parameter(0)
780 ROOT tanh = f32[100] tanh(input)
781 }
782 )";
783
784 AssertArrayForRootExpressionIs(hlo_text, "%tanh");
785 }
786
TEST_F(IndexedArrayAnalysisTest,RegularBinaryOp)787 TEST_F(IndexedArrayAnalysisTest, RegularBinaryOp) {
788 string hlo_text = R"(
789 HloModule RegularUnaryOp
790
791 ENTRY main {
792 input0 = f32[100] parameter(0)
793 input1 = f32[100] parameter(1)
794 ROOT add = f32[100] add(input0, input1)
795 }
796 )";
797
798 AssertArrayForRootExpressionIs(hlo_text, "%add");
799 }
800
TEST_F(IndexedArrayAnalysisTest,DotOpBasic_0)801 TEST_F(IndexedArrayAnalysisTest, DotOpBasic_0) {
802 string hlo_text = R"(
803 HloModule DotOp
804
805 ENTRY main {
806 gather_operand = s32[3,4] constant({{1,2,3,4},{5,6,7,8},{9,10,11,12}})
807 dot_rhs_constant = s32[4,3] constant({{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
808 indices = s32[5] parameter(0)
809 dot_lhs = s32[5,4] gather(gather_operand, indices),
810 offset_dims={1},
811 collapsed_slice_dims={0},
812 start_index_map={0},
813 index_vector_dim=1,
814 slice_sizes={1,4}
815 ROOT dot = s32[5,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={1}, rhs_contracting_dims={0}
816 }
817 )";
818
819 AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"(
820 (scalar-indexed-const
821 (constant s32[3,3] s32[3,3] {
822 { 70, 80, 90 },
823 { 158, 184, 210 },
824 { 246, 288, 330 } })
825 %indices 0->[0]))");
826 }
827
TEST_F(IndexedArrayAnalysisTest,DotOpBasic_1)828 TEST_F(IndexedArrayAnalysisTest, DotOpBasic_1) {
829 string hlo_text = R"(
830 HloModule DotOp
831
832 ENTRY main {
833 gather_operand = s32[3,4] constant({{1,2,3,4},{5,6,7,8},{9,10,11,12}})
834 dot_rhs_constant = s32[3,3] constant({{1,2,3},{4,5,6},{7,8,9}})
835 indices = s32[5] parameter(0)
836 dot_lhs = s32[3,5] gather(gather_operand, indices),
837 offset_dims={0},
838 collapsed_slice_dims={1},
839 start_index_map={1},
840 index_vector_dim=1,
841 slice_sizes={3,1}
842 ROOT dot = s32[5,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={0}, rhs_contracting_dims={0}
843 }
844 )";
845
846 AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"(
847 (scalar-indexed-const
848 (constant s32[4,3] s32[4,3] {
849 { 84, 99, 114 },
850 { 96, 114, 132 },
851 { 108, 129, 150 },
852 { 120, 144, 168 } })
853 %indices 0->[1]))");
854 }
855
TEST_F(IndexedArrayAnalysisTest,DotOpBasic_2)856 TEST_F(IndexedArrayAnalysisTest, DotOpBasic_2) {
857 string hlo_text = R"(
858 HloModule DotOp
859
860 ENTRY main {
861 gather_operand = s32[3,4] constant({{1,2,3,4},{5,6,7,8},{9,10,11,12}})
862 dot_lhs_constant = s32[4,3] constant({{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
863 indices = s32[5] parameter(0)
864 dot_rhs = s32[3,5] gather(gather_operand, indices),
865 offset_dims={0},
866 collapsed_slice_dims={1},
867 start_index_map={1},
868 index_vector_dim=1,
869 slice_sizes={3,1}
870 ROOT dot = s32[4,5] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
871 }
872 )";
873
874 AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"(
875 (scalar-indexed-const
876 (constant s32[4,4] s32[4,4] {
877 { 38, 44, 50, 56 },
878 { 83, 98, 113, 128 },
879 { 128, 152, 176, 200 },
880 { 173, 206, 239, 272 } })
881 %indices 1->[1])
882 )");
883 }
884
TEST_F(IndexedArrayAnalysisTest,DotOpBasic_3)885 TEST_F(IndexedArrayAnalysisTest, DotOpBasic_3) {
886 string hlo_text = R"(
887 HloModule DotOp
888
889 ENTRY main {
890 gather_operand = s32[4,3] constant({{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
891 dot_lhs_constant = s32[4,3] constant({{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
892 indices = s32[5] parameter(0)
893 dot_rhs = s32[5,3] gather(gather_operand, indices),
894 offset_dims={1},
895 collapsed_slice_dims={0},
896 start_index_map={0},
897 index_vector_dim=1,
898 slice_sizes={1,3}
899 ROOT dot = s32[4,5] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={1}, rhs_contracting_dims={1}
900 }
901 )";
902
903 AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"(
904 (scalar-indexed-const
905 (constant s32[4,4] s32[4,4] {
906 { 14, 32, 50, 68 },
907 { 32, 77, 122, 167 },
908 { 50, 122, 194, 266 },
909 { 68, 167, 266, 365 } })
910 %indices 1->[0])
911 )");
912 }
913
TEST_F(IndexedArrayAnalysisTest,DotOpWithBatch)914 TEST_F(IndexedArrayAnalysisTest, DotOpWithBatch) {
915 string hlo_text = R"(
916 HloModule DotOp
917
918 ENTRY main {
919 gather_operand = s32[2,3,2] constant({{{1,2},{3,4},{5,6}},{{7,8},{9,10},{11,12}}})
920 dot_lhs_constant = s32[2,2,3] constant({{{1,2,3},{4,5,6}},{{7,8,9},{10,11,12}}})
921 indices = s32[4] parameter(0)
922 dot_rhs = s32[2,3,4] gather(gather_operand, indices),
923 offset_dims={0,1},
924 collapsed_slice_dims={2},
925 start_index_map={2},
926 index_vector_dim=1,
927 slice_sizes={2,3,1}
928 ROOT dot = s32[2,2,4] dot(dot_lhs_constant, dot_rhs),
929 lhs_contracting_dims={2}, rhs_contracting_dims={1},
930 lhs_batch_dims={0}, rhs_batch_dims={0}
931 }
932 )";
933
934 AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"(
935 (scalar-indexed-const
936 (constant s32[2,2,2] s32[2,2,2] {
937 { { 22, 28 },
938 { 49, 64 } },
939 { { 220, 244 },
940 { 301, 334 } } })
941 %indices 3->[2])
942 )");
943 }
944
TEST_F(IndexedArrayAnalysisTest,DotOpNegative)945 TEST_F(IndexedArrayAnalysisTest, DotOpNegative) {
946 string hlo_text = R"(
947 HloModule DotOp
948
949 ENTRY main {
950 gather_operand = s32[3,4] constant({{1,2,3,4},{5,6,7,8},{9,10,11,12}})
951 dot_rhs_constant = s32[2,3] constant({{1,2,3},{4,5,6}})
952 indices = s32[2] parameter(0)
953 dot_lhs = s32[3,2] gather(gather_operand, indices),
954 offset_dims={0},
955 collapsed_slice_dims={1},
956 start_index_map={1},
957 index_vector_dim=1,
958 slice_sizes={3,1}
959 ROOT dot = s32[3,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={1}, rhs_contracting_dims={0}
960 }
961 )";
962
963 AssertArrayWithConstantsForRootExpressionIs(hlo_text, "%dot");
964 }
965
966 } // namespace
967 } // namespace xla
968