1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include "OperationResolver.h"
20 #include "RNN.h"
21 
22 namespace android {
23 namespace nn {
24 namespace unidirectional_sequence_rnn {
25 
26 constexpr uint32_t kNumInputs = 7;
27 constexpr uint32_t kInputTensor = 0;
28 constexpr uint32_t kWeightsTensor = 1;
29 constexpr uint32_t kRecurrentWeightsTensor = 2;
30 constexpr uint32_t kBiasTensor = 3;
31 constexpr uint32_t kHiddenStateTensor = 4;
32 constexpr uint32_t kActivationParam = 5;
33 constexpr uint32_t kTimeMajorParam = 6;
34 
35 constexpr uint32_t kNumOutputs = 1;
36 constexpr uint32_t kOutputTensor = 0;
37 
38 namespace {
39 
40 template <typename T>
transposeFirstTwoDims(const T * input,const Shape & inputShape,T * output)41 void transposeFirstTwoDims(const T* input, const Shape& inputShape, T* output) {
42     const uint32_t firstDimSize = getSizeOfDimension(inputShape, 0);
43     const uint32_t secondDimSize = getSizeOfDimension(inputShape, 1);
44     const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
45     for (int f = 0; f < firstDimSize; ++f) {
46         for (int s = 0; s < secondDimSize; ++s) {
47             for (int i = 0; i < inputSize; ++i) {
48                 const uint32_t inputIndex = f * secondDimSize * inputSize + s * inputSize + i;
49                 const uint32_t outputIndex = s * firstDimSize * inputSize + f * inputSize + i;
50                 output[outputIndex] = input[inputIndex];
51             }
52         }
53     }
54 }
55 
56 template <typename T>
executeTyped(IOperationExecutionContext * context)57 bool executeTyped(IOperationExecutionContext* context) {
58     const T* input = context->getInputBuffer<T>(kInputTensor);
59     Shape inputShape = context->getInputShape(kInputTensor);
60     const T* weights = context->getInputBuffer<T>(kWeightsTensor);
61     Shape weightsShape = context->getInputShape(kWeightsTensor);
62     const T* recurrentWeights = context->getInputBuffer<T>(kRecurrentWeightsTensor);
63     Shape recurrentWeightsShape = context->getInputShape(kRecurrentWeightsTensor);
64     const T* bias = context->getInputBuffer<T>(kBiasTensor);
65     const T* hiddenState = context->getInputBuffer<T>(kHiddenStateTensor);
66     int32_t activation = context->getInputValue<int32_t>(kActivationParam);
67 
68     T* output = context->getOutputBuffer<T>(kOutputTensor);
69     Shape outputShape = context->getOutputShape(kOutputTensor);
70 
71     int32_t timeMajor = context->getInputValue<int32_t>(kTimeMajorParam);
72     // If the input tensors are not in time major format, we transpose the first
73     // two dimensions, and set input and output pointers to temporary vectors
74     // which are transposed back after the RNN is applied.
75     std::vector<T> inputTransposed;
76     std::vector<T> outputTransposed;
77     if (!timeMajor) {
78         // Convert input and output to time major format.
79         inputTransposed.resize(getNumberOfElements(inputShape));
80         outputTransposed.resize(getNumberOfElements(outputShape));
81         transposeFirstTwoDims(input, inputShape, inputTransposed.data());
82         input = inputTransposed.data();
83         output = outputTransposed.data();
84         std::swap(inputShape.dimensions[0], inputShape.dimensions[1]);
85         std::swap(outputShape.dimensions[0], outputShape.dimensions[1]);
86     }
87 
88     const uint32_t maxTime = getSizeOfDimension(inputShape, 0);
89     const uint32_t batchSize = getSizeOfDimension(inputShape, 1);
90     const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
91     const uint32_t numUnits = getSizeOfDimension(weightsShape, 0);
92 
93     // A shape at a fixed step (removed time dimension).
94     Shape fixedTimeInputShape = inputShape;
95     fixedTimeInputShape.dimensions.resize(2);
96     fixedTimeInputShape.dimensions[0] = inputShape.dimensions[1];
97     fixedTimeInputShape.dimensions[1] = inputShape.dimensions[2];
98 
99     for (int i = 0; i < maxTime; ++i) {
100         RNN::RNNStep<T>(input, fixedTimeInputShape, hiddenState, bias, weights, weightsShape,
101                         recurrentWeights, recurrentWeightsShape, activation, output);
102         input += batchSize * inputSize;
103         hiddenState = output;
104         output += batchSize * numUnits;
105     }
106 
107     if (!timeMajor) {
108         transposeFirstTwoDims(outputTransposed.data(), outputShape,
109                               context->getOutputBuffer<T>(kOutputTensor));
110     }
111     return true;
112 }
113 
114 }  // namespace
115 
validate(const IOperationValidationContext * context)116 bool validate(const IOperationValidationContext* context) {
117     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
118     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
119     OperandType inputType = context->getInputType(kInputTensor);
120     if (inputType != OperandType::TENSOR_FLOAT16 && inputType != OperandType::TENSOR_FLOAT32) {
121         LOG(ERROR) << "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_RNN op: "
122                    << toString(inputType);
123         return false;
124     }
125     NN_RET_CHECK(validateInputTypes(context, {inputType, inputType, inputType, inputType, inputType,
126                                               OperandType::INT32, OperandType::INT32}));
127     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
128     return validateHalVersion(context, HalVersion::V1_2);
129 }
130 
prepare(IOperationExecutionContext * context)131 bool prepare(IOperationExecutionContext* context) {
132     Shape input = context->getInputShape(kInputTensor);
133     Shape weights = context->getInputShape(kWeightsTensor);
134     Shape recurrentWeights = context->getInputShape(kRecurrentWeightsTensor);
135     Shape bias = context->getInputShape(kBiasTensor);
136     Shape hiddenState = context->getInputShape(kHiddenStateTensor);
137 
138     int32_t timeMajor = context->getInputValue<int32_t>(kTimeMajorParam);
139     NN_RET_CHECK(timeMajor == 0 || timeMajor == 1);
140     const uint32_t batchSize =
141             timeMajor ? getSizeOfDimension(input, 1) : getSizeOfDimension(input, 0);
142     const uint32_t maxTime =
143             timeMajor ? getSizeOfDimension(input, 0) : getSizeOfDimension(input, 1);
144     const uint32_t numUnits = getSizeOfDimension(weights, 0);
145     const uint32_t inputSize = getSizeOfDimension(input, 2);
146 
147     NN_RET_CHECK_EQ(getNumberOfDimensions(input), 3);
148     NN_RET_CHECK_EQ(getNumberOfDimensions(weights), 2);
149     NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentWeights), 2);
150     NN_RET_CHECK_EQ(getNumberOfDimensions(bias), 1);
151     NN_RET_CHECK_EQ(getNumberOfDimensions(hiddenState), 2);
152 
153     NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(weights, 1));
154     NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(bias, 0));
155     NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(recurrentWeights, 0));
156     NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(recurrentWeights, 1));
157     NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(hiddenState, 0));
158     NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(hiddenState, 1));
159 
160     Shape output = context->getOutputShape(kOutputTensor);
161     output.dimensions[0] = timeMajor ? maxTime : batchSize;
162     output.dimensions[1] = timeMajor ? batchSize : maxTime;
163     output.dimensions[2] = numUnits;
164 
165     return context->setOutputShape(kOutputTensor, output);
166 }
167 
execute(IOperationExecutionContext * context)168 bool execute(IOperationExecutionContext* context) {
169     if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) {
170         executeTyped<_Float16>(context);
171     } else {
172         executeTyped<float>(context);
173     }
174     return true;
175 }
176 
177 }  // namespace unidirectional_sequence_rnn
178 
179 NN_REGISTER_OPERATION(UNIDIRECTIONAL_SEQUENCE_RNN, "UNIDIRECTIONAL_SEQUENCE_RNN",
180                       unidirectional_sequence_rnn::validate, unidirectional_sequence_rnn::prepare,
181                       unidirectional_sequence_rnn::execute);
182 
183 }  // namespace nn
184 }  // namespace android
185