1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/SmallVector.h"
17 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/MLIRContext.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/IR/Types.h"
26 #include "mlir/Transforms/DialectConversion.h"
27 
28 namespace mlir {
29 namespace mhlo {
30 
31 namespace {
32 
33 // Broadcasts the 1D value tensor 'value_1d' to the shape of 'result_type'. If
34 // 'shape_value' is initialized, creates a dynamic broadcast, otherwise creates
35 // a static broadcast.
BroadcastToFeatureDim(Location loc,RankedTensorType result_type,Value value_1d,Value shape_value,int64_t feature_dim,PatternRewriter & rewriter)36 Value BroadcastToFeatureDim(Location loc, RankedTensorType result_type,
37                             Value value_1d, Value shape_value,
38                             int64_t feature_dim,
39                             PatternRewriter& rewriter) {  // NOLINT
40   Builder b(rewriter.getContext());
41   auto dims_type = RankedTensorType::get({1}, b.getIntegerType(64));
42   auto dims = DenseIntElementsAttr::get(dims_type, {feature_dim});
43   if (shape_value) {
44     return rewriter.createOrFold<mhlo::DynamicBroadcastInDimOp>(
45         loc, result_type, value_1d, shape_value, dims);
46   }
47   assert(result_type.hasStaticShape());
48   return rewriter.create<mhlo::BroadcastInDimOp>(loc, result_type, value_1d,
49                                                  dims);
50 }
51 
52 // Calculate the shape value of operand, assuming it is a dynamic shape with
53 // static rank.
CalculateShapeValue(Location loc,Value operand,PatternRewriter & rewriter)54 Value CalculateShapeValue(Location loc, Value operand,
55                           PatternRewriter& rewriter) {  // NOLINT
56   RankedTensorType result_type = operand.getType().dyn_cast<RankedTensorType>();
57   llvm::SmallVector<Value, 4> shape_values;
58   int64_t rank = result_type.getRank();
59   shape_values.reserve(rank);
60   for (int64_t i = 0; i < rank; ++i) {
61     shape_values.push_back(rewriter.create<mlir::DimOp>(loc, operand, i));
62   }
63   return rewriter.create<tensor::FromElementsOp>(loc, shape_values);
64 }
65 
MaterializeEpsilon(Operation * op,FloatAttr epsilon_attr,FloatType fp_type,Value variance,RankedTensorType broadcast_to_type,PatternRewriter & rewriter)66 Value MaterializeEpsilon(Operation* op, FloatAttr epsilon_attr,
67                          FloatType fp_type, Value variance,
68                          RankedTensorType broadcast_to_type,
69                          PatternRewriter& rewriter) {  // NOLINT
70   Builder b(rewriter.getContext());
71   if (epsilon_attr.getType() != fp_type) {
72     // Need to convert.
73     bool loses_info;
74     APFloat epsilon_float = epsilon_attr.getValue();
75     auto status = epsilon_float.convert(
76         fp_type.getFloatSemantics(), APFloat::rmNearestTiesToEven, &loses_info);
77     if ((status & (~APFloat::opInexact)) != APFloat::opOK) {
78       op->emitWarning() << "Could not convert batch_norm epsilon to target fp "
79                            "type: opStatus = "
80                         << static_cast<int>(status);
81       return nullptr;
82     }
83     if (loses_info) {
84       op->emitWarning("Conversion of epsilon loses precision");
85     }
86     epsilon_attr = b.getFloatAttr(fp_type, epsilon_float);
87   }
88 
89   auto scalar_type = RankedTensorType::get({}, fp_type);
90   auto epsilon_tensor_attr =
91       DenseElementsAttr::get(scalar_type, {epsilon_attr.cast<Attribute>()});
92   Value epsilon =
93       rewriter.create<mhlo::ConstOp>(op->getLoc(), epsilon_tensor_attr);
94   auto dims_type = RankedTensorType::get({0}, b.getIntegerType(64));
95   auto dims = DenseIntElementsAttr::get(dims_type, SmallVector<int64_t, 1>{});
96   if (broadcast_to_type.hasStaticShape()) {
97     return rewriter.create<mhlo::BroadcastInDimOp>(
98         op->getLoc(), broadcast_to_type, epsilon, /*broadcast_dims=*/dims);
99   }
100   Value shape_value = CalculateShapeValue(op->getLoc(), variance, rewriter);
101   return rewriter.createOrFold<mhlo::DynamicBroadcastInDimOp>(
102       op->getLoc(), broadcast_to_type, epsilon, shape_value,
103       /*broadcast_dims=*/dims);
104 }
105 
106 class UnfuseBatchNormInferencePattern
107     : public OpRewritePattern<mhlo::BatchNormInferenceOp> {
108  public:
109   using OpRewritePattern<mhlo::BatchNormInferenceOp>::OpRewritePattern;
110 
matchAndRewrite(mhlo::BatchNormInferenceOp bn_op,PatternRewriter & rewriter) const111   LogicalResult matchAndRewrite(mhlo::BatchNormInferenceOp bn_op,
112                                 PatternRewriter& rewriter) const override {
113     // Enforce type invariants.
114     // Note that we deduce the actual element type from the variance,
115     // which should not be subject to quantization at a higher level.
116     auto input_type = bn_op.operand().getType().dyn_cast<RankedTensorType>();
117     auto variance_type =
118         bn_op.variance().getType().dyn_cast<RankedTensorType>();
119     if (!input_type || !variance_type) {
120       return failure();
121     }
122     auto fp_type = variance_type.getElementType().dyn_cast<FloatType>();
123     if (!fp_type) {
124       return failure();
125     }
126     int64_t feature_dim = bn_op.feature_index();
127 
128     // Add epsilon to the variance and sqrt to get stddev:
129     // stddev = sqrt(variance + epsilon)
130     auto epsilon =
131         MaterializeEpsilon(bn_op.getOperation(), bn_op.epsilonAttr(), fp_type,
132                            bn_op.variance(), variance_type, rewriter);
133     if (!epsilon) {
134       return failure();
135     }
136     Value stddev =
137         rewriter.create<mhlo::AddOp>(bn_op.getLoc(), bn_op.variance(), epsilon);
138     stddev = rewriter.create<mhlo::SqrtOp>(bn_op.getLoc(), stddev);
139 
140     // Broadcast all terms.
141     Value shape_value;
142     if (!input_type.hasStaticShape()) {
143       shape_value =
144           CalculateShapeValue(bn_op.getLoc(), bn_op.operand(), rewriter);
145     }
146     auto broadcast_scale =
147         BroadcastToFeatureDim(bn_op.getLoc(), input_type, bn_op.scale(),
148                               shape_value, feature_dim, rewriter);
149     auto broadcast_offset =
150         BroadcastToFeatureDim(bn_op.getLoc(), input_type, bn_op.offset(),
151                               shape_value, feature_dim, rewriter);
152     auto broadcast_mean =
153         BroadcastToFeatureDim(bn_op.getLoc(), input_type, bn_op.mean(),
154                               shape_value, feature_dim, rewriter);
155     auto broadcast_stddev = BroadcastToFeatureDim(
156         bn_op.getLoc(), input_type, stddev, shape_value, feature_dim, rewriter);
157 
158     // Compute:
159     // scale * (input - mean) / stddev + offset
160     Value result = rewriter.create<mhlo::SubOp>(bn_op.getLoc(), bn_op.operand(),
161                                                 broadcast_mean);
162     result =
163         rewriter.create<mhlo::MulOp>(bn_op.getLoc(), result, broadcast_scale);
164     result =
165         rewriter.create<mhlo::DivOp>(bn_op.getLoc(), result, broadcast_stddev);
166     rewriter.replaceOpWithNewOp<mhlo::AddOp>(bn_op, result, broadcast_offset);
167 
168     return success();
169   }
170 };
171 
172 }  // namespace
173 
174 // Populates conversion patterns to unfuse batch normalization operations.
175 // In combination with marking such ops as illegal, this allows backends that
176 // do not have special support for fused batchnorm to use simpler arithmetic
177 // primitives.
PopulateUnfuseBatchNormPatterns(MLIRContext * context,OwningRewritePatternList * patterns)178 void PopulateUnfuseBatchNormPatterns(MLIRContext* context,
179                                      OwningRewritePatternList* patterns) {
180   patterns->insert<UnfuseBatchNormInferencePattern>(context);
181 }
182 
183 }  // namespace mhlo
184 }  // namespace mlir
185