1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "RNN.h"
18 
19 #include "CpuExecutor.h"
20 #include "HalInterfaces.h"
21 
22 namespace android {
23 namespace nn {
24 
RNN(const Operation & operation,std::vector<RunTimeOperandInfo> & operands)25 RNN::RNN(const Operation& operation,
26          std::vector<RunTimeOperandInfo>& operands) {
27   input_ = GetInput(operation, operands, kInputTensor);
28   weights_ = GetInput(operation, operands, kWeightsTensor);
29   recurrent_weights_ = GetInput(operation, operands, kRecurrentWeightsTensor);
30   hidden_state_in_ = GetInput(operation, operands, kHiddenStateInTensor);
31   bias_ = GetInput(operation, operands, kBiasTensor);
32 
33   activation_ = static_cast<ActivationFn>(
34       getScalarData<int32_t>(operands[operation.inputs[kActivationParam]]));
35 
36   hidden_state_out_ = GetOutput(operation, operands, kHiddenStateOutTensor);
37   output_ = GetOutput(operation, operands, kOutputTensor);
38 }
39 
Prepare(const Operation & operation,std::vector<RunTimeOperandInfo> & operands,Shape * hiddenStateShape,Shape * outputShape)40 bool RNN::Prepare(const Operation &operation,
41                   std::vector<RunTimeOperandInfo> &operands,
42                   Shape *hiddenStateShape,
43                   Shape *outputShape) {
44   // Check we have all the inputs and outputs we need.
45   const int num_inputs = NumInputsWithValues(operation, operands);
46   NN_CHECK(num_inputs == 5 || num_inputs == 6);
47   NN_CHECK_EQ(NumOutputs(operation), 2);
48 
49   const RunTimeOperandInfo *input =
50       GetInput(operation, operands, kInputTensor);
51   const RunTimeOperandInfo *input_weights =
52       GetInput(operation, operands, kWeightsTensor);
53   const RunTimeOperandInfo *recurrent_weights =
54       GetInput(operation, operands, kRecurrentWeightsTensor);
55   const RunTimeOperandInfo *bias =
56       GetInput(operation, operands, kBiasTensor);
57 
58   // Check all the parameters of tensor match within themselves and match the
59   // input configuration.
60   const uint32_t batch_size = SizeOfDimension(input, 0);
61   const uint32_t num_units = SizeOfDimension(input_weights, 0);
62   NN_CHECK_EQ(SizeOfDimension(input, 1), SizeOfDimension(input_weights, 1));
63   NN_CHECK_EQ(SizeOfDimension(input_weights, 0), SizeOfDimension(bias, 0));
64   NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 0), SizeOfDimension(bias, 0));
65   NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 1), SizeOfDimension(bias, 0));
66 
67   const Shape &inputShape = input->shape();
68 
69   // Resize state.
70   hiddenStateShape->type = inputShape.type;
71   hiddenStateShape->dimensions = { batch_size, num_units };
72 
73   // Resize output.
74   outputShape->type = inputShape.type;
75   outputShape->dimensions = { batch_size, num_units };
76 
77   return true;
78 }
79 
Eval()80 bool RNN::Eval() {
81   const float* bias_ptr = reinterpret_cast<float*>(bias_->buffer);
82 
83   const uint32_t batch_size = input_->shape().dimensions[0];
84   const uint32_t num_units = weights_->shape().dimensions[0];
85   const uint32_t input_size = input_->shape().dimensions[1];
86   const uint32_t input_weights_stride = weights_->shape().dimensions[1];
87   const uint32_t recurrent_weights_stride =
88       recurrent_weights_->shape().dimensions[1];
89 
90   // For each batch
91   for (uint32_t b = 0; b < batch_size; b++) {
92     // Initialize the pointer to input, output and bias.
93     const float* input_ptr_batch =
94         reinterpret_cast<float*>(input_->buffer) + b * input_size;
95     const float* hidden_state_in_ptr_batch =
96         reinterpret_cast<float*>(hidden_state_in_->buffer) + b * num_units;
97     float* output_ptr_batch =
98         reinterpret_cast<float*>(output_->buffer) + b * num_units;
99     float* hidden_state_out_ptr_batch =
100         reinterpret_cast<float*>(hidden_state_out_->buffer) + b * num_units;
101 
102     // Initialize input_weights and recurrent_weights.
103     const float* input_weights_ptr = reinterpret_cast<float*>(weights_->buffer);
104     const float* recurrent_weights_ptr =
105         reinterpret_cast<float*>(recurrent_weights_->buffer);
106 
107     // Output = bias
108     for (uint32_t o = 0; o < num_units; o++) {
109       output_ptr_batch[o] = bias_ptr[o];
110     }
111 
112     // Output += input * input_weights
113     for (uint32_t o = 0; o < num_units; o++) {
114       for (uint32_t i = 0; i < input_size; i++) {
115         output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i];
116       }
117       input_weights_ptr += input_weights_stride;
118     }
119 
120     // Output += recurrent_weights * hidden_state
121     for (uint32_t o = 0; o < num_units; o++) {
122       for (uint32_t h = 0; h < num_units; h++) {
123         output_ptr_batch[o] +=
124             hidden_state_in_ptr_batch[h] * recurrent_weights_ptr[h];
125       }
126       recurrent_weights_ptr += recurrent_weights_stride;
127     }
128 
129     // Output = activation(Output) and update hidden_state
130     for (uint32_t o = 0; o < num_units; o++) {
131       output_ptr_batch[o] =
132           (ActivationFunctor(activation_))(output_ptr_batch[o]);
133       hidden_state_out_ptr_batch[o] = output_ptr_batch[o];
134     }
135   }
136 
137   return true;
138 }
139 
140 }  // namespace nn
141 }  // namespace android
142