1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering TensorFlow dialect to XLA dialect.
17
18 #include <cctype>
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <limits>
23 #include <numeric>
24
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/Optional.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/Sequence.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/Support/ErrorHandling.h"
31 #include "llvm/Support/FormatVariadic.h"
32 #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
33 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
34 #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
35 #include "mlir/Dialect/Traits.h" // from @llvm-project
36 #include "mlir/IR/Attributes.h" // from @llvm-project
37 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
38 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
39 #include "mlir/IR/Diagnostics.h" // from @llvm-project
40 #include "mlir/IR/MLIRContext.h" // from @llvm-project
41 #include "mlir/IR/Matchers.h" // from @llvm-project
42 #include "mlir/IR/Operation.h" // from @llvm-project
43 #include "mlir/IR/PatternMatch.h" // from @llvm-project
44 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
45 #include "mlir/IR/Types.h" // from @llvm-project
46 #include "mlir/Pass/Pass.h" // from @llvm-project
47 #include "mlir/Support/LogicalResult.h" // from @llvm-project
48 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
49 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
50 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
51 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
52 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/convert_op_folder.h"
53 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
54 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
55 #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
56 #include "tensorflow/compiler/mlir/xla/attribute_importer.h"
57 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
58 #include "tensorflow/compiler/xla/client/lib/conv_grad_size_util.h"
59 #include "tensorflow/compiler/xla/client/padding.h"
60 #include "tensorflow/compiler/xla/client/sharding_builder.h"
61 #include "tensorflow/compiler/xla/xla_data.pb.h"
62 #include "tensorflow/core/framework/kernel_shape_util.h"
63 #include "tensorflow/core/kernels/conv_grad_shape_utils.h"
64 #include "tensorflow/core/platform/bfloat16.h"
65 #include "tensorflow/core/util/padding.h"
66 #include "tensorflow/core/util/tensor_format.h"
67
68 namespace mlir {
69 namespace mhlo {
70 namespace {
71
72 constexpr char kShardingAttr[] = "mhlo.sharding";
73
74 class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
getDependentDialects(DialectRegistry & registry) const75 void getDependentDialects(DialectRegistry ®istry) const override {
76 registry.insert<chlo::HloClientDialect, mhlo::MhloDialect,
77 shape::ShapeDialect, StandardOpsDialect>();
78 }
79
80 public:
81 LegalizeTF() = default;
LegalizeTF(const LegalizeTF &)82 LegalizeTF(const LegalizeTF &) {}
LegalizeTF(bool allow_partial_conversion,bool legalize_chlo,llvm::Optional<StringRef> tf2xla_fallback_device_type)83 explicit LegalizeTF(bool allow_partial_conversion, bool legalize_chlo,
84 llvm::Optional<StringRef> tf2xla_fallback_device_type) {
85 allow_partial_conversion_ = allow_partial_conversion;
86 legalize_chlo_ = legalize_chlo;
87 use_tf2xla_fallback_ = tf2xla_fallback_device_type.hasValue();
88 if (tf2xla_fallback_device_type.hasValue()) {
89 device_type_ = tf2xla_fallback_device_type.getValue().str();
90 }
91 }
92
93 /// Performs the lowering to XLA dialect.
94 void runOnFunction() override;
95
96 private:
97 Option<bool> allow_partial_conversion_{
98 *this, "allow-partial-conversion",
99 llvm::cl::desc("Allow operations that can't be legalized."),
100 llvm::cl::init(false)};
101 Option<bool> legalize_chlo_{
102 *this, "legalize-chlo",
103 llvm::cl::desc(
104 "Also legalizes intermediate chlo ops to hlo (default true)"),
105 llvm::cl::init(true)};
106 Option<bool> use_tf2xla_fallback_{
107 *this, "use-tf2xla-fallback",
108 llvm::cl::desc(
109 "Also use TF2XLA fallback for legalization (default false)"),
110 llvm::cl::init(false)};
111 Option<std::string> device_type_{
112 *this, "device-type",
113 llvm::cl::desc(
114 "The device type used by TF2XLA fallback. Must be specified if "
115 "use-tf2xla-fallback is true, otherwise not used."),
116 llvm::cl::init("INVALID_DEVICE_TYPE")};
117 };
118
119 /// Returns if the given TF data format string is the default format.
IsDefaultDataFormat(StringRef format)120 static bool IsDefaultDataFormat(StringRef format) { return format == "NHWC"; }
121
122 /// Returns the feature dimension for the given format and input type.
GetFeatureDimension(StringRef format,RankedTensorType inputType)123 static size_t GetFeatureDimension(StringRef format,
124 RankedTensorType inputType) {
125 return IsDefaultDataFormat(format) ? inputType.getRank() - 1 : 1;
126 }
127
128 // Gets all integer values from the given attribute and push them to `values`.
GetI64ArrayAttrValues(Attribute attr,SmallVectorImpl<int64_t> * values)129 void GetI64ArrayAttrValues(Attribute attr, SmallVectorImpl<int64_t> *values) {
130 auto array_attr = attr.cast<ArrayAttr>();
131 values->reserve(array_attr.getValue().size());
132 for (Attribute val : array_attr.getValue())
133 values->push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
134 }
135
136 // Returns 1D 64-bit dense elements attribute with the given values.
GetI64ElementsAttr(ArrayRef<int64_t> values,Builder * builder)137 static DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
138 Builder *builder) {
139 RankedTensorType ty = RankedTensorType::get(
140 {static_cast<int64_t>(values.size())}, builder->getIntegerType(64));
141 return DenseIntElementsAttr::get(ty, values);
142 }
143
144 // Converts an ArrayAttr to a 1D 64-bit dense elements attribute.
GetI64ElementsAttr(ArrayAttr attr)145 static DenseIntElementsAttr GetI64ElementsAttr(ArrayAttr attr) {
146 RankedTensorType ty =
147 RankedTensorType::get(static_cast<int64_t>(attr.size()),
148 IntegerType::get(attr.getContext(), 64));
149 return DenseIntElementsAttr::get(ty, attr.getValue());
150 }
151
152 // Returns 1D 32-bit dense elements attribute with the given values.
GetI32ElementsAttr(ArrayRef<int32_t> values,Builder * builder)153 static DenseIntElementsAttr GetI32ElementsAttr(ArrayRef<int32_t> values,
154 Builder *builder) {
155 RankedTensorType ty = RankedTensorType::get(
156 {static_cast<int32_t>(values.size())}, builder->getIntegerType(32));
157 return DenseIntElementsAttr::get(ty, values);
158 }
159
160 // Returns a 1-d i64 elements attribute populated with numbers from start to
161 // end, excluding.
GetI64ElementsAttrForSeq(int start,int end,Builder * builder)162 static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end,
163 Builder *builder) {
164 int size = end - start;
165
166 SmallVector<int64_t, 4> vals;
167 vals.resize(size);
168 std::iota(vals.begin(), vals.end(), start);
169
170 TensorType ty = RankedTensorType::get({size}, builder->getIntegerType(64));
171 return DenseIntElementsAttr::get(ty, vals);
172 }
173
174 // Returns a 1-d i64 elements attribute populated with `val` repeated `size`
175 // times.
GetI64ElementsAttrForValue(int size,int64_t val,Builder * builder)176 static DenseIntElementsAttr GetI64ElementsAttrForValue(int size, int64_t val,
177 Builder *builder) {
178 TensorType ty = RankedTensorType::get({size}, builder->getIntegerType(64));
179 return DenseIntElementsAttr::get(ty, val);
180 }
181
182 // Returns the corresponding type that should be used for performing sum
183 // accumulation over the given input type.
GetSumAccumulationType(Type input_type)184 Type GetSumAccumulationType(Type input_type) {
185 MLIRContext *ctx = input_type.getContext();
186 if (input_type.isBF16() || input_type.isF16()) return FloatType::getF32(ctx);
187 if (input_type.isSignlessInteger(8) || input_type.isSignlessInteger(16))
188 return IntegerType::get(ctx, 32);
189 return input_type;
190 }
191
192 // Returns axis in HLO format from TF elements attr with exactly one element or
193 // is an IntegerAttr, containing axis in the TensorFlow format. TensorFlow
194 // format supports negative indexing unlike HLO.
GetHLOAxisFromTFAxis(Attribute attr,int64_t rank,Builder * b)195 static IntegerAttr GetHLOAxisFromTFAxis(Attribute attr, int64_t rank,
196 Builder *b) {
197 IntegerAttr intAttr = attr.dyn_cast_or_null<IntegerAttr>();
198 if (auto elementAttr = attr.dyn_cast_or_null<ElementsAttr>()) {
199 SmallVector<uint64_t, 1> index(elementAttr.getType().getRank(), 0);
200 intAttr = elementAttr.getValue<IntegerAttr>(index);
201 }
202
203 assert(intAttr && "Invalid attribute passed to GetHLOAxisFromTFAxis");
204
205 int64_t axis = intAttr.getInt();
206 if (axis < 0) {
207 axis += rank;
208 }
209 return b->getI64IntegerAttr(axis);
210 }
211
212 // If `value` is an IntegerAttr, returns the integer value for the HLO axis
213 // corresponding to the tensorflow axis. In particular, the tensorflow axis can
214 // be negative, in which case, the corresponding HLO axis is
215 // (axis + rank-of-the-tensor).
GetIntegerHLOAxisFromTFAxis(Value value,int64_t rank)216 static llvm::Optional<int64_t> GetIntegerHLOAxisFromTFAxis(Value value,
217 int64_t rank) {
218 DenseIntElementsAttr attrs;
219 if (!matchPattern(value, m_Constant(&attrs)) ||
220 attrs.getType().getRank() != 0) {
221 return llvm::None;
222 }
223 int64_t axis = attrs.getValue<IntegerAttr>({}).getInt();
224 return axis < 0 ? axis + rank : axis;
225 }
226
227 /// Returns a `ConvertOp` that casts the elements to a i64 type while retaining
228 /// the shape of the input value.
CastValueToI64(Location loc,Value value,PatternRewriter * rewriter)229 static ConvertOp CastValueToI64(Location loc, Value value,
230 PatternRewriter *rewriter) {
231 return rewriter->create<ConvertOp>(loc, value, rewriter->getIntegerType(64));
232 }
233
234 // Creates an unpack op along the 0th dimension of the tensor. The `value` input
235 // must be a ranked tensor.
UnpackTensorAlongZeroDim(Location loc,Value value,PatternRewriter * rewriter)236 static TF::UnpackOp UnpackTensorAlongZeroDim(Location loc, Value value,
237 PatternRewriter *rewriter) {
238 auto indices_type = value.getType().cast<RankedTensorType>();
239 int num_outputs = indices_type.getShape().front();
240 SmallVector<Type, 2> unpacked_indices_type(
241 num_outputs, RankedTensorType::get({}, indices_type.getElementType()));
242 auto unpacked_indices = rewriter->create<TF::UnpackOp>(
243 loc, unpacked_indices_type, value,
244 IntegerAttr::get(rewriter->getIntegerType(64), 0));
245 return unpacked_indices;
246 }
247
248 // Returns size of dimension at the specified index, if ranked tensor.
249 // Otherwise, returns -1.
250 //
251 // Aborts if the type is ranked but doesn't have the dimension.
GetDimSize(Type ty,int64_t index)252 int64_t GetDimSize(Type ty, int64_t index) {
253 RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>();
254 if (!ranked_ty) return -1;
255
256 return ranked_ty.getDimSize(index);
257 }
258
259 template <typename T, int num_dims>
ToTensorShape(llvm::ArrayRef<T> sizes)260 tensorflow::TensorShape ToTensorShape(llvm::ArrayRef<T> sizes) {
261 return tensorflow::TensorShape(llvm::SmallVector<tensorflow::int64, num_dims>(
262 sizes.begin(), sizes.end()));
263 }
264
265 template <typename T, int num_dims>
ToTensorShape(llvm::iterator_range<DenseElementsAttr::ElementIterator<T>> sizes)266 tensorflow::TensorShape ToTensorShape(
267 llvm::iterator_range<DenseElementsAttr::ElementIterator<T>> sizes) {
268 return tensorflow::TensorShape(llvm::SmallVector<tensorflow::int64, num_dims>(
269 sizes.begin(), sizes.end()));
270 }
271
272 // Returns int, float, or complex scalar DenseElementsAttr attribute with the
273 // given element type and the value.
GetScalarConstOfType(Type ty,Location loc,int64_t raw_value,OpBuilder * builder)274 static ConstOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value,
275 OpBuilder *builder) {
276 return builder->create<ConstOp>(loc, hlo::GetScalarOfType(ty, raw_value));
277 }
278
279 // Returns a limit scalar const op for the given type.
280 // Requires FloatType or IntegerType
GetScalarLimitConstOfType(Type ty,Location loc,hlo::ScalarLimit limit,OpBuilder * builder)281 static ConstOp GetScalarLimitConstOfType(Type ty, Location loc,
282 hlo::ScalarLimit limit,
283 OpBuilder *builder) {
284 return builder->create<ConstOp>(loc, hlo::GetScalarLimitOfType(ty, limit));
285 }
286
287 // Creates an mhlo::SliceOp where the major dimensions have full size, and
288 // the minor dimensions have the provided offsets and sizes.
SliceInMinorDims(Location loc,Value v,ArrayRef<int64_t> minor_starts,ArrayRef<int64_t> minor_limits,OpBuilder * builder)289 static Value SliceInMinorDims(Location loc, Value v,
290 ArrayRef<int64_t> minor_starts,
291 ArrayRef<int64_t> minor_limits,
292 OpBuilder *builder) {
293 auto type = v.getType().cast<RankedTensorType>();
294 llvm::SmallVector<int64_t, 4> slice_starts(type.getRank(), 0);
295 int64_t major_dims = type.getRank() - minor_starts.size();
296 std::copy(minor_starts.begin(), minor_starts.end(),
297 slice_starts.begin() + major_dims);
298 auto slice_limits = llvm::to_vector<4>(type.getShape());
299 std::copy(minor_limits.begin(), minor_limits.end(),
300 slice_limits.begin() + major_dims);
301 llvm::SmallVector<int64_t, 4> slice_strides(type.getRank(), 1);
302 return builder->create<SliceOp>(loc, v,
303 GetI64ElementsAttr(slice_starts, builder),
304 GetI64ElementsAttr(slice_limits, builder),
305 GetI64ElementsAttr(slice_strides, builder));
306 }
307
308 // Creates a vector of index values:
309 // [0, 0, ..., minor_indices[0], minor_indices[1], ... minor_indices[-1]]
310 // with length `rank`.
CreateFullIndexVectorFromMinorIndices(Location loc,ArrayRef<Value> minor_indices,int64_t rank,OpBuilder * builder)311 static llvm::SmallVector<Value, 4> CreateFullIndexVectorFromMinorIndices(
312 Location loc, ArrayRef<Value> minor_indices, int64_t rank,
313 OpBuilder *builder) {
314 auto zero =
315 GetScalarConstOfType(getElementTypeOrSelf(minor_indices[0].getType()),
316 loc, 0, builder)
317 .output();
318 llvm::SmallVector<Value, 4> indices(rank, zero);
319 std::copy(minor_indices.begin(), minor_indices.end(),
320 indices.begin() + (rank - minor_indices.size()));
321 return indices;
322 }
323
324 // Creates an mhlo::DynamicSliceOp where the major dimensions have full size,
325 // and the minor dimensions have the provided offsets and sizes.
DynamicSliceInMinorDims(Location loc,Value v,ArrayRef<Value> minor_starts,ArrayRef<int64_t> minor_sizes,OpBuilder * builder)326 static Value DynamicSliceInMinorDims(Location loc, Value v,
327 ArrayRef<Value> minor_starts,
328 ArrayRef<int64_t> minor_sizes,
329 OpBuilder *builder) {
330 if (minor_starts.empty()) return v;
331 auto type = v.getType().cast<RankedTensorType>();
332 auto slice_starts = CreateFullIndexVectorFromMinorIndices(
333 loc, minor_starts, type.getRank(), builder);
334 int64_t major_dims = type.getRank() - minor_starts.size();
335 auto slice_sizes = llvm::to_vector<4>(type.getShape());
336 std::copy(minor_sizes.begin(), minor_sizes.end(),
337 slice_sizes.begin() + major_dims);
338 auto slice_type = RankedTensorType::get(slice_sizes, type.getElementType());
339 return builder->create<mhlo::DynamicSliceOp>(
340 loc, slice_type, v, slice_starts,
341 GetI64ElementsAttr(slice_sizes, builder));
342 }
343
344 // Creates an mhlo::DynamicUpdateSliceOp where the major dimensions have zero
345 // offsets, and the minor dimensions have the provided offsets.
DynamicUpdateSliceInMinorDims(Location loc,Value v,Value update,ArrayRef<Value> minor_starts,OpBuilder * builder)346 static Value DynamicUpdateSliceInMinorDims(Location loc, Value v, Value update,
347 ArrayRef<Value> minor_starts,
348 OpBuilder *builder) {
349 if (minor_starts.empty()) return v;
350 auto type = v.getType().cast<RankedTensorType>();
351 auto dus_starts = CreateFullIndexVectorFromMinorIndices(
352 loc, minor_starts, type.getRank(), builder);
353 return builder->create<DynamicUpdateSliceOp>(loc, type, v, update,
354 llvm::makeArrayRef(dus_starts));
355 }
356
357 // Creates an mhlo::DynamicUpdateSliceOp where the major dimensions have zero
358 // offsets, and the minor dimensions have the provided static offsets.
UpdateSliceInMinorDims(Location loc,Value v,Value update,ArrayRef<int64_t> minor_starts,OpBuilder * builder)359 static Value UpdateSliceInMinorDims(Location loc, Value v, Value update,
360 ArrayRef<int64_t> minor_starts,
361 OpBuilder *builder) {
362 llvm::SmallVector<Value, 4> dus_starts(minor_starts.size());
363 for (uint64_t i = 0; i < minor_starts.size(); ++i) {
364 dus_starts[i] = GetScalarConstOfType(builder->getIntegerType(32), loc,
365 minor_starts[i], builder);
366 }
367 return DynamicUpdateSliceInMinorDims(loc, v, update, dus_starts, builder);
368 }
369
370 // Deprecated: This is maintained to aid in porting old code that is not yet
371 // dynamic shape aware and uses broadcasting modes that CHLO does not support.
372 // Gets the resulting type from a broadcast between two types for statically
373 // shaped types. This is to be used for legacy lowerings that both use non
374 // left-padded broadcasting and static shapes. Its use should not be permitted
375 // in new code.
376 // May return nullptr on invalid static broadcast dimensions.
377 // ABSL_DEPRECATED()
GetStaticBroadcastType(RankedTensorType x,RankedTensorType y,DenseIntElementsAttr broadcast_dimensions_attr)378 static RankedTensorType GetStaticBroadcastType(
379 RankedTensorType x, RankedTensorType y,
380 DenseIntElementsAttr broadcast_dimensions_attr) {
381 auto element_type = x.getElementType();
382 auto shape_x = x.getShape();
383 auto shape_y = y.getShape();
384
385 if (shape_x.size() == shape_y.size()) {
386 llvm::SmallVector<int64_t, 4> out_shape(shape_x.size());
387 for (int i = 0; i < shape_x.size(); i++) {
388 auto x_val = shape_x[i];
389 auto y_val = shape_y[i];
390 out_shape[i] = std::max(x_val, y_val);
391 }
392 return RankedTensorType::get(out_shape, element_type);
393 }
394
395 auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y;
396 auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y;
397
398 llvm::SmallVector<int64_t, 4> broadcast_dimensions;
399 // Explicit broadcast dimensions.
400 for (const APInt &int_value : broadcast_dimensions_attr) {
401 broadcast_dimensions.push_back(int_value.getSExtValue());
402 }
403 if (broadcast_dimensions.size() != shape_small.size()) {
404 return nullptr;
405 }
406 llvm::SmallVector<int64_t, 4> out_shape(shape_large.begin(),
407 shape_large.end());
408
409 // Update according to the broadcast dimensions.
410 for (auto index_pair : llvm::enumerate(broadcast_dimensions)) {
411 auto old_value = out_shape[index_pair.value()];
412 auto new_value = shape_small[index_pair.index()];
413 out_shape[index_pair.value()] = std::max(old_value, new_value);
414 }
415 return RankedTensorType::get(out_shape, element_type);
416 }
417
418 // Deprecated: This is maintained to aid in porting old code that is not yet
419 // dynamic shape aware and uses broadcasting modes that CHLO does not support.
420 // Applies static binary broadcasting to a binary elementwise op.
421 // This is a legacy helper to provide general broadcasting support in legacy,
422 // static shaped code that relies on non-left-padded broadcasting semantics.
423 template <typename BinaryOp>
StaticBinaryBroadcast(Location loc,Value x,Value y,DenseIntElementsAttr broadcast_dims,OpBuilder & builder)424 static Value StaticBinaryBroadcast(Location loc, Value x, Value y,
425 DenseIntElementsAttr broadcast_dims,
426 OpBuilder &builder) {
427 auto x_type = x.getType().cast<RankedTensorType>();
428 auto y_type = y.getType().cast<RankedTensorType>();
429 auto result_type = GetStaticBroadcastType(x_type, y_type, broadcast_dims);
430 if (!result_type) {
431 emitError(loc) << "could not binary broadcast " << x_type << ", " << y_type
432 << " with broadcast_dims = " << broadcast_dims;
433 return nullptr;
434 }
435 auto larger_broadcast_dims =
436 GetI64ElementsAttrForSeq(0, result_type.getRank(), &builder);
437 if (x_type.getRank() < y_type.getRank()) {
438 if (x_type != result_type) {
439 x = builder.create<BroadcastInDimOp>(loc, result_type, x, broadcast_dims);
440 }
441 if (y_type != result_type) {
442 y = builder.create<BroadcastInDimOp>(loc, result_type, y,
443 larger_broadcast_dims);
444 }
445 } else {
446 if (x_type != result_type) {
447 x = builder.create<BroadcastInDimOp>(loc, result_type, x,
448 larger_broadcast_dims);
449 }
450 if (y_type != result_type) {
451 y = builder.create<BroadcastInDimOp>(loc, result_type, y, broadcast_dims);
452 }
453 }
454 return builder.create<BinaryOp>(loc, x, y);
455 }
456
457 // Gets a 1D tensor type suitable for expressing extents of the given tensor
458 // value type. If the value type is ranked, the result will be statically
459 // shaped. Otherwise, it will have a dynamic dimension.
GetExtentsTensorTypeFor(TensorType value_type)460 static RankedTensorType GetExtentsTensorTypeFor(TensorType value_type) {
461 Builder b(value_type.getContext());
462 int64_t dim = value_type.hasRank() ? value_type.getRank() : -1;
463 return RankedTensorType::get({dim}, b.getIndexType());
464 }
465
466 // Broadcasts a 'lower_rank_value' to the shape of a 'higher_rank_value'
467 // by assuming that the shape of the lower ranked is a broadcast compatible
468 // prefix of the higher ranked.
469 // Values must be RankedTensorType (this restriction derives from the
470 // broadcast_dimensions attribute on DynamicBroadcastInDim).
471 //
472 // Example:
473 // CommonPrefixBroadcast(tensor<4x3x256>, tensor<4, 3>) will broadcast the
474 // lower rank value to [4, 3, 256] (i.e. the opposite of numpy-style
475 // implicit broadcasting).
CommonPrefixBroadcast(Location loc,Value higher_rank_value,Value lower_rank_value,OpBuilder & builder)476 static Value CommonPrefixBroadcast(Location loc, Value higher_rank_value,
477 Value lower_rank_value, OpBuilder &builder) {
478 Value higher_rank_shape =
479 builder.create<shape::ShapeOfOp>(loc, higher_rank_value);
480 auto result_extents_type =
481 GetExtentsTensorTypeFor(higher_rank_value.getType().cast<TensorType>());
482 Value result_extents = builder.create<shape::ToExtentTensorOp>(
483 loc, result_extents_type, higher_rank_shape);
484
485 auto lower_rank_type = lower_rank_value.getType().cast<RankedTensorType>();
486 auto lower_rank = lower_rank_type.getRank();
487 auto prefix_dims = GetI64ElementsAttrForSeq(0, lower_rank, &builder);
488 return builder.create<DynamicBroadcastInDimOp>(
489 loc, higher_rank_value.getType(), lower_rank_value, result_extents,
490 prefix_dims);
491 }
492
493 // Given a value (broadcast_to) and a feature dimension, broadcasts a 1D
494 // value (broadcast_from) along that feature dimension. This is a shortcut
495 // for the cases where a 1D tensor must be broadcast along a specific feature
496 // dimension, which can vary based on data layout, etc.
497 //
498 // The extent of `broadcast_from` dim0 must be equal to the extent of the
499 // feature_dim of `broadcast_to`.
500 //
501 // Example:
502 // [1x2x3x4], [2], 1 -> [1x2x3x4]
503 // TODO(laurenzo): Swap the order of broadcast_to and broadcast_from for
504 // consistency. Possibly also rename for clarity.
Broadcast1DToFeatureDim(Location loc,Value broadcast_to,Value broadcast_from,int64_t feature_dim,OpBuilder & builder)505 static Value Broadcast1DToFeatureDim(Location loc, Value broadcast_to,
506 Value broadcast_from, int64_t feature_dim,
507 OpBuilder &builder) {
508 auto broadcast_dims = GetI64ElementsAttr({feature_dim}, &builder);
509 auto to_type = broadcast_to.getType().cast<RankedTensorType>();
510 auto result_shape = builder.create<shape::ShapeOfOp>(loc, broadcast_to);
511 auto result_extents_type = GetExtentsTensorTypeFor(to_type);
512 auto result_extents = builder.create<shape::ToExtentTensorOp>(
513 loc, result_extents_type, result_shape);
514 return builder.create<DynamicBroadcastInDimOp>(
515 loc, to_type, broadcast_from, result_extents, broadcast_dims);
516 }
517
518 // Broadcasts `input` to the shape of `broadcast_to` value following
519 // TF::BroadcastTo semantics.
520 //
521 // Requires that input is a ranked tensor.
522 //
523 // TODO(hinsu): Utilize TF::ShapeOp followed by TF::BroadcastTo once ShapeOp
524 // supports unranked inputs in the lowering.
BroadcastToShapeOf(Location loc,Value input,Value broadcast_to,OpBuilder & builder)525 static Value BroadcastToShapeOf(Location loc, Value input, Value broadcast_to,
526 OpBuilder &builder) {
527 auto result_shape = builder.create<shape::ShapeOfOp>(loc, broadcast_to);
528 auto to_type = broadcast_to.getType().cast<TensorType>();
529 auto result_extents_type = GetExtentsTensorTypeFor(to_type);
530 auto result_extents = builder.create<shape::ToExtentTensorOp>(
531 loc, result_extents_type, result_shape);
532 int64_t rank = input.getType().cast<RankedTensorType>().getRank();
533 auto broadcast_dims = GetI64ElementsAttrForSeq(0, rank, &builder);
534 return builder.create<DynamicBroadcastInDimOp>(
535 loc, to_type, input, result_extents, broadcast_dims);
536 }
537
538 // Creates a batch dot using mhlo::DotGeneralOp.
BatchDot(Location loc,Value lhs,bool transpose_lhs,Value rhs,bool transpose_rhs,int64_t num_batch_dims,ArrayAttr precision_config,OpBuilder * builder)539 Value BatchDot(Location loc, Value lhs, bool transpose_lhs, Value rhs,
540 bool transpose_rhs, int64_t num_batch_dims,
541 ArrayAttr precision_config, OpBuilder *builder) {
542 auto batch_dimensions = GetI64ElementsAttr(
543 llvm::to_vector<4>(llvm::seq<int64_t>(0, num_batch_dims)), builder);
544 auto lhs_contracting_dimensions = GetI64ElementsAttr(
545 llvm::makeArrayRef({transpose_lhs ? num_batch_dims : num_batch_dims + 1}),
546 builder);
547 auto rhs_contracting_dimensions = GetI64ElementsAttr(
548 llvm::makeArrayRef({transpose_rhs ? num_batch_dims + 1 : num_batch_dims}),
549 builder);
550 auto dimension_numbers = DotDimensionNumbers::get(
551 /*lhs_batching_dimensions=*/batch_dimensions,
552 /*rhs_batching_dimensions=*/batch_dimensions,
553 /*lhs_contracting_dimensions=*/lhs_contracting_dimensions,
554 /*rhs_contracting_dimensions=*/rhs_contracting_dimensions,
555 builder->getContext());
556 auto lhs_shape = lhs.getType().cast<RankedTensorType>().getShape();
557 auto rhs_shape = rhs.getType().cast<RankedTensorType>().getShape();
558 auto shape = llvm::to_vector<4>(lhs_shape);
559 shape[shape.size() - 2] =
560 transpose_lhs ? lhs_shape.back() : lhs_shape[lhs_shape.size() - 2];
561 shape[shape.size() - 1] =
562 transpose_rhs ? rhs_shape[rhs_shape.size() - 2] : rhs_shape.back();
563 Type element_type = getElementTypeOrSelf(lhs.getType());
564 return builder->create<DotGeneralOp>(
565 loc, RankedTensorType::get(shape, element_type), lhs, rhs,
566 dimension_numbers, precision_config);
567 }
568
569 // Builds body for reduce op by using the template binary op as the
570 // reducer op.
571 template <typename Op>
BuildReduceBody(Type element_type,Region * body,OpBuilder * builder)572 static void BuildReduceBody(Type element_type, Region *body,
573 OpBuilder *builder) {
574 OpBuilder::InsertionGuard guard(*builder);
575 Block *block = builder->createBlock(body);
576
577 // Block arguments are scalars of the given element type.
578 Type type = RankedTensorType::get(/*shape=*/{}, element_type);
579 block->addArguments({type, type});
580
581 Location loc = body->getLoc();
582 auto reducer =
583 builder->create<Op>(loc, block->getArgument(0), block->getArgument(1));
584 builder->create<ReturnOp>(loc, reducer.getResult());
585 }
586
587 // Builds region taking two arguments and returning second argument as the
588 // result. Corresponds to the function f(x, y) = y.
589 // Used in Scatter op's computation to update specific elements.
BuildBinaryAssignmentRegion(Type element_type,Region * region,OpBuilder * builder)590 static void BuildBinaryAssignmentRegion(Type element_type, Region *region,
591 OpBuilder *builder) {}
592
593 // Builds a set of operations for applying reduction on the input value. A
594 // tf.sum op is created and will be legalized to tfl ops automatically.
ApplyReduction(Location loc,Value input,DenseIntElementsAttr reduce_dims,OpBuilder * builder)595 static Value ApplyReduction(Location loc, Value input,
596 DenseIntElementsAttr reduce_dims,
597 OpBuilder *builder) {
598 auto reduce_dims_op = builder->create<ConstOp>(loc, reduce_dims);
599 return builder->create<TF::SumOp>(loc, input, reduce_dims_op,
600 builder->getBoolAttr(false));
601 }
602
603 // Creates a mhlo.rng_uniform op with `builder` to generate `num_elements`
604 // 32-bit integer numbers in the range of [`lower_limit`, `upper_limit`).
CreateRngUniform32(Location loc,int num_elements,int lower_limit,int upper_limit,OpBuilder * builder)605 static mhlo::RngUniformOp CreateRngUniform32(Location loc, int num_elements,
606 int lower_limit, int upper_limit,
607 OpBuilder *builder) {
608 auto i32_type = builder->getIntegerType(32);
609 auto key_type = RankedTensorType::get({num_elements}, i32_type);
610 auto shape_tensor = builder->create<mhlo::ConstOp>(
611 loc, GetI64ElementsAttr({num_elements}, builder));
612
613 auto lower = builder->create<mhlo::ConstOp>(
614 loc, builder->getI32IntegerAttr(lower_limit));
615 auto upper = builder->create<mhlo::ConstOp>(
616 loc, builder->getI32IntegerAttr(upper_limit));
617
618 return builder->create<mhlo::RngUniformOp>(loc, key_type, lower, upper,
619 shape_tensor);
620 }
621
622 using WhileBodyFnType = llvm::function_ref<void(
623 Location loc, Value iteration, ArrayRef<Value> old_values,
624 SmallVectorImpl<Value> *new_values, OpBuilder *builder)>;
625
626 // Creates a mhlo.while op with `builder` to loop `num_interations` times,
627 // each time calling the given `body_fn` on a set of values to generate a new
628 // set of values. Returns the final set of values via `final_values`. The
629 // initial set of values is passed in via `init_values`.
630 //
631 // This effectively does:
632 //
633 // ```c++
634 // SmallVector<Values, 4> old_values = init_values;
635 // SmallVector<Values, 4> new_values;
636 // for (int i = 0; i < num_iterations; ++i) {
637 // body_fn(old_values, &new_values, ...);
638 // old_values = new_values;
639 // }
640 // ```
641 //
642 // Under the hood an induction variable is prepended to values to control the
643 // number of iterations, but that is transparent to `body_fn`, which does not
644 // need to care about that.
CreateWhile32(Location loc,int num_iterations,WhileBodyFnType body_fn,ArrayRef<Value> init_values,SmallVectorImpl<Value> * final_values,OpBuilder * builder)645 static void CreateWhile32(Location loc, int num_iterations,
646 WhileBodyFnType body_fn, ArrayRef<Value> init_values,
647 SmallVectorImpl<Value> *final_values,
648 OpBuilder *builder) {
649 int value_count = init_values.size() + 1;
650
651 // Prepend a loop induction variable to the initial values.
652 SmallVector<Value, 2> init_values_with_loop_iv;
653 init_values_with_loop_iv.reserve(value_count);
654 // The initial value for the loop induction variable is 0.
655 init_values_with_loop_iv.push_back(
656 builder->create<mhlo::ConstOp>(loc, builder->getI32IntegerAttr(0)));
657 init_values_with_loop_iv.append(init_values.begin(), init_values.end());
658
659 // Prepare the initial tuple for the while op.
660 auto init_tuple =
661 builder->create<mhlo::TupleOp>(loc, init_values_with_loop_iv);
662 auto tuple_type = init_tuple.getType();
663
664 // Create the while op.
665 auto while_op = builder->create<mhlo::WhileOp>(loc, init_tuple);
666
667 {
668 OpBuilder::InsertionGuard guard(*builder);
669
670 // Build up the only block in the condition region. It should take one
671 // argument of the loop's tuple type.
672 Region &condition = while_op.cond();
673 Block *block = builder->createBlock(&condition);
674 BlockArgument arg = block->addArgument(tuple_type);
675
676 // Get the loop induction variable and compare it against the upper limit.
677 auto loop_iv = builder->create<GetTupleElementOp>(loc, arg, 0);
678 auto upper_limit = builder->create<mhlo::ConstOp>(
679 loc, builder->getI32IntegerAttr(num_iterations));
680 StringAttr compare_direction = StringAttr::get(builder->getContext(), "LT");
681 Value compare = builder->create<mhlo::CompareOp>(loc, loop_iv, upper_limit,
682 compare_direction);
683
684 builder->create<mhlo::ReturnOp>(loc, compare);
685 }
686
687 {
688 OpBuilder::InsertionGuard guard(*builder);
689
690 // Build up the only block in the body region. It should take one
691 // argument of the loop's tuple type.
692 Region &body = while_op.body();
693 Block *block = builder->createBlock(&body);
694 BlockArgument arg = block->addArgument(tuple_type);
695
696 SmallVector<Value, 4> old_values; // From the previous iteration
697 SmallVector<Value, 4> new_values; // Generated by this iteration
698 old_values.reserve(value_count);
699 new_values.reserve(value_count);
700
701 // Unpack the tuple value from the last iteration.
702 for (int i = 0; i < value_count; ++i)
703 old_values.push_back(builder->create<GetTupleElementOp>(loc, arg, i));
704
705 // Feed all values excluding the loop induction variable to body_fn.
706 body_fn(loc, old_values[0], llvm::makeArrayRef(old_values).drop_front(),
707 &new_values, builder);
708
709 // Increment the loop induction variable by one.
710 auto one =
711 builder->create<mhlo::ConstOp>(loc, builder->getI32IntegerAttr(1));
712 auto scalar_broadcast_dims = GetI64ElementsAttr({}, builder);
713 auto plus_one = builder->create<chlo::BroadcastAddOp>(
714 loc, old_values[0], one, scalar_broadcast_dims);
715 // Prepend with the updated loop induction variable.
716 new_values.insert(new_values.begin(), plus_one);
717
718 Value updated_tuple = builder->create<mhlo::TupleOp>(loc, new_values);
719
720 builder->create<mhlo::ReturnOp>(loc, updated_tuple);
721 }
722
723 final_values->reserve(init_values.size());
724 for (int i = 0, e = init_values.size(); i < e; ++i)
725 final_values->push_back(
726 builder->create<GetTupleElementOp>(loc, while_op, i + 1));
727 }
728
729 //===----------------------------------------------------------------------===//
730 // BatchNorm op utilities.
731 //===----------------------------------------------------------------------===//
732
getFeatureDimensionAttr(Builder & b,StringRef format,Value input)733 static IntegerAttr getFeatureDimensionAttr(Builder &b, StringRef format,
734 Value input) {
735 return b.getI64IntegerAttr(
736 GetFeatureDimension(format, input.getType().cast<RankedTensorType>()));
737 }
738
739 //===----------------------------------------------------------------------===//
740 // FFT op utilities.
741 //===----------------------------------------------------------------------===//
742 // Returns the 1D i64 elements attribute populated with the inner-most dim of
743 // the value.
GetInnerDimFromValue(ShapedType type,Builder * builder)744 static DenseIntElementsAttr GetInnerDimFromValue(ShapedType type,
745 Builder *builder) {
746 if (type.getRank() == 0) {
747 return builder->getI64TensorAttr({});
748 }
749 return builder->getI64TensorAttr(type.getShape().back());
750 }
751
752 // Returns True if the inner-most dim is static.
CheckInnerDimStatic(ShapedType type,Builder * builder)753 bool CheckInnerDimStatic(ShapedType type, Builder *builder) {
754 if (!type.hasRank()) {
755 return false;
756 }
757 return !type.isDynamicDim(type.getShape().size() - 1);
758 }
759
760 //===----------------------------------------------------------------------===//
761 // MatMul op utilities.
762 //===----------------------------------------------------------------------===//
763
764 // If the 'transpose' attribute is true returns ElementsAttr to transpose 2D
765 // matrix. Otherwise, returns ElementsAttr for identity transpose.
Get2DTransposePerm(BoolAttr transpose,Builder * b)766 static DenseIntElementsAttr Get2DTransposePerm(BoolAttr transpose, Builder *b) {
767 if (transpose.getValue()) return GetI64ElementsAttr({1, 0}, b);
768 return GetI64ElementsAttr({0, 1}, b);
769 }
770
771 //===----------------------------------------------------------------------===//
772 // MatrixBandPart op utilities.
773 //===----------------------------------------------------------------------===//
774
775 // Gets the size of the dimension `dim_from_end` from the end of `input`.
776 // Requires that `input` is a tensor.
GetDimensionSizeFromEnd(Value input,int dim_from_end)777 static int GetDimensionSizeFromEnd(Value input, int dim_from_end) {
778 // Note: the verifier enforces that `input` is a ranked tensor.
779 auto input_type = input.getType().cast<TensorType>();
780 auto input_shape = input_type.getShape();
781 int dim = (input_shape.size() - 1) - dim_from_end;
782 return input_shape[dim];
783 }
784
785 // Gets a 2D tensor type with shape {dim_0, dim_1}, where `dim_0` and `dim_1`
786 // have the same size as the last two dimensions of `input` (the second-to-last
787 // dimension and last dimension, respectively). The element type of the
788 // outputted RankedTensorType will match the element type of `input`.
789 // Requires that `input` is a tensor.
Get2DTensorType(Value input,Value num_lower)790 static RankedTensorType Get2DTensorType(Value input, Value num_lower) {
791 // `dim_0` refers to the second-to-last dimension; `dim_1` refers to the last.
792 int dim_0 = GetDimensionSizeFromEnd(input, 1);
793 int dim_1 = GetDimensionSizeFromEnd(input, 0);
794 auto element_type = num_lower.getType().cast<TensorType>().getElementType();
795 return RankedTensorType::get({dim_0, dim_1}, element_type);
796 }
797
798 // Creates a HLO ConvertOp, converting `input` to have the same element type as
799 // `elem_type_tensor`. Requires `elem_type_tensor` to be a tensor.
CreateConvertOp(OpBuilder * builder,Location loc,Value input,Value elem_type_tensor)800 static Value CreateConvertOp(OpBuilder *builder, Location loc, Value input,
801 Value elem_type_tensor) {
802 auto element_type =
803 elem_type_tensor.getType().cast<TensorType>().getElementType();
804 return builder->create<mhlo::ConvertOp>(loc, input, element_type);
805 }
806
807 //===----------------------------------------------------------------------===//
808 // Pad op utilities.
809 //===----------------------------------------------------------------------===//
810
811 // Slices input attribute of rank two and returns the specified column.
812 //
813 // Always returns 64 bit integer attribute regardless of bitwidth of the input
814 // attribute.
SliceDenseIntElementsAttrColumn2D(ElementsAttr input,int column)815 static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D(
816 ElementsAttr input, int column) {
817 auto int_attr = input.cast<DenseIntElementsAttr>();
818 auto shaped_type = int_attr.getType();
819 auto shape = shaped_type.getShape();
820
821 if (shape.size() != 2) return DenseIntElementsAttr();
822
823 llvm::SmallVector<int64_t, 4> values;
824 values.reserve(shaped_type.getNumElements() / shape[1]);
825
826 for (auto it : llvm::enumerate(int_attr.getIntValues())) {
827 if (static_cast<int>(it.index() % shape[1]) == column) {
828 values.push_back(it.value().getSExtValue());
829 }
830 }
831
832 auto element_type = IntegerType::get(input.getContext(), 64);
833 return DenseIntElementsAttr::get(
834 RankedTensorType::get({shape[0]}, element_type), values);
835 }
836
837 // Returns interior padding to use in HLO Pad op based on the TensorFlow padding
838 // in TensorFlow PadV2 op.
GetInteriorPadding(ElementsAttr tf_padding)839 static DenseIntElementsAttr GetInteriorPadding(ElementsAttr tf_padding) {
840 auto length = tf_padding.getType().getShape()[0];
841 auto element_type = IntegerType::get(tf_padding.getContext(), 64);
842 return DenseIntElementsAttr::get<int64_t>(
843 RankedTensorType::get({length}, element_type), 0);
844 }
845
846 //===----------------------------------------------------------------------===//
847 // Binary op utilities.
848 //===----------------------------------------------------------------------===//
849
850 // Returns whether the two values are guaranteed to be broadcastable to the
851 // same shape, this broadcasts size 1 tensors up to any rank. Dynamic dimensions
852 // must be broadcasted with a size 1 tensor or another dynamic dimension.
853 // Returns false on rankless.
AreBroadcastCompatible(Value x,Value y)854 static bool AreBroadcastCompatible(Value x, Value y) {
855 auto x_rankless = x.getType().dyn_cast<RankedTensorType>();
856 auto y_rankless = y.getType().dyn_cast<RankedTensorType>();
857 if (!x_rankless || !y_rankless) {
858 return false;
859 }
860
861 // Check that the shapes can be broadcasted.
862 auto shape_x = x_rankless.getShape();
863 auto shape_y = y_rankless.getShape();
864
865 int rank_diff = shape_x.size() - shape_y.size();
866 int offset_x = rank_diff > 0 ? rank_diff : 0;
867 int offset_y = rank_diff < 0 ? -rank_diff : 0;
868 for (int i = 0, s = std::min(shape_x.size(), shape_y.size()); i < s; i++) {
869 int index_x = i + offset_x;
870 int index_y = i + offset_y;
871 if ((shape_x[index_x] == -1 && shape_y[index_y] != 1) ||
872 (shape_y[index_y] == -1 && shape_x[index_x] != 1)) {
873 return false;
874 }
875 }
876
877 return true;
878 }
879
880 // Return a new TensorType the same rank and dimensions as the input with an
881 // updated element type.
ChangeTensorElementType(Builder * b,Type tensor_type,Type element_type)882 static Type ChangeTensorElementType(Builder *b, Type tensor_type,
883 Type element_type) {
884 RankedTensorType ranked_type = tensor_type.dyn_cast<RankedTensorType>();
885 if (ranked_type) {
886 return RankedTensorType::get(ranked_type.getShape(), element_type);
887 }
888
889 return UnrankedTensorType::get(element_type);
890 }
891
892 //===----------------------------------------------------------------------===//
893 // Softmax op utilities.
894 //===----------------------------------------------------------------------===//
895
896 // Returns the type to use for accumulating the given type.
GetAccumulationType(Type ty)897 static Type GetAccumulationType(Type ty) {
898 // Upcast 16 bit sum reductions to 32 bit to reduce the precision loss from
899 // repeated floating point additions.
900 return (ty.isF16() || ty.isBF16()) ? FloatType::getF32(ty.getContext()) : ty;
901 }
902
903 //===----------------------------------------------------------------------===//
904 // Softplus op utilities.
905 //===----------------------------------------------------------------------===//
906
GetEpsilonValue(Type ty)907 static DenseElementsAttr GetEpsilonValue(Type ty) {
908 auto element_ty = ty.cast<TensorType>().getElementType();
909 auto scalar_ty = RankedTensorType::get({}, element_ty);
910 if (element_ty.isF16()) {
911 uint16_t raw_epsilon = Eigen::NumTraits<Eigen::half>::epsilon().x;
912 auto value = APFloat(APFloat::IEEEhalf(), APInt(16, raw_epsilon));
913 return DenseElementsAttr::get(scalar_ty, value);
914 } else if (element_ty.isBF16()) {
915 uint16_t raw_epsilon = Eigen::NumTraits<Eigen::bfloat16>::epsilon().value;
916 auto value = APFloat(APFloat::BFloat(), APInt(16, raw_epsilon));
917 return DenseElementsAttr::get(scalar_ty, value);
918 } else if (element_ty.isF32()) {
919 auto value = APFloat(std::numeric_limits<float>::epsilon());
920 return DenseElementsAttr::get(scalar_ty, value);
921 } else if (element_ty.isF64()) {
922 auto value = APFloat(std::numeric_limits<double>::epsilon());
923 return DenseElementsAttr::get(scalar_ty, value);
924 }
925 llvm_unreachable("unsupported element type for tf.SoftPlus");
926 }
927
928 //===----------------------------------------------------------------------===//
929 // ArgMax/ArgMin op utilities.
930 //===----------------------------------------------------------------------===//
931
BuildArgMinMaxReductionBody(Type input_element_type,Type index_element_type,StringRef direction,Region * body,OpBuilder * builder)932 static void BuildArgMinMaxReductionBody(Type input_element_type,
933 Type index_element_type,
934 StringRef direction, Region *body,
935 OpBuilder *builder) {
936 OpBuilder::InsertionGuard insertion_point_gurad(*builder);
937
938 Type input_type = RankedTensorType::get(/*shape=*/{}, input_element_type);
939 Type index_type = RankedTensorType::get(/*shape=*/{}, index_element_type);
940 Block *block = builder->createBlock(body);
941 block->addArguments({input_type, index_type, input_type, index_type});
942
943 Location loc = body->getLoc();
944 StringAttr compare_direction =
945 StringAttr::get(builder->getContext(), direction);
946 Value compare = builder->create<CompareOp>(
947 loc, block->getArgument(0), block->getArgument(2), compare_direction);
948
949 Value selected_input = builder->create<SelectOp>(
950 loc, input_type, compare, block->getArgument(0), block->getArgument(2));
951 Value selected_index = builder->create<SelectOp>(
952 loc, index_type, compare, block->getArgument(1), block->getArgument(3));
953
954 Value return_values[] = {selected_input, selected_index};
955 builder->create<ReturnOp>(loc, return_values);
956 }
957
958 //===----------------------------------------------------------------------===//
959 // PartitionedCall op utilities.
960 //===----------------------------------------------------------------------===//
961
962 // Verify that the arguments to be passed into the function are the same types
963 // as the function paramter types.
ArgTypesMatchCallee(mlir::Operation * op,OperandRange args,SymbolRefAttr func)964 static bool ArgTypesMatchCallee(mlir::Operation *op, OperandRange args,
965 SymbolRefAttr func) {
966 auto module = op->getParentOfType<ModuleOp>();
967 auto function =
968 dyn_cast_or_null<FuncOp>(SymbolTable::lookupSymbolIn(module, func));
969 FunctionType function_ty = function.getType();
970
971 for (auto arg_in : llvm::zip(args, function_ty.getInputs())) {
972 if (std::get<0>(arg_in).getType() != std::get<1>(arg_in)) {
973 // Argument type and input type mismatch.
974 return false;
975 }
976 }
977 return true;
978 }
979
980 //===----------------------------------------------------------------------===//
981 // Slice op utilities.
982 //===----------------------------------------------------------------------===//
983
CanBeTranslatedToDynamicSlice(Value input,Value start_indices,DenseIntElementsAttr slice_sizes)984 static bool CanBeTranslatedToDynamicSlice(Value input, Value start_indices,
985 DenseIntElementsAttr slice_sizes) {
986 auto input_ty = input.getType().dyn_cast<RankedTensorType>();
987 if (!input_ty) return false;
988 auto start_indices_ty = start_indices.getType().dyn_cast<RankedTensorType>();
989 if (!start_indices_ty) return false;
990
991 int64_t input_rank = input_ty.getRank();
992 ArrayRef<int64_t> input_shape = input_ty.getShape();
993 DenseIntElementsAttr constant_start_indices;
994 bool is_constant_start =
995 matchPattern(start_indices, m_Constant(&constant_start_indices));
996
997 for (int64_t i = 0; i < input_rank; ++i) {
998 int64_t input_size = input_shape[i];
999 int64_t slice_size = slice_sizes.getValue<IntegerAttr>(i).getInt();
1000 // A slice_size of -1 means "all elements from start_index to the end".
1001 // In order to support these semantics, we need to know both the start index
1002 // and the shape of the input dimension.
1003 if (slice_size < 0 && (!is_constant_start || input_size < 0)) return false;
1004 }
1005 return true;
1006 }
1007
1008 // TF slice size can be -1, which represents all elements from start_index to
1009 // the end. HLO slice size can't be -1. As such, we need to translate TF slice
1010 // size -1 to HLO slice size.
TFSliceSizes2HLOSliceSizes(Value input,Value start_indices,DenseIntElementsAttr slice_sizes,Builder * builder)1011 static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes(
1012 Value input, Value start_indices, DenseIntElementsAttr slice_sizes,
1013 Builder *builder) {
1014 DenseIntElementsAttr constant_start_indices;
1015 if (!matchPattern(start_indices, m_Constant(&constant_start_indices))) {
1016 return hlo::ConvertElementsAttr(slice_sizes, builder->getIntegerType(64))
1017 .cast<DenseIntElementsAttr>();
1018 }
1019
1020 auto input_ty = input.getType().dyn_cast<RankedTensorType>();
1021 int64_t input_rank = input_ty.getRank();
1022 ArrayRef<int64_t> input_shape = input_ty.getShape();
1023 SmallVector<int64_t, 4> normalized_sizes;
1024
1025 for (int64_t i = 0; i < input_rank; ++i) {
1026 int64_t input_size = input_shape[i];
1027 int64_t start_index =
1028 constant_start_indices.getValue<IntegerAttr>(i).getInt();
1029 int64_t slice_size = slice_sizes.getValue<IntegerAttr>(i).getInt();
1030 normalized_sizes.push_back(slice_size == -1 ? input_size - start_index
1031 : slice_size);
1032 }
1033
1034 return GetI64ElementsAttr(normalized_sizes, builder);
1035 }
1036
1037 //===----------------------------------------------------------------------===//
1038 // Sort op utilities.
1039 //===----------------------------------------------------------------------===//
1040
1041 // Builds the region `body` for mhlo.sort's comparator: for each type in
1042 // `element_types`, create two block arguments, one for lhs and one for rhs, and
1043 // generates mhlo.compare op to compare them with the given `direction`.
1044 //
1045 // Note that this right now only does comparision on the first pair of block
1046 // arguments.
BuildSortComparisonBody(llvm::ArrayRef<Type> element_types,StringRef direction,llvm::Optional<StringRef> compare_type,Region * body,OpBuilder * builder)1047 static void BuildSortComparisonBody(llvm::ArrayRef<Type> element_types,
1048 StringRef direction,
1049 llvm::Optional<StringRef> compare_type,
1050 Region *body, OpBuilder *builder) {
1051 OpBuilder::InsertionGuard insertion_point_gurad(*builder);
1052
1053 Block *block = builder->createBlock(body);
1054 // Add two arguments for each element type.
1055 for (Type element_type : element_types) {
1056 TensorType tensor_type = RankedTensorType::get({}, element_type);
1057 block->addArguments({tensor_type, tensor_type});
1058 }
1059
1060 Location loc = body->getLoc();
1061 StringAttr compare_direction = builder->getStringAttr(direction);
1062 StringAttr type_attr;
1063 if (compare_type) type_attr = builder->getStringAttr(*compare_type);
1064 Value compare = builder->create<mhlo::CompareOp>(
1065 loc, block->getArgument(0), block->getArgument(1), compare_direction,
1066 type_attr);
1067
1068 builder->create<mhlo::ReturnOp>(loc, compare);
1069 }
1070
1071 //===----------------------------------------------------------------------===//
1072 // XlaGather op utilities.
1073 //===----------------------------------------------------------------------===//
1074
HasValidGatherDims(StringAttr attr)1075 bool HasValidGatherDims(StringAttr attr) {
1076 ::xla::GatherDimensionNumbers dims;
1077 return dims.ParseFromString(attr.getValue().str());
1078 }
1079
GetGatherDimNumsAttr(StringAttr attr,Builder * builder)1080 GatherDimensionNumbers GetGatherDimNumsAttr(StringAttr attr, Builder *builder) {
1081 ::xla::GatherDimensionNumbers dims;
1082 if (!dims.ParseFromString(attr.getValue().str())) return {};
1083 return ::xla::ConvertGatherDimensionNumbers(dims, builder);
1084 }
1085
1086 //===----------------------------------------------------------------------===//
1087 // Op converters.
1088 //===----------------------------------------------------------------------===//
1089
GetConvDimensionNumbersAttr(ArrayRef<int64_t> spatial_dim_indices,tensorflow::TensorFormat format,Builder * builder)1090 NamedAttribute GetConvDimensionNumbersAttr(
1091 ArrayRef<int64_t> spatial_dim_indices, tensorflow::TensorFormat format,
1092 Builder *builder) {
1093 int64_t num_spatial_dims = spatial_dim_indices.size();
1094 int64_t num_dims = num_spatial_dims + 2;
1095
1096 IntegerAttr batch_dim =
1097 builder->getI64IntegerAttr(GetTensorBatchDimIndex(num_dims, format));
1098 IntegerAttr feature_dim =
1099 builder->getI64IntegerAttr(GetTensorFeatureDimIndex(num_dims, format));
1100 DenseIntElementsAttr spatial_dims =
1101 GetI64ElementsAttr(spatial_dim_indices, builder);
1102
1103 // Filters data_format is always HWIO so input channels dimension is after
1104 // all spatial dimensions.
1105 IntegerAttr kernel_input_feature_dim =
1106 builder->getI64IntegerAttr(num_spatial_dims);
1107 IntegerAttr kernel_output_feature_dim =
1108 builder->getI64IntegerAttr(num_spatial_dims + 1);
1109 DenseIntElementsAttr kernel_spatial_dimensions =
1110 GetI64ElementsAttrForSeq(0, num_spatial_dims, builder);
1111
1112 return builder->getNamedAttr(
1113 "dimension_numbers",
1114 ConvDimensionNumbers::get(
1115 batch_dim, feature_dim, spatial_dims, kernel_input_feature_dim,
1116 kernel_output_feature_dim, kernel_spatial_dimensions, batch_dim,
1117 feature_dim, spatial_dims, builder->getContext()));
1118 }
1119
1120 // Converts a TF::BiasAddOp to HLO.
1121 // This differs from a normal TF::AddOp with respect to how the data_format
1122 // is handled, which can optionally require a general broadcast of the
1123 // 'bias' term in a way that is not compatible with the standard left-padded
1124 // broadcast semantics (i.e. NCHW will broadcast into dimension 1).
1125 // The correct 'bias' broadcast will be synthesized manually.
1126 class ConvertBiasAddOp : public OpRewritePattern<TF::BiasAddOp> {
1127 public:
1128 using OpRewritePattern::OpRewritePattern;
matchAndRewrite(TF::BiasAddOp op,PatternRewriter & rewriter) const1129 LogicalResult matchAndRewrite(TF::BiasAddOp op,
1130 PatternRewriter &rewriter) const override {
1131 auto loc = op.getLoc();
1132 auto feature_dim = GetFeatureDimension(
1133 op.data_format(), op.value().getType().cast<RankedTensorType>());
1134 auto bias_broadcast = Broadcast1DToFeatureDim(loc, op.value(), op.bias(),
1135 feature_dim, rewriter);
1136 rewriter.replaceOpWithNewOp<AddOp>(op, op.value(), bias_broadcast);
1137 return success();
1138 }
1139 };
1140
1141 // Converts the TensorFlow conv op in template to the generic HLO conv op by
1142 // converting TensorFlow op attributes to HLO op attributes.
1143 //
1144 // Sample result for Conv2D:
1145 //
1146 // %conv = "mhlo.convolution"(%input, %filter) {
1147 // strides = [1, 2],
1148 // paddings = [[1, 0], [1, 1]],
1149 // ...
1150 // }
1151 //
1152 // This pattern is not defined using declarative rewrite rules as computation of
1153 // the paddings attribute anyway requires multiple source op attributes and
1154 // result op attributes. Defining it as declarative rewrite rule will introduce
1155 // some duplication in the C++ helper methods.
1156 template <typename OpTy, int num_spatial_dims, bool depthwise_conv = false>
1157 class ConvertConvOp : public OpRewritePattern<OpTy> {
1158 public:
1159 using OpRewritePattern<OpTy>::OpRewritePattern;
1160
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const1161 LogicalResult matchAndRewrite(OpTy op,
1162 PatternRewriter &rewriter) const override {
1163 tensorflow::TensorFormat data_format;
1164 if (!FormatFromString(op.data_format().str(), &data_format))
1165 return failure();
1166
1167 tensorflow::Padding padding;
1168 if (!GetPaddingFromString(op.padding().str(), &padding).ok())
1169 return failure();
1170
1171 auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
1172 auto filter_ty =
1173 op.filter().getType().template dyn_cast<RankedTensorType>();
1174 auto result_ty = op.getType().template dyn_cast<RankedTensorType>();
1175
1176 // Input, filter and the result needs to have static shape for calculation
1177 // of HLO paddings and feature group count attributes.
1178 for (RankedTensorType ty : {input_ty, filter_ty, result_ty})
1179 if (!ty || !ty.hasStaticShape()) return failure();
1180
1181 ArrayRef<Attribute> dilations = op.dilations().getValue();
1182 ArrayRef<Attribute> strides = op.strides().getValue();
1183 ArrayRef<Attribute> explicit_paddings;
1184 if (padding == tensorflow::Padding::EXPLICIT) {
1185 // EXPLICIT padding mode and the associated attribute is limited to
1186 // Conv2D. So, fetch attribute by identifier instead of the
1187 // op.explicit_paddings() attribute getter.
1188 explicit_paddings =
1189 op->template getAttrOfType<ArrayAttr>("explicit_paddings").getValue();
1190 }
1191
1192 SmallVector<int64_t, num_spatial_dims> spatial_dim_indices;
1193 SmallVector<int64_t, num_spatial_dims> rhs_dilations;
1194 SmallVector<int64_t, num_spatial_dims> window_strides;
1195 SmallVector<int64_t, num_spatial_dims * 2> paddings;
1196
1197 auto get_int = [](Attribute attr) {
1198 return attr.template cast<IntegerAttr>().getInt();
1199 };
1200
1201 constexpr int num_dims = num_spatial_dims + 2;
1202 for (auto i : llvm::seq<int>(0, num_spatial_dims)) {
1203 const int64_t dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
1204 spatial_dim_indices.push_back(dim);
1205
1206 const int64_t dilation = get_int(dilations[dim]);
1207 rhs_dilations.push_back(dilation);
1208 const int64_t stride = get_int(strides[dim]);
1209 window_strides.push_back(stride);
1210
1211 int64_t pad_low, pad_high;
1212 if (padding == tensorflow::Padding::EXPLICIT) {
1213 pad_low = get_int(explicit_paddings[2 * dim]);
1214 pad_high = get_int(explicit_paddings[2 * dim + 1]);
1215 } else {
1216 tensorflow::int64 output_size;
1217 tensorflow::int64 pad_low_int64;
1218 tensorflow::int64 pad_high_int64;
1219 tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
1220 input_ty.getDimSize(dim), filter_ty.getDimSize(i), dilation, stride,
1221 padding, &output_size, &pad_low_int64, &pad_high_int64);
1222 if (!status.ok()) return failure();
1223 pad_low = pad_low_int64;
1224 pad_high = pad_high_int64;
1225 }
1226 paddings.push_back(pad_low);
1227 paddings.push_back(pad_high);
1228 }
1229
1230 auto rhs_dilations_attr = rewriter.getNamedAttr(
1231 "rhs_dilation", GetI64ElementsAttr(rhs_dilations, &rewriter));
1232
1233 auto window_strides_attr = rewriter.getNamedAttr(
1234 "window_strides", GetI64ElementsAttr(window_strides, &rewriter));
1235
1236 auto dimension_numbers_attr = GetConvDimensionNumbersAttr(
1237 spatial_dim_indices, data_format, &rewriter);
1238
1239 const int64_t input_channels =
1240 GetDimSize(input_ty, GetTensorFeatureDimIndex(num_dims, data_format));
1241 // Filters data_format is always HWIO so input channels dimension is after
1242 // all spatial dimensions.
1243 const int64_t filter_channels = GetDimSize(filter_ty, num_spatial_dims);
1244 // TensorFlow convolution op verifies that the number of input channels is
1245 // divisible by the number of filter channels.
1246 // For depthwise convolution the feature_group_count argument would be set
1247 // to the input feature dimension.
1248 const int64_t feature_group_count =
1249 depthwise_conv ? input_channels : input_channels / filter_channels;
1250 auto feature_group_count_attr = rewriter.getNamedAttr(
1251 "feature_group_count", rewriter.getI64IntegerAttr(feature_group_count));
1252
1253 auto batch_group_count_attr = rewriter.getNamedAttr(
1254 "batch_group_count", rewriter.getI64IntegerAttr(1));
1255
1256 RankedTensorType paddings_ty = RankedTensorType::get(
1257 {num_spatial_dims, 2}, rewriter.getIntegerType(64));
1258 auto paddings_attr = rewriter.getNamedAttr(
1259 "padding", DenseElementsAttr::get<int64_t>(paddings_ty, paddings));
1260
1261 SmallVector<Value, 2> operands(op.getOperands());
1262 // Reshape the filter to {spatial_dims...., 1,in_channels *
1263 // channel_multiplier}
1264 if (depthwise_conv) {
1265 ArrayRef<int64_t> filter_shape = filter_ty.getShape();
1266 llvm::SmallVector<int64_t, num_dims> new_shape(
1267 filter_shape.begin(), filter_shape.begin() + num_spatial_dims);
1268 new_shape.push_back(1);
1269 new_shape.push_back(filter_shape[num_spatial_dims] *
1270 filter_shape[num_spatial_dims + 1]);
1271 operands[1] = rewriter.create<mhlo::ReshapeOp>(
1272 op.getLoc(),
1273 RankedTensorType::get(new_shape, filter_ty.getElementType()),
1274 operands[1]);
1275 }
1276 NamedAttribute attrs[] = {rhs_dilations_attr, window_strides_attr,
1277 dimension_numbers_attr, feature_group_count_attr,
1278 batch_group_count_attr, paddings_attr};
1279 rewriter.replaceOpWithNewOp<ConvOp>(op, op.getType(), operands,
1280 llvm::makeArrayRef(attrs));
1281 return success();
1282 }
1283 };
1284
1285 using ConvertConv2DOp = ConvertConvOp<TF::Conv2DOp, /*num_spatial_dims=*/2>;
1286 using ConvertConv3DOp = ConvertConvOp<TF::Conv3DOp, /*num_spatial_dims=*/3>;
1287 using ConvertDepthConv2DOp =
1288 ConvertConvOp<TF::DepthwiseConv2dNativeOp, /*num_spatial_dims=*/2,
1289 /*depthwise_conv=*/true>;
1290
1291 // Converts BF16 FloorDiv op to have casting operators on either end as BF16
1292 // division can result in strange behavior.
1293 //
1294 // floordiv = cast(floordiv(cast(left), cast(right))))
1295 //
1296 // %left_cast = cast(%left)
1297 // %right_cast = cast(%right)
1298 // %div = div(%left, %left)
1299 // %floored = floor(%div)
1300 // %floored_cast = cast(%floored)
1301 //
1302 // Required to manually specify the intermediate types.
1303 class ConvertBF16FloorDivOp : public OpRewritePattern<TF::FloorDivOp> {
1304 public:
1305 using OpRewritePattern::OpRewritePattern;
1306
matchAndRewrite(TF::FloorDivOp op,PatternRewriter & rewriter) const1307 LogicalResult matchAndRewrite(TF::FloorDivOp op,
1308 PatternRewriter &rewriter) const override {
1309 auto l = op.x();
1310 auto r = op.y();
1311 auto element_type = getElementTypeOrSelf(l.getType());
1312 if (!element_type.isBF16()) return failure();
1313
1314 auto out_type = op.z().getType().cast<TensorType>();
1315
1316 l = rewriter.create<ConvertOp>(op.getLoc(), l, rewriter.getF32Type());
1317 r = rewriter.create<ConvertOp>(op.getLoc(), r, rewriter.getF32Type());
1318
1319 auto intermediate = rewriter.create<TF::FloorDivOp>(
1320 op.getLoc(),
1321 ChangeTensorElementType(&rewriter, out_type, rewriter.getF32Type()), l,
1322 r);
1323
1324 auto floor_op =
1325 rewriter.create<ConvertOp>(op.getLoc(), out_type, intermediate);
1326 rewriter.replaceOp(op, floor_op.getResult());
1327 return success();
1328 }
1329 };
1330
1331 class ConvertBroadcastToOp : public OpRewritePattern<TF::BroadcastToOp> {
1332 public:
1333 using OpRewritePattern::OpRewritePattern;
1334
matchAndRewrite(TF::BroadcastToOp op,PatternRewriter & rewriter) const1335 LogicalResult matchAndRewrite(TF::BroadcastToOp op,
1336 PatternRewriter &rewriter) const override {
1337 auto input_type = op.input().getType().dyn_cast<RankedTensorType>();
1338 auto output_type = op.output().getType().dyn_cast<RankedTensorType>();
1339 if (!input_type || !output_type) {
1340 return rewriter.notifyMatchFailure(op, "requires ranked shape");
1341 }
1342 auto rank_diff = output_type.getRank() - input_type.getRank();
1343 // The tf.BroadcastTo op performs "right-aligned" numpy-style broadcasting.
1344 auto broadcast_dimensions = llvm::to_vector<4>(
1345 llvm::seq<int64_t>(rank_diff, output_type.getRank()));
1346 rewriter.replaceOpWithNewOp<DynamicBroadcastInDimOp>(
1347 op, output_type, op.input(), op.shape(),
1348 rewriter.getI64TensorAttr(broadcast_dimensions));
1349 return success();
1350 }
1351 };
1352
1353 // Converts TensorFlow DiagPartOp to HLO ops using reduction on masked matrix.
1354 // For a Rank-2 input, it creates the following ops:
1355 // %1 = "mhlo.iota"() {iota_dimension = 0 : i64}
1356 // %2 = "mhlo.iota"() {iota_dimension = 1 : i64}
1357 // %3 = "mhlo.compare"(%1, %2) {comparison_direction = "EQ"}
1358 // %4 = mhlo.constant dense<0.000000e+00> : tensor<f32>
1359 // %5 = "mhlo.broadcast"(%4)
1360 // %6 = "mhlo.select"(%3, %input, %5)
1361 // %7 = "mhlo.reduce"(%6, %4) ( {
1362 // ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
1363 // %9 = mhlo.add %arg1, %arg2 : tensor<f32>
1364 // "mhlo.return"(%9) : (tensor<f32>) -> ()
1365 // }) {dimensions = dense<0> : tensor<1xi64>}
1366 //
1367 // If the input's rank N is greater than 2, we will reshape it to R2 first and
1368 // create the above ops, then reshape it back to rank N/2.
1369 class ConvertDiagPartOp : public OpRewritePattern<TF::DiagPartOp> {
1370 public:
1371 using OpRewritePattern::OpRewritePattern;
1372
matchAndRewrite(TF::DiagPartOp op,PatternRewriter & rewriter) const1373 LogicalResult matchAndRewrite(TF::DiagPartOp op,
1374 PatternRewriter &rewriter) const override {
1375 auto input_type = op.input().getType().dyn_cast<RankedTensorType>();
1376 if (!input_type || !input_type.hasStaticShape()) return failure();
1377 int64_t num_dims = input_type.getRank();
1378 if (num_dims < 2 || num_dims % 2 != 0) return failure();
1379 const int64_t out_dims = num_dims / 2;
1380
1381 int64_t new_size = 1;
1382 llvm::SmallVector<int64_t, 4> new_dims;
1383 for (int i = 0; i < out_dims; i++) {
1384 if (input_type.getDimSize(i) != input_type.getDimSize(i + out_dims))
1385 return op.emitOpError("invalid dimensions size");
1386 new_size *= input_type.getDimSize(i);
1387 new_dims.push_back(input_type.getDimSize(i));
1388 }
1389 Value reshaped_input = rewriter.create<mhlo::ReshapeOp>(
1390 op.getLoc(),
1391 RankedTensorType::get({new_size, new_size},
1392 input_type.getElementType()),
1393 op.input());
1394 auto iota_type = RankedTensorType::get({new_size, new_size},
1395 rewriter.getIntegerType(32));
1396 auto iota0 = rewriter.create<IotaOp>(op.getLoc(), iota_type,
1397 rewriter.getI64IntegerAttr(0));
1398 auto iota1 = rewriter.create<IotaOp>(op.getLoc(), iota_type,
1399 rewriter.getI64IntegerAttr(1));
1400 Value compare = rewriter.create<CompareOp>(
1401 op.getLoc(), iota0, iota1,
1402 StringAttr::get(rewriter.getContext(), "EQ"));
1403 Value zero = GetScalarConstOfType(input_type.getElementType(), op.getLoc(),
1404 0, &rewriter);
1405 Value zero_matrix = rewriter.create<BroadcastOp>(
1406 op.getLoc(), reshaped_input.getType(), zero,
1407 GetI64ElementsAttr({new_size, new_size}, &rewriter));
1408 Value masked =
1409 rewriter.create<SelectOp>(op.getLoc(), reshaped_input.getType(),
1410 compare, reshaped_input, zero_matrix);
1411 auto reduce = rewriter.create<ReduceOp>(op.getLoc(), masked, zero,
1412 GetI64ElementsAttr({0}, &rewriter));
1413 assert(!input_type.getElementType().isInteger(1) &&
1414 "data type should not be i1");
1415 BuildReduceBody<AddOp>(input_type.getElementType(), &reduce.body(),
1416 &rewriter);
1417 rewriter.replaceOpWithNewOp<ReshapeOp>(
1418 op, RankedTensorType::get(new_dims, input_type.getElementType()),
1419 reduce.getResult(0));
1420 return success();
1421 }
1422 };
1423
1424 // Converts TensorFlow MatrixDiagPartOp to HLO ops.
1425 class ConvertMatrixDiagPartV3Op
1426 : public OpRewritePattern<TF::MatrixDiagPartV3Op> {
1427 using Shape = llvm::SmallVector<int64_t, 4>;
1428
1429 // Parse the "k" parameter. MatrixDiagPartV3 allows to specify the diagonal(s)
1430 // with k. This can be either a single value (for a single diagonal) or a
1431 // tuple of two values (starting and ending diagonal, for a band).
ExtractK(TF::MatrixDiagPartV3Op op,int64_t (* k)[2]) const1432 LogicalResult ExtractK(TF::MatrixDiagPartV3Op op, int64_t (*k)[2]) const {
1433 DenseIntElementsAttr kattr;
1434 if (!matchPattern(op.k(), m_Constant(&kattr))) {
1435 return failure();
1436 }
1437 DenseIntElementsAttr::iterator it = kattr.begin();
1438 (*k)[0] = (*it).getSExtValue();
1439 it++;
1440 if (it == kattr.end()) {
1441 // Handle input like e.g. "k = 5", in which case we extract a single
1442 // diagonal.
1443 (*k)[1] = (*k)[0];
1444 } else {
1445 // Handle input like e.g. "k = [-1, 1]", in which case we extract a
1446 // band (multiple diagonals).
1447 (*k)[1] = (*it).getSExtValue();
1448 }
1449 return success();
1450 }
1451
1452 // Utility method for broadcasting integer constants to a given shape.
BroadcastConstant(Location loc,Shape shape,int32_t constant,int int_size,PatternRewriter & rewriter) const1453 BroadcastOp BroadcastConstant(Location loc, Shape shape, int32_t constant,
1454 int int_size, PatternRewriter &rewriter) const {
1455 return rewriter.create<BroadcastOp>(
1456 loc, RankedTensorType::get(shape, rewriter.getIntegerType(int_size)),
1457 GetScalarConstOfType(rewriter.getIntegerType(int_size), loc, constant,
1458 &rewriter),
1459 GetI64ElementsAttr(shape, &rewriter));
1460 }
1461
1462 public:
1463 using OpRewritePattern::OpRewritePattern;
1464
matchAndRewrite(TF::MatrixDiagPartV3Op op,PatternRewriter & rewriter) const1465 LogicalResult matchAndRewrite(TF::MatrixDiagPartV3Op op,
1466 PatternRewriter &rewriter) const override {
1467 Location loc = op.getLoc();
1468 ShapedType input_type = op.input().getType().dyn_cast<ShapedType>();
1469 auto element_type = input_type.getElementType();
1470
1471 // Align is a string specifying how superdiagonals and subdiagonals should
1472 // be aligned/padded for diagonals that are shorter than max_diag_len. The
1473 // format is "{super}_{sub}", with {super} the superdiagonal alignment and
1474 // {sub} the subdiagonal alignment. "LEFT" means rows will be padded to the
1475 // left, "RIGHT" means rows will be padded ot the right. The default is
1476 // "RIGHT_LEFT".
1477 StringRef align = op->getAttrOfType<StringAttr>("align").getValue();
1478 enum Alignment { kLeft, kRight };
1479
1480 // default is RIGHT_LEFT
1481 Alignment superdiagonal_align = kRight;
1482 Alignment subdiagonal_align = kLeft;
1483
1484 if (align == "RIGHT_LEFT") {
1485 superdiagonal_align = kRight;
1486 subdiagonal_align = kLeft;
1487 } else if (align == "RIGHT_RIGHT") {
1488 superdiagonal_align = kRight;
1489 subdiagonal_align = kRight;
1490 } else if (align == "LEFT_RIGHT") {
1491 superdiagonal_align = kLeft;
1492 subdiagonal_align = kRight;
1493 } else if (align == "LEFT_LEFT") {
1494 superdiagonal_align = kLeft;
1495 subdiagonal_align = kLeft;
1496 } else {
1497 return failure(); // unsupported alignment
1498 }
1499
1500 // MatrixDiagPart operates on a matrix of shape [I, J, ..., L, M, N], and
1501 // will extract the diagonal(s) out of [M, N], for all [I, J, ..., L].
1502 if (!input_type || !input_type.hasStaticShape()) return failure();
1503 int64_t num_dims = input_type.getRank();
1504 if (num_dims < 2) return failure();
1505 int64_t rows = input_type.getDimSize(num_dims - 2); // rows
1506 int64_t cols = input_type.getDimSize(num_dims - 1); // cols
1507
1508 // We extract the diagonals from k[0] up to and including k[1].
1509 // Addressing is 0 for the main diagonal. (So k = [0, 0] would just extract
1510 // the main diagonal). It's negative for subdiagonals (under and to the left
1511 // of the main diagonal) and positive for superdiagonals (above and to the
1512 // right of the main diagonal).
1513 int64_t k[2];
1514 if (failed(ExtractK(op, &k))) return failure();
1515 int num_diags = k[1] - k[0] + 1;
1516
1517 // Shifting diagonals away from the main diagonal might shorten them. This
1518 // is the longest diagonal we will see. We make this the last dimension of
1519 // the output shape.
1520 int64_t max_diag_len =
1521 std::min(rows + std::min(k[1], static_cast<int64_t>(0)),
1522 cols + std::min(-k[0], static_cast<int64_t>(0)));
1523
1524 // The first dimension is the index vector dimension we'll use for gather.
1525 // It's 1 here, but will be 2 once we glue x and y together.
1526 Shape indices_shape({1, num_diags, max_diag_len});
1527
1528 RankedTensorType iota_type =
1529 RankedTensorType::get(indices_shape, rewriter.getIntegerType(32));
1530 Value iotaM =
1531 rewriter.create<IotaOp>(loc, iota_type, rewriter.getI64IntegerAttr(1));
1532 Value iotaN =
1533 rewriter.create<IotaOp>(loc, iota_type, rewriter.getI64IntegerAttr(2));
1534
1535 // Boradcasted constants, of the same shape as iotaM and iotaN.
1536 Value b_zero = BroadcastConstant(loc, indices_shape, 0, 32, rewriter);
1537 Value b_false = BroadcastConstant(loc, indices_shape, 0, 1, rewriter);
1538 Value b_true = BroadcastConstant(loc, indices_shape, 1, 1, rewriter);
1539 Value b_k1 = BroadcastConstant(loc, indices_shape, k[1], 32, rewriter);
1540 Value b_rows = BroadcastConstant(loc, indices_shape, rows, 32, rewriter);
1541 Value b_cols = BroadcastConstant(loc, indices_shape, cols, 32, rewriter);
1542 Value b_max_diag_len =
1543 BroadcastConstant(loc, indices_shape, max_diag_len, 32, rewriter);
1544
1545 // d = k[1] - m
1546 // (A.k.a. the number of the diagonal, depending on m. Note that we
1547 // subtract m here. This means we start with the superdiagonals and
1548 // move downwards towards the subdiagonals. So the start indices will
1549 // be decreasing.)
1550 Value d = rewriter.create<SubOp>(loc, b_k1, iotaM);
1551 Value neg_d = rewriter.create<NegOp>(loc, d);
1552
1553 // diag_len_d = min(rows + min(d, 0), cols - max(d, 0))
1554 // (Length of a diagonal for a given d. Same as max_diag_len for m = 0.)
1555 Value diag_len_d = rewriter.create<MinOp>(
1556 loc,
1557 rewriter.create<AddOp>(loc, b_rows,
1558 rewriter.create<MinOp>(loc, d, b_zero)),
1559 rewriter.create<SubOp>(loc, b_cols,
1560 rewriter.create<MaxOp>(loc, d, b_zero)));
1561
1562 // offset is max_diag_len - diag_len_d if we're padding, 0 otherwise.
1563 Value cmp;
1564 if (subdiagonal_align == kRight && superdiagonal_align == kRight) {
1565 cmp = b_true;
1566 } else if (superdiagonal_align == kRight) {
1567 // offset = d>=0 ? max_diag_len - diag_len_d : 0
1568 cmp = rewriter.create<TF::GreaterEqualOp>(loc, d, b_zero);
1569 } else if (subdiagonal_align == kRight) {
1570 // offset = d<=0 ? max_diag_len - diag_len_d : 0
1571 cmp = rewriter.create<TF::LessEqualOp>(loc, d, b_zero);
1572 } else {
1573 // offset = 0
1574 cmp = b_false;
1575 }
1576
1577 // This offset shifts the diagonals to the "left" or "right", depending
1578 // on alignment.
1579 Value offset = rewriter.create<SelectOp>(
1580 loc, b_zero.getType(), cmp,
1581 rewriter.create<SubOp>(loc, b_max_diag_len, diag_len_d), b_zero);
1582
1583 // x = max(d, 0) - offset
1584 // y = max(-d, 0) - offset
1585 Value x = rewriter.create<SubOp>(
1586 loc, rewriter.create<MaxOp>(loc, d, b_zero), offset);
1587 Value y = rewriter.create<SubOp>(
1588 loc, rewriter.create<MaxOp>(loc, neg_d, b_zero), offset);
1589
1590 Value n_plus_x = rewriter.create<AddOp>(loc, iotaN, x);
1591 Value n_plus_y = rewriter.create<AddOp>(loc, iotaN, y);
1592
1593 // GatherOp is happy about letting us index out of bounds values, but those
1594 // values will be undefined. So we mask them later. Set up the boolean
1595 // expression that tells us which entries, in the output shape, are out of
1596 // bounds and thus become the padding_value.
1597 Value x_in_bounds = rewriter.create<AndOp>(
1598 loc,
1599 rewriter.create<TF::GreaterEqualOp>(loc, b_false.getType(), n_plus_x,
1600 b_zero),
1601 rewriter.create<TF::LessOp>(loc, b_false.getType(), n_plus_x, b_cols));
1602 Value y_in_bounds = rewriter.create<AndOp>(
1603 loc,
1604 rewriter.create<TF::GreaterEqualOp>(loc, b_false.getType(), n_plus_y,
1605 b_zero),
1606 rewriter.create<TF::LessOp>(loc, b_false.getType(), n_plus_y, b_rows));
1607 Value in_bounds = rewriter.create<ReshapeOp>(
1608 loc,
1609 RankedTensorType::get(Shape({num_diags, max_diag_len}),
1610 rewriter.getIntegerType(1)),
1611 rewriter.create<AndOp>(loc, x_in_bounds, y_in_bounds));
1612
1613 // Now combine x and y into the index data structure needed for gather.
1614 Shape concat_shape({2, num_diags, max_diag_len});
1615 Value start_indices = rewriter.create<ConcatenateOp>(
1616 loc, RankedTensorType::get(concat_shape, rewriter.getIntegerType(32)),
1617 mlir::ValueRange({n_plus_y, n_plus_x}),
1618 mlir::IntegerAttr::get(rewriter.getIntegerType(64), 0));
1619
1620 // Shape of the final output. (Except for dimension folding in the
1621 // single diagonal case.)
1622 Shape output_shape;
1623 for (int i = 0; i < num_dims - 2; i++) {
1624 output_shape.push_back(input_type.getDimSize(i));
1625 }
1626 output_shape.push_back(num_diags);
1627 output_shape.push_back(max_diag_len);
1628 auto output_type = RankedTensorType::get(output_shape, element_type);
1629
1630 // A slice is the shape of what GatherOp copies per lookup. So the last
1631 // two dimensions (M, N in the matrix-diag-part docs) are where we go
1632 // through entry by entry.
1633 ArrayRef<int64_t> input_shape = input_type.getShape();
1634 Shape slice_sizes(input_shape.begin(), input_shape.end());
1635 int slice_dimensions = slice_sizes.size();
1636 slice_sizes[slice_dimensions - 2] = 1;
1637 slice_sizes[slice_dimensions - 1] = 1;
1638
1639 // Dimensions of the input we won't see in the output (M and N).
1640 SmallVector<int64_t, 2> collapsed_dims(
1641 {slice_dimensions - 2, slice_dimensions - 1});
1642
1643 // Which dimensions (in the input) the two offset "columns" map to.
1644 SmallVector<int64_t, 2> start_index_map({num_dims - 2, num_dims - 1});
1645
1646 // Gather the diagonal entries.
1647 // TODO(kramm): For a single diagonal, this might be slower than the
1648 // mask + sum approach. Special-case num_diags==1?
1649 auto dims_attr = GatherDimensionNumbers::get(
1650 /*offset_dims=*/GetI64ElementsAttrForSeq(0, num_dims - 2, &rewriter),
1651 /*collapsed_slice_dims=*/GetI64ElementsAttr(collapsed_dims, &rewriter),
1652 /*start_index_map=*/GetI64ElementsAttr(start_index_map, &rewriter),
1653 /*index_vector_dim=*/rewriter.getI64IntegerAttr(0),
1654 rewriter.getContext());
1655 Value gather = rewriter.create<mhlo::GatherOp>(
1656 loc, output_type, op.input(), start_indices, dims_attr,
1657 GetI64ElementsAttr(slice_sizes, &rewriter));
1658
1659 // We now need to broadcast the "in_bounds" boolean expression, as well as
1660 // the padding value, to do the final select.
1661 Shape broadcast_bounds;
1662 for (int i = 0; i < output_shape.size() - 2; i++) {
1663 broadcast_bounds.push_back(output_shape[i]);
1664 }
1665 Value b_in_bounds = rewriter.create<BroadcastOp>(
1666 loc, RankedTensorType::get(output_shape, rewriter.getIntegerType(1)),
1667 in_bounds, GetI64ElementsAttr(broadcast_bounds, &rewriter));
1668 Value b_padding = rewriter.create<BroadcastOp>(
1669 loc, output_type, op.padding_value(),
1670 GetI64ElementsAttr(output_shape, &rewriter));
1671
1672 // Replace all out-of-bounds values in the result with padding_value.
1673 Value result = rewriter.create<SelectOp>(loc, output_type, b_in_bounds,
1674 gather, b_padding);
1675
1676 if (num_diags == 1) {
1677 // matrix_diag_part folds away the 1-sized band dimension if we only
1678 // extract a single diagonal.
1679 result = rewriter.create<ReshapeOp>(loc, op.getType(), result);
1680 }
1681
1682 rewriter.replaceOp(op, result);
1683 return success();
1684 }
1685 };
1686
1687 // Converts TensorFlow EinsumOp to either HLO EinsumOp or UnaryEinsumOp
1688 // depending on arity of the op.
1689 class ConvertEinsumOp : public OpRewritePattern<TF::EinsumOp> {
1690 public:
1691 using OpRewritePattern::OpRewritePattern;
1692
matchAndRewrite(TF::EinsumOp op,PatternRewriter & rewriter) const1693 LogicalResult matchAndRewrite(TF::EinsumOp op,
1694 PatternRewriter &rewriter) const override {
1695 StringAttr equation = op->getAttrOfType<StringAttr>("equation");
1696 if (op.N() == 1) {
1697 rewriter.replaceOpWithNewOp<UnaryEinsumOp>(
1698 op, op.getType(), *op.inputs().begin(), equation);
1699 } else if (op.N() == 2) {
1700 ValueRange inputs = op.inputs();
1701 rewriter.replaceOpWithNewOp<EinsumOp>(op, op.getType(), inputs[0],
1702 inputs[1], equation);
1703 } else {
1704 // TensorFlow EinsumOp verifies that the number of operands are at most
1705 // two.
1706 return failure();
1707 }
1708 return success();
1709 }
1710 };
1711
1712 // Bypasses IdentityN op.
1713 class ConvertIdentityNOp : public OpRewritePattern<TF::IdentityNOp> {
1714 public:
1715 using OpRewritePattern<TF::IdentityNOp>::OpRewritePattern;
matchAndRewrite(TF::IdentityNOp op,PatternRewriter & rewriter) const1716 LogicalResult matchAndRewrite(TF::IdentityNOp op,
1717 PatternRewriter &rewriter) const override {
1718 rewriter.replaceOp(op, op.getOperands());
1719 return success();
1720 }
1721 };
1722
1723 template <typename OpTy>
1724 class ConvertFFTOp : public OpRewritePattern<OpTy> {
1725 public:
1726 using OpRewritePattern<OpTy>::OpRewritePattern;
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const1727 LogicalResult matchAndRewrite(OpTy op,
1728 PatternRewriter &rewriter) const override {
1729 auto input_ty = op.input().getType().template cast<ShapedType>();
1730 if (!input_ty.hasRank()) {
1731 return failure();
1732 }
1733 auto input_shape = input_ty.getShape();
1734 DenseIntElementsAttr fft_length_attr;
1735 if (!matchPattern(op.fft_length(), m_Constant(&fft_length_attr))) {
1736 return failure();
1737 }
1738 int64_t fft_length;
1739 if (fft_length_attr.getNumElements() != 0) {
1740 fft_length = fft_length_attr.getValue<IntegerAttr>(0).getInt();
1741 } else {
1742 return failure();
1743 }
1744
1745 std::string fft_string = "RFFT";
1746 if (typeid(OpTy) == typeid(TF::IRFFTOp)) {
1747 fft_length = fft_length / 2 + 1;
1748 fft_string = "IRFFT";
1749 }
1750 auto loc = op.getLoc();
1751
1752 // The inner-most dim cannot be dynamic.
1753 if (input_ty.isDynamicDim(input_shape.size() - 1)) {
1754 return failure();
1755 }
1756
1757 auto expected_shape = llvm::to_vector<4>(input_shape.drop_back());
1758 expected_shape.push_back(fft_length);
1759
1760 // Zero pad or truncate the last axis
1761 Value reshaped = op.input();
1762 SmallVector<int64_t, 4> begin_indices(input_shape.size(), 0);
1763 SmallVector<int64_t, 4> strides(input_shape.size(), 1);
1764
1765 // Last dim larger than fft_length, slice the input
1766 if (input_shape.back() > fft_length) {
1767 reshaped = rewriter.create<SliceOp>(
1768 op.getLoc(),
1769 RankedTensorType::get(expected_shape, input_ty.getElementType()),
1770 op.input(), GetI64ElementsAttr(begin_indices, &rewriter),
1771 GetI64ElementsAttr(expected_shape, &rewriter),
1772 GetI64ElementsAttr(strides, &rewriter));
1773
1774 // Last dim smaller than fft_length, zero-pad the input
1775 } else if (input_ty.getShape().back() < fft_length) {
1776 SmallVector<int64_t, 4> no_padding(input_shape.size(), 0);
1777 SmallVector<int64_t, 4> padding(input_shape.size() - 1, 0);
1778 padding.push_back(fft_length - input_shape.back());
1779 Value zero =
1780 GetScalarConstOfType(input_ty.getElementType(), loc, 0, &rewriter);
1781 reshaped = rewriter.create<PadOp>(
1782 loc, RankedTensorType::get(expected_shape, input_ty.getElementType()),
1783 op.input(), zero, GetI64ElementsAttr(no_padding, &rewriter),
1784 GetI64ElementsAttr(padding, &rewriter),
1785 GetI64ElementsAttr(no_padding, &rewriter));
1786 }
1787
1788 rewriter.replaceOpWithNewOp<FftOp>(op, op.getType(), reshaped, fft_string,
1789 rewriter.getI64TensorAttr(fft_length));
1790 return success();
1791 }
1792 };
1793
1794 using ConvertRFFTOp = ConvertFFTOp<TF::RFFTOp>;
1795 using ConvertIRFFTOp = ConvertFFTOp<TF::IRFFTOp>;
1796
1797 // The base class to convert TensorFlow FusedBatchNormGrad*Op to HLO
1798 // BatchNormGradOp for training and a sequence of binary ops for inference.
1799 // TODO(b/145536565): move to legalize_tf_patterns.td if it applies.
1800 template <typename FusedBatchNormGradOpT>
1801 class ConvertFusedBatchNormGradBase
1802 : public OpRewritePattern<FusedBatchNormGradOpT> {
1803 public:
1804 using OpRewritePattern<FusedBatchNormGradOpT>::OpRewritePattern;
1805
matchAndRewrite(FusedBatchNormGradOpT op,PatternRewriter & rewriter) const1806 LogicalResult matchAndRewrite(FusedBatchNormGradOpT op,
1807 PatternRewriter &rewriter) const override {
1808 Location loc = op.getLoc();
1809 Value grad = op.y_backprop();
1810 Value act = op.x();
1811 Value scale = op.scale();
1812 Value mean = op.reserve_space_1();
1813 Value var = op.reserve_space_2();
1814
1815 // TODO(b/141785544): Update this to not require static shapes.
1816 // activation shape needs to be static to convert negative indices in
1817 // TensorFlow to absolute indices required by HLO.
1818 RankedTensorType act_type =
1819 act.getType().template dyn_cast<RankedTensorType>();
1820 if (!act_type) return failure();
1821 Type act_ele_type = act_type.getElementType();
1822 // To support mixed precision, the statistics type, which maybe more
1823 // precise than the input types, are used for this op.
1824 Type kernel_type =
1825 scale.getType().template cast<TensorType>().getElementType();
1826 grad = rewriter.create<ConvertOp>(loc, grad, kernel_type);
1827 act = rewriter.create<ConvertOp>(loc, act, kernel_type);
1828
1829 auto feature_dim_attr =
1830 getFeatureDimensionAttr(rewriter, op.data_format(), act);
1831 auto feature_dim = feature_dim_attr.getValue().getSExtValue();
1832
1833 // Gets the result values.
1834 Value x_backprop, scale_backprop, offset_backprop;
1835 if (op.is_training()) { // training
1836 // TODO(b/145536565): handle GPU logic separately.
1837 // Infers the output type with the converted `act`.
1838 Type feature_type = RankedTensorType::get(
1839 {GetDimSize(act_type, feature_dim)}, kernel_type);
1840 Type result_type = TupleType::get(
1841 rewriter.getContext(), {act.getType(), feature_type, feature_type});
1842
1843 auto training_op = rewriter.create<BatchNormGradOp>(
1844 loc, result_type, act, scale, mean, var, grad, op.epsilon(),
1845 feature_dim);
1846
1847 x_backprop =
1848 rewriter.create<GetTupleElementOp>(loc, training_op.getResult(), 0);
1849
1850 scale_backprop =
1851 rewriter.create<GetTupleElementOp>(loc, training_op.getResult(), 1);
1852
1853 offset_backprop =
1854 rewriter.create<GetTupleElementOp>(loc, training_op.getResult(), 2);
1855 } else { // inference
1856 SmallVector<int64_t, 4> non_feature_dims;
1857 for (int64_t i = 0; i < act_type.getRank(); ++i) {
1858 if (i == feature_dim) continue;
1859 non_feature_dims.push_back(i);
1860 }
1861 auto reduce_dims = GetI64ElementsAttr(non_feature_dims, &rewriter);
1862 auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter);
1863
1864 // scratch1 = rsqrt(var + epsilon)
1865 RankedTensorType scalar_float = RankedTensorType::get({}, kernel_type);
1866 auto epsilon = rewriter.create<ConstOp>(
1867 loc, DenseFPElementsAttr::get(scalar_float, {op.epsilon()}));
1868 auto add_op = rewriter.create<chlo::BroadcastAddOp>(
1869 loc, var, epsilon.getResult(), scalar_broadcast_dims);
1870
1871 Value scratch1 = rewriter.create<RsqrtOp>(loc, add_op);
1872
1873 // scratch2 = sum(y_backprop * (x - mean))
1874 auto sub_op = rewriter.create<mhlo::SubOp>(
1875 loc, act,
1876 Broadcast1DToFeatureDim(loc, act, mean, feature_dim, rewriter));
1877 auto weighted_grad = rewriter.create<mhlo::MulOp>(loc, grad, sub_op);
1878 Value scratch2 =
1879 ApplyReduction(loc, weighted_grad, reduce_dims, &rewriter);
1880
1881 // x_backprop = y_backprop * (scale * scratch1)
1882 auto scaled_grad =
1883 rewriter.create<mhlo::MulOp>(loc, op.scale(), scratch1);
1884 x_backprop = rewriter.create<mhlo::MulOp>(
1885 loc, grad,
1886 Broadcast1DToFeatureDim(loc, act, scaled_grad, feature_dim,
1887 rewriter));
1888
1889 // scale_backprop = scratch2 * scratch1
1890 scale_backprop = rewriter.create<mhlo::MulOp>(loc, scratch1, scratch2);
1891
1892 // offset_backprop = sum(y_backprop)
1893 offset_backprop = ApplyReduction(loc, grad, reduce_dims, &rewriter);
1894 }
1895
1896 x_backprop = rewriter.create<ConvertOp>(loc, x_backprop, act_ele_type);
1897 Value last_val[2];
1898 if (op.getResult(3).use_empty() && op.getResult(4).use_empty()) {
1899 // It doesn't matter what values we provide for the last 2 results.
1900 last_val[0] = last_val[1] = op.x();
1901 } else {
1902 auto const_val = rewriter.create<ConstOp>(
1903 op.getLoc(),
1904 DenseElementsAttr::get<float>(
1905 RankedTensorType::get({0}, getElementTypeOrSelf(op.getResult(3))),
1906 0.0));
1907 auto maybe_cast = [&](Value val, Type t) -> Value {
1908 if (val.getType() == t) return val;
1909 return rewriter.create<tensor::CastOp>(op.getLoc(), t, val);
1910 };
1911 last_val[0] = maybe_cast(const_val, op.getResult(3).getType());
1912 last_val[1] = maybe_cast(const_val, op.getResult(4).getType());
1913 }
1914 rewriter.replaceOp(
1915 op, {/*x_backprop=*/x_backprop,
1916 /*scale_backprop=*/scale_backprop,
1917 /*offset_backprop=*/offset_backprop, last_val[0], last_val[1]});
1918 return success();
1919 }
1920 };
1921
1922 using ConvertFusedBatchNormGradOp =
1923 ConvertFusedBatchNormGradBase<TF::FusedBatchNormGradOp>;
1924 using ConvertFusedBatchNormGradV2Op =
1925 ConvertFusedBatchNormGradBase<TF::FusedBatchNormGradV2Op>;
1926 using ConvertFusedBatchNormGradV3Op =
1927 ConvertFusedBatchNormGradBase<TF::FusedBatchNormGradV3Op>;
1928
1929 // Converts TensorFlow FusedBatchNormV3Op to either HLO BatchNormTrainingOp or
1930 // HLO BatchNormInferenceOp, depending on the value of the 'is_training'
1931 // parameter.
1932 template <typename FusedBatchNormOpT>
1933 class ConvertFusedBatchNormBase : public OpRewritePattern<FusedBatchNormOpT> {
1934 public:
1935 using OpRewritePattern<FusedBatchNormOpT>::OpRewritePattern;
1936
matchAndRewrite(FusedBatchNormOpT op,PatternRewriter & rewriter) const1937 LogicalResult matchAndRewrite(FusedBatchNormOpT op,
1938 PatternRewriter &rewriter) const override {
1939 auto feature_dim =
1940 getFeatureDimensionAttr(rewriter, op.data_format(), op.x());
1941
1942 auto input_type_tensor = op.x().getType().template cast<TensorType>();
1943 auto input_element_type = input_type_tensor.getElementType();
1944
1945 auto scale_type_tensor = op.scale().getType().template cast<TensorType>();
1946 auto scale_element_type = scale_type_tensor.getElementType();
1947
1948 auto mean_type_tensor = op.mean().getType().template cast<TensorType>();
1949 auto mean_element_type = mean_type_tensor.getElementType();
1950 // In the training case, dimensions of input tensors must be static.
1951 if (op.is_training() && (!input_type_tensor.hasStaticShape() ||
1952 !scale_type_tensor.hasStaticShape() ||
1953 !mean_type_tensor.hasStaticShape()))
1954 return failure();
1955
1956 // TODO(b/69928690): Support mixed precision in the XLA batch
1957 // normalization operators. As a workaround, create a new x with the same
1958 // element type as scale (which may be more precise than the input type).
1959 Value bn_train_input = rewriter.create<mhlo::ConvertOp>(op.getLoc(), op.x(),
1960 scale_element_type);
1961 TensorType bn_train_input_type_tensor =
1962 bn_train_input.getType().template cast<TensorType>();
1963
1964 if (op.is_training()) {
1965 // Training case.
1966 auto operand_shape = bn_train_input_type_tensor.getShape();
1967 // The mean and variance are each 1 dimensional arrays the size of the
1968 // feature dimension, with the same element type as the operand (x).
1969 // This shape must be constructed manually because the mean and variance
1970 // inputs are empty in the training case.
1971 Type mean_var_type = RankedTensorType::get(
1972 {operand_shape[feature_dim.getInt()]}, scale_element_type);
1973 // Op result type is a tuple of 3 values: output with same shape as input;
1974 // batch_mean, and batch_var.
1975 SmallVector<Type, 3> operand_types = {bn_train_input_type_tensor,
1976 mean_var_type, mean_var_type};
1977 Type result_type = TupleType::get(rewriter.getContext(), operand_types);
1978
1979 auto bn_train_op = rewriter.create<mhlo::BatchNormTrainingOp>(
1980 op.getLoc(), result_type, bn_train_input, op.scale(), op.offset(),
1981 op.epsilon(), feature_dim.getInt());
1982 // HLO op outputs a tuple of tensors. Extract those results.
1983 auto bn_train_op_result = bn_train_op.getResult();
1984 Value y_out = rewriter.create<mhlo::GetTupleElementOp>(
1985 op.getLoc(), bn_train_op_result, 0);
1986 Value batch_mean = rewriter.create<mhlo::GetTupleElementOp>(
1987 op.getLoc(), bn_train_op_result, 1);
1988 Value reserve_space_1 = batch_mean;
1989 Value batch_variance = rewriter.create<mhlo::GetTupleElementOp>(
1990 op.getLoc(), bn_train_op_result, 2);
1991
1992 // Apply Bessel's correction on the variance.
1993 int total_input_size = bn_train_input_type_tensor.getNumElements();
1994 int total_scale_size = scale_type_tensor.getNumElements();
1995 int sample_size = total_input_size / total_scale_size;
1996 int sample_size_minus_one = std::max(1, sample_size - 1);
1997 double factor = static_cast<double>(sample_size) /
1998 static_cast<double>(sample_size_minus_one);
1999 auto factor_const_op = rewriter.create<mhlo::ConstOp>(
2000 op.getLoc(), rewriter.getFloatAttr(scale_element_type, factor));
2001
2002 Value corrected_variance = rewriter.create<chlo::BroadcastMulOp>(
2003 op.getLoc(), batch_variance.getType(), batch_variance,
2004 factor_const_op, /*broadcast_dimensions=*/DenseIntElementsAttr());
2005
2006 // Convert back to input type to stay aligned with expected output type
2007 // for TF op.
2008 y_out = rewriter.create<mhlo::ConvertOp>(op.getLoc(), y_out,
2009 input_element_type);
2010
2011 float exponential_avg_factor =
2012 op.exponential_avg_factor().convertToFloat();
2013 if (exponential_avg_factor != 1.0f) {
2014 auto alpha = rewriter.create<mhlo::ConstOp>(
2015 op.getLoc(), rewriter.getFloatAttr(mean_element_type,
2016 1.0f - exponential_avg_factor));
2017 auto beta = rewriter.create<mhlo::ConstOp>(
2018 op.getLoc(),
2019 rewriter.getFloatAttr(mean_element_type, exponential_avg_factor));
2020
2021 // new_running_mean = alpha * old_mean + beta * batch_mean.
2022 auto alpha_mul_old_mean = rewriter.create<chlo::BroadcastMulOp>(
2023 op.getLoc(), op.mean().getType(), alpha, op.mean(),
2024 /*broadcast_dimensions=*/DenseIntElementsAttr());
2025 auto beta_mul_batch_mean = rewriter.create<chlo::BroadcastMulOp>(
2026 op.getLoc(), batch_mean.getType(), beta, batch_mean,
2027 /*broadcast_dimensions=*/DenseIntElementsAttr());
2028 batch_mean = rewriter.create<chlo::BroadcastAddOp>(
2029 op.getLoc(), alpha_mul_old_mean, beta_mul_batch_mean,
2030 /*broadcast_dimensions=*/DenseIntElementsAttr());
2031
2032 // new_running_variance = alpha * old_variance + beta * batch_variance.
2033 auto alpha_mul_old_variance = rewriter.create<chlo::BroadcastMulOp>(
2034 op.getLoc(), op.variance().getType(), alpha, op.variance(),
2035 /*broadcast_dimensions=*/DenseIntElementsAttr());
2036 auto beta_mul_batch_variance = rewriter.create<chlo::BroadcastMulOp>(
2037 op.getLoc(), corrected_variance.getType(), beta, corrected_variance,
2038 /*broadcast_dimensions=*/DenseIntElementsAttr());
2039 corrected_variance = rewriter.create<chlo::BroadcastAddOp>(
2040 op.getLoc(), alpha_mul_old_variance, beta_mul_batch_variance,
2041 /*broadcast_dimensions=*/DenseIntElementsAttr());
2042 }
2043
2044 if (std::is_same<FusedBatchNormOpT, TF::FusedBatchNormV2Op>::value) {
2045 // FusedBatchNormV2 expects 4 outputs.
2046 // Outputs 3 and 4 are currently marked as "reserved spaces 1 and 2".
2047 // They are used to pass the per-batch mean and variance to the
2048 // gradiant. Here we maintain the same behavior by setting them to the
2049 // mean and variance calculated by BatchNormTraining.
2050 rewriter.replaceOp(op, {y_out, /*batch_mean=*/batch_mean,
2051 /*batch_variance=*/corrected_variance,
2052 /*reserve_space_1=*/reserve_space_1,
2053 /*reserve_space_2=*/batch_variance});
2054 } else { // TF::FusedBatchNormV3Op
2055 // For FusedBatchNormV3Op, also create a constant tensor to forward to
2056 // last reserve_space_3 output.
2057 auto reserve_space_3_type =
2058 op.getResult(5).getType().template cast<TensorType>();
2059 int num_elements = reserve_space_3_type.hasStaticShape()
2060 ? reserve_space_3_type.getNumElements()
2061 : 0;
2062 auto const_attr_type = RankedTensorType::get(
2063 {num_elements}, getElementTypeOrSelf(reserve_space_3_type));
2064 Value dummy_const = rewriter.create<ConstOp>(
2065 op.getLoc(), DenseElementsAttr::get<float>(const_attr_type, 0.0));
2066 if (const_attr_type != reserve_space_3_type)
2067 dummy_const = rewriter.create<tensor::CastOp>(
2068 op.getLoc(), reserve_space_3_type, dummy_const);
2069 rewriter.replaceOp(op, {y_out, /*batch_mean=*/batch_mean,
2070 /*batch_variance=*/corrected_variance,
2071 /*reserve_space_1=*/reserve_space_1,
2072 /*reserve_space_2=*/batch_variance,
2073 /*reserve_space_3=*/dummy_const});
2074 }
2075 } else { // Inference case.
2076 auto bn_train_op = rewriter.create<BatchNormInferenceOp>(
2077 op.getLoc(),
2078 /*result_type=*/bn_train_input_type_tensor, bn_train_input,
2079 op.scale(), op.offset(), op.mean(), op.variance(), op.epsilon(),
2080 feature_dim.getInt());
2081
2082 // Convert back to input type to stay aligned with expected output type
2083 // for TF op.
2084 auto y_out = rewriter.create<mhlo::ConvertOp>(op.getLoc(), bn_train_op,
2085 input_element_type);
2086
2087 // The mean, variance, and reserved space outputs of the batch norm op are
2088 // not used for inference. It doesn't matter what values we provide for
2089 // the last 5 results as long as they are of the same type. Forward
2090 // input mean and variance to output mean, variance, reserved_space_1 and
2091 // reserved_space_2.
2092 if (std::is_same<FusedBatchNormOpT, TF::FusedBatchNormV2Op>::value) {
2093 rewriter.replaceOp(op, {/*y=*/y_out,
2094 /*batch_mean=*/op.mean(),
2095 /*batch_variance=*/op.variance(),
2096 /*reserve_space_1=*/op.mean(),
2097 /*reserve_space_2=*/op.variance()});
2098 } else {
2099 // For FusedBatchNormV3Op, also create a constant tensor to forward to
2100 // last reserve_space_3 output.
2101 auto reserve_space_3_type =
2102 op.getResult(5).getType().template cast<TensorType>();
2103 int num_elements = reserve_space_3_type.hasStaticShape()
2104 ? reserve_space_3_type.getNumElements()
2105 : 0;
2106 auto const_attr_type = RankedTensorType::get(
2107 {num_elements}, getElementTypeOrSelf(reserve_space_3_type));
2108 Value dummy_const = rewriter.create<ConstOp>(
2109 op.getLoc(), DenseElementsAttr::get<float>(const_attr_type, 0.0));
2110 if (const_attr_type != reserve_space_3_type)
2111 dummy_const = rewriter.create<tensor::CastOp>(
2112 op.getLoc(), reserve_space_3_type, dummy_const);
2113 rewriter.replaceOp(op, {/*y=*/y_out,
2114 /*batch_mean=*/op.mean(),
2115 /*batch_variance=*/op.variance(),
2116 /*reserve_space_1=*/op.mean(),
2117 /*reserve_space_2=*/op.variance(),
2118 /*reserve_space_3=*/dummy_const});
2119 }
2120 }
2121 return success();
2122 }
2123 };
2124
2125 using ConvertFusedBatchNormV2Op =
2126 ConvertFusedBatchNormBase<TF::FusedBatchNormV2Op>;
2127 using ConvertFusedBatchNormV3Op =
2128 ConvertFusedBatchNormBase<TF::FusedBatchNormV3Op>;
2129
2130 using PaddingArray =
2131 std::vector<std::pair<tensorflow::int64, tensorflow::int64>>;
2132
2133 // Returns padding values for ReduceWindow op as a vector of pairs.
2134 //
2135 // Requires padding to be either 'SAME' or 'VALID' and the number of input
2136 // dimensions to be equal to the size of window dimensions and window strides.
2137 template <int num_dims>
GetReduceWindowPaddingAsArray(llvm::ArrayRef<int64_t> input_dims,ArrayAttr window_dims,ArrayAttr window_strides,StringRef padding,Builder * builder)2138 static PaddingArray GetReduceWindowPaddingAsArray(
2139 llvm::ArrayRef<int64_t> input_dims, ArrayAttr window_dims,
2140 ArrayAttr window_strides, StringRef padding, Builder *builder) {
2141 if (padding == "VALID") {
2142 return PaddingArray(num_dims, std::make_pair(0, 0));
2143 }
2144 assert(padding == "SAME");
2145 llvm::SmallVector<tensorflow::int64, num_dims> input_shape, window_shape,
2146 strides;
2147 input_shape.reserve(input_dims.size());
2148 window_shape.reserve(window_shape.size());
2149 strides.reserve(window_strides.size());
2150
2151 for (const auto &dim : input_dims) input_shape.push_back(dim);
2152 for (Attribute attr : window_dims)
2153 window_shape.push_back(attr.cast<IntegerAttr>().getInt());
2154 for (Attribute attr : window_strides)
2155 strides.push_back(attr.cast<IntegerAttr>().getInt());
2156
2157 PaddingArray paddings = ::xla::MakePadding(input_shape, window_shape, strides,
2158 ::xla::Padding::kSame);
2159 return paddings;
2160 }
2161
2162 // Same as GetReduceWindowPaddingAsArray but returns padding as
2163 // DenseIntElementsAttr. Returns empty attribute for `VALID` padding.
2164 template <int num_dims>
GetReduceWindowPaddingAsAttr(llvm::ArrayRef<int64_t> input_dims,ArrayAttr window_dims,ArrayAttr window_strides,StringRef padding,Builder * builder)2165 static DenseIntElementsAttr GetReduceWindowPaddingAsAttr(
2166 llvm::ArrayRef<int64_t> input_dims, ArrayAttr window_dims,
2167 ArrayAttr window_strides, StringRef padding, Builder *builder) {
2168 if (padding == "VALID") return {};
2169 assert(padding == "SAME");
2170 PaddingArray paddings = GetReduceWindowPaddingAsArray<num_dims>(
2171 input_dims, window_dims, window_strides, padding, builder);
2172 int64_t rank = paddings.size();
2173 llvm::SmallVector<int64_t, num_dims * 2> flatten_paddings(rank * 2);
2174 for (int i = 0; i < rank; i++) {
2175 flatten_paddings[2 * i] = paddings[i].first;
2176 flatten_paddings[2 * i + 1] = paddings[i].second;
2177 }
2178 return DenseIntElementsAttr::get(
2179 RankedTensorType::get({rank, 2}, builder->getIntegerType(64)),
2180 flatten_paddings);
2181 }
2182
2183 // Helper function for dividing each entry of `pooled` by the count of its
2184 // corresponding window, i.e., the number of non-padding entries of the window
2185 // which an `AvgPool` operation performed on an `input_shape`-tensor would map
2186 // to this entry, depending on `ksize` and `strides`. This function is used for
2187 // `AvgPool` and `AvgPoolGrad` legalizations.
2188 // `zero` is passed as a parameter because it can be reused from caller level.
2189 // `pooled` must have `RankedTensorType`.
2190 template <typename OpTy, int num_dims>
AvgPoolDivideByCount(Value pooled,const SmallVector<int64_t,num_dims> & input_shape,const SmallVector<int64_t,num_dims> & ksize,const SmallVector<int64_t,num_dims> & strides,OpTy op,Value zero,PatternRewriter & rewriter)2191 Operation *AvgPoolDivideByCount(
2192 Value pooled, const SmallVector<int64_t, num_dims> &input_shape,
2193 const SmallVector<int64_t, num_dims> &ksize,
2194 const SmallVector<int64_t, num_dims> &strides, OpTy op, Value zero,
2195 PatternRewriter &rewriter) {
2196 Location loc = op.getLoc();
2197 RankedTensorType pooled_type =
2198 pooled.getType().template cast<RankedTensorType>();
2199 Type element_type = pooled_type.getElementType();
2200 Operation *result = nullptr;
2201 RankedTensorType orig_input_type =
2202 RankedTensorType::get(input_shape, element_type);
2203
2204 if (op.padding() == "VALID") {
2205 // All window counts are equal here because we don't have padding
2206 // (each entry of `pooled` corresponds to a window that consists of
2207 // original input entries only).
2208 int64_t window_count = std::accumulate(ksize.begin(), ksize.end(), 1,
2209 std::multiplies<int64_t>());
2210 // Divide `pooled` by window counts.
2211 Value divisor =
2212 GetScalarConstOfType(element_type, loc, window_count, &rewriter);
2213 auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter);
2214 result = rewriter.create<chlo::BroadcastDivOp>(
2215 loc, pooled_type, pooled, divisor, scalar_broadcast_dims);
2216 } else {
2217 assert(op.padding() == "SAME");
2218 // For SAME padding, only original entries that contributed to a window
2219 // are counted for the average of this window, not padded entries.
2220
2221 // Build all-ones tensor of same shape as the original input.
2222 ElementsAttr splat = hlo::getSplat(&rewriter, orig_input_type, 1);
2223 auto all_ones_tensor = rewriter.create<ConstOp>(loc, splat);
2224
2225 // Get padding for the input.
2226 DenseIntElementsAttr input_padding_attr =
2227 GetReduceWindowPaddingAsAttr<num_dims>(
2228 input_shape, op.ksize(), op.strides(), op.padding(), &rewriter);
2229
2230 // Count the 1's in each window, using the same padding as for the input,
2231 // which gives us the window counts by which `pooled` needs to be divided.
2232 auto divisor = rewriter.create<ReduceWindowOp>(
2233 loc, pooled_type,
2234 /*operand=*/all_ones_tensor,
2235 /*init_value=*/zero,
2236 /*window_dimensions=*/GetI64ElementsAttr(op.ksize()),
2237 /*window_strides=*/GetI64ElementsAttr(op.strides()),
2238 /*base_dilations=*/DenseIntElementsAttr(),
2239 /*window_dilations=*/DenseIntElementsAttr(),
2240 /*padding=*/input_padding_attr);
2241 BuildReduceBody<AddOp>(element_type, &divisor.body(), &rewriter);
2242
2243 // Divide `pooled` by window counts.
2244 result = rewriter.create<mhlo::DivOp>(loc, pooled_type, pooled, divisor);
2245 }
2246 return result;
2247 }
2248
GetAvgPoolInput(TF::AvgPoolOp op)2249 Value GetAvgPoolInput(TF::AvgPoolOp op) { return op.value(); }
GetAvgPoolInput(TF::AvgPool3DOp op)2250 Value GetAvgPoolInput(TF::AvgPool3DOp op) { return op.input(); }
2251
2252 // Converts AvgPool op to HLO ReduceWindow op by setting appropriate window
2253 // dimensions with add as the reduction function. The reduction result is
2254 // then divided by the number of elements in the window.
2255 template <typename OpTy, int num_dims>
2256 class ConvertAvgPoolOp : public OpRewritePattern<OpTy> {
2257 public:
2258 using OpRewritePattern<OpTy>::OpRewritePattern;
2259
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const2260 LogicalResult matchAndRewrite(OpTy op,
2261 PatternRewriter &rewriter) const override {
2262 Value input_value = GetAvgPoolInput(op);
2263 auto input_type =
2264 input_value.getType().template dyn_cast<RankedTensorType>();
2265 if (!input_type) return failure();
2266
2267 // We will do accumulation first; use a larger bitwidth if suitable.
2268 Type input_element_type = input_type.getElementType();
2269 Type sum_element_type = GetSumAccumulationType(input_element_type);
2270 Type result_type;
2271
2272 // The result type for reduction and division with the proper element type.
2273 if (auto ranked_type = op.getType().template dyn_cast<RankedTensorType>())
2274 result_type =
2275 RankedTensorType::get(ranked_type.getShape(), sum_element_type);
2276 else
2277 result_type = UnrankedTensorType::get(sum_element_type);
2278
2279 // Convert if we need enlarge the element type's bitwidth.
2280 if (input_element_type != sum_element_type)
2281 input_value = rewriter.create<ConvertOp>(op.getLoc(), input_value,
2282 sum_element_type);
2283
2284 // Create the tf.ReduceWindow op.
2285 Value init =
2286 GetScalarConstOfType(sum_element_type, op.getLoc(), 0, &rewriter);
2287 DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr<num_dims>(
2288 input_type.getShape(), op.ksize(), op.strides(), op.padding(),
2289 &rewriter);
2290 auto reduce = rewriter.create<ReduceWindowOp>(
2291 op.getLoc(), result_type, input_value, init,
2292 GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()),
2293 /*base_dilations=*/DenseIntElementsAttr(),
2294 /*window_dilations=*/DenseIntElementsAttr(), paddings_attr);
2295 BuildReduceBody<AddOp>(sum_element_type, &reduce.body(), &rewriter);
2296
2297 // Count the number of elements in the window. The following calculation
2298 // is only valid for no paddings.
2299 SmallVector<int64_t, num_dims> input_shape(
2300 llvm::to_vector<num_dims>(input_type.getShape()));
2301 SmallVector<int64_t, num_dims> ksize, strides;
2302 GetI64ArrayAttrValues(op.ksize(), &ksize);
2303 GetI64ArrayAttrValues(op.strides(), &strides);
2304
2305 Operation *result_op = AvgPoolDivideByCount<OpTy, num_dims>(
2306 reduce.getResult(), input_shape, ksize, strides, op, init, rewriter);
2307
2308 // Convert back if we enlarged the element type's bitwidth.
2309 Value result = result_op->getOpResult(0);
2310 if (input_element_type != sum_element_type)
2311 result =
2312 rewriter.create<ConvertOp>(op.getLoc(), result, input_element_type);
2313
2314 rewriter.replaceOp(op, result);
2315 return success();
2316 }
2317 };
2318
2319 using ConvertAvgPool2DOp = ConvertAvgPoolOp<TF::AvgPoolOp, /*num_dims=*/4>;
2320 using ConvertAvgPool3DOp = ConvertAvgPoolOp<TF::AvgPool3DOp, /*num_dims=*/5>;
2321
2322 // `AvgPoolGradOp` is converted to the following operations:
2323 // 1. Divide each entry of the output gradient (the gradient for the previous
2324 // layer in backpropagation order) by the count of the corresponding window
2325 // (i.e., the number of non-padding entries of the window which `AvgPool`
2326 // has mapped to this entry in forward propagation).
2327 // 2. Add appropriate interior and exterior padding for step 3 (see example
2328 // below).
2329 // 3. Convolve the result of step 2. with a kernel consisting of 1's (same shape
2330 // as windows) and stride 1 in each dimension. This is implemented as a
2331 // `ReduceWindowOp` with `AddOp` as body.
2332 //
2333 // Example:
2334 // Let f : R^4 -> R^2 be an average pool function with window size 3, stride 2,
2335 // and SAME padding with 0's. It is defined by
2336 // f(x) = [ (x_1 + x_2 + x_3) / 3 ] ( x = (x_1, x_2, x_3, x_4) )
2337 // [ (x_3 + x_4 + 0) / 2 ] (the 0 results from right padding)
2338 // Note that for SAME padding in `AvgPool` the padded entries are not counted
2339 // for the average, this is why the second denominator is 2 and not 3.
2340 // The Jacobian Df is
2341 // [ 1/3 1/3 1/3 0 ]
2342 // [ 0 0 1/2 1/2 ]
2343 //
2344 // Note that the Jacobian is constant (this is why `ConvertAvgPoolGradOp` only
2345 // needs the original input shape and not the tensor as argument).
2346 // Let v = [ 4 6 ]^T be the output gradient (^T = transposed). Then the
2347 // average pool gradient is given by
2348 // Df^T * v = [ 4/3 4/3 13/3 3 ]^T
2349 // Instead of a matrix-vector-multiplication we can utilize the sparsity and
2350 // structure of Df by using the 3-step approach from above:
2351 // 1. Divide output gradient v by window counts: [ 4/3 6/2 ]^T
2352 // 2. Add appropriate padding: [ 0 0 4/3 0 3 0 ]^T
2353 // 3. Convolve with kernel [ 1 1 1 ]: [ 4/3 4/3 11/3 3 ]^T
2354 //
2355 // Note that the padding in step 2. is chosen in such a way that the subsequent
2356 // convolution produces the gradient. Higher dimensions, different padding, and
2357 // different windows/strides work in a similar way, the main difference is in
2358 // the computation of the paddings in step 2.
2359 //
2360 // For more details on backpropagation for convolution of which `AvgPoolGrad`
2361 // is a special case see `tensorflow/core/kernels/conv_grad_ops.h`.
2362 // `tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir` has more
2363 // examples for different cases.
2364 template <typename OpTy, int num_dims>
2365 class ConvertAvgPoolGradOp : public OpRewritePattern<OpTy> {
2366 using DimVector = SmallVector<int64_t, num_dims>;
2367
2368 public:
2369 using OpRewritePattern<OpTy>::OpRewritePattern;
2370
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const2371 LogicalResult matchAndRewrite(OpTy op,
2372 PatternRewriter &rewriter) const override {
2373 Location loc = op.getLoc();
2374 tensorflow::TensorFormat data_format;
2375 if (!FormatFromString(op.data_format().str(), &data_format)) {
2376 return failure();
2377 }
2378 // `out_grad` is the gradient that was propagated via backpropagation from
2379 // the output layer.
2380 Value out_grad = op.grad();
2381 auto out_grad_type =
2382 out_grad.getType().template dyn_cast<RankedTensorType>();
2383 if (!out_grad_type) {
2384 return failure();
2385 }
2386 Type element_type = out_grad_type.getElementType();
2387 DenseIntElementsAttr orig_input_shape_attr;
2388 if (!matchPattern(op.orig_input_shape(),
2389 m_Constant(&orig_input_shape_attr))) {
2390 return failure();
2391 }
2392 auto orig_input_shape_values = orig_input_shape_attr.getValues<int32_t>();
2393 DimVector orig_input_shape(orig_input_shape_values.begin(),
2394 orig_input_shape_values.end());
2395 DimVector ksize, strides;
2396 GetI64ArrayAttrValues(op.ksize(), &ksize);
2397 GetI64ArrayAttrValues(op.strides(), &strides);
2398 Value zero = GetScalarConstOfType(element_type, loc, 0, &rewriter);
2399
2400 auto out_grad_divided = AvgPoolDivideByCount<OpTy, num_dims>(
2401 out_grad, orig_input_shape, ksize, strides, op, zero, rewriter);
2402
2403 // Get same padding as for original input.
2404 PaddingArray orig_padding = GetReduceWindowPaddingAsArray<num_dims>(
2405 orig_input_shape, op.ksize(), op.strides(), op.padding(), &rewriter);
2406
2407 // Add padding around `out_grad_divided` values in such a way that the
2408 // subsequent `ReduceWindowOp` produces the gradient.
2409 DimVector out_grad_shape(
2410 llvm::to_vector<num_dims>(out_grad_type.getShape()));
2411 DimVector low_padding(num_dims, 0);
2412 DimVector high_padding(num_dims, 0);
2413 DimVector interior_padding(num_dims, 0);
2414 constexpr int num_spatial_dims = num_dims - 2;
2415 for (int i = 0; i < num_spatial_dims; ++i) {
2416 int dim = tensorflow::GetTensorSpatialDimIndex(num_dims, data_format, i);
2417 int orig_input_shape_padded_in_dim = orig_input_shape[dim] +
2418 orig_padding[dim].first +
2419 orig_padding[dim].second;
2420 // Set interior padding such that neighboring entries from
2421 // `out_grad_divided` have distance `strides[dim]` from each other in
2422 // every dimension.
2423 interior_padding[dim] = strides[dim] - 1;
2424 // Set exterior padding in the same way as for convolution gradient
2425 // computation.
2426 auto status = ::xla::ConvGradExtractAndVerifyDimension(
2427 /*input_size=*/orig_input_shape_padded_in_dim,
2428 /*filter_size=*/ksize[dim],
2429 /*output_size=*/out_grad_shape[dim],
2430 /*dilation=*/1,
2431 /*stride=*/strides[dim],
2432 /*padding=*/::xla::Padding::kValid);
2433 if (!status.ok()) {
2434 return failure();
2435 }
2436 ::xla::SpatialDimensionOutputSizeAndPadding &conv_grad_spatial_dim =
2437 status.ValueOrDie();
2438 // Subtract the original exterior padding since it doesn't contribute to
2439 // the gradient. Note that we save one `PadOp` and some unnecessary kernel
2440 // computations, compared to the `xla::AvgPoolGrad` implementation, by
2441 // subtracting the original exterior padding before `ReduceWindowOp`
2442 // instead of trimming the result of `ReduceWindowOp` (the final result is
2443 // the same because all strides are 1).
2444 low_padding[dim] =
2445 conv_grad_spatial_dim.pad_before - orig_padding[dim].first;
2446 high_padding[dim] =
2447 conv_grad_spatial_dim.pad_after - orig_padding[dim].second;
2448
2449 // Update `out_grad_shape` to result shape of following `PadOp`.
2450 out_grad_shape[dim] = low_padding[dim] + high_padding[dim] +
2451 (out_grad_shape[dim] - 1) * strides[dim] + 1;
2452 }
2453 Value reduce_window_input = rewriter.create<PadOp>(
2454 loc, RankedTensorType::get(out_grad_shape, element_type),
2455 /*operand=*/out_grad_divided->getOpResult(0),
2456 /*padding_value=*/zero,
2457 /*edge_padding_low=*/GetI64ElementsAttr(low_padding, &rewriter),
2458 /*edge_padding_high=*/GetI64ElementsAttr(high_padding, &rewriter),
2459 /*interior_padding=*/GetI64ElementsAttr(interior_padding, &rewriter));
2460
2461 // Compute result by convolving `reduce_window_input` with an all-ones
2462 // kernel, using `ReduceWindowOp` with `AddOp` body.
2463
2464 Type sum_element_type = GetSumAccumulationType(element_type);
2465 if (element_type != sum_element_type) {
2466 // Convert to appropriate sum accumulation type to avoid precision loss.
2467 reduce_window_input = rewriter.create<ConvertOp>(loc, reduce_window_input,
2468 sum_element_type);
2469 zero = GetScalarConstOfType(sum_element_type, loc, 0, &rewriter);
2470 }
2471 auto ones = GetI64ElementsAttr(DimVector(num_dims, 1), &rewriter);
2472 auto reduce_window_op = rewriter.create<ReduceWindowOp>(
2473 loc, RankedTensorType::get(orig_input_shape, sum_element_type),
2474 /*operand=*/reduce_window_input,
2475 /*init_value=*/zero,
2476 /*window_dimensions=*/GetI64ElementsAttr(op.ksize()),
2477 /*window_strides=*/ones,
2478 /*base_dilations=*/DenseIntElementsAttr(),
2479 /*window_dilations=*/DenseIntElementsAttr(),
2480 /*padding=*/DenseIntElementsAttr());
2481 BuildReduceBody<AddOp>(sum_element_type, &reduce_window_op.body(),
2482 &rewriter);
2483 Value result = reduce_window_op.getResult();
2484
2485 if (element_type != sum_element_type) {
2486 // Convert back to original element type.
2487 result = rewriter.create<ConvertOp>(op.getLoc(), result, element_type);
2488 }
2489 rewriter.replaceOp(op, {result});
2490 return success();
2491 }
2492 };
2493
2494 using ConvertAvgPool2DGradOp =
2495 ConvertAvgPoolGradOp<TF::AvgPoolGradOp, /*num_dims=*/4>;
2496 using ConvertAvgPool3DGradOp =
2497 ConvertAvgPoolGradOp<TF::AvgPool3DGradOp, /*num_dims=*/5>;
2498
2499 // Converts MaxPool op to HLO ReduceWindow op by setting appropriate window
2500 // dimensions with max as the reduction function.
2501 //
2502 // Sample result for VALID padding mode:
2503 //
2504 // %init = constant dense<...> : tensor<i32>
2505 // %max_pool = "mhlo.reduce"(%inp, %init) ["mhlo.maximum"]
2506 // {window_dimensions = ..., window_strides = ... }
2507 //
2508 template <typename OpTy, int num_dims>
2509 class ConvertMaxPoolOp : public OpRewritePattern<OpTy> {
2510 public:
2511 using OpRewritePattern<OpTy>::OpRewritePattern;
2512
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const2513 LogicalResult matchAndRewrite(OpTy op,
2514 PatternRewriter &rewriter) const override {
2515 Type element_type =
2516 op.input().getType().template cast<TensorType>().getElementType();
2517 if (!element_type.isSignlessIntOrFloat()) return failure();
2518 tensorflow::Padding padding;
2519 if (!GetPaddingFromString(op.padding().str(), &padding).ok())
2520 return failure();
2521 if (padding == tensorflow::Padding::EXPLICIT) {
2522 return failure();
2523 }
2524 Location loc = op.getLoc();
2525 ConstOp init = GetScalarLimitConstOfType(element_type, loc,
2526 hlo::kInfinityLowest, &rewriter);
2527
2528 auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
2529 if (!input_ty) return failure();
2530 DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr<num_dims>(
2531 input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter);
2532 auto reduce = rewriter.create<ReduceWindowOp>(
2533 loc, op.getType(), op.input(), init, GetI64ElementsAttr(op.ksize()),
2534 GetI64ElementsAttr(op.strides()),
2535 /*base_dilations=*/DenseIntElementsAttr(),
2536 /*window_dilations=*/DenseIntElementsAttr(), paddings_attr);
2537 BuildReduceBody<MaxOp>(element_type, &reduce.body(), &rewriter);
2538
2539 rewriter.replaceOp(op, reduce.getResult());
2540 return success();
2541 }
2542 };
2543
2544 using ConvertMaxPool2DOp = ConvertMaxPoolOp<TF::MaxPoolOp, /*num_dims=*/4>;
2545 using ConvertMaxPool3DOp = ConvertMaxPoolOp<TF::MaxPool3DOp, /*num_dims=*/5>;
2546
2547 // Converts SelectV2 to HLO Select op and necessary BroadcastInDim ops on
2548 // operands.
2549 //
2550 // For example, the following source IR:
2551 //
2552 // %select = "tf.SelectV2"(%condition, %t, %e) :
2553 // (tensor<1xi1>, tensor<2xi32>, tensor<1xi32>) -> tensor<2xi32>
2554 //
2555 // will be converted into:
2556 //
2557 // %pred = "mhlo.broadcast_in_dim"(%cond)
2558 // {broadcast_dimensions = dense<[0]> : tensor<1xi64>} :
2559 // (tensor<1xi1>) -> tensor<2xi1>
2560 // %on_false = "mhlo.broadcast_in_dim"(%e)
2561 // {broadcast_dimensions = dense<[0]> : tensor<1xi64>} :
2562 // (tensor<1xi32>) -> tensor<2xi32>
2563 // %select = "mhlo.select"(%pred, %t, %on_false) :
2564 // (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
2565 class ConvertSelectV2Op : public OpRewritePattern<TF::SelectV2Op> {
2566 public:
2567 using OpRewritePattern::OpRewritePattern;
2568
matchAndRewrite(TF::SelectV2Op op,PatternRewriter & rewriter) const2569 LogicalResult matchAndRewrite(TF::SelectV2Op op,
2570 PatternRewriter &rewriter) const override {
2571 llvm::SmallVector<int64_t, 4> broadcast_then_else_shape;
2572 auto ranked_then_type = op.t().getType().dyn_cast<RankedTensorType>();
2573 auto ranked_else_type = op.e().getType().dyn_cast<RankedTensorType>();
2574 auto ranked_cond_type =
2575 op.condition().getType().dyn_cast<RankedTensorType>();
2576 if (!ranked_then_type || !ranked_then_type.hasStaticShape() ||
2577 !ranked_else_type || !ranked_else_type.hasStaticShape() ||
2578 !ranked_cond_type || !ranked_cond_type.hasStaticShape())
2579 return failure();
2580
2581 if (!OpTrait::util::getBroadcastedShape(ranked_then_type.getShape(),
2582 ranked_else_type.getShape(),
2583 broadcast_then_else_shape))
2584 return failure();
2585
2586 llvm::SmallVector<int64_t, 4> broadcast_shape;
2587 if (!OpTrait::util::getBroadcastedShape(broadcast_then_else_shape,
2588 ranked_cond_type.getShape(),
2589 broadcast_shape))
2590 return failure();
2591
2592 auto broadcast_or_self = [&](Value value) {
2593 RankedTensorType type = value.getType().cast<RankedTensorType>();
2594 auto output_type =
2595 RankedTensorType::get(broadcast_shape, type.getElementType());
2596 if (output_type == type) return value;
2597
2598 int64_t rank = type.getRank();
2599 SmallVector<int64_t, 4> broadcast_dimensions(rank);
2600 std::iota(broadcast_dimensions.begin(), broadcast_dimensions.end(),
2601 broadcast_shape.size() - rank);
2602
2603 return rewriter
2604 .create<BroadcastInDimOp>(
2605 op.getLoc(), output_type, value,
2606 GetI64ElementsAttr(broadcast_dimensions, &rewriter))
2607 .getResult();
2608 };
2609
2610 // HLO SelectOp supports broadcasting for predicate/condition if
2611 // predicate/condition is a scalar.
2612 Value pred = ranked_cond_type.getRank() == 0
2613 ? op.condition()
2614 : broadcast_or_self(op.condition());
2615 Value on_true = broadcast_or_self(op.t());
2616 Value on_false = broadcast_or_self(op.e());
2617
2618 rewriter.replaceOpWithNewOp<SelectOp>(op, on_true.getType(), pred, on_true,
2619 on_false);
2620
2621 return success();
2622 };
2623 };
2624
2625 // Converts Sigmoid op to HLO ops computing sigmoid with the following formula:
2626 //
2627 // sigmoid = add(mul(tanh(mul(logits, 0.5)), 0.5), 0.5)
2628 //
2629 // Sample result with 2-d f16 inputs with B batches of with N elements each.
2630 //
2631 // // Create an array of 0.5 the shape of the input array.
2632 // %half = mhlo.constant dense<5.000000e-01> : tensor<f32>
2633 // %half_array = "mhlo.broadcast"(half)
2634 // {broadcast_sizes = dense<2> : tensor<1xi64>}
2635 // : (tensor<f32>) -> tensor<2xf32>
2636 //
2637 // // Compute Tanh of half the logits of the values.
2638 // %halved_logits = mhlo.multiply %logits, %half_array : tensor<2xf32>
2639 // %tanh = "mhlo.tanh"(%halved_logits) : (tensor<2xf32>) -> tensor<2xf32>
2640 //
2641 // // Have the result of Tanh and add 0.5.
2642 // %halved_tanh = mhlo.multiply %tanh, %half : tensor<2xf32>
2643 // %sigmoid = mhlo.add %halved_tanh, %half : tensor<2xf32>
2644 //
2645 class ConvertSigmoidOp : public OpRewritePattern<TF::SigmoidOp> {
2646 public:
2647 using OpRewritePattern::OpRewritePattern;
2648
matchAndRewrite(TF::SigmoidOp op,PatternRewriter & rewriter) const2649 LogicalResult matchAndRewrite(TF::SigmoidOp op,
2650 PatternRewriter &rewriter) const override {
2651 Location loc = op.getLoc();
2652
2653 // Create constant half with shape and element type same as the operand.
2654 Value operand = op.getOperand();
2655 auto operand_ty = operand.getType().cast<TensorType>();
2656 auto scalar_ty = RankedTensorType::get({}, operand_ty.getElementType());
2657 ElementsAttr attr = mlir::hlo::getSplat(&rewriter, scalar_ty, 0.5);
2658 auto scalar_half = rewriter.create<ConstOp>(loc, attr);
2659 auto half = BroadcastToShapeOf(loc, scalar_half, operand, rewriter);
2660
2661 auto scaled_input = rewriter.create<MulOp>(loc, operand, half);
2662 auto tanh_op = rewriter.create<TanhOp>(loc, scaled_input);
2663 auto mul_op = rewriter.create<MulOp>(loc, tanh_op, half);
2664 auto add_op = rewriter.create<AddOp>(loc, mul_op, half);
2665
2666 rewriter.replaceOp(op, add_op.getResult());
2667 return success();
2668 }
2669 };
2670
2671 // Converts Softmax and LogSoftmax to HLO ops, computing softmax with the
2672 // following formulas:
2673 //
2674 // softmax = div(exp(logits), sum(exp(logits)))
2675
2676 // log_softmax = sub(logits, log(sum(exp(logits))))
2677 //
2678 // Sample result with 2-d f16 inputs with B batches of with N elements each.
2679 //
2680 // %reduce_dim = tf.Const dense<[1]> : tensor<1xi64>
2681 //
2682 // // Subtract each element by their batches' max to improve numerical
2683 // // stability.
2684 // %max = "tf.Max"(%input, %reduce_dim)
2685 // : (tensor<BxNxf16>, tensor<1xi64>) -> tensor<Bxf16>
2686 // %sub = "mhlo.subtract"(%inp, %max) {broadcast_dimensions = 0}
2687 // : (tensor<BxNxf16>, tensor<Bxf16>) -> tensor<BxNxf16>
2688 //
2689 // %exp = "mhlo.exponential"(%sub) : (tensor<BxNxf16>) -> tensor<BxNxf16>
2690 // %sum = "tf.Sum"(%exp, %reduce_dim)
2691 // : (tensor<BxNxf32>, tensor<1xi64>) -> tensor<Bxf32>
2692 //
2693 // // Softmax computation:
2694 // %softmax = "mhlo.divide"(%exp, %sum_f16) {broadcast_dimensions = 0}
2695 // : (tensor<BxNxf16>, tensor<Bxf16>) -> tensor<BxNxf16>
2696 template <typename OpTy, bool use_log = true>
2697 class ConvertSoftmaxOp : public OpRewritePattern<OpTy> {
2698 public:
2699 using OpRewritePattern<OpTy>::OpRewritePattern;
2700
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const2701 LogicalResult matchAndRewrite(OpTy op,
2702 PatternRewriter &rewriter) const override {
2703 // Softmax converter requires ranked type because the XLA reduce ops used
2704 // while lowering requires dimensions attribute to reduce along.
2705 // Note that the input and output shape is equivalent, so we use 'logits'
2706 // and its type for shape calculations.
2707 Value logits = op.logits();
2708 RankedTensorType type = logits.getType().dyn_cast<RankedTensorType>();
2709 if (!type) return failure();
2710 auto loc = op.getLoc();
2711 int rank = type.getRank();
2712
2713 // Note that the TensorFlow Softmax op verifies that the input rank is
2714 // greater than or equal to one so the following sequence is valid.
2715 auto reduce_dim = rewriter.create<TF::ConstOp>(
2716 loc, GetI64ElementsAttr({rank - 1}, &rewriter));
2717
2718 // Exponential of input values and then their sum can be very large here.
2719 // Division with large denominator is numerically unstable. To improve
2720 // numerical stability, subtract each batch with their max element so that
2721 // the maximum input value is zero. It can be shown that softmax computed
2722 // after adding or subtracting all inputs in a batch using a common value
2723 // gives mathematically equivalent result.
2724 auto max_logits =
2725 rewriter.create<TF::MaxOp>(loc, logits, reduce_dim,
2726 /*keep_dims=*/rewriter.getBoolAttr(false));
2727 auto max_logits_broadcast =
2728 CommonPrefixBroadcast(loc, logits, max_logits, rewriter);
2729 auto shifted_logits =
2730 rewriter.create<mhlo::SubOp>(loc, type, logits, max_logits_broadcast);
2731
2732 // Exponentiate the inputs.
2733 Value exp = rewriter.create<ExpOp>(loc, type, shifted_logits);
2734
2735 // Compute summation of the exponentials.
2736 auto exp_sum =
2737 rewriter.create<TF::SumOp>(loc, exp, reduce_dim,
2738 /*keep_dims=*/rewriter.getBoolAttr(false));
2739 Value sum = exp_sum.getResult();
2740
2741 if (use_log) {
2742 Value log = rewriter.create<LogOp>(loc, sum);
2743 auto log_broadcast = CommonPrefixBroadcast(loc, logits, log, rewriter);
2744 rewriter.replaceOpWithNewOp<mhlo::SubOp>(op, shifted_logits,
2745 log_broadcast);
2746 } else {
2747 auto sum_broadcast = CommonPrefixBroadcast(loc, logits, sum, rewriter);
2748 rewriter.replaceOpWithNewOp<mhlo::DivOp>(op, exp, sum_broadcast);
2749 }
2750 return success();
2751 }
2752 };
2753
BroadcastBatchMatMulV2Operands(Value lhs,Value rhs,Location loc,Value * out_lhs,Value * out_rhs,PatternRewriter * rewriter)2754 static void BroadcastBatchMatMulV2Operands(Value lhs, Value rhs, Location loc,
2755 Value *out_lhs, Value *out_rhs,
2756 PatternRewriter *rewriter) {
2757 // The dimension structure of the relevant operands to a tf.BatchMatMulV2 is:
2758 // - lhs: [LHSBATCHDIMS..., LHSROWS, LHSCOLS]
2759 // - rhs: [RHSBATCHDIMS..., RHSROWS, RHSCOLS]
2760 // - result: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., LHSROWS, RHSCOLS]
2761 // To perform the matmul, we need to first broadcast lhs and rhs to a common
2762 // set of leading dimensions before doing the actual matmul.
2763 // That's what the code below does.
2764 // In particular, we populate out_lhs and out_rhs to have dimension structure:
2765 // - out_lhs: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., LHSROWS, LHSCOLS]
2766 // - out_rhs: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., RHSROWS, RHSCOLS]
2767 // To do this, we need to calculate those output shapes, which involves
2768 // slicing off the leading batch dims of each operand, broadcasting them,
2769 // then concatenating the broadcasted leading dims back to the row/col dims.
2770 // Finally, we create a TF::BroadcastTo op that does the actual broadcast.
2771
2772 // TODO(silvasean): Reduce duplication across reified shape calculations and
2773 // the static computation of output types needed to create ops.
2774 Value lhs_shape = rewriter->create<shape::ShapeOfOp>(loc, lhs);
2775 Value rhs_shape = rewriter->create<shape::ShapeOfOp>(loc, rhs);
2776 Value const_neg2 =
2777 rewriter->create<ConstantOp>(loc, rewriter->getIndexAttr(-2));
2778 auto lhs_splitted =
2779 rewriter->create<shape::SplitAtOp>(loc, lhs_shape, const_neg2);
2780 auto rhs_splitted =
2781 rewriter->create<shape::SplitAtOp>(loc, rhs_shape, const_neg2);
2782 auto lhs_type = lhs.getType().cast<RankedTensorType>();
2783 auto rhs_type = rhs.getType().cast<RankedTensorType>();
2784 // The last two dimensions are the matrix row/col dimensions. Don't broadcast
2785 // them.
2786 SmallVector<int64_t, 6> result_batch_shape_compile_time_extents;
2787 OpTrait::util::getBroadcastedShape(lhs_type.getShape().drop_back(2),
2788 rhs_type.getShape().drop_back(2),
2789 result_batch_shape_compile_time_extents);
2790 auto result_batch_shape = rewriter->create<shape::BroadcastOp>(
2791 loc, shape::ShapeType::get(rewriter->getContext()), lhs_splitted.head(),
2792 rhs_splitted.head(),
2793 /*error=*/nullptr);
2794 // Lambda which handles the broadcasting of one side to the common
2795 // leading-batch dimensions.
2796 auto broadcast_one_side = [&](Value side, RankedTensorType type,
2797 Value tail_shape, Value *out_side) {
2798 ArrayRef<int64_t> matrix_dims = type.getShape().take_back(2);
2799 auto result_shape = result_batch_shape_compile_time_extents;
2800 result_shape.append(matrix_dims.begin(), matrix_dims.end());
2801 auto result_type =
2802 RankedTensorType::get(result_shape, type.getElementType());
2803 auto shape =
2804 rewriter->create<shape::ConcatOp>(loc, result_batch_shape, tail_shape);
2805 auto shape_tensor = rewriter->create<shape::ToExtentTensorOp>(
2806 loc,
2807 RankedTensorType::get({static_cast<int64_t>(result_shape.size())},
2808 rewriter->getIndexType()),
2809 shape);
2810 *out_side = rewriter->create<TF::BroadcastToOp>(loc, result_type, side,
2811 shape_tensor);
2812 };
2813 broadcast_one_side(lhs, lhs_type, lhs_splitted.tail(), out_lhs);
2814 broadcast_one_side(rhs, rhs_type, rhs_splitted.tail(), out_rhs);
2815 }
2816
2817 class ConvertBatchMatMulV2Op : public OpRewritePattern<TF::BatchMatMulV2Op> {
2818 public:
2819 // TODO(hinsu): Legalize this op to Einsum op. HLO Einsum op needs to be moved
2820 // to CHLO and it is missing legalization to MHLO. Once that is done, this
2821 // pattern's benefit can be changed back to one as well as the fallback
2822 // lowering pattern for the op can be removed.
2823 //
2824 // Set benefit of this pattern to zero to prefer the fallback pattern when
2825 // available and applicable. That pattern avoids broadcast on operands and is
2826 // therefore faster.
ConvertBatchMatMulV2Op(MLIRContext * context)2827 explicit ConvertBatchMatMulV2Op(MLIRContext *context)
2828 : OpRewritePattern<TF::BatchMatMulV2Op>(context, /*benefit=*/0) {}
2829
matchAndRewrite(TF::BatchMatMulV2Op op,PatternRewriter & rewriter) const2830 LogicalResult matchAndRewrite(TF::BatchMatMulV2Op op,
2831 PatternRewriter &rewriter) const override {
2832 Value lhs = op.x();
2833 Value rhs = op.y();
2834 auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
2835 auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
2836 if (!lhs_type || !rhs_type) return failure();
2837 if (lhs_type.getElementType().isa<ComplexType>() && op.adj_x()) {
2838 lhs = rewriter.create<TF::ConjOp>(op.getLoc(), lhs_type, lhs);
2839 }
2840 if (rhs_type.getElementType().isa<ComplexType>() && op.adj_y()) {
2841 rhs = rewriter.create<TF::ConjOp>(op.getLoc(), rhs_type, rhs);
2842 }
2843
2844 // Broadcast both operands.
2845 BroadcastBatchMatMulV2Operands(lhs, rhs, op.getLoc(), &lhs, &rhs,
2846 &rewriter);
2847 lhs_type = lhs.getType().cast<RankedTensorType>();
2848 rhs_type = rhs.getType().cast<RankedTensorType>();
2849 assert(lhs_type.getRank() == rhs_type.getRank());
2850 int64_t rank = lhs_type.getRank();
2851 auto batch_dimensions = GetI64ElementsAttr(
2852 llvm::to_vector<4>(llvm::seq<int64_t>(0, rank - 2)), &rewriter);
2853 auto lhs_contracting_dimensions = GetI64ElementsAttr(
2854 llvm::makeArrayRef({op.adj_x() ? rank - 2 : rank - 1}), &rewriter);
2855 auto rhs_contracting_dimensions = GetI64ElementsAttr(
2856 llvm::makeArrayRef({op.adj_y() ? rank - 1 : rank - 2}), &rewriter);
2857 auto dimension_numbers = DotDimensionNumbers::get(
2858 /*lhs_batching_dimensions=*/batch_dimensions,
2859 /*rhs_batching_dimensions=*/batch_dimensions,
2860 /*lhs_contracting_dimensions=*/lhs_contracting_dimensions,
2861 /*rhs_contracting_dimensions=*/rhs_contracting_dimensions,
2862 rewriter.getContext());
2863 // TODO(silvasean): Emit shape checks for contracting dimensions.
2864 // (The batch dimensions are checked by the broadcasting logic)
2865 rewriter.replaceOpWithNewOp<DotGeneralOp>(op, op.getType(), lhs, rhs,
2866 dimension_numbers,
2867 /*precision_config=*/nullptr);
2868 return success();
2869 }
2870 };
2871
2872 // Converts the tf.Split op into a series of HLO slice ops when the tensor to be
2873 // split has fully static shape and the dimension to split is a constant.
2874 //
2875 // The main logic of this pattern is to calculate the index start and end range
2876 // for each slice. And this happens only on the dimension to be split; for all
2877 // other dimensions, all resultant slices' index start and end range covers the
2878 // input tensor's full range. Strides for all resultant slices are all one.
2879 //
2880 // For example, the following source IR:
2881 //
2882 // %dim = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
2883 // %0:3 = "tf.Split"(%dim, %input) : (tensor<i32>, tensor<4x6xf32>) ->
2884 // (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>)
2885 //
2886 // will be converted into:
2887 //
2888 // %0 = "mhlo.slice"(%input) {
2889 // limit_indices = dense<[4, 2]> : tensor<2xi64>,
2890 // start_indices = dense<0> : tensor<2xi64>,
2891 // strides = dense<1> : tensor<2xi64>} :
2892 // (tensor<4x6xf32>) -> tensor<4x2xf32>
2893 // %1 = "mhlo.slice"(%input) {
2894 // limit_indices = dense<4> : tensor<2xi64>,
2895 // start_indices = dense<[0, 2]> : tensor<2xi64>,
2896 // strides = dense<1> : tensor<2xi64>} :
2897 // (tensor<4x6xf32>) -> tensor<4x2xf32>
2898 // %2 = "mhlo.slice"(%input) {
2899 // limit_indices = dense<[4, 6]> : tensor<2xi64>,
2900 // start_indices = dense<[0, 4]> : tensor<2xi64>,
2901 // strides = dense<1> : tensor<2xi64>} :
2902 // (tensor<4x6xf32>) -> tensor<4x2xf32>
2903 // TODO(antiagainst): consider lowering into TF ops so the pattern can be more
2904 // applicable.
2905 class ConvertSplitOp : public OpRewritePattern<TF::SplitOp> {
2906 public:
2907 using OpRewritePattern::OpRewritePattern;
2908
matchAndRewrite(TF::SplitOp op,PatternRewriter & rewriter) const2909 LogicalResult matchAndRewrite(TF::SplitOp op,
2910 PatternRewriter &rewriter) const override {
2911 // We can only split along static dimensions.
2912 auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
2913 if (!input_type) return failure();
2914
2915 // We can only match when the split dimension is a constant scalar.
2916 DenseIntElementsAttr split_dim_attr;
2917 if (!matchPattern(op.split_dim(), m_Constant(&split_dim_attr)))
2918 return failure();
2919
2920 // Get the dimension we are splitting at. Offset properly if it's negative.
2921 int64_t input_rank = input_type.getRank();
2922 int64_t dim_index = (*split_dim_attr.begin()).getSExtValue();
2923 if (dim_index < 0) dim_index += input_rank;
2924
2925 // Calculate the dimension size for each slice along the split dimension.
2926 int64_t input_dim_size = input_type.getDimSize(dim_index);
2927 // If we are splitting along the dynamic dimension then we cannot compute
2928 // the static dimension length.
2929 if (TensorType::isDynamic(input_dim_size)) return failure();
2930
2931 int64_t num_splits = op.getNumResults();
2932 int64_t slice_size = input_dim_size / num_splits;
2933
2934 // Get each slice's type.
2935 auto slice_shape = llvm::to_vector<4>(input_type.getShape());
2936 slice_shape[dim_index] = slice_size;
2937 Type slice_type =
2938 RankedTensorType::get(slice_shape, input_type.getElementType());
2939
2940 // Parameters for constructing each slice.
2941 SmallVector<int64_t, 4> begin_indices(input_rank, 0);
2942 auto end_indices = llvm::to_vector<4>(input_type.getShape());
2943 SmallVector<int64_t, 4> strides(input_rank, 1);
2944
2945 // All HLO slice results used to replace the original tf.Split op.
2946 SmallVector<Value, 4> slices;
2947 slices.reserve(num_splits);
2948
2949 for (int i = 0; i < num_splits; ++i) {
2950 begin_indices[dim_index] = i * slice_size;
2951 end_indices[dim_index] = (i + 1) * slice_size;
2952 slices.push_back(
2953 rewriter.create<SliceOp>(op.getLoc(), slice_type, op.value(),
2954 GetI64ElementsAttr(begin_indices, &rewriter),
2955 GetI64ElementsAttr(end_indices, &rewriter),
2956 GetI64ElementsAttr(strides, &rewriter)));
2957 }
2958
2959 rewriter.replaceOp(op, slices);
2960 return success();
2961 }
2962 };
2963
2964 // Converts the tf.SplitV op into a series of HLO slice ops when the tensor to
2965 // be split has fully static shape and the dimension to split and split sizes
2966 // are constants.
2967 //
2968 // This is similar to the conversion for tf.Split op other than that the size of
2969 // each chunk on the dimension to split is explicitly given as an op operand
2970 // and they are not necessarily the same.
2971 //
2972 // For example, given the following IR:
2973 //
2974 // %split_sizes = "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi32>}
2975 // %split_dim = "tf.Const"() {value = dense<1> : tensor<i32>}
2976 // %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) :
2977 // (tensor<4x6xf32>, tensor<3xi32>, tensor<i32>) ->
2978 // (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>)
2979 //
2980 // We will generate slices following slices:
2981 // %0 = "mhlo.slice"(%input) {
2982 // limit_indices = dense<[4, 1]> : tensor<2xi64>,
2983 // start_indices = dense<0> : tensor<2xi64>,
2984 // strides = dense<1> : tensor<2xi64>} :
2985 // (tensor<4x6xf32>) -> tensor<4x1xf32>
2986 // %1 = "mhlo.slice"(%input) {
2987 // limit_indices = dense<[4, 3]> : tensor<2xi64>,
2988 // start_indices = dense<[0, 1]> : tensor<2xi64>,
2989 // strides = dense<1> : tensor<2xi64>} :
2990 // (tensor<4x6xf32>) -> tensor<4x2xf32>
2991 // %2 = "mhlo.slice"(%input) {
2992 // limit_indices = dense<[4, 6]> : tensor<2xi64>,
2993 // start_indices = dense<[0, 3]> : tensor<2xi64>,
2994 // strides = dense<1> : tensor<2xi64>} :
2995 // (tensor<4x6xf32>) -> tensor<4x3xf32>
2996 class ConvertSplitVOp : public OpRewritePattern<TF::SplitVOp> {
2997 public:
2998 using OpRewritePattern::OpRewritePattern;
2999
matchAndRewrite(TF::SplitVOp op,PatternRewriter & rewriter) const3000 LogicalResult matchAndRewrite(TF::SplitVOp op,
3001 PatternRewriter &rewriter) const override {
3002 // We can only split along static dimensions.
3003 // TODO(b/145731001): enhance to support dynamic-shaped inputs.
3004 auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
3005 if (!input_type) return failure();
3006
3007 // We can only match when the split dimension is a constant scalar.
3008 DenseIntElementsAttr split_dim_attr;
3009 if (!matchPattern(op.split_dim(), m_Constant(&split_dim_attr)))
3010 return failure();
3011
3012 // We can only match when the split sizes is a constant int vector.
3013 DenseIntElementsAttr split_sizes_attr;
3014 if (!matchPattern(op.size_splits(), m_Constant(&split_sizes_attr)))
3015 return failure();
3016
3017 // Get each chunck's size along the dimension to split. It may contain
3018 // dynamic sizes and we need to update it if so.
3019 SmallVector<int64_t, 4> split_sizes;
3020 int64_t total_dim_size = 0; // Total dimension size assigned to splits
3021 llvm::Optional<int> dynamic_dim_index;
3022 split_sizes.reserve(
3023 split_sizes_attr.getType().cast<ShapedType>().getNumElements());
3024 for (auto dim : llvm::enumerate(split_sizes_attr)) {
3025 int64_t dim_val = dim.value().getSExtValue();
3026 split_sizes.push_back(dim_val);
3027 if (dim_val == ShapedType::kDynamicSize) {
3028 // We cannot have more than one dynamic dimension.
3029 assert(!dynamic_dim_index && "invalid split sizes");
3030 dynamic_dim_index = dim.index();
3031 } else {
3032 total_dim_size += dim_val;
3033 }
3034 }
3035
3036 // Get the dimension we are splitting at. Offset properly if it's negative.
3037 int64_t input_rank = input_type.getRank();
3038 int64_t dim_index = (*split_dim_attr.begin()).getSExtValue();
3039 if (dim_index < 0) dim_index += input_rank;
3040
3041 int64_t input_dim_size = input_type.getDimSize(dim_index);
3042 if (TensorType::isDynamic(input_dim_size)) return failure();
3043
3044 assert(((dynamic_dim_index && total_dim_size <= input_dim_size) ||
3045 (!dynamic_dim_index && total_dim_size == input_dim_size)) &&
3046 "invalid split sizes");
3047
3048 // Update the dynamic dimension with calculated concrete size.
3049 if (dynamic_dim_index)
3050 split_sizes[*dynamic_dim_index] = input_dim_size - total_dim_size;
3051
3052 // Parameters for constructing each slice.
3053 SmallVector<int64_t, 4> begin_indices(input_rank, 0);
3054 auto end_indices = llvm::to_vector<4>(input_type.getShape());
3055 SmallVector<int64_t, 4> strides(input_rank, 1);
3056
3057 // All HLO slice results used to replace the original tf.Split op.
3058 SmallVector<Value, 4> slices;
3059 slices.reserve(op.getNumResults());
3060
3061 for (int i = 0, end = op.getNumResults(); i < end; ++i) {
3062 end_indices[dim_index] = begin_indices[dim_index] + split_sizes[i];
3063 slices.push_back(rewriter.create<mhlo::SliceOp>(
3064 op.getLoc(), op.value(), GetI64ElementsAttr(begin_indices, &rewriter),
3065 GetI64ElementsAttr(end_indices, &rewriter),
3066 GetI64ElementsAttr(strides, &rewriter)));
3067 // Prepare the begin indice for the next slice.
3068 begin_indices[dim_index] = end_indices[dim_index];
3069 }
3070
3071 rewriter.replaceOp(op, slices);
3072 return success();
3073 }
3074 };
3075
3076 // Converts StridedSlice op to HLO Slice op along with Reverse op to handle
3077 // negative strides and Reshape op to update the output shape. Indices and
3078 // strides operands are converted to attributes with non-negative indexing.
3079 //
3080 // If the begin input is not a compile time constant, the begin input needs to
3081 // be sliced and the slice needs to be lowered to mhlo.DynamicSlice. In this
3082 // case, strides must have a known value of 1 (otherwise we have insufficient
3083 // information to conform to XLA's op semantics).
3084 //
3085 // For example with an op like following,
3086 // tf.StridedSlice(%input, %begin, %end, %strides) {shrink_axis_mask = 1}
3087 // : tensor<AxBxf32> -> tensor<Pxf32>
3088 //
3089 // If the %begin input is constant, output would be:
3090 // %reversed = "mhlo.Reverse" (%input) {dimensions = ...}
3091 // %sliced = "mhlo.Slice" (%input)
3092 // {start_indices = ..., limit_indices = ..., strides = ...}
3093 // %output = "mhlo.Reshape" (%sliced) : tensor<1xPxf32> -> tensor<Pxf32>
3094 //
3095 class ConvertStridedSliceOp : public OpRewritePattern<TF::StridedSliceOp> {
3096 public:
3097 using OpRewritePattern::OpRewritePattern;
3098
rewriteWithConstantBegin(TF::StridedSliceOp op,ArrayRef<int64_t> begin_indices,ArrayRef<int64_t> end_indices,ArrayRef<int64_t> strides,RankedTensorType input_ty,PatternRewriter & rewriter) const3099 LogicalResult rewriteWithConstantBegin(TF::StridedSliceOp op,
3100 ArrayRef<int64_t> begin_indices,
3101 ArrayRef<int64_t> end_indices,
3102 ArrayRef<int64_t> strides,
3103 RankedTensorType input_ty,
3104 PatternRewriter &rewriter) const {
3105 SmallVector<int64_t, 4> hlo_begin_indices, hlo_end_indices, hlo_strides,
3106 dims_to_reverse;
3107 int64_t input_rank = input_ty.getRank();
3108 ArrayRef<int64_t> input_shape = input_ty.getShape();
3109 hlo_begin_indices.reserve(input_rank);
3110 hlo_end_indices.reserve(input_rank);
3111 hlo_strides.reserve(input_rank);
3112
3113 int64_t indices_elements = begin_indices.size();
3114 if (input_rank < indices_elements) return failure();
3115
3116 // Convert from TensorFlow negative or out of range indices and strides
3117 // values to legal HLO Slice attributes.
3118 for (int i = 0, e = indices_elements; i != e; i++) {
3119 int64_t begin = begin_indices[i];
3120 int64_t end = end_indices[i];
3121 int64_t stride = strides[i];
3122
3123 if (stride < 0) {
3124 // Negative stride means that the output values are computed starting
3125 // from end until begin. Mark the dimension for reversal before slice
3126 // and compute indices for the reversed input.
3127 dims_to_reverse.push_back(i);
3128 begin = (input_shape[i] - 1) - begin;
3129 end = (input_shape[i] - 1) - end;
3130 stride = -stride;
3131 }
3132
3133 // Unlike TensorFlow, HLO requires begin and end values to be within
3134 // range.
3135 begin = std::max(int64_t(0), begin);
3136 end = std::max(begin, end);
3137 end = std::min(end, input_shape[i]);
3138
3139 hlo_begin_indices.push_back(begin);
3140 hlo_end_indices.push_back(end);
3141 hlo_strides.push_back(stride);
3142 }
3143
3144 Location loc = op.getLoc();
3145 Value input = op.input();
3146 if (!dims_to_reverse.empty())
3147 input = rewriter.create<ReverseOp>(
3148 loc, input_ty, op.input(),
3149 GetI64ElementsAttr(dims_to_reverse, &rewriter));
3150 auto sliced = rewriter.create<SliceOp>(
3151 loc, input, GetI64ElementsAttr(hlo_begin_indices, &rewriter),
3152 GetI64ElementsAttr(hlo_end_indices, &rewriter),
3153 GetI64ElementsAttr(hlo_strides, &rewriter));
3154
3155 // Reshape slice result so that the shape is updated depending on
3156 // 'new_axis_mask' or 'shrink_axis_mask' attributes.
3157 rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), sliced);
3158 return success();
3159 }
3160
rewriteWithUnknownBegin(TF::StridedSliceOp op,RankedTensorType input_ty,RankedTensorType result_ty,PatternRewriter & rewriter) const3161 LogicalResult rewriteWithUnknownBegin(TF::StridedSliceOp op,
3162 RankedTensorType input_ty,
3163 RankedTensorType result_ty,
3164 PatternRewriter &rewriter) const {
3165 // If begin and end values are dynamic, we can only support this lowering
3166 // if strides are a known value of 1.
3167 DenseIntElementsAttr sparse_strides_attr;
3168 if (!matchPattern(op.strides(), m_Constant(&sparse_strides_attr))) {
3169 return rewriter.notifyMatchFailure(
3170 op,
3171 "requires that strides are known when begin/end values are dynamic");
3172 }
3173 SmallVector<int64_t, 4> strides;
3174 int64_t stride_value;
3175 for (const APInt &stride : sparse_strides_attr) {
3176 if ((stride_value = stride.getSExtValue()) != 1) {
3177 return rewriter.notifyMatchFailure(op,
3178 "requires that strides are all 1 "
3179 "when begin/end values are dynamic");
3180 }
3181 strides.push_back(stride_value);
3182 }
3183
3184 ArrayRef<int64_t> input_shape = input_ty.getShape();
3185 int last_dim = std::max(static_cast<int>(input_shape.size()) - 1, 0);
3186
3187 // When begin/end values are dynamic, we can only support shrinking a major
3188 // axis. For instance, if there are 4 dims, we can support a
3189 // shrink_axis_mask of 0001 (1), 0011 (3), 0111 (7), or 1111 (15), but no
3190 // other.
3191 bool shrink_axis_mask_ok = llvm::isMask_64(op.shrink_axis_mask());
3192 if (!shrink_axis_mask_ok)
3193 return rewriter.notifyMatchFailure(
3194 op,
3195 "requires that shrink_axis_mask, if set, refer to a major axis "
3196 "dimension (when begin/end values are dynamic)");
3197
3198 // When begin/end values are dynamic, the ellipsis mask, if set, must refer
3199 // to the last dimension.
3200 int ellipsis_mask = op.ellipsis_mask();
3201 if (!(ellipsis_mask == 0 || ellipsis_mask == (1 << last_dim)))
3202 return rewriter.notifyMatchFailure(
3203 op,
3204 "requires that ellipsis_mask, if set, refer to the last dimension of "
3205 "input (when begin/end values are dynamic)");
3206
3207 uint64_t begin_mask = op.begin_mask();
3208 if (begin_mask)
3209 return rewriter.notifyMatchFailure(
3210 op,
3211 "requires that begin_mask is either set to 0 or not set when "
3212 "begin/end values are dynamic");
3213 uint64_t end_mask = op.end_mask();
3214 if (end_mask)
3215 return rewriter.notifyMatchFailure(
3216 op,
3217 "requires that end_mask is either set to 0 or not set when begin/end "
3218 "values are dynamic");
3219 uint64_t new_axis_mask = op.new_axis_mask();
3220 if (new_axis_mask)
3221 return rewriter.notifyMatchFailure(
3222 op,
3223 "requires that new_axis_mask is either set to 0 or not set when "
3224 "begin/end values are dynamic");
3225
3226 // In this case where the begin and end values are dynamic, the number of
3227 // output elements has to be equal to the number of input elements that
3228 // are sliced.
3229 int output_elements = result_ty.getNumElements();
3230 int input_elements_sliced = 1;
3231
3232 // Begin must be a ranked, 1-dimensional tensor: This is checked by the
3233 // verifier.
3234 int64_t slicing_dim_size =
3235 op.begin().getType().cast<RankedTensorType>().getShape()[0];
3236 const int input_rank = input_shape.size();
3237 for (int d = slicing_dim_size; d < input_rank; ++d) {
3238 // We only support slicing major dimensions, so minor dimensions after
3239 // slicing dimensions are all sliced with their full sizes.
3240 input_elements_sliced *= input_shape[d];
3241 }
3242 if (input_elements_sliced != output_elements) {
3243 return rewriter.notifyMatchFailure(
3244 op,
3245 "requires the number of output elements to be equal to the number of "
3246 "input elements sliced (when begin/end values are dynamic)");
3247 }
3248
3249 SmallVector<Value, 4> slice_begin_indices;
3250 // For the dimensions that are to be sliced, all have slice sizes of 1.
3251 SmallVector<int64_t, 4> slice_sizes(slicing_dim_size, 1);
3252 auto begin_element_ty =
3253 op.begin().getType().cast<ShapedType>().getElementType();
3254 // Scalar tensor type.
3255 TensorType type = RankedTensorType::get(/*shape=*/{}, begin_element_ty);
3256 Location loc = op.getLoc();
3257 auto zero = GetScalarConstOfType(begin_element_ty, loc, 0, &rewriter);
3258 for (int d = 0; d < slicing_dim_size; ++d) {
3259 auto index = rewriter.create<SliceOp>(
3260 loc, op.begin(), GetI64ElementsAttr({d}, &rewriter),
3261 GetI64ElementsAttr({d + 1}, &rewriter),
3262 GetI64ElementsAttr({1}, &rewriter));
3263 // Convert index to scalar.
3264 auto reshaped_index = rewriter.create<ReshapeOp>(loc, type, index);
3265 // If the index is negative, wrap it around with dimension size.
3266 auto index_negative =
3267 rewriter.create<TF::LessOp>(loc, reshaped_index, zero);
3268 auto input_val = GetScalarConstOfType(begin_element_ty, loc,
3269 input_shape[d], &rewriter);
3270 auto wrapped_index =
3271 rewriter.create<TF::AddV2Op>(loc, input_val, reshaped_index);
3272 auto final_index = rewriter.create<SelectOp>(
3273 loc, type, index_negative, wrapped_index, reshaped_index);
3274 slice_begin_indices.push_back(final_index);
3275 }
3276
3277 // For non-slice dims, get the full slice of that dimension.
3278 for (int d = slicing_dim_size, end = input_shape.size(); d < end; ++d) {
3279 slice_sizes.push_back(input_shape[d]);
3280 slice_begin_indices.push_back(zero);
3281 }
3282
3283 auto slice_sizes_attr = GetI64ElementsAttr(slice_sizes, &rewriter);
3284 // This must be an xla DynamicSlice op due to the inputs that aren't
3285 // constant.
3286 auto sliced = rewriter.create<DynamicSliceOp>(
3287 loc, op.getType(), op.input(), slice_begin_indices, slice_sizes_attr);
3288
3289 // Reshape slice result so that the shape is updated depending on
3290 // 'new_axis_mask' or 'shrink_axis_mask' attributes.
3291 rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), sliced);
3292 return success();
3293 }
3294
matchAndRewrite(TF::StridedSliceOp op,PatternRewriter & rewriter) const3295 LogicalResult matchAndRewrite(TF::StridedSliceOp op,
3296 PatternRewriter &rewriter) const override {
3297 // Input shape needs to be static to convert negative indices in TensorFlow
3298 // to absolute indices required by HLO.
3299 //
3300 // TODO(hinsu): Relax this constraint for ops without negative indices and
3301 // strides.
3302 auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
3303 if (!input_ty || !input_ty.hasStaticShape()) return failure();
3304
3305 // Output shape needs to be static to apply 'new_axis_mask' or
3306 // 'shrink_axis_mask' by reshaping tensor after slice.
3307 //
3308 // TODO(hinsu): Relax this constraint for ops without the above masks.
3309 auto result_ty = op.getType().dyn_cast<RankedTensorType>();
3310 if (!result_ty || !result_ty.hasStaticShape()) return failure();
3311
3312 DenseIntElementsAttr sparse_begin_attr, sparse_end_attr;
3313 if (!matchPattern(op.begin(), m_Constant(&sparse_begin_attr)) ||
3314 !matchPattern(op.end(), m_Constant(&sparse_end_attr))) {
3315 return rewriteWithUnknownBegin(op, input_ty, result_ty, rewriter);
3316 }
3317
3318 SmallVector<int64_t, 4> begin_indices, end_indices, strides;
3319 if (!op.GetSlicedBoundRanges(&begin_indices, &end_indices, &strides)) {
3320 return failure();
3321 }
3322 return rewriteWithConstantBegin(op, begin_indices, end_indices, strides,
3323 input_ty, rewriter);
3324 }
3325 };
3326
3327 // Converts tf.StridedSliceGrad to HLO reshape, reverse and padding ops.
3328 //
3329 // tf.StridedSlice is taking slice of the input tensor. tf.StridedSliceGrad does
3330 // the reverse: it propagates the graident for the sliced tensor to the original
3331 // input tensor by doing padding with zeros. The main logic is calculating the
3332 // indices and strides for padding.
3333 class ConvertStridedSliceGradOp
3334 : public OpRewritePattern<TF::StridedSliceGradOp> {
3335 public:
3336 using OpRewritePattern::OpRewritePattern;
3337
matchAndRewrite(TF::StridedSliceGradOp op,PatternRewriter & rewriter) const3338 LogicalResult matchAndRewrite(TF::StridedSliceGradOp op,
3339 PatternRewriter &rewriter) const override {
3340 // We need constant input shape to perform padding calculations later.
3341 DenseIntElementsAttr input_shape_attr;
3342 if (!matchPattern(op.shape(), m_Constant(&input_shape_attr)))
3343 return failure();
3344
3345 // We also need constant begin/end indices and strides to perform padding
3346 // calculations.
3347 // Bounded shape after performing strided slice
3348 SmallVector<int64_t, 4> shape;
3349 // Bounded begin, end, and strides for strided slice
3350 SmallVector<int64_t, 4> begin_indices, end_indices, strides;
3351 if (!op.GetSlicedShapeAndBoundRanges(&shape, &begin_indices, &end_indices,
3352 &strides))
3353 return failure();
3354
3355 Value grad = op.dy();
3356 Type element_type = grad.getType().cast<ShapedType>().getElementType();
3357
3358 // Perform reshape to undo any new/shrink axes done by strided slice.
3359 grad = rewriter.create<mhlo::ReshapeOp>(
3360 op.getLoc(), RankedTensorType::get(shape, element_type), grad);
3361
3362 SmallVector<int64_t, 4> padding_low, padding_high, padding_interm;
3363 SmallVector<int64_t, 4> dims_to_reverse;
3364 padding_low.reserve(shape.size());
3365 padding_high.reserve(shape.size());
3366 padding_interm.reserve(shape.size());
3367
3368 // Prepare padding parameters for each dimension.
3369 for (int i = 0, e = shape.size(); i < e; ++i) {
3370 int64_t input_dim = (*(input_shape_attr.begin() + i)).getSExtValue();
3371 if (strides[i] > 0) {
3372 padding_low.push_back(begin_indices[i]);
3373 padding_interm.push_back(strides[i] - 1);
3374
3375 // Pad the upper dimension up to the expected input shape. It's not
3376 // sufficient simply to use end_indices[i] to compute the padding in
3377 // cases where the stride does not divide evenly into the interval
3378 // between begin_indices[i] and end_indices[i].
3379 int64_t size =
3380 padding_low[i] + shape[i] + (shape[i] - 1) * padding_interm[i];
3381 padding_high.push_back(input_dim - size);
3382 } else {
3383 dims_to_reverse.push_back(i);
3384 padding_high.push_back(input_dim - begin_indices[i] - 1);
3385 padding_interm.push_back(-strides[i] - 1);
3386
3387 // Pad the lower dimension up to the expected input shape.
3388 int64_t size =
3389 padding_high[i] + shape[i] + (shape[i] - 1) * padding_interm[i];
3390 padding_low.push_back(input_dim - size);
3391 }
3392 }
3393
3394 if (!dims_to_reverse.empty()) {
3395 grad = rewriter.create<mhlo::ReverseOp>(
3396 op.getLoc(), grad.getType(), grad,
3397 GetI64ElementsAttr(dims_to_reverse, &rewriter));
3398 }
3399
3400 auto zero = GetScalarConstOfType(element_type, op.getLoc(), 0, &rewriter);
3401 rewriter.replaceOpWithNewOp<mhlo::PadOp>(
3402 op, op.getType(), grad, zero,
3403 GetI64ElementsAttr(padding_low, &rewriter),
3404 GetI64ElementsAttr(padding_high, &rewriter),
3405 GetI64ElementsAttr(padding_interm, &rewriter));
3406 return success();
3407 }
3408 };
3409
3410 /// Converts the RangeOp tensorflow op to a mhlo.iota op with a scaling and
3411 /// offset applied to generate the range values. The output tensor needs to
3412 /// have a static shape.
3413 ///
3414 /// For example an op like the following:
3415 /// %result = "tf.Range"(%start, %limit, %delta) {Tidx = "tfdtype$DT_FLOAT"}
3416 /// : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<5xf32>
3417 ///
3418 /// Output would be:
3419 /// %iota = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xf32>
3420 /// %scaled = "mhlo.multiply"(%iota, %delta)
3421 /// {broadcast_dimensions = dense<[]> : tensor<0xi64>} :
3422 /// (tensor<5xf32>, tensor<f32>) -> tensor<5xf32>
3423 /// %result = "mhlo.add"(%scaled, %offset)
3424 /// {broadcast_dimensions = dense<[]> : tensor<0xi64>} :
3425 /// (tensor<5xf32>, tensor<f32>) -> tensor<5xf32>
3426 ///
3427 /// Implementation is defined in C++ due to no type interface for the iota op.
3428 class ConvertRangeOp : public OpRewritePattern<TF::RangeOp> {
3429 using OpRewritePattern<TF::RangeOp>::OpRewritePattern;
3430
matchAndRewrite(TF::RangeOp op,PatternRewriter & rewriter) const3431 LogicalResult matchAndRewrite(TF::RangeOp op,
3432 PatternRewriter &rewriter) const override {
3433 auto result = op.getResult();
3434 auto result_type = result.getType();
3435 if (!result_type.cast<ShapedType>().hasStaticShape()) {
3436 return failure();
3437 }
3438
3439 auto iota = rewriter.create<IotaOp>(op.getLoc(), result_type,
3440 rewriter.getI64IntegerAttr(0));
3441 auto scaled = rewriter.create<chlo::BroadcastMulOp>(
3442 op.getLoc(), result_type, iota, op.delta(),
3443 hlo::getBroadcastDimensionsAttr(&rewriter, iota, op.delta()));
3444 rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
3445 op, result_type, scaled, op.start(),
3446 hlo::getBroadcastDimensionsAttr(&rewriter, scaled, op.start()));
3447 return success();
3448 }
3449 };
3450
3451 // Converts RangeOp for cases with the length is a dynamic value. The shape of
3452 // the resulting tensor computed, then the start and delta is used with the
3453 // dynamic_iota value to compute the final range value.
3454 //
3455 // For example, the resulting range op value:
3456 // %range = "tf.range"(%start, %limit, %delta)
3457 //
3458 // Is converted to the following.
3459 // %start + %delta * iota(ceil(abs((%limit - %start) / %delta))
3460 //
3461 // Implementation is defined in C++ due to the complicated type behavior.
3462 class ConvertDynamicRangeOp : public OpRewritePattern<TF::RangeOp> {
3463 using OpRewritePattern<TF::RangeOp>::OpRewritePattern;
3464
matchAndRewrite(TF::RangeOp op,PatternRewriter & rewriter) const3465 LogicalResult matchAndRewrite(TF::RangeOp op,
3466 PatternRewriter &rewriter) const override {
3467 auto result = op.getResult();
3468 auto result_type = result.getType().cast<ShapedType>();
3469 if (result_type.hasStaticShape()) {
3470 return failure();
3471 }
3472
3473 Value start = op.start();
3474 Value delta = op.delta();
3475 Value limit = op.limit();
3476
3477 // To compute the length we need to use floating point calculations so that
3478 // ceil can be computed for the number of steps.
3479 auto compute_element_type =
3480 getElementTypeOrSelf(start.getType()).isa<FloatType>()
3481 ? getElementTypeOrSelf(start.getType())
3482 : rewriter.getF64Type();
3483 auto compute_type = RankedTensorType::get(
3484 limit.getType().cast<ShapedType>().getShape(), compute_element_type);
3485
3486 // Compute the length of the sequence we are going to need. This includes
3487 // some conversion to float for the operations.
3488 //
3489 // %size = ceil(abs((%limit - %start) / %delta))
3490 auto range = rewriter.create<mhlo::SubOp>(op.getLoc(), limit, start);
3491 auto abs = rewriter.create<mhlo::AbsOp>(op.getLoc(), range);
3492
3493 // Delta is not necessarily the same type as start and limit.
3494 auto abs_cast =
3495 rewriter.create<mhlo::ConvertOp>(op.getLoc(), compute_type, abs);
3496 auto delta_cast =
3497 rewriter.create<mhlo::ConvertOp>(op.getLoc(), compute_type, delta);
3498
3499 // Compute the total number of integer steps and convert to the HLO
3500 // dimension tensor.
3501 auto normalized =
3502 rewriter.create<mhlo::DivOp>(op.getLoc(), abs_cast, delta_cast);
3503 auto ceil = rewriter.create<mhlo::CeilOp>(op.getLoc(), normalized);
3504 auto steps = rewriter.create<mhlo::ConvertOp>(
3505 op.getLoc(), RankedTensorType::get({}, rewriter.getI64Type()), ceil);
3506 auto reshape = rewriter.create<mhlo::ReshapeOp>(
3507 op.getLoc(), RankedTensorType::get({1}, rewriter.getI64Type()), steps);
3508
3509 // Using the resulting length compute the correct range value:
3510 //
3511 // %range = %start + %delta * iota(%size)
3512 auto out_scalar_type =
3513 RankedTensorType::get({}, getElementTypeOrSelf(result_type));
3514 auto start_out_cast =
3515 rewriter.create<mhlo::ConvertOp>(op.getLoc(), out_scalar_type, start);
3516 auto delta_out_cast =
3517 rewriter.create<mhlo::ConvertOp>(op.getLoc(), out_scalar_type, delta);
3518
3519 auto iota = rewriter.create<DynamicIotaOp>(
3520 op.getLoc(), result_type, reshape, rewriter.getI64IntegerAttr(0));
3521 auto scaled = rewriter.create<chlo::BroadcastMulOp>(
3522 op.getLoc(), result_type, iota, delta_out_cast,
3523 hlo::getBroadcastDimensionsAttr(&rewriter, iota, delta_cast));
3524 rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
3525 op, result_type, scaled, start_out_cast,
3526 hlo::getBroadcastDimensionsAttr(&rewriter, scaled, start_out_cast));
3527 return success();
3528 }
3529 };
3530
ConvertAxisAttr(Value val,ElementsAttr attr,Builder * builder)3531 ElementsAttr ConvertAxisAttr(Value val, ElementsAttr attr, Builder *builder) {
3532 auto int_attr = attr.cast<DenseIntElementsAttr>();
3533 auto type = val.getType().cast<ShapedType>();
3534
3535 SmallVector<int64_t, 6> axis;
3536 axis.reserve(int_attr.getNumElements());
3537
3538 int64_t rank = type.getRank();
3539 for (auto val : int_attr.getValues<APInt>()) {
3540 axis.push_back((val.getSExtValue() + rank) % rank);
3541 }
3542
3543 return builder->getI64TensorAttr(axis);
3544 }
3545
3546 /// Converts the LinSpace tensorflow op to a mhlo.iota op with a scaling
3547 /// and offset applied to generate the linspace values. The output tensor needs
3548 /// to have a static shape. The implementation is defined in C++ because there
3549 /// is no type inference for the iota op.
3550 class ConvertLinSpaceOp : public OpRewritePattern<TF::LinSpaceOp> {
3551 using OpRewritePattern<TF::LinSpaceOp>::OpRewritePattern;
3552
matchAndRewrite(TF::LinSpaceOp op,PatternRewriter & rewriter) const3553 LogicalResult matchAndRewrite(TF::LinSpaceOp op,
3554 PatternRewriter &rewriter) const override {
3555 auto result = op.getResult();
3556 auto result_type = result.getType().dyn_cast<ShapedType>();
3557 if (!result_type || !result_type.hasStaticShape()) {
3558 return failure();
3559 }
3560
3561 DenseIntElementsAttr num_attr;
3562 if (!matchPattern(op.num(), m_Constant(&num_attr))) {
3563 return rewriter.notifyMatchFailure(op, "Num must be a constant scalar");
3564 }
3565
3566 if (num_attr.begin() == num_attr.end()) {
3567 return rewriter.notifyMatchFailure(op, "Num must not be empty");
3568 }
3569 int64_t num = (*num_attr.begin()).getSExtValue();
3570
3571 // Calculate the scaling that needs to be applied to the iota.
3572 auto step_numerator = rewriter.create<chlo::BroadcastSubOp>(
3573 op.getLoc(), op.start().getType(), op.stop(), op.start(),
3574 hlo::getBroadcastDimensionsAttr(&rewriter, op.stop(), op.start()));
3575 Value step_denominator = rewriter.create<ConvertOp>(
3576 op.getLoc(), op.num(), result_type.getElementType());
3577 if (num > 1) {
3578 Value one = GetScalarConstOfType(result_type.getElementType(),
3579 op.getLoc(), 1, &rewriter);
3580 step_denominator = rewriter.create<chlo::BroadcastSubOp>(
3581 op.getLoc(), step_denominator.getType(), step_denominator, one,
3582 hlo::getBroadcastDimensionsAttr(&rewriter, step_denominator, one));
3583 }
3584 auto step = rewriter.create<chlo::BroadcastDivOp>(
3585 op.getLoc(), step_numerator.getType(), step_numerator, step_denominator,
3586 hlo::getBroadcastDimensionsAttr(&rewriter, step_numerator,
3587 step_denominator));
3588
3589 // Scale the iota and add the offset.
3590 auto iota = rewriter.create<IotaOp>(op.getLoc(), result_type,
3591 rewriter.getI64IntegerAttr(0));
3592 auto scaled = rewriter.create<chlo::BroadcastMulOp>(
3593 op.getLoc(), result_type, iota, step,
3594 hlo::getBroadcastDimensionsAttr(&rewriter, iota, step));
3595 rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
3596 op, result_type, scaled, op.start(),
3597 hlo::getBroadcastDimensionsAttr(&rewriter, scaled, op.start()));
3598 return success();
3599 }
3600 };
3601
3602 /// Converts a generic OpTy tensorflow op to a mhlo.reduce op over
3603 /// ReductionOp.
3604 /// `is_accumulation` controls whether it uses higher precision for the actual
3605 /// reduction. This is set to false for ops like max where there is no precision
3606 /// concerns.
3607 //
3608 // The Derived class should have a static method to return the initial value to
3609 // use for reduction:
3610 // static Value GetInitialValue(Type reduce_element_type, Location loc,
3611 // PatternRewriter *rewriter);
3612 // The reduce_element_type is guaranteed to be a float, int, or complex type
3613 // suitable for use with GetScalarConstOfType or GetScalarLimitConstOfType.
3614 template <typename Derived, typename OpTy, typename ReductionOp,
3615 bool is_accumulation = true>
3616 class GenericConvertReductionOp : public OpRewritePattern<OpTy> {
3617 using OpRewritePattern<OpTy>::OpRewritePattern;
3618
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const3619 LogicalResult matchAndRewrite(OpTy op,
3620 PatternRewriter &rewriter) const override {
3621 // TODO(b/141785544): Update this to not require static shapes.
3622 // Input shape needs to be static to convert negative indices in TensorFlow
3623 // to absolute indices required by HLO.
3624 auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
3625 if (!input_ty) return failure();
3626 ArrayRef<int64_t> input_shape = input_ty.getShape();
3627
3628 DenseIntElementsAttr dimensions;
3629 if (!matchPattern(op.reduction_indices(), m_Constant(&dimensions)))
3630 return failure();
3631
3632 // Build the final shape from input_shape and dimensions using a bitmap
3633 // to mark the reduced dimensions.
3634 SmallVector<bool, 4> reduced_dimensions_bitmap(input_shape.size(), false);
3635 SmallVector<int64_t, 4> xla_dimensions;
3636 for (const APInt &index_raw : dimensions.getValues<APInt>()) {
3637 int64_t index = index_raw.getSExtValue();
3638 int64_t rank = input_shape.size();
3639 if ((index < -rank || index >= rank)) return failure();
3640 index = (index + rank) % rank;
3641 reduced_dimensions_bitmap[index] = true;
3642 xla_dimensions.push_back(index);
3643 }
3644
3645 Location loc = op.getLoc();
3646 Type element_type = input_ty.getElementType();
3647
3648 // Only float, int, and complex types are currently supported.
3649 if (!element_type.isa<FloatType>() && !element_type.isa<IntegerType>() &&
3650 !element_type.isa<ComplexType>()) {
3651 return rewriter.notifyMatchFailure(
3652 op, "element type must be float, int, or complex type");
3653 }
3654
3655 // Convert to an accumulation type to not lose precision when doing
3656 // repeated arithmetic operations.
3657 Type reduce_element_type =
3658 is_accumulation ? GetAccumulationType(element_type) : element_type;
3659 auto casted_input =
3660 rewriter.create<ConvertOp>(loc, op.input(), reduce_element_type);
3661
3662 // Each reduction op can have a different initial value.
3663 Value init = Derived::GetInitialValue(reduce_element_type, loc, &rewriter);
3664
3665 auto reduction = rewriter.create<ReduceOp>(
3666 loc, casted_input.getResult(), init,
3667 GetI64ElementsAttr(xla_dimensions, &rewriter));
3668 BuildReduceBody<ReductionOp>(reduce_element_type, &reduction.body(),
3669 &rewriter);
3670 Value result = reduction.getResult(0);
3671
3672 // The mean op needs to divide by the product of the reduced dimensions.
3673 if (std::is_same<OpTy, TF::MeanOp>::value) {
3674 int64_t divisor_count = 1;
3675 for (size_t i = 0; i < input_shape.size(); ++i) {
3676 if (reduced_dimensions_bitmap[i]) {
3677 if (TensorType::isDynamic(input_shape[i])) {
3678 return failure();
3679 }
3680 divisor_count *= input_shape[i];
3681 }
3682 }
3683 auto divisor = GetScalarConstOfType(reduce_element_type, loc,
3684 divisor_count, &rewriter);
3685 auto broadcast_dims = GetI64ElementsAttr({}, &rewriter);
3686 result = rewriter.create<chlo::BroadcastDivOp>(
3687 loc, result, divisor.getResult(), broadcast_dims);
3688 }
3689
3690 result = rewriter.create<ConvertOp>(loc, result, element_type);
3691
3692 // Need to reshape back after the reduction if we're keeping the reduced
3693 // dimensions.
3694 if (op.keep_dims()) {
3695 result = rewriter.create<ReshapeOp>(loc, op.getType(), result);
3696 }
3697 rewriter.replaceOp(op, {result});
3698
3699 return success();
3700 }
3701 };
3702
3703 // Converts Mean op to HLO Reduce op.
3704 //
3705 // %init = constant dense<...> : tensor<T>
3706 // %sum = "mhlo.reduce"(%inp, %init) ["mhlo.add"]
3707 // {dimensions = ...}
3708 // %divisor = constant dense<...> : tensor<T>
3709 // %mean = "mhlo.divide"(%sum, %divisor)
3710 class ConvertMeanOp
3711 : public GenericConvertReductionOp<ConvertMeanOp, TF::MeanOp, AddOp> {
3712 public:
3713 using GenericConvertReductionOp::GenericConvertReductionOp;
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)3714 static Value GetInitialValue(Type reduce_element_type, Location loc,
3715 PatternRewriter *rewriter) {
3716 return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter);
3717 }
3718 };
3719
3720 // Converts Sum op to HLO Reduce op.
3721 //
3722 // %init = constant dense<...> : tensor<T>
3723 // %sum = "mhlo.reduce"(%inp, %init) ["mhlo.add"]
3724 // {dimensions = ...}
3725 class ConvertSumOp
3726 : public GenericConvertReductionOp<ConvertSumOp, TF::SumOp, AddOp> {
3727 public:
3728 using GenericConvertReductionOp::GenericConvertReductionOp;
3729
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)3730 static Value GetInitialValue(Type reduce_element_type, Location loc,
3731 PatternRewriter *rewriter) {
3732 return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter);
3733 }
3734 };
3735
3736 // Converts Max op to HLO Reduce op.
3737 //
3738 // %init = constant dense<...> : tensor<T>
3739 // %max = "mhlo.reduce"(%inp, %init) ["mhlo.maximum"]
3740 // {dimensions = ...}
3741 class ConvertMaxOp
3742 : public GenericConvertReductionOp<ConvertMaxOp, TF::MaxOp, MaxOp,
3743 /* is_accumulation= */ false> {
3744 public:
3745 using GenericConvertReductionOp::GenericConvertReductionOp;
3746
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)3747 static Value GetInitialValue(Type reduce_element_type, Location loc,
3748 PatternRewriter *rewriter) {
3749 return GetScalarLimitConstOfType(reduce_element_type, loc,
3750 hlo::kInfinityLowest, rewriter);
3751 }
3752 };
3753
3754 // Converts Min op to HLO Reduce op.
3755 //
3756 // %init = constant dense<...> : tensor<T>
3757 // %min = "mhlo.reduce"(%inp, %init) ["mhlo.minimum"]
3758 // {dimensions = ...}
3759 class ConvertMinOp
3760 : public GenericConvertReductionOp<ConvertMinOp, TF::MinOp, MinOp,
3761 /* is_accumulation= */ false> {
3762 public:
3763 using GenericConvertReductionOp::GenericConvertReductionOp;
3764
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)3765 static Value GetInitialValue(Type reduce_element_type, Location loc,
3766 PatternRewriter *rewriter) {
3767 return GetScalarLimitConstOfType(reduce_element_type, loc,
3768 hlo::kInfinityMax, rewriter);
3769 }
3770 };
3771
3772 // Converts Prod op to HLO Reduce op.
3773 //
3774 // %init = constant dense<...> : tensor<T>
3775 // %prod = "mhlo.reduce"(%inp, %init) ["mhlo.multiply"]
3776 // {dimensions = ...}
3777 class ConvertProdOp
3778 : public GenericConvertReductionOp<ConvertProdOp, TF::ProdOp, MulOp> {
3779 public:
3780 using GenericConvertReductionOp::GenericConvertReductionOp;
3781
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)3782 static Value GetInitialValue(Type reduce_element_type, Location loc,
3783 PatternRewriter *rewriter) {
3784 return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter);
3785 }
3786 };
3787
3788 // Converts All op to HLO Reduce op.
3789 //
3790 // %init = constant dense<...> : tensor<T>
3791 // %max = "mhlo.reduce"(%inp, %init) ["mhlo.and"]
3792 // {dimensions = ...}
3793 class ConvertAllOp
3794 : public GenericConvertReductionOp<ConvertAllOp, TF::AllOp, AndOp> {
3795 public:
3796 using GenericConvertReductionOp::GenericConvertReductionOp;
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)3797 static Value GetInitialValue(Type reduce_element_type, Location loc,
3798 PatternRewriter *rewriter) {
3799 return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter);
3800 }
3801 };
3802
3803 // Converts Any op to HLO Reduce op.
3804 //
3805 // %init = constant dense<...> : tensor<T>
3806 // %max = "mhlo.reduce"(%inp, %init) ["mhlo.or"]
3807 // {dimensions = ...}
3808 class ConvertAnyOp
3809 : public GenericConvertReductionOp<ConvertAnyOp, TF::AnyOp, OrOp> {
3810 public:
3811 using GenericConvertReductionOp::GenericConvertReductionOp;
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)3812 static Value GetInitialValue(Type reduce_element_type, Location loc,
3813 PatternRewriter *rewriter) {
3814 return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter);
3815 }
3816 };
3817
3818 // Converts tensorflow ArgMin or ArgMax op to mhlo operations that perform
3819 // a reduction on the original input and the corresponding index. The reduction
3820 // sub-computation selects the max (or min) value and the index for the value.
3821 // Derived: is the resulting derived class of this class.
3822 // OpTy: is TF::ArgMaxOp or TF::ArgMinOp.
3823 template <typename Derived, typename OpTy>
3824 class ConvertArgMinMaxOp : public OpRewritePattern<OpTy> {
3825 using OpRewritePattern<OpTy>::OpRewritePattern;
3826
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const3827 LogicalResult matchAndRewrite(OpTy op,
3828 PatternRewriter &rewriter) const override {
3829 RankedTensorType input_type =
3830 op.input().getType().template dyn_cast<RankedTensorType>();
3831 if (!input_type) {
3832 return failure();
3833 }
3834
3835 Type input_element_type = input_type.getElementType();
3836 // TODO(bixia): Clarify whether tf.ArgMax supports complex data types. If
3837 // tf.ArgMax doesn't support complex data types, this check can be removed.
3838 if (!input_element_type.isSignlessIntOrFloat()) return failure();
3839
3840 Location loc = op.getLoc();
3841 Value init_value =
3842 Derived::GetInitialValue(input_element_type, loc, rewriter);
3843
3844 RankedTensorType output_type =
3845 op.output().getType().template dyn_cast<RankedTensorType>();
3846 if (!output_type) {
3847 return failure();
3848 }
3849
3850 Type index_element_type = output_type.getElementType();
3851 Value index_init_value =
3852 GetScalarConstOfType(index_element_type, loc, 0, &rewriter);
3853
3854 RankedTensorType index_type =
3855 RankedTensorType::get(input_type.getShape(), index_element_type);
3856
3857 llvm::Optional<int64_t> optional_axis =
3858 GetIntegerHLOAxisFromTFAxis(op.dimension(), input_type.getRank());
3859 if (!optional_axis.hasValue()) {
3860 return failure();
3861 }
3862 int64_t axis = optional_axis.getValue();
3863
3864 IntegerAttr iota_dimension =
3865 IntegerAttr::get(rewriter.getIntegerType(64), axis);
3866 Value index_values =
3867 rewriter.create<IotaOp>(loc, index_type, iota_dimension);
3868
3869 std::vector<int64_t> dimensions = input_type.getShape();
3870 dimensions.erase(dimensions.begin() + axis);
3871 ArrayRef<int64_t> reduction_result_shape(dimensions);
3872
3873 Value operands[] = {op.input(), index_values};
3874 Value init_values[] = {init_value, index_init_value};
3875 DenseIntElementsAttr reduction_dimensions =
3876 GetI64ElementsAttr({axis}, &rewriter);
3877
3878 auto reduction = rewriter.create<ReduceOp>(
3879 loc, llvm::ArrayRef<Value>(operands),
3880 llvm::ArrayRef<Value>(init_values), reduction_dimensions);
3881 StringRef direction = Derived::GetDirection();
3882 BuildArgMinMaxReductionBody(input_element_type, index_element_type,
3883 direction, &reduction.body(), &rewriter);
3884
3885 rewriter.replaceOp(op, {reduction.getResult(1)});
3886 return success();
3887 }
3888 };
3889
3890 // Converts tensorflow ArgMax op to mhlo operations. The actual
3891 // implementation is in class ConvertArgMinMaxOp:
3892 //
3893 // %init_index = constant dense<...> : tensor<T>
3894 // %init = constant dense<...> : tensor<T>
3895 // %reduce = "mhlo.reduce"(%selected_input, %select_index, %init,
3896 // %init_index) ["mhlo.arg_max"]
3897 class ConvertArgMaxOp
3898 : public ConvertArgMinMaxOp<ConvertArgMaxOp, TF::ArgMaxOp> {
3899 public:
3900 using ConvertArgMinMaxOp::ConvertArgMinMaxOp;
3901
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter & rewriter)3902 static Value GetInitialValue(Type reduce_element_type, Location loc,
3903 PatternRewriter &rewriter) {
3904 return GetScalarLimitConstOfType(reduce_element_type, loc,
3905 hlo::kInfinityLowest, &rewriter);
3906 }
3907
GetDirection()3908 static StringRef GetDirection() { return "GT"; }
3909 };
3910
3911 // Converts TF TensorScatterUpdate op into Scatter Op with assignment:
3912 //
3913 // %result = "mhlo.scatter"(%tensor, %indices, %updates)
3914 // { dimensions = ... }
3915 //
3916 class ConvertTensorScatterUpdateOp
3917 : public OpRewritePattern<TF::TensorScatterUpdateOp> {
3918 public:
3919 using OpRewritePattern::OpRewritePattern;
3920
matchAndRewrite(TF::TensorScatterUpdateOp op,PatternRewriter & rewriter) const3921 LogicalResult matchAndRewrite(TF::TensorScatterUpdateOp op,
3922 PatternRewriter &rewriter) const override {
3923 auto tensor_ty = op.tensor().getType().dyn_cast<RankedTensorType>();
3924 auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
3925 auto updates_ty = op.updates().getType().dyn_cast<RankedTensorType>();
3926
3927 if (!tensor_ty || !indices_ty || !updates_ty) return failure();
3928 // Last dimension of the indices needs to known at compile time for
3929 // computation of the 'update_window_dims' attribute in the dimensions
3930 // struct.
3931 int64_t num_index_dims = indices_ty.getShape().back();
3932 if (ShapedType::isDynamic(num_index_dims)) return failure();
3933
3934 int64_t tensor_rank = tensor_ty.getRank();
3935 int64_t indices_rank = indices_ty.getRank();
3936 int64_t updates_rank = updates_ty.getRank();
3937
3938 int64_t window_dims = tensor_rank - num_index_dims;
3939 auto dims_attr = ScatterDimensionNumbers::get(
3940 GetI64ElementsAttrForSeq(updates_rank - window_dims, updates_rank,
3941 &rewriter),
3942 GetI64ElementsAttrForSeq(0, num_index_dims, &rewriter),
3943 GetI64ElementsAttrForSeq(0, num_index_dims, &rewriter),
3944 rewriter.getI64IntegerAttr(indices_rank - 1), rewriter.getContext());
3945
3946 Location loc = op.getLoc();
3947 auto scatter = rewriter.create<ScatterOp>(
3948 loc, op.getType(), op.tensor(), op.indices(), op.updates(), dims_attr);
3949
3950 // Build region to assign the new value.
3951 [&](Region *region) {
3952 OpBuilder::InsertionGuard guard(rewriter);
3953 Block *block = rewriter.createBlock(region);
3954
3955 // Block arguments are scalars of the given element type.
3956 Type type =
3957 RankedTensorType::get(/*shape=*/{}, tensor_ty.getElementType());
3958 block->addArguments({type, type});
3959 rewriter.create<ReturnOp>(loc, block->getArgument(1));
3960 }(&scatter.update_computation());
3961
3962 rewriter.replaceOp(op, scatter.getResult());
3963 return success();
3964 }
3965 };
3966
3967 // Converts Tile op to HLO BroadcastInDim and Reshape ops.
3968 // For shape [S1, S2] and multiples [M1, M2],
3969 // MS1 = M1 * S1; MS2 = M2 * S2
3970 //
3971 // %broadcast = mhlo.broadcast_in_dim(%input) {
3972 // broadcast_dimensions = [0, 2]
3973 // }
3974 // %result = "mhlo.reshape"(%broadcast) : (tensor<S1xM1xS2xM2xf32>)
3975 // -> tensor<MS1xMS2xf32>
3976 class ConvertTileOp : public OpRewritePattern<TF::TileOp> {
3977 public:
3978 using OpRewritePattern::OpRewritePattern;
3979
matchAndRewrite(TF::TileOp op,PatternRewriter & rewriter) const3980 LogicalResult matchAndRewrite(TF::TileOp op,
3981 PatternRewriter &rewriter) const override {
3982 auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
3983 if (!input_ty || !input_ty.hasStaticShape()) return failure();
3984 ArrayRef<int64_t> input_shape = input_ty.getShape();
3985 Type element_type = input_ty.getElementType();
3986
3987 DenseIntElementsAttr multiples;
3988 if (!matchPattern(op.multiples(), m_Constant(&multiples)) ||
3989 multiples.getType().getRank() != 1)
3990 return failure();
3991
3992 const int64_t input_shape_size = input_shape.size();
3993 if (multiples.getNumElements() != input_shape_size) return failure();
3994
3995 SmallVector<int64_t, 8> broadcasted_shape;
3996 SmallVector<int64_t, 4> broadcast_dimensions;
3997 broadcasted_shape.reserve(input_shape.size() * 2);
3998 broadcast_dimensions.reserve(input_shape.size());
3999 for (auto multiple_and_input :
4000 llvm::zip(multiples.getValues<APInt>(), input_shape)) {
4001 int64_t multiple = std::get<0>(multiple_and_input).getSExtValue();
4002 int64_t input_size = std::get<1>(multiple_and_input);
4003
4004 if (multiple < 0) return failure();
4005
4006 // Line input up with the next dimension in broadcasted_shape
4007 // when broadcasting.
4008 int64_t broadcast_dim;
4009 int64_t output_size = input_size * multiple;
4010 if (input_size == 1 || multiple == 1) {
4011 // Special case for when normal broadcasting will just work.
4012 broadcast_dim = broadcasted_shape.size();
4013 broadcasted_shape.push_back(output_size);
4014 } else {
4015 // Tiling will happen for this dimension during the ReshapeOp below.
4016 broadcasted_shape.push_back(multiple);
4017 broadcast_dim = broadcasted_shape.size();
4018 broadcasted_shape.push_back(input_size);
4019 }
4020 broadcast_dimensions.push_back(broadcast_dim);
4021 }
4022 Location loc = op.getLoc();
4023 Type broadcasted_type =
4024 RankedTensorType::get(broadcasted_shape, element_type);
4025 Type output_type = op.getType();
4026
4027 Value result = rewriter.create<BroadcastInDimOp>(
4028 loc, broadcasted_type, op.input(),
4029 GetI64ElementsAttr(broadcast_dimensions, &rewriter));
4030
4031 if (output_type != broadcasted_type) {
4032 result = rewriter.create<ReshapeOp>(loc, output_type, result);
4033 }
4034
4035 rewriter.replaceOp(op, {result});
4036
4037 return success();
4038 }
4039 };
4040
4041 template <typename OpTy, int num_dims>
4042 class ConvertMaxPoolGradOp : public OpRewritePattern<OpTy> {
4043 public:
4044 using OpRewritePattern<OpTy>::OpRewritePattern;
4045
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const4046 LogicalResult matchAndRewrite(OpTy op,
4047 PatternRewriter &rewriter) const override {
4048 Location loc = op.getLoc();
4049
4050 Type element_type =
4051 op.orig_input().getType().template cast<TensorType>().getElementType();
4052
4053 // Compute paddings using the original input and kernel shape and strides.
4054 // Here, ReduceWindow op as used as the MaxPool op is lowered to the
4055 // ReduceWindow op.
4056 auto input_ty =
4057 op.orig_input().getType().template dyn_cast<RankedTensorType>();
4058 if (!input_ty) return failure();
4059 DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr<num_dims>(
4060 input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter);
4061
4062 auto result = rewriter.create<SelectAndScatterOp>(
4063 loc, op.getType(), op.orig_input(), op.grad(),
4064 GetScalarConstOfType(element_type, loc, 0, &rewriter),
4065 GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()),
4066 paddings_attr);
4067
4068 BuildReduceBody<AddOp>(element_type, &result.scatter(), &rewriter);
4069 {
4070 OpBuilder::InsertionGuard guard(rewriter);
4071 Block *block = rewriter.createBlock(&result.select());
4072
4073 // Block arguments are scalars of the given element type.
4074 Type type = RankedTensorType::get(/*shape=*/{}, element_type);
4075 block->addArguments({type, type});
4076
4077 auto reducer = rewriter.create<CompareOp>(
4078 loc, block->getArgument(0), block->getArgument(1),
4079 StringAttr::get(rewriter.getContext(), "GE"));
4080 rewriter.create<ReturnOp>(loc, reducer.getResult());
4081 }
4082
4083 rewriter.replaceOp(op, {result});
4084
4085 return success();
4086 }
4087 };
4088
4089 using ConvertMaxPool2DGradOp =
4090 ConvertMaxPoolGradOp<TF::MaxPoolGradOp, /*num_dims=*/4>;
4091 using ConvertMaxPool3DGradOp =
4092 ConvertMaxPoolGradOp<TF::MaxPool3DGradOp, /*num_dims=*/5>;
4093
4094 // Converts tf.Conv?DBackpropInputOp into:
4095 // %rev_filter = "mhlo.reverse"(%filter)
4096 // %result = "mhlo.convolution"(%out_backprop, %rev_filter)
4097 template <typename OpTy, int num_spatial_dims>
4098 class ConvertConvBackpropInputOp : public OpRewritePattern<OpTy> {
4099 public:
4100 using OpRewritePattern<OpTy>::OpRewritePattern;
4101
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const4102 LogicalResult matchAndRewrite(OpTy op,
4103 PatternRewriter &rewriter) const override {
4104 // Unpack all of the attributes.
4105 tensorflow::TensorFormat data_format;
4106 if (!FormatFromString(op.data_format().str(), &data_format))
4107 return failure();
4108
4109 tensorflow::Padding padding;
4110 if (!GetPaddingFromString(op.padding().str(), &padding).ok())
4111 return failure();
4112
4113 auto out_backprop_ty =
4114 op.out_backprop().getType().template dyn_cast<RankedTensorType>();
4115 auto filter_ty =
4116 op.filter().getType().template dyn_cast<RankedTensorType>();
4117
4118 for (RankedTensorType ty : {out_backprop_ty, filter_ty})
4119 if (!ty || !ty.hasStaticShape()) return failure();
4120
4121 DenseIntElementsAttr input_shape_attr;
4122 if (!matchPattern(op.input_sizes(), m_Constant(&input_shape_attr)) ||
4123 input_shape_attr.getType().getRank() != 1)
4124 return failure();
4125
4126 auto input_shape = input_shape_attr.getValues<int32_t>();
4127
4128 auto dilations_attr = GetI64ElementsAttr(op.dilations());
4129 std::vector<int> dilations{
4130 dilations_attr.template getValues<int64_t>().begin(),
4131 dilations_attr.template getValues<int64_t>().end()};
4132 auto strides_attr = GetI64ElementsAttr(op.strides());
4133 std::vector<tensorflow::int32> strides{
4134 strides_attr.template getValues<int64_t>().begin(),
4135 strides_attr.template getValues<int64_t>().end()};
4136
4137 std::vector<tensorflow::int64> explicit_paddings;
4138 if (padding == tensorflow::Padding::EXPLICIT) {
4139 // EXPLICIT padding mode and the associated attribute is limited to
4140 // Conv2DBackpropInput. So, fetch attribute by identifier instead of the
4141 // op.explicit_paddings() attribute getter.
4142 ArrayRef<Attribute> explicit_paddings_attr =
4143 op->template getAttrOfType<ArrayAttr>("explicit_paddings").getValue();
4144 explicit_paddings.reserve(explicit_paddings_attr.size());
4145 for (Attribute explicit_padding : explicit_paddings_attr)
4146 explicit_paddings.push_back(
4147 explicit_padding.cast<IntegerAttr>().getInt());
4148 }
4149
4150 constexpr int num_dims = num_spatial_dims + 2;
4151 ArrayRef<int64_t> filter_shape = filter_ty.getShape();
4152
4153 // Reuse dimension computation logic from conv_grad_shape_utils.cc.
4154 tensorflow::ConvBackpropDimensions dims;
4155 if (!tensorflow::ConvBackpropComputeDimensionsV2(
4156 /*label=*/"", num_spatial_dims,
4157 ToTensorShape<int32_t, num_dims>(input_shape),
4158 ToTensorShape<int64_t, num_dims>(filter_shape),
4159 ToTensorShape<int64_t, num_dims>(out_backprop_ty.getShape()),
4160 dilations, strides, padding, explicit_paddings, data_format, &dims)
4161 .ok()) {
4162 return failure();
4163 }
4164
4165 // Compute ConvDimensionNumbers, dilation, and padding.
4166 SmallVector<int64_t, num_spatial_dims> spatial_dims;
4167 SmallVector<int64_t, num_spatial_dims> lhs_dilation;
4168 SmallVector<int64_t, num_spatial_dims> rhs_dilation;
4169 SmallVector<int64_t, num_spatial_dims * 2> paddings;
4170
4171 for (int i : llvm::seq<int>(0, num_spatial_dims)) {
4172 const int64_t dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
4173 spatial_dims.push_back(dim);
4174 const auto &spatial_dim_i = dims.spatial_dims[i];
4175 lhs_dilation.push_back(spatial_dim_i.stride);
4176 rhs_dilation.push_back(dilations[dim]);
4177 paddings.push_back(spatial_dim_i.pad_before);
4178 paddings.push_back(spatial_dim_i.pad_after);
4179 }
4180
4181 RankedTensorType paddings_ty = RankedTensorType::get(
4182 {num_spatial_dims, 2}, rewriter.getIntegerType(64));
4183 auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, paddings);
4184
4185 auto spatial_dims_attr = GetI64ElementsAttr(spatial_dims, &rewriter);
4186
4187 Value filter = op.filter();
4188
4189 const int feature_dim =
4190 tensorflow::GetTensorFeatureDimIndex(num_dims, data_format);
4191 const int64_t in_depth = *(input_shape.begin() + feature_dim);
4192 const int64_t filter_in_depth = filter_shape[num_spatial_dims];
4193 const int64_t feature_group_count = in_depth / filter_in_depth;
4194
4195 if (feature_group_count != 1) {
4196 // 1. Reshape filter from
4197 // [H, W, ..., filter_in_depth, out_depth] to
4198 // [H, W, ..., filter_in_depth, G, out_depth / G].
4199 auto new_shape = llvm::to_vector<6>(filter_shape);
4200 new_shape.back() = feature_group_count;
4201 new_shape.push_back(filter_shape.back() / feature_group_count);
4202 Type filter_element_ty = filter_ty.getElementType();
4203 auto ty = RankedTensorType::get(new_shape, filter_element_ty);
4204 filter = rewriter.create<ReshapeOp>(op.getLoc(), ty, filter);
4205
4206 // 2. Transpose to [H, W, ..., G, filter_in_depth, out_depth / G].
4207 llvm::SmallVector<int64_t, 6> perm(num_dims + 1);
4208 std::iota(perm.begin(), perm.end(), 0);
4209 std::swap(perm[num_spatial_dims], perm[num_spatial_dims + 1]);
4210 std::swap(new_shape[num_spatial_dims], new_shape[num_spatial_dims + 1]);
4211 ty = RankedTensorType::get(new_shape, filter_element_ty);
4212 filter = rewriter.create<TransposeOp>(
4213 op.getLoc(), ty, filter, GetI64ElementsAttr(perm, &rewriter));
4214
4215 // 3. Reshape to [H, W, ..., in_depth, out_depth / G].
4216 new_shape[num_spatial_dims] *= new_shape[num_spatial_dims + 1];
4217 new_shape[num_spatial_dims + 1] = new_shape.back();
4218 new_shape.pop_back();
4219 ty = RankedTensorType::get(new_shape, filter_element_ty);
4220 filter = rewriter.create<ReshapeOp>(op.getLoc(), ty, filter);
4221 }
4222
4223 auto kernel_spatial_dims_attr =
4224 GetI64ElementsAttrForSeq(0, num_spatial_dims, &rewriter);
4225
4226 // Mirror the filter in the spatial dimensions.
4227 filter = rewriter.create<ReverseOp>(op.getLoc(), filter,
4228 kernel_spatial_dims_attr);
4229
4230 const int batch_dim =
4231 tensorflow::GetTensorBatchDimIndex(num_dims, data_format);
4232 auto batch_dim_attr = rewriter.getI64IntegerAttr(batch_dim);
4233 auto feature_dim_attr = rewriter.getI64IntegerAttr(feature_dim);
4234
4235 // activation gradients
4236 // = gradients (with padding and dilation) <conv> mirrored_weights
4237 Value result = rewriter.create<ConvOp>(
4238 op.getLoc(), op.getType(), op.out_backprop(), filter,
4239 /*window_strides=*/
4240 GetI64ElementsAttrForValue(/*size=*/num_spatial_dims, /*val=*/1,
4241 &rewriter),
4242 /*padding=*/paddings_attr, GetI64ElementsAttr(lhs_dilation, &rewriter),
4243 GetI64ElementsAttr(rhs_dilation, &rewriter),
4244 /*window_reversal=*/nullptr,
4245 ConvDimensionNumbers::get(
4246 /*input_batch_dimension=*/batch_dim_attr,
4247 /*input_feature_dimension=*/feature_dim_attr,
4248 /*input_spatial_dimensions=*/spatial_dims_attr,
4249 // TF filter shape is [ H, W, ..., inC, outC ]
4250 // Transpose the input and output features for computing the
4251 // gradient.
4252 /*kernel_input_feature_dimension=*/
4253 rewriter.getI64IntegerAttr(num_spatial_dims + 1),
4254 /*kernel_output_feature_dimension=*/
4255 rewriter.getI64IntegerAttr(num_spatial_dims),
4256 /*kernel_spatial_dimensions=*/kernel_spatial_dims_attr,
4257 /*output_batch_dimension=*/batch_dim_attr,
4258 /*output_feature_dimension=*/feature_dim_attr,
4259 /*output_spatial_dimensions=*/spatial_dims_attr,
4260 rewriter.getContext()),
4261 rewriter.getI64IntegerAttr(feature_group_count),
4262 /*batch_group_count=*/rewriter.getI64IntegerAttr(1),
4263 /*precision_config=*/ArrayAttr());
4264
4265 rewriter.replaceOp(op, {result});
4266
4267 return success();
4268 }
4269 };
4270
4271 using ConvertConv2DBackpropInputOp =
4272 ConvertConvBackpropInputOp<TF::Conv2DBackpropInputOp,
4273 /*num_spatial_dims=*/2>;
4274 using ConvertConv3DBackpropInputOp =
4275 ConvertConvBackpropInputOp<TF::Conv3DBackpropInputV2Op,
4276 /*num_spatial_dims=*/3>;
4277
4278 // Converts tf.Conv?DBackpropFilterOp into:
4279 // %result = "mhlo.convolution"(%input, %out_backprop)
4280 template <typename OpTy, int num_spatial_dims>
4281 class ConvertConvBackpropFilterOp : public OpRewritePattern<OpTy> {
4282 public:
4283 using OpRewritePattern<OpTy>::OpRewritePattern;
4284
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const4285 LogicalResult matchAndRewrite(OpTy op,
4286 PatternRewriter &rewriter) const override {
4287 // Unpack all of the attributes.
4288 tensorflow::TensorFormat data_format;
4289 if (!FormatFromString(op.data_format().str(), &data_format))
4290 return failure();
4291
4292 tensorflow::Padding padding;
4293 if (!GetPaddingFromString(op.padding().str(), &padding).ok())
4294 return failure();
4295
4296 auto out_backprop_ty =
4297 op.out_backprop().getType().template dyn_cast<RankedTensorType>();
4298 auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
4299
4300 for (RankedTensorType ty : {out_backprop_ty, input_ty})
4301 if (!ty || !ty.hasStaticShape()) return failure();
4302
4303 ArrayRef<int64_t> out_backprop_shape = out_backprop_ty.getShape();
4304 ArrayRef<int64_t> input_shape = input_ty.getShape();
4305
4306 DenseIntElementsAttr filter_shape_attr;
4307 if (!matchPattern(op.filter_sizes(), m_Constant(&filter_shape_attr)) ||
4308 filter_shape_attr.getType().getRank() != 1)
4309 return failure();
4310
4311 auto dilations_attr = GetI64ElementsAttr(op.dilations());
4312 std::vector<int> dilations{
4313 dilations_attr.template getValues<int64_t>().begin(),
4314 dilations_attr.template getValues<int64_t>().end()};
4315 auto strides_attr = GetI64ElementsAttr(op.strides());
4316 std::vector<tensorflow::int32> strides{
4317 strides_attr.template getValues<int64_t>().begin(),
4318 strides_attr.template getValues<int64_t>().end()};
4319
4320 std::vector<tensorflow::int64> explicit_paddings;
4321 if (padding == tensorflow::Padding::EXPLICIT) {
4322 // EXPLICIT padding mode and the associated attribute is limited to
4323 // Conv2DBackpropFilter. So, fetch attribute by identifier instead of the
4324 // op.explicit_paddings() attribute getter.
4325 ArrayRef<Attribute> explicit_paddings_attr =
4326 op->template getAttrOfType<ArrayAttr>("explicit_paddings").getValue();
4327 explicit_paddings.reserve(explicit_paddings_attr.size());
4328 for (Attribute explicit_padding : explicit_paddings_attr)
4329 explicit_paddings.push_back(
4330 explicit_padding.cast<IntegerAttr>().getInt());
4331 }
4332
4333 constexpr int num_dims = num_spatial_dims + 2;
4334 auto filter_shape = filter_shape_attr.getValues<int32_t>();
4335
4336 // Reuse dimension computation logic from conv_grad_shape_utils.cc.
4337 tensorflow::ConvBackpropDimensions dims;
4338 if (!tensorflow::ConvBackpropComputeDimensionsV2(
4339 /*label=*/"", num_spatial_dims,
4340 ToTensorShape<int64_t, num_dims>(input_shape),
4341 ToTensorShape<int32_t, num_dims>(filter_shape),
4342 ToTensorShape<int64_t, num_dims>(out_backprop_shape), dilations,
4343 strides, padding, explicit_paddings, data_format, &dims)
4344 .ok()) {
4345 return failure();
4346 }
4347
4348 // The activations (inputs) form the LHS of the convolution.
4349 // Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
4350 // For the gradient computation, we need to:
4351 // 1. In the case of group convolution, move the num_groups dimension before
4352 // the batch dimension
4353 // 2. Swap the roles of the batch and feature dimensions.
4354 const int feature_dim =
4355 tensorflow::GetTensorFeatureDimIndex(num_dims, data_format);
4356 const int64_t in_depth = input_shape[feature_dim];
4357 const int64_t filter_in_depth = *(filter_shape.begin() + num_spatial_dims);
4358 const int64_t batch_group_count = in_depth / filter_in_depth;
4359
4360 // Compute ConvDimensionNumbers, dilation, and padding.
4361 SmallVector<int64_t, num_spatial_dims> spatial_dims;
4362 SmallVector<int64_t, num_spatial_dims> kernel_spatial_dims;
4363 SmallVector<int64_t, num_spatial_dims> rhs_dilation;
4364 SmallVector<int64_t, num_spatial_dims * 2> paddings;
4365 SmallVector<int64_t, num_spatial_dims> window_strides;
4366
4367 // The filter gradients are computed by a convolution of the input
4368 // activations and the output gradients, with some appropriate padding.
4369 // See the comment at the top of conv_grad_ops.h for details.
4370
4371 for (int i : llvm::seq<int>(0, num_spatial_dims)) {
4372 const int64_t dim =
4373 tensorflow::GetTensorSpatialDimIndex(num_dims, data_format, i);
4374 kernel_spatial_dims.push_back(dim);
4375 // Besides padding the input, we will also expand output_rows to
4376 // expanded_out_rows = (output_rows - 1) * stride + 1
4377 // with zeros in between:
4378 //
4379 // a . . . b . . . c . . . d . . . e
4380 //
4381 // This is done by specifying the window dilation factors in the
4382 // convolution HLO below.
4383 const auto &spatial_dim_i = dims.spatial_dims[i];
4384 rhs_dilation.push_back(spatial_dim_i.stride);
4385 window_strides.push_back(dilations[dim]);
4386
4387 // We will also need to pad the input with zeros such that after the
4388 // convolution, we get the right size for the filter.
4389 // The padded_in_rows should be such that when we convolve this with the
4390 // expanded_out_rows as a filter, we should get filter_rows back.
4391
4392 const int64_t padded_in_size =
4393 spatial_dim_i.expanded_output_size +
4394 (spatial_dim_i.filter_size - 1) * dilations[dim];
4395
4396 // However it can be smaller than input_rows: in this
4397 // case it means some of the inputs are not used.
4398 //
4399 // An example is to have input_cols = 3, filter_cols = 2 and stride = 2:
4400 //
4401 // INPUT = [ A B C ]
4402 //
4403 // FILTER = [ x y ]
4404 //
4405 // and the output will only have one column: a = A * x + B * y
4406 //
4407 // and input "C" is not used at all.
4408 //
4409 // We apply negative padding in this case.
4410 const int64_t pad_total = padded_in_size - spatial_dim_i.input_size;
4411
4412 // + For the EXPLICIT padding, we pad the top/left side with the explicit
4413 // padding and pad the bottom/right side with the remaining space.
4414 // + For the VALID padding, we don't pad anything on the top/left side
4415 // and pad the bottom/right side with the remaining space.
4416 // + For the SAME padding, we pad top/left side the same as bottom/right
4417 // side.
4418 //
4419 // In addition, if the padded input size is smaller than the input size,
4420 // we need to ignore some training elements of the input. We do this by
4421 // applying negative padding on the right/bottom.
4422 const int64_t pad_before = padding == tensorflow::Padding::EXPLICIT
4423 ? explicit_paddings[2 * dim]
4424 : padding == tensorflow::Padding::SAME
4425 ? std::max<int64_t>(pad_total / 2, 0)
4426 : 0;
4427 paddings.push_back(pad_before);
4428 paddings.push_back(pad_total - pad_before);
4429 }
4430
4431 RankedTensorType paddings_ty = RankedTensorType::get(
4432 {num_spatial_dims, 2}, rewriter.getIntegerType(64));
4433 auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, paddings);
4434 auto kernel_spatial_dims_attr =
4435 GetI64ElementsAttr(kernel_spatial_dims, &rewriter);
4436
4437 const int batch_dim =
4438 tensorflow::GetTensorBatchDimIndex(num_dims, data_format);
4439 auto batch_dim_attr = rewriter.getI64IntegerAttr(batch_dim);
4440 auto feature_dim_attr = rewriter.getI64IntegerAttr(feature_dim);
4441
4442 Value result = rewriter.create<ConvOp>(
4443 op.getLoc(), op.getType(), op.input(), op.out_backprop(),
4444 /*window_strides=*/GetI64ElementsAttr(window_strides, &rewriter),
4445 /*padding=*/paddings_attr, /*lhs_dilation=*/
4446 GetI64ElementsAttrForValue(/*size=*/num_spatial_dims, /*val=*/1,
4447 &rewriter),
4448 GetI64ElementsAttr(rhs_dilation, &rewriter),
4449 /*window_reversal=*/nullptr,
4450 ConvDimensionNumbers::get(
4451 // Swap batch_dim and feature_dim in the activations.
4452 /*input_batch_dimension=*/feature_dim_attr,
4453 /*input_feature_dimension=*/batch_dim_attr,
4454 /*input_spatial_dimensions=*/kernel_spatial_dims_attr,
4455 // The gradients become the RHS of the convolution.
4456 // The gradients have shape [batch, out_rows, out_cols, ...,
4457 // out_depth] where the batch becomes the input feature for the
4458 // convolution.
4459 /*kernel_input_feature_dimension=*/batch_dim_attr,
4460 /*kernel_output_feature_dimension=*/feature_dim_attr,
4461 /*kernel_spatial_dimensions=*/kernel_spatial_dims_attr,
4462 /*output_batch_dimension=*/
4463 rewriter.getI64IntegerAttr(num_spatial_dims),
4464 /*output_feature_dimension=*/
4465 rewriter.getI64IntegerAttr(num_spatial_dims + 1),
4466 /*output_spatial_dimensions=*/
4467 GetI64ElementsAttrForSeq(0, num_spatial_dims, &rewriter),
4468 rewriter.getContext()),
4469 /*feature_group_count=*/rewriter.getI64IntegerAttr(1),
4470 rewriter.getI64IntegerAttr(batch_group_count),
4471 /*precision_config=*/ArrayAttr());
4472
4473 rewriter.replaceOp(op, {result});
4474
4475 return success();
4476 }
4477 };
4478
4479 using ConvertConv2DBackpropFilterOp =
4480 ConvertConvBackpropFilterOp<TF::Conv2DBackpropFilterOp,
4481 /*num_spatial_dims=*/2>;
4482 using ConvertConv3DBackpropFilterOp =
4483 ConvertConvBackpropFilterOp<TF::Conv3DBackpropFilterV2Op,
4484 /*num_spatial_dims=*/3>;
4485
4486 class ConvertOneHotOp : public OpRewritePattern<TF::OneHotOp> {
4487 public:
4488 using OpRewritePattern::OpRewritePattern;
4489
matchAndRewrite(TF::OneHotOp op,PatternRewriter & rewriter) const4490 LogicalResult matchAndRewrite(TF::OneHotOp op,
4491 PatternRewriter &rewriter) const override {
4492 auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
4493 if (!indices_ty || !indices_ty.hasStaticShape()) return failure();
4494 ArrayRef<int64_t> indices_shape = indices_ty.getShape();
4495 Type element_type = indices_ty.getElementType();
4496
4497 DenseIntElementsAttr depth_attr;
4498 if (!matchPattern(op.depth(), m_Constant(&depth_attr))) {
4499 return failure();
4500 }
4501
4502 int64_t depth = depth_attr.getValue<APInt>({}).getSExtValue();
4503 int64_t axis = op.axis();
4504 if (axis == -1) axis = indices_shape.size();
4505
4506 llvm::SmallVector<int64_t, 4> broadcast_dims(indices_shape.size());
4507 std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
4508 std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
4509
4510 llvm::SmallVector<int64_t, 4> output_dims =
4511 llvm::to_vector<4>(indices_shape);
4512 output_dims.insert(output_dims.begin() + axis, depth);
4513
4514 Location loc = op.getLoc();
4515
4516 // The iota result is the effective output shape of the computation,
4517 // and indices must be broadcast into it. At this point, this computation
4518 // would need to be reworked quite a bit to support dynamic shapes, so
4519 // just using static broadcasting.
4520 auto index_type = RankedTensorType::get(output_dims, element_type);
4521 auto iota = rewriter.create<IotaOp>(
4522 loc, index_type, IntegerAttr::get(rewriter.getIntegerType(64), axis));
4523 auto broadcast_indices = rewriter.create<BroadcastInDimOp>(
4524 loc, index_type, op.indices(),
4525 GetI64ElementsAttr(broadcast_dims, &rewriter));
4526
4527 Value compare = rewriter.create<mhlo::CompareOp>(
4528 loc, broadcast_indices, iota,
4529 StringAttr::get(rewriter.getContext(), "EQ"));
4530 Value on_value = rewriter.create<BroadcastOp>(
4531 loc, op.getType(), op.on_value(),
4532 GetI64ElementsAttr(output_dims, &rewriter));
4533 Value off_value = rewriter.create<BroadcastOp>(
4534 loc, op.getType(), op.off_value(),
4535 GetI64ElementsAttr(output_dims, &rewriter));
4536 Value result = rewriter.create<SelectOp>(loc, op.getType(), compare,
4537 on_value, off_value);
4538
4539 rewriter.replaceOp(op, {result});
4540
4541 return success();
4542 }
4543 };
4544
4545 // Converts InfeedDequeueTuple to XLA HLO create_token, infeed and
4546 // get_tuple_element ops.
4547 //
4548 // All HLO infeed ops expect a HLO token type operand and produce a tuple
4549 // containing a token. This HLO token type is used to order multiple infeed
4550 // operations within a computation. The token type can come from other
4551 // infeed/outfeed/send/recv ops or can be generated using create_token op with
4552 // no operands. Here we emit a create_token op to generate the token type
4553 // operand of infeed.
4554 //
4555 // For example the following IR:
4556 // %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<3xi32>, tensor<4xf32>)
4557 //
4558 // would be lowered to
4559 //
4560 // %token = "mhlo.create_token"() : () -> !mhlo.token
4561 // %data_and_token = "mhlo.infeed"(%token) {infeed_config = ""} :
4562 // (!mhlo.token) -> tuple<tuple<tensor<3xi32>, tensor<4xf32>>,
4563 // !mhlo.token>
4564 // %data = "mhlo.get_tuple_element"(%data_and_token) {index = 0}
4565 // %0#0 = "mhlo.get_tuple_element"(%data) {index = 0}
4566 // %0#1 = "mhlo.get_tuple_element"(%data) {index = 1}
4567 //
4568 class ConvertInfeedDequeueTupleOp
4569 : public OpRewritePattern<TF::InfeedDequeueTupleOp> {
4570 public:
4571 using OpRewritePattern::OpRewritePattern;
4572
GetLayout(const Type & type,PatternRewriter & rewriter) const4573 Attribute GetLayout(const Type &type, PatternRewriter &rewriter) const {
4574 auto i64_type = rewriter.getIntegerType(64);
4575 if (type.isa<TupleType>()) {
4576 TupleType tuple_type = type.dyn_cast<TupleType>();
4577 std::vector<mlir::Attribute> v;
4578 for (const mlir::Type &t : tuple_type.getTypes()) {
4579 v.push_back(GetLayout(t, rewriter));
4580 }
4581 ArrayRef<Attribute> shape(v);
4582 return rewriter.getArrayAttr(shape);
4583 } else if (type.isa<RankedTensorType>()) {
4584 RankedTensorType t = type.dyn_cast<RankedTensorType>();
4585 std::vector<mlir::Attribute> attrs;
4586 std::vector<Attribute> elements;
4587 // Tuples are always serialized with an ascending layout. See
4588 // LiteralLinearizer::LinearizeToBuffers.
4589 for (int64_t i = 0; i < t.getRank(); i++) {
4590 elements.push_back(rewriter.getIntegerAttr(i64_type, i));
4591 }
4592 return rewriter.getArrayAttr(elements);
4593 } else {
4594 return rewriter.getUnitAttr(); // e.g. tokens
4595 }
4596 }
4597
matchAndRewrite(TF::InfeedDequeueTupleOp op,PatternRewriter & rewriter) const4598 LogicalResult matchAndRewrite(TF::InfeedDequeueTupleOp op,
4599 PatternRewriter &rewriter) const override {
4600 std::vector<Type> result_types(op.outputs().size());
4601 for (auto idx_and_output : llvm::enumerate(op.outputs())) {
4602 result_types[idx_and_output.index()] = (idx_and_output.value().getType());
4603 }
4604 // Infeed takes a single token operand. Generate the token using
4605 // create_token op to pass to the infeed op.
4606 auto token = rewriter.create<CreateTokenOp>(
4607 op.getLoc(), mhlo::TokenType::get(rewriter.getContext()));
4608
4609 // Emit infeed op.
4610 // The result type of infeed is a tuple(tuple(result types), token type).
4611 auto data_tuple_type =
4612 mlir::TupleType::get(rewriter.getContext(), result_types);
4613 auto data_and_token_type = mlir::TupleType::get(
4614 rewriter.getContext(), {data_tuple_type, token.getType()});
4615
4616 ArrayAttr layout =
4617 GetLayout(data_and_token_type, rewriter).cast<ArrayAttr>();
4618 auto data_and_token =
4619 rewriter.create<InfeedOp>(op.getLoc(), data_and_token_type, token,
4620 /*infeed_config=*/rewriter.getStringAttr(""),
4621 /*layout=*/layout);
4622
4623 // TODO(b/171212005): Reenable layout.
4624 data_and_token.removeAttr("layout");
4625
4626 if (op._XlaSharding().hasValue()) {
4627 // _XlaSharding attribute in TF is a serialized string of the OpSharding
4628 // proto, so convert to a text form here.
4629 ::xla::OpSharding sharding_proto;
4630 if (!sharding_proto.ParseFromString(op._XlaSharding().getValue().str()))
4631 return failure();
4632
4633 // Token is a control signal and not a real data, so arbitrarily assign
4634 // the token to device 0.
4635 if (sharding_proto.type() == ::xla::OpSharding::TUPLE) {
4636 *sharding_proto.add_tuple_shardings() =
4637 ::xla::sharding_builder::AssignDevice(0);
4638 data_and_token->setAttr(
4639 kShardingAttr,
4640 rewriter.getStringAttr(sharding_proto.SerializeAsString()));
4641 } else {
4642 data_and_token->setAttr(kShardingAttr, op._XlaShardingAttr());
4643 }
4644 }
4645
4646 // The infeed instruction produces a tuple of the infeed data and a token
4647 // type. Emit get_tuple_element to get infeed data tuple.
4648 auto data_tuple = rewriter.create<GetTupleElementOp>(
4649 op.getLoc(), data_tuple_type, data_and_token,
4650 rewriter.getI32IntegerAttr(0));
4651
4652 // Emit get_tuple_element for each result.
4653 std::vector<Value> results;
4654 for (auto idx_and_type : llvm::enumerate(result_types)) {
4655 auto tuple_element = rewriter.create<GetTupleElementOp>(
4656 op.getLoc(), idx_and_type.value(), data_tuple,
4657 rewriter.getI32IntegerAttr(idx_and_type.index()));
4658 results.push_back(tuple_element);
4659 }
4660 rewriter.replaceOp(op, ValueRange(results));
4661 return success();
4662 }
4663 };
4664
4665 // Converts tf.OutfeedEnqueueTuple to XLA HLO tuple, create_token and outfeed
4666 // ops.
4667 //
4668 // XLA HLO outfeed op expects a token, which we generate by emitting an
4669 // create_token op.
4670 //
4671 // For example the following IR:
4672 // "tf.OutfeedEnqueueTuple"(%val_1, %val_2) : (tensor<3xi32>, tensor<4xf32>) ->
4673 // ()
4674 //
4675 // would be lowered to
4676 //
4677 // %tuple = "mhlo.tuple"(%val_1, %val_2) : (tensor<3xi32>, tensor<4xf32>) ->
4678 // tuple<tensor<3xi32>, tensor<4xf32>>
4679 // %token = "mhlo.create_token"() : () -> !mhlo.token
4680 // %outfeed_token = "mhlo.outfeed"(%tuple, %token) {outfeed_config = ""} :
4681 // (tuple<tensor<3xi32>, tensor<4xf32>>, !mhlo.token) -> !mhlo.token
4682 //
4683 class ConvertOutfeedEnqueueTupleOp
4684 : public OpRewritePattern<TF::OutfeedEnqueueTupleOp> {
4685 public:
4686 using OpRewritePattern::OpRewritePattern;
4687
matchAndRewrite(TF::OutfeedEnqueueTupleOp op,PatternRewriter & rewriter) const4688 LogicalResult matchAndRewrite(TF::OutfeedEnqueueTupleOp op,
4689 PatternRewriter &rewriter) const override {
4690 auto token_type = mhlo::TokenType::get(rewriter.getContext());
4691 auto tuple = rewriter.create<TupleOp>(op.getLoc(), op.inputs());
4692 auto token = rewriter.create<CreateTokenOp>(op.getLoc(), token_type);
4693 rewriter.create<OutfeedOp>(op.getLoc(), token_type, tuple, token,
4694 /*outfeed_config=*/rewriter.getStringAttr(""));
4695 rewriter.eraseOp(op);
4696 return success();
4697 }
4698 };
4699
4700 // Converts tf.TopKV2 to XLA HLO iota, sort, and slice ops when k is a constant.
4701 //
4702 // tf.TopKV2 sorts along last dimension of the input tensor and then returns
4703 // the top K components' values and indices. This is translated into a few
4704 // ops in XLA HLO: first generating an integer sequence for the indices,
4705 // then sort both the original input tensor and the indices togheter, and
4706 // at last slice out the top K components.
4707 //
4708 // For example, for the following IR:
4709 //
4710 // %k = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32>
4711 // %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor<i32>) ->
4712 // (tensor<16x8xf32>, tensor<16x8xi32>)
4713 //
4714 // We will get:
4715 //
4716 // %1 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<16x16xi32>
4717 // %2 = "mhlo.sort"(%input, %1) ( {
4718 // ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>,
4719 // %arg3: tensor<i32>, %arg4: tensor<i32>):
4720 // %7 = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"}: ...
4721 // "mhlo.return"(%7) : (tensor<i1>) -> ()
4722 // }) {dimension = 1 : i64, is_stable = true} : ...
4723 // %3 = "mhlo.get_tuple_element"(%2) {index = 0 : i32} : ...
4724 // %4 = "mhlo.get_tuple_element"(%2) {index = 1 : i32} : ...
4725 // %5 = "mhlo.slice"(%3) {limit_indices = dense<[16, 8]> : tensor<2xi64>,
4726 // start_indices dense<0> : tensor<2xi64>,
4727 // strides = dense<1> : tensor<2xi64>} :
4728 // (tensor<16x16xf32>) -> tensor<16x8xf32>
4729 // %6 = "mhlo.slice"(%4) ...
4730 class ConvertTopKV2Op : public OpRewritePattern<TF::TopKV2Op> {
4731 public:
4732 using OpRewritePattern::OpRewritePattern;
4733
matchAndRewrite(TF::TopKV2Op op,PatternRewriter & rewriter) const4734 LogicalResult matchAndRewrite(TF::TopKV2Op op,
4735 PatternRewriter &rewriter) const override {
4736 // We can only match when the `k` operand is a constant scalar.
4737 DenseIntElementsAttr k_attr;
4738 if (!matchPattern(op.k(), m_Constant(&k_attr))) return failure();
4739
4740 // The last dimension of the input tensor's shape should be known so we can
4741 // have clamped end_indices for slices.
4742 TensorType input_type = op.input().getType().cast<TensorType>();
4743 if (!input_type.hasRank()) return failure();
4744 int64_t input_rank = input_type.getRank();
4745 int64_t last_dim_index = input_rank - 1;
4746 int64_t last_dim_size = input_type.getDimSize(last_dim_index);
4747 if (last_dim_size == ShapedType::kDynamicSize) return failure();
4748
4749 // Create an Itoa op for indices.
4750 auto i32_type = rewriter.getIntegerType(32);
4751 Type iota_type = RankedTensorType::get(input_type.getShape(), i32_type);
4752 Value iota_op = rewriter.create<mhlo::IotaOp>(
4753 op.getLoc(), iota_type, rewriter.getI64IntegerAttr(last_dim_index));
4754
4755 // Create the sort op. It takes two inputs, one for the original input, the
4756 // other for the indices.
4757 auto sort_op = rewriter.create<mhlo::SortOp>(
4758 op.getLoc(), llvm::ArrayRef<Value>{op.input(), iota_op}, last_dim_index,
4759 /*is_stable=*/true);
4760
4761 // Use TOTALORDER comparison type instead of the default comparison if the
4762 // element type is of type float.
4763 llvm::Optional<StringRef> compare_type;
4764 if (input_type.getElementType().isa<FloatType>())
4765 compare_type.emplace("TOTALORDER");
4766 BuildSortComparisonBody({input_type.getElementType(), i32_type},
4767 /*direction=*/"GT", compare_type,
4768 &sort_op.comparator(), &rewriter);
4769
4770 // Get the sorted input and index tuple element.
4771 auto tuple_first_element = sort_op.getResult(0);
4772 auto tuple_second_element = sort_op.getResult(1);
4773
4774 SmallVector<int64_t, 4> begin_indices(input_rank, 0);
4775 auto end_indices = llvm::to_vector<4>(input_type.getShape());
4776 end_indices.back() =
4777 std::min((*k_attr.begin()).getSExtValue(), last_dim_size);
4778 SmallVector<int64_t, 4> strides(input_rank, 1);
4779
4780 // Get the slice for the top K elements.
4781
4782 Value values = rewriter.create<mhlo::SliceOp>(
4783 op.getLoc(), tuple_first_element,
4784 GetI64ElementsAttr(begin_indices, &rewriter),
4785 GetI64ElementsAttr(end_indices, &rewriter),
4786 GetI64ElementsAttr(strides, &rewriter));
4787
4788 Value indices = rewriter.create<mhlo::SliceOp>(
4789 op.getLoc(), tuple_second_element,
4790 GetI64ElementsAttr(begin_indices, &rewriter),
4791 GetI64ElementsAttr(end_indices, &rewriter),
4792 GetI64ElementsAttr(strides, &rewriter));
4793
4794 rewriter.replaceOp(op, {values, indices});
4795 return success();
4796 }
4797 };
4798
4799 // Converts tf.Unpack to a series of XLA HLO slice ops.
4800 //
4801 // Each slice takes one element along the dimension to unpack and takes the full
4802 // range for all other dimensions. Each slice is then reshaped to drop the
4803 // dimension to unpack (which is always of size 1).
4804 // TODO(antiagainst): consider changing this into a TF internal lowering pass.
4805 class ConvertUnpackOp : public OpRewritePattern<TF::UnpackOp> {
4806 public:
4807 using OpRewritePattern::OpRewritePattern;
4808
matchAndRewrite(TF::UnpackOp op,PatternRewriter & rewriter) const4809 LogicalResult matchAndRewrite(TF::UnpackOp op,
4810 PatternRewriter &rewriter) const override {
4811 auto value_type = op.value().getType().dyn_cast<RankedTensorType>();
4812 if (!value_type) return failure();
4813
4814 int64_t value_rank = value_type.getRank();
4815 int64_t axis = op.axis();
4816 if (axis < 0) axis += value_rank;
4817
4818 // Parameters for constructing each slice.
4819 SmallVector<int64_t, 4> begin_indices(value_rank, 0);
4820 auto end_indices = llvm::to_vector<4>(value_type.getShape());
4821 SmallVector<int64_t, 4> strides(value_rank, 1);
4822
4823 // All HLO slice+squeeze results used to replace the original tf.Unpack op.
4824 SmallVector<Value, 4> results;
4825 results.reserve(op.getNumResults());
4826
4827 for (int i = 0, end = op.getNumResults(); i < end; ++i) {
4828 begin_indices[axis] = i;
4829 end_indices[axis] = i + 1;
4830
4831 auto slice_op = rewriter.create<mhlo::SliceOp>(
4832 op.getLoc(), op.value(), GetI64ElementsAttr(begin_indices, &rewriter),
4833 GetI64ElementsAttr(end_indices, &rewriter),
4834 GetI64ElementsAttr(strides, &rewriter));
4835 // Reshape to drop the axis dimension.
4836 auto result =
4837 rewriter.create<TF::SqueezeOp>(op.getLoc(), op.getType(i), slice_op,
4838 rewriter.getI64ArrayAttr(op.axis()));
4839 results.push_back(result);
4840 }
4841
4842 rewriter.replaceOp(op, results);
4843 return success();
4844 }
4845 };
4846
4847 // Converts TF unsorted segment reduction ops to XLA HLO scatter op.
4848 //
4849 // TF unsorted segment reduction op peforms the following calculation:
4850 //
4851 // Assume segment ids' shape is [SI0, SI1, ..., SIm] and data's shape is
4852 // [D0, D1, ..., Dn]. Note that segment ids' shape must be a prefix of data's
4853 // shape, so we can have data's shape represented as [SI0, SI1, ..., SIm,
4854 // Dm+1, ..., Dn]. Then
4855 // output[segment_ids[SI_i0, SI_i1, ..., SI_im], D_im+1, ..., D_in] =
4856 // <ReductionOp> over data[SI_i0, SI_i1, ..., SI_im, D_im+1, ..., D_in]
4857 // where SI_iN is in the range of [0, SIN) and D_iN is in the range of [0, DN).
4858 //
4859 // The op will be translated to XLA HLO scatter with the following parameters:
4860 // * Update window dims is [segment_id_rank, data_rank).
4861 // * Inserted window dims is {0}.
4862 // * Scatter dims to operand dims mapping is {0}.
4863 // * Index vector dim is segment_id_rank.
4864 template <typename ConcreteClass, typename OpTy, typename ReductionOp>
4865 class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern<OpTy> {
4866 using OpRewritePattern<OpTy>::OpRewritePattern;
4867
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const4868 LogicalResult matchAndRewrite(OpTy op,
4869 PatternRewriter &rewriter) const override {
4870 auto data_type = op.data().getType().template dyn_cast<RankedTensorType>();
4871 if (!data_type) return failure();
4872 int64_t data_rank = data_type.getRank();
4873
4874 auto segment_ids_type =
4875 op.segment_ids().getType().template dyn_cast<RankedTensorType>();
4876 if (!segment_ids_type) return failure();
4877 int64_t segment_ids_rank = segment_ids_type.getRank();
4878
4879 DenseIntElementsAttr num_segments_attr;
4880 if (!matchPattern(op.num_segments(), m_Constant(&num_segments_attr)))
4881 return failure();
4882
4883 // The final shape for TF unsorted segment reduction op is [num_segments] +
4884 // data_shape[segment_ids_rank:].
4885 SmallVector<int64_t, 4> output_shape;
4886 output_shape.push_back((*num_segments_attr.begin()).getSExtValue());
4887 auto suffix = data_type.getShape().drop_front(segment_ids_rank);
4888 output_shape.append(suffix.begin(), suffix.end());
4889 auto output_type =
4890 RankedTensorType::get(output_shape, data_type.getElementType());
4891
4892 // Broadcast the initial value for reduction. This will become the
4893 // 'operand' parameter to scatter to for the final scatter op.
4894 Value init = ConcreteClass::GetInitialValue(data_type.getElementType(),
4895 op.getLoc(), &rewriter);
4896 auto broadcasted_init = rewriter.create<mhlo::BroadcastOp>(
4897 op.getLoc(), output_type, init,
4898 GetI64ElementsAttr(output_shape, &rewriter));
4899
4900 // Parameters for the generated scatter op.
4901 SmallVector<int64_t, 1> inserted_window_dims(1, 0);
4902 SmallVector<int64_t, 1> scatter_dims_to_operand_dims(1, 0);
4903 int64_t index_vector_dim = segment_ids_rank;
4904
4905 // Put all parameters in a StructAttr.
4906 auto dims_attr = ScatterDimensionNumbers::get(
4907 GetI64ElementsAttrForSeq(segment_ids_rank, data_rank, &rewriter),
4908 GetI64ElementsAttr(inserted_window_dims, &rewriter),
4909 GetI64ElementsAttr(scatter_dims_to_operand_dims, &rewriter),
4910 rewriter.getI64IntegerAttr(index_vector_dim), rewriter.getContext());
4911
4912 auto scatter =
4913 rewriter.create<ScatterOp>(op.getLoc(), op.getType(), broadcasted_init,
4914 op.segment_ids(), op.data(), dims_attr);
4915 BuildReduceBody<ReductionOp>(data_type.getElementType(),
4916 &scatter.update_computation(), &rewriter);
4917
4918 rewriter.replaceOp(op, scatter.getResult());
4919 return success();
4920 }
4921 };
4922
4923 class ConvertUnsortedSegmentMaxOp
4924 : public GenericConvertUnsortedSegmentReductionOp<
4925 ConvertUnsortedSegmentMaxOp, TF::UnsortedSegmentMaxOp, MaxOp> {
4926 public:
4927 using GenericConvertUnsortedSegmentReductionOp::
4928 GenericConvertUnsortedSegmentReductionOp;
4929
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4930 static Value GetInitialValue(Type reduce_element_type, Location loc,
4931 PatternRewriter *rewriter) {
4932 return GetScalarLimitConstOfType(reduce_element_type, loc, hlo::kLowest,
4933 rewriter);
4934 }
4935 };
4936
4937 class ConvertUnsortedSegmentMinOp
4938 : public GenericConvertUnsortedSegmentReductionOp<
4939 ConvertUnsortedSegmentMinOp, TF::UnsortedSegmentMinOp, MinOp> {
4940 public:
4941 using GenericConvertUnsortedSegmentReductionOp::
4942 GenericConvertUnsortedSegmentReductionOp;
4943
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4944 static Value GetInitialValue(Type reduce_element_type, Location loc,
4945 PatternRewriter *rewriter) {
4946 return GetScalarLimitConstOfType(reduce_element_type, loc, hlo::kMax,
4947 rewriter);
4948 }
4949 };
4950
4951 class ConvertUnsortedSegmentProdOp
4952 : public GenericConvertUnsortedSegmentReductionOp<
4953 ConvertUnsortedSegmentProdOp, TF::UnsortedSegmentProdOp, MulOp> {
4954 public:
4955 using GenericConvertUnsortedSegmentReductionOp::
4956 GenericConvertUnsortedSegmentReductionOp;
4957
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4958 static Value GetInitialValue(Type reduce_element_type, Location loc,
4959 PatternRewriter *rewriter) {
4960 return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter);
4961 }
4962 };
4963
4964 class ConvertUnsortedSegmentSumOp
4965 : public GenericConvertUnsortedSegmentReductionOp<
4966 ConvertUnsortedSegmentSumOp, TF::UnsortedSegmentSumOp, AddOp> {
4967 public:
4968 using GenericConvertUnsortedSegmentReductionOp::
4969 GenericConvertUnsortedSegmentReductionOp;
4970
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4971 static Value GetInitialValue(Type reduce_element_type, Location loc,
4972 PatternRewriter *rewriter) {
4973 return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter);
4974 }
4975 };
4976
4977 // Converts tf.RandomShuffle op into a series of XLA HLO ops.
4978 //
4979 // tf.RandomShuffle shuffles tensors along the first dimension. If the input
4980 // tensor's rank is 1, then it is translated into HLO sort op(s) according to
4981 // indices randomly generated via HLO rng_uniform ops. Otherwise, it is
4982 // translated into an HLO while op to first emulate shuffling indices using
4983 // HLO dynamic_slice and dynamic_update_slice ops, then finally HLO gather
4984 // with the shuffled indices.
4985 class ConvertRandomShuffleOp : public OpRewritePattern<TF::RandomShuffleOp> {
4986 public:
4987 using OpRewritePattern::OpRewritePattern;
4988
matchAndRewrite(TF::RandomShuffleOp op,PatternRewriter & rewriter) const4989 LogicalResult matchAndRewrite(TF::RandomShuffleOp op,
4990 PatternRewriter &rewriter) const override {
4991 auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
4992 if (!input_type) return failure();
4993
4994 int64_t input_rank = input_type.getRank();
4995 int64_t first_dim_size = input_type.getDimSize(0);
4996 if (ShapedType::isDynamic(first_dim_size)) return failure();
4997
4998 // We are shuffling along the first dimension. If its size is <= 1, then
4999 // shuffling is a no-op.
5000 if (first_dim_size <= 1) {
5001 rewriter.replaceOp(op, op.value());
5002 return success();
5003 }
5004
5005 // For vectors, shuffle values by sorting instead of the obvious
5006 // Fisher-Yates algorithm. Fisher-Yates is simple to implement and correct,
5007 // but not easily parallelizable. For a sufficiently parallel architecture,
5008 // it is faster to sort many times, than Fisher-Yates shuffle once.
5009 if (input_rank == 1) {
5010 // Shuffle values by assigning each value a random key and sorting the
5011 // keys. Keys can collide causing detectable patterns in the shuffled
5012 // output. Collisions translates into more ascending sub-sequences in the
5013 // shuffled output than would be expected by chance. To avoid collisions,
5014 // the number of possible key values must be sufficiently large.
5015
5016 // How are more than 2^32 keys created? In each loop iteration, the
5017 // algorithm sorts by random keys. Conceptually, the earlier iterations
5018 // are sorting on the lower-order bits of larger keys that are never
5019 // actually assembled.
5020
5021 // The expected number of collisions is n - d + d(1 - 1/d)^n, where d is
5022 // the number of possible keys and n is the number of values. If d = n^2,
5023 // then the limit as n goes to infinity is 1/2. If d = n^3, then the limit
5024 // as n goes to infinity is zero.
5025
5026 // This implementation ensures that the key-space is greater than or equal
5027 // to the cube of the number of values. The risk of collisions can be
5028 // further reduced by increasing Exponent at the expense of
5029 // performance.
5030
5031 // For Exponent = 2, the expected number of collisions per shuffle is
5032 // maximized at n = floor((2^32-1)^(1/2)) = 65535 where the expectation is
5033 // about 1/2.
5034
5035 // For Exponent = 3, the expected number of collisions per shuffle is
5036 // maximized at n = floor((2^32-1)^(1/3)) = 1625 where the expectation is
5037 // about 1/3255.
5038
5039 // For Exponent = 4, the expected number of collisions per shuffle is
5040 // maximized at n = floor((2^32-1)^(1/4)) = 255 where the expectation is
5041 // about 1/132622.
5042 constexpr int exponent = 3;
5043 int64_t num_elements = input_type.getNumElements();
5044 uint32_t u32_max = std::numeric_limits<uint32_t>::max();
5045 int rounds =
5046 std::ceil(exponent * std::log(num_elements) / std::log(u32_max));
5047
5048 Value current = op.value();
5049 for (int i = 0; i < rounds; ++i) {
5050 auto keys =
5051 CreateRngUniform32(op.getLoc(), num_elements, /*lower_limit=*/0,
5052 /*upper_limit=*/u32_max, &rewriter);
5053 auto sorted = rewriter.create<mhlo::SortOp>(
5054 op.getLoc(), llvm::ArrayRef<Value>{keys, current});
5055 auto i32_type = rewriter.getIntegerType(32);
5056 BuildSortComparisonBody({i32_type, input_type.getElementType()},
5057 /*direction=*/"LT", llvm::None,
5058 &sorted.comparator(), &rewriter);
5059 current = sorted.getResult(1);
5060 }
5061 rewriter.replaceOp(op, current);
5062 return success();
5063 }
5064
5065 // The Fisher-Yates algorithm.
5066
5067 // Generate range(n) as the initial value for the indices to be swapped.
5068 auto indices_type =
5069 RankedTensorType::get({first_dim_size}, rewriter.getIntegerType(32));
5070 Value indices = rewriter.create<mhlo::IotaOp>(
5071 op.getLoc(), indices_type, rewriter.getI64IntegerAttr(0));
5072
5073 // Generate random numbers to be used as swaps for the indices.
5074 Value swaps = CreateRngUniform32(op.getLoc(), first_dim_size, 0,
5075 first_dim_size, &rewriter);
5076
5077 // While loop body to perform index swaps.
5078 auto swap_body_fn = [&](Location loc, Value i, ArrayRef<Value> old_values,
5079 SmallVectorImpl<Value> *new_values,
5080 OpBuilder *builder) {
5081 Value swaps = old_values[0];
5082 Value indices = old_values[1];
5083
5084 auto vec1_i32_type =
5085 RankedTensorType::get({1}, builder->getIntegerType(32));
5086 auto scalar_i32_type =
5087 RankedTensorType::get({}, builder->getIntegerType(32));
5088 auto scalar_i64_type =
5089 RankedTensorType::get({}, builder->getIntegerType(64));
5090
5091 auto scalar_one =
5092 DenseIntElementsAttr::get(scalar_i64_type, ArrayRef<int64_t>(1));
5093
5094 // We need to swap the indices[i] with indices[swaps[i]]. First get
5095 // these index values.
5096 Value source_index = builder->create<mhlo::DynamicSliceOp>(
5097 loc, vec1_i32_type, indices, i, scalar_one);
5098 Value swap_index = builder->create<mhlo::ReshapeOp>(
5099 loc, scalar_i32_type,
5100 builder->create<mhlo::DynamicSliceOp>(loc, vec1_i32_type, swaps, i,
5101 scalar_one));
5102 Value target_index = builder->create<mhlo::DynamicSliceOp>(
5103 loc, vec1_i32_type, indices, swap_index, scalar_one);
5104
5105 // Then perform the swap.
5106 // indices[i] <- indices[swaps[i]]
5107 indices = builder->create<mhlo::DynamicUpdateSliceOp>(
5108 loc, indices.getType(), indices, target_index, llvm::makeArrayRef(i));
5109 // indices[swaps[i]] <- indices[i]
5110 indices = builder->create<mhlo::DynamicUpdateSliceOp>(
5111 loc, indices.getType(), indices, source_index,
5112 llvm::makeArrayRef(swap_index));
5113
5114 // Update new values.
5115 new_values->assign({swaps, indices});
5116 };
5117
5118 // Create a while op to swap indices.
5119 SmallVector<Value, 2> while_output;
5120 CreateWhile32(op.getLoc(), first_dim_size, swap_body_fn, {swaps, indices},
5121 &while_output, &rewriter);
5122 Value swaped_indices = while_output[1];
5123
5124 // Gather the data using the swapped indices as the shuffled order.
5125 ArrayRef<int64_t> input_shape = input_type.getShape();
5126 SmallVector<int64_t, 4> slice_sizes(input_shape.begin(), input_shape.end());
5127 slice_sizes[0] = 1;
5128 auto dims_attr = GatherDimensionNumbers::get(
5129 /*offset_dims=*/GetI64ElementsAttrForSeq(1, input_rank, &rewriter),
5130 /*collapsed_slice_dims=*/GetI64ElementsAttr({0}, &rewriter),
5131 /*start_index_map=*/GetI64ElementsAttr({0}, &rewriter),
5132 /*index_vector_dim=*/rewriter.getI64IntegerAttr(1),
5133 rewriter.getContext());
5134 rewriter.replaceOpWithNewOp<mhlo::GatherOp>(
5135 op, op.getType(), op.value(), swaped_indices, dims_attr,
5136 GetI64ElementsAttr(slice_sizes, &rewriter));
5137
5138 return success();
5139 }
5140 };
5141
5142 // Converts an XlaSharding op to a XLA HLO shard op with sharding attributes.
5143 class ConvertXlaShardingOp : public OpRewritePattern<TF::XlaShardingOp> {
5144 public:
5145 using OpRewritePattern::OpRewritePattern;
5146
matchAndRewrite(TF::XlaShardingOp op,PatternRewriter & rewriter) const5147 LogicalResult matchAndRewrite(TF::XlaShardingOp op,
5148 PatternRewriter &rewriter) const override {
5149 // TODO(b/148313088): define sharding attribute struct in MLIR intead of
5150 // using a string.
5151 if (!op._XlaSharding().hasValue()) return failure();
5152
5153 auto custom_call = rewriter.create<mhlo::CustomCallOp>(
5154 op.getLoc(), op.getType(), op.input(),
5155 /*call_target_name=*/rewriter.getStringAttr("Sharding"),
5156 /*has_side_effect=*/rewriter.getBoolAttr(false),
5157 /*backend_config=*/rewriter.getStringAttr(""));
5158 custom_call->setAttr(kShardingAttr, op._XlaShardingAttr());
5159 rewriter.replaceOp(op, custom_call.getResult(0));
5160
5161 return success();
5162 }
5163 };
5164
5165 // Converts a TF InplaceUpdate op to DynamicUpdateSlice HLO.
5166 class ConvertInplaceUpdateOp : public OpRewritePattern<TF::InplaceUpdateOp> {
5167 public:
5168 using OpRewritePattern::OpRewritePattern;
5169
matchAndRewrite(TF::InplaceUpdateOp op,PatternRewriter & rewriter) const5170 LogicalResult matchAndRewrite(TF::InplaceUpdateOp op,
5171 PatternRewriter &rewriter) const override {
5172 auto input = op.x();
5173 auto indices = op.i();
5174 auto updates = op.v();
5175
5176 // Slice each row of `i` and `v` to perform a separate dynamic-update-slice
5177 // on the contents of `x`.
5178 auto input_type = input.getType().cast<ShapedType>();
5179 auto updates_type = updates.getType().cast<ShapedType>();
5180 auto indices_type = indices.getType().cast<ShapedType>();
5181 if (!indices_type.hasStaticShape()) return failure();
5182
5183 if (indices_type.getRank() != 1) return failure();
5184
5185 SmallVector<Type, 4> unpacked_indices_type(
5186 indices_type.getDimSize(0),
5187 RankedTensorType::get({}, indices_type.getElementType()));
5188 // Note on zero_attr integer type: DynamicUpdateSlice op start_indices are
5189 // required to have matching types. This rewrite rule creates
5190 // DynamicUpdateSlice ops where the first "start index" is always i32 and
5191 // subsequent ones are constructed based on zero_attr. Thus the type
5192 // for zero_attr needs to be i32 as well.
5193 auto zero_attr = IntegerAttr::get(rewriter.getIntegerType(32), 0);
5194 auto unpacked_indices = rewriter.create<TF::UnpackOp>(
5195 op.getLoc(), unpacked_indices_type, indices, zero_attr);
5196
5197 SmallVector<int64_t, 4> split_updates_shape;
5198 split_updates_shape.append(updates_type.getShape().begin(),
5199 updates_type.getShape().end());
5200 split_updates_shape.front() = 1;
5201 SmallVector<Type, 4> split_updates_type;
5202 split_updates_type.resize(
5203 updates_type.getShape().front(),
5204 RankedTensorType::get(split_updates_shape,
5205 updates_type.getElementType()));
5206
5207 auto cst =
5208 rewriter.create<mhlo::ConstOp>(op.getLoc(), zero_attr).getResult();
5209 auto split_updates = rewriter.create<TF::SplitOp>(
5210 op.getLoc(), split_updates_type, cst, updates);
5211
5212 SmallVector<Value, 6> input_indices;
5213 input_indices.resize(input_type.getRank(), cst);
5214
5215 SmallVector<int64_t, 6> starts(updates_type.getRank(), 0);
5216 SmallVector<int64_t, 6> strides(updates_type.getRank(), 1);
5217 SmallVector<int64_t, 6> limits(updates_type.getShape().begin(),
5218 updates_type.getShape().end());
5219
5220 for (auto pair :
5221 llvm::zip(unpacked_indices.output(), split_updates.output())) {
5222 input_indices.front() = std::get<0>(pair);
5223 input = rewriter.create<mhlo::DynamicUpdateSliceOp>(
5224 op.getLoc(), op.getType(), input, std::get<1>(pair), input_indices);
5225 }
5226
5227 rewriter.replaceOp(op, input);
5228 return success();
5229 }
5230 };
5231
5232 // Converts a TF XlaDynamicUpdateSlice op to DynamicUpdateSlice HLO.
5233 class ConvertXlaDynamicUpdateSliceOp
5234 : public OpRewritePattern<TF::XlaDynamicUpdateSliceOp> {
5235 public:
5236 using OpRewritePattern::OpRewritePattern;
5237
matchAndRewrite(TF::XlaDynamicUpdateSliceOp op,PatternRewriter & rewriter) const5238 LogicalResult matchAndRewrite(TF::XlaDynamicUpdateSliceOp op,
5239 PatternRewriter &rewriter) const override {
5240 auto indices_type = op.indices().getType().dyn_cast<RankedTensorType>();
5241 if (!indices_type || !indices_type.hasStaticShape() ||
5242 indices_type.getShape().size() != 1)
5243 return failure();
5244
5245 SmallVector<Type, 4> unpacked_indices_type(
5246 indices_type.getDimSize(0),
5247 RankedTensorType::get({}, indices_type.getElementType()));
5248 auto unpacked_indices = rewriter.create<TF::UnpackOp>(
5249 op.getLoc(), unpacked_indices_type, op.indices(),
5250 IntegerAttr::get(rewriter.getIntegerType(64), 0));
5251 rewriter.replaceOpWithNewOp<mhlo::DynamicUpdateSliceOp>(
5252 op, op.getType(), op.input(), op.update(), unpacked_indices.output());
5253 return success();
5254 }
5255 };
5256
5257 // Converts a TF XlaAllReduce op to AllReduce HLO.
5258 class ConvertXlaAllReduceOp : public OpRewritePattern<TF::XlaAllReduceOp> {
5259 using OpRewritePattern::OpRewritePattern;
5260
matchAndRewrite(TF::XlaAllReduceOp op,PatternRewriter & rewriter) const5261 LogicalResult matchAndRewrite(TF::XlaAllReduceOp op,
5262 PatternRewriter &rewriter) const override {
5263 DenseIntElementsAttr group_assignment;
5264 if (!matchPattern(op.group_assignment(), m_Constant(&group_assignment)))
5265 return failure();
5266 auto replica_groups =
5267 hlo::ConvertElementsAttr(group_assignment, rewriter.getIntegerType(64))
5268 .cast<DenseIntElementsAttr>();
5269 if (replica_groups.getType().getRank() != 2) return failure();
5270
5271 Location loc = op.getLoc();
5272 Type element_type = getElementTypeOrSelf(op.input().getType());
5273
5274 auto all_reduce = rewriter.create<AllReduceOp>(
5275 loc, op.getType(), op.input(), replica_groups, ChannelHandle());
5276 StringRef reduce_op = op.reduce_op();
5277 if (reduce_op == "Add") {
5278 BuildReduceBody<AddOp>(element_type, &all_reduce.computation(),
5279 &rewriter);
5280 } else if (reduce_op == "Mul") {
5281 BuildReduceBody<MulOp>(element_type, &all_reduce.computation(),
5282 &rewriter);
5283 } else if (reduce_op == "Min") {
5284 BuildReduceBody<MinOp>(element_type, &all_reduce.computation(),
5285 &rewriter);
5286 } else if (reduce_op == "Max") {
5287 BuildReduceBody<MaxOp>(element_type, &all_reduce.computation(),
5288 &rewriter);
5289 } else {
5290 // For mean, add replicas in the same group. Then divide the sum by the
5291 // number of replicas in each group below.
5292 assert(reduce_op == "Mean");
5293 BuildReduceBody<AddOp>(element_type, &all_reduce.computation(),
5294 &rewriter);
5295 }
5296 Value result = all_reduce.getResult();
5297
5298 // For mean, divide the merge result by group size.
5299 if (reduce_op == "Mean") {
5300 int64_t replica_group_size = replica_groups.getType().getDimSize(1);
5301 auto divisor = GetScalarConstOfType(element_type, loc, replica_group_size,
5302 &rewriter);
5303 auto broadcast_dims = GetI64ElementsAttr({}, &rewriter);
5304 result = rewriter.create<chlo::BroadcastDivOp>(
5305 loc, result, divisor.getResult(), broadcast_dims);
5306 }
5307
5308 rewriter.replaceOp(op, {result});
5309 return success();
5310 }
5311 };
5312
5313 // Converts ClipByValue to XLA's clamp operation. Includes the broadcasting
5314 // semantics for static and dynamic cases.
5315 class ConvertClipByValueOp : public OpRewritePattern<TF::ClipByValueOp> {
5316 public:
5317 using OpRewritePattern::OpRewritePattern;
5318
matchAndRewrite(TF::ClipByValueOp op,PatternRewriter & rewriter) const5319 LogicalResult matchAndRewrite(TF::ClipByValueOp op,
5320 PatternRewriter &rewriter) const override {
5321 Value input = op.t();
5322 Value min = op.clip_value_min();
5323 Value max = op.clip_value_max();
5324
5325 auto input_ty = input.getType().cast<ShapedType>();
5326 auto min_ty = min.getType().cast<ShapedType>();
5327 auto max_ty = max.getType().cast<ShapedType>();
5328
5329 if (!input_ty.hasRank() || !min_ty.hasRank() || !max_ty.hasRank()) {
5330 return failure();
5331 }
5332
5333 auto shape = rewriter.create<TF::ShapeOp>(
5334 op.getLoc(),
5335 RankedTensorType::get({input_ty.getRank()}, rewriter.getI32Type()),
5336 input);
5337
5338 if (min_ty != input_ty) {
5339 min =
5340 rewriter.create<TF::BroadcastToOp>(op.getLoc(), input_ty, min, shape);
5341 }
5342
5343 if (max_ty != input_ty) {
5344 max =
5345 rewriter.create<TF::BroadcastToOp>(op.getLoc(), input_ty, max, shape);
5346 }
5347
5348 rewriter.replaceOpWithNewOp<mhlo::ClampOp>(op, input_ty, min, input, max);
5349 return success();
5350 }
5351 };
5352
5353 // Converts the Cumsum or Cumprod TensorFlow op to the HLO ReduceWindow op by
5354 // setting appropriate window dimensions, with the given aggregation op as the
5355 // reduction function. The input tensor needs to have a static shape, and 'axis'
5356 // must be const. The TableGen pattern is not used for this rewrite because it
5357 // involves regions.
5358 template <typename OpT, typename AggregationOp>
5359 class ConvertCumOp : public OpRewritePattern<OpT> {
5360 using OpRewritePattern<OpT>::OpRewritePattern;
5361
matchAndRewrite(OpT op,PatternRewriter & rewriter) const5362 LogicalResult matchAndRewrite(OpT op,
5363 PatternRewriter &rewriter) const override {
5364 auto input = op.x();
5365 auto input_type = input.getType().template dyn_cast<ShapedType>();
5366 if (!input_type || !input_type.hasStaticShape()) {
5367 return failure();
5368 }
5369
5370 ArrayRef<int64_t> input_shape = input_type.getShape();
5371 int64_t rank = input_shape.size();
5372
5373 // We can only match when the axis is a constant scalar.
5374 DenseIntElementsAttr axis_attr;
5375 if (!matchPattern(op.axis(), m_Constant(&axis_attr))) {
5376 return failure();
5377 }
5378
5379 // Get the dimension to apply the reduction on, and offset properly if it is
5380 // negative.
5381 int64_t axis = (*axis_attr.begin()).getSExtValue();
5382 if (axis < 0) {
5383 axis += rank;
5384 }
5385
5386 // If we're supposed to sum things up in the reverse direction, we reverse
5387 // the input and then later reverse the output.
5388 if (op.reverse()) {
5389 llvm::SmallVector<int64_t, 4> dims_to_reverse({axis});
5390 input = rewriter.create<ReverseOp>(
5391 op.getLoc(), op.getType(), input,
5392 GetI64ElementsAttr(dims_to_reverse, &rewriter));
5393 }
5394
5395 // Convert if we need to enlarge the element type's bitwidth to avoid
5396 // precision loss.
5397 Type input_element_type = input_type.getElementType();
5398
5399 // TODO(hinsu): Handle complex element types.
5400 if (!input_element_type.isIntOrFloat()) return failure();
5401
5402 Type sum_element_type = GetSumAccumulationType(input_element_type);
5403 input = rewriter.create<ConvertOp>(op.getLoc(), input, sum_element_type);
5404
5405 SmallVector<int64_t, 4> window_dims(rank, 1);
5406 SmallVector<int64_t, 4> window_strides(rank, 1);
5407 window_dims[axis] = input_shape[axis];
5408
5409 SmallVector<int64_t, 8> paddings(rank * 2, 0);
5410 paddings[axis * 2] =
5411 std::max(input_shape[axis] - 1, static_cast<int64_t>(0));
5412 auto paddings_attr = DenseIntElementsAttr::get(
5413 RankedTensorType::get({rank, 2}, rewriter.getIntegerType(64)),
5414 paddings);
5415
5416 int64_t init_value = (std::is_same<AggregationOp, AddOp>::value) ? 0 : 1;
5417 Value init = GetScalarConstOfType(sum_element_type, op.getLoc(), init_value,
5418 &rewriter);
5419
5420 auto reduce = rewriter.create<ReduceWindowOp>(
5421 op.getLoc(), input_type, input, init,
5422 GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_dims)),
5423 GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_strides)),
5424 /*base_dilations=*/DenseIntElementsAttr(),
5425 /*window_dilations=*/DenseIntElementsAttr(), paddings_attr);
5426 BuildReduceBody<AggregationOp>(sum_element_type, &reduce.body(), &rewriter);
5427 Value result = reduce.getResult();
5428
5429 if (op.exclusive()) {
5430 // In "exclusive" operation, the output will start with the "init" (0)
5431 // values. There is no way to express that as a ReduceWindowOp, so run the
5432 // normal operation, and then use a PadOp to add the 0 "column" on the
5433 // left and cut away the last column on the right.
5434 llvm::SmallVector<int64_t, 4> low_padding(rank, 0);
5435 llvm::SmallVector<int64_t, 4> high_padding(rank, 0);
5436 llvm::SmallVector<int64_t, 4> interior_padding(rank, 0);
5437 low_padding[axis] = 1;
5438 high_padding[axis] = -1;
5439 result = rewriter.create<PadOp>(
5440 op.getLoc(), op.getType(), result, init,
5441 GetI64ElementsAttr(low_padding, &rewriter),
5442 GetI64ElementsAttr(high_padding, &rewriter),
5443 GetI64ElementsAttr(interior_padding, &rewriter));
5444 }
5445
5446 // Convert back if we enlarged the element type's bitwidth.
5447 result =
5448 rewriter.create<ConvertOp>(op.getLoc(), result, input_element_type);
5449
5450 if (op.reverse()) {
5451 llvm::SmallVector<int64_t, 4> dims_to_reverse({axis});
5452 result = rewriter.create<ReverseOp>(
5453 op.getLoc(), op.getType(), result,
5454 GetI64ElementsAttr(dims_to_reverse, &rewriter));
5455 }
5456
5457 rewriter.replaceOp(op, result);
5458 return success();
5459 }
5460 };
5461
5462 using ConvertCumsumOp = ConvertCumOp<TF::CumsumOp, AddOp>;
5463 using ConvertCumprodOp = ConvertCumOp<TF::CumprodOp, MulOp>;
5464
5465 // Converts the Tensorflow ShapeOp to a sequence of Shape dialect and Standard
5466 // dialect lowerings. This involves extracting the shape type, extracting and
5467 // converting each dimension to a known integer type, and repacking into a final
5468 // tensor.
5469 class ConvertShapeOp : public OpRewritePattern<TF::ShapeOp> {
5470 public:
5471 using OpRewritePattern::OpRewritePattern;
5472
matchAndRewrite(TF::ShapeOp op,PatternRewriter & rewriter) const5473 LogicalResult matchAndRewrite(TF::ShapeOp op,
5474 PatternRewriter &rewriter) const override {
5475 Value input = op.input();
5476
5477 auto shape_op = rewriter.create<shape::ShapeOfOp>(op.getLoc(), input);
5478 auto result_ty = op.getResult().getType().dyn_cast<RankedTensorType>();
5479 if (!result_ty) {
5480 return failure();
5481 }
5482
5483 auto index_tensor =
5484 RankedTensorType::get(result_ty.getShape(), rewriter.getIndexType());
5485 auto extent_tensor = rewriter.create<shape::ToExtentTensorOp>(
5486 op.getLoc(), index_tensor, shape_op);
5487
5488 rewriter.replaceOpWithNewOp<IndexCastOp>(op, result_ty, extent_tensor);
5489 return success();
5490 }
5491 };
5492
5493 class ConvertDynamicReshapeOp : public OpRewritePattern<TF::ReshapeOp> {
5494 public:
5495 using OpRewritePattern::OpRewritePattern;
5496
matchAndRewrite(TF::ReshapeOp op,PatternRewriter & rewriter) const5497 LogicalResult matchAndRewrite(TF::ReshapeOp op,
5498 PatternRewriter &rewriter) const override {
5499 auto tensor = op.tensor();
5500 auto shape = op.shape();
5501
5502 auto tensor_ty = tensor.getType().cast<ShapedType>();
5503 auto shape_ty = shape.getType().cast<ShapedType>();
5504 auto result_ty = op.getType().cast<ShapedType>();
5505
5506 if (!result_ty.hasRank() || !tensor_ty.hasRank() || !shape_ty.hasRank()) {
5507 return failure();
5508 }
5509
5510 // Handle with the static case.
5511 if (result_ty.hasStaticShape()) {
5512 return failure();
5513 }
5514
5515 rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_ty, tensor,
5516 shape);
5517 return success();
5518 }
5519 };
5520
5521 class ConvertDynamicExpandDimsOp : public OpRewritePattern<TF::ExpandDimsOp> {
5522 public:
5523 using OpRewritePattern::OpRewritePattern;
5524
matchAndRewrite(TF::ExpandDimsOp op,PatternRewriter & rewriter) const5525 LogicalResult matchAndRewrite(TF::ExpandDimsOp op,
5526 PatternRewriter &rewriter) const override {
5527 auto input = op.input();
5528 auto input_ty = input.getType().cast<ShapedType>();
5529 auto result_ty = op.getType().cast<ShapedType>();
5530 if (!result_ty.hasRank() || !input_ty.hasRank() ||
5531 result_ty.hasStaticShape()) {
5532 return failure();
5533 }
5534
5535 DenseIntElementsAttr expand_dims_attr;
5536 if (!matchPattern(op.dim(), m_Constant(&expand_dims_attr))) {
5537 return failure();
5538 }
5539
5540 auto shape = rewriter.create<shape::ShapeOfOp>(
5541 op.getLoc(),
5542 RankedTensorType::get({input_ty.getRank()}, rewriter.getIndexType()),
5543 input);
5544 auto expand_dims = llvm::to_vector<6>(expand_dims_attr.getIntValues());
5545
5546 llvm::SmallVector<Value, 4> dims;
5547 dims.resize(result_ty.getRank());
5548
5549 auto inserted_dim = expand_dims[0].getSExtValue();
5550
5551 // Handle the negative value use case.
5552 if (inserted_dim < 0) {
5553 inserted_dim += result_ty.getRank();
5554 // This means the value is completely incorrect, just return.
5555 if (inserted_dim < 0) {
5556 return failure();
5557 }
5558 }
5559
5560 dims[inserted_dim] = rewriter.create<ConstantIndexOp>(op.getLoc(), 1);
5561
5562 for (int i = 0; i < dims.size() - 1; i++) {
5563 // Add the extracted dim.
5564 auto index = rewriter.create<ConstantIndexOp>(op.getLoc(), i);
5565 auto dim = rewriter.create<shape::GetExtentOp>(
5566 op.getLoc(), rewriter.getIndexType(), shape, index);
5567
5568 dims[i >= inserted_dim ? i + 1 : i] = dim;
5569 }
5570
5571 auto from_extents = rewriter.create<shape::FromExtentsOp>(
5572 op.getLoc(), shape::ShapeType::get(op.getContext()), dims);
5573
5574 auto to_extent_tensor = rewriter.create<shape::ToExtentTensorOp>(
5575 op.getLoc(),
5576 RankedTensorType::get({result_ty.getRank()}, rewriter.getIndexType()),
5577 from_extents);
5578
5579 rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_ty, input,
5580 to_extent_tensor);
5581 return success();
5582 }
5583 };
5584
5585 // Converts a TF QR op to HLO.
5586 class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
5587 public:
5588 using OpRewritePattern::OpRewritePattern;
5589
matchAndRewrite(TF::QrOp op,PatternRewriter & rewriter) const5590 LogicalResult matchAndRewrite(TF::QrOp op,
5591 PatternRewriter &rewriter) const override {
5592 // Block Householder QR Factorization. Algorithm 5.2.2 of Golub and van
5593 // Loan. def qr_blocked(a, block_size):
5594 // m = a.shape[0]
5595 // n = a.shape[1]
5596 // q = np.eye(m)
5597 // for i in xrange(0, min(m, n), block_size):
5598 // k = min(block_size, min(m, n) - s)
5599 // (a, vs, taus) = qr(a[i:, i:i+k])
5600 // y = vs
5601 // w = ComputeWYRepresentation(vs, taus, m-i, k)
5602 // a[i:, i+r:] += np.dot(y, np.dot(w.T, a[i:, i+k:]))
5603 // q[:, i:] += np.dot(q[:, i:], np.dot(w, y.T))
5604 // return (q, a)
5605 auto type = op.input().getType().dyn_cast<RankedTensorType>();
5606 if (!type || !type.hasStaticShape()) return failure();
5607 // The block size is chosen to match old bridge lowering.
5608 constexpr int64_t kBlockSize = 128;
5609 Value a = op.input();
5610 int64_t m = type.getDimSize(type.getRank() - 2);
5611 int64_t n = type.getDimSize(type.getRank() - 1);
5612 int64_t p = std::min(m, n);
5613 auto batch_dims = type.getShape().drop_back(2);
5614 auto iota_type = RankedTensorType::get({m, m}, rewriter.getIntegerType(32));
5615 auto iota0 = rewriter.create<IotaOp>(op.getLoc(), iota_type,
5616 rewriter.getI64IntegerAttr(0));
5617 auto iota1 = rewriter.create<IotaOp>(op.getLoc(), iota_type,
5618 rewriter.getI64IntegerAttr(1));
5619 Value compare = rewriter.create<CompareOp>(
5620 op.getLoc(), iota0, iota1,
5621 StringAttr::get(rewriter.getContext(), "EQ"));
5622 Value identity_matrix =
5623 rewriter.create<ConvertOp>(op.getLoc(), compare, type.getElementType());
5624 auto q_shape = llvm::to_vector<4>(type.getShape());
5625 q_shape.back() = m;
5626 Value q = rewriter.create<BroadcastOp>(
5627 op.getLoc(), RankedTensorType::get(q_shape, type.getElementType()),
5628 identity_matrix, GetI64ElementsAttr(batch_dims, &rewriter));
5629 auto precision_config = rewriter.getStrArrayAttr({"HIGHEST", "HIGHEST"});
5630 for (int64_t i = 0; i < p; i += kBlockSize) {
5631 int64_t k = std::min(kBlockSize, p - i);
5632 auto a_block =
5633 SliceInMinorDims(op.getLoc(), a, {i, i}, {m, i + k}, &rewriter);
5634 Value r_block;
5635 Value taus;
5636 Value vs;
5637 QRBlock(op.getLoc(), a_block, &r_block, &taus, &vs, &rewriter);
5638 a = UpdateSliceInMinorDims(op.getLoc(), a, r_block, {i, i}, &rewriter);
5639
5640 // Compute the I-WY block representation of a product of Householder
5641 // matrices.
5642 Value w =
5643 ComputeWYRepresentation(op.getLoc(), type.getElementType(),
5644 batch_dims, vs, taus, m - i, k, &rewriter);
5645 auto y = vs;
5646
5647 // a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:]))
5648 Value a_panel =
5649 SliceInMinorDims(op.getLoc(), a, {i, i + k}, {m, n}, &rewriter);
5650 auto a_update = BatchDot(op.getLoc(), w, true, a_panel, false,
5651 batch_dims.size(), precision_config, &rewriter);
5652 a_update = BatchDot(op.getLoc(), y, false, a_update, false,
5653 batch_dims.size(), precision_config, &rewriter);
5654 a_panel = rewriter.create<AddOp>(op.getLoc(), a_panel, a_update);
5655 a = UpdateSliceInMinorDims(op.getLoc(), a, a_panel, {i, i + k},
5656 &rewriter);
5657
5658 // q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T))
5659 Value q_panel =
5660 SliceInMinorDims(op.getLoc(), q, {0, i}, {m, m}, &rewriter);
5661 Value q_update = BatchDot(op.getLoc(), q_panel, false, w, false,
5662 batch_dims.size(), precision_config, &rewriter);
5663 q_update = BatchDot(op.getLoc(), q_update, false, y, true,
5664 batch_dims.size(), precision_config, &rewriter);
5665 q_panel = rewriter.create<AddOp>(op.getLoc(), q_panel, q_update);
5666 q = UpdateSliceInMinorDims(op.getLoc(), q, q_panel, {i}, &rewriter);
5667 }
5668 // full_matrices is false when only a partial result in needed. Slice to the
5669 // needed dimensions here.
5670 if (!op.full_matrices()) {
5671 q = SliceInMinorDims(op.getLoc(), q, {0, 0}, {m, p}, &rewriter);
5672 a = SliceInMinorDims(op.getLoc(), a, {0, 0}, {p, n}, &rewriter);
5673 }
5674 rewriter.replaceOp(op, {q, a});
5675 return success();
5676 }
5677
5678 private:
5679 // Computes a Householder reflection of the form:
5680 // H = I - tau v v.T.
5681 // such that
5682 // H . ( x1 ) = ( x1 )
5683 // ( x2 ) = ( x2 )
5684 // ( ... ) = ( ... )
5685 // ( xk ) = ( beta )
5686 // ( ... ) ( 0 )
5687 // ( ... ) ( 0 )
5688 // Unlike the usual formulation, we allow the caller to supply 'k' rather than
5689 // only providing the relevant part of 'x' to maintain XLA's static shape
5690 // invariant. In addition, the implementation supports batching.
5691 // Pseudo-code, without batching:
5692 // alpha = x[k]
5693 // x_copy = np.copy(x)
5694 // x_copy[:k+1] = 0
5695 // xnorm = norm2(x_copy)
5696 // if xnorm == 0:
5697 // beta = alpha
5698 // tau = 0
5699 // v = np.zeros_like(x)
5700 // else:
5701 // beta = - np.sign(alpha) * dlapy2(alpha, xnorm)
5702 // tau = (beta - alpha) / beta
5703 // v = x / (alpha - beta)
5704 // v[k] = 1
5705 // return (v, tau, beta)
House(Location loc,Value x,Value k,ArrayRef<int64_t> batch_dims,const int64_t m,OpBuilder * builder,Value * v,Value * tau,Value * beta) const5706 void House(Location loc, Value x, Value k, ArrayRef<int64_t> batch_dims,
5707 const int64_t m, OpBuilder *builder, Value *v, Value *tau,
5708 Value *beta) const {
5709 auto x_type = x.getType().cast<RankedTensorType>();
5710
5711 llvm::SmallVector<int64_t, 4> batch_dim_ids(batch_dims.size());
5712 std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0);
5713 const int64_t minor_dim = batch_dims.size();
5714
5715 Value zero = GetScalarConstOfType(x_type.getElementType(), loc, 0, builder);
5716 Value one = GetScalarConstOfType(x_type.getElementType(), loc, 1, builder);
5717
5718 // alpha = x[k]
5719 Value alpha = DynamicSliceInMinorDims(loc, x, {k}, {1}, builder);
5720 alpha = builder->create<ReshapeOp>(
5721 loc, RankedTensorType::get(batch_dims, x_type.getElementType()), alpha);
5722
5723 // Compute x[k+1:] (padded with zeros in elements 0..k)
5724 Value iota = builder->create<IotaOp>(
5725 loc, RankedTensorType::get({m}, builder->getIntegerType(32)),
5726 builder->getI64IntegerAttr(0));
5727 Value gtk = builder->create<chlo::BroadcastCompareOp>(
5728 loc, iota, k, GetI64ElementsAttr({}, builder),
5729 StringAttr::get(builder->getContext(), "GT"));
5730 gtk = builder->create<ConvertOp>(loc, gtk, x_type.getElementType());
5731 Value x_after_k = builder->create<chlo::BroadcastMulOp>(
5732 loc, x, gtk, GetI64ElementsAttr({minor_dim}, builder));
5733 Value x_after_k_sq = builder->create<MulOp>(loc, x_after_k, x_after_k);
5734 // sigma = np.dot(x[k+1:], x[k+1:])
5735 auto sigma = builder->create<ReduceOp>(
5736 loc, x_after_k_sq, zero, GetI64ElementsAttr({minor_dim}, builder));
5737 BuildReduceBody<AddOp>(x_type.getElementType(), &sigma.body(), builder);
5738 // mu = np.sqrt(x[k]*x[k] + sigma)
5739 Value alpha_sq = builder->create<MulOp>(loc, alpha, alpha);
5740 Value mu = builder->create<SqrtOp>(
5741 loc, builder->create<AddOp>(loc, alpha_sq, sigma.getResult(0)));
5742
5743 Value sigma_is_zero = builder->create<chlo::BroadcastCompareOp>(
5744 loc, sigma.getResult(0), zero, GetI64ElementsAttr({}, builder),
5745 StringAttr::get(builder->getContext(), "EQ"));
5746 Value alpha_is_negative = builder->create<chlo::BroadcastCompareOp>(
5747 loc, alpha, zero, GetI64ElementsAttr({}, builder),
5748 StringAttr::get(builder->getContext(), "LT"));
5749 auto batch_size_one = builder->create<BroadcastOp>(
5750 loc, alpha.getType(), one, GetI64ElementsAttr(batch_dims, builder));
5751 Value signed_mu = builder->create<chlo::BroadcastMulOp>(
5752 loc,
5753 builder->create<SelectOp>(loc, mu.getType(), alpha_is_negative,
5754 batch_size_one,
5755 builder->create<NegOp>(loc, batch_size_one)),
5756 mu, GetI64ElementsAttr({}, builder));
5757 *beta = builder->create<SelectOp>(loc, alpha.getType(), sigma_is_zero,
5758 alpha, signed_mu);
5759 *tau = builder->create<DivOp>(
5760 loc, builder->create<SubOp>(loc, *beta, alpha), *beta);
5761 Value zero_tau = builder->create<BroadcastOp>(
5762 loc, alpha.getType(), zero, GetI64ElementsAttr(batch_dims, builder));
5763 *tau = builder->create<SelectOp>(loc, alpha.getType(), sigma_is_zero,
5764 zero_tau, *tau);
5765 Value divisor = builder->create<SubOp>(loc, alpha, *beta);
5766 divisor = builder->create<SelectOp>(loc, divisor.getType(), sigma_is_zero,
5767 batch_size_one, divisor);
5768
5769 Value eqk = builder->create<chlo::BroadcastCompareOp>(
5770 loc, iota, k, GetI64ElementsAttr({}, builder),
5771 StringAttr::get(builder->getContext(), "EQ"));
5772 eqk = builder->create<ConvertOp>(loc, eqk, x_type.getElementType());
5773 llvm::SmallVector<int64_t, 4> e_k_shape(batch_dims.size(), 1);
5774 e_k_shape.push_back(m);
5775 auto e_k = builder->create<BroadcastOp>(
5776 loc, RankedTensorType::get(e_k_shape, x_type.getElementType()), eqk,
5777 GetI64ElementsAttr(llvm::SmallVector<int64_t, 4>(batch_dims.size(), 1),
5778 builder));
5779
5780 // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor
5781 // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor.
5782 // Note that the add performs a degenerate broadcast.
5783 *v = builder->create<chlo::BroadcastAddOp>(
5784 loc, e_k,
5785 StaticBinaryBroadcast<DivOp>(loc, x_after_k, divisor,
5786 GetI64ElementsAttr(batch_dim_ids, builder),
5787 *builder),
5788 /*broadcast_dimensions=*/nullptr);
5789 }
5790
5791 // Householder QR decomposition. Algorithm 5.2.1 from Golub and Van
5792 // Loan "Matrix Computations", 4th Edition. This is an unblocked
5793 // implementation used as an inner routine of the blocked implementation.
5794 // Algorithm is adapted slightly so the shapes inside the loop are static, at
5795 // the cost of some redundant computation. Since this is used as an inner
5796 // block kernel, accumulates the Householder transformations (vs, taus) rather
5797 // than the matrix q. Equivalent Python code, without batching: def qr(a):
5798 // m = a.shape[0]
5799 // n = a.shape[1]
5800 // vs = np.zeros([m, n])
5801 // taus = np.zeros([n])
5802 // for j in xrange(min(m, n)):
5803 // v, tau, beta = house(a[:, j], j)
5804 // # Unusually, we apply the Householder transformation to the entirety of
5805 // # a, wasting FLOPs to maintain the static shape invariant that XLA
5806 // # requires. For columns that precede j this has no effect.
5807 // a[:, :] -= tau * np.dot(v[:, np.newaxis],
5808 // np.dot(v[np.newaxis, :], a[:, :]))
5809 // # Form column j explicitly rather than relying on the precision of the
5810 // # Householder update.
5811 // a[j, j] = beta
5812 // a[j+1:, j] = np.zeros([m - j - 1], dtype=a.dtype)
5813 // vs[:, j] = v
5814 // taus[j] = tau
5815 // return (q, vs, taus)
QRBlock(Location loc,Value a,Value * r,Value * taus,Value * vs,PatternRewriter * rewriter) const5816 void QRBlock(Location loc, Value a, Value *r, Value *taus, Value *vs,
5817 PatternRewriter *rewriter) const {
5818 auto a_type = a.getType().cast<RankedTensorType>();
5819 const int num_dims = a_type.getRank();
5820 assert(num_dims >= 2 && "Argument to QR must have rank >= 2");
5821
5822 const int64_t m = a_type.getDimSize(a_type.getRank() - 2);
5823 const int64_t n = a_type.getDimSize(a_type.getRank() - 1);
5824
5825 const int64_t num_batch_dims = num_dims - 2;
5826 auto batch_dims = a_type.getShape().take_front(num_batch_dims);
5827 llvm::SmallVector<int64_t, 4> batch_dim_indices(batch_dims.size());
5828 std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
5829
5830 auto qr_body_fn = [&](Location loc, Value j, ArrayRef<Value> old_values,
5831 SmallVectorImpl<Value> *new_values,
5832 OpBuilder *builder) {
5833 auto a = old_values[0];
5834 auto vs = old_values[1];
5835 auto taus = old_values[2];
5836
5837 // v, beta = house(a[:, j], j)
5838 auto x = DynamicSliceInMinorDims(loc, a, {j}, {1}, builder);
5839 auto x_collapsed_shape = llvm::to_vector<4>(batch_dims);
5840 x_collapsed_shape.push_back(m);
5841 auto x_collapsed = builder->create<ReshapeOp>(
5842 loc,
5843 RankedTensorType::get(x_collapsed_shape,
5844 getElementTypeOrSelf(x.getType())),
5845 x);
5846 Value v, tau, beta;
5847 House(loc, x_collapsed, j, batch_dims, m, builder, &v, &tau, &beta);
5848
5849 auto shape = llvm::to_vector<4>(batch_dims);
5850 shape.append({1, m});
5851 auto v_broadcast = builder->create<ReshapeOp>(
5852 loc, RankedTensorType::get(shape, getElementTypeOrSelf(v.getType())),
5853 v);
5854 // a[:, :] -= tau * np.dot(v[:, np.newaxis],
5855 // np.dot(v[np.newaxis, :], a[:, :]))
5856 auto precision = builder->getStrArrayAttr({"HIGHEST", "HIGHEST"});
5857 auto vva = BatchDot(loc, v_broadcast, false, a, false, num_batch_dims,
5858 precision, builder);
5859 vva = BatchDot(loc, v_broadcast, true, vva, false, num_batch_dims,
5860 precision, builder);
5861 auto tau_x_vva = StaticBinaryBroadcast<mhlo::MulOp>(
5862 loc, tau, vva, GetI64ElementsAttr(batch_dim_indices, builder),
5863 *builder);
5864 a = builder->create<SubOp>(loc, a, tau_x_vva);
5865
5866 // It is more precise to populate column 'k' explicitly, rather than
5867 // computing it implicitly by applying the Householder transformation.
5868 // a[k,k] = beta
5869 // a[k+1:,k] = np.zeros([m-k-1], dtype=a.dtype)
5870 auto iota = builder->create<IotaOp>(
5871 loc, RankedTensorType::get({m, 1}, builder->getIntegerType(32)),
5872 builder->getI64IntegerAttr(0));
5873 Value predecessor_mask = builder->create<chlo::BroadcastCompareOp>(
5874 loc, iota, j, GetI64ElementsAttr({}, builder),
5875 StringAttr::get(builder->getContext(), "LT"));
5876 predecessor_mask = builder->create<ConvertOp>(loc, predecessor_mask,
5877 a_type.getElementType());
5878 Value mask = builder->create<chlo::BroadcastCompareOp>(
5879 loc, iota, j, GetI64ElementsAttr({}, builder),
5880 StringAttr::get(builder->getContext(), "EQ"));
5881 mask = builder->create<ConvertOp>(loc, mask, a_type.getElementType());
5882 llvm::SmallVector<int64_t, 4> broadcast_mask_shape(a_type.getRank(), 1);
5883 broadcast_mask_shape[a_type.getRank() - 2] = m;
5884 mask = builder->create<BroadcastOp>(
5885 loc,
5886 RankedTensorType::get(broadcast_mask_shape, a_type.getElementType()),
5887 mask,
5888 GetI64ElementsAttr(llvm::SmallVector<int64_t, 4>(num_batch_dims, 1),
5889 builder));
5890 Value predecessor_masked_x = StaticBinaryBroadcast<MulOp>(
5891 loc, x, predecessor_mask,
5892 GetI64ElementsAttr({num_dims - 2, num_dims - 1}, builder), *builder);
5893 Value masked_beta = StaticBinaryBroadcast<MulOp>(
5894 loc, beta, mask, GetI64ElementsAttr(batch_dim_indices, builder),
5895 *builder);
5896 Value new_x =
5897 builder->create<AddOp>(loc, predecessor_masked_x, masked_beta);
5898 // Update a[:,j]
5899 llvm::SmallVector<int64_t, 4> dim_ids(num_dims);
5900 std::iota(dim_ids.begin(), dim_ids.end(), 0);
5901 new_x = builder->create<BroadcastInDimOp>(
5902 loc, a_type, new_x, GetI64ElementsAttr(dim_ids, builder));
5903 const int64_t minor_dim = num_batch_dims;
5904 auto iota_mn = builder->create<IotaOp>(
5905 loc,
5906 RankedTensorType::get(a_type.getShape(), builder->getIntegerType(32)),
5907 builder->getI64IntegerAttr(minor_dim + 1));
5908 Value xa_mask = builder->create<chlo::BroadcastCompareOp>(
5909 loc, iota_mn, j, GetI64ElementsAttr({}, builder),
5910 StringAttr::get(builder->getContext(), "EQ"));
5911 a = builder->create<SelectOp>(loc, a_type, xa_mask, new_x, a);
5912
5913 // vs[:, j] = v
5914 llvm::SmallVector<int64_t, 4> vs_broadcast_dims(num_batch_dims + 1);
5915 std::iota(vs_broadcast_dims.begin(), vs_broadcast_dims.end(), 0);
5916 Value vs_zeros =
5917 GetScalarConstOfType(a_type.getElementType(), loc, 0, builder);
5918 vs_zeros = builder->create<BroadcastOp>(
5919 loc, vs.getType(), vs_zeros,
5920 GetI64ElementsAttr(vs.getType().cast<RankedTensorType>().getShape(),
5921 builder));
5922 auto vs_update = builder->create<SelectOp>(
5923 loc, vs.getType(), xa_mask,
5924 StaticBinaryBroadcast<AddOp>(
5925 loc, vs_zeros, v, GetI64ElementsAttr(vs_broadcast_dims, builder),
5926 *builder),
5927 vs_zeros);
5928 vs = builder->create<AddOp>(loc, vs, vs_update);
5929
5930 // taus[j] = tau
5931 llvm::SmallVector<int64_t, 4> tau_broadcast_dims(batch_dims.size());
5932 std::iota(tau_broadcast_dims.begin(), tau_broadcast_dims.end(), 0);
5933
5934 auto iota_shape = llvm::to_vector<4>(batch_dims);
5935 iota_shape.push_back(n);
5936 auto iota_n = builder->create<IotaOp>(
5937 loc, RankedTensorType::get(iota_shape, builder->getIntegerType(32)),
5938 builder->getI64IntegerAttr(minor_dim));
5939 Value taus_zeros =
5940 GetScalarConstOfType(a_type.getElementType(), loc, 0, builder);
5941 taus_zeros = builder->create<BroadcastOp>(
5942 loc, taus.getType(), taus_zeros,
5943 GetI64ElementsAttr(taus.getType().cast<RankedTensorType>().getShape(),
5944 builder));
5945 Value taus_mask = builder->create<chlo::BroadcastCompareOp>(
5946 loc, iota_n, j, GetI64ElementsAttr({}, builder),
5947 StringAttr::get(builder->getContext(), "EQ"));
5948 auto taus_update = builder->create<SelectOp>(
5949 loc, taus.getType(), taus_mask,
5950 StaticBinaryBroadcast<AddOp>(
5951 loc, taus_zeros, tau,
5952 GetI64ElementsAttr(tau_broadcast_dims, builder), *builder),
5953 taus_zeros);
5954 taus = builder->create<AddOp>(loc, taus, taus_update);
5955 new_values->assign({a, vs, taus});
5956 };
5957
5958 Value zero =
5959 GetScalarConstOfType(a_type.getElementType(), loc, 0, rewriter);
5960 *vs = rewriter->create<BroadcastOp>(
5961 loc, a_type, zero, GetI64ElementsAttr(a_type.getShape(), rewriter));
5962 auto taus_shape = llvm::to_vector<4>(batch_dims);
5963 taus_shape.push_back(n);
5964 *taus = rewriter->create<BroadcastOp>(
5965 loc, RankedTensorType::get(taus_shape, a_type.getElementType()), zero,
5966 GetI64ElementsAttr(taus_shape, rewriter));
5967
5968 SmallVector<Value, 4> while_output;
5969 CreateWhile32(loc, std::min(m, n), qr_body_fn, {a, *vs, *taus},
5970 &while_output, rewriter);
5971 *r = while_output[0];
5972 *vs = while_output[1];
5973 *taus = while_output[2];
5974 }
5975
5976 // Computes W and Y such that I-WY is equivalent to the sequence of
5977 // Householder
5978 // transformations given by vs and taus.
5979 // Golub and van Loan, "Matrix Computations", algorithm 5.1.2.
5980 // Y = np.zeros([m, n])
5981 // W = np.zeros([m, n])
5982 // Y[:, 0] = vs[:, 0]
5983 // W[:, 0] = -taus[0] * vs[:, 0]
5984 // for j in xrange(1, n):
5985 // v = vs[:, j]
5986 // z = -taus[j] * v - taus[j] * np.dot(W, np.dot(Y.T, v))
5987 // W[:, j] = z
5988 // Y[:, j] = v
5989 // return W
5990 // There is no need to return Y since at termination of the loop it is equal
5991 // to vs.
ComputeWYRepresentation(Location loc,Type data_type,ArrayRef<int64_t> batch_dims,Value vs,Value taus,int64_t m,int64_t n,PatternRewriter * rewriter) const5992 Value ComputeWYRepresentation(Location loc, Type data_type,
5993 ArrayRef<int64_t> batch_dims, Value vs,
5994 Value taus, int64_t m, int64_t n,
5995 PatternRewriter *rewriter) const {
5996 int64_t n_index = batch_dims.size() + 1;
5997 llvm::SmallVector<int64_t, 4> batch_dim_indices(batch_dims.size());
5998 std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
5999
6000 auto body_fn = [&](Location loc, Value j, ArrayRef<Value> old_values,
6001 SmallVectorImpl<Value> *new_values, OpBuilder *builder) {
6002 // w has shape [..., m, n]
6003 auto w = old_values[0];
6004 const auto vs = old_values[1];
6005 const auto taus = old_values[2];
6006
6007 // Want j values in range [1, ... n).
6008 j = builder->create<AddOp>(
6009 loc, j,
6010 GetScalarConstOfType(getElementTypeOrSelf(j.getType()), loc, 1,
6011 builder));
6012 // vs has shape [..., m, 1]
6013 auto v = DynamicSliceInMinorDims(loc, vs, {j}, {1}, builder);
6014 // beta has shape [..., 1]
6015 auto beta = DynamicSliceInMinorDims(loc, taus, {j}, {1}, builder);
6016
6017 auto iota_shape = llvm::to_vector<4>(batch_dims);
6018 iota_shape.append({m, n});
6019 auto iota_mn = builder->create<IotaOp>(
6020 loc, RankedTensorType::get(iota_shape, builder->getIntegerType(32)),
6021 builder->getI64IntegerAttr(n_index));
6022
6023 // y has shape [..., m, n]
6024 Value zero = GetScalarConstOfType(getElementTypeOrSelf(vs.getType()), loc,
6025 0, builder);
6026 zero = builder->create<BroadcastOp>(
6027 loc, vs.getType(), zero,
6028 GetI64ElementsAttr(vs.getType().cast<RankedTensorType>().getShape(),
6029 builder));
6030 auto compare = builder->create<chlo::BroadcastCompareOp>(
6031 loc, iota_mn, j, GetI64ElementsAttr({}, builder),
6032 StringAttr::get(builder->getContext(), "GE"));
6033 auto y = builder->create<SelectOp>(loc, vs.getType(), compare, zero, vs);
6034
6035 // yv has shape [..., n, 1]
6036 auto precision = builder->getStrArrayAttr({"HIGHEST", "HIGHEST"});
6037 auto yv = BatchDot(loc, y, true, v, false, batch_dims.size(), precision,
6038 builder);
6039 // wyv has shape [..., m, 1]
6040 auto wyv = BatchDot(loc, w, false, yv, false, batch_dims.size(),
6041 precision, builder);
6042
6043 // z = -beta * (v + wyv)
6044 auto neg_beta = builder->create<NegOp>(loc, beta);
6045 auto v_wyv = builder->create<AddOp>(loc, v, wyv);
6046 auto beta_broadcast_dims = llvm::to_vector<4>(batch_dim_indices);
6047 beta_broadcast_dims.push_back(n_index);
6048 auto z = StaticBinaryBroadcast<MulOp>(
6049 loc, neg_beta, v_wyv,
6050 GetI64ElementsAttr(beta_broadcast_dims, builder), *rewriter);
6051
6052 w = DynamicUpdateSliceInMinorDims(loc, w, z, {j}, builder);
6053 new_values->assign({w, vs, taus});
6054 };
6055
6056 Value w =
6057 GetScalarConstOfType(getElementTypeOrSelf(data_type), loc, 0, rewriter);
6058 auto w_shape = llvm::to_vector<4>(batch_dims);
6059 w_shape.append({m, n});
6060 w = rewriter->create<BroadcastOp>(loc,
6061 RankedTensorType::get(w_shape, data_type),
6062 w, GetI64ElementsAttr(w_shape, rewriter));
6063 auto v = SliceInMinorDims(loc, vs, {0}, {1}, rewriter);
6064 auto beta = SliceInMinorDims(loc, taus, {0}, {1}, rewriter);
6065 auto neg_beta = rewriter->create<NegOp>(loc, beta);
6066 auto beta_broadcast_dims = llvm::to_vector<4>(batch_dim_indices);
6067 beta_broadcast_dims.push_back(n_index);
6068 auto bv = StaticBinaryBroadcast<MulOp>(
6069 loc, neg_beta, v, GetI64ElementsAttr(beta_broadcast_dims, rewriter),
6070 *rewriter);
6071 w = UpdateSliceInMinorDims(loc, w, bv, {0}, rewriter);
6072
6073 SmallVector<Value, 4> while_output;
6074 CreateWhile32(loc, n - 1, body_fn, {w, vs, taus}, &while_output, rewriter);
6075 return while_output[0];
6076 }
6077 };
6078
6079 // Emits debug information which includes the number of ops of each type which
6080 // failed to legalize.
EmitLegalizationErrors(Operation * op,const DenseSet<Operation * > & nonlegalized_ops)6081 void EmitLegalizationErrors(Operation *op,
6082 const DenseSet<Operation *> &nonlegalized_ops) {
6083 // Track the legalization failures by mapping op name to information about
6084 // that failure: the number of unlegalized occurrences of the op, and one
6085 // example operation that failed.
6086 std::map<StringRef, std::pair<int, Operation *>> op_name_to_error_info;
6087 DenseSet<Operation *> error_ops;
6088 for (Operation *nonlegalized_op : nonlegalized_ops) {
6089 // Increment count of this legalization failure.
6090 StringRef op_name = nonlegalized_op->getName().getStringRef();
6091 // If this emplace is successful, it's the first time we've encountered
6092 // this op type. Initialize count to 0 so that after increment, it is 1.
6093 auto insertion_result = op_name_to_error_info.emplace(
6094 op_name, std::make_pair(0, nonlegalized_op));
6095 ++insertion_result.first->second.first;
6096 }
6097 std::vector<std::string> error_messages;
6098 error_messages.reserve(op_name_to_error_info.size());
6099 for (const auto &op_info : op_name_to_error_info) {
6100 error_messages.push_back(
6101 llvm::formatv("{0} (count: {1})", op_info.first, op_info.second.first));
6102 }
6103 Location loc = op->getLoc();
6104 emitError(loc) << "The following operations cannot be legalized: "
6105 << llvm::join(error_messages, "; ")
6106 << ". These legalization failure(s) may be due to missing TF "
6107 "to HLO lowerings and/or unsupported attributes, etc.";
6108 // Emit more information about the missing ops. This error message
6109 // contains useful details beyond the op name (input and output shapes,
6110 // attributes, etc.).
6111 if (!VLOG_IS_ON(1) && nonlegalized_ops.size() != 1) {
6112 emitError(loc)
6113 << "Emitting more detail about one op that failed to legalize...";
6114 } else if (VLOG_IS_ON(1)) {
6115 emitError(loc) << "Emitting more detail about one of each type of op "
6116 "that failed to legalize...";
6117 }
6118 for (const auto &op_info : op_name_to_error_info) {
6119 op_info.second.second->emitOpError() << "is not legalizable";
6120 if (!VLOG_IS_ON(1)) break;
6121 }
6122 }
6123
6124 // Performs the lowering to XLA dialect.
runOnFunction()6125 void LegalizeTF::runOnFunction() {
6126 llvm::Optional<StringRef> tf2xla_fallback_device_type = llvm::None;
6127 if (use_tf2xla_fallback_) {
6128 tf2xla_fallback_device_type = device_type_;
6129 }
6130 if (failed(legalizeTF(getFunction(), allow_partial_conversion_,
6131 legalize_chlo_, tf2xla_fallback_device_type))) {
6132 signalPassFailure();
6133 }
6134 }
6135
6136 static PassRegistration<LegalizeTF> pass(
6137 "xla-legalize-tf", "Legalize from TensorFlow to the XLA dialect");
6138
6139 } // end namespace
6140
6141 #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc"
6142
legalizeTF(Operation * op,bool allow_partial_conversion,bool legalize_chlo,llvm::Optional<StringRef> tf2xla_fallback_device_type)6143 LogicalResult legalizeTF(
6144 Operation *op, bool allow_partial_conversion, bool legalize_chlo,
6145 llvm::Optional<StringRef> tf2xla_fallback_device_type) {
6146 MLIRContext *context = op->getContext();
6147 OwningRewritePatternList patterns;
6148 // Note that the `OperationConverter` orders patterns lexicographically by:
6149 // 1) Ascending legalization depth (i.e., minimum number of patterns necessary
6150 // to arrive at conversion target). This requires relevant patterns to
6151 // specify the list of ops generated by it which most of patterns
6152 // implemented in C++ don't do so this comparison doesn't work in those
6153 // cases.
6154 // 2) Descending pattern benefit.
6155 // 3) Op specific patterns over patterns with MatchAnyOpTypeTag.
6156 // 4) Order of patterns in `OwningRewritePatternList`.
6157
6158 // Add TF->HLO legalization patterns.
6159 PopulateLegalizeTfPatterns(context, &patterns);
6160
6161 // Add TF->TF lowering patterns.
6162 TF::PopulateTFLoweringBeforeHLOPatterns(context, &patterns);
6163
6164 // Add TF->HLO legalization patterns via TF2XLA fallback.
6165 if (tf2xla_fallback_device_type.hasValue()) {
6166 PopulateLegalizeTfWithTf2XlaPatterns(tf2xla_fallback_device_type.getValue(),
6167 patterns);
6168 }
6169
6170 // Populate with CHLO->HLO lowerings to account for TF ops legalized to
6171 // CHLO first.
6172 if (legalize_chlo) {
6173 chlo::PopulateLegalizeChloToHloPatterns(context, &patterns);
6174 }
6175 // ConstantLike op is convenient to create splat constants, but is
6176 // canonicalized to plain HLO constant if statically shaped. Add the
6177 // canonicalization pattern to pattern list to enable multi-hop lowering.
6178 chlo::ConstantLikeOp::getCanonicalizationPatterns(patterns, context);
6179
6180 ConversionTarget target(*context);
6181 if (legalize_chlo) {
6182 target.addIllegalDialect<chlo::HloClientDialect>();
6183 } else {
6184 target.addLegalDialect<chlo::HloClientDialect>();
6185 }
6186 target.addLegalDialect<MhloDialect>();
6187 target.addLegalDialect<StandardOpsDialect>();
6188 target.addLegalDialect<tensor::TensorDialect>();
6189 target.addLegalDialect<shape::ShapeDialect>();
6190 target.addLegalOp<CallOp>();
6191
6192 if (!allow_partial_conversion) {
6193 // Fully qualify ReturnOp here as mhlo dialect also defines a ReturnOp.
6194 target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ::mlir::ReturnOp>();
6195 DenseSet<Operation *> nonlegalized_ops;
6196 LogicalResult result = applyPartialConversion(
6197 op, target, std::move(patterns), &nonlegalized_ops);
6198 // In order to enforce that the conversion result is fully converted,
6199 // fail if there are any nonlegalized ops in the set.
6200 if (failed(result) || !nonlegalized_ops.empty()) {
6201 EmitLegalizationErrors(op, nonlegalized_ops);
6202 return failure();
6203 }
6204 return result;
6205 }
6206
6207 return applyPartialConversion(op, target, std::move(patterns));
6208 }
6209
PopulateLegalizeTfPatterns(MLIRContext * context,OwningRewritePatternList * patterns)6210 void PopulateLegalizeTfPatterns(MLIRContext *context,
6211 OwningRewritePatternList *patterns) {
6212 populateWithGenerated(context, *patterns);
6213 patterns->insert<
6214 ConvertAllOp, ConvertAnyOp, ConvertArgMaxOp, ConvertBatchMatMulV2Op,
6215 ConvertBiasAddOp, ConvertBroadcastToOp, ConvertBF16FloorDivOp,
6216 ConvertClipByValueOp, ConvertConv2DOp, ConvertConv3DOp,
6217 ConvertDepthConv2DOp, ConvertConv2DBackpropFilterOp,
6218 ConvertConv3DBackpropFilterOp, ConvertConv2DBackpropInputOp,
6219 ConvertConv3DBackpropInputOp, ConvertCumprodOp, ConvertCumsumOp,
6220 ConvertDiagPartOp, ConvertDynamicExpandDimsOp, ConvertDynamicReshapeOp,
6221 ConvertEinsumOp, ConvertRFFTOp, ConvertIRFFTOp,
6222 ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op,
6223 ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV2Op,
6224 ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp,
6225 ConvertIdentityNOp, ConvertInplaceUpdateOp, ConvertLinSpaceOp,
6226 ConvertMaxOp, ConvertMinOp, ConvertAvgPool2DOp, ConvertAvgPool3DOp,
6227 ConvertAvgPool2DGradOp, ConvertAvgPool3DGradOp, ConvertMaxPool2DOp,
6228 ConvertMaxPool3DOp, ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp,
6229 ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp,
6230 ConvertProdOp, ConvertQrOp, ConvertDynamicRangeOp,
6231 ConvertMatrixDiagPartV3Op, ConvertRangeOp, ConvertSelectV2Op,
6232 ConvertSigmoidOp, ConvertShapeOp,
6233 ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
6234 ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
6235 ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
6236 ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,
6237 ConvertUnpackOp, ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp,
6238 ConvertUnsortedSegmentProdOp, ConvertUnsortedSegmentSumOp,
6239 ConvertRandomShuffleOp, ConvertXlaShardingOp,
6240 ConvertXlaDynamicUpdateSliceOp, ConvertXlaAllReduceOp>(context);
6241 }
6242
createLegalizeTFPass(bool allow_partial_conversion,bool legalize_chlo,llvm::Optional<StringRef> tf2xla_fallback_device_type)6243 std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass(
6244 bool allow_partial_conversion, bool legalize_chlo,
6245 llvm::Optional<StringRef> tf2xla_fallback_device_type) {
6246 return std::make_unique<LegalizeTF>(allow_partial_conversion, legalize_chlo,
6247 tf2xla_fallback_device_type);
6248 }
6249
6250 } // end namespace mhlo
6251 } // end namespace mlir
6252