1# Copyright (c) 2013 The Chromium OS Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5import collections
6import logging
7import os.path
8import time
9import uuid
10
11from autotest_lib.client.bin import site_utils
12from autotest_lib.client.common_lib import error
13from autotest_lib.client.common_lib.cros import path_utils
14
15
16class PacketCapturesDisabledError(Exception):
17    """Signifies that this remote host does not support packet captures."""
18    pass
19
20
21# local_pcap_path refers to the path of the result on the local host.
22# local_log_path refers to the tcpdump log file path on the local host.
23CaptureResult = collections.namedtuple('CaptureResult',
24                                       ['local_pcap_path', 'local_log_path'])
25
26# The number of bytes needed for a probe request is hard to define,
27# because the frame contents are variable (e.g. radiotap header may
28# contain different fields, maybe SSID isn't the first tagged
29# parameter?). The value here is 2x the largest frame size observed in
30# a quick sample.
31SNAPLEN_WIFI_PROBE_REQUEST = 600
32
33TCPDUMP_START_TIMEOUT_SECONDS = 5
34TCPDUMP_START_POLL_SECONDS = 0.1
35
36def get_packet_capturer(host, host_description=None, cmd_ifconfig=None,
37                        cmd_ip=None, cmd_iw=None, cmd_netdump=None,
38                        ignore_failures=False):
39    cmd_ifconfig = (cmd_ifconfig or
40                    path_utils.get_install_path('ifconfig', host=host))
41    cmd_iw = cmd_iw or path_utils.get_install_path('iw', host=host)
42    cmd_ip = cmd_ip or path_utils.get_install_path('ip', host=host)
43    cmd_netdump = (cmd_netdump or
44                   path_utils.get_install_path('tcpdump', host=host))
45    host_description = host_description or 'cap_%s' % uuid.uuid4().hex
46    if None in [cmd_ifconfig, cmd_iw, cmd_ip, cmd_netdump, host_description]:
47        if ignore_failures:
48            logging.warning('Creating a disabled packet capturer for %s.',
49                            host_description)
50            return DisabledPacketCapturer()
51        else:
52            raise error.TestFail('Missing commands needed for '
53                                 'capturing packets')
54
55    return PacketCapturer(host, host_description, cmd_ifconfig, cmd_ip, cmd_iw,
56                          cmd_netdump)
57
58
59class DisabledPacketCapturer(object):
60    """Delegate meant to look like it could take packet captures."""
61
62    @property
63    def capture_running(self):
64        """@return False"""
65        return False
66
67
68    def __init__(self):
69        pass
70
71
72    def  __enter__(self):
73        return self
74
75
76    def __exit__(self):
77        pass
78
79
80    def close(self):
81        """No-op"""
82
83
84    def create_raw_monitor(self, phy, frequency, ht_type=None,
85                           monitor_device=None):
86        """Appears to fail while creating a raw monitor device.
87
88        @param phy string ignored.
89        @param frequency int ignored.
90        @param ht_type string ignored.
91        @param monitor_device string ignored.
92        @return None.
93
94        """
95        return None
96
97
98    def configure_raw_monitor(self, monitor_device, frequency, ht_type=None):
99        """Fails to configure a raw monitor.
100
101        @param monitor_device string ignored.
102        @param frequency int ignored.
103        @param ht_type string ignored.
104
105        """
106
107
108    def create_managed_monitor(self, existing_dev, monitor_device=None):
109        """Fails to create a managed monitor device.
110
111        @param existing_device string ignored.
112        @param monitor_device string ignored.
113        @return None
114
115        """
116        return None
117
118
119    def start_capture(self, interface, local_save_dir,
120                      remote_file=None, snaplen=None):
121        """Fails to start a packet capture.
122
123        @param interface string ignored.
124        @param local_save_dir string ignored.
125        @param remote_file string ignored.
126        @param snaplen int ignored.
127
128        @raises PacketCapturesDisabledError.
129
130        """
131        raise PacketCapturesDisabledError()
132
133
134    def stop_capture(self, capture_pid=None):
135        """Stops all ongoing packet captures.
136
137        @param capture_pid int ignored.
138
139        """
140
141
142class PacketCapturer(object):
143    """Delegate with capability to initiate packet captures on a remote host."""
144
145    LIBPCAP_POLL_FREQ_SECS = 1
146
147    @property
148    def capture_running(self):
149        """@return True iff we have at least one ongoing packet capture."""
150        if self._ongoing_captures:
151            return True
152
153        return False
154
155
156    def __init__(self, host, host_description, cmd_ifconfig, cmd_ip,
157                 cmd_iw, cmd_netdump, disable_captures=False):
158        self._cmd_netdump = cmd_netdump
159        self._cmd_iw = cmd_iw
160        self._cmd_ip = cmd_ip
161        self._cmd_ifconfig = cmd_ifconfig
162        self._host = host
163        self._ongoing_captures = {}
164        self._cap_num = 0
165        self._if_num = 0
166        self._created_managed_devices = []
167        self._created_raw_devices = []
168        self._host_description = host_description
169
170
171    def __enter__(self):
172        return self
173
174
175    def __exit__(self):
176        self.close()
177
178
179    def close(self):
180        """Stop ongoing captures and destroy all created devices."""
181        self.stop_capture()
182        for device in self._created_managed_devices:
183            self._host.run("%s dev %s del" % (self._cmd_iw, device))
184        self._created_managed_devices = []
185        for device in self._created_raw_devices:
186            self._host.run("%s link set %s down" % (self._cmd_ip, device))
187            self._host.run("%s dev %s del" % (self._cmd_iw, device))
188        self._created_raw_devices = []
189
190
191    def create_raw_monitor(self, phy, frequency, ht_type=None,
192                           monitor_device=None):
193        """Create and configure a monitor type WiFi interface on a phy.
194
195        If a device called |monitor_device| already exists, it is first removed.
196
197        @param phy string phy name for created monitor (e.g. phy0).
198        @param frequency int frequency for created monitor to watch.
199        @param ht_type string optional HT type ('HT20', 'HT40+', or 'HT40-').
200        @param monitor_device string name of monitor interface to create.
201        @return string monitor device name created or None on failure.
202
203        """
204        if not monitor_device:
205            monitor_device = 'mon%d' % self._if_num
206            self._if_num += 1
207
208        self._host.run('%s dev %s del' % (self._cmd_iw, monitor_device),
209                       ignore_status=True)
210        result = self._host.run('%s phy %s interface add %s type monitor' %
211                                (self._cmd_iw,
212                                 phy,
213                                 monitor_device),
214                                ignore_status=True)
215        if result.exit_status:
216            logging.error('Failed creating raw monitor.')
217            return None
218
219        self.configure_raw_monitor(monitor_device, frequency, ht_type)
220        self._created_raw_devices.append(monitor_device)
221        return monitor_device
222
223
224    def configure_raw_monitor(self, monitor_device, frequency, ht_type=None):
225        """Configure a raw monitor with frequency and HT params.
226
227        Note that this will stomp on earlier device settings.
228
229        @param monitor_device string name of device to configure.
230        @param frequency int WiFi frequency to dwell on.
231        @param ht_type string optional HT type ('HT20', 'HT40+', or 'HT40-').
232
233        """
234        channel_args = str(frequency)
235        if ht_type:
236            ht_type = ht_type.upper()
237            channel_args = '%s %s' % (channel_args, ht_type)
238            if ht_type not in ('HT20', 'HT40+', 'HT40-'):
239                raise error.TestError('Cannot set HT mode: %s', ht_type)
240
241        self._host.run("%s link set %s up" % (self._cmd_ip, monitor_device))
242        self._host.run("%s dev %s set freq %s" % (self._cmd_iw,
243                                                  monitor_device,
244                                                  channel_args))
245
246
247    def create_managed_monitor(self, existing_dev, monitor_device=None):
248        """Create a monitor type WiFi interface next to a managed interface.
249
250        If a device called |monitor_device| already exists, it is first removed.
251
252        @param existing_device string existing interface (e.g. mlan0).
253        @param monitor_device string name of monitor interface to create.
254        @return string monitor device name created or None on failure.
255
256        """
257        if not monitor_device:
258            monitor_device = 'mon%d' % self._if_num
259            self._if_num += 1
260        self._host.run('%s dev %s del' % (self._cmd_iw, monitor_device),
261                       ignore_status=True)
262        result = self._host.run('%s dev %s interface add %s type monitor' %
263                                (self._cmd_iw,
264                                 existing_dev,
265                                 monitor_device),
266                                ignore_status=True)
267        if result.exit_status:
268            logging.warning('Failed creating monitor.')
269            return None
270
271        self._host.run('%s %s up' % (self._cmd_ifconfig, monitor_device))
272        self._created_managed_devices.append(monitor_device)
273        return monitor_device
274
275
276    def _is_capture_active(self, remote_log_file):
277        """Check if a packet capture has completed initialization.
278
279        @param remote_log_file string path to the capture's log file
280        @return True iff log file indicates that tcpdump is listening.
281        """
282        return self._host.run(
283            'grep "listening on" "%s"' % remote_log_file, ignore_status=True
284            ).exit_status == 0
285
286
287    def start_capture(self, interface, local_save_dir,
288                      remote_file=None, snaplen=None):
289        """Start a packet capture on an existing interface.
290
291        @param interface string existing interface to capture on.
292        @param local_save_dir string directory on local machine to hold results.
293        @param remote_file string full path on remote host to hold the capture.
294        @param snaplen int maximum captured frame length.
295        @return int pid of started packet capture.
296
297        """
298        remote_file = (remote_file or
299                       '/tmp/%s.%d.pcap' % (self._host_description,
300                                            self._cap_num))
301        self._cap_num += 1
302        remote_log_file = '%s.log' % remote_file
303        # Redirect output because SSH refuses to return until the child file
304        # descriptors are closed.
305        cmd = '%s -U -i %s -w %s -s %d >%s 2>&1 & echo $!' % (
306            self._cmd_netdump,
307            interface,
308            remote_file,
309            snaplen or 0,
310            remote_log_file)
311        logging.debug('Starting managed packet capture')
312        pid = int(self._host.run(cmd).stdout)
313        self._ongoing_captures[pid] = (remote_file,
314                                       remote_log_file,
315                                       local_save_dir)
316        is_capture_active = lambda: self._is_capture_active(remote_log_file)
317        site_utils.poll_for_condition(
318            is_capture_active,
319            timeout=TCPDUMP_START_TIMEOUT_SECONDS,
320            sleep_interval=TCPDUMP_START_POLL_SECONDS,
321            desc='Timeout waiting for tcpdump to start.')
322        return pid
323
324
325    def stop_capture(self, capture_pid=None, local_save_dir=None,
326                     local_pcap_filename=None):
327        """Stop an ongoing packet capture, or all ongoing packet captures.
328
329        If |capture_pid| is given, stops that capture, otherwise stops all
330        ongoing captures.
331
332        This method will sleep for a small amount of time, to ensure that
333        libpcap has completed its last poll(). The caller must ensure that
334        no unwanted traffic is received during this time.
335
336        @param capture_pid int pid of ongoing packet capture or None.
337        @param local_save_dir path to directory to save pcap file in locally.
338        @param local_pcap_filename name of file to store pcap in
339                (basename only).
340        @return list of RemoteCaptureResult tuples
341
342        """
343        time.sleep(self.LIBPCAP_POLL_FREQ_SECS * 2)
344
345        if capture_pid:
346            pids_to_kill = [capture_pid]
347        else:
348            pids_to_kill = list(self._ongoing_captures.keys())
349
350        results = []
351        for pid in pids_to_kill:
352            self._host.run('kill -INT %d' % pid, ignore_status=True)
353            remote_pcap, remote_pcap_log, save_dir = self._ongoing_captures[pid]
354            pcap_filename = os.path.basename(remote_pcap)
355            pcap_log_filename = os.path.basename(remote_pcap_log)
356            if local_pcap_filename:
357                pcap_filename = os.path.join(local_save_dir or save_dir,
358                                             local_pcap_filename)
359                pcap_log_filename = os.path.join(local_save_dir or save_dir,
360                                                 '%s.log' % local_pcap_filename)
361            pairs = [(remote_pcap, pcap_filename),
362                     (remote_pcap_log, pcap_log_filename)]
363
364            for remote_file, local_file in pairs:
365                self._host.get_file(remote_file, local_file)
366                self._host.run('rm -f %s' % remote_file)
367
368            self._ongoing_captures.pop(pid)
369            results.append(CaptureResult(pcap_filename,
370                                         pcap_log_filename))
371        return results
372