1 /*
2 * Copyright (C) 2022 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "DepthwiseConv2DOperationConverter.h"
18
19 #include <vector>
20
21 #include "OperationConverterResolver.h"
22 #include "SubGraphContext.h"
23
24 namespace android {
25 namespace nn {
26
convert(const Operation & operation,SubGraphContext * context) const27 Result<void> DepthwiseConv2DOperationConverter::convert(const Operation& operation,
28 SubGraphContext* context) const {
29 const Model::Subgraph* subgraph = context->getSubgraph();
30
31 // add opcode for DEPTHWISE_CONV_2D if not added yet
32 uint32_t opCodeIdx = context->addOpCode(OperationType::DEPTHWISE_CONV_2D);
33
34 // if there are less than 9 inputs or the input at the 8th index is a BOOL, there is implicit
35 // padding
36 const bool isImplicitPadding =
37 (operation.inputs.size() < 9 ||
38 subgraph->operands[operation.inputs[8]].type == OperandType::BOOL);
39
40 std::vector<int32_t> inputs = NN_TRY(getConv2DInputs(operation, context));
41 std::vector<int32_t> outputs = NN_TRY(getConv2DOutputs(operation, context));
42
43 // if explicit padding, we need to decompose the operation to a separate padding op and a conv2d
44 // op
45 if (!isImplicitPadding) {
46 auto padOpIdx = NN_TRY(decomposeExplicitPadding(operation, context));
47 inputs[0] = padOpIdx;
48 }
49
50 int baseOptionsIdx = 4;
51 tflite::Padding padding;
52 if (isImplicitPadding) {
53 const Operand& paddingTypeOperand = subgraph->operands[operation.inputs[3]];
54 NN_RET_CHECK(isOperandConstant(paddingTypeOperand));
55
56 int32_t paddingType = context->getConstantScalar<int32_t>(paddingTypeOperand);
57 padding = getTFLitePadding(paddingType);
58 } else {
59 padding = tflite::Padding::Padding_VALID;
60 baseOptionsIdx = 7;
61 }
62
63 // check if stride, depthwise multiplier, and activation Operands are constant
64 const Operand& strideWOperand =
65 subgraph->operands[operation.inputs[baseOptionsIdx + kStrideWOffset]];
66 const Operand& strideHOperand =
67 subgraph->operands[operation.inputs[baseOptionsIdx + kStrideHOffset]];
68 const Operand& activationOperand =
69 subgraph->operands[operation.inputs[baseOptionsIdx + kActivationOffset]];
70 const Operand& depthwiseMultiplierOperand =
71 subgraph->operands[operation.inputs[baseOptionsIdx + kDepthwiseMultiplier]];
72 NN_RET_CHECK(isOperandConstant(strideWOperand));
73 NN_RET_CHECK(isOperandConstant(strideHOperand));
74 NN_RET_CHECK(isOperandConstant(activationOperand));
75 NN_RET_CHECK(isOperandConstant(depthwiseMultiplierOperand));
76
77 // get strides and activation
78 int32_t strideW = context->getConstantScalar<int32_t>(strideWOperand);
79 int32_t strideH = context->getConstantScalar<int32_t>(strideHOperand);
80 int32_t depthwiseMultiplier = context->getConstantScalar<int32_t>(depthwiseMultiplierOperand);
81 FusedActivationFunc activation = static_cast<FusedActivationFunc>(
82 context->getConstantScalar<int32_t>(activationOperand));
83
84 // check for nchw
85 int isNchwIdx = baseOptionsIdx + kIsNchwOffset;
86 if (operation.inputs.size() > static_cast<uint32_t>(isNchwIdx)) {
87 const Operand& isNchwOperand = subgraph->operands[operation.inputs[isNchwIdx]];
88 NN_RET_CHECK(isOperandConstant(isNchwOperand));
89
90 bool isNchw = context->getConstantScalar<bool>(isNchwOperand);
91 NN_RET_CHECK(!isNchw) << "TFLite does not support NCHW formatted input tensors";
92 }
93
94 // dilations
95 int dilationWIdx = baseOptionsIdx + kDilationWOffset;
96 int dilationHIdx = baseOptionsIdx + kDilationHOffset;
97 // default dilation factors are 1
98 int32_t dilationW = 1;
99 int32_t dilationH = 1;
100 if (operation.inputs.size() > static_cast<uint32_t>(dilationWIdx)) {
101 const Operand& dilationWOperand = subgraph->operands[operation.inputs[dilationWIdx]];
102 NN_RET_CHECK(isOperandConstant(dilationWOperand));
103
104 dilationW = context->getConstantScalar<int32_t>(dilationWOperand);
105 }
106 if (operation.inputs.size() > static_cast<uint32_t>(dilationHIdx)) {
107 const Operand& dilationHOperand = subgraph->operands[operation.inputs[dilationHIdx]];
108 NN_RET_CHECK(isOperandConstant(dilationHOperand));
109
110 dilationH = context->getConstantScalar<int32_t>(dilationHOperand);
111 }
112
113 flatbuffers::Offset<tflite::DepthwiseConv2DOptions> optionsFlatbuffer =
114 tflite::CreateDepthwiseConv2DOptions(
115 context->getBuilder(), padding, strideW, strideH, depthwiseMultiplier,
116 NN_TRY(getTfliteActivation(activation)) /* fused_activation_function */,
117 dilationW, dilationH);
118 auto operatorFlatbuffer = tflite::CreateOperatorDirect(
119 context->getBuilder() /* builder */, opCodeIdx /* opcode_index */, &inputs /* inputs */,
120 &outputs /* outputs */,
121 tflite::BuiltinOptions::
122 BuiltinOptions_DepthwiseConv2DOptions /* builtin_options_type */,
123 optionsFlatbuffer.Union() /* builtin_options */);
124 context->addOperatorFlatbuffer(operatorFlatbuffer);
125
126 return {};
127 }
128
129 NN_REGISTER_OPERATION_CONVERTER(DEPTHWISE_CONV_2D, DepthwiseConv2DOperationConverter);
130
131 } // namespace nn
132 } // namespace android