1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/shape_inference.h"
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <numeric>
21 #include <set>
22 #include <string>
23 
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_format.h"
28 #include "absl/strings/str_join.h"
29 #include "absl/strings/string_view.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/compiler/xla/window_util.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/lib/math/math_util.h"
38 #include "tensorflow/core/platform/logging.h"
39 #include "tensorflow/core/platform/protobuf.h"
40 
41 namespace xla {
42 namespace {
43 
44 using absl::StrFormat;
45 using absl::StrJoin;
46 
47 // Returns true if no element is present in slice more than once.
AllUnique(absl::Span<const int64> slice)48 bool AllUnique(absl::Span<const int64> slice) {
49   return std::set<int64>(slice.begin(), slice.end()).size() == slice.size();
50 }
51 
ExpectArray(const Shape & shape,absl::string_view op_type)52 Status ExpectArray(const Shape& shape, absl::string_view op_type) {
53   if (!shape.IsArray()) {
54     return InvalidArgument("Expected array argument for %s, but got %s.",
55                            string(op_type), ShapeUtil::HumanString(shape));
56   }
57   return Status::OK();
58 }
59 
VerifyReducerShape(const ProgramShape & reducer_shape,absl::Span<const Shape * const> init_value_shapes,absl::Span<const PrimitiveType> input_element_types,int64 inputs)60 Status VerifyReducerShape(const ProgramShape& reducer_shape,
61                           absl::Span<const Shape* const> init_value_shapes,
62                           absl::Span<const PrimitiveType> input_element_types,
63                           int64 inputs) {
64   if (reducer_shape.parameters_size() != inputs * 2) {
65     return InvalidArgument(
66         "Reduction function must take %d parameters, but "
67         "takes %d parameter(s).",
68         inputs * 2, reducer_shape.parameters_size());
69   }
70 
71   const Shape& accumulator_shape = reducer_shape.result();
72   std::vector<const Shape*> accumulator_subshapes;
73   if (accumulator_shape.IsArray()) {
74     if (inputs != 1) {
75       return InvalidArgument(
76           "Reduction function must produce a tuple with %d elements, but "
77           "produces a scalar",
78           inputs);
79     }
80     accumulator_subshapes.push_back(&accumulator_shape);
81   } else if (accumulator_shape.IsTuple()) {
82     if (ShapeUtil::TupleElementCount(accumulator_shape) != inputs) {
83       return InvalidArgument(
84           "Reduction function must produce a tuple with %d elements, but has "
85           "%d elements",
86           inputs, ShapeUtil::TupleElementCount(accumulator_shape));
87     }
88     for (const Shape& element_shape : accumulator_shape.tuple_shapes()) {
89       accumulator_subshapes.push_back(&element_shape);
90     }
91   } else {
92     return InvalidArgument(
93         "Reduction function must produce a scalar or tuple of scalars, but has "
94         "shape: %s",
95         ShapeUtil::HumanString(accumulator_shape));
96   }
97 
98   for (const Shape* element_shape : accumulator_subshapes) {
99     if (element_shape->rank() != 0) {
100       return InvalidArgument(
101           "Reduction function must return a scalar or tuple of scalars but "
102           "returns shape: %s",
103           ShapeUtil::HumanString(accumulator_shape));
104     }
105   }
106 
107   for (int64 i = 0; i < inputs; ++i) {
108     // Check that the accumulator can be passed in as the first argument.
109     // Note: comparing here and below with Compatible since we don't care about
110     // layout in scalars - see b/26668201 for a longer-term vision.
111     if (!ShapeUtil::Compatible(*accumulator_subshapes[i],
112                                reducer_shape.parameters(i))) {
113       return InvalidArgument(
114           "Reduction function's %d-th parameter shape differs from the "
115           "result shape: %s vs %s",
116           i, ShapeUtil::HumanString(reducer_shape.parameters(i)),
117           ShapeUtil::HumanString(*accumulator_subshapes[i]));
118     }
119     // Check that init_value's shapes are suitable for reducer_shape.
120     if (!ShapeUtil::CompatibleIgnoringFpPrecision(*accumulator_subshapes[i],
121                                                   *init_value_shapes[i])) {
122       return InvalidArgument(
123           "Reduction function's accumulator shape at index %d differs from "
124           "the init_value shape: %s vs %s",
125           i, ShapeUtil::HumanString(*accumulator_subshapes[i]),
126           ShapeUtil::HumanString(*init_value_shapes[i]));
127     }
128     // Check that the inputs can be passed in as the non-accumulator arguments.
129     const Shape input_element_shape =
130         ShapeUtil::MakeShape(input_element_types[i], {});
131     if (!ShapeUtil::CompatibleIgnoringFpPrecision(
132             input_element_shape, reducer_shape.parameters(inputs + i))) {
133       return InvalidArgument(
134           "Reduction function's %d-th parameter shape differs from the "
135           "input type element type: %s vs %s",
136           inputs + i,
137           ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)),
138           ShapeUtil::HumanString(input_element_shape));
139     }
140     // Check that the accumulator and inputs to the reducer function match.
141     // If the accumulator is scalar, it must have the same type as the inputs
142     // (up to fp precision). If it is a tuple, then the k-th element of the
143     // tuple must have the same type as the K-th input (again, up to fp
144     // precision.)
145     if (!ShapeUtil::CompatibleIgnoringFpPrecision(
146             *accumulator_subshapes[i], reducer_shape.parameters(inputs + i))) {
147       return InvalidArgument(
148           "Reduction function's %d-th parameter shape must "
149           "match the result shape, but got %s vs %s.",
150           inputs + i,
151           ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)),
152           ShapeUtil::HumanString(*accumulator_subshapes[i]));
153     }
154   }
155 
156   return Status::OK();
157 }
158 
InferWindowOutputShape(const Shape & base_shape,const Window & window,PrimitiveType element_type,bool allow_negative_padding)159 StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
160                                        const Window& window,
161                                        PrimitiveType element_type,
162                                        bool allow_negative_padding) {
163   if (window.dimensions_size() != base_shape.rank()) {
164     return InvalidArgument(
165         "Window has dimension %d but base shape has dimension %d.",
166         window.dimensions_size(), base_shape.rank());
167   }
168 
169   std::vector<int64> output_dimensions(window.dimensions_size());
170   std::vector<bool> output_is_dynamic(window.dimensions_size());
171   for (int64 i = 0; i < window.dimensions_size(); ++i) {
172     const auto& dim = window.dimensions(i);
173     if (dim.size() <= 0) {
174       return InvalidArgument("Window %s has a non-positive dimension.",
175                              window.DebugString());
176     }
177     if (dim.stride() <= 0) {
178       return InvalidArgument("Window %s has a non-positive stride.",
179                              window.DebugString());
180     }
181     if (!allow_negative_padding && dim.padding_low() < 0) {
182       return InvalidArgument("Window %s has a negative low padding.",
183                              window.DebugString());
184     }
185     if (!allow_negative_padding && dim.padding_high() < 0) {
186       return InvalidArgument("Window %s has a negative high padding.",
187                              window.DebugString());
188     }
189     if (dim.base_dilation() < 1) {
190       return InvalidArgument(
191           "Window %s has a non-positive base area dilation factor.",
192           window.DebugString());
193     }
194     if (dim.window_dilation() < 1) {
195       return InvalidArgument(
196           "Window %s has a non-positive window dilation factor.",
197           window.DebugString());
198     }
199 
200     if (base_shape.is_dynamic_dimension(i) &&
201         !window_util::IsTrivialWindowDimension(dim)) {
202       return Unimplemented(
203           "Dynamic shape is not supported for non trivial window: %s",
204           window_util::ToString(window));
205     }
206 
207     const int64 dilated_base = window_util::DilatedBound(
208         ShapeUtil::GetDimension(base_shape, i), dim.base_dilation());
209     const int64 padded_dilated_base =
210         dim.padding_low() + dilated_base + dim.padding_high();
211     const int64 dilated_window =
212         window_util::DilatedBound(dim.size(), dim.window_dilation());
213 
214     output_dimensions[i] = window_util::StridedBound(
215         padded_dilated_base, dilated_window, dim.stride());
216     output_is_dynamic[i] = base_shape.is_dynamic_dimension(i);
217   }
218 
219   return ShapeUtil::MakeValidatedShape(element_type, output_dimensions,
220                                        output_is_dynamic);
221 }
222 
223 }  // namespace
224 
InferUnaryOpShape(HloOpcode opcode,const HloInstruction * operand)225 /* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
226     HloOpcode opcode, const HloInstruction* operand) {
227   return InferUnaryOpShape(opcode, operand->shape());
228 }
229 
InferUnaryOpShape(HloOpcode opcode,const Shape & shape)230 /* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
231     HloOpcode opcode, const Shape& shape) {
232   // There is no copy operation at the proto level, so handle copy explicitly.
233   // A domain shape is the same as the input one.
234   if (opcode == HloOpcode::kCopy || opcode == HloOpcode::kDomain) {
235     return shape;
236   }
237 
238   TF_RETURN_IF_ERROR(ExpectArray(shape, "operand of unary operation"));
239 
240   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
241   switch (opcode) {
242     case HloOpcode::kFloor:
243     case HloOpcode::kCeil:
244     case HloOpcode::kRoundNearestAfz:
245       if (!ShapeUtil::ElementIsFloating(shape)) {
246         return InvalidArgument(
247             "Expected element type in shape to be floating for %s operation; "
248             "got %s.",
249             HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
250       }
251       return shape;
252     case HloOpcode::kCos:
253     case HloOpcode::kSin:
254     case HloOpcode::kExp:
255     case HloOpcode::kExpm1:
256     case HloOpcode::kLog:
257     case HloOpcode::kLog1p:
258     case HloOpcode::kRsqrt:
259     case HloOpcode::kSqrt:
260     case HloOpcode::kTanh:
261       if (!ShapeUtil::ElementIsFloating(shape) &&
262           !ShapeUtil::ElementIsComplex(shape)) {
263         return InvalidArgument(
264             "Expected element type in shape to be floating or complex for %s "
265             "operation; got %s.",
266             HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
267       }
268       return shape;
269     case HloOpcode::kReal:
270     case HloOpcode::kImag:
271       if (ShapeUtil::ElementIsComplex(shape)) {
272         return ShapeUtil::ComplexComponentShape(shape);
273       } else if (ShapeUtil::ElementIsFloating(shape)) {
274         return shape;
275       } else {
276         return InvalidArgument(
277             "Expected element type in shape to be floating or complex for "
278             "%s operation; got %s.",
279             HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
280       }
281     case HloOpcode::kAbs:
282       if (ShapeUtil::ElementIsComplex(shape)) {
283         return ShapeUtil::ChangeElementType(
284             shape, primitive_util::ComplexComponentType(shape.element_type()));
285       } else if (ShapeUtil::ElementIsSigned(shape)) {
286         return shape;
287       } else {
288         return InvalidArgument(
289             "Expected element type in shape to be floating or complex for "
290             "%s operation; got %s.",
291             HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
292       }
293     case HloOpcode::kClz:
294       if (!ShapeUtil::ElementIsIntegral(shape)) {
295         return InvalidArgument(
296             "Expected an integral element type in argument to Clz "
297             "operation; got %s.",
298             PrimitiveType_Name(shape.element_type()));
299       }
300       return shape;
301     case HloOpcode::kNegate:
302       if (!ShapeUtil::ElementIsIntegral(shape) &&
303           !ShapeUtil::ElementIsFloating(shape) &&
304           !ShapeUtil::ElementIsComplex(shape)) {
305         return InvalidArgument(
306             "Expected element type in shape to be integral, floating or "
307             "complex for %s operation; got %s.",
308             HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
309       }
310       return shape;
311     case HloOpcode::kSign:
312       if (!ShapeUtil::ElementIsSigned(shape) &&
313           !ShapeUtil::ElementIsComplex(shape)) {
314         return InvalidArgument(
315             "Expected element type in shape to be signed or complex for "
316             "%s operation; got %s.",
317             HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
318       }
319       return shape;
320 
321     case HloOpcode::kNot:
322       if (shape.element_type() != PRED &&
323           !primitive_util::IsIntegralType(shape.element_type())) {
324         return InvalidArgument(
325             "Expected pred or an integral element type in argument to Not "
326             "operation; got %s.",
327             PrimitiveType_Name(shape.element_type()));
328       }
329       return shape;
330 
331     case HloOpcode::kIsFinite:
332       if (!ShapeUtil::ElementIsFloating(shape)) {
333         return InvalidArgument(
334             "Expected element type in shape to be floating "
335             "point for IsFinite "
336             "operation; got %s.",
337             PrimitiveType_Name(shape.element_type()));
338       }
339       return ShapeUtil::ChangeElementType(shape, PRED);
340 
341     default:
342       return InvalidArgument(
343           "Unknown operation for unary shape inference: \"%s\".",
344           HloOpcodeString(opcode));
345   }
346 }
347 
InferConcatOpShape(absl::Span<const Shape * const> arg_shapes,const int64 dimension)348 /* static */ StatusOr<Shape> ShapeInference::InferConcatOpShape(
349     absl::Span<const Shape* const> arg_shapes, const int64 dimension) {
350   if (arg_shapes.empty()) {
351     return InvalidArgument("Concatenate expects at least one argument.");
352   }
353   if (dimension < 0 || dimension >= arg_shapes[0]->rank()) {
354     return InvalidArgument("Concatenate dimension out of bounds: %d.",
355                            dimension);
356   }
357   const Shape* arg_shape = nullptr;
358   PrimitiveType element_type = PRIMITIVE_TYPE_INVALID;
359   for (const Shape* shape : arg_shapes) {
360     TF_RETURN_IF_ERROR(ExpectArray(*shape, "operand of concatenation"));
361     if (!arg_shape) {
362       arg_shape = shape;
363       element_type = arg_shape->element_type();
364       continue;
365     }
366     if (arg_shape->rank() != shape->rank()) {
367       return InvalidArgument(
368           "Cannot concatenate arrays with different ranks: %d (%s) vs %d "
369           "(%s).",
370           arg_shape->rank(), ShapeUtil::HumanString(*arg_shape), shape->rank(),
371           ShapeUtil::HumanString(*shape));
372     }
373     if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) {
374       return InvalidArgument(
375           "Cannot concatenate arrays with different element types: %s vs %s.",
376           PrimitiveType_Name(arg_shape->element_type()),
377           PrimitiveType_Name(shape->element_type()));
378     }
379     for (int64 dimension_number = 0; dimension_number < arg_shape->rank();
380          ++dimension_number) {
381       if (arg_shape->dimensions(dimension_number) !=
382           shape->dimensions(dimension_number)) {
383         if (dimension_number == dimension) {
384           continue;  // It's okay to differ in the dimension we're
385                      // concatenating.
386         }
387         return InvalidArgument(
388             "Cannot concatenate arrays that differ in dimensions other than "
389             "the one being concatenated (the other array dimensions must be "
390             "the same): %s vs %s in dimension %d.",
391             ShapeUtil::HumanString(*arg_shape), ShapeUtil::HumanString(*shape),
392             dimension);
393       }
394     }
395     element_type = ShapeUtil::HigherPrecisionElementType(*shape, *arg_shape);
396   }
397 
398   std::vector<int64> new_dimensions(arg_shape->dimensions().begin(),
399                                     arg_shape->dimensions().end());
400   for (size_t i = 1; i < arg_shapes.size(); ++i) {
401     new_dimensions[dimension] += arg_shapes[i]->dimensions(dimension);
402   }
403   return ShapeUtil::MakeShape(element_type, new_dimensions);
404 }
405 
InferConvertShape(const Shape & operand_shape,PrimitiveType new_element_type)406 /* static */ StatusOr<Shape> ShapeInference::InferConvertShape(
407     const Shape& operand_shape, PrimitiveType new_element_type) {
408   auto old_element_type = operand_shape.element_type();
409   if (primitive_util::IsComplexType(old_element_type) &&
410       !primitive_util::IsComplexType(new_element_type)) {
411     return Unimplemented(
412         "Conversion from complex to real type %s => %s is not implemented.",
413         ShapeUtil::HumanString(operand_shape),
414         PrimitiveType_Name(new_element_type));
415   }
416   if (!operand_shape.IsArray() ||
417       !primitive_util::IsArrayType(new_element_type)) {
418     // Note: we may want to support tuple conversions via this operation in the
419     // future, by recursing into the tuple elements to check all sub-conversions
420     // are valid. For now we just reject them, though.
421     return InvalidArgument(
422         "Convert does not allow non-arrays, so cannot convert from %s to %s.",
423         ShapeUtil::HumanString(operand_shape),
424         PrimitiveType_Name(new_element_type));
425   }
426 
427   return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
428 }
429 
InferBitcastConvertShape(const Shape & operand_shape,PrimitiveType new_element_type)430 /* static */ StatusOr<Shape> ShapeInference::InferBitcastConvertShape(
431     const Shape& operand_shape, PrimitiveType new_element_type) {
432   auto old_element_type = operand_shape.element_type();
433   if (primitive_util::IsComplexType(old_element_type) !=
434       primitive_util::IsComplexType(new_element_type)) {
435     return InvalidArgument("Conversion from complex to real type %s => %s.",
436                            ShapeUtil::HumanString(operand_shape),
437                            PrimitiveType_Name(new_element_type));
438   }
439   if (!operand_shape.IsArray() ||
440       !primitive_util::IsArrayType(new_element_type)) {
441     // Note: we may want to support tuple conversions via this operation in the
442     // future, by recursing into the tuple elements to check all sub-conversions
443     // are valid. For now we just reject them, though.
444     return InvalidArgument(
445         "Cannot convert from or to tuple type; requested conversion: %s => %s.",
446         ShapeUtil::HumanString(operand_shape),
447         PrimitiveType_Name(new_element_type));
448   }
449   if (primitive_util::BitWidth(old_element_type) !=
450       primitive_util::BitWidth(new_element_type)) {
451     return InvalidArgument(
452         "Cannot bitcast types with different bit-widths: %s => %s.",
453         PrimitiveType_Name(old_element_type),
454         PrimitiveType_Name(new_element_type));
455   }
456 
457   return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
458 }
459 
InferReducePrecisionShape(const Shape & operand_shape,const int exponent_bits,const int mantissa_bits)460 /* static */ StatusOr<Shape> ShapeInference::InferReducePrecisionShape(
461     const Shape& operand_shape, const int exponent_bits,
462     const int mantissa_bits) {
463   if (!ShapeUtil::ElementIsFloating(operand_shape)) {
464     return InvalidArgument(
465         "Expected element type in shape to be floating point for "
466         "ReducePrecision operation; got %s.",
467         PrimitiveType_Name(operand_shape.element_type()));
468   }
469   if (exponent_bits < 1) {
470     // One exponent bit is necessary to distinguish 0 from infinity.  Having
471     // no exponent bits doesn't produce a sensible number, so we require at
472     // least one.
473     return InvalidArgument("Expected exponent_bits >= 1; got %d.",
474                            exponent_bits);
475   }
476   if (mantissa_bits < 0) {
477     // A number with no mantissa bits is still meaningful, however.
478     return InvalidArgument("Expected non-negative mantissa_bits; got %d.",
479                            mantissa_bits);
480   }
481   return operand_shape;
482 }
483 
InferPadShape(const Shape & operand_shape,const Shape & padding_value_shape,const PaddingConfig & padding_config)484 /* static */ StatusOr<Shape> ShapeInference::InferPadShape(
485     const Shape& operand_shape, const Shape& padding_value_shape,
486     const PaddingConfig& padding_config) {
487   if (!operand_shape.IsArray()) {
488     return InvalidArgument(
489         "Pad operation does not support tuple-shape operands.");
490   }
491   if (!ShapeUtil::IsScalar(padding_value_shape)) {
492     return InvalidArgument(
493         "Pad operation does not support non-scalar padding values.");
494   }
495   if (operand_shape.rank() != padding_config.dimensions_size()) {
496     return InvalidArgument(
497         "The rank of the operand and the padding configuration do not match: "
498         "%s vs %s.",
499         ShapeUtil::HumanString(operand_shape),
500         padding_config.ShortDebugString());
501   }
502   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
503                                                      padding_value_shape)) {
504     return InvalidArgument(
505         "The element types of the operands to Pad do not match.");
506   }
507   if (absl::c_any_of(padding_config.dimensions(),
508                      [](const PaddingConfig::PaddingConfigDimension& p) {
509                        return p.interior_padding() < 0;
510                      })) {
511     return InvalidArgument("Interior padding cannot be negative: %s",
512                            padding_config.ShortDebugString());
513   }
514 
515   if (!padding_value_shape.is_static()) {
516     return InvalidArgument("Dynamic padding value is not supported");
517   }
518 
519   std::vector<int64> dimensions(operand_shape.rank());
520   std::vector<bool> is_dynamic(operand_shape.rank());
521   for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) {
522     const auto& p = padding_config.dimensions(i);
523     if (operand_shape.is_dynamic_dimension(i) && p.edge_padding_high() != 0 &&
524         p.edge_padding_low() != 0 && p.interior_padding() != 0) {
525       return InvalidArgument(
526           "Dynamic dimension on padding dimension is not supported.");
527     }
528     dimensions[i] = operand_shape.dimensions(i) + p.edge_padding_low() +
529                     p.edge_padding_high() +
530                     std::max<int64>(operand_shape.dimensions(i) - 1, 0LL) *
531                         p.interior_padding();
532     if (dimensions[i] < 0) {
533       return InvalidArgument("Padding result in negative size for dimension %d",
534                              i);
535     }
536     is_dynamic[i] = operand_shape.is_dynamic_dimension(i);
537   }
538 
539   return ShapeUtil::MakeShape(
540       ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape),
541       dimensions, is_dynamic);
542 }
543 
544 // Current DotDimensionNumbers Requirements:
545 //
546 // Contracting Dimensions:
547 // *) Same number of contracting dimensions on both lhs and rhs.
548 // *) Contracting dimension size must be the same on both lhs and rhs.
549 //
550 // Batch Dimensions:
551 // *) Same number of batch dimensions on both lhs and rhs.
552 // *) Same batch dimension sizes on both lhs and rhs.
553 //
554 
555 namespace {
556 
ValidateDotDimensionNumbers(const Shape & lhs,const Shape & rhs,const DotDimensionNumbers & dimension_numbers)557 Status ValidateDotDimensionNumbers(
558     const Shape& lhs, const Shape& rhs,
559     const DotDimensionNumbers& dimension_numbers) {
560   // Check that dimension numbers are in range.
561   auto dims_in_range = [](const int64 rank,
562                           absl::Span<const int64> contracting_dims,
563                           absl::Span<const int64> batch_dims) -> bool {
564     auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; };
565     return absl::c_all_of(contracting_dims, in_range) &&
566            absl::c_all_of(batch_dims, in_range);
567   };
568 
569   absl::Span<const int64> lhs_contracting_dimensions =
570       AsInt64Slice(dimension_numbers.lhs_contracting_dimensions());
571   absl::Span<const int64> rhs_contracting_dimensions =
572       AsInt64Slice(dimension_numbers.rhs_contracting_dimensions());
573   absl::Span<const int64> lhs_batch_dimensions =
574       AsInt64Slice(dimension_numbers.lhs_batch_dimensions());
575   absl::Span<const int64> rhs_batch_dimensions =
576       AsInt64Slice(dimension_numbers.rhs_batch_dimensions());
577 
578   if (!dims_in_range(lhs.rank(), lhs_contracting_dimensions,
579                      lhs_batch_dimensions) ||
580       !dims_in_range(rhs.rank(), rhs_contracting_dimensions,
581                      rhs_batch_dimensions)) {
582     return InvalidArgument("A dimension number is out of range in Dot: %s.",
583                            dimension_numbers.DebugString());
584   }
585 
586   // Check that dimension numbers are unique.
587   auto dims_unique = [](absl::Span<const int64> contracting_dims,
588                         absl::Span<const int64> batch_dims) -> bool {
589     absl::flat_hash_set<int64> dim_set;
590     auto is_unique = [&dim_set](int64 i) -> bool {
591       return dim_set.insert(i).second;
592     };
593     return absl::c_all_of(contracting_dims, is_unique) &&
594            absl::c_all_of(batch_dims, is_unique);
595   };
596 
597   if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) ||
598       !dims_unique(rhs_contracting_dimensions, rhs_batch_dimensions)) {
599     return InvalidArgument("A dimension number is not unique in Dot: %s.",
600                            dimension_numbers.DebugString());
601   }
602 
603   return Status::OK();
604 }
605 
606 }  // namespace
607 
InferDotOpShape(const Shape & lhs,const Shape & rhs,const DotDimensionNumbers & dimension_numbers)608 /* static */ StatusOr<Shape> ShapeInference::InferDotOpShape(
609     const Shape& lhs, const Shape& rhs,
610     const DotDimensionNumbers& dimension_numbers) {
611   TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of dot"));
612   TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of dot"));
613 
614   auto fail = [lhs, rhs](const string& addendum) -> Status {
615     string message =
616         StrFormat("Cannot infer shape for dot operation: %s <dot> %s.",
617                   ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs));
618     if (!addendum.empty()) {
619       message += " " + addendum;
620     }
621     return InvalidArgument("%s", message);
622   };
623 
624   // Check if both element types are the same.
625   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
626     return fail("Element types do not match.");
627   }
628 
629   if ((lhs.rank() < 1) || (rhs.rank() < 1)) {
630     return fail("Dot only supports rank 1 or above.");
631   }
632 
633   // Validate basic properties of dot dimension numbers.
634   TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(lhs, rhs, dimension_numbers));
635 
636   // Check that number of contracting dimensions match.
637   if (dimension_numbers.lhs_contracting_dimensions_size() !=
638       dimension_numbers.rhs_contracting_dimensions_size()) {
639     return fail(
640         "Must specify the same number of contracting dimensions for lhs and "
641         "rhs.");
642   }
643   // Check that contracting dimension sizes match.
644   for (int64 i = 0; i < dimension_numbers.lhs_contracting_dimensions_size();
645        ++i) {
646     const int64 lhs_contracting_dimension =
647         dimension_numbers.lhs_contracting_dimensions(i);
648     const int64 rhs_contracting_dimension =
649         dimension_numbers.rhs_contracting_dimensions(i);
650     if (lhs.dimensions(lhs_contracting_dimension) !=
651             rhs.dimensions(rhs_contracting_dimension) ||
652         lhs.is_dynamic_dimension(lhs_contracting_dimension) !=
653             rhs.is_dynamic_dimension(rhs_contracting_dimension)) {
654       return fail("Contracting dimension sizes do not match.");
655     }
656   }
657 
658   // Check that number of batch dimensions match.
659   if (dimension_numbers.lhs_batch_dimensions_size() !=
660       dimension_numbers.rhs_batch_dimensions_size()) {
661     return fail("Must the same number of batch dimensions for lhs and rhs.");
662   }
663 
664   // Check that batch dimension numbers and sizes match.
665   for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) {
666     if (lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) !=
667             rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i)) ||
668         lhs.is_dynamic_dimension(dimension_numbers.lhs_batch_dimensions(i)) !=
669             rhs.is_dynamic_dimension(
670                 dimension_numbers.rhs_batch_dimensions(i))) {
671       return fail("Batch dimension sizes must match for lhs/rhs.");
672     }
673   }
674 
675   // The ranks of lhs and rhs are decremented by 1 respectively due to the
676   // contraction, and added for the rank of the result. When an input tensor is
677   // a scalar, its contribution to the rank of the result is 0.
678   // Generate the result dimensions in order, rhs dimensions followed by lhs
679   // dimensions except the contracted and batch dimensions.
680   std::vector<int64> dimensions;
681   std::vector<bool> is_dynamic;
682   for (int64 lhs_dim : dimension_numbers.lhs_batch_dimensions()) {
683     dimensions.push_back(lhs.dimensions(lhs_dim));
684     is_dynamic.push_back(lhs.is_dynamic_dimension(lhs_dim));
685   }
686   for (int64 i = 0; i < lhs.rank(); i++) {
687     if (!absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(),
688                                i) &&
689         !absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) {
690       dimensions.push_back(lhs.dimensions(i));
691       is_dynamic.push_back(lhs.is_dynamic_dimension(i));
692     }
693   }
694   for (int64 i = 0; i < rhs.rank(); i++) {
695     if (!absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(),
696                                i) &&
697         !absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) {
698       dimensions.push_back(rhs.dimensions(i));
699       is_dynamic.push_back(rhs.is_dynamic_dimension(i));
700     }
701   }
702   Shape result = ShapeUtil::MakeShape(
703       ShapeUtil::HigherPrecisionElementType(lhs, rhs), dimensions, is_dynamic);
704 
705   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result));
706   VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result);
707   return result;
708 }
709 
710 /* static */ StatusOr<Shape>
InferDegenerateDimensionBroadcastShape(HloOpcode operation,const Shape & lhs,const Shape & rhs)711 ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
712                                                        const Shape& lhs,
713                                                        const Shape& rhs) {
714   TF_RET_CHECK(lhs.rank() == rhs.rank());
715 
716   // The shapes have to be compatible. That is, if some dimension d has a
717   // different size in the two shapes, one of them has to be 1 (a "degenerate"
718   // dimension). In that case, the output shape has the non-1 dimension size
719   // from the lhs/rhs pair in every index.
720   std::vector<int64> output_dimensions(lhs.rank());
721   std::vector<bool> output_dimensions_is_dynamic(lhs.rank());
722   for (int64 i = 0; i < lhs.rank(); ++i) {
723     if (lhs.dimensions(i) == rhs.dimensions(i)) {
724       output_dimensions[i] = lhs.dimensions(i);
725       output_dimensions_is_dynamic[i] = lhs.is_dynamic_dimension(i);
726     } else if (lhs.dimensions(i) == 1) {
727       output_dimensions[i] = rhs.dimensions(i);
728       output_dimensions_is_dynamic[i] = rhs.is_dynamic_dimension(i);
729     } else if (rhs.dimensions(i) == 1) {
730       output_dimensions[i] = lhs.dimensions(i);
731       output_dimensions_is_dynamic[i] = lhs.is_dynamic_dimension(i);
732     } else {
733       return InvalidArgument(
734           "Binary op %s with incompatible shapes: %s and %s.",
735           HloOpcodeString(operation), ShapeUtil::HumanString(lhs),
736           ShapeUtil::HumanString(rhs));
737     }
738   }
739   return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
740                               output_dimensions, output_dimensions_is_dynamic);
741 }
742 
InferInDimBroadcastShape(const Shape & smaller_shape,const Shape & larger_shape,absl::Span<const int64> broadcast_dimensions)743 /* static */ StatusOr<Shape> ShapeInference::InferInDimBroadcastShape(
744     const Shape& smaller_shape, const Shape& larger_shape,
745     absl::Span<const int64> broadcast_dimensions) {
746   if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) {
747     // Reject "magic" inference for binops on different shapes, requiring
748     // the user to provide an explicit broadcast dimension in this case.
749     // See b/25177275 for more details.
750     return InvalidArgument("Automatic shape inference not supported: %s and %s",
751                            ShapeUtil::HumanString(smaller_shape),
752                            ShapeUtil::HumanString(larger_shape));
753   } else if (broadcast_dimensions.size() != smaller_shape.rank()) {
754     return InvalidArgument(
755         "Size of broadcast_dimensions has to match lower-rank operand's "
756         "rank; "
757         " lower-rank operand's rank is %d, size of broadcast_dimensions is "
758         "%u.",
759         smaller_shape.rank(), broadcast_dimensions.size());
760   }
761 
762   // broadcast_dimensions is a sequence of dimensions; its length is equal to
763   // the rank of the lower-rank operand. The lower-rank operand's dimensions
764   // have to be compatible with the higher-rank operand's dimensions at indices
765   // specified by broadcast_dimensions. Here compatible means the dimension
766   // sizes are equal or in one of the shapes the dimension size is
767   // one. Examples:
768   //
769   // smaller_shape   larger_shape   broadcast_dimensions   output_shape
770   //   []              [2, 3]          {}                    [2, 3]
771   //   [3]             [4, 3]          {1}                   [4, 3]
772   //   [2, 3]          [2, 3, 4]       {0, 1}                [2, 3, 4]
773   //   [2, 1]          [2, 3, 4]       {0, 2}                [2, 3, 1]
774   //   [2, 3]          [2, 1, 4]       {0, 1}                [2, 3, 4]
775   //
776   // The column output_shape may not be the final shape of the XLA
777   // operation. After the "InDim" broadcasting implemented in this function
778   // expands the rank, degenerate-dimension broadcasting (implemented in
779   // InferDegenerateDimensionBroadcastShape) broadcasts dimensions of size one
780   // up to match the dimension size of the other operand. For example, consider
781   // the row in the table above with a smaller_shape of [2, 1]. The shape
782   // returned by this function is [2, 3, 1] (output_shape) however, the result
783   // shape of the XLA operation is [2, 3, 4] after degenerate-dimension
784   // broadcasting.
785   //
786   // Invalid broadcasts:
787   //
788   // smaller_shape=[3], larger_shape=[4, 3], broadcast_dimensions={0}
789   // Reason: Dimension zero** of larger_shape (size 4) is not compatible with
790   //   dimension zero of smaller_shape(size 3). **Zero here comes from the value
791   //   in broadcast_dimensions.
792   //
793   // smaller_shape=[2, 1], larger_shape=[2, 3, 4], broadcast_dimensions={1, 2}
794   // Reason: Dimension one of larger_shape (size 3) is not compatible with
795   //   dimension zero of smaller_shape(size 2)
796 
797   // The output shape is initially the larger_shape. Sizes of dimensions
798   // specified in broadcast_dimensions are then changed to match the
799   // corresponding dimension size in smaller_shape.
800   Shape output_shape(larger_shape);
801   output_shape.set_element_type(
802       ShapeUtil::HigherPrecisionElementType(larger_shape, smaller_shape));
803 
804   for (int i = 0; i < smaller_shape.dimensions_size(); ++i) {
805     int64 dimension_to_match = broadcast_dimensions.at(i);
806     if (dimension_to_match < 0) {
807       return InvalidArgument(
808           "Broadcast dimension number (%d) cannot be negative.",
809           dimension_to_match);
810     }
811     if (dimension_to_match >= larger_shape.dimensions_size()) {
812       return InvalidArgument(
813           "Broadcast dimension number (%d) too large; higher-rank "
814           "operand has rank %d.",
815           dimension_to_match, larger_shape.dimensions_size());
816     }
817     int64 small_dimension_size = smaller_shape.dimensions(i);
818     int64 large_dimension_size = larger_shape.dimensions(dimension_to_match);
819     bool small_is_dynamic = smaller_shape.is_dynamic_dimension(i);
820     bool large_is_dynamic =
821         larger_shape.is_dynamic_dimension(dimension_to_match);
822     // Dimension sizes must be compatible: match or be degenerate (degenerate
823     // case is handled by degenerate dimension broadcasting which occurs after
824     // InDim broadcasting).
825     if (small_dimension_size != large_dimension_size &&
826         small_dimension_size != 1 && large_dimension_size != 1) {
827       return InvalidArgument(
828           "Broadcast dimension %d mismatch: %d != %d; %s and %s.", i,
829           small_dimension_size, large_dimension_size,
830           ShapeUtil::HumanString(smaller_shape),
831           ShapeUtil::HumanString(larger_shape));
832     }
833     if (small_is_dynamic != large_is_dynamic) {
834       if (small_dimension_size == large_dimension_size ||
835           (small_dimension_size == 1 && !small_is_dynamic) ||
836           (large_dimension_size == 1 && !large_is_dynamic)) {
837         // Do nothing. It's OK when the size-1 dimension is not static.
838       } else {
839         return InvalidArgument(
840             "Broadcast dimension %d dynamism mismatch: %s and %s.", i,
841             ShapeUtil::HumanString(smaller_shape),
842             ShapeUtil::HumanString(larger_shape));
843       }
844     }
845     // Make sure the broadcast dimensions are listed in a strictly increasing
846     // order.
847     if (i > 0 && broadcast_dimensions.at(i - 1) >= dimension_to_match) {
848       return InvalidArgument(
849           "Broadcast dimensions order is wrong: %d comes after %d.",
850           dimension_to_match, broadcast_dimensions.at(i - 1));
851     }
852 
853     output_shape.set_dimensions(dimension_to_match, small_dimension_size);
854     output_shape.set_dynamic_dimension(dimension_to_match, small_is_dynamic);
855   }
856 
857   return output_shape;
858 }
859 
InferElementwiseBinaryOpShape(HloOpcode operation,const Shape & lhs,const Shape & rhs,absl::Span<const int64> broadcast_dimensions)860 /* static */ StatusOr<Shape> ShapeInference::InferElementwiseBinaryOpShape(
861     HloOpcode operation, const Shape& lhs, const Shape& rhs,
862     absl::Span<const int64> broadcast_dimensions) {
863   TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of elementwise binary operation"));
864   TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of elementwise binary operation"));
865 
866   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
867     return InvalidArgument(
868         "Binary op %s with different element types: %s and %s.",
869         HloOpcodeString(operation), ShapeUtil::HumanString(lhs),
870         ShapeUtil::HumanString(rhs));
871   }
872 
873   if (lhs.rank() == rhs.rank()) {
874     std::vector<int64> identity_dims(lhs.rank());
875     std::iota(identity_dims.begin(), identity_dims.end(), 0);
876     if (!broadcast_dimensions.empty() &&
877         broadcast_dimensions != identity_dims) {
878       return InvalidArgument(
879           "Broadcast dimensions field must either be not set or be the "
880           "identity on binary operations with operands of the same rank.");
881     }
882   }
883 
884   if (ShapeUtil::CompatibleIgnoringFpPrecision(lhs, rhs)) {
885     // If the shapes are the same other than layout, the output shape is the
886     // same (elementwise op).
887     return ShapeUtil::ChangeElementType(
888         lhs, ShapeUtil::HigherPrecisionElementType(lhs, rhs));
889   }
890 
891   if (lhs.rank() == rhs.rank()) {
892     return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs);
893   } else {
894     // Ranks do not match, so perform InDim broadcasting using
895     // broadcast_dimensions. Scalar broadcasting is a special case of this.
896     const Shape& larger_shape = lhs.rank() > rhs.rank() ? lhs : rhs;
897     const Shape& smaller_shape = lhs.rank() > rhs.rank() ? rhs : lhs;
898 
899     // After InDim broadcasting, perform degenerate dimensions broadcasting.
900     TF_ASSIGN_OR_RETURN(Shape indim_broadcast_shape,
901                         InferInDimBroadcastShape(smaller_shape, larger_shape,
902                                                  broadcast_dimensions));
903 
904     return InferDegenerateDimensionBroadcastShape(
905         operation, indim_broadcast_shape, larger_shape);
906   }
907 }
908 
InferBinaryOpShape(HloOpcode opcode,const HloInstruction * lhs,const HloInstruction * rhs)909 /* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
910     HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs) {
911   return InferBinaryOpShape(opcode, lhs->shape(), rhs->shape(),
912                             /*broadcast_dimensions=*/{});
913 }
914 
InferBinaryOpShape(HloOpcode opcode,const Shape & lhs,const Shape & rhs,absl::Span<const int64> broadcast_dimensions)915 /* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
916     HloOpcode opcode, const Shape& lhs, const Shape& rhs,
917     absl::Span<const int64> broadcast_dimensions) {
918   VLOG(2) << StrFormat(
919       "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}",
920       HloOpcodeString(opcode), ShapeUtil::HumanString(lhs),
921       ShapeUtil::HumanString(rhs), StrJoin(broadcast_dimensions, ", "));
922   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
923   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
924 
925   TF_RETURN_IF_ERROR(ExpectArray(
926       lhs, absl::StrCat("lhs of binary operation ", HloOpcodeString(opcode))));
927   TF_RETURN_IF_ERROR(ExpectArray(
928       rhs, absl::StrCat("rhs of binary operation ", HloOpcodeString(opcode))));
929   switch (opcode) {
930     case HloOpcode::kMaximum:
931     case HloOpcode::kMinimum:
932       return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
933                                            broadcast_dimensions);
934 
935     case HloOpcode::kSubtract:
936     case HloOpcode::kAdd:
937     case HloOpcode::kAtan2:
938     case HloOpcode::kPower:
939     case HloOpcode::kDivide:
940     case HloOpcode::kRemainder:
941     case HloOpcode::kMultiply:
942     case HloOpcode::kShiftLeft:
943     case HloOpcode::kShiftRightArithmetic:
944     case HloOpcode::kShiftRightLogical:
945       if (lhs.element_type() == PRED || rhs.element_type() == PRED) {
946         return InvalidArgument(
947             "Expected element type in shape to be arithmetic type for "
948             "operation %s; got PRED.",
949             HloOpcodeString(opcode));
950       }
951       return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
952                                            broadcast_dimensions);
953 
954     case HloOpcode::kComplex: {
955       if (!ShapeUtil::ElementIsFloating(lhs)) {
956         return InvalidArgument(
957             "Expected element type in shape to be floating for complex compose "
958             "operation; got %s.",
959             PrimitiveType_Name(lhs.element_type()));
960       }
961       TF_ASSIGN_OR_RETURN(const Shape& shape,
962                           InferElementwiseBinaryOpShape(opcode, lhs, rhs,
963                                                         broadcast_dimensions));
964       if (lhs.element_type() == F32 && rhs.element_type() == F32) {
965         return ShapeUtil::ChangeElementType(shape, C64);
966       } else if (lhs.element_type() == F64 && rhs.element_type() == F64) {
967         return ShapeUtil::ChangeElementType(shape, C128);
968       } else {
969         return Unimplemented("Complex component type is not implemented.");
970       }
971     }
972     case HloOpcode::kAnd:
973     case HloOpcode::kOr:
974     case HloOpcode::kXor:
975       if (lhs.element_type() != PRED &&
976           !primitive_util::IsIntegralType(lhs.element_type())) {
977         return InvalidArgument(
978             "Expected pred or integral type in argument to and/or operation; "
979             "got %s.",
980             PrimitiveType_Name(lhs.element_type()));
981       }
982       return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
983                                            broadcast_dimensions);
984     case HloOpcode::kCompare: {
985       TF_ASSIGN_OR_RETURN(const Shape& shape,
986                           InferElementwiseBinaryOpShape(opcode, lhs, rhs,
987                                                         broadcast_dimensions));
988       return ShapeUtil::ChangeElementType(shape, PRED);
989     }
990     default:
991       return Unimplemented(
992           "Binary op shape inference: %s; lhs: %s; rhs: %s is not implemented.",
993           HloOpcodeString(opcode), lhs.ShortDebugString(),
994           rhs.ShortDebugString());
995   }
996 }
997 
InferTernaryOpShape(HloOpcode opcode,const HloInstruction * lhs,const HloInstruction * rhs,const HloInstruction * ehs)998 /* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
999     HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs,
1000     const HloInstruction* ehs) {
1001   return InferTernaryOpShape(opcode, lhs->shape(), rhs->shape(), ehs->shape());
1002 }
1003 
InferTernaryOpShape(HloOpcode opcode,const Shape & lhs,const Shape & rhs,const Shape & ehs)1004 /* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
1005     HloOpcode opcode, const Shape& lhs, const Shape& rhs, const Shape& ehs) {
1006   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
1007   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
1008   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(ehs));
1009   switch (opcode) {
1010     case HloOpcode::kClamp:
1011       return InferClampShape(lhs, rhs, ehs);
1012     case HloOpcode::kSelect:
1013       return InferSelectShape(lhs, rhs, ehs);
1014     case HloOpcode::kTupleSelect:
1015       return InferTupleSelectShape(lhs, rhs, ehs);
1016     default:
1017       return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode));
1018   }
1019 }
1020 
InferVariadicOpShape(HloOpcode opcode,absl::Span<const HloInstruction * const> operands)1021 /* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
1022     HloOpcode opcode, absl::Span<const HloInstruction* const> operands) {
1023   std::vector<const Shape*> operand_shapes;
1024   operand_shapes.reserve(operands.size());
1025   for (const HloInstruction* operand : operands) {
1026     operand_shapes.push_back(&operand->shape());
1027   }
1028   return InferVariadicOpShape(opcode, operand_shapes);
1029 }
1030 
InferVariadicOpShape(HloOpcode opcode,absl::Span<const Shape * const> operand_shapes)1031 /* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
1032     HloOpcode opcode, absl::Span<const Shape* const> operand_shapes) {
1033   for (const Shape* shape : operand_shapes) {
1034     TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape));
1035   }
1036   switch (opcode) {
1037     case HloOpcode::kTuple: {
1038       Shape result = ShapeUtil::MakeTupleShape({});
1039       result.mutable_tuple_shapes()->reserve(operand_shapes.size());
1040       for (const Shape* shape : operand_shapes) {
1041         ShapeUtil::AppendShapeToTuple(*shape, &result);
1042       }
1043       return result;
1044     }
1045     case HloOpcode::kSort: {
1046       if (operand_shapes.size() == 1) {
1047         return *operand_shapes[0];
1048       } else {
1049         for (int64 operand = 1; operand < operand_shapes.size(); ++operand) {
1050           if (!ShapeUtil::SameDimensions(*operand_shapes[0],
1051                                          *operand_shapes[operand])) {
1052             return InvalidArgument(
1053                 "Sort keys and values dimensions must match. "
1054                 "Keys shape is: %s\n, Values shape (operand index %lld) is: %s",
1055                 ShapeUtil::HumanString(*operand_shapes[0]), operand,
1056                 ShapeUtil::HumanString(*operand_shapes[operand]));
1057           }
1058         }
1059         std::vector<Shape> operand_shape_values;
1060         for (const Shape* operand_shape : operand_shapes) {
1061           operand_shape_values.push_back(*operand_shape);
1062         }
1063         return ShapeUtil::MakeTupleShape(operand_shape_values);
1064       }
1065       return InvalidArgument("Unexpected number of operands for sort");
1066     }
1067     default:
1068       return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode));
1069   }
1070 }
1071 
InferMapShape(absl::Span<const Shape * const> arg_shapes,const ProgramShape & to_apply,absl::Span<const int64> dimensions)1072 /* static */ StatusOr<Shape> ShapeInference::InferMapShape(
1073     absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply,
1074     absl::Span<const int64> dimensions) {
1075   if (arg_shapes.empty()) {
1076     return InvalidArgument("Map expects at least one argument.");
1077   }
1078 
1079   // All arguments must have the same shape.
1080   const Shape* arg_shape = arg_shapes[0];
1081   for (size_t i = 1; i < arg_shapes.size(); ++i) {
1082     TF_RETURN_IF_ERROR(ExpectArray(*arg_shapes[i], "operand of map"));
1083 
1084     if (ShapeUtil::CompatibleIgnoringFpPrecision(*arg_shapes[i], *arg_shape)) {
1085       continue;
1086     }
1087     if (ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i],
1088                                                       *arg_shape)) {
1089       if (ShapeUtil::IsScalar(*arg_shapes[i])) {
1090         continue;
1091       }
1092       if (ShapeUtil::IsScalar(*arg_shape)) {
1093         arg_shape = arg_shapes[i];
1094         continue;
1095       }
1096     }
1097 
1098     std::vector<string> pieces;
1099     for (const Shape* shape : arg_shapes) {
1100       pieces.push_back(ShapeUtil::HumanString(*shape));
1101     }
1102     return InvalidArgument(
1103         "Map operation requires all operands to have the same shape; got: "
1104         "%s.",
1105         StrJoin(pieces, ", "));
1106   }
1107 
1108   // Check that dimensions.size == arg_shape.dimensions_size() (we currently
1109   // only support mapping across all dimensions: i.e. scalar map functions).
1110   if (dimensions.size() != arg_shape->dimensions_size()) {
1111     return InvalidArgument(
1112         "Map applied to a subset of dimensions currently not supported: "
1113         "arg_dimension_size: %d, requested_map_dimensions_size: %u.",
1114         arg_shape->dimensions_size(), dimensions.size());
1115   }
1116 
1117   // Check that requested map dimensions numbers are monotonically increasing.
1118   for (int i = 0; i < dimensions.size(); ++i) {
1119     if (dimensions[i] != i) {
1120       return InvalidArgument(
1121           "Map requires monotonically increasing dimension numbers; got: %s.",
1122           StrJoin(dimensions, ", "));
1123     }
1124   }
1125 
1126   // The applied function's arity equals the number of arguments.
1127   if (arg_shapes.size() != to_apply.parameters_size()) {
1128     return InvalidArgument(
1129         "Map applied function arity must match number of arguments; got: "
1130         "arity: %d, arguments: %u.",
1131         to_apply.parameters_size(), arg_shapes.size());
1132   }
1133 
1134   // The parameters should all be scalars, and the output too.
1135   const Shape& output_shape = to_apply.result();
1136   if (!ShapeUtil::IsScalar(output_shape)) {
1137     return InvalidArgument(
1138         "Mapped computation's result has to be a scalar; got: %s.",
1139         ShapeUtil::HumanString(output_shape));
1140   }
1141 
1142   for (int i = 0; i < to_apply.parameters_size(); ++i) {
1143     const Shape& parameter_shape = to_apply.parameters(i);
1144 
1145     if (!ShapeUtil::IsScalar(parameter_shape)) {
1146       return InvalidArgument(
1147           "Mapped computation's parameter has to be a scalar; "
1148           "got parameter %d shape: %s.",
1149           i, ShapeUtil::HumanString(parameter_shape));
1150     }
1151 
1152     if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(parameter_shape,
1153                                                        *arg_shape)) {
1154       return InvalidArgument(
1155           "Mapped computation's parameter type has to match argument element "
1156           "type; got parameter %d shape: %s, argument shape: %s.",
1157           i, ShapeUtil::HumanString(parameter_shape),
1158           ShapeUtil::HumanString(*arg_shape));
1159     }
1160   }
1161 
1162   return ShapeUtil::MakeShape(output_shape.element_type(),
1163                               AsInt64Slice(arg_shape->dimensions()));
1164 }
1165 
InferBatchNormTrainingShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & offset_shape,int64 feature_index)1166 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormTrainingShape(
1167     const Shape& operand_shape, const Shape& scale_shape,
1168     const Shape& offset_shape, int64 feature_index) {
1169   TF_RETURN_IF_ERROR(
1170       ExpectArray(operand_shape, "operand of batch norm training"));
1171   TF_RETURN_IF_ERROR(
1172       ExpectArray(offset_shape, "offset input of batch norm training"));
1173   TF_RETURN_IF_ERROR(
1174       ExpectArray(scale_shape, "scale input of batch norm training"));
1175 
1176   TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) ==
1177                Status::OK());
1178   TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) ==
1179                Status::OK());
1180   TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) ==
1181                Status::OK());
1182 
1183   if (feature_index >= operand_shape.rank()) {
1184     return InvalidArgument(
1185         "Expected feature_index of batch-norm-training to be "
1186         "smaller than the rank of operand_shape; "
1187         "got feature_index %d, and rank %d.",
1188         feature_index, operand_shape.rank());
1189   }
1190 
1191   if (feature_index < 0) {
1192     return InvalidArgument(
1193         "Expected feature_index of batch-norm-training to "
1194         "be a non-negative number, got %d.",
1195         feature_index);
1196   }
1197 
1198   if (operand_shape.rank() < 1) {
1199     return InvalidArgument(
1200         "Expected the rank of operand to "
1201         "batch-norm-training to be at least 1; got %d.",
1202         operand_shape.rank());
1203   }
1204 
1205   if (offset_shape.rank() != 1) {
1206     return InvalidArgument(
1207         "Offset input of batch-norm-training must have"
1208         " rank 1, but has rank %d.",
1209         offset_shape.rank());
1210   }
1211 
1212   if (scale_shape.rank() != 1) {
1213     return InvalidArgument(
1214         "Scale input of batch-norm-training must have"
1215         " rank 1, but has rank %d.",
1216         scale_shape.rank());
1217   }
1218 
1219   if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1220     return InvalidArgument(
1221         "The operand to batch-norm-training must have a floating point "
1222         "element type, but the shape is %s.",
1223         PrimitiveType_Name(operand_shape.element_type()));
1224   }
1225 
1226   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
1227                                                      operand_shape)) {
1228     return InvalidArgument(
1229         "The inputs should have the same element type for batch-norm-training, "
1230         "but the shape of offset factor is %s "
1231         "and the shape of operand is %s.",
1232         PrimitiveType_Name(offset_shape.element_type()),
1233         PrimitiveType_Name(operand_shape.element_type()));
1234   }
1235 
1236   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1237                                                      operand_shape)) {
1238     return InvalidArgument(
1239         "The inputs should have the same element type for batch-norm-training, "
1240         "but the shape of scale factor is %s "
1241         "and the shape of operand is %s.",
1242         PrimitiveType_Name(scale_shape.element_type()),
1243         PrimitiveType_Name(operand_shape.element_type()));
1244   }
1245 
1246   const int64 feature_count = operand_shape.dimensions(feature_index);
1247   Shape output_shape_for_mean_and_var =
1248       ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1249 
1250   if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
1251     return InvalidArgument(
1252         "The size of offset factor should be the same as feature count,"
1253         "but the size of offset factor is %d "
1254         "and the feature count is %d.",
1255         ShapeUtil::GetDimension(offset_shape, 0), feature_count);
1256   }
1257 
1258   if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1259     return InvalidArgument(
1260         "The size of scale factor should be the same as feature count,"
1261         "but the size of scale factor is %d "
1262         "and the feature count is %d.",
1263         ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1264   }
1265 
1266   return ShapeUtil::MakeTupleShape({operand_shape,
1267                                     output_shape_for_mean_and_var,
1268                                     output_shape_for_mean_and_var});
1269 }
1270 
InferBatchNormInferenceShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & offset_shape,const Shape & mean_shape,const Shape & variance_shape,int64 feature_index)1271 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormInferenceShape(
1272     const Shape& operand_shape, const Shape& scale_shape,
1273     const Shape& offset_shape, const Shape& mean_shape,
1274     const Shape& variance_shape, int64 feature_index) {
1275   TF_RETURN_IF_ERROR(
1276       ExpectArray(operand_shape, "operand of batch norm inference"));
1277   TF_RETURN_IF_ERROR(
1278       ExpectArray(offset_shape, "offset input of batch norm inference"));
1279   TF_RETURN_IF_ERROR(
1280       ExpectArray(scale_shape, "scale input of batch norm inference"));
1281 
1282   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape));
1283   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape));
1284   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape));
1285   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape));
1286   TF_RETURN_IF_ERROR(
1287       ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape));
1288 
1289   if (feature_index >= operand_shape.rank()) {
1290     return InvalidArgument(
1291         "Expected feature_index of batch-norm-inference to be "
1292         "smaller than the rank of operand_shape; "
1293         "got feature_index %d, and rank %d.",
1294         feature_index, operand_shape.rank());
1295   }
1296 
1297   if (feature_index < 0) {
1298     return InvalidArgument(
1299         "Expected feature_index of batch-norm-inference to "
1300         "be a non-negative number, got %d.",
1301         feature_index);
1302   }
1303 
1304   if (operand_shape.rank() < 1) {
1305     return InvalidArgument(
1306         "Expected the rank of operand to "
1307         "batch-norm-inference to be at least 1; got %d.",
1308         operand_shape.rank());
1309   }
1310 
1311   if (offset_shape.rank() != 1) {
1312     return InvalidArgument(
1313         "Offset input of batch-norm-inference must have"
1314         " rank 1, but has rank %d.",
1315         offset_shape.rank());
1316   }
1317 
1318   if (scale_shape.rank() != 1) {
1319     return InvalidArgument(
1320         "Scale input of batch-norm-inference must have"
1321         " rank 1, but has rank %d.",
1322         scale_shape.rank());
1323   }
1324 
1325   if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1326     return InvalidArgument(
1327         "The operand to batch-norm-inference must have a floating point "
1328         "element type, but the shape is %s.",
1329         PrimitiveType_Name(operand_shape.element_type()));
1330   }
1331 
1332   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
1333                                                      operand_shape)) {
1334     return InvalidArgument(
1335         "The inputs should have the same element type for "
1336         "batch-norm-inference, "
1337         "but the shape of offset factor is %s "
1338         "and the shape of operand is %s.",
1339         PrimitiveType_Name(offset_shape.element_type()),
1340         PrimitiveType_Name(operand_shape.element_type()));
1341   }
1342 
1343   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1344                                                      operand_shape)) {
1345     return InvalidArgument(
1346         "The inputs should have the same element type for "
1347         "batch-norm-inference, "
1348         "but the shape of scale factor is %s "
1349         "and the shape of operand is %s.",
1350         PrimitiveType_Name(scale_shape.element_type()),
1351         PrimitiveType_Name(operand_shape.element_type()));
1352   }
1353 
1354   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
1355                                                      operand_shape)) {
1356     return InvalidArgument(
1357         "The inputs should have the same element type for "
1358         "batch-norm-inference, "
1359         "but the shape of mean is %s "
1360         "and the shape of operand is %s.",
1361         PrimitiveType_Name(mean_shape.element_type()),
1362         PrimitiveType_Name(operand_shape.element_type()));
1363   }
1364 
1365   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(variance_shape,
1366                                                      operand_shape)) {
1367     return InvalidArgument(
1368         "The inputs should have the same element type for "
1369         "batch-norm-inference, "
1370         "but the shape of variance is %s "
1371         "and the shape of operand is %s.",
1372         PrimitiveType_Name(mean_shape.element_type()),
1373         PrimitiveType_Name(variance_shape.element_type()));
1374   }
1375 
1376   const int64 feature_count = operand_shape.dimensions(feature_index);
1377   Shape output_shape_for_mean_and_var =
1378       ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1379 
1380   if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
1381     return InvalidArgument(
1382         "The size of offset factor should be the same as feature count,"
1383         "but the size of offset factor is %d "
1384         "and the feature count is %d.",
1385         ShapeUtil::GetDimension(offset_shape, 0), feature_count);
1386   }
1387 
1388   if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1389     return InvalidArgument(
1390         "The size of scale factor should be the same as feature count,"
1391         "but the size of scale factor is %d "
1392         "and the feature count is %d.",
1393         ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1394   }
1395 
1396   if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
1397     return InvalidArgument(
1398         "The size of mean should be the same as feature count,"
1399         "but the size of mean is %d "
1400         "and the feature count is %d.",
1401         ShapeUtil::GetDimension(mean_shape, 0), feature_count);
1402   }
1403 
1404   if (ShapeUtil::GetDimension(variance_shape, 0) != feature_count) {
1405     return InvalidArgument(
1406         "The size of variance should be the same as feature count,"
1407         "but the size of variance is %d "
1408         "and the feature count is %d.",
1409         ShapeUtil::GetDimension(variance_shape, 0), feature_count);
1410   }
1411 
1412   return operand_shape;
1413 }
1414 
InferBatchNormGradShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & mean_shape,const Shape & var_shape,const Shape & output_grad_shape,int64 feature_index)1415 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormGradShape(
1416     const Shape& operand_shape, const Shape& scale_shape,
1417     const Shape& mean_shape, const Shape& var_shape,
1418     const Shape& output_grad_shape, int64 feature_index) {
1419   TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of batch norm grad"));
1420   TF_RETURN_IF_ERROR(
1421       ExpectArray(scale_shape, "scale input of batch norm grad"));
1422   TF_RETURN_IF_ERROR(ExpectArray(mean_shape, "mean input of batch norm grad"));
1423   TF_RETURN_IF_ERROR(ExpectArray(var_shape, "var input of batch norm grad"));
1424   TF_RETURN_IF_ERROR(
1425       ExpectArray(output_grad_shape, "output_grad input of batch norm grad"));
1426 
1427   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape));
1428   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape));
1429   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape));
1430   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(var_shape));
1431   TF_RETURN_IF_ERROR(
1432       ShapeUtil::ValidateShapeWithOptionalLayout(output_grad_shape));
1433 
1434   if (feature_index >= operand_shape.rank()) {
1435     return InvalidArgument(
1436         "Expected feature_index of batch-norm-grad to be "
1437         "smaller than the rank of operand_shape; "
1438         "got feature_index %d, and rank %d.",
1439         feature_index, operand_shape.rank());
1440   }
1441 
1442   if (operand_shape.rank() != output_grad_shape.rank()) {
1443     return InvalidArgument(
1444         "Expected operand_shape of batch-norm-grad to have the same rank as"
1445         " output_grad_shape; got rank(oprand_shape) %d, and"
1446         " rank(output_grad_shape) %d.",
1447         operand_shape.rank(), output_grad_shape.rank());
1448   }
1449 
1450   if (mean_shape.rank() != 1) {
1451     return InvalidArgument(
1452         "Mean input of batch-norm-grad must have"
1453         " rank 1, but has rank %d.",
1454         mean_shape.rank());
1455   }
1456 
1457   if (scale_shape.rank() != 1) {
1458     return InvalidArgument(
1459         "Scale input of batch-norm-grad must have"
1460         " rank 1, but has rank %d.",
1461         scale_shape.rank());
1462   }
1463 
1464   if (var_shape.rank() != 1) {
1465     return InvalidArgument(
1466         "Var input of batch-norm-grad must have"
1467         " rank 1, but has rank %d.",
1468         var_shape.rank());
1469   }
1470 
1471   if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1472     return InvalidArgument(
1473         "The operand to batch-norm-grad must have a floating point "
1474         "element type, but the shape is %s.",
1475         PrimitiveType_Name(operand_shape.element_type()));
1476   }
1477 
1478   if (!ShapeUtil::ElementIsFloating(output_grad_shape)) {
1479     return InvalidArgument(
1480         "The output_grad to batch-norm-grad must have a floating point "
1481         "element type, but the shape is %s.",
1482         PrimitiveType_Name(output_grad_shape.element_type()));
1483   }
1484 
1485   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(output_grad_shape,
1486                                                      operand_shape)) {
1487     return InvalidArgument(
1488         "The inputs should have the same element type for batch-norm-grad, "
1489         "but the element type of output_grad is %s "
1490         "and the element type of operand is %s.",
1491         PrimitiveType_Name(output_grad_shape.element_type()),
1492         PrimitiveType_Name(operand_shape.element_type()));
1493   }
1494 
1495   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1496                                                      operand_shape)) {
1497     return InvalidArgument(
1498         "The inputs should have the same element type for batch-norm-grad, "
1499         "but the element type of scale factor is %s "
1500         "and the element type of operand is %s.",
1501         PrimitiveType_Name(scale_shape.element_type()),
1502         PrimitiveType_Name(operand_shape.element_type()));
1503   }
1504 
1505   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
1506                                                      operand_shape)) {
1507     return InvalidArgument(
1508         "The inputs should have the same element type for batch-norm-grad, "
1509         "but the element type of mean is %s "
1510         "and the element type of operand is %s.",
1511         PrimitiveType_Name(mean_shape.element_type()),
1512         PrimitiveType_Name(operand_shape.element_type()));
1513   }
1514 
1515   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(var_shape,
1516                                                      operand_shape)) {
1517     return InvalidArgument(
1518         "The inputs should have the same element type for batch-norm-grad, "
1519         "but the element type of mean is %s "
1520         "and the element type of operand is %s.",
1521         PrimitiveType_Name(mean_shape.element_type()),
1522         PrimitiveType_Name(operand_shape.element_type()));
1523   }
1524 
1525   const int64 feature_count = operand_shape.dimensions(feature_index);
1526 
1527   Shape feature_shape =
1528       ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1529 
1530   if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
1531     return InvalidArgument(
1532         "The size of mean should be the same as feature count,"
1533         "but the size of offset factor is %d "
1534         "and the feature count is %d.",
1535         ShapeUtil::GetDimension(mean_shape, 0), feature_count);
1536   }
1537 
1538   if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1539     return InvalidArgument(
1540         "The size of scale factor should be the same as feature count,"
1541         "but the size of scale factor is %d "
1542         "and the feature count is %d.",
1543         ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1544   }
1545 
1546   if (ShapeUtil::GetDimension(var_shape, 0) != feature_count) {
1547     return InvalidArgument(
1548         "The size of variance should be the same as feature count,"
1549         "but the size of variance is %d "
1550         "and the feature count is %d.",
1551         ShapeUtil::GetDimension(var_shape, 0), feature_count);
1552   }
1553 
1554   // Verify operand_shape and output_grad_shape have same bounds.
1555   for (int64 i = 0; i < operand_shape.rank(); ++i) {
1556     if (ShapeUtil::GetDimension(operand_shape, i) !=
1557         ShapeUtil::GetDimension(output_grad_shape, i)) {
1558       return InvalidArgument(
1559           "The bounds of operand shape should be the same as output_grad's,"
1560           "but the bound of operand_shape at dimension %d is %d "
1561           "and the bound of output_grad_shape is %d.",
1562           i, ShapeUtil::GetDimension(operand_shape, i),
1563           ShapeUtil::GetDimension(output_grad_shape, i));
1564     }
1565   }
1566 
1567   return ShapeUtil::MakeTupleShape(
1568       {operand_shape, feature_shape, feature_shape});
1569 }
1570 
InferConvolveShape(const Shape & lhs,const Shape & rhs,int64 feature_group_count,int64 batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dnums)1571 /* static */ StatusOr<Shape> ShapeInference::InferConvolveShape(
1572     const Shape& lhs, const Shape& rhs, int64 feature_group_count,
1573     int64 batch_group_count, const Window& window,
1574     const ConvolutionDimensionNumbers& dnums) {
1575   TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution"));
1576   TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution"));
1577 
1578   if (feature_group_count <= 0) {
1579     return InvalidArgument(
1580         "feature_group_count must be a positive number, got %d",
1581         feature_group_count);
1582   }
1583 
1584   if (batch_group_count <= 0) {
1585     return InvalidArgument(
1586         "batch_group_count must be a positive number, got %d",
1587         batch_group_count);
1588   }
1589 
1590   if (batch_group_count > 1 && feature_group_count > 1) {
1591     return InvalidArgument(
1592         "both batch_group_count %d and feature_group_count %d cannot be "
1593         "greater than 1",
1594         batch_group_count, feature_group_count);
1595   }
1596 
1597   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
1598     return InvalidArgument(
1599         "Convolution with different element types: %s and %s.",
1600         ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs));
1601   }
1602   if (dnums.input_spatial_dimensions_size() !=
1603       dnums.kernel_spatial_dimensions_size()) {
1604     return InvalidArgument(
1605         "Both arguments to convolution must have same number of dimensions.\n"
1606         "Numbers: %s",
1607         dnums.DebugString());
1608   }
1609 
1610   if (dnums.input_spatial_dimensions_size() !=
1611       dnums.output_spatial_dimensions_size()) {
1612     return InvalidArgument(
1613         "Both input and output of convolution must have same number of "
1614         "dimensions.\nNumbers: %s",
1615         dnums.DebugString());
1616   }
1617 
1618   const int num_spatial_dims = dnums.input_spatial_dimensions_size();
1619   if (window.dimensions_size() != num_spatial_dims) {
1620     return InvalidArgument(
1621         "Window must have same number of dimensions as dimension numbers.\n"
1622         "Window: %s\nDimension numbers: %s.",
1623         window.DebugString(), dnums.DebugString());
1624   }
1625 
1626   const int num_dims = num_spatial_dims + 2;
1627   if (lhs.rank() != num_dims) {
1628     return InvalidArgument(
1629         "The LHS argument to a convolution should have rank %d; lhs: %s.",
1630         num_dims, ShapeUtil::HumanString(lhs));
1631   }
1632   if (rhs.rank() != num_dims) {
1633     return InvalidArgument(
1634         "The RHS argument to a convolution should have rank %d; rhs: %s.",
1635         num_dims, ShapeUtil::HumanString(rhs));
1636   }
1637   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
1638   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
1639 
1640   // Verifies that the input and window dimensions are a permutation of
1641   // the dimension numbers.
1642   std::vector<int64> input_dnums(num_dims);
1643   input_dnums[0] = dnums.input_batch_dimension();
1644   input_dnums[1] = dnums.input_feature_dimension();
1645   std::copy(dnums.input_spatial_dimensions().begin(),
1646             dnums.input_spatial_dimensions().end(), input_dnums.begin() + 2);
1647   absl::c_sort(input_dnums);
1648 
1649   std::vector<int64> window_dnums(num_dims);
1650   window_dnums[0] = dnums.kernel_input_feature_dimension();
1651   window_dnums[1] = dnums.kernel_output_feature_dimension();
1652   std::copy(dnums.kernel_spatial_dimensions().begin(),
1653             dnums.kernel_spatial_dimensions().end(), window_dnums.begin() + 2);
1654   absl::c_sort(window_dnums);
1655 
1656   std::vector<int64> output_dnums(num_dims);
1657   output_dnums[0] = dnums.output_batch_dimension();
1658   output_dnums[1] = dnums.output_feature_dimension();
1659   std::copy(dnums.output_spatial_dimensions().begin(),
1660             dnums.output_spatial_dimensions().end(), output_dnums.begin() + 2);
1661   absl::c_sort(output_dnums);
1662 
1663   std::vector<int64> expected_dnums(num_dims);
1664   std::iota(expected_dnums.begin(), expected_dnums.end(), 0);
1665 
1666   const auto in_range = [num_dims](int64 i) { return 0 <= i && i < num_dims; };
1667   if (!absl::c_all_of(input_dnums, in_range) ||
1668       !absl::c_all_of(window_dnums, in_range) ||
1669       !absl::c_all_of(output_dnums, in_range)) {
1670     return InvalidArgument(
1671         "A dimension number is out of range in convolution: %s.",
1672         dnums.DebugString());
1673   }
1674 
1675   if (input_dnums != expected_dnums) {
1676     return InvalidArgument(
1677         "Input dimensions of convolution must contain each dimension exactly "
1678         "once: %s.",
1679         dnums.DebugString());
1680   }
1681   if (window_dnums != expected_dnums) {
1682     return InvalidArgument(
1683         "Window dimensions of convolution must contain each dimension exactly "
1684         "once: %s.",
1685         dnums.DebugString());
1686   }
1687   if (output_dnums != expected_dnums) {
1688     return InvalidArgument(
1689         "Output dimensions of convolution must contain each dimension exactly "
1690         "once: %s.",
1691         dnums.DebugString());
1692   }
1693 
1694   std::vector<int64> input_spatial_dims(num_spatial_dims);
1695   for (int i = 0; i < num_spatial_dims; ++i) {
1696     input_spatial_dims[i] = lhs.dimensions(dnums.input_spatial_dimensions(i));
1697   }
1698   const int64 input_features = lhs.dimensions(dnums.input_feature_dimension());
1699   const int64 input_batch = lhs.dimensions(dnums.input_batch_dimension());
1700 
1701   std::vector<int64> kernel_spatial_dims(num_spatial_dims);
1702   for (int i = 0; i < num_spatial_dims; ++i) {
1703     kernel_spatial_dims[i] = rhs.dimensions(dnums.kernel_spatial_dimensions(i));
1704   }
1705   const int64 kernel_input_features =
1706       rhs.dimensions(dnums.kernel_input_feature_dimension());
1707   const int64 kernel_output_features =
1708       rhs.dimensions(dnums.kernel_output_feature_dimension());
1709 
1710   if (batch_group_count > 1 && input_batch % kernel_output_features != 0) {
1711     return InvalidArgument(
1712         "Expected input batch (value %d) to be divisible by output feature "
1713         "dimension size (value %d) for batch group count %d; "
1714         "got <conv>(%s, %s)\n"
1715         "Dimension numbers: {%s}.",
1716         input_batch, kernel_output_features, batch_group_count,
1717         ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
1718         dnums.DebugString());
1719   }
1720 
1721   if (input_features % feature_group_count != 0 ||
1722       input_features / feature_group_count != kernel_input_features) {
1723     return InvalidArgument(
1724         "Expected LHS feature dimension (value %d) to be a multiple of "
1725         "feature_group_count (value %d), and LHS feature dimension / "
1726         "feature_group_count = RHS feature dimension (value %d); "
1727         "got <conv>(%s, %s)\n"
1728         "Dimension numbers: {%s}.",
1729         input_features, feature_group_count, kernel_input_features,
1730         ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
1731         dnums.DebugString());
1732   }
1733 
1734   if (kernel_output_features % feature_group_count > 0) {
1735     // A depthwise/grouped filter has the shape
1736     // [space0, .. spaceN, GROUP_SIZE, NUM_OUTPUT_FEATURES]. When
1737     // [space0, .. spaceN, GROUP_SIZE] is convolved with the input, a shape
1738     // [space0, .. spaceN, feature_group_count] is formed. Therefore, the output
1739     // feature count (which is equal to kernel output features) has to be a
1740     // multiple of feature_group_count.
1741     return InvalidArgument(
1742         "Expected output feature dimension (value %d) to be divisible by "
1743         "feature_group_count (value %d); "
1744         "got <conv>(%s, %s)\n"
1745         "Dimension numbers: {%s}.",
1746         kernel_output_features, feature_group_count,
1747         ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
1748         dnums.DebugString());
1749   }
1750 
1751   if (input_batch % batch_group_count > 0) {
1752     return InvalidArgument(
1753         "Expected input batch dimension (value %d) to be divisible by "
1754         "batch_group_count (value %d); "
1755         "got <conv>(%s, %s)\n"
1756         "Dimension numbers: {%s}.",
1757         input_batch, batch_group_count, ShapeUtil::HumanString(lhs),
1758         ShapeUtil::HumanString(rhs), dnums.DebugString());
1759   }
1760 
1761   std::vector<int64> window_dims(num_spatial_dims);
1762   for (int i = 0; i < num_spatial_dims; ++i) {
1763     window_dims[i] = window.dimensions(i).size();
1764   }
1765   if (kernel_spatial_dims != window_dims) {
1766     return InvalidArgument(
1767         "Window dimensions do not match RHS shape:\n\t"
1768         "RHS shape: %s\n\t"
1769         "Window: {%s}\n\t"
1770         "Dimension numbers: {%s}.",
1771         ShapeUtil::HumanString(rhs), window.ShortDebugString(),
1772         dnums.ShortDebugString());
1773   }
1774 
1775   Shape base_shape =
1776       ShapeUtil::MakeShape(lhs.element_type(), input_spatial_dims);
1777   TF_ASSIGN_OR_RETURN(
1778       Shape window_output_shape,
1779       InferWindowOutputShape(base_shape, window, lhs.element_type(),
1780                              /*allow_negative_padding=*/true));
1781 
1782   std::vector<int64> dimensions(num_dims);
1783   dimensions[dnums.output_batch_dimension()] = input_batch / batch_group_count;
1784   dimensions[dnums.output_feature_dimension()] = kernel_output_features;
1785   for (int i = 0; i < num_spatial_dims; ++i) {
1786     dimensions[dnums.output_spatial_dimensions(i)] =
1787         window_output_shape.dimensions(i);
1788   }
1789   std::vector<bool> is_dynamic(num_dims);
1790   for (int i = 0; i < num_dims; i++) {
1791     if (lhs.is_dynamic_dimension(i)) {
1792       if (i == dnums.input_batch_dimension()) {
1793         is_dynamic[dnums.output_batch_dimension()] = true;
1794       } else if (i == dnums.input_feature_dimension()) {
1795         // Input feature dimension is a contracting dimension, which does not
1796         // affect the output dimension size. So we need to do nothing.
1797       } else {
1798         return InvalidArgument(
1799             "Dynamic Spatial Convolution is not supported: lhs shape is %s ",
1800             lhs.ToString());
1801       }
1802     }
1803     if (rhs.is_dynamic_dimension(i)) {
1804       if (i == dnums.kernel_input_feature_dimension()) {
1805         // Kernel feature dimension does not affect the output dimension size.
1806         // So we need to do nothing.
1807       } else {
1808         return InvalidArgument(
1809             "Dynamic Spatial Convolution is not supported: rhs shape is %s ",
1810             rhs.ToString());
1811       }
1812     }
1813   }
1814   return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
1815                               dimensions, is_dynamic);
1816 }
1817 
InferFftShape(const Shape & in,const FftType fft_type,const absl::Span<const int64> fft_length)1818 /* static */ StatusOr<Shape> ShapeInference::InferFftShape(
1819     const Shape& in, const FftType fft_type,
1820     const absl::Span<const int64> fft_length) {
1821   const int64 fft_rank = fft_length.size();
1822   if (fft_rank < 1 || fft_rank > 3) {
1823     return InvalidArgument("FFT only supports ranks 1-3; got %d.", fft_rank);
1824   }
1825 #define RET_CHECK_RANK(x)                            \
1826   if (x.dimensions_size() < fft_rank) {              \
1827     return InvalidArgument(                          \
1828         "FFT of rank %d requires input of at least " \
1829         "same rank; got input of rank %d",           \
1830         fft_rank, x.dimensions_size());              \
1831   }
1832   switch (fft_type) {
1833     case FFT:
1834     case IFFT:
1835       if (in.element_type() != C64) {
1836         return InvalidArgument("%s requires complex input type, found %s.",
1837                                FftType_Name(fft_type),
1838                                PrimitiveType_Name(in.element_type()));
1839       }
1840       RET_CHECK_RANK(in);
1841       return in;
1842     case RFFT: {
1843       if (in.element_type() != F32) {
1844         return InvalidArgument("RFFT requires F32 input type, found %s.",
1845                                PrimitiveType_Name(in.element_type()));
1846       }
1847       RET_CHECK_RANK(in);
1848       for (int i = 0; i < fft_rank; i++) {
1849         if (in.dimensions(in.dimensions_size() - fft_rank + i) !=
1850             fft_length[i]) {
1851           return InvalidArgument(
1852               "RFFT requires innermost dimensions match fft_length but "
1853               "dimension %d is %d and should be %d.",
1854               in.dimensions_size() - fft_rank + i,
1855               in.dimensions(in.dimensions_size() - fft_rank + i),
1856               fft_length[i]);
1857         }
1858       }
1859       if (ShapeUtil::IsZeroElementArray(in)) {
1860         return in;
1861       }
1862       Shape result = ShapeUtil::ChangeElementType(in, C64);
1863       result.set_dimensions(result.dimensions_size() - 1,
1864                             fft_length[fft_rank - 1] / 2 + 1);
1865       return result;
1866     }
1867     case IRFFT: {
1868       if (in.element_type() != C64) {
1869         return InvalidArgument("IRFFT requires C64 input type, found %s.",
1870                                PrimitiveType_Name(in.element_type()));
1871       }
1872       RET_CHECK_RANK(in);
1873       Shape result = ShapeUtil::ComplexComponentShape(in);
1874       for (int i = 0; i < fft_rank - 1; i++) {
1875         if (in.dimensions(in.dimensions_size() - fft_rank + i) !=
1876             fft_length[i]) {
1877           return InvalidArgument(
1878               "IRFFT requires all but one innermost dimensions match "
1879               "fft_length, but dimension %d is %d and should be %d.",
1880               in.dimensions_size() - fft_rank + i,
1881               in.dimensions(in.dimensions_size() - fft_rank + i),
1882               fft_length[i]);
1883         }
1884       }
1885       if (in.dimensions(in.dimensions_size() - 1) !=
1886           fft_length[fft_rank - 1] / 2 + 1) {
1887         return InvalidArgument(
1888             "IRFFT requires innermost dimension matches fft_length/2+1, but "
1889             "dimension %d is %d and should be %d.",
1890             in.dimensions_size() - 1, in.dimensions(in.dimensions_size() - 1),
1891             fft_length[fft_rank - 1] / 2 + 1);
1892       }
1893       result.set_dimensions(result.dimensions_size() - 1,
1894                             fft_length[fft_rank - 1]);
1895       return result;
1896     }
1897     default:
1898       LOG(FATAL) << "Unexpected fft_type: " << fft_type;
1899   }
1900 #undef RET_CHECK_RANK
1901 }
1902 
InferTriangularSolveShape(const Shape & a,const Shape & b,const TriangularSolveOptions & options)1903 /* static */ StatusOr<Shape> ShapeInference::InferTriangularSolveShape(
1904     const Shape& a, const Shape& b, const TriangularSolveOptions& options) {
1905   if ((!ShapeUtil::ElementIsFloating(a) && !ShapeUtil::ElementIsComplex(a)) ||
1906       a.element_type() != b.element_type()) {
1907     return InvalidArgument(
1908         "Expected element types in shape to be floating or complex and "
1909         "identical for TriangularSolve; got %s and %s.",
1910         PrimitiveType_Name(a.element_type()),
1911         PrimitiveType_Name(b.element_type()));
1912   }
1913   if (a.rank() < 2) {
1914     return InvalidArgument(
1915         "The 'a' argument to TriangularSolve must have rank >= 2, got shape %s",
1916         a.ToString());
1917   }
1918   if (b.rank() != a.rank()) {
1919     return InvalidArgument(
1920         "Arguments to triangular solve must have equal rank; got %s and %s.",
1921         b.ToString(), a.ToString());
1922   }
1923   if (a.dimensions(a.rank() - 2) != a.dimensions(a.rank() - 1)) {
1924     return InvalidArgument(
1925         "The two minor dimensions of 'a' must have equal size, got %s.",
1926         a.ToString());
1927   }
1928   if (a.dimensions(a.rank() - 1) !=
1929       b.dimensions(b.rank() - (options.left_side() ? 2 : 1))) {
1930     return InvalidArgument(
1931         "The shared dimension of 'a' and 'b' does not match, got shapes %s and "
1932         "%s",
1933         a.ToString(), b.ToString());
1934   }
1935   absl::Span<const int64> a_batch_dims(a.dimensions());
1936   absl::Span<const int64> b_batch_dims(b.dimensions());
1937   a_batch_dims.remove_suffix(2);
1938   b_batch_dims.remove_suffix(2);
1939   if (a_batch_dims != b_batch_dims) {
1940     return InvalidArgument(
1941         "The leading batch dimensions of the arguments to triangular solve "
1942         "must be equal; got %s and %s.",
1943         b.ToString(), a.ToString());
1944   }
1945   if (!TriangularSolveOptions_Transpose_IsValid(options.transpose_a()) ||
1946       options.transpose_a() == TriangularSolveOptions::TRANSPOSE_INVALID) {
1947     return InvalidArgument(
1948         "Invalid transpose option value for triangular solve (%d).\n",
1949         options.transpose_a());
1950   }
1951   return b;
1952 }
1953 
InferCholeskyShape(const Shape & a)1954 /* static */ StatusOr<Shape> ShapeInference::InferCholeskyShape(
1955     const Shape& a) {
1956   if (!ShapeUtil::ElementIsFloating(a) && !ShapeUtil::ElementIsComplex(a)) {
1957     return InvalidArgument(
1958         "Expected element type in shape to be floating or complex for "
1959         "Cholesky; got %s.",
1960         PrimitiveType_Name(a.element_type()));
1961   }
1962   if (a.rank() < 2) {
1963     return InvalidArgument(
1964         "The 'a' argument to Cholesky must have rank >= 2, got shape %s",
1965         a.ToString());
1966   }
1967   if (a.dimensions(a.rank() - 2) != a.dimensions(a.rank() - 1)) {
1968     return InvalidArgument(
1969         "The two minor dimensions of 'a' must have equal size, got %s.",
1970         a.ToString());
1971   }
1972   return a;
1973 }
1974 
InferAllReduceShape(absl::Span<const Shape * const> operand_shapes)1975 /* static */ StatusOr<Shape> ShapeInference::InferAllReduceShape(
1976     absl::Span<const Shape* const> operand_shapes) {
1977   for (const Shape* operand_shape : operand_shapes) {
1978     TF_RETURN_IF_ERROR(
1979         ExpectArray(*operand_shape, "operand of cross replica sum"));
1980   }
1981   if (operand_shapes.size() == 1) {
1982     return *operand_shapes[0];
1983   }
1984   std::vector<Shape> operand_shape_values;
1985   for (const Shape* operand_shape : operand_shapes) {
1986     operand_shape_values.push_back(*operand_shape);
1987   }
1988   return ShapeUtil::MakeTupleShape(operand_shape_values);
1989 }
1990 
InferAllToAllShape(const Shape & shape,int64 split_dimension,int64 concat_dimension,int64 split_count)1991 /* static */ StatusOr<Shape> ShapeInference::InferAllToAllShape(
1992     const Shape& shape, int64 split_dimension, int64 concat_dimension,
1993     int64 split_count) {
1994   TF_RET_CHECK(split_count > 0);
1995   if (split_dimension >= shape.rank() || split_dimension < 0) {
1996     return InvalidArgument(
1997         "AllToAll split_dimension %d is out-of-bounds in shape %s.",
1998         split_dimension, ShapeUtil::HumanString(shape));
1999   }
2000   if (concat_dimension >= shape.rank() || concat_dimension < 0) {
2001     return InvalidArgument(
2002         "AllToAll concat_dimension %d is out-of-bounds in shape %s.",
2003         concat_dimension, ShapeUtil::HumanString(shape));
2004   }
2005   if (shape.dimensions(split_dimension) % split_count != 0) {
2006     return InvalidArgument(
2007         "AllToAll split dimension size %d must be dividable by split_count "
2008         "%d.",
2009         shape.dimensions(split_dimension), split_count);
2010   }
2011   std::vector<int64> new_dimensions(shape.dimensions().begin(),
2012                                     shape.dimensions().end());
2013   new_dimensions[split_dimension] /= split_count;
2014   new_dimensions[concat_dimension] *= split_count;
2015   return ShapeUtil::MakeShape(shape.element_type(), new_dimensions);
2016 }
2017 
InferAllToAllTupleShape(absl::Span<const Shape * const> operand_shapes)2018 /* static */ StatusOr<Shape> ShapeInference::InferAllToAllTupleShape(
2019     absl::Span<const Shape* const> operand_shapes) {
2020   // An Alltoall HLO instruction receives N operands (with the same shape) and
2021   // returns a tuple that contains N array shapes.
2022   TF_RET_CHECK(!operand_shapes.empty());
2023   for (int i = 0; i < operand_shapes.size(); i++) {
2024     if (!ShapeUtil::Equal(*operand_shapes[0], *operand_shapes[i])) {
2025       return InvalidArgument(
2026           "HLO all-to-all has operands with different shapes: the 0th "
2027           "operand shape %s, but the %dth operand has shape %s.",
2028           ShapeUtil::HumanString(*operand_shapes[0]), i,
2029           ShapeUtil::HumanString(*operand_shapes[i]));
2030     }
2031   }
2032 
2033   return InferVariadicOpShape(HloOpcode::kTuple, operand_shapes);
2034 }
2035 
InferCollectivePermuteShape(const Shape & shape)2036 /* static */ StatusOr<Shape> ShapeInference::InferCollectivePermuteShape(
2037     const Shape& shape) {
2038   TF_RET_CHECK(shape.IsArray());
2039   return shape;
2040 }
2041 
InferReduceShape(absl::Span<const Shape * const> arg_shapes,absl::Span<const int64> dimensions_to_reduce,const ProgramShape & to_apply)2042 /* static */ StatusOr<Shape> ShapeInference::InferReduceShape(
2043     absl::Span<const Shape* const> arg_shapes,
2044     absl::Span<const int64> dimensions_to_reduce,
2045     const ProgramShape& to_apply) {
2046   if (arg_shapes.empty()) {
2047     return InvalidArgument("Reduce must have at least 2 arguments, has 0");
2048   }
2049   if (arg_shapes.size() % 2) {
2050     return InvalidArgument(
2051         "Reduce must have an even number of arguments, has %lu",
2052         arg_shapes.size());
2053   }
2054   int64 num_reduced_args = arg_shapes.size() / 2;
2055 
2056   auto reduced_args = arg_shapes.subspan(0, num_reduced_args);
2057   // Check that all of the reduced tensors have the same dimensions. The element
2058   // types may be different.
2059   for (int64 i = 1; i < num_reduced_args; ++i) {
2060     if (!ShapeUtil::SameDimensions(*reduced_args[0], *reduced_args[i])) {
2061       return InvalidArgument(
2062           "All reduced tensors must have the same dimension. Tensor 0 has "
2063           "shape %s, Tensor %d has shape %s",
2064           ShapeUtil::HumanString(*reduced_args[0]), i,
2065           ShapeUtil::HumanString(*reduced_args[i]));
2066     }
2067   }
2068 
2069   // Check that the dimensions to reduce are in-bounds for the given shape.
2070   // We've already verified all reduced tensors have the same dimensions, so it
2071   // doesn't matter which one we choose.
2072   const Shape& arg = *reduced_args[0];
2073   for (int64 dimension : dimensions_to_reduce) {
2074     if (dimension >= arg.rank() || dimension < 0) {
2075       return InvalidArgument("Reducing out-of-bounds dimension %d in shape %s.",
2076                              dimension, ShapeUtil::HumanString(arg));
2077     }
2078   }
2079 
2080   auto init_values = arg_shapes.subspan(num_reduced_args, arg_shapes.size());
2081   std::vector<PrimitiveType> element_types;
2082   for (const Shape* arg : reduced_args) {
2083     element_types.push_back(arg->element_type());
2084   }
2085   TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply, init_values, element_types,
2086                                         num_reduced_args));
2087 
2088   std::set<int64> dimensions_to_reduce_set(dimensions_to_reduce.begin(),
2089                                            dimensions_to_reduce.end());
2090   std::vector<int64> new_dimensions;
2091   std::vector<bool> new_is_dynamic;
2092   for (int i = 0; i < arg.rank(); ++i) {
2093     if (dimensions_to_reduce_set.find(i) == dimensions_to_reduce_set.end()) {
2094       new_dimensions.push_back(arg.dimensions(i));
2095       new_is_dynamic.push_back(arg.is_dynamic_dimension(i));
2096     }
2097   }
2098 
2099   if (ShapeUtil::IsScalar(to_apply.result())) {
2100     return ShapeUtil::MakeShape(to_apply.result().element_type(),
2101                                 new_dimensions, new_is_dynamic);
2102   } else {
2103     std::vector<Shape> result_subshapes;
2104     for (const Shape& subshape : to_apply.result().tuple_shapes()) {
2105       result_subshapes.push_back(ShapeUtil::MakeShape(
2106           subshape.element_type(), new_dimensions, new_is_dynamic));
2107     }
2108     return ShapeUtil::MakeTupleShape(result_subshapes);
2109   }
2110 }
2111 
InferReduceWindowShape(const Shape & operand_shape,const Shape & init_value_shape,const Window & window,const ProgramShape & to_apply_shape)2112 /* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
2113     const Shape& operand_shape, const Shape& init_value_shape,
2114     const Window& window, const ProgramShape& to_apply_shape) {
2115   TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reduce-window"));
2116   TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape},
2117                                         {operand_shape.element_type()},
2118                                         /*inputs=*/1));
2119   return InferWindowOutputShape(operand_shape, window,
2120                                 init_value_shape.element_type(),
2121                                 /*allow_negative_padding=*/false);
2122 }
2123 
InferSelectAndScatterShape(const Shape & operand_shape,const ProgramShape & select_shape,const Window & window,const Shape & source_shape,const Shape & init_value_shape,const ProgramShape & scatter_shape)2124 /* static */ StatusOr<Shape> ShapeInference::InferSelectAndScatterShape(
2125     const Shape& operand_shape, const ProgramShape& select_shape,
2126     const Window& window, const Shape& source_shape,
2127     const Shape& init_value_shape, const ProgramShape& scatter_shape) {
2128   TF_RETURN_IF_ERROR(
2129       ExpectArray(operand_shape, "operand of select-and-scatter"));
2130 
2131   // Check if the select function has a proper shape of (T,T) -> PRED.
2132   if (select_shape.parameters_size() != 2) {
2133     return InvalidArgument(
2134         "Select function must take 2 parameters, but "
2135         "takes %d parameter(s).",
2136         select_shape.parameters_size());
2137   }
2138   const Shape& select_result_shape = select_shape.result();
2139   if (!ShapeUtil::Compatible(select_result_shape,
2140                              ShapeUtil::MakeShape(PRED, {}))) {
2141     return InvalidArgument("Select function must have rank-0 PRED result.");
2142   }
2143   const Shape& operand_element_shape =
2144       ShapeUtil::MakeShape(operand_shape.element_type(), {});
2145   if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
2146                                                 select_shape.parameters(0))) {
2147     return InvalidArgument(
2148         "Select function's first parameter shape currently must "
2149         "match the operand element shape, but got %s vs %s.",
2150         ShapeUtil::HumanString(select_shape.parameters(0)),
2151         ShapeUtil::HumanString(operand_element_shape));
2152   }
2153   if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
2154                                                 select_shape.parameters(1))) {
2155     return InvalidArgument(
2156         "Select function's second parameter shape currently must "
2157         "match the operand element shape, but got %s vs %s.",
2158         ShapeUtil::HumanString(select_shape.parameters(1)),
2159         ShapeUtil::HumanString(operand_element_shape));
2160   }
2161 
2162   // Check if the scatter function has a proper shape as a reduction.
2163   TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, {&init_value_shape},
2164                                         {source_shape.element_type()},
2165                                         /*inputs=*/1));
2166 
2167   // Check if the result shape of window operation matches the source shape.
2168   TF_ASSIGN_OR_RETURN(const Shape& window_result_shape,
2169                       InferWindowOutputShape(operand_shape, window,
2170                                              operand_shape.element_type(),
2171                                              /*allow_negative_padding=*/false));
2172   if (!ShapeUtil::CompatibleIgnoringFpPrecision(source_shape,
2173                                                 window_result_shape)) {
2174     return InvalidArgument(
2175         "Source shape does not match the shape of window-reduced operand: "
2176         "source(%s), window-reduced operand(%s).",
2177         ShapeUtil::HumanString(source_shape),
2178         ShapeUtil::HumanString(window_result_shape));
2179   }
2180 
2181   return operand_shape;
2182 }
2183 
InferGetDimensionSizeShape(const Shape & shape,int64 dimension)2184 /* static */ StatusOr<Shape> ShapeInference::InferGetDimensionSizeShape(
2185     const Shape& shape, int64 dimension) {
2186   if (dimension < 0 || dimension >= shape.rank()) {
2187     return InvalidArgument("GetDimensionSize dimension out of bounds: %d.",
2188                            dimension);
2189   }
2190 
2191   // TODO(b/119580730): Remove this restriction when very large dimension size
2192   // is needed.
2193   if (shape.dimensions(dimension) > std::numeric_limits<uint32>::max()) {
2194     return InvalidArgument(
2195         "GetDimensionSize's input shape is %s, the %dth dimension exceeds the "
2196         "UINT_MAX limit.",
2197         ShapeUtil::HumanString(shape), dimension);
2198   }
2199 
2200   return ShapeUtil::MakeShape(U32, {});
2201 }
2202 
InferSliceShape(const Shape & arg,absl::Span<const int64> starts,absl::Span<const int64> limits,absl::Span<const int64> strides)2203 /* static */ StatusOr<Shape> ShapeInference::InferSliceShape(
2204     const Shape& arg, absl::Span<const int64> starts,
2205     absl::Span<const int64> limits, absl::Span<const int64> strides) {
2206   auto error = [&](const string& message) {
2207     return InvalidArgument(
2208         "%s in slice operation; argument shape: %s; starts: {%s}; limits: "
2209         "{%s}; strides: {%s}.",
2210         message, ShapeUtil::HumanString(arg), StrJoin(starts, ","),
2211         StrJoin(limits, ","), StrJoin(strides, ","));
2212   };
2213   TF_RETURN_IF_ERROR(ExpectArray(arg, "operand of slice"));
2214   VLOG(2) << StrFormat("slicing shape %s starts={%s} limits={%s}",
2215                        ShapeUtil::HumanString(arg), StrJoin(starts, ", "),
2216                        StrJoin(limits, ", "));
2217 
2218   if (starts.size() != limits.size()) {
2219     return error(StrFormat("slice start and limit sizes differ: %u vs %u",
2220                            starts.size(), limits.size()));
2221   }
2222 
2223   if (starts.size() != strides.size()) {
2224     return error(StrFormat("slice start and strides sizes differ: %u vs %u",
2225                            starts.size(), strides.size()));
2226   }
2227 
2228   if (starts.size() != arg.rank()) {
2229     return InvalidArgument(
2230         "Slice index count does not match argument rank: %u vs %d.",
2231         starts.size(), arg.rank());
2232   }
2233 
2234   std::vector<int64> sizes;
2235   for (int64 dimension = 0; dimension < starts.size(); ++dimension) {
2236     int64 start_index = starts[dimension];
2237     int64 limit_index = limits[dimension];
2238     int64 stride = strides[dimension];
2239     if (start_index < 0) {
2240       return InvalidArgument("Negative start index to slice: %d.", start_index);
2241     }
2242     if (limit_index > arg.dimensions(dimension)) {
2243       return error(
2244           StrFormat("limit index (%d) must be less than or equal to dimension "
2245                     "size (%d)",
2246                     limit_index, arg.dimensions(dimension)));
2247     }
2248     VLOG(2) << StrFormat("starts[%d] = %d", dimension, start_index);
2249     VLOG(2) << StrFormat("limits[%d] = %d", dimension, limit_index);
2250     if (start_index > limit_index) {
2251       return error(
2252           StrFormat("limit index (%d) must be greater or equal to "
2253                     "start index (%d) in slice with positive stride",
2254                     limit_index, start_index));
2255     }
2256     if (stride <= 0) {
2257       return InvalidArgument("Stride (%d) must be positive.", stride);
2258     }
2259     sizes.push_back((limit_index - start_index + stride - 1) / stride);
2260   }
2261 
2262   return ShapeUtil::MakeShape(arg.element_type(), sizes);
2263 }
2264 
InferDynamicSliceShape(const Shape & operand_shape,absl::Span<const Shape> start_index_shapes,absl::Span<const int64> slice_sizes,bool allow_scalar_indices)2265 /* static */ StatusOr<Shape> ShapeInference::InferDynamicSliceShape(
2266     const Shape& operand_shape, absl::Span<const Shape> start_index_shapes,
2267     absl::Span<const int64> slice_sizes, bool allow_scalar_indices) {
2268   TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice"));
2269   auto number_of_indices = start_index_shapes.size();
2270   // TODO(b/118437727): Remove this path.
2271   if (!allow_scalar_indices ||
2272       (number_of_indices >= 1 && start_index_shapes[0].rank() == 1)) {
2273     if (number_of_indices != 1) {
2274       return InvalidArgument(
2275           "Dynamic slice should have exactly 1 index operand, has %d.",
2276           number_of_indices);
2277     }
2278 
2279     const Shape& start_indices_shape = start_index_shapes[0];
2280     VLOG(2) << StrFormat(
2281         "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}",
2282         ShapeUtil::HumanString(operand_shape),
2283         ShapeUtil::HumanString(start_indices_shape),
2284         StrJoin(slice_sizes, ", "));
2285 
2286     TF_RETURN_IF_ERROR(
2287         ExpectArray(start_indices_shape, "start indices of dynamic slice"));
2288 
2289     if (start_indices_shape.rank() != 1) {
2290       return InvalidArgument(
2291           "Dynamic slice start indices of rank %d must be rank1.",
2292           start_indices_shape.rank());
2293     }
2294 
2295     if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
2296       return InvalidArgument(
2297           "Dynamic slice start indices must be of integral type.");
2298     }
2299 
2300     const int64 start_num_dims = start_indices_shape.dimensions(0);
2301     if (operand_shape.rank() != start_num_dims) {
2302       return InvalidArgument(
2303           "Dynamic slice start number of dimensions %d (%s) must match rank "
2304           "%d of slice input (%s).",
2305           start_num_dims, ShapeUtil::HumanString(start_indices_shape),
2306           operand_shape.rank(), ShapeUtil::HumanString(operand_shape));
2307     }
2308   } else {
2309     VLOG(2) << StrFormat("slicing shape %s a with slice_sizes={%s}",
2310                          ShapeUtil::HumanString(operand_shape),
2311                          StrJoin(slice_sizes, ", "));
2312 
2313     if (operand_shape.rank() != number_of_indices) {
2314       return InvalidArgument(
2315           "Dynamic slice start number of dimensions %d must match rank "
2316           "%d of slice input (%s).",
2317           number_of_indices, operand_shape.rank(),
2318           ShapeUtil::HumanString(operand_shape));
2319     }
2320 
2321     if (number_of_indices > 0) {
2322       const Shape& first_index_shape = start_index_shapes[0];
2323       if (!ShapeUtil::IsScalar(first_index_shape)) {
2324         return InvalidArgument("Dynamic slice indices must be scalar, not %s.",
2325                                ShapeUtil::HumanString(first_index_shape));
2326       }
2327       if (!ShapeUtil::ElementIsIntegral(first_index_shape)) {
2328         return InvalidArgument(
2329             "Dynamic slice start indices must be of integral type.");
2330       }
2331       for (const Shape& index_shape : start_index_shapes) {
2332         if (!ShapeUtil::Compatible(first_index_shape, index_shape)) {
2333           return InvalidArgument(
2334               "Dynamic slice start indices must all have the same shape, got "
2335               "mismatching indices with shapes %s and %s.",
2336               ShapeUtil::HumanString(first_index_shape),
2337               ShapeUtil::HumanString(index_shape));
2338         }
2339       }
2340     }
2341   }
2342 
2343   if (slice_sizes.size() != operand_shape.rank()) {
2344     return InvalidArgument(
2345         "Dynamic slice index count does not match argument rank: %u vs %d.",
2346         slice_sizes.size(), operand_shape.rank());
2347   }
2348 
2349   for (int64 dim = 0; dim < slice_sizes.size(); ++dim) {
2350     const int64 input_dim_size = operand_shape.dimensions(dim);
2351     const int64 slice_dim_size = slice_sizes[dim];
2352     if (slice_dim_size < 0) {
2353       return InvalidArgument("Negative size index to dynamic slice: %d.",
2354                              slice_dim_size);
2355     }
2356     if (slice_dim_size > input_dim_size) {
2357       return InvalidArgument(
2358           "Slice dim size %d greater than dynamic slice dimension: %d.",
2359           slice_dim_size, input_dim_size);
2360     }
2361     VLOG(2) << StrFormat("slice_sizes[%d] = %d", dim, slice_dim_size);
2362   }
2363 
2364   return ShapeUtil::MakeShape(operand_shape.element_type(), slice_sizes);
2365 }
2366 
InferDynamicUpdateSliceShape(const Shape & operand_shape,const Shape & update_shape,absl::Span<const Shape> start_index_shapes,bool allow_scalar_indices)2367 /* static */ StatusOr<Shape> ShapeInference::InferDynamicUpdateSliceShape(
2368     const Shape& operand_shape, const Shape& update_shape,
2369     absl::Span<const Shape> start_index_shapes, bool allow_scalar_indices) {
2370   TF_RETURN_IF_ERROR(
2371       ExpectArray(operand_shape, "operand of dynamic update slice"));
2372   TF_RETURN_IF_ERROR(
2373       ExpectArray(update_shape, "update of dynamic update slice"));
2374 
2375   auto number_of_indices = start_index_shapes.size();
2376   // TODO(b/118437727): Remove this path.
2377   if (!allow_scalar_indices ||
2378       (number_of_indices >= 1 && start_index_shapes[0].rank() == 1)) {
2379     if (number_of_indices != 1) {
2380       return InvalidArgument(
2381           "Dynamic update slice should have exactly 1 index operand, has %d.",
2382           number_of_indices);
2383     }
2384     const Shape& start_indices_shape = start_index_shapes[0];
2385     TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape,
2386                                    "start indices of dynamic update slice"));
2387 
2388     VLOG(2) << StrFormat(
2389         "updating slice of shape %s at dynamic start_indices %s with update "
2390         "shape %s",
2391         ShapeUtil::HumanString(operand_shape),
2392         ShapeUtil::HumanString(start_indices_shape),
2393         ShapeUtil::HumanString(update_shape));
2394 
2395     if (start_indices_shape.rank() != 1) {
2396       return InvalidArgument(
2397           "Dynamic update slice start indices of rank %d must be rank1.",
2398           start_indices_shape.rank());
2399     }
2400 
2401     if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
2402       return InvalidArgument(
2403           "Dynamic update slice start indices must be of integral type.");
2404     }
2405 
2406     const int64 start_num_dims = start_indices_shape.dimensions(0);
2407     if (operand_shape.rank() != start_num_dims) {
2408       return InvalidArgument(
2409           "Dynamic update slice start number of dimensions %d (%s) must match "
2410           "rank %d of slice input (%s).",
2411           start_num_dims, ShapeUtil::HumanString(start_indices_shape),
2412           operand_shape.rank(), ShapeUtil::HumanString(operand_shape));
2413     }
2414   } else {
2415     VLOG(2) << StrFormat("updating slice of shape %s with update shape %s",
2416                          ShapeUtil::HumanString(operand_shape),
2417                          ShapeUtil::HumanString(update_shape));
2418 
2419     if (operand_shape.rank() != number_of_indices) {
2420       return InvalidArgument(
2421           "Dynamic update slice start number of dimensions %d must match "
2422           "rank %d of slice input (%s).",
2423           number_of_indices, operand_shape.rank(),
2424           ShapeUtil::HumanString(operand_shape));
2425     }
2426 
2427     if (number_of_indices > 0) {
2428       const Shape& first_index_shape = start_index_shapes[0];
2429       if (!ShapeUtil::IsScalar(first_index_shape)) {
2430         return InvalidArgument(
2431             "Dynamic update slice indices must be scalar, not %s.",
2432             ShapeUtil::HumanString(first_index_shape));
2433       }
2434       if (!ShapeUtil::ElementIsIntegral(first_index_shape)) {
2435         return InvalidArgument(
2436             "Dynamic update slice start indices must be of integral type.");
2437       }
2438       for (const Shape& index_shape : start_index_shapes) {
2439         if (!ShapeUtil::Compatible(first_index_shape, index_shape)) {
2440           return InvalidArgument(
2441               "Dynamic update slice start indices must all have the same "
2442               "shape, got mismatching indices with shapes %s and %s.",
2443               ShapeUtil::HumanString(first_index_shape),
2444               ShapeUtil::HumanString(index_shape));
2445         }
2446       }
2447     }
2448   }
2449 
2450   if (update_shape.rank() != operand_shape.rank()) {
2451     return InvalidArgument(
2452         "Dynamic update slice update rank does not match argument rank: "
2453         "%d vs %d.",
2454         update_shape.rank(), operand_shape.rank());
2455   }
2456 
2457   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
2458                                                      update_shape)) {
2459     return InvalidArgument(
2460         "Dynamic update slice update element type does not match argument. "
2461         "operand.element_type: %s vs update.element_type: %s.",
2462         PrimitiveType_Name(operand_shape.element_type()),
2463         PrimitiveType_Name(update_shape.element_type()));
2464   }
2465 
2466   for (int64 dim = 0; dim < operand_shape.rank(); ++dim) {
2467     const int64 input_dim_size = operand_shape.dimensions(dim);
2468     const int64 update_dim_size = update_shape.dimensions(dim);
2469     if (update_dim_size < 0) {
2470       return InvalidArgument(
2471           "Size index %d to dynamic update slice must be >= 0.",
2472           update_dim_size);
2473     }
2474     if (update_dim_size > input_dim_size) {
2475       return InvalidArgument(
2476           "Update dim size %d greater than dynamic slice dimension: %d.",
2477           update_dim_size, input_dim_size);
2478     }
2479     VLOG(2) << StrFormat("update_sizes[%d] = %d", dim, update_dim_size);
2480   }
2481 
2482   return operand_shape;
2483 }
2484 
InferReverseShape(const Shape & operand_shape,absl::Span<const int64> dimensions)2485 /*static */ StatusOr<Shape> ShapeInference::InferReverseShape(
2486     const Shape& operand_shape, absl::Span<const int64> dimensions) {
2487   TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reverse"));
2488   if (!AllUnique(dimensions)) {
2489     return InvalidArgument("a dimension number is duplicated in reverse");
2490   }
2491   for (int64 dimension : dimensions) {
2492     if (dimension >= operand_shape.rank() || dimension < 0) {
2493       return InvalidArgument(
2494           "One of the reverse dimensions (%d) is out-of-bounds in shape %s.",
2495           dimension, ShapeUtil::HumanString(operand_shape));
2496     }
2497   }
2498   return operand_shape;
2499 }
2500 
InferGetTupleElementShape(const Shape & arg,int64 index)2501 /* static */ StatusOr<Shape> ShapeInference::InferGetTupleElementShape(
2502     const Shape& arg, int64 index) {
2503   if (!arg.IsTuple()) {
2504     return InvalidArgument(
2505         "Cannot infer shape: attempting to index into non-tuple: %s.",
2506         ShapeUtil::HumanString(arg));
2507   }
2508 
2509   if (index < 0 || index >= arg.tuple_shapes_size()) {
2510     return InvalidArgument(
2511         "Cannot infer shape: attempt to index out of tuple bounds: %d "
2512         ">= %d in shape %s.",
2513         index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg));
2514   }
2515 
2516   return arg.tuple_shapes(index);
2517 }
2518 
InferWhileShape(const ProgramShape & condition,const ProgramShape & body,const Shape & init)2519 /* static */ StatusOr<Shape> ShapeInference::InferWhileShape(
2520     const ProgramShape& condition, const ProgramShape& body,
2521     const Shape& init) {
2522   // Check the number of parameters for given computations.
2523   if (condition.parameters_size() != 1) {
2524     return InvalidArgument("Condition must take 1 arguments; got %d.",
2525                            condition.parameters_size());
2526   }
2527   if (body.parameters_size() != 1) {
2528     return InvalidArgument("Body must take 1 arguments; got %d.",
2529                            body.parameters_size());
2530   }
2531 
2532   auto shape_string = [&]() {
2533     return StrFormat(
2534         "Condition: %s; body: %s; init: %s.", ShapeUtil::HumanString(condition),
2535         ShapeUtil::HumanString(body), ShapeUtil::HumanString(init));
2536   };
2537 
2538   // Check the shapes of computation parameters and return types.
2539   if (!ShapeUtil::Equal(condition.result(), ShapeUtil::MakeShape(PRED, {}))) {
2540     return InvalidArgument("Condition must return a boolean; got %s.",
2541                            shape_string());
2542   }
2543   if (!ShapeUtil::Compatible(body.result(), condition.parameters(0)) ||
2544       !ShapeUtil::Compatible(body.result(), body.parameters(0)) ||
2545       !ShapeUtil::Compatible(body.result(), init)) {
2546     return InvalidArgument(
2547         "The parameter of condition and body, the result of the body, and init "
2548         "must all have the same shape; got %s.",
2549         shape_string());
2550   }
2551 
2552   return init;
2553 }
2554 
InferConditionalShape(const Shape & branch_index,absl::Span<const ProgramShape> branch_computations,absl::Span<const Shape> branch_operands)2555 /* static */ StatusOr<Shape> ShapeInference::InferConditionalShape(
2556     const Shape& branch_index,
2557     absl::Span<const ProgramShape> branch_computations,
2558     absl::Span<const Shape> branch_operands) {
2559   if (!ShapeUtil::Equal(branch_index, ShapeUtil::MakeShape(PRED, {})) &&
2560       !ShapeUtil::Equal(branch_index, ShapeUtil::MakeShape(S32, {}))) {
2561     return InvalidArgument("branch_index must be bool or int32; got %s.",
2562                            ShapeUtil::HumanString(branch_index));
2563   }
2564   if (branch_index.element_type() == PRED) {
2565     TF_RET_CHECK(2 == branch_computations.size());
2566   } else {
2567     TF_RET_CHECK(!branch_computations.empty());
2568   }
2569   TF_RET_CHECK(branch_computations.size() == branch_operands.size());
2570 
2571   for (int j = 0; j < branch_computations.size(); ++j) {
2572     if (branch_computations[j].parameters_size() != 1) {
2573       return InvalidArgument(
2574           "branch computation %d must take 1 argument; got %d.", j,
2575           branch_computations[j].parameters_size());
2576     }
2577     if (!ShapeUtil::Compatible(branch_computations[j].parameters(0),
2578                                branch_operands[j])) {
2579       auto shape_string = [&]() {
2580         return StrFormat("operand: %s; computation: %s",
2581                          ShapeUtil::HumanString(branch_operands[j]),
2582                          ShapeUtil::HumanString(branch_computations[j]));
2583       };
2584       return InvalidArgument(
2585           "branch operand %d must match the shape of the only parameter of "
2586           "branch computation %d: got %s.",
2587           j, j, shape_string());
2588     }
2589 
2590     if (!ShapeUtil::Compatible(branch_computations[0].result(),
2591                                branch_computations[j].result())) {
2592       auto shape_string = [&]() {
2593         return StrFormat(
2594             "branch 0 computation result: %s; branch %d computation result: %s",
2595             ShapeUtil::HumanString(branch_computations[0].result()), j,
2596             ShapeUtil::HumanString(branch_computations[j].result()));
2597       };
2598       return InvalidArgument(
2599           "the result of branch 0 computation and branch %d computation must "
2600           "have the same shape: got %s.",
2601           j, shape_string());
2602     }
2603   }
2604   return branch_computations[0].result();
2605 }
2606 
InferBroadcastShape(const Shape & operand,absl::Span<const int64> broadcast_sizes)2607 /* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
2608     const Shape& operand, absl::Span<const int64> broadcast_sizes) {
2609   TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast"));
2610   for (int64 size : broadcast_sizes) {
2611     if (size < 0) {
2612       return InvalidArgument("Broadcast with negative dimension size %d.",
2613                              size);
2614     }
2615   }
2616 
2617   std::vector<int64> dimensions(operand.dimensions_size() +
2618                                 broadcast_sizes.size());
2619   std::copy(broadcast_sizes.begin(), broadcast_sizes.end(), dimensions.begin());
2620   std::copy(operand.dimensions().begin(), operand.dimensions().end(),
2621             dimensions.begin() + broadcast_sizes.size());
2622   return ShapeUtil::MakeShape(operand.element_type(), dimensions);
2623 }
2624 
InferBroadcastShape(const Shape & operand_shape,const Shape & output_shape,absl::Span<const int64> broadcast_dimensions)2625 /* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
2626     const Shape& operand_shape, const Shape& output_shape,
2627     absl::Span<const int64> broadcast_dimensions) {
2628   TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of broadcast"));
2629   TF_RETURN_IF_ERROR(ExpectArray(output_shape, "operand of broadcast"));
2630   const int64 operand_rank = operand_shape.rank();
2631   const int64 output_rank = output_shape.rank();
2632   if (operand_rank > output_rank) {
2633     return InvalidArgument(
2634         "InDim style broadcast must be to an equal or higher ranked shape; "
2635         "operand rank: %lld; output rank: %lld",
2636         operand_rank, output_rank);
2637   }
2638   if (operand_rank != broadcast_dimensions.size()) {
2639     return InvalidArgument(
2640         "Size of broadcast_dimensions has to match operand's rank; operand "
2641         "rank: %lld, size of broadcast_dimensions %u.",
2642         operand_rank, broadcast_dimensions.size());
2643   }
2644   for (int64 i = 0; i < operand_rank; i++) {
2645     if (broadcast_dimensions[i] < 0 || broadcast_dimensions[i] >= output_rank) {
2646       return InvalidArgument("Broadcast dimension %lld is out of bound",
2647                              broadcast_dimensions[i]);
2648     }
2649     if (operand_shape.dimensions(i) !=
2650             output_shape.dimensions(broadcast_dimensions[i]) &&
2651         operand_shape.dimensions(i) != 1) {
2652       return InvalidArgument(
2653           "Input dimension should be either 1 or equal to the output dimension "
2654           "it is broadcasting into; the %lldth operand dimension is %lld, the "
2655           "%lldth output dimension is %lld.",
2656           i, operand_shape.dimensions(i), broadcast_dimensions[i],
2657           output_shape.dimensions(broadcast_dimensions[i]));
2658     }
2659     if (operand_shape.is_dynamic_dimension(i) !=
2660         output_shape.is_dynamic_dimension(broadcast_dimensions[i])) {
2661       return InvalidArgument(
2662           "Broadcast input and output dynamism mismatch: %s and %s",
2663           operand_shape.ToString(), output_shape.ToString());
2664     }
2665     // Make sure the broadcast dimensions are listed in a strictly increasing
2666     // order.
2667     if (i > 0 && broadcast_dimensions[i - 1] >= broadcast_dimensions[i]) {
2668       return InvalidArgument(
2669           "Broadcast dimensions order is wrong: %d comes after %d.",
2670           broadcast_dimensions[i], broadcast_dimensions.at(i - 1));
2671     }
2672   }
2673 
2674   return output_shape;
2675 }
2676 
InferReshapeShape(const Shape & operand,absl::Span<const int64> dimensions,absl::Span<const int64> new_sizes)2677 /* static */ StatusOr<Shape> ShapeInference::InferReshapeShape(
2678     const Shape& operand, absl::Span<const int64> dimensions,
2679     absl::Span<const int64> new_sizes) {
2680   TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape"));
2681 
2682   Shape inferred_shape =
2683       ShapeUtil::MakeShape(operand.element_type(), new_sizes);
2684   VLOG(3) << "Reshape inferred shape: "
2685           << ShapeUtil::HumanString(inferred_shape);
2686 
2687   if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) {
2688     return InvalidArgument(
2689         "Reshape operation has mismatched element counts: from=%d (%s) "
2690         "to=%d (%s).",
2691         ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand),
2692         ShapeUtil::ElementsIn(inferred_shape),
2693         ShapeUtil::HumanString(inferred_shape));
2694   }
2695 
2696   std::vector<int64> indices(operand.rank());
2697   std::iota(indices.begin(), indices.end(), 0);
2698   if (dimensions.size() != operand.rank() ||
2699       !std::is_permutation(dimensions.begin(), dimensions.end(),
2700                            indices.begin())) {
2701     return InvalidArgument(
2702         "Reshape dimensions [%s] are not a permutation of the operand "
2703         "dimensions (operand shape is %s).",
2704         StrJoin(dimensions, ","), ShapeUtil::HumanString(operand));
2705   }
2706 
2707   std::vector<std::pair<int64, int64>> unmodified_dims =
2708       ShapeUtil::DimensionsUnmodifiedByReshape(operand, inferred_shape);
2709   for (auto& unmodified : unmodified_dims) {
2710     if (operand.is_dynamic_dimension(unmodified.first)) {
2711       inferred_shape.set_dynamic_dimension(unmodified.second, true);
2712     }
2713   }
2714 
2715   return inferred_shape;
2716 }
2717 
InferTransposeShape(const Shape & operand,absl::Span<const int64> dimensions)2718 /* static */ StatusOr<Shape> ShapeInference::InferTransposeShape(
2719     const Shape& operand, absl::Span<const int64> dimensions) {
2720   TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose"));
2721 
2722   if (!IsPermutation(dimensions, operand.rank())) {
2723     return InvalidArgument(
2724         "Transpose dimensions [%s] are not a permutation of the operand "
2725         "dimensions (operand shape is %s).",
2726         StrJoin(dimensions, ","), ShapeUtil::HumanString(operand));
2727   }
2728 
2729   // Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However,
2730   // we need output[i]=input[dimensions[i]] which is
2731   // Permute(Inverse(dimensions),input).
2732   return ShapeUtil::PermuteDimensions(InversePermutation(dimensions), operand);
2733 }
2734 
2735 // TODO(b/36794510): Make broadcast semantics more consistent, by supporting
2736 // "degenerate" cases, as with binary elementwise ops.
InferClampShape(const Shape & min,const Shape & operand,const Shape & max)2737 /* static */ StatusOr<Shape> ShapeInference::InferClampShape(
2738     const Shape& min, const Shape& operand, const Shape& max) {
2739   TF_RETURN_IF_ERROR(ExpectArray(min, "clamp min"));
2740   TF_RETURN_IF_ERROR(ExpectArray(operand, "clamp operand"));
2741   TF_RETURN_IF_ERROR(ExpectArray(max, "clamp max"));
2742   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) ||
2743       !ShapeUtil::SameElementTypeIgnoringFpPrecision(max, operand)) {
2744     return InvalidArgument("Clamp with different operand types: %s, %s, %s.",
2745                            ShapeUtil::HumanString(min),
2746                            ShapeUtil::HumanString(operand),
2747                            ShapeUtil::HumanString(max));
2748   }
2749   if (((ShapeUtil::CompatibleIgnoringFpPrecision(min, operand) ||
2750         ShapeUtil::IsScalar(min)) &&
2751        (ShapeUtil::CompatibleIgnoringFpPrecision(max, operand) ||
2752         ShapeUtil::IsScalar(max)))) {
2753     return operand;
2754   }
2755   if (ShapeUtil::IsScalar(operand)) {
2756     if (ShapeUtil::CompatibleIgnoringFpPrecision(min, max)) {
2757       return ShapeUtil::ChangeElementType(min, operand.element_type());
2758     } else if (ShapeUtil::IsScalar(min)) {
2759       return ShapeUtil::ChangeElementType(max, operand.element_type());
2760     } else if (ShapeUtil::IsScalar(max)) {
2761       return ShapeUtil::ChangeElementType(min, operand.element_type());
2762     }
2763   }
2764   return Unimplemented("%s, %s <clamp> %s is not implemented.",
2765                        min.ShortDebugString(), max.ShortDebugString(),
2766                        operand.ShortDebugString());
2767 }
2768 
2769 // TODO(b/36794510): Make broadcast semantics more consistent, by supporting
2770 // "degenerate" cases, as with binary elementwise ops, as well as scalar
2771 // broadcast from all operands, not just the predicate.
InferSelectShape(const Shape & pred,const Shape & on_true,const Shape & on_false)2772 /* static */ StatusOr<Shape> ShapeInference::InferSelectShape(
2773     const Shape& pred, const Shape& on_true, const Shape& on_false) {
2774   if (!ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false)) {
2775     return InvalidArgument(
2776         "Operands to select must be the same shape; got %s and %s.",
2777         ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false));
2778   }
2779   if (pred.element_type() != PRED) {
2780     return InvalidArgument(
2781         "Select's pred operand must have PRED element type; got %s.",
2782         ShapeUtil::HumanString(pred));
2783   }
2784   if (Shape::Equal()
2785           .IgnoreElementType()
2786           .IgnoreLayout()
2787           .IgnoreDynamicDimension()(pred, on_true) ||
2788       ShapeUtil::IsScalar(pred)) {
2789     // By this stage we know that pred's element type is PRED. Therefore, this
2790     // check restricts pred to be a PRED scalar, or a PRED array with the same
2791     // dimensions as on_true and on_false.
2792     Shape inferred_shape = ShapeUtil::ChangeElementType(
2793         on_true, ShapeUtil::HigherPrecisionElementType(on_true, on_false));
2794 
2795     // Propagate dynamic dimensions if pred is not a scalar.
2796     if (!ShapeUtil::IsScalar(pred)) {
2797       for (int i = 0; i < inferred_shape.rank(); i++) {
2798         if (pred.is_dynamic_dimension(i)) {
2799           inferred_shape.set_dynamic_dimension(i, true);
2800         }
2801       }
2802     }
2803     return inferred_shape;
2804   }
2805   return InvalidArgument(
2806       "Select operation with non-scalar predicate with dimensionality "
2807       "different from the other operands: %s.",
2808       ShapeUtil::HumanString(pred));
2809 }
2810 
InferTupleSelectShape(const Shape & pred,const Shape & on_true,const Shape & on_false)2811 /* static */ StatusOr<Shape> ShapeInference::InferTupleSelectShape(
2812     const Shape& pred, const Shape& on_true, const Shape& on_false) {
2813   // Select only defines the top-level buffer, so if it's a tuple, the two
2814   // input must match exactly.
2815   if (!ShapeUtil::Compatible(on_true, on_false)) {
2816     return InvalidArgument(
2817         "Operands to tuple-select must be the same shape; got %s and %s.",
2818         ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false));
2819   }
2820   if (pred.element_type() != PRED) {
2821     return InvalidArgument(
2822         "TupleSelect's pred operand must have PRED element type; got %s.",
2823         ShapeUtil::HumanString(pred));
2824   }
2825   if (!ShapeUtil::IsScalar(pred)) {
2826     return InvalidArgument(
2827         "TupleSelect operation with non-scalar predicate: %s.",
2828         ShapeUtil::HumanString(pred));
2829   }
2830   return on_true;
2831 }
2832 
InferCallShape(absl::Span<const Shape * const> arg_shapes,const ProgramShape & to_apply)2833 /* static */ StatusOr<Shape> ShapeInference::InferCallShape(
2834     absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply) {
2835   // The applied function's arity equals the number of arguments.
2836   if (arg_shapes.size() != to_apply.parameters_size()) {
2837     string computation_signature = ShapeUtil::HumanString(to_apply);
2838     string argument_shapes =
2839         StrJoin(arg_shapes, ", ", [](string* out, const Shape* shape) {
2840           absl::StrAppend(out, ShapeUtil::HumanString(*shape));
2841         });
2842     return InvalidArgument(
2843         "Call applied function arity must match number of arguments; got: "
2844         "arity: %d, arguments: %u; computation signature: %s; argument "
2845         "shapes: [%s].",
2846         to_apply.parameters_size(), arg_shapes.size(), computation_signature,
2847         argument_shapes);
2848   }
2849 
2850   // All arguments must be compatible with the program shape.
2851   for (int i = 0; i < arg_shapes.size(); ++i) {
2852     const Shape& arg_shape = *arg_shapes[i];
2853     const Shape& param_shape = to_apply.parameters(i);
2854     if (!ShapeUtil::Compatible(arg_shape, param_shape)) {
2855       return InvalidArgument(
2856           "Call parameter must match argument; got parameter %d shape: %s, "
2857           "argument shape: %s.",
2858           i, ShapeUtil::HumanString(param_shape),
2859           ShapeUtil::HumanString(arg_shape));
2860     }
2861   }
2862 
2863   return to_apply.result();
2864 }
2865 
ValidateGatherDimensionNumbers(const Shape & input_shape,absl::Span<const int64> start_indices_shape,const GatherDimensionNumbers & dim_numbers)2866 static Status ValidateGatherDimensionNumbers(
2867     const Shape& input_shape, absl::Span<const int64> start_indices_shape,
2868     const GatherDimensionNumbers& dim_numbers) {
2869   if (!absl::c_is_sorted(dim_numbers.offset_dims())) {
2870     return InvalidArgument(
2871         "Output window dimensions in gather op must be ascending; got: %s.",
2872         StrJoin(dim_numbers.offset_dims(), ", "));
2873   }
2874 
2875   if (absl::c_adjacent_find(dim_numbers.offset_dims()) !=
2876       dim_numbers.offset_dims().end()) {
2877     return InvalidArgument(
2878         "Output window dimensions in gather op must not repeat; got: %s.",
2879         StrJoin(dim_numbers.offset_dims(), ", "));
2880   }
2881 
2882   const int64 output_offset_dim_count = dim_numbers.offset_dims_size();
2883   const int64 output_shape_rank =
2884       output_offset_dim_count + start_indices_shape.size() - 1;
2885 
2886   for (int i = 0; i < dim_numbers.offset_dims_size(); ++i) {
2887     int64 offset_dim = dim_numbers.offset_dims(i);
2888     if (offset_dim < 0 || offset_dim >= output_shape_rank) {
2889       return InvalidArgument(
2890           "Offset dimension %d in gather op is out of bounds; got %d, but "
2891           "should "
2892           "have been in [0,%d).",
2893           i, offset_dim, output_shape_rank);
2894     }
2895   }
2896 
2897   if (dim_numbers.start_index_map_size() !=
2898       start_indices_shape[dim_numbers.index_vector_dim()]) {
2899     return InvalidArgument(
2900         "Gather op has %d elements in start_index_map and the "
2901         "bound of dimension index_vector_dim=%d of start_indices is "
2902         "%d. These two numbers must be equal.",
2903         dim_numbers.start_index_map_size(), dim_numbers.index_vector_dim(),
2904         start_indices_shape[dim_numbers.index_vector_dim()]);
2905   }
2906 
2907   for (int i = 0; i < dim_numbers.start_index_map_size(); i++) {
2908     int64 operand_dim_for_start_index_i = dim_numbers.start_index_map(i);
2909     if (operand_dim_for_start_index_i < 0 ||
2910         operand_dim_for_start_index_i >= input_shape.dimensions_size()) {
2911       return InvalidArgument(
2912           "Invalid start_index_map; domain is [0, %d), got: %d->%d.",
2913           input_shape.dimensions_size(), i, operand_dim_for_start_index_i);
2914     }
2915   }
2916 
2917   std::vector<int64> sorted_start_index_map(
2918       dim_numbers.start_index_map().begin(),
2919       dim_numbers.start_index_map().end());
2920 
2921   absl::c_sort(sorted_start_index_map);
2922 
2923   if (absl::c_adjacent_find(sorted_start_index_map) !=
2924       sorted_start_index_map.end()) {
2925     return InvalidArgument(
2926         "Repeated dimensions are not allowed in start_index_map; "
2927         "got: %s.",
2928         StrJoin(dim_numbers.start_index_map(), ", "));
2929   }
2930 
2931   for (int64 collapsed_dim : dim_numbers.collapsed_slice_dims()) {
2932     if (collapsed_dim < 0 || collapsed_dim >= input_shape.dimensions_size()) {
2933       return InvalidArgument(
2934           "Invalid collapsed_slice_dims set in gather op; valid range is [0, "
2935           "%d), got: %d.",
2936           input_shape.dimensions_size(), collapsed_dim);
2937     }
2938   }
2939 
2940   if (!absl::c_is_sorted(dim_numbers.collapsed_slice_dims())) {
2941     return InvalidArgument(
2942         "collapsed_slice_dims in gather op must be sorted; got: %s",
2943         StrJoin(dim_numbers.collapsed_slice_dims(), ", "));
2944   }
2945 
2946   if (absl::c_adjacent_find(dim_numbers.collapsed_slice_dims()) !=
2947       dim_numbers.collapsed_slice_dims().end()) {
2948     return InvalidArgument(
2949         "Repeated dimensions not allowed in collapsed_slice_dims in gather op; "
2950         "got: %s.",
2951         StrJoin(dim_numbers.collapsed_slice_dims(), ", "));
2952   }
2953 
2954   return Status::OK();
2955 }
2956 
InferGatherShape(const Shape & input_shape,const Shape & start_indices_shape,const GatherDimensionNumbers & gather_dim_numbers,absl::Span<const int64> slice_sizes)2957 /*static*/ StatusOr<Shape> ShapeInference::InferGatherShape(
2958     const Shape& input_shape, const Shape& start_indices_shape,
2959     const GatherDimensionNumbers& gather_dim_numbers,
2960     absl::Span<const int64> slice_sizes) {
2961   TF_RETURN_IF_ERROR(
2962       ExpectArray(input_shape, "input tensor operand gather op"));
2963   TF_RETURN_IF_ERROR(
2964       ExpectArray(start_indices_shape, "gather indices operand of gather op"));
2965 
2966   if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
2967     return InvalidArgument(
2968         "Gather indices parameter must be an integral tensor; got %s.",
2969         ShapeUtil::HumanString(start_indices_shape));
2970   }
2971 
2972   // We implicitly reshape gather indices of shape P[A,B,C] to P[A,B,C,1] if
2973   // index_vector_dim is rank(P).  The bounds of this expanded shape is
2974   // stored in expanded_start_indices_shape.
2975 
2976   if (start_indices_shape.dimensions_size() <
2977           gather_dim_numbers.index_vector_dim() ||
2978       gather_dim_numbers.index_vector_dim() < 0) {
2979     return InvalidArgument(
2980         "Gather index leaf dimension must be within [0, rank(start_indices) + "
2981         "1). rank(start_indices) is %d and gather index leaf dimension is "
2982         "%d.",
2983         start_indices_shape.dimensions_size(),
2984         gather_dim_numbers.index_vector_dim());
2985   }
2986 
2987   std::vector<int64> expanded_start_indices_shape;
2988   expanded_start_indices_shape.reserve(start_indices_shape.dimensions_size());
2989   absl::c_copy(start_indices_shape.dimensions(),
2990                std::back_inserter(expanded_start_indices_shape));
2991   if (expanded_start_indices_shape.size() ==
2992       gather_dim_numbers.index_vector_dim()) {
2993     expanded_start_indices_shape.push_back(1);
2994   }
2995 
2996   TF_RETURN_IF_ERROR(ValidateGatherDimensionNumbers(
2997       input_shape, expanded_start_indices_shape, gather_dim_numbers));
2998 
2999   if (slice_sizes.size() != input_shape.dimensions_size()) {
3000     return InvalidArgument(
3001         "Gather op must have one slice size for every input dimension; got: "
3002         "len(slice_sizes)=%lu, input_shape.rank=%d.",
3003         slice_sizes.size(), input_shape.dimensions_size());
3004   }
3005 
3006   if (slice_sizes.size() !=
3007       gather_dim_numbers.offset_dims_size() +
3008           gather_dim_numbers.collapsed_slice_dims_size()) {
3009     return InvalidArgument(
3010         "All components of the offset index in a gather op must either be a "
3011         "offset dimension or explicitly collapsed; got len(slice_sizes)=%lu, "
3012         "output_slice_sizes=%s, collapsed_slice_dims=%s.",
3013         slice_sizes.size(), StrJoin(gather_dim_numbers.offset_dims(), ","),
3014         StrJoin(gather_dim_numbers.collapsed_slice_dims(), ","));
3015   }
3016 
3017   for (int i = 0; i < slice_sizes.size(); i++) {
3018     int64 slice_size = slice_sizes[i];
3019     int64 corresponding_input_size = input_shape.dimensions(i);
3020     if (slice_size < 0 || slice_size > corresponding_input_size) {
3021       return InvalidArgument(
3022           "Slice size at index %d in gather op is out of range, must be "
3023           "within [0, %d), got %d.",
3024           i, corresponding_input_size + 1, slice_size);
3025     }
3026   }
3027 
3028   for (int i = 0; i < gather_dim_numbers.collapsed_slice_dims_size(); i++) {
3029     if (slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)] != 1) {
3030       return InvalidArgument(
3031           "Gather op can only collapse slice dims with bound 1, but bound is "
3032           "%d for index %d at position %d.",
3033           slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)],
3034           gather_dim_numbers.collapsed_slice_dims(i), i);
3035     }
3036   }
3037 
3038   int64 result_rank = gather_dim_numbers.offset_dims_size() +
3039                       (expanded_start_indices_shape.size() - 1);
3040   int64 offset_dims_seen = 0;
3041   int64 gather_dims_seen = 0;
3042   std::vector<int64> output_dim_bounds;
3043   output_dim_bounds.reserve(result_rank);
3044   for (int64 i = 0; i < result_rank; i++) {
3045     int64 current_bound;
3046     bool is_window_index =
3047         absl::c_binary_search(gather_dim_numbers.offset_dims(), i);
3048     if (is_window_index) {
3049       while (absl::c_binary_search(gather_dim_numbers.collapsed_slice_dims(),
3050                                    offset_dims_seen)) {
3051         offset_dims_seen++;
3052       }
3053       current_bound = slice_sizes[offset_dims_seen++];
3054     } else {
3055       if (gather_dims_seen == gather_dim_numbers.index_vector_dim()) {
3056         gather_dims_seen++;
3057       }
3058       current_bound = expanded_start_indices_shape[gather_dims_seen++];
3059     }
3060 
3061     output_dim_bounds.push_back(current_bound);
3062   }
3063 
3064   return ShapeUtil::MakeShape(input_shape.element_type(), output_dim_bounds);
3065 }
3066 
3067 namespace {
3068 
ValidateScatterDimensionNumbers(const Shape & operand_shape,absl::Span<const int64> scatter_indices_shape,const Shape & updates_shape,const ScatterDimensionNumbers & dim_numbers)3069 Status ValidateScatterDimensionNumbers(
3070     const Shape& operand_shape, absl::Span<const int64> scatter_indices_shape,
3071     const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) {
3072   // Validate update_window_dims in ScatterDimensionNumbers.
3073   if (!absl::c_is_sorted(dim_numbers.update_window_dims())) {
3074     return InvalidArgument(
3075         "update_window_dims in scatter op must be sorted; got: %s.",
3076         StrJoin(dim_numbers.update_window_dims(), ", "));
3077   }
3078   if (absl::c_adjacent_find(dim_numbers.update_window_dims()) !=
3079       dim_numbers.update_window_dims().end()) {
3080     return InvalidArgument(
3081         "update_window_dims in scatter op must not repeat; got: %s.",
3082         StrJoin(dim_numbers.update_window_dims(), ", "));
3083   }
3084   const int64 updates_rank = updates_shape.rank();
3085   for (int64 window_dim : dim_numbers.update_window_dims()) {
3086     if (window_dim < 0 || window_dim >= updates_rank) {
3087       return InvalidArgument(
3088           "Invalid update_window_dims set in scatter op; valid range is [0, "
3089           "%d). got: %d.",
3090           updates_rank, window_dim);
3091     }
3092   }
3093 
3094   // Validate inserted_window_dims in ScatterDimensionNumbers.
3095   if (!absl::c_is_sorted(dim_numbers.inserted_window_dims())) {
3096     return InvalidArgument(
3097         "inserted_window_dims in scatter op must be sorted; got: %s.",
3098         StrJoin(dim_numbers.inserted_window_dims(), ", "));
3099   }
3100   if (absl::c_adjacent_find(dim_numbers.inserted_window_dims()) !=
3101       dim_numbers.inserted_window_dims().end()) {
3102     return InvalidArgument(
3103         "inserted_window_dims in scatter op must not repeat; got: %s.",
3104         StrJoin(dim_numbers.inserted_window_dims(), ", "));
3105   }
3106   for (int64 inserted_dim : dim_numbers.inserted_window_dims()) {
3107     if (inserted_dim < 0 || inserted_dim >= operand_shape.dimensions_size()) {
3108       return InvalidArgument(
3109           "Invalid inserted_window_dims set in scatter op; valid range is [0, "
3110           "%d), got: %d.",
3111           operand_shape.dimensions_size(), inserted_dim);
3112     }
3113   }
3114 
3115   // Validate window size.
3116   auto window_size = dim_numbers.update_window_dims_size() +
3117                      dim_numbers.inserted_window_dims_size();
3118   if (window_size != operand_shape.rank()) {
3119     return InvalidArgument(
3120         "Scatter op has window of size %d; doesn't match operand of rank %d.",
3121         window_size, operand_shape.rank());
3122   }
3123 
3124   // Validate scatter_dims_to_operand_dims in ScatterDimensionNumbers.
3125   if (dim_numbers.scatter_dims_to_operand_dims_size() !=
3126       scatter_indices_shape[dim_numbers.index_vector_dim()]) {
3127     return InvalidArgument(
3128         "Scatter op has %d elements in scatter_dims_to_operand_dims and the "
3129         "bound of dimension index_vector_dim=%d of scatter_indices is %d. "
3130         "These two numbers must be equal.",
3131         dim_numbers.scatter_dims_to_operand_dims_size(),
3132         dim_numbers.index_vector_dim(),
3133         scatter_indices_shape[dim_numbers.index_vector_dim()]);
3134   }
3135   for (int i = 0; i < dim_numbers.scatter_dims_to_operand_dims_size(); ++i) {
3136     int64 scatter_dim_to_operand_dim =
3137         dim_numbers.scatter_dims_to_operand_dims(i);
3138     if (scatter_dim_to_operand_dim < 0 ||
3139         scatter_dim_to_operand_dim >= operand_shape.dimensions_size()) {
3140       return InvalidArgument(
3141           "Invalid scatter_dims_to_operand_dims mapping; domain is [0, %d), "
3142           "got: %d->%d.",
3143           operand_shape.dimensions_size(), i, scatter_dim_to_operand_dim);
3144     }
3145   }
3146   std::vector<int64> sorted_scatter_dims_to_operand_dims(
3147       dim_numbers.scatter_dims_to_operand_dims().begin(),
3148       dim_numbers.scatter_dims_to_operand_dims().end());
3149   absl::c_sort(sorted_scatter_dims_to_operand_dims);
3150   if (absl::c_adjacent_find(sorted_scatter_dims_to_operand_dims) !=
3151       sorted_scatter_dims_to_operand_dims.end()) {
3152     return InvalidArgument(
3153         "Repeated dimensions not allowed in scatter_dims_to_operand_dims; "
3154         "got: %s.",
3155         StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ", "));
3156   }
3157 
3158   return Status::OK();
3159 }
3160 
3161 }  // namespace
3162 
InferScatterShape(const Shape & operand_shape,const Shape & scatter_indices_shape,const Shape & updates_shape,const ProgramShape & to_apply_shape,const ScatterDimensionNumbers & scatter_dim_numbers)3163 /*static*/ StatusOr<Shape> ShapeInference::InferScatterShape(
3164     const Shape& operand_shape, const Shape& scatter_indices_shape,
3165     const Shape& updates_shape, const ProgramShape& to_apply_shape,
3166     const ScatterDimensionNumbers& scatter_dim_numbers) {
3167   TF_RETURN_IF_ERROR(
3168       ExpectArray(operand_shape, "operand tensor of scatter op"));
3169   TF_RETURN_IF_ERROR(
3170       ExpectArray(scatter_indices_shape, "scatter indices of scatter op"));
3171   TF_RETURN_IF_ERROR(ExpectArray(updates_shape, "updates of scatter op"));
3172 
3173   if (!ShapeUtil::ElementIsIntegral(scatter_indices_shape)) {
3174     return InvalidArgument(
3175         "Scatter indices parameter must be an integral tensor; got %s.",
3176         ShapeUtil::HumanString(scatter_indices_shape));
3177   }
3178 
3179   if (scatter_indices_shape.dimensions_size() <
3180           scatter_dim_numbers.index_vector_dim() ||
3181       scatter_dim_numbers.index_vector_dim() < 0) {
3182     return InvalidArgument(
3183         "Scatter index leaf dimension must be within [0, rank(scatter_indices)"
3184         " + 1). rank(scatter_indices) is %d and scatter index leaf dimension "
3185         "is %d.",
3186         scatter_indices_shape.dimensions_size(),
3187         scatter_dim_numbers.index_vector_dim());
3188   }
3189 
3190   // Check if the update computation has a proper shape as a reduction.
3191   const Shape init_value_shape =
3192       ShapeUtil::MakeShape(operand_shape.element_type(), {});
3193   TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape},
3194                                         {updates_shape.element_type()},
3195                                         /*inputs=*/1));
3196 
3197   std::vector<int64> expanded_scatter_indices_shape =
3198       ArraySliceToVector(AsInt64Slice(scatter_indices_shape.dimensions()));
3199   if (expanded_scatter_indices_shape.size() ==
3200       scatter_dim_numbers.index_vector_dim()) {
3201     expanded_scatter_indices_shape.push_back(1);
3202   }
3203 
3204   int64 expected_updates_rank = expanded_scatter_indices_shape.size() - 1 +
3205                                 scatter_dim_numbers.update_window_dims_size();
3206   if (updates_shape.rank() != expected_updates_rank) {
3207     return InvalidArgument("Updates tensor must be of rank %d; got %d.",
3208                            expected_updates_rank, updates_shape.rank());
3209   }
3210 
3211   TF_RETURN_IF_ERROR(ValidateScatterDimensionNumbers(
3212       operand_shape, expanded_scatter_indices_shape, updates_shape,
3213       scatter_dim_numbers));
3214 
3215   int64 inserted_dims_seen = 0;
3216   std::vector<int64> max_update_slice_sizes;
3217   for (int i = 0; i < operand_shape.dimensions_size(); ++i) {
3218     if (inserted_dims_seen < scatter_dim_numbers.inserted_window_dims_size() &&
3219         scatter_dim_numbers.inserted_window_dims(inserted_dims_seen) == i) {
3220       ++inserted_dims_seen;
3221     } else {
3222       max_update_slice_sizes.push_back(operand_shape.dimensions(i));
3223     }
3224   }
3225   for (int i = 0; i < scatter_dim_numbers.update_window_dims_size(); ++i) {
3226     auto update_window_dim = scatter_dim_numbers.update_window_dims(i);
3227     if (updates_shape.dimensions(update_window_dim) >
3228         max_update_slice_sizes[i]) {
3229       return InvalidArgument(
3230           "Bounds of the window dimensions of updates must not exceed the "
3231           "bounds of the corresponding dimensions of operand. For dimension "
3232           "%d, updates bound is %d, operand bound is %d.",
3233           update_window_dim, updates_shape.dimensions(update_window_dim),
3234           max_update_slice_sizes[i]);
3235     }
3236   }
3237 
3238   int64 scatter_dims_seen = 0;
3239   for (int64 i = 0; i < updates_shape.rank(); ++i) {
3240     bool is_update_window_dim =
3241         absl::c_binary_search(scatter_dim_numbers.update_window_dims(), i);
3242     if (is_update_window_dim) {
3243       continue;
3244     }
3245     if (scatter_dims_seen == scatter_dim_numbers.index_vector_dim()) {
3246       ++scatter_dims_seen;
3247     }
3248     if (updates_shape.dimensions(i) !=
3249         expanded_scatter_indices_shape[scatter_dims_seen]) {
3250       return InvalidArgument(
3251           "Bounds of the scatter dimensions of updates must be same as the "
3252           "bounds of the corresponding dimensions of scatter indices. For "
3253           "scatter dimension %d, updates bound is %d, scatter_indices "
3254           "bound is %d.",
3255           i, updates_shape.dimensions(i),
3256           expanded_scatter_indices_shape[scatter_dims_seen]);
3257     }
3258     ++scatter_dims_seen;
3259   }
3260 
3261   return operand_shape;
3262 }
3263 
3264 }  // namespace xla
3265