1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/SmallVector.h"
17 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/MLIRContext.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/IR/Types.h"
26 #include "mlir/Transforms/DialectConversion.h"
27
28 namespace mlir {
29 namespace mhlo {
30
31 namespace {
32
33 // Broadcasts the 1D value tensor 'value_1d' to the shape of 'result_type'. If
34 // 'shape_value' is initialized, creates a dynamic broadcast, otherwise creates
35 // a static broadcast.
BroadcastToFeatureDim(Location loc,RankedTensorType result_type,Value value_1d,Value shape_value,int64_t feature_dim,PatternRewriter & rewriter)36 Value BroadcastToFeatureDim(Location loc, RankedTensorType result_type,
37 Value value_1d, Value shape_value,
38 int64_t feature_dim,
39 PatternRewriter& rewriter) { // NOLINT
40 Builder b(rewriter.getContext());
41 auto dims_type = RankedTensorType::get({1}, b.getIntegerType(64));
42 auto dims = DenseIntElementsAttr::get(dims_type, {feature_dim});
43 if (shape_value) {
44 return rewriter.createOrFold<mhlo::DynamicBroadcastInDimOp>(
45 loc, result_type, value_1d, shape_value, dims);
46 }
47 assert(result_type.hasStaticShape());
48 return rewriter.create<mhlo::BroadcastInDimOp>(loc, result_type, value_1d,
49 dims);
50 }
51
52 // Calculate the shape value of operand, assuming it is a dynamic shape with
53 // static rank.
CalculateShapeValue(Location loc,Value operand,PatternRewriter & rewriter)54 Value CalculateShapeValue(Location loc, Value operand,
55 PatternRewriter& rewriter) { // NOLINT
56 RankedTensorType result_type = operand.getType().dyn_cast<RankedTensorType>();
57 llvm::SmallVector<Value, 4> shape_values;
58 int64_t rank = result_type.getRank();
59 shape_values.reserve(rank);
60 for (int64_t i = 0; i < rank; ++i) {
61 shape_values.push_back(rewriter.create<mlir::DimOp>(loc, operand, i));
62 }
63 return rewriter.create<tensor::FromElementsOp>(loc, shape_values);
64 }
65
MaterializeEpsilon(Operation * op,FloatAttr epsilon_attr,FloatType fp_type,Value variance,RankedTensorType broadcast_to_type,PatternRewriter & rewriter)66 Value MaterializeEpsilon(Operation* op, FloatAttr epsilon_attr,
67 FloatType fp_type, Value variance,
68 RankedTensorType broadcast_to_type,
69 PatternRewriter& rewriter) { // NOLINT
70 Builder b(rewriter.getContext());
71 if (epsilon_attr.getType() != fp_type) {
72 // Need to convert.
73 bool loses_info;
74 APFloat epsilon_float = epsilon_attr.getValue();
75 auto status = epsilon_float.convert(
76 fp_type.getFloatSemantics(), APFloat::rmNearestTiesToEven, &loses_info);
77 if ((status & (~APFloat::opInexact)) != APFloat::opOK) {
78 op->emitWarning() << "Could not convert batch_norm epsilon to target fp "
79 "type: opStatus = "
80 << static_cast<int>(status);
81 return nullptr;
82 }
83 if (loses_info) {
84 op->emitWarning("Conversion of epsilon loses precision");
85 }
86 epsilon_attr = b.getFloatAttr(fp_type, epsilon_float);
87 }
88
89 auto scalar_type = RankedTensorType::get({}, fp_type);
90 auto epsilon_tensor_attr =
91 DenseElementsAttr::get(scalar_type, {epsilon_attr.cast<Attribute>()});
92 Value epsilon =
93 rewriter.create<mhlo::ConstOp>(op->getLoc(), epsilon_tensor_attr);
94 auto dims_type = RankedTensorType::get({0}, b.getIntegerType(64));
95 auto dims = DenseIntElementsAttr::get(dims_type, SmallVector<int64_t, 1>{});
96 if (broadcast_to_type.hasStaticShape()) {
97 return rewriter.create<mhlo::BroadcastInDimOp>(
98 op->getLoc(), broadcast_to_type, epsilon, /*broadcast_dims=*/dims);
99 }
100 Value shape_value = CalculateShapeValue(op->getLoc(), variance, rewriter);
101 return rewriter.createOrFold<mhlo::DynamicBroadcastInDimOp>(
102 op->getLoc(), broadcast_to_type, epsilon, shape_value,
103 /*broadcast_dims=*/dims);
104 }
105
106 class UnfuseBatchNormInferencePattern
107 : public OpRewritePattern<mhlo::BatchNormInferenceOp> {
108 public:
109 using OpRewritePattern<mhlo::BatchNormInferenceOp>::OpRewritePattern;
110
matchAndRewrite(mhlo::BatchNormInferenceOp bn_op,PatternRewriter & rewriter) const111 LogicalResult matchAndRewrite(mhlo::BatchNormInferenceOp bn_op,
112 PatternRewriter& rewriter) const override {
113 // Enforce type invariants.
114 // Note that we deduce the actual element type from the variance,
115 // which should not be subject to quantization at a higher level.
116 auto input_type = bn_op.operand().getType().dyn_cast<RankedTensorType>();
117 auto variance_type =
118 bn_op.variance().getType().dyn_cast<RankedTensorType>();
119 if (!input_type || !variance_type) {
120 return failure();
121 }
122 auto fp_type = variance_type.getElementType().dyn_cast<FloatType>();
123 if (!fp_type) {
124 return failure();
125 }
126 int64_t feature_dim = bn_op.feature_index();
127
128 // Add epsilon to the variance and sqrt to get stddev:
129 // stddev = sqrt(variance + epsilon)
130 auto epsilon =
131 MaterializeEpsilon(bn_op.getOperation(), bn_op.epsilonAttr(), fp_type,
132 bn_op.variance(), variance_type, rewriter);
133 if (!epsilon) {
134 return failure();
135 }
136 Value stddev =
137 rewriter.create<mhlo::AddOp>(bn_op.getLoc(), bn_op.variance(), epsilon);
138 stddev = rewriter.create<mhlo::SqrtOp>(bn_op.getLoc(), stddev);
139
140 // Broadcast all terms.
141 Value shape_value;
142 if (!input_type.hasStaticShape()) {
143 shape_value =
144 CalculateShapeValue(bn_op.getLoc(), bn_op.operand(), rewriter);
145 }
146 auto broadcast_scale =
147 BroadcastToFeatureDim(bn_op.getLoc(), input_type, bn_op.scale(),
148 shape_value, feature_dim, rewriter);
149 auto broadcast_offset =
150 BroadcastToFeatureDim(bn_op.getLoc(), input_type, bn_op.offset(),
151 shape_value, feature_dim, rewriter);
152 auto broadcast_mean =
153 BroadcastToFeatureDim(bn_op.getLoc(), input_type, bn_op.mean(),
154 shape_value, feature_dim, rewriter);
155 auto broadcast_stddev = BroadcastToFeatureDim(
156 bn_op.getLoc(), input_type, stddev, shape_value, feature_dim, rewriter);
157
158 // Compute:
159 // scale * (input - mean) / stddev + offset
160 Value result = rewriter.create<mhlo::SubOp>(bn_op.getLoc(), bn_op.operand(),
161 broadcast_mean);
162 result =
163 rewriter.create<mhlo::MulOp>(bn_op.getLoc(), result, broadcast_scale);
164 result =
165 rewriter.create<mhlo::DivOp>(bn_op.getLoc(), result, broadcast_stddev);
166 rewriter.replaceOpWithNewOp<mhlo::AddOp>(bn_op, result, broadcast_offset);
167
168 return success();
169 }
170 };
171
172 } // namespace
173
174 // Populates conversion patterns to unfuse batch normalization operations.
175 // In combination with marking such ops as illegal, this allows backends that
176 // do not have special support for fused batchnorm to use simpler arithmetic
177 // primitives.
PopulateUnfuseBatchNormPatterns(MLIRContext * context,OwningRewritePatternList * patterns)178 void PopulateUnfuseBatchNormPatterns(MLIRContext* context,
179 OwningRewritePatternList* patterns) {
180 patterns->insert<UnfuseBatchNormInferencePattern>(context);
181 }
182
183 } // namespace mhlo
184 } // namespace mlir
185