1 // -*- mode: c++ -*-
2 #ifndef ARES_TEST_H
3 #define ARES_TEST_H
4 
5 #include "dns-proto.h"
6 // Include ares internal file for DNS protocol constants
7 #include "nameser.h"
8 
9 #include "ares_setup.h"
10 #include "ares.h"
11 
12 #include "gtest/gtest.h"
13 #include "gmock/gmock.h"
14 
15 #ifdef HAVE_CONFIG_H
16 #include "config.h"
17 #endif
18 #if defined(HAVE_USER_NAMESPACE) && defined(HAVE_UTS_NAMESPACE)
19 #define HAVE_CONTAINER
20 #endif
21 
22 #include <functional>
23 #include <list>
24 #include <map>
25 #include <memory>
26 #include <set>
27 #include <string>
28 #include <utility>
29 #include <vector>
30 
31 namespace ares {
32 
33 typedef unsigned char byte;
34 
35 namespace test {
36 
37 extern bool verbose;
38 extern int mock_port;
39 extern const std::vector<int> both_families;
40 extern const std::vector<int> ipv4_family;
41 extern const std::vector<int> ipv6_family;
42 
43 extern const std::vector<std::pair<int, bool>> both_families_both_modes;
44 extern const std::vector<std::pair<int, bool>> ipv4_family_both_modes;
45 extern const std::vector<std::pair<int, bool>> ipv6_family_both_modes;
46 
47 // Which parameters to use in tests
48 extern std::vector<int> families;
49 extern std::vector<std::pair<int, bool>> families_modes;
50 
51 // Process all pending work on ares-owned file descriptors, plus
52 // optionally the given set-of-FDs + work function.
53 void ProcessWork(ares_channel channel,
54                  std::function<std::set<int>()> get_extrafds,
55                  std::function<void(int)> process_extra);
56 std::set<int> NoExtraFDs();
57 
58 // Test fixture that ensures library initialization, and allows
59 // memory allocations to be failed.
60 class LibraryTest : public ::testing::Test {
61  public:
LibraryTest()62   LibraryTest() {
63     EXPECT_EQ(ARES_SUCCESS,
64               ares_library_init_mem(ARES_LIB_INIT_ALL,
65                                     &LibraryTest::amalloc,
66                                     &LibraryTest::afree,
67                                     &LibraryTest::arealloc));
68   }
~LibraryTest()69   ~LibraryTest() {
70     ares_library_cleanup();
71     ClearFails();
72   }
73   // Set the n-th malloc call (of any size) from the library to fail.
74   // (nth == 1 means the next call)
75   static void SetAllocFail(int nth);
76   // Set the next malloc call for the given size to fail.
77   static void SetAllocSizeFail(size_t size);
78   // Remove any pending alloc failures.
79   static void ClearFails();
80 
81   static void *amalloc(size_t size);
82   static void* arealloc(void *ptr, size_t size);
83   static void afree(void *ptr);
84  private:
85   static bool ShouldAllocFail(size_t size);
86   static unsigned long long fails_;
87   static std::map<size_t, int> size_fails_;
88 };
89 
90 // Test fixture that uses a default channel.
91 class DefaultChannelTest : public LibraryTest {
92  public:
DefaultChannelTest()93   DefaultChannelTest() : channel_(nullptr) {
94     EXPECT_EQ(ARES_SUCCESS, ares_init(&channel_));
95     EXPECT_NE(nullptr, channel_);
96   }
97 
~DefaultChannelTest()98   ~DefaultChannelTest() {
99     ares_destroy(channel_);
100     channel_ = nullptr;
101   }
102 
103   // Process all pending work on ares-owned file descriptors.
104   void Process();
105 
106  protected:
107   ares_channel channel_;
108 };
109 
110 // Test fixture that uses a default channel with the specified lookup mode.
111 class DefaultChannelModeTest
112     : public LibraryTest,
113       public ::testing::WithParamInterface<std::string> {
114  public:
DefaultChannelModeTest()115   DefaultChannelModeTest() : channel_(nullptr) {
116     struct ares_options opts = {0};
117     opts.lookups = strdup(GetParam().c_str());
118     int optmask = ARES_OPT_LOOKUPS;
119     EXPECT_EQ(ARES_SUCCESS, ares_init_options(&channel_, &opts, optmask));
120     EXPECT_NE(nullptr, channel_);
121     free(opts.lookups);
122   }
123 
~DefaultChannelModeTest()124   ~DefaultChannelModeTest() {
125     ares_destroy(channel_);
126     channel_ = nullptr;
127   }
128 
129   // Process all pending work on ares-owned file descriptors.
130   void Process();
131 
132  protected:
133   ares_channel channel_;
134 };
135 
136 // Mock DNS server to allow responses to be scripted by tests.
137 class MockServer {
138  public:
139   MockServer(int family, int port, int tcpport = 0);
140   ~MockServer();
141 
142   // Mock method indicating the processing of a particular <name, RRtype>
143   // request.
144   MOCK_METHOD2(OnRequest, void(const std::string& name, int rrtype));
145 
146   // Set the reply to be sent next; the query ID field will be overwritten
147   // with the value from the request.
SetReplyData(const std::vector<byte> & reply)148   void SetReplyData(const std::vector<byte>& reply) { reply_ = reply; }
SetReply(const DNSPacket * reply)149   void SetReply(const DNSPacket* reply) { SetReplyData(reply->data()); }
SetReplyQID(int qid)150   void SetReplyQID(int qid) { qid_ = qid; }
151 
152   // The set of file descriptors that the server handles.
153   std::set<int> fds() const;
154 
155   // Process activity on a file descriptor.
156   void ProcessFD(int fd);
157 
158   // Ports the server is responding to
udpport()159   int udpport() const { return udpport_; }
tcpport()160   int tcpport() const { return tcpport_; }
161 
162  private:
163   void ProcessRequest(int fd, struct sockaddr_storage* addr, int addrlen,
164                       int qid, const std::string& name, int rrtype);
165 
166   int udpport_;
167   int tcpport_;
168   int udpfd_;
169   int tcpfd_;
170   std::set<int> connfds_;
171   std::vector<byte> reply_;
172   int qid_;
173 };
174 
175 // Test fixture that uses a mock DNS server.
176 class MockChannelOptsTest : public LibraryTest {
177  public:
178   MockChannelOptsTest(int count, int family, bool force_tcp, struct ares_options* givenopts, int optmask);
179   ~MockChannelOptsTest();
180 
181   // Process all pending work on ares-owned and mock-server-owned file descriptors.
182   void Process();
183 
184  protected:
185   // NiceMockServer doesn't complain about uninteresting calls.
186   typedef testing::NiceMock<MockServer> NiceMockServer;
187   typedef std::vector< std::unique_ptr<NiceMockServer> > NiceMockServers;
188 
189   std::set<int> fds() const;
190   void ProcessFD(int fd);
191 
192   static NiceMockServers BuildServers(int count, int family, int base_port);
193 
194   NiceMockServers servers_;
195   // Convenience reference to first server.
196   NiceMockServer& server_;
197   ares_channel channel_;
198 };
199 
200 class MockChannelTest
201     : public MockChannelOptsTest,
202       public ::testing::WithParamInterface< std::pair<int, bool> > {
203  public:
MockChannelTest()204   MockChannelTest() : MockChannelOptsTest(1, GetParam().first, GetParam().second, nullptr, 0) {}
205 };
206 
207 class MockUDPChannelTest
208     : public MockChannelOptsTest,
209       public ::testing::WithParamInterface<int> {
210  public:
MockUDPChannelTest()211   MockUDPChannelTest() : MockChannelOptsTest(1, GetParam(), false, nullptr, 0) {}
212 };
213 
214 class MockTCPChannelTest
215     : public MockChannelOptsTest,
216       public ::testing::WithParamInterface<int> {
217  public:
MockTCPChannelTest()218   MockTCPChannelTest() : MockChannelOptsTest(1, GetParam(), true, nullptr, 0) {}
219 };
220 
221 // gMock action to set the reply for a mock server.
ACTION_P2(SetReplyData,mockserver,data)222 ACTION_P2(SetReplyData, mockserver, data) {
223   mockserver->SetReplyData(data);
224 }
ACTION_P2(SetReply,mockserver,reply)225 ACTION_P2(SetReply, mockserver, reply) {
226   mockserver->SetReply(reply);
227 }
ACTION_P2(SetReplyQID,mockserver,qid)228 ACTION_P2(SetReplyQID, mockserver, qid) {
229   mockserver->SetReplyQID(qid);
230 }
231 // gMock action to cancel a channel.
ACTION_P2(CancelChannel,mockserver,channel)232 ACTION_P2(CancelChannel, mockserver, channel) {
233   ares_cancel(channel);
234 }
235 
236 // C++ wrapper for struct hostent.
237 struct HostEnt {
HostEntHostEnt238   HostEnt() : addrtype_(-1) {}
239   HostEnt(const struct hostent* hostent);
240   std::string name_;
241   std::vector<std::string> aliases_;
242   int addrtype_;  // AF_INET or AF_INET6
243   std::vector<std::string> addrs_;
244 };
245 std::ostream& operator<<(std::ostream& os, const HostEnt& result);
246 
247 // Structure that describes the result of an ares_host_callback invocation.
248 struct HostResult {
249   // Whether the callback has been invoked.
250   bool done_;
251   // Explicitly provided result information.
252   int status_;
253   int timeouts_;
254   // Contents of the hostent structure, if provided.
255   HostEnt host_;
256 };
257 std::ostream& operator<<(std::ostream& os, const HostResult& result);
258 
259 // Structure that describes the result of an ares_callback invocation.
260 struct SearchResult {
261   // Whether the callback has been invoked.
262   bool done_;
263   // Explicitly provided result information.
264   int status_;
265   int timeouts_;
266   std::vector<byte> data_;
267 };
268 std::ostream& operator<<(std::ostream& os, const SearchResult& result);
269 
270 // Structure that describes the result of an ares_nameinfo_callback invocation.
271 struct NameInfoResult {
272   // Whether the callback has been invoked.
273   bool done_;
274   // Explicitly provided result information.
275   int status_;
276   int timeouts_;
277   std::string node_;
278   std::string service_;
279 };
280 std::ostream& operator<<(std::ostream& os, const NameInfoResult& result);
281 
282 // Standard implementation of ares callbacks that fill out the corresponding
283 // structures.
284 void HostCallback(void *data, int status, int timeouts,
285                   struct hostent *hostent);
286 void SearchCallback(void *data, int status, int timeouts,
287                     unsigned char *abuf, int alen);
288 void NameInfoCallback(void *data, int status, int timeouts,
289                       char *node, char *service);
290 
291 // Retrieve the name servers used by a channel.
292 std::vector<std::string> GetNameServers(ares_channel channel);
293 
294 
295 // RAII class to temporarily create a directory of a given name.
296 class TransientDir {
297  public:
298   TransientDir(const std::string& dirname);
299   ~TransientDir();
300 
301  private:
302   std::string dirname_;
303 };
304 
305 // C++ wrapper around tempnam()
306 std::string TempNam(const char *dir, const char *prefix);
307 
308 // RAII class to temporarily create file of a given name and contents.
309 class TransientFile {
310  public:
311   TransientFile(const std::string &filename, const std::string &contents);
312   ~TransientFile();
313 
314  protected:
315   std::string filename_;
316 };
317 
318 // RAII class for a temporary file with the given contents.
319 class TempFile : public TransientFile {
320  public:
321   TempFile(const std::string& contents);
filename()322   const char* filename() const { return filename_.c_str(); }
323 };
324 
325 #ifndef WIN32
326 // RAII class for a temporary environment variable value.
327 class EnvValue {
328  public:
EnvValue(const char * name,const char * value)329   EnvValue(const char *name, const char *value) : name_(name), restore_(false) {
330     char *original = getenv(name);
331     if (original) {
332       restore_ = true;
333       original_ = original;
334     }
335     setenv(name_.c_str(), value, 1);
336   }
~EnvValue()337   ~EnvValue() {
338     if (restore_) {
339       setenv(name_.c_str(), original_.c_str(), 1);
340     } else {
341       unsetenv(name_.c_str());
342     }
343   }
344  private:
345   std::string name_;
346   bool restore_;
347   std::string original_;
348 };
349 #endif
350 
351 
352 #ifdef HAVE_CONTAINER
353 // Linux-specific functionality for running code in a container, implemented
354 // in ares-test-ns.cc
355 typedef std::function<int(void)> VoidToIntFn;
356 typedef std::vector<std::pair<std::string, std::string>> NameContentList;
357 
358 class ContainerFilesystem {
359  public:
360   ContainerFilesystem(NameContentList files, const std::string& mountpt);
361   ~ContainerFilesystem();
root()362   std::string root() const { return rootdir_; };
mountpt()363   std::string mountpt() const { return mountpt_; };
364  private:
365   void EnsureDirExists(const std::string& dir);
366   std::string rootdir_;
367   std::string mountpt_;
368   std::list<std::string> dirs_;
369   std::vector<std::unique_ptr<TransientFile>> files_;
370 };
371 
372 int RunInContainer(ContainerFilesystem* fs, const std::string& hostname,
373                    const std::string& domainname, VoidToIntFn fn);
374 
375 #define ICLASS_NAME(casename, testname) Contained##casename##_##testname
376 #define CONTAINED_TEST_F(casename, testname, hostname, domainname, files)       \
377   class ICLASS_NAME(casename, testname) : public casename {                     \
378    public:                                                                      \
379     ICLASS_NAME(casename, testname)() {}                                        \
380     static int InnerTestBody();                                                 \
381   };                                                                            \
382   TEST_F(ICLASS_NAME(casename, testname), _) {                                  \
383     ContainerFilesystem chroot(files, "..");                                    \
384     VoidToIntFn fn(ICLASS_NAME(casename, testname)::InnerTestBody);             \
385     EXPECT_EQ(0, RunInContainer(&chroot, hostname, domainname, fn));            \
386   }                                                                             \
387   int ICLASS_NAME(casename, testname)::InnerTestBody()
388 
389 #endif
390 
391 /* Assigns virtual IO functions to a channel. These functions simply call
392  * the actual system functions.
393  */
394 class VirtualizeIO {
395 public:
396   VirtualizeIO(ares_channel);
397   ~VirtualizeIO();
398 
399   static const ares_socket_functions default_functions;
400 private:
401   ares_channel channel_;
402 };
403 
404 /*
405  * Slightly white-box macro to generate two runs for a given test case:
406  * One with no modifications, and one with all IO functions set to use
407  * the virtual io structure.
408  * Since no magic socket setup or anything is done in the latter case
409  * this should probably only be used for test with very vanilla IO
410  * requirements.
411  */
412 #define VCLASS_NAME(casename, testname) Virt##casename##_##testname
413 #define VIRT_NONVIRT_TEST_F(casename, testname)                                 \
414   class VCLASS_NAME(casename, testname) : public casename {                     \
415   public:                                                                       \
416     VCLASS_NAME(casename, testname)() {}                                        \
417     void InnerTestBody();                                                       \
418   };                                                                            \
419   GTEST_TEST_(casename, testname, VCLASS_NAME(casename, testname),              \
420               ::testing::internal::GetTypeId<casename>()) {                     \
421     InnerTestBody();                                                            \
422   }                                                                             \
423   GTEST_TEST_(casename, testname##_virtualized,                                 \
424               VCLASS_NAME(casename, testname),                                  \
425               ::testing::internal::GetTypeId<casename>()) {                     \
426     VirtualizeIO vio(channel_);                                                 \
427     InnerTestBody();                                                            \
428   }                                                                             \
429   void VCLASS_NAME(casename, testname)::InnerTestBody()
430 
431 
432 }  // namespace test
433 }  // namespace ares
434 
435 #endif
436