1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering TensorFlow dialect to XLA dialect.
17 
18 #include <cctype>
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <limits>
23 #include <numeric>
24 
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/Optional.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/Sequence.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/Support/ErrorHandling.h"
31 #include "llvm/Support/FormatVariadic.h"
32 #include "mlir/Dialect/Shape/IR/Shape.h"  // from @llvm-project
33 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
34 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
35 #include "mlir/Dialect/Traits.h"  // from @llvm-project
36 #include "mlir/IR/Attributes.h"  // from @llvm-project
37 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
38 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
39 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
40 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
41 #include "mlir/IR/Matchers.h"  // from @llvm-project
42 #include "mlir/IR/Operation.h"  // from @llvm-project
43 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
44 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
45 #include "mlir/IR/Types.h"  // from @llvm-project
46 #include "mlir/Pass/Pass.h"  // from @llvm-project
47 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
48 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
49 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
50 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
51 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
52 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/convert_op_folder.h"
53 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
54 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
55 #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
56 #include "tensorflow/compiler/mlir/xla/attribute_importer.h"
57 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
58 #include "tensorflow/compiler/xla/client/lib/conv_grad_size_util.h"
59 #include "tensorflow/compiler/xla/client/padding.h"
60 #include "tensorflow/compiler/xla/client/sharding_builder.h"
61 #include "tensorflow/compiler/xla/xla_data.pb.h"
62 #include "tensorflow/core/framework/kernel_shape_util.h"
63 #include "tensorflow/core/kernels/conv_grad_shape_utils.h"
64 #include "tensorflow/core/platform/bfloat16.h"
65 #include "tensorflow/core/util/padding.h"
66 #include "tensorflow/core/util/tensor_format.h"
67 
68 namespace mlir {
69 namespace mhlo {
70 namespace {
71 
72 constexpr char kShardingAttr[] = "mhlo.sharding";
73 
74 class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
getDependentDialects(DialectRegistry & registry) const75   void getDependentDialects(DialectRegistry &registry) const override {
76     registry.insert<chlo::HloClientDialect, mhlo::MhloDialect,
77                     shape::ShapeDialect, StandardOpsDialect>();
78   }
79 
80  public:
81   LegalizeTF() = default;
LegalizeTF(const LegalizeTF &)82   LegalizeTF(const LegalizeTF &) {}
LegalizeTF(bool allow_partial_conversion,bool legalize_chlo,llvm::Optional<StringRef> tf2xla_fallback_device_type)83   explicit LegalizeTF(bool allow_partial_conversion, bool legalize_chlo,
84                       llvm::Optional<StringRef> tf2xla_fallback_device_type) {
85     allow_partial_conversion_ = allow_partial_conversion;
86     legalize_chlo_ = legalize_chlo;
87     use_tf2xla_fallback_ = tf2xla_fallback_device_type.hasValue();
88     if (tf2xla_fallback_device_type.hasValue()) {
89       device_type_ = tf2xla_fallback_device_type.getValue().str();
90     }
91   }
92 
93   /// Performs the lowering to XLA dialect.
94   void runOnFunction() override;
95 
96  private:
97   Option<bool> allow_partial_conversion_{
98       *this, "allow-partial-conversion",
99       llvm::cl::desc("Allow operations that can't be legalized."),
100       llvm::cl::init(false)};
101   Option<bool> legalize_chlo_{
102       *this, "legalize-chlo",
103       llvm::cl::desc(
104           "Also legalizes intermediate chlo ops to hlo (default true)"),
105       llvm::cl::init(true)};
106   Option<bool> use_tf2xla_fallback_{
107       *this, "use-tf2xla-fallback",
108       llvm::cl::desc(
109           "Also use TF2XLA fallback for legalization (default false)"),
110       llvm::cl::init(false)};
111   Option<std::string> device_type_{
112       *this, "device-type",
113       llvm::cl::desc(
114           "The device type used by TF2XLA fallback. Must be specified if "
115           "use-tf2xla-fallback is true, otherwise not used."),
116       llvm::cl::init("INVALID_DEVICE_TYPE")};
117 };
118 
119 /// Returns if the given TF data format string is the default format.
IsDefaultDataFormat(StringRef format)120 static bool IsDefaultDataFormat(StringRef format) { return format == "NHWC"; }
121 
122 /// Returns the feature dimension for the given format and input type.
GetFeatureDimension(StringRef format,RankedTensorType inputType)123 static size_t GetFeatureDimension(StringRef format,
124                                   RankedTensorType inputType) {
125   return IsDefaultDataFormat(format) ? inputType.getRank() - 1 : 1;
126 }
127 
128 // Gets all integer values from the given attribute and push them to `values`.
GetI64ArrayAttrValues(Attribute attr,SmallVectorImpl<int64_t> * values)129 void GetI64ArrayAttrValues(Attribute attr, SmallVectorImpl<int64_t> *values) {
130   auto array_attr = attr.cast<ArrayAttr>();
131   values->reserve(array_attr.getValue().size());
132   for (Attribute val : array_attr.getValue())
133     values->push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
134 }
135 
136 // Returns 1D 64-bit dense elements attribute with the given values.
GetI64ElementsAttr(ArrayRef<int64_t> values,Builder * builder)137 static DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
138                                                Builder *builder) {
139   RankedTensorType ty = RankedTensorType::get(
140       {static_cast<int64_t>(values.size())}, builder->getIntegerType(64));
141   return DenseIntElementsAttr::get(ty, values);
142 }
143 
144 // Converts an ArrayAttr to a 1D 64-bit dense elements attribute.
GetI64ElementsAttr(ArrayAttr attr)145 static DenseIntElementsAttr GetI64ElementsAttr(ArrayAttr attr) {
146   RankedTensorType ty =
147       RankedTensorType::get(static_cast<int64_t>(attr.size()),
148                             IntegerType::get(attr.getContext(), 64));
149   return DenseIntElementsAttr::get(ty, attr.getValue());
150 }
151 
152 // Returns 1D 32-bit dense elements attribute with the given values.
GetI32ElementsAttr(ArrayRef<int32_t> values,Builder * builder)153 static DenseIntElementsAttr GetI32ElementsAttr(ArrayRef<int32_t> values,
154                                                Builder *builder) {
155   RankedTensorType ty = RankedTensorType::get(
156       {static_cast<int32_t>(values.size())}, builder->getIntegerType(32));
157   return DenseIntElementsAttr::get(ty, values);
158 }
159 
160 // Returns a 1-d i64 elements attribute populated with numbers from start to
161 // end, excluding.
GetI64ElementsAttrForSeq(int start,int end,Builder * builder)162 static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end,
163                                                      Builder *builder) {
164   int size = end - start;
165 
166   SmallVector<int64_t, 4> vals;
167   vals.resize(size);
168   std::iota(vals.begin(), vals.end(), start);
169 
170   TensorType ty = RankedTensorType::get({size}, builder->getIntegerType(64));
171   return DenseIntElementsAttr::get(ty, vals);
172 }
173 
174 // Returns a 1-d i64 elements attribute populated with `val` repeated `size`
175 // times.
GetI64ElementsAttrForValue(int size,int64_t val,Builder * builder)176 static DenseIntElementsAttr GetI64ElementsAttrForValue(int size, int64_t val,
177                                                        Builder *builder) {
178   TensorType ty = RankedTensorType::get({size}, builder->getIntegerType(64));
179   return DenseIntElementsAttr::get(ty, val);
180 }
181 
182 // Returns the corresponding type that should be used for performing sum
183 // accumulation over the given input type.
GetSumAccumulationType(Type input_type)184 Type GetSumAccumulationType(Type input_type) {
185   MLIRContext *ctx = input_type.getContext();
186   if (input_type.isBF16() || input_type.isF16()) return FloatType::getF32(ctx);
187   if (input_type.isSignlessInteger(8) || input_type.isSignlessInteger(16))
188     return IntegerType::get(ctx, 32);
189   return input_type;
190 }
191 
192 // Returns axis in HLO format from TF elements attr with exactly one element or
193 // is an IntegerAttr, containing axis in the TensorFlow format. TensorFlow
194 // format supports negative indexing unlike HLO.
GetHLOAxisFromTFAxis(Attribute attr,int64_t rank,Builder * b)195 static IntegerAttr GetHLOAxisFromTFAxis(Attribute attr, int64_t rank,
196                                         Builder *b) {
197   IntegerAttr intAttr = attr.dyn_cast_or_null<IntegerAttr>();
198   if (auto elementAttr = attr.dyn_cast_or_null<ElementsAttr>()) {
199     SmallVector<uint64_t, 1> index(elementAttr.getType().getRank(), 0);
200     intAttr = elementAttr.getValue<IntegerAttr>(index);
201   }
202 
203   assert(intAttr && "Invalid attribute passed to GetHLOAxisFromTFAxis");
204 
205   int64_t axis = intAttr.getInt();
206   if (axis < 0) {
207     axis += rank;
208   }
209   return b->getI64IntegerAttr(axis);
210 }
211 
212 // If `value` is an IntegerAttr, returns the integer value for the HLO axis
213 // corresponding to the tensorflow axis. In particular, the tensorflow axis can
214 // be negative, in which case, the corresponding HLO axis is
215 // (axis + rank-of-the-tensor).
GetIntegerHLOAxisFromTFAxis(Value value,int64_t rank)216 static llvm::Optional<int64_t> GetIntegerHLOAxisFromTFAxis(Value value,
217                                                            int64_t rank) {
218   DenseIntElementsAttr attrs;
219   if (!matchPattern(value, m_Constant(&attrs)) ||
220       attrs.getType().getRank() != 0) {
221     return llvm::None;
222   }
223   int64_t axis = attrs.getValue<IntegerAttr>({}).getInt();
224   return axis < 0 ? axis + rank : axis;
225 }
226 
227 /// Returns a `ConvertOp` that casts the elements to a i64 type while retaining
228 /// the shape of the input value.
CastValueToI64(Location loc,Value value,PatternRewriter * rewriter)229 static ConvertOp CastValueToI64(Location loc, Value value,
230                                 PatternRewriter *rewriter) {
231   return rewriter->create<ConvertOp>(loc, value, rewriter->getIntegerType(64));
232 }
233 
234 // Creates an unpack op along the 0th dimension of the tensor. The `value` input
235 // must be a ranked tensor.
UnpackTensorAlongZeroDim(Location loc,Value value,PatternRewriter * rewriter)236 static TF::UnpackOp UnpackTensorAlongZeroDim(Location loc, Value value,
237                                              PatternRewriter *rewriter) {
238   auto indices_type = value.getType().cast<RankedTensorType>();
239   int num_outputs = indices_type.getShape().front();
240   SmallVector<Type, 2> unpacked_indices_type(
241       num_outputs, RankedTensorType::get({}, indices_type.getElementType()));
242   auto unpacked_indices = rewriter->create<TF::UnpackOp>(
243       loc, unpacked_indices_type, value,
244       IntegerAttr::get(rewriter->getIntegerType(64), 0));
245   return unpacked_indices;
246 }
247 
248 // Returns size of dimension at the specified index, if ranked tensor.
249 // Otherwise, returns -1.
250 //
251 // Aborts if the type is ranked but doesn't have the dimension.
GetDimSize(Type ty,int64_t index)252 int64_t GetDimSize(Type ty, int64_t index) {
253   RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>();
254   if (!ranked_ty) return -1;
255 
256   return ranked_ty.getDimSize(index);
257 }
258 
259 template <typename T, int num_dims>
ToTensorShape(llvm::ArrayRef<T> sizes)260 tensorflow::TensorShape ToTensorShape(llvm::ArrayRef<T> sizes) {
261   return tensorflow::TensorShape(llvm::SmallVector<tensorflow::int64, num_dims>(
262       sizes.begin(), sizes.end()));
263 }
264 
265 template <typename T, int num_dims>
ToTensorShape(llvm::iterator_range<DenseElementsAttr::ElementIterator<T>> sizes)266 tensorflow::TensorShape ToTensorShape(
267     llvm::iterator_range<DenseElementsAttr::ElementIterator<T>> sizes) {
268   return tensorflow::TensorShape(llvm::SmallVector<tensorflow::int64, num_dims>(
269       sizes.begin(), sizes.end()));
270 }
271 
272 // Returns int, float, or complex scalar DenseElementsAttr attribute with the
273 // given element type and the value.
GetScalarConstOfType(Type ty,Location loc,int64_t raw_value,OpBuilder * builder)274 static ConstOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value,
275                                     OpBuilder *builder) {
276   return builder->create<ConstOp>(loc, hlo::GetScalarOfType(ty, raw_value));
277 }
278 
279 // Returns a limit scalar const op for the given type.
280 // Requires FloatType or IntegerType
GetScalarLimitConstOfType(Type ty,Location loc,hlo::ScalarLimit limit,OpBuilder * builder)281 static ConstOp GetScalarLimitConstOfType(Type ty, Location loc,
282                                          hlo::ScalarLimit limit,
283                                          OpBuilder *builder) {
284   return builder->create<ConstOp>(loc, hlo::GetScalarLimitOfType(ty, limit));
285 }
286 
287 // Creates an mhlo::SliceOp where the major dimensions have full size, and
288 // the minor dimensions have the provided offsets and sizes.
SliceInMinorDims(Location loc,Value v,ArrayRef<int64_t> minor_starts,ArrayRef<int64_t> minor_limits,OpBuilder * builder)289 static Value SliceInMinorDims(Location loc, Value v,
290                               ArrayRef<int64_t> minor_starts,
291                               ArrayRef<int64_t> minor_limits,
292                               OpBuilder *builder) {
293   auto type = v.getType().cast<RankedTensorType>();
294   llvm::SmallVector<int64_t, 4> slice_starts(type.getRank(), 0);
295   int64_t major_dims = type.getRank() - minor_starts.size();
296   std::copy(minor_starts.begin(), minor_starts.end(),
297             slice_starts.begin() + major_dims);
298   auto slice_limits = llvm::to_vector<4>(type.getShape());
299   std::copy(minor_limits.begin(), minor_limits.end(),
300             slice_limits.begin() + major_dims);
301   llvm::SmallVector<int64_t, 4> slice_strides(type.getRank(), 1);
302   return builder->create<SliceOp>(loc, v,
303                                   GetI64ElementsAttr(slice_starts, builder),
304                                   GetI64ElementsAttr(slice_limits, builder),
305                                   GetI64ElementsAttr(slice_strides, builder));
306 }
307 
308 // Creates a vector of index values:
309 //  [0, 0, ..., minor_indices[0], minor_indices[1], ... minor_indices[-1]]
310 // with length `rank`.
CreateFullIndexVectorFromMinorIndices(Location loc,ArrayRef<Value> minor_indices,int64_t rank,OpBuilder * builder)311 static llvm::SmallVector<Value, 4> CreateFullIndexVectorFromMinorIndices(
312     Location loc, ArrayRef<Value> minor_indices, int64_t rank,
313     OpBuilder *builder) {
314   auto zero =
315       GetScalarConstOfType(getElementTypeOrSelf(minor_indices[0].getType()),
316                            loc, 0, builder)
317           .output();
318   llvm::SmallVector<Value, 4> indices(rank, zero);
319   std::copy(minor_indices.begin(), minor_indices.end(),
320             indices.begin() + (rank - minor_indices.size()));
321   return indices;
322 }
323 
324 // Creates an mhlo::DynamicSliceOp where the major dimensions have full size,
325 // and the minor dimensions have the provided offsets and sizes.
DynamicSliceInMinorDims(Location loc,Value v,ArrayRef<Value> minor_starts,ArrayRef<int64_t> minor_sizes,OpBuilder * builder)326 static Value DynamicSliceInMinorDims(Location loc, Value v,
327                                      ArrayRef<Value> minor_starts,
328                                      ArrayRef<int64_t> minor_sizes,
329                                      OpBuilder *builder) {
330   if (minor_starts.empty()) return v;
331   auto type = v.getType().cast<RankedTensorType>();
332   auto slice_starts = CreateFullIndexVectorFromMinorIndices(
333       loc, minor_starts, type.getRank(), builder);
334   int64_t major_dims = type.getRank() - minor_starts.size();
335   auto slice_sizes = llvm::to_vector<4>(type.getShape());
336   std::copy(minor_sizes.begin(), minor_sizes.end(),
337             slice_sizes.begin() + major_dims);
338   auto slice_type = RankedTensorType::get(slice_sizes, type.getElementType());
339   return builder->create<mhlo::DynamicSliceOp>(
340       loc, slice_type, v, slice_starts,
341       GetI64ElementsAttr(slice_sizes, builder));
342 }
343 
344 // Creates an mhlo::DynamicUpdateSliceOp where the major dimensions have zero
345 // offsets, and the minor dimensions have the provided offsets.
DynamicUpdateSliceInMinorDims(Location loc,Value v,Value update,ArrayRef<Value> minor_starts,OpBuilder * builder)346 static Value DynamicUpdateSliceInMinorDims(Location loc, Value v, Value update,
347                                            ArrayRef<Value> minor_starts,
348                                            OpBuilder *builder) {
349   if (minor_starts.empty()) return v;
350   auto type = v.getType().cast<RankedTensorType>();
351   auto dus_starts = CreateFullIndexVectorFromMinorIndices(
352       loc, minor_starts, type.getRank(), builder);
353   return builder->create<DynamicUpdateSliceOp>(loc, type, v, update,
354                                                llvm::makeArrayRef(dus_starts));
355 }
356 
357 // Creates an mhlo::DynamicUpdateSliceOp where the major dimensions have zero
358 // offsets, and the minor dimensions have the provided static offsets.
UpdateSliceInMinorDims(Location loc,Value v,Value update,ArrayRef<int64_t> minor_starts,OpBuilder * builder)359 static Value UpdateSliceInMinorDims(Location loc, Value v, Value update,
360                                     ArrayRef<int64_t> minor_starts,
361                                     OpBuilder *builder) {
362   llvm::SmallVector<Value, 4> dus_starts(minor_starts.size());
363   for (uint64_t i = 0; i < minor_starts.size(); ++i) {
364     dus_starts[i] = GetScalarConstOfType(builder->getIntegerType(32), loc,
365                                          minor_starts[i], builder);
366   }
367   return DynamicUpdateSliceInMinorDims(loc, v, update, dus_starts, builder);
368 }
369 
370 // Deprecated: This is maintained to aid in porting old code that is not yet
371 // dynamic shape aware and uses broadcasting modes that CHLO does not support.
372 // Gets the resulting type from a broadcast between two types for statically
373 // shaped types. This is to be used for legacy lowerings that both use non
374 // left-padded broadcasting and static shapes. Its use should not be permitted
375 // in new code.
376 // May return nullptr on invalid static broadcast dimensions.
377 // ABSL_DEPRECATED()
GetStaticBroadcastType(RankedTensorType x,RankedTensorType y,DenseIntElementsAttr broadcast_dimensions_attr)378 static RankedTensorType GetStaticBroadcastType(
379     RankedTensorType x, RankedTensorType y,
380     DenseIntElementsAttr broadcast_dimensions_attr) {
381   auto element_type = x.getElementType();
382   auto shape_x = x.getShape();
383   auto shape_y = y.getShape();
384 
385   if (shape_x.size() == shape_y.size()) {
386     llvm::SmallVector<int64_t, 4> out_shape(shape_x.size());
387     for (int i = 0; i < shape_x.size(); i++) {
388       auto x_val = shape_x[i];
389       auto y_val = shape_y[i];
390       out_shape[i] = std::max(x_val, y_val);
391     }
392     return RankedTensorType::get(out_shape, element_type);
393   }
394 
395   auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y;
396   auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y;
397 
398   llvm::SmallVector<int64_t, 4> broadcast_dimensions;
399   // Explicit broadcast dimensions.
400   for (const APInt &int_value : broadcast_dimensions_attr) {
401     broadcast_dimensions.push_back(int_value.getSExtValue());
402   }
403   if (broadcast_dimensions.size() != shape_small.size()) {
404     return nullptr;
405   }
406   llvm::SmallVector<int64_t, 4> out_shape(shape_large.begin(),
407                                           shape_large.end());
408 
409   // Update according to the broadcast dimensions.
410   for (auto index_pair : llvm::enumerate(broadcast_dimensions)) {
411     auto old_value = out_shape[index_pair.value()];
412     auto new_value = shape_small[index_pair.index()];
413     out_shape[index_pair.value()] = std::max(old_value, new_value);
414   }
415   return RankedTensorType::get(out_shape, element_type);
416 }
417 
418 // Deprecated: This is maintained to aid in porting old code that is not yet
419 // dynamic shape aware and uses broadcasting modes that CHLO does not support.
420 // Applies static binary broadcasting to a binary elementwise op.
421 // This is a legacy helper to provide general broadcasting support in legacy,
422 // static shaped code that relies on non-left-padded broadcasting semantics.
423 template <typename BinaryOp>
StaticBinaryBroadcast(Location loc,Value x,Value y,DenseIntElementsAttr broadcast_dims,OpBuilder & builder)424 static Value StaticBinaryBroadcast(Location loc, Value x, Value y,
425                                    DenseIntElementsAttr broadcast_dims,
426                                    OpBuilder &builder) {
427   auto x_type = x.getType().cast<RankedTensorType>();
428   auto y_type = y.getType().cast<RankedTensorType>();
429   auto result_type = GetStaticBroadcastType(x_type, y_type, broadcast_dims);
430   if (!result_type) {
431     emitError(loc) << "could not binary broadcast " << x_type << ", " << y_type
432                    << " with broadcast_dims = " << broadcast_dims;
433     return nullptr;
434   }
435   auto larger_broadcast_dims =
436       GetI64ElementsAttrForSeq(0, result_type.getRank(), &builder);
437   if (x_type.getRank() < y_type.getRank()) {
438     if (x_type != result_type) {
439       x = builder.create<BroadcastInDimOp>(loc, result_type, x, broadcast_dims);
440     }
441     if (y_type != result_type) {
442       y = builder.create<BroadcastInDimOp>(loc, result_type, y,
443                                            larger_broadcast_dims);
444     }
445   } else {
446     if (x_type != result_type) {
447       x = builder.create<BroadcastInDimOp>(loc, result_type, x,
448                                            larger_broadcast_dims);
449     }
450     if (y_type != result_type) {
451       y = builder.create<BroadcastInDimOp>(loc, result_type, y, broadcast_dims);
452     }
453   }
454   return builder.create<BinaryOp>(loc, x, y);
455 }
456 
457 // Gets a 1D tensor type suitable for expressing extents of the given tensor
458 // value type. If the value type is ranked, the result will be statically
459 // shaped. Otherwise, it will have a dynamic dimension.
GetExtentsTensorTypeFor(TensorType value_type)460 static RankedTensorType GetExtentsTensorTypeFor(TensorType value_type) {
461   Builder b(value_type.getContext());
462   int64_t dim = value_type.hasRank() ? value_type.getRank() : -1;
463   return RankedTensorType::get({dim}, b.getIndexType());
464 }
465 
466 // Broadcasts a 'lower_rank_value' to the shape of a 'higher_rank_value'
467 // by assuming that the shape of the lower ranked is a broadcast compatible
468 // prefix of the higher ranked.
469 // Values must be RankedTensorType (this restriction derives from the
470 // broadcast_dimensions attribute on DynamicBroadcastInDim).
471 //
472 // Example:
473 //   CommonPrefixBroadcast(tensor<4x3x256>, tensor<4, 3>) will broadcast the
474 //   lower rank value to [4, 3, 256] (i.e. the opposite of numpy-style
475 //   implicit broadcasting).
CommonPrefixBroadcast(Location loc,Value higher_rank_value,Value lower_rank_value,OpBuilder & builder)476 static Value CommonPrefixBroadcast(Location loc, Value higher_rank_value,
477                                    Value lower_rank_value, OpBuilder &builder) {
478   Value higher_rank_shape =
479       builder.create<shape::ShapeOfOp>(loc, higher_rank_value);
480   auto result_extents_type =
481       GetExtentsTensorTypeFor(higher_rank_value.getType().cast<TensorType>());
482   Value result_extents = builder.create<shape::ToExtentTensorOp>(
483       loc, result_extents_type, higher_rank_shape);
484 
485   auto lower_rank_type = lower_rank_value.getType().cast<RankedTensorType>();
486   auto lower_rank = lower_rank_type.getRank();
487   auto prefix_dims = GetI64ElementsAttrForSeq(0, lower_rank, &builder);
488   return builder.create<DynamicBroadcastInDimOp>(
489       loc, higher_rank_value.getType(), lower_rank_value, result_extents,
490       prefix_dims);
491 }
492 
493 // Given a value (broadcast_to) and a feature dimension, broadcasts a 1D
494 // value (broadcast_from) along that feature dimension. This is a shortcut
495 // for the cases where a 1D tensor must be broadcast along a specific feature
496 // dimension, which can vary based on data layout, etc.
497 //
498 // The extent of `broadcast_from` dim0 must be equal to the extent of the
499 // feature_dim of `broadcast_to`.
500 //
501 // Example:
502 //   [1x2x3x4], [2], 1 -> [1x2x3x4]
503 // TODO(laurenzo): Swap the order of broadcast_to and broadcast_from for
504 // consistency. Possibly also rename for clarity.
Broadcast1DToFeatureDim(Location loc,Value broadcast_to,Value broadcast_from,int64_t feature_dim,OpBuilder & builder)505 static Value Broadcast1DToFeatureDim(Location loc, Value broadcast_to,
506                                      Value broadcast_from, int64_t feature_dim,
507                                      OpBuilder &builder) {
508   auto broadcast_dims = GetI64ElementsAttr({feature_dim}, &builder);
509   auto to_type = broadcast_to.getType().cast<RankedTensorType>();
510   auto result_shape = builder.create<shape::ShapeOfOp>(loc, broadcast_to);
511   auto result_extents_type = GetExtentsTensorTypeFor(to_type);
512   auto result_extents = builder.create<shape::ToExtentTensorOp>(
513       loc, result_extents_type, result_shape);
514   return builder.create<DynamicBroadcastInDimOp>(
515       loc, to_type, broadcast_from, result_extents, broadcast_dims);
516 }
517 
518 // Broadcasts `input` to the shape of `broadcast_to` value following
519 // TF::BroadcastTo semantics.
520 //
521 // Requires that input is a ranked tensor.
522 //
523 // TODO(hinsu): Utilize TF::ShapeOp followed by TF::BroadcastTo once ShapeOp
524 // supports unranked inputs in the lowering.
BroadcastToShapeOf(Location loc,Value input,Value broadcast_to,OpBuilder & builder)525 static Value BroadcastToShapeOf(Location loc, Value input, Value broadcast_to,
526                                 OpBuilder &builder) {
527   auto result_shape = builder.create<shape::ShapeOfOp>(loc, broadcast_to);
528   auto to_type = broadcast_to.getType().cast<TensorType>();
529   auto result_extents_type = GetExtentsTensorTypeFor(to_type);
530   auto result_extents = builder.create<shape::ToExtentTensorOp>(
531       loc, result_extents_type, result_shape);
532   int64_t rank = input.getType().cast<RankedTensorType>().getRank();
533   auto broadcast_dims = GetI64ElementsAttrForSeq(0, rank, &builder);
534   return builder.create<DynamicBroadcastInDimOp>(
535       loc, to_type, input, result_extents, broadcast_dims);
536 }
537 
538 // Creates a batch dot using mhlo::DotGeneralOp.
BatchDot(Location loc,Value lhs,bool transpose_lhs,Value rhs,bool transpose_rhs,int64_t num_batch_dims,ArrayAttr precision_config,OpBuilder * builder)539 Value BatchDot(Location loc, Value lhs, bool transpose_lhs, Value rhs,
540                bool transpose_rhs, int64_t num_batch_dims,
541                ArrayAttr precision_config, OpBuilder *builder) {
542   auto batch_dimensions = GetI64ElementsAttr(
543       llvm::to_vector<4>(llvm::seq<int64_t>(0, num_batch_dims)), builder);
544   auto lhs_contracting_dimensions = GetI64ElementsAttr(
545       llvm::makeArrayRef({transpose_lhs ? num_batch_dims : num_batch_dims + 1}),
546       builder);
547   auto rhs_contracting_dimensions = GetI64ElementsAttr(
548       llvm::makeArrayRef({transpose_rhs ? num_batch_dims + 1 : num_batch_dims}),
549       builder);
550   auto dimension_numbers = DotDimensionNumbers::get(
551       /*lhs_batching_dimensions=*/batch_dimensions,
552       /*rhs_batching_dimensions=*/batch_dimensions,
553       /*lhs_contracting_dimensions=*/lhs_contracting_dimensions,
554       /*rhs_contracting_dimensions=*/rhs_contracting_dimensions,
555       builder->getContext());
556   auto lhs_shape = lhs.getType().cast<RankedTensorType>().getShape();
557   auto rhs_shape = rhs.getType().cast<RankedTensorType>().getShape();
558   auto shape = llvm::to_vector<4>(lhs_shape);
559   shape[shape.size() - 2] =
560       transpose_lhs ? lhs_shape.back() : lhs_shape[lhs_shape.size() - 2];
561   shape[shape.size() - 1] =
562       transpose_rhs ? rhs_shape[rhs_shape.size() - 2] : rhs_shape.back();
563   Type element_type = getElementTypeOrSelf(lhs.getType());
564   return builder->create<DotGeneralOp>(
565       loc, RankedTensorType::get(shape, element_type), lhs, rhs,
566       dimension_numbers, precision_config);
567 }
568 
569 // Builds body for reduce op by using the template binary op as the
570 // reducer op.
571 template <typename Op>
BuildReduceBody(Type element_type,Region * body,OpBuilder * builder)572 static void BuildReduceBody(Type element_type, Region *body,
573                             OpBuilder *builder) {
574   OpBuilder::InsertionGuard guard(*builder);
575   Block *block = builder->createBlock(body);
576 
577   // Block arguments are scalars of the given element type.
578   Type type = RankedTensorType::get(/*shape=*/{}, element_type);
579   block->addArguments({type, type});
580 
581   Location loc = body->getLoc();
582   auto reducer =
583       builder->create<Op>(loc, block->getArgument(0), block->getArgument(1));
584   builder->create<ReturnOp>(loc, reducer.getResult());
585 }
586 
587 // Builds region taking two arguments and returning second argument as the
588 // result. Corresponds to the function f(x, y) = y.
589 // Used in Scatter op's computation to update specific elements.
BuildBinaryAssignmentRegion(Type element_type,Region * region,OpBuilder * builder)590 static void BuildBinaryAssignmentRegion(Type element_type, Region *region,
591                                         OpBuilder *builder) {}
592 
593 // Builds a set of operations for applying reduction on the input value. A
594 // tf.sum op is created and will be legalized to tfl ops automatically.
ApplyReduction(Location loc,Value input,DenseIntElementsAttr reduce_dims,OpBuilder * builder)595 static Value ApplyReduction(Location loc, Value input,
596                             DenseIntElementsAttr reduce_dims,
597                             OpBuilder *builder) {
598   auto reduce_dims_op = builder->create<ConstOp>(loc, reduce_dims);
599   return builder->create<TF::SumOp>(loc, input, reduce_dims_op,
600                                     builder->getBoolAttr(false));
601 }
602 
603 // Creates a mhlo.rng_uniform op with `builder` to generate `num_elements`
604 // 32-bit integer numbers in the range of [`lower_limit`, `upper_limit`).
CreateRngUniform32(Location loc,int num_elements,int lower_limit,int upper_limit,OpBuilder * builder)605 static mhlo::RngUniformOp CreateRngUniform32(Location loc, int num_elements,
606                                              int lower_limit, int upper_limit,
607                                              OpBuilder *builder) {
608   auto i32_type = builder->getIntegerType(32);
609   auto key_type = RankedTensorType::get({num_elements}, i32_type);
610   auto shape_tensor = builder->create<mhlo::ConstOp>(
611       loc, GetI64ElementsAttr({num_elements}, builder));
612 
613   auto lower = builder->create<mhlo::ConstOp>(
614       loc, builder->getI32IntegerAttr(lower_limit));
615   auto upper = builder->create<mhlo::ConstOp>(
616       loc, builder->getI32IntegerAttr(upper_limit));
617 
618   return builder->create<mhlo::RngUniformOp>(loc, key_type, lower, upper,
619                                              shape_tensor);
620 }
621 
622 using WhileBodyFnType = llvm::function_ref<void(
623     Location loc, Value iteration, ArrayRef<Value> old_values,
624     SmallVectorImpl<Value> *new_values, OpBuilder *builder)>;
625 
626 // Creates a mhlo.while op with `builder` to loop `num_interations` times,
627 // each time calling the given `body_fn` on a set of values to generate a new
628 // set of values. Returns the final set of values via `final_values`. The
629 // initial set of values is passed in via `init_values`.
630 //
631 // This effectively does:
632 //
633 // ```c++
634 // SmallVector<Values, 4> old_values = init_values;
635 // SmallVector<Values, 4> new_values;
636 // for (int i = 0; i < num_iterations; ++i) {
637 //   body_fn(old_values, &new_values, ...);
638 //   old_values = new_values;
639 // }
640 // ```
641 //
642 // Under the hood an induction variable is prepended to values to control the
643 // number of iterations, but that is transparent to `body_fn`, which does not
644 // need to care about that.
CreateWhile32(Location loc,int num_iterations,WhileBodyFnType body_fn,ArrayRef<Value> init_values,SmallVectorImpl<Value> * final_values,OpBuilder * builder)645 static void CreateWhile32(Location loc, int num_iterations,
646                           WhileBodyFnType body_fn, ArrayRef<Value> init_values,
647                           SmallVectorImpl<Value> *final_values,
648                           OpBuilder *builder) {
649   int value_count = init_values.size() + 1;
650 
651   // Prepend a loop induction variable to the initial values.
652   SmallVector<Value, 2> init_values_with_loop_iv;
653   init_values_with_loop_iv.reserve(value_count);
654   // The initial value for the loop induction variable is 0.
655   init_values_with_loop_iv.push_back(
656       builder->create<mhlo::ConstOp>(loc, builder->getI32IntegerAttr(0)));
657   init_values_with_loop_iv.append(init_values.begin(), init_values.end());
658 
659   // Prepare the initial tuple for the while op.
660   auto init_tuple =
661       builder->create<mhlo::TupleOp>(loc, init_values_with_loop_iv);
662   auto tuple_type = init_tuple.getType();
663 
664   // Create the while op.
665   auto while_op = builder->create<mhlo::WhileOp>(loc, init_tuple);
666 
667   {
668     OpBuilder::InsertionGuard guard(*builder);
669 
670     // Build up the only block in the condition region. It should take one
671     // argument of the loop's tuple type.
672     Region &condition = while_op.cond();
673     Block *block = builder->createBlock(&condition);
674     BlockArgument arg = block->addArgument(tuple_type);
675 
676     // Get the loop induction variable and compare it against the upper limit.
677     auto loop_iv = builder->create<GetTupleElementOp>(loc, arg, 0);
678     auto upper_limit = builder->create<mhlo::ConstOp>(
679         loc, builder->getI32IntegerAttr(num_iterations));
680     StringAttr compare_direction = StringAttr::get(builder->getContext(), "LT");
681     Value compare = builder->create<mhlo::CompareOp>(loc, loop_iv, upper_limit,
682                                                      compare_direction);
683 
684     builder->create<mhlo::ReturnOp>(loc, compare);
685   }
686 
687   {
688     OpBuilder::InsertionGuard guard(*builder);
689 
690     // Build up the only block in the body region. It should take one
691     // argument of the loop's tuple type.
692     Region &body = while_op.body();
693     Block *block = builder->createBlock(&body);
694     BlockArgument arg = block->addArgument(tuple_type);
695 
696     SmallVector<Value, 4> old_values;  // From the previous iteration
697     SmallVector<Value, 4> new_values;  // Generated by this iteration
698     old_values.reserve(value_count);
699     new_values.reserve(value_count);
700 
701     // Unpack the tuple value from the last iteration.
702     for (int i = 0; i < value_count; ++i)
703       old_values.push_back(builder->create<GetTupleElementOp>(loc, arg, i));
704 
705     // Feed all values excluding the loop induction variable to body_fn.
706     body_fn(loc, old_values[0], llvm::makeArrayRef(old_values).drop_front(),
707             &new_values, builder);
708 
709     // Increment the loop induction variable by one.
710     auto one =
711         builder->create<mhlo::ConstOp>(loc, builder->getI32IntegerAttr(1));
712     auto scalar_broadcast_dims = GetI64ElementsAttr({}, builder);
713     auto plus_one = builder->create<chlo::BroadcastAddOp>(
714         loc, old_values[0], one, scalar_broadcast_dims);
715     // Prepend with the updated loop induction variable.
716     new_values.insert(new_values.begin(), plus_one);
717 
718     Value updated_tuple = builder->create<mhlo::TupleOp>(loc, new_values);
719 
720     builder->create<mhlo::ReturnOp>(loc, updated_tuple);
721   }
722 
723   final_values->reserve(init_values.size());
724   for (int i = 0, e = init_values.size(); i < e; ++i)
725     final_values->push_back(
726         builder->create<GetTupleElementOp>(loc, while_op, i + 1));
727 }
728 
729 //===----------------------------------------------------------------------===//
730 // BatchNorm op utilities.
731 //===----------------------------------------------------------------------===//
732 
getFeatureDimensionAttr(Builder & b,StringRef format,Value input)733 static IntegerAttr getFeatureDimensionAttr(Builder &b, StringRef format,
734                                            Value input) {
735   return b.getI64IntegerAttr(
736       GetFeatureDimension(format, input.getType().cast<RankedTensorType>()));
737 }
738 
739 //===----------------------------------------------------------------------===//
740 // FFT op utilities.
741 //===----------------------------------------------------------------------===//
742 // Returns the 1D i64 elements attribute populated with the inner-most dim of
743 // the value.
GetInnerDimFromValue(ShapedType type,Builder * builder)744 static DenseIntElementsAttr GetInnerDimFromValue(ShapedType type,
745                                                  Builder *builder) {
746   if (type.getRank() == 0) {
747     return builder->getI64TensorAttr({});
748   }
749   return builder->getI64TensorAttr(type.getShape().back());
750 }
751 
752 // Returns True if the inner-most dim is static.
CheckInnerDimStatic(ShapedType type,Builder * builder)753 bool CheckInnerDimStatic(ShapedType type, Builder *builder) {
754   if (!type.hasRank()) {
755     return false;
756   }
757   return !type.isDynamicDim(type.getShape().size() - 1);
758 }
759 
760 //===----------------------------------------------------------------------===//
761 // MatMul op utilities.
762 //===----------------------------------------------------------------------===//
763 
764 // If the 'transpose' attribute is true returns ElementsAttr to transpose 2D
765 // matrix. Otherwise, returns ElementsAttr for identity transpose.
Get2DTransposePerm(BoolAttr transpose,Builder * b)766 static DenseIntElementsAttr Get2DTransposePerm(BoolAttr transpose, Builder *b) {
767   if (transpose.getValue()) return GetI64ElementsAttr({1, 0}, b);
768   return GetI64ElementsAttr({0, 1}, b);
769 }
770 
771 //===----------------------------------------------------------------------===//
772 // MatrixBandPart op utilities.
773 //===----------------------------------------------------------------------===//
774 
775 // Gets the size of the dimension `dim_from_end` from the end of `input`.
776 // Requires that `input` is a tensor.
GetDimensionSizeFromEnd(Value input,int dim_from_end)777 static int GetDimensionSizeFromEnd(Value input, int dim_from_end) {
778   // Note: the verifier enforces that `input` is a ranked tensor.
779   auto input_type = input.getType().cast<TensorType>();
780   auto input_shape = input_type.getShape();
781   int dim = (input_shape.size() - 1) - dim_from_end;
782   return input_shape[dim];
783 }
784 
785 // Gets a 2D tensor type with shape {dim_0, dim_1}, where `dim_0` and `dim_1`
786 // have the same size as the last two dimensions of `input` (the second-to-last
787 // dimension and last dimension, respectively). The element type of the
788 // outputted RankedTensorType will match the element type of `input`.
789 // Requires that `input` is a tensor.
Get2DTensorType(Value input,Value num_lower)790 static RankedTensorType Get2DTensorType(Value input, Value num_lower) {
791   // `dim_0` refers to the second-to-last dimension; `dim_1` refers to the last.
792   int dim_0 = GetDimensionSizeFromEnd(input, 1);
793   int dim_1 = GetDimensionSizeFromEnd(input, 0);
794   auto element_type = num_lower.getType().cast<TensorType>().getElementType();
795   return RankedTensorType::get({dim_0, dim_1}, element_type);
796 }
797 
798 // Creates a HLO ConvertOp, converting `input` to have the same element type as
799 // `elem_type_tensor`. Requires `elem_type_tensor` to be a tensor.
CreateConvertOp(OpBuilder * builder,Location loc,Value input,Value elem_type_tensor)800 static Value CreateConvertOp(OpBuilder *builder, Location loc, Value input,
801                              Value elem_type_tensor) {
802   auto element_type =
803       elem_type_tensor.getType().cast<TensorType>().getElementType();
804   return builder->create<mhlo::ConvertOp>(loc, input, element_type);
805 }
806 
807 //===----------------------------------------------------------------------===//
808 // Pad op utilities.
809 //===----------------------------------------------------------------------===//
810 
811 // Slices input attribute of rank two and returns the specified column.
812 //
813 // Always returns 64 bit integer attribute regardless of bitwidth of the input
814 // attribute.
SliceDenseIntElementsAttrColumn2D(ElementsAttr input,int column)815 static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D(
816     ElementsAttr input, int column) {
817   auto int_attr = input.cast<DenseIntElementsAttr>();
818   auto shaped_type = int_attr.getType();
819   auto shape = shaped_type.getShape();
820 
821   if (shape.size() != 2) return DenseIntElementsAttr();
822 
823   llvm::SmallVector<int64_t, 4> values;
824   values.reserve(shaped_type.getNumElements() / shape[1]);
825 
826   for (auto it : llvm::enumerate(int_attr.getIntValues())) {
827     if (static_cast<int>(it.index() % shape[1]) == column) {
828       values.push_back(it.value().getSExtValue());
829     }
830   }
831 
832   auto element_type = IntegerType::get(input.getContext(), 64);
833   return DenseIntElementsAttr::get(
834       RankedTensorType::get({shape[0]}, element_type), values);
835 }
836 
837 // Returns interior padding to use in HLO Pad op based on the TensorFlow padding
838 // in TensorFlow PadV2 op.
GetInteriorPadding(ElementsAttr tf_padding)839 static DenseIntElementsAttr GetInteriorPadding(ElementsAttr tf_padding) {
840   auto length = tf_padding.getType().getShape()[0];
841   auto element_type = IntegerType::get(tf_padding.getContext(), 64);
842   return DenseIntElementsAttr::get<int64_t>(
843       RankedTensorType::get({length}, element_type), 0);
844 }
845 
846 //===----------------------------------------------------------------------===//
847 // Binary op utilities.
848 //===----------------------------------------------------------------------===//
849 
850 // Returns whether the two values are guaranteed to be broadcastable to the
851 // same shape, this broadcasts size 1 tensors up to any rank. Dynamic dimensions
852 // must be broadcasted with a size 1 tensor or another dynamic dimension.
853 // Returns false on rankless.
AreBroadcastCompatible(Value x,Value y)854 static bool AreBroadcastCompatible(Value x, Value y) {
855   auto x_rankless = x.getType().dyn_cast<RankedTensorType>();
856   auto y_rankless = y.getType().dyn_cast<RankedTensorType>();
857   if (!x_rankless || !y_rankless) {
858     return false;
859   }
860 
861   // Check that the shapes can be broadcasted.
862   auto shape_x = x_rankless.getShape();
863   auto shape_y = y_rankless.getShape();
864 
865   int rank_diff = shape_x.size() - shape_y.size();
866   int offset_x = rank_diff > 0 ? rank_diff : 0;
867   int offset_y = rank_diff < 0 ? -rank_diff : 0;
868   for (int i = 0, s = std::min(shape_x.size(), shape_y.size()); i < s; i++) {
869     int index_x = i + offset_x;
870     int index_y = i + offset_y;
871     if ((shape_x[index_x] == -1 && shape_y[index_y] != 1) ||
872         (shape_y[index_y] == -1 && shape_x[index_x] != 1)) {
873       return false;
874     }
875   }
876 
877   return true;
878 }
879 
880 // Return a new TensorType the same rank and dimensions as the input with an
881 // updated element type.
ChangeTensorElementType(Builder * b,Type tensor_type,Type element_type)882 static Type ChangeTensorElementType(Builder *b, Type tensor_type,
883                                     Type element_type) {
884   RankedTensorType ranked_type = tensor_type.dyn_cast<RankedTensorType>();
885   if (ranked_type) {
886     return RankedTensorType::get(ranked_type.getShape(), element_type);
887   }
888 
889   return UnrankedTensorType::get(element_type);
890 }
891 
892 //===----------------------------------------------------------------------===//
893 // Softmax op utilities.
894 //===----------------------------------------------------------------------===//
895 
896 // Returns the type to use for accumulating the given type.
GetAccumulationType(Type ty)897 static Type GetAccumulationType(Type ty) {
898   // Upcast 16 bit sum reductions to 32 bit to reduce the precision loss from
899   // repeated floating point additions.
900   return (ty.isF16() || ty.isBF16()) ? FloatType::getF32(ty.getContext()) : ty;
901 }
902 
903 //===----------------------------------------------------------------------===//
904 // Softplus op utilities.
905 //===----------------------------------------------------------------------===//
906 
GetEpsilonValue(Type ty)907 static DenseElementsAttr GetEpsilonValue(Type ty) {
908   auto element_ty = ty.cast<TensorType>().getElementType();
909   auto scalar_ty = RankedTensorType::get({}, element_ty);
910   if (element_ty.isF16()) {
911     uint16_t raw_epsilon = Eigen::NumTraits<Eigen::half>::epsilon().x;
912     auto value = APFloat(APFloat::IEEEhalf(), APInt(16, raw_epsilon));
913     return DenseElementsAttr::get(scalar_ty, value);
914   } else if (element_ty.isBF16()) {
915     uint16_t raw_epsilon = Eigen::NumTraits<Eigen::bfloat16>::epsilon().value;
916     auto value = APFloat(APFloat::BFloat(), APInt(16, raw_epsilon));
917     return DenseElementsAttr::get(scalar_ty, value);
918   } else if (element_ty.isF32()) {
919     auto value = APFloat(std::numeric_limits<float>::epsilon());
920     return DenseElementsAttr::get(scalar_ty, value);
921   } else if (element_ty.isF64()) {
922     auto value = APFloat(std::numeric_limits<double>::epsilon());
923     return DenseElementsAttr::get(scalar_ty, value);
924   }
925   llvm_unreachable("unsupported element type for tf.SoftPlus");
926 }
927 
928 //===----------------------------------------------------------------------===//
929 // ArgMax/ArgMin op utilities.
930 //===----------------------------------------------------------------------===//
931 
BuildArgMinMaxReductionBody(Type input_element_type,Type index_element_type,StringRef direction,Region * body,OpBuilder * builder)932 static void BuildArgMinMaxReductionBody(Type input_element_type,
933                                         Type index_element_type,
934                                         StringRef direction, Region *body,
935                                         OpBuilder *builder) {
936   OpBuilder::InsertionGuard insertion_point_gurad(*builder);
937 
938   Type input_type = RankedTensorType::get(/*shape=*/{}, input_element_type);
939   Type index_type = RankedTensorType::get(/*shape=*/{}, index_element_type);
940   Block *block = builder->createBlock(body);
941   block->addArguments({input_type, index_type, input_type, index_type});
942 
943   Location loc = body->getLoc();
944   StringAttr compare_direction =
945       StringAttr::get(builder->getContext(), direction);
946   Value compare = builder->create<CompareOp>(
947       loc, block->getArgument(0), block->getArgument(2), compare_direction);
948 
949   Value selected_input = builder->create<SelectOp>(
950       loc, input_type, compare, block->getArgument(0), block->getArgument(2));
951   Value selected_index = builder->create<SelectOp>(
952       loc, index_type, compare, block->getArgument(1), block->getArgument(3));
953 
954   Value return_values[] = {selected_input, selected_index};
955   builder->create<ReturnOp>(loc, return_values);
956 }
957 
958 //===----------------------------------------------------------------------===//
959 // PartitionedCall op utilities.
960 //===----------------------------------------------------------------------===//
961 
962 // Verify that the arguments to be passed into the function are the same types
963 // as the function paramter types.
ArgTypesMatchCallee(mlir::Operation * op,OperandRange args,SymbolRefAttr func)964 static bool ArgTypesMatchCallee(mlir::Operation *op, OperandRange args,
965                                 SymbolRefAttr func) {
966   auto module = op->getParentOfType<ModuleOp>();
967   auto function =
968       dyn_cast_or_null<FuncOp>(SymbolTable::lookupSymbolIn(module, func));
969   FunctionType function_ty = function.getType();
970 
971   for (auto arg_in : llvm::zip(args, function_ty.getInputs())) {
972     if (std::get<0>(arg_in).getType() != std::get<1>(arg_in)) {
973       // Argument type and input type mismatch.
974       return false;
975     }
976   }
977   return true;
978 }
979 
980 //===----------------------------------------------------------------------===//
981 // Slice op utilities.
982 //===----------------------------------------------------------------------===//
983 
CanBeTranslatedToDynamicSlice(Value input,Value start_indices,DenseIntElementsAttr slice_sizes)984 static bool CanBeTranslatedToDynamicSlice(Value input, Value start_indices,
985                                           DenseIntElementsAttr slice_sizes) {
986   auto input_ty = input.getType().dyn_cast<RankedTensorType>();
987   if (!input_ty) return false;
988   auto start_indices_ty = start_indices.getType().dyn_cast<RankedTensorType>();
989   if (!start_indices_ty) return false;
990 
991   int64_t input_rank = input_ty.getRank();
992   ArrayRef<int64_t> input_shape = input_ty.getShape();
993   DenseIntElementsAttr constant_start_indices;
994   bool is_constant_start =
995       matchPattern(start_indices, m_Constant(&constant_start_indices));
996 
997   for (int64_t i = 0; i < input_rank; ++i) {
998     int64_t input_size = input_shape[i];
999     int64_t slice_size = slice_sizes.getValue<IntegerAttr>(i).getInt();
1000     // A slice_size of -1 means "all elements from start_index to the end".
1001     // In order to support these semantics, we need to know both the start index
1002     // and the shape of the input dimension.
1003     if (slice_size < 0 && (!is_constant_start || input_size < 0)) return false;
1004   }
1005   return true;
1006 }
1007 
1008 // TF slice size can be -1, which represents all elements from start_index to
1009 // the end. HLO slice size can't be -1. As such, we need to translate TF slice
1010 // size -1 to HLO slice size.
TFSliceSizes2HLOSliceSizes(Value input,Value start_indices,DenseIntElementsAttr slice_sizes,Builder * builder)1011 static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes(
1012     Value input, Value start_indices, DenseIntElementsAttr slice_sizes,
1013     Builder *builder) {
1014   DenseIntElementsAttr constant_start_indices;
1015   if (!matchPattern(start_indices, m_Constant(&constant_start_indices))) {
1016     return hlo::ConvertElementsAttr(slice_sizes, builder->getIntegerType(64))
1017         .cast<DenseIntElementsAttr>();
1018   }
1019 
1020   auto input_ty = input.getType().dyn_cast<RankedTensorType>();
1021   int64_t input_rank = input_ty.getRank();
1022   ArrayRef<int64_t> input_shape = input_ty.getShape();
1023   SmallVector<int64_t, 4> normalized_sizes;
1024 
1025   for (int64_t i = 0; i < input_rank; ++i) {
1026     int64_t input_size = input_shape[i];
1027     int64_t start_index =
1028         constant_start_indices.getValue<IntegerAttr>(i).getInt();
1029     int64_t slice_size = slice_sizes.getValue<IntegerAttr>(i).getInt();
1030     normalized_sizes.push_back(slice_size == -1 ? input_size - start_index
1031                                                 : slice_size);
1032   }
1033 
1034   return GetI64ElementsAttr(normalized_sizes, builder);
1035 }
1036 
1037 //===----------------------------------------------------------------------===//
1038 // Sort op utilities.
1039 //===----------------------------------------------------------------------===//
1040 
1041 // Builds the region `body` for mhlo.sort's comparator: for each type in
1042 // `element_types`, create two block arguments, one for lhs and one for rhs, and
1043 // generates mhlo.compare op to compare them with the given `direction`.
1044 //
1045 // Note that this right now only does comparision on the first pair of block
1046 // arguments.
BuildSortComparisonBody(llvm::ArrayRef<Type> element_types,StringRef direction,llvm::Optional<StringRef> compare_type,Region * body,OpBuilder * builder)1047 static void BuildSortComparisonBody(llvm::ArrayRef<Type> element_types,
1048                                     StringRef direction,
1049                                     llvm::Optional<StringRef> compare_type,
1050                                     Region *body, OpBuilder *builder) {
1051   OpBuilder::InsertionGuard insertion_point_gurad(*builder);
1052 
1053   Block *block = builder->createBlock(body);
1054   // Add two arguments for each element type.
1055   for (Type element_type : element_types) {
1056     TensorType tensor_type = RankedTensorType::get({}, element_type);
1057     block->addArguments({tensor_type, tensor_type});
1058   }
1059 
1060   Location loc = body->getLoc();
1061   StringAttr compare_direction = builder->getStringAttr(direction);
1062   StringAttr type_attr;
1063   if (compare_type) type_attr = builder->getStringAttr(*compare_type);
1064   Value compare = builder->create<mhlo::CompareOp>(
1065       loc, block->getArgument(0), block->getArgument(1), compare_direction,
1066       type_attr);
1067 
1068   builder->create<mhlo::ReturnOp>(loc, compare);
1069 }
1070 
1071 //===----------------------------------------------------------------------===//
1072 // XlaGather op utilities.
1073 //===----------------------------------------------------------------------===//
1074 
HasValidGatherDims(StringAttr attr)1075 bool HasValidGatherDims(StringAttr attr) {
1076   ::xla::GatherDimensionNumbers dims;
1077   return dims.ParseFromString(attr.getValue().str());
1078 }
1079 
GetGatherDimNumsAttr(StringAttr attr,Builder * builder)1080 GatherDimensionNumbers GetGatherDimNumsAttr(StringAttr attr, Builder *builder) {
1081   ::xla::GatherDimensionNumbers dims;
1082   if (!dims.ParseFromString(attr.getValue().str())) return {};
1083   return ::xla::ConvertGatherDimensionNumbers(dims, builder);
1084 }
1085 
1086 //===----------------------------------------------------------------------===//
1087 // Op converters.
1088 //===----------------------------------------------------------------------===//
1089 
GetConvDimensionNumbersAttr(ArrayRef<int64_t> spatial_dim_indices,tensorflow::TensorFormat format,Builder * builder)1090 NamedAttribute GetConvDimensionNumbersAttr(
1091     ArrayRef<int64_t> spatial_dim_indices, tensorflow::TensorFormat format,
1092     Builder *builder) {
1093   int64_t num_spatial_dims = spatial_dim_indices.size();
1094   int64_t num_dims = num_spatial_dims + 2;
1095 
1096   IntegerAttr batch_dim =
1097       builder->getI64IntegerAttr(GetTensorBatchDimIndex(num_dims, format));
1098   IntegerAttr feature_dim =
1099       builder->getI64IntegerAttr(GetTensorFeatureDimIndex(num_dims, format));
1100   DenseIntElementsAttr spatial_dims =
1101       GetI64ElementsAttr(spatial_dim_indices, builder);
1102 
1103   // Filters data_format is always HWIO so input channels dimension is after
1104   // all spatial dimensions.
1105   IntegerAttr kernel_input_feature_dim =
1106       builder->getI64IntegerAttr(num_spatial_dims);
1107   IntegerAttr kernel_output_feature_dim =
1108       builder->getI64IntegerAttr(num_spatial_dims + 1);
1109   DenseIntElementsAttr kernel_spatial_dimensions =
1110       GetI64ElementsAttrForSeq(0, num_spatial_dims, builder);
1111 
1112   return builder->getNamedAttr(
1113       "dimension_numbers",
1114       ConvDimensionNumbers::get(
1115           batch_dim, feature_dim, spatial_dims, kernel_input_feature_dim,
1116           kernel_output_feature_dim, kernel_spatial_dimensions, batch_dim,
1117           feature_dim, spatial_dims, builder->getContext()));
1118 }
1119 
1120 // Converts a TF::BiasAddOp to HLO.
1121 // This differs from a normal TF::AddOp with respect to how the data_format
1122 // is handled, which can optionally require a general broadcast of the
1123 // 'bias' term in a way that is not compatible with the standard left-padded
1124 // broadcast semantics (i.e. NCHW will broadcast into dimension 1).
1125 // The correct 'bias' broadcast will be synthesized manually.
1126 class ConvertBiasAddOp : public OpRewritePattern<TF::BiasAddOp> {
1127  public:
1128   using OpRewritePattern::OpRewritePattern;
matchAndRewrite(TF::BiasAddOp op,PatternRewriter & rewriter) const1129   LogicalResult matchAndRewrite(TF::BiasAddOp op,
1130                                 PatternRewriter &rewriter) const override {
1131     auto loc = op.getLoc();
1132     auto feature_dim = GetFeatureDimension(
1133         op.data_format(), op.value().getType().cast<RankedTensorType>());
1134     auto bias_broadcast = Broadcast1DToFeatureDim(loc, op.value(), op.bias(),
1135                                                   feature_dim, rewriter);
1136     rewriter.replaceOpWithNewOp<AddOp>(op, op.value(), bias_broadcast);
1137     return success();
1138   }
1139 };
1140 
1141 // Converts the TensorFlow conv op in template to the generic HLO conv op by
1142 // converting TensorFlow op attributes to HLO op attributes.
1143 //
1144 // Sample result for Conv2D:
1145 //
1146 //   %conv = "mhlo.convolution"(%input, %filter) {
1147 //     strides = [1, 2],
1148 //     paddings = [[1, 0], [1, 1]],
1149 //     ...
1150 //   }
1151 //
1152 // This pattern is not defined using declarative rewrite rules as computation of
1153 // the paddings attribute anyway requires multiple source op attributes and
1154 // result op attributes. Defining it as declarative rewrite rule will introduce
1155 // some duplication in the C++ helper methods.
1156 template <typename OpTy, int num_spatial_dims, bool depthwise_conv = false>
1157 class ConvertConvOp : public OpRewritePattern<OpTy> {
1158  public:
1159   using OpRewritePattern<OpTy>::OpRewritePattern;
1160 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const1161   LogicalResult matchAndRewrite(OpTy op,
1162                                 PatternRewriter &rewriter) const override {
1163     tensorflow::TensorFormat data_format;
1164     if (!FormatFromString(op.data_format().str(), &data_format))
1165       return failure();
1166 
1167     tensorflow::Padding padding;
1168     if (!GetPaddingFromString(op.padding().str(), &padding).ok())
1169       return failure();
1170 
1171     auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
1172     auto filter_ty =
1173         op.filter().getType().template dyn_cast<RankedTensorType>();
1174     auto result_ty = op.getType().template dyn_cast<RankedTensorType>();
1175 
1176     // Input, filter and the result needs to have static shape for calculation
1177     // of HLO paddings and feature group count attributes.
1178     for (RankedTensorType ty : {input_ty, filter_ty, result_ty})
1179       if (!ty || !ty.hasStaticShape()) return failure();
1180 
1181     ArrayRef<Attribute> dilations = op.dilations().getValue();
1182     ArrayRef<Attribute> strides = op.strides().getValue();
1183     ArrayRef<Attribute> explicit_paddings;
1184     if (padding == tensorflow::Padding::EXPLICIT) {
1185       // EXPLICIT padding mode and the associated attribute is limited to
1186       // Conv2D. So, fetch attribute by identifier instead of the
1187       // op.explicit_paddings() attribute getter.
1188       explicit_paddings =
1189           op->template getAttrOfType<ArrayAttr>("explicit_paddings").getValue();
1190     }
1191 
1192     SmallVector<int64_t, num_spatial_dims> spatial_dim_indices;
1193     SmallVector<int64_t, num_spatial_dims> rhs_dilations;
1194     SmallVector<int64_t, num_spatial_dims> window_strides;
1195     SmallVector<int64_t, num_spatial_dims * 2> paddings;
1196 
1197     auto get_int = [](Attribute attr) {
1198       return attr.template cast<IntegerAttr>().getInt();
1199     };
1200 
1201     constexpr int num_dims = num_spatial_dims + 2;
1202     for (auto i : llvm::seq<int>(0, num_spatial_dims)) {
1203       const int64_t dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
1204       spatial_dim_indices.push_back(dim);
1205 
1206       const int64_t dilation = get_int(dilations[dim]);
1207       rhs_dilations.push_back(dilation);
1208       const int64_t stride = get_int(strides[dim]);
1209       window_strides.push_back(stride);
1210 
1211       int64_t pad_low, pad_high;
1212       if (padding == tensorflow::Padding::EXPLICIT) {
1213         pad_low = get_int(explicit_paddings[2 * dim]);
1214         pad_high = get_int(explicit_paddings[2 * dim + 1]);
1215       } else {
1216         tensorflow::int64 output_size;
1217         tensorflow::int64 pad_low_int64;
1218         tensorflow::int64 pad_high_int64;
1219         tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
1220             input_ty.getDimSize(dim), filter_ty.getDimSize(i), dilation, stride,
1221             padding, &output_size, &pad_low_int64, &pad_high_int64);
1222         if (!status.ok()) return failure();
1223         pad_low = pad_low_int64;
1224         pad_high = pad_high_int64;
1225       }
1226       paddings.push_back(pad_low);
1227       paddings.push_back(pad_high);
1228     }
1229 
1230     auto rhs_dilations_attr = rewriter.getNamedAttr(
1231         "rhs_dilation", GetI64ElementsAttr(rhs_dilations, &rewriter));
1232 
1233     auto window_strides_attr = rewriter.getNamedAttr(
1234         "window_strides", GetI64ElementsAttr(window_strides, &rewriter));
1235 
1236     auto dimension_numbers_attr = GetConvDimensionNumbersAttr(
1237         spatial_dim_indices, data_format, &rewriter);
1238 
1239     const int64_t input_channels =
1240         GetDimSize(input_ty, GetTensorFeatureDimIndex(num_dims, data_format));
1241     // Filters data_format is always HWIO so input channels dimension is after
1242     // all spatial dimensions.
1243     const int64_t filter_channels = GetDimSize(filter_ty, num_spatial_dims);
1244     // TensorFlow convolution op verifies that the number of input channels is
1245     // divisible by the number of filter channels.
1246     // For depthwise convolution the feature_group_count argument would be set
1247     // to the input feature dimension.
1248     const int64_t feature_group_count =
1249         depthwise_conv ? input_channels : input_channels / filter_channels;
1250     auto feature_group_count_attr = rewriter.getNamedAttr(
1251         "feature_group_count", rewriter.getI64IntegerAttr(feature_group_count));
1252 
1253     auto batch_group_count_attr = rewriter.getNamedAttr(
1254         "batch_group_count", rewriter.getI64IntegerAttr(1));
1255 
1256     RankedTensorType paddings_ty = RankedTensorType::get(
1257         {num_spatial_dims, 2}, rewriter.getIntegerType(64));
1258     auto paddings_attr = rewriter.getNamedAttr(
1259         "padding", DenseElementsAttr::get<int64_t>(paddings_ty, paddings));
1260 
1261     SmallVector<Value, 2> operands(op.getOperands());
1262     // Reshape the filter to {spatial_dims...., 1,in_channels *
1263     // channel_multiplier}
1264     if (depthwise_conv) {
1265       ArrayRef<int64_t> filter_shape = filter_ty.getShape();
1266       llvm::SmallVector<int64_t, num_dims> new_shape(
1267           filter_shape.begin(), filter_shape.begin() + num_spatial_dims);
1268       new_shape.push_back(1);
1269       new_shape.push_back(filter_shape[num_spatial_dims] *
1270                           filter_shape[num_spatial_dims + 1]);
1271       operands[1] = rewriter.create<mhlo::ReshapeOp>(
1272           op.getLoc(),
1273           RankedTensorType::get(new_shape, filter_ty.getElementType()),
1274           operands[1]);
1275     }
1276     NamedAttribute attrs[] = {rhs_dilations_attr,     window_strides_attr,
1277                               dimension_numbers_attr, feature_group_count_attr,
1278                               batch_group_count_attr, paddings_attr};
1279     rewriter.replaceOpWithNewOp<ConvOp>(op, op.getType(), operands,
1280                                         llvm::makeArrayRef(attrs));
1281     return success();
1282   }
1283 };
1284 
1285 using ConvertConv2DOp = ConvertConvOp<TF::Conv2DOp, /*num_spatial_dims=*/2>;
1286 using ConvertConv3DOp = ConvertConvOp<TF::Conv3DOp, /*num_spatial_dims=*/3>;
1287 using ConvertDepthConv2DOp =
1288     ConvertConvOp<TF::DepthwiseConv2dNativeOp, /*num_spatial_dims=*/2,
1289                   /*depthwise_conv=*/true>;
1290 
1291 // Converts BF16 FloorDiv op to have casting operators on either end as BF16
1292 // division can result in strange behavior.
1293 //
1294 //      floordiv = cast(floordiv(cast(left), cast(right))))
1295 //
1296 //   %left_cast = cast(%left)
1297 //   %right_cast = cast(%right)
1298 //   %div = div(%left, %left)
1299 //   %floored = floor(%div)
1300 //   %floored_cast = cast(%floored)
1301 //
1302 // Required to manually specify the intermediate types.
1303 class ConvertBF16FloorDivOp : public OpRewritePattern<TF::FloorDivOp> {
1304  public:
1305   using OpRewritePattern::OpRewritePattern;
1306 
matchAndRewrite(TF::FloorDivOp op,PatternRewriter & rewriter) const1307   LogicalResult matchAndRewrite(TF::FloorDivOp op,
1308                                 PatternRewriter &rewriter) const override {
1309     auto l = op.x();
1310     auto r = op.y();
1311     auto element_type = getElementTypeOrSelf(l.getType());
1312     if (!element_type.isBF16()) return failure();
1313 
1314     auto out_type = op.z().getType().cast<TensorType>();
1315 
1316     l = rewriter.create<ConvertOp>(op.getLoc(), l, rewriter.getF32Type());
1317     r = rewriter.create<ConvertOp>(op.getLoc(), r, rewriter.getF32Type());
1318 
1319     auto intermediate = rewriter.create<TF::FloorDivOp>(
1320         op.getLoc(),
1321         ChangeTensorElementType(&rewriter, out_type, rewriter.getF32Type()), l,
1322         r);
1323 
1324     auto floor_op =
1325         rewriter.create<ConvertOp>(op.getLoc(), out_type, intermediate);
1326     rewriter.replaceOp(op, floor_op.getResult());
1327     return success();
1328   }
1329 };
1330 
1331 class ConvertBroadcastToOp : public OpRewritePattern<TF::BroadcastToOp> {
1332  public:
1333   using OpRewritePattern::OpRewritePattern;
1334 
matchAndRewrite(TF::BroadcastToOp op,PatternRewriter & rewriter) const1335   LogicalResult matchAndRewrite(TF::BroadcastToOp op,
1336                                 PatternRewriter &rewriter) const override {
1337     auto input_type = op.input().getType().dyn_cast<RankedTensorType>();
1338     auto output_type = op.output().getType().dyn_cast<RankedTensorType>();
1339     if (!input_type || !output_type) {
1340       return rewriter.notifyMatchFailure(op, "requires ranked shape");
1341     }
1342     auto rank_diff = output_type.getRank() - input_type.getRank();
1343     // The tf.BroadcastTo op performs "right-aligned" numpy-style broadcasting.
1344     auto broadcast_dimensions = llvm::to_vector<4>(
1345         llvm::seq<int64_t>(rank_diff, output_type.getRank()));
1346     rewriter.replaceOpWithNewOp<DynamicBroadcastInDimOp>(
1347         op, output_type, op.input(), op.shape(),
1348         rewriter.getI64TensorAttr(broadcast_dimensions));
1349     return success();
1350   }
1351 };
1352 
1353 // Converts TensorFlow DiagPartOp to HLO ops using reduction on masked matrix.
1354 // For a Rank-2 input, it creates the following ops:
1355 //   %1 = "mhlo.iota"() {iota_dimension = 0 : i64}
1356 //   %2 = "mhlo.iota"() {iota_dimension = 1 : i64}
1357 //   %3 = "mhlo.compare"(%1, %2) {comparison_direction = "EQ"}
1358 //   %4 = mhlo.constant dense<0.000000e+00> : tensor<f32>
1359 //   %5 = "mhlo.broadcast"(%4)
1360 //   %6 = "mhlo.select"(%3, %input, %5)
1361 //   %7 = "mhlo.reduce"(%6, %4) ( {
1362 //   ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
1363 //     %9 = mhlo.add %arg1, %arg2 : tensor<f32>
1364 //     "mhlo.return"(%9) : (tensor<f32>) -> ()
1365 //   }) {dimensions = dense<0> : tensor<1xi64>}
1366 //
1367 // If the input's rank N is greater than 2, we will reshape it to R2 first and
1368 // create the above ops, then reshape it back to rank N/2.
1369 class ConvertDiagPartOp : public OpRewritePattern<TF::DiagPartOp> {
1370  public:
1371   using OpRewritePattern::OpRewritePattern;
1372 
matchAndRewrite(TF::DiagPartOp op,PatternRewriter & rewriter) const1373   LogicalResult matchAndRewrite(TF::DiagPartOp op,
1374                                 PatternRewriter &rewriter) const override {
1375     auto input_type = op.input().getType().dyn_cast<RankedTensorType>();
1376     if (!input_type || !input_type.hasStaticShape()) return failure();
1377     int64_t num_dims = input_type.getRank();
1378     if (num_dims < 2 || num_dims % 2 != 0) return failure();
1379     const int64_t out_dims = num_dims / 2;
1380 
1381     int64_t new_size = 1;
1382     llvm::SmallVector<int64_t, 4> new_dims;
1383     for (int i = 0; i < out_dims; i++) {
1384       if (input_type.getDimSize(i) != input_type.getDimSize(i + out_dims))
1385         return op.emitOpError("invalid dimensions size");
1386       new_size *= input_type.getDimSize(i);
1387       new_dims.push_back(input_type.getDimSize(i));
1388     }
1389     Value reshaped_input = rewriter.create<mhlo::ReshapeOp>(
1390         op.getLoc(),
1391         RankedTensorType::get({new_size, new_size},
1392                               input_type.getElementType()),
1393         op.input());
1394     auto iota_type = RankedTensorType::get({new_size, new_size},
1395                                            rewriter.getIntegerType(32));
1396     auto iota0 = rewriter.create<IotaOp>(op.getLoc(), iota_type,
1397                                          rewriter.getI64IntegerAttr(0));
1398     auto iota1 = rewriter.create<IotaOp>(op.getLoc(), iota_type,
1399                                          rewriter.getI64IntegerAttr(1));
1400     Value compare = rewriter.create<CompareOp>(
1401         op.getLoc(), iota0, iota1,
1402         StringAttr::get(rewriter.getContext(), "EQ"));
1403     Value zero = GetScalarConstOfType(input_type.getElementType(), op.getLoc(),
1404                                       0, &rewriter);
1405     Value zero_matrix = rewriter.create<BroadcastOp>(
1406         op.getLoc(), reshaped_input.getType(), zero,
1407         GetI64ElementsAttr({new_size, new_size}, &rewriter));
1408     Value masked =
1409         rewriter.create<SelectOp>(op.getLoc(), reshaped_input.getType(),
1410                                   compare, reshaped_input, zero_matrix);
1411     auto reduce = rewriter.create<ReduceOp>(op.getLoc(), masked, zero,
1412                                             GetI64ElementsAttr({0}, &rewriter));
1413     assert(!input_type.getElementType().isInteger(1) &&
1414            "data type should not be i1");
1415     BuildReduceBody<AddOp>(input_type.getElementType(), &reduce.body(),
1416                            &rewriter);
1417     rewriter.replaceOpWithNewOp<ReshapeOp>(
1418         op, RankedTensorType::get(new_dims, input_type.getElementType()),
1419         reduce.getResult(0));
1420     return success();
1421   }
1422 };
1423 
1424 // Converts TensorFlow MatrixDiagPartOp to HLO ops.
1425 class ConvertMatrixDiagPartV3Op
1426     : public OpRewritePattern<TF::MatrixDiagPartV3Op> {
1427   using Shape = llvm::SmallVector<int64_t, 4>;
1428 
1429   // Parse the "k" parameter. MatrixDiagPartV3 allows to specify the diagonal(s)
1430   // with k. This can be either a single value (for a single diagonal) or a
1431   // tuple of two values (starting and ending diagonal, for a band).
ExtractK(TF::MatrixDiagPartV3Op op,int64_t (* k)[2]) const1432   LogicalResult ExtractK(TF::MatrixDiagPartV3Op op, int64_t (*k)[2]) const {
1433     DenseIntElementsAttr kattr;
1434     if (!matchPattern(op.k(), m_Constant(&kattr))) {
1435       return failure();
1436     }
1437     DenseIntElementsAttr::iterator it = kattr.begin();
1438     (*k)[0] = (*it).getSExtValue();
1439     it++;
1440     if (it == kattr.end()) {
1441       // Handle input like e.g. "k = 5", in which case we extract a single
1442       // diagonal.
1443       (*k)[1] = (*k)[0];
1444     } else {
1445       // Handle input like e.g. "k = [-1, 1]", in which case we extract a
1446       // band (multiple diagonals).
1447       (*k)[1] = (*it).getSExtValue();
1448     }
1449     return success();
1450   }
1451 
1452   // Utility method for broadcasting integer constants to a given shape.
BroadcastConstant(Location loc,Shape shape,int32_t constant,int int_size,PatternRewriter & rewriter) const1453   BroadcastOp BroadcastConstant(Location loc, Shape shape, int32_t constant,
1454                                 int int_size, PatternRewriter &rewriter) const {
1455     return rewriter.create<BroadcastOp>(
1456         loc, RankedTensorType::get(shape, rewriter.getIntegerType(int_size)),
1457         GetScalarConstOfType(rewriter.getIntegerType(int_size), loc, constant,
1458                              &rewriter),
1459         GetI64ElementsAttr(shape, &rewriter));
1460   }
1461 
1462  public:
1463   using OpRewritePattern::OpRewritePattern;
1464 
matchAndRewrite(TF::MatrixDiagPartV3Op op,PatternRewriter & rewriter) const1465   LogicalResult matchAndRewrite(TF::MatrixDiagPartV3Op op,
1466                                 PatternRewriter &rewriter) const override {
1467     Location loc = op.getLoc();
1468     ShapedType input_type = op.input().getType().dyn_cast<ShapedType>();
1469     auto element_type = input_type.getElementType();
1470 
1471     // Align is a string specifying how superdiagonals and subdiagonals should
1472     // be aligned/padded for diagonals that are shorter than max_diag_len. The
1473     // format is "{super}_{sub}", with {super} the superdiagonal alignment and
1474     // {sub} the subdiagonal alignment. "LEFT" means rows will be padded to the
1475     // left, "RIGHT" means rows will be padded ot the right.  The default is
1476     // "RIGHT_LEFT".
1477     StringRef align = op->getAttrOfType<StringAttr>("align").getValue();
1478     enum Alignment { kLeft, kRight };
1479 
1480     // default is RIGHT_LEFT
1481     Alignment superdiagonal_align = kRight;
1482     Alignment subdiagonal_align = kLeft;
1483 
1484     if (align == "RIGHT_LEFT") {
1485       superdiagonal_align = kRight;
1486       subdiagonal_align = kLeft;
1487     } else if (align == "RIGHT_RIGHT") {
1488       superdiagonal_align = kRight;
1489       subdiagonal_align = kRight;
1490     } else if (align == "LEFT_RIGHT") {
1491       superdiagonal_align = kLeft;
1492       subdiagonal_align = kRight;
1493     } else if (align == "LEFT_LEFT") {
1494       superdiagonal_align = kLeft;
1495       subdiagonal_align = kLeft;
1496     } else {
1497       return failure();  // unsupported alignment
1498     }
1499 
1500     // MatrixDiagPart operates on a matrix of shape [I, J, ..., L, M, N], and
1501     // will extract the diagonal(s) out of [M, N], for all [I, J, ..., L].
1502     if (!input_type || !input_type.hasStaticShape()) return failure();
1503     int64_t num_dims = input_type.getRank();
1504     if (num_dims < 2) return failure();
1505     int64_t rows = input_type.getDimSize(num_dims - 2);  // rows
1506     int64_t cols = input_type.getDimSize(num_dims - 1);  // cols
1507 
1508     // We extract the diagonals from k[0] up to and including k[1].
1509     // Addressing is 0 for the main diagonal. (So k = [0, 0] would just extract
1510     // the main diagonal). It's negative for subdiagonals (under and to the left
1511     // of the main diagonal) and positive for superdiagonals (above and to the
1512     // right of the main diagonal).
1513     int64_t k[2];
1514     if (failed(ExtractK(op, &k))) return failure();
1515     int num_diags = k[1] - k[0] + 1;
1516 
1517     // Shifting diagonals away from the main diagonal might shorten them. This
1518     // is the longest diagonal we will see. We make this the last dimension of
1519     // the output shape.
1520     int64_t max_diag_len =
1521         std::min(rows + std::min(k[1], static_cast<int64_t>(0)),
1522                  cols + std::min(-k[0], static_cast<int64_t>(0)));
1523 
1524     // The first dimension is the index vector dimension we'll use for gather.
1525     // It's 1 here, but will be 2 once we glue x and y together.
1526     Shape indices_shape({1, num_diags, max_diag_len});
1527 
1528     RankedTensorType iota_type =
1529         RankedTensorType::get(indices_shape, rewriter.getIntegerType(32));
1530     Value iotaM =
1531         rewriter.create<IotaOp>(loc, iota_type, rewriter.getI64IntegerAttr(1));
1532     Value iotaN =
1533         rewriter.create<IotaOp>(loc, iota_type, rewriter.getI64IntegerAttr(2));
1534 
1535     // Boradcasted constants, of the same shape as iotaM and iotaN.
1536     Value b_zero = BroadcastConstant(loc, indices_shape, 0, 32, rewriter);
1537     Value b_false = BroadcastConstant(loc, indices_shape, 0, 1, rewriter);
1538     Value b_true = BroadcastConstant(loc, indices_shape, 1, 1, rewriter);
1539     Value b_k1 = BroadcastConstant(loc, indices_shape, k[1], 32, rewriter);
1540     Value b_rows = BroadcastConstant(loc, indices_shape, rows, 32, rewriter);
1541     Value b_cols = BroadcastConstant(loc, indices_shape, cols, 32, rewriter);
1542     Value b_max_diag_len =
1543         BroadcastConstant(loc, indices_shape, max_diag_len, 32, rewriter);
1544 
1545     // d = k[1] - m
1546     // (A.k.a. the number of the diagonal, depending on m. Note that we
1547     //  subtract m here. This means we start with the superdiagonals and
1548     //  move downwards towards the subdiagonals. So the start indices will
1549     //  be decreasing.)
1550     Value d = rewriter.create<SubOp>(loc, b_k1, iotaM);
1551     Value neg_d = rewriter.create<NegOp>(loc, d);
1552 
1553     // diag_len_d = min(rows + min(d, 0), cols - max(d, 0))
1554     // (Length of a diagonal for a given d. Same as max_diag_len for m = 0.)
1555     Value diag_len_d = rewriter.create<MinOp>(
1556         loc,
1557         rewriter.create<AddOp>(loc, b_rows,
1558                                rewriter.create<MinOp>(loc, d, b_zero)),
1559         rewriter.create<SubOp>(loc, b_cols,
1560                                rewriter.create<MaxOp>(loc, d, b_zero)));
1561 
1562     // offset is max_diag_len - diag_len_d if we're padding, 0 otherwise.
1563     Value cmp;
1564     if (subdiagonal_align == kRight && superdiagonal_align == kRight) {
1565       cmp = b_true;
1566     } else if (superdiagonal_align == kRight) {
1567       // offset = d>=0 ? max_diag_len - diag_len_d : 0
1568       cmp = rewriter.create<TF::GreaterEqualOp>(loc, d, b_zero);
1569     } else if (subdiagonal_align == kRight) {
1570       // offset = d<=0 ? max_diag_len - diag_len_d : 0
1571       cmp = rewriter.create<TF::LessEqualOp>(loc, d, b_zero);
1572     } else {
1573       // offset = 0
1574       cmp = b_false;
1575     }
1576 
1577     // This offset shifts the diagonals to the "left" or "right", depending
1578     // on alignment.
1579     Value offset = rewriter.create<SelectOp>(
1580         loc, b_zero.getType(), cmp,
1581         rewriter.create<SubOp>(loc, b_max_diag_len, diag_len_d), b_zero);
1582 
1583     // x = max(d, 0) - offset
1584     // y = max(-d, 0) - offset
1585     Value x = rewriter.create<SubOp>(
1586         loc, rewriter.create<MaxOp>(loc, d, b_zero), offset);
1587     Value y = rewriter.create<SubOp>(
1588         loc, rewriter.create<MaxOp>(loc, neg_d, b_zero), offset);
1589 
1590     Value n_plus_x = rewriter.create<AddOp>(loc, iotaN, x);
1591     Value n_plus_y = rewriter.create<AddOp>(loc, iotaN, y);
1592 
1593     // GatherOp is happy about letting us index out of bounds values, but those
1594     // values will be undefined. So we mask them later. Set up the boolean
1595     // expression that tells us which entries, in the output shape, are out of
1596     // bounds and thus become the padding_value.
1597     Value x_in_bounds = rewriter.create<AndOp>(
1598         loc,
1599         rewriter.create<TF::GreaterEqualOp>(loc, b_false.getType(), n_plus_x,
1600                                             b_zero),
1601         rewriter.create<TF::LessOp>(loc, b_false.getType(), n_plus_x, b_cols));
1602     Value y_in_bounds = rewriter.create<AndOp>(
1603         loc,
1604         rewriter.create<TF::GreaterEqualOp>(loc, b_false.getType(), n_plus_y,
1605                                             b_zero),
1606         rewriter.create<TF::LessOp>(loc, b_false.getType(), n_plus_y, b_rows));
1607     Value in_bounds = rewriter.create<ReshapeOp>(
1608         loc,
1609         RankedTensorType::get(Shape({num_diags, max_diag_len}),
1610                               rewriter.getIntegerType(1)),
1611         rewriter.create<AndOp>(loc, x_in_bounds, y_in_bounds));
1612 
1613     // Now combine x and y into the index data structure needed for gather.
1614     Shape concat_shape({2, num_diags, max_diag_len});
1615     Value start_indices = rewriter.create<ConcatenateOp>(
1616         loc, RankedTensorType::get(concat_shape, rewriter.getIntegerType(32)),
1617         mlir::ValueRange({n_plus_y, n_plus_x}),
1618         mlir::IntegerAttr::get(rewriter.getIntegerType(64), 0));
1619 
1620     // Shape of the final output. (Except for dimension folding in the
1621     // single diagonal case.)
1622     Shape output_shape;
1623     for (int i = 0; i < num_dims - 2; i++) {
1624       output_shape.push_back(input_type.getDimSize(i));
1625     }
1626     output_shape.push_back(num_diags);
1627     output_shape.push_back(max_diag_len);
1628     auto output_type = RankedTensorType::get(output_shape, element_type);
1629 
1630     // A slice is the shape of what GatherOp copies per lookup. So the last
1631     // two dimensions (M, N in the matrix-diag-part docs) are where we go
1632     // through entry by entry.
1633     ArrayRef<int64_t> input_shape = input_type.getShape();
1634     Shape slice_sizes(input_shape.begin(), input_shape.end());
1635     int slice_dimensions = slice_sizes.size();
1636     slice_sizes[slice_dimensions - 2] = 1;
1637     slice_sizes[slice_dimensions - 1] = 1;
1638 
1639     // Dimensions of the input we won't see in the output (M and N).
1640     SmallVector<int64_t, 2> collapsed_dims(
1641         {slice_dimensions - 2, slice_dimensions - 1});
1642 
1643     // Which dimensions (in the input) the two offset "columns" map to.
1644     SmallVector<int64_t, 2> start_index_map({num_dims - 2, num_dims - 1});
1645 
1646     // Gather the diagonal entries.
1647     // TODO(kramm): For a single diagonal, this might be slower than the
1648     //              mask + sum approach. Special-case num_diags==1?
1649     auto dims_attr = GatherDimensionNumbers::get(
1650         /*offset_dims=*/GetI64ElementsAttrForSeq(0, num_dims - 2, &rewriter),
1651         /*collapsed_slice_dims=*/GetI64ElementsAttr(collapsed_dims, &rewriter),
1652         /*start_index_map=*/GetI64ElementsAttr(start_index_map, &rewriter),
1653         /*index_vector_dim=*/rewriter.getI64IntegerAttr(0),
1654         rewriter.getContext());
1655     Value gather = rewriter.create<mhlo::GatherOp>(
1656         loc, output_type, op.input(), start_indices, dims_attr,
1657         GetI64ElementsAttr(slice_sizes, &rewriter));
1658 
1659     // We now need to broadcast the "in_bounds" boolean expression, as well as
1660     // the padding value, to do the final select.
1661     Shape broadcast_bounds;
1662     for (int i = 0; i < output_shape.size() - 2; i++) {
1663       broadcast_bounds.push_back(output_shape[i]);
1664     }
1665     Value b_in_bounds = rewriter.create<BroadcastOp>(
1666         loc, RankedTensorType::get(output_shape, rewriter.getIntegerType(1)),
1667         in_bounds, GetI64ElementsAttr(broadcast_bounds, &rewriter));
1668     Value b_padding = rewriter.create<BroadcastOp>(
1669         loc, output_type, op.padding_value(),
1670         GetI64ElementsAttr(output_shape, &rewriter));
1671 
1672     // Replace all out-of-bounds values in the result with padding_value.
1673     Value result = rewriter.create<SelectOp>(loc, output_type, b_in_bounds,
1674                                              gather, b_padding);
1675 
1676     if (num_diags == 1) {
1677       // matrix_diag_part folds away the 1-sized band dimension if we only
1678       // extract a single diagonal.
1679       result = rewriter.create<ReshapeOp>(loc, op.getType(), result);
1680     }
1681 
1682     rewriter.replaceOp(op, result);
1683     return success();
1684   }
1685 };
1686 
1687 // Converts TensorFlow EinsumOp to either HLO EinsumOp or UnaryEinsumOp
1688 // depending on arity of the op.
1689 class ConvertEinsumOp : public OpRewritePattern<TF::EinsumOp> {
1690  public:
1691   using OpRewritePattern::OpRewritePattern;
1692 
matchAndRewrite(TF::EinsumOp op,PatternRewriter & rewriter) const1693   LogicalResult matchAndRewrite(TF::EinsumOp op,
1694                                 PatternRewriter &rewriter) const override {
1695     StringAttr equation = op->getAttrOfType<StringAttr>("equation");
1696     if (op.N() == 1) {
1697       rewriter.replaceOpWithNewOp<UnaryEinsumOp>(
1698           op, op.getType(), *op.inputs().begin(), equation);
1699     } else if (op.N() == 2) {
1700       ValueRange inputs = op.inputs();
1701       rewriter.replaceOpWithNewOp<EinsumOp>(op, op.getType(), inputs[0],
1702                                             inputs[1], equation);
1703     } else {
1704       // TensorFlow EinsumOp verifies that the number of operands are at most
1705       // two.
1706       return failure();
1707     }
1708     return success();
1709   }
1710 };
1711 
1712 // Bypasses IdentityN op.
1713 class ConvertIdentityNOp : public OpRewritePattern<TF::IdentityNOp> {
1714  public:
1715   using OpRewritePattern<TF::IdentityNOp>::OpRewritePattern;
matchAndRewrite(TF::IdentityNOp op,PatternRewriter & rewriter) const1716   LogicalResult matchAndRewrite(TF::IdentityNOp op,
1717                                 PatternRewriter &rewriter) const override {
1718     rewriter.replaceOp(op, op.getOperands());
1719     return success();
1720   }
1721 };
1722 
1723 template <typename OpTy>
1724 class ConvertFFTOp : public OpRewritePattern<OpTy> {
1725  public:
1726   using OpRewritePattern<OpTy>::OpRewritePattern;
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const1727   LogicalResult matchAndRewrite(OpTy op,
1728                                 PatternRewriter &rewriter) const override {
1729     auto input_ty = op.input().getType().template cast<ShapedType>();
1730     if (!input_ty.hasRank()) {
1731       return failure();
1732     }
1733     auto input_shape = input_ty.getShape();
1734     DenseIntElementsAttr fft_length_attr;
1735     if (!matchPattern(op.fft_length(), m_Constant(&fft_length_attr))) {
1736       return failure();
1737     }
1738     int64_t fft_length;
1739     if (fft_length_attr.getNumElements() != 0) {
1740       fft_length = fft_length_attr.getValue<IntegerAttr>(0).getInt();
1741     } else {
1742       return failure();
1743     }
1744 
1745     std::string fft_string = "RFFT";
1746     if (typeid(OpTy) == typeid(TF::IRFFTOp)) {
1747       fft_length = fft_length / 2 + 1;
1748       fft_string = "IRFFT";
1749     }
1750     auto loc = op.getLoc();
1751 
1752     // The inner-most dim cannot be dynamic.
1753     if (input_ty.isDynamicDim(input_shape.size() - 1)) {
1754       return failure();
1755     }
1756 
1757     auto expected_shape = llvm::to_vector<4>(input_shape.drop_back());
1758     expected_shape.push_back(fft_length);
1759 
1760     // Zero pad or truncate the last axis
1761     Value reshaped = op.input();
1762     SmallVector<int64_t, 4> begin_indices(input_shape.size(), 0);
1763     SmallVector<int64_t, 4> strides(input_shape.size(), 1);
1764 
1765     // Last dim larger than fft_length, slice the input
1766     if (input_shape.back() > fft_length) {
1767       reshaped = rewriter.create<SliceOp>(
1768           op.getLoc(),
1769           RankedTensorType::get(expected_shape, input_ty.getElementType()),
1770           op.input(), GetI64ElementsAttr(begin_indices, &rewriter),
1771           GetI64ElementsAttr(expected_shape, &rewriter),
1772           GetI64ElementsAttr(strides, &rewriter));
1773 
1774       // Last dim smaller than fft_length, zero-pad the input
1775     } else if (input_ty.getShape().back() < fft_length) {
1776       SmallVector<int64_t, 4> no_padding(input_shape.size(), 0);
1777       SmallVector<int64_t, 4> padding(input_shape.size() - 1, 0);
1778       padding.push_back(fft_length - input_shape.back());
1779       Value zero =
1780           GetScalarConstOfType(input_ty.getElementType(), loc, 0, &rewriter);
1781       reshaped = rewriter.create<PadOp>(
1782           loc, RankedTensorType::get(expected_shape, input_ty.getElementType()),
1783           op.input(), zero, GetI64ElementsAttr(no_padding, &rewriter),
1784           GetI64ElementsAttr(padding, &rewriter),
1785           GetI64ElementsAttr(no_padding, &rewriter));
1786     }
1787 
1788     rewriter.replaceOpWithNewOp<FftOp>(op, op.getType(), reshaped, fft_string,
1789                                        rewriter.getI64TensorAttr(fft_length));
1790     return success();
1791   }
1792 };
1793 
1794 using ConvertRFFTOp = ConvertFFTOp<TF::RFFTOp>;
1795 using ConvertIRFFTOp = ConvertFFTOp<TF::IRFFTOp>;
1796 
1797 // The base class to convert TensorFlow FusedBatchNormGrad*Op to HLO
1798 // BatchNormGradOp for training and a sequence of binary ops for inference.
1799 // TODO(b/145536565): move to legalize_tf_patterns.td if it applies.
1800 template <typename FusedBatchNormGradOpT>
1801 class ConvertFusedBatchNormGradBase
1802     : public OpRewritePattern<FusedBatchNormGradOpT> {
1803  public:
1804   using OpRewritePattern<FusedBatchNormGradOpT>::OpRewritePattern;
1805 
matchAndRewrite(FusedBatchNormGradOpT op,PatternRewriter & rewriter) const1806   LogicalResult matchAndRewrite(FusedBatchNormGradOpT op,
1807                                 PatternRewriter &rewriter) const override {
1808     Location loc = op.getLoc();
1809     Value grad = op.y_backprop();
1810     Value act = op.x();
1811     Value scale = op.scale();
1812     Value mean = op.reserve_space_1();
1813     Value var = op.reserve_space_2();
1814 
1815     // TODO(b/141785544): Update this to not require static shapes.
1816     // activation shape needs to be static to convert negative indices in
1817     // TensorFlow to absolute indices required by HLO.
1818     RankedTensorType act_type =
1819         act.getType().template dyn_cast<RankedTensorType>();
1820     if (!act_type) return failure();
1821     Type act_ele_type = act_type.getElementType();
1822     // To support mixed precision, the statistics type, which maybe more
1823     // precise than the input types, are used for this op.
1824     Type kernel_type =
1825         scale.getType().template cast<TensorType>().getElementType();
1826     grad = rewriter.create<ConvertOp>(loc, grad, kernel_type);
1827     act = rewriter.create<ConvertOp>(loc, act, kernel_type);
1828 
1829     auto feature_dim_attr =
1830         getFeatureDimensionAttr(rewriter, op.data_format(), act);
1831     auto feature_dim = feature_dim_attr.getValue().getSExtValue();
1832 
1833     // Gets the result values.
1834     Value x_backprop, scale_backprop, offset_backprop;
1835     if (op.is_training()) {  // training
1836       // TODO(b/145536565): handle GPU logic separately.
1837       // Infers the output type with the converted `act`.
1838       Type feature_type = RankedTensorType::get(
1839           {GetDimSize(act_type, feature_dim)}, kernel_type);
1840       Type result_type = TupleType::get(
1841           rewriter.getContext(), {act.getType(), feature_type, feature_type});
1842 
1843       auto training_op = rewriter.create<BatchNormGradOp>(
1844           loc, result_type, act, scale, mean, var, grad, op.epsilon(),
1845           feature_dim);
1846 
1847       x_backprop =
1848           rewriter.create<GetTupleElementOp>(loc, training_op.getResult(), 0);
1849 
1850       scale_backprop =
1851           rewriter.create<GetTupleElementOp>(loc, training_op.getResult(), 1);
1852 
1853       offset_backprop =
1854           rewriter.create<GetTupleElementOp>(loc, training_op.getResult(), 2);
1855     } else {  // inference
1856       SmallVector<int64_t, 4> non_feature_dims;
1857       for (int64_t i = 0; i < act_type.getRank(); ++i) {
1858         if (i == feature_dim) continue;
1859         non_feature_dims.push_back(i);
1860       }
1861       auto reduce_dims = GetI64ElementsAttr(non_feature_dims, &rewriter);
1862       auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter);
1863 
1864       // scratch1 = rsqrt(var + epsilon)
1865       RankedTensorType scalar_float = RankedTensorType::get({}, kernel_type);
1866       auto epsilon = rewriter.create<ConstOp>(
1867           loc, DenseFPElementsAttr::get(scalar_float, {op.epsilon()}));
1868       auto add_op = rewriter.create<chlo::BroadcastAddOp>(
1869           loc, var, epsilon.getResult(), scalar_broadcast_dims);
1870 
1871       Value scratch1 = rewriter.create<RsqrtOp>(loc, add_op);
1872 
1873       // scratch2 = sum(y_backprop * (x - mean))
1874       auto sub_op = rewriter.create<mhlo::SubOp>(
1875           loc, act,
1876           Broadcast1DToFeatureDim(loc, act, mean, feature_dim, rewriter));
1877       auto weighted_grad = rewriter.create<mhlo::MulOp>(loc, grad, sub_op);
1878       Value scratch2 =
1879           ApplyReduction(loc, weighted_grad, reduce_dims, &rewriter);
1880 
1881       // x_backprop = y_backprop * (scale * scratch1)
1882       auto scaled_grad =
1883           rewriter.create<mhlo::MulOp>(loc, op.scale(), scratch1);
1884       x_backprop = rewriter.create<mhlo::MulOp>(
1885           loc, grad,
1886           Broadcast1DToFeatureDim(loc, act, scaled_grad, feature_dim,
1887                                   rewriter));
1888 
1889       // scale_backprop = scratch2 * scratch1
1890       scale_backprop = rewriter.create<mhlo::MulOp>(loc, scratch1, scratch2);
1891 
1892       // offset_backprop = sum(y_backprop)
1893       offset_backprop = ApplyReduction(loc, grad, reduce_dims, &rewriter);
1894     }
1895 
1896     x_backprop = rewriter.create<ConvertOp>(loc, x_backprop, act_ele_type);
1897     Value last_val[2];
1898     if (op.getResult(3).use_empty() && op.getResult(4).use_empty()) {
1899       // It doesn't matter what values we provide for the last 2 results.
1900       last_val[0] = last_val[1] = op.x();
1901     } else {
1902       auto const_val = rewriter.create<ConstOp>(
1903           op.getLoc(),
1904           DenseElementsAttr::get<float>(
1905               RankedTensorType::get({0}, getElementTypeOrSelf(op.getResult(3))),
1906               0.0));
1907       auto maybe_cast = [&](Value val, Type t) -> Value {
1908         if (val.getType() == t) return val;
1909         return rewriter.create<tensor::CastOp>(op.getLoc(), t, val);
1910       };
1911       last_val[0] = maybe_cast(const_val, op.getResult(3).getType());
1912       last_val[1] = maybe_cast(const_val, op.getResult(4).getType());
1913     }
1914     rewriter.replaceOp(
1915         op, {/*x_backprop=*/x_backprop,
1916              /*scale_backprop=*/scale_backprop,
1917              /*offset_backprop=*/offset_backprop, last_val[0], last_val[1]});
1918     return success();
1919   }
1920 };
1921 
1922 using ConvertFusedBatchNormGradOp =
1923     ConvertFusedBatchNormGradBase<TF::FusedBatchNormGradOp>;
1924 using ConvertFusedBatchNormGradV2Op =
1925     ConvertFusedBatchNormGradBase<TF::FusedBatchNormGradV2Op>;
1926 using ConvertFusedBatchNormGradV3Op =
1927     ConvertFusedBatchNormGradBase<TF::FusedBatchNormGradV3Op>;
1928 
1929 // Converts TensorFlow FusedBatchNormV3Op to either HLO BatchNormTrainingOp or
1930 // HLO BatchNormInferenceOp, depending on the value of the 'is_training'
1931 // parameter.
1932 template <typename FusedBatchNormOpT>
1933 class ConvertFusedBatchNormBase : public OpRewritePattern<FusedBatchNormOpT> {
1934  public:
1935   using OpRewritePattern<FusedBatchNormOpT>::OpRewritePattern;
1936 
matchAndRewrite(FusedBatchNormOpT op,PatternRewriter & rewriter) const1937   LogicalResult matchAndRewrite(FusedBatchNormOpT op,
1938                                 PatternRewriter &rewriter) const override {
1939     auto feature_dim =
1940         getFeatureDimensionAttr(rewriter, op.data_format(), op.x());
1941 
1942     auto input_type_tensor = op.x().getType().template cast<TensorType>();
1943     auto input_element_type = input_type_tensor.getElementType();
1944 
1945     auto scale_type_tensor = op.scale().getType().template cast<TensorType>();
1946     auto scale_element_type = scale_type_tensor.getElementType();
1947 
1948     auto mean_type_tensor = op.mean().getType().template cast<TensorType>();
1949     auto mean_element_type = mean_type_tensor.getElementType();
1950     // In the training case, dimensions of input tensors must be static.
1951     if (op.is_training() && (!input_type_tensor.hasStaticShape() ||
1952                              !scale_type_tensor.hasStaticShape() ||
1953                              !mean_type_tensor.hasStaticShape()))
1954       return failure();
1955 
1956     // TODO(b/69928690): Support mixed precision in the XLA batch
1957     // normalization operators. As a workaround, create a new x with the same
1958     // element type as scale (which may be more precise than the input type).
1959     Value bn_train_input = rewriter.create<mhlo::ConvertOp>(op.getLoc(), op.x(),
1960                                                             scale_element_type);
1961     TensorType bn_train_input_type_tensor =
1962         bn_train_input.getType().template cast<TensorType>();
1963 
1964     if (op.is_training()) {
1965       // Training case.
1966       auto operand_shape = bn_train_input_type_tensor.getShape();
1967       // The mean and variance are each 1 dimensional arrays the size of the
1968       // feature dimension, with the same element type as the operand (x).
1969       // This shape must be constructed manually because the mean and variance
1970       // inputs are empty in the training case.
1971       Type mean_var_type = RankedTensorType::get(
1972           {operand_shape[feature_dim.getInt()]}, scale_element_type);
1973       // Op result type is a tuple of 3 values: output with same shape as input;
1974       // batch_mean, and batch_var.
1975       SmallVector<Type, 3> operand_types = {bn_train_input_type_tensor,
1976                                             mean_var_type, mean_var_type};
1977       Type result_type = TupleType::get(rewriter.getContext(), operand_types);
1978 
1979       auto bn_train_op = rewriter.create<mhlo::BatchNormTrainingOp>(
1980           op.getLoc(), result_type, bn_train_input, op.scale(), op.offset(),
1981           op.epsilon(), feature_dim.getInt());
1982       // HLO op outputs a tuple of tensors. Extract those results.
1983       auto bn_train_op_result = bn_train_op.getResult();
1984       Value y_out = rewriter.create<mhlo::GetTupleElementOp>(
1985           op.getLoc(), bn_train_op_result, 0);
1986       Value batch_mean = rewriter.create<mhlo::GetTupleElementOp>(
1987           op.getLoc(), bn_train_op_result, 1);
1988       Value reserve_space_1 = batch_mean;
1989       Value batch_variance = rewriter.create<mhlo::GetTupleElementOp>(
1990           op.getLoc(), bn_train_op_result, 2);
1991 
1992       // Apply Bessel's correction on the variance.
1993       int total_input_size = bn_train_input_type_tensor.getNumElements();
1994       int total_scale_size = scale_type_tensor.getNumElements();
1995       int sample_size = total_input_size / total_scale_size;
1996       int sample_size_minus_one = std::max(1, sample_size - 1);
1997       double factor = static_cast<double>(sample_size) /
1998                       static_cast<double>(sample_size_minus_one);
1999       auto factor_const_op = rewriter.create<mhlo::ConstOp>(
2000           op.getLoc(), rewriter.getFloatAttr(scale_element_type, factor));
2001 
2002       Value corrected_variance = rewriter.create<chlo::BroadcastMulOp>(
2003           op.getLoc(), batch_variance.getType(), batch_variance,
2004           factor_const_op, /*broadcast_dimensions=*/DenseIntElementsAttr());
2005 
2006       // Convert back to input type to stay aligned with expected output type
2007       // for TF op.
2008       y_out = rewriter.create<mhlo::ConvertOp>(op.getLoc(), y_out,
2009                                                input_element_type);
2010 
2011       float exponential_avg_factor =
2012           op.exponential_avg_factor().convertToFloat();
2013       if (exponential_avg_factor != 1.0f) {
2014         auto alpha = rewriter.create<mhlo::ConstOp>(
2015             op.getLoc(), rewriter.getFloatAttr(mean_element_type,
2016                                                1.0f - exponential_avg_factor));
2017         auto beta = rewriter.create<mhlo::ConstOp>(
2018             op.getLoc(),
2019             rewriter.getFloatAttr(mean_element_type, exponential_avg_factor));
2020 
2021         // new_running_mean = alpha * old_mean + beta * batch_mean.
2022         auto alpha_mul_old_mean = rewriter.create<chlo::BroadcastMulOp>(
2023             op.getLoc(), op.mean().getType(), alpha, op.mean(),
2024             /*broadcast_dimensions=*/DenseIntElementsAttr());
2025         auto beta_mul_batch_mean = rewriter.create<chlo::BroadcastMulOp>(
2026             op.getLoc(), batch_mean.getType(), beta, batch_mean,
2027             /*broadcast_dimensions=*/DenseIntElementsAttr());
2028         batch_mean = rewriter.create<chlo::BroadcastAddOp>(
2029             op.getLoc(), alpha_mul_old_mean, beta_mul_batch_mean,
2030             /*broadcast_dimensions=*/DenseIntElementsAttr());
2031 
2032         // new_running_variance = alpha * old_variance + beta * batch_variance.
2033         auto alpha_mul_old_variance = rewriter.create<chlo::BroadcastMulOp>(
2034             op.getLoc(), op.variance().getType(), alpha, op.variance(),
2035             /*broadcast_dimensions=*/DenseIntElementsAttr());
2036         auto beta_mul_batch_variance = rewriter.create<chlo::BroadcastMulOp>(
2037             op.getLoc(), corrected_variance.getType(), beta, corrected_variance,
2038             /*broadcast_dimensions=*/DenseIntElementsAttr());
2039         corrected_variance = rewriter.create<chlo::BroadcastAddOp>(
2040             op.getLoc(), alpha_mul_old_variance, beta_mul_batch_variance,
2041             /*broadcast_dimensions=*/DenseIntElementsAttr());
2042       }
2043 
2044       if (std::is_same<FusedBatchNormOpT, TF::FusedBatchNormV2Op>::value) {
2045         // FusedBatchNormV2 expects 4 outputs.
2046         // Outputs 3 and 4 are currently marked as "reserved spaces 1 and 2".
2047         // They are used to pass the per-batch mean and variance to the
2048         // gradiant. Here we maintain the same behavior by setting them to the
2049         // mean and variance calculated by BatchNormTraining.
2050         rewriter.replaceOp(op, {y_out, /*batch_mean=*/batch_mean,
2051                                 /*batch_variance=*/corrected_variance,
2052                                 /*reserve_space_1=*/reserve_space_1,
2053                                 /*reserve_space_2=*/batch_variance});
2054       } else {  // TF::FusedBatchNormV3Op
2055         // For FusedBatchNormV3Op, also create a constant tensor to forward to
2056         // last reserve_space_3 output.
2057         auto reserve_space_3_type =
2058             op.getResult(5).getType().template cast<TensorType>();
2059         int num_elements = reserve_space_3_type.hasStaticShape()
2060                                ? reserve_space_3_type.getNumElements()
2061                                : 0;
2062         auto const_attr_type = RankedTensorType::get(
2063             {num_elements}, getElementTypeOrSelf(reserve_space_3_type));
2064         Value dummy_const = rewriter.create<ConstOp>(
2065             op.getLoc(), DenseElementsAttr::get<float>(const_attr_type, 0.0));
2066         if (const_attr_type != reserve_space_3_type)
2067           dummy_const = rewriter.create<tensor::CastOp>(
2068               op.getLoc(), reserve_space_3_type, dummy_const);
2069         rewriter.replaceOp(op, {y_out, /*batch_mean=*/batch_mean,
2070                                 /*batch_variance=*/corrected_variance,
2071                                 /*reserve_space_1=*/reserve_space_1,
2072                                 /*reserve_space_2=*/batch_variance,
2073                                 /*reserve_space_3=*/dummy_const});
2074       }
2075     } else {  // Inference case.
2076       auto bn_train_op = rewriter.create<BatchNormInferenceOp>(
2077           op.getLoc(),
2078           /*result_type=*/bn_train_input_type_tensor, bn_train_input,
2079           op.scale(), op.offset(), op.mean(), op.variance(), op.epsilon(),
2080           feature_dim.getInt());
2081 
2082       // Convert back to input type to stay aligned with expected output type
2083       // for TF op.
2084       auto y_out = rewriter.create<mhlo::ConvertOp>(op.getLoc(), bn_train_op,
2085                                                     input_element_type);
2086 
2087       // The mean, variance, and reserved space outputs of the batch norm op are
2088       // not used for inference. It doesn't matter what values we provide for
2089       // the last 5 results as long as they are of the same type. Forward
2090       // input mean and variance to output mean, variance, reserved_space_1 and
2091       // reserved_space_2.
2092       if (std::is_same<FusedBatchNormOpT, TF::FusedBatchNormV2Op>::value) {
2093         rewriter.replaceOp(op, {/*y=*/y_out,
2094                                 /*batch_mean=*/op.mean(),
2095                                 /*batch_variance=*/op.variance(),
2096                                 /*reserve_space_1=*/op.mean(),
2097                                 /*reserve_space_2=*/op.variance()});
2098       } else {
2099         // For FusedBatchNormV3Op, also create a constant tensor to forward to
2100         // last reserve_space_3 output.
2101         auto reserve_space_3_type =
2102             op.getResult(5).getType().template cast<TensorType>();
2103         int num_elements = reserve_space_3_type.hasStaticShape()
2104                                ? reserve_space_3_type.getNumElements()
2105                                : 0;
2106         auto const_attr_type = RankedTensorType::get(
2107             {num_elements}, getElementTypeOrSelf(reserve_space_3_type));
2108         Value dummy_const = rewriter.create<ConstOp>(
2109             op.getLoc(), DenseElementsAttr::get<float>(const_attr_type, 0.0));
2110         if (const_attr_type != reserve_space_3_type)
2111           dummy_const = rewriter.create<tensor::CastOp>(
2112               op.getLoc(), reserve_space_3_type, dummy_const);
2113         rewriter.replaceOp(op, {/*y=*/y_out,
2114                                 /*batch_mean=*/op.mean(),
2115                                 /*batch_variance=*/op.variance(),
2116                                 /*reserve_space_1=*/op.mean(),
2117                                 /*reserve_space_2=*/op.variance(),
2118                                 /*reserve_space_3=*/dummy_const});
2119       }
2120     }
2121     return success();
2122   }
2123 };
2124 
2125 using ConvertFusedBatchNormV2Op =
2126     ConvertFusedBatchNormBase<TF::FusedBatchNormV2Op>;
2127 using ConvertFusedBatchNormV3Op =
2128     ConvertFusedBatchNormBase<TF::FusedBatchNormV3Op>;
2129 
2130 using PaddingArray =
2131     std::vector<std::pair<tensorflow::int64, tensorflow::int64>>;
2132 
2133 // Returns padding values for ReduceWindow op as a vector of pairs.
2134 //
2135 // Requires padding to be either 'SAME' or 'VALID' and the number of input
2136 // dimensions to be equal to the size of window dimensions and window strides.
2137 template <int num_dims>
GetReduceWindowPaddingAsArray(llvm::ArrayRef<int64_t> input_dims,ArrayAttr window_dims,ArrayAttr window_strides,StringRef padding,Builder * builder)2138 static PaddingArray GetReduceWindowPaddingAsArray(
2139     llvm::ArrayRef<int64_t> input_dims, ArrayAttr window_dims,
2140     ArrayAttr window_strides, StringRef padding, Builder *builder) {
2141   if (padding == "VALID") {
2142     return PaddingArray(num_dims, std::make_pair(0, 0));
2143   }
2144   assert(padding == "SAME");
2145   llvm::SmallVector<tensorflow::int64, num_dims> input_shape, window_shape,
2146       strides;
2147   input_shape.reserve(input_dims.size());
2148   window_shape.reserve(window_shape.size());
2149   strides.reserve(window_strides.size());
2150 
2151   for (const auto &dim : input_dims) input_shape.push_back(dim);
2152   for (Attribute attr : window_dims)
2153     window_shape.push_back(attr.cast<IntegerAttr>().getInt());
2154   for (Attribute attr : window_strides)
2155     strides.push_back(attr.cast<IntegerAttr>().getInt());
2156 
2157   PaddingArray paddings = ::xla::MakePadding(input_shape, window_shape, strides,
2158                                              ::xla::Padding::kSame);
2159   return paddings;
2160 }
2161 
2162 // Same as GetReduceWindowPaddingAsArray but returns padding as
2163 // DenseIntElementsAttr. Returns empty attribute for `VALID` padding.
2164 template <int num_dims>
GetReduceWindowPaddingAsAttr(llvm::ArrayRef<int64_t> input_dims,ArrayAttr window_dims,ArrayAttr window_strides,StringRef padding,Builder * builder)2165 static DenseIntElementsAttr GetReduceWindowPaddingAsAttr(
2166     llvm::ArrayRef<int64_t> input_dims, ArrayAttr window_dims,
2167     ArrayAttr window_strides, StringRef padding, Builder *builder) {
2168   if (padding == "VALID") return {};
2169   assert(padding == "SAME");
2170   PaddingArray paddings = GetReduceWindowPaddingAsArray<num_dims>(
2171       input_dims, window_dims, window_strides, padding, builder);
2172   int64_t rank = paddings.size();
2173   llvm::SmallVector<int64_t, num_dims * 2> flatten_paddings(rank * 2);
2174   for (int i = 0; i < rank; i++) {
2175     flatten_paddings[2 * i] = paddings[i].first;
2176     flatten_paddings[2 * i + 1] = paddings[i].second;
2177   }
2178   return DenseIntElementsAttr::get(
2179       RankedTensorType::get({rank, 2}, builder->getIntegerType(64)),
2180       flatten_paddings);
2181 }
2182 
2183 // Helper function for dividing each entry of `pooled` by the count of its
2184 // corresponding window, i.e., the number of non-padding entries of the window
2185 // which an `AvgPool` operation performed on an `input_shape`-tensor would map
2186 // to this entry, depending on `ksize` and `strides`. This function is used for
2187 // `AvgPool` and `AvgPoolGrad` legalizations.
2188 // `zero` is passed as a parameter because it can be reused from caller level.
2189 // `pooled` must have `RankedTensorType`.
2190 template <typename OpTy, int num_dims>
AvgPoolDivideByCount(Value pooled,const SmallVector<int64_t,num_dims> & input_shape,const SmallVector<int64_t,num_dims> & ksize,const SmallVector<int64_t,num_dims> & strides,OpTy op,Value zero,PatternRewriter & rewriter)2191 Operation *AvgPoolDivideByCount(
2192     Value pooled, const SmallVector<int64_t, num_dims> &input_shape,
2193     const SmallVector<int64_t, num_dims> &ksize,
2194     const SmallVector<int64_t, num_dims> &strides, OpTy op, Value zero,
2195     PatternRewriter &rewriter) {
2196   Location loc = op.getLoc();
2197   RankedTensorType pooled_type =
2198       pooled.getType().template cast<RankedTensorType>();
2199   Type element_type = pooled_type.getElementType();
2200   Operation *result = nullptr;
2201   RankedTensorType orig_input_type =
2202       RankedTensorType::get(input_shape, element_type);
2203 
2204   if (op.padding() == "VALID") {
2205     // All window counts are equal here because we don't have padding
2206     // (each entry of `pooled` corresponds to a window that consists of
2207     //  original input entries only).
2208     int64_t window_count = std::accumulate(ksize.begin(), ksize.end(), 1,
2209                                            std::multiplies<int64_t>());
2210     // Divide `pooled` by window counts.
2211     Value divisor =
2212         GetScalarConstOfType(element_type, loc, window_count, &rewriter);
2213     auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter);
2214     result = rewriter.create<chlo::BroadcastDivOp>(
2215         loc, pooled_type, pooled, divisor, scalar_broadcast_dims);
2216   } else {
2217     assert(op.padding() == "SAME");
2218     // For SAME padding, only original entries that contributed to a window
2219     // are counted for the average of this window, not padded entries.
2220 
2221     // Build all-ones tensor of same shape as the original input.
2222     ElementsAttr splat = hlo::getSplat(&rewriter, orig_input_type, 1);
2223     auto all_ones_tensor = rewriter.create<ConstOp>(loc, splat);
2224 
2225     // Get padding for the input.
2226     DenseIntElementsAttr input_padding_attr =
2227         GetReduceWindowPaddingAsAttr<num_dims>(
2228             input_shape, op.ksize(), op.strides(), op.padding(), &rewriter);
2229 
2230     // Count the 1's in each window, using the same padding as for the input,
2231     // which gives us the window counts by which `pooled` needs to be divided.
2232     auto divisor = rewriter.create<ReduceWindowOp>(
2233         loc, pooled_type,
2234         /*operand=*/all_ones_tensor,
2235         /*init_value=*/zero,
2236         /*window_dimensions=*/GetI64ElementsAttr(op.ksize()),
2237         /*window_strides=*/GetI64ElementsAttr(op.strides()),
2238         /*base_dilations=*/DenseIntElementsAttr(),
2239         /*window_dilations=*/DenseIntElementsAttr(),
2240         /*padding=*/input_padding_attr);
2241     BuildReduceBody<AddOp>(element_type, &divisor.body(), &rewriter);
2242 
2243     // Divide `pooled` by window counts.
2244     result = rewriter.create<mhlo::DivOp>(loc, pooled_type, pooled, divisor);
2245   }
2246   return result;
2247 }
2248 
GetAvgPoolInput(TF::AvgPoolOp op)2249 Value GetAvgPoolInput(TF::AvgPoolOp op) { return op.value(); }
GetAvgPoolInput(TF::AvgPool3DOp op)2250 Value GetAvgPoolInput(TF::AvgPool3DOp op) { return op.input(); }
2251 
2252 // Converts AvgPool op to HLO ReduceWindow op by setting appropriate window
2253 // dimensions with add as the reduction function. The reduction result is
2254 // then divided by the number of elements in the window.
2255 template <typename OpTy, int num_dims>
2256 class ConvertAvgPoolOp : public OpRewritePattern<OpTy> {
2257  public:
2258   using OpRewritePattern<OpTy>::OpRewritePattern;
2259 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const2260   LogicalResult matchAndRewrite(OpTy op,
2261                                 PatternRewriter &rewriter) const override {
2262     Value input_value = GetAvgPoolInput(op);
2263     auto input_type =
2264         input_value.getType().template dyn_cast<RankedTensorType>();
2265     if (!input_type) return failure();
2266 
2267     // We will do accumulation first; use a larger bitwidth if suitable.
2268     Type input_element_type = input_type.getElementType();
2269     Type sum_element_type = GetSumAccumulationType(input_element_type);
2270     Type result_type;
2271 
2272     // The result type for reduction and division with the proper element type.
2273     if (auto ranked_type = op.getType().template dyn_cast<RankedTensorType>())
2274       result_type =
2275           RankedTensorType::get(ranked_type.getShape(), sum_element_type);
2276     else
2277       result_type = UnrankedTensorType::get(sum_element_type);
2278 
2279     // Convert if we need enlarge the element type's bitwidth.
2280     if (input_element_type != sum_element_type)
2281       input_value = rewriter.create<ConvertOp>(op.getLoc(), input_value,
2282                                                sum_element_type);
2283 
2284     // Create the tf.ReduceWindow op.
2285     Value init =
2286         GetScalarConstOfType(sum_element_type, op.getLoc(), 0, &rewriter);
2287     DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr<num_dims>(
2288         input_type.getShape(), op.ksize(), op.strides(), op.padding(),
2289         &rewriter);
2290     auto reduce = rewriter.create<ReduceWindowOp>(
2291         op.getLoc(), result_type, input_value, init,
2292         GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()),
2293         /*base_dilations=*/DenseIntElementsAttr(),
2294         /*window_dilations=*/DenseIntElementsAttr(), paddings_attr);
2295     BuildReduceBody<AddOp>(sum_element_type, &reduce.body(), &rewriter);
2296 
2297     // Count the number of elements in the window. The following calculation
2298     // is only valid for no paddings.
2299     SmallVector<int64_t, num_dims> input_shape(
2300         llvm::to_vector<num_dims>(input_type.getShape()));
2301     SmallVector<int64_t, num_dims> ksize, strides;
2302     GetI64ArrayAttrValues(op.ksize(), &ksize);
2303     GetI64ArrayAttrValues(op.strides(), &strides);
2304 
2305     Operation *result_op = AvgPoolDivideByCount<OpTy, num_dims>(
2306         reduce.getResult(), input_shape, ksize, strides, op, init, rewriter);
2307 
2308     // Convert back if we enlarged the element type's bitwidth.
2309     Value result = result_op->getOpResult(0);
2310     if (input_element_type != sum_element_type)
2311       result =
2312           rewriter.create<ConvertOp>(op.getLoc(), result, input_element_type);
2313 
2314     rewriter.replaceOp(op, result);
2315     return success();
2316   }
2317 };
2318 
2319 using ConvertAvgPool2DOp = ConvertAvgPoolOp<TF::AvgPoolOp, /*num_dims=*/4>;
2320 using ConvertAvgPool3DOp = ConvertAvgPoolOp<TF::AvgPool3DOp, /*num_dims=*/5>;
2321 
2322 // `AvgPoolGradOp` is converted to the following operations:
2323 // 1. Divide each entry of the output gradient (the gradient for the previous
2324 //    layer in backpropagation order) by the count of the corresponding window
2325 //    (i.e., the number of non-padding entries of the window which `AvgPool`
2326 //    has mapped to this entry in forward propagation).
2327 // 2. Add appropriate interior and exterior padding for step 3 (see example
2328 //    below).
2329 // 3. Convolve the result of step 2. with a kernel consisting of 1's (same shape
2330 //    as windows) and stride 1 in each dimension. This is implemented as a
2331 //    `ReduceWindowOp` with `AddOp` as body.
2332 //
2333 // Example:
2334 // Let f : R^4 -> R^2 be an average pool function with window size 3, stride 2,
2335 // and SAME padding with 0's. It is defined by
2336 //    f(x) = [ (x_1 + x_2 + x_3) / 3 ]      ( x = (x_1, x_2, x_3, x_4) )
2337 //           [ (x_3 + x_4 + 0)   / 2 ]      (the 0 results from right padding)
2338 // Note that for SAME padding in `AvgPool` the padded entries are not counted
2339 // for the average, this is why the second denominator is 2 and not 3.
2340 // The Jacobian Df is
2341 //    [ 1/3  1/3  1/3  0   ]
2342 //    [ 0    0    1/2  1/2 ]
2343 //
2344 // Note that the Jacobian is constant (this is why `ConvertAvgPoolGradOp` only
2345 // needs the original input shape and not the tensor as argument).
2346 // Let v = [ 4  6 ]^T  be the output gradient (^T = transposed). Then the
2347 // average pool gradient is given by
2348 //    Df^T * v = [ 4/3  4/3  13/3  3 ]^T
2349 // Instead of a matrix-vector-multiplication we can utilize the sparsity and
2350 // structure of Df by using the 3-step approach from above:
2351 // 1. Divide output gradient v by window counts: [ 4/3  6/2 ]^T
2352 // 2. Add appropriate padding: [ 0  0  4/3  0  3  0 ]^T
2353 // 3. Convolve with kernel [ 1  1  1 ]: [ 4/3  4/3  11/3  3 ]^T
2354 //
2355 // Note that the padding in step 2. is chosen in such a way that the subsequent
2356 // convolution produces the gradient. Higher dimensions, different padding, and
2357 // different windows/strides work in a similar way, the main difference is in
2358 // the computation of the paddings in step 2.
2359 //
2360 // For more details on backpropagation for convolution of which `AvgPoolGrad`
2361 // is a special case see `tensorflow/core/kernels/conv_grad_ops.h`.
2362 // `tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir` has more
2363 // examples for different cases.
2364 template <typename OpTy, int num_dims>
2365 class ConvertAvgPoolGradOp : public OpRewritePattern<OpTy> {
2366   using DimVector = SmallVector<int64_t, num_dims>;
2367 
2368  public:
2369   using OpRewritePattern<OpTy>::OpRewritePattern;
2370 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const2371   LogicalResult matchAndRewrite(OpTy op,
2372                                 PatternRewriter &rewriter) const override {
2373     Location loc = op.getLoc();
2374     tensorflow::TensorFormat data_format;
2375     if (!FormatFromString(op.data_format().str(), &data_format)) {
2376       return failure();
2377     }
2378     // `out_grad` is the gradient that was propagated via backpropagation from
2379     // the output layer.
2380     Value out_grad = op.grad();
2381     auto out_grad_type =
2382         out_grad.getType().template dyn_cast<RankedTensorType>();
2383     if (!out_grad_type) {
2384       return failure();
2385     }
2386     Type element_type = out_grad_type.getElementType();
2387     DenseIntElementsAttr orig_input_shape_attr;
2388     if (!matchPattern(op.orig_input_shape(),
2389                       m_Constant(&orig_input_shape_attr))) {
2390       return failure();
2391     }
2392     auto orig_input_shape_values = orig_input_shape_attr.getValues<int32_t>();
2393     DimVector orig_input_shape(orig_input_shape_values.begin(),
2394                                orig_input_shape_values.end());
2395     DimVector ksize, strides;
2396     GetI64ArrayAttrValues(op.ksize(), &ksize);
2397     GetI64ArrayAttrValues(op.strides(), &strides);
2398     Value zero = GetScalarConstOfType(element_type, loc, 0, &rewriter);
2399 
2400     auto out_grad_divided = AvgPoolDivideByCount<OpTy, num_dims>(
2401         out_grad, orig_input_shape, ksize, strides, op, zero, rewriter);
2402 
2403     // Get same padding as for original input.
2404     PaddingArray orig_padding = GetReduceWindowPaddingAsArray<num_dims>(
2405         orig_input_shape, op.ksize(), op.strides(), op.padding(), &rewriter);
2406 
2407     // Add padding around `out_grad_divided` values in such a way that the
2408     // subsequent `ReduceWindowOp` produces the gradient.
2409     DimVector out_grad_shape(
2410         llvm::to_vector<num_dims>(out_grad_type.getShape()));
2411     DimVector low_padding(num_dims, 0);
2412     DimVector high_padding(num_dims, 0);
2413     DimVector interior_padding(num_dims, 0);
2414     constexpr int num_spatial_dims = num_dims - 2;
2415     for (int i = 0; i < num_spatial_dims; ++i) {
2416       int dim = tensorflow::GetTensorSpatialDimIndex(num_dims, data_format, i);
2417       int orig_input_shape_padded_in_dim = orig_input_shape[dim] +
2418                                            orig_padding[dim].first +
2419                                            orig_padding[dim].second;
2420       // Set interior padding such that neighboring entries from
2421       // `out_grad_divided` have distance `strides[dim]` from each other in
2422       // every dimension.
2423       interior_padding[dim] = strides[dim] - 1;
2424       // Set exterior padding in the same way as for convolution gradient
2425       // computation.
2426       auto status = ::xla::ConvGradExtractAndVerifyDimension(
2427           /*input_size=*/orig_input_shape_padded_in_dim,
2428           /*filter_size=*/ksize[dim],
2429           /*output_size=*/out_grad_shape[dim],
2430           /*dilation=*/1,
2431           /*stride=*/strides[dim],
2432           /*padding=*/::xla::Padding::kValid);
2433       if (!status.ok()) {
2434         return failure();
2435       }
2436       ::xla::SpatialDimensionOutputSizeAndPadding &conv_grad_spatial_dim =
2437           status.ValueOrDie();
2438       // Subtract the original exterior padding since it doesn't contribute to
2439       // the gradient. Note that we save one `PadOp` and some unnecessary kernel
2440       // computations, compared to the `xla::AvgPoolGrad` implementation, by
2441       // subtracting the original exterior padding before `ReduceWindowOp`
2442       // instead of trimming the result of `ReduceWindowOp` (the final result is
2443       // the same because all strides are 1).
2444       low_padding[dim] =
2445           conv_grad_spatial_dim.pad_before - orig_padding[dim].first;
2446       high_padding[dim] =
2447           conv_grad_spatial_dim.pad_after - orig_padding[dim].second;
2448 
2449       // Update `out_grad_shape` to result shape of following `PadOp`.
2450       out_grad_shape[dim] = low_padding[dim] + high_padding[dim] +
2451                             (out_grad_shape[dim] - 1) * strides[dim] + 1;
2452     }
2453     Value reduce_window_input = rewriter.create<PadOp>(
2454         loc, RankedTensorType::get(out_grad_shape, element_type),
2455         /*operand=*/out_grad_divided->getOpResult(0),
2456         /*padding_value=*/zero,
2457         /*edge_padding_low=*/GetI64ElementsAttr(low_padding, &rewriter),
2458         /*edge_padding_high=*/GetI64ElementsAttr(high_padding, &rewriter),
2459         /*interior_padding=*/GetI64ElementsAttr(interior_padding, &rewriter));
2460 
2461     // Compute result by convolving `reduce_window_input` with an all-ones
2462     // kernel, using `ReduceWindowOp` with `AddOp` body.
2463 
2464     Type sum_element_type = GetSumAccumulationType(element_type);
2465     if (element_type != sum_element_type) {
2466       // Convert to appropriate sum accumulation type to avoid precision loss.
2467       reduce_window_input = rewriter.create<ConvertOp>(loc, reduce_window_input,
2468                                                        sum_element_type);
2469       zero = GetScalarConstOfType(sum_element_type, loc, 0, &rewriter);
2470     }
2471     auto ones = GetI64ElementsAttr(DimVector(num_dims, 1), &rewriter);
2472     auto reduce_window_op = rewriter.create<ReduceWindowOp>(
2473         loc, RankedTensorType::get(orig_input_shape, sum_element_type),
2474         /*operand=*/reduce_window_input,
2475         /*init_value=*/zero,
2476         /*window_dimensions=*/GetI64ElementsAttr(op.ksize()),
2477         /*window_strides=*/ones,
2478         /*base_dilations=*/DenseIntElementsAttr(),
2479         /*window_dilations=*/DenseIntElementsAttr(),
2480         /*padding=*/DenseIntElementsAttr());
2481     BuildReduceBody<AddOp>(sum_element_type, &reduce_window_op.body(),
2482                            &rewriter);
2483     Value result = reduce_window_op.getResult();
2484 
2485     if (element_type != sum_element_type) {
2486       // Convert back to original element type.
2487       result = rewriter.create<ConvertOp>(op.getLoc(), result, element_type);
2488     }
2489     rewriter.replaceOp(op, {result});
2490     return success();
2491   }
2492 };
2493 
2494 using ConvertAvgPool2DGradOp =
2495     ConvertAvgPoolGradOp<TF::AvgPoolGradOp, /*num_dims=*/4>;
2496 using ConvertAvgPool3DGradOp =
2497     ConvertAvgPoolGradOp<TF::AvgPool3DGradOp, /*num_dims=*/5>;
2498 
2499 // Converts MaxPool op to HLO ReduceWindow op by setting appropriate window
2500 // dimensions with max as the reduction function.
2501 //
2502 // Sample result for VALID padding mode:
2503 //
2504 //   %init = constant dense<...> : tensor<i32>
2505 //   %max_pool = "mhlo.reduce"(%inp, %init) ["mhlo.maximum"]
2506 //               {window_dimensions = ..., window_strides = ... }
2507 //
2508 template <typename OpTy, int num_dims>
2509 class ConvertMaxPoolOp : public OpRewritePattern<OpTy> {
2510  public:
2511   using OpRewritePattern<OpTy>::OpRewritePattern;
2512 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const2513   LogicalResult matchAndRewrite(OpTy op,
2514                                 PatternRewriter &rewriter) const override {
2515     Type element_type =
2516         op.input().getType().template cast<TensorType>().getElementType();
2517     if (!element_type.isSignlessIntOrFloat()) return failure();
2518     tensorflow::Padding padding;
2519     if (!GetPaddingFromString(op.padding().str(), &padding).ok())
2520       return failure();
2521     if (padding == tensorflow::Padding::EXPLICIT) {
2522       return failure();
2523     }
2524     Location loc = op.getLoc();
2525     ConstOp init = GetScalarLimitConstOfType(element_type, loc,
2526                                              hlo::kInfinityLowest, &rewriter);
2527 
2528     auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
2529     if (!input_ty) return failure();
2530     DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr<num_dims>(
2531         input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter);
2532     auto reduce = rewriter.create<ReduceWindowOp>(
2533         loc, op.getType(), op.input(), init, GetI64ElementsAttr(op.ksize()),
2534         GetI64ElementsAttr(op.strides()),
2535         /*base_dilations=*/DenseIntElementsAttr(),
2536         /*window_dilations=*/DenseIntElementsAttr(), paddings_attr);
2537     BuildReduceBody<MaxOp>(element_type, &reduce.body(), &rewriter);
2538 
2539     rewriter.replaceOp(op, reduce.getResult());
2540     return success();
2541   }
2542 };
2543 
2544 using ConvertMaxPool2DOp = ConvertMaxPoolOp<TF::MaxPoolOp, /*num_dims=*/4>;
2545 using ConvertMaxPool3DOp = ConvertMaxPoolOp<TF::MaxPool3DOp, /*num_dims=*/5>;
2546 
2547 // Converts SelectV2 to HLO Select op and necessary BroadcastInDim ops on
2548 // operands.
2549 //
2550 // For example, the following source IR:
2551 //
2552 //   %select = "tf.SelectV2"(%condition, %t, %e) :
2553 //               (tensor<1xi1>, tensor<2xi32>, tensor<1xi32>) -> tensor<2xi32>
2554 //
2555 // will be converted into:
2556 //
2557 //   %pred = "mhlo.broadcast_in_dim"(%cond)
2558 //             {broadcast_dimensions = dense<[0]> : tensor<1xi64>} :
2559 //               (tensor<1xi1>) -> tensor<2xi1>
2560 //   %on_false = "mhlo.broadcast_in_dim"(%e)
2561 //                 {broadcast_dimensions = dense<[0]> : tensor<1xi64>} :
2562 //                   (tensor<1xi32>) -> tensor<2xi32>
2563 //   %select = "mhlo.select"(%pred, %t, %on_false) :
2564 //               (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
2565 class ConvertSelectV2Op : public OpRewritePattern<TF::SelectV2Op> {
2566  public:
2567   using OpRewritePattern::OpRewritePattern;
2568 
matchAndRewrite(TF::SelectV2Op op,PatternRewriter & rewriter) const2569   LogicalResult matchAndRewrite(TF::SelectV2Op op,
2570                                 PatternRewriter &rewriter) const override {
2571     llvm::SmallVector<int64_t, 4> broadcast_then_else_shape;
2572     auto ranked_then_type = op.t().getType().dyn_cast<RankedTensorType>();
2573     auto ranked_else_type = op.e().getType().dyn_cast<RankedTensorType>();
2574     auto ranked_cond_type =
2575         op.condition().getType().dyn_cast<RankedTensorType>();
2576     if (!ranked_then_type || !ranked_then_type.hasStaticShape() ||
2577         !ranked_else_type || !ranked_else_type.hasStaticShape() ||
2578         !ranked_cond_type || !ranked_cond_type.hasStaticShape())
2579       return failure();
2580 
2581     if (!OpTrait::util::getBroadcastedShape(ranked_then_type.getShape(),
2582                                             ranked_else_type.getShape(),
2583                                             broadcast_then_else_shape))
2584       return failure();
2585 
2586     llvm::SmallVector<int64_t, 4> broadcast_shape;
2587     if (!OpTrait::util::getBroadcastedShape(broadcast_then_else_shape,
2588                                             ranked_cond_type.getShape(),
2589                                             broadcast_shape))
2590       return failure();
2591 
2592     auto broadcast_or_self = [&](Value value) {
2593       RankedTensorType type = value.getType().cast<RankedTensorType>();
2594       auto output_type =
2595           RankedTensorType::get(broadcast_shape, type.getElementType());
2596       if (output_type == type) return value;
2597 
2598       int64_t rank = type.getRank();
2599       SmallVector<int64_t, 4> broadcast_dimensions(rank);
2600       std::iota(broadcast_dimensions.begin(), broadcast_dimensions.end(),
2601                 broadcast_shape.size() - rank);
2602 
2603       return rewriter
2604           .create<BroadcastInDimOp>(
2605               op.getLoc(), output_type, value,
2606               GetI64ElementsAttr(broadcast_dimensions, &rewriter))
2607           .getResult();
2608     };
2609 
2610     // HLO SelectOp supports broadcasting for predicate/condition if
2611     // predicate/condition is a scalar.
2612     Value pred = ranked_cond_type.getRank() == 0
2613                      ? op.condition()
2614                      : broadcast_or_self(op.condition());
2615     Value on_true = broadcast_or_self(op.t());
2616     Value on_false = broadcast_or_self(op.e());
2617 
2618     rewriter.replaceOpWithNewOp<SelectOp>(op, on_true.getType(), pred, on_true,
2619                                           on_false);
2620 
2621     return success();
2622   };
2623 };
2624 
2625 // Converts Sigmoid op to HLO ops computing sigmoid with the following formula:
2626 //
2627 //     sigmoid = add(mul(tanh(mul(logits, 0.5)), 0.5), 0.5)
2628 //
2629 // Sample result with 2-d f16 inputs with B batches of with N elements each.
2630 //
2631 //    // Create an array of 0.5 the shape of the input array.
2632 //    %half = mhlo.constant dense<5.000000e-01> : tensor<f32>
2633 //    %half_array = "mhlo.broadcast"(half)
2634 //                           {broadcast_sizes = dense<2> : tensor<1xi64>}
2635 //                           : (tensor<f32>) -> tensor<2xf32>
2636 //
2637 //    // Compute Tanh of half the logits of the values.
2638 //    %halved_logits = mhlo.multiply %logits, %half_array : tensor<2xf32>
2639 //    %tanh = "mhlo.tanh"(%halved_logits) : (tensor<2xf32>) -> tensor<2xf32>
2640 //
2641 //    // Have the result of Tanh and add 0.5.
2642 //    %halved_tanh = mhlo.multiply %tanh, %half : tensor<2xf32>
2643 //    %sigmoid = mhlo.add %halved_tanh, %half : tensor<2xf32>
2644 //
2645 class ConvertSigmoidOp : public OpRewritePattern<TF::SigmoidOp> {
2646  public:
2647   using OpRewritePattern::OpRewritePattern;
2648 
matchAndRewrite(TF::SigmoidOp op,PatternRewriter & rewriter) const2649   LogicalResult matchAndRewrite(TF::SigmoidOp op,
2650                                 PatternRewriter &rewriter) const override {
2651     Location loc = op.getLoc();
2652 
2653     // Create constant half with shape and element type same as the operand.
2654     Value operand = op.getOperand();
2655     auto operand_ty = operand.getType().cast<TensorType>();
2656     auto scalar_ty = RankedTensorType::get({}, operand_ty.getElementType());
2657     ElementsAttr attr = mlir::hlo::getSplat(&rewriter, scalar_ty, 0.5);
2658     auto scalar_half = rewriter.create<ConstOp>(loc, attr);
2659     auto half = BroadcastToShapeOf(loc, scalar_half, operand, rewriter);
2660 
2661     auto scaled_input = rewriter.create<MulOp>(loc, operand, half);
2662     auto tanh_op = rewriter.create<TanhOp>(loc, scaled_input);
2663     auto mul_op = rewriter.create<MulOp>(loc, tanh_op, half);
2664     auto add_op = rewriter.create<AddOp>(loc, mul_op, half);
2665 
2666     rewriter.replaceOp(op, add_op.getResult());
2667     return success();
2668   }
2669 };
2670 
2671 // Converts Softmax and LogSoftmax to HLO ops, computing softmax with the
2672 // following formulas:
2673 //
2674 //     softmax = div(exp(logits), sum(exp(logits)))
2675 
2676 //     log_softmax = sub(logits, log(sum(exp(logits))))
2677 //
2678 // Sample result with 2-d f16 inputs with B batches of with N elements each.
2679 //
2680 //    %reduce_dim = tf.Const dense<[1]> : tensor<1xi64>
2681 //
2682 //    // Subtract each element by their batches' max to improve numerical
2683 //    // stability.
2684 //    %max = "tf.Max"(%input, %reduce_dim)
2685 //           : (tensor<BxNxf16>, tensor<1xi64>) -> tensor<Bxf16>
2686 //    %sub = "mhlo.subtract"(%inp, %max) {broadcast_dimensions = 0}
2687 //            : (tensor<BxNxf16>, tensor<Bxf16>) -> tensor<BxNxf16>
2688 //
2689 //    %exp = "mhlo.exponential"(%sub) : (tensor<BxNxf16>) -> tensor<BxNxf16>
2690 //    %sum = "tf.Sum"(%exp, %reduce_dim)
2691 //            : (tensor<BxNxf32>, tensor<1xi64>) -> tensor<Bxf32>
2692 //
2693 //    // Softmax computation:
2694 //    %softmax = "mhlo.divide"(%exp, %sum_f16) {broadcast_dimensions = 0}
2695 //            : (tensor<BxNxf16>, tensor<Bxf16>) -> tensor<BxNxf16>
2696 template <typename OpTy, bool use_log = true>
2697 class ConvertSoftmaxOp : public OpRewritePattern<OpTy> {
2698  public:
2699   using OpRewritePattern<OpTy>::OpRewritePattern;
2700 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const2701   LogicalResult matchAndRewrite(OpTy op,
2702                                 PatternRewriter &rewriter) const override {
2703     // Softmax converter requires ranked type because the XLA reduce ops used
2704     // while lowering requires dimensions attribute to reduce along.
2705     // Note that the input and output shape is equivalent, so we use 'logits'
2706     // and its type for shape calculations.
2707     Value logits = op.logits();
2708     RankedTensorType type = logits.getType().dyn_cast<RankedTensorType>();
2709     if (!type) return failure();
2710     auto loc = op.getLoc();
2711     int rank = type.getRank();
2712 
2713     // Note that the TensorFlow Softmax op verifies that the input rank is
2714     // greater than or equal to one so the following sequence is valid.
2715     auto reduce_dim = rewriter.create<TF::ConstOp>(
2716         loc, GetI64ElementsAttr({rank - 1}, &rewriter));
2717 
2718     // Exponential of input values and then their sum can be very large here.
2719     // Division with large denominator is numerically unstable. To improve
2720     // numerical stability, subtract each batch with their max element so that
2721     // the maximum input value is zero. It can be shown that softmax computed
2722     // after adding or subtracting all inputs in a batch using a common value
2723     // gives mathematically equivalent result.
2724     auto max_logits =
2725         rewriter.create<TF::MaxOp>(loc, logits, reduce_dim,
2726                                    /*keep_dims=*/rewriter.getBoolAttr(false));
2727     auto max_logits_broadcast =
2728         CommonPrefixBroadcast(loc, logits, max_logits, rewriter);
2729     auto shifted_logits =
2730         rewriter.create<mhlo::SubOp>(loc, type, logits, max_logits_broadcast);
2731 
2732     // Exponentiate the inputs.
2733     Value exp = rewriter.create<ExpOp>(loc, type, shifted_logits);
2734 
2735     // Compute summation of the exponentials.
2736     auto exp_sum =
2737         rewriter.create<TF::SumOp>(loc, exp, reduce_dim,
2738                                    /*keep_dims=*/rewriter.getBoolAttr(false));
2739     Value sum = exp_sum.getResult();
2740 
2741     if (use_log) {
2742       Value log = rewriter.create<LogOp>(loc, sum);
2743       auto log_broadcast = CommonPrefixBroadcast(loc, logits, log, rewriter);
2744       rewriter.replaceOpWithNewOp<mhlo::SubOp>(op, shifted_logits,
2745                                                log_broadcast);
2746     } else {
2747       auto sum_broadcast = CommonPrefixBroadcast(loc, logits, sum, rewriter);
2748       rewriter.replaceOpWithNewOp<mhlo::DivOp>(op, exp, sum_broadcast);
2749     }
2750     return success();
2751   }
2752 };
2753 
BroadcastBatchMatMulV2Operands(Value lhs,Value rhs,Location loc,Value * out_lhs,Value * out_rhs,PatternRewriter * rewriter)2754 static void BroadcastBatchMatMulV2Operands(Value lhs, Value rhs, Location loc,
2755                                            Value *out_lhs, Value *out_rhs,
2756                                            PatternRewriter *rewriter) {
2757   // The dimension structure of the relevant operands to a tf.BatchMatMulV2 is:
2758   // - lhs: [LHSBATCHDIMS..., LHSROWS, LHSCOLS]
2759   // - rhs: [RHSBATCHDIMS..., RHSROWS, RHSCOLS]
2760   // - result: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., LHSROWS, RHSCOLS]
2761   // To perform the matmul, we need to first broadcast lhs and rhs to a common
2762   // set of leading dimensions before doing the actual matmul.
2763   // That's what the code below does.
2764   // In particular, we populate out_lhs and out_rhs to have dimension structure:
2765   // - out_lhs: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., LHSROWS, LHSCOLS]
2766   // - out_rhs: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., RHSROWS, RHSCOLS]
2767   // To do this, we need to calculate those output shapes, which involves
2768   // slicing off the leading batch dims of each operand, broadcasting them,
2769   // then concatenating the broadcasted leading dims back to the row/col dims.
2770   // Finally, we create a TF::BroadcastTo op that does the actual broadcast.
2771 
2772   // TODO(silvasean): Reduce duplication across reified shape calculations and
2773   // the static computation of output types needed to create ops.
2774   Value lhs_shape = rewriter->create<shape::ShapeOfOp>(loc, lhs);
2775   Value rhs_shape = rewriter->create<shape::ShapeOfOp>(loc, rhs);
2776   Value const_neg2 =
2777       rewriter->create<ConstantOp>(loc, rewriter->getIndexAttr(-2));
2778   auto lhs_splitted =
2779       rewriter->create<shape::SplitAtOp>(loc, lhs_shape, const_neg2);
2780   auto rhs_splitted =
2781       rewriter->create<shape::SplitAtOp>(loc, rhs_shape, const_neg2);
2782   auto lhs_type = lhs.getType().cast<RankedTensorType>();
2783   auto rhs_type = rhs.getType().cast<RankedTensorType>();
2784   // The last two dimensions are the matrix row/col dimensions. Don't broadcast
2785   // them.
2786   SmallVector<int64_t, 6> result_batch_shape_compile_time_extents;
2787   OpTrait::util::getBroadcastedShape(lhs_type.getShape().drop_back(2),
2788                                      rhs_type.getShape().drop_back(2),
2789                                      result_batch_shape_compile_time_extents);
2790   auto result_batch_shape = rewriter->create<shape::BroadcastOp>(
2791       loc, shape::ShapeType::get(rewriter->getContext()), lhs_splitted.head(),
2792       rhs_splitted.head(),
2793       /*error=*/nullptr);
2794   // Lambda which handles the broadcasting of one side to the common
2795   // leading-batch dimensions.
2796   auto broadcast_one_side = [&](Value side, RankedTensorType type,
2797                                 Value tail_shape, Value *out_side) {
2798     ArrayRef<int64_t> matrix_dims = type.getShape().take_back(2);
2799     auto result_shape = result_batch_shape_compile_time_extents;
2800     result_shape.append(matrix_dims.begin(), matrix_dims.end());
2801     auto result_type =
2802         RankedTensorType::get(result_shape, type.getElementType());
2803     auto shape =
2804         rewriter->create<shape::ConcatOp>(loc, result_batch_shape, tail_shape);
2805     auto shape_tensor = rewriter->create<shape::ToExtentTensorOp>(
2806         loc,
2807         RankedTensorType::get({static_cast<int64_t>(result_shape.size())},
2808                               rewriter->getIndexType()),
2809         shape);
2810     *out_side = rewriter->create<TF::BroadcastToOp>(loc, result_type, side,
2811                                                     shape_tensor);
2812   };
2813   broadcast_one_side(lhs, lhs_type, lhs_splitted.tail(), out_lhs);
2814   broadcast_one_side(rhs, rhs_type, rhs_splitted.tail(), out_rhs);
2815 }
2816 
2817 class ConvertBatchMatMulV2Op : public OpRewritePattern<TF::BatchMatMulV2Op> {
2818  public:
2819   // TODO(hinsu): Legalize this op to Einsum op. HLO Einsum op needs to be moved
2820   // to CHLO and it is missing legalization to MHLO. Once that is done, this
2821   // pattern's benefit can be changed back to one as well as the fallback
2822   // lowering pattern for the op can be removed.
2823   //
2824   // Set benefit of this pattern to zero to prefer the fallback pattern when
2825   // available and applicable. That pattern avoids broadcast on operands and is
2826   // therefore faster.
ConvertBatchMatMulV2Op(MLIRContext * context)2827   explicit ConvertBatchMatMulV2Op(MLIRContext *context)
2828       : OpRewritePattern<TF::BatchMatMulV2Op>(context, /*benefit=*/0) {}
2829 
matchAndRewrite(TF::BatchMatMulV2Op op,PatternRewriter & rewriter) const2830   LogicalResult matchAndRewrite(TF::BatchMatMulV2Op op,
2831                                 PatternRewriter &rewriter) const override {
2832     Value lhs = op.x();
2833     Value rhs = op.y();
2834     auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
2835     auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
2836     if (!lhs_type || !rhs_type) return failure();
2837     if (lhs_type.getElementType().isa<ComplexType>() && op.adj_x()) {
2838       lhs = rewriter.create<TF::ConjOp>(op.getLoc(), lhs_type, lhs);
2839     }
2840     if (rhs_type.getElementType().isa<ComplexType>() && op.adj_y()) {
2841       rhs = rewriter.create<TF::ConjOp>(op.getLoc(), rhs_type, rhs);
2842     }
2843 
2844     // Broadcast both operands.
2845     BroadcastBatchMatMulV2Operands(lhs, rhs, op.getLoc(), &lhs, &rhs,
2846                                    &rewriter);
2847     lhs_type = lhs.getType().cast<RankedTensorType>();
2848     rhs_type = rhs.getType().cast<RankedTensorType>();
2849     assert(lhs_type.getRank() == rhs_type.getRank());
2850     int64_t rank = lhs_type.getRank();
2851     auto batch_dimensions = GetI64ElementsAttr(
2852         llvm::to_vector<4>(llvm::seq<int64_t>(0, rank - 2)), &rewriter);
2853     auto lhs_contracting_dimensions = GetI64ElementsAttr(
2854         llvm::makeArrayRef({op.adj_x() ? rank - 2 : rank - 1}), &rewriter);
2855     auto rhs_contracting_dimensions = GetI64ElementsAttr(
2856         llvm::makeArrayRef({op.adj_y() ? rank - 1 : rank - 2}), &rewriter);
2857     auto dimension_numbers = DotDimensionNumbers::get(
2858         /*lhs_batching_dimensions=*/batch_dimensions,
2859         /*rhs_batching_dimensions=*/batch_dimensions,
2860         /*lhs_contracting_dimensions=*/lhs_contracting_dimensions,
2861         /*rhs_contracting_dimensions=*/rhs_contracting_dimensions,
2862         rewriter.getContext());
2863     // TODO(silvasean): Emit shape checks for contracting dimensions.
2864     // (The batch dimensions are checked by the broadcasting logic)
2865     rewriter.replaceOpWithNewOp<DotGeneralOp>(op, op.getType(), lhs, rhs,
2866                                               dimension_numbers,
2867                                               /*precision_config=*/nullptr);
2868     return success();
2869   }
2870 };
2871 
2872 // Converts the tf.Split op into a series of HLO slice ops when the tensor to be
2873 // split has fully static shape and the dimension to split is a constant.
2874 //
2875 // The main logic of this pattern is to calculate the index start and end range
2876 // for each slice. And this happens only on the dimension to be split; for all
2877 // other dimensions, all resultant slices' index start and end range covers the
2878 // input tensor's full range. Strides for all resultant slices are all one.
2879 //
2880 // For example, the following source IR:
2881 //
2882 //   %dim = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
2883 //   %0:3 = "tf.Split"(%dim, %input) : (tensor<i32>, tensor<4x6xf32>) ->
2884 //                (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>)
2885 //
2886 // will be converted into:
2887 //
2888 //   %0 = "mhlo.slice"(%input) {
2889 //             limit_indices = dense<[4, 2]> : tensor<2xi64>,
2890 //             start_indices = dense<0> : tensor<2xi64>,
2891 //             strides = dense<1> : tensor<2xi64>} :
2892 //        (tensor<4x6xf32>) -> tensor<4x2xf32>
2893 //   %1 = "mhlo.slice"(%input) {
2894 //             limit_indices = dense<4> : tensor<2xi64>,
2895 //              start_indices = dense<[0, 2]> : tensor<2xi64>,
2896 //            strides = dense<1> : tensor<2xi64>} :
2897 //        (tensor<4x6xf32>) -> tensor<4x2xf32>
2898 //    %2 = "mhlo.slice"(%input) {
2899 //            limit_indices = dense<[4, 6]> : tensor<2xi64>,
2900 //            start_indices = dense<[0, 4]> : tensor<2xi64>,
2901 //             strides = dense<1> : tensor<2xi64>} :
2902 //        (tensor<4x6xf32>) -> tensor<4x2xf32>
2903 // TODO(antiagainst): consider lowering into TF ops so the pattern can be more
2904 // applicable.
2905 class ConvertSplitOp : public OpRewritePattern<TF::SplitOp> {
2906  public:
2907   using OpRewritePattern::OpRewritePattern;
2908 
matchAndRewrite(TF::SplitOp op,PatternRewriter & rewriter) const2909   LogicalResult matchAndRewrite(TF::SplitOp op,
2910                                 PatternRewriter &rewriter) const override {
2911     // We can only split along static dimensions.
2912     auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
2913     if (!input_type) return failure();
2914 
2915     // We can only match when the split dimension is a constant scalar.
2916     DenseIntElementsAttr split_dim_attr;
2917     if (!matchPattern(op.split_dim(), m_Constant(&split_dim_attr)))
2918       return failure();
2919 
2920     // Get the dimension we are splitting at. Offset properly if it's negative.
2921     int64_t input_rank = input_type.getRank();
2922     int64_t dim_index = (*split_dim_attr.begin()).getSExtValue();
2923     if (dim_index < 0) dim_index += input_rank;
2924 
2925     // Calculate the dimension size for each slice along the split dimension.
2926     int64_t input_dim_size = input_type.getDimSize(dim_index);
2927     // If we are splitting along the dynamic dimension then we cannot compute
2928     // the static dimension length.
2929     if (TensorType::isDynamic(input_dim_size)) return failure();
2930 
2931     int64_t num_splits = op.getNumResults();
2932     int64_t slice_size = input_dim_size / num_splits;
2933 
2934     // Get each slice's type.
2935     auto slice_shape = llvm::to_vector<4>(input_type.getShape());
2936     slice_shape[dim_index] = slice_size;
2937     Type slice_type =
2938         RankedTensorType::get(slice_shape, input_type.getElementType());
2939 
2940     // Parameters for constructing each slice.
2941     SmallVector<int64_t, 4> begin_indices(input_rank, 0);
2942     auto end_indices = llvm::to_vector<4>(input_type.getShape());
2943     SmallVector<int64_t, 4> strides(input_rank, 1);
2944 
2945     // All HLO slice results used to replace the original tf.Split op.
2946     SmallVector<Value, 4> slices;
2947     slices.reserve(num_splits);
2948 
2949     for (int i = 0; i < num_splits; ++i) {
2950       begin_indices[dim_index] = i * slice_size;
2951       end_indices[dim_index] = (i + 1) * slice_size;
2952       slices.push_back(
2953           rewriter.create<SliceOp>(op.getLoc(), slice_type, op.value(),
2954                                    GetI64ElementsAttr(begin_indices, &rewriter),
2955                                    GetI64ElementsAttr(end_indices, &rewriter),
2956                                    GetI64ElementsAttr(strides, &rewriter)));
2957     }
2958 
2959     rewriter.replaceOp(op, slices);
2960     return success();
2961   }
2962 };
2963 
2964 // Converts the tf.SplitV op into a series of HLO slice ops when the tensor to
2965 // be split has fully static shape and the dimension to split and split sizes
2966 // are constants.
2967 //
2968 // This is similar to the conversion for tf.Split op other than that the size of
2969 // each chunk on the dimension to split is explicitly given as an op operand
2970 // and they are not necessarily the same.
2971 //
2972 // For example, given the following IR:
2973 //
2974 // %split_sizes = "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi32>}
2975 // %split_dim = "tf.Const"() {value = dense<1> : tensor<i32>}
2976 // %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) :
2977 //                   (tensor<4x6xf32>, tensor<3xi32>, tensor<i32>) ->
2978 //                   (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>)
2979 //
2980 // We will generate slices following slices:
2981 // %0 = "mhlo.slice"(%input) {
2982 //        limit_indices = dense<[4, 1]> : tensor<2xi64>,
2983 //        start_indices = dense<0> : tensor<2xi64>,
2984 //        strides = dense<1> : tensor<2xi64>} :
2985 //        (tensor<4x6xf32>) -> tensor<4x1xf32>
2986 // %1 = "mhlo.slice"(%input) {
2987 //        limit_indices = dense<[4, 3]> : tensor<2xi64>,
2988 //        start_indices = dense<[0, 1]> : tensor<2xi64>,
2989 //        strides = dense<1> : tensor<2xi64>} :
2990 //        (tensor<4x6xf32>) -> tensor<4x2xf32>
2991 // %2 = "mhlo.slice"(%input) {
2992 //        limit_indices = dense<[4, 6]> : tensor<2xi64>,
2993 //        start_indices = dense<[0, 3]> : tensor<2xi64>,
2994 //        strides = dense<1> : tensor<2xi64>} :
2995 //        (tensor<4x6xf32>) -> tensor<4x3xf32>
2996 class ConvertSplitVOp : public OpRewritePattern<TF::SplitVOp> {
2997  public:
2998   using OpRewritePattern::OpRewritePattern;
2999 
matchAndRewrite(TF::SplitVOp op,PatternRewriter & rewriter) const3000   LogicalResult matchAndRewrite(TF::SplitVOp op,
3001                                 PatternRewriter &rewriter) const override {
3002     // We can only split along static dimensions.
3003     // TODO(b/145731001): enhance to support dynamic-shaped inputs.
3004     auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
3005     if (!input_type) return failure();
3006 
3007     // We can only match when the split dimension is a constant scalar.
3008     DenseIntElementsAttr split_dim_attr;
3009     if (!matchPattern(op.split_dim(), m_Constant(&split_dim_attr)))
3010       return failure();
3011 
3012     // We can only match when the split sizes is a constant int vector.
3013     DenseIntElementsAttr split_sizes_attr;
3014     if (!matchPattern(op.size_splits(), m_Constant(&split_sizes_attr)))
3015       return failure();
3016 
3017     // Get each chunck's size along the dimension to split. It may contain
3018     // dynamic sizes and we need to update it if so.
3019     SmallVector<int64_t, 4> split_sizes;
3020     int64_t total_dim_size = 0;  // Total dimension size assigned to splits
3021     llvm::Optional<int> dynamic_dim_index;
3022     split_sizes.reserve(
3023         split_sizes_attr.getType().cast<ShapedType>().getNumElements());
3024     for (auto dim : llvm::enumerate(split_sizes_attr)) {
3025       int64_t dim_val = dim.value().getSExtValue();
3026       split_sizes.push_back(dim_val);
3027       if (dim_val == ShapedType::kDynamicSize) {
3028         // We cannot have more than one dynamic dimension.
3029         assert(!dynamic_dim_index && "invalid split sizes");
3030         dynamic_dim_index = dim.index();
3031       } else {
3032         total_dim_size += dim_val;
3033       }
3034     }
3035 
3036     // Get the dimension we are splitting at. Offset properly if it's negative.
3037     int64_t input_rank = input_type.getRank();
3038     int64_t dim_index = (*split_dim_attr.begin()).getSExtValue();
3039     if (dim_index < 0) dim_index += input_rank;
3040 
3041     int64_t input_dim_size = input_type.getDimSize(dim_index);
3042     if (TensorType::isDynamic(input_dim_size)) return failure();
3043 
3044     assert(((dynamic_dim_index && total_dim_size <= input_dim_size) ||
3045             (!dynamic_dim_index && total_dim_size == input_dim_size)) &&
3046            "invalid split sizes");
3047 
3048     // Update the dynamic dimension with calculated concrete size.
3049     if (dynamic_dim_index)
3050       split_sizes[*dynamic_dim_index] = input_dim_size - total_dim_size;
3051 
3052     // Parameters for constructing each slice.
3053     SmallVector<int64_t, 4> begin_indices(input_rank, 0);
3054     auto end_indices = llvm::to_vector<4>(input_type.getShape());
3055     SmallVector<int64_t, 4> strides(input_rank, 1);
3056 
3057     // All HLO slice results used to replace the original tf.Split op.
3058     SmallVector<Value, 4> slices;
3059     slices.reserve(op.getNumResults());
3060 
3061     for (int i = 0, end = op.getNumResults(); i < end; ++i) {
3062       end_indices[dim_index] = begin_indices[dim_index] + split_sizes[i];
3063       slices.push_back(rewriter.create<mhlo::SliceOp>(
3064           op.getLoc(), op.value(), GetI64ElementsAttr(begin_indices, &rewriter),
3065           GetI64ElementsAttr(end_indices, &rewriter),
3066           GetI64ElementsAttr(strides, &rewriter)));
3067       // Prepare the begin indice for the next slice.
3068       begin_indices[dim_index] = end_indices[dim_index];
3069     }
3070 
3071     rewriter.replaceOp(op, slices);
3072     return success();
3073   }
3074 };
3075 
3076 // Converts StridedSlice op to HLO Slice op along with Reverse op to handle
3077 // negative strides and Reshape op to update the output shape. Indices and
3078 // strides operands are converted to attributes with non-negative indexing.
3079 //
3080 // If the begin input is not a compile time constant, the begin input needs to
3081 // be sliced and the slice needs to be lowered to mhlo.DynamicSlice. In this
3082 // case, strides must have a known value of 1 (otherwise we have insufficient
3083 // information to conform to XLA's op semantics).
3084 //
3085 // For example with an op like following,
3086 //   tf.StridedSlice(%input, %begin, %end, %strides) {shrink_axis_mask = 1}
3087 //     : tensor<AxBxf32> -> tensor<Pxf32>
3088 //
3089 // If the %begin input is constant, output would be:
3090 //   %reversed = "mhlo.Reverse" (%input) {dimensions = ...}
3091 //   %sliced = "mhlo.Slice" (%input)
3092 //             {start_indices = ..., limit_indices = ..., strides = ...}
3093 //   %output = "mhlo.Reshape" (%sliced) : tensor<1xPxf32> -> tensor<Pxf32>
3094 //
3095 class ConvertStridedSliceOp : public OpRewritePattern<TF::StridedSliceOp> {
3096  public:
3097   using OpRewritePattern::OpRewritePattern;
3098 
rewriteWithConstantBegin(TF::StridedSliceOp op,ArrayRef<int64_t> begin_indices,ArrayRef<int64_t> end_indices,ArrayRef<int64_t> strides,RankedTensorType input_ty,PatternRewriter & rewriter) const3099   LogicalResult rewriteWithConstantBegin(TF::StridedSliceOp op,
3100                                          ArrayRef<int64_t> begin_indices,
3101                                          ArrayRef<int64_t> end_indices,
3102                                          ArrayRef<int64_t> strides,
3103                                          RankedTensorType input_ty,
3104                                          PatternRewriter &rewriter) const {
3105     SmallVector<int64_t, 4> hlo_begin_indices, hlo_end_indices, hlo_strides,
3106         dims_to_reverse;
3107     int64_t input_rank = input_ty.getRank();
3108     ArrayRef<int64_t> input_shape = input_ty.getShape();
3109     hlo_begin_indices.reserve(input_rank);
3110     hlo_end_indices.reserve(input_rank);
3111     hlo_strides.reserve(input_rank);
3112 
3113     int64_t indices_elements = begin_indices.size();
3114     if (input_rank < indices_elements) return failure();
3115 
3116     // Convert from TensorFlow negative or out of range indices and strides
3117     // values to legal HLO Slice attributes.
3118     for (int i = 0, e = indices_elements; i != e; i++) {
3119       int64_t begin = begin_indices[i];
3120       int64_t end = end_indices[i];
3121       int64_t stride = strides[i];
3122 
3123       if (stride < 0) {
3124         // Negative stride means that the output values are computed starting
3125         // from end until begin. Mark the dimension for reversal before slice
3126         // and compute indices for the reversed input.
3127         dims_to_reverse.push_back(i);
3128         begin = (input_shape[i] - 1) - begin;
3129         end = (input_shape[i] - 1) - end;
3130         stride = -stride;
3131       }
3132 
3133       // Unlike TensorFlow, HLO requires begin and end values to be within
3134       // range.
3135       begin = std::max(int64_t(0), begin);
3136       end = std::max(begin, end);
3137       end = std::min(end, input_shape[i]);
3138 
3139       hlo_begin_indices.push_back(begin);
3140       hlo_end_indices.push_back(end);
3141       hlo_strides.push_back(stride);
3142     }
3143 
3144     Location loc = op.getLoc();
3145     Value input = op.input();
3146     if (!dims_to_reverse.empty())
3147       input = rewriter.create<ReverseOp>(
3148           loc, input_ty, op.input(),
3149           GetI64ElementsAttr(dims_to_reverse, &rewriter));
3150     auto sliced = rewriter.create<SliceOp>(
3151         loc, input, GetI64ElementsAttr(hlo_begin_indices, &rewriter),
3152         GetI64ElementsAttr(hlo_end_indices, &rewriter),
3153         GetI64ElementsAttr(hlo_strides, &rewriter));
3154 
3155     // Reshape slice result so that the shape is updated depending on
3156     // 'new_axis_mask' or 'shrink_axis_mask' attributes.
3157     rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), sliced);
3158     return success();
3159   }
3160 
rewriteWithUnknownBegin(TF::StridedSliceOp op,RankedTensorType input_ty,RankedTensorType result_ty,PatternRewriter & rewriter) const3161   LogicalResult rewriteWithUnknownBegin(TF::StridedSliceOp op,
3162                                         RankedTensorType input_ty,
3163                                         RankedTensorType result_ty,
3164                                         PatternRewriter &rewriter) const {
3165     // If begin and end values are dynamic, we can only support this lowering
3166     // if strides are a known value of 1.
3167     DenseIntElementsAttr sparse_strides_attr;
3168     if (!matchPattern(op.strides(), m_Constant(&sparse_strides_attr))) {
3169       return rewriter.notifyMatchFailure(
3170           op,
3171           "requires that strides are known when begin/end values are dynamic");
3172     }
3173     SmallVector<int64_t, 4> strides;
3174     int64_t stride_value;
3175     for (const APInt &stride : sparse_strides_attr) {
3176       if ((stride_value = stride.getSExtValue()) != 1) {
3177         return rewriter.notifyMatchFailure(op,
3178                                            "requires that strides are all 1 "
3179                                            "when begin/end values are dynamic");
3180       }
3181       strides.push_back(stride_value);
3182     }
3183 
3184     ArrayRef<int64_t> input_shape = input_ty.getShape();
3185     int last_dim = std::max(static_cast<int>(input_shape.size()) - 1, 0);
3186 
3187     // When begin/end values are dynamic, we can only support shrinking a major
3188     // axis. For instance, if there are 4 dims, we can support a
3189     // shrink_axis_mask of 0001 (1), 0011 (3), 0111 (7), or 1111 (15), but no
3190     // other.
3191     bool shrink_axis_mask_ok = llvm::isMask_64(op.shrink_axis_mask());
3192     if (!shrink_axis_mask_ok)
3193       return rewriter.notifyMatchFailure(
3194           op,
3195           "requires that shrink_axis_mask, if set, refer to a major axis "
3196           "dimension (when begin/end values are dynamic)");
3197 
3198     // When begin/end values are dynamic, the ellipsis mask, if set, must refer
3199     // to the last dimension.
3200     int ellipsis_mask = op.ellipsis_mask();
3201     if (!(ellipsis_mask == 0 || ellipsis_mask == (1 << last_dim)))
3202       return rewriter.notifyMatchFailure(
3203           op,
3204           "requires that ellipsis_mask, if set, refer to the last dimension of "
3205           "input (when begin/end values are dynamic)");
3206 
3207     uint64_t begin_mask = op.begin_mask();
3208     if (begin_mask)
3209       return rewriter.notifyMatchFailure(
3210           op,
3211           "requires that begin_mask is either set to 0 or not set when "
3212           "begin/end values are dynamic");
3213     uint64_t end_mask = op.end_mask();
3214     if (end_mask)
3215       return rewriter.notifyMatchFailure(
3216           op,
3217           "requires that end_mask is either set to 0 or not set when begin/end "
3218           "values are dynamic");
3219     uint64_t new_axis_mask = op.new_axis_mask();
3220     if (new_axis_mask)
3221       return rewriter.notifyMatchFailure(
3222           op,
3223           "requires that new_axis_mask is either set to 0 or not set when "
3224           "begin/end values are dynamic");
3225 
3226     // In this case where the begin and end values are dynamic, the number of
3227     // output elements has to be equal to the number of input elements that
3228     // are sliced.
3229     int output_elements = result_ty.getNumElements();
3230     int input_elements_sliced = 1;
3231 
3232     // Begin must be a ranked, 1-dimensional tensor: This is checked by the
3233     // verifier.
3234     int64_t slicing_dim_size =
3235         op.begin().getType().cast<RankedTensorType>().getShape()[0];
3236     const int input_rank = input_shape.size();
3237     for (int d = slicing_dim_size; d < input_rank; ++d) {
3238       // We only support slicing major dimensions, so minor dimensions after
3239       // slicing dimensions are all sliced with their full sizes.
3240       input_elements_sliced *= input_shape[d];
3241     }
3242     if (input_elements_sliced != output_elements) {
3243       return rewriter.notifyMatchFailure(
3244           op,
3245           "requires the number of output elements to be equal to the number of "
3246           "input elements sliced (when begin/end values are dynamic)");
3247     }
3248 
3249     SmallVector<Value, 4> slice_begin_indices;
3250     // For the dimensions that are to be sliced, all have slice sizes of 1.
3251     SmallVector<int64_t, 4> slice_sizes(slicing_dim_size, 1);
3252     auto begin_element_ty =
3253         op.begin().getType().cast<ShapedType>().getElementType();
3254     // Scalar tensor type.
3255     TensorType type = RankedTensorType::get(/*shape=*/{}, begin_element_ty);
3256     Location loc = op.getLoc();
3257     auto zero = GetScalarConstOfType(begin_element_ty, loc, 0, &rewriter);
3258     for (int d = 0; d < slicing_dim_size; ++d) {
3259       auto index = rewriter.create<SliceOp>(
3260           loc, op.begin(), GetI64ElementsAttr({d}, &rewriter),
3261           GetI64ElementsAttr({d + 1}, &rewriter),
3262           GetI64ElementsAttr({1}, &rewriter));
3263       // Convert index to scalar.
3264       auto reshaped_index = rewriter.create<ReshapeOp>(loc, type, index);
3265       // If the index is negative, wrap it around with dimension size.
3266       auto index_negative =
3267           rewriter.create<TF::LessOp>(loc, reshaped_index, zero);
3268       auto input_val = GetScalarConstOfType(begin_element_ty, loc,
3269                                             input_shape[d], &rewriter);
3270       auto wrapped_index =
3271           rewriter.create<TF::AddV2Op>(loc, input_val, reshaped_index);
3272       auto final_index = rewriter.create<SelectOp>(
3273           loc, type, index_negative, wrapped_index, reshaped_index);
3274       slice_begin_indices.push_back(final_index);
3275     }
3276 
3277     // For non-slice dims, get the full slice of that dimension.
3278     for (int d = slicing_dim_size, end = input_shape.size(); d < end; ++d) {
3279       slice_sizes.push_back(input_shape[d]);
3280       slice_begin_indices.push_back(zero);
3281     }
3282 
3283     auto slice_sizes_attr = GetI64ElementsAttr(slice_sizes, &rewriter);
3284     // This must be an xla DynamicSlice op due to the inputs that aren't
3285     // constant.
3286     auto sliced = rewriter.create<DynamicSliceOp>(
3287         loc, op.getType(), op.input(), slice_begin_indices, slice_sizes_attr);
3288 
3289     // Reshape slice result so that the shape is updated depending on
3290     // 'new_axis_mask' or 'shrink_axis_mask' attributes.
3291     rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), sliced);
3292     return success();
3293   }
3294 
matchAndRewrite(TF::StridedSliceOp op,PatternRewriter & rewriter) const3295   LogicalResult matchAndRewrite(TF::StridedSliceOp op,
3296                                 PatternRewriter &rewriter) const override {
3297     // Input shape needs to be static to convert negative indices in TensorFlow
3298     // to absolute indices required by HLO.
3299     //
3300     // TODO(hinsu): Relax this constraint for ops without negative indices and
3301     // strides.
3302     auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
3303     if (!input_ty || !input_ty.hasStaticShape()) return failure();
3304 
3305     // Output shape needs to be static to apply 'new_axis_mask' or
3306     // 'shrink_axis_mask' by reshaping tensor after slice.
3307     //
3308     // TODO(hinsu): Relax this constraint for ops without the above masks.
3309     auto result_ty = op.getType().dyn_cast<RankedTensorType>();
3310     if (!result_ty || !result_ty.hasStaticShape()) return failure();
3311 
3312     DenseIntElementsAttr sparse_begin_attr, sparse_end_attr;
3313     if (!matchPattern(op.begin(), m_Constant(&sparse_begin_attr)) ||
3314         !matchPattern(op.end(), m_Constant(&sparse_end_attr))) {
3315       return rewriteWithUnknownBegin(op, input_ty, result_ty, rewriter);
3316     }
3317 
3318     SmallVector<int64_t, 4> begin_indices, end_indices, strides;
3319     if (!op.GetSlicedBoundRanges(&begin_indices, &end_indices, &strides)) {
3320       return failure();
3321     }
3322     return rewriteWithConstantBegin(op, begin_indices, end_indices, strides,
3323                                     input_ty, rewriter);
3324   }
3325 };
3326 
3327 // Converts tf.StridedSliceGrad to HLO reshape, reverse and padding ops.
3328 //
3329 // tf.StridedSlice is taking slice of the input tensor. tf.StridedSliceGrad does
3330 // the reverse: it propagates the graident for the sliced tensor to the original
3331 // input tensor by doing padding with zeros. The main logic is calculating the
3332 // indices and strides for padding.
3333 class ConvertStridedSliceGradOp
3334     : public OpRewritePattern<TF::StridedSliceGradOp> {
3335  public:
3336   using OpRewritePattern::OpRewritePattern;
3337 
matchAndRewrite(TF::StridedSliceGradOp op,PatternRewriter & rewriter) const3338   LogicalResult matchAndRewrite(TF::StridedSliceGradOp op,
3339                                 PatternRewriter &rewriter) const override {
3340     // We need constant input shape to perform padding calculations later.
3341     DenseIntElementsAttr input_shape_attr;
3342     if (!matchPattern(op.shape(), m_Constant(&input_shape_attr)))
3343       return failure();
3344 
3345     // We also need constant begin/end indices and strides to perform padding
3346     // calculations.
3347     // Bounded shape after performing strided slice
3348     SmallVector<int64_t, 4> shape;
3349     // Bounded begin, end, and strides for strided slice
3350     SmallVector<int64_t, 4> begin_indices, end_indices, strides;
3351     if (!op.GetSlicedShapeAndBoundRanges(&shape, &begin_indices, &end_indices,
3352                                          &strides))
3353       return failure();
3354 
3355     Value grad = op.dy();
3356     Type element_type = grad.getType().cast<ShapedType>().getElementType();
3357 
3358     // Perform reshape to undo any new/shrink axes done by strided slice.
3359     grad = rewriter.create<mhlo::ReshapeOp>(
3360         op.getLoc(), RankedTensorType::get(shape, element_type), grad);
3361 
3362     SmallVector<int64_t, 4> padding_low, padding_high, padding_interm;
3363     SmallVector<int64_t, 4> dims_to_reverse;
3364     padding_low.reserve(shape.size());
3365     padding_high.reserve(shape.size());
3366     padding_interm.reserve(shape.size());
3367 
3368     // Prepare padding parameters for each dimension.
3369     for (int i = 0, e = shape.size(); i < e; ++i) {
3370       int64_t input_dim = (*(input_shape_attr.begin() + i)).getSExtValue();
3371       if (strides[i] > 0) {
3372         padding_low.push_back(begin_indices[i]);
3373         padding_interm.push_back(strides[i] - 1);
3374 
3375         // Pad the upper dimension up to the expected input shape. It's not
3376         // sufficient simply to use end_indices[i] to compute the padding in
3377         // cases where the stride does not divide evenly into the interval
3378         // between begin_indices[i] and end_indices[i].
3379         int64_t size =
3380             padding_low[i] + shape[i] + (shape[i] - 1) * padding_interm[i];
3381         padding_high.push_back(input_dim - size);
3382       } else {
3383         dims_to_reverse.push_back(i);
3384         padding_high.push_back(input_dim - begin_indices[i] - 1);
3385         padding_interm.push_back(-strides[i] - 1);
3386 
3387         // Pad the lower dimension up to the expected input shape.
3388         int64_t size =
3389             padding_high[i] + shape[i] + (shape[i] - 1) * padding_interm[i];
3390         padding_low.push_back(input_dim - size);
3391       }
3392     }
3393 
3394     if (!dims_to_reverse.empty()) {
3395       grad = rewriter.create<mhlo::ReverseOp>(
3396           op.getLoc(), grad.getType(), grad,
3397           GetI64ElementsAttr(dims_to_reverse, &rewriter));
3398     }
3399 
3400     auto zero = GetScalarConstOfType(element_type, op.getLoc(), 0, &rewriter);
3401     rewriter.replaceOpWithNewOp<mhlo::PadOp>(
3402         op, op.getType(), grad, zero,
3403         GetI64ElementsAttr(padding_low, &rewriter),
3404         GetI64ElementsAttr(padding_high, &rewriter),
3405         GetI64ElementsAttr(padding_interm, &rewriter));
3406     return success();
3407   }
3408 };
3409 
3410 /// Converts the RangeOp tensorflow op to a mhlo.iota op with a scaling and
3411 /// offset applied to generate the range values. The output tensor needs to
3412 /// have a static shape.
3413 ///
3414 /// For example an op like the following:
3415 ///   %result = "tf.Range"(%start, %limit, %delta) {Tidx = "tfdtype$DT_FLOAT"}
3416 ///      : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<5xf32>
3417 ///
3418 /// Output would be:
3419 ///   %iota = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xf32>
3420 ///   %scaled = "mhlo.multiply"(%iota, %delta)
3421 ///       {broadcast_dimensions = dense<[]> : tensor<0xi64>} :
3422 ///       (tensor<5xf32>, tensor<f32>) -> tensor<5xf32>
3423 ///   %result = "mhlo.add"(%scaled, %offset)
3424 ///       {broadcast_dimensions = dense<[]> : tensor<0xi64>} :
3425 ///       (tensor<5xf32>, tensor<f32>) -> tensor<5xf32>
3426 ///
3427 /// Implementation is defined in C++ due to no type interface for the iota op.
3428 class ConvertRangeOp : public OpRewritePattern<TF::RangeOp> {
3429   using OpRewritePattern<TF::RangeOp>::OpRewritePattern;
3430 
matchAndRewrite(TF::RangeOp op,PatternRewriter & rewriter) const3431   LogicalResult matchAndRewrite(TF::RangeOp op,
3432                                 PatternRewriter &rewriter) const override {
3433     auto result = op.getResult();
3434     auto result_type = result.getType();
3435     if (!result_type.cast<ShapedType>().hasStaticShape()) {
3436       return failure();
3437     }
3438 
3439     auto iota = rewriter.create<IotaOp>(op.getLoc(), result_type,
3440                                         rewriter.getI64IntegerAttr(0));
3441     auto scaled = rewriter.create<chlo::BroadcastMulOp>(
3442         op.getLoc(), result_type, iota, op.delta(),
3443         hlo::getBroadcastDimensionsAttr(&rewriter, iota, op.delta()));
3444     rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
3445         op, result_type, scaled, op.start(),
3446         hlo::getBroadcastDimensionsAttr(&rewriter, scaled, op.start()));
3447     return success();
3448   }
3449 };
3450 
3451 // Converts RangeOp for cases with the length is a dynamic value. The shape of
3452 // the resulting tensor computed, then the start and delta is used with the
3453 // dynamic_iota value to compute the final range value.
3454 //
3455 // For example, the resulting range op value:
3456 //   %range = "tf.range"(%start, %limit, %delta)
3457 //
3458 // Is converted to the following.
3459 //   %start + %delta * iota(ceil(abs((%limit - %start) / %delta))
3460 //
3461 // Implementation is defined in C++ due to the complicated type behavior.
3462 class ConvertDynamicRangeOp : public OpRewritePattern<TF::RangeOp> {
3463   using OpRewritePattern<TF::RangeOp>::OpRewritePattern;
3464 
matchAndRewrite(TF::RangeOp op,PatternRewriter & rewriter) const3465   LogicalResult matchAndRewrite(TF::RangeOp op,
3466                                 PatternRewriter &rewriter) const override {
3467     auto result = op.getResult();
3468     auto result_type = result.getType().cast<ShapedType>();
3469     if (result_type.hasStaticShape()) {
3470       return failure();
3471     }
3472 
3473     Value start = op.start();
3474     Value delta = op.delta();
3475     Value limit = op.limit();
3476 
3477     // To compute the length we need to use floating point calculations so that
3478     // ceil can be computed for the number of steps.
3479     auto compute_element_type =
3480         getElementTypeOrSelf(start.getType()).isa<FloatType>()
3481             ? getElementTypeOrSelf(start.getType())
3482             : rewriter.getF64Type();
3483     auto compute_type = RankedTensorType::get(
3484         limit.getType().cast<ShapedType>().getShape(), compute_element_type);
3485 
3486     // Compute the length of the sequence we are going to need. This includes
3487     // some conversion to float for the operations.
3488     //
3489     // %size = ceil(abs((%limit - %start) / %delta))
3490     auto range = rewriter.create<mhlo::SubOp>(op.getLoc(), limit, start);
3491     auto abs = rewriter.create<mhlo::AbsOp>(op.getLoc(), range);
3492 
3493     // Delta is not necessarily the same type as start and limit.
3494     auto abs_cast =
3495         rewriter.create<mhlo::ConvertOp>(op.getLoc(), compute_type, abs);
3496     auto delta_cast =
3497         rewriter.create<mhlo::ConvertOp>(op.getLoc(), compute_type, delta);
3498 
3499     // Compute the total number of integer steps and convert to the HLO
3500     // dimension tensor.
3501     auto normalized =
3502         rewriter.create<mhlo::DivOp>(op.getLoc(), abs_cast, delta_cast);
3503     auto ceil = rewriter.create<mhlo::CeilOp>(op.getLoc(), normalized);
3504     auto steps = rewriter.create<mhlo::ConvertOp>(
3505         op.getLoc(), RankedTensorType::get({}, rewriter.getI64Type()), ceil);
3506     auto reshape = rewriter.create<mhlo::ReshapeOp>(
3507         op.getLoc(), RankedTensorType::get({1}, rewriter.getI64Type()), steps);
3508 
3509     // Using the resulting length compute the correct range value:
3510     //
3511     // %range = %start + %delta * iota(%size)
3512     auto out_scalar_type =
3513         RankedTensorType::get({}, getElementTypeOrSelf(result_type));
3514     auto start_out_cast =
3515         rewriter.create<mhlo::ConvertOp>(op.getLoc(), out_scalar_type, start);
3516     auto delta_out_cast =
3517         rewriter.create<mhlo::ConvertOp>(op.getLoc(), out_scalar_type, delta);
3518 
3519     auto iota = rewriter.create<DynamicIotaOp>(
3520         op.getLoc(), result_type, reshape, rewriter.getI64IntegerAttr(0));
3521     auto scaled = rewriter.create<chlo::BroadcastMulOp>(
3522         op.getLoc(), result_type, iota, delta_out_cast,
3523         hlo::getBroadcastDimensionsAttr(&rewriter, iota, delta_cast));
3524     rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
3525         op, result_type, scaled, start_out_cast,
3526         hlo::getBroadcastDimensionsAttr(&rewriter, scaled, start_out_cast));
3527     return success();
3528   }
3529 };
3530 
ConvertAxisAttr(Value val,ElementsAttr attr,Builder * builder)3531 ElementsAttr ConvertAxisAttr(Value val, ElementsAttr attr, Builder *builder) {
3532   auto int_attr = attr.cast<DenseIntElementsAttr>();
3533   auto type = val.getType().cast<ShapedType>();
3534 
3535   SmallVector<int64_t, 6> axis;
3536   axis.reserve(int_attr.getNumElements());
3537 
3538   int64_t rank = type.getRank();
3539   for (auto val : int_attr.getValues<APInt>()) {
3540     axis.push_back((val.getSExtValue() + rank) % rank);
3541   }
3542 
3543   return builder->getI64TensorAttr(axis);
3544 }
3545 
3546 /// Converts the LinSpace tensorflow op to a mhlo.iota op with a scaling
3547 /// and offset applied to generate the linspace values. The output tensor needs
3548 /// to have a static shape.  The implementation is defined in C++ because there
3549 /// is no type inference for the iota op.
3550 class ConvertLinSpaceOp : public OpRewritePattern<TF::LinSpaceOp> {
3551   using OpRewritePattern<TF::LinSpaceOp>::OpRewritePattern;
3552 
matchAndRewrite(TF::LinSpaceOp op,PatternRewriter & rewriter) const3553   LogicalResult matchAndRewrite(TF::LinSpaceOp op,
3554                                 PatternRewriter &rewriter) const override {
3555     auto result = op.getResult();
3556     auto result_type = result.getType().dyn_cast<ShapedType>();
3557     if (!result_type || !result_type.hasStaticShape()) {
3558       return failure();
3559     }
3560 
3561     DenseIntElementsAttr num_attr;
3562     if (!matchPattern(op.num(), m_Constant(&num_attr))) {
3563       return rewriter.notifyMatchFailure(op, "Num must be a constant scalar");
3564     }
3565 
3566     if (num_attr.begin() == num_attr.end()) {
3567       return rewriter.notifyMatchFailure(op, "Num must not be empty");
3568     }
3569     int64_t num = (*num_attr.begin()).getSExtValue();
3570 
3571     // Calculate the scaling that needs to be applied to the iota.
3572     auto step_numerator = rewriter.create<chlo::BroadcastSubOp>(
3573         op.getLoc(), op.start().getType(), op.stop(), op.start(),
3574         hlo::getBroadcastDimensionsAttr(&rewriter, op.stop(), op.start()));
3575     Value step_denominator = rewriter.create<ConvertOp>(
3576         op.getLoc(), op.num(), result_type.getElementType());
3577     if (num > 1) {
3578       Value one = GetScalarConstOfType(result_type.getElementType(),
3579                                        op.getLoc(), 1, &rewriter);
3580       step_denominator = rewriter.create<chlo::BroadcastSubOp>(
3581           op.getLoc(), step_denominator.getType(), step_denominator, one,
3582           hlo::getBroadcastDimensionsAttr(&rewriter, step_denominator, one));
3583     }
3584     auto step = rewriter.create<chlo::BroadcastDivOp>(
3585         op.getLoc(), step_numerator.getType(), step_numerator, step_denominator,
3586         hlo::getBroadcastDimensionsAttr(&rewriter, step_numerator,
3587                                         step_denominator));
3588 
3589     // Scale the iota and add the offset.
3590     auto iota = rewriter.create<IotaOp>(op.getLoc(), result_type,
3591                                         rewriter.getI64IntegerAttr(0));
3592     auto scaled = rewriter.create<chlo::BroadcastMulOp>(
3593         op.getLoc(), result_type, iota, step,
3594         hlo::getBroadcastDimensionsAttr(&rewriter, iota, step));
3595     rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
3596         op, result_type, scaled, op.start(),
3597         hlo::getBroadcastDimensionsAttr(&rewriter, scaled, op.start()));
3598     return success();
3599   }
3600 };
3601 
3602 /// Converts a generic OpTy tensorflow op to a mhlo.reduce op over
3603 /// ReductionOp.
3604 /// `is_accumulation` controls whether it uses higher precision for the actual
3605 /// reduction. This is set to false for ops like max where there is no precision
3606 /// concerns.
3607 //
3608 // The Derived class should have a static method to return the initial value to
3609 // use for reduction:
3610 //   static Value GetInitialValue(Type reduce_element_type, Location loc,
3611 //                                PatternRewriter *rewriter);
3612 // The reduce_element_type is guaranteed to be a float, int, or complex type
3613 // suitable for use with GetScalarConstOfType or GetScalarLimitConstOfType.
3614 template <typename Derived, typename OpTy, typename ReductionOp,
3615           bool is_accumulation = true>
3616 class GenericConvertReductionOp : public OpRewritePattern<OpTy> {
3617   using OpRewritePattern<OpTy>::OpRewritePattern;
3618 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const3619   LogicalResult matchAndRewrite(OpTy op,
3620                                 PatternRewriter &rewriter) const override {
3621     // TODO(b/141785544): Update this to not require static shapes.
3622     // Input shape needs to be static to convert negative indices in TensorFlow
3623     // to absolute indices required by HLO.
3624     auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
3625     if (!input_ty) return failure();
3626     ArrayRef<int64_t> input_shape = input_ty.getShape();
3627 
3628     DenseIntElementsAttr dimensions;
3629     if (!matchPattern(op.reduction_indices(), m_Constant(&dimensions)))
3630       return failure();
3631 
3632     // Build the final shape from input_shape and dimensions using a bitmap
3633     // to mark the reduced dimensions.
3634     SmallVector<bool, 4> reduced_dimensions_bitmap(input_shape.size(), false);
3635     SmallVector<int64_t, 4> xla_dimensions;
3636     for (const APInt &index_raw : dimensions.getValues<APInt>()) {
3637       int64_t index = index_raw.getSExtValue();
3638       int64_t rank = input_shape.size();
3639       if ((index < -rank || index >= rank)) return failure();
3640       index = (index + rank) % rank;
3641       reduced_dimensions_bitmap[index] = true;
3642       xla_dimensions.push_back(index);
3643     }
3644 
3645     Location loc = op.getLoc();
3646     Type element_type = input_ty.getElementType();
3647 
3648     // Only float, int, and complex types are currently supported.
3649     if (!element_type.isa<FloatType>() && !element_type.isa<IntegerType>() &&
3650         !element_type.isa<ComplexType>()) {
3651       return rewriter.notifyMatchFailure(
3652           op, "element type must be float, int, or complex type");
3653     }
3654 
3655     // Convert to an accumulation type to not lose precision when doing
3656     // repeated arithmetic operations.
3657     Type reduce_element_type =
3658         is_accumulation ? GetAccumulationType(element_type) : element_type;
3659     auto casted_input =
3660         rewriter.create<ConvertOp>(loc, op.input(), reduce_element_type);
3661 
3662     // Each reduction op can have a different initial value.
3663     Value init = Derived::GetInitialValue(reduce_element_type, loc, &rewriter);
3664 
3665     auto reduction = rewriter.create<ReduceOp>(
3666         loc, casted_input.getResult(), init,
3667         GetI64ElementsAttr(xla_dimensions, &rewriter));
3668     BuildReduceBody<ReductionOp>(reduce_element_type, &reduction.body(),
3669                                  &rewriter);
3670     Value result = reduction.getResult(0);
3671 
3672     // The mean op needs to divide by the product of the reduced dimensions.
3673     if (std::is_same<OpTy, TF::MeanOp>::value) {
3674       int64_t divisor_count = 1;
3675       for (size_t i = 0; i < input_shape.size(); ++i) {
3676         if (reduced_dimensions_bitmap[i]) {
3677           if (TensorType::isDynamic(input_shape[i])) {
3678             return failure();
3679           }
3680           divisor_count *= input_shape[i];
3681         }
3682       }
3683       auto divisor = GetScalarConstOfType(reduce_element_type, loc,
3684                                           divisor_count, &rewriter);
3685       auto broadcast_dims = GetI64ElementsAttr({}, &rewriter);
3686       result = rewriter.create<chlo::BroadcastDivOp>(
3687           loc, result, divisor.getResult(), broadcast_dims);
3688     }
3689 
3690     result = rewriter.create<ConvertOp>(loc, result, element_type);
3691 
3692     // Need to reshape back after the reduction if we're keeping the reduced
3693     // dimensions.
3694     if (op.keep_dims()) {
3695       result = rewriter.create<ReshapeOp>(loc, op.getType(), result);
3696     }
3697     rewriter.replaceOp(op, {result});
3698 
3699     return success();
3700   }
3701 };
3702 
3703 // Converts Mean op to HLO Reduce op.
3704 //
3705 //   %init = constant dense<...> : tensor<T>
3706 //   %sum = "mhlo.reduce"(%inp, %init) ["mhlo.add"]
3707 //               {dimensions = ...}
3708 //   %divisor = constant dense<...> : tensor<T>
3709 //   %mean = "mhlo.divide"(%sum, %divisor)
3710 class ConvertMeanOp
3711     : public GenericConvertReductionOp<ConvertMeanOp, TF::MeanOp, AddOp> {
3712  public:
3713   using GenericConvertReductionOp::GenericConvertReductionOp;
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)3714   static Value GetInitialValue(Type reduce_element_type, Location loc,
3715                                PatternRewriter *rewriter) {
3716     return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter);
3717   }
3718 };
3719 
3720 // Converts Sum op to HLO Reduce op.
3721 //
3722 //   %init = constant dense<...> : tensor<T>
3723 //   %sum = "mhlo.reduce"(%inp, %init) ["mhlo.add"]
3724 //               {dimensions = ...}
3725 class ConvertSumOp
3726     : public GenericConvertReductionOp<ConvertSumOp, TF::SumOp, AddOp> {
3727  public:
3728   using GenericConvertReductionOp::GenericConvertReductionOp;
3729 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)3730   static Value GetInitialValue(Type reduce_element_type, Location loc,
3731                                PatternRewriter *rewriter) {
3732     return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter);
3733   }
3734 };
3735 
3736 // Converts Max op to HLO Reduce op.
3737 //
3738 //   %init = constant dense<...> : tensor<T>
3739 //   %max = "mhlo.reduce"(%inp, %init) ["mhlo.maximum"]
3740 //               {dimensions = ...}
3741 class ConvertMaxOp
3742     : public GenericConvertReductionOp<ConvertMaxOp, TF::MaxOp, MaxOp,
3743                                        /* is_accumulation= */ false> {
3744  public:
3745   using GenericConvertReductionOp::GenericConvertReductionOp;
3746 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)3747   static Value GetInitialValue(Type reduce_element_type, Location loc,
3748                                PatternRewriter *rewriter) {
3749     return GetScalarLimitConstOfType(reduce_element_type, loc,
3750                                      hlo::kInfinityLowest, rewriter);
3751   }
3752 };
3753 
3754 // Converts Min op to HLO Reduce op.
3755 //
3756 //   %init = constant dense<...> : tensor<T>
3757 //   %min = "mhlo.reduce"(%inp, %init) ["mhlo.minimum"]
3758 //               {dimensions = ...}
3759 class ConvertMinOp
3760     : public GenericConvertReductionOp<ConvertMinOp, TF::MinOp, MinOp,
3761                                        /* is_accumulation= */ false> {
3762  public:
3763   using GenericConvertReductionOp::GenericConvertReductionOp;
3764 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)3765   static Value GetInitialValue(Type reduce_element_type, Location loc,
3766                                PatternRewriter *rewriter) {
3767     return GetScalarLimitConstOfType(reduce_element_type, loc,
3768                                      hlo::kInfinityMax, rewriter);
3769   }
3770 };
3771 
3772 // Converts Prod op to HLO Reduce op.
3773 //
3774 //   %init = constant dense<...> : tensor<T>
3775 //   %prod = "mhlo.reduce"(%inp, %init) ["mhlo.multiply"]
3776 //               {dimensions = ...}
3777 class ConvertProdOp
3778     : public GenericConvertReductionOp<ConvertProdOp, TF::ProdOp, MulOp> {
3779  public:
3780   using GenericConvertReductionOp::GenericConvertReductionOp;
3781 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)3782   static Value GetInitialValue(Type reduce_element_type, Location loc,
3783                                PatternRewriter *rewriter) {
3784     return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter);
3785   }
3786 };
3787 
3788 // Converts All op to HLO Reduce op.
3789 //
3790 //   %init = constant dense<...> : tensor<T>
3791 //   %max = "mhlo.reduce"(%inp, %init) ["mhlo.and"]
3792 //               {dimensions = ...}
3793 class ConvertAllOp
3794     : public GenericConvertReductionOp<ConvertAllOp, TF::AllOp, AndOp> {
3795  public:
3796   using GenericConvertReductionOp::GenericConvertReductionOp;
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)3797   static Value GetInitialValue(Type reduce_element_type, Location loc,
3798                                PatternRewriter *rewriter) {
3799     return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter);
3800   }
3801 };
3802 
3803 // Converts Any op to HLO Reduce op.
3804 //
3805 //   %init = constant dense<...> : tensor<T>
3806 //   %max = "mhlo.reduce"(%inp, %init) ["mhlo.or"]
3807 //               {dimensions = ...}
3808 class ConvertAnyOp
3809     : public GenericConvertReductionOp<ConvertAnyOp, TF::AnyOp, OrOp> {
3810  public:
3811   using GenericConvertReductionOp::GenericConvertReductionOp;
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)3812   static Value GetInitialValue(Type reduce_element_type, Location loc,
3813                                PatternRewriter *rewriter) {
3814     return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter);
3815   }
3816 };
3817 
3818 // Converts tensorflow ArgMin or ArgMax op to mhlo operations that perform
3819 // a reduction on the original input and the corresponding index. The reduction
3820 // sub-computation selects the max (or min) value and the index for the value.
3821 //   Derived: is the resulting derived class of this class.
3822 //   OpTy: is TF::ArgMaxOp or TF::ArgMinOp.
3823 template <typename Derived, typename OpTy>
3824 class ConvertArgMinMaxOp : public OpRewritePattern<OpTy> {
3825   using OpRewritePattern<OpTy>::OpRewritePattern;
3826 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const3827   LogicalResult matchAndRewrite(OpTy op,
3828                                 PatternRewriter &rewriter) const override {
3829     RankedTensorType input_type =
3830         op.input().getType().template dyn_cast<RankedTensorType>();
3831     if (!input_type) {
3832       return failure();
3833     }
3834 
3835     Type input_element_type = input_type.getElementType();
3836     // TODO(bixia): Clarify whether tf.ArgMax supports complex data types. If
3837     // tf.ArgMax doesn't support complex data types, this check can be removed.
3838     if (!input_element_type.isSignlessIntOrFloat()) return failure();
3839 
3840     Location loc = op.getLoc();
3841     Value init_value =
3842         Derived::GetInitialValue(input_element_type, loc, rewriter);
3843 
3844     RankedTensorType output_type =
3845         op.output().getType().template dyn_cast<RankedTensorType>();
3846     if (!output_type) {
3847       return failure();
3848     }
3849 
3850     Type index_element_type = output_type.getElementType();
3851     Value index_init_value =
3852         GetScalarConstOfType(index_element_type, loc, 0, &rewriter);
3853 
3854     RankedTensorType index_type =
3855         RankedTensorType::get(input_type.getShape(), index_element_type);
3856 
3857     llvm::Optional<int64_t> optional_axis =
3858         GetIntegerHLOAxisFromTFAxis(op.dimension(), input_type.getRank());
3859     if (!optional_axis.hasValue()) {
3860       return failure();
3861     }
3862     int64_t axis = optional_axis.getValue();
3863 
3864     IntegerAttr iota_dimension =
3865         IntegerAttr::get(rewriter.getIntegerType(64), axis);
3866     Value index_values =
3867         rewriter.create<IotaOp>(loc, index_type, iota_dimension);
3868 
3869     std::vector<int64_t> dimensions = input_type.getShape();
3870     dimensions.erase(dimensions.begin() + axis);
3871     ArrayRef<int64_t> reduction_result_shape(dimensions);
3872 
3873     Value operands[] = {op.input(), index_values};
3874     Value init_values[] = {init_value, index_init_value};
3875     DenseIntElementsAttr reduction_dimensions =
3876         GetI64ElementsAttr({axis}, &rewriter);
3877 
3878     auto reduction = rewriter.create<ReduceOp>(
3879         loc, llvm::ArrayRef<Value>(operands),
3880         llvm::ArrayRef<Value>(init_values), reduction_dimensions);
3881     StringRef direction = Derived::GetDirection();
3882     BuildArgMinMaxReductionBody(input_element_type, index_element_type,
3883                                 direction, &reduction.body(), &rewriter);
3884 
3885     rewriter.replaceOp(op, {reduction.getResult(1)});
3886     return success();
3887   }
3888 };
3889 
3890 // Converts tensorflow ArgMax op to mhlo operations. The actual
3891 // implementation is in class ConvertArgMinMaxOp:
3892 //
3893 //   %init_index = constant dense<...> : tensor<T>
3894 //   %init = constant dense<...> : tensor<T>
3895 //   %reduce = "mhlo.reduce"(%selected_input, %select_index, %init,
3896 //                              %init_index) ["mhlo.arg_max"]
3897 class ConvertArgMaxOp
3898     : public ConvertArgMinMaxOp<ConvertArgMaxOp, TF::ArgMaxOp> {
3899  public:
3900   using ConvertArgMinMaxOp::ConvertArgMinMaxOp;
3901 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter & rewriter)3902   static Value GetInitialValue(Type reduce_element_type, Location loc,
3903                                PatternRewriter &rewriter) {
3904     return GetScalarLimitConstOfType(reduce_element_type, loc,
3905                                      hlo::kInfinityLowest, &rewriter);
3906   }
3907 
GetDirection()3908   static StringRef GetDirection() { return "GT"; }
3909 };
3910 
3911 // Converts TF TensorScatterUpdate op into Scatter Op with assignment:
3912 //
3913 //   %result = "mhlo.scatter"(%tensor, %indices, %updates)
3914 //     { dimensions = ... }
3915 //
3916 class ConvertTensorScatterUpdateOp
3917     : public OpRewritePattern<TF::TensorScatterUpdateOp> {
3918  public:
3919   using OpRewritePattern::OpRewritePattern;
3920 
matchAndRewrite(TF::TensorScatterUpdateOp op,PatternRewriter & rewriter) const3921   LogicalResult matchAndRewrite(TF::TensorScatterUpdateOp op,
3922                                 PatternRewriter &rewriter) const override {
3923     auto tensor_ty = op.tensor().getType().dyn_cast<RankedTensorType>();
3924     auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
3925     auto updates_ty = op.updates().getType().dyn_cast<RankedTensorType>();
3926 
3927     if (!tensor_ty || !indices_ty || !updates_ty) return failure();
3928     // Last dimension of the indices needs to known at compile time for
3929     // computation of the 'update_window_dims' attribute in the dimensions
3930     // struct.
3931     int64_t num_index_dims = indices_ty.getShape().back();
3932     if (ShapedType::isDynamic(num_index_dims)) return failure();
3933 
3934     int64_t tensor_rank = tensor_ty.getRank();
3935     int64_t indices_rank = indices_ty.getRank();
3936     int64_t updates_rank = updates_ty.getRank();
3937 
3938     int64_t window_dims = tensor_rank - num_index_dims;
3939     auto dims_attr = ScatterDimensionNumbers::get(
3940         GetI64ElementsAttrForSeq(updates_rank - window_dims, updates_rank,
3941                                  &rewriter),
3942         GetI64ElementsAttrForSeq(0, num_index_dims, &rewriter),
3943         GetI64ElementsAttrForSeq(0, num_index_dims, &rewriter),
3944         rewriter.getI64IntegerAttr(indices_rank - 1), rewriter.getContext());
3945 
3946     Location loc = op.getLoc();
3947     auto scatter = rewriter.create<ScatterOp>(
3948         loc, op.getType(), op.tensor(), op.indices(), op.updates(), dims_attr);
3949 
3950     // Build region to assign the new value.
3951     [&](Region *region) {
3952       OpBuilder::InsertionGuard guard(rewriter);
3953       Block *block = rewriter.createBlock(region);
3954 
3955       // Block arguments are scalars of the given element type.
3956       Type type =
3957           RankedTensorType::get(/*shape=*/{}, tensor_ty.getElementType());
3958       block->addArguments({type, type});
3959       rewriter.create<ReturnOp>(loc, block->getArgument(1));
3960     }(&scatter.update_computation());
3961 
3962     rewriter.replaceOp(op, scatter.getResult());
3963     return success();
3964   }
3965 };
3966 
3967 // Converts Tile op to HLO BroadcastInDim and Reshape ops.
3968 //   For shape [S1, S2] and multiples [M1, M2],
3969 //     MS1 = M1 * S1; MS2 = M2 * S2
3970 //
3971 //   %broadcast = mhlo.broadcast_in_dim(%input) {
3972 //     broadcast_dimensions = [0, 2]
3973 //   }
3974 //   %result = "mhlo.reshape"(%broadcast) : (tensor<S1xM1xS2xM2xf32>)
3975 //      -> tensor<MS1xMS2xf32>
3976 class ConvertTileOp : public OpRewritePattern<TF::TileOp> {
3977  public:
3978   using OpRewritePattern::OpRewritePattern;
3979 
matchAndRewrite(TF::TileOp op,PatternRewriter & rewriter) const3980   LogicalResult matchAndRewrite(TF::TileOp op,
3981                                 PatternRewriter &rewriter) const override {
3982     auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
3983     if (!input_ty || !input_ty.hasStaticShape()) return failure();
3984     ArrayRef<int64_t> input_shape = input_ty.getShape();
3985     Type element_type = input_ty.getElementType();
3986 
3987     DenseIntElementsAttr multiples;
3988     if (!matchPattern(op.multiples(), m_Constant(&multiples)) ||
3989         multiples.getType().getRank() != 1)
3990       return failure();
3991 
3992     const int64_t input_shape_size = input_shape.size();
3993     if (multiples.getNumElements() != input_shape_size) return failure();
3994 
3995     SmallVector<int64_t, 8> broadcasted_shape;
3996     SmallVector<int64_t, 4> broadcast_dimensions;
3997     broadcasted_shape.reserve(input_shape.size() * 2);
3998     broadcast_dimensions.reserve(input_shape.size());
3999     for (auto multiple_and_input :
4000          llvm::zip(multiples.getValues<APInt>(), input_shape)) {
4001       int64_t multiple = std::get<0>(multiple_and_input).getSExtValue();
4002       int64_t input_size = std::get<1>(multiple_and_input);
4003 
4004       if (multiple < 0) return failure();
4005 
4006       // Line input up with the next dimension in broadcasted_shape
4007       // when broadcasting.
4008       int64_t broadcast_dim;
4009       int64_t output_size = input_size * multiple;
4010       if (input_size == 1 || multiple == 1) {
4011         // Special case for when normal broadcasting will just work.
4012         broadcast_dim = broadcasted_shape.size();
4013         broadcasted_shape.push_back(output_size);
4014       } else {
4015         // Tiling will happen for this dimension during the ReshapeOp below.
4016         broadcasted_shape.push_back(multiple);
4017         broadcast_dim = broadcasted_shape.size();
4018         broadcasted_shape.push_back(input_size);
4019       }
4020       broadcast_dimensions.push_back(broadcast_dim);
4021     }
4022     Location loc = op.getLoc();
4023     Type broadcasted_type =
4024         RankedTensorType::get(broadcasted_shape, element_type);
4025     Type output_type = op.getType();
4026 
4027     Value result = rewriter.create<BroadcastInDimOp>(
4028         loc, broadcasted_type, op.input(),
4029         GetI64ElementsAttr(broadcast_dimensions, &rewriter));
4030 
4031     if (output_type != broadcasted_type) {
4032       result = rewriter.create<ReshapeOp>(loc, output_type, result);
4033     }
4034 
4035     rewriter.replaceOp(op, {result});
4036 
4037     return success();
4038   }
4039 };
4040 
4041 template <typename OpTy, int num_dims>
4042 class ConvertMaxPoolGradOp : public OpRewritePattern<OpTy> {
4043  public:
4044   using OpRewritePattern<OpTy>::OpRewritePattern;
4045 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const4046   LogicalResult matchAndRewrite(OpTy op,
4047                                 PatternRewriter &rewriter) const override {
4048     Location loc = op.getLoc();
4049 
4050     Type element_type =
4051         op.orig_input().getType().template cast<TensorType>().getElementType();
4052 
4053     // Compute paddings using the original input and kernel shape and strides.
4054     // Here, ReduceWindow op as used as the MaxPool op is lowered to the
4055     // ReduceWindow op.
4056     auto input_ty =
4057         op.orig_input().getType().template dyn_cast<RankedTensorType>();
4058     if (!input_ty) return failure();
4059     DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr<num_dims>(
4060         input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter);
4061 
4062     auto result = rewriter.create<SelectAndScatterOp>(
4063         loc, op.getType(), op.orig_input(), op.grad(),
4064         GetScalarConstOfType(element_type, loc, 0, &rewriter),
4065         GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()),
4066         paddings_attr);
4067 
4068     BuildReduceBody<AddOp>(element_type, &result.scatter(), &rewriter);
4069     {
4070       OpBuilder::InsertionGuard guard(rewriter);
4071       Block *block = rewriter.createBlock(&result.select());
4072 
4073       // Block arguments are scalars of the given element type.
4074       Type type = RankedTensorType::get(/*shape=*/{}, element_type);
4075       block->addArguments({type, type});
4076 
4077       auto reducer = rewriter.create<CompareOp>(
4078           loc, block->getArgument(0), block->getArgument(1),
4079           StringAttr::get(rewriter.getContext(), "GE"));
4080       rewriter.create<ReturnOp>(loc, reducer.getResult());
4081     }
4082 
4083     rewriter.replaceOp(op, {result});
4084 
4085     return success();
4086   }
4087 };
4088 
4089 using ConvertMaxPool2DGradOp =
4090     ConvertMaxPoolGradOp<TF::MaxPoolGradOp, /*num_dims=*/4>;
4091 using ConvertMaxPool3DGradOp =
4092     ConvertMaxPoolGradOp<TF::MaxPool3DGradOp, /*num_dims=*/5>;
4093 
4094 // Converts tf.Conv?DBackpropInputOp into:
4095 //   %rev_filter = "mhlo.reverse"(%filter)
4096 //   %result = "mhlo.convolution"(%out_backprop, %rev_filter)
4097 template <typename OpTy, int num_spatial_dims>
4098 class ConvertConvBackpropInputOp : public OpRewritePattern<OpTy> {
4099  public:
4100   using OpRewritePattern<OpTy>::OpRewritePattern;
4101 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const4102   LogicalResult matchAndRewrite(OpTy op,
4103                                 PatternRewriter &rewriter) const override {
4104     // Unpack all of the attributes.
4105     tensorflow::TensorFormat data_format;
4106     if (!FormatFromString(op.data_format().str(), &data_format))
4107       return failure();
4108 
4109     tensorflow::Padding padding;
4110     if (!GetPaddingFromString(op.padding().str(), &padding).ok())
4111       return failure();
4112 
4113     auto out_backprop_ty =
4114         op.out_backprop().getType().template dyn_cast<RankedTensorType>();
4115     auto filter_ty =
4116         op.filter().getType().template dyn_cast<RankedTensorType>();
4117 
4118     for (RankedTensorType ty : {out_backprop_ty, filter_ty})
4119       if (!ty || !ty.hasStaticShape()) return failure();
4120 
4121     DenseIntElementsAttr input_shape_attr;
4122     if (!matchPattern(op.input_sizes(), m_Constant(&input_shape_attr)) ||
4123         input_shape_attr.getType().getRank() != 1)
4124       return failure();
4125 
4126     auto input_shape = input_shape_attr.getValues<int32_t>();
4127 
4128     auto dilations_attr = GetI64ElementsAttr(op.dilations());
4129     std::vector<int> dilations{
4130         dilations_attr.template getValues<int64_t>().begin(),
4131         dilations_attr.template getValues<int64_t>().end()};
4132     auto strides_attr = GetI64ElementsAttr(op.strides());
4133     std::vector<tensorflow::int32> strides{
4134         strides_attr.template getValues<int64_t>().begin(),
4135         strides_attr.template getValues<int64_t>().end()};
4136 
4137     std::vector<tensorflow::int64> explicit_paddings;
4138     if (padding == tensorflow::Padding::EXPLICIT) {
4139       // EXPLICIT padding mode and the associated attribute is limited to
4140       // Conv2DBackpropInput. So, fetch attribute by identifier instead of the
4141       // op.explicit_paddings() attribute getter.
4142       ArrayRef<Attribute> explicit_paddings_attr =
4143           op->template getAttrOfType<ArrayAttr>("explicit_paddings").getValue();
4144       explicit_paddings.reserve(explicit_paddings_attr.size());
4145       for (Attribute explicit_padding : explicit_paddings_attr)
4146         explicit_paddings.push_back(
4147             explicit_padding.cast<IntegerAttr>().getInt());
4148     }
4149 
4150     constexpr int num_dims = num_spatial_dims + 2;
4151     ArrayRef<int64_t> filter_shape = filter_ty.getShape();
4152 
4153     // Reuse dimension computation logic from conv_grad_shape_utils.cc.
4154     tensorflow::ConvBackpropDimensions dims;
4155     if (!tensorflow::ConvBackpropComputeDimensionsV2(
4156              /*label=*/"", num_spatial_dims,
4157              ToTensorShape<int32_t, num_dims>(input_shape),
4158              ToTensorShape<int64_t, num_dims>(filter_shape),
4159              ToTensorShape<int64_t, num_dims>(out_backprop_ty.getShape()),
4160              dilations, strides, padding, explicit_paddings, data_format, &dims)
4161              .ok()) {
4162       return failure();
4163     }
4164 
4165     // Compute ConvDimensionNumbers, dilation, and padding.
4166     SmallVector<int64_t, num_spatial_dims> spatial_dims;
4167     SmallVector<int64_t, num_spatial_dims> lhs_dilation;
4168     SmallVector<int64_t, num_spatial_dims> rhs_dilation;
4169     SmallVector<int64_t, num_spatial_dims * 2> paddings;
4170 
4171     for (int i : llvm::seq<int>(0, num_spatial_dims)) {
4172       const int64_t dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
4173       spatial_dims.push_back(dim);
4174       const auto &spatial_dim_i = dims.spatial_dims[i];
4175       lhs_dilation.push_back(spatial_dim_i.stride);
4176       rhs_dilation.push_back(dilations[dim]);
4177       paddings.push_back(spatial_dim_i.pad_before);
4178       paddings.push_back(spatial_dim_i.pad_after);
4179     }
4180 
4181     RankedTensorType paddings_ty = RankedTensorType::get(
4182         {num_spatial_dims, 2}, rewriter.getIntegerType(64));
4183     auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, paddings);
4184 
4185     auto spatial_dims_attr = GetI64ElementsAttr(spatial_dims, &rewriter);
4186 
4187     Value filter = op.filter();
4188 
4189     const int feature_dim =
4190         tensorflow::GetTensorFeatureDimIndex(num_dims, data_format);
4191     const int64_t in_depth = *(input_shape.begin() + feature_dim);
4192     const int64_t filter_in_depth = filter_shape[num_spatial_dims];
4193     const int64_t feature_group_count = in_depth / filter_in_depth;
4194 
4195     if (feature_group_count != 1) {
4196       // 1. Reshape filter from
4197       //   [H, W, ..., filter_in_depth, out_depth] to
4198       //   [H, W, ..., filter_in_depth, G, out_depth / G].
4199       auto new_shape = llvm::to_vector<6>(filter_shape);
4200       new_shape.back() = feature_group_count;
4201       new_shape.push_back(filter_shape.back() / feature_group_count);
4202       Type filter_element_ty = filter_ty.getElementType();
4203       auto ty = RankedTensorType::get(new_shape, filter_element_ty);
4204       filter = rewriter.create<ReshapeOp>(op.getLoc(), ty, filter);
4205 
4206       // 2. Transpose to [H, W, ..., G, filter_in_depth, out_depth / G].
4207       llvm::SmallVector<int64_t, 6> perm(num_dims + 1);
4208       std::iota(perm.begin(), perm.end(), 0);
4209       std::swap(perm[num_spatial_dims], perm[num_spatial_dims + 1]);
4210       std::swap(new_shape[num_spatial_dims], new_shape[num_spatial_dims + 1]);
4211       ty = RankedTensorType::get(new_shape, filter_element_ty);
4212       filter = rewriter.create<TransposeOp>(
4213           op.getLoc(), ty, filter, GetI64ElementsAttr(perm, &rewriter));
4214 
4215       // 3. Reshape to [H, W, ..., in_depth, out_depth / G].
4216       new_shape[num_spatial_dims] *= new_shape[num_spatial_dims + 1];
4217       new_shape[num_spatial_dims + 1] = new_shape.back();
4218       new_shape.pop_back();
4219       ty = RankedTensorType::get(new_shape, filter_element_ty);
4220       filter = rewriter.create<ReshapeOp>(op.getLoc(), ty, filter);
4221     }
4222 
4223     auto kernel_spatial_dims_attr =
4224         GetI64ElementsAttrForSeq(0, num_spatial_dims, &rewriter);
4225 
4226     // Mirror the filter in the spatial dimensions.
4227     filter = rewriter.create<ReverseOp>(op.getLoc(), filter,
4228                                         kernel_spatial_dims_attr);
4229 
4230     const int batch_dim =
4231         tensorflow::GetTensorBatchDimIndex(num_dims, data_format);
4232     auto batch_dim_attr = rewriter.getI64IntegerAttr(batch_dim);
4233     auto feature_dim_attr = rewriter.getI64IntegerAttr(feature_dim);
4234 
4235     // activation gradients
4236     //   = gradients (with padding and dilation) <conv> mirrored_weights
4237     Value result = rewriter.create<ConvOp>(
4238         op.getLoc(), op.getType(), op.out_backprop(), filter,
4239         /*window_strides=*/
4240         GetI64ElementsAttrForValue(/*size=*/num_spatial_dims, /*val=*/1,
4241                                    &rewriter),
4242         /*padding=*/paddings_attr, GetI64ElementsAttr(lhs_dilation, &rewriter),
4243         GetI64ElementsAttr(rhs_dilation, &rewriter),
4244         /*window_reversal=*/nullptr,
4245         ConvDimensionNumbers::get(
4246             /*input_batch_dimension=*/batch_dim_attr,
4247             /*input_feature_dimension=*/feature_dim_attr,
4248             /*input_spatial_dimensions=*/spatial_dims_attr,
4249             // TF filter shape is [ H, W, ..., inC, outC ]
4250             // Transpose the input and output features for computing the
4251             // gradient.
4252             /*kernel_input_feature_dimension=*/
4253             rewriter.getI64IntegerAttr(num_spatial_dims + 1),
4254             /*kernel_output_feature_dimension=*/
4255             rewriter.getI64IntegerAttr(num_spatial_dims),
4256             /*kernel_spatial_dimensions=*/kernel_spatial_dims_attr,
4257             /*output_batch_dimension=*/batch_dim_attr,
4258             /*output_feature_dimension=*/feature_dim_attr,
4259             /*output_spatial_dimensions=*/spatial_dims_attr,
4260             rewriter.getContext()),
4261         rewriter.getI64IntegerAttr(feature_group_count),
4262         /*batch_group_count=*/rewriter.getI64IntegerAttr(1),
4263         /*precision_config=*/ArrayAttr());
4264 
4265     rewriter.replaceOp(op, {result});
4266 
4267     return success();
4268   }
4269 };
4270 
4271 using ConvertConv2DBackpropInputOp =
4272     ConvertConvBackpropInputOp<TF::Conv2DBackpropInputOp,
4273                                /*num_spatial_dims=*/2>;
4274 using ConvertConv3DBackpropInputOp =
4275     ConvertConvBackpropInputOp<TF::Conv3DBackpropInputV2Op,
4276                                /*num_spatial_dims=*/3>;
4277 
4278 // Converts tf.Conv?DBackpropFilterOp into:
4279 //   %result = "mhlo.convolution"(%input, %out_backprop)
4280 template <typename OpTy, int num_spatial_dims>
4281 class ConvertConvBackpropFilterOp : public OpRewritePattern<OpTy> {
4282  public:
4283   using OpRewritePattern<OpTy>::OpRewritePattern;
4284 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const4285   LogicalResult matchAndRewrite(OpTy op,
4286                                 PatternRewriter &rewriter) const override {
4287     // Unpack all of the attributes.
4288     tensorflow::TensorFormat data_format;
4289     if (!FormatFromString(op.data_format().str(), &data_format))
4290       return failure();
4291 
4292     tensorflow::Padding padding;
4293     if (!GetPaddingFromString(op.padding().str(), &padding).ok())
4294       return failure();
4295 
4296     auto out_backprop_ty =
4297         op.out_backprop().getType().template dyn_cast<RankedTensorType>();
4298     auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
4299 
4300     for (RankedTensorType ty : {out_backprop_ty, input_ty})
4301       if (!ty || !ty.hasStaticShape()) return failure();
4302 
4303     ArrayRef<int64_t> out_backprop_shape = out_backprop_ty.getShape();
4304     ArrayRef<int64_t> input_shape = input_ty.getShape();
4305 
4306     DenseIntElementsAttr filter_shape_attr;
4307     if (!matchPattern(op.filter_sizes(), m_Constant(&filter_shape_attr)) ||
4308         filter_shape_attr.getType().getRank() != 1)
4309       return failure();
4310 
4311     auto dilations_attr = GetI64ElementsAttr(op.dilations());
4312     std::vector<int> dilations{
4313         dilations_attr.template getValues<int64_t>().begin(),
4314         dilations_attr.template getValues<int64_t>().end()};
4315     auto strides_attr = GetI64ElementsAttr(op.strides());
4316     std::vector<tensorflow::int32> strides{
4317         strides_attr.template getValues<int64_t>().begin(),
4318         strides_attr.template getValues<int64_t>().end()};
4319 
4320     std::vector<tensorflow::int64> explicit_paddings;
4321     if (padding == tensorflow::Padding::EXPLICIT) {
4322       // EXPLICIT padding mode and the associated attribute is limited to
4323       // Conv2DBackpropFilter. So, fetch attribute by identifier instead of the
4324       // op.explicit_paddings() attribute getter.
4325       ArrayRef<Attribute> explicit_paddings_attr =
4326           op->template getAttrOfType<ArrayAttr>("explicit_paddings").getValue();
4327       explicit_paddings.reserve(explicit_paddings_attr.size());
4328       for (Attribute explicit_padding : explicit_paddings_attr)
4329         explicit_paddings.push_back(
4330             explicit_padding.cast<IntegerAttr>().getInt());
4331     }
4332 
4333     constexpr int num_dims = num_spatial_dims + 2;
4334     auto filter_shape = filter_shape_attr.getValues<int32_t>();
4335 
4336     // Reuse dimension computation logic from conv_grad_shape_utils.cc.
4337     tensorflow::ConvBackpropDimensions dims;
4338     if (!tensorflow::ConvBackpropComputeDimensionsV2(
4339              /*label=*/"", num_spatial_dims,
4340              ToTensorShape<int64_t, num_dims>(input_shape),
4341              ToTensorShape<int32_t, num_dims>(filter_shape),
4342              ToTensorShape<int64_t, num_dims>(out_backprop_shape), dilations,
4343              strides, padding, explicit_paddings, data_format, &dims)
4344              .ok()) {
4345       return failure();
4346     }
4347 
4348     // The activations (inputs) form the LHS of the convolution.
4349     // Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
4350     // For the gradient computation, we need to:
4351     // 1. In the case of group convolution, move the num_groups dimension before
4352     // the batch dimension
4353     // 2. Swap the roles of the batch and feature dimensions.
4354     const int feature_dim =
4355         tensorflow::GetTensorFeatureDimIndex(num_dims, data_format);
4356     const int64_t in_depth = input_shape[feature_dim];
4357     const int64_t filter_in_depth = *(filter_shape.begin() + num_spatial_dims);
4358     const int64_t batch_group_count = in_depth / filter_in_depth;
4359 
4360     // Compute ConvDimensionNumbers, dilation, and padding.
4361     SmallVector<int64_t, num_spatial_dims> spatial_dims;
4362     SmallVector<int64_t, num_spatial_dims> kernel_spatial_dims;
4363     SmallVector<int64_t, num_spatial_dims> rhs_dilation;
4364     SmallVector<int64_t, num_spatial_dims * 2> paddings;
4365     SmallVector<int64_t, num_spatial_dims> window_strides;
4366 
4367     // The filter gradients are computed by a convolution of the input
4368     // activations and the output gradients, with some appropriate padding.
4369     // See the comment at the top of conv_grad_ops.h for details.
4370 
4371     for (int i : llvm::seq<int>(0, num_spatial_dims)) {
4372       const int64_t dim =
4373           tensorflow::GetTensorSpatialDimIndex(num_dims, data_format, i);
4374       kernel_spatial_dims.push_back(dim);
4375       // Besides padding the input, we will also expand output_rows to
4376       //    expanded_out_rows = (output_rows - 1) * stride + 1
4377       // with zeros in between:
4378       //
4379       //      a . . . b . . . c . . . d . . . e
4380       //
4381       // This is done by specifying the window dilation factors in the
4382       // convolution HLO below.
4383       const auto &spatial_dim_i = dims.spatial_dims[i];
4384       rhs_dilation.push_back(spatial_dim_i.stride);
4385       window_strides.push_back(dilations[dim]);
4386 
4387       // We will also need to pad the input with zeros such that after the
4388       // convolution, we get the right size for the filter.
4389       // The padded_in_rows should be such that when we convolve this with the
4390       // expanded_out_rows as a filter, we should get filter_rows back.
4391 
4392       const int64_t padded_in_size =
4393           spatial_dim_i.expanded_output_size +
4394           (spatial_dim_i.filter_size - 1) * dilations[dim];
4395 
4396       // However it can be smaller than input_rows: in this
4397       // case it means some of the inputs are not used.
4398       //
4399       // An example is to have input_cols = 3, filter_cols = 2 and stride = 2:
4400       //
4401       // INPUT =  [ A  B  C ]
4402       //
4403       // FILTER = [ x y ]
4404       //
4405       // and the output will only have one column: a = A * x + B * y
4406       //
4407       // and input "C" is not used at all.
4408       //
4409       // We apply negative padding in this case.
4410       const int64_t pad_total = padded_in_size - spatial_dim_i.input_size;
4411 
4412       // + For the EXPLICIT padding, we pad the top/left side with the explicit
4413       //   padding and pad the bottom/right side with the remaining space.
4414       // + For the VALID padding, we don't pad anything on the top/left side
4415       //   and pad the bottom/right side with the remaining space.
4416       // + For the SAME padding, we pad top/left side the same as bottom/right
4417       //   side.
4418       //
4419       // In addition, if the padded input size is smaller than the input size,
4420       // we need to ignore some training elements of the input. We do this by
4421       // applying negative padding on the right/bottom.
4422       const int64_t pad_before = padding == tensorflow::Padding::EXPLICIT
4423                                      ? explicit_paddings[2 * dim]
4424                                      : padding == tensorflow::Padding::SAME
4425                                            ? std::max<int64_t>(pad_total / 2, 0)
4426                                            : 0;
4427       paddings.push_back(pad_before);
4428       paddings.push_back(pad_total - pad_before);
4429     }
4430 
4431     RankedTensorType paddings_ty = RankedTensorType::get(
4432         {num_spatial_dims, 2}, rewriter.getIntegerType(64));
4433     auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, paddings);
4434     auto kernel_spatial_dims_attr =
4435         GetI64ElementsAttr(kernel_spatial_dims, &rewriter);
4436 
4437     const int batch_dim =
4438         tensorflow::GetTensorBatchDimIndex(num_dims, data_format);
4439     auto batch_dim_attr = rewriter.getI64IntegerAttr(batch_dim);
4440     auto feature_dim_attr = rewriter.getI64IntegerAttr(feature_dim);
4441 
4442     Value result = rewriter.create<ConvOp>(
4443         op.getLoc(), op.getType(), op.input(), op.out_backprop(),
4444         /*window_strides=*/GetI64ElementsAttr(window_strides, &rewriter),
4445         /*padding=*/paddings_attr, /*lhs_dilation=*/
4446         GetI64ElementsAttrForValue(/*size=*/num_spatial_dims, /*val=*/1,
4447                                    &rewriter),
4448         GetI64ElementsAttr(rhs_dilation, &rewriter),
4449         /*window_reversal=*/nullptr,
4450         ConvDimensionNumbers::get(
4451             // Swap batch_dim and feature_dim in the activations.
4452             /*input_batch_dimension=*/feature_dim_attr,
4453             /*input_feature_dimension=*/batch_dim_attr,
4454             /*input_spatial_dimensions=*/kernel_spatial_dims_attr,
4455             // The gradients become the RHS of the convolution.
4456             // The gradients have shape [batch, out_rows, out_cols, ...,
4457             // out_depth] where the batch becomes the input feature for the
4458             // convolution.
4459             /*kernel_input_feature_dimension=*/batch_dim_attr,
4460             /*kernel_output_feature_dimension=*/feature_dim_attr,
4461             /*kernel_spatial_dimensions=*/kernel_spatial_dims_attr,
4462             /*output_batch_dimension=*/
4463             rewriter.getI64IntegerAttr(num_spatial_dims),
4464             /*output_feature_dimension=*/
4465             rewriter.getI64IntegerAttr(num_spatial_dims + 1),
4466             /*output_spatial_dimensions=*/
4467             GetI64ElementsAttrForSeq(0, num_spatial_dims, &rewriter),
4468             rewriter.getContext()),
4469         /*feature_group_count=*/rewriter.getI64IntegerAttr(1),
4470         rewriter.getI64IntegerAttr(batch_group_count),
4471         /*precision_config=*/ArrayAttr());
4472 
4473     rewriter.replaceOp(op, {result});
4474 
4475     return success();
4476   }
4477 };
4478 
4479 using ConvertConv2DBackpropFilterOp =
4480     ConvertConvBackpropFilterOp<TF::Conv2DBackpropFilterOp,
4481                                 /*num_spatial_dims=*/2>;
4482 using ConvertConv3DBackpropFilterOp =
4483     ConvertConvBackpropFilterOp<TF::Conv3DBackpropFilterV2Op,
4484                                 /*num_spatial_dims=*/3>;
4485 
4486 class ConvertOneHotOp : public OpRewritePattern<TF::OneHotOp> {
4487  public:
4488   using OpRewritePattern::OpRewritePattern;
4489 
matchAndRewrite(TF::OneHotOp op,PatternRewriter & rewriter) const4490   LogicalResult matchAndRewrite(TF::OneHotOp op,
4491                                 PatternRewriter &rewriter) const override {
4492     auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
4493     if (!indices_ty || !indices_ty.hasStaticShape()) return failure();
4494     ArrayRef<int64_t> indices_shape = indices_ty.getShape();
4495     Type element_type = indices_ty.getElementType();
4496 
4497     DenseIntElementsAttr depth_attr;
4498     if (!matchPattern(op.depth(), m_Constant(&depth_attr))) {
4499       return failure();
4500     }
4501 
4502     int64_t depth = depth_attr.getValue<APInt>({}).getSExtValue();
4503     int64_t axis = op.axis();
4504     if (axis == -1) axis = indices_shape.size();
4505 
4506     llvm::SmallVector<int64_t, 4> broadcast_dims(indices_shape.size());
4507     std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
4508     std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
4509 
4510     llvm::SmallVector<int64_t, 4> output_dims =
4511         llvm::to_vector<4>(indices_shape);
4512     output_dims.insert(output_dims.begin() + axis, depth);
4513 
4514     Location loc = op.getLoc();
4515 
4516     // The iota result is the effective output shape of the computation,
4517     // and indices must be broadcast into it. At this point, this computation
4518     // would need to be reworked quite a bit to support dynamic shapes, so
4519     // just using static broadcasting.
4520     auto index_type = RankedTensorType::get(output_dims, element_type);
4521     auto iota = rewriter.create<IotaOp>(
4522         loc, index_type, IntegerAttr::get(rewriter.getIntegerType(64), axis));
4523     auto broadcast_indices = rewriter.create<BroadcastInDimOp>(
4524         loc, index_type, op.indices(),
4525         GetI64ElementsAttr(broadcast_dims, &rewriter));
4526 
4527     Value compare = rewriter.create<mhlo::CompareOp>(
4528         loc, broadcast_indices, iota,
4529         StringAttr::get(rewriter.getContext(), "EQ"));
4530     Value on_value = rewriter.create<BroadcastOp>(
4531         loc, op.getType(), op.on_value(),
4532         GetI64ElementsAttr(output_dims, &rewriter));
4533     Value off_value = rewriter.create<BroadcastOp>(
4534         loc, op.getType(), op.off_value(),
4535         GetI64ElementsAttr(output_dims, &rewriter));
4536     Value result = rewriter.create<SelectOp>(loc, op.getType(), compare,
4537                                              on_value, off_value);
4538 
4539     rewriter.replaceOp(op, {result});
4540 
4541     return success();
4542   }
4543 };
4544 
4545 // Converts InfeedDequeueTuple to XLA HLO create_token, infeed and
4546 // get_tuple_element ops.
4547 //
4548 // All HLO infeed ops expect a HLO token type operand and produce a tuple
4549 // containing a token. This HLO token type is used to order multiple infeed
4550 // operations within a computation. The token type can come from other
4551 // infeed/outfeed/send/recv ops or can be generated using create_token op with
4552 // no operands. Here we emit a create_token op to generate the token type
4553 // operand of infeed.
4554 //
4555 // For example the following IR:
4556 // %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<3xi32>, tensor<4xf32>)
4557 //
4558 // would be lowered to
4559 //
4560 // %token = "mhlo.create_token"() : () -> !mhlo.token
4561 // %data_and_token = "mhlo.infeed"(%token) {infeed_config = ""} :
4562 //      (!mhlo.token) -> tuple<tuple<tensor<3xi32>, tensor<4xf32>>,
4563 //      !mhlo.token>
4564 // %data = "mhlo.get_tuple_element"(%data_and_token) {index = 0}
4565 // %0#0 = "mhlo.get_tuple_element"(%data) {index = 0}
4566 // %0#1 = "mhlo.get_tuple_element"(%data) {index = 1}
4567 //
4568 class ConvertInfeedDequeueTupleOp
4569     : public OpRewritePattern<TF::InfeedDequeueTupleOp> {
4570  public:
4571   using OpRewritePattern::OpRewritePattern;
4572 
GetLayout(const Type & type,PatternRewriter & rewriter) const4573   Attribute GetLayout(const Type &type, PatternRewriter &rewriter) const {
4574     auto i64_type = rewriter.getIntegerType(64);
4575     if (type.isa<TupleType>()) {
4576       TupleType tuple_type = type.dyn_cast<TupleType>();
4577       std::vector<mlir::Attribute> v;
4578       for (const mlir::Type &t : tuple_type.getTypes()) {
4579         v.push_back(GetLayout(t, rewriter));
4580       }
4581       ArrayRef<Attribute> shape(v);
4582       return rewriter.getArrayAttr(shape);
4583     } else if (type.isa<RankedTensorType>()) {
4584       RankedTensorType t = type.dyn_cast<RankedTensorType>();
4585       std::vector<mlir::Attribute> attrs;
4586       std::vector<Attribute> elements;
4587       // Tuples are always serialized with an ascending layout. See
4588       // LiteralLinearizer::LinearizeToBuffers.
4589       for (int64_t i = 0; i < t.getRank(); i++) {
4590         elements.push_back(rewriter.getIntegerAttr(i64_type, i));
4591       }
4592       return rewriter.getArrayAttr(elements);
4593     } else {
4594       return rewriter.getUnitAttr();  // e.g. tokens
4595     }
4596   }
4597 
matchAndRewrite(TF::InfeedDequeueTupleOp op,PatternRewriter & rewriter) const4598   LogicalResult matchAndRewrite(TF::InfeedDequeueTupleOp op,
4599                                 PatternRewriter &rewriter) const override {
4600     std::vector<Type> result_types(op.outputs().size());
4601     for (auto idx_and_output : llvm::enumerate(op.outputs())) {
4602       result_types[idx_and_output.index()] = (idx_and_output.value().getType());
4603     }
4604     // Infeed takes a single token operand. Generate the token using
4605     // create_token op to pass to the infeed op.
4606     auto token = rewriter.create<CreateTokenOp>(
4607         op.getLoc(), mhlo::TokenType::get(rewriter.getContext()));
4608 
4609     // Emit infeed op.
4610     // The result type of infeed is a tuple(tuple(result types), token type).
4611     auto data_tuple_type =
4612         mlir::TupleType::get(rewriter.getContext(), result_types);
4613     auto data_and_token_type = mlir::TupleType::get(
4614         rewriter.getContext(), {data_tuple_type, token.getType()});
4615 
4616     ArrayAttr layout =
4617         GetLayout(data_and_token_type, rewriter).cast<ArrayAttr>();
4618     auto data_and_token =
4619         rewriter.create<InfeedOp>(op.getLoc(), data_and_token_type, token,
4620                                   /*infeed_config=*/rewriter.getStringAttr(""),
4621                                   /*layout=*/layout);
4622 
4623     // TODO(b/171212005): Reenable layout.
4624     data_and_token.removeAttr("layout");
4625 
4626     if (op._XlaSharding().hasValue()) {
4627       // _XlaSharding attribute in TF is a serialized string of the OpSharding
4628       // proto, so convert to a text form here.
4629       ::xla::OpSharding sharding_proto;
4630       if (!sharding_proto.ParseFromString(op._XlaSharding().getValue().str()))
4631         return failure();
4632 
4633       // Token is a control signal and not a real data, so arbitrarily assign
4634       // the token to device 0.
4635       if (sharding_proto.type() == ::xla::OpSharding::TUPLE) {
4636         *sharding_proto.add_tuple_shardings() =
4637             ::xla::sharding_builder::AssignDevice(0);
4638         data_and_token->setAttr(
4639             kShardingAttr,
4640             rewriter.getStringAttr(sharding_proto.SerializeAsString()));
4641       } else {
4642         data_and_token->setAttr(kShardingAttr, op._XlaShardingAttr());
4643       }
4644     }
4645 
4646     // The infeed instruction produces a tuple of the infeed data and a token
4647     // type. Emit get_tuple_element to get infeed data tuple.
4648     auto data_tuple = rewriter.create<GetTupleElementOp>(
4649         op.getLoc(), data_tuple_type, data_and_token,
4650         rewriter.getI32IntegerAttr(0));
4651 
4652     // Emit get_tuple_element for each result.
4653     std::vector<Value> results;
4654     for (auto idx_and_type : llvm::enumerate(result_types)) {
4655       auto tuple_element = rewriter.create<GetTupleElementOp>(
4656           op.getLoc(), idx_and_type.value(), data_tuple,
4657           rewriter.getI32IntegerAttr(idx_and_type.index()));
4658       results.push_back(tuple_element);
4659     }
4660     rewriter.replaceOp(op, ValueRange(results));
4661     return success();
4662   }
4663 };
4664 
4665 // Converts tf.OutfeedEnqueueTuple to XLA HLO tuple, create_token and outfeed
4666 // ops.
4667 //
4668 // XLA HLO outfeed op expects a token, which we generate by emitting an
4669 // create_token op.
4670 //
4671 // For example the following IR:
4672 // "tf.OutfeedEnqueueTuple"(%val_1, %val_2) : (tensor<3xi32>, tensor<4xf32>) ->
4673 //      ()
4674 //
4675 // would be lowered to
4676 //
4677 // %tuple = "mhlo.tuple"(%val_1, %val_2) : (tensor<3xi32>, tensor<4xf32>) ->
4678 //      tuple<tensor<3xi32>, tensor<4xf32>>
4679 // %token = "mhlo.create_token"() : () -> !mhlo.token
4680 // %outfeed_token = "mhlo.outfeed"(%tuple, %token) {outfeed_config = ""} :
4681 //      (tuple<tensor<3xi32>, tensor<4xf32>>, !mhlo.token) -> !mhlo.token
4682 //
4683 class ConvertOutfeedEnqueueTupleOp
4684     : public OpRewritePattern<TF::OutfeedEnqueueTupleOp> {
4685  public:
4686   using OpRewritePattern::OpRewritePattern;
4687 
matchAndRewrite(TF::OutfeedEnqueueTupleOp op,PatternRewriter & rewriter) const4688   LogicalResult matchAndRewrite(TF::OutfeedEnqueueTupleOp op,
4689                                 PatternRewriter &rewriter) const override {
4690     auto token_type = mhlo::TokenType::get(rewriter.getContext());
4691     auto tuple = rewriter.create<TupleOp>(op.getLoc(), op.inputs());
4692     auto token = rewriter.create<CreateTokenOp>(op.getLoc(), token_type);
4693     rewriter.create<OutfeedOp>(op.getLoc(), token_type, tuple, token,
4694                                /*outfeed_config=*/rewriter.getStringAttr(""));
4695     rewriter.eraseOp(op);
4696     return success();
4697   }
4698 };
4699 
4700 // Converts tf.TopKV2 to XLA HLO iota, sort, and slice ops when k is a constant.
4701 //
4702 // tf.TopKV2 sorts along last dimension of the input tensor and then returns
4703 // the top K components' values and indices. This is translated into a few
4704 // ops in XLA HLO: first generating an integer sequence for the indices,
4705 // then sort both the original input tensor and the indices togheter, and
4706 // at last slice out the top K components.
4707 //
4708 // For example, for the following IR:
4709 //
4710 // %k = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32>
4711 // %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor<i32>) ->
4712 //                                 (tensor<16x8xf32>, tensor<16x8xi32>)
4713 //
4714 // We will get:
4715 //
4716 // %1 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<16x16xi32>
4717 // %2 = "mhlo.sort"(%input, %1) ( {
4718 // ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>,
4719 //      %arg3: tensor<i32>, %arg4: tensor<i32>):
4720 //   %7 = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"}: ...
4721 //   "mhlo.return"(%7) : (tensor<i1>) -> ()
4722 // }) {dimension = 1 : i64, is_stable = true} : ...
4723 // %3 = "mhlo.get_tuple_element"(%2) {index = 0 : i32} : ...
4724 // %4 = "mhlo.get_tuple_element"(%2) {index = 1 : i32} : ...
4725 // %5 = "mhlo.slice"(%3) {limit_indices = dense<[16, 8]> : tensor<2xi64>,
4726 //                           start_indices dense<0> : tensor<2xi64>,
4727 //                           strides = dense<1> : tensor<2xi64>} :
4728 //                              (tensor<16x16xf32>) -> tensor<16x8xf32>
4729 // %6 = "mhlo.slice"(%4) ...
4730 class ConvertTopKV2Op : public OpRewritePattern<TF::TopKV2Op> {
4731  public:
4732   using OpRewritePattern::OpRewritePattern;
4733 
matchAndRewrite(TF::TopKV2Op op,PatternRewriter & rewriter) const4734   LogicalResult matchAndRewrite(TF::TopKV2Op op,
4735                                 PatternRewriter &rewriter) const override {
4736     // We can only match when the `k` operand is a constant scalar.
4737     DenseIntElementsAttr k_attr;
4738     if (!matchPattern(op.k(), m_Constant(&k_attr))) return failure();
4739 
4740     // The last dimension of the input tensor's shape should be known so we can
4741     // have clamped end_indices for slices.
4742     TensorType input_type = op.input().getType().cast<TensorType>();
4743     if (!input_type.hasRank()) return failure();
4744     int64_t input_rank = input_type.getRank();
4745     int64_t last_dim_index = input_rank - 1;
4746     int64_t last_dim_size = input_type.getDimSize(last_dim_index);
4747     if (last_dim_size == ShapedType::kDynamicSize) return failure();
4748 
4749     // Create an Itoa op for indices.
4750     auto i32_type = rewriter.getIntegerType(32);
4751     Type iota_type = RankedTensorType::get(input_type.getShape(), i32_type);
4752     Value iota_op = rewriter.create<mhlo::IotaOp>(
4753         op.getLoc(), iota_type, rewriter.getI64IntegerAttr(last_dim_index));
4754 
4755     // Create the sort op. It takes two inputs, one for the original input, the
4756     // other for the indices.
4757     auto sort_op = rewriter.create<mhlo::SortOp>(
4758         op.getLoc(), llvm::ArrayRef<Value>{op.input(), iota_op}, last_dim_index,
4759         /*is_stable=*/true);
4760 
4761     // Use TOTALORDER comparison type instead of the default comparison if the
4762     // element type is of type float.
4763     llvm::Optional<StringRef> compare_type;
4764     if (input_type.getElementType().isa<FloatType>())
4765       compare_type.emplace("TOTALORDER");
4766     BuildSortComparisonBody({input_type.getElementType(), i32_type},
4767                             /*direction=*/"GT", compare_type,
4768                             &sort_op.comparator(), &rewriter);
4769 
4770     // Get the sorted input and index tuple element.
4771     auto tuple_first_element = sort_op.getResult(0);
4772     auto tuple_second_element = sort_op.getResult(1);
4773 
4774     SmallVector<int64_t, 4> begin_indices(input_rank, 0);
4775     auto end_indices = llvm::to_vector<4>(input_type.getShape());
4776     end_indices.back() =
4777         std::min((*k_attr.begin()).getSExtValue(), last_dim_size);
4778     SmallVector<int64_t, 4> strides(input_rank, 1);
4779 
4780     // Get the slice for the top K elements.
4781 
4782     Value values = rewriter.create<mhlo::SliceOp>(
4783         op.getLoc(), tuple_first_element,
4784         GetI64ElementsAttr(begin_indices, &rewriter),
4785         GetI64ElementsAttr(end_indices, &rewriter),
4786         GetI64ElementsAttr(strides, &rewriter));
4787 
4788     Value indices = rewriter.create<mhlo::SliceOp>(
4789         op.getLoc(), tuple_second_element,
4790         GetI64ElementsAttr(begin_indices, &rewriter),
4791         GetI64ElementsAttr(end_indices, &rewriter),
4792         GetI64ElementsAttr(strides, &rewriter));
4793 
4794     rewriter.replaceOp(op, {values, indices});
4795     return success();
4796   }
4797 };
4798 
4799 // Converts tf.Unpack to a series of XLA HLO slice ops.
4800 //
4801 // Each slice takes one element along the dimension to unpack and takes the full
4802 // range for all other dimensions. Each slice is then reshaped to drop the
4803 // dimension to unpack (which is always of size 1).
4804 // TODO(antiagainst): consider changing this into a TF internal lowering pass.
4805 class ConvertUnpackOp : public OpRewritePattern<TF::UnpackOp> {
4806  public:
4807   using OpRewritePattern::OpRewritePattern;
4808 
matchAndRewrite(TF::UnpackOp op,PatternRewriter & rewriter) const4809   LogicalResult matchAndRewrite(TF::UnpackOp op,
4810                                 PatternRewriter &rewriter) const override {
4811     auto value_type = op.value().getType().dyn_cast<RankedTensorType>();
4812     if (!value_type) return failure();
4813 
4814     int64_t value_rank = value_type.getRank();
4815     int64_t axis = op.axis();
4816     if (axis < 0) axis += value_rank;
4817 
4818     // Parameters for constructing each slice.
4819     SmallVector<int64_t, 4> begin_indices(value_rank, 0);
4820     auto end_indices = llvm::to_vector<4>(value_type.getShape());
4821     SmallVector<int64_t, 4> strides(value_rank, 1);
4822 
4823     // All HLO slice+squeeze results used to replace the original tf.Unpack op.
4824     SmallVector<Value, 4> results;
4825     results.reserve(op.getNumResults());
4826 
4827     for (int i = 0, end = op.getNumResults(); i < end; ++i) {
4828       begin_indices[axis] = i;
4829       end_indices[axis] = i + 1;
4830 
4831       auto slice_op = rewriter.create<mhlo::SliceOp>(
4832           op.getLoc(), op.value(), GetI64ElementsAttr(begin_indices, &rewriter),
4833           GetI64ElementsAttr(end_indices, &rewriter),
4834           GetI64ElementsAttr(strides, &rewriter));
4835       // Reshape to drop the axis dimension.
4836       auto result =
4837           rewriter.create<TF::SqueezeOp>(op.getLoc(), op.getType(i), slice_op,
4838                                          rewriter.getI64ArrayAttr(op.axis()));
4839       results.push_back(result);
4840     }
4841 
4842     rewriter.replaceOp(op, results);
4843     return success();
4844   }
4845 };
4846 
4847 // Converts TF unsorted segment reduction ops to XLA HLO scatter op.
4848 //
4849 // TF unsorted segment reduction op peforms the following calculation:
4850 //
4851 // Assume segment ids' shape is [SI0, SI1, ..., SIm] and data's  shape is
4852 // [D0, D1, ..., Dn]. Note that segment ids' shape must be a prefix of data's
4853 // shape, so we can have data's shape represented as [SI0, SI1, ..., SIm,
4854 // Dm+1, ..., Dn]. Then
4855 //   output[segment_ids[SI_i0, SI_i1, ..., SI_im], D_im+1, ..., D_in] =
4856 //      <ReductionOp> over data[SI_i0, SI_i1, ..., SI_im, D_im+1, ..., D_in]
4857 // where SI_iN is in the range of [0, SIN) and D_iN is in the range of [0, DN).
4858 //
4859 // The op will be translated to XLA HLO scatter with the following parameters:
4860 // * Update window dims is [segment_id_rank, data_rank).
4861 // * Inserted window dims is {0}.
4862 // * Scatter dims to operand dims mapping is {0}.
4863 // * Index vector dim is segment_id_rank.
4864 template <typename ConcreteClass, typename OpTy, typename ReductionOp>
4865 class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern<OpTy> {
4866   using OpRewritePattern<OpTy>::OpRewritePattern;
4867 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const4868   LogicalResult matchAndRewrite(OpTy op,
4869                                 PatternRewriter &rewriter) const override {
4870     auto data_type = op.data().getType().template dyn_cast<RankedTensorType>();
4871     if (!data_type) return failure();
4872     int64_t data_rank = data_type.getRank();
4873 
4874     auto segment_ids_type =
4875         op.segment_ids().getType().template dyn_cast<RankedTensorType>();
4876     if (!segment_ids_type) return failure();
4877     int64_t segment_ids_rank = segment_ids_type.getRank();
4878 
4879     DenseIntElementsAttr num_segments_attr;
4880     if (!matchPattern(op.num_segments(), m_Constant(&num_segments_attr)))
4881       return failure();
4882 
4883     // The final shape for TF unsorted segment reduction op is [num_segments] +
4884     // data_shape[segment_ids_rank:].
4885     SmallVector<int64_t, 4> output_shape;
4886     output_shape.push_back((*num_segments_attr.begin()).getSExtValue());
4887     auto suffix = data_type.getShape().drop_front(segment_ids_rank);
4888     output_shape.append(suffix.begin(), suffix.end());
4889     auto output_type =
4890         RankedTensorType::get(output_shape, data_type.getElementType());
4891 
4892     // Broadcast the initial value for reduction. This will become the
4893     // 'operand' parameter to scatter to for the final scatter op.
4894     Value init = ConcreteClass::GetInitialValue(data_type.getElementType(),
4895                                                 op.getLoc(), &rewriter);
4896     auto broadcasted_init = rewriter.create<mhlo::BroadcastOp>(
4897         op.getLoc(), output_type, init,
4898         GetI64ElementsAttr(output_shape, &rewriter));
4899 
4900     // Parameters for the generated scatter op.
4901     SmallVector<int64_t, 1> inserted_window_dims(1, 0);
4902     SmallVector<int64_t, 1> scatter_dims_to_operand_dims(1, 0);
4903     int64_t index_vector_dim = segment_ids_rank;
4904 
4905     // Put all parameters in a StructAttr.
4906     auto dims_attr = ScatterDimensionNumbers::get(
4907         GetI64ElementsAttrForSeq(segment_ids_rank, data_rank, &rewriter),
4908         GetI64ElementsAttr(inserted_window_dims, &rewriter),
4909         GetI64ElementsAttr(scatter_dims_to_operand_dims, &rewriter),
4910         rewriter.getI64IntegerAttr(index_vector_dim), rewriter.getContext());
4911 
4912     auto scatter =
4913         rewriter.create<ScatterOp>(op.getLoc(), op.getType(), broadcasted_init,
4914                                    op.segment_ids(), op.data(), dims_attr);
4915     BuildReduceBody<ReductionOp>(data_type.getElementType(),
4916                                  &scatter.update_computation(), &rewriter);
4917 
4918     rewriter.replaceOp(op, scatter.getResult());
4919     return success();
4920   }
4921 };
4922 
4923 class ConvertUnsortedSegmentMaxOp
4924     : public GenericConvertUnsortedSegmentReductionOp<
4925           ConvertUnsortedSegmentMaxOp, TF::UnsortedSegmentMaxOp, MaxOp> {
4926  public:
4927   using GenericConvertUnsortedSegmentReductionOp::
4928       GenericConvertUnsortedSegmentReductionOp;
4929 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4930   static Value GetInitialValue(Type reduce_element_type, Location loc,
4931                                PatternRewriter *rewriter) {
4932     return GetScalarLimitConstOfType(reduce_element_type, loc, hlo::kLowest,
4933                                      rewriter);
4934   }
4935 };
4936 
4937 class ConvertUnsortedSegmentMinOp
4938     : public GenericConvertUnsortedSegmentReductionOp<
4939           ConvertUnsortedSegmentMinOp, TF::UnsortedSegmentMinOp, MinOp> {
4940  public:
4941   using GenericConvertUnsortedSegmentReductionOp::
4942       GenericConvertUnsortedSegmentReductionOp;
4943 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4944   static Value GetInitialValue(Type reduce_element_type, Location loc,
4945                                PatternRewriter *rewriter) {
4946     return GetScalarLimitConstOfType(reduce_element_type, loc, hlo::kMax,
4947                                      rewriter);
4948   }
4949 };
4950 
4951 class ConvertUnsortedSegmentProdOp
4952     : public GenericConvertUnsortedSegmentReductionOp<
4953           ConvertUnsortedSegmentProdOp, TF::UnsortedSegmentProdOp, MulOp> {
4954  public:
4955   using GenericConvertUnsortedSegmentReductionOp::
4956       GenericConvertUnsortedSegmentReductionOp;
4957 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4958   static Value GetInitialValue(Type reduce_element_type, Location loc,
4959                                PatternRewriter *rewriter) {
4960     return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter);
4961   }
4962 };
4963 
4964 class ConvertUnsortedSegmentSumOp
4965     : public GenericConvertUnsortedSegmentReductionOp<
4966           ConvertUnsortedSegmentSumOp, TF::UnsortedSegmentSumOp, AddOp> {
4967  public:
4968   using GenericConvertUnsortedSegmentReductionOp::
4969       GenericConvertUnsortedSegmentReductionOp;
4970 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4971   static Value GetInitialValue(Type reduce_element_type, Location loc,
4972                                PatternRewriter *rewriter) {
4973     return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter);
4974   }
4975 };
4976 
4977 // Converts tf.RandomShuffle op into a series of XLA HLO ops.
4978 //
4979 // tf.RandomShuffle shuffles tensors along the first dimension. If the input
4980 // tensor's rank is 1, then it is translated into HLO sort op(s) according to
4981 // indices randomly generated via HLO rng_uniform ops. Otherwise, it is
4982 // translated into an HLO while op to first emulate shuffling indices using
4983 // HLO dynamic_slice and dynamic_update_slice ops, then finally HLO gather
4984 // with the shuffled indices.
4985 class ConvertRandomShuffleOp : public OpRewritePattern<TF::RandomShuffleOp> {
4986  public:
4987   using OpRewritePattern::OpRewritePattern;
4988 
matchAndRewrite(TF::RandomShuffleOp op,PatternRewriter & rewriter) const4989   LogicalResult matchAndRewrite(TF::RandomShuffleOp op,
4990                                 PatternRewriter &rewriter) const override {
4991     auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
4992     if (!input_type) return failure();
4993 
4994     int64_t input_rank = input_type.getRank();
4995     int64_t first_dim_size = input_type.getDimSize(0);
4996     if (ShapedType::isDynamic(first_dim_size)) return failure();
4997 
4998     // We are shuffling along the first dimension. If its size is <= 1, then
4999     // shuffling is a no-op.
5000     if (first_dim_size <= 1) {
5001       rewriter.replaceOp(op, op.value());
5002       return success();
5003     }
5004 
5005     // For vectors, shuffle values by sorting instead of the obvious
5006     // Fisher-Yates algorithm. Fisher-Yates is simple to implement and correct,
5007     // but not easily parallelizable. For a sufficiently parallel architecture,
5008     // it is faster to sort many times, than Fisher-Yates shuffle once.
5009     if (input_rank == 1) {
5010       // Shuffle values by assigning each value a random key and sorting the
5011       // keys. Keys can collide causing detectable patterns in the shuffled
5012       // output. Collisions translates into more ascending sub-sequences in the
5013       // shuffled output than would be expected by chance. To avoid collisions,
5014       // the number of possible key values must be sufficiently large.
5015 
5016       // How are more than 2^32 keys created? In each loop iteration, the
5017       // algorithm sorts by random keys. Conceptually, the earlier iterations
5018       // are sorting on the lower-order bits of larger keys that are never
5019       // actually assembled.
5020 
5021       // The expected number of collisions is n - d + d(1 - 1/d)^n, where d is
5022       // the number of possible keys and n is the number of values. If d = n^2,
5023       // then the limit as n goes to infinity is 1/2. If d = n^3, then the limit
5024       // as n goes to infinity is zero.
5025 
5026       // This implementation ensures that the key-space is greater than or equal
5027       // to the cube of the number of values. The risk of collisions can be
5028       // further reduced by increasing Exponent at the expense of
5029       // performance.
5030 
5031       // For Exponent = 2, the expected number of collisions per shuffle is
5032       // maximized at n = floor((2^32-1)^(1/2)) = 65535 where the expectation is
5033       // about 1/2.
5034 
5035       // For Exponent = 3, the expected number of collisions per shuffle is
5036       // maximized at n = floor((2^32-1)^(1/3)) = 1625 where the expectation is
5037       // about 1/3255.
5038 
5039       // For Exponent = 4, the expected number of collisions per shuffle is
5040       // maximized at n = floor((2^32-1)^(1/4)) = 255 where the expectation is
5041       // about 1/132622.
5042       constexpr int exponent = 3;
5043       int64_t num_elements = input_type.getNumElements();
5044       uint32_t u32_max = std::numeric_limits<uint32_t>::max();
5045       int rounds =
5046           std::ceil(exponent * std::log(num_elements) / std::log(u32_max));
5047 
5048       Value current = op.value();
5049       for (int i = 0; i < rounds; ++i) {
5050         auto keys =
5051             CreateRngUniform32(op.getLoc(), num_elements, /*lower_limit=*/0,
5052                                /*upper_limit=*/u32_max, &rewriter);
5053         auto sorted = rewriter.create<mhlo::SortOp>(
5054             op.getLoc(), llvm::ArrayRef<Value>{keys, current});
5055         auto i32_type = rewriter.getIntegerType(32);
5056         BuildSortComparisonBody({i32_type, input_type.getElementType()},
5057                                 /*direction=*/"LT", llvm::None,
5058                                 &sorted.comparator(), &rewriter);
5059         current = sorted.getResult(1);
5060       }
5061       rewriter.replaceOp(op, current);
5062       return success();
5063     }
5064 
5065     // The Fisher-Yates algorithm.
5066 
5067     // Generate range(n) as the initial value for the indices to be swapped.
5068     auto indices_type =
5069         RankedTensorType::get({first_dim_size}, rewriter.getIntegerType(32));
5070     Value indices = rewriter.create<mhlo::IotaOp>(
5071         op.getLoc(), indices_type, rewriter.getI64IntegerAttr(0));
5072 
5073     // Generate random numbers to be used as swaps for the indices.
5074     Value swaps = CreateRngUniform32(op.getLoc(), first_dim_size, 0,
5075                                      first_dim_size, &rewriter);
5076 
5077     // While loop body to perform index swaps.
5078     auto swap_body_fn = [&](Location loc, Value i, ArrayRef<Value> old_values,
5079                             SmallVectorImpl<Value> *new_values,
5080                             OpBuilder *builder) {
5081       Value swaps = old_values[0];
5082       Value indices = old_values[1];
5083 
5084       auto vec1_i32_type =
5085           RankedTensorType::get({1}, builder->getIntegerType(32));
5086       auto scalar_i32_type =
5087           RankedTensorType::get({}, builder->getIntegerType(32));
5088       auto scalar_i64_type =
5089           RankedTensorType::get({}, builder->getIntegerType(64));
5090 
5091       auto scalar_one =
5092           DenseIntElementsAttr::get(scalar_i64_type, ArrayRef<int64_t>(1));
5093 
5094       // We need to swap the indices[i] with indices[swaps[i]]. First get
5095       // these index values.
5096       Value source_index = builder->create<mhlo::DynamicSliceOp>(
5097           loc, vec1_i32_type, indices, i, scalar_one);
5098       Value swap_index = builder->create<mhlo::ReshapeOp>(
5099           loc, scalar_i32_type,
5100           builder->create<mhlo::DynamicSliceOp>(loc, vec1_i32_type, swaps, i,
5101                                                 scalar_one));
5102       Value target_index = builder->create<mhlo::DynamicSliceOp>(
5103           loc, vec1_i32_type, indices, swap_index, scalar_one);
5104 
5105       // Then perform the swap.
5106       // indices[i] <- indices[swaps[i]]
5107       indices = builder->create<mhlo::DynamicUpdateSliceOp>(
5108           loc, indices.getType(), indices, target_index, llvm::makeArrayRef(i));
5109       // indices[swaps[i]] <- indices[i]
5110       indices = builder->create<mhlo::DynamicUpdateSliceOp>(
5111           loc, indices.getType(), indices, source_index,
5112           llvm::makeArrayRef(swap_index));
5113 
5114       // Update new values.
5115       new_values->assign({swaps, indices});
5116     };
5117 
5118     // Create a while op to swap indices.
5119     SmallVector<Value, 2> while_output;
5120     CreateWhile32(op.getLoc(), first_dim_size, swap_body_fn, {swaps, indices},
5121                   &while_output, &rewriter);
5122     Value swaped_indices = while_output[1];
5123 
5124     // Gather the data using the swapped indices as the shuffled order.
5125     ArrayRef<int64_t> input_shape = input_type.getShape();
5126     SmallVector<int64_t, 4> slice_sizes(input_shape.begin(), input_shape.end());
5127     slice_sizes[0] = 1;
5128     auto dims_attr = GatherDimensionNumbers::get(
5129         /*offset_dims=*/GetI64ElementsAttrForSeq(1, input_rank, &rewriter),
5130         /*collapsed_slice_dims=*/GetI64ElementsAttr({0}, &rewriter),
5131         /*start_index_map=*/GetI64ElementsAttr({0}, &rewriter),
5132         /*index_vector_dim=*/rewriter.getI64IntegerAttr(1),
5133         rewriter.getContext());
5134     rewriter.replaceOpWithNewOp<mhlo::GatherOp>(
5135         op, op.getType(), op.value(), swaped_indices, dims_attr,
5136         GetI64ElementsAttr(slice_sizes, &rewriter));
5137 
5138     return success();
5139   }
5140 };
5141 
5142 // Converts an XlaSharding op to a XLA HLO shard op with sharding attributes.
5143 class ConvertXlaShardingOp : public OpRewritePattern<TF::XlaShardingOp> {
5144  public:
5145   using OpRewritePattern::OpRewritePattern;
5146 
matchAndRewrite(TF::XlaShardingOp op,PatternRewriter & rewriter) const5147   LogicalResult matchAndRewrite(TF::XlaShardingOp op,
5148                                 PatternRewriter &rewriter) const override {
5149     // TODO(b/148313088): define sharding attribute struct in MLIR intead of
5150     // using a string.
5151     if (!op._XlaSharding().hasValue()) return failure();
5152 
5153     auto custom_call = rewriter.create<mhlo::CustomCallOp>(
5154         op.getLoc(), op.getType(), op.input(),
5155         /*call_target_name=*/rewriter.getStringAttr("Sharding"),
5156         /*has_side_effect=*/rewriter.getBoolAttr(false),
5157         /*backend_config=*/rewriter.getStringAttr(""));
5158     custom_call->setAttr(kShardingAttr, op._XlaShardingAttr());
5159     rewriter.replaceOp(op, custom_call.getResult(0));
5160 
5161     return success();
5162   }
5163 };
5164 
5165 // Converts a TF InplaceUpdate op to DynamicUpdateSlice HLO.
5166 class ConvertInplaceUpdateOp : public OpRewritePattern<TF::InplaceUpdateOp> {
5167  public:
5168   using OpRewritePattern::OpRewritePattern;
5169 
matchAndRewrite(TF::InplaceUpdateOp op,PatternRewriter & rewriter) const5170   LogicalResult matchAndRewrite(TF::InplaceUpdateOp op,
5171                                 PatternRewriter &rewriter) const override {
5172     auto input = op.x();
5173     auto indices = op.i();
5174     auto updates = op.v();
5175 
5176     // Slice each row of `i` and `v` to perform a separate dynamic-update-slice
5177     // on the contents of `x`.
5178     auto input_type = input.getType().cast<ShapedType>();
5179     auto updates_type = updates.getType().cast<ShapedType>();
5180     auto indices_type = indices.getType().cast<ShapedType>();
5181     if (!indices_type.hasStaticShape()) return failure();
5182 
5183     if (indices_type.getRank() != 1) return failure();
5184 
5185     SmallVector<Type, 4> unpacked_indices_type(
5186         indices_type.getDimSize(0),
5187         RankedTensorType::get({}, indices_type.getElementType()));
5188     // Note on zero_attr integer type: DynamicUpdateSlice op start_indices are
5189     // required to have matching types. This rewrite rule creates
5190     // DynamicUpdateSlice ops where the first "start index" is always i32 and
5191     // subsequent ones are constructed based on zero_attr. Thus the type
5192     // for zero_attr needs to be i32 as well.
5193     auto zero_attr = IntegerAttr::get(rewriter.getIntegerType(32), 0);
5194     auto unpacked_indices = rewriter.create<TF::UnpackOp>(
5195         op.getLoc(), unpacked_indices_type, indices, zero_attr);
5196 
5197     SmallVector<int64_t, 4> split_updates_shape;
5198     split_updates_shape.append(updates_type.getShape().begin(),
5199                                updates_type.getShape().end());
5200     split_updates_shape.front() = 1;
5201     SmallVector<Type, 4> split_updates_type;
5202     split_updates_type.resize(
5203         updates_type.getShape().front(),
5204         RankedTensorType::get(split_updates_shape,
5205                               updates_type.getElementType()));
5206 
5207     auto cst =
5208         rewriter.create<mhlo::ConstOp>(op.getLoc(), zero_attr).getResult();
5209     auto split_updates = rewriter.create<TF::SplitOp>(
5210         op.getLoc(), split_updates_type, cst, updates);
5211 
5212     SmallVector<Value, 6> input_indices;
5213     input_indices.resize(input_type.getRank(), cst);
5214 
5215     SmallVector<int64_t, 6> starts(updates_type.getRank(), 0);
5216     SmallVector<int64_t, 6> strides(updates_type.getRank(), 1);
5217     SmallVector<int64_t, 6> limits(updates_type.getShape().begin(),
5218                                    updates_type.getShape().end());
5219 
5220     for (auto pair :
5221          llvm::zip(unpacked_indices.output(), split_updates.output())) {
5222       input_indices.front() = std::get<0>(pair);
5223       input = rewriter.create<mhlo::DynamicUpdateSliceOp>(
5224           op.getLoc(), op.getType(), input, std::get<1>(pair), input_indices);
5225     }
5226 
5227     rewriter.replaceOp(op, input);
5228     return success();
5229   }
5230 };
5231 
5232 // Converts a TF XlaDynamicUpdateSlice op to DynamicUpdateSlice HLO.
5233 class ConvertXlaDynamicUpdateSliceOp
5234     : public OpRewritePattern<TF::XlaDynamicUpdateSliceOp> {
5235  public:
5236   using OpRewritePattern::OpRewritePattern;
5237 
matchAndRewrite(TF::XlaDynamicUpdateSliceOp op,PatternRewriter & rewriter) const5238   LogicalResult matchAndRewrite(TF::XlaDynamicUpdateSliceOp op,
5239                                 PatternRewriter &rewriter) const override {
5240     auto indices_type = op.indices().getType().dyn_cast<RankedTensorType>();
5241     if (!indices_type || !indices_type.hasStaticShape() ||
5242         indices_type.getShape().size() != 1)
5243       return failure();
5244 
5245     SmallVector<Type, 4> unpacked_indices_type(
5246         indices_type.getDimSize(0),
5247         RankedTensorType::get({}, indices_type.getElementType()));
5248     auto unpacked_indices = rewriter.create<TF::UnpackOp>(
5249         op.getLoc(), unpacked_indices_type, op.indices(),
5250         IntegerAttr::get(rewriter.getIntegerType(64), 0));
5251     rewriter.replaceOpWithNewOp<mhlo::DynamicUpdateSliceOp>(
5252         op, op.getType(), op.input(), op.update(), unpacked_indices.output());
5253     return success();
5254   }
5255 };
5256 
5257 // Converts a TF XlaAllReduce op to AllReduce HLO.
5258 class ConvertXlaAllReduceOp : public OpRewritePattern<TF::XlaAllReduceOp> {
5259   using OpRewritePattern::OpRewritePattern;
5260 
matchAndRewrite(TF::XlaAllReduceOp op,PatternRewriter & rewriter) const5261   LogicalResult matchAndRewrite(TF::XlaAllReduceOp op,
5262                                 PatternRewriter &rewriter) const override {
5263     DenseIntElementsAttr group_assignment;
5264     if (!matchPattern(op.group_assignment(), m_Constant(&group_assignment)))
5265       return failure();
5266     auto replica_groups =
5267         hlo::ConvertElementsAttr(group_assignment, rewriter.getIntegerType(64))
5268             .cast<DenseIntElementsAttr>();
5269     if (replica_groups.getType().getRank() != 2) return failure();
5270 
5271     Location loc = op.getLoc();
5272     Type element_type = getElementTypeOrSelf(op.input().getType());
5273 
5274     auto all_reduce = rewriter.create<AllReduceOp>(
5275         loc, op.getType(), op.input(), replica_groups, ChannelHandle());
5276     StringRef reduce_op = op.reduce_op();
5277     if (reduce_op == "Add") {
5278       BuildReduceBody<AddOp>(element_type, &all_reduce.computation(),
5279                              &rewriter);
5280     } else if (reduce_op == "Mul") {
5281       BuildReduceBody<MulOp>(element_type, &all_reduce.computation(),
5282                              &rewriter);
5283     } else if (reduce_op == "Min") {
5284       BuildReduceBody<MinOp>(element_type, &all_reduce.computation(),
5285                              &rewriter);
5286     } else if (reduce_op == "Max") {
5287       BuildReduceBody<MaxOp>(element_type, &all_reduce.computation(),
5288                              &rewriter);
5289     } else {
5290       // For mean, add replicas in the same group. Then divide the sum by the
5291       // number of replicas in each group below.
5292       assert(reduce_op == "Mean");
5293       BuildReduceBody<AddOp>(element_type, &all_reduce.computation(),
5294                              &rewriter);
5295     }
5296     Value result = all_reduce.getResult();
5297 
5298     // For mean, divide the merge result by group size.
5299     if (reduce_op == "Mean") {
5300       int64_t replica_group_size = replica_groups.getType().getDimSize(1);
5301       auto divisor = GetScalarConstOfType(element_type, loc, replica_group_size,
5302                                           &rewriter);
5303       auto broadcast_dims = GetI64ElementsAttr({}, &rewriter);
5304       result = rewriter.create<chlo::BroadcastDivOp>(
5305           loc, result, divisor.getResult(), broadcast_dims);
5306     }
5307 
5308     rewriter.replaceOp(op, {result});
5309     return success();
5310   }
5311 };
5312 
5313 // Converts ClipByValue to XLA's clamp operation. Includes the broadcasting
5314 // semantics for static and dynamic cases.
5315 class ConvertClipByValueOp : public OpRewritePattern<TF::ClipByValueOp> {
5316  public:
5317   using OpRewritePattern::OpRewritePattern;
5318 
matchAndRewrite(TF::ClipByValueOp op,PatternRewriter & rewriter) const5319   LogicalResult matchAndRewrite(TF::ClipByValueOp op,
5320                                 PatternRewriter &rewriter) const override {
5321     Value input = op.t();
5322     Value min = op.clip_value_min();
5323     Value max = op.clip_value_max();
5324 
5325     auto input_ty = input.getType().cast<ShapedType>();
5326     auto min_ty = min.getType().cast<ShapedType>();
5327     auto max_ty = max.getType().cast<ShapedType>();
5328 
5329     if (!input_ty.hasRank() || !min_ty.hasRank() || !max_ty.hasRank()) {
5330       return failure();
5331     }
5332 
5333     auto shape = rewriter.create<TF::ShapeOp>(
5334         op.getLoc(),
5335         RankedTensorType::get({input_ty.getRank()}, rewriter.getI32Type()),
5336         input);
5337 
5338     if (min_ty != input_ty) {
5339       min =
5340           rewriter.create<TF::BroadcastToOp>(op.getLoc(), input_ty, min, shape);
5341     }
5342 
5343     if (max_ty != input_ty) {
5344       max =
5345           rewriter.create<TF::BroadcastToOp>(op.getLoc(), input_ty, max, shape);
5346     }
5347 
5348     rewriter.replaceOpWithNewOp<mhlo::ClampOp>(op, input_ty, min, input, max);
5349     return success();
5350   }
5351 };
5352 
5353 // Converts the Cumsum or Cumprod TensorFlow op to the HLO ReduceWindow op by
5354 // setting appropriate window dimensions, with the given aggregation op as the
5355 // reduction function. The input tensor needs to have a static shape, and 'axis'
5356 // must be const. The TableGen pattern is not used for this rewrite because it
5357 // involves regions.
5358 template <typename OpT, typename AggregationOp>
5359 class ConvertCumOp : public OpRewritePattern<OpT> {
5360   using OpRewritePattern<OpT>::OpRewritePattern;
5361 
matchAndRewrite(OpT op,PatternRewriter & rewriter) const5362   LogicalResult matchAndRewrite(OpT op,
5363                                 PatternRewriter &rewriter) const override {
5364     auto input = op.x();
5365     auto input_type = input.getType().template dyn_cast<ShapedType>();
5366     if (!input_type || !input_type.hasStaticShape()) {
5367       return failure();
5368     }
5369 
5370     ArrayRef<int64_t> input_shape = input_type.getShape();
5371     int64_t rank = input_shape.size();
5372 
5373     // We can only match when the axis is a constant scalar.
5374     DenseIntElementsAttr axis_attr;
5375     if (!matchPattern(op.axis(), m_Constant(&axis_attr))) {
5376       return failure();
5377     }
5378 
5379     // Get the dimension to apply the reduction on, and offset properly if it is
5380     // negative.
5381     int64_t axis = (*axis_attr.begin()).getSExtValue();
5382     if (axis < 0) {
5383       axis += rank;
5384     }
5385 
5386     // If we're supposed to sum things up in the reverse direction, we reverse
5387     // the input and then later reverse the output.
5388     if (op.reverse()) {
5389       llvm::SmallVector<int64_t, 4> dims_to_reverse({axis});
5390       input = rewriter.create<ReverseOp>(
5391           op.getLoc(), op.getType(), input,
5392           GetI64ElementsAttr(dims_to_reverse, &rewriter));
5393     }
5394 
5395     // Convert if we need to enlarge the element type's bitwidth to avoid
5396     // precision loss.
5397     Type input_element_type = input_type.getElementType();
5398 
5399     // TODO(hinsu): Handle complex element types.
5400     if (!input_element_type.isIntOrFloat()) return failure();
5401 
5402     Type sum_element_type = GetSumAccumulationType(input_element_type);
5403     input = rewriter.create<ConvertOp>(op.getLoc(), input, sum_element_type);
5404 
5405     SmallVector<int64_t, 4> window_dims(rank, 1);
5406     SmallVector<int64_t, 4> window_strides(rank, 1);
5407     window_dims[axis] = input_shape[axis];
5408 
5409     SmallVector<int64_t, 8> paddings(rank * 2, 0);
5410     paddings[axis * 2] =
5411         std::max(input_shape[axis] - 1, static_cast<int64_t>(0));
5412     auto paddings_attr = DenseIntElementsAttr::get(
5413         RankedTensorType::get({rank, 2}, rewriter.getIntegerType(64)),
5414         paddings);
5415 
5416     int64_t init_value = (std::is_same<AggregationOp, AddOp>::value) ? 0 : 1;
5417     Value init = GetScalarConstOfType(sum_element_type, op.getLoc(), init_value,
5418                                       &rewriter);
5419 
5420     auto reduce = rewriter.create<ReduceWindowOp>(
5421         op.getLoc(), input_type, input, init,
5422         GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_dims)),
5423         GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_strides)),
5424         /*base_dilations=*/DenseIntElementsAttr(),
5425         /*window_dilations=*/DenseIntElementsAttr(), paddings_attr);
5426     BuildReduceBody<AggregationOp>(sum_element_type, &reduce.body(), &rewriter);
5427     Value result = reduce.getResult();
5428 
5429     if (op.exclusive()) {
5430       // In "exclusive" operation, the output will start with the "init" (0)
5431       // values. There is no way to express that as a ReduceWindowOp, so run the
5432       // normal operation, and then use a PadOp to add the 0 "column" on the
5433       // left and cut away the last column on the right.
5434       llvm::SmallVector<int64_t, 4> low_padding(rank, 0);
5435       llvm::SmallVector<int64_t, 4> high_padding(rank, 0);
5436       llvm::SmallVector<int64_t, 4> interior_padding(rank, 0);
5437       low_padding[axis] = 1;
5438       high_padding[axis] = -1;
5439       result = rewriter.create<PadOp>(
5440           op.getLoc(), op.getType(), result, init,
5441           GetI64ElementsAttr(low_padding, &rewriter),
5442           GetI64ElementsAttr(high_padding, &rewriter),
5443           GetI64ElementsAttr(interior_padding, &rewriter));
5444     }
5445 
5446     // Convert back if we enlarged the element type's bitwidth.
5447     result =
5448         rewriter.create<ConvertOp>(op.getLoc(), result, input_element_type);
5449 
5450     if (op.reverse()) {
5451       llvm::SmallVector<int64_t, 4> dims_to_reverse({axis});
5452       result = rewriter.create<ReverseOp>(
5453           op.getLoc(), op.getType(), result,
5454           GetI64ElementsAttr(dims_to_reverse, &rewriter));
5455     }
5456 
5457     rewriter.replaceOp(op, result);
5458     return success();
5459   }
5460 };
5461 
5462 using ConvertCumsumOp = ConvertCumOp<TF::CumsumOp, AddOp>;
5463 using ConvertCumprodOp = ConvertCumOp<TF::CumprodOp, MulOp>;
5464 
5465 // Converts the Tensorflow ShapeOp to a sequence of Shape dialect and Standard
5466 // dialect lowerings. This involves extracting the shape type, extracting and
5467 // converting each dimension to a known integer type, and repacking into a final
5468 // tensor.
5469 class ConvertShapeOp : public OpRewritePattern<TF::ShapeOp> {
5470  public:
5471   using OpRewritePattern::OpRewritePattern;
5472 
matchAndRewrite(TF::ShapeOp op,PatternRewriter & rewriter) const5473   LogicalResult matchAndRewrite(TF::ShapeOp op,
5474                                 PatternRewriter &rewriter) const override {
5475     Value input = op.input();
5476 
5477     auto shape_op = rewriter.create<shape::ShapeOfOp>(op.getLoc(), input);
5478     auto result_ty = op.getResult().getType().dyn_cast<RankedTensorType>();
5479     if (!result_ty) {
5480       return failure();
5481     }
5482 
5483     auto index_tensor =
5484         RankedTensorType::get(result_ty.getShape(), rewriter.getIndexType());
5485     auto extent_tensor = rewriter.create<shape::ToExtentTensorOp>(
5486         op.getLoc(), index_tensor, shape_op);
5487 
5488     rewriter.replaceOpWithNewOp<IndexCastOp>(op, result_ty, extent_tensor);
5489     return success();
5490   }
5491 };
5492 
5493 class ConvertDynamicReshapeOp : public OpRewritePattern<TF::ReshapeOp> {
5494  public:
5495   using OpRewritePattern::OpRewritePattern;
5496 
matchAndRewrite(TF::ReshapeOp op,PatternRewriter & rewriter) const5497   LogicalResult matchAndRewrite(TF::ReshapeOp op,
5498                                 PatternRewriter &rewriter) const override {
5499     auto tensor = op.tensor();
5500     auto shape = op.shape();
5501 
5502     auto tensor_ty = tensor.getType().cast<ShapedType>();
5503     auto shape_ty = shape.getType().cast<ShapedType>();
5504     auto result_ty = op.getType().cast<ShapedType>();
5505 
5506     if (!result_ty.hasRank() || !tensor_ty.hasRank() || !shape_ty.hasRank()) {
5507       return failure();
5508     }
5509 
5510     // Handle with the static case.
5511     if (result_ty.hasStaticShape()) {
5512       return failure();
5513     }
5514 
5515     rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_ty, tensor,
5516                                                         shape);
5517     return success();
5518   }
5519 };
5520 
5521 class ConvertDynamicExpandDimsOp : public OpRewritePattern<TF::ExpandDimsOp> {
5522  public:
5523   using OpRewritePattern::OpRewritePattern;
5524 
matchAndRewrite(TF::ExpandDimsOp op,PatternRewriter & rewriter) const5525   LogicalResult matchAndRewrite(TF::ExpandDimsOp op,
5526                                 PatternRewriter &rewriter) const override {
5527     auto input = op.input();
5528     auto input_ty = input.getType().cast<ShapedType>();
5529     auto result_ty = op.getType().cast<ShapedType>();
5530     if (!result_ty.hasRank() || !input_ty.hasRank() ||
5531         result_ty.hasStaticShape()) {
5532       return failure();
5533     }
5534 
5535     DenseIntElementsAttr expand_dims_attr;
5536     if (!matchPattern(op.dim(), m_Constant(&expand_dims_attr))) {
5537       return failure();
5538     }
5539 
5540     auto shape = rewriter.create<shape::ShapeOfOp>(
5541         op.getLoc(),
5542         RankedTensorType::get({input_ty.getRank()}, rewriter.getIndexType()),
5543         input);
5544     auto expand_dims = llvm::to_vector<6>(expand_dims_attr.getIntValues());
5545 
5546     llvm::SmallVector<Value, 4> dims;
5547     dims.resize(result_ty.getRank());
5548 
5549     auto inserted_dim = expand_dims[0].getSExtValue();
5550 
5551     // Handle the negative value use case.
5552     if (inserted_dim < 0) {
5553       inserted_dim += result_ty.getRank();
5554       // This means the value is completely incorrect, just return.
5555       if (inserted_dim < 0) {
5556         return failure();
5557       }
5558     }
5559 
5560     dims[inserted_dim] = rewriter.create<ConstantIndexOp>(op.getLoc(), 1);
5561 
5562     for (int i = 0; i < dims.size() - 1; i++) {
5563       // Add the extracted dim.
5564       auto index = rewriter.create<ConstantIndexOp>(op.getLoc(), i);
5565       auto dim = rewriter.create<shape::GetExtentOp>(
5566           op.getLoc(), rewriter.getIndexType(), shape, index);
5567 
5568       dims[i >= inserted_dim ? i + 1 : i] = dim;
5569     }
5570 
5571     auto from_extents = rewriter.create<shape::FromExtentsOp>(
5572         op.getLoc(), shape::ShapeType::get(op.getContext()), dims);
5573 
5574     auto to_extent_tensor = rewriter.create<shape::ToExtentTensorOp>(
5575         op.getLoc(),
5576         RankedTensorType::get({result_ty.getRank()}, rewriter.getIndexType()),
5577         from_extents);
5578 
5579     rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_ty, input,
5580                                                         to_extent_tensor);
5581     return success();
5582   }
5583 };
5584 
5585 // Converts a TF QR op to HLO.
5586 class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
5587  public:
5588   using OpRewritePattern::OpRewritePattern;
5589 
matchAndRewrite(TF::QrOp op,PatternRewriter & rewriter) const5590   LogicalResult matchAndRewrite(TF::QrOp op,
5591                                 PatternRewriter &rewriter) const override {
5592     // Block Householder QR Factorization. Algorithm 5.2.2 of Golub and van
5593     // Loan. def qr_blocked(a, block_size):
5594     //   m = a.shape[0]
5595     //   n = a.shape[1]
5596     //   q = np.eye(m)
5597     //   for i in xrange(0, min(m, n), block_size):
5598     //     k = min(block_size, min(m, n) - s)
5599     //     (a, vs, taus) = qr(a[i:, i:i+k])
5600     //     y = vs
5601     //     w = ComputeWYRepresentation(vs, taus, m-i, k)
5602     //     a[i:, i+r:] += np.dot(y, np.dot(w.T, a[i:, i+k:]))
5603     //     q[:, i:] += np.dot(q[:, i:], np.dot(w, y.T))
5604     //   return (q, a)
5605     auto type = op.input().getType().dyn_cast<RankedTensorType>();
5606     if (!type || !type.hasStaticShape()) return failure();
5607     // The block size is chosen to match old bridge lowering.
5608     constexpr int64_t kBlockSize = 128;
5609     Value a = op.input();
5610     int64_t m = type.getDimSize(type.getRank() - 2);
5611     int64_t n = type.getDimSize(type.getRank() - 1);
5612     int64_t p = std::min(m, n);
5613     auto batch_dims = type.getShape().drop_back(2);
5614     auto iota_type = RankedTensorType::get({m, m}, rewriter.getIntegerType(32));
5615     auto iota0 = rewriter.create<IotaOp>(op.getLoc(), iota_type,
5616                                          rewriter.getI64IntegerAttr(0));
5617     auto iota1 = rewriter.create<IotaOp>(op.getLoc(), iota_type,
5618                                          rewriter.getI64IntegerAttr(1));
5619     Value compare = rewriter.create<CompareOp>(
5620         op.getLoc(), iota0, iota1,
5621         StringAttr::get(rewriter.getContext(), "EQ"));
5622     Value identity_matrix =
5623         rewriter.create<ConvertOp>(op.getLoc(), compare, type.getElementType());
5624     auto q_shape = llvm::to_vector<4>(type.getShape());
5625     q_shape.back() = m;
5626     Value q = rewriter.create<BroadcastOp>(
5627         op.getLoc(), RankedTensorType::get(q_shape, type.getElementType()),
5628         identity_matrix, GetI64ElementsAttr(batch_dims, &rewriter));
5629     auto precision_config = rewriter.getStrArrayAttr({"HIGHEST", "HIGHEST"});
5630     for (int64_t i = 0; i < p; i += kBlockSize) {
5631       int64_t k = std::min(kBlockSize, p - i);
5632       auto a_block =
5633           SliceInMinorDims(op.getLoc(), a, {i, i}, {m, i + k}, &rewriter);
5634       Value r_block;
5635       Value taus;
5636       Value vs;
5637       QRBlock(op.getLoc(), a_block, &r_block, &taus, &vs, &rewriter);
5638       a = UpdateSliceInMinorDims(op.getLoc(), a, r_block, {i, i}, &rewriter);
5639 
5640       // Compute the I-WY block representation of a product of Householder
5641       // matrices.
5642       Value w =
5643           ComputeWYRepresentation(op.getLoc(), type.getElementType(),
5644                                   batch_dims, vs, taus, m - i, k, &rewriter);
5645       auto y = vs;
5646 
5647       // a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:]))
5648       Value a_panel =
5649           SliceInMinorDims(op.getLoc(), a, {i, i + k}, {m, n}, &rewriter);
5650       auto a_update = BatchDot(op.getLoc(), w, true, a_panel, false,
5651                                batch_dims.size(), precision_config, &rewriter);
5652       a_update = BatchDot(op.getLoc(), y, false, a_update, false,
5653                           batch_dims.size(), precision_config, &rewriter);
5654       a_panel = rewriter.create<AddOp>(op.getLoc(), a_panel, a_update);
5655       a = UpdateSliceInMinorDims(op.getLoc(), a, a_panel, {i, i + k},
5656                                  &rewriter);
5657 
5658       // q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T))
5659       Value q_panel =
5660           SliceInMinorDims(op.getLoc(), q, {0, i}, {m, m}, &rewriter);
5661       Value q_update = BatchDot(op.getLoc(), q_panel, false, w, false,
5662                                 batch_dims.size(), precision_config, &rewriter);
5663       q_update = BatchDot(op.getLoc(), q_update, false, y, true,
5664                           batch_dims.size(), precision_config, &rewriter);
5665       q_panel = rewriter.create<AddOp>(op.getLoc(), q_panel, q_update);
5666       q = UpdateSliceInMinorDims(op.getLoc(), q, q_panel, {i}, &rewriter);
5667     }
5668     // full_matrices is false when only a partial result in needed. Slice to the
5669     // needed dimensions here.
5670     if (!op.full_matrices()) {
5671       q = SliceInMinorDims(op.getLoc(), q, {0, 0}, {m, p}, &rewriter);
5672       a = SliceInMinorDims(op.getLoc(), a, {0, 0}, {p, n}, &rewriter);
5673     }
5674     rewriter.replaceOp(op, {q, a});
5675     return success();
5676   }
5677 
5678  private:
5679   // Computes a Householder reflection of the form:
5680   // H = I - tau v v.T.
5681   // such that
5682   // H . ( x1  ) = ( x1   )
5683   //     ( x2  ) = ( x2   )
5684   //     ( ... ) = ( ...  )
5685   //     ( xk  ) = ( beta )
5686   //     ( ... )   ( 0    )
5687   //     ( ... )   ( 0    )
5688   // Unlike the usual formulation, we allow the caller to supply 'k' rather than
5689   // only providing the relevant part of 'x' to maintain XLA's static shape
5690   // invariant. In addition, the implementation supports batching.
5691   // Pseudo-code, without batching:
5692   //   alpha = x[k]
5693   //   x_copy = np.copy(x)
5694   //   x_copy[:k+1] = 0
5695   //   xnorm = norm2(x_copy)
5696   //   if xnorm == 0:
5697   //     beta = alpha
5698   //     tau = 0
5699   //     v = np.zeros_like(x)
5700   //   else:
5701   //     beta = - np.sign(alpha) * dlapy2(alpha, xnorm)
5702   //     tau = (beta - alpha) / beta
5703   //     v = x / (alpha - beta)
5704   //   v[k] = 1
5705   //   return (v, tau, beta)
House(Location loc,Value x,Value k,ArrayRef<int64_t> batch_dims,const int64_t m,OpBuilder * builder,Value * v,Value * tau,Value * beta) const5706   void House(Location loc, Value x, Value k, ArrayRef<int64_t> batch_dims,
5707              const int64_t m, OpBuilder *builder, Value *v, Value *tau,
5708              Value *beta) const {
5709     auto x_type = x.getType().cast<RankedTensorType>();
5710 
5711     llvm::SmallVector<int64_t, 4> batch_dim_ids(batch_dims.size());
5712     std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0);
5713     const int64_t minor_dim = batch_dims.size();
5714 
5715     Value zero = GetScalarConstOfType(x_type.getElementType(), loc, 0, builder);
5716     Value one = GetScalarConstOfType(x_type.getElementType(), loc, 1, builder);
5717 
5718     // alpha = x[k]
5719     Value alpha = DynamicSliceInMinorDims(loc, x, {k}, {1}, builder);
5720     alpha = builder->create<ReshapeOp>(
5721         loc, RankedTensorType::get(batch_dims, x_type.getElementType()), alpha);
5722 
5723     // Compute x[k+1:] (padded with zeros in elements 0..k)
5724     Value iota = builder->create<IotaOp>(
5725         loc, RankedTensorType::get({m}, builder->getIntegerType(32)),
5726         builder->getI64IntegerAttr(0));
5727     Value gtk = builder->create<chlo::BroadcastCompareOp>(
5728         loc, iota, k, GetI64ElementsAttr({}, builder),
5729         StringAttr::get(builder->getContext(), "GT"));
5730     gtk = builder->create<ConvertOp>(loc, gtk, x_type.getElementType());
5731     Value x_after_k = builder->create<chlo::BroadcastMulOp>(
5732         loc, x, gtk, GetI64ElementsAttr({minor_dim}, builder));
5733     Value x_after_k_sq = builder->create<MulOp>(loc, x_after_k, x_after_k);
5734     // sigma = np.dot(x[k+1:], x[k+1:])
5735     auto sigma = builder->create<ReduceOp>(
5736         loc, x_after_k_sq, zero, GetI64ElementsAttr({minor_dim}, builder));
5737     BuildReduceBody<AddOp>(x_type.getElementType(), &sigma.body(), builder);
5738     // mu = np.sqrt(x[k]*x[k] + sigma)
5739     Value alpha_sq = builder->create<MulOp>(loc, alpha, alpha);
5740     Value mu = builder->create<SqrtOp>(
5741         loc, builder->create<AddOp>(loc, alpha_sq, sigma.getResult(0)));
5742 
5743     Value sigma_is_zero = builder->create<chlo::BroadcastCompareOp>(
5744         loc, sigma.getResult(0), zero, GetI64ElementsAttr({}, builder),
5745         StringAttr::get(builder->getContext(), "EQ"));
5746     Value alpha_is_negative = builder->create<chlo::BroadcastCompareOp>(
5747         loc, alpha, zero, GetI64ElementsAttr({}, builder),
5748         StringAttr::get(builder->getContext(), "LT"));
5749     auto batch_size_one = builder->create<BroadcastOp>(
5750         loc, alpha.getType(), one, GetI64ElementsAttr(batch_dims, builder));
5751     Value signed_mu = builder->create<chlo::BroadcastMulOp>(
5752         loc,
5753         builder->create<SelectOp>(loc, mu.getType(), alpha_is_negative,
5754                                   batch_size_one,
5755                                   builder->create<NegOp>(loc, batch_size_one)),
5756         mu, GetI64ElementsAttr({}, builder));
5757     *beta = builder->create<SelectOp>(loc, alpha.getType(), sigma_is_zero,
5758                                       alpha, signed_mu);
5759     *tau = builder->create<DivOp>(
5760         loc, builder->create<SubOp>(loc, *beta, alpha), *beta);
5761     Value zero_tau = builder->create<BroadcastOp>(
5762         loc, alpha.getType(), zero, GetI64ElementsAttr(batch_dims, builder));
5763     *tau = builder->create<SelectOp>(loc, alpha.getType(), sigma_is_zero,
5764                                      zero_tau, *tau);
5765     Value divisor = builder->create<SubOp>(loc, alpha, *beta);
5766     divisor = builder->create<SelectOp>(loc, divisor.getType(), sigma_is_zero,
5767                                         batch_size_one, divisor);
5768 
5769     Value eqk = builder->create<chlo::BroadcastCompareOp>(
5770         loc, iota, k, GetI64ElementsAttr({}, builder),
5771         StringAttr::get(builder->getContext(), "EQ"));
5772     eqk = builder->create<ConvertOp>(loc, eqk, x_type.getElementType());
5773     llvm::SmallVector<int64_t, 4> e_k_shape(batch_dims.size(), 1);
5774     e_k_shape.push_back(m);
5775     auto e_k = builder->create<BroadcastOp>(
5776         loc, RankedTensorType::get(e_k_shape, x_type.getElementType()), eqk,
5777         GetI64ElementsAttr(llvm::SmallVector<int64_t, 4>(batch_dims.size(), 1),
5778                            builder));
5779 
5780     // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor
5781     // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor.
5782     // Note that the add performs a degenerate broadcast.
5783     *v = builder->create<chlo::BroadcastAddOp>(
5784         loc, e_k,
5785         StaticBinaryBroadcast<DivOp>(loc, x_after_k, divisor,
5786                                      GetI64ElementsAttr(batch_dim_ids, builder),
5787                                      *builder),
5788         /*broadcast_dimensions=*/nullptr);
5789   }
5790 
5791   // Householder QR decomposition. Algorithm 5.2.1 from Golub and Van
5792   // Loan "Matrix Computations", 4th Edition. This is an unblocked
5793   // implementation used as an inner routine of the blocked implementation.
5794   // Algorithm is adapted slightly so the shapes inside the loop are static, at
5795   // the cost of some redundant computation. Since this is used as an inner
5796   // block kernel, accumulates the Householder transformations (vs, taus) rather
5797   // than the matrix q. Equivalent Python code, without batching: def qr(a):
5798   //   m = a.shape[0]
5799   //   n = a.shape[1]
5800   //   vs = np.zeros([m, n])
5801   //   taus = np.zeros([n])
5802   //   for j in xrange(min(m, n)):
5803   //     v, tau, beta = house(a[:, j], j)
5804   //     # Unusually, we apply the Householder transformation to the entirety of
5805   //     # a, wasting FLOPs to maintain the static shape invariant that XLA
5806   //     # requires. For columns that precede j this has no effect.
5807   //     a[:, :] -= tau * np.dot(v[:, np.newaxis],
5808   //                              np.dot(v[np.newaxis, :], a[:, :]))
5809   //     # Form column j explicitly rather than relying on the precision of the
5810   //     # Householder update.
5811   //     a[j, j] = beta
5812   //     a[j+1:, j] = np.zeros([m - j - 1], dtype=a.dtype)
5813   //     vs[:, j] = v
5814   //     taus[j] = tau
5815   //   return (q, vs, taus)
QRBlock(Location loc,Value a,Value * r,Value * taus,Value * vs,PatternRewriter * rewriter) const5816   void QRBlock(Location loc, Value a, Value *r, Value *taus, Value *vs,
5817                PatternRewriter *rewriter) const {
5818     auto a_type = a.getType().cast<RankedTensorType>();
5819     const int num_dims = a_type.getRank();
5820     assert(num_dims >= 2 && "Argument to QR must have rank >= 2");
5821 
5822     const int64_t m = a_type.getDimSize(a_type.getRank() - 2);
5823     const int64_t n = a_type.getDimSize(a_type.getRank() - 1);
5824 
5825     const int64_t num_batch_dims = num_dims - 2;
5826     auto batch_dims = a_type.getShape().take_front(num_batch_dims);
5827     llvm::SmallVector<int64_t, 4> batch_dim_indices(batch_dims.size());
5828     std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
5829 
5830     auto qr_body_fn = [&](Location loc, Value j, ArrayRef<Value> old_values,
5831                           SmallVectorImpl<Value> *new_values,
5832                           OpBuilder *builder) {
5833       auto a = old_values[0];
5834       auto vs = old_values[1];
5835       auto taus = old_values[2];
5836 
5837       // v, beta = house(a[:, j], j)
5838       auto x = DynamicSliceInMinorDims(loc, a, {j}, {1}, builder);
5839       auto x_collapsed_shape = llvm::to_vector<4>(batch_dims);
5840       x_collapsed_shape.push_back(m);
5841       auto x_collapsed = builder->create<ReshapeOp>(
5842           loc,
5843           RankedTensorType::get(x_collapsed_shape,
5844                                 getElementTypeOrSelf(x.getType())),
5845           x);
5846       Value v, tau, beta;
5847       House(loc, x_collapsed, j, batch_dims, m, builder, &v, &tau, &beta);
5848 
5849       auto shape = llvm::to_vector<4>(batch_dims);
5850       shape.append({1, m});
5851       auto v_broadcast = builder->create<ReshapeOp>(
5852           loc, RankedTensorType::get(shape, getElementTypeOrSelf(v.getType())),
5853           v);
5854       // a[:, :] -= tau * np.dot(v[:, np.newaxis],
5855       //                          np.dot(v[np.newaxis, :], a[:, :]))
5856       auto precision = builder->getStrArrayAttr({"HIGHEST", "HIGHEST"});
5857       auto vva = BatchDot(loc, v_broadcast, false, a, false, num_batch_dims,
5858                           precision, builder);
5859       vva = BatchDot(loc, v_broadcast, true, vva, false, num_batch_dims,
5860                      precision, builder);
5861       auto tau_x_vva = StaticBinaryBroadcast<mhlo::MulOp>(
5862           loc, tau, vva, GetI64ElementsAttr(batch_dim_indices, builder),
5863           *builder);
5864       a = builder->create<SubOp>(loc, a, tau_x_vva);
5865 
5866       // It is more precise to populate column 'k' explicitly, rather than
5867       // computing it implicitly by applying the Householder transformation.
5868       // a[k,k] = beta
5869       // a[k+1:,k] = np.zeros([m-k-1], dtype=a.dtype)
5870       auto iota = builder->create<IotaOp>(
5871           loc, RankedTensorType::get({m, 1}, builder->getIntegerType(32)),
5872           builder->getI64IntegerAttr(0));
5873       Value predecessor_mask = builder->create<chlo::BroadcastCompareOp>(
5874           loc, iota, j, GetI64ElementsAttr({}, builder),
5875           StringAttr::get(builder->getContext(), "LT"));
5876       predecessor_mask = builder->create<ConvertOp>(loc, predecessor_mask,
5877                                                     a_type.getElementType());
5878       Value mask = builder->create<chlo::BroadcastCompareOp>(
5879           loc, iota, j, GetI64ElementsAttr({}, builder),
5880           StringAttr::get(builder->getContext(), "EQ"));
5881       mask = builder->create<ConvertOp>(loc, mask, a_type.getElementType());
5882       llvm::SmallVector<int64_t, 4> broadcast_mask_shape(a_type.getRank(), 1);
5883       broadcast_mask_shape[a_type.getRank() - 2] = m;
5884       mask = builder->create<BroadcastOp>(
5885           loc,
5886           RankedTensorType::get(broadcast_mask_shape, a_type.getElementType()),
5887           mask,
5888           GetI64ElementsAttr(llvm::SmallVector<int64_t, 4>(num_batch_dims, 1),
5889                              builder));
5890       Value predecessor_masked_x = StaticBinaryBroadcast<MulOp>(
5891           loc, x, predecessor_mask,
5892           GetI64ElementsAttr({num_dims - 2, num_dims - 1}, builder), *builder);
5893       Value masked_beta = StaticBinaryBroadcast<MulOp>(
5894           loc, beta, mask, GetI64ElementsAttr(batch_dim_indices, builder),
5895           *builder);
5896       Value new_x =
5897           builder->create<AddOp>(loc, predecessor_masked_x, masked_beta);
5898       // Update a[:,j]
5899       llvm::SmallVector<int64_t, 4> dim_ids(num_dims);
5900       std::iota(dim_ids.begin(), dim_ids.end(), 0);
5901       new_x = builder->create<BroadcastInDimOp>(
5902           loc, a_type, new_x, GetI64ElementsAttr(dim_ids, builder));
5903       const int64_t minor_dim = num_batch_dims;
5904       auto iota_mn = builder->create<IotaOp>(
5905           loc,
5906           RankedTensorType::get(a_type.getShape(), builder->getIntegerType(32)),
5907           builder->getI64IntegerAttr(minor_dim + 1));
5908       Value xa_mask = builder->create<chlo::BroadcastCompareOp>(
5909           loc, iota_mn, j, GetI64ElementsAttr({}, builder),
5910           StringAttr::get(builder->getContext(), "EQ"));
5911       a = builder->create<SelectOp>(loc, a_type, xa_mask, new_x, a);
5912 
5913       // vs[:, j] = v
5914       llvm::SmallVector<int64_t, 4> vs_broadcast_dims(num_batch_dims + 1);
5915       std::iota(vs_broadcast_dims.begin(), vs_broadcast_dims.end(), 0);
5916       Value vs_zeros =
5917           GetScalarConstOfType(a_type.getElementType(), loc, 0, builder);
5918       vs_zeros = builder->create<BroadcastOp>(
5919           loc, vs.getType(), vs_zeros,
5920           GetI64ElementsAttr(vs.getType().cast<RankedTensorType>().getShape(),
5921                              builder));
5922       auto vs_update = builder->create<SelectOp>(
5923           loc, vs.getType(), xa_mask,
5924           StaticBinaryBroadcast<AddOp>(
5925               loc, vs_zeros, v, GetI64ElementsAttr(vs_broadcast_dims, builder),
5926               *builder),
5927           vs_zeros);
5928       vs = builder->create<AddOp>(loc, vs, vs_update);
5929 
5930       // taus[j] = tau
5931       llvm::SmallVector<int64_t, 4> tau_broadcast_dims(batch_dims.size());
5932       std::iota(tau_broadcast_dims.begin(), tau_broadcast_dims.end(), 0);
5933 
5934       auto iota_shape = llvm::to_vector<4>(batch_dims);
5935       iota_shape.push_back(n);
5936       auto iota_n = builder->create<IotaOp>(
5937           loc, RankedTensorType::get(iota_shape, builder->getIntegerType(32)),
5938           builder->getI64IntegerAttr(minor_dim));
5939       Value taus_zeros =
5940           GetScalarConstOfType(a_type.getElementType(), loc, 0, builder);
5941       taus_zeros = builder->create<BroadcastOp>(
5942           loc, taus.getType(), taus_zeros,
5943           GetI64ElementsAttr(taus.getType().cast<RankedTensorType>().getShape(),
5944                              builder));
5945       Value taus_mask = builder->create<chlo::BroadcastCompareOp>(
5946           loc, iota_n, j, GetI64ElementsAttr({}, builder),
5947           StringAttr::get(builder->getContext(), "EQ"));
5948       auto taus_update = builder->create<SelectOp>(
5949           loc, taus.getType(), taus_mask,
5950           StaticBinaryBroadcast<AddOp>(
5951               loc, taus_zeros, tau,
5952               GetI64ElementsAttr(tau_broadcast_dims, builder), *builder),
5953           taus_zeros);
5954       taus = builder->create<AddOp>(loc, taus, taus_update);
5955       new_values->assign({a, vs, taus});
5956     };
5957 
5958     Value zero =
5959         GetScalarConstOfType(a_type.getElementType(), loc, 0, rewriter);
5960     *vs = rewriter->create<BroadcastOp>(
5961         loc, a_type, zero, GetI64ElementsAttr(a_type.getShape(), rewriter));
5962     auto taus_shape = llvm::to_vector<4>(batch_dims);
5963     taus_shape.push_back(n);
5964     *taus = rewriter->create<BroadcastOp>(
5965         loc, RankedTensorType::get(taus_shape, a_type.getElementType()), zero,
5966         GetI64ElementsAttr(taus_shape, rewriter));
5967 
5968     SmallVector<Value, 4> while_output;
5969     CreateWhile32(loc, std::min(m, n), qr_body_fn, {a, *vs, *taus},
5970                   &while_output, rewriter);
5971     *r = while_output[0];
5972     *vs = while_output[1];
5973     *taus = while_output[2];
5974   }
5975 
5976   // Computes W and Y such that I-WY is equivalent to the sequence of
5977   // Householder
5978   // transformations given by vs and taus.
5979   // Golub and van Loan, "Matrix Computations", algorithm 5.1.2.
5980   // Y = np.zeros([m, n])
5981   // W = np.zeros([m, n])
5982   // Y[:, 0] = vs[:, 0]
5983   // W[:, 0] = -taus[0] * vs[:, 0]
5984   // for j in xrange(1, n):
5985   //   v = vs[:, j]
5986   //   z = -taus[j] * v - taus[j] * np.dot(W, np.dot(Y.T, v))
5987   //   W[:, j] = z
5988   //   Y[:, j] = v
5989   // return W
5990   // There is no need to return Y since at termination of the loop it is equal
5991   // to vs.
ComputeWYRepresentation(Location loc,Type data_type,ArrayRef<int64_t> batch_dims,Value vs,Value taus,int64_t m,int64_t n,PatternRewriter * rewriter) const5992   Value ComputeWYRepresentation(Location loc, Type data_type,
5993                                 ArrayRef<int64_t> batch_dims, Value vs,
5994                                 Value taus, int64_t m, int64_t n,
5995                                 PatternRewriter *rewriter) const {
5996     int64_t n_index = batch_dims.size() + 1;
5997     llvm::SmallVector<int64_t, 4> batch_dim_indices(batch_dims.size());
5998     std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
5999 
6000     auto body_fn = [&](Location loc, Value j, ArrayRef<Value> old_values,
6001                        SmallVectorImpl<Value> *new_values, OpBuilder *builder) {
6002       // w has shape [..., m, n]
6003       auto w = old_values[0];
6004       const auto vs = old_values[1];
6005       const auto taus = old_values[2];
6006 
6007       // Want j values in range [1, ... n).
6008       j = builder->create<AddOp>(
6009           loc, j,
6010           GetScalarConstOfType(getElementTypeOrSelf(j.getType()), loc, 1,
6011                                builder));
6012       // vs has shape [..., m, 1]
6013       auto v = DynamicSliceInMinorDims(loc, vs, {j}, {1}, builder);
6014       // beta has shape [..., 1]
6015       auto beta = DynamicSliceInMinorDims(loc, taus, {j}, {1}, builder);
6016 
6017       auto iota_shape = llvm::to_vector<4>(batch_dims);
6018       iota_shape.append({m, n});
6019       auto iota_mn = builder->create<IotaOp>(
6020           loc, RankedTensorType::get(iota_shape, builder->getIntegerType(32)),
6021           builder->getI64IntegerAttr(n_index));
6022 
6023       // y has shape [..., m, n]
6024       Value zero = GetScalarConstOfType(getElementTypeOrSelf(vs.getType()), loc,
6025                                         0, builder);
6026       zero = builder->create<BroadcastOp>(
6027           loc, vs.getType(), zero,
6028           GetI64ElementsAttr(vs.getType().cast<RankedTensorType>().getShape(),
6029                              builder));
6030       auto compare = builder->create<chlo::BroadcastCompareOp>(
6031           loc, iota_mn, j, GetI64ElementsAttr({}, builder),
6032           StringAttr::get(builder->getContext(), "GE"));
6033       auto y = builder->create<SelectOp>(loc, vs.getType(), compare, zero, vs);
6034 
6035       // yv has shape [..., n, 1]
6036       auto precision = builder->getStrArrayAttr({"HIGHEST", "HIGHEST"});
6037       auto yv = BatchDot(loc, y, true, v, false, batch_dims.size(), precision,
6038                          builder);
6039       // wyv has shape [..., m, 1]
6040       auto wyv = BatchDot(loc, w, false, yv, false, batch_dims.size(),
6041                           precision, builder);
6042 
6043       // z = -beta * (v + wyv)
6044       auto neg_beta = builder->create<NegOp>(loc, beta);
6045       auto v_wyv = builder->create<AddOp>(loc, v, wyv);
6046       auto beta_broadcast_dims = llvm::to_vector<4>(batch_dim_indices);
6047       beta_broadcast_dims.push_back(n_index);
6048       auto z = StaticBinaryBroadcast<MulOp>(
6049           loc, neg_beta, v_wyv,
6050           GetI64ElementsAttr(beta_broadcast_dims, builder), *rewriter);
6051 
6052       w = DynamicUpdateSliceInMinorDims(loc, w, z, {j}, builder);
6053       new_values->assign({w, vs, taus});
6054     };
6055 
6056     Value w =
6057         GetScalarConstOfType(getElementTypeOrSelf(data_type), loc, 0, rewriter);
6058     auto w_shape = llvm::to_vector<4>(batch_dims);
6059     w_shape.append({m, n});
6060     w = rewriter->create<BroadcastOp>(loc,
6061                                       RankedTensorType::get(w_shape, data_type),
6062                                       w, GetI64ElementsAttr(w_shape, rewriter));
6063     auto v = SliceInMinorDims(loc, vs, {0}, {1}, rewriter);
6064     auto beta = SliceInMinorDims(loc, taus, {0}, {1}, rewriter);
6065     auto neg_beta = rewriter->create<NegOp>(loc, beta);
6066     auto beta_broadcast_dims = llvm::to_vector<4>(batch_dim_indices);
6067     beta_broadcast_dims.push_back(n_index);
6068     auto bv = StaticBinaryBroadcast<MulOp>(
6069         loc, neg_beta, v, GetI64ElementsAttr(beta_broadcast_dims, rewriter),
6070         *rewriter);
6071     w = UpdateSliceInMinorDims(loc, w, bv, {0}, rewriter);
6072 
6073     SmallVector<Value, 4> while_output;
6074     CreateWhile32(loc, n - 1, body_fn, {w, vs, taus}, &while_output, rewriter);
6075     return while_output[0];
6076   }
6077 };
6078 
6079 // Emits debug information which includes the number of ops of each type which
6080 // failed to legalize.
EmitLegalizationErrors(Operation * op,const DenseSet<Operation * > & nonlegalized_ops)6081 void EmitLegalizationErrors(Operation *op,
6082                             const DenseSet<Operation *> &nonlegalized_ops) {
6083   // Track the legalization failures by mapping op name to information about
6084   // that failure: the number of unlegalized occurrences of the op, and one
6085   // example operation that failed.
6086   std::map<StringRef, std::pair<int, Operation *>> op_name_to_error_info;
6087   DenseSet<Operation *> error_ops;
6088   for (Operation *nonlegalized_op : nonlegalized_ops) {
6089     // Increment count of this legalization failure.
6090     StringRef op_name = nonlegalized_op->getName().getStringRef();
6091     // If this emplace is successful, it's the first time we've encountered
6092     // this op type. Initialize count to 0 so that after increment, it is 1.
6093     auto insertion_result = op_name_to_error_info.emplace(
6094         op_name, std::make_pair(0, nonlegalized_op));
6095     ++insertion_result.first->second.first;
6096   }
6097   std::vector<std::string> error_messages;
6098   error_messages.reserve(op_name_to_error_info.size());
6099   for (const auto &op_info : op_name_to_error_info) {
6100     error_messages.push_back(
6101         llvm::formatv("{0} (count: {1})", op_info.first, op_info.second.first));
6102   }
6103   Location loc = op->getLoc();
6104   emitError(loc) << "The following operations cannot be legalized: "
6105                  << llvm::join(error_messages, "; ")
6106                  << ". These legalization failure(s) may be due to missing TF "
6107                     "to HLO lowerings and/or unsupported attributes, etc.";
6108   // Emit more information about the missing ops. This error message
6109   // contains useful details beyond the op name (input and output shapes,
6110   // attributes, etc.).
6111   if (!VLOG_IS_ON(1) && nonlegalized_ops.size() != 1) {
6112     emitError(loc)
6113         << "Emitting more detail about one op that failed to legalize...";
6114   } else if (VLOG_IS_ON(1)) {
6115     emitError(loc) << "Emitting more detail about one of each type of op "
6116                       "that failed to legalize...";
6117   }
6118   for (const auto &op_info : op_name_to_error_info) {
6119     op_info.second.second->emitOpError() << "is not legalizable";
6120     if (!VLOG_IS_ON(1)) break;
6121   }
6122 }
6123 
6124 // Performs the lowering to XLA dialect.
runOnFunction()6125 void LegalizeTF::runOnFunction() {
6126   llvm::Optional<StringRef> tf2xla_fallback_device_type = llvm::None;
6127   if (use_tf2xla_fallback_) {
6128     tf2xla_fallback_device_type = device_type_;
6129   }
6130   if (failed(legalizeTF(getFunction(), allow_partial_conversion_,
6131                         legalize_chlo_, tf2xla_fallback_device_type))) {
6132     signalPassFailure();
6133   }
6134 }
6135 
6136 static PassRegistration<LegalizeTF> pass(
6137     "xla-legalize-tf", "Legalize from TensorFlow to the XLA dialect");
6138 
6139 }  // end namespace
6140 
6141 #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc"
6142 
legalizeTF(Operation * op,bool allow_partial_conversion,bool legalize_chlo,llvm::Optional<StringRef> tf2xla_fallback_device_type)6143 LogicalResult legalizeTF(
6144     Operation *op, bool allow_partial_conversion, bool legalize_chlo,
6145     llvm::Optional<StringRef> tf2xla_fallback_device_type) {
6146   MLIRContext *context = op->getContext();
6147   OwningRewritePatternList patterns;
6148   // Note that the `OperationConverter` orders patterns lexicographically by:
6149   // 1) Ascending legalization depth (i.e., minimum number of patterns necessary
6150   //    to arrive at conversion target). This requires relevant patterns to
6151   //    specify the list of ops generated by it which most of patterns
6152   //    implemented in C++ don't do so this comparison doesn't work in those
6153   //    cases.
6154   // 2) Descending pattern benefit.
6155   // 3) Op specific patterns over patterns with MatchAnyOpTypeTag.
6156   // 4) Order of patterns in `OwningRewritePatternList`.
6157 
6158   // Add TF->HLO legalization patterns.
6159   PopulateLegalizeTfPatterns(context, &patterns);
6160 
6161   // Add TF->TF lowering patterns.
6162   TF::PopulateTFLoweringBeforeHLOPatterns(context, &patterns);
6163 
6164   // Add TF->HLO legalization patterns via TF2XLA fallback.
6165   if (tf2xla_fallback_device_type.hasValue()) {
6166     PopulateLegalizeTfWithTf2XlaPatterns(tf2xla_fallback_device_type.getValue(),
6167                                          patterns);
6168   }
6169 
6170   // Populate with CHLO->HLO lowerings to account for TF ops legalized to
6171   // CHLO first.
6172   if (legalize_chlo) {
6173     chlo::PopulateLegalizeChloToHloPatterns(context, &patterns);
6174   }
6175   // ConstantLike op is convenient to create splat constants, but is
6176   // canonicalized to plain HLO constant if statically shaped. Add the
6177   // canonicalization pattern to pattern list to enable multi-hop lowering.
6178   chlo::ConstantLikeOp::getCanonicalizationPatterns(patterns, context);
6179 
6180   ConversionTarget target(*context);
6181   if (legalize_chlo) {
6182     target.addIllegalDialect<chlo::HloClientDialect>();
6183   } else {
6184     target.addLegalDialect<chlo::HloClientDialect>();
6185   }
6186   target.addLegalDialect<MhloDialect>();
6187   target.addLegalDialect<StandardOpsDialect>();
6188   target.addLegalDialect<tensor::TensorDialect>();
6189   target.addLegalDialect<shape::ShapeDialect>();
6190   target.addLegalOp<CallOp>();
6191 
6192   if (!allow_partial_conversion) {
6193     // Fully qualify ReturnOp here as mhlo dialect also defines a ReturnOp.
6194     target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ::mlir::ReturnOp>();
6195     DenseSet<Operation *> nonlegalized_ops;
6196     LogicalResult result = applyPartialConversion(
6197         op, target, std::move(patterns), &nonlegalized_ops);
6198     // In order to enforce that the conversion result is fully converted,
6199     // fail if there are any nonlegalized ops in the set.
6200     if (failed(result) || !nonlegalized_ops.empty()) {
6201       EmitLegalizationErrors(op, nonlegalized_ops);
6202       return failure();
6203     }
6204     return result;
6205   }
6206 
6207   return applyPartialConversion(op, target, std::move(patterns));
6208 }
6209 
PopulateLegalizeTfPatterns(MLIRContext * context,OwningRewritePatternList * patterns)6210 void PopulateLegalizeTfPatterns(MLIRContext *context,
6211                                 OwningRewritePatternList *patterns) {
6212   populateWithGenerated(context, *patterns);
6213   patterns->insert<
6214       ConvertAllOp, ConvertAnyOp, ConvertArgMaxOp, ConvertBatchMatMulV2Op,
6215       ConvertBiasAddOp, ConvertBroadcastToOp, ConvertBF16FloorDivOp,
6216       ConvertClipByValueOp, ConvertConv2DOp, ConvertConv3DOp,
6217       ConvertDepthConv2DOp, ConvertConv2DBackpropFilterOp,
6218       ConvertConv3DBackpropFilterOp, ConvertConv2DBackpropInputOp,
6219       ConvertConv3DBackpropInputOp, ConvertCumprodOp, ConvertCumsumOp,
6220       ConvertDiagPartOp, ConvertDynamicExpandDimsOp, ConvertDynamicReshapeOp,
6221       ConvertEinsumOp, ConvertRFFTOp, ConvertIRFFTOp,
6222       ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op,
6223       ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV2Op,
6224       ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp,
6225       ConvertIdentityNOp, ConvertInplaceUpdateOp, ConvertLinSpaceOp,
6226       ConvertMaxOp, ConvertMinOp, ConvertAvgPool2DOp, ConvertAvgPool3DOp,
6227       ConvertAvgPool2DGradOp, ConvertAvgPool3DGradOp, ConvertMaxPool2DOp,
6228       ConvertMaxPool3DOp, ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp,
6229       ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp,
6230       ConvertProdOp, ConvertQrOp, ConvertDynamicRangeOp,
6231       ConvertMatrixDiagPartV3Op, ConvertRangeOp, ConvertSelectV2Op,
6232       ConvertSigmoidOp, ConvertShapeOp,
6233       ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
6234       ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
6235       ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
6236       ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,
6237       ConvertUnpackOp, ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp,
6238       ConvertUnsortedSegmentProdOp, ConvertUnsortedSegmentSumOp,
6239       ConvertRandomShuffleOp, ConvertXlaShardingOp,
6240       ConvertXlaDynamicUpdateSliceOp, ConvertXlaAllReduceOp>(context);
6241 }
6242 
createLegalizeTFPass(bool allow_partial_conversion,bool legalize_chlo,llvm::Optional<StringRef> tf2xla_fallback_device_type)6243 std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass(
6244     bool allow_partial_conversion, bool legalize_chlo,
6245     llvm::Optional<StringRef> tf2xla_fallback_device_type) {
6246   return std::make_unique<LegalizeTF>(allow_partial_conversion, legalize_chlo,
6247                                       tf2xla_fallback_device_type);
6248 }
6249 
6250 }  // end namespace mhlo
6251 }  // end namespace mlir
6252