1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
17
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/container/inlined_vector.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "absl/types/optional.h"
25 #include "tensorflow/compiler/xla/map_util.h"
26 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
27 #include "tensorflow/compiler/xla/util.h"
28
29 namespace xla {
30
31 namespace {
32 using Analysis = IndexedArrayAnalysis;
33 using UnknownArray = Analysis::UnknownArray;
34 using ConstantArray = Analysis::ConstantArray;
35 using ReshapedArray = Analysis::ReshapedArray;
36 using ScalarIndexedArray = Analysis::ScalarIndexedArray;
37 using absl::StrJoin;
38 } // namespace
39
ToString(Array * root,bool print_constants)40 string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) {
41 switch (root->kind()) {
42 case Array::kUnknown: {
43 auto* unknown_tensor = root->as<UnknownArray>();
44 return absl::StrCat("%", unknown_tensor->instruction().name());
45 }
46
47 case Array::kConstant: {
48 if (print_constants) {
49 string contents = root->as<ConstantArray>()->literal()->ToString();
50 return absl::StrCat("(constant ", ShapeUtil::HumanString(root->shape()),
51 " ", contents, ")");
52 }
53 return absl::StrCat("(constant ", ShapeUtil::HumanString(root->shape()),
54 ")");
55 }
56
57 case Array::kReshaped: {
58 ReshapedArray* reshaped_array = root->as<ReshapedArray>();
59 return absl::StrCat(
60 "(reshape ", ToString(reshaped_array->operand(), print_constants),
61 " to ", ShapeUtil::HumanString(reshaped_array->shape()), ")");
62 }
63
64 case Array::kScalarIndexedConstant:
65 case Array::kScalarIndexed: {
66 auto* indexed_array = root->as<ScalarIndexedArray>();
67 string name = root->kind() == Array::kScalarIndexedConstant
68 ? "scalar-indexed-const"
69 : "scalar-indexed";
70 return absl::StrCat(
71 "(", name, " ", ToString(indexed_array->source(), print_constants),
72 " ", ToString(indexed_array->indices(), print_constants), " ",
73 indexed_array->source_dim(), "->[",
74 StrJoin(indexed_array->output_dims(), ","), "])");
75 }
76 }
77 }
78
GetArrayFor(const HloInstruction * instr)79 StatusOr<Analysis::Array*> IndexedArrayAnalysis::GetArrayFor(
80 const HloInstruction* instr) {
81 auto it = cache_.find(instr);
82 if (it != cache_.end()) {
83 return it->second;
84 }
85
86 TF_RETURN_IF_ERROR(TraverseAndPopulateCache(instr));
87 return FindOrDie(cache_, instr);
88 }
89
TraverseAndPopulateCache(const HloInstruction * root)90 Status IndexedArrayAnalysis::TraverseAndPopulateCache(
91 const HloInstruction* root) {
92 // Depth first search over the DAG, invoking ComputeArrayFor in post order.
93 // The HLO instructions already in the cache are considered leaves.
94
95 absl::InlinedVector<const HloInstruction*, 4> stack;
96
97 enum DfsState { kDiscovered, kVisited };
98 absl::flat_hash_map<const HloInstruction*, DfsState> dfs_state_map;
99
100 stack.push_back(root);
101 InsertOrDie(&dfs_state_map, root, kDiscovered);
102
103 do {
104 const HloInstruction* instr = stack.back();
105 if (cache_.contains(instr)) {
106 stack.pop_back();
107 continue;
108 }
109
110 switch (FindOrDie(dfs_state_map, instr)) {
111 case kDiscovered: {
112 for (const HloInstruction* operand : instr->operands()) {
113 if (!cache_.contains(operand)) {
114 stack.push_back(operand);
115 CHECK(!dfs_state_map.contains(operand) ||
116 dfs_state_map[operand] == kDiscovered);
117 dfs_state_map[operand] = kDiscovered;
118 }
119 }
120 dfs_state_map[instr] = kVisited;
121 break;
122 }
123
124 case kVisited:
125 stack.pop_back();
126 TF_ASSIGN_OR_RETURN(Array * array, ComputeArrayFor(instr));
127 InsertOrDie(&cache_, instr, array);
128 break;
129 }
130 } while (!stack.empty());
131
132 return Status::OK();
133 }
134
ComputeArrayFor(const HloInstruction * instr)135 StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayFor(
136 const HloInstruction* instr) {
137 Array* computed_array;
138 if (instr->IsElementwise() && instr->operand_count() == 1) {
139 TF_ASSIGN_OR_RETURN(
140 computed_array,
141 ComputeArrayForElementwiseUnaryOp(
142 instr->opcode(), FindOrDie(cache_, instr->operand(0))));
143 } else if (instr->IsElementwise() && instr->operand_count() == 2) {
144 TF_ASSIGN_OR_RETURN(
145 computed_array,
146 ComputeArrayForElementwiseBinaryOp(
147 instr->opcode(), FindOrDie(cache_, instr->operand(0)),
148 FindOrDie(cache_, instr->operand(1))));
149 } else if (instr->opcode() == HloOpcode::kConstant) {
150 TF_ASSIGN_OR_RETURN(computed_array,
151 ComputeArrayForConstant(instr->literal()));
152 } else if (instr->opcode() == HloOpcode::kGather) {
153 TF_ASSIGN_OR_RETURN(
154 computed_array,
155 ComputeArrayForGather(instr->shape(), instr->gather_dimension_numbers(),
156 instr->gather_slice_sizes(),
157 FindOrDie(cache_, instr->operand(0)),
158 FindOrDie(cache_, instr->operand(1))));
159 } else if (instr->opcode() == HloOpcode::kReshape) {
160 TF_ASSIGN_OR_RETURN(
161 computed_array,
162 ComputeArrayForReshape(instr->shape(),
163 FindOrDie(cache_, instr->operand(0))));
164 } else if (instr->opcode() == HloOpcode::kDot) {
165 TF_ASSIGN_OR_RETURN(
166 computed_array,
167 ComputeArrayForDot(instr->shape(), instr->dot_dimension_numbers(),
168 instr->precision_config(),
169 FindOrDie(cache_, instr->operand(0)),
170 FindOrDie(cache_, instr->operand(1))));
171 } else {
172 computed_array = nullptr;
173 }
174
175 if (!computed_array) {
176 computed_array = Construct<UnknownArray>(instr);
177 }
178
179 return computed_array;
180 }
181
ComputeArrayForConstant(const Literal & literal)182 StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForConstant(
183 const Literal& literal) {
184 return Construct<ConstantArray>(&literal);
185 }
186
FoldGatherOfGather(ScalarIndexedArray * source,Array * indices,int64 source_dim,absl::Span<const int64> output_dims,Shape shape)187 StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::FoldGatherOfGather(
188 ScalarIndexedArray* source, Array* indices, int64 source_dim,
189 absl::Span<const int64> output_dims, Shape shape) {
190 // We want to transform Gather(Gather(A, X), Y) => Gather(A, Gather(X, Y)).
191 // `source` is the inner Gather(A, X).
192
193 Array* a = source->source();
194 Array* x = source->indices();
195 Array* y = indices;
196
197 // This bit is slightly tricky, so we do a naive "simulation" of the two
198 // consecutive gather operations to infer what the composed gather should look
199 // like.
200
201 enum class IndexComponent { Ungathered, GatheredFirst, GatheredSecond };
202
203 std::vector<IndexComponent> simulated_index(a->shape().dimensions_size(),
204 IndexComponent::Ungathered);
205
206 // Simulate the first gather.
207 EraseAt(&simulated_index, source->source_dim());
208 for (int64 gather_dim : source->output_dims()) {
209 simulated_index.insert(simulated_index.begin() + gather_dim,
210 IndexComponent::GatheredFirst);
211 }
212
213 // Simulate the second gather.
214 EraseAt(&simulated_index, source_dim);
215 for (int64 output_dim : output_dims) {
216 simulated_index.insert(simulated_index.begin() + output_dim,
217 IndexComponent::GatheredSecond);
218 }
219
220 int64 source_dim_for_index_array =
221 FindIndex(source->output_dims(), source_dim);
222 CHECK_NE(source_dim_for_index_array, source->output_dims().size());
223
224 std::vector<int64> output_dims_for_index_array;
225 int64 gathered_index_components_seen = 0;
226 for (IndexComponent simulation_dim : simulated_index) {
227 if (simulation_dim == IndexComponent::GatheredSecond) {
228 output_dims_for_index_array.push_back(gathered_index_components_seen);
229 }
230 if (simulation_dim != IndexComponent::Ungathered) {
231 gathered_index_components_seen++;
232 }
233 }
234
235 std::vector<int64> dim_sizes_for_composed_index;
236 std::vector<int64> output_dims_for_new_gather;
237 for (int64 i = 0, e = simulated_index.size(); i < e; i++) {
238 if (simulated_index[i] != IndexComponent::Ungathered) {
239 dim_sizes_for_composed_index.push_back(shape.dimensions(i));
240 output_dims_for_new_gather.push_back(i);
241 }
242 }
243
244 Array* inner_indices = ConstructScalarIndexedArray(
245 x, y, source_dim_for_index_array, output_dims_for_index_array,
246 ShapeUtil::MakeShape(x->shape().element_type(),
247 dim_sizes_for_composed_index));
248 return ConstructScalarIndexedArray(a, inner_indices, source->source_dim(),
249 output_dims_for_new_gather,
250 std::move(shape));
251 }
252
ComputeArrayForGather(const Shape & shape,const GatherDimensionNumbers & dim_numbers,absl::Span<const int64> slice_sizes,Array * source,Array * indices)253 StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForGather(
254 const Shape& shape, const GatherDimensionNumbers& dim_numbers,
255 absl::Span<const int64> slice_sizes, Array* source, Array* indices) {
256 if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) {
257 VLOG(3) << "ComputeArrayForGather: indices are not scalar";
258 return nullptr;
259 }
260
261 CHECK_EQ(dim_numbers.start_index_map_size(), 1);
262
263 // We can also handle dim_numbers.collapsed_slice_dims_size() == 0 here,
264 // should it become relevant.
265
266 if (dim_numbers.collapsed_slice_dims_size() != 1 ||
267 dim_numbers.collapsed_slice_dims(0) != dim_numbers.start_index_map(0)) {
268 VLOG(3) << "ComputeArrayForGather: gather operations must elide "
269 "start_index_map[0] and "
270 "start_index_map[0] only";
271 return nullptr;
272 }
273
274 // ScalarIndexedArray cannot represent gathers that "slice" along some
275 // dimensions -- for instance it cannot represent a gather that picks 5 [2,3]
276 // arrays from an array of size [7,4,6]. We check that condition down below:
277
278 for (int64 i = 0, e = source->shape().dimensions_size(); i < e; i++) {
279 if (i != dim_numbers.collapsed_slice_dims(0) &&
280 source->shape().dimensions(i) != slice_sizes[i]) {
281 VLOG(3) << "ComputeArrayForGather: slice_sizes[" << i
282 << "] != source->shape().dimensions(" << i << ") -- "
283 << source->shape().dimensions(i) << " vs. " << slice_sizes[i]
284 << " with dim_numbers.collapsed_slice_dims(0) = "
285 << dim_numbers.collapsed_slice_dims(0);
286 return nullptr;
287 }
288 }
289
290 int64 source_dim = dim_numbers.start_index_map(0);
291 std::vector<int64> output_dims;
292 for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) {
293 if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) {
294 output_dims.push_back(i);
295 }
296 }
297
298 if (auto* indexed = dynamic_cast<ScalarIndexedArray*>(source)) {
299 if (absl::c_linear_search(indexed->output_dims(), source_dim)) {
300 return FoldGatherOfGather(indexed, indices, source_dim, output_dims,
301 shape);
302 }
303 } else if (auto* constant = dynamic_cast<ConstantArray*>(source)) {
304 return Construct<ScalarIndexedConstantArray>(constant, indices, source_dim,
305 output_dims, shape);
306 }
307
308 return Construct<ScalarIndexedArray>(source, indices, source_dim, output_dims,
309 shape);
310 }
311
312 namespace {
313 // Returns an index into `values` such that the product of the range
314 // [values.begin()+index, values.end()) is equal to `product`. If there is no
315 // such index, return -1. All integers in `values` must be positive.
FindSuffixWithProduct(absl::Span<const int64> values,int64 product)316 int64 FindSuffixWithProduct(absl::Span<const int64> values, int64 product) {
317 DCHECK(absl::c_all_of(values, [](int64 value) { return value > 0; }));
318
319 int64 current_product = 1;
320 int64 i;
321 for (i = values.size() - 1; i >= 0 && product > current_product; --i) {
322 current_product *= values[i];
323 }
324
325 if (product == current_product) {
326 return i + 1;
327 }
328
329 return -1;
330 }
331
332 struct ReshapePassthroughDimPair {
333 int64 result_dim;
334 int64 operand_dim;
335 };
336
337 // Returns a set of dimension pairs such for all (result_dim, operand_dim) in
338 // the set:
339 //
340 // output_index[result_dim] = SourceIndexOfReshape(output_index)[operand_dim]
341 //
342 // The returned vector of pairs is sorted in both the result_dim and the
343 // operand_dim components.
ComputeReshapePassthroughDimPairs(absl::Span<const int64> operand_shape,absl::Span<const int64> result_shape)344 std::vector<ReshapePassthroughDimPair> ComputeReshapePassthroughDimPairs(
345 absl::Span<const int64> operand_shape,
346 absl::Span<const int64> result_shape) {
347 // A reshape can be seen as an index mapping from output index to input index:
348 //
349 // (i_0, ..., i_n) = f(o_0, ..., o_m)
350 //
351 // This function returns the pairs (j, k) for which the following invariant
352 // holds for all indices in the shape:
353 //
354 // o_j == i_k
355 //
356 // And this occurs when:
357 //
358 // O_{j+1} * ... * O_n == I_{k+1} * ... * I_m
359 //
360 // (where O_x are the sizes of the output shape and I_x are the sizes of the
361 // input shape) and the size of the dimension j of the result is the same as
362 // the size of dimension k in the operand.
363 //
364 // These conditions are sufficient because the Reshape HLO is spec'ed such
365 // that the rightmost dimensions are always minor in the flattening and refine
366 // operation.
367
368 std::vector<ReshapePassthroughDimPair> result;
369 int64 result_subarray_size = 1;
370 for (int64 result_dim = result_shape.size() - 1; result_dim >= 0;
371 --result_dim) {
372 int64 candidate_operand_dim =
373 FindSuffixWithProduct(operand_shape, result_subarray_size);
374
375 // result_subarray_size does not include the elements in the current
376 // `result_dim` dimension (we multiply in result_shape[result_dim] at the
377 // end of loop body) so candidate_operand_dim can never be zero.
378 CHECK_NE(candidate_operand_dim, 0)
379 << "result_dim = " << result_dim
380 << ", result_subarray_size = " << result_subarray_size
381 << ", result_shape = [" << StrJoin(result_shape, ",") << "]"
382 << ", operand_shape = [" << StrJoin(operand_shape, ",") << "]";
383
384 if (candidate_operand_dim != -1 &&
385 result_shape[result_dim] == operand_shape[candidate_operand_dim - 1]) {
386 result.push_back({/*result_dim=*/result_dim,
387 /*operand_dim=*/candidate_operand_dim - 1});
388 }
389 result_subarray_size *= result_shape[result_dim];
390 }
391
392 absl::c_reverse(result);
393
394 if (VLOG_IS_ON(3)) {
395 std::vector<string> result_strings;
396 absl::c_transform(result, std::back_inserter(result_strings),
397 [](ReshapePassthroughDimPair value) {
398 return absl::StrCat(value.result_dim, "->",
399 value.operand_dim);
400 });
401 VLOG(3) << "For a reshape from [" << StrJoin(operand_shape, ",") << "] to ["
402 << StrJoin(result_shape, ",") << "] passthrough indices are ["
403 << StrJoin(result_strings, ",")
404 << "] (legend: `result`->`operand`)";
405 }
406
407 DCHECK(absl::c_is_sorted(
408 result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) {
409 return lhs.result_dim < rhs.result_dim;
410 }));
411
412 DCHECK(absl::c_is_sorted(
413 result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) {
414 return lhs.operand_dim < rhs.operand_dim;
415 }));
416
417 return result;
418 }
419
420 // Return true if `dim` is stated as an passthrough operand dim in
421 // `passthrough_dims`.
IsReshapePassthroughOperandDim(absl::Span<const ReshapePassthroughDimPair> passthrough_dims,int64 dim)422 bool IsReshapePassthroughOperandDim(
423 absl::Span<const ReshapePassthroughDimPair> passthrough_dims, int64 dim) {
424 return absl::c_any_of(passthrough_dims,
425 [&](ReshapePassthroughDimPair passthrough_dim_pair) {
426 return passthrough_dim_pair.operand_dim == dim;
427 });
428 }
429
430 // Maps `operand_dim` which must be an passthrough operand dimension to its
431 // corresponding passthrough result dimension based on `passthrough_dims`.
MapPassthroughOperandDimToResultDim(absl::Span<const ReshapePassthroughDimPair> passthrough_dims,int64 operand_dim)432 int64 MapPassthroughOperandDimToResultDim(
433 absl::Span<const ReshapePassthroughDimPair> passthrough_dims,
434 int64 operand_dim) {
435 auto it = absl::c_find_if(
436 passthrough_dims, [&](ReshapePassthroughDimPair passthrough_dim_pair) {
437 return passthrough_dim_pair.operand_dim == operand_dim;
438 });
439 CHECK(it != passthrough_dims.end());
440 return it->result_dim;
441 }
442
FindSourcePositionForPassthroughResultDim(absl::Span<const int64> operand_shape,absl::Span<const int64> result_shape,int64 source_passthrough_dim)443 int64 FindSourcePositionForPassthroughResultDim(
444 absl::Span<const int64> operand_shape, absl::Span<const int64> result_shape,
445 int64 source_passthrough_dim) {
446 VLOG(3) << "FindSourcePositionForPassthroughResultDim(["
447 << StrJoin(operand_shape, ",") << "], [" << StrJoin(result_shape, ",")
448 << "], " << source_passthrough_dim << ")";
449
450 int64 indexed_source_subarray_size =
451 std::accumulate(operand_shape.begin() + source_passthrough_dim + 1,
452 operand_shape.end(), 1LL, std::multiplies<int64>());
453
454 return FindSuffixWithProduct(result_shape, indexed_source_subarray_size);
455 }
456
StripDegenerateDimensions(const Shape & shape)457 Shape StripDegenerateDimensions(const Shape& shape) {
458 DimensionVector new_dims;
459 absl::c_copy_if(shape.dimensions(), std::back_inserter(new_dims),
460 [](int64 dim) { return dim != 1; });
461 return ShapeUtil::MakeShape(shape.element_type(), new_dims);
462 }
463 }; // namespace
464
465 StatusOr<ScalarIndexedArray*>
ReshapeToRemoveDegenerateDims(ScalarIndexedArray * operand)466 IndexedArrayAnalysis::ReshapeToRemoveDegenerateDims(
467 ScalarIndexedArray* operand) {
468 const Shape& shape = operand->shape();
469 if (!ShapeUtil::HasDegenerateDimensions(shape)) {
470 return operand;
471 }
472
473 // We only need to reshape out the degenerate dims from the indices and the
474 // source (except the source dim).
475
476 const Shape& source_shape = operand->source()->shape();
477 DimensionVector new_source_shape_dims;
478 for (int64 i = 0, e = source_shape.dimensions_size(); i < e; i++) {
479 if (i == operand->source_dim() || source_shape.dimensions(i) != 1) {
480 new_source_shape_dims.push_back(source_shape.dimensions(i));
481 }
482 }
483
484 Shape new_source_shape =
485 ShapeUtil::MakeShape(shape.element_type(), new_source_shape_dims);
486 Shape new_indices_shape =
487 StripDegenerateDimensions(operand->indices()->shape());
488
489 TF_ASSIGN_OR_RETURN(
490 Array* const new_source,
491 ComputeArrayForReshape(new_source_shape, operand->source()));
492 TF_ASSIGN_OR_RETURN(
493 Array* const new_indices,
494 ComputeArrayForReshape(new_indices_shape, operand->indices()));
495
496 // Build the new output dims while keeping track of the degenerate dims that
497 // will no longer be present.
498 DimensionVector new_output_dims;
499 int64 degenerate_dims_seen = 0;
500 for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) {
501 if (shape.dimensions(i) == 1) {
502 degenerate_dims_seen++;
503 } else if (absl::c_linear_search(operand->output_dims(), i)) {
504 new_output_dims.push_back(i - degenerate_dims_seen);
505 }
506 }
507
508 // Similarly, build the new source dim while keeping track of the degenerate
509 // dims that will no longer be present.
510 int64 degenerate_dims_before_source_dim =
511 std::count(source_shape.dimensions().begin(),
512 source_shape.dimensions().begin() + operand->source_dim(), 1);
513 int64 new_source_dim =
514 operand->source_dim() - degenerate_dims_before_source_dim;
515
516 return ConstructScalarIndexedArray(
517 new_source, new_indices, new_source_dim,
518 InlinedVectorToVector(new_output_dims),
519 StripDegenerateDimensions(operand->shape()));
520 }
521
ReshapeToAddDegenerateDims(ScalarIndexedArray * operand,absl::Span<const int64> degenerate_dims)522 StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::ReshapeToAddDegenerateDims(
523 ScalarIndexedArray* operand, absl::Span<const int64> degenerate_dims) {
524 if (degenerate_dims.empty()) {
525 return operand;
526 }
527
528 CHECK(!ShapeUtil::HasDegenerateDimensions(operand->shape()));
529
530 DimensionVector new_output_dims = [&]() {
531 // To make things easy we use a "scratch" buffer of bools where the i'th
532 // element is true iff the i'th component of the result index is an output
533 // index.
534
535 absl::InlinedVector<bool, 6> output_dims_bitvector(
536 operand->shape().dimensions_size());
537 for (int64 output_dim : operand->output_dims()) {
538 output_dims_bitvector[output_dim] = true;
539 }
540
541 for (int64 degenerate_dim : degenerate_dims) {
542 InsertAt(&output_dims_bitvector, degenerate_dim, false);
543 }
544
545 DimensionVector result;
546 result.reserve(operand->output_dims().size());
547 for (int64 i = 0, e = output_dims_bitvector.size(); i < e; i++) {
548 if (output_dims_bitvector[i]) {
549 result.push_back(i);
550 }
551 }
552
553 return result;
554 }();
555
556 DimensionVector new_result_shape_dims;
557 absl::c_copy(operand->shape().dimensions(),
558 std::back_inserter(new_result_shape_dims));
559 for (int64 degenerate_dim : degenerate_dims) {
560 InsertAt(&new_result_shape_dims, degenerate_dim, 1);
561 }
562
563 DimensionVector new_source_shape_dims = new_result_shape_dims;
564 for (int64 output_dim : new_output_dims) {
565 EraseAt(&new_source_shape_dims, output_dim);
566 }
567
568 int64 new_source_dim = [&]() {
569 for (int i = 0, e = new_source_shape_dims.size(); i < e; i++) {
570 int64 non_degenerate_dims_seen = 0;
571 if (non_degenerate_dims_seen == operand->source_dim()) {
572 return i;
573 }
574 if (new_source_shape_dims[new_source_dim] != 1) {
575 non_degenerate_dims_seen++;
576 }
577 }
578 LOG(FATAL) << "Did not find source dim in " << ToString(operand);
579 }();
580
581 int64 source_dim_size =
582 operand->source()->shape().dimensions(operand->source_dim());
583 InsertAt(&new_source_shape_dims, /*index=*/new_source_dim,
584 /*value=*/source_dim_size);
585
586 Shape new_source_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
587 new_source_shape_dims);
588 Shape new_result_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
589 new_result_shape_dims);
590
591 TF_ASSIGN_OR_RETURN(
592 Array* const new_source,
593 ComputeArrayForReshape(new_source_shape, operand->source()));
594 return ConstructScalarIndexedArray(
595 new_source, operand->indices(), new_source_dim,
596 InlinedVectorToVector(new_output_dims), new_result_shape);
597 }
598
FoldReshapeOfGather(const Shape & shape,ScalarIndexedConstantArray * operand)599 StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::FoldReshapeOfGather(
600 const Shape& shape, ScalarIndexedConstantArray* operand) {
601 VLOG(3) << "FoldReshapeOfGather(" << ToString(operand) << ")";
602
603 // To make things easier on ourselves, instead of directly trying to fold the
604 // reshape of `operand` to `shape`, we call
605 // `FoldReshapeOfGatherNoDegenerateDims` on shapes without degenerate dims and
606 // handle the degenerate dimensions here by inserting reshapes.
607
608 TF_ASSIGN_OR_RETURN(ScalarIndexedArray* const operand_without_degenerate_dims,
609 ReshapeToRemoveDegenerateDims(operand));
610
611 Shape output_shape_without_degenerate_dims = StripDegenerateDimensions(shape);
612 TF_ASSIGN_OR_RETURN(
613 ScalarIndexedArray* const folded_reshape_without_degenerate_dims,
614 FoldReshapeOfGatherNoDegenerateDims(
615 output_shape_without_degenerate_dims,
616 operand_without_degenerate_dims->as<ScalarIndexedConstantArray>()));
617
618 if (folded_reshape_without_degenerate_dims == nullptr) {
619 return nullptr;
620 }
621
622 DimensionVector degenerate_result_dims;
623 for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) {
624 if (shape.dimensions(i) == 1) {
625 degenerate_result_dims.push_back(i);
626 }
627 }
628
629 return ReshapeToAddDegenerateDims(folded_reshape_without_degenerate_dims,
630 degenerate_result_dims);
631 }
632
633 StatusOr<ScalarIndexedArray*>
FoldReshapeOfGatherNoDegenerateDims(const Shape & shape,ScalarIndexedConstantArray * scalar_indexed)634 IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims(
635 const Shape& shape, ScalarIndexedConstantArray* scalar_indexed) {
636 VLOG(3) << "FoldReshapeOfGatherNoDegenerateDims(" << ToString(scalar_indexed)
637 << ")";
638 CHECK(!ShapeUtil::HasDegenerateDimensions(shape));
639 CHECK(!ShapeUtil::HasDegenerateDimensions(scalar_indexed->shape()));
640
641 // Try to fold Reshape(ScalarIndexed(Const, Indices))
642 // => ScalarIndexed(Const', Indices)
643 //
644 // We can view the reshape and the scalar-indexed operations as functions that
645 // map an output index (i.e. an index into the result) to an input index
646 // (i.e. an index into the operand). The key idea used here is that the
647 // output-to-input mapping for some reshape operations may "pass through" some
648 // output dimensions into the input space unchanged -- i.e. there may exist
649 // output dimension "O" and input dimension "I" such that OutputIndex[O] is
650 // always == InputIndexForReshape(OutputIndex)[I]. If these pass-through
651 // dimensions in the input space of the reshape happen to be include all the
652 // output dimensions for the scalar-indexed node then, roughly, the following
653 // holds:
654 //
655 // SourceIndexOfScalarIndexed(SourceIndexOfReshape(Idx))
656 // == SourceIndexOfScalarIndexed(SourceIndexOfReshape(Ps ++ Qs))
657 //
658 // Where Ps are the set of the pass-through components of Idx that are
659 // also the output dims of the scalar-indexed node, and Qs are the rest.
660 // For brevity, we're playing fast and loose with the notation here -- we
661 // don't literally require Idx to be a concatenation of Ps and Qs, as
662 // suggested by the "++".
663 //
664 // == SourceIndexOfScalarIndexed(Ps ++ SourceIndexOfReshape(Qs))
665 //
666 // Again, we're playing fast and loose with the notation around "++".
667 // Generally this ++ will be a different function that the ++ in the
668 // previous step.
669 //
670 // If the scalar-indexed node has a constant as the source then the
671 // SourceIndexOfReshape function can be "folded into" the constant itself by
672 // reshaping it, leaving us with:
673 //
674 // == SourceIndexOfScalarIndexed(Ps ++ Qs)
675 // == SourceIndexOfScalarIndexed(Idx)
676 //
677 // which is just a scalar-indexed node (with parameters different from the
678 // scalar-indexed node we started with) with a reshaped constant as the
679 // source.
680 //
681 // We can't fold SourceIndexOfReshape into the constant without introducing
682 // another precondition: since the new scalar-indexed node will have a
683 // reshaped (constant) array as its source it will, in general, have a
684 // different source dimension than the original scalar-indexed node. This
685 // source dimension will have to be a passthrough dimension of the
686 // SourceIndexOfReshape indexing function that is folded into the source. And
687 // such a dimension need not exist so this is a non-trivial precondition.
688
689 std::vector<ReshapePassthroughDimPair> reshape_passthrough_dims =
690 ComputeReshapePassthroughDimPairs(
691 /*operand_shape=*/AsInt64Slice(scalar_indexed->shape().dimensions()),
692 /*result_shape=*/AsInt64Slice(shape.dimensions()));
693
694 auto is_reshape_passthrough_operand_dim = [&](int64 operand_dim) {
695 return IsReshapePassthroughOperandDim(reshape_passthrough_dims,
696 operand_dim);
697 };
698
699 if (!absl::c_all_of(scalar_indexed->output_dims(),
700 is_reshape_passthrough_operand_dim)) {
701 VLOG(3) << "Not all output dims are passthrough dims "
702 << ToString(scalar_indexed);
703 return nullptr;
704 }
705
706 // To compute the shape of the source for the new scalar-indexed node we're
707 // going to create, we first "undo" the scalar-indexed operation.
708 std::vector<int64> new_scalar_indexed_source_shape(shape.dimensions().begin(),
709 shape.dimensions().end());
710 for (int64 i = scalar_indexed->output_dims().size() - 1; i >= 0; i--) {
711 int64 output_dim = scalar_indexed->output_dims()[i];
712 int64 output_dim_after_reshape = MapPassthroughOperandDimToResultDim(
713 reshape_passthrough_dims, output_dim);
714 EraseAt(&new_scalar_indexed_source_shape, output_dim_after_reshape);
715 }
716
717 // After this, we need to add in the dimension that will be the source
718 // dimension for the new scalar-indexed node. A scalar-indexed node "removes"
719 // the source dimensions and "adds" the output dimensions, so to get back to
720 // the shape for the *source* of the scalar-indexed node we need to remove the
721 // output dims (which we did above) and then add back the source dim (which we
722 // are about to do below):
723
724 const Shape& scalar_indexed_source_shape = scalar_indexed->source()->shape();
725
726 int64 source_dim_for_new_scalar_indexed_node =
727 FindSourcePositionForPassthroughResultDim(
728 /*operand_shape=*/AsInt64Slice(
729 scalar_indexed_source_shape.dimensions()),
730 /*result_shape=*/new_scalar_indexed_source_shape,
731 scalar_indexed->source_dim());
732
733 // We may not be able to find a source dim for the new scalar-indexed node.
734 // For instance consider:
735 //
736 // operand = s32[3,5,2] constant({...})
737 // indices = s32[7] parameter(0)
738 // gather = s32[3,2,7] gather(operand, indices),
739 // offset_dims={0,1},
740 // collapsed_slice_dims={1},
741 // start_index_map={1},
742 // index_vector_dim=1,
743 // slice_sizes={3,1,2}
744 // reshape = s32[6,7] reshape(gather)
745 //
746 // In this case the gather maps to:
747 // (scalar-indexed-const (constant s32[3,5,2]) %indices 1->[2])
748 //
749 // and the reshape passes through dimension 2 from its input into dimension 1
750 // in its output. However, we can't rewrite the reshape as a scalar-indexed
751 // node because then we'd have to reshape the [3,5,2] `operand` array to
752 // [6,5], but then dimension 1 of the reshaped [6,5] array indexes differently
753 // (a.k.a. isn't pass-through) than the [3,5,2] array.
754
755 if (source_dim_for_new_scalar_indexed_node == -1) {
756 VLOG(3) << "Could not compute the source dim for the new scalar indexed "
757 "node: scalar_indexed_source_shape = ["
758 << StrJoin(scalar_indexed_source_shape.dimensions(), ",")
759 << "] and new_scalar_indexed_source_shape = ["
760 << StrJoin(new_scalar_indexed_source_shape, ",") << "]";
761 return nullptr;
762 }
763
764 InsertAt(
765 &new_scalar_indexed_source_shape, source_dim_for_new_scalar_indexed_node,
766 scalar_indexed_source_shape.dimensions(scalar_indexed->source_dim()));
767
768 CHECK_EQ(absl::c_accumulate(new_scalar_indexed_source_shape, 1LL,
769 std::multiplies<int64>()),
770 ShapeUtil::ElementsIn(scalar_indexed_source_shape));
771
772 CHECK(IsReshapePassthroughOperandDim(
773 ComputeReshapePassthroughDimPairs(
774 /*operand_shape=*/AsInt64Slice(
775 scalar_indexed_source_shape.dimensions()),
776 /*result_shape=*/new_scalar_indexed_source_shape),
777 scalar_indexed->source_dim()));
778
779 auto map_passthrough_operand_dim_to_result_dim = [&](int64 result_dim) {
780 return MapPassthroughOperandDimToResultDim(reshape_passthrough_dims,
781 result_dim);
782 };
783
784 std::vector<int64> output_dims_for_new_scalar_indexed_node;
785 absl::c_transform(scalar_indexed->output_dims(),
786 std::back_inserter(output_dims_for_new_scalar_indexed_node),
787 map_passthrough_operand_dim_to_result_dim);
788
789 TF_ASSIGN_OR_RETURN(const Literal* new_scalar_indexed_source_literal,
790 TakeOwnership(scalar_indexed->literal().Reshape(
791 new_scalar_indexed_source_shape)));
792 TF_ASSIGN_OR_RETURN(
793 Array * new_scalar_indexed_source,
794 ComputeArrayForConstant(*new_scalar_indexed_source_literal));
795
796 return ConstructScalarIndexedArray(
797 new_scalar_indexed_source, scalar_indexed->indices(),
798 source_dim_for_new_scalar_indexed_node,
799 output_dims_for_new_scalar_indexed_node, shape);
800 }
801
ComputeArrayForReshape(const Shape & shape,Array * operand)802 StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForReshape(
803 const Shape& shape, Array* operand) {
804 if (ShapeUtil::Compatible(operand->shape(), shape)) {
805 return operand;
806 }
807
808 if (auto* scalar_indexed =
809 dynamic_cast<ScalarIndexedConstantArray*>(operand)) {
810 TF_ASSIGN_OR_RETURN(Analysis::Array * reshape_folded_into_gather,
811 FoldReshapeOfGather(shape, scalar_indexed));
812 if (reshape_folded_into_gather) {
813 return reshape_folded_into_gather;
814 }
815 }
816
817 if (auto* constant_array = dynamic_cast<ConstantArray*>(operand)) {
818 TF_ASSIGN_OR_RETURN(Literal* const new_literal,
819 TakeOwnership(constant_array->literal()->Reshape(
820 AsInt64Slice(shape.dimensions()))));
821 return Construct<ConstantArray>(new_literal);
822 }
823
824 return Construct<ReshapedArray>(operand, shape);
825 }
826
827 StatusOr<Analysis::Array*>
ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,Array * lhs,Array * rhs)828 IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
829 Array* lhs,
830 Array* rhs) {
831 // Try to fold BinaryOp(Broadcast(Const0), ScalarIndexed(Const1, Indices))
832 // => ScalarIndexed(BinaryOp(Broadcast'(Const0), Const1), Indices)
833 //
834 // We can do this if every output dimension from the scalar-indexed node is a
835 // broadcasted dimension for the broadcast node. Informally, the precondition
836 // means Broadcast(Const0)[IDX] is solely a function of the components of IDX
837 // that are not output-dims for the scalar-indexed node. In other words, for
838 // every assignment to the non-output dims in IDX we have a "constant" LHS to
839 // the BinaryOp. This transform propagates this "constant" to the source for
840 // the scalar-indexed node.
841
842 ScalarIndexedConstantArray* lhs_scalar_indexed_const =
843 dynamic_cast<ScalarIndexedConstantArray*>(lhs);
844 ScalarIndexedConstantArray* rhs_scalar_indexed_const =
845 dynamic_cast<ScalarIndexedConstantArray*>(rhs);
846
847 bool lhs_is_indexed;
848
849 // One of the operands must be scalar-indexed and the other must be a
850 // broadcast of a constant.
851 if (lhs_scalar_indexed_const && !rhs_scalar_indexed_const) {
852 lhs_is_indexed = true;
853 } else if (rhs_scalar_indexed_const && !lhs_scalar_indexed_const) {
854 lhs_is_indexed = false;
855 } else {
856 return nullptr;
857 }
858
859 ScalarIndexedConstantArray* scalar_indexed_const =
860 lhs_is_indexed ? lhs_scalar_indexed_const : rhs_scalar_indexed_const;
861 UnknownArray* candidate_broadcast_array =
862 dynamic_cast<UnknownArray*>(lhs_is_indexed ? rhs : lhs);
863 if (!candidate_broadcast_array ||
864 candidate_broadcast_array->instruction().opcode() !=
865 HloOpcode::kBroadcast) {
866 return nullptr;
867 }
868
869 const HloInstruction* broadcast_instr =
870 &candidate_broadcast_array->instruction();
871 const HloInstruction* broadcast_const_operand = broadcast_instr->operand(0);
872 if (broadcast_const_operand->opcode() != HloOpcode::kConstant) {
873 return nullptr;
874 }
875
876 absl::Span<const int64> broadcast_dims = broadcast_instr->dimensions();
877 auto is_broadcasted_dim = [&](int64 output_dim) {
878 return absl::c_find(broadcast_dims, output_dim) == broadcast_dims.end();
879 };
880
881 // All of the output dims must be "broadcasted" dims for the other operand.
882 if (!absl::c_all_of(scalar_indexed_const->output_dims(),
883 is_broadcasted_dim)) {
884 return nullptr;
885 }
886
887 // To figure out the broadcast dimensions for the (constant) source for the
888 // scalar-indexed node, we "simulate" the index transformation done by the
889 // existing broadcsat:
890 enum class IndexComponent { Broadcasted, NotBroadcasted };
891 std::vector<IndexComponent> simulated_index(
892 broadcast_instr->shape().dimensions_size(), IndexComponent::Broadcasted);
893 for (int64 broadcast_dim : broadcast_dims) {
894 simulated_index[broadcast_dim] = IndexComponent::NotBroadcasted;
895 }
896
897 // The scalar-indexed node "removes" the source dim and "inserts" the output
898 // dims. We do the opposite here to undo the scalar-indexed operation.
899 absl::Span<const int64> output_dims = scalar_indexed_const->output_dims();
900 for (int64 i = output_dims.size() - 1; i >= 0; --i) {
901 CHECK(simulated_index[output_dims[i]] == IndexComponent::Broadcasted);
902 EraseAt(&simulated_index, output_dims[i]);
903 }
904
905 InsertAt(&simulated_index, scalar_indexed_const->source_dim(),
906 IndexComponent::Broadcasted);
907
908 // new_inner_broadcast_dims holds the broadcast dimensions for the inner
909 // BinaryOp(Broadcast'(Const0), Const1). We now translate simulated_index to
910 // new_inner_broadcast_dims.
911 std::vector<int64> new_inner_broadcast_dims;
912 for (int64 i = 0; i < simulated_index.size(); i++) {
913 if (simulated_index[i] == IndexComponent::NotBroadcasted) {
914 new_inner_broadcast_dims.push_back(i);
915 }
916 }
917
918 // inner_broadcast_result is the Broadcast'(Const0) bit in
919 // BinaryOp(Broadcast'(Const0), Const1)
920 TF_ASSIGN_OR_RETURN(
921 Literal inner_broadcast_result,
922 broadcast_const_operand->literal().Broadcast(
923 scalar_indexed_const->source()->shape(), new_inner_broadcast_dims));
924
925 // literal_for_new_source is BinaryOp(Broadcast'(Const0), Const1)
926 const Literal* literal_for_new_source;
927 if (lhs_is_indexed) {
928 TF_ASSIGN_OR_RETURN(
929 literal_for_new_source,
930 TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
931 opcode, scalar_indexed_const->literal(), inner_broadcast_result)));
932 } else {
933 TF_ASSIGN_OR_RETURN(
934 literal_for_new_source,
935 TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
936 opcode, inner_broadcast_result, scalar_indexed_const->literal())));
937 }
938
939 ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
940 return Construct<ScalarIndexedConstantArray>(
941 new_source, scalar_indexed_const->indices(),
942 scalar_indexed_const->source_dim(),
943 std::vector<int64>(scalar_indexed_const->output_dims().begin(),
944 scalar_indexed_const->output_dims().end()),
945 scalar_indexed_const->shape());
946 }
947
948 StatusOr<Analysis::Array*>
ComputeArrayForElementwiseUnaryOp(HloOpcode opcode,Array * operand)949 IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(HloOpcode opcode,
950 Array* operand) {
951 auto* scalar_indexed_const =
952 dynamic_cast<ScalarIndexedConstantArray*>(operand);
953 if (scalar_indexed_const == nullptr) {
954 return nullptr;
955 }
956
957 // Fold UnaryOp(ScalarIndexed(Const, Indices))
958 // => ScalarIndexed(UnaryOp(Const), Indices)
959
960 TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source,
961 TakeOwnership(HloEvaluator{}.EvaluateElementwiseUnaryOp(
962 opcode, scalar_indexed_const->literal())));
963 ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
964 return Construct<ScalarIndexedConstantArray>(
965 new_source, scalar_indexed_const->indices(),
966 scalar_indexed_const->source_dim(),
967 ArraySliceToVector(scalar_indexed_const->output_dims()),
968 scalar_indexed_const->shape());
969 }
970
971 namespace {
972
973 // Returns the non-contracting non-batch dimension (as per `contracting_dims`
974 // and `batch_dims`) if there is exactly one, otherwise returns nullopt.
GetOnlyNonContractingNonBatchDim(int64 rank,absl::Span<const int64> contracting_dims,absl::Span<const int64> batch_dims)975 absl::optional<int64> GetOnlyNonContractingNonBatchDim(
976 int64 rank, absl::Span<const int64> contracting_dims,
977 absl::Span<const int64> batch_dims) {
978 absl::optional<int64> result;
979 for (int64 dim = 0; dim < rank; dim++) {
980 if (!absl::c_linear_search(contracting_dims, dim) &&
981 !absl::c_linear_search(batch_dims, dim)) {
982 if (result.has_value()) {
983 return absl::nullopt;
984 }
985 result = dim;
986 }
987 }
988 return result;
989 }
990
991 // Returns true if `indexed_array`, which is either the LHS or the RHS of a Dot
992 // HLO, can be folded into the dot operation. For now these conditions are both
993 // necessary and sufficient.
994 //
995 // `tag` describes the caller. Used only for logging.
996 //
997 // `contracting_dims` and `batch_dims` are the contracting and batch dimensions
998 // of whatever operand `indexed_array` is to the dot (LHS or RHS).
CanFoldDotIntoIndexedArray(absl::string_view tag,Analysis::ScalarIndexedConstantArray * indexed_array,absl::Span<const int64> contracting_dims,absl::Span<const int64> batch_dims)999 bool CanFoldDotIntoIndexedArray(
1000 absl::string_view tag, Analysis::ScalarIndexedConstantArray* indexed_array,
1001 absl::Span<const int64> contracting_dims,
1002 absl::Span<const int64> batch_dims) {
1003 absl::optional<int64> non_contracting_non_batch_dim =
1004 GetOnlyNonContractingNonBatchDim(indexed_array->shape().rank(),
1005 contracting_dims, batch_dims);
1006 if (!non_contracting_non_batch_dim.has_value()) {
1007 VLOG(3) << tag << ": multiple or no non-contracting non-batch dimensions";
1008 return false;
1009 }
1010
1011 if (indexed_array->output_dims().size() != 1 ||
1012 indexed_array->output_dims()[0] != *non_contracting_non_batch_dim) {
1013 VLOG(3) << tag << ": output dims != the lhs non-contracting non-batch dim";
1014 return false;
1015 }
1016
1017 int64 indexed_array_rank = indexed_array->shape().rank();
1018 if (indexed_array->source_dim() < (indexed_array_rank - 2)) {
1019 // This restriction can be lifted by inserting reshape nodes.
1020 VLOG(3) << tag
1021 << ": source dim is not in the low two dims, won't be able to form "
1022 "a matmul";
1023 return false;
1024 }
1025
1026 return true;
1027 }
1028
1029 } // namespace
1030
1031 StatusOr<Analysis::Array*>
ComputeArrayForDotWithIndexedLhs(const Shape & shape,const DotDimensionNumbers & dim_numbers,const PrecisionConfig & precision_config,ScalarIndexedConstantArray * lhs,ConstantArray * rhs)1032 IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs(
1033 const Shape& shape, const DotDimensionNumbers& dim_numbers,
1034 const PrecisionConfig& precision_config, ScalarIndexedConstantArray* lhs,
1035 ConstantArray* rhs) {
1036 VLOG(3) << "ComputeArrayForDotWithIndexedLhs(" << ToString(lhs) << " "
1037 << ToString(rhs);
1038 if (!CanFoldDotIntoIndexedArray(
1039 "ComputeArrayForDotWithIndexedLhs", lhs, /*contracting_dims=*/
1040 AsInt64Slice(dim_numbers.lhs_contracting_dimensions()),
1041 /*batch_dims=*/AsInt64Slice(dim_numbers.lhs_batch_dimensions()))) {
1042 return nullptr;
1043 }
1044
1045 int64 lhs_rank = lhs->shape().rank();
1046 DotDimensionNumbers new_dim_numbers = dim_numbers;
1047 new_dim_numbers.set_lhs_contracting_dimensions(
1048 0, lhs->source_dim() == (lhs_rank - 1) ? (lhs_rank - 2) : (lhs_rank - 1));
1049
1050 TF_ASSIGN_OR_RETURN(
1051 Literal * literal_for_new_source,
1052 TakeOwnership(HloEvaluator{}.EvaluateDotOp(
1053 new_dim_numbers, precision_config, lhs->literal(), *rhs->literal())));
1054
1055 // The new source dimension is wherever the non-batch non-contracting LHS
1056 // dimension "went".
1057 int64 new_source_dim = dim_numbers.lhs_batch_dimensions_size() +
1058 dim_numbers.rhs_batch_dimensions_size();
1059
1060 ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
1061 return Construct<ScalarIndexedConstantArray>(
1062 new_source, lhs->indices(), new_source_dim,
1063 ArraySliceToVector(lhs->output_dims()), shape);
1064 }
1065
1066 StatusOr<Analysis::Array*>
ComputeArrayForDotWithIndexedRhs(const Shape & shape,const DotDimensionNumbers & dim_numbers,const PrecisionConfig & precision_config,ConstantArray * lhs,ScalarIndexedConstantArray * rhs)1067 IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs(
1068 const Shape& shape, const DotDimensionNumbers& dim_numbers,
1069 const PrecisionConfig& precision_config, ConstantArray* lhs,
1070 ScalarIndexedConstantArray* rhs) {
1071 VLOG(3) << "ComputeArrayForDotWithIndexedRhs(" << ToString(lhs) << " "
1072 << ToString(rhs);
1073 if (!CanFoldDotIntoIndexedArray(
1074 "ComputeArrayForDotWithIndexedRhs", rhs, /*contracting_dims=*/
1075 AsInt64Slice(dim_numbers.rhs_contracting_dimensions()),
1076 /*batch_dims=*/AsInt64Slice(dim_numbers.rhs_batch_dimensions()))) {
1077 return nullptr;
1078 }
1079
1080 int64 rhs_rank = rhs->shape().rank();
1081
1082 DotDimensionNumbers new_dim_numbers = dim_numbers;
1083 new_dim_numbers.set_rhs_contracting_dimensions(
1084 0, rhs->source_dim() == (rhs_rank - 1) ? (rhs_rank - 2) : (rhs_rank - 1));
1085
1086 TF_ASSIGN_OR_RETURN(
1087 Literal * literal_for_new_source,
1088 TakeOwnership(HloEvaluator{}.EvaluateDotOp(
1089 new_dim_numbers, precision_config, *lhs->literal(), rhs->literal())));
1090
1091 // The new source dimension is wherever the non-batch non-contracting RHS
1092 // dimension "went".
1093 int64 new_source_dim = dim_numbers.lhs_batch_dimensions_size() +
1094 dim_numbers.rhs_batch_dimensions_size() + 1;
1095
1096 ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
1097 return Construct<ScalarIndexedConstantArray>(
1098 new_source, rhs->indices(), new_source_dim,
1099 ArraySliceToVector(rhs->output_dims()), shape);
1100 }
1101
ComputeArrayForDot(const Shape & shape,const DotDimensionNumbers & dim_numbers,const PrecisionConfig & precision_config,Array * lhs,Array * rhs)1102 StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDot(
1103 const Shape& shape, const DotDimensionNumbers& dim_numbers,
1104 const PrecisionConfig& precision_config, Array* lhs, Array* rhs) {
1105 // Intuitively, if
1106 //
1107 // - The LHS of a dot product is a gathered sequence of rows from a constant
1108 // array (i.e. LHS[I,J] = Const[Indices[I],J]) and the RHS is a constant
1109 //
1110 // OR
1111 //
1112 // - If the RHS of a dot product is a gathered sequence of columns from a
1113 // constant array (i.e. RHS[I,J] = Const[I, Indices[J]]) and the LHS is a
1114 // constant
1115 //
1116 // then the result of the dot product itself is a gather from a constant
1117 // array. E.g. Dot(LHS, ConstRhs) where LHS[I,J] = Const[Indices[I],J] can be
1118 // rewritten as Result where Result[I,J] = Dot(Const, ConstRhs)[Indices[I],
1119 // J].
1120 //
1121 // We do a general version of this rewrite here.
1122 VLOG(3) << "ComputeArrayForDot(" << ToString(lhs) << " " << ToString(rhs);
1123 if (auto* lhs_indexed_array =
1124 dynamic_cast<ScalarIndexedConstantArray*>(lhs)) {
1125 if (auto* rhs_constant = dynamic_cast<ConstantArray*>(rhs)) {
1126 return ComputeArrayForDotWithIndexedLhs(shape, dim_numbers,
1127 precision_config,
1128 lhs_indexed_array, rhs_constant);
1129 }
1130 }
1131
1132 if (auto* rhs_indexed_array =
1133 dynamic_cast<ScalarIndexedConstantArray*>(rhs)) {
1134 if (auto* lhs_constant = dynamic_cast<ConstantArray*>(lhs)) {
1135 return ComputeArrayForDotWithIndexedRhs(shape, dim_numbers,
1136 precision_config, lhs_constant,
1137 rhs_indexed_array);
1138 }
1139 }
1140
1141 return nullptr;
1142 }
1143
name() const1144 absl::string_view IndexedArrayAnalysisPrinterPass::name() const {
1145 return "indexed-array-analysis-printer-pass";
1146 }
1147
Run(HloModule * module)1148 StatusOr<bool> IndexedArrayAnalysisPrinterPass::Run(HloModule* module) {
1149 if (!VLOG_IS_ON(2)) {
1150 return false;
1151 }
1152
1153 IndexedArrayAnalysis analysis;
1154 for (auto* computation : module->MakeNonfusionComputations()) {
1155 for (auto* instr : computation->instructions()) {
1156 TF_ASSIGN_OR_RETURN(Analysis::Array * t, analysis.GetArrayFor(instr));
1157 if (!dynamic_cast<UnknownArray*>(t) && !dynamic_cast<ConstantArray*>(t)) {
1158 VLOG(2) << instr->ToString() << " -> " << analysis.ToString(t);
1159 }
1160 }
1161 }
1162
1163 return false;
1164 }
1165
1166 } // namespace xla
1167