1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include "CpuOperationUtils.h"
20 #include "HalInterfaces.h"
21 #include "OperationResolver.h"
22 #include "Tracing.h"
23
24 #include <cmath>
25 #include <vector>
26
27 namespace android {
28 namespace nn {
29 namespace instance_normalization {
30
31 constexpr char kOperationName[] = "INSTANCE_NORMALIZATION";
32
33 constexpr uint32_t kNumInputs = 5;
34 constexpr uint32_t kInputTensor = 0;
35 constexpr uint32_t kGammaScalar = 1;
36 constexpr uint32_t kBetaScalar = 2;
37 constexpr uint32_t kEpsilonScalar = 3;
38 constexpr uint32_t kLayoutScalar = 4;
39
40 constexpr uint32_t kNumOutputs = 1;
41 constexpr uint32_t kOutputTensor = 0;
42
43 namespace {
44
45 template <typename T>
instanceNormNhwc(const T * inputData,const Shape & inputShape,T gamma,T beta,T epsilon,T * outputData,const Shape & outputShape)46 inline bool instanceNormNhwc(const T* inputData, const Shape& inputShape, T gamma, T beta,
47 T epsilon, T* outputData, const Shape& outputShape) {
48 NNTRACE_TRANS("InstanceNormalizationNhwc");
49 uint32_t numBatches = getSizeOfDimension(inputShape, 0);
50 uint32_t height = getSizeOfDimension(inputShape, 1);
51 uint32_t width = getSizeOfDimension(inputShape, 2);
52 uint32_t depth = getSizeOfDimension(inputShape, 3);
53 for (uint32_t b = 0; b < numBatches; b++) {
54 for (uint32_t d = 0; d < depth; d++) {
55 uint32_t indexBase = b * height * width * depth + d;
56 T mean = 0, var = 0;
57 for (uint32_t h = 0; h < height; h++) {
58 for (uint32_t w = 0; w < width; w++) {
59 T val = inputData[indexBase + (h * width + w) * depth];
60 mean += val;
61 var += val * val;
62 }
63 }
64 mean /= static_cast<T>(height * width);
65 var = std::sqrt(static_cast<float>(var / static_cast<T>(height * width)) + epsilon);
66 for (uint32_t h = 0; h < height; h++) {
67 for (uint32_t w = 0; w < width; w++) {
68 uint32_t ind = indexBase + (h * width + w) * depth;
69 outputData[ind] = (inputData[ind] - mean) * gamma / var + beta;
70 }
71 }
72 }
73 }
74 return true;
75 }
76
77 template <typename T>
instanceNorm(const T * inputData,const Shape & inputShape,T gamma,T beta,T epsilon,bool useNchw,T * outputData,const Shape & outputShape)78 inline bool instanceNorm(const T* inputData, const Shape& inputShape, T gamma, T beta, T epsilon,
79 bool useNchw, T* outputData, const Shape& outputShape) {
80 InputWithLayout<T> input(useNchw);
81 OutputWithLayout<T> output(useNchw);
82 NN_RET_CHECK(input.initialize(inputData, inputShape));
83 NN_RET_CHECK(output.initialize(outputData, outputShape));
84 NN_RET_CHECK(instanceNormNhwc(input.getNhwcBuffer(), input.getNhwcShape(), gamma, beta, epsilon,
85 output.getNhwcBuffer(), output.getNhwcShape()));
86 NN_RET_CHECK(output.commit());
87 return true;
88 }
89
90 } // namespace
91
validate(const IOperationValidationContext * context)92 bool validate(const IOperationValidationContext* context) {
93 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
94 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
95 std::vector<OperandType> inExpectedTypes;
96 auto inputType = context->getInputType(kInputTensor);
97 if (inputType == OperandType::TENSOR_FLOAT32) {
98 inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::FLOAT32, OperandType::FLOAT32,
99 OperandType::FLOAT32, OperandType::BOOL};
100 } else if (inputType == OperandType::TENSOR_FLOAT16) {
101 inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::FLOAT16, OperandType::FLOAT16,
102 OperandType::FLOAT16, OperandType::BOOL};
103 } else {
104 LOG(ERROR) << "Unsupported input tensor type for operation " << kOperationName;
105 return false;
106 }
107 NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
108 NN_RET_CHECK(validateOutputTypes(context, {inputType}));
109 return validateHalVersion(context, HalVersion::V1_2);
110 }
111
prepare(IOperationExecutionContext * context)112 bool prepare(IOperationExecutionContext* context) {
113 Shape input = context->getInputShape(kInputTensor);
114 NN_RET_CHECK_EQ(getNumberOfDimensions(input), 4);
115 return context->setOutputShape(kOutputTensor, input);
116 }
117
execute(IOperationExecutionContext * context)118 bool execute(IOperationExecutionContext* context) {
119 switch (context->getInputType(kInputTensor)) {
120 case OperandType::TENSOR_FLOAT16:
121 return instanceNorm(context->getInputBuffer<_Float16>(kInputTensor),
122 context->getInputShape(kInputTensor),
123 context->getInputValue<_Float16>(kGammaScalar),
124 context->getInputValue<_Float16>(kBetaScalar),
125 context->getInputValue<_Float16>(kEpsilonScalar),
126 context->getInputValue<bool>(kLayoutScalar),
127 context->getOutputBuffer<_Float16>(kOutputTensor),
128 context->getOutputShape(kOutputTensor));
129 case OperandType::TENSOR_FLOAT32:
130 return instanceNorm(context->getInputBuffer<float>(kInputTensor),
131 context->getInputShape(kInputTensor),
132 context->getInputValue<float>(kGammaScalar),
133 context->getInputValue<float>(kBetaScalar),
134 context->getInputValue<float>(kEpsilonScalar),
135 context->getInputValue<bool>(kLayoutScalar),
136 context->getOutputBuffer<float>(kOutputTensor),
137 context->getOutputShape(kOutputTensor));
138 default:
139 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
140 }
141 }
142
143 } // namespace instance_normalization
144
145 NN_REGISTER_OPERATION(INSTANCE_NORMALIZATION, instance_normalization::kOperationName,
146 instance_normalization::validate, instance_normalization::prepare,
147 instance_normalization::execute);
148
149 } // namespace nn
150 } // namespace android
151