1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
17 #include "absl/strings/ascii.h"
18 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
19 #include "tensorflow/compiler/xla/tests/test_utils.h"
20 
21 namespace xla {
22 namespace {
23 class IndexedArrayAnalysisTest : public HloTestBase {
24  protected:
AssertArrayForRootExpressionIs(const string & hlo_text,const string & root_expression)25   void AssertArrayForRootExpressionIs(const string& hlo_text,
26                                       const string& root_expression) {
27     AssertArrayForRootExpressionIsImpl(hlo_text, root_expression,
28                                        /*print_constants=*/false);
29   }
30 
AssertArrayWithConstantsForRootExpressionIs(const string & hlo_text,const string & root_expression)31   void AssertArrayWithConstantsForRootExpressionIs(
32       const string& hlo_text, const string& root_expression) {
33     AssertArrayForRootExpressionIsImpl(hlo_text, root_expression,
34                                        /*print_constants=*/true);
35   }
36 
37  private:
38   // Replaces seqences of whitespace with a single space.  This makes the
39   // strings being matched against "whitespace insensitive" which lets us indent
40   // them for readability.
CanonicalizeWhitespace(const string & text)41   string CanonicalizeWhitespace(const string& text) {
42     string result;
43 
44     for (char c : text) {
45       if (!absl::ascii_isspace(c)) {
46         result.push_back(c);
47       } else if (!result.empty() && result.back() != ' ') {
48         result.push_back(' ');
49       }
50     }
51 
52     while (!result.empty() && result.back() == ' ') {
53       result.pop_back();
54     }
55 
56     return result;
57   }
58 
AssertArrayForRootExpressionIsImpl(const string & hlo_text,const string & root_expression,bool print_constants)59   void AssertArrayForRootExpressionIsImpl(const string& hlo_text,
60                                           const string& root_expression,
61                                           bool print_constants) {
62     IndexedArrayAnalysis indexed_tensor_analysis;
63     TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
64                             ParseAndReturnVerifiedModule(hlo_text));
65 
66     TF_ASSERT_OK_AND_ASSIGN(IndexedArrayAnalysis::Array* const array_result,
67                             indexed_tensor_analysis.GetArrayFor(
68                                 m->entry_computation()->root_instruction()));
69     string string_result = CanonicalizeWhitespace(
70         indexed_tensor_analysis.ToString(array_result, print_constants));
71     LOG(INFO) << string_result;
72     ASSERT_EQ(string_result, CanonicalizeWhitespace(root_expression));
73   }
74 };
75 
TEST_F(IndexedArrayAnalysisTest,SimpleOneToOneGather)76 TEST_F(IndexedArrayAnalysisTest, SimpleOneToOneGather) {
77   string hlo_text = R"(
78 HloModule SimpleGather
79 
80 ENTRY main {
81   operand = s32[3,3] parameter(0)
82   indices = s32[5] parameter(1)
83   ROOT gather = s32[5,3] gather(operand, indices),
84       offset_dims={1},
85       collapsed_slice_dims={0},
86       start_index_map={0},
87       index_vector_dim=1,
88       slice_sizes={1,3}
89 }
90 )";
91 
92   AssertArrayForRootExpressionIs(hlo_text,
93                                  "(scalar-indexed %operand %indices 0->[0])");
94 }
95 
TEST_F(IndexedArrayAnalysisTest,SimpleOneToOneConstantGather)96 TEST_F(IndexedArrayAnalysisTest, SimpleOneToOneConstantGather) {
97   string hlo_text = R"(
98 HloModule SimpleGather
99 
100 ENTRY main {
101   operand = s32[3,3] constant({{1,2,3},{1,2,3},{1,2,3}})
102   indices = s32[5] parameter(0)
103   ROOT gather = s32[5,3] gather(operand, indices),
104       offset_dims={1},
105       collapsed_slice_dims={0},
106       start_index_map={0},
107       index_vector_dim=1,
108       slice_sizes={1,3}
109 }
110 )";
111 
112   AssertArrayForRootExpressionIs(
113       hlo_text, "(scalar-indexed-const (constant s32[3,3]) %indices 0->[0])");
114 }
115 
TEST_F(IndexedArrayAnalysisTest,GatherIsNotScalarIndexed0)116 TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed0) {
117   string hlo_text = R"(
118 HloModule SimpleGather
119 
120 ENTRY main {
121   operand = s32[3,3] constant({{1,2,3},{1,2,3},{1,2,3}})
122   indices = s32[5,2] parameter(0)
123   ROOT gather = s32[5] gather(operand, indices),
124       offset_dims={},
125       collapsed_slice_dims={0,1},
126       start_index_map={0,1},
127       index_vector_dim=1,
128       slice_sizes={1,1}
129 }
130 )";
131 
132   AssertArrayForRootExpressionIs(hlo_text, "%gather");
133 }
134 
TEST_F(IndexedArrayAnalysisTest,GatherIsNotScalarIndexed1)135 TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed1) {
136   string hlo_text = R"(
137 HloModule SimpleGather
138 
139 ENTRY main {
140   operand = s32[3,3,1] parameter(0)
141   indices = s32[5] parameter(1)
142   ROOT gather = s32[5,3] gather(operand, indices),
143       offset_dims={1},
144       collapsed_slice_dims={0,2},
145       start_index_map={0},
146       index_vector_dim=1,
147       slice_sizes={1,3,1}
148 }
149 )";
150 
151   AssertArrayForRootExpressionIs(hlo_text, "%gather");
152 }
153 
TEST_F(IndexedArrayAnalysisTest,GatherIsNotScalarIndexed2)154 TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed2) {
155   string hlo_text = R"(
156 HloModule SimpleGather
157 
158 ENTRY main {
159   operand = s32[3,3,1] parameter(0)
160   indices = s32[5] parameter(1)
161   ROOT gather = s32[5,2,3] gather(operand, indices),
162       offset_dims={1,2},
163       collapsed_slice_dims={2},
164       start_index_map={0},
165       index_vector_dim=1,
166       slice_sizes={2,3,1}
167 }
168 )";
169 
170   AssertArrayForRootExpressionIs(hlo_text, "%gather");
171 }
172 
TEST_F(IndexedArrayAnalysisTest,GatherIsNotScalarIndexed3)173 TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed3) {
174   string hlo_text = R"(
175 HloModule SimpleGather
176 
177 ENTRY main {
178   operand = s32[3,3] parameter(0)
179   indices = s32[5] parameter(1)
180   ROOT gather = s32[5,2] gather(operand, indices),
181       offset_dims={1},
182       collapsed_slice_dims={0},
183       start_index_map={0},
184       index_vector_dim=1,
185       slice_sizes={1,2}
186 }
187 )";
188 
189   AssertArrayForRootExpressionIs(hlo_text, "%gather");
190 }
191 
TEST_F(IndexedArrayAnalysisTest,GatherOfGather_OneToOne)192 TEST_F(IndexedArrayAnalysisTest, GatherOfGather_OneToOne) {
193   string hlo_text = R"(
194 HloModule SimpleGather
195 
196 ENTRY main {
197   operand = s32[3,3] constant({{1,2,3},{1,2,3},{1,2,3}})
198   indices_a = s32[5] parameter(0)
199   indices_b = s32[2] parameter(1)
200   gather_a = s32[5,3] gather(operand, indices_a),
201       offset_dims={1},
202       collapsed_slice_dims={0},
203       start_index_map={0},
204       index_vector_dim=1,
205       slice_sizes={1,3}
206   ROOT gather_b = s32[2,3] gather(gather_a, indices_b),
207       offset_dims={1},
208       collapsed_slice_dims={0},
209       start_index_map={0},
210       index_vector_dim=1,
211       slice_sizes={1,3}
212 }
213 )";
214 
215   AssertArrayForRootExpressionIs(
216       hlo_text,
217       "(scalar-indexed-const (constant s32[3,3]) (scalar-indexed %indices_a "
218       "%indices_b 0->[0]) 0->[0])");
219 }
220 
TEST_F(IndexedArrayAnalysisTest,GatherOfGather_ManyToOneWithOneToOne)221 TEST_F(IndexedArrayAnalysisTest, GatherOfGather_ManyToOneWithOneToOne) {
222   string hlo_text = R"(
223 HloModule SimpleGather
224 
225 ENTRY main {
226   operand = s32[3,2] parameter(0)
227   indices_a = s32[5,7] parameter(1)
228   indices_b = s32[2] parameter(2)
229   gather_a = s32[5,3,7] gather(operand, indices_a),
230       offset_dims={1},
231       collapsed_slice_dims={1},
232       start_index_map={1},
233       index_vector_dim=2,
234       slice_sizes={3,1}
235   ROOT gather_b = s32[5,3,2] gather(gather_a, indices_b),
236       offset_dims={0,1},
237       collapsed_slice_dims={2},
238       start_index_map={2},
239       index_vector_dim=1,
240       slice_sizes={5,3,1}
241 }
242 )";
243 
244   AssertArrayForRootExpressionIs(hlo_text,
245                                  "(scalar-indexed %operand (scalar-indexed "
246                                  "%indices_a %indices_b 1->[1]) 1->[0,2])");
247 }
248 
TEST_F(IndexedArrayAnalysisTest,GatherOfGather_OneToOneWithManyToOne)249 TEST_F(IndexedArrayAnalysisTest, GatherOfGather_OneToOneWithManyToOne) {
250   string hlo_text = R"(
251 HloModule SimpleGather
252 
253 ENTRY main {
254   operand = s32[3,6] parameter(0)
255   indices_a = s32[2] parameter(1)
256   indices_b = s32[5,7] parameter(2)
257   gather_a = s32[2,6] gather(operand, indices_a),
258       offset_dims={1},
259       collapsed_slice_dims={0},
260       start_index_map={0},
261       index_vector_dim=1,
262       slice_sizes={1,6}
263   ROOT gather_b = s32[5,6,7] gather(gather_a, indices_b),
264       offset_dims={1},
265       collapsed_slice_dims={0},
266       start_index_map={0},
267       index_vector_dim=2,
268       slice_sizes={1,6}
269 }
270 )";
271 
272   AssertArrayForRootExpressionIs(hlo_text,
273                                  "(scalar-indexed %operand (scalar-indexed "
274                                  "%indices_a %indices_b 0->[0,1]) 0->[0,2])");
275 }
276 
TEST_F(IndexedArrayAnalysisTest,GatherOfGather_ManyToOneWithManyToOne)277 TEST_F(IndexedArrayAnalysisTest, GatherOfGather_ManyToOneWithManyToOne) {
278   string hlo_text = R"(
279 HloModule SimpleGather
280 
281 ENTRY main {
282   operand = s32[3,2] parameter(0)
283   indices_a = s32[5,7] parameter(1)
284   indices_b = s32[4,8] parameter(2)
285   gather_a = s32[5,3,7] gather(operand, indices_a),
286       offset_dims={1},
287       collapsed_slice_dims={1},
288       start_index_map={1},
289       index_vector_dim=2,
290       slice_sizes={3,1}
291   ROOT gather_b = s32[4,5,3,8] gather(gather_a, indices_b),
292       offset_dims={1,2},
293       collapsed_slice_dims={2},
294       start_index_map={2},
295       index_vector_dim=2,
296       slice_sizes={5,3,1}
297 }
298 )";
299 
300   AssertArrayForRootExpressionIs(
301       hlo_text,
302       "(scalar-indexed %operand (scalar-indexed %indices_a %indices_b "
303       "1->[0,2]) 1->[0,1,3])");
304 }
305 
TEST_F(IndexedArrayAnalysisTest,ReshapeOfGather0)306 TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather0) {
307   string hlo_text = R"(
308 HloModule ReshapeOfGather
309 
310 ENTRY main {
311   operand = s32[3,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4}})
312   indices = s32[5] parameter(0)
313   gather = s32[5,4] gather(operand, indices),
314       offset_dims={1},
315       collapsed_slice_dims={0},
316       start_index_map={0},
317       index_vector_dim=1,
318       slice_sizes={1,4}
319   ROOT reshape = s32[5,2,2] reshape(gather)
320 }
321 )";
322 
323   AssertArrayForRootExpressionIs(
324       hlo_text, "(scalar-indexed-const (constant s32[3,2,2]) %indices 0->[0])");
325 }
326 
TEST_F(IndexedArrayAnalysisTest,ReshapeOfGather1)327 TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather1) {
328   string hlo_text = R"(
329 HloModule ReshapeOfGather
330 
331 ENTRY main {
332   operand = s32[3,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4}})
333   indices = s32[5,7] parameter(0)
334   gather = s32[5,4,7] gather(operand, indices),
335       offset_dims={1},
336       collapsed_slice_dims={0},
337       start_index_map={0},
338       index_vector_dim=2,
339       slice_sizes={1,4}
340   ROOT reshape = s32[5,2,2,7] reshape(gather)
341 }
342 )";
343 
344   AssertArrayForRootExpressionIs(
345       hlo_text,
346       "(scalar-indexed-const (constant s32[3,2,2]) %indices 0->[0,3])");
347 }
348 
TEST_F(IndexedArrayAnalysisTest,ReshapeOfGather2)349 TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather2) {
350   string hlo_text = R"(
351 HloModule ReshapeOfGather
352 
353 ENTRY main {
354   operand = s32[3,2,6] constant({
355       {{1,2,3,4,5,6},{1,2,3,4,5,6}},
356       {{1,2,3,4,5,6},{1,2,3,4,5,6}},
357       {{1,2,3,4,5,6},{1,2,3,4,5,6}}})
358   indices = s32[5,7] parameter(0)
359   gather = s32[5,2,6,7] gather(operand, indices),
360       offset_dims={1,2},
361       collapsed_slice_dims={0},
362       start_index_map={0},
363       index_vector_dim=2,
364       slice_sizes={1,2,6}
365   ROOT reshape = s32[5,3,4,7] reshape(gather)
366 }
367 )";
368 
369   AssertArrayForRootExpressionIs(
370       hlo_text,
371       "(scalar-indexed-const (constant s32[3,3,4]) %indices 0->[0,3])");
372 }
373 
TEST_F(IndexedArrayAnalysisTest,ReshapeOfGather3)374 TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather3) {
375   string hlo_text = R"(
376 HloModule ReshapeOfGather
377 
378 ENTRY main {
379   operand = s32[2,6] constant({
380       {1,2,3,4,5,6},{1,2,3,4,5,6}})
381   indices = s32[1] parameter(0)
382   gather = s32[1,6] gather(operand, indices),
383       offset_dims={1},
384       collapsed_slice_dims={0},
385       start_index_map={0},
386       index_vector_dim=1,
387       slice_sizes={1,6}
388   ROOT reshape = s32[1,1,6] reshape(gather)
389 }
390 )";
391 
392   const char* expected_root_expression = R"(
393 (scalar-indexed-const
394   (constant s32[2,1,1,6])
395   (reshape %indices to s32[])
396   0->[])
397 )";
398 
399   AssertArrayForRootExpressionIs(hlo_text, expected_root_expression);
400 }
401 
TEST_F(IndexedArrayAnalysisTest,ReshapeOfGather4)402 TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather4) {
403   string hlo_text = R"(
404 HloModule ReshapeOfGather
405 
406 ENTRY main {
407   operand = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 1, 2, 3 } })
408 
409   i.0 = s64[1,3]{1,0} parameter(0)
410   g.0 = s32[1,3,3]{2,1,0} gather(operand, i.0), offset_dims={2},
411     collapsed_slice_dims={0}, start_index_map={0},
412     index_vector_dim=2, slice_sizes={1,3}
413 
414   i.1 = s64[1] parameter(1)
415   g.1 = s32[1,1,3]{2,1,0} gather(g.0, i.1), offset_dims={0,2},
416     collapsed_slice_dims={1}, start_index_map={1},
417     index_vector_dim=1, slice_sizes={1,1,3}
418 
419   ROOT reshape = s32[1,3]{1,0} reshape(g.1)
420 }
421 )";
422 
423   const char* expected_root_expression = R"(
424 (scalar-indexed-const
425   (constant s32[2,1,3])
426    (reshape
427      (scalar-indexed %i.0 %i.1 1->[1])
428      to s64[])
429   0->[])
430 )";
431 
432   AssertArrayForRootExpressionIs(hlo_text, expected_root_expression);
433 }
434 
TEST_F(IndexedArrayAnalysisTest,ReshapeOfGather5)435 TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather5) {
436   string hlo_text = R"(
437 HloModule ReshapeOfGather
438 
439 ENTRY main {
440   operand = s32[1,6] constant({{1,2,3,4,5,6}})
441   indices = s32[1] parameter(0)
442   gather = s32[1,6] gather(operand, indices),
443       offset_dims={1},
444       collapsed_slice_dims={0},
445       start_index_map={0},
446       index_vector_dim=1,
447       slice_sizes={1,6}
448   ROOT reshape = s32[1,1,6] reshape(gather)
449 }
450 )";
451 
452   const char* expected_root_expression = R"(
453 (scalar-indexed-const
454   (constant s32[1,1,1,6])
455   (reshape %indices to s32[])
456   0->[])
457 )";
458 
459   AssertArrayForRootExpressionIs(hlo_text, expected_root_expression);
460 }
461 
TEST_F(IndexedArrayAnalysisTest,ReshapeOfGather6)462 TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather6) {
463   string hlo_text = R"(
464 HloModule ReshapeOfGather
465 
466 ENTRY main {
467   operand = s32[1,2,6] constant({{
468       {1,2,3,4,5,6},{1,2,3,4,5,6}}})
469   indices = s32[1] parameter(0)
470   gather = s32[1,1,6] gather(operand, indices),
471       offset_dims={1,2},
472       collapsed_slice_dims={1},
473       start_index_map={1},
474       index_vector_dim=1,
475       slice_sizes={1,1,6}
476   ROOT reshape = s32[1,1,1,6] reshape(gather)
477 }
478 )";
479 
480   const char* expected_root_expression = R"(
481 (scalar-indexed-const
482   (constant s32[2,1,1,1,6] s32[2,1,1,1,6] {
483     { /*i0=0*/ { /*i1=0*/ { /*i2=0*/ { 1, 2, 3, 4, 5, 6 } } } },
484     { /*i0=1*/ { /*i1=0*/ { /*i2=0*/ { 1, 2, 3, 4, 5, 6 } } } } })
485   (reshape %indices to s32[])
486   0->[])
487 )";
488 
489   AssertArrayWithConstantsForRootExpressionIs(hlo_text,
490                                               expected_root_expression);
491 }
492 
TEST_F(IndexedArrayAnalysisTest,ReshapeOfGather7)493 TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather7) {
494   string hlo_text = R"(
495 HloModule ReshapeOfGather
496 
497 ENTRY main {
498   operand = s32[2,6] constant({
499       {1,2,3,4,5,6},{1,2,3,4,5,6}})
500   indices = s32[1,5] parameter(0)
501   gather = s32[1,5,6] gather(operand, indices),
502       offset_dims={2},
503       collapsed_slice_dims={0},
504       start_index_map={0},
505       index_vector_dim=2,
506       slice_sizes={1,6}
507   ROOT reshape = s32[1,1,5,6] reshape(gather)
508 }
509 )";
510 
511   const char* expected_root_expression = R"(
512 (scalar-indexed-const
513   (constant s32[2,1,1,6] s32[2,1,1,6] {
514     { /*i0=0*/ { /*i1=0*/ { 1, 2, 3, 4, 5, 6 } } },
515     { /*i0=1*/ { /*i1=0*/ { 1, 2, 3, 4, 5, 6 } } } })
516   (reshape %indices to s32[5])
517   0->[2])
518 )";
519 
520   AssertArrayWithConstantsForRootExpressionIs(hlo_text,
521                                               expected_root_expression);
522 }
523 
TEST_F(IndexedArrayAnalysisTest,ReshapeOfGatherNoFold0)524 TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold0) {
525   string hlo_text = R"(
526 HloModule ReshapeOfGather
527 
528 ENTRY main {
529   operand = s32[3,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4}})
530   indices = s32[5,6] parameter(0)
531   gather = s32[5,4,6] gather(operand, indices),
532       offset_dims={1},
533       collapsed_slice_dims={0},
534       start_index_map={0},
535       index_vector_dim=2,
536       slice_sizes={1,4}
537   ROOT reshape = s32[5,2,2,2,3] reshape(gather)
538 }
539 )";
540 
541   const char* expected_root_expression = R"(
542 (reshape
543   (scalar-indexed-const
544     (constant s32[3,4])
545     %indices
546     0->[0,2])
547   to s32[5,2,2,2,3])
548 )";
549 
550   AssertArrayForRootExpressionIs(hlo_text, expected_root_expression);
551 }
552 
TEST_F(IndexedArrayAnalysisTest,ReshapeOfGatherNoFold1)553 TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold1) {
554   string hlo_text = R"(
555 HloModule ReshapeOfGather
556 
557 ENTRY main {
558   operand = s32[3,5,2] constant({
559       {{1,2},{3,4},{5,6},{7,8},{9,10}},
560       {{1,2},{3,4},{5,6},{7,8},{9,10}},
561       {{1,2},{3,4},{5,6},{7,8},{9,10}}})
562   indices = s32[7] parameter(0)
563   gather = s32[3,2,7] gather(operand, indices),
564       offset_dims={0,1},
565       collapsed_slice_dims={1},
566       start_index_map={1},
567       index_vector_dim=1,
568       slice_sizes={3,1,2}
569   ROOT reshape = s32[6,7] reshape(gather)
570 }
571 )";
572 
573   const char* expected_root_expression = R"(
574 (reshape
575   (scalar-indexed-const
576     (constant s32[3,5,2])
577     %indices
578     1->[2])
579   to s32[6,7])
580 )";
581 
582   AssertArrayForRootExpressionIs(hlo_text, expected_root_expression);
583 }
584 
TEST_F(IndexedArrayAnalysisTest,ReshapeOfGatherNoFold2)585 TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold2) {
586   string hlo_text = R"(
587 HloModule ReshapeOfGather
588 
589 ENTRY main {
590   operand = s32[3,4,1] constant({
591     {{1},{2},{3},{4}},
592     {{1},{2},{3},{4}},
593     {{1},{2},{3},{4}}})
594   indices = s32[5,6] parameter(0)
595   gather = s32[5,4,6,1] gather(operand, indices),
596       offset_dims={1,3},
597       collapsed_slice_dims={0},
598       start_index_map={0},
599       index_vector_dim=2,
600       slice_sizes={1,4,1}
601   ROOT reshape = s32[5,2,2,2,3,1] reshape(gather)
602 }
603 )";
604 
605   const char* expected_root_expression = R"(
606 (reshape
607   (scalar-indexed-const
608     (constant s32[3,4,1])
609     %indices
610     0->[0,2])
611   to s32[5,2,2,2,3,1])
612 )";
613 
614   AssertArrayForRootExpressionIs(hlo_text, expected_root_expression);
615 }
616 
TEST_F(IndexedArrayAnalysisTest,UnaryOpOfGather)617 TEST_F(IndexedArrayAnalysisTest, UnaryOpOfGather) {
618   string hlo_text = R"(
619 HloModule UnaryOpOfGather
620 
621 ENTRY main {
622   operand = f32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}})
623   indices = s32[5] parameter(0)
624   gather = f32[5,4] gather(operand, indices),
625       offset_dims={1},
626       collapsed_slice_dims={0},
627       start_index_map={0},
628       index_vector_dim=1,
629       slice_sizes={1,4}
630   ROOT tanh = f32[5,4] tanh(gather)
631 }
632 )";
633 
634   AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"(
635 (scalar-indexed-const (constant f32[3,4] f32[3,4] {
636   { 0.761594, 0.964028, 0.995055, 0.999329 },
637   { 0.761594, 0.995055, 0.964028, 0.999329 },
638   { 0.999329, 0.995055, 0.964028, 0.761594 }
639 }) %indices 0->[0]))");
640 }
641 
TEST_F(IndexedArrayAnalysisTest,AddBroadcastedScalarWithGather)642 TEST_F(IndexedArrayAnalysisTest, AddBroadcastedScalarWithGather) {
643   string hlo_text = R"(
644 HloModule AddBroadcastedScalarWithGather
645 
646 ENTRY main {
647   gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}})
648   constant = s32[] constant(5)
649   constant_broadcasted = s32[5,4] broadcast(constant), dimensions={}
650   indices = s32[5] parameter(0)
651   gather = s32[5,4] gather(gather_operand, indices),
652       offset_dims={1},
653       collapsed_slice_dims={0},
654       start_index_map={0},
655       index_vector_dim=1,
656       slice_sizes={1,4}
657   ROOT add = s32[5,4] add(gather, constant_broadcasted)
658 }
659 )";
660 
661   AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"(
662 (scalar-indexed-const (constant s32[3,4] s32[3,4] {
663   { 6, 7, 8, 9 },
664   { 6, 8, 7, 9 },
665   { 9, 8, 7, 6 }
666 }) %indices 0->[0]))");
667 }
668 
TEST_F(IndexedArrayAnalysisTest,SubtractBroadcastedScalarWithGather_GatherIsLhs)669 TEST_F(IndexedArrayAnalysisTest,
670        SubtractBroadcastedScalarWithGather_GatherIsLhs) {
671   string hlo_text = R"(
672 HloModule SubtractBroadcastedScalarWithGather
673 
674 ENTRY main {
675   gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}})
676   constant = s32[] constant(5)
677   constant_broadcasted = s32[5,4] broadcast(constant), dimensions={}
678   indices = s32[5] parameter(0)
679   gather = s32[5,4] gather(gather_operand, indices),
680       offset_dims={1},
681       collapsed_slice_dims={0},
682       start_index_map={0},
683       index_vector_dim=1,
684       slice_sizes={1,4}
685   ROOT sub = s32[5,4] subtract(gather, constant_broadcasted)
686 }
687 )";
688 
689   AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"(
690 (scalar-indexed-const (constant s32[3,4] s32[3,4] {
691   { -4, -3, -2, -1 },
692   { -4, -2, -3, -1 },
693   { -1, -2, -3, -4 }
694 }) %indices 0->[0]))");
695 }
696 
TEST_F(IndexedArrayAnalysisTest,SubtractBroadcastedScalarWithGather_GatherIsRhs)697 TEST_F(IndexedArrayAnalysisTest,
698        SubtractBroadcastedScalarWithGather_GatherIsRhs) {
699   string hlo_text = R"(
700 HloModule SubtractBroadcastedScalarWithGather
701 
702 ENTRY main {
703   gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}})
704   constant = s32[] constant(5)
705   constant_broadcasted = s32[5,4] broadcast(constant), dimensions={}
706   indices = s32[5] parameter(0)
707   gather = s32[5,4] gather(gather_operand, indices),
708       offset_dims={1},
709       collapsed_slice_dims={0},
710       start_index_map={0},
711       index_vector_dim=1,
712       slice_sizes={1,4}
713   ROOT sub = s32[5,4] subtract(constant_broadcasted, gather)
714 }
715 )";
716 
717   AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"(
718 (scalar-indexed-const (constant s32[3,4] s32[3,4] {
719   { 4, 3, 2, 1 },
720   { 4, 2, 3, 1 },
721   { 1, 2, 3, 4 }
722 }) %indices 0->[0]))");
723 }
724 
TEST_F(IndexedArrayAnalysisTest,AddBroadcastedVectorWithGather)725 TEST_F(IndexedArrayAnalysisTest, AddBroadcastedVectorWithGather) {
726   string hlo_text = R"(
727 HloModule AddBroadcastedVectorWithGather
728 
729 ENTRY main {
730   gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}})
731   constant_vect = s32[4] constant({10,11,12,13})
732   constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={1}
733   indices = s32[5] parameter(0)
734   gather = s32[5,4] gather(gather_operand, indices),
735       offset_dims={1},
736       collapsed_slice_dims={0},
737       start_index_map={0},
738       index_vector_dim=1,
739       slice_sizes={1,4}
740   ROOT add = s32[5,4] add(gather, constant_broadcasted)
741 }
742 )";
743 
744   AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"(
745 (scalar-indexed-const (constant s32[3,4] s32[3,4] {
746   { 11, 13, 15, 17 },
747   { 11, 14, 14, 17 },
748   { 14, 14, 14, 14 }
749 }) %indices 0->[0]))");
750 }
751 
TEST_F(IndexedArrayAnalysisTest,AddBroadcastedVectorWithGather_Negative)752 TEST_F(IndexedArrayAnalysisTest, AddBroadcastedVectorWithGather_Negative) {
753   string hlo_text = R"(
754 HloModule AddBroadcastedVectorWithGather
755 
756 ENTRY main {
757   gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}})
758   constant_vect = s32[5] constant({10,11,12,13,14})
759   constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={0}
760   indices = s32[5] parameter(0)
761   gather = s32[5,4] gather(gather_operand, indices),
762       offset_dims={1},
763       collapsed_slice_dims={0},
764       start_index_map={0},
765       index_vector_dim=1,
766       slice_sizes={1,4}
767   ROOT add = s32[5,4] add(gather, constant_broadcasted)
768 }
769 )";
770 
771   AssertArrayForRootExpressionIs(hlo_text, "%add");
772 }
773 
TEST_F(IndexedArrayAnalysisTest,RegularUnaryOp)774 TEST_F(IndexedArrayAnalysisTest, RegularUnaryOp) {
775   string hlo_text = R"(
776 HloModule RegularUnaryOp
777 
778 ENTRY main {
779   input = f32[100] parameter(0)
780   ROOT tanh = f32[100] tanh(input)
781 }
782 )";
783 
784   AssertArrayForRootExpressionIs(hlo_text, "%tanh");
785 }
786 
TEST_F(IndexedArrayAnalysisTest,RegularBinaryOp)787 TEST_F(IndexedArrayAnalysisTest, RegularBinaryOp) {
788   string hlo_text = R"(
789 HloModule RegularUnaryOp
790 
791 ENTRY main {
792   input0 = f32[100] parameter(0)
793   input1 = f32[100] parameter(1)
794   ROOT add = f32[100] add(input0, input1)
795 }
796 )";
797 
798   AssertArrayForRootExpressionIs(hlo_text, "%add");
799 }
800 
TEST_F(IndexedArrayAnalysisTest,DotOpBasic_0)801 TEST_F(IndexedArrayAnalysisTest, DotOpBasic_0) {
802   string hlo_text = R"(
803 HloModule DotOp
804 
805 ENTRY main {
806   gather_operand = s32[3,4] constant({{1,2,3,4},{5,6,7,8},{9,10,11,12}})
807   dot_rhs_constant = s32[4,3] constant({{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
808   indices = s32[5] parameter(0)
809   dot_lhs = s32[5,4] gather(gather_operand, indices),
810       offset_dims={1},
811       collapsed_slice_dims={0},
812       start_index_map={0},
813       index_vector_dim=1,
814       slice_sizes={1,4}
815   ROOT dot = s32[5,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={1}, rhs_contracting_dims={0}
816 }
817 )";
818 
819   AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"(
820 (scalar-indexed-const
821   (constant s32[3,3] s32[3,3] {
822     { 70, 80, 90 },
823     { 158, 184, 210 },
824     { 246, 288, 330 } })
825   %indices 0->[0]))");
826 }
827 
TEST_F(IndexedArrayAnalysisTest,DotOpBasic_1)828 TEST_F(IndexedArrayAnalysisTest, DotOpBasic_1) {
829   string hlo_text = R"(
830 HloModule DotOp
831 
832 ENTRY main {
833   gather_operand = s32[3,4] constant({{1,2,3,4},{5,6,7,8},{9,10,11,12}})
834   dot_rhs_constant = s32[3,3] constant({{1,2,3},{4,5,6},{7,8,9}})
835   indices = s32[5] parameter(0)
836   dot_lhs = s32[3,5] gather(gather_operand, indices),
837       offset_dims={0},
838       collapsed_slice_dims={1},
839       start_index_map={1},
840       index_vector_dim=1,
841       slice_sizes={3,1}
842   ROOT dot = s32[5,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={0}, rhs_contracting_dims={0}
843 }
844 )";
845 
846   AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"(
847 (scalar-indexed-const
848   (constant s32[4,3] s32[4,3] {
849     { 84, 99, 114 },
850     { 96, 114, 132 },
851     { 108, 129, 150 },
852     { 120, 144, 168 } })
853    %indices 0->[1]))");
854 }
855 
TEST_F(IndexedArrayAnalysisTest,DotOpBasic_2)856 TEST_F(IndexedArrayAnalysisTest, DotOpBasic_2) {
857   string hlo_text = R"(
858 HloModule DotOp
859 
860 ENTRY main {
861   gather_operand = s32[3,4] constant({{1,2,3,4},{5,6,7,8},{9,10,11,12}})
862   dot_lhs_constant = s32[4,3] constant({{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
863   indices = s32[5] parameter(0)
864   dot_rhs = s32[3,5] gather(gather_operand, indices),
865       offset_dims={0},
866       collapsed_slice_dims={1},
867       start_index_map={1},
868       index_vector_dim=1,
869       slice_sizes={3,1}
870   ROOT dot = s32[4,5] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
871 }
872 )";
873 
874   AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"(
875 (scalar-indexed-const
876   (constant s32[4,4] s32[4,4] {
877     { 38, 44, 50, 56 },
878     { 83, 98, 113, 128 },
879     { 128, 152, 176, 200 },
880     { 173, 206, 239, 272 } })
881   %indices 1->[1])
882 )");
883 }
884 
TEST_F(IndexedArrayAnalysisTest,DotOpBasic_3)885 TEST_F(IndexedArrayAnalysisTest, DotOpBasic_3) {
886   string hlo_text = R"(
887 HloModule DotOp
888 
889 ENTRY main {
890   gather_operand = s32[4,3] constant({{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
891   dot_lhs_constant = s32[4,3] constant({{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
892   indices = s32[5] parameter(0)
893   dot_rhs = s32[5,3] gather(gather_operand, indices),
894       offset_dims={1},
895       collapsed_slice_dims={0},
896       start_index_map={0},
897       index_vector_dim=1,
898       slice_sizes={1,3}
899   ROOT dot = s32[4,5] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={1}, rhs_contracting_dims={1}
900 }
901 )";
902 
903   AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"(
904 (scalar-indexed-const
905   (constant s32[4,4] s32[4,4] {
906     { 14, 32, 50, 68 },
907     { 32, 77, 122, 167 },
908     { 50, 122, 194, 266 },
909     { 68, 167, 266, 365 } })
910   %indices 1->[0])
911 )");
912 }
913 
TEST_F(IndexedArrayAnalysisTest,DotOpWithBatch)914 TEST_F(IndexedArrayAnalysisTest, DotOpWithBatch) {
915   string hlo_text = R"(
916 HloModule DotOp
917 
918 ENTRY main {
919   gather_operand = s32[2,3,2] constant({{{1,2},{3,4},{5,6}},{{7,8},{9,10},{11,12}}})
920   dot_lhs_constant = s32[2,2,3] constant({{{1,2,3},{4,5,6}},{{7,8,9},{10,11,12}}})
921   indices = s32[4] parameter(0)
922   dot_rhs = s32[2,3,4] gather(gather_operand, indices),
923       offset_dims={0,1},
924       collapsed_slice_dims={2},
925       start_index_map={2},
926       index_vector_dim=1,
927       slice_sizes={2,3,1}
928   ROOT dot = s32[2,2,4] dot(dot_lhs_constant, dot_rhs),
929       lhs_contracting_dims={2}, rhs_contracting_dims={1},
930       lhs_batch_dims={0}, rhs_batch_dims={0}
931 }
932 )";
933 
934   AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"(
935 (scalar-indexed-const
936   (constant s32[2,2,2] s32[2,2,2] {
937     { { 22, 28 },
938       { 49, 64 } },
939     { { 220, 244 },
940       { 301, 334 } } })
941   %indices 3->[2])
942 )");
943 }
944 
TEST_F(IndexedArrayAnalysisTest,DotOpNegative)945 TEST_F(IndexedArrayAnalysisTest, DotOpNegative) {
946   string hlo_text = R"(
947 HloModule DotOp
948 
949 ENTRY main {
950   gather_operand = s32[3,4] constant({{1,2,3,4},{5,6,7,8},{9,10,11,12}})
951   dot_rhs_constant = s32[2,3] constant({{1,2,3},{4,5,6}})
952   indices = s32[2] parameter(0)
953   dot_lhs = s32[3,2] gather(gather_operand, indices),
954       offset_dims={0},
955       collapsed_slice_dims={1},
956       start_index_map={1},
957       index_vector_dim=1,
958       slice_sizes={3,1}
959   ROOT dot = s32[3,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={1}, rhs_contracting_dims={0}
960 }
961 )";
962 
963   AssertArrayWithConstantsForRootExpressionIs(hlo_text, "%dot");
964 }
965 
966 }  // namespace
967 }  // namespace xla
968