1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "DepthwiseConv2DOperationConverter.h"
18 
19 #include <vector>
20 
21 #include "OperationConverterResolver.h"
22 #include "SubGraphContext.h"
23 
24 namespace android {
25 namespace nn {
26 
convert(const Operation & operation,SubGraphContext * context) const27 Result<void> DepthwiseConv2DOperationConverter::convert(const Operation& operation,
28                                                         SubGraphContext* context) const {
29     const Model::Subgraph* subgraph = context->getSubgraph();
30 
31     // add opcode for DEPTHWISE_CONV_2D if not added yet
32     uint32_t opCodeIdx = context->addOpCode(OperationType::DEPTHWISE_CONV_2D);
33 
34     // if there are less than 9 inputs or the input at the 8th index is a BOOL, there is implicit
35     // padding
36     const bool isImplicitPadding =
37             (operation.inputs.size() < 9 ||
38              subgraph->operands[operation.inputs[8]].type == OperandType::BOOL);
39 
40     std::vector<int32_t> inputs = NN_TRY(getConv2DInputs(operation, context));
41     std::vector<int32_t> outputs = NN_TRY(getConv2DOutputs(operation, context));
42 
43     // if explicit padding, we need to decompose the operation to a separate padding op and a conv2d
44     // op
45     if (!isImplicitPadding) {
46         auto padOpIdx = NN_TRY(decomposeExplicitPadding(operation, context));
47         inputs[0] = padOpIdx;
48     }
49 
50     int baseOptionsIdx = 4;
51     tflite::Padding padding;
52     if (isImplicitPadding) {
53         const Operand& paddingTypeOperand = subgraph->operands[operation.inputs[3]];
54         NN_RET_CHECK(isOperandConstant(paddingTypeOperand));
55 
56         int32_t paddingType = context->getConstantScalar<int32_t>(paddingTypeOperand);
57         padding = getTFLitePadding(paddingType);
58     } else {
59         padding = tflite::Padding::Padding_VALID;
60         baseOptionsIdx = 7;
61     }
62 
63     // check if stride, depthwise multiplier, and activation Operands are constant
64     const Operand& strideWOperand =
65             subgraph->operands[operation.inputs[baseOptionsIdx + kStrideWOffset]];
66     const Operand& strideHOperand =
67             subgraph->operands[operation.inputs[baseOptionsIdx + kStrideHOffset]];
68     const Operand& activationOperand =
69             subgraph->operands[operation.inputs[baseOptionsIdx + kActivationOffset]];
70     const Operand& depthwiseMultiplierOperand =
71             subgraph->operands[operation.inputs[baseOptionsIdx + kDepthwiseMultiplier]];
72     NN_RET_CHECK(isOperandConstant(strideWOperand));
73     NN_RET_CHECK(isOperandConstant(strideHOperand));
74     NN_RET_CHECK(isOperandConstant(activationOperand));
75     NN_RET_CHECK(isOperandConstant(depthwiseMultiplierOperand));
76 
77     // get strides and activation
78     int32_t strideW = context->getConstantScalar<int32_t>(strideWOperand);
79     int32_t strideH = context->getConstantScalar<int32_t>(strideHOperand);
80     int32_t depthwiseMultiplier = context->getConstantScalar<int32_t>(depthwiseMultiplierOperand);
81     FusedActivationFunc activation = static_cast<FusedActivationFunc>(
82             context->getConstantScalar<int32_t>(activationOperand));
83 
84     // check for nchw
85     int isNchwIdx = baseOptionsIdx + kIsNchwOffset;
86     if (operation.inputs.size() > static_cast<uint32_t>(isNchwIdx)) {
87         const Operand& isNchwOperand = subgraph->operands[operation.inputs[isNchwIdx]];
88         NN_RET_CHECK(isOperandConstant(isNchwOperand));
89 
90         bool isNchw = context->getConstantScalar<bool>(isNchwOperand);
91         NN_RET_CHECK(!isNchw) << "TFLite does not support NCHW formatted input tensors";
92     }
93 
94     // dilations
95     int dilationWIdx = baseOptionsIdx + kDilationWOffset;
96     int dilationHIdx = baseOptionsIdx + kDilationHOffset;
97     // default dilation factors are 1
98     int32_t dilationW = 1;
99     int32_t dilationH = 1;
100     if (operation.inputs.size() > static_cast<uint32_t>(dilationWIdx)) {
101         const Operand& dilationWOperand = subgraph->operands[operation.inputs[dilationWIdx]];
102         NN_RET_CHECK(isOperandConstant(dilationWOperand));
103 
104         dilationW = context->getConstantScalar<int32_t>(dilationWOperand);
105     }
106     if (operation.inputs.size() > static_cast<uint32_t>(dilationHIdx)) {
107         const Operand& dilationHOperand = subgraph->operands[operation.inputs[dilationHIdx]];
108         NN_RET_CHECK(isOperandConstant(dilationHOperand));
109 
110         dilationH = context->getConstantScalar<int32_t>(dilationHOperand);
111     }
112 
113     flatbuffers::Offset<tflite::DepthwiseConv2DOptions> optionsFlatbuffer =
114             tflite::CreateDepthwiseConv2DOptions(
115                     context->getBuilder(), padding, strideW, strideH, depthwiseMultiplier,
116                     NN_TRY(getTfliteActivation(activation)) /* fused_activation_function */,
117                     dilationW, dilationH);
118     auto operatorFlatbuffer = tflite::CreateOperatorDirect(
119             context->getBuilder() /* builder */, opCodeIdx /* opcode_index */, &inputs /* inputs */,
120             &outputs /* outputs */,
121             tflite::BuiltinOptions::
122                     BuiltinOptions_DepthwiseConv2DOptions /* builtin_options_type */,
123             optionsFlatbuffer.Union() /* builtin_options */);
124     context->addOperatorFlatbuffer(operatorFlatbuffer);
125 
126     return {};
127 }
128 
129 NN_REGISTER_OPERATION_CONVERTER(DEPTHWISE_CONV_2D, DepthwiseConv2DOperationConverter);
130 
131 }  // namespace nn
132 }  // namespace android