1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "RNN.h"
18 
19 #include "NeuralNetworksWrapper.h"
20 
21 #include <gmock/gmock-matchers.h>
22 #include <gtest/gtest.h>
23 
24 namespace android {
25 namespace nn {
26 namespace wrapper {
27 
28 using ::testing::Each;
29 using ::testing::FloatNear;
30 using ::testing::Matcher;
31 
32 namespace {
33 
ArrayFloatNear(const std::vector<float> & values,float max_abs_error=1.e-5)34 std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
35                                            float max_abs_error = 1.e-5) {
36     std::vector<Matcher<float>> matchers;
37     matchers.reserve(values.size());
38     for (const float& v : values) {
39         matchers.emplace_back(FloatNear(v, max_abs_error));
40     }
41     return matchers;
42 }
43 
44 static float rnn_input[] = {
45         0.23689353,  0.285385,     0.037029743, -0.19858193,  -0.27569133,  0.43773448,
46         0.60379338,  0.35562468,   -0.69424844, -0.93421471,  -0.87287879,  0.37144363,
47         -0.62476718, 0.23791671,   0.40060222,  0.1356622,    -0.99774903,  -0.98858172,
48         -0.38952237, -0.47685933,  0.31073618,  0.71511042,   -0.63767755,  -0.31729108,
49         0.33468103,  0.75801885,   0.30660987,  -0.37354088,  0.77002847,   -0.62747043,
50         -0.68572164, 0.0069220066, 0.65791464,  0.35130811,   0.80834007,   -0.61777675,
51         -0.21095741, 0.41213346,   0.73784804,  0.094794154,  0.47791874,   0.86496925,
52         -0.53376222, 0.85315156,   0.10288584,  0.86684,      -0.011186242, 0.10513687,
53         0.87825835,  0.59929144,   0.62827742,  0.18899453,   0.31440187,   0.99059987,
54         0.87170351,  -0.35091716,  0.74861872,  0.17831337,   0.2755419,    0.51864719,
55         0.55084288,  0.58982027,   -0.47443086, 0.20875752,   -0.058871567, -0.66609079,
56         0.59098077,  0.73017097,   0.74604273,  0.32882881,   -0.17503482,  0.22396147,
57         0.19379807,  0.29120302,   0.077113032, -0.70331609,  0.15804303,   -0.93407321,
58         0.40182066,  0.036301374,  0.66521823,  0.0300982,    -0.7747041,   -0.02038002,
59         0.020698071, -0.90300065,  0.62870288,  -0.23068321,  0.27531278,   -0.095755219,
60         -0.712036,   -0.17384434,  -0.50593495, -0.18646687,  -0.96508682,  0.43519354,
61         0.14744234,  0.62589407,   0.1653645,   -0.10651493,  -0.045277178, 0.99032974,
62         -0.88255352, -0.85147917,  0.28153265,  0.19455957,   -0.55479527,  -0.56042433,
63         0.26048636,  0.84702539,   0.47587705,  -0.074295521, -0.12287641,  0.70117295,
64         0.90532446,  0.89782166,   0.79817224,  0.53402734,   -0.33286154,  0.073485017,
65         -0.56172788, -0.044897556, 0.89964068,  -0.067662835, 0.76863563,   0.93455386,
66         -0.6324693,  -0.083922029};
67 
68 static float rnn_golden_output[] = {
69         0.496726,   0,        0.965996,  0,         0.0584254, 0,          0,         0.12315,
70         0,          0,        0.612266,  0.456601,  0,         0.52286,    1.16099,   0.0291232,
71 
72         0,          0,        0.524901,  0,         0,         0,          0,         1.02116,
73         0,          1.35762,  0,         0.356909,  0.436415,  0.0355727,  0,         0,
74 
75         0,          0,        0,         0.262335,  0,         0,          0,         1.33992,
76         0,          2.9739,   0,         0,         1.31914,   2.66147,    0,         0,
77 
78         0.942568,   0,        0,         0,         0.025507,  0,          0,         0,
79         0.321429,   0.569141, 1.25274,   1.57719,   0.8158,    1.21805,    0.586239,  0.25427,
80 
81         1.04436,    0,        0.630725,  0,         0.133801,  0.210693,   0.363026,  0,
82         0.533426,   0,        1.25926,   0.722707,  0,         1.22031,    1.30117,   0.495867,
83 
84         0.222187,   0,        0.72725,   0,         0.767003,  0,          0,         0.147835,
85         0,          0,        0,         0.608758,  0.469394,  0.00720298, 0.927537,  0,
86 
87         0.856974,   0.424257, 0,         0,         0.937329,  0,          0,         0,
88         0.476425,   0,        0.566017,  0.418462,  0.141911,  0.996214,   1.13063,   0,
89 
90         0.967899,   0,        0,         0,         0.0831304, 0,          0,         1.00378,
91         0,          0,        0,         1.44818,   1.01768,   0.943891,   0.502745,  0,
92 
93         0.940135,   0,        0,         0,         0,         0,          0,         2.13243,
94         0,          0.71208,  0.123918,  1.53907,   1.30225,   1.59644,    0.70222,   0,
95 
96         0.804329,   0,        0.430576,  0,         0.505872,  0.509603,   0.343448,  0,
97         0.107756,   0.614544, 1.44549,   1.52311,   0.0454298, 0.300267,   0.562784,  0.395095,
98 
99         0.228154,   0,        0.675323,  0,         1.70536,   0.766217,   0,         0,
100         0,          0.735363, 0.0759267, 1.91017,   0.941888,  0,          0,         0,
101 
102         0,          0,        1.5909,    0,         0,         0,          0,         0.5755,
103         0,          0.184687, 0,         1.56296,   0.625285,  0,          0,         0,
104 
105         0,          0,        0.0857888, 0,         0,         0,          0,         0.488383,
106         0.252786,   0,        0,         0,         1.02817,   1.85665,    0,         0,
107 
108         0.00981836, 0,        1.06371,   0,         0,         0,          0,         0,
109         0,          0.290445, 0.316406,  0,         0.304161,  1.25079,    0.0707152, 0,
110 
111         0.986264,   0.309201, 0,         0,         0,         0,          0,         1.64896,
112         0.346248,   0,        0.918175,  0.78884,   0.524981,  1.92076,    2.07013,   0.333244,
113 
114         0.415153,   0.210318, 0,         0,         0,         0,          0,         2.02616,
115         0,          0.728256, 0.84183,   0.0907453, 0.628881,  3.58099,    1.49974,   0};
116 
117 }  // anonymous namespace
118 
119 #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
120     ACTION(Input)                                \
121     ACTION(Weights)                              \
122     ACTION(RecurrentWeights)                     \
123     ACTION(Bias)                                 \
124     ACTION(HiddenStateIn)
125 
126 // For all output and intermediate states
127 #define FOR_ALL_OUTPUT_TENSORS(ACTION) \
128     ACTION(HiddenStateOut)             \
129     ACTION(Output)
130 
131 class BasicRNNOpModel {
132    public:
BasicRNNOpModel(uint32_t batches,uint32_t units,uint32_t size)133     BasicRNNOpModel(uint32_t batches, uint32_t units, uint32_t size)
134         : batches_(batches), units_(units), input_size_(size), activation_(kActivationRelu) {
135         std::vector<uint32_t> inputs;
136 
137         OperandType InputTy(Type::TENSOR_FLOAT32, {batches_, input_size_});
138         inputs.push_back(model_.addOperand(&InputTy));
139         OperandType WeightTy(Type::TENSOR_FLOAT32, {units_, input_size_});
140         inputs.push_back(model_.addOperand(&WeightTy));
141         OperandType RecurrentWeightTy(Type::TENSOR_FLOAT32, {units_, units_});
142         inputs.push_back(model_.addOperand(&RecurrentWeightTy));
143         OperandType BiasTy(Type::TENSOR_FLOAT32, {units_});
144         inputs.push_back(model_.addOperand(&BiasTy));
145         OperandType HiddenStateTy(Type::TENSOR_FLOAT32, {batches_, units_});
146         inputs.push_back(model_.addOperand(&HiddenStateTy));
147         OperandType ActionParamTy(Type::INT32, {});
148         inputs.push_back(model_.addOperand(&ActionParamTy));
149 
150         std::vector<uint32_t> outputs;
151 
152         outputs.push_back(model_.addOperand(&HiddenStateTy));
153         OperandType OutputTy(Type::TENSOR_FLOAT32, {batches_, units_});
154         outputs.push_back(model_.addOperand(&OutputTy));
155 
156         Input_.insert(Input_.end(), batches_ * input_size_, 0.f);
157         HiddenStateIn_.insert(HiddenStateIn_.end(), batches_ * units_, 0.f);
158         HiddenStateOut_.insert(HiddenStateOut_.end(), batches_ * units_, 0.f);
159         Output_.insert(Output_.end(), batches_ * units_, 0.f);
160 
161         model_.addOperation(ANEURALNETWORKS_RNN, inputs, outputs);
162         model_.identifyInputsAndOutputs(inputs, outputs);
163 
164         model_.finish();
165     }
166 
167 #define DefineSetter(X) \
168     void Set##X(const std::vector<float>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); }
169 
170     FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
171 
172 #undef DefineSetter
173 
SetInput(int offset,float * begin,float * end)174     void SetInput(int offset, float* begin, float* end) {
175         for (; begin != end; begin++, offset++) {
176             Input_[offset] = *begin;
177         }
178     }
179 
ResetHiddenState()180     void ResetHiddenState() {
181         std::fill(HiddenStateIn_.begin(), HiddenStateIn_.end(), 0.f);
182         std::fill(HiddenStateOut_.begin(), HiddenStateOut_.end(), 0.f);
183     }
184 
GetOutput() const185     const std::vector<float>& GetOutput() const { return Output_; }
186 
input_size() const187     uint32_t input_size() const { return input_size_; }
num_units() const188     uint32_t num_units() const { return units_; }
num_batches() const189     uint32_t num_batches() const { return batches_; }
190 
Invoke()191     void Invoke() {
192         ASSERT_TRUE(model_.isValid());
193 
194         HiddenStateIn_.swap(HiddenStateOut_);
195 
196         Compilation compilation(&model_);
197         compilation.finish();
198         Execution execution(&compilation);
199 #define SetInputOrWeight(X)                                                                    \
200     ASSERT_EQ(execution.setInput(RNN::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
201               Result::NO_ERROR);
202 
203         FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
204 
205 #undef SetInputOrWeight
206 
207 #define SetOutput(X)                                                                            \
208     ASSERT_EQ(execution.setOutput(RNN::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
209               Result::NO_ERROR);
210 
211         FOR_ALL_OUTPUT_TENSORS(SetOutput);
212 
213 #undef SetOutput
214 
215         ASSERT_EQ(execution.setInput(RNN::kActivationParam, &activation_, sizeof(activation_)),
216                   Result::NO_ERROR);
217 
218         ASSERT_EQ(execution.compute(), Result::NO_ERROR);
219     }
220 
221    private:
222     Model model_;
223 
224     const uint32_t batches_;
225     const uint32_t units_;
226     const uint32_t input_size_;
227 
228     const int activation_;
229 
230 #define DefineTensor(X) std::vector<float> X##_;
231 
232     FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
233     FOR_ALL_OUTPUT_TENSORS(DefineTensor);
234 
235 #undef DefineTensor
236 };
237 
TEST(RNNOpTest,BlackBoxTest)238 TEST(RNNOpTest, BlackBoxTest) {
239     BasicRNNOpModel rnn(2, 16, 8);
240     rnn.SetWeights(
241             {0.461459,  0.153381,    0.529743,   -0.00371218, 0.676267,    -0.211346, 0.317493,
242              0.969689,  -0.343251,   0.186423,   0.398151,    0.152399,    0.448504,  0.317662,
243              0.523556,  -0.323514,   0.480877,   0.333113,    -0.757714,   -0.674487, -0.643585,
244              0.217766,  -0.0251462,  0.79512,    -0.595574,   -0.422444,   0.371572,  -0.452178,
245              -0.556069, -0.482188,   -0.685456,  -0.727851,   0.841829,    0.551535,  -0.232336,
246              0.729158,  -0.00294906, -0.69754,   0.766073,    -0.178424,   0.369513,  -0.423241,
247              0.548547,  -0.0152023,  -0.757482,  -0.85491,    0.251331,    -0.989183, 0.306261,
248              -0.340716, 0.886103,    -0.0726757, -0.723523,   -0.784303,   0.0354295, 0.566564,
249              -0.485469, -0.620498,   0.832546,   0.697884,    -0.279115,   0.294415,  -0.584313,
250              0.548772,  0.0648819,   0.968726,   0.723834,    -0.0080452,  -0.350386, -0.272803,
251              0.115121,  -0.412644,   -0.824713,  -0.992843,   -0.592904,   -0.417893, 0.863791,
252              -0.423461, -0.147601,   -0.770664,  -0.479006,   0.654782,    0.587314,  -0.639158,
253              0.816969,  -0.337228,   0.659878,   0.73107,     0.754768,    -0.337042, 0.0960841,
254              0.368357,  0.244191,    -0.817703,  -0.211223,   0.442012,    0.37225,   -0.623598,
255              -0.405423, 0.455101,    0.673656,   -0.145345,   -0.511346,   -0.901675, -0.81252,
256              -0.127006, 0.809865,    -0.721884,  0.636255,    0.868989,    -0.347973, -0.10179,
257              -0.777449, 0.917274,    0.819286,   0.206218,    -0.00785118, 0.167141,  0.45872,
258              0.972934,  -0.276798,   0.837861,   0.747958,    -0.0151566,  -0.330057, -0.469077,
259              0.277308,  0.415818});
260 
261     rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, -0.23566568,
262                  -0.389184, 0.47481549, -0.4791103, 0.29931796, 0.10463274, 0.83918178, 0.37197268,
263                  0.61957061, 0.3956964, -0.37609905});
264 
265     rnn.SetRecurrentWeights(
266             {0.1, 0,   0, 0,   0, 0,   0, 0,   0,  0,   0,   0,   0,   0,   0,   0,   0,   0.1, 0,
267              0,   0,   0, 0,   0, 0,   0, 0,   0,  0,   0,   0,   0,   0,   0,   0.1, 0,   0,   0,
268              0,   0,   0, 0,   0, 0,   0, 0,   0,  0,   0,   0,   0,   0.1, 0,   0,   0,   0,   0,
269              0,   0,   0, 0,   0, 0,   0, 0,   0,  0,   0,   0.1, 0,   0,   0,   0,   0,   0,   0,
270              0,   0,   0, 0,   0, 0,   0, 0,   0,  0.1, 0,   0,   0,   0,   0,   0,   0,   0,   0,
271              0,   0,   0, 0,   0, 0,   0, 0.1, 0,  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
272              0,   0,   0, 0,   0, 0.1, 0, 0,   0,  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
273              0,   0,   0, 0.1, 0, 0,   0, 0,   0,  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
274              0,   0.1, 0, 0,   0, 0,   0, 0,   0,  0,   0,   0,   0,   0,   0,   0,   0,   0,   0.1,
275              0,   0,   0, 0,   0, 0,   0, 0,   0,  0,   0,   0,   0,   0,   0,   0,   0.1, 0,   0,
276              0,   0,   0, 0,   0, 0,   0, 0,   0,  0,   0,   0,   0,   0,   0.1, 0,   0,   0,   0,
277              0,   0,   0, 0,   0, 0,   0, 0,   0,  0,   0,   0,   0.1, 0,   0,   0,   0,   0,   0,
278              0,   0,   0, 0,   0, 0,   0, 0,   0,  0,   0.1, 0,   0,   0,   0,   0,   0,   0,   0,
279              0,   0,   0, 0,   0, 0,   0, 0,   0.1});
280 
281     rnn.ResetHiddenState();
282     const int input_sequence_size =
283             sizeof(rnn_input) / sizeof(float) / (rnn.input_size() * rnn.num_batches());
284 
285     for (int i = 0; i < input_sequence_size; i++) {
286         float* batch_start = rnn_input + i * rnn.input_size();
287         float* batch_end = batch_start + rnn.input_size();
288         rnn.SetInput(0, batch_start, batch_end);
289         rnn.SetInput(rnn.input_size(), batch_start, batch_end);
290 
291         rnn.Invoke();
292 
293         float* golden_start = rnn_golden_output + i * rnn.num_units();
294         float* golden_end = golden_start + rnn.num_units();
295         std::vector<float> expected;
296         expected.insert(expected.end(), golden_start, golden_end);
297         expected.insert(expected.end(), golden_start, golden_end);
298 
299         EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
300     }
301 }
302 
303 }  // namespace wrapper
304 }  // namespace nn
305 }  // namespace android
306