1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include "CpuOperationUtils.h"
20 #include "HalInterfaces.h"
21 #include "OperationResolver.h"
22 #include "Tracing.h"
23 
24 #include <cmath>
25 #include <vector>
26 
27 namespace android {
28 namespace nn {
29 namespace instance_normalization {
30 
31 constexpr char kOperationName[] = "INSTANCE_NORMALIZATION";
32 
33 constexpr uint32_t kNumInputs = 5;
34 constexpr uint32_t kInputTensor = 0;
35 constexpr uint32_t kGammaScalar = 1;
36 constexpr uint32_t kBetaScalar = 2;
37 constexpr uint32_t kEpsilonScalar = 3;
38 constexpr uint32_t kLayoutScalar = 4;
39 
40 constexpr uint32_t kNumOutputs = 1;
41 constexpr uint32_t kOutputTensor = 0;
42 
43 namespace {
44 
45 template <typename T>
instanceNormNhwc(const T * inputData,const Shape & inputShape,T gamma,T beta,T epsilon,T * outputData,const Shape & outputShape)46 inline bool instanceNormNhwc(const T* inputData, const Shape& inputShape, T gamma, T beta,
47                              T epsilon, T* outputData, const Shape& outputShape) {
48     NNTRACE_TRANS("InstanceNormalizationNhwc");
49     uint32_t numBatches = getSizeOfDimension(inputShape, 0);
50     uint32_t height = getSizeOfDimension(inputShape, 1);
51     uint32_t width = getSizeOfDimension(inputShape, 2);
52     uint32_t depth = getSizeOfDimension(inputShape, 3);
53     for (uint32_t b = 0; b < numBatches; b++) {
54         for (uint32_t d = 0; d < depth; d++) {
55             uint32_t indexBase = b * height * width * depth + d;
56             T mean = 0, var = 0;
57             for (uint32_t h = 0; h < height; h++) {
58                 for (uint32_t w = 0; w < width; w++) {
59                     T val = inputData[indexBase + (h * width + w) * depth];
60                     mean += val;
61                     var += val * val;
62                 }
63             }
64             mean /= static_cast<T>(height * width);
65             var = std::sqrt(static_cast<float>(var / static_cast<T>(height * width)) + epsilon);
66             for (uint32_t h = 0; h < height; h++) {
67                 for (uint32_t w = 0; w < width; w++) {
68                     uint32_t ind = indexBase + (h * width + w) * depth;
69                     outputData[ind] = (inputData[ind] - mean) * gamma / var + beta;
70                 }
71             }
72         }
73     }
74     return true;
75 }
76 
77 template <typename T>
instanceNorm(const T * inputData,const Shape & inputShape,T gamma,T beta,T epsilon,bool useNchw,T * outputData,const Shape & outputShape)78 inline bool instanceNorm(const T* inputData, const Shape& inputShape, T gamma, T beta, T epsilon,
79                          bool useNchw, T* outputData, const Shape& outputShape) {
80     InputWithLayout<T> input(useNchw);
81     OutputWithLayout<T> output(useNchw);
82     NN_RET_CHECK(input.initialize(inputData, inputShape));
83     NN_RET_CHECK(output.initialize(outputData, outputShape));
84     NN_RET_CHECK(instanceNormNhwc(input.getNhwcBuffer(), input.getNhwcShape(), gamma, beta, epsilon,
85                                   output.getNhwcBuffer(), output.getNhwcShape()));
86     NN_RET_CHECK(output.commit());
87     return true;
88 }
89 
90 }  // namespace
91 
validate(const IOperationValidationContext * context)92 bool validate(const IOperationValidationContext* context) {
93     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
94     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
95     std::vector<OperandType> inExpectedTypes;
96     auto inputType = context->getInputType(kInputTensor);
97     if (inputType == OperandType::TENSOR_FLOAT32) {
98         inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::FLOAT32, OperandType::FLOAT32,
99                            OperandType::FLOAT32, OperandType::BOOL};
100     } else if (inputType == OperandType::TENSOR_FLOAT16) {
101         inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::FLOAT16, OperandType::FLOAT16,
102                            OperandType::FLOAT16, OperandType::BOOL};
103     } else {
104         LOG(ERROR) << "Unsupported input tensor type for operation " << kOperationName;
105         return false;
106     }
107     NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
108     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
109     return validateHalVersion(context, HalVersion::V1_2);
110 }
111 
prepare(IOperationExecutionContext * context)112 bool prepare(IOperationExecutionContext* context) {
113     Shape input = context->getInputShape(kInputTensor);
114     NN_RET_CHECK_EQ(getNumberOfDimensions(input), 4);
115     return context->setOutputShape(kOutputTensor, input);
116 }
117 
execute(IOperationExecutionContext * context)118 bool execute(IOperationExecutionContext* context) {
119     switch (context->getInputType(kInputTensor)) {
120         case OperandType::TENSOR_FLOAT16:
121             return instanceNorm(context->getInputBuffer<_Float16>(kInputTensor),
122                                 context->getInputShape(kInputTensor),
123                                 context->getInputValue<_Float16>(kGammaScalar),
124                                 context->getInputValue<_Float16>(kBetaScalar),
125                                 context->getInputValue<_Float16>(kEpsilonScalar),
126                                 context->getInputValue<bool>(kLayoutScalar),
127                                 context->getOutputBuffer<_Float16>(kOutputTensor),
128                                 context->getOutputShape(kOutputTensor));
129         case OperandType::TENSOR_FLOAT32:
130             return instanceNorm(context->getInputBuffer<float>(kInputTensor),
131                                 context->getInputShape(kInputTensor),
132                                 context->getInputValue<float>(kGammaScalar),
133                                 context->getInputValue<float>(kBetaScalar),
134                                 context->getInputValue<float>(kEpsilonScalar),
135                                 context->getInputValue<bool>(kLayoutScalar),
136                                 context->getOutputBuffer<float>(kOutputTensor),
137                                 context->getOutputShape(kOutputTensor));
138         default:
139             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
140     }
141 }
142 
143 }  // namespace instance_normalization
144 
145 NN_REGISTER_OPERATION(INSTANCE_NORMALIZATION, instance_normalization::kOperationName,
146                       instance_normalization::validate, instance_normalization::prepare,
147                       instance_normalization::execute);
148 
149 }  // namespace nn
150 }  // namespace android
151