1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/container/inlined_vector.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "absl/types/optional.h"
25 #include "tensorflow/compiler/xla/map_util.h"
26 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
27 #include "tensorflow/compiler/xla/util.h"
28 
29 namespace xla {
30 
31 namespace {
32 using Analysis = IndexedArrayAnalysis;
33 using UnknownArray = Analysis::UnknownArray;
34 using ConstantArray = Analysis::ConstantArray;
35 using ReshapedArray = Analysis::ReshapedArray;
36 using ScalarIndexedArray = Analysis::ScalarIndexedArray;
37 using absl::StrJoin;
38 }  // namespace
39 
ToString(Array * root,bool print_constants)40 string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) {
41   switch (root->kind()) {
42     case Array::kUnknown: {
43       auto* unknown_tensor = root->as<UnknownArray>();
44       return absl::StrCat("%", unknown_tensor->instruction().name());
45     }
46 
47     case Array::kConstant: {
48       if (print_constants) {
49         string contents = root->as<ConstantArray>()->literal()->ToString();
50         return absl::StrCat("(constant ", ShapeUtil::HumanString(root->shape()),
51                             " ", contents, ")");
52       }
53       return absl::StrCat("(constant ", ShapeUtil::HumanString(root->shape()),
54                           ")");
55     }
56 
57     case Array::kReshaped: {
58       ReshapedArray* reshaped_array = root->as<ReshapedArray>();
59       return absl::StrCat(
60           "(reshape ", ToString(reshaped_array->operand(), print_constants),
61           " to ", ShapeUtil::HumanString(reshaped_array->shape()), ")");
62     }
63 
64     case Array::kScalarIndexedConstant:
65     case Array::kScalarIndexed: {
66       auto* indexed_array = root->as<ScalarIndexedArray>();
67       string name = root->kind() == Array::kScalarIndexedConstant
68                         ? "scalar-indexed-const"
69                         : "scalar-indexed";
70       return absl::StrCat(
71           "(", name, " ", ToString(indexed_array->source(), print_constants),
72           " ", ToString(indexed_array->indices(), print_constants), " ",
73           indexed_array->source_dim(), "->[",
74           StrJoin(indexed_array->output_dims(), ","), "])");
75     }
76   }
77 }
78 
GetArrayFor(const HloInstruction * instr)79 StatusOr<Analysis::Array*> IndexedArrayAnalysis::GetArrayFor(
80     const HloInstruction* instr) {
81   auto it = cache_.find(instr);
82   if (it != cache_.end()) {
83     return it->second;
84   }
85 
86   TF_RETURN_IF_ERROR(TraverseAndPopulateCache(instr));
87   return FindOrDie(cache_, instr);
88 }
89 
TraverseAndPopulateCache(const HloInstruction * root)90 Status IndexedArrayAnalysis::TraverseAndPopulateCache(
91     const HloInstruction* root) {
92   // Depth first search over the DAG, invoking ComputeArrayFor in post order.
93   // The HLO instructions already in the cache are considered leaves.
94 
95   absl::InlinedVector<const HloInstruction*, 4> stack;
96 
97   enum DfsState { kDiscovered, kVisited };
98   absl::flat_hash_map<const HloInstruction*, DfsState> dfs_state_map;
99 
100   stack.push_back(root);
101   InsertOrDie(&dfs_state_map, root, kDiscovered);
102 
103   do {
104     const HloInstruction* instr = stack.back();
105     if (cache_.contains(instr)) {
106       stack.pop_back();
107       continue;
108     }
109 
110     switch (FindOrDie(dfs_state_map, instr)) {
111       case kDiscovered: {
112         for (const HloInstruction* operand : instr->operands()) {
113           if (!cache_.contains(operand)) {
114             stack.push_back(operand);
115             CHECK(!dfs_state_map.contains(operand) ||
116                   dfs_state_map[operand] == kDiscovered);
117             dfs_state_map[operand] = kDiscovered;
118           }
119         }
120         dfs_state_map[instr] = kVisited;
121         break;
122       }
123 
124       case kVisited:
125         stack.pop_back();
126         TF_ASSIGN_OR_RETURN(Array * array, ComputeArrayFor(instr));
127         InsertOrDie(&cache_, instr, array);
128         break;
129     }
130   } while (!stack.empty());
131 
132   return Status::OK();
133 }
134 
ComputeArrayFor(const HloInstruction * instr)135 StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayFor(
136     const HloInstruction* instr) {
137   Array* computed_array;
138   if (instr->IsElementwise() && instr->operand_count() == 1) {
139     TF_ASSIGN_OR_RETURN(
140         computed_array,
141         ComputeArrayForElementwiseUnaryOp(
142             instr->opcode(), FindOrDie(cache_, instr->operand(0))));
143   } else if (instr->IsElementwise() && instr->operand_count() == 2) {
144     TF_ASSIGN_OR_RETURN(
145         computed_array,
146         ComputeArrayForElementwiseBinaryOp(
147             instr->opcode(), FindOrDie(cache_, instr->operand(0)),
148             FindOrDie(cache_, instr->operand(1))));
149   } else if (instr->opcode() == HloOpcode::kConstant) {
150     TF_ASSIGN_OR_RETURN(computed_array,
151                         ComputeArrayForConstant(instr->literal()));
152   } else if (instr->opcode() == HloOpcode::kGather) {
153     TF_ASSIGN_OR_RETURN(
154         computed_array,
155         ComputeArrayForGather(instr->shape(), instr->gather_dimension_numbers(),
156                               instr->gather_slice_sizes(),
157                               FindOrDie(cache_, instr->operand(0)),
158                               FindOrDie(cache_, instr->operand(1))));
159   } else if (instr->opcode() == HloOpcode::kReshape) {
160     TF_ASSIGN_OR_RETURN(
161         computed_array,
162         ComputeArrayForReshape(instr->shape(),
163                                FindOrDie(cache_, instr->operand(0))));
164   } else if (instr->opcode() == HloOpcode::kDot) {
165     TF_ASSIGN_OR_RETURN(
166         computed_array,
167         ComputeArrayForDot(instr->shape(), instr->dot_dimension_numbers(),
168                            instr->precision_config(),
169                            FindOrDie(cache_, instr->operand(0)),
170                            FindOrDie(cache_, instr->operand(1))));
171   } else {
172     computed_array = nullptr;
173   }
174 
175   if (!computed_array) {
176     computed_array = Construct<UnknownArray>(instr);
177   }
178 
179   return computed_array;
180 }
181 
ComputeArrayForConstant(const Literal & literal)182 StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForConstant(
183     const Literal& literal) {
184   return Construct<ConstantArray>(&literal);
185 }
186 
FoldGatherOfGather(ScalarIndexedArray * source,Array * indices,int64 source_dim,absl::Span<const int64> output_dims,Shape shape)187 StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::FoldGatherOfGather(
188     ScalarIndexedArray* source, Array* indices, int64 source_dim,
189     absl::Span<const int64> output_dims, Shape shape) {
190   // We want to transform Gather(Gather(A, X), Y) => Gather(A, Gather(X, Y)).
191   // `source` is the inner Gather(A, X).
192 
193   Array* a = source->source();
194   Array* x = source->indices();
195   Array* y = indices;
196 
197   // This bit is slightly tricky, so we do a naive "simulation" of the two
198   // consecutive gather operations to infer what the composed gather should look
199   // like.
200 
201   enum class IndexComponent { Ungathered, GatheredFirst, GatheredSecond };
202 
203   std::vector<IndexComponent> simulated_index(a->shape().dimensions_size(),
204                                               IndexComponent::Ungathered);
205 
206   // Simulate the first gather.
207   EraseAt(&simulated_index, source->source_dim());
208   for (int64 gather_dim : source->output_dims()) {
209     simulated_index.insert(simulated_index.begin() + gather_dim,
210                            IndexComponent::GatheredFirst);
211   }
212 
213   // Simulate the second gather.
214   EraseAt(&simulated_index, source_dim);
215   for (int64 output_dim : output_dims) {
216     simulated_index.insert(simulated_index.begin() + output_dim,
217                            IndexComponent::GatheredSecond);
218   }
219 
220   int64 source_dim_for_index_array =
221       FindIndex(source->output_dims(), source_dim);
222   CHECK_NE(source_dim_for_index_array, source->output_dims().size());
223 
224   std::vector<int64> output_dims_for_index_array;
225   int64 gathered_index_components_seen = 0;
226   for (IndexComponent simulation_dim : simulated_index) {
227     if (simulation_dim == IndexComponent::GatheredSecond) {
228       output_dims_for_index_array.push_back(gathered_index_components_seen);
229     }
230     if (simulation_dim != IndexComponent::Ungathered) {
231       gathered_index_components_seen++;
232     }
233   }
234 
235   std::vector<int64> dim_sizes_for_composed_index;
236   std::vector<int64> output_dims_for_new_gather;
237   for (int64 i = 0, e = simulated_index.size(); i < e; i++) {
238     if (simulated_index[i] != IndexComponent::Ungathered) {
239       dim_sizes_for_composed_index.push_back(shape.dimensions(i));
240       output_dims_for_new_gather.push_back(i);
241     }
242   }
243 
244   Array* inner_indices = ConstructScalarIndexedArray(
245       x, y, source_dim_for_index_array, output_dims_for_index_array,
246       ShapeUtil::MakeShape(x->shape().element_type(),
247                            dim_sizes_for_composed_index));
248   return ConstructScalarIndexedArray(a, inner_indices, source->source_dim(),
249                                      output_dims_for_new_gather,
250                                      std::move(shape));
251 }
252 
ComputeArrayForGather(const Shape & shape,const GatherDimensionNumbers & dim_numbers,absl::Span<const int64> slice_sizes,Array * source,Array * indices)253 StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForGather(
254     const Shape& shape, const GatherDimensionNumbers& dim_numbers,
255     absl::Span<const int64> slice_sizes, Array* source, Array* indices) {
256   if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) {
257     VLOG(3) << "ComputeArrayForGather: indices are not scalar";
258     return nullptr;
259   }
260 
261   CHECK_EQ(dim_numbers.start_index_map_size(), 1);
262 
263   // We can also handle dim_numbers.collapsed_slice_dims_size() == 0 here,
264   // should it become relevant.
265 
266   if (dim_numbers.collapsed_slice_dims_size() != 1 ||
267       dim_numbers.collapsed_slice_dims(0) != dim_numbers.start_index_map(0)) {
268     VLOG(3) << "ComputeArrayForGather: gather operations must elide "
269                "start_index_map[0] and "
270                "start_index_map[0] only";
271     return nullptr;
272   }
273 
274   // ScalarIndexedArray cannot represent gathers that "slice" along some
275   // dimensions -- for instance it cannot represent a gather that picks 5 [2,3]
276   // arrays from an array of size [7,4,6].  We check that condition down below:
277 
278   for (int64 i = 0, e = source->shape().dimensions_size(); i < e; i++) {
279     if (i != dim_numbers.collapsed_slice_dims(0) &&
280         source->shape().dimensions(i) != slice_sizes[i]) {
281       VLOG(3) << "ComputeArrayForGather: slice_sizes[" << i
282               << "] != source->shape().dimensions(" << i << ") -- "
283               << source->shape().dimensions(i) << " vs. " << slice_sizes[i]
284               << " with dim_numbers.collapsed_slice_dims(0) = "
285               << dim_numbers.collapsed_slice_dims(0);
286       return nullptr;
287     }
288   }
289 
290   int64 source_dim = dim_numbers.start_index_map(0);
291   std::vector<int64> output_dims;
292   for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) {
293     if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) {
294       output_dims.push_back(i);
295     }
296   }
297 
298   if (auto* indexed = dynamic_cast<ScalarIndexedArray*>(source)) {
299     if (absl::c_linear_search(indexed->output_dims(), source_dim)) {
300       return FoldGatherOfGather(indexed, indices, source_dim, output_dims,
301                                 shape);
302     }
303   } else if (auto* constant = dynamic_cast<ConstantArray*>(source)) {
304     return Construct<ScalarIndexedConstantArray>(constant, indices, source_dim,
305                                                  output_dims, shape);
306   }
307 
308   return Construct<ScalarIndexedArray>(source, indices, source_dim, output_dims,
309                                        shape);
310 }
311 
312 namespace {
313 // Returns an index into `values` such that the product of the range
314 // [values.begin()+index, values.end()) is equal to `product`.  If there is no
315 // such index, return -1.  All integers in `values` must be positive.
FindSuffixWithProduct(absl::Span<const int64> values,int64 product)316 int64 FindSuffixWithProduct(absl::Span<const int64> values, int64 product) {
317   DCHECK(absl::c_all_of(values, [](int64 value) { return value > 0; }));
318 
319   int64 current_product = 1;
320   int64 i;
321   for (i = values.size() - 1; i >= 0 && product > current_product; --i) {
322     current_product *= values[i];
323   }
324 
325   if (product == current_product) {
326     return i + 1;
327   }
328 
329   return -1;
330 }
331 
332 struct ReshapePassthroughDimPair {
333   int64 result_dim;
334   int64 operand_dim;
335 };
336 
337 // Returns a set of dimension pairs such for all (result_dim, operand_dim) in
338 // the set:
339 //
340 // output_index[result_dim] = SourceIndexOfReshape(output_index)[operand_dim]
341 //
342 // The returned vector of pairs is sorted in both the result_dim and the
343 // operand_dim components.
ComputeReshapePassthroughDimPairs(absl::Span<const int64> operand_shape,absl::Span<const int64> result_shape)344 std::vector<ReshapePassthroughDimPair> ComputeReshapePassthroughDimPairs(
345     absl::Span<const int64> operand_shape,
346     absl::Span<const int64> result_shape) {
347   // A reshape can be seen as an index mapping from output index to input index:
348   //
349   // (i_0, ..., i_n) = f(o_0, ..., o_m)
350   //
351   // This function returns the pairs (j, k) for which the following invariant
352   // holds for all indices in the shape:
353   //
354   //   o_j == i_k
355   //
356   // And this occurs when:
357   //
358   //    O_{j+1} * ... * O_n == I_{k+1} * ...  * I_m
359   //
360   // (where O_x are the sizes of the output shape and I_x are the sizes of the
361   // input shape) and the size of the dimension j of the result is the same as
362   // the size of dimension k in the operand.
363   //
364   // These conditions are sufficient because the Reshape HLO is spec'ed such
365   // that the rightmost dimensions are always minor in the flattening and refine
366   // operation.
367 
368   std::vector<ReshapePassthroughDimPair> result;
369   int64 result_subarray_size = 1;
370   for (int64 result_dim = result_shape.size() - 1; result_dim >= 0;
371        --result_dim) {
372     int64 candidate_operand_dim =
373         FindSuffixWithProduct(operand_shape, result_subarray_size);
374 
375     // result_subarray_size does not include the elements in the current
376     // `result_dim` dimension (we multiply in result_shape[result_dim] at the
377     // end of loop body) so candidate_operand_dim can never be zero.
378     CHECK_NE(candidate_operand_dim, 0)
379         << "result_dim = " << result_dim
380         << ", result_subarray_size = " << result_subarray_size
381         << ", result_shape = [" << StrJoin(result_shape, ",") << "]"
382         << ", operand_shape = [" << StrJoin(operand_shape, ",") << "]";
383 
384     if (candidate_operand_dim != -1 &&
385         result_shape[result_dim] == operand_shape[candidate_operand_dim - 1]) {
386       result.push_back({/*result_dim=*/result_dim,
387                         /*operand_dim=*/candidate_operand_dim - 1});
388     }
389     result_subarray_size *= result_shape[result_dim];
390   }
391 
392   absl::c_reverse(result);
393 
394   if (VLOG_IS_ON(3)) {
395     std::vector<string> result_strings;
396     absl::c_transform(result, std::back_inserter(result_strings),
397                       [](ReshapePassthroughDimPair value) {
398                         return absl::StrCat(value.result_dim, "->",
399                                             value.operand_dim);
400                       });
401     VLOG(3) << "For a reshape from [" << StrJoin(operand_shape, ",") << "] to ["
402             << StrJoin(result_shape, ",") << "] passthrough indices are ["
403             << StrJoin(result_strings, ",")
404             << "] (legend: `result`->`operand`)";
405   }
406 
407   DCHECK(absl::c_is_sorted(
408       result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) {
409         return lhs.result_dim < rhs.result_dim;
410       }));
411 
412   DCHECK(absl::c_is_sorted(
413       result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) {
414         return lhs.operand_dim < rhs.operand_dim;
415       }));
416 
417   return result;
418 }
419 
420 // Return true if `dim` is stated as an passthrough operand dim in
421 // `passthrough_dims`.
IsReshapePassthroughOperandDim(absl::Span<const ReshapePassthroughDimPair> passthrough_dims,int64 dim)422 bool IsReshapePassthroughOperandDim(
423     absl::Span<const ReshapePassthroughDimPair> passthrough_dims, int64 dim) {
424   return absl::c_any_of(passthrough_dims,
425                         [&](ReshapePassthroughDimPair passthrough_dim_pair) {
426                           return passthrough_dim_pair.operand_dim == dim;
427                         });
428 }
429 
430 // Maps `operand_dim` which must be an passthrough operand dimension to its
431 // corresponding passthrough result dimension based on `passthrough_dims`.
MapPassthroughOperandDimToResultDim(absl::Span<const ReshapePassthroughDimPair> passthrough_dims,int64 operand_dim)432 int64 MapPassthroughOperandDimToResultDim(
433     absl::Span<const ReshapePassthroughDimPair> passthrough_dims,
434     int64 operand_dim) {
435   auto it = absl::c_find_if(
436       passthrough_dims, [&](ReshapePassthroughDimPair passthrough_dim_pair) {
437         return passthrough_dim_pair.operand_dim == operand_dim;
438       });
439   CHECK(it != passthrough_dims.end());
440   return it->result_dim;
441 }
442 
FindSourcePositionForPassthroughResultDim(absl::Span<const int64> operand_shape,absl::Span<const int64> result_shape,int64 source_passthrough_dim)443 int64 FindSourcePositionForPassthroughResultDim(
444     absl::Span<const int64> operand_shape, absl::Span<const int64> result_shape,
445     int64 source_passthrough_dim) {
446   VLOG(3) << "FindSourcePositionForPassthroughResultDim(["
447           << StrJoin(operand_shape, ",") << "], [" << StrJoin(result_shape, ",")
448           << "], " << source_passthrough_dim << ")";
449 
450   int64 indexed_source_subarray_size =
451       std::accumulate(operand_shape.begin() + source_passthrough_dim + 1,
452                       operand_shape.end(), 1LL, std::multiplies<int64>());
453 
454   return FindSuffixWithProduct(result_shape, indexed_source_subarray_size);
455 }
456 
StripDegenerateDimensions(const Shape & shape)457 Shape StripDegenerateDimensions(const Shape& shape) {
458   DimensionVector new_dims;
459   absl::c_copy_if(shape.dimensions(), std::back_inserter(new_dims),
460                   [](int64 dim) { return dim != 1; });
461   return ShapeUtil::MakeShape(shape.element_type(), new_dims);
462 }
463 };  // namespace
464 
465 StatusOr<ScalarIndexedArray*>
ReshapeToRemoveDegenerateDims(ScalarIndexedArray * operand)466 IndexedArrayAnalysis::ReshapeToRemoveDegenerateDims(
467     ScalarIndexedArray* operand) {
468   const Shape& shape = operand->shape();
469   if (!ShapeUtil::HasDegenerateDimensions(shape)) {
470     return operand;
471   }
472 
473   // We only need to reshape out the degenerate dims from the indices and the
474   // source (except the source dim).
475 
476   const Shape& source_shape = operand->source()->shape();
477   DimensionVector new_source_shape_dims;
478   for (int64 i = 0, e = source_shape.dimensions_size(); i < e; i++) {
479     if (i == operand->source_dim() || source_shape.dimensions(i) != 1) {
480       new_source_shape_dims.push_back(source_shape.dimensions(i));
481     }
482   }
483 
484   Shape new_source_shape =
485       ShapeUtil::MakeShape(shape.element_type(), new_source_shape_dims);
486   Shape new_indices_shape =
487       StripDegenerateDimensions(operand->indices()->shape());
488 
489   TF_ASSIGN_OR_RETURN(
490       Array* const new_source,
491       ComputeArrayForReshape(new_source_shape, operand->source()));
492   TF_ASSIGN_OR_RETURN(
493       Array* const new_indices,
494       ComputeArrayForReshape(new_indices_shape, operand->indices()));
495 
496   // Build the new output dims while keeping track of the degenerate dims that
497   // will no longer be present.
498   DimensionVector new_output_dims;
499   int64 degenerate_dims_seen = 0;
500   for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) {
501     if (shape.dimensions(i) == 1) {
502       degenerate_dims_seen++;
503     } else if (absl::c_linear_search(operand->output_dims(), i)) {
504       new_output_dims.push_back(i - degenerate_dims_seen);
505     }
506   }
507 
508   // Similarly, build the new source dim while keeping track of the degenerate
509   // dims that will no longer be present.
510   int64 degenerate_dims_before_source_dim =
511       std::count(source_shape.dimensions().begin(),
512                  source_shape.dimensions().begin() + operand->source_dim(), 1);
513   int64 new_source_dim =
514       operand->source_dim() - degenerate_dims_before_source_dim;
515 
516   return ConstructScalarIndexedArray(
517       new_source, new_indices, new_source_dim,
518       InlinedVectorToVector(new_output_dims),
519       StripDegenerateDimensions(operand->shape()));
520 }
521 
ReshapeToAddDegenerateDims(ScalarIndexedArray * operand,absl::Span<const int64> degenerate_dims)522 StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::ReshapeToAddDegenerateDims(
523     ScalarIndexedArray* operand, absl::Span<const int64> degenerate_dims) {
524   if (degenerate_dims.empty()) {
525     return operand;
526   }
527 
528   CHECK(!ShapeUtil::HasDegenerateDimensions(operand->shape()));
529 
530   DimensionVector new_output_dims = [&]() {
531     // To make things easy we use a "scratch" buffer of bools where the i'th
532     // element is true iff the i'th component of the result index is an output
533     // index.
534 
535     absl::InlinedVector<bool, 6> output_dims_bitvector(
536         operand->shape().dimensions_size());
537     for (int64 output_dim : operand->output_dims()) {
538       output_dims_bitvector[output_dim] = true;
539     }
540 
541     for (int64 degenerate_dim : degenerate_dims) {
542       InsertAt(&output_dims_bitvector, degenerate_dim, false);
543     }
544 
545     DimensionVector result;
546     result.reserve(operand->output_dims().size());
547     for (int64 i = 0, e = output_dims_bitvector.size(); i < e; i++) {
548       if (output_dims_bitvector[i]) {
549         result.push_back(i);
550       }
551     }
552 
553     return result;
554   }();
555 
556   DimensionVector new_result_shape_dims;
557   absl::c_copy(operand->shape().dimensions(),
558                std::back_inserter(new_result_shape_dims));
559   for (int64 degenerate_dim : degenerate_dims) {
560     InsertAt(&new_result_shape_dims, degenerate_dim, 1);
561   }
562 
563   DimensionVector new_source_shape_dims = new_result_shape_dims;
564   for (int64 output_dim : new_output_dims) {
565     EraseAt(&new_source_shape_dims, output_dim);
566   }
567 
568   int64 new_source_dim = [&]() {
569     for (int i = 0, e = new_source_shape_dims.size(); i < e; i++) {
570       int64 non_degenerate_dims_seen = 0;
571       if (non_degenerate_dims_seen == operand->source_dim()) {
572         return i;
573       }
574       if (new_source_shape_dims[new_source_dim] != 1) {
575         non_degenerate_dims_seen++;
576       }
577     }
578     LOG(FATAL) << "Did not find source dim in " << ToString(operand);
579   }();
580 
581   int64 source_dim_size =
582       operand->source()->shape().dimensions(operand->source_dim());
583   InsertAt(&new_source_shape_dims, /*index=*/new_source_dim,
584            /*value=*/source_dim_size);
585 
586   Shape new_source_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
587                                                 new_source_shape_dims);
588   Shape new_result_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
589                                                 new_result_shape_dims);
590 
591   TF_ASSIGN_OR_RETURN(
592       Array* const new_source,
593       ComputeArrayForReshape(new_source_shape, operand->source()));
594   return ConstructScalarIndexedArray(
595       new_source, operand->indices(), new_source_dim,
596       InlinedVectorToVector(new_output_dims), new_result_shape);
597 }
598 
FoldReshapeOfGather(const Shape & shape,ScalarIndexedConstantArray * operand)599 StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::FoldReshapeOfGather(
600     const Shape& shape, ScalarIndexedConstantArray* operand) {
601   VLOG(3) << "FoldReshapeOfGather(" << ToString(operand) << ")";
602 
603   // To make things easier on ourselves, instead of directly trying to fold the
604   // reshape of `operand` to `shape`, we call
605   // `FoldReshapeOfGatherNoDegenerateDims` on shapes without degenerate dims and
606   // handle the degenerate dimensions here by inserting reshapes.
607 
608   TF_ASSIGN_OR_RETURN(ScalarIndexedArray* const operand_without_degenerate_dims,
609                       ReshapeToRemoveDegenerateDims(operand));
610 
611   Shape output_shape_without_degenerate_dims = StripDegenerateDimensions(shape);
612   TF_ASSIGN_OR_RETURN(
613       ScalarIndexedArray* const folded_reshape_without_degenerate_dims,
614       FoldReshapeOfGatherNoDegenerateDims(
615           output_shape_without_degenerate_dims,
616           operand_without_degenerate_dims->as<ScalarIndexedConstantArray>()));
617 
618   if (folded_reshape_without_degenerate_dims == nullptr) {
619     return nullptr;
620   }
621 
622   DimensionVector degenerate_result_dims;
623   for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) {
624     if (shape.dimensions(i) == 1) {
625       degenerate_result_dims.push_back(i);
626     }
627   }
628 
629   return ReshapeToAddDegenerateDims(folded_reshape_without_degenerate_dims,
630                                     degenerate_result_dims);
631 }
632 
633 StatusOr<ScalarIndexedArray*>
FoldReshapeOfGatherNoDegenerateDims(const Shape & shape,ScalarIndexedConstantArray * scalar_indexed)634 IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims(
635     const Shape& shape, ScalarIndexedConstantArray* scalar_indexed) {
636   VLOG(3) << "FoldReshapeOfGatherNoDegenerateDims(" << ToString(scalar_indexed)
637           << ")";
638   CHECK(!ShapeUtil::HasDegenerateDimensions(shape));
639   CHECK(!ShapeUtil::HasDegenerateDimensions(scalar_indexed->shape()));
640 
641   // Try to fold Reshape(ScalarIndexed(Const, Indices))
642   //          => ScalarIndexed(Const', Indices)
643   //
644   // We can view the reshape and the scalar-indexed operations as functions that
645   // map an output index (i.e. an index into the result) to an input index
646   // (i.e. an index into the operand).  The key idea used here is that the
647   // output-to-input mapping for some reshape operations may "pass through" some
648   // output dimensions into the input space unchanged -- i.e. there may exist
649   // output dimension "O" and input dimension "I" such that OutputIndex[O] is
650   // always == InputIndexForReshape(OutputIndex)[I].  If these pass-through
651   // dimensions in the input space of the reshape happen to be include all the
652   // output dimensions for the scalar-indexed node then, roughly, the following
653   // holds:
654   //
655   //    SourceIndexOfScalarIndexed(SourceIndexOfReshape(Idx))
656   // == SourceIndexOfScalarIndexed(SourceIndexOfReshape(Ps ++ Qs))
657   //
658   //      Where Ps are the set of the pass-through components of Idx that are
659   //      also the output dims of the scalar-indexed node, and Qs are the rest.
660   //      For brevity, we're playing fast and loose with the notation here -- we
661   //      don't literally require Idx to be a concatenation of Ps and Qs, as
662   //      suggested by the "++".
663   //
664   // == SourceIndexOfScalarIndexed(Ps ++ SourceIndexOfReshape(Qs))
665   //
666   //      Again, we're playing fast and loose with the notation around "++".
667   //      Generally this ++ will be a different function that the ++ in the
668   //      previous step.
669   //
670   // If the scalar-indexed node has a constant as the source then the
671   // SourceIndexOfReshape function can be "folded into" the constant itself by
672   // reshaping it, leaving us with:
673   //
674   // == SourceIndexOfScalarIndexed(Ps ++ Qs)
675   // == SourceIndexOfScalarIndexed(Idx)
676   //
677   // which is just a scalar-indexed node (with parameters different from the
678   // scalar-indexed node we started with) with a reshaped constant as the
679   // source.
680   //
681   // We can't fold SourceIndexOfReshape into the constant without introducing
682   // another precondition: since the new scalar-indexed node will have a
683   // reshaped (constant) array as its source it will, in general, have a
684   // different source dimension than the original scalar-indexed node.  This
685   // source dimension will have to be a passthrough dimension of the
686   // SourceIndexOfReshape indexing function that is folded into the source. And
687   // such a dimension need not exist so this is a non-trivial precondition.
688 
689   std::vector<ReshapePassthroughDimPair> reshape_passthrough_dims =
690       ComputeReshapePassthroughDimPairs(
691           /*operand_shape=*/AsInt64Slice(scalar_indexed->shape().dimensions()),
692           /*result_shape=*/AsInt64Slice(shape.dimensions()));
693 
694   auto is_reshape_passthrough_operand_dim = [&](int64 operand_dim) {
695     return IsReshapePassthroughOperandDim(reshape_passthrough_dims,
696                                           operand_dim);
697   };
698 
699   if (!absl::c_all_of(scalar_indexed->output_dims(),
700                       is_reshape_passthrough_operand_dim)) {
701     VLOG(3) << "Not all output dims are passthrough dims "
702             << ToString(scalar_indexed);
703     return nullptr;
704   }
705 
706   // To compute the shape of the source for the new scalar-indexed node we're
707   // going to create, we first "undo" the scalar-indexed operation.
708   std::vector<int64> new_scalar_indexed_source_shape(shape.dimensions().begin(),
709                                                      shape.dimensions().end());
710   for (int64 i = scalar_indexed->output_dims().size() - 1; i >= 0; i--) {
711     int64 output_dim = scalar_indexed->output_dims()[i];
712     int64 output_dim_after_reshape = MapPassthroughOperandDimToResultDim(
713         reshape_passthrough_dims, output_dim);
714     EraseAt(&new_scalar_indexed_source_shape, output_dim_after_reshape);
715   }
716 
717   // After this, we need to add in the dimension that will be the source
718   // dimension for the new scalar-indexed node.  A scalar-indexed node "removes"
719   // the source dimensions and "adds" the output dimensions, so to get back to
720   // the shape for the *source* of the scalar-indexed node we need to remove the
721   // output dims (which we did above) and then add back the source dim (which we
722   // are about to do below):
723 
724   const Shape& scalar_indexed_source_shape = scalar_indexed->source()->shape();
725 
726   int64 source_dim_for_new_scalar_indexed_node =
727       FindSourcePositionForPassthroughResultDim(
728           /*operand_shape=*/AsInt64Slice(
729               scalar_indexed_source_shape.dimensions()),
730           /*result_shape=*/new_scalar_indexed_source_shape,
731           scalar_indexed->source_dim());
732 
733   // We may not be able to find a source dim for the new scalar-indexed node.
734   // For instance consider:
735   //
736   //   operand = s32[3,5,2] constant({...})
737   //   indices = s32[7] parameter(0)
738   //   gather = s32[3,2,7] gather(operand, indices),
739   //       offset_dims={0,1},
740   //       collapsed_slice_dims={1},
741   //       start_index_map={1},
742   //       index_vector_dim=1,
743   //       slice_sizes={3,1,2}
744   //   reshape = s32[6,7] reshape(gather)
745   //
746   // In this case the gather maps to:
747   //    (scalar-indexed-const (constant s32[3,5,2]) %indices 1->[2])
748   //
749   // and the reshape passes through dimension 2 from its input into dimension 1
750   // in its output.  However, we can't rewrite the reshape as a scalar-indexed
751   // node because then we'd have to reshape the [3,5,2] `operand` array to
752   // [6,5], but then dimension 1 of the reshaped [6,5] array indexes differently
753   // (a.k.a. isn't pass-through) than the [3,5,2] array.
754 
755   if (source_dim_for_new_scalar_indexed_node == -1) {
756     VLOG(3) << "Could not compute the source dim for the new scalar indexed "
757                "node: scalar_indexed_source_shape = ["
758             << StrJoin(scalar_indexed_source_shape.dimensions(), ",")
759             << "] and new_scalar_indexed_source_shape = ["
760             << StrJoin(new_scalar_indexed_source_shape, ",") << "]";
761     return nullptr;
762   }
763 
764   InsertAt(
765       &new_scalar_indexed_source_shape, source_dim_for_new_scalar_indexed_node,
766       scalar_indexed_source_shape.dimensions(scalar_indexed->source_dim()));
767 
768   CHECK_EQ(absl::c_accumulate(new_scalar_indexed_source_shape, 1LL,
769                               std::multiplies<int64>()),
770            ShapeUtil::ElementsIn(scalar_indexed_source_shape));
771 
772   CHECK(IsReshapePassthroughOperandDim(
773       ComputeReshapePassthroughDimPairs(
774           /*operand_shape=*/AsInt64Slice(
775               scalar_indexed_source_shape.dimensions()),
776           /*result_shape=*/new_scalar_indexed_source_shape),
777       scalar_indexed->source_dim()));
778 
779   auto map_passthrough_operand_dim_to_result_dim = [&](int64 result_dim) {
780     return MapPassthroughOperandDimToResultDim(reshape_passthrough_dims,
781                                                result_dim);
782   };
783 
784   std::vector<int64> output_dims_for_new_scalar_indexed_node;
785   absl::c_transform(scalar_indexed->output_dims(),
786                     std::back_inserter(output_dims_for_new_scalar_indexed_node),
787                     map_passthrough_operand_dim_to_result_dim);
788 
789   TF_ASSIGN_OR_RETURN(const Literal* new_scalar_indexed_source_literal,
790                       TakeOwnership(scalar_indexed->literal().Reshape(
791                           new_scalar_indexed_source_shape)));
792   TF_ASSIGN_OR_RETURN(
793       Array * new_scalar_indexed_source,
794       ComputeArrayForConstant(*new_scalar_indexed_source_literal));
795 
796   return ConstructScalarIndexedArray(
797       new_scalar_indexed_source, scalar_indexed->indices(),
798       source_dim_for_new_scalar_indexed_node,
799       output_dims_for_new_scalar_indexed_node, shape);
800 }
801 
ComputeArrayForReshape(const Shape & shape,Array * operand)802 StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForReshape(
803     const Shape& shape, Array* operand) {
804   if (ShapeUtil::Compatible(operand->shape(), shape)) {
805     return operand;
806   }
807 
808   if (auto* scalar_indexed =
809           dynamic_cast<ScalarIndexedConstantArray*>(operand)) {
810     TF_ASSIGN_OR_RETURN(Analysis::Array * reshape_folded_into_gather,
811                         FoldReshapeOfGather(shape, scalar_indexed));
812     if (reshape_folded_into_gather) {
813       return reshape_folded_into_gather;
814     }
815   }
816 
817   if (auto* constant_array = dynamic_cast<ConstantArray*>(operand)) {
818     TF_ASSIGN_OR_RETURN(Literal* const new_literal,
819                         TakeOwnership(constant_array->literal()->Reshape(
820                             AsInt64Slice(shape.dimensions()))));
821     return Construct<ConstantArray>(new_literal);
822   }
823 
824   return Construct<ReshapedArray>(operand, shape);
825 }
826 
827 StatusOr<Analysis::Array*>
ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,Array * lhs,Array * rhs)828 IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
829                                                          Array* lhs,
830                                                          Array* rhs) {
831   // Try to fold BinaryOp(Broadcast(Const0), ScalarIndexed(Const1, Indices))
832   //          => ScalarIndexed(BinaryOp(Broadcast'(Const0), Const1), Indices)
833   //
834   // We can do this if every output dimension from the scalar-indexed node is a
835   // broadcasted dimension for the broadcast node.  Informally, the precondition
836   // means Broadcast(Const0)[IDX] is solely a function of the components of IDX
837   // that are not output-dims for the scalar-indexed node. In other words, for
838   // every assignment to the non-output dims in IDX we have a "constant" LHS to
839   // the BinaryOp.  This transform propagates this "constant" to the source for
840   // the scalar-indexed node.
841 
842   ScalarIndexedConstantArray* lhs_scalar_indexed_const =
843       dynamic_cast<ScalarIndexedConstantArray*>(lhs);
844   ScalarIndexedConstantArray* rhs_scalar_indexed_const =
845       dynamic_cast<ScalarIndexedConstantArray*>(rhs);
846 
847   bool lhs_is_indexed;
848 
849   // One of the operands must be scalar-indexed and the other must be a
850   // broadcast of a constant.
851   if (lhs_scalar_indexed_const && !rhs_scalar_indexed_const) {
852     lhs_is_indexed = true;
853   } else if (rhs_scalar_indexed_const && !lhs_scalar_indexed_const) {
854     lhs_is_indexed = false;
855   } else {
856     return nullptr;
857   }
858 
859   ScalarIndexedConstantArray* scalar_indexed_const =
860       lhs_is_indexed ? lhs_scalar_indexed_const : rhs_scalar_indexed_const;
861   UnknownArray* candidate_broadcast_array =
862       dynamic_cast<UnknownArray*>(lhs_is_indexed ? rhs : lhs);
863   if (!candidate_broadcast_array ||
864       candidate_broadcast_array->instruction().opcode() !=
865           HloOpcode::kBroadcast) {
866     return nullptr;
867   }
868 
869   const HloInstruction* broadcast_instr =
870       &candidate_broadcast_array->instruction();
871   const HloInstruction* broadcast_const_operand = broadcast_instr->operand(0);
872   if (broadcast_const_operand->opcode() != HloOpcode::kConstant) {
873     return nullptr;
874   }
875 
876   absl::Span<const int64> broadcast_dims = broadcast_instr->dimensions();
877   auto is_broadcasted_dim = [&](int64 output_dim) {
878     return absl::c_find(broadcast_dims, output_dim) == broadcast_dims.end();
879   };
880 
881   // All of the output dims must be "broadcasted" dims for the other operand.
882   if (!absl::c_all_of(scalar_indexed_const->output_dims(),
883                       is_broadcasted_dim)) {
884     return nullptr;
885   }
886 
887   // To figure out the broadcast dimensions for the (constant) source for the
888   // scalar-indexed node, we "simulate" the index transformation done by the
889   // existing broadcsat:
890   enum class IndexComponent { Broadcasted, NotBroadcasted };
891   std::vector<IndexComponent> simulated_index(
892       broadcast_instr->shape().dimensions_size(), IndexComponent::Broadcasted);
893   for (int64 broadcast_dim : broadcast_dims) {
894     simulated_index[broadcast_dim] = IndexComponent::NotBroadcasted;
895   }
896 
897   // The scalar-indexed node "removes" the source dim and "inserts" the output
898   // dims.  We do the opposite here to undo the scalar-indexed operation.
899   absl::Span<const int64> output_dims = scalar_indexed_const->output_dims();
900   for (int64 i = output_dims.size() - 1; i >= 0; --i) {
901     CHECK(simulated_index[output_dims[i]] == IndexComponent::Broadcasted);
902     EraseAt(&simulated_index, output_dims[i]);
903   }
904 
905   InsertAt(&simulated_index, scalar_indexed_const->source_dim(),
906            IndexComponent::Broadcasted);
907 
908   // new_inner_broadcast_dims holds the broadcast dimensions for the inner
909   // BinaryOp(Broadcast'(Const0), Const1).  We now translate simulated_index to
910   // new_inner_broadcast_dims.
911   std::vector<int64> new_inner_broadcast_dims;
912   for (int64 i = 0; i < simulated_index.size(); i++) {
913     if (simulated_index[i] == IndexComponent::NotBroadcasted) {
914       new_inner_broadcast_dims.push_back(i);
915     }
916   }
917 
918   // inner_broadcast_result is the Broadcast'(Const0) bit in
919   // BinaryOp(Broadcast'(Const0), Const1)
920   TF_ASSIGN_OR_RETURN(
921       Literal inner_broadcast_result,
922       broadcast_const_operand->literal().Broadcast(
923           scalar_indexed_const->source()->shape(), new_inner_broadcast_dims));
924 
925   // literal_for_new_source is BinaryOp(Broadcast'(Const0), Const1)
926   const Literal* literal_for_new_source;
927   if (lhs_is_indexed) {
928     TF_ASSIGN_OR_RETURN(
929         literal_for_new_source,
930         TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
931             opcode, scalar_indexed_const->literal(), inner_broadcast_result)));
932   } else {
933     TF_ASSIGN_OR_RETURN(
934         literal_for_new_source,
935         TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
936             opcode, inner_broadcast_result, scalar_indexed_const->literal())));
937   }
938 
939   ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
940   return Construct<ScalarIndexedConstantArray>(
941       new_source, scalar_indexed_const->indices(),
942       scalar_indexed_const->source_dim(),
943       std::vector<int64>(scalar_indexed_const->output_dims().begin(),
944                          scalar_indexed_const->output_dims().end()),
945       scalar_indexed_const->shape());
946 }
947 
948 StatusOr<Analysis::Array*>
ComputeArrayForElementwiseUnaryOp(HloOpcode opcode,Array * operand)949 IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(HloOpcode opcode,
950                                                         Array* operand) {
951   auto* scalar_indexed_const =
952       dynamic_cast<ScalarIndexedConstantArray*>(operand);
953   if (scalar_indexed_const == nullptr) {
954     return nullptr;
955   }
956 
957   // Fold UnaryOp(ScalarIndexed(Const, Indices))
958   //   => ScalarIndexed(UnaryOp(Const), Indices)
959 
960   TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source,
961                       TakeOwnership(HloEvaluator{}.EvaluateElementwiseUnaryOp(
962                           opcode, scalar_indexed_const->literal())));
963   ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
964   return Construct<ScalarIndexedConstantArray>(
965       new_source, scalar_indexed_const->indices(),
966       scalar_indexed_const->source_dim(),
967       ArraySliceToVector(scalar_indexed_const->output_dims()),
968       scalar_indexed_const->shape());
969 }
970 
971 namespace {
972 
973 // Returns the non-contracting non-batch dimension (as per `contracting_dims`
974 // and `batch_dims`) if there is exactly one, otherwise returns nullopt.
GetOnlyNonContractingNonBatchDim(int64 rank,absl::Span<const int64> contracting_dims,absl::Span<const int64> batch_dims)975 absl::optional<int64> GetOnlyNonContractingNonBatchDim(
976     int64 rank, absl::Span<const int64> contracting_dims,
977     absl::Span<const int64> batch_dims) {
978   absl::optional<int64> result;
979   for (int64 dim = 0; dim < rank; dim++) {
980     if (!absl::c_linear_search(contracting_dims, dim) &&
981         !absl::c_linear_search(batch_dims, dim)) {
982       if (result.has_value()) {
983         return absl::nullopt;
984       }
985       result = dim;
986     }
987   }
988   return result;
989 }
990 
991 // Returns true if `indexed_array`, which is either the LHS or the RHS of a Dot
992 // HLO, can be folded into the dot operation.  For now these conditions are both
993 // necessary and sufficient.
994 //
995 // `tag` describes the caller.  Used only for logging.
996 //
997 // `contracting_dims` and `batch_dims` are the contracting and batch dimensions
998 // of whatever operand `indexed_array` is to the dot (LHS or RHS).
CanFoldDotIntoIndexedArray(absl::string_view tag,Analysis::ScalarIndexedConstantArray * indexed_array,absl::Span<const int64> contracting_dims,absl::Span<const int64> batch_dims)999 bool CanFoldDotIntoIndexedArray(
1000     absl::string_view tag, Analysis::ScalarIndexedConstantArray* indexed_array,
1001     absl::Span<const int64> contracting_dims,
1002     absl::Span<const int64> batch_dims) {
1003   absl::optional<int64> non_contracting_non_batch_dim =
1004       GetOnlyNonContractingNonBatchDim(indexed_array->shape().rank(),
1005                                        contracting_dims, batch_dims);
1006   if (!non_contracting_non_batch_dim.has_value()) {
1007     VLOG(3) << tag << ": multiple or no non-contracting non-batch dimensions";
1008     return false;
1009   }
1010 
1011   if (indexed_array->output_dims().size() != 1 ||
1012       indexed_array->output_dims()[0] != *non_contracting_non_batch_dim) {
1013     VLOG(3) << tag << ": output dims != the lhs non-contracting non-batch dim";
1014     return false;
1015   }
1016 
1017   int64 indexed_array_rank = indexed_array->shape().rank();
1018   if (indexed_array->source_dim() < (indexed_array_rank - 2)) {
1019     // This restriction can be lifted by inserting reshape nodes.
1020     VLOG(3) << tag
1021             << ": source dim is not in the low two dims, won't be able to form "
1022                "a matmul";
1023     return false;
1024   }
1025 
1026   return true;
1027 }
1028 
1029 }  // namespace
1030 
1031 StatusOr<Analysis::Array*>
ComputeArrayForDotWithIndexedLhs(const Shape & shape,const DotDimensionNumbers & dim_numbers,const PrecisionConfig & precision_config,ScalarIndexedConstantArray * lhs,ConstantArray * rhs)1032 IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs(
1033     const Shape& shape, const DotDimensionNumbers& dim_numbers,
1034     const PrecisionConfig& precision_config, ScalarIndexedConstantArray* lhs,
1035     ConstantArray* rhs) {
1036   VLOG(3) << "ComputeArrayForDotWithIndexedLhs(" << ToString(lhs) << " "
1037           << ToString(rhs);
1038   if (!CanFoldDotIntoIndexedArray(
1039           "ComputeArrayForDotWithIndexedLhs", lhs, /*contracting_dims=*/
1040           AsInt64Slice(dim_numbers.lhs_contracting_dimensions()),
1041           /*batch_dims=*/AsInt64Slice(dim_numbers.lhs_batch_dimensions()))) {
1042     return nullptr;
1043   }
1044 
1045   int64 lhs_rank = lhs->shape().rank();
1046   DotDimensionNumbers new_dim_numbers = dim_numbers;
1047   new_dim_numbers.set_lhs_contracting_dimensions(
1048       0, lhs->source_dim() == (lhs_rank - 1) ? (lhs_rank - 2) : (lhs_rank - 1));
1049 
1050   TF_ASSIGN_OR_RETURN(
1051       Literal * literal_for_new_source,
1052       TakeOwnership(HloEvaluator{}.EvaluateDotOp(
1053           new_dim_numbers, precision_config, lhs->literal(), *rhs->literal())));
1054 
1055   // The new source dimension is wherever the non-batch non-contracting LHS
1056   // dimension "went".
1057   int64 new_source_dim = dim_numbers.lhs_batch_dimensions_size() +
1058                          dim_numbers.rhs_batch_dimensions_size();
1059 
1060   ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
1061   return Construct<ScalarIndexedConstantArray>(
1062       new_source, lhs->indices(), new_source_dim,
1063       ArraySliceToVector(lhs->output_dims()), shape);
1064 }
1065 
1066 StatusOr<Analysis::Array*>
ComputeArrayForDotWithIndexedRhs(const Shape & shape,const DotDimensionNumbers & dim_numbers,const PrecisionConfig & precision_config,ConstantArray * lhs,ScalarIndexedConstantArray * rhs)1067 IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs(
1068     const Shape& shape, const DotDimensionNumbers& dim_numbers,
1069     const PrecisionConfig& precision_config, ConstantArray* lhs,
1070     ScalarIndexedConstantArray* rhs) {
1071   VLOG(3) << "ComputeArrayForDotWithIndexedRhs(" << ToString(lhs) << " "
1072           << ToString(rhs);
1073   if (!CanFoldDotIntoIndexedArray(
1074           "ComputeArrayForDotWithIndexedRhs", rhs, /*contracting_dims=*/
1075           AsInt64Slice(dim_numbers.rhs_contracting_dimensions()),
1076           /*batch_dims=*/AsInt64Slice(dim_numbers.rhs_batch_dimensions()))) {
1077     return nullptr;
1078   }
1079 
1080   int64 rhs_rank = rhs->shape().rank();
1081 
1082   DotDimensionNumbers new_dim_numbers = dim_numbers;
1083   new_dim_numbers.set_rhs_contracting_dimensions(
1084       0, rhs->source_dim() == (rhs_rank - 1) ? (rhs_rank - 2) : (rhs_rank - 1));
1085 
1086   TF_ASSIGN_OR_RETURN(
1087       Literal * literal_for_new_source,
1088       TakeOwnership(HloEvaluator{}.EvaluateDotOp(
1089           new_dim_numbers, precision_config, *lhs->literal(), rhs->literal())));
1090 
1091   // The new source dimension is wherever the non-batch non-contracting RHS
1092   // dimension "went".
1093   int64 new_source_dim = dim_numbers.lhs_batch_dimensions_size() +
1094                          dim_numbers.rhs_batch_dimensions_size() + 1;
1095 
1096   ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
1097   return Construct<ScalarIndexedConstantArray>(
1098       new_source, rhs->indices(), new_source_dim,
1099       ArraySliceToVector(rhs->output_dims()), shape);
1100 }
1101 
ComputeArrayForDot(const Shape & shape,const DotDimensionNumbers & dim_numbers,const PrecisionConfig & precision_config,Array * lhs,Array * rhs)1102 StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDot(
1103     const Shape& shape, const DotDimensionNumbers& dim_numbers,
1104     const PrecisionConfig& precision_config, Array* lhs, Array* rhs) {
1105   // Intuitively, if
1106   //
1107   //  - The LHS of a dot product is a gathered sequence of rows from a constant
1108   //    array (i.e. LHS[I,J] = Const[Indices[I],J]) and the RHS is a constant
1109   //
1110   //  OR
1111   //
1112   //  - If the RHS of a dot product is a gathered sequence of columns from a
1113   //    constant array (i.e. RHS[I,J] = Const[I, Indices[J]]) and the LHS is a
1114   //    constant
1115   //
1116   // then the result of the dot product itself is a gather from a constant
1117   // array.  E.g. Dot(LHS, ConstRhs) where LHS[I,J] = Const[Indices[I],J] can be
1118   // rewritten as Result where Result[I,J] = Dot(Const, ConstRhs)[Indices[I],
1119   // J].
1120   //
1121   // We do a general version of this rewrite here.
1122   VLOG(3) << "ComputeArrayForDot(" << ToString(lhs) << " " << ToString(rhs);
1123   if (auto* lhs_indexed_array =
1124           dynamic_cast<ScalarIndexedConstantArray*>(lhs)) {
1125     if (auto* rhs_constant = dynamic_cast<ConstantArray*>(rhs)) {
1126       return ComputeArrayForDotWithIndexedLhs(shape, dim_numbers,
1127                                               precision_config,
1128                                               lhs_indexed_array, rhs_constant);
1129     }
1130   }
1131 
1132   if (auto* rhs_indexed_array =
1133           dynamic_cast<ScalarIndexedConstantArray*>(rhs)) {
1134     if (auto* lhs_constant = dynamic_cast<ConstantArray*>(lhs)) {
1135       return ComputeArrayForDotWithIndexedRhs(shape, dim_numbers,
1136                                               precision_config, lhs_constant,
1137                                               rhs_indexed_array);
1138     }
1139   }
1140 
1141   return nullptr;
1142 }
1143 
name() const1144 absl::string_view IndexedArrayAnalysisPrinterPass::name() const {
1145   return "indexed-array-analysis-printer-pass";
1146 }
1147 
Run(HloModule * module)1148 StatusOr<bool> IndexedArrayAnalysisPrinterPass::Run(HloModule* module) {
1149   if (!VLOG_IS_ON(2)) {
1150     return false;
1151   }
1152 
1153   IndexedArrayAnalysis analysis;
1154   for (auto* computation : module->MakeNonfusionComputations()) {
1155     for (auto* instr : computation->instructions()) {
1156       TF_ASSIGN_OR_RETURN(Analysis::Array * t, analysis.GetArrayFor(instr));
1157       if (!dynamic_cast<UnknownArray*>(t) && !dynamic_cast<ConstantArray*>(t)) {
1158         VLOG(2) << instr->ToString() << "   ->   " << analysis.ToString(t);
1159       }
1160     }
1161   }
1162 
1163   return false;
1164 }
1165 
1166 }  // namespace xla
1167