1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include "OperationResolver.h"
20 #include "RNN.h"
21
22 namespace android {
23 namespace nn {
24 namespace unidirectional_sequence_rnn {
25
26 constexpr uint32_t kNumInputs = 7;
27 constexpr uint32_t kInputTensor = 0;
28 constexpr uint32_t kWeightsTensor = 1;
29 constexpr uint32_t kRecurrentWeightsTensor = 2;
30 constexpr uint32_t kBiasTensor = 3;
31 constexpr uint32_t kHiddenStateTensor = 4;
32 constexpr uint32_t kActivationParam = 5;
33 constexpr uint32_t kTimeMajorParam = 6;
34
35 constexpr uint32_t kNumOutputs = 1;
36 constexpr uint32_t kOutputTensor = 0;
37
38 namespace {
39
40 template <typename T>
transposeFirstTwoDims(const T * input,const Shape & inputShape,T * output)41 void transposeFirstTwoDims(const T* input, const Shape& inputShape, T* output) {
42 const uint32_t firstDimSize = getSizeOfDimension(inputShape, 0);
43 const uint32_t secondDimSize = getSizeOfDimension(inputShape, 1);
44 const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
45 for (int f = 0; f < firstDimSize; ++f) {
46 for (int s = 0; s < secondDimSize; ++s) {
47 for (int i = 0; i < inputSize; ++i) {
48 const uint32_t inputIndex = f * secondDimSize * inputSize + s * inputSize + i;
49 const uint32_t outputIndex = s * firstDimSize * inputSize + f * inputSize + i;
50 output[outputIndex] = input[inputIndex];
51 }
52 }
53 }
54 }
55
56 template <typename T>
executeTyped(IOperationExecutionContext * context)57 bool executeTyped(IOperationExecutionContext* context) {
58 const T* input = context->getInputBuffer<T>(kInputTensor);
59 Shape inputShape = context->getInputShape(kInputTensor);
60 const T* weights = context->getInputBuffer<T>(kWeightsTensor);
61 Shape weightsShape = context->getInputShape(kWeightsTensor);
62 const T* recurrentWeights = context->getInputBuffer<T>(kRecurrentWeightsTensor);
63 Shape recurrentWeightsShape = context->getInputShape(kRecurrentWeightsTensor);
64 const T* bias = context->getInputBuffer<T>(kBiasTensor);
65 const T* hiddenState = context->getInputBuffer<T>(kHiddenStateTensor);
66 int32_t activation = context->getInputValue<int32_t>(kActivationParam);
67
68 T* output = context->getOutputBuffer<T>(kOutputTensor);
69 Shape outputShape = context->getOutputShape(kOutputTensor);
70
71 int32_t timeMajor = context->getInputValue<int32_t>(kTimeMajorParam);
72 // If the input tensors are not in time major format, we transpose the first
73 // two dimensions, and set input and output pointers to temporary vectors
74 // which are transposed back after the RNN is applied.
75 std::vector<T> inputTransposed;
76 std::vector<T> outputTransposed;
77 if (!timeMajor) {
78 // Convert input and output to time major format.
79 inputTransposed.resize(getNumberOfElements(inputShape));
80 outputTransposed.resize(getNumberOfElements(outputShape));
81 transposeFirstTwoDims(input, inputShape, inputTransposed.data());
82 input = inputTransposed.data();
83 output = outputTransposed.data();
84 std::swap(inputShape.dimensions[0], inputShape.dimensions[1]);
85 std::swap(outputShape.dimensions[0], outputShape.dimensions[1]);
86 }
87
88 const uint32_t maxTime = getSizeOfDimension(inputShape, 0);
89 const uint32_t batchSize = getSizeOfDimension(inputShape, 1);
90 const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
91 const uint32_t numUnits = getSizeOfDimension(weightsShape, 0);
92
93 // A shape at a fixed step (removed time dimension).
94 Shape fixedTimeInputShape = inputShape;
95 fixedTimeInputShape.dimensions.resize(2);
96 fixedTimeInputShape.dimensions[0] = inputShape.dimensions[1];
97 fixedTimeInputShape.dimensions[1] = inputShape.dimensions[2];
98
99 for (int i = 0; i < maxTime; ++i) {
100 RNN::RNNStep<T>(input, fixedTimeInputShape, hiddenState, bias, weights, weightsShape,
101 recurrentWeights, recurrentWeightsShape, activation, output);
102 input += batchSize * inputSize;
103 hiddenState = output;
104 output += batchSize * numUnits;
105 }
106
107 if (!timeMajor) {
108 transposeFirstTwoDims(outputTransposed.data(), outputShape,
109 context->getOutputBuffer<T>(kOutputTensor));
110 }
111 return true;
112 }
113
114 } // namespace
115
validate(const IOperationValidationContext * context)116 bool validate(const IOperationValidationContext* context) {
117 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
118 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
119 OperandType inputType = context->getInputType(kInputTensor);
120 if (inputType != OperandType::TENSOR_FLOAT16 && inputType != OperandType::TENSOR_FLOAT32) {
121 LOG(ERROR) << "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_RNN op: "
122 << toString(inputType);
123 return false;
124 }
125 NN_RET_CHECK(validateInputTypes(context, {inputType, inputType, inputType, inputType, inputType,
126 OperandType::INT32, OperandType::INT32}));
127 NN_RET_CHECK(validateOutputTypes(context, {inputType}));
128 return validateHalVersion(context, HalVersion::V1_2);
129 }
130
prepare(IOperationExecutionContext * context)131 bool prepare(IOperationExecutionContext* context) {
132 Shape input = context->getInputShape(kInputTensor);
133 Shape weights = context->getInputShape(kWeightsTensor);
134 Shape recurrentWeights = context->getInputShape(kRecurrentWeightsTensor);
135 Shape bias = context->getInputShape(kBiasTensor);
136 Shape hiddenState = context->getInputShape(kHiddenStateTensor);
137
138 int32_t timeMajor = context->getInputValue<int32_t>(kTimeMajorParam);
139 NN_RET_CHECK(timeMajor == 0 || timeMajor == 1);
140 const uint32_t batchSize =
141 timeMajor ? getSizeOfDimension(input, 1) : getSizeOfDimension(input, 0);
142 const uint32_t maxTime =
143 timeMajor ? getSizeOfDimension(input, 0) : getSizeOfDimension(input, 1);
144 const uint32_t numUnits = getSizeOfDimension(weights, 0);
145 const uint32_t inputSize = getSizeOfDimension(input, 2);
146
147 NN_RET_CHECK_EQ(getNumberOfDimensions(input), 3);
148 NN_RET_CHECK_EQ(getNumberOfDimensions(weights), 2);
149 NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentWeights), 2);
150 NN_RET_CHECK_EQ(getNumberOfDimensions(bias), 1);
151 NN_RET_CHECK_EQ(getNumberOfDimensions(hiddenState), 2);
152
153 NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(weights, 1));
154 NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(bias, 0));
155 NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(recurrentWeights, 0));
156 NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(recurrentWeights, 1));
157 NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(hiddenState, 0));
158 NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(hiddenState, 1));
159
160 Shape output = context->getOutputShape(kOutputTensor);
161 output.dimensions[0] = timeMajor ? maxTime : batchSize;
162 output.dimensions[1] = timeMajor ? batchSize : maxTime;
163 output.dimensions[2] = numUnits;
164
165 return context->setOutputShape(kOutputTensor, output);
166 }
167
execute(IOperationExecutionContext * context)168 bool execute(IOperationExecutionContext* context) {
169 if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) {
170 executeTyped<_Float16>(context);
171 } else {
172 executeTyped<float>(context);
173 }
174 return true;
175 }
176
177 } // namespace unidirectional_sequence_rnn
178
179 NN_REGISTER_OPERATION(UNIDIRECTIONAL_SEQUENCE_RNN, "UNIDIRECTIONAL_SEQUENCE_RNN",
180 unidirectional_sequence_rnn::validate, unidirectional_sequence_rnn::prepare,
181 unidirectional_sequence_rnn::execute);
182
183 } // namespace nn
184 } // namespace android
185