1 /*
2  * Copyright 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *
16  * IptablesBaseTest.cpp - utility class for tests that use iptables
17  */
18 
19 #include <deque>
20 #include <string>
21 #include <vector>
22 
23 #include <gtest/gtest.h>
24 
25 #include <android-base/stringprintf.h>
26 
27 #include "IptablesBaseTest.h"
28 #include "NetdConstants.h"
29 
30 #define LOG_TAG "IptablesBaseTest"
31 #include <cutils/log.h>
32 
33 using android::base::StringPrintf;
34 
IptablesBaseTest()35 IptablesBaseTest::IptablesBaseTest() {
36     sCmds.clear();
37     sRestoreCmds.clear();
38     sReturnValues.clear();
39 }
40 
fake_android_fork_exec(int argc,char * argv[],int * status,bool,bool)41 int IptablesBaseTest::fake_android_fork_exec(int argc, char* argv[], int *status, bool, bool) {
42     std::string cmd = argv[0];
43     for (int i = 1; i < argc; i++) {
44         if (argv[i] == NULL) break;  // NatController likes to pass in invalid argc values.
45         cmd += " ";
46         cmd += argv[i];
47     }
48     sCmds.push_back(cmd);
49 
50     int ret;
51     if (sReturnValues.size()) {
52         ret = sReturnValues.front();
53         sReturnValues.pop_front();
54     } else {
55         ret = 0;
56     }
57 
58     if (status) {
59         *status = ret;
60     }
61     return ret;
62 }
63 
fakeExecIptables(IptablesTarget target,...)64 int IptablesBaseTest::fakeExecIptables(IptablesTarget target, ...) {
65     std::string cmd = " -w";
66     va_list args;
67     va_start(args, target);
68     const char *arg;
69     do {
70         arg = va_arg(args, const char *);
71         if (arg != nullptr) {
72             cmd += " ";
73             cmd += arg;
74         }
75     } while (arg);
76 
77     if (target == V4 || target == V4V6) {
78         sCmds.push_back(IPTABLES_PATH + cmd);
79     }
80     if (target == V6 || target == V4V6) {
81         sCmds.push_back(IP6TABLES_PATH + cmd);
82     }
83 
84     return 0;
85 }
86 
fake_popen(const char *,const char * type)87 FILE *IptablesBaseTest::fake_popen(const char * /* cmd */, const char *type) {
88     if (sPopenContents.empty() || strcmp(type, "r") != 0) {
89         return NULL;
90     }
91 
92     std::string realCmd = StringPrintf("echo '%s'", sPopenContents.front().c_str());
93     sPopenContents.pop_front();
94     return popen(realCmd.c_str(), "r");
95 }
96 
fakeExecIptablesRestoreWithOutput(IptablesTarget target,const std::string & commands,std::string * output)97 int IptablesBaseTest::fakeExecIptablesRestoreWithOutput(IptablesTarget target,
98                                                         const std::string& commands,
99                                                         std::string *output) {
100     sRestoreCmds.push_back({ target, commands });
101     if (output != nullptr) {
102         *output = sIptablesRestoreOutput.size() ? sIptablesRestoreOutput.front().c_str() : "";
103     }
104     if (sIptablesRestoreOutput.size()) {
105         sIptablesRestoreOutput.pop_front();
106     }
107     return 0;
108 }
109 
fakeExecIptablesRestore(IptablesTarget target,const std::string & commands)110 int IptablesBaseTest::fakeExecIptablesRestore(IptablesTarget target, const std::string& commands) {
111     return fakeExecIptablesRestoreWithOutput(target, commands, nullptr);
112 }
113 
fakeExecIptablesRestoreCommand(IptablesTarget target,const std::string & table,const std::string & command,std::string * output)114 int IptablesBaseTest::fakeExecIptablesRestoreCommand(IptablesTarget target,
115                                                      const std::string& table,
116                                                      const std::string& command,
117                                                      std::string *output) {
118     std::string fullCmd = StringPrintf("-t %s %s", table.c_str(), command.c_str());
119     return fakeExecIptablesRestoreWithOutput(target, fullCmd, output);
120 }
121 
expectIptablesCommand(IptablesTarget target,int pos,const std::string & cmd)122 int IptablesBaseTest::expectIptablesCommand(IptablesTarget target, int pos,
123                                             const std::string& cmd) {
124 
125     if ((unsigned) pos >= sCmds.size()) {
126         ADD_FAILURE() << "Expected too many iptables commands, want command "
127                << pos + 1 << "/" << sCmds.size();
128         return -1;
129     }
130 
131     if (target == V4 || target == V4V6) {
132         EXPECT_EQ("/system/bin/iptables -w " + cmd, sCmds[pos++]);
133     }
134     if (target == V6 || target == V4V6) {
135         EXPECT_EQ("/system/bin/ip6tables -w " + cmd, sCmds[pos++]);
136     }
137 
138     return target == V4V6 ? 2 : 1;
139 }
140 
expectIptablesCommands(const std::vector<std::string> & expectedCmds)141 void IptablesBaseTest::expectIptablesCommands(const std::vector<std::string>& expectedCmds) {
142     ExpectedIptablesCommands expected;
143     for (auto cmd : expectedCmds) {
144         expected.push_back({ V4V6, cmd });
145     }
146     expectIptablesCommands(expected);
147 }
148 
expectIptablesCommands(const ExpectedIptablesCommands & expectedCmds)149 void IptablesBaseTest::expectIptablesCommands(const ExpectedIptablesCommands& expectedCmds) {
150     size_t pos = 0;
151     for (size_t i = 0; i < expectedCmds.size(); i ++) {
152         auto target = expectedCmds[i].first;
153         auto cmd = expectedCmds[i].second;
154         int numConsumed = expectIptablesCommand(target, pos, cmd);
155         if (numConsumed < 0) {
156             // Read past the end of the array.
157             break;
158         }
159         pos += numConsumed;
160     }
161 
162     EXPECT_EQ(pos, sCmds.size());
163     sCmds.clear();
164 }
165 
expectIptablesCommands(const std::vector<ExpectedIptablesCommands> & snippets)166 void IptablesBaseTest::expectIptablesCommands(
167         const std::vector<ExpectedIptablesCommands>& snippets) {
168     ExpectedIptablesCommands expected;
169     for (const auto& snippet: snippets) {
170         expected.insert(expected.end(), snippet.begin(), snippet.end());
171     }
172     expectIptablesCommands(expected);
173 }
174 
expectIptablesRestoreCommands(const std::vector<std::string> & expectedCmds)175 void IptablesBaseTest::expectIptablesRestoreCommands(const std::vector<std::string>& expectedCmds) {
176     ExpectedIptablesCommands expected;
177     for (auto cmd : expectedCmds) {
178         expected.push_back({ V4V6, cmd });
179     }
180     expectIptablesRestoreCommands(expected);
181 }
182 
expectIptablesRestoreCommands(const ExpectedIptablesCommands & expectedCmds)183 void IptablesBaseTest::expectIptablesRestoreCommands(const ExpectedIptablesCommands& expectedCmds) {
184     EXPECT_EQ(expectedCmds.size(), sRestoreCmds.size());
185     for (size_t i = 0; i < expectedCmds.size(); i++) {
186         EXPECT_EQ(expectedCmds[i], sRestoreCmds[i]) <<
187             "iptables-restore command " << i << " differs";
188     }
189     sRestoreCmds.clear();
190 }
191 
setReturnValues(const std::deque<int> & returnValues)192 void IptablesBaseTest::setReturnValues(const std::deque<int>& returnValues) {
193     sReturnValues = returnValues;
194 }
195 
196 std::vector<std::string> IptablesBaseTest::sCmds = {};
197 IptablesBaseTest::ExpectedIptablesCommands IptablesBaseTest::sRestoreCmds = {};
198 std::deque<std::string> IptablesBaseTest::sPopenContents = {};
199 std::deque<std::string> IptablesBaseTest::sIptablesRestoreOutput = {};
200 std::deque<int> IptablesBaseTest::sReturnValues = {};
201