1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "RandomVariable.h"
18 
19 #include <algorithm>
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <unordered_set>
24 #include <vector>
25 
26 #include "RandomGraphGeneratorUtils.h"
27 
28 namespace android {
29 namespace nn {
30 namespace fuzzing_test {
31 
32 unsigned int RandomVariableBase::globalIndex = 0;
33 int RandomVariable::defaultValue = 10;
34 
RandomVariableBase(int value)35 RandomVariableBase::RandomVariableBase(int value)
36     : index(globalIndex++),
37       type(RandomVariableType::CONST),
38       range(value),
39       value(value),
40       timestamp(RandomVariableNetwork::get()->getGlobalTime()) {}
41 
RandomVariableBase(int lower,int upper)42 RandomVariableBase::RandomVariableBase(int lower, int upper)
43     : index(globalIndex++),
44       type(RandomVariableType::FREE),
45       range(lower, upper),
46       timestamp(RandomVariableNetwork::get()->getGlobalTime()) {}
47 
RandomVariableBase(const std::vector<int> & choices)48 RandomVariableBase::RandomVariableBase(const std::vector<int>& choices)
49     : index(globalIndex++),
50       type(RandomVariableType::FREE),
51       range(choices),
52       timestamp(RandomVariableNetwork::get()->getGlobalTime()) {}
53 
RandomVariableBase(const RandomVariableNode & lhs,const RandomVariableNode & rhs,const std::shared_ptr<const IRandomVariableOp> & op)54 RandomVariableBase::RandomVariableBase(const RandomVariableNode& lhs, const RandomVariableNode& rhs,
55                                        const std::shared_ptr<const IRandomVariableOp>& op)
56     : index(globalIndex++),
57       type(RandomVariableType::OP),
58       range(op->getInitRange(lhs->range, rhs == nullptr ? RandomVariableRange(0) : rhs->range)),
59       op(op),
60       parent1(lhs),
61       parent2(rhs),
62       timestamp(RandomVariableNetwork::get()->getGlobalTime()) {}
63 
setRange(int lower,int upper)64 void RandomVariableRange::setRange(int lower, int upper) {
65     // kInvalidValue indicates unlimited bound.
66     auto head = lower == kInvalidValue ? mChoices.begin()
67                                        : std::lower_bound(mChoices.begin(), mChoices.end(), lower);
68     auto tail = upper == kInvalidValue ? mChoices.end()
69                                        : std::upper_bound(mChoices.begin(), mChoices.end(), upper);
70     NN_FUZZER_CHECK(head <= tail) << "Invalid range!";
71     if (head != mChoices.begin() || tail != mChoices.end()) {
72         mChoices = std::vector<int>(head, tail);
73     }
74 }
75 
toConst()76 int RandomVariableRange::toConst() {
77     if (mChoices.size() > 1) mChoices = {getRandomChoice(mChoices)};
78     return mChoices[0];
79 }
80 
operator &(const RandomVariableRange & lhs,const RandomVariableRange & rhs)81 RandomVariableRange operator&(const RandomVariableRange& lhs, const RandomVariableRange& rhs) {
82     std::vector<int> result(lhs.size() + rhs.size());
83     auto it = std::set_intersection(lhs.mChoices.begin(), lhs.mChoices.end(), rhs.mChoices.begin(),
84                                     rhs.mChoices.end(), result.begin());
85     result.resize(it - result.begin());
86     return RandomVariableRange(std::move(result));
87 }
88 
freeze()89 void RandomVariableBase::freeze() {
90     if (type == RandomVariableType::CONST) return;
91     value = range.toConst();
92     type = RandomVariableType::CONST;
93 }
94 
getValue() const95 int RandomVariableBase::getValue() const {
96     switch (type) {
97         case RandomVariableType::CONST:
98             return value;
99         case RandomVariableType::OP:
100             return op->eval(parent1->getValue(), parent2 == nullptr ? 0 : parent2->getValue());
101         default:
102             NN_FUZZER_CHECK(false) << "Invalid type when getting value of var" << index;
103             return 0;
104     }
105 }
106 
updateTimestamp()107 void RandomVariableBase::updateTimestamp() {
108     timestamp = RandomVariableNetwork::get()->getGlobalTime();
109     NN_FUZZER_LOG << "Update timestamp of var" << index << " to " << timestamp;
110 }
111 
RandomVariable(int value)112 RandomVariable::RandomVariable(int value) : mVar(new RandomVariableBase(value)) {
113     NN_FUZZER_LOG << "New RandomVariable " << toString(mVar);
114     RandomVariableNetwork::get()->add(mVar);
115 }
RandomVariable(int lower,int upper)116 RandomVariable::RandomVariable(int lower, int upper) : mVar(new RandomVariableBase(lower, upper)) {
117     NN_FUZZER_LOG << "New RandomVariable " << toString(mVar);
118     RandomVariableNetwork::get()->add(mVar);
119 }
RandomVariable(const std::vector<int> & choices)120 RandomVariable::RandomVariable(const std::vector<int>& choices)
121     : mVar(new RandomVariableBase(choices)) {
122     NN_FUZZER_LOG << "New RandomVariable " << toString(mVar);
123     RandomVariableNetwork::get()->add(mVar);
124 }
RandomVariable(RandomVariableType type)125 RandomVariable::RandomVariable(RandomVariableType type)
126     : mVar(new RandomVariableBase(1, defaultValue)) {
127     NN_FUZZER_CHECK(type == RandomVariableType::FREE);
128     NN_FUZZER_LOG << "New RandomVariable " << toString(mVar);
129     RandomVariableNetwork::get()->add(mVar);
130 }
RandomVariable(const RandomVariable & lhs,const RandomVariable & rhs,const std::shared_ptr<const IRandomVariableOp> & op)131 RandomVariable::RandomVariable(const RandomVariable& lhs, const RandomVariable& rhs,
132                                const std::shared_ptr<const IRandomVariableOp>& op)
133     : mVar(new RandomVariableBase(lhs.get(), rhs.get(), op)) {
134     // Make a copy if the parent is CONST. This will resolve the fake dependency problem.
135     if (mVar->parent1->type == RandomVariableType::CONST) {
136         mVar->parent1 = RandomVariable(mVar->parent1->value).get();
137     }
138     if (mVar->parent2 != nullptr && mVar->parent2->type == RandomVariableType::CONST) {
139         mVar->parent2 = RandomVariable(mVar->parent2->value).get();
140     }
141     mVar->parent1->children.push_back(mVar);
142     if (mVar->parent2 != nullptr) mVar->parent2->children.push_back(mVar);
143     RandomVariableNetwork::get()->add(mVar);
144     NN_FUZZER_LOG << "New RandomVariable " << toString(mVar);
145 }
146 
setRange(int lower,int upper)147 void RandomVariable::setRange(int lower, int upper) {
148     NN_FUZZER_CHECK(mVar != nullptr) << "setRange() on nullptr";
149     NN_FUZZER_LOG << "Set range [" << lower << ", " << upper << "] on var" << mVar->index;
150     size_t oldSize = mVar->range.size();
151     mVar->range.setRange(lower, upper);
152     // Only update the timestamp if the range is *indeed* narrowed down.
153     if (mVar->range.size() != oldSize) mVar->updateTimestamp();
154 }
155 
getInitRange(const RandomVariableRange & lhs,const RandomVariableRange & rhs) const156 RandomVariableRange IRandomVariableOp::getInitRange(const RandomVariableRange& lhs,
157                                                     const RandomVariableRange& rhs) const {
158     std::set<int> st;
159     for (auto i : lhs.getChoices()) {
160         for (auto j : rhs.getChoices()) {
161             int res = this->eval(i, j);
162             if (res > kMaxValue || res < -kMaxValue) continue;
163             st.insert(res);
164         }
165     }
166     return RandomVariableRange(st);
167 }
168 
169 // Check if the range contains exactly all values in [min, max].
isContinuous(const std::set<int> * range)170 static inline bool isContinuous(const std::set<int>* range) {
171     return (*(range->rbegin()) - *(range->begin()) + 1) == static_cast<int>(range->size());
172 }
173 
174 // Fill the set with a range of values specified by [lower, upper].
fillRange(std::set<int> * range,int lower,int upper)175 static inline void fillRange(std::set<int>* range, int lower, int upper) {
176     for (int i = lower; i <= upper; i++) range->insert(i);
177 }
178 
179 // The slowest algorithm: iterate through every combinations of parents and save the valid pairs.
eval(const std::set<int> * parent1In,const std::set<int> * parent2In,const std::set<int> * childIn,std::set<int> * parent1Out,std::set<int> * parent2Out,std::set<int> * childOut) const180 void IRandomVariableOp::eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
181                              const std::set<int>* childIn, std::set<int>* parent1Out,
182                              std::set<int>* parent2Out, std::set<int>* childOut) const {
183     // Avoid the binary search if the child is a closed range.
184     bool isChildInContinuous = isContinuous(childIn);
185     std::pair<int, int> child = {*childIn->begin(), *childIn->rbegin()};
186     for (auto i : *parent1In) {
187         bool valid = false;
188         for (auto j : *parent2In) {
189             int res = this->eval(i, j);
190             // Avoid the binary search if obviously out of range.
191             if (res > child.second || res < child.first) continue;
192             if (isChildInContinuous || childIn->find(res) != childIn->end()) {
193                 parent2Out->insert(j);
194                 childOut->insert(res);
195                 valid = true;
196             }
197         }
198         if (valid) parent1Out->insert(i);
199     }
200 }
201 
202 // A helper template to make a class into a Singleton.
203 template <class T>
204 class Singleton : public T {
205    public:
get()206     static const std::shared_ptr<const T>& get() {
207         static std::shared_ptr<const T> instance(new T);
208         return instance;
209     }
210 };
211 
212 // A set of operations that only compute on a single input value.
213 class IUnaryOp : public IRandomVariableOp {
214    public:
215     using IRandomVariableOp::eval;
216     virtual int eval(int val) const = 0;
eval(int lhs,int) const217     virtual int eval(int lhs, int) const override { return eval(lhs); }
218     // The slowest algorithm: iterate through every value of the parent and save the valid one.
eval(const std::set<int> * parent1In,const std::set<int> * parent2In,const std::set<int> * childIn,std::set<int> * parent1Out,std::set<int> * parent2Out,std::set<int> * childOut) const219     virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
220                       const std::set<int>* childIn, std::set<int>* parent1Out,
221                       std::set<int>* parent2Out, std::set<int>* childOut) const override {
222         NN_FUZZER_CHECK(parent2In == nullptr);
223         NN_FUZZER_CHECK(parent2Out == nullptr);
224         bool isChildInContinuous = isContinuous(childIn);
225         std::pair<int, int> child = {*childIn->begin(), *childIn->rbegin()};
226         for (auto i : *parent1In) {
227             int res = this->eval(i);
228             if (res > child.second || res < child.first) continue;
229             if (isChildInContinuous || childIn->find(res) != childIn->end()) {
230                 parent1Out->insert(i);
231                 childOut->insert(res);
232             }
233         }
234     }
235 };
236 
237 // A set of operations that only check conditional constraints.
238 class IConstraintOp : public IRandomVariableOp {
239    public:
240     using IRandomVariableOp::eval;
241     virtual bool check(int lhs, int rhs) const = 0;
eval(int lhs,int rhs) const242     virtual int eval(int lhs, int rhs) const override {
243         return check(lhs, rhs) ? 0 : kInvalidValue;
244     }
245     // The range for a constraint op is always {0}.
getInitRange(const RandomVariableRange &,const RandomVariableRange &) const246     virtual RandomVariableRange getInitRange(const RandomVariableRange&,
247                                              const RandomVariableRange&) const override {
248         return RandomVariableRange(0);
249     }
250     // The slowest algorithm:
251     // iterate through every combinations of parents and save the valid pairs.
eval(const std::set<int> * parent1In,const std::set<int> * parent2In,const std::set<int> *,std::set<int> * parent1Out,std::set<int> * parent2Out,std::set<int> * childOut) const252     virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
253                       const std::set<int>*, std::set<int>* parent1Out, std::set<int>* parent2Out,
254                       std::set<int>* childOut) const override {
255         for (auto i : *parent1In) {
256             bool valid = false;
257             for (auto j : *parent2In) {
258                 if (this->check(i, j)) {
259                     parent2Out->insert(j);
260                     valid = true;
261                 }
262             }
263             if (valid) parent1Out->insert(i);
264         }
265         if (!parent1Out->empty()) childOut->insert(0);
266     }
267 };
268 
269 class Addition : public IRandomVariableOp {
270    public:
eval(int lhs,int rhs) const271     virtual int eval(int lhs, int rhs) const override { return lhs + rhs; }
getInitRange(const RandomVariableRange & lhs,const RandomVariableRange & rhs) const272     virtual RandomVariableRange getInitRange(const RandomVariableRange& lhs,
273                                              const RandomVariableRange& rhs) const override {
274         return RandomVariableRange(lhs.min() + rhs.min(), lhs.max() + rhs.max());
275     }
eval(const std::set<int> * parent1In,const std::set<int> * parent2In,const std::set<int> * childIn,std::set<int> * parent1Out,std::set<int> * parent2Out,std::set<int> * childOut) const276     virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
277                       const std::set<int>* childIn, std::set<int>* parent1Out,
278                       std::set<int>* parent2Out, std::set<int>* childOut) const override {
279         if (!isContinuous(parent1In) || !isContinuous(parent2In) || !isContinuous(childIn)) {
280             IRandomVariableOp::eval(parent1In, parent2In, childIn, parent1Out, parent2Out,
281                                     childOut);
282         } else {
283             // For parents and child with close range, the out range can be computed directly
284             // without iterations.
285             std::pair<int, int> parent1 = {*parent1In->begin(), *parent1In->rbegin()};
286             std::pair<int, int> parent2 = {*parent2In->begin(), *parent2In->rbegin()};
287             std::pair<int, int> child = {*childIn->begin(), *childIn->rbegin()};
288 
289             // From ranges for parent, evalute range for child.
290             // [a, b] + [c, d] -> [a + c, b + d]
291             fillRange(childOut, std::max(child.first, parent1.first + parent2.first),
292                       std::min(child.second, parent1.second + parent2.second));
293 
294             // From ranges for child and one parent, evalute range for another parent.
295             // [a, b] - [c, d] -> [a - d, b - c]
296             fillRange(parent1Out, std::max(parent1.first, child.first - parent2.second),
297                       std::min(parent1.second, child.second - parent2.first));
298             fillRange(parent2Out, std::max(parent2.first, child.first - parent1.second),
299                       std::min(parent2.second, child.second - parent1.first));
300         }
301     }
getName() const302     virtual const char* getName() const override { return "ADD"; }
303 };
304 
305 class Subtraction : public IRandomVariableOp {
306    public:
eval(int lhs,int rhs) const307     virtual int eval(int lhs, int rhs) const override { return lhs - rhs; }
getInitRange(const RandomVariableRange & lhs,const RandomVariableRange & rhs) const308     virtual RandomVariableRange getInitRange(const RandomVariableRange& lhs,
309                                              const RandomVariableRange& rhs) const override {
310         return RandomVariableRange(lhs.min() - rhs.max(), lhs.max() - rhs.min());
311     }
eval(const std::set<int> * parent1In,const std::set<int> * parent2In,const std::set<int> * childIn,std::set<int> * parent1Out,std::set<int> * parent2Out,std::set<int> * childOut) const312     virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
313                       const std::set<int>* childIn, std::set<int>* parent1Out,
314                       std::set<int>* parent2Out, std::set<int>* childOut) const override {
315         if (!isContinuous(parent1In) || !isContinuous(parent2In) || !isContinuous(childIn)) {
316             IRandomVariableOp::eval(parent1In, parent2In, childIn, parent1Out, parent2Out,
317                                     childOut);
318         } else {
319             // Similar algorithm as Addition.
320             std::pair<int, int> parent1 = {*parent1In->begin(), *parent1In->rbegin()};
321             std::pair<int, int> parent2 = {*parent2In->begin(), *parent2In->rbegin()};
322             std::pair<int, int> child = {*childIn->begin(), *childIn->rbegin()};
323             fillRange(childOut, std::max(child.first, parent1.first - parent2.second),
324                       std::min(child.second, parent1.second - parent2.first));
325             fillRange(parent1Out, std::max(parent1.first, child.first + parent2.first),
326                       std::min(parent1.second, child.second + parent2.second));
327             fillRange(parent2Out, std::max(parent2.first, parent1.first - child.second),
328                       std::min(parent2.second, parent1.second - child.first));
329         }
330     }
getName() const331     virtual const char* getName() const override { return "SUB"; }
332 };
333 
334 class Multiplication : public IRandomVariableOp {
335    public:
eval(int lhs,int rhs) const336     virtual int eval(int lhs, int rhs) const override { return lhs * rhs; }
getInitRange(const RandomVariableRange & lhs,const RandomVariableRange & rhs) const337     virtual RandomVariableRange getInitRange(const RandomVariableRange& lhs,
338                                              const RandomVariableRange& rhs) const override {
339         if (lhs.min() < 0 || rhs.min() < 0) {
340             return IRandomVariableOp::getInitRange(lhs, rhs);
341         } else {
342             int lower = std::min(lhs.min() * rhs.min(), kMaxValue);
343             int upper = std::min(lhs.max() * rhs.max(), kMaxValue);
344             return RandomVariableRange(lower, upper);
345         }
346     }
eval(const std::set<int> * parent1In,const std::set<int> * parent2In,const std::set<int> * childIn,std::set<int> * parent1Out,std::set<int> * parent2Out,std::set<int> * childOut) const347     virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
348                       const std::set<int>* childIn, std::set<int>* parent1Out,
349                       std::set<int>* parent2Out, std::set<int>* childOut) const override {
350         if (*parent1In->begin() < 0 || *parent2In->begin() < 0 || *childIn->begin() < 0) {
351             IRandomVariableOp::eval(parent1In, parent2In, childIn, parent1Out, parent2Out,
352                                     childOut);
353         } else {
354             bool isChildInContinuous = isContinuous(childIn);
355             std::pair<int, int> child = {*childIn->begin(), *childIn->rbegin()};
356             for (auto i : *parent1In) {
357                 bool valid = false;
358                 for (auto j : *parent2In) {
359                     int res = this->eval(i, j);
360                     // Since MUL increases monotonically with one value, break the loop if the
361                     // result is larger than the limit.
362                     if (res > child.second) break;
363                     if (res < child.first) continue;
364                     if (isChildInContinuous || childIn->find(res) != childIn->end()) {
365                         valid = true;
366                         parent2Out->insert(j);
367                         childOut->insert(res);
368                     }
369                 }
370                 if (valid) parent1Out->insert(i);
371             }
372         }
373     }
getName() const374     virtual const char* getName() const override { return "MUL"; }
375 };
376 
377 class Division : public IRandomVariableOp {
378    public:
eval(int lhs,int rhs) const379     virtual int eval(int lhs, int rhs) const override {
380         return rhs == 0 ? kInvalidValue : lhs / rhs;
381     }
getInitRange(const RandomVariableRange & lhs,const RandomVariableRange & rhs) const382     virtual RandomVariableRange getInitRange(const RandomVariableRange& lhs,
383                                              const RandomVariableRange& rhs) const override {
384         if (lhs.min() < 0 || rhs.min() <= 0) {
385             return IRandomVariableOp::getInitRange(lhs, rhs);
386         } else {
387             return RandomVariableRange(lhs.min() / rhs.max(), lhs.max() / rhs.min());
388         }
389     }
getName() const390     virtual const char* getName() const override { return "DIV"; }
391 };
392 
393 class ExactDivision : public Division {
394    public:
eval(int lhs,int rhs) const395     virtual int eval(int lhs, int rhs) const override {
396         return (rhs == 0 || lhs % rhs != 0) ? kInvalidValue : lhs / rhs;
397     }
getName() const398     virtual const char* getName() const override { return "EXACT_DIV"; }
399 };
400 
401 class Modulo : public IRandomVariableOp {
402    public:
eval(int lhs,int rhs) const403     virtual int eval(int lhs, int rhs) const override {
404         return rhs == 0 ? kInvalidValue : lhs % rhs;
405     }
getInitRange(const RandomVariableRange &,const RandomVariableRange & rhs) const406     virtual RandomVariableRange getInitRange(const RandomVariableRange&,
407                                              const RandomVariableRange& rhs) const override {
408         return RandomVariableRange(0, rhs.max());
409     }
eval(const std::set<int> * parent1In,const std::set<int> * parent2In,const std::set<int> * childIn,std::set<int> * parent1Out,std::set<int> * parent2Out,std::set<int> * childOut) const410     virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
411                       const std::set<int>* childIn, std::set<int>* parent1Out,
412                       std::set<int>* parent2Out, std::set<int>* childOut) const override {
413         if (*childIn->begin() != 0 || childIn->size() != 1u) {
414             IRandomVariableOp::eval(parent1In, parent2In, childIn, parent1Out, parent2Out,
415                                     childOut);
416         } else {
417             // For the special case that child is a const 0, it would be faster if the range for
418             // parents are evaluated separately.
419 
420             // Evalute parent1 directly.
421             for (auto i : *parent1In) {
422                 for (auto j : *parent2In) {
423                     if (i % j == 0) {
424                         parent1Out->insert(i);
425                         break;
426                     }
427                 }
428             }
429             // Evalute parent2, see if a multiple of parent2 value can be found in parent1.
430             int parent1Max = *parent1In->rbegin();
431             for (auto i : *parent2In) {
432                 int jMax = parent1Max / i;
433                 for (int j = 1; j <= jMax; j++) {
434                     if (parent1In->find(i * j) != parent1In->end()) {
435                         parent2Out->insert(i);
436                         break;
437                     }
438                 }
439             }
440             if (!parent1Out->empty()) childOut->insert(0);
441         }
442     }
getName() const443     virtual const char* getName() const override { return "MOD"; }
444 };
445 
446 class Maximum : public IRandomVariableOp {
447    public:
eval(int lhs,int rhs) const448     virtual int eval(int lhs, int rhs) const override { return std::max(lhs, rhs); }
getName() const449     virtual const char* getName() const override { return "MAX"; }
450 };
451 
452 class Minimum : public IRandomVariableOp {
453    public:
eval(int lhs,int rhs) const454     virtual int eval(int lhs, int rhs) const override { return std::min(lhs, rhs); }
getName() const455     virtual const char* getName() const override { return "MIN"; }
456 };
457 
458 class Square : public IUnaryOp {
459    public:
eval(int val) const460     virtual int eval(int val) const override { return val * val; }
getName() const461     virtual const char* getName() const override { return "SQUARE"; }
462 };
463 
464 class UnaryEqual : public IUnaryOp {
465    public:
eval(int val) const466     virtual int eval(int val) const override { return val; }
getName() const467     virtual const char* getName() const override { return "UNARY_EQUAL"; }
468 };
469 
470 class Equal : public IConstraintOp {
471    public:
check(int lhs,int rhs) const472     virtual bool check(int lhs, int rhs) const override { return lhs == rhs; }
eval(const std::set<int> * parent1In,const std::set<int> * parent2In,const std::set<int> * childIn,std::set<int> * parent1Out,std::set<int> * parent2Out,std::set<int> * childOut) const473     virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
474                       const std::set<int>* childIn, std::set<int>* parent1Out,
475                       std::set<int>* parent2Out, std::set<int>* childOut) const override {
476         NN_FUZZER_CHECK(childIn->size() == 1u && *childIn->begin() == 0);
477         // The intersection of two sets can be found in O(n).
478         std::set_intersection(parent1In->begin(), parent1In->end(), parent2In->begin(),
479                               parent2In->end(), std::inserter(*parent1Out, parent1Out->begin()));
480         *parent2Out = *parent1Out;
481         childOut->insert(0);
482     }
getName() const483     virtual const char* getName() const override { return "EQUAL"; }
484 };
485 
486 class GreaterThan : public IConstraintOp {
487    public:
check(int lhs,int rhs) const488     virtual bool check(int lhs, int rhs) const override { return lhs > rhs; }
getName() const489     virtual const char* getName() const override { return "GREATER_THAN"; }
490 };
491 
492 class GreaterEqual : public IConstraintOp {
493    public:
check(int lhs,int rhs) const494     virtual bool check(int lhs, int rhs) const override { return lhs >= rhs; }
getName() const495     virtual const char* getName() const override { return "GREATER_EQUAL"; }
496 };
497 
498 class FloatMultiplication : public IUnaryOp {
499    public:
FloatMultiplication(float multiplicand)500     FloatMultiplication(float multiplicand) : mMultiplicand(multiplicand) {}
eval(int val) const501     virtual int eval(int val) const override {
502         return static_cast<int>(std::floor(static_cast<float>(val) * mMultiplicand));
503     }
getName() const504     virtual const char* getName() const override { return "MUL_FLOAT"; }
505 
506    private:
507     float mMultiplicand;
508 };
509 
510 // Arithmetic operators and methods on RandomVariables will create OP RandomVariableNodes.
511 // Since there must be at most one edge between two RandomVariableNodes, we have to do something
512 // special when both sides are refering to the same node.
513 
operator +(const RandomVariable & lhs,const RandomVariable & rhs)514 RandomVariable operator+(const RandomVariable& lhs, const RandomVariable& rhs) {
515     return lhs.get() == rhs.get() ? RandomVariable(lhs, 2, Singleton<Multiplication>::get())
516                                   : RandomVariable(lhs, rhs, Singleton<Addition>::get());
517 }
operator -(const RandomVariable & lhs,const RandomVariable & rhs)518 RandomVariable operator-(const RandomVariable& lhs, const RandomVariable& rhs) {
519     return lhs.get() == rhs.get() ? RandomVariable(0)
520                                   : RandomVariable(lhs, rhs, Singleton<Subtraction>::get());
521 }
operator *(const RandomVariable & lhs,const RandomVariable & rhs)522 RandomVariable operator*(const RandomVariable& lhs, const RandomVariable& rhs) {
523     return lhs.get() == rhs.get() ? RandomVariable(lhs, RandomVariable(), Singleton<Square>::get())
524                                   : RandomVariable(lhs, rhs, Singleton<Multiplication>::get());
525 }
operator *(const RandomVariable & lhs,const float & rhs)526 RandomVariable operator*(const RandomVariable& lhs, const float& rhs) {
527     return RandomVariable(lhs, RandomVariable(), std::make_shared<FloatMultiplication>(rhs));
528 }
operator /(const RandomVariable & lhs,const RandomVariable & rhs)529 RandomVariable operator/(const RandomVariable& lhs, const RandomVariable& rhs) {
530     return lhs.get() == rhs.get() ? RandomVariable(1)
531                                   : RandomVariable(lhs, rhs, Singleton<Division>::get());
532 }
operator %(const RandomVariable & lhs,const RandomVariable & rhs)533 RandomVariable operator%(const RandomVariable& lhs, const RandomVariable& rhs) {
534     return lhs.get() == rhs.get() ? RandomVariable(0)
535                                   : RandomVariable(lhs, rhs, Singleton<Modulo>::get());
536 }
max(const RandomVariable & lhs,const RandomVariable & rhs)537 RandomVariable max(const RandomVariable& lhs, const RandomVariable& rhs) {
538     return lhs.get() == rhs.get() ? lhs : RandomVariable(lhs, rhs, Singleton<Maximum>::get());
539 }
min(const RandomVariable & lhs,const RandomVariable & rhs)540 RandomVariable min(const RandomVariable& lhs, const RandomVariable& rhs) {
541     return lhs.get() == rhs.get() ? lhs : RandomVariable(lhs, rhs, Singleton<Minimum>::get());
542 }
543 
exactDiv(const RandomVariable & other)544 RandomVariable RandomVariable::exactDiv(const RandomVariable& other) {
545     return mVar == other.get() ? RandomVariable(1)
546                                : RandomVariable(*this, other, Singleton<ExactDivision>::get());
547 }
548 
setEqual(const RandomVariable & other) const549 RandomVariable RandomVariable::setEqual(const RandomVariable& other) const {
550     RandomVariableNode node1 = mVar, node2 = other.get();
551     NN_FUZZER_LOG << "Set equality of var" << node1->index << " and var" << node2->index;
552 
553     // Do not setEqual on the same pair twice.
554     if (node1 == node2 || (node1->op == Singleton<UnaryEqual>::get() && node1->parent1 == node2) ||
555         (node2->op == Singleton<UnaryEqual>::get() && node2->parent1 == node1)) {
556         NN_FUZZER_LOG << "Already equal. Return.";
557         return RandomVariable();
558     }
559 
560     // If possible, always try UnaryEqual first to reduce the search space.
561     // UnaryEqual can be used if node B is FREE and is evaluated later than node A.
562     // TODO: Reduce code duplication.
563     if (RandomVariableNetwork::get()->isSubordinate(node1, node2)) {
564         NN_FUZZER_LOG << "  Make var" << node2->index << " a child of var" << node1->index;
565         node2->type = RandomVariableType::OP;
566         node2->parent1 = node1;
567         node2->op = Singleton<UnaryEqual>::get();
568         node1->children.push_back(node2);
569         RandomVariableNetwork::get()->join(node1, node2);
570         node1->updateTimestamp();
571         return other;
572     }
573     if (RandomVariableNetwork::get()->isSubordinate(node2, node1)) {
574         NN_FUZZER_LOG << "  Make var" << node1->index << " a child of var" << node2->index;
575         node1->type = RandomVariableType::OP;
576         node1->parent1 = node2;
577         node1->op = Singleton<UnaryEqual>::get();
578         node2->children.push_back(node1);
579         RandomVariableNetwork::get()->join(node2, node1);
580         node1->updateTimestamp();
581         return *this;
582     }
583     return RandomVariable(*this, other, Singleton<Equal>::get());
584 }
585 
setGreaterThan(const RandomVariable & other) const586 RandomVariable RandomVariable::setGreaterThan(const RandomVariable& other) const {
587     NN_FUZZER_CHECK(mVar != other.get());
588     return RandomVariable(*this, other, Singleton<GreaterThan>::get());
589 }
setGreaterEqual(const RandomVariable & other) const590 RandomVariable RandomVariable::setGreaterEqual(const RandomVariable& other) const {
591     return mVar == other.get() ? *this
592                                : RandomVariable(*this, other, Singleton<GreaterEqual>::get());
593 }
594 
add(const RandomVariableNode & var)595 void DisjointNetwork::add(const RandomVariableNode& var) {
596     // Find the subnet index of the parents and decide the index for var.
597     int ind1 = var->parent1 == nullptr ? -1 : mIndexMap[var->parent1];
598     int ind2 = var->parent2 == nullptr ? -1 : mIndexMap[var->parent2];
599     int ind = join(ind1, ind2);
600     // If no parent, put it into a new subnet component.
601     if (ind == -1) ind = mNextIndex++;
602     NN_FUZZER_LOG << "Add RandomVariable var" << var->index << " to network #" << ind;
603     mIndexMap[var] = ind;
604     mEvalOrderMap[ind].push_back(var);
605 }
606 
join(int ind1,int ind2)607 int DisjointNetwork::join(int ind1, int ind2) {
608     if (ind1 == -1) return ind2;
609     if (ind2 == -1) return ind1;
610     if (ind1 == ind2) return ind1;
611     NN_FUZZER_LOG << "Join network #" << ind1 << " and #" << ind2;
612     auto &order1 = mEvalOrderMap[ind1], &order2 = mEvalOrderMap[ind2];
613     // Append every node in ind2 to the end of ind1
614     for (const auto& var : order2) {
615         order1.push_back(var);
616         mIndexMap[var] = ind1;
617     }
618     // Remove ind2 from mEvalOrderMap.
619     mEvalOrderMap.erase(mEvalOrderMap.find(ind2));
620     return ind1;
621 }
622 
get()623 RandomVariableNetwork* RandomVariableNetwork::get() {
624     static RandomVariableNetwork instance;
625     return &instance;
626 }
627 
initialize(int defaultValue)628 void RandomVariableNetwork::initialize(int defaultValue) {
629     RandomVariableBase::globalIndex = 0;
630     RandomVariable::defaultValue = defaultValue;
631     mIndexMap.clear();
632     mEvalOrderMap.clear();
633     mDimProd.clear();
634     mNextIndex = 0;
635     mGlobalTime = 0;
636     mTimestamp = -1;
637 }
638 
isSubordinate(const RandomVariableNode & node1,const RandomVariableNode & node2)639 bool RandomVariableNetwork::isSubordinate(const RandomVariableNode& node1,
640                                           const RandomVariableNode& node2) {
641     if (node2->type != RandomVariableType::FREE) return false;
642     int ind1 = mIndexMap[node1];
643     // node2 is of a different subnet.
644     if (ind1 != mIndexMap[node2]) return true;
645     for (const auto& node : mEvalOrderMap[ind1]) {
646         if (node == node2) return false;
647         // node2 is of the same subnet but evaluated later than node1.
648         if (node == node1) return true;
649     }
650     NN_FUZZER_CHECK(false) << "Code executed in non-reachable region.";
651     return false;
652 }
653 
654 struct EvalInfo {
655     // The RandomVariableNode that this EvalInfo is associated with.
656     // var->value is the current value during evaluation.
657     RandomVariableNode var;
658 
659     // The RandomVariable value is staged when a valid combination is found.
660     std::set<int> staging;
661 
662     // The staging values are committed after a subnet evaluation.
663     std::set<int> committed;
664 
665     // Keeps track of the latest timestamp that committed is updated.
666     int timestamp;
667 
668     // For evalSubnetWithLocalNetwork.
669     RandomVariableType originalType;
670 
671     // Should only invoke eval on OP RandomVariable.
evalandroid::nn::fuzzing_test::EvalInfo672     bool eval() {
673         NN_FUZZER_CHECK(var->type == RandomVariableType::OP);
674         var->value = var->op->eval(var->parent1->value,
675                                    var->parent2 == nullptr ? 0 : var->parent2->value);
676         if (var->value == kInvalidValue) return false;
677         return committed.find(var->value) != committed.end();
678     }
stageandroid::nn::fuzzing_test::EvalInfo679     void stage() { staging.insert(var->value); }
commitandroid::nn::fuzzing_test::EvalInfo680     void commit() {
681         // Only update committed and timestamp if the range is *indeed* changed.
682         if (staging.size() != committed.size()) {
683             committed = std::move(staging);
684             timestamp = RandomVariableNetwork::get()->getGlobalTime();
685         }
686         staging.clear();
687     }
updateRangeandroid::nn::fuzzing_test::EvalInfo688     void updateRange() {
689         // Only update range and timestamp if the range is *indeed* changed.
690         if (committed.size() != var->range.size()) {
691             var->range = RandomVariableRange(committed);
692             var->timestamp = timestamp;
693         }
694         committed.clear();
695     }
696 
EvalInfoandroid::nn::fuzzing_test::EvalInfo697     EvalInfo(const RandomVariableNode& var)
698         : var(var),
699           committed(var->range.getChoices().begin(), var->range.getChoices().end()),
700           timestamp(var->timestamp) {}
701 };
702 using EvalContext = std::unordered_map<RandomVariableNode, EvalInfo>;
703 
704 // For logging only.
toString(const RandomVariableNode & var,EvalContext * context)705 inline std::string toString(const RandomVariableNode& var, EvalContext* context) {
706     std::stringstream ss;
707     ss << "var" << var->index << " = ";
708     const auto& committed = context->at(var).committed;
709     switch (var->type) {
710         case RandomVariableType::FREE:
711             ss << "FREE ["
712                << joinStr(", ", 20, std::vector<int>(committed.begin(), committed.end())) << "]";
713             break;
714         case RandomVariableType::CONST:
715             ss << "CONST " << toString(var->value);
716             break;
717         case RandomVariableType::OP:
718             ss << "var" << var->parent1->index << " " << var->op->getName();
719             if (var->parent2 != nullptr) ss << " var" << var->parent2->index;
720             ss << ", [" << joinStr(", ", 20, std::vector<int>(committed.begin(), committed.end()))
721                << "]";
722             break;
723         default:
724             NN_FUZZER_CHECK(false);
725     }
726     ss << ", timestamp = " << context->at(var).timestamp;
727     return ss.str();
728 }
729 
730 // Check if the subnet needs to be re-evaluated by comparing the timestamps.
needEvaluate(const EvaluationOrder & evalOrder,int subnetTime,EvalContext * context=nullptr)731 static inline bool needEvaluate(const EvaluationOrder& evalOrder, int subnetTime,
732                                 EvalContext* context = nullptr) {
733     for (const auto& var : evalOrder) {
734         int timestamp = context == nullptr ? var->timestamp : context->at(var).timestamp;
735         // If we find a node that has been modified since last evaluation, the subnet needs to be
736         // re-evaluated.
737         if (timestamp > subnetTime) return true;
738     }
739     return false;
740 }
741 
742 // Helper function to evaluate the subnet recursively.
743 // Iterate through all combinations of FREE RandomVariables choices.
evalSubnetHelper(const EvaluationOrder & evalOrder,EvalContext * context,size_t i=0)744 static void evalSubnetHelper(const EvaluationOrder& evalOrder, EvalContext* context, size_t i = 0) {
745     if (i == evalOrder.size()) {
746         // Reach the end of the evaluation, find a valid combination.
747         for (auto& var : evalOrder) context->at(var).stage();
748         return;
749     }
750     const auto& var = evalOrder[i];
751     if (var->type == RandomVariableType::FREE) {
752         // For FREE RandomVariable, iterate through all valid choices.
753         for (int val : context->at(var).committed) {
754             var->value = val;
755             evalSubnetHelper(evalOrder, context, i + 1);
756         }
757         return;
758     } else if (var->type == RandomVariableType::OP) {
759         // For OP RandomVariable, evaluate from parents and terminate if the result is invalid.
760         if (!context->at(var).eval()) return;
761     }
762     evalSubnetHelper(evalOrder, context, i + 1);
763 }
764 
765 // Check if the subnet has only one single OP RandomVariable.
isSingleOpSubnet(const EvaluationOrder & evalOrder)766 static inline bool isSingleOpSubnet(const EvaluationOrder& evalOrder) {
767     int numOp = 0;
768     for (const auto& var : evalOrder) {
769         if (var->type == RandomVariableType::OP) numOp++;
770         if (numOp > 1) return false;
771     }
772     return numOp != 0;
773 }
774 
775 // Evaluate with a potentially faster approach provided by IRandomVariableOp.
evalSubnetSingleOpHelper(const EvaluationOrder & evalOrder,EvalContext * context)776 static inline void evalSubnetSingleOpHelper(const EvaluationOrder& evalOrder,
777                                             EvalContext* context) {
778     NN_FUZZER_LOG << "Identified as single op subnet";
779     const auto& var = evalOrder.back();
780     NN_FUZZER_CHECK(var->type == RandomVariableType::OP);
781     var->op->eval(&context->at(var->parent1).committed,
782                   var->parent2 == nullptr ? nullptr : &context->at(var->parent2).committed,
783                   &context->at(var).committed, &context->at(var->parent1).staging,
784                   var->parent2 == nullptr ? nullptr : &context->at(var->parent2).staging,
785                   &context->at(var).staging);
786 }
787 
788 // Check if the number of combinations of FREE RandomVariables exceeds the limit.
getNumCombinations(const EvaluationOrder & evalOrder,EvalContext * context=nullptr)789 static inline uint64_t getNumCombinations(const EvaluationOrder& evalOrder,
790                                           EvalContext* context = nullptr) {
791     constexpr uint64_t kLimit = 1e8;
792     uint64_t numCombinations = 1;
793     for (const auto& var : evalOrder) {
794         if (var->type == RandomVariableType::FREE) {
795             size_t size =
796                     context == nullptr ? var->range.size() : context->at(var).committed.size();
797             numCombinations *= size;
798             // To prevent overflow.
799             if (numCombinations > kLimit) return kLimit;
800         }
801     }
802     return numCombinations;
803 }
804 
805 // Evaluate the subnet recursively. Will return fail if the number of combinations of FREE
806 // RandomVariable exceeds the threshold kMaxNumCombinations.
evalSubnetWithBruteForce(const EvaluationOrder & evalOrder,EvalContext * context)807 static bool evalSubnetWithBruteForce(const EvaluationOrder& evalOrder, EvalContext* context) {
808     constexpr uint64_t kMaxNumCombinations = 1e7;
809     NN_FUZZER_LOG << "Evaluate with brute force";
810     if (isSingleOpSubnet(evalOrder)) {
811         // If the network only have one single OP, dispatch to a faster evaluation.
812         evalSubnetSingleOpHelper(evalOrder, context);
813     } else {
814         if (getNumCombinations(evalOrder, context) > kMaxNumCombinations) {
815             NN_FUZZER_LOG << "Terminate the evaluation because of large search range";
816             std::cout << "[          ]   Terminate the evaluation because of large search range"
817                       << std::endl;
818             return false;
819         }
820         evalSubnetHelper(evalOrder, context);
821     }
822     for (auto& var : evalOrder) {
823         if (context->at(var).staging.empty()) {
824             NN_FUZZER_LOG << "Evaluation failed at " << toString(var, context);
825             return false;
826         }
827         context->at(var).commit();
828     }
829     return true;
830 }
831 
832 struct LocalNetwork {
833     EvaluationOrder evalOrder;
834     std::vector<RandomVariableNode> bridgeNodes;
835     int timestamp = 0;
836 
evalandroid::nn::fuzzing_test::LocalNetwork837     bool eval(EvalContext* context) {
838         NN_FUZZER_LOG << "Evaluate local network with timestamp = " << timestamp;
839         // Temporarily treat bridge nodes as FREE RandomVariables.
840         for (const auto& var : bridgeNodes) {
841             context->at(var).originalType = var->type;
842             var->type = RandomVariableType::FREE;
843         }
844         for (const auto& var : evalOrder) {
845             context->at(var).staging.clear();
846             NN_FUZZER_LOG << "  - " << toString(var, context);
847         }
848         bool success = evalSubnetWithBruteForce(evalOrder, context);
849         // Reset the RandomVariable types for bridge nodes.
850         for (const auto& var : bridgeNodes) var->type = context->at(var).originalType;
851         return success;
852     }
853 };
854 
855 // Partition the network further into LocalNetworks based on the result from bridge annotation
856 // algorithm.
857 class GraphPartitioner : public DisjointNetwork {
858    public:
859     GraphPartitioner() = default;
860 
partition(const EvaluationOrder & evalOrder,int timestamp)861     std::vector<LocalNetwork> partition(const EvaluationOrder& evalOrder, int timestamp) {
862         annotateBridge(evalOrder);
863         for (const auto& var : evalOrder) add(var);
864         return get(timestamp);
865     }
866 
867    private:
868     GraphPartitioner(const GraphPartitioner&) = delete;
869     GraphPartitioner& operator=(const GraphPartitioner&) = delete;
870 
871     // Find the parent-child relationship between var1 and var2, and reset the bridge.
setBridgeFlag(const RandomVariableNode & var1,const RandomVariableNode & var2)872     void setBridgeFlag(const RandomVariableNode& var1, const RandomVariableNode& var2) {
873         if (var1->parent1 == var2) {
874             mBridgeInfo[var1].isParent1Bridge = true;
875         } else if (var1->parent2 == var2) {
876             mBridgeInfo[var1].isParent2Bridge = true;
877         } else {
878             setBridgeFlag(var2, var1);
879         }
880     }
881 
882     // Annoate the bridges with DFS -- an edge [u, v] is a bridge if none of u's ancestor is
883     // reachable from a node in the subtree of b. The complexity is O(V + E).
884     // discoveryTime: The timestamp a node is visited
885     // lowTime: The min discovery time of all reachable nodes from the subtree of the node.
annotateBridgeHelper(const RandomVariableNode & var,int * time)886     void annotateBridgeHelper(const RandomVariableNode& var, int* time) {
887         mBridgeInfo[var].visited = true;
888         mBridgeInfo[var].discoveryTime = mBridgeInfo[var].lowTime = (*time)++;
889 
890         // The algorithm operates on undirected graph. First find all adjacent nodes.
891         auto adj = var->children;
892         if (var->parent1 != nullptr) adj.push_back(var->parent1);
893         if (var->parent2 != nullptr) adj.push_back(var->parent2);
894 
895         for (const auto& child : adj) {
896             if (mBridgeInfo.find(child) == mBridgeInfo.end()) continue;
897             if (!mBridgeInfo[child].visited) {
898                 mBridgeInfo[child].parent = var;
899                 annotateBridgeHelper(child, time);
900 
901                 // If none of nodes in the subtree of child is connected to any ancestors of var,
902                 // then it is a bridge.
903                 mBridgeInfo[var].lowTime =
904                         std::min(mBridgeInfo[var].lowTime, mBridgeInfo[child].lowTime);
905                 if (mBridgeInfo[child].lowTime > mBridgeInfo[var].discoveryTime)
906                     setBridgeFlag(var, child);
907             } else if (mBridgeInfo[var].parent != child) {
908                 mBridgeInfo[var].lowTime =
909                         std::min(mBridgeInfo[var].lowTime, mBridgeInfo[child].discoveryTime);
910             }
911         }
912     }
913 
914     // Find all bridges in the subnet with DFS.
annotateBridge(const EvaluationOrder & evalOrder)915     void annotateBridge(const EvaluationOrder& evalOrder) {
916         for (const auto& var : evalOrder) mBridgeInfo[var];
917         int time = 0;
918         for (const auto& var : evalOrder) {
919             if (!mBridgeInfo[var].visited) annotateBridgeHelper(var, &time);
920         }
921     }
922 
923     // Re-partition the network by treating bridges as no edge.
add(const RandomVariableNode & var)924     void add(const RandomVariableNode& var) {
925         auto parent1 = var->parent1;
926         auto parent2 = var->parent2;
927         if (mBridgeInfo[var].isParent1Bridge) var->parent1 = nullptr;
928         if (mBridgeInfo[var].isParent2Bridge) var->parent2 = nullptr;
929         DisjointNetwork::add(var);
930         var->parent1 = parent1;
931         var->parent2 = parent2;
932     }
933 
934     // Add bridge nodes to the local network and remove single node subnet.
get(int timestamp)935     std::vector<LocalNetwork> get(int timestamp) {
936         std::vector<LocalNetwork> res;
937         for (auto& pair : mEvalOrderMap) {
938             // We do not need to evaluate subnet with only a single node.
939             if (pair.second.size() == 1 && pair.second[0]->parent1 == nullptr) continue;
940             res.emplace_back();
941             for (const auto& var : pair.second) {
942                 if (mBridgeInfo[var].isParent1Bridge) {
943                     res.back().evalOrder.push_back(var->parent1);
944                     res.back().bridgeNodes.push_back(var->parent1);
945                 }
946                 if (mBridgeInfo[var].isParent2Bridge) {
947                     res.back().evalOrder.push_back(var->parent2);
948                     res.back().bridgeNodes.push_back(var->parent2);
949                 }
950                 res.back().evalOrder.push_back(var);
951             }
952             res.back().timestamp = timestamp;
953         }
954         return res;
955     }
956 
957     // For bridge discovery algorithm.
958     struct BridgeInfo {
959         bool isParent1Bridge = false;
960         bool isParent2Bridge = false;
961         int discoveryTime = 0;
962         int lowTime = 0;
963         bool visited = false;
964         std::shared_ptr<RandomVariableBase> parent = nullptr;
965     };
966     std::unordered_map<RandomVariableNode, BridgeInfo> mBridgeInfo;
967 };
968 
969 // Evaluate subnets repeatedly until converge.
970 // Class T_Subnet must have member evalOrder, timestamp, and member function eval.
971 template <class T_Subnet>
evalSubnetsRepeatedly(std::vector<T_Subnet> * subnets,EvalContext * context)972 inline bool evalSubnetsRepeatedly(std::vector<T_Subnet>* subnets, EvalContext* context) {
973     bool terminate = false;
974     while (!terminate) {
975         terminate = true;
976         for (auto& subnet : *subnets) {
977             if (needEvaluate(subnet.evalOrder, subnet.timestamp, context)) {
978                 if (!subnet.eval(context)) return false;
979                 subnet.timestamp = RandomVariableNetwork::get()->getGlobalTime();
980                 terminate = false;
981             }
982         }
983     }
984     return true;
985 }
986 
987 // Evaluate the subnet by first partitioning it further into LocalNetworks.
evalSubnetWithLocalNetwork(const EvaluationOrder & evalOrder,int timestamp,EvalContext * context)988 static bool evalSubnetWithLocalNetwork(const EvaluationOrder& evalOrder, int timestamp,
989                                        EvalContext* context) {
990     NN_FUZZER_LOG << "Evaluate with local network";
991     auto localNetworks = GraphPartitioner().partition(evalOrder, timestamp);
992     return evalSubnetsRepeatedly(&localNetworks, context);
993 }
994 
995 struct LeafNetwork {
996     EvaluationOrder evalOrder;
997     int timestamp = 0;
LeafNetworkandroid::nn::fuzzing_test::LeafNetwork998     LeafNetwork(const RandomVariableNode& var, int timestamp) : timestamp(timestamp) {
999         std::set<RandomVariableNode> visited;
1000         constructorHelper(var, &visited);
1001     }
1002     // Construct the leaf network by recursively including parent nodes.
constructorHelperandroid::nn::fuzzing_test::LeafNetwork1003     void constructorHelper(const RandomVariableNode& var, std::set<RandomVariableNode>* visited) {
1004         if (var == nullptr || visited->find(var) != visited->end()) return;
1005         constructorHelper(var->parent1, visited);
1006         constructorHelper(var->parent2, visited);
1007         visited->insert(var);
1008         evalOrder.push_back(var);
1009     }
evalandroid::nn::fuzzing_test::LeafNetwork1010     bool eval(EvalContext* context) {
1011         return evalSubnetWithLocalNetwork(evalOrder, timestamp, context);
1012     }
1013 };
1014 
1015 // Evaluate the subnet by leaf network.
1016 // NOTE: This algorithm will only produce correct result for *most* of the time (> 99%).
1017 //       The random graph generator is expected to retry if it fails.
evalSubnetWithLeafNetwork(const EvaluationOrder & evalOrder,int timestamp,EvalContext * context)1018 static bool evalSubnetWithLeafNetwork(const EvaluationOrder& evalOrder, int timestamp,
1019                                       EvalContext* context) {
1020     NN_FUZZER_LOG << "Evaluate with leaf network";
1021     // Construct leaf networks.
1022     std::vector<LeafNetwork> leafNetworks;
1023     for (const auto& var : evalOrder) {
1024         if (var->children.empty()) {
1025             NN_FUZZER_LOG << "Found leaf " << toString(var, context);
1026             leafNetworks.emplace_back(var, timestamp);
1027         }
1028     }
1029     return evalSubnetsRepeatedly(&leafNetworks, context);
1030 }
1031 
addDimensionProd(const std::vector<RandomVariable> & dims)1032 void RandomVariableNetwork::addDimensionProd(const std::vector<RandomVariable>& dims) {
1033     if (dims.size() <= 1) return;
1034     EvaluationOrder order;
1035     for (const auto& dim : dims) order.push_back(dim.get());
1036     mDimProd.push_back(order);
1037 }
1038 
enforceDimProd(const std::vector<EvaluationOrder> & mDimProd,const std::unordered_map<RandomVariableNode,int> & indexMap,EvalContext * context,std::unordered_set<int> * dirtySubnets)1039 bool enforceDimProd(const std::vector<EvaluationOrder>& mDimProd,
1040                     const std::unordered_map<RandomVariableNode, int>& indexMap,
1041                     EvalContext* context, std::unordered_set<int>* dirtySubnets) {
1042     for (auto& evalOrder : mDimProd) {
1043         NN_FUZZER_LOG << "  Dimension product network size = " << evalOrder.size();
1044         // Initialize EvalInfo of each RandomVariable.
1045         for (auto& var : evalOrder) {
1046             if (context->find(var) == context->end()) context->emplace(var, var);
1047             NN_FUZZER_LOG << "  - " << toString(var, context);
1048         }
1049 
1050         // Enforce the product of the dimension values below kMaxValue:
1051         // max(dimA) = kMaxValue / (min(dimB) * min(dimC) * ...)
1052         int prod = 1;
1053         for (const auto& var : evalOrder) prod *= (*context->at(var).committed.begin());
1054         for (auto& var : evalOrder) {
1055             auto& committed = context->at(var).committed;
1056             int maxValue = kMaxValue / (prod / *committed.begin());
1057             auto it = committed.upper_bound(maxValue);
1058             // var has empty range -> no solution.
1059             if (it == committed.begin()) return false;
1060             // The range is not modified -> continue.
1061             if (it == committed.end()) continue;
1062             // The range is modified -> the subnet of var is dirty, i.e. needs re-evaluation.
1063             committed.erase(it, committed.end());
1064             context->at(var).timestamp = RandomVariableNetwork::get()->getGlobalTime();
1065             dirtySubnets->insert(indexMap.at(var));
1066         }
1067     }
1068     return true;
1069 }
1070 
evalRange()1071 bool RandomVariableNetwork::evalRange() {
1072     constexpr uint64_t kMaxNumCombinationsWithBruteForce = 500;
1073     constexpr uint64_t kMaxNumCombinationsWithLocalNetwork = 1e5;
1074     NN_FUZZER_LOG << "Evaluate on " << mEvalOrderMap.size() << " sub-networks";
1075     EvalContext context;
1076     std::unordered_set<int> dirtySubnets;  // Which subnets needs evaluation.
1077     for (auto& pair : mEvalOrderMap) {
1078         const auto& evalOrder = pair.second;
1079         // Decide whether needs evaluation by timestamp -- if no range has changed after the last
1080         // evaluation, then the subnet does not need re-evaluation.
1081         if (evalOrder.size() == 1 || !needEvaluate(evalOrder, mTimestamp)) continue;
1082         dirtySubnets.insert(pair.first);
1083     }
1084     if (!enforceDimProd(mDimProd, mIndexMap, &context, &dirtySubnets)) return false;
1085 
1086     // Repeat until the ranges converge.
1087     while (!dirtySubnets.empty()) {
1088         for (int ind : dirtySubnets) {
1089             const auto& evalOrder = mEvalOrderMap[ind];
1090             NN_FUZZER_LOG << "  Sub-network #" << ind << " size = " << evalOrder.size();
1091 
1092             // Initialize EvalInfo of each RandomVariable.
1093             for (auto& var : evalOrder) {
1094                 if (context.find(var) == context.end()) context.emplace(var, var);
1095                 NN_FUZZER_LOG << "  - " << toString(var, &context);
1096             }
1097 
1098             // Dispatch to different algorithm according to search range.
1099             bool success;
1100             uint64_t numCombinations = getNumCombinations(evalOrder);
1101             if (numCombinations <= kMaxNumCombinationsWithBruteForce) {
1102                 success = evalSubnetWithBruteForce(evalOrder, &context);
1103             } else if (numCombinations <= kMaxNumCombinationsWithLocalNetwork) {
1104                 success = evalSubnetWithLocalNetwork(evalOrder, mTimestamp, &context);
1105             } else {
1106                 success = evalSubnetWithLeafNetwork(evalOrder, mTimestamp, &context);
1107             }
1108             if (!success) return false;
1109         }
1110         dirtySubnets.clear();
1111         if (!enforceDimProd(mDimProd, mIndexMap, &context, &dirtySubnets)) return false;
1112     }
1113     // A successful evaluation, update RandomVariables from EvalContext.
1114     for (auto& pair : context) pair.second.updateRange();
1115     mTimestamp = getGlobalTime();
1116     NN_FUZZER_LOG << "Finish range evaluation";
1117     return true;
1118 }
1119 
unsetEqual(const RandomVariableNode & node)1120 static void unsetEqual(const RandomVariableNode& node) {
1121     if (node == nullptr) return;
1122     NN_FUZZER_LOG << "Unset equality of var" << node->index;
1123     RandomVariableNode parent1 = node->parent1, parent2 = node->parent2;
1124     parent1->children.erase(std::find(parent1->children.begin(), parent1->children.end(), node));
1125     node->parent1 = nullptr;
1126     if (parent2 != nullptr) {
1127         // For Equal.
1128         parent2->children.erase(
1129                 std::find(parent2->children.begin(), parent2->children.end(), node));
1130         node->parent2 = nullptr;
1131     } else {
1132         // For UnaryEqual.
1133         node->type = RandomVariableType::FREE;
1134         node->op = nullptr;
1135     }
1136 }
1137 
1138 // A class to revert all the changes made to RandomVariableNetwork since the Reverter object is
1139 // constructed. Only used when setEqualIfCompatible results in incompatible.
1140 class RandomVariableNetwork::Reverter {
1141    public:
1142     // Take a snapshot of RandomVariableNetwork when Reverter is constructed.
Reverter()1143     Reverter() : mSnapshot(*RandomVariableNetwork::get()) {}
1144     // Add constraint (Equal) nodes to the reverter.
addNode(const RandomVariableNode & node)1145     void addNode(const RandomVariableNode& node) { mEqualNodes.push_back(node); }
revert()1146     void revert() {
1147         NN_FUZZER_LOG << "Revert RandomVariableNetwork";
1148         // Release the constraints.
1149         for (const auto& node : mEqualNodes) unsetEqual(node);
1150         // Reset all member variables.
1151         *RandomVariableNetwork::get() = std::move(mSnapshot);
1152     }
1153 
1154    private:
1155     Reverter(const Reverter&) = delete;
1156     Reverter& operator=(const Reverter&) = delete;
1157     RandomVariableNetwork mSnapshot;
1158     std::vector<RandomVariableNode> mEqualNodes;
1159 };
1160 
setEqualIfCompatible(const std::vector<RandomVariable> & lhs,const std::vector<RandomVariable> & rhs)1161 bool RandomVariableNetwork::setEqualIfCompatible(const std::vector<RandomVariable>& lhs,
1162                                                  const std::vector<RandomVariable>& rhs) {
1163     NN_FUZZER_LOG << "Check compatibility of {" << joinStr(", ", lhs) << "} and {"
1164                   << joinStr(", ", rhs) << "}";
1165     if (lhs.size() != rhs.size()) return false;
1166     Reverter reverter;
1167     bool result = true;
1168     for (size_t i = 0; i < lhs.size(); i++) {
1169         auto node = lhs[i].setEqual(rhs[i]).get();
1170         reverter.addNode(node);
1171         // Early terminate if there is no common choice between two ranges.
1172         if (node != nullptr && node->range.empty()) result = false;
1173     }
1174     result = result && evalRange();
1175     if (!result) reverter.revert();
1176     NN_FUZZER_LOG << "setEqualIfCompatible: " << (result ? "[COMPATIBLE]" : "[INCOMPATIBLE]");
1177     return result;
1178 }
1179 
freeze()1180 bool RandomVariableNetwork::freeze() {
1181     NN_FUZZER_LOG << "Freeze the random network";
1182     if (!evalRange()) return false;
1183     for (const auto& pair : mEvalOrderMap) {
1184         // Find all FREE RandomVariables in the subnet.
1185         std::vector<RandomVariableNode> nodes;
1186         for (const auto& var : pair.second) {
1187             if (var->type == RandomVariableType::FREE) nodes.push_back(var);
1188         }
1189         // Randomly shuffle the order, this is for a more uniform randomness.
1190         randomShuffle(&nodes);
1191         // An inefficient algorithm that does freeze -> re-evaluate for every FREE RandomVariable.
1192         // TODO: Might be able to optimize this.
1193         for (const auto& var : nodes) {
1194             size_t size = var->range.size();
1195             NN_FUZZER_LOG << "Freeze " << toString(var);
1196             var->freeze();
1197             NN_FUZZER_LOG << "  " << toString(var);
1198             // There is no need to re-evaluate if the FREE RandomVariable have only one choice.
1199             if (size > 1) {
1200                 var->updateTimestamp();
1201                 if (!evalRange()) {
1202                     NN_FUZZER_LOG << "Freeze failed at " << toString(var);
1203                     return false;
1204                 }
1205             }
1206         }
1207     }
1208     NN_FUZZER_LOG << "Finish freezing the random network";
1209     return true;
1210 }
1211 
1212 }  // namespace fuzzing_test
1213 }  // namespace nn
1214 }  // namespace android
1215