1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "RNN.h"
18
19 #include "CpuExecutor.h"
20 #include "HalInterfaces.h"
21
22 namespace android {
23 namespace nn {
24
RNN(const Operation & operation,std::vector<RunTimeOperandInfo> & operands)25 RNN::RNN(const Operation& operation,
26 std::vector<RunTimeOperandInfo>& operands) {
27 input_ = GetInput(operation, operands, kInputTensor);
28 weights_ = GetInput(operation, operands, kWeightsTensor);
29 recurrent_weights_ = GetInput(operation, operands, kRecurrentWeightsTensor);
30 hidden_state_in_ = GetInput(operation, operands, kHiddenStateInTensor);
31 bias_ = GetInput(operation, operands, kBiasTensor);
32
33 activation_ = static_cast<ActivationFn>(
34 getScalarData<int32_t>(operands[operation.inputs[kActivationParam]]));
35
36 hidden_state_out_ = GetOutput(operation, operands, kHiddenStateOutTensor);
37 output_ = GetOutput(operation, operands, kOutputTensor);
38 }
39
Prepare(const Operation & operation,std::vector<RunTimeOperandInfo> & operands,Shape * hiddenStateShape,Shape * outputShape)40 bool RNN::Prepare(const Operation &operation,
41 std::vector<RunTimeOperandInfo> &operands,
42 Shape *hiddenStateShape,
43 Shape *outputShape) {
44 // Check we have all the inputs and outputs we need.
45 const int num_inputs = NumInputsWithValues(operation, operands);
46 NN_CHECK(num_inputs == 5 || num_inputs == 6);
47 NN_CHECK_EQ(NumOutputs(operation), 2);
48
49 const RunTimeOperandInfo *input =
50 GetInput(operation, operands, kInputTensor);
51 const RunTimeOperandInfo *input_weights =
52 GetInput(operation, operands, kWeightsTensor);
53 const RunTimeOperandInfo *recurrent_weights =
54 GetInput(operation, operands, kRecurrentWeightsTensor);
55 const RunTimeOperandInfo *bias =
56 GetInput(operation, operands, kBiasTensor);
57
58 // Check all the parameters of tensor match within themselves and match the
59 // input configuration.
60 const uint32_t batch_size = SizeOfDimension(input, 0);
61 const uint32_t num_units = SizeOfDimension(input_weights, 0);
62 NN_CHECK_EQ(SizeOfDimension(input, 1), SizeOfDimension(input_weights, 1));
63 NN_CHECK_EQ(SizeOfDimension(input_weights, 0), SizeOfDimension(bias, 0));
64 NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 0), SizeOfDimension(bias, 0));
65 NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 1), SizeOfDimension(bias, 0));
66
67 const Shape &inputShape = input->shape();
68
69 // Resize state.
70 hiddenStateShape->type = inputShape.type;
71 hiddenStateShape->dimensions = { batch_size, num_units };
72
73 // Resize output.
74 outputShape->type = inputShape.type;
75 outputShape->dimensions = { batch_size, num_units };
76
77 return true;
78 }
79
Eval()80 bool RNN::Eval() {
81 const float* bias_ptr = reinterpret_cast<float*>(bias_->buffer);
82
83 const uint32_t batch_size = input_->shape().dimensions[0];
84 const uint32_t num_units = weights_->shape().dimensions[0];
85 const uint32_t input_size = input_->shape().dimensions[1];
86 const uint32_t input_weights_stride = weights_->shape().dimensions[1];
87 const uint32_t recurrent_weights_stride =
88 recurrent_weights_->shape().dimensions[1];
89
90 // For each batch
91 for (uint32_t b = 0; b < batch_size; b++) {
92 // Initialize the pointer to input, output and bias.
93 const float* input_ptr_batch =
94 reinterpret_cast<float*>(input_->buffer) + b * input_size;
95 const float* hidden_state_in_ptr_batch =
96 reinterpret_cast<float*>(hidden_state_in_->buffer) + b * num_units;
97 float* output_ptr_batch =
98 reinterpret_cast<float*>(output_->buffer) + b * num_units;
99 float* hidden_state_out_ptr_batch =
100 reinterpret_cast<float*>(hidden_state_out_->buffer) + b * num_units;
101
102 // Initialize input_weights and recurrent_weights.
103 const float* input_weights_ptr = reinterpret_cast<float*>(weights_->buffer);
104 const float* recurrent_weights_ptr =
105 reinterpret_cast<float*>(recurrent_weights_->buffer);
106
107 // Output = bias
108 for (uint32_t o = 0; o < num_units; o++) {
109 output_ptr_batch[o] = bias_ptr[o];
110 }
111
112 // Output += input * input_weights
113 for (uint32_t o = 0; o < num_units; o++) {
114 for (uint32_t i = 0; i < input_size; i++) {
115 output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i];
116 }
117 input_weights_ptr += input_weights_stride;
118 }
119
120 // Output += recurrent_weights * hidden_state
121 for (uint32_t o = 0; o < num_units; o++) {
122 for (uint32_t h = 0; h < num_units; h++) {
123 output_ptr_batch[o] +=
124 hidden_state_in_ptr_batch[h] * recurrent_weights_ptr[h];
125 }
126 recurrent_weights_ptr += recurrent_weights_stride;
127 }
128
129 // Output = activation(Output) and update hidden_state
130 for (uint32_t o = 0; o < num_units; o++) {
131 output_ptr_batch[o] =
132 (ActivationFunctor(activation_))(output_ptr_batch[o]);
133 hidden_state_out_ptr_batch[o] = output_ptr_batch[o];
134 }
135 }
136
137 return true;
138 }
139
140 } // namespace nn
141 } // namespace android
142