1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_TYPES_OPERATIONS_RNN_H
18 #define ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_TYPES_OPERATIONS_RNN_H
19 
20 #include <vector>
21 
22 #include "ActivationFunctor.h"
23 #include "OperationsValidationUtils.h"
24 #include "nnapi/Types.h"
25 
26 namespace android {
27 namespace nn {
28 namespace rnn {
29 
30 // TODO: Add input/output labels.
31 
32 }  // namespace rnn
33 
34 struct RunTimeOperandInfo;
35 struct Shape;
36 
37 class RNN {
38    public:
39     RNN(const Operation& operation, RunTimeOperandInfo* operands);
40 
41     static bool Prepare(const Operation& operation, RunTimeOperandInfo* operands,
42                         Shape* hiddenStateShape, Shape* outputShape);
43     bool Eval();
44 
45     static constexpr int kInputTensor = 0;
46     static constexpr int kWeightsTensor = 1;
47     static constexpr int kRecurrentWeightsTensor = 2;
48     static constexpr int kBiasTensor = 3;
49     static constexpr int kHiddenStateInTensor = 4;
50     static constexpr int kActivationParam = 5;
51 
52     static constexpr int kHiddenStateOutTensor = 0;
53     static constexpr int kOutputTensor = 1;
54 
55     template <typename T>
56     static bool RNNStep(const T* inputData, const Shape& inputShape, const T* hiddenStateInputData,
57                         const T* biasData, const T* weightsData, const Shape& weightsShape,
58                         const T* recurrentWeightsData, const Shape& recurrentWeightsShape,
59                         int32_t activation, T* outputData);
60 
61     template <typename T>
62     static bool RNNStep(const T* inputData, const Shape& inputShape, const T* auxInputData,
63                         const Shape& auxInputShape, const T* hiddenStateInputData,
64                         const T* biasData, const T* weightsData, const Shape& weightsShape,
65                         const T* auxWeightsData, const Shape& auxWeightsShape,
66                         const T* recurrentWeightsData, const Shape& recurrentWeightsShape,
67                         int32_t activation, uint32_t outputBatchStride, uint32_t outputBatchStep,
68                         T* outputData, T* hiddenStateOutput = nullptr);
69 
70    private:
71     ActivationFn activation_;
72 
73     const RunTimeOperandInfo* input_;
74     const RunTimeOperandInfo* weights_;
75     const RunTimeOperandInfo* recurrent_weights_;
76     const RunTimeOperandInfo* bias_;
77     const RunTimeOperandInfo* hidden_state_in_;
78 
79     RunTimeOperandInfo* hidden_state_out_;
80     RunTimeOperandInfo* output_;
81 };
82 
83 }  // namespace nn
84 }  // namespace android
85 
86 #endif  // ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_TYPES_OPERATIONS_RNN_H
87