1 //
2 // Copyright (C) 2013 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 
17 #include "shill/shims/netfilter_queue_processor.h"
18 
19 #include <arpa/inet.h>
20 #include <errno.h>
21 #include <libnetfilter_queue/libnetfilter_queue.h>
22 #include <linux/ip.h>
23 #include <linux/netfilter.h>    /* for NF_ACCEPT */
24 #include <linux/types.h>
25 #include <linux/udp.h>
26 #include <net/if.h>
27 #include <netinet/in.h>
28 #include <string.h>
29 #include <sys/ioctl.h>
30 #include <unistd.h>
31 
32 #include <deque>
33 
34 #include <base/files/scoped_file.h>
35 #include <base/logging.h>
36 #include <base/strings/stringprintf.h>
37 
38 using std::deque;
39 
40 namespace shill {
41 
42 namespace shims {
43 
44 // static
45 const int NetfilterQueueProcessor::kBufferSize = 4096;
46 const int NetfilterQueueProcessor::kExpirationIntervalSeconds = 5;
47 const int NetfilterQueueProcessor::kIPHeaderLengthUnitBytes = 4;
48 const int NetfilterQueueProcessor::kMaxIPHeaderLength =
49     16;  // ihl is a 4-bit field.
50 const size_t NetfilterQueueProcessor::kMaxListenerEntries = 32;
51 const int NetfilterQueueProcessor::kPayloadCopySize = 0xffff;
52 
Packet()53 NetfilterQueueProcessor::Packet::Packet()
54     : packet_id_(0),
55       in_device_(0),
56       out_device_(0),
57       is_udp_(false),
58       source_ip_(INADDR_ANY),
59       destination_ip_(INADDR_ANY),
60       source_port_(0),
61       destination_port_(0) {}
62 
~Packet()63 NetfilterQueueProcessor::Packet::~Packet() {}
64 
AddressAndPortToString(uint32_t ip,uint16_t port)65 std::string NetfilterQueueProcessor::AddressAndPortToString(uint32_t ip,
66                                                             uint16_t port) {
67   struct in_addr addr;
68   addr.s_addr = htonl(ip);
69   return base::StringPrintf("%s:%d", inet_ntoa(addr), port);
70 }
71 
ParseNetfilterData(struct nfq_data * netfilter_data)72 bool NetfilterQueueProcessor::Packet::ParseNetfilterData(
73     struct nfq_data* netfilter_data) {
74   struct nfqnl_msg_packet_hdr* packet_header =
75       nfq_get_msg_packet_hdr(netfilter_data);
76   if (!packet_header) {
77     return false;
78   }
79   packet_id_ = ntohl(packet_header->packet_id);
80   in_device_ = nfq_get_indev(netfilter_data);
81   out_device_ = nfq_get_outdev(netfilter_data);
82 
83   unsigned char* payload;
84   int payload_len = nfq_get_payload(netfilter_data, &payload);
85   if (payload_len >= 0) {
86     is_udp_ = ParsePayloadUDPData(payload, payload_len);
87   }
88 
89   return true;
90 }
91 
ParsePayloadUDPData(const unsigned char * payload,size_t payload_len)92 bool NetfilterQueueProcessor::Packet::ParsePayloadUDPData(
93     const unsigned char* payload, size_t payload_len) {
94   struct iphdr ip;
95 
96   if (payload_len <= sizeof(ip)) {
97     return false;
98   }
99 
100   memcpy(&ip, payload, sizeof(ip));
101 
102   size_t iphdr_len = ip.ihl * kIPHeaderLengthUnitBytes;
103   if (iphdr_len < sizeof(ip) ||
104       ip.version != IPVERSION ||
105       ip.protocol != IPPROTO_UDP) {
106     return false;
107   }
108 
109   struct udphdr udp;
110   if (payload_len < iphdr_len + sizeof(udp)) {
111     return false;
112   }
113 
114   memcpy(&udp, payload + iphdr_len, sizeof(udp));
115 
116   source_ip_ = ntohl(ip.saddr);
117   destination_ip_ = ntohl(ip.daddr);
118   source_port_ = ntohs(udp.source);
119   destination_port_ = ntohs(udp.dest);
120 
121   return true;
122 }
123 
SetValues(int in_device,int out_device,bool is_udp,uint32_t packet_id,uint32_t source_ip,uint32_t destination_ip,uint16_t source_port,uint16_t destination_port)124 void NetfilterQueueProcessor::Packet::SetValues(int in_device,
125                                                 int out_device,
126                                                 bool is_udp,
127                                                 uint32_t packet_id,
128                                                 uint32_t source_ip,
129                                                 uint32_t destination_ip,
130                                                 uint16_t source_port,
131                                                 uint16_t destination_port) {
132   in_device_ = in_device;
133   out_device_ = out_device;
134   is_udp_ = is_udp;
135   packet_id_ = packet_id;
136   source_ip_ = source_ip;
137   destination_ip_ = destination_ip;
138   source_port_ = source_port;
139   destination_port_ = destination_port;
140 }
141 
NetfilterQueueProcessor(int input_queue,int output_queue)142 NetfilterQueueProcessor::NetfilterQueueProcessor(
143     int input_queue, int output_queue)
144     : input_queue_(input_queue),
145       output_queue_(output_queue),
146       nfq_handle_(NULL),
147       input_queue_handle_(NULL),
148       output_queue_handle_(NULL)  {
149   VLOG(2) << "Created netfilter queue processor.";
150 }
151 
~NetfilterQueueProcessor()152 NetfilterQueueProcessor::~NetfilterQueueProcessor() {
153   Stop();
154 }
155 
Run()156 void NetfilterQueueProcessor::Run() {
157   LOG(INFO) << "Netfilter queue processor running.";
158   CHECK(nfq_handle_);
159 
160   int file_handle = nfq_fd(nfq_handle_);
161   char buffer[kBufferSize] __attribute__((aligned));
162 
163   for (;;) {
164     int receive_count = recv(file_handle, buffer, sizeof(buffer), 0);
165     if (receive_count <= 0) {
166       if (receive_count < 0 && errno == ENOBUFS) {
167         LOG(WARNING) << "Packets dropped in the queue.";
168         continue;
169       }
170       LOG(ERROR) << "Receive failed; exiting";
171       break;
172     }
173 
174     nfq_handle_packet(nfq_handle_, buffer, receive_count);
175   }
176 }
177 
Start()178 bool NetfilterQueueProcessor::Start() {
179   VLOG(2) << "Netfilter queue processor starting.";
180   if (!nfq_handle_) {
181     nfq_handle_ = nfq_open();
182     if (!nfq_handle_) {
183       LOG(ERROR) << "nfq_open() returned an error";
184       return false;
185     }
186   }
187 
188   if (nfq_unbind_pf(nfq_handle_, AF_INET) < 0) {
189     LOG(ERROR) << "nfq_unbind_pf() returned an error";
190     return false;
191   }
192 
193   if (nfq_bind_pf(nfq_handle_, AF_INET) < 0) {
194     LOG(ERROR) << "nfq_bind_pf() returned an error";
195     return false;
196   }
197 
198   input_queue_handle_ = nfq_create_queue(
199       nfq_handle_, input_queue_,
200       &NetfilterQueueProcessor::InputQueueCallback, this);
201   if (!input_queue_handle_) {
202     LOG(ERROR) << "nfq_create_queue() failed for input queue " << input_queue_;
203     return false;
204   }
205 
206   if (nfq_set_mode(input_queue_handle_, NFQNL_COPY_PACKET,
207                    kPayloadCopySize) < 0) {
208     LOG(ERROR) << "nfq_set_mode() failed: can't set input queue packet_copy.";
209     return false;
210   }
211 
212   output_queue_handle_ = nfq_create_queue(
213       nfq_handle_, output_queue_,
214       &NetfilterQueueProcessor::OutputQueueCallback, this);
215   if (!output_queue_handle_) {
216     LOG(ERROR) << "nfq_create_queue() failed for output queue "
217                << output_queue_;
218     return false;
219   }
220 
221   if (nfq_set_mode(output_queue_handle_, NFQNL_COPY_PACKET,
222                    kPayloadCopySize) < 0) {
223     LOG(ERROR) << "nfq_set_mode() failed: can't set output queue packet_copy.";
224     return false;
225   }
226 
227   return true;
228 }
229 
Stop()230 void NetfilterQueueProcessor::Stop() {
231   if (input_queue_handle_) {
232     nfq_destroy_queue(input_queue_handle_);
233     input_queue_handle_ = NULL;
234   }
235 
236   if (output_queue_handle_) {
237     nfq_destroy_queue(output_queue_handle_);
238     output_queue_handle_ = NULL;
239   }
240 
241   if (nfq_handle_) {
242     nfq_close(nfq_handle_);
243     nfq_handle_ = NULL;
244   }
245 }
246 
247 // static
InputQueueCallback(struct nfq_q_handle * queue_handle,struct nfgenmsg * generic_message,struct nfq_data * netfilter_data,void * private_data)248 int NetfilterQueueProcessor::InputQueueCallback(
249     struct nfq_q_handle* queue_handle,
250     struct nfgenmsg* generic_message,
251     struct nfq_data* netfilter_data,
252     void* private_data) {
253   Packet packet;
254   if (!packet.ParseNetfilterData(netfilter_data)) {
255     LOG(FATAL) << "Unable to parse netfilter data.";
256   }
257 
258   NetfilterQueueProcessor* processor =
259       reinterpret_cast<NetfilterQueueProcessor*>(private_data);
260   uint32_t verdict;
261   time_t now = time(NULL);
262   if (processor->IsIncomingPacketAllowed(packet, now)) {
263     verdict = NF_ACCEPT;
264   } else {
265     verdict = NF_DROP;
266   }
267   return nfq_set_verdict(queue_handle, packet.packet_id(), verdict, 0, NULL);
268 }
269 
270 // static
OutputQueueCallback(struct nfq_q_handle * queue_handle,struct nfgenmsg * generic_message,struct nfq_data * netfilter_data,void * private_data)271 int NetfilterQueueProcessor::OutputQueueCallback(
272     struct nfq_q_handle* queue_handle,
273     struct nfgenmsg* generic_message,
274     struct nfq_data* netfilter_data,
275     void* private_data) {
276   Packet packet;
277   if (!packet.ParseNetfilterData(netfilter_data)) {
278     LOG(FATAL) << "Unable to get parse netfilter data.";
279   }
280 
281   NetfilterQueueProcessor* processor =
282       reinterpret_cast<NetfilterQueueProcessor*>(private_data);
283   time_t now = time(NULL);
284   processor->LogOutgoingPacket(packet, now);
285   return nfq_set_verdict(queue_handle, packet.packet_id(), NF_ACCEPT, 0, NULL);
286 }
287 
288 // static
GetNetmaskForDevice(int device_index)289 uint32_t NetfilterQueueProcessor::GetNetmaskForDevice(int device_index) {
290   struct ifreq ifr;
291   memset(&ifr, 0, sizeof(ifr));
292   if (if_indextoname(device_index, ifr.ifr_name) != ifr.ifr_name) {
293     return INADDR_NONE;
294   }
295 
296   int socket_fd = socket(AF_INET, SOCK_DGRAM, 0);
297   if (socket_fd < 0) {
298     return INADDR_NONE;
299   }
300 
301   base::ScopedFD scoped_fd(socket_fd);
302 
303   if (ioctl(socket_fd, SIOCGIFNETMASK, &ifr) != 0) {
304     return INADDR_NONE;
305   }
306 
307   struct sockaddr_in* netmask_addr =
308       reinterpret_cast<struct sockaddr_in*>(&ifr.ifr_netmask);
309   return ntohl(netmask_addr->sin_addr.s_addr);
310 }
311 
ExpireListeners(time_t now)312 void NetfilterQueueProcessor::ExpireListeners(time_t now) {
313   time_t expiration_threshold = now - kExpirationIntervalSeconds;
314   VLOG(2) << __func__ << " entered.";
315   while (!listeners_.empty()) {
316     const ListenerEntryPtr& last_listener = listeners_.back();
317     if (last_listener->last_transmission >= expiration_threshold &&
318         listeners_.size() <= kMaxListenerEntries) {
319       break;
320     }
321     VLOG(2) << "Expired listener for "
322             << AddressAndPortToString(last_listener->address,
323                                       last_listener->port);
324     listeners_.pop_back();
325   }
326 }
327 
328 deque<NetfilterQueueProcessor::ListenerEntryPtr>::iterator
FindListener(uint16_t port,int device_index,uint32_t address)329     NetfilterQueueProcessor::FindListener(uint16_t port,
330                                           int device_index,
331                                           uint32_t address) {
332   deque<ListenerEntryPtr>::iterator it;
333   for (it = listeners_.begin(); it != listeners_.end(); ++it) {
334     if ((*it)->port == port &&
335         (*it)->device_index == device_index &&
336         (*it)->address == address) {
337       break;
338     }
339   }
340   return it;
341 }
342 
343 deque<NetfilterQueueProcessor::ListenerEntryPtr>::iterator
FindDestination(uint16_t port,int device_index,uint32_t destination)344     NetfilterQueueProcessor::FindDestination(uint16_t port,
345                                              int device_index,
346                                              uint32_t destination) {
347   deque<ListenerEntryPtr>::iterator it;
348   for (it = listeners_.begin(); it != listeners_.end(); ++it) {
349     if ((*it)->port == port &&
350         (*it)->device_index == device_index &&
351         (*it)->destination == destination) {
352       break;
353     }
354   }
355   return it;
356 }
357 
IsIncomingPacketAllowed(const Packet & packet,time_t now)358 bool NetfilterQueueProcessor::IsIncomingPacketAllowed(
359     const Packet& packet, time_t now) {
360   VLOG(2) << __func__ << " entered.";
361   VLOG(3) << "Incoming packet is from "
362           << AddressAndPortToString(packet.source_ip(),
363                                     packet.source_port())
364           << " and to "
365           << AddressAndPortToString(packet.destination_ip(),
366                                     packet.destination_port());
367   if (!packet.is_udp()) {
368     VLOG(2) << "Incoming packet is not udp.";
369     return false;
370   }
371 
372   ExpireListeners(now);
373 
374   uint16_t port = packet.destination_port();
375   uint32_t address = packet.destination_ip();
376   int device_index = packet.in_device();
377 
378   deque<ListenerEntryPtr>::iterator entry_ptr = listeners_.end();
379   if (IN_MULTICAST(address)) {
380     VLOG(2) << "Incoming packet is multicast.";
381     entry_ptr = FindDestination(port, device_index, address);
382   } else {
383     entry_ptr = FindListener(port, device_index, address);
384   }
385 
386   if (entry_ptr == listeners_.end()) {
387     VLOG(2) << "Incoming does not match any listener.";
388     return false;
389   }
390 
391   uint32_t netmask = (*entry_ptr)->netmask;
392   if ((packet.source_ip() & netmask) != ((*entry_ptr)->address & netmask)) {
393     VLOG(2) << "Incoming packet is from a non-local address.";
394     return false;
395   }
396 
397   VLOG(3) << "Accepting packet.";
398   return true;
399 }
400 
LogOutgoingPacket(const Packet & packet,time_t now)401 void NetfilterQueueProcessor::LogOutgoingPacket(
402     const Packet& packet, time_t now) {
403   VLOG(2) << __func__ << " entered.";
404   if (!packet.is_udp()) {
405     VLOG(2) << "Outgoing packet is not udp.";
406     return;
407   }
408   if (!IN_MULTICAST(packet.destination_ip())) {
409     VLOG(2) << "Outgoing packet is not multicast.";
410     return;
411   }
412   int device_index = packet.out_device();
413   if (device_index == 0) {
414     VLOG(2) << "Outgoing packet is not assigned a valid device.";
415     return;
416   }
417   uint16_t port = packet.source_port();
418   uint32_t address = packet.source_ip();
419   uint32_t destination = 0;
420   // Allow multicast replies if the destination port of the packet is the
421   // same as the port the sender transmitted from;
422   if (packet.source_port() == packet.destination_port()) {
423     destination = packet.destination_ip();
424   }
425   deque<ListenerEntryPtr>::iterator entry_it =
426       FindListener(port, device_index, address);
427   if (entry_it != listeners_.end()) {
428     if (entry_it != listeners_.begin()) {
429       // Make this the newest entry.
430       ListenerEntryPtr entry_ptr = *entry_it;
431       listeners_.erase(entry_it);
432       listeners_.push_front(entry_ptr);
433       entry_it = listeners_.begin();
434     }
435     (*entry_it)->last_transmission = now;
436   } else {
437     uint32_t netmask = GetNetmaskForDevice(device_index);
438     ListenerEntryPtr entry_ptr(
439         new ListenerEntry(now, port, device_index,
440                           address, netmask, destination));
441     listeners_.push_front(entry_ptr);
442     VLOG(2) << "Added listener for " << AddressAndPortToString(address, port)
443             << " with destination "
444             << AddressAndPortToString(destination, port);
445   }
446 
447   // Perform expiration at the end, so that we don't end up expiring something
448   // just to resurrect it again.
449   ExpireListeners(now);
450 }
451 
452 }  // namespace shims
453 
454 }  // namespace shill
455 
456