1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include "BidirectionalSequenceRNN.h"
20 
21 #include <algorithm>
22 #include <utility>
23 #include <vector>
24 
25 #include "OperationResolver.h"
26 #include "RNN.h"
27 
28 namespace android {
29 namespace nn {
30 namespace bidirectional_sequence_rnn {
31 
32 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
33 namespace {
34 
35 template <typename T>
transposeFirstTwoDims(const T * input,const Shape & inputShape,T * output)36 void transposeFirstTwoDims(const T* input, const Shape& inputShape, T* output) {
37     const uint32_t firstDimSize = getSizeOfDimension(inputShape, 0);
38     const uint32_t secondDimSize = getSizeOfDimension(inputShape, 1);
39     const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
40     for (uint32_t f = 0; f < firstDimSize; ++f) {
41         for (uint32_t s = 0; s < secondDimSize; ++s) {
42             for (uint32_t i = 0; i < inputSize; ++i) {
43                 const uint32_t inputIndex = f * secondDimSize * inputSize + s * inputSize + i;
44                 const uint32_t outputIndex = s * firstDimSize * inputSize + f * inputSize + i;
45                 output[outputIndex] = input[inputIndex];
46             }
47         }
48     }
49 }
50 
removeFirstDim(const Shape & input)51 Shape removeFirstDim(const Shape& input) {
52     Shape output = input;
53     output.dimensions.resize(input.dimensions.size() - 1);
54     for (size_t i = 0; i < input.dimensions.size() - 1; ++i) {
55         output.dimensions[i] = input.dimensions[i + 1];
56     }
57     return output;
58 }
59 
60 enum class LinkingMode {
61     NO_LINKING,
62     PARALLEL_LINKING,
63     CROSS_LINKING,
64 };
65 
getLinkingMode(IOperationExecutionContext * context,LinkingMode * linkingMode)66 bool getLinkingMode(IOperationExecutionContext* context, LinkingMode* linkingMode) {
67     const bool hasAuxInput = !context->isOmittedInput(kAuxInputTensor);
68     const bool hasFwAuxWeights = !context->isOmittedInput(kFwAuxWeightsTensor);
69     const bool hasBwAuxWeights = !context->isOmittedInput(kBwAuxWeightsTensor);
70 
71     // Three possible configurations for three possible linking modes:
72     // 1) NO_LINKING -- no auxiliary tensors at all
73     // 2) PARALLEL_LINKING -- auxiliary input is provided and used as a regular
74     //    input to the backward network, so the auxiliary weights are omitted.
75     // 3) CROSS_LINKING -- auxiliary input is provided and multiplied by
76     //    auxiliary weights.
77     if (!hasAuxInput && !hasFwAuxWeights && !hasBwAuxWeights) {
78         *linkingMode = LinkingMode::NO_LINKING;
79     } else if (hasAuxInput && !hasFwAuxWeights && !hasBwAuxWeights) {
80         *linkingMode = LinkingMode::PARALLEL_LINKING;
81     } else if (hasAuxInput && hasFwAuxWeights && hasBwAuxWeights) {
82         *linkingMode = LinkingMode::CROSS_LINKING;
83     } else {
84         NN_RET_CHECK_FAIL()
85                 << "Unsupported auxiliary tensors configuration for BIDIRECTIONAL_SEQUENCE_RNN.";
86     }
87 
88     return true;
89 }
90 
91 template <typename T>
executeTyped(IOperationExecutionContext * context)92 bool executeTyped(IOperationExecutionContext* context) {
93     const T* input = context->getInputBuffer<T>(kInputTensor);
94     Shape inputShape = context->getInputShape(kInputTensor);
95 
96     const T* fwWeights = context->getInputBuffer<T>(kFwWeightsTensor);
97     Shape fwWeightsShape = context->getInputShape(kFwWeightsTensor);
98     const T* fwRecurrentWeights = context->getInputBuffer<T>(kFwRecurrentWeightsTensor);
99     Shape fwRecurrentWeightsShape = context->getInputShape(kFwRecurrentWeightsTensor);
100     const T* fwBias = context->getInputBuffer<T>(kFwBiasTensor);
101     const T* fwHiddenState = context->getInputBuffer<T>(kFwHiddenStateTensor);
102 
103     const T* bwWeights = context->getInputBuffer<T>(kBwWeightsTensor);
104     Shape bwWeightsShape = context->getInputShape(kBwWeightsTensor);
105     const T* bwRecurrentWeights = context->getInputBuffer<T>(kBwRecurrentWeightsTensor);
106     Shape bwRecurrentWeightsShape = context->getInputShape(kBwRecurrentWeightsTensor);
107     const T* bwBias = context->getInputBuffer<T>(kBwBiasTensor);
108     const T* bwHiddenState = context->getInputBuffer<T>(kBwHiddenStateTensor);
109 
110     const T* auxInput = nullptr;
111     const T* fwAuxWeights = nullptr;
112     const T* bwAuxWeights = nullptr;
113     LinkingMode linkingMode;
114     NN_RET_CHECK(getLinkingMode(context, &linkingMode));
115     if (linkingMode == LinkingMode::CROSS_LINKING) {
116         auxInput = context->getInputBuffer<T>(kAuxInputTensor);
117         fwAuxWeights = context->getInputBuffer<T>(kFwAuxWeightsTensor);
118         bwAuxWeights = context->getInputBuffer<T>(kBwAuxWeightsTensor);
119     } else if (linkingMode == LinkingMode::PARALLEL_LINKING) {
120         auxInput = context->getInputBuffer<T>(kAuxInputTensor);
121     }
122     const bool hasAuxInput = (linkingMode == LinkingMode::CROSS_LINKING ||
123                               linkingMode == LinkingMode::PARALLEL_LINKING);
124     const bool hasAuxWeights = (linkingMode == LinkingMode::CROSS_LINKING);
125     Shape auxInputShape = context->getInputShape(kAuxInputTensor);
126     Shape fwAuxWeightsShape = context->getInputShape(kFwAuxWeightsTensor);
127     Shape bwAuxWeightsShape = context->getInputShape(kBwAuxWeightsTensor);
128 
129     const int32_t activation = context->getInputValue<int32_t>(kActivationParam);
130     const bool timeMajor = context->getInputValue<bool>(kTimeMajorParam);
131     const bool mergeOutputs = context->getInputValue<bool>(kMergeOutputsParam);
132 
133     T* fwOutput = context->getOutputBuffer<T>(kFwOutputTensor);
134     Shape fwOutputShape = context->getOutputShape(kFwOutputTensor);
135     T* bwOutput = nullptr;
136     Shape bwOutputShape;
137     if (!mergeOutputs) {
138         bwOutputShape = context->getOutputShape(kBwOutputTensor);
139         bwOutput = context->getOutputBuffer<T>(kBwOutputTensor);
140     }
141 
142     // If the input tensors are not in time major format, we transpose the first
143     // two dimensions, and set input and output pointers to temporary vectors
144     // which are transposed back after the RNN is applied.
145     std::vector<T> inputTransposed;
146     std::vector<T> auxInputTransposed;
147     std::vector<T> fwOutputTransposed;
148     std::vector<T> bwOutputTransposed;
149     if (!timeMajor) {
150         // First, resize temporary buffers to accommodate for transposed tensors.
151         inputTransposed.resize(getNumberOfElements(inputShape));
152         if (hasAuxInput) {
153             auxInputTransposed.resize(getNumberOfElements(auxInputShape));
154         }
155         fwOutputTransposed.resize(getNumberOfElements(fwOutputShape));
156         if (!mergeOutputs) {
157             bwOutputTransposed.resize(getNumberOfElements(bwOutputShape));
158         }
159 
160         // Transpose the input tensors.
161         transposeFirstTwoDims(input, inputShape, inputTransposed.data());
162         if (hasAuxInput) {
163             transposeFirstTwoDims(auxInput, auxInputShape, auxInputTransposed.data());
164         }
165 
166         // Change input and output pointers to the temporary buffers.
167         input = inputTransposed.data();
168         if (hasAuxInput) {
169             auxInput = auxInputTransposed.data();
170         }
171         fwOutput = fwOutputTransposed.data();
172         if (!mergeOutputs) {
173             bwOutput = bwOutputTransposed.data();
174         }
175 
176         // Swap the first two dimensions in the Shapes to reflect the
177         // transposition.
178         std::swap(inputShape.dimensions[0], inputShape.dimensions[1]);
179         if (hasAuxInput) {
180             std::swap(auxInputShape.dimensions[0], auxInputShape.dimensions[1]);
181         }
182         std::swap(fwOutputShape.dimensions[0], fwOutputShape.dimensions[1]);
183         if (!mergeOutputs) {
184             std::swap(bwOutputShape.dimensions[0], bwOutputShape.dimensions[1]);
185         }
186     }
187 
188     const uint32_t maxTime = getSizeOfDimension(inputShape, 0);
189     const uint32_t batchSize = getSizeOfDimension(inputShape, 1);
190     const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
191     uint32_t auxInputSize = 0;
192     if (hasAuxInput) {
193         auxInputSize = getSizeOfDimension(auxInputShape, 2);
194     }
195     const uint32_t fwNumUnits = getSizeOfDimension(fwWeightsShape, 0);
196     const uint32_t bwNumUnits = getSizeOfDimension(bwWeightsShape, 0);
197 
198     Shape fixedTimeInputShape = removeFirstDim(inputShape);
199     Shape fixedTimeAuxInputShape = auxInputShape;
200     if (hasAuxInput) {
201         fixedTimeAuxInputShape = removeFirstDim(auxInputShape);
202     }
203 
204     const T* bwInput = input;
205     if (linkingMode == LinkingMode::PARALLEL_LINKING) {
206         bwInput = auxInput;
207         auxInput = nullptr;
208     }
209 
210     const bool outputState = (context->getNumOutputs() == kNumOutputsWithState ||
211                               context->getNumOutputs() == kNumOutputsMergedWithState);
212     T* fwOutputHiddenState = nullptr;
213     T* bwOutputHiddenState = nullptr;
214     // Create an additional buffer to store a hidden state between steps.
215     std::vector<T> tempHiddenState;
216     if (outputState) {
217         const int delta = mergeOutputs ? 1 : 0;
218         fwOutputHiddenState = context->getOutputBuffer<T>(kFwOutputHiddenStateTensor - delta);
219         bwOutputHiddenState = context->getOutputBuffer<T>(kBwOutputHiddenStateTensor - delta);
220     } else {
221         tempHiddenState.resize(std::max(batchSize * fwNumUnits, batchSize * bwNumUnits));
222         fwOutputHiddenState = tempHiddenState.data();
223         bwOutputHiddenState = tempHiddenState.data();
224     }
225 
226     // Forward pass
227     for (uint32_t i = 0; i < maxTime; ++i) {
228         const T* inputBatchPtr = input + i * batchSize * inputSize;
229         const T* auxInputBatchPtr = nullptr;
230         if (hasAuxWeights) {
231             auxInputBatchPtr = auxInput + i * batchSize * auxInputSize;
232         }
233         const uint32_t fwOutputBatchStride = mergeOutputs ? (fwNumUnits + bwNumUnits) : fwNumUnits;
234         T* fwOutputBatchPtr = fwOutput + i * batchSize * fwOutputBatchStride;
235 
236         RNN::RNNStep<T>(inputBatchPtr, fixedTimeInputShape, auxInputBatchPtr,
237                         fixedTimeAuxInputShape, fwHiddenState, fwBias, fwWeights, fwWeightsShape,
238                         fwAuxWeights, fwAuxWeightsShape, fwRecurrentWeights,
239                         fwRecurrentWeightsShape, activation, fwOutputBatchStride,
240                         /*outputBatchOffset=*/0, fwOutputBatchPtr, fwOutputHiddenState);
241 
242         fwHiddenState = fwOutputHiddenState;
243     }
244 
245     // Backward pass
246     for (int i = maxTime - 1; i >= 0; --i) {
247         const T* inputBatchPtr = bwInput + i * batchSize * inputSize;
248         const T* auxInputBatchPtr = nullptr;
249         if (hasAuxWeights) {
250             auxInputBatchPtr = auxInput + i * batchSize * auxInputSize;
251         }
252         T* bwOutputBatchPtr;
253         uint32_t bwOutputBatchOffset = 0;
254         uint32_t bwOutputBatchStride;
255         if (mergeOutputs) {
256             bwOutputBatchStride = fwNumUnits + bwNumUnits;
257             bwOutputBatchOffset = fwNumUnits;
258             bwOutputBatchPtr = fwOutput + i * batchSize * bwOutputBatchStride;
259         } else {
260             bwOutputBatchStride = bwNumUnits;
261             bwOutputBatchPtr = bwOutput + i * batchSize * bwOutputBatchStride;
262         }
263 
264         RNN::RNNStep<T>(inputBatchPtr, fixedTimeInputShape, auxInputBatchPtr,
265                         fixedTimeAuxInputShape, bwHiddenState, bwBias, bwWeights, bwWeightsShape,
266                         bwAuxWeights, bwAuxWeightsShape, bwRecurrentWeights,
267                         bwRecurrentWeightsShape, activation, bwOutputBatchStride,
268                         bwOutputBatchOffset, bwOutputBatchPtr, bwOutputHiddenState);
269 
270         bwHiddenState = bwOutputHiddenState;
271     }
272 
273     // If the inputs were in batch major format, transpose data in temporary
274     // buffers and write to the output(s).
275     if (!timeMajor) {
276         transposeFirstTwoDims(fwOutputTransposed.data(), fwOutputShape,
277                               context->getOutputBuffer<T>(kFwOutputTensor));
278         if (!mergeOutputs) {
279             transposeFirstTwoDims(bwOutputTransposed.data(), bwOutputShape,
280                                   context->getOutputBuffer<T>(kBwOutputTensor));
281         }
282     }
283     return true;
284 }
285 
286 }  // namespace
287 
prepare(IOperationExecutionContext * context)288 bool prepare(IOperationExecutionContext* context) {
289     const bool mergeOutputs = context->getInputValue<bool>(kMergeOutputsParam);
290     const int32_t numOutputs = context->getNumOutputs();
291     if (mergeOutputs) {
292         NN_RET_CHECK(numOutputs == kNumOutputsMerged || numOutputs == kNumOutputsMergedWithState);
293     } else {
294         NN_RET_CHECK(numOutputs == kNumOutputs || numOutputs == kNumOutputsWithState);
295     }
296 
297     // Check that none of the required inputs are omitted.
298     const std::vector<int> requiredInputs = {
299             kInputTensor,         kFwWeightsTensor, kFwRecurrentWeightsTensor, kFwBiasTensor,
300             kFwHiddenStateTensor, kBwWeightsTensor, kBwRecurrentWeightsTensor, kBwBiasTensor,
301             kBwHiddenStateTensor, kActivationParam, kTimeMajorParam,           kMergeOutputsParam,
302     };
303     for (const int requiredInput : requiredInputs) {
304         NN_RET_CHECK(!context->isOmittedInput(requiredInput))
305                 << "required input " << requiredInput << " is omitted";
306     }
307 
308     Shape input = context->getInputShape(kInputTensor);
309     Shape fwWeights = context->getInputShape(kFwWeightsTensor);
310     Shape fwRecurrentWeights = context->getInputShape(kFwRecurrentWeightsTensor);
311     Shape fwBias = context->getInputShape(kFwBiasTensor);
312     Shape fwHiddenState = context->getInputShape(kFwHiddenStateTensor);
313     Shape bwWeights = context->getInputShape(kBwWeightsTensor);
314     Shape bwRecurrentWeights = context->getInputShape(kBwRecurrentWeightsTensor);
315     Shape bwBias = context->getInputShape(kBwBiasTensor);
316     Shape bwHiddenState = context->getInputShape(kBwHiddenStateTensor);
317 
318     Shape auxInput = context->getInputShape(kAuxInputTensor);
319     Shape fwAuxWeights = context->getInputShape(kFwAuxWeightsTensor);
320     Shape bwAuxWeights = context->getInputShape(kBwAuxWeightsTensor);
321 
322     LinkingMode linkingMode;
323     NN_RET_CHECK(getLinkingMode(context, &linkingMode));
324 
325     bool timeMajor = context->getInputValue<bool>(kTimeMajorParam);
326     const uint32_t batchSize =
327             timeMajor ? getSizeOfDimension(input, 1) : getSizeOfDimension(input, 0);
328     const uint32_t maxTime =
329             timeMajor ? getSizeOfDimension(input, 0) : getSizeOfDimension(input, 1);
330     const uint32_t fwNumUnits = getSizeOfDimension(fwWeights, 0);
331     const uint32_t bwNumUnits = getSizeOfDimension(bwWeights, 0);
332     const uint32_t inputSize = getSizeOfDimension(input, 2);
333 
334     NN_RET_CHECK_EQ(getNumberOfDimensions(input), 3u);
335     NN_RET_CHECK_EQ(getNumberOfDimensions(fwWeights), 2u);
336     NN_RET_CHECK_EQ(getNumberOfDimensions(fwRecurrentWeights), 2u);
337     NN_RET_CHECK_EQ(getNumberOfDimensions(fwBias), 1u);
338     NN_RET_CHECK_EQ(getNumberOfDimensions(fwHiddenState), 2u);
339     NN_RET_CHECK_EQ(getNumberOfDimensions(bwWeights), 2u);
340     NN_RET_CHECK_EQ(getNumberOfDimensions(bwRecurrentWeights), 2u);
341     NN_RET_CHECK_EQ(getNumberOfDimensions(bwBias), 1u);
342     NN_RET_CHECK_EQ(getNumberOfDimensions(bwHiddenState), 2u);
343 
344     NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(fwWeights, 1));
345     NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwBias, 0));
346     NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwRecurrentWeights, 0));
347     NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwRecurrentWeights, 1));
348     NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(fwHiddenState, 0));
349     NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwHiddenState, 1));
350 
351     if (linkingMode != LinkingMode::PARALLEL_LINKING) {
352         NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(bwWeights, 1));
353     }
354     NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwBias, 0));
355     NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwRecurrentWeights, 0));
356     NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwRecurrentWeights, 1));
357     NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(bwHiddenState, 0));
358     NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwHiddenState, 1));
359 
360     if (linkingMode == LinkingMode::CROSS_LINKING) {
361         NN_RET_CHECK_EQ(getNumberOfDimensions(auxInput), 3u);
362         NN_RET_CHECK_EQ(getNumberOfDimensions(fwAuxWeights), 2u);
363         NN_RET_CHECK_EQ(getNumberOfDimensions(bwAuxWeights), 2u);
364 
365         NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 0), getSizeOfDimension(input, 0));
366         NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 1), getSizeOfDimension(input, 1));
367         NN_RET_CHECK_EQ(getSizeOfDimension(fwAuxWeights, 0), fwNumUnits);
368         NN_RET_CHECK_EQ(getSizeOfDimension(fwAuxWeights, 1), getSizeOfDimension(auxInput, 2));
369         NN_RET_CHECK_EQ(getSizeOfDimension(bwAuxWeights, 0), bwNumUnits);
370         NN_RET_CHECK_EQ(getSizeOfDimension(bwAuxWeights, 1), getSizeOfDimension(auxInput, 2));
371     } else if (linkingMode == LinkingMode::PARALLEL_LINKING) {
372         NN_RET_CHECK_EQ(getNumberOfDimensions(auxInput), 3u);
373 
374         NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 0), getSizeOfDimension(input, 0));
375         NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 1), getSizeOfDimension(input, 1));
376         NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 2), getSizeOfDimension(bwWeights, 1));
377     }
378 
379     Shape fwOutput = context->getOutputShape(kFwOutputTensor);
380     fwOutput.dimensions.resize(3);
381     fwOutput.dimensions[0] = timeMajor ? maxTime : batchSize;
382     fwOutput.dimensions[1] = timeMajor ? batchSize : maxTime;
383     fwOutput.dimensions[2] = mergeOutputs ? fwNumUnits + bwNumUnits : fwNumUnits;
384     NN_RET_CHECK(context->setOutputShape(kFwOutputTensor, fwOutput));
385     if (!mergeOutputs) {
386         Shape bwOutput = context->getOutputShape(kBwOutputTensor);
387         bwOutput.dimensions.resize(3);
388         bwOutput.dimensions[0] = timeMajor ? maxTime : batchSize;
389         bwOutput.dimensions[1] = timeMajor ? batchSize : maxTime;
390         bwOutput.dimensions[2] = bwNumUnits;
391         NN_RET_CHECK(context->setOutputShape(kBwOutputTensor, bwOutput));
392     }
393 
394     const bool outputState =
395             (numOutputs == kNumOutputsWithState || numOutputs == kNumOutputsMergedWithState);
396     if (outputState) {
397         const int delta = mergeOutputs ? 1 : 0;
398         NN_RET_CHECK(context->setOutputShape(kFwOutputHiddenStateTensor - delta,
399                                              context->getInputShape(kFwHiddenStateTensor)));
400         NN_RET_CHECK(context->setOutputShape(kBwOutputHiddenStateTensor - delta,
401                                              context->getInputShape(kBwHiddenStateTensor)));
402     }
403 
404     return true;
405 }
406 
execute(IOperationExecutionContext * context)407 bool execute(IOperationExecutionContext* context) {
408     if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) {
409         executeTyped<_Float16>(context);
410     } else {
411         executeTyped<float>(context);
412     }
413     return true;
414 }
415 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
416 
417 }  // namespace bidirectional_sequence_rnn
418 
419 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(BIDIRECTIONAL_SEQUENCE_RNN,
420                                          bidirectional_sequence_rnn::prepare,
421                                          bidirectional_sequence_rnn::execute,
422                                          .allowOmittedOperand = true);
423 
424 }  // namespace nn
425 }  // namespace android
426