1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "RNN.h"
18
19 #include "NeuralNetworksWrapper.h"
20
21 #include <gmock/gmock-matchers.h>
22 #include <gtest/gtest.h>
23
24 namespace android {
25 namespace nn {
26 namespace wrapper {
27
28 using ::testing::Each;
29 using ::testing::FloatNear;
30 using ::testing::Matcher;
31
32 namespace {
33
ArrayFloatNear(const std::vector<float> & values,float max_abs_error=1.e-5)34 std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
35 float max_abs_error = 1.e-5) {
36 std::vector<Matcher<float>> matchers;
37 matchers.reserve(values.size());
38 for (const float& v : values) {
39 matchers.emplace_back(FloatNear(v, max_abs_error));
40 }
41 return matchers;
42 }
43
44 static float rnn_input[] = {
45 0.23689353, 0.285385, 0.037029743, -0.19858193, -0.27569133, 0.43773448,
46 0.60379338, 0.35562468, -0.69424844, -0.93421471, -0.87287879, 0.37144363,
47 -0.62476718, 0.23791671, 0.40060222, 0.1356622, -0.99774903, -0.98858172,
48 -0.38952237, -0.47685933, 0.31073618, 0.71511042, -0.63767755, -0.31729108,
49 0.33468103, 0.75801885, 0.30660987, -0.37354088, 0.77002847, -0.62747043,
50 -0.68572164, 0.0069220066, 0.65791464, 0.35130811, 0.80834007, -0.61777675,
51 -0.21095741, 0.41213346, 0.73784804, 0.094794154, 0.47791874, 0.86496925,
52 -0.53376222, 0.85315156, 0.10288584, 0.86684, -0.011186242, 0.10513687,
53 0.87825835, 0.59929144, 0.62827742, 0.18899453, 0.31440187, 0.99059987,
54 0.87170351, -0.35091716, 0.74861872, 0.17831337, 0.2755419, 0.51864719,
55 0.55084288, 0.58982027, -0.47443086, 0.20875752, -0.058871567, -0.66609079,
56 0.59098077, 0.73017097, 0.74604273, 0.32882881, -0.17503482, 0.22396147,
57 0.19379807, 0.29120302, 0.077113032, -0.70331609, 0.15804303, -0.93407321,
58 0.40182066, 0.036301374, 0.66521823, 0.0300982, -0.7747041, -0.02038002,
59 0.020698071, -0.90300065, 0.62870288, -0.23068321, 0.27531278, -0.095755219,
60 -0.712036, -0.17384434, -0.50593495, -0.18646687, -0.96508682, 0.43519354,
61 0.14744234, 0.62589407, 0.1653645, -0.10651493, -0.045277178, 0.99032974,
62 -0.88255352, -0.85147917, 0.28153265, 0.19455957, -0.55479527, -0.56042433,
63 0.26048636, 0.84702539, 0.47587705, -0.074295521, -0.12287641, 0.70117295,
64 0.90532446, 0.89782166, 0.79817224, 0.53402734, -0.33286154, 0.073485017,
65 -0.56172788, -0.044897556, 0.89964068, -0.067662835, 0.76863563, 0.93455386,
66 -0.6324693, -0.083922029};
67
68 static float rnn_golden_output[] = {
69 0.496726, 0, 0.965996, 0, 0.0584254, 0, 0, 0.12315,
70 0, 0, 0.612266, 0.456601, 0, 0.52286, 1.16099, 0.0291232,
71
72 0, 0, 0.524901, 0, 0, 0, 0, 1.02116,
73 0, 1.35762, 0, 0.356909, 0.436415, 0.0355727, 0, 0,
74
75 0, 0, 0, 0.262335, 0, 0, 0, 1.33992,
76 0, 2.9739, 0, 0, 1.31914, 2.66147, 0, 0,
77
78 0.942568, 0, 0, 0, 0.025507, 0, 0, 0,
79 0.321429, 0.569141, 1.25274, 1.57719, 0.8158, 1.21805, 0.586239, 0.25427,
80
81 1.04436, 0, 0.630725, 0, 0.133801, 0.210693, 0.363026, 0,
82 0.533426, 0, 1.25926, 0.722707, 0, 1.22031, 1.30117, 0.495867,
83
84 0.222187, 0, 0.72725, 0, 0.767003, 0, 0, 0.147835,
85 0, 0, 0, 0.608758, 0.469394, 0.00720298, 0.927537, 0,
86
87 0.856974, 0.424257, 0, 0, 0.937329, 0, 0, 0,
88 0.476425, 0, 0.566017, 0.418462, 0.141911, 0.996214, 1.13063, 0,
89
90 0.967899, 0, 0, 0, 0.0831304, 0, 0, 1.00378,
91 0, 0, 0, 1.44818, 1.01768, 0.943891, 0.502745, 0,
92
93 0.940135, 0, 0, 0, 0, 0, 0, 2.13243,
94 0, 0.71208, 0.123918, 1.53907, 1.30225, 1.59644, 0.70222, 0,
95
96 0.804329, 0, 0.430576, 0, 0.505872, 0.509603, 0.343448, 0,
97 0.107756, 0.614544, 1.44549, 1.52311, 0.0454298, 0.300267, 0.562784, 0.395095,
98
99 0.228154, 0, 0.675323, 0, 1.70536, 0.766217, 0, 0,
100 0, 0.735363, 0.0759267, 1.91017, 0.941888, 0, 0, 0,
101
102 0, 0, 1.5909, 0, 0, 0, 0, 0.5755,
103 0, 0.184687, 0, 1.56296, 0.625285, 0, 0, 0,
104
105 0, 0, 0.0857888, 0, 0, 0, 0, 0.488383,
106 0.252786, 0, 0, 0, 1.02817, 1.85665, 0, 0,
107
108 0.00981836, 0, 1.06371, 0, 0, 0, 0, 0,
109 0, 0.290445, 0.316406, 0, 0.304161, 1.25079, 0.0707152, 0,
110
111 0.986264, 0.309201, 0, 0, 0, 0, 0, 1.64896,
112 0.346248, 0, 0.918175, 0.78884, 0.524981, 1.92076, 2.07013, 0.333244,
113
114 0.415153, 0.210318, 0, 0, 0, 0, 0, 2.02616,
115 0, 0.728256, 0.84183, 0.0907453, 0.628881, 3.58099, 1.49974, 0};
116
117 } // anonymous namespace
118
119 #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
120 ACTION(Input) \
121 ACTION(Weights) \
122 ACTION(RecurrentWeights) \
123 ACTION(Bias) \
124 ACTION(HiddenStateIn)
125
126 // For all output and intermediate states
127 #define FOR_ALL_OUTPUT_TENSORS(ACTION) \
128 ACTION(HiddenStateOut) \
129 ACTION(Output)
130
131 class BasicRNNOpModel {
132 public:
BasicRNNOpModel(uint32_t batches,uint32_t units,uint32_t size)133 BasicRNNOpModel(uint32_t batches, uint32_t units, uint32_t size)
134 : batches_(batches), units_(units), input_size_(size), activation_(kActivationRelu) {
135 std::vector<uint32_t> inputs;
136
137 OperandType InputTy(Type::TENSOR_FLOAT32, {batches_, input_size_});
138 inputs.push_back(model_.addOperand(&InputTy));
139 OperandType WeightTy(Type::TENSOR_FLOAT32, {units_, input_size_});
140 inputs.push_back(model_.addOperand(&WeightTy));
141 OperandType RecurrentWeightTy(Type::TENSOR_FLOAT32, {units_, units_});
142 inputs.push_back(model_.addOperand(&RecurrentWeightTy));
143 OperandType BiasTy(Type::TENSOR_FLOAT32, {units_});
144 inputs.push_back(model_.addOperand(&BiasTy));
145 OperandType HiddenStateTy(Type::TENSOR_FLOAT32, {batches_, units_});
146 inputs.push_back(model_.addOperand(&HiddenStateTy));
147 OperandType ActionParamTy(Type::INT32, {});
148 inputs.push_back(model_.addOperand(&ActionParamTy));
149
150 std::vector<uint32_t> outputs;
151
152 outputs.push_back(model_.addOperand(&HiddenStateTy));
153 OperandType OutputTy(Type::TENSOR_FLOAT32, {batches_, units_});
154 outputs.push_back(model_.addOperand(&OutputTy));
155
156 Input_.insert(Input_.end(), batches_ * input_size_, 0.f);
157 HiddenStateIn_.insert(HiddenStateIn_.end(), batches_ * units_, 0.f);
158 HiddenStateOut_.insert(HiddenStateOut_.end(), batches_ * units_, 0.f);
159 Output_.insert(Output_.end(), batches_ * units_, 0.f);
160
161 model_.addOperation(ANEURALNETWORKS_RNN, inputs, outputs);
162 model_.identifyInputsAndOutputs(inputs, outputs);
163
164 model_.finish();
165 }
166
167 #define DefineSetter(X) \
168 void Set##X(const std::vector<float>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); }
169
170 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
171
172 #undef DefineSetter
173
SetInput(int offset,float * begin,float * end)174 void SetInput(int offset, float* begin, float* end) {
175 for (; begin != end; begin++, offset++) {
176 Input_[offset] = *begin;
177 }
178 }
179
ResetHiddenState()180 void ResetHiddenState() {
181 std::fill(HiddenStateIn_.begin(), HiddenStateIn_.end(), 0.f);
182 std::fill(HiddenStateOut_.begin(), HiddenStateOut_.end(), 0.f);
183 }
184
GetOutput() const185 const std::vector<float>& GetOutput() const { return Output_; }
186
input_size() const187 uint32_t input_size() const { return input_size_; }
num_units() const188 uint32_t num_units() const { return units_; }
num_batches() const189 uint32_t num_batches() const { return batches_; }
190
Invoke()191 void Invoke() {
192 ASSERT_TRUE(model_.isValid());
193
194 HiddenStateIn_.swap(HiddenStateOut_);
195
196 Compilation compilation(&model_);
197 compilation.finish();
198 Execution execution(&compilation);
199 #define SetInputOrWeight(X) \
200 ASSERT_EQ(execution.setInput(RNN::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
201 Result::NO_ERROR);
202
203 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
204
205 #undef SetInputOrWeight
206
207 #define SetOutput(X) \
208 ASSERT_EQ(execution.setOutput(RNN::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
209 Result::NO_ERROR);
210
211 FOR_ALL_OUTPUT_TENSORS(SetOutput);
212
213 #undef SetOutput
214
215 ASSERT_EQ(execution.setInput(RNN::kActivationParam, &activation_, sizeof(activation_)),
216 Result::NO_ERROR);
217
218 ASSERT_EQ(execution.compute(), Result::NO_ERROR);
219 }
220
221 private:
222 Model model_;
223
224 const uint32_t batches_;
225 const uint32_t units_;
226 const uint32_t input_size_;
227
228 const int activation_;
229
230 #define DefineTensor(X) std::vector<float> X##_;
231
232 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
233 FOR_ALL_OUTPUT_TENSORS(DefineTensor);
234
235 #undef DefineTensor
236 };
237
TEST(RNNOpTest,BlackBoxTest)238 TEST(RNNOpTest, BlackBoxTest) {
239 BasicRNNOpModel rnn(2, 16, 8);
240 rnn.SetWeights(
241 {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, 0.317493,
242 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, 0.448504, 0.317662,
243 0.523556, -0.323514, 0.480877, 0.333113, -0.757714, -0.674487, -0.643585,
244 0.217766, -0.0251462, 0.79512, -0.595574, -0.422444, 0.371572, -0.452178,
245 -0.556069, -0.482188, -0.685456, -0.727851, 0.841829, 0.551535, -0.232336,
246 0.729158, -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241,
247 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, 0.306261,
248 -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, 0.0354295, 0.566564,
249 -0.485469, -0.620498, 0.832546, 0.697884, -0.279115, 0.294415, -0.584313,
250 0.548772, 0.0648819, 0.968726, 0.723834, -0.0080452, -0.350386, -0.272803,
251 0.115121, -0.412644, -0.824713, -0.992843, -0.592904, -0.417893, 0.863791,
252 -0.423461, -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158,
253 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, 0.0960841,
254 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, 0.37225, -0.623598,
255 -0.405423, 0.455101, 0.673656, -0.145345, -0.511346, -0.901675, -0.81252,
256 -0.127006, 0.809865, -0.721884, 0.636255, 0.868989, -0.347973, -0.10179,
257 -0.777449, 0.917274, 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872,
258 0.972934, -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077,
259 0.277308, 0.415818});
260
261 rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, -0.23566568,
262 -0.389184, 0.47481549, -0.4791103, 0.29931796, 0.10463274, 0.83918178, 0.37197268,
263 0.61957061, 0.3956964, -0.37609905});
264
265 rnn.SetRecurrentWeights(
266 {0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0,
267 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, 0,
268 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, 0, 0, 0,
269 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, 0, 0, 0, 0, 0,
270 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
271 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
272 0, 0, 0, 0, 0, 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
273 0, 0, 0, 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
274 0, 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1,
275 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0,
276 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, 0, 0,
277 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, 0, 0, 0, 0,
278 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, 0, 0, 0, 0, 0, 0,
279 0, 0, 0, 0, 0, 0, 0, 0, 0.1});
280
281 rnn.ResetHiddenState();
282 const int input_sequence_size =
283 sizeof(rnn_input) / sizeof(float) / (rnn.input_size() * rnn.num_batches());
284
285 for (int i = 0; i < input_sequence_size; i++) {
286 float* batch_start = rnn_input + i * rnn.input_size();
287 float* batch_end = batch_start + rnn.input_size();
288 rnn.SetInput(0, batch_start, batch_end);
289 rnn.SetInput(rnn.input_size(), batch_start, batch_end);
290
291 rnn.Invoke();
292
293 float* golden_start = rnn_golden_output + i * rnn.num_units();
294 float* golden_end = golden_start + rnn.num_units();
295 std::vector<float> expected;
296 expected.insert(expected.end(), golden_start, golden_end);
297 expected.insert(expected.end(), golden_start, golden_end);
298
299 EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
300 }
301 }
302
303 } // namespace wrapper
304 } // namespace nn
305 } // namespace android
306