1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include "BidirectionalSequenceRNN.h"
20
21 #include <algorithm>
22 #include <utility>
23 #include <vector>
24
25 #include "OperationResolver.h"
26 #include "RNN.h"
27
28 namespace android {
29 namespace nn {
30 namespace bidirectional_sequence_rnn {
31
32 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
33 namespace {
34
35 template <typename T>
transposeFirstTwoDims(const T * input,const Shape & inputShape,T * output)36 void transposeFirstTwoDims(const T* input, const Shape& inputShape, T* output) {
37 const uint32_t firstDimSize = getSizeOfDimension(inputShape, 0);
38 const uint32_t secondDimSize = getSizeOfDimension(inputShape, 1);
39 const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
40 for (uint32_t f = 0; f < firstDimSize; ++f) {
41 for (uint32_t s = 0; s < secondDimSize; ++s) {
42 for (uint32_t i = 0; i < inputSize; ++i) {
43 const uint32_t inputIndex = f * secondDimSize * inputSize + s * inputSize + i;
44 const uint32_t outputIndex = s * firstDimSize * inputSize + f * inputSize + i;
45 output[outputIndex] = input[inputIndex];
46 }
47 }
48 }
49 }
50
removeFirstDim(const Shape & input)51 Shape removeFirstDim(const Shape& input) {
52 Shape output = input;
53 output.dimensions.resize(input.dimensions.size() - 1);
54 for (size_t i = 0; i < input.dimensions.size() - 1; ++i) {
55 output.dimensions[i] = input.dimensions[i + 1];
56 }
57 return output;
58 }
59
60 enum class LinkingMode {
61 NO_LINKING,
62 PARALLEL_LINKING,
63 CROSS_LINKING,
64 };
65
getLinkingMode(IOperationExecutionContext * context,LinkingMode * linkingMode)66 bool getLinkingMode(IOperationExecutionContext* context, LinkingMode* linkingMode) {
67 const bool hasAuxInput = !context->isOmittedInput(kAuxInputTensor);
68 const bool hasFwAuxWeights = !context->isOmittedInput(kFwAuxWeightsTensor);
69 const bool hasBwAuxWeights = !context->isOmittedInput(kBwAuxWeightsTensor);
70
71 // Three possible configurations for three possible linking modes:
72 // 1) NO_LINKING -- no auxiliary tensors at all
73 // 2) PARALLEL_LINKING -- auxiliary input is provided and used as a regular
74 // input to the backward network, so the auxiliary weights are omitted.
75 // 3) CROSS_LINKING -- auxiliary input is provided and multiplied by
76 // auxiliary weights.
77 if (!hasAuxInput && !hasFwAuxWeights && !hasBwAuxWeights) {
78 *linkingMode = LinkingMode::NO_LINKING;
79 } else if (hasAuxInput && !hasFwAuxWeights && !hasBwAuxWeights) {
80 *linkingMode = LinkingMode::PARALLEL_LINKING;
81 } else if (hasAuxInput && hasFwAuxWeights && hasBwAuxWeights) {
82 *linkingMode = LinkingMode::CROSS_LINKING;
83 } else {
84 NN_RET_CHECK_FAIL()
85 << "Unsupported auxiliary tensors configuration for BIDIRECTIONAL_SEQUENCE_RNN.";
86 }
87
88 return true;
89 }
90
91 template <typename T>
executeTyped(IOperationExecutionContext * context)92 bool executeTyped(IOperationExecutionContext* context) {
93 const T* input = context->getInputBuffer<T>(kInputTensor);
94 Shape inputShape = context->getInputShape(kInputTensor);
95
96 const T* fwWeights = context->getInputBuffer<T>(kFwWeightsTensor);
97 Shape fwWeightsShape = context->getInputShape(kFwWeightsTensor);
98 const T* fwRecurrentWeights = context->getInputBuffer<T>(kFwRecurrentWeightsTensor);
99 Shape fwRecurrentWeightsShape = context->getInputShape(kFwRecurrentWeightsTensor);
100 const T* fwBias = context->getInputBuffer<T>(kFwBiasTensor);
101 const T* fwHiddenState = context->getInputBuffer<T>(kFwHiddenStateTensor);
102
103 const T* bwWeights = context->getInputBuffer<T>(kBwWeightsTensor);
104 Shape bwWeightsShape = context->getInputShape(kBwWeightsTensor);
105 const T* bwRecurrentWeights = context->getInputBuffer<T>(kBwRecurrentWeightsTensor);
106 Shape bwRecurrentWeightsShape = context->getInputShape(kBwRecurrentWeightsTensor);
107 const T* bwBias = context->getInputBuffer<T>(kBwBiasTensor);
108 const T* bwHiddenState = context->getInputBuffer<T>(kBwHiddenStateTensor);
109
110 const T* auxInput = nullptr;
111 const T* fwAuxWeights = nullptr;
112 const T* bwAuxWeights = nullptr;
113 LinkingMode linkingMode;
114 NN_RET_CHECK(getLinkingMode(context, &linkingMode));
115 if (linkingMode == LinkingMode::CROSS_LINKING) {
116 auxInput = context->getInputBuffer<T>(kAuxInputTensor);
117 fwAuxWeights = context->getInputBuffer<T>(kFwAuxWeightsTensor);
118 bwAuxWeights = context->getInputBuffer<T>(kBwAuxWeightsTensor);
119 } else if (linkingMode == LinkingMode::PARALLEL_LINKING) {
120 auxInput = context->getInputBuffer<T>(kAuxInputTensor);
121 }
122 const bool hasAuxInput = (linkingMode == LinkingMode::CROSS_LINKING ||
123 linkingMode == LinkingMode::PARALLEL_LINKING);
124 const bool hasAuxWeights = (linkingMode == LinkingMode::CROSS_LINKING);
125 Shape auxInputShape = context->getInputShape(kAuxInputTensor);
126 Shape fwAuxWeightsShape = context->getInputShape(kFwAuxWeightsTensor);
127 Shape bwAuxWeightsShape = context->getInputShape(kBwAuxWeightsTensor);
128
129 const int32_t activation = context->getInputValue<int32_t>(kActivationParam);
130 const bool timeMajor = context->getInputValue<bool>(kTimeMajorParam);
131 const bool mergeOutputs = context->getInputValue<bool>(kMergeOutputsParam);
132
133 T* fwOutput = context->getOutputBuffer<T>(kFwOutputTensor);
134 Shape fwOutputShape = context->getOutputShape(kFwOutputTensor);
135 T* bwOutput = nullptr;
136 Shape bwOutputShape;
137 if (!mergeOutputs) {
138 bwOutputShape = context->getOutputShape(kBwOutputTensor);
139 bwOutput = context->getOutputBuffer<T>(kBwOutputTensor);
140 }
141
142 // If the input tensors are not in time major format, we transpose the first
143 // two dimensions, and set input and output pointers to temporary vectors
144 // which are transposed back after the RNN is applied.
145 std::vector<T> inputTransposed;
146 std::vector<T> auxInputTransposed;
147 std::vector<T> fwOutputTransposed;
148 std::vector<T> bwOutputTransposed;
149 if (!timeMajor) {
150 // First, resize temporary buffers to accommodate for transposed tensors.
151 inputTransposed.resize(getNumberOfElements(inputShape));
152 if (hasAuxInput) {
153 auxInputTransposed.resize(getNumberOfElements(auxInputShape));
154 }
155 fwOutputTransposed.resize(getNumberOfElements(fwOutputShape));
156 if (!mergeOutputs) {
157 bwOutputTransposed.resize(getNumberOfElements(bwOutputShape));
158 }
159
160 // Transpose the input tensors.
161 transposeFirstTwoDims(input, inputShape, inputTransposed.data());
162 if (hasAuxInput) {
163 transposeFirstTwoDims(auxInput, auxInputShape, auxInputTransposed.data());
164 }
165
166 // Change input and output pointers to the temporary buffers.
167 input = inputTransposed.data();
168 if (hasAuxInput) {
169 auxInput = auxInputTransposed.data();
170 }
171 fwOutput = fwOutputTransposed.data();
172 if (!mergeOutputs) {
173 bwOutput = bwOutputTransposed.data();
174 }
175
176 // Swap the first two dimensions in the Shapes to reflect the
177 // transposition.
178 std::swap(inputShape.dimensions[0], inputShape.dimensions[1]);
179 if (hasAuxInput) {
180 std::swap(auxInputShape.dimensions[0], auxInputShape.dimensions[1]);
181 }
182 std::swap(fwOutputShape.dimensions[0], fwOutputShape.dimensions[1]);
183 if (!mergeOutputs) {
184 std::swap(bwOutputShape.dimensions[0], bwOutputShape.dimensions[1]);
185 }
186 }
187
188 const uint32_t maxTime = getSizeOfDimension(inputShape, 0);
189 const uint32_t batchSize = getSizeOfDimension(inputShape, 1);
190 const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
191 uint32_t auxInputSize = 0;
192 if (hasAuxInput) {
193 auxInputSize = getSizeOfDimension(auxInputShape, 2);
194 }
195 const uint32_t fwNumUnits = getSizeOfDimension(fwWeightsShape, 0);
196 const uint32_t bwNumUnits = getSizeOfDimension(bwWeightsShape, 0);
197
198 Shape fixedTimeInputShape = removeFirstDim(inputShape);
199 Shape fixedTimeAuxInputShape = auxInputShape;
200 if (hasAuxInput) {
201 fixedTimeAuxInputShape = removeFirstDim(auxInputShape);
202 }
203
204 const T* bwInput = input;
205 if (linkingMode == LinkingMode::PARALLEL_LINKING) {
206 bwInput = auxInput;
207 auxInput = nullptr;
208 }
209
210 const bool outputState = (context->getNumOutputs() == kNumOutputsWithState ||
211 context->getNumOutputs() == kNumOutputsMergedWithState);
212 T* fwOutputHiddenState = nullptr;
213 T* bwOutputHiddenState = nullptr;
214 // Create an additional buffer to store a hidden state between steps.
215 std::vector<T> tempHiddenState;
216 if (outputState) {
217 const int delta = mergeOutputs ? 1 : 0;
218 fwOutputHiddenState = context->getOutputBuffer<T>(kFwOutputHiddenStateTensor - delta);
219 bwOutputHiddenState = context->getOutputBuffer<T>(kBwOutputHiddenStateTensor - delta);
220 } else {
221 tempHiddenState.resize(std::max(batchSize * fwNumUnits, batchSize * bwNumUnits));
222 fwOutputHiddenState = tempHiddenState.data();
223 bwOutputHiddenState = tempHiddenState.data();
224 }
225
226 // Forward pass
227 for (uint32_t i = 0; i < maxTime; ++i) {
228 const T* inputBatchPtr = input + i * batchSize * inputSize;
229 const T* auxInputBatchPtr = nullptr;
230 if (hasAuxWeights) {
231 auxInputBatchPtr = auxInput + i * batchSize * auxInputSize;
232 }
233 const uint32_t fwOutputBatchStride = mergeOutputs ? (fwNumUnits + bwNumUnits) : fwNumUnits;
234 T* fwOutputBatchPtr = fwOutput + i * batchSize * fwOutputBatchStride;
235
236 RNN::RNNStep<T>(inputBatchPtr, fixedTimeInputShape, auxInputBatchPtr,
237 fixedTimeAuxInputShape, fwHiddenState, fwBias, fwWeights, fwWeightsShape,
238 fwAuxWeights, fwAuxWeightsShape, fwRecurrentWeights,
239 fwRecurrentWeightsShape, activation, fwOutputBatchStride,
240 /*outputBatchOffset=*/0, fwOutputBatchPtr, fwOutputHiddenState);
241
242 fwHiddenState = fwOutputHiddenState;
243 }
244
245 // Backward pass
246 for (int i = maxTime - 1; i >= 0; --i) {
247 const T* inputBatchPtr = bwInput + i * batchSize * inputSize;
248 const T* auxInputBatchPtr = nullptr;
249 if (hasAuxWeights) {
250 auxInputBatchPtr = auxInput + i * batchSize * auxInputSize;
251 }
252 T* bwOutputBatchPtr;
253 uint32_t bwOutputBatchOffset = 0;
254 uint32_t bwOutputBatchStride;
255 if (mergeOutputs) {
256 bwOutputBatchStride = fwNumUnits + bwNumUnits;
257 bwOutputBatchOffset = fwNumUnits;
258 bwOutputBatchPtr = fwOutput + i * batchSize * bwOutputBatchStride;
259 } else {
260 bwOutputBatchStride = bwNumUnits;
261 bwOutputBatchPtr = bwOutput + i * batchSize * bwOutputBatchStride;
262 }
263
264 RNN::RNNStep<T>(inputBatchPtr, fixedTimeInputShape, auxInputBatchPtr,
265 fixedTimeAuxInputShape, bwHiddenState, bwBias, bwWeights, bwWeightsShape,
266 bwAuxWeights, bwAuxWeightsShape, bwRecurrentWeights,
267 bwRecurrentWeightsShape, activation, bwOutputBatchStride,
268 bwOutputBatchOffset, bwOutputBatchPtr, bwOutputHiddenState);
269
270 bwHiddenState = bwOutputHiddenState;
271 }
272
273 // If the inputs were in batch major format, transpose data in temporary
274 // buffers and write to the output(s).
275 if (!timeMajor) {
276 transposeFirstTwoDims(fwOutputTransposed.data(), fwOutputShape,
277 context->getOutputBuffer<T>(kFwOutputTensor));
278 if (!mergeOutputs) {
279 transposeFirstTwoDims(bwOutputTransposed.data(), bwOutputShape,
280 context->getOutputBuffer<T>(kBwOutputTensor));
281 }
282 }
283 return true;
284 }
285
286 } // namespace
287
prepare(IOperationExecutionContext * context)288 bool prepare(IOperationExecutionContext* context) {
289 const bool mergeOutputs = context->getInputValue<bool>(kMergeOutputsParam);
290 const int32_t numOutputs = context->getNumOutputs();
291 if (mergeOutputs) {
292 NN_RET_CHECK(numOutputs == kNumOutputsMerged || numOutputs == kNumOutputsMergedWithState);
293 } else {
294 NN_RET_CHECK(numOutputs == kNumOutputs || numOutputs == kNumOutputsWithState);
295 }
296
297 // Check that none of the required inputs are omitted.
298 const std::vector<int> requiredInputs = {
299 kInputTensor, kFwWeightsTensor, kFwRecurrentWeightsTensor, kFwBiasTensor,
300 kFwHiddenStateTensor, kBwWeightsTensor, kBwRecurrentWeightsTensor, kBwBiasTensor,
301 kBwHiddenStateTensor, kActivationParam, kTimeMajorParam, kMergeOutputsParam,
302 };
303 for (const int requiredInput : requiredInputs) {
304 NN_RET_CHECK(!context->isOmittedInput(requiredInput))
305 << "required input " << requiredInput << " is omitted";
306 }
307
308 Shape input = context->getInputShape(kInputTensor);
309 Shape fwWeights = context->getInputShape(kFwWeightsTensor);
310 Shape fwRecurrentWeights = context->getInputShape(kFwRecurrentWeightsTensor);
311 Shape fwBias = context->getInputShape(kFwBiasTensor);
312 Shape fwHiddenState = context->getInputShape(kFwHiddenStateTensor);
313 Shape bwWeights = context->getInputShape(kBwWeightsTensor);
314 Shape bwRecurrentWeights = context->getInputShape(kBwRecurrentWeightsTensor);
315 Shape bwBias = context->getInputShape(kBwBiasTensor);
316 Shape bwHiddenState = context->getInputShape(kBwHiddenStateTensor);
317
318 Shape auxInput = context->getInputShape(kAuxInputTensor);
319 Shape fwAuxWeights = context->getInputShape(kFwAuxWeightsTensor);
320 Shape bwAuxWeights = context->getInputShape(kBwAuxWeightsTensor);
321
322 LinkingMode linkingMode;
323 NN_RET_CHECK(getLinkingMode(context, &linkingMode));
324
325 bool timeMajor = context->getInputValue<bool>(kTimeMajorParam);
326 const uint32_t batchSize =
327 timeMajor ? getSizeOfDimension(input, 1) : getSizeOfDimension(input, 0);
328 const uint32_t maxTime =
329 timeMajor ? getSizeOfDimension(input, 0) : getSizeOfDimension(input, 1);
330 const uint32_t fwNumUnits = getSizeOfDimension(fwWeights, 0);
331 const uint32_t bwNumUnits = getSizeOfDimension(bwWeights, 0);
332 const uint32_t inputSize = getSizeOfDimension(input, 2);
333
334 NN_RET_CHECK_EQ(getNumberOfDimensions(input), 3u);
335 NN_RET_CHECK_EQ(getNumberOfDimensions(fwWeights), 2u);
336 NN_RET_CHECK_EQ(getNumberOfDimensions(fwRecurrentWeights), 2u);
337 NN_RET_CHECK_EQ(getNumberOfDimensions(fwBias), 1u);
338 NN_RET_CHECK_EQ(getNumberOfDimensions(fwHiddenState), 2u);
339 NN_RET_CHECK_EQ(getNumberOfDimensions(bwWeights), 2u);
340 NN_RET_CHECK_EQ(getNumberOfDimensions(bwRecurrentWeights), 2u);
341 NN_RET_CHECK_EQ(getNumberOfDimensions(bwBias), 1u);
342 NN_RET_CHECK_EQ(getNumberOfDimensions(bwHiddenState), 2u);
343
344 NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(fwWeights, 1));
345 NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwBias, 0));
346 NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwRecurrentWeights, 0));
347 NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwRecurrentWeights, 1));
348 NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(fwHiddenState, 0));
349 NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwHiddenState, 1));
350
351 if (linkingMode != LinkingMode::PARALLEL_LINKING) {
352 NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(bwWeights, 1));
353 }
354 NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwBias, 0));
355 NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwRecurrentWeights, 0));
356 NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwRecurrentWeights, 1));
357 NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(bwHiddenState, 0));
358 NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwHiddenState, 1));
359
360 if (linkingMode == LinkingMode::CROSS_LINKING) {
361 NN_RET_CHECK_EQ(getNumberOfDimensions(auxInput), 3u);
362 NN_RET_CHECK_EQ(getNumberOfDimensions(fwAuxWeights), 2u);
363 NN_RET_CHECK_EQ(getNumberOfDimensions(bwAuxWeights), 2u);
364
365 NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 0), getSizeOfDimension(input, 0));
366 NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 1), getSizeOfDimension(input, 1));
367 NN_RET_CHECK_EQ(getSizeOfDimension(fwAuxWeights, 0), fwNumUnits);
368 NN_RET_CHECK_EQ(getSizeOfDimension(fwAuxWeights, 1), getSizeOfDimension(auxInput, 2));
369 NN_RET_CHECK_EQ(getSizeOfDimension(bwAuxWeights, 0), bwNumUnits);
370 NN_RET_CHECK_EQ(getSizeOfDimension(bwAuxWeights, 1), getSizeOfDimension(auxInput, 2));
371 } else if (linkingMode == LinkingMode::PARALLEL_LINKING) {
372 NN_RET_CHECK_EQ(getNumberOfDimensions(auxInput), 3u);
373
374 NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 0), getSizeOfDimension(input, 0));
375 NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 1), getSizeOfDimension(input, 1));
376 NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 2), getSizeOfDimension(bwWeights, 1));
377 }
378
379 Shape fwOutput = context->getOutputShape(kFwOutputTensor);
380 fwOutput.dimensions.resize(3);
381 fwOutput.dimensions[0] = timeMajor ? maxTime : batchSize;
382 fwOutput.dimensions[1] = timeMajor ? batchSize : maxTime;
383 fwOutput.dimensions[2] = mergeOutputs ? fwNumUnits + bwNumUnits : fwNumUnits;
384 NN_RET_CHECK(context->setOutputShape(kFwOutputTensor, fwOutput));
385 if (!mergeOutputs) {
386 Shape bwOutput = context->getOutputShape(kBwOutputTensor);
387 bwOutput.dimensions.resize(3);
388 bwOutput.dimensions[0] = timeMajor ? maxTime : batchSize;
389 bwOutput.dimensions[1] = timeMajor ? batchSize : maxTime;
390 bwOutput.dimensions[2] = bwNumUnits;
391 NN_RET_CHECK(context->setOutputShape(kBwOutputTensor, bwOutput));
392 }
393
394 const bool outputState =
395 (numOutputs == kNumOutputsWithState || numOutputs == kNumOutputsMergedWithState);
396 if (outputState) {
397 const int delta = mergeOutputs ? 1 : 0;
398 NN_RET_CHECK(context->setOutputShape(kFwOutputHiddenStateTensor - delta,
399 context->getInputShape(kFwHiddenStateTensor)));
400 NN_RET_CHECK(context->setOutputShape(kBwOutputHiddenStateTensor - delta,
401 context->getInputShape(kBwHiddenStateTensor)));
402 }
403
404 return true;
405 }
406
execute(IOperationExecutionContext * context)407 bool execute(IOperationExecutionContext* context) {
408 if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) {
409 executeTyped<_Float16>(context);
410 } else {
411 executeTyped<float>(context);
412 }
413 return true;
414 }
415 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
416
417 } // namespace bidirectional_sequence_rnn
418
419 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(BIDIRECTIONAL_SEQUENCE_RNN,
420 bidirectional_sequence_rnn::prepare,
421 bidirectional_sequence_rnn::execute,
422 .allowOmittedOperand = true);
423
424 } // namespace nn
425 } // namespace android
426