Searched refs:bn_op (Results 1 – 2 of 2) sorted by relevance
/external/tensorflow/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/ |
D | unfuse_batch_norm.cc | 111 LogicalResult matchAndRewrite(mhlo::BatchNormInferenceOp bn_op, in matchAndRewrite() argument 116 auto input_type = bn_op.operand().getType().dyn_cast<RankedTensorType>(); in matchAndRewrite() 118 bn_op.variance().getType().dyn_cast<RankedTensorType>(); in matchAndRewrite() 126 int64_t feature_dim = bn_op.feature_index(); in matchAndRewrite() 131 MaterializeEpsilon(bn_op.getOperation(), bn_op.epsilonAttr(), fp_type, in matchAndRewrite() 132 bn_op.variance(), variance_type, rewriter); in matchAndRewrite() 137 rewriter.create<mhlo::AddOp>(bn_op.getLoc(), bn_op.variance(), epsilon); in matchAndRewrite() 138 stddev = rewriter.create<mhlo::SqrtOp>(bn_op.getLoc(), stddev); in matchAndRewrite() 144 CalculateShapeValue(bn_op.getLoc(), bn_op.operand(), rewriter); in matchAndRewrite() 147 BroadcastToFeatureDim(bn_op.getLoc(), input_type, bn_op.scale(), in matchAndRewrite() [all …]
|
/external/tensorflow/tensorflow/lite/toco/graph_transformations/ |
D | resolve_batch_normalization.cc | 36 const auto* bn_op = in Run() local 39 auto& mean_array = model->GetArray(bn_op->inputs[1]); in Run() 40 const auto& multiplier_array = model->GetArray(bn_op->inputs[2]); in Run() 41 const auto& offset_array = model->GetArray(bn_op->inputs[3]); in Run() 50 CHECK(IsConstantParameterArray(*model, bn_op->inputs[1]) && in Run() 51 IsConstantParameterArray(*model, bn_op->inputs[2]) && in Run() 52 IsConstantParameterArray(*model, bn_op->inputs[3])) in Run() 66 AvailableArrayName(*model, bn_op->outputs[0] + "_mul"); in Run() 68 AvailableArrayName(*model, bn_op->outputs[0] + "_add"); in Run() 73 mul_op->inputs = {bn_op->inputs[0], mul_param_name}; in Run() [all …]
|