1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "RandomVariable.h"
18 
19 #include <algorithm>
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <unordered_map>
24 #include <utility>
25 #include <vector>
26 
27 #include "RandomGraphGeneratorUtils.h"
28 
29 namespace android {
30 namespace nn {
31 namespace fuzzing_test {
32 
33 unsigned int RandomVariableBase::globalIndex = 0;
34 int RandomVariable::defaultValue = 10;
35 
RandomVariableBase(int value)36 RandomVariableBase::RandomVariableBase(int value)
37     : index(globalIndex++),
38       type(RandomVariableType::CONST),
39       range(value),
40       value(value),
41       timestamp(RandomVariableNetwork::get()->getGlobalTime()) {}
42 
RandomVariableBase(int lower,int upper)43 RandomVariableBase::RandomVariableBase(int lower, int upper)
44     : index(globalIndex++),
45       type(RandomVariableType::FREE),
46       range(lower, upper),
47       timestamp(RandomVariableNetwork::get()->getGlobalTime()) {}
48 
RandomVariableBase(const std::vector<int> & choices)49 RandomVariableBase::RandomVariableBase(const std::vector<int>& choices)
50     : index(globalIndex++),
51       type(RandomVariableType::FREE),
52       range(choices),
53       timestamp(RandomVariableNetwork::get()->getGlobalTime()) {}
54 
RandomVariableBase(const RandomVariableNode & lhs,const RandomVariableNode & rhs,const std::shared_ptr<const IRandomVariableOp> & op)55 RandomVariableBase::RandomVariableBase(const RandomVariableNode& lhs, const RandomVariableNode& rhs,
56                                        const std::shared_ptr<const IRandomVariableOp>& op)
57     : index(globalIndex++),
58       type(RandomVariableType::OP),
59       range(op->getInitRange(lhs->range, rhs == nullptr ? RandomVariableRange(0) : rhs->range)),
60       op(op),
61       parent1(lhs),
62       parent2(rhs),
63       timestamp(RandomVariableNetwork::get()->getGlobalTime()) {}
64 
setRange(int lower,int upper)65 void RandomVariableRange::setRange(int lower, int upper) {
66     // kInvalidValue indicates unlimited bound.
67     auto head = lower == kInvalidValue ? mChoices.begin()
68                                        : std::lower_bound(mChoices.begin(), mChoices.end(), lower);
69     auto tail = upper == kInvalidValue ? mChoices.end()
70                                        : std::upper_bound(mChoices.begin(), mChoices.end(), upper);
71     NN_FUZZER_CHECK(head <= tail) << "Invalid range!";
72     if (head != mChoices.begin() || tail != mChoices.end()) {
73         mChoices = std::vector<int>(head, tail);
74     }
75 }
76 
toConst()77 int RandomVariableRange::toConst() {
78     if (mChoices.size() > 1) mChoices = {getRandomChoice(mChoices)};
79     return mChoices[0];
80 }
81 
operator &(const RandomVariableRange & lhs,const RandomVariableRange & rhs)82 RandomVariableRange operator&(const RandomVariableRange& lhs, const RandomVariableRange& rhs) {
83     std::vector<int> result(lhs.size() + rhs.size());
84     auto it = std::set_intersection(lhs.mChoices.begin(), lhs.mChoices.end(), rhs.mChoices.begin(),
85                                     rhs.mChoices.end(), result.begin());
86     result.resize(it - result.begin());
87     return RandomVariableRange(std::move(result));
88 }
89 
freeze()90 void RandomVariableBase::freeze() {
91     if (type == RandomVariableType::CONST) return;
92     value = range.toConst();
93     type = RandomVariableType::CONST;
94 }
95 
getValue() const96 int RandomVariableBase::getValue() const {
97     switch (type) {
98         case RandomVariableType::CONST:
99             return value;
100         case RandomVariableType::OP:
101             return op->eval(parent1->getValue(), parent2 == nullptr ? 0 : parent2->getValue());
102         default:
103             NN_FUZZER_CHECK(false) << "Invalid type when getting value of var" << index;
104             return 0;
105     }
106 }
107 
updateTimestamp()108 void RandomVariableBase::updateTimestamp() {
109     timestamp = RandomVariableNetwork::get()->getGlobalTime();
110     NN_FUZZER_LOG << "Update timestamp of var" << index << " to " << timestamp;
111 }
112 
RandomVariable(int value)113 RandomVariable::RandomVariable(int value) : mVar(new RandomVariableBase(value)) {
114     NN_FUZZER_LOG << "New RandomVariable " << mVar;
115     RandomVariableNetwork::get()->add(mVar);
116 }
RandomVariable(int lower,int upper)117 RandomVariable::RandomVariable(int lower, int upper) : mVar(new RandomVariableBase(lower, upper)) {
118     NN_FUZZER_LOG << "New RandomVariable " << mVar;
119     RandomVariableNetwork::get()->add(mVar);
120 }
RandomVariable(const std::vector<int> & choices)121 RandomVariable::RandomVariable(const std::vector<int>& choices)
122     : mVar(new RandomVariableBase(choices)) {
123     NN_FUZZER_LOG << "New RandomVariable " << mVar;
124     RandomVariableNetwork::get()->add(mVar);
125 }
RandomVariable(RandomVariableType type)126 RandomVariable::RandomVariable(RandomVariableType type)
127     : mVar(new RandomVariableBase(1, defaultValue)) {
128     NN_FUZZER_CHECK(type == RandomVariableType::FREE);
129     NN_FUZZER_LOG << "New RandomVariable " << mVar;
130     RandomVariableNetwork::get()->add(mVar);
131 }
RandomVariable(const RandomVariable & lhs,const RandomVariable & rhs,const std::shared_ptr<const IRandomVariableOp> & op)132 RandomVariable::RandomVariable(const RandomVariable& lhs, const RandomVariable& rhs,
133                                const std::shared_ptr<const IRandomVariableOp>& op)
134     : mVar(new RandomVariableBase(lhs.get(), rhs.get(), op)) {
135     // Make a copy if the parent is CONST. This will resolve the fake dependency problem.
136     if (mVar->parent1->type == RandomVariableType::CONST) {
137         mVar->parent1 = RandomVariable(mVar->parent1->value).get();
138     }
139     if (mVar->parent2 != nullptr && mVar->parent2->type == RandomVariableType::CONST) {
140         mVar->parent2 = RandomVariable(mVar->parent2->value).get();
141     }
142     mVar->parent1->children.push_back(mVar);
143     if (mVar->parent2 != nullptr) mVar->parent2->children.push_back(mVar);
144     RandomVariableNetwork::get()->add(mVar);
145     NN_FUZZER_LOG << "New RandomVariable " << mVar;
146 }
147 
setRange(int lower,int upper)148 void RandomVariable::setRange(int lower, int upper) {
149     NN_FUZZER_CHECK(mVar != nullptr) << "setRange() on nullptr";
150     NN_FUZZER_LOG << "Set range [" << lower << ", " << upper << "] on var" << mVar->index;
151     size_t oldSize = mVar->range.size();
152     mVar->range.setRange(lower, upper);
153     // Only update the timestamp if the range is *indeed* narrowed down.
154     if (mVar->range.size() != oldSize) mVar->updateTimestamp();
155 }
156 
getInitRange(const RandomVariableRange & lhs,const RandomVariableRange & rhs) const157 RandomVariableRange IRandomVariableOp::getInitRange(const RandomVariableRange& lhs,
158                                                     const RandomVariableRange& rhs) const {
159     std::set<int> st;
160     for (auto i : lhs.getChoices()) {
161         for (auto j : rhs.getChoices()) {
162             int res = this->eval(i, j);
163             if (res > kMaxValue || res < -kMaxValue) continue;
164             st.insert(res);
165         }
166     }
167     return RandomVariableRange(st);
168 }
169 
170 // Check if the range contains exactly all values in [min, max].
isContinuous(const std::set<int> * range)171 static inline bool isContinuous(const std::set<int>* range) {
172     return (*(range->rbegin()) - *(range->begin()) + 1) == static_cast<int>(range->size());
173 }
174 
175 // Fill the set with a range of values specified by [lower, upper].
fillRange(std::set<int> * range,int lower,int upper)176 static inline void fillRange(std::set<int>* range, int lower, int upper) {
177     for (int i = lower; i <= upper; i++) range->insert(i);
178 }
179 
180 // The slowest algorithm: iterate through every combinations of parents and save the valid pairs.
eval(const std::set<int> * parent1In,const std::set<int> * parent2In,const std::set<int> * childIn,std::set<int> * parent1Out,std::set<int> * parent2Out,std::set<int> * childOut) const181 void IRandomVariableOp::eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
182                              const std::set<int>* childIn, std::set<int>* parent1Out,
183                              std::set<int>* parent2Out, std::set<int>* childOut) const {
184     // Avoid the binary search if the child is a closed range.
185     bool isChildInContinuous = isContinuous(childIn);
186     std::pair<int, int> child = {*childIn->begin(), *childIn->rbegin()};
187     for (auto i : *parent1In) {
188         bool valid = false;
189         for (auto j : *parent2In) {
190             int res = this->eval(i, j);
191             // Avoid the binary search if obviously out of range.
192             if (res > child.second || res < child.first) continue;
193             if (isChildInContinuous || childIn->find(res) != childIn->end()) {
194                 parent2Out->insert(j);
195                 childOut->insert(res);
196                 valid = true;
197             }
198         }
199         if (valid) parent1Out->insert(i);
200     }
201 }
202 
203 // A helper template to make a class into a Singleton.
204 template <class T>
205 class Singleton : public T {
206    public:
get()207     static const std::shared_ptr<const T>& get() {
208         static std::shared_ptr<const T> instance(new T);
209         return instance;
210     }
211 };
212 
213 // A set of operations that only compute on a single input value.
214 class IUnaryOp : public IRandomVariableOp {
215    public:
216     using IRandomVariableOp::eval;
217     virtual int eval(int val) const = 0;
eval(int lhs,int) const218     virtual int eval(int lhs, int) const override { return eval(lhs); }
219     // The slowest algorithm: iterate through every value of the parent and save the valid one.
eval(const std::set<int> * parent1In,const std::set<int> * parent2In,const std::set<int> * childIn,std::set<int> * parent1Out,std::set<int> * parent2Out,std::set<int> * childOut) const220     virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
221                       const std::set<int>* childIn, std::set<int>* parent1Out,
222                       std::set<int>* parent2Out, std::set<int>* childOut) const override {
223         NN_FUZZER_CHECK(parent2In == nullptr);
224         NN_FUZZER_CHECK(parent2Out == nullptr);
225         bool isChildInContinuous = isContinuous(childIn);
226         std::pair<int, int> child = {*childIn->begin(), *childIn->rbegin()};
227         for (auto i : *parent1In) {
228             int res = this->eval(i);
229             if (res > child.second || res < child.first) continue;
230             if (isChildInContinuous || childIn->find(res) != childIn->end()) {
231                 parent1Out->insert(i);
232                 childOut->insert(res);
233             }
234         }
235     }
236 };
237 
238 // A set of operations that only check conditional constraints.
239 class IConstraintOp : public IRandomVariableOp {
240    public:
241     using IRandomVariableOp::eval;
242     virtual bool check(int lhs, int rhs) const = 0;
eval(int lhs,int rhs) const243     virtual int eval(int lhs, int rhs) const override {
244         return check(lhs, rhs) ? 0 : kInvalidValue;
245     }
246     // The range for a constraint op is always {0}.
getInitRange(const RandomVariableRange &,const RandomVariableRange &) const247     virtual RandomVariableRange getInitRange(const RandomVariableRange&,
248                                              const RandomVariableRange&) const override {
249         return RandomVariableRange(0);
250     }
251     // The slowest algorithm:
252     // iterate through every combinations of parents and save the valid pairs.
eval(const std::set<int> * parent1In,const std::set<int> * parent2In,const std::set<int> *,std::set<int> * parent1Out,std::set<int> * parent2Out,std::set<int> * childOut) const253     virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
254                       const std::set<int>*, std::set<int>* parent1Out, std::set<int>* parent2Out,
255                       std::set<int>* childOut) const override {
256         for (auto i : *parent1In) {
257             bool valid = false;
258             for (auto j : *parent2In) {
259                 if (this->check(i, j)) {
260                     parent2Out->insert(j);
261                     valid = true;
262                 }
263             }
264             if (valid) parent1Out->insert(i);
265         }
266         if (!parent1Out->empty()) childOut->insert(0);
267     }
268 };
269 
270 class Addition : public IRandomVariableOp {
271    public:
eval(int lhs,int rhs) const272     virtual int eval(int lhs, int rhs) const override { return lhs + rhs; }
getInitRange(const RandomVariableRange & lhs,const RandomVariableRange & rhs) const273     virtual RandomVariableRange getInitRange(const RandomVariableRange& lhs,
274                                              const RandomVariableRange& rhs) const override {
275         return RandomVariableRange(lhs.min() + rhs.min(), lhs.max() + rhs.max());
276     }
eval(const std::set<int> * parent1In,const std::set<int> * parent2In,const std::set<int> * childIn,std::set<int> * parent1Out,std::set<int> * parent2Out,std::set<int> * childOut) const277     virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
278                       const std::set<int>* childIn, std::set<int>* parent1Out,
279                       std::set<int>* parent2Out, std::set<int>* childOut) const override {
280         if (!isContinuous(parent1In) || !isContinuous(parent2In) || !isContinuous(childIn)) {
281             IRandomVariableOp::eval(parent1In, parent2In, childIn, parent1Out, parent2Out,
282                                     childOut);
283         } else {
284             // For parents and child with close range, the out range can be computed directly
285             // without iterations.
286             std::pair<int, int> parent1 = {*parent1In->begin(), *parent1In->rbegin()};
287             std::pair<int, int> parent2 = {*parent2In->begin(), *parent2In->rbegin()};
288             std::pair<int, int> child = {*childIn->begin(), *childIn->rbegin()};
289 
290             // From ranges for parent, evaluate range for child.
291             // [a, b] + [c, d] -> [a + c, b + d]
292             fillRange(childOut, std::max(child.first, parent1.first + parent2.first),
293                       std::min(child.second, parent1.second + parent2.second));
294 
295             // From ranges for child and one parent, evaluate range for another parent.
296             // [a, b] - [c, d] -> [a - d, b - c]
297             fillRange(parent1Out, std::max(parent1.first, child.first - parent2.second),
298                       std::min(parent1.second, child.second - parent2.first));
299             fillRange(parent2Out, std::max(parent2.first, child.first - parent1.second),
300                       std::min(parent2.second, child.second - parent1.first));
301         }
302     }
getName() const303     virtual const char* getName() const override { return "ADD"; }
304 };
305 
306 class Subtraction : public IRandomVariableOp {
307    public:
eval(int lhs,int rhs) const308     virtual int eval(int lhs, int rhs) const override { return lhs - rhs; }
getInitRange(const RandomVariableRange & lhs,const RandomVariableRange & rhs) const309     virtual RandomVariableRange getInitRange(const RandomVariableRange& lhs,
310                                              const RandomVariableRange& rhs) const override {
311         return RandomVariableRange(lhs.min() - rhs.max(), lhs.max() - rhs.min());
312     }
eval(const std::set<int> * parent1In,const std::set<int> * parent2In,const std::set<int> * childIn,std::set<int> * parent1Out,std::set<int> * parent2Out,std::set<int> * childOut) const313     virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
314                       const std::set<int>* childIn, std::set<int>* parent1Out,
315                       std::set<int>* parent2Out, std::set<int>* childOut) const override {
316         if (!isContinuous(parent1In) || !isContinuous(parent2In) || !isContinuous(childIn)) {
317             IRandomVariableOp::eval(parent1In, parent2In, childIn, parent1Out, parent2Out,
318                                     childOut);
319         } else {
320             // Similar algorithm as Addition.
321             std::pair<int, int> parent1 = {*parent1In->begin(), *parent1In->rbegin()};
322             std::pair<int, int> parent2 = {*parent2In->begin(), *parent2In->rbegin()};
323             std::pair<int, int> child = {*childIn->begin(), *childIn->rbegin()};
324             fillRange(childOut, std::max(child.first, parent1.first - parent2.second),
325                       std::min(child.second, parent1.second - parent2.first));
326             fillRange(parent1Out, std::max(parent1.first, child.first + parent2.first),
327                       std::min(parent1.second, child.second + parent2.second));
328             fillRange(parent2Out, std::max(parent2.first, parent1.first - child.second),
329                       std::min(parent2.second, parent1.second - child.first));
330         }
331     }
getName() const332     virtual const char* getName() const override { return "SUB"; }
333 };
334 
335 class Multiplication : public IRandomVariableOp {
336    public:
eval(int lhs,int rhs) const337     virtual int eval(int lhs, int rhs) const override { return lhs * rhs; }
getInitRange(const RandomVariableRange & lhs,const RandomVariableRange & rhs) const338     virtual RandomVariableRange getInitRange(const RandomVariableRange& lhs,
339                                              const RandomVariableRange& rhs) const override {
340         if (lhs.min() < 0 || rhs.min() < 0) {
341             return IRandomVariableOp::getInitRange(lhs, rhs);
342         } else {
343             int lower = std::min(lhs.min() * rhs.min(), kMaxValue);
344             int upper = std::min(lhs.max() * rhs.max(), kMaxValue);
345             return RandomVariableRange(lower, upper);
346         }
347     }
eval(const std::set<int> * parent1In,const std::set<int> * parent2In,const std::set<int> * childIn,std::set<int> * parent1Out,std::set<int> * parent2Out,std::set<int> * childOut) const348     virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
349                       const std::set<int>* childIn, std::set<int>* parent1Out,
350                       std::set<int>* parent2Out, std::set<int>* childOut) const override {
351         if (*parent1In->begin() < 0 || *parent2In->begin() < 0 || *childIn->begin() < 0) {
352             IRandomVariableOp::eval(parent1In, parent2In, childIn, parent1Out, parent2Out,
353                                     childOut);
354         } else {
355             bool isChildInContinuous = isContinuous(childIn);
356             std::pair<int, int> child = {*childIn->begin(), *childIn->rbegin()};
357             for (auto i : *parent1In) {
358                 bool valid = false;
359                 for (auto j : *parent2In) {
360                     int res = this->eval(i, j);
361                     // Since MUL increases monotonically with one value, break the loop if the
362                     // result is larger than the limit.
363                     if (res > child.second) break;
364                     if (res < child.first) continue;
365                     if (isChildInContinuous || childIn->find(res) != childIn->end()) {
366                         valid = true;
367                         parent2Out->insert(j);
368                         childOut->insert(res);
369                     }
370                 }
371                 if (valid) parent1Out->insert(i);
372             }
373         }
374     }
getName() const375     virtual const char* getName() const override { return "MUL"; }
376 };
377 
378 class Division : public IRandomVariableOp {
379    public:
eval(int lhs,int rhs) const380     virtual int eval(int lhs, int rhs) const override {
381         return rhs == 0 ? kInvalidValue : lhs / rhs;
382     }
getInitRange(const RandomVariableRange & lhs,const RandomVariableRange & rhs) const383     virtual RandomVariableRange getInitRange(const RandomVariableRange& lhs,
384                                              const RandomVariableRange& rhs) const override {
385         if (lhs.min() < 0 || rhs.min() <= 0) {
386             return IRandomVariableOp::getInitRange(lhs, rhs);
387         } else {
388             return RandomVariableRange(lhs.min() / rhs.max(), lhs.max() / rhs.min());
389         }
390     }
getName() const391     virtual const char* getName() const override { return "DIV"; }
392 };
393 
394 class ExactDivision : public Division {
395    public:
eval(int lhs,int rhs) const396     virtual int eval(int lhs, int rhs) const override {
397         return (rhs == 0 || lhs % rhs != 0) ? kInvalidValue : lhs / rhs;
398     }
getName() const399     virtual const char* getName() const override { return "EXACT_DIV"; }
400 };
401 
402 class Modulo : public IRandomVariableOp {
403    public:
eval(int lhs,int rhs) const404     virtual int eval(int lhs, int rhs) const override {
405         return rhs == 0 ? kInvalidValue : lhs % rhs;
406     }
getInitRange(const RandomVariableRange &,const RandomVariableRange & rhs) const407     virtual RandomVariableRange getInitRange(const RandomVariableRange&,
408                                              const RandomVariableRange& rhs) const override {
409         return RandomVariableRange(0, rhs.max());
410     }
eval(const std::set<int> * parent1In,const std::set<int> * parent2In,const std::set<int> * childIn,std::set<int> * parent1Out,std::set<int> * parent2Out,std::set<int> * childOut) const411     virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
412                       const std::set<int>* childIn, std::set<int>* parent1Out,
413                       std::set<int>* parent2Out, std::set<int>* childOut) const override {
414         if (*childIn->begin() != 0 || childIn->size() != 1u) {
415             IRandomVariableOp::eval(parent1In, parent2In, childIn, parent1Out, parent2Out,
416                                     childOut);
417         } else {
418             // For the special case that child is a const 0, it would be faster if the range for
419             // parents are evaluated separately.
420 
421             // Evaluate parent1 directly.
422             for (auto i : *parent1In) {
423                 for (auto j : *parent2In) {
424                     if (i % j == 0) {
425                         parent1Out->insert(i);
426                         break;
427                     }
428                 }
429             }
430             // Evaluate parent2, see if a multiple of parent2 value can be found in parent1.
431             int parent1Max = *parent1In->rbegin();
432             for (auto i : *parent2In) {
433                 int jMax = parent1Max / i;
434                 for (int j = 1; j <= jMax; j++) {
435                     if (parent1In->find(i * j) != parent1In->end()) {
436                         parent2Out->insert(i);
437                         break;
438                     }
439                 }
440             }
441             if (!parent1Out->empty()) childOut->insert(0);
442         }
443     }
getName() const444     virtual const char* getName() const override { return "MOD"; }
445 };
446 
447 class Maximum : public IRandomVariableOp {
448    public:
eval(int lhs,int rhs) const449     virtual int eval(int lhs, int rhs) const override { return std::max(lhs, rhs); }
getName() const450     virtual const char* getName() const override { return "MAX"; }
451 };
452 
453 class Minimum : public IRandomVariableOp {
454    public:
eval(int lhs,int rhs) const455     virtual int eval(int lhs, int rhs) const override { return std::min(lhs, rhs); }
getName() const456     virtual const char* getName() const override { return "MIN"; }
457 };
458 
459 class Square : public IUnaryOp {
460    public:
eval(int val) const461     virtual int eval(int val) const override { return val * val; }
getName() const462     virtual const char* getName() const override { return "SQUARE"; }
463 };
464 
465 class UnaryEqual : public IUnaryOp {
466    public:
eval(int val) const467     virtual int eval(int val) const override { return val; }
getName() const468     virtual const char* getName() const override { return "UNARY_EQUAL"; }
469 };
470 
471 class Equal : public IConstraintOp {
472    public:
check(int lhs,int rhs) const473     virtual bool check(int lhs, int rhs) const override { return lhs == rhs; }
eval(const std::set<int> * parent1In,const std::set<int> * parent2In,const std::set<int> * childIn,std::set<int> * parent1Out,std::set<int> * parent2Out,std::set<int> * childOut) const474     virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
475                       const std::set<int>* childIn, std::set<int>* parent1Out,
476                       std::set<int>* parent2Out, std::set<int>* childOut) const override {
477         NN_FUZZER_CHECK(childIn->size() == 1u && *childIn->begin() == 0);
478         // The intersection of two sets can be found in O(n).
479         std::set_intersection(parent1In->begin(), parent1In->end(), parent2In->begin(),
480                               parent2In->end(), std::inserter(*parent1Out, parent1Out->begin()));
481         *parent2Out = *parent1Out;
482         childOut->insert(0);
483     }
getName() const484     virtual const char* getName() const override { return "EQUAL"; }
485 };
486 
487 class GreaterThan : public IConstraintOp {
488    public:
check(int lhs,int rhs) const489     virtual bool check(int lhs, int rhs) const override { return lhs > rhs; }
getName() const490     virtual const char* getName() const override { return "GREATER_THAN"; }
491 };
492 
493 class GreaterEqual : public IConstraintOp {
494    public:
check(int lhs,int rhs) const495     virtual bool check(int lhs, int rhs) const override { return lhs >= rhs; }
getName() const496     virtual const char* getName() const override { return "GREATER_EQUAL"; }
497 };
498 
499 class FloatMultiplication : public IUnaryOp {
500    public:
FloatMultiplication(float multiplicand)501     FloatMultiplication(float multiplicand) : mMultiplicand(multiplicand) {}
eval(int val) const502     virtual int eval(int val) const override {
503         return static_cast<int>(std::floor(static_cast<float>(val) * mMultiplicand));
504     }
getName() const505     virtual const char* getName() const override { return "MUL_FLOAT"; }
506 
507    private:
508     float mMultiplicand;
509 };
510 
511 // Arithmetic operators and methods on RandomVariables will create OP RandomVariableNodes.
512 // Since there must be at most one edge between two RandomVariableNodes, we have to do something
513 // special when both sides are refering to the same node.
514 
operator +(const RandomVariable & lhs,const RandomVariable & rhs)515 RandomVariable operator+(const RandomVariable& lhs, const RandomVariable& rhs) {
516     return lhs.get() == rhs.get() ? RandomVariable(lhs, 2, Singleton<Multiplication>::get())
517                                   : RandomVariable(lhs, rhs, Singleton<Addition>::get());
518 }
operator -(const RandomVariable & lhs,const RandomVariable & rhs)519 RandomVariable operator-(const RandomVariable& lhs, const RandomVariable& rhs) {
520     return lhs.get() == rhs.get() ? RandomVariable(0)
521                                   : RandomVariable(lhs, rhs, Singleton<Subtraction>::get());
522 }
operator *(const RandomVariable & lhs,const RandomVariable & rhs)523 RandomVariable operator*(const RandomVariable& lhs, const RandomVariable& rhs) {
524     return lhs.get() == rhs.get() ? RandomVariable(lhs, RandomVariable(), Singleton<Square>::get())
525                                   : RandomVariable(lhs, rhs, Singleton<Multiplication>::get());
526 }
operator *(const RandomVariable & lhs,const float & rhs)527 RandomVariable operator*(const RandomVariable& lhs, const float& rhs) {
528     return RandomVariable(lhs, RandomVariable(), std::make_shared<FloatMultiplication>(rhs));
529 }
operator /(const RandomVariable & lhs,const RandomVariable & rhs)530 RandomVariable operator/(const RandomVariable& lhs, const RandomVariable& rhs) {
531     return lhs.get() == rhs.get() ? RandomVariable(1)
532                                   : RandomVariable(lhs, rhs, Singleton<Division>::get());
533 }
operator %(const RandomVariable & lhs,const RandomVariable & rhs)534 RandomVariable operator%(const RandomVariable& lhs, const RandomVariable& rhs) {
535     return lhs.get() == rhs.get() ? RandomVariable(0)
536                                   : RandomVariable(lhs, rhs, Singleton<Modulo>::get());
537 }
max(const RandomVariable & lhs,const RandomVariable & rhs)538 RandomVariable max(const RandomVariable& lhs, const RandomVariable& rhs) {
539     return lhs.get() == rhs.get() ? lhs : RandomVariable(lhs, rhs, Singleton<Maximum>::get());
540 }
min(const RandomVariable & lhs,const RandomVariable & rhs)541 RandomVariable min(const RandomVariable& lhs, const RandomVariable& rhs) {
542     return lhs.get() == rhs.get() ? lhs : RandomVariable(lhs, rhs, Singleton<Minimum>::get());
543 }
544 
exactDiv(const RandomVariable & other)545 RandomVariable RandomVariable::exactDiv(const RandomVariable& other) {
546     return mVar == other.get() ? RandomVariable(1)
547                                : RandomVariable(*this, other, Singleton<ExactDivision>::get());
548 }
549 
setEqual(const RandomVariable & other) const550 RandomVariable RandomVariable::setEqual(const RandomVariable& other) const {
551     RandomVariableNode node1 = mVar, node2 = other.get();
552     NN_FUZZER_LOG << "Set equality of var" << node1->index << " and var" << node2->index;
553 
554     // Do not setEqual on the same pair twice.
555     if (node1 == node2 || (node1->op == Singleton<UnaryEqual>::get() && node1->parent1 == node2) ||
556         (node2->op == Singleton<UnaryEqual>::get() && node2->parent1 == node1)) {
557         NN_FUZZER_LOG << "Already equal. Return.";
558         return RandomVariable();
559     }
560 
561     // If possible, always try UnaryEqual first to reduce the search space.
562     // UnaryEqual can be used if node B is FREE and is evaluated later than node A.
563     // TODO: Reduce code duplication.
564     if (RandomVariableNetwork::get()->isSubordinate(node1, node2)) {
565         NN_FUZZER_LOG << "  Make var" << node2->index << " a child of var" << node1->index;
566         node2->type = RandomVariableType::OP;
567         node2->parent1 = node1;
568         node2->op = Singleton<UnaryEqual>::get();
569         node1->children.push_back(node2);
570         RandomVariableNetwork::get()->join(node1, node2);
571         node1->updateTimestamp();
572         return other;
573     }
574     if (RandomVariableNetwork::get()->isSubordinate(node2, node1)) {
575         NN_FUZZER_LOG << "  Make var" << node1->index << " a child of var" << node2->index;
576         node1->type = RandomVariableType::OP;
577         node1->parent1 = node2;
578         node1->op = Singleton<UnaryEqual>::get();
579         node2->children.push_back(node1);
580         RandomVariableNetwork::get()->join(node2, node1);
581         node1->updateTimestamp();
582         return *this;
583     }
584     return RandomVariable(*this, other, Singleton<Equal>::get());
585 }
586 
setGreaterThan(const RandomVariable & other) const587 RandomVariable RandomVariable::setGreaterThan(const RandomVariable& other) const {
588     NN_FUZZER_CHECK(mVar != other.get());
589     return RandomVariable(*this, other, Singleton<GreaterThan>::get());
590 }
setGreaterEqual(const RandomVariable & other) const591 RandomVariable RandomVariable::setGreaterEqual(const RandomVariable& other) const {
592     return mVar == other.get() ? *this
593                                : RandomVariable(*this, other, Singleton<GreaterEqual>::get());
594 }
595 
add(const RandomVariableNode & var)596 void DisjointNetwork::add(const RandomVariableNode& var) {
597     // Find the subnet index of the parents and decide the index for var.
598     int ind1 = var->parent1 == nullptr ? -1 : mIndexMap[var->parent1];
599     int ind2 = var->parent2 == nullptr ? -1 : mIndexMap[var->parent2];
600     int ind = join(ind1, ind2);
601     // If no parent, put it into a new subnet component.
602     if (ind == -1) ind = mNextIndex++;
603     NN_FUZZER_LOG << "Add RandomVariable var" << var->index << " to network #" << ind;
604     mIndexMap[var] = ind;
605     mEvalOrderMap[ind].push_back(var);
606 }
607 
join(int ind1,int ind2)608 int DisjointNetwork::join(int ind1, int ind2) {
609     if (ind1 == -1) return ind2;
610     if (ind2 == -1) return ind1;
611     if (ind1 == ind2) return ind1;
612     NN_FUZZER_LOG << "Join network #" << ind1 << " and #" << ind2;
613     auto &order1 = mEvalOrderMap[ind1], &order2 = mEvalOrderMap[ind2];
614     // Append every node in ind2 to the end of ind1
615     for (const auto& var : order2) {
616         order1.push_back(var);
617         mIndexMap[var] = ind1;
618     }
619     // Remove ind2 from mEvalOrderMap.
620     mEvalOrderMap.erase(mEvalOrderMap.find(ind2));
621     return ind1;
622 }
623 
get()624 RandomVariableNetwork* RandomVariableNetwork::get() {
625     static RandomVariableNetwork instance;
626     return &instance;
627 }
628 
initialize(int defaultValue)629 void RandomVariableNetwork::initialize(int defaultValue) {
630     RandomVariableBase::globalIndex = 0;
631     RandomVariable::defaultValue = defaultValue;
632     mIndexMap.clear();
633     mEvalOrderMap.clear();
634     mDimProd.clear();
635     mNextIndex = 0;
636     mGlobalTime = 0;
637     mTimestamp = -1;
638 }
639 
isSubordinate(const RandomVariableNode & node1,const RandomVariableNode & node2)640 bool RandomVariableNetwork::isSubordinate(const RandomVariableNode& node1,
641                                           const RandomVariableNode& node2) {
642     if (node2->type != RandomVariableType::FREE) return false;
643     int ind1 = mIndexMap[node1];
644     // node2 is of a different subnet.
645     if (ind1 != mIndexMap[node2]) return true;
646     for (const auto& node : mEvalOrderMap[ind1]) {
647         if (node == node2) return false;
648         // node2 is of the same subnet but evaluated later than node1.
649         if (node == node1) return true;
650     }
651     NN_FUZZER_CHECK(false) << "Code executed in non-reachable region.";
652     return false;
653 }
654 
655 struct EvalInfo {
656     // The RandomVariableNode that this EvalInfo is associated with.
657     // var->value is the current value during evaluation.
658     RandomVariableNode var;
659 
660     // The RandomVariable value is staged when a valid combination is found.
661     std::set<int> staging;
662 
663     // The staging values are committed after a subnet evaluation.
664     std::set<int> committed;
665 
666     // Keeps track of the latest timestamp that committed is updated.
667     int timestamp;
668 
669     // For evalSubnetWithLocalNetwork.
670     RandomVariableType originalType;
671 
672     // Should only invoke eval on OP RandomVariable.
evalandroid::nn::fuzzing_test::EvalInfo673     bool eval() {
674         NN_FUZZER_CHECK(var->type == RandomVariableType::OP);
675         var->value = var->op->eval(var->parent1->value,
676                                    var->parent2 == nullptr ? 0 : var->parent2->value);
677         if (var->value == kInvalidValue) return false;
678         return committed.find(var->value) != committed.end();
679     }
stageandroid::nn::fuzzing_test::EvalInfo680     void stage() { staging.insert(var->value); }
commitandroid::nn::fuzzing_test::EvalInfo681     void commit() {
682         // Only update committed and timestamp if the range is *indeed* changed.
683         if (staging.size() != committed.size()) {
684             committed = std::move(staging);
685             timestamp = RandomVariableNetwork::get()->getGlobalTime();
686         }
687         staging.clear();
688     }
updateRangeandroid::nn::fuzzing_test::EvalInfo689     void updateRange() {
690         // Only update range and timestamp if the range is *indeed* changed.
691         if (committed.size() != var->range.size()) {
692             var->range = RandomVariableRange(committed);
693             var->timestamp = timestamp;
694         }
695         committed.clear();
696     }
697 
EvalInfoandroid::nn::fuzzing_test::EvalInfo698     EvalInfo(const RandomVariableNode& var)
699         : var(var),
700           committed(var->range.getChoices().begin(), var->range.getChoices().end()),
701           timestamp(var->timestamp) {}
702 };
703 using EvalContext = std::unordered_map<RandomVariableNode, EvalInfo>;
704 
705 // For logging only.
toString(const RandomVariableNode & var,EvalContext * context)706 inline std::string toString(const RandomVariableNode& var, EvalContext* context) {
707     std::stringstream ss;
708     ss << "var" << var->index << " = ";
709     const auto& committed = context->at(var).committed;
710     switch (var->type) {
711         case RandomVariableType::FREE:
712             ss << "FREE ["
713                << joinStr(", ", 20, std::vector<int>(committed.begin(), committed.end())) << "]";
714             break;
715         case RandomVariableType::CONST:
716             ss << "CONST " << var->value;
717             break;
718         case RandomVariableType::OP:
719             ss << "var" << var->parent1->index << " " << var->op->getName();
720             if (var->parent2 != nullptr) ss << " var" << var->parent2->index;
721             ss << ", [" << joinStr(", ", 20, std::vector<int>(committed.begin(), committed.end()))
722                << "]";
723             break;
724         default:
725             NN_FUZZER_CHECK(false);
726     }
727     ss << ", timestamp = " << context->at(var).timestamp;
728     return ss.str();
729 }
730 
731 // Check if the subnet needs to be re-evaluated by comparing the timestamps.
needEvaluate(const EvaluationOrder & evalOrder,int subnetTime,EvalContext * context=nullptr)732 static inline bool needEvaluate(const EvaluationOrder& evalOrder, int subnetTime,
733                                 EvalContext* context = nullptr) {
734     for (const auto& var : evalOrder) {
735         int timestamp = context == nullptr ? var->timestamp : context->at(var).timestamp;
736         // If we find a node that has been modified since last evaluation, the subnet needs to be
737         // re-evaluated.
738         if (timestamp > subnetTime) return true;
739     }
740     return false;
741 }
742 
743 // Helper function to evaluate the subnet recursively.
744 // Iterate through all combinations of FREE RandomVariables choices.
evalSubnetHelper(const EvaluationOrder & evalOrder,EvalContext * context,size_t i=0)745 static void evalSubnetHelper(const EvaluationOrder& evalOrder, EvalContext* context, size_t i = 0) {
746     if (i == evalOrder.size()) {
747         // Reach the end of the evaluation, find a valid combination.
748         for (auto& var : evalOrder) context->at(var).stage();
749         return;
750     }
751     const auto& var = evalOrder[i];
752     if (var->type == RandomVariableType::FREE) {
753         // For FREE RandomVariable, iterate through all valid choices.
754         for (int val : context->at(var).committed) {
755             var->value = val;
756             evalSubnetHelper(evalOrder, context, i + 1);
757         }
758         return;
759     } else if (var->type == RandomVariableType::OP) {
760         // For OP RandomVariable, evaluate from parents and terminate if the result is invalid.
761         if (!context->at(var).eval()) return;
762     }
763     evalSubnetHelper(evalOrder, context, i + 1);
764 }
765 
766 // Check if the subnet has only one single OP RandomVariable.
isSingleOpSubnet(const EvaluationOrder & evalOrder)767 static inline bool isSingleOpSubnet(const EvaluationOrder& evalOrder) {
768     int numOp = 0;
769     for (const auto& var : evalOrder) {
770         if (var->type == RandomVariableType::OP) numOp++;
771         if (numOp > 1) return false;
772     }
773     return numOp != 0;
774 }
775 
776 // Evaluate with a potentially faster approach provided by IRandomVariableOp.
evalSubnetSingleOpHelper(const EvaluationOrder & evalOrder,EvalContext * context)777 static inline void evalSubnetSingleOpHelper(const EvaluationOrder& evalOrder,
778                                             EvalContext* context) {
779     NN_FUZZER_LOG << "Identified as single op subnet";
780     const auto& var = evalOrder.back();
781     NN_FUZZER_CHECK(var->type == RandomVariableType::OP);
782     var->op->eval(&context->at(var->parent1).committed,
783                   var->parent2 == nullptr ? nullptr : &context->at(var->parent2).committed,
784                   &context->at(var).committed, &context->at(var->parent1).staging,
785                   var->parent2 == nullptr ? nullptr : &context->at(var->parent2).staging,
786                   &context->at(var).staging);
787 }
788 
789 // Check if the number of combinations of FREE RandomVariables exceeds the limit.
getNumCombinations(const EvaluationOrder & evalOrder,EvalContext * context=nullptr)790 static inline uint64_t getNumCombinations(const EvaluationOrder& evalOrder,
791                                           EvalContext* context = nullptr) {
792     constexpr uint64_t kLimit = 1e8;
793     uint64_t numCombinations = 1;
794     for (const auto& var : evalOrder) {
795         if (var->type == RandomVariableType::FREE) {
796             size_t size =
797                     context == nullptr ? var->range.size() : context->at(var).committed.size();
798             numCombinations *= size;
799             // To prevent overflow.
800             if (numCombinations > kLimit) return kLimit;
801         }
802     }
803     return numCombinations;
804 }
805 
806 // Evaluate the subnet recursively. Will return fail if the number of combinations of FREE
807 // RandomVariable exceeds the threshold kMaxNumCombinations.
evalSubnetWithBruteForce(const EvaluationOrder & evalOrder,EvalContext * context)808 static bool evalSubnetWithBruteForce(const EvaluationOrder& evalOrder, EvalContext* context) {
809     constexpr uint64_t kMaxNumCombinations = 1e7;
810     NN_FUZZER_LOG << "Evaluate with brute force";
811     if (isSingleOpSubnet(evalOrder)) {
812         // If the network only have one single OP, dispatch to a faster evaluation.
813         evalSubnetSingleOpHelper(evalOrder, context);
814     } else {
815         if (getNumCombinations(evalOrder, context) > kMaxNumCombinations) {
816             NN_FUZZER_LOG << "Terminate the evaluation because of large search range";
817             std::cout << "[          ]   Terminate the evaluation because of large search range"
818                       << std::endl;
819             return false;
820         }
821         evalSubnetHelper(evalOrder, context);
822     }
823     for (auto& var : evalOrder) {
824         if (context->at(var).staging.empty()) {
825             NN_FUZZER_LOG << "Evaluation failed at " << toString(var, context);
826             return false;
827         }
828         context->at(var).commit();
829     }
830     return true;
831 }
832 
833 struct LocalNetwork {
834     EvaluationOrder evalOrder;
835     std::vector<RandomVariableNode> bridgeNodes;
836     int timestamp = 0;
837 
evalandroid::nn::fuzzing_test::LocalNetwork838     bool eval(EvalContext* context) {
839         NN_FUZZER_LOG << "Evaluate local network with timestamp = " << timestamp;
840         // Temporarily treat bridge nodes as FREE RandomVariables.
841         for (const auto& var : bridgeNodes) {
842             context->at(var).originalType = var->type;
843             var->type = RandomVariableType::FREE;
844         }
845         for (const auto& var : evalOrder) {
846             context->at(var).staging.clear();
847             NN_FUZZER_LOG << "  - " << toString(var, context);
848         }
849         bool success = evalSubnetWithBruteForce(evalOrder, context);
850         // Reset the RandomVariable types for bridge nodes.
851         for (const auto& var : bridgeNodes) var->type = context->at(var).originalType;
852         return success;
853     }
854 };
855 
856 // Partition the network further into LocalNetworks based on the result from bridge annotation
857 // algorithm.
858 class GraphPartitioner : public DisjointNetwork {
859    public:
860     GraphPartitioner() = default;
861 
partition(const EvaluationOrder & evalOrder,int timestamp)862     std::vector<LocalNetwork> partition(const EvaluationOrder& evalOrder, int timestamp) {
863         annotateBridge(evalOrder);
864         for (const auto& var : evalOrder) add(var);
865         return get(timestamp);
866     }
867 
868    private:
869     GraphPartitioner(const GraphPartitioner&) = delete;
870     GraphPartitioner& operator=(const GraphPartitioner&) = delete;
871 
872     // Find the parent-child relationship between var1 and var2, and reset the bridge.
setBridgeFlag(const RandomVariableNode & var1,const RandomVariableNode & var2)873     void setBridgeFlag(const RandomVariableNode& var1, const RandomVariableNode& var2) {
874         if (var1->parent1 == var2) {
875             mBridgeInfo[var1].isParent1Bridge = true;
876         } else if (var1->parent2 == var2) {
877             mBridgeInfo[var1].isParent2Bridge = true;
878         } else {
879             setBridgeFlag(var2, var1);
880         }
881     }
882 
883     // Annoate the bridges with DFS -- an edge [u, v] is a bridge if none of u's ancestor is
884     // reachable from a node in the subtree of b. The complexity is O(V + E).
885     // discoveryTime: The timestamp a node is visited
886     // lowTime: The min discovery time of all reachable nodes from the subtree of the node.
annotateBridgeHelper(const RandomVariableNode & var,int * time)887     void annotateBridgeHelper(const RandomVariableNode& var, int* time) {
888         mBridgeInfo[var].visited = true;
889         mBridgeInfo[var].discoveryTime = mBridgeInfo[var].lowTime = (*time)++;
890 
891         // The algorithm operates on undirected graph. First find all adjacent nodes.
892         auto adj = var->children;
893         if (var->parent1 != nullptr) adj.push_back(var->parent1);
894         if (var->parent2 != nullptr) adj.push_back(var->parent2);
895 
896         for (const auto& weakChild : adj) {
897             auto child = weakChild.lock();
898             NN_FUZZER_CHECK(child != nullptr);
899             if (mBridgeInfo.find(child) == mBridgeInfo.end()) continue;
900             if (!mBridgeInfo[child].visited) {
901                 mBridgeInfo[child].parent = var;
902                 annotateBridgeHelper(child, time);
903 
904                 // If none of nodes in the subtree of child is connected to any ancestors of var,
905                 // then it is a bridge.
906                 mBridgeInfo[var].lowTime =
907                         std::min(mBridgeInfo[var].lowTime, mBridgeInfo[child].lowTime);
908                 if (mBridgeInfo[child].lowTime > mBridgeInfo[var].discoveryTime)
909                     setBridgeFlag(var, child);
910             } else if (mBridgeInfo[var].parent != child) {
911                 mBridgeInfo[var].lowTime =
912                         std::min(mBridgeInfo[var].lowTime, mBridgeInfo[child].discoveryTime);
913             }
914         }
915     }
916 
917     // Find all bridges in the subnet with DFS.
annotateBridge(const EvaluationOrder & evalOrder)918     void annotateBridge(const EvaluationOrder& evalOrder) {
919         for (const auto& var : evalOrder) mBridgeInfo[var];
920         int time = 0;
921         for (const auto& var : evalOrder) {
922             if (!mBridgeInfo[var].visited) annotateBridgeHelper(var, &time);
923         }
924     }
925 
926     // Re-partition the network by treating bridges as no edge.
add(const RandomVariableNode & var)927     void add(const RandomVariableNode& var) {
928         auto parent1 = var->parent1;
929         auto parent2 = var->parent2;
930         if (mBridgeInfo[var].isParent1Bridge) var->parent1 = nullptr;
931         if (mBridgeInfo[var].isParent2Bridge) var->parent2 = nullptr;
932         DisjointNetwork::add(var);
933         var->parent1 = parent1;
934         var->parent2 = parent2;
935     }
936 
937     // Add bridge nodes to the local network and remove single node subnet.
get(int timestamp)938     std::vector<LocalNetwork> get(int timestamp) {
939         std::vector<LocalNetwork> res;
940         for (auto& pair : mEvalOrderMap) {
941             // We do not need to evaluate subnet with only a single node.
942             if (pair.second.size() == 1 && pair.second[0]->parent1 == nullptr) continue;
943             res.emplace_back();
944             for (const auto& var : pair.second) {
945                 if (mBridgeInfo[var].isParent1Bridge) {
946                     res.back().evalOrder.push_back(var->parent1);
947                     res.back().bridgeNodes.push_back(var->parent1);
948                 }
949                 if (mBridgeInfo[var].isParent2Bridge) {
950                     res.back().evalOrder.push_back(var->parent2);
951                     res.back().bridgeNodes.push_back(var->parent2);
952                 }
953                 res.back().evalOrder.push_back(var);
954             }
955             res.back().timestamp = timestamp;
956         }
957         return res;
958     }
959 
960     // For bridge discovery algorithm.
961     struct BridgeInfo {
962         bool isParent1Bridge = false;
963         bool isParent2Bridge = false;
964         int discoveryTime = 0;
965         int lowTime = 0;
966         bool visited = false;
967         std::shared_ptr<RandomVariableBase> parent = nullptr;
968     };
969     std::unordered_map<RandomVariableNode, BridgeInfo> mBridgeInfo;
970 };
971 
972 // Evaluate subnets repeatedly until converge.
973 // Class T_Subnet must have member evalOrder, timestamp, and member function eval.
974 template <class T_Subnet>
evalSubnetsRepeatedly(std::vector<T_Subnet> * subnets,EvalContext * context)975 inline bool evalSubnetsRepeatedly(std::vector<T_Subnet>* subnets, EvalContext* context) {
976     bool terminate = false;
977     while (!terminate) {
978         terminate = true;
979         for (auto& subnet : *subnets) {
980             if (needEvaluate(subnet.evalOrder, subnet.timestamp, context)) {
981                 if (!subnet.eval(context)) return false;
982                 subnet.timestamp = RandomVariableNetwork::get()->getGlobalTime();
983                 terminate = false;
984             }
985         }
986     }
987     return true;
988 }
989 
990 // Evaluate the subnet by first partitioning it further into LocalNetworks.
evalSubnetWithLocalNetwork(const EvaluationOrder & evalOrder,int timestamp,EvalContext * context)991 static bool evalSubnetWithLocalNetwork(const EvaluationOrder& evalOrder, int timestamp,
992                                        EvalContext* context) {
993     NN_FUZZER_LOG << "Evaluate with local network";
994     auto localNetworks = GraphPartitioner().partition(evalOrder, timestamp);
995     return evalSubnetsRepeatedly(&localNetworks, context);
996 }
997 
998 struct LeafNetwork {
999     EvaluationOrder evalOrder;
1000     int timestamp = 0;
LeafNetworkandroid::nn::fuzzing_test::LeafNetwork1001     LeafNetwork(const RandomVariableNode& var, int timestamp) : timestamp(timestamp) {
1002         std::set<RandomVariableNode> visited;
1003         constructorHelper(var, &visited);
1004     }
1005     // Construct the leaf network by recursively including parent nodes.
constructorHelperandroid::nn::fuzzing_test::LeafNetwork1006     void constructorHelper(const RandomVariableNode& var, std::set<RandomVariableNode>* visited) {
1007         if (var == nullptr || visited->find(var) != visited->end()) return;
1008         constructorHelper(var->parent1, visited);
1009         constructorHelper(var->parent2, visited);
1010         visited->insert(var);
1011         evalOrder.push_back(var);
1012     }
evalandroid::nn::fuzzing_test::LeafNetwork1013     bool eval(EvalContext* context) {
1014         return evalSubnetWithLocalNetwork(evalOrder, timestamp, context);
1015     }
1016 };
1017 
1018 // Evaluate the subnet by leaf network.
1019 // NOTE: This algorithm will only produce correct result for *most* of the time (> 99%).
1020 //       The random graph generator is expected to retry if it fails.
evalSubnetWithLeafNetwork(const EvaluationOrder & evalOrder,int timestamp,EvalContext * context)1021 static bool evalSubnetWithLeafNetwork(const EvaluationOrder& evalOrder, int timestamp,
1022                                       EvalContext* context) {
1023     NN_FUZZER_LOG << "Evaluate with leaf network";
1024     // Construct leaf networks.
1025     std::vector<LeafNetwork> leafNetworks;
1026     for (const auto& var : evalOrder) {
1027         if (var->children.empty()) {
1028             NN_FUZZER_LOG << "Found leaf " << toString(var, context);
1029             leafNetworks.emplace_back(var, timestamp);
1030         }
1031     }
1032     return evalSubnetsRepeatedly(&leafNetworks, context);
1033 }
1034 
addDimensionProd(const std::vector<RandomVariable> & dims)1035 void RandomVariableNetwork::addDimensionProd(const std::vector<RandomVariable>& dims) {
1036     if (dims.size() <= 1) return;
1037     EvaluationOrder order;
1038     for (const auto& dim : dims) order.push_back(dim.get());
1039     mDimProd.push_back(order);
1040 }
1041 
enforceDimProd(const std::vector<EvaluationOrder> & mDimProd,const std::unordered_map<RandomVariableNode,int> & indexMap,EvalContext * context,std::set<int> * dirtySubnets)1042 bool enforceDimProd(const std::vector<EvaluationOrder>& mDimProd,
1043                     const std::unordered_map<RandomVariableNode, int>& indexMap,
1044                     EvalContext* context, std::set<int>* dirtySubnets) {
1045     for (auto& evalOrder : mDimProd) {
1046         NN_FUZZER_LOG << "  Dimension product network size = " << evalOrder.size();
1047         // Initialize EvalInfo of each RandomVariable.
1048         for (auto& var : evalOrder) {
1049             if (context->find(var) == context->end()) context->emplace(var, var);
1050             NN_FUZZER_LOG << "  - " << toString(var, context);
1051         }
1052 
1053         // Enforce the product of the dimension values below kMaxValue:
1054         // max(dimA) = kMaxValue / (min(dimB) * min(dimC) * ...)
1055         int prod = 1;
1056         for (const auto& var : evalOrder) prod *= (*context->at(var).committed.begin());
1057         for (auto& var : evalOrder) {
1058             auto& committed = context->at(var).committed;
1059             int maxValue = kMaxValue / (prod / *committed.begin());
1060             auto it = committed.upper_bound(maxValue);
1061             // var has empty range -> no solution.
1062             if (it == committed.begin()) return false;
1063             // The range is not modified -> continue.
1064             if (it == committed.end()) continue;
1065             // The range is modified -> the subnet of var is dirty, i.e. needs re-evaluation.
1066             committed.erase(it, committed.end());
1067             context->at(var).timestamp = RandomVariableNetwork::get()->getGlobalTime();
1068             dirtySubnets->insert(indexMap.at(var));
1069         }
1070     }
1071     return true;
1072 }
1073 
evalRange()1074 bool RandomVariableNetwork::evalRange() {
1075     constexpr uint64_t kMaxNumCombinationsWithBruteForce = 500;
1076     constexpr uint64_t kMaxNumCombinationsWithLocalNetwork = 1e5;
1077     NN_FUZZER_LOG << "Evaluate on " << mEvalOrderMap.size() << " sub-networks";
1078     EvalContext context;
1079     std::set<int> dirtySubnets;  // Which subnets needs evaluation.
1080     for (auto& pair : mEvalOrderMap) {
1081         const auto& evalOrder = pair.second;
1082         // Decide whether needs evaluation by timestamp -- if no range has changed after the last
1083         // evaluation, then the subnet does not need re-evaluation.
1084         if (evalOrder.size() == 1 || !needEvaluate(evalOrder, mTimestamp)) continue;
1085         dirtySubnets.insert(pair.first);
1086     }
1087     if (!enforceDimProd(mDimProd, mIndexMap, &context, &dirtySubnets)) return false;
1088 
1089     // Repeat until the ranges converge.
1090     while (!dirtySubnets.empty()) {
1091         for (int ind : dirtySubnets) {
1092             const auto& evalOrder = mEvalOrderMap[ind];
1093             NN_FUZZER_LOG << "  Sub-network #" << ind << " size = " << evalOrder.size();
1094 
1095             // Initialize EvalInfo of each RandomVariable.
1096             for (auto& var : evalOrder) {
1097                 if (context.find(var) == context.end()) context.emplace(var, var);
1098                 NN_FUZZER_LOG << "  - " << toString(var, &context);
1099             }
1100 
1101             // Dispatch to different algorithm according to search range.
1102             bool success;
1103             uint64_t numCombinations = getNumCombinations(evalOrder);
1104             if (numCombinations <= kMaxNumCombinationsWithBruteForce) {
1105                 success = evalSubnetWithBruteForce(evalOrder, &context);
1106             } else if (numCombinations <= kMaxNumCombinationsWithLocalNetwork) {
1107                 success = evalSubnetWithLocalNetwork(evalOrder, mTimestamp, &context);
1108             } else {
1109                 success = evalSubnetWithLeafNetwork(evalOrder, mTimestamp, &context);
1110             }
1111             if (!success) return false;
1112         }
1113         dirtySubnets.clear();
1114         if (!enforceDimProd(mDimProd, mIndexMap, &context, &dirtySubnets)) return false;
1115     }
1116     // A successful evaluation, update RandomVariables from EvalContext.
1117     for (auto& pair : context) pair.second.updateRange();
1118     mTimestamp = getGlobalTime();
1119     NN_FUZZER_LOG << "Finish range evaluation";
1120     return true;
1121 }
1122 
unsetEqual(const RandomVariableNode & node)1123 static void unsetEqual(const RandomVariableNode& node) {
1124     if (node == nullptr) return;
1125     NN_FUZZER_LOG << "Unset equality of var" << node->index;
1126     auto weakPtrEqual = [&node](const std::weak_ptr<RandomVariableBase>& ptr) {
1127         return ptr.lock() == node;
1128     };
1129     RandomVariableNode parent1 = node->parent1, parent2 = node->parent2;
1130     parent1->children.erase(
1131             std::find_if(parent1->children.begin(), parent1->children.end(), weakPtrEqual));
1132     node->parent1 = nullptr;
1133     if (parent2 != nullptr) {
1134         // For Equal.
1135         parent2->children.erase(
1136                 std::find_if(parent2->children.begin(), parent2->children.end(), weakPtrEqual));
1137         node->parent2 = nullptr;
1138     } else {
1139         // For UnaryEqual.
1140         node->type = RandomVariableType::FREE;
1141         node->op = nullptr;
1142     }
1143 }
1144 
1145 // A class to revert all the changes made to RandomVariableNetwork since the Reverter object is
1146 // constructed. Only used when setEqualIfCompatible results in incompatible.
1147 class RandomVariableNetwork::Reverter {
1148    public:
1149     // Take a snapshot of RandomVariableNetwork when Reverter is constructed.
Reverter()1150     Reverter() : mSnapshot(*RandomVariableNetwork::get()) {}
1151     // Add constraint (Equal) nodes to the reverter.
addNode(const RandomVariableNode & node)1152     void addNode(const RandomVariableNode& node) { mEqualNodes.push_back(node); }
revert()1153     void revert() {
1154         NN_FUZZER_LOG << "Revert RandomVariableNetwork";
1155         // Release the constraints.
1156         for (const auto& node : mEqualNodes) unsetEqual(node);
1157         // Reset all member variables.
1158         *RandomVariableNetwork::get() = std::move(mSnapshot);
1159     }
1160 
1161    private:
1162     Reverter(const Reverter&) = delete;
1163     Reverter& operator=(const Reverter&) = delete;
1164     RandomVariableNetwork mSnapshot;
1165     std::vector<RandomVariableNode> mEqualNodes;
1166 };
1167 
setEqualIfCompatible(const std::vector<RandomVariable> & lhs,const std::vector<RandomVariable> & rhs)1168 bool RandomVariableNetwork::setEqualIfCompatible(const std::vector<RandomVariable>& lhs,
1169                                                  const std::vector<RandomVariable>& rhs) {
1170     NN_FUZZER_LOG << "Check compatibility of {" << joinStr(", ", lhs) << "} and {"
1171                   << joinStr(", ", rhs) << "}";
1172     if (lhs.size() != rhs.size()) return false;
1173     Reverter reverter;
1174     bool result = true;
1175     for (size_t i = 0; i < lhs.size(); i++) {
1176         auto node = lhs[i].setEqual(rhs[i]).get();
1177         reverter.addNode(node);
1178         // Early terminate if there is no common choice between two ranges.
1179         if (node != nullptr && node->range.empty()) result = false;
1180     }
1181     result = result && evalRange();
1182     if (!result) reverter.revert();
1183     NN_FUZZER_LOG << "setEqualIfCompatible: " << (result ? "[COMPATIBLE]" : "[INCOMPATIBLE]");
1184     return result;
1185 }
1186 
freeze()1187 bool RandomVariableNetwork::freeze() {
1188     NN_FUZZER_LOG << "Freeze the random network";
1189     if (!evalRange()) return false;
1190 
1191     std::vector<RandomVariableNode> nodes;
1192     for (const auto& pair : mEvalOrderMap) {
1193         // Find all FREE RandomVariables in the subnet.
1194         for (const auto& var : pair.second) {
1195             if (var->type == RandomVariableType::FREE) nodes.push_back(var);
1196         }
1197     }
1198 
1199     // Randomly shuffle the order, this is for a more uniform randomness.
1200     randomShuffle(&nodes);
1201 
1202     // An inefficient algorithm that does freeze -> re-evaluate for every FREE RandomVariable.
1203     // TODO: Might be able to optimize this.
1204     for (const auto& var : nodes) {
1205         if (var->type != RandomVariableType::FREE) continue;
1206         size_t size = var->range.size();
1207         NN_FUZZER_LOG << "Freeze " << var;
1208         var->freeze();
1209         NN_FUZZER_LOG << "  " << var;
1210         // There is no need to re-evaluate if the FREE RandomVariable have only one choice.
1211         if (size > 1) {
1212             var->updateTimestamp();
1213             if (!evalRange()) {
1214                 NN_FUZZER_LOG << "Freeze failed at " << var;
1215                 return false;
1216             }
1217         }
1218     }
1219     NN_FUZZER_LOG << "Finish freezing the random network";
1220     return true;
1221 }
1222 
1223 }  // namespace fuzzing_test
1224 }  // namespace nn
1225 }  // namespace android
1226