Home
last modified time | relevance | path

Searched refs:bn_op (Results 1 – 2 of 2) sorted by relevance

/external/tensorflow/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/
Dunfuse_batch_norm.cc111 LogicalResult matchAndRewrite(mhlo::BatchNormInferenceOp bn_op, in matchAndRewrite() argument
116 auto input_type = bn_op.operand().getType().dyn_cast<RankedTensorType>(); in matchAndRewrite()
118 bn_op.variance().getType().dyn_cast<RankedTensorType>(); in matchAndRewrite()
126 int64_t feature_dim = bn_op.feature_index(); in matchAndRewrite()
131 MaterializeEpsilon(bn_op.getOperation(), bn_op.epsilonAttr(), fp_type, in matchAndRewrite()
132 bn_op.variance(), variance_type, rewriter); in matchAndRewrite()
137 rewriter.create<mhlo::AddOp>(bn_op.getLoc(), bn_op.variance(), epsilon); in matchAndRewrite()
138 stddev = rewriter.create<mhlo::SqrtOp>(bn_op.getLoc(), stddev); in matchAndRewrite()
144 CalculateShapeValue(bn_op.getLoc(), bn_op.operand(), rewriter); in matchAndRewrite()
147 BroadcastToFeatureDim(bn_op.getLoc(), input_type, bn_op.scale(), in matchAndRewrite()
[all …]
/external/tensorflow/tensorflow/lite/toco/graph_transformations/
Dresolve_batch_normalization.cc36 const auto* bn_op = in Run() local
39 auto& mean_array = model->GetArray(bn_op->inputs[1]); in Run()
40 const auto& multiplier_array = model->GetArray(bn_op->inputs[2]); in Run()
41 const auto& offset_array = model->GetArray(bn_op->inputs[3]); in Run()
50 CHECK(IsConstantParameterArray(*model, bn_op->inputs[1]) && in Run()
51 IsConstantParameterArray(*model, bn_op->inputs[2]) && in Run()
52 IsConstantParameterArray(*model, bn_op->inputs[3])) in Run()
66 AvailableArrayName(*model, bn_op->outputs[0] + "_mul"); in Run()
68 AvailableArrayName(*model, bn_op->outputs[0] + "_add"); in Run()
73 mul_op->inputs = {bn_op->inputs[0], mul_param_name}; in Run()
[all …]