1#!/usr/bin/env python3.4
2#
3#   Copyright 2019 - The Android Open Source Project
4#
5#   Licensed under the Apache License, Version 2.0 (the 'License');
6#   you may not use this file except in compliance with the License.
7#   You may obtain a copy of the License at
8#
9#       http://www.apache.org/licenses/LICENSE-2.0
10#
11#   Unless required by applicable law or agreed to in writing, software
12#   distributed under the License is distributed on an 'AS IS' BASIS,
13#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14#   See the License for the specific language governing permissions and
15#   limitations under the License.
16
17import bokeh, bokeh.plotting, bokeh.io
18import collections
19import hashlib
20import ipaddress
21import itertools
22import json
23import logging
24import math
25import os
26import re
27import statistics
28import time
29from acts.controllers.android_device import AndroidDevice
30from acts.controllers.utils_lib import ssh
31from acts import asserts
32from acts import utils
33from acts_contrib.test_utils.wifi import wifi_test_utils as wutils
34from concurrent.futures import ThreadPoolExecutor
35
36SHORT_SLEEP = 1
37MED_SLEEP = 6
38TEST_TIMEOUT = 10
39STATION_DUMP = 'iw wlan0 station dump'
40SCAN = 'wpa_cli scan'
41SCAN_RESULTS = 'wpa_cli scan_results'
42SIGNAL_POLL = 'wpa_cli signal_poll'
43WPA_CLI_STATUS = 'wpa_cli status'
44DISCONNECTION_MESSAGE_BRCM = 'driver adapter not found'
45CONST_3dB = 3.01029995664
46RSSI_ERROR_VAL = float('nan')
47RTT_REGEX = re.compile(r'^\[(?P<timestamp>\S+)\] .*? time=(?P<rtt>\S+)')
48LOSS_REGEX = re.compile(r'(?P<loss>\S+)% packet loss')
49FW_REGEX = re.compile(r'FW:(?P<firmware>\S+) HW:')
50
51
52# Threading decorator
53def nonblocking(f):
54    """Creates a decorator transforming function calls to non-blocking"""
55    def wrap(*args, **kwargs):
56        executor = ThreadPoolExecutor(max_workers=1)
57        thread_future = executor.submit(f, *args, **kwargs)
58        # Ensure resources are freed up when executor ruturns or raises
59        executor.shutdown(wait=False)
60        return thread_future
61
62    return wrap
63
64
65# JSON serializer
66def serialize_dict(input_dict):
67    """Function to serialize dicts to enable JSON output"""
68    output_dict = collections.OrderedDict()
69    for key, value in input_dict.items():
70        output_dict[_serialize_value(key)] = _serialize_value(value)
71    return output_dict
72
73
74def _serialize_value(value):
75    """Function to recursively serialize dict entries to enable JSON output"""
76    if isinstance(value, tuple):
77        return str(value)
78    if isinstance(value, list):
79        return [_serialize_value(x) for x in value]
80    elif isinstance(value, dict):
81        return serialize_dict(value)
82    else:
83        return value
84
85
86# Miscellaneous Wifi Utilities
87def extract_sub_dict(full_dict, fields):
88    sub_dict = collections.OrderedDict(
89        (field, full_dict[field]) for field in fields)
90    return sub_dict
91
92
93def validate_network(dut, ssid):
94    """Check that DUT has a valid internet connection through expected SSID
95
96    Args:
97        dut: android device of interest
98        ssid: expected ssid
99    """
100    current_network = dut.droid.wifiGetConnectionInfo()
101    try:
102        connected = wutils.validate_connection(dut) is not None
103    except:
104        connected = False
105    if connected and current_network['SSID'] == ssid:
106        return True
107    else:
108        return False
109
110
111def get_server_address(ssh_connection, dut_ip, subnet_mask):
112    """Get server address on a specific subnet,
113
114    This function retrieves the LAN or WAN IP of a remote machine used in
115    testing. If subnet_mask is set to 'public' it returns a machines global ip,
116    else it returns the ip belonging to the dut local network given the dut's
117    ip and subnet mask.
118
119    Args:
120        ssh_connection: object representing server for which we want an ip
121        dut_ip: string in ip address format, i.e., xxx.xxx.xxx.xxx
122        subnet_mask: string representing subnet mask (public for global ip)
123    """
124    ifconfig_out = ssh_connection.run('ifconfig').stdout
125    ip_list = re.findall('inet (?:addr:)?(\d+.\d+.\d+.\d+)', ifconfig_out)
126    ip_list = [ipaddress.ip_address(ip) for ip in ip_list]
127
128    if subnet_mask == 'public':
129        for ip in ip_list:
130            # is_global is not used to allow for CGNAT ips in 100.x.y.z range
131            if not ip.is_private:
132                return str(ip)
133    else:
134        dut_network = ipaddress.ip_network('{}/{}'.format(dut_ip, subnet_mask),
135                                           strict=False)
136        for ip in ip_list:
137            if ip in dut_network:
138                return str(ip)
139    logging.error('No IP address found in requested subnet')
140
141
142# Plotting Utilities
143class BokehFigure():
144    """Class enabling  simplified Bokeh plotting."""
145
146    COLORS = [
147        'black',
148        'blue',
149        'blueviolet',
150        'brown',
151        'burlywood',
152        'cadetblue',
153        'cornflowerblue',
154        'crimson',
155        'cyan',
156        'darkblue',
157        'darkgreen',
158        'darkmagenta',
159        'darkorange',
160        'darkred',
161        'deepskyblue',
162        'goldenrod',
163        'green',
164        'grey',
165        'indigo',
166        'navy',
167        'olive',
168        'orange',
169        'red',
170        'salmon',
171        'teal',
172        'yellow',
173    ]
174    MARKERS = [
175        'asterisk', 'circle', 'circle_cross', 'circle_x', 'cross', 'diamond',
176        'diamond_cross', 'hex', 'inverted_triangle', 'square', 'square_x',
177        'square_cross', 'triangle', 'x'
178    ]
179
180    TOOLS = ('box_zoom,box_select,pan,crosshair,redo,undo,reset,hover,save')
181    TOOLTIPS = [
182        ('index', '$index'),
183        ('(x,y)', '($x, $y)'),
184        ('info', '@hover_text'),
185    ]
186
187    def __init__(self,
188                 title=None,
189                 x_label=None,
190                 primary_y_label=None,
191                 secondary_y_label=None,
192                 height=700,
193                 width=1100,
194                 title_size='15pt',
195                 axis_label_size='12pt',
196                 json_file=None):
197        if json_file:
198            self.load_from_json(json_file)
199        else:
200            self.figure_data = []
201            self.fig_property = {
202                'title': title,
203                'x_label': x_label,
204                'primary_y_label': primary_y_label,
205                'secondary_y_label': secondary_y_label,
206                'num_lines': 0,
207                'height': height,
208                'width': width,
209                'title_size': title_size,
210                'axis_label_size': axis_label_size
211            }
212
213    def init_plot(self):
214        self.plot = bokeh.plotting.figure(
215            sizing_mode='scale_both',
216            plot_width=self.fig_property['width'],
217            plot_height=self.fig_property['height'],
218            title=self.fig_property['title'],
219            tools=self.TOOLS,
220            output_backend='webgl')
221        self.plot.hover.tooltips = self.TOOLTIPS
222        self.plot.add_tools(
223            bokeh.models.tools.WheelZoomTool(dimensions='width'))
224        self.plot.add_tools(
225            bokeh.models.tools.WheelZoomTool(dimensions='height'))
226
227    def _filter_line(self, x_data, y_data, hover_text=None):
228        """Function to remove NaN points from bokeh plots."""
229        x_data_filtered = []
230        y_data_filtered = []
231        hover_text_filtered = []
232        for x, y, hover in itertools.zip_longest(x_data, y_data, hover_text):
233            if not math.isnan(y):
234                x_data_filtered.append(x)
235                y_data_filtered.append(y)
236                hover_text_filtered.append(hover)
237        return x_data_filtered, y_data_filtered, hover_text_filtered
238
239    def add_line(self,
240                 x_data,
241                 y_data,
242                 legend,
243                 hover_text=None,
244                 color=None,
245                 width=3,
246                 style='solid',
247                 marker=None,
248                 marker_size=10,
249                 shaded_region=None,
250                 y_axis='default'):
251        """Function to add line to existing BokehFigure.
252
253        Args:
254            x_data: list containing x-axis values for line
255            y_data: list containing y_axis values for line
256            legend: string containing line title
257            hover_text: text to display when hovering over lines
258            color: string describing line color
259            width: integer line width
260            style: string describing line style, e.g, solid or dashed
261            marker: string specifying line marker, e.g., cross
262            shaded region: data describing shaded region to plot
263            y_axis: identifier for y-axis to plot line against
264        """
265        if y_axis not in ['default', 'secondary']:
266            raise ValueError('y_axis must be default or secondary')
267        if color == None:
268            color = self.COLORS[self.fig_property['num_lines'] %
269                                len(self.COLORS)]
270        if style == 'dashed':
271            style = [5, 5]
272        if not hover_text:
273            hover_text = ['y={}'.format(y) for y in y_data]
274        x_data_filter, y_data_filter, hover_text_filter = self._filter_line(
275            x_data, y_data, hover_text)
276        self.figure_data.append({
277            'x_data': x_data_filter,
278            'y_data': y_data_filter,
279            'legend': legend,
280            'hover_text': hover_text_filter,
281            'color': color,
282            'width': width,
283            'style': style,
284            'marker': marker,
285            'marker_size': marker_size,
286            'shaded_region': shaded_region,
287            'y_axis': y_axis
288        })
289        self.fig_property['num_lines'] += 1
290
291    def add_scatter(self,
292                    x_data,
293                    y_data,
294                    legend,
295                    hover_text=None,
296                    color=None,
297                    marker=None,
298                    marker_size=10,
299                    y_axis='default'):
300        """Function to add line to existing BokehFigure.
301
302        Args:
303            x_data: list containing x-axis values for line
304            y_data: list containing y_axis values for line
305            legend: string containing line title
306            hover_text: text to display when hovering over lines
307            color: string describing line color
308            marker: string specifying marker, e.g., cross
309            y_axis: identifier for y-axis to plot line against
310        """
311        if y_axis not in ['default', 'secondary']:
312            raise ValueError('y_axis must be default or secondary')
313        if color == None:
314            color = self.COLORS[self.fig_property['num_lines'] %
315                                len(self.COLORS)]
316        if marker == None:
317            marker = self.MARKERS[self.fig_property['num_lines'] %
318                                  len(self.MARKERS)]
319        if not hover_text:
320            hover_text = ['y={}'.format(y) for y in y_data]
321        self.figure_data.append({
322            'x_data': x_data,
323            'y_data': y_data,
324            'legend': legend,
325            'hover_text': hover_text,
326            'color': color,
327            'width': 0,
328            'style': 'solid',
329            'marker': marker,
330            'marker_size': marker_size,
331            'shaded_region': None,
332            'y_axis': y_axis
333        })
334        self.fig_property['num_lines'] += 1
335
336    def generate_figure(self, output_file=None, save_json=True):
337        """Function to generate and save BokehFigure.
338
339        Args:
340            output_file: string specifying output file path
341        """
342        self.init_plot()
343        two_axes = False
344        for line in self.figure_data:
345            source = bokeh.models.ColumnDataSource(
346                data=dict(x=line['x_data'],
347                          y=line['y_data'],
348                          hover_text=line['hover_text']))
349            if line['width'] > 0:
350                self.plot.line(x='x',
351                               y='y',
352                               legend_label=line['legend'],
353                               line_width=line['width'],
354                               color=line['color'],
355                               line_dash=line['style'],
356                               name=line['y_axis'],
357                               y_range_name=line['y_axis'],
358                               source=source)
359            if line['shaded_region']:
360                band_x = line['shaded_region']['x_vector']
361                band_x.extend(line['shaded_region']['x_vector'][::-1])
362                band_y = line['shaded_region']['lower_limit']
363                band_y.extend(line['shaded_region']['upper_limit'][::-1])
364                self.plot.patch(band_x,
365                                band_y,
366                                color='#7570B3',
367                                line_alpha=0.1,
368                                fill_alpha=0.1)
369            if line['marker'] in self.MARKERS:
370                marker_func = getattr(self.plot, line['marker'])
371                marker_func(x='x',
372                            y='y',
373                            size=line['marker_size'],
374                            legend_label=line['legend'],
375                            line_color=line['color'],
376                            fill_color=line['color'],
377                            name=line['y_axis'],
378                            y_range_name=line['y_axis'],
379                            source=source)
380            if line['y_axis'] == 'secondary':
381                two_axes = True
382
383        #x-axis formatting
384        self.plot.xaxis.axis_label = self.fig_property['x_label']
385        self.plot.x_range.range_padding = 0
386        self.plot.xaxis[0].axis_label_text_font_size = self.fig_property[
387            'axis_label_size']
388        #y-axis formatting
389        self.plot.yaxis[0].axis_label = self.fig_property['primary_y_label']
390        self.plot.yaxis[0].axis_label_text_font_size = self.fig_property[
391            'axis_label_size']
392        self.plot.y_range = bokeh.models.DataRange1d(names=['default'])
393        if two_axes and 'secondary' not in self.plot.extra_y_ranges:
394            self.plot.extra_y_ranges = {
395                'secondary': bokeh.models.DataRange1d(names=['secondary'])
396            }
397            self.plot.add_layout(
398                bokeh.models.LinearAxis(
399                    y_range_name='secondary',
400                    axis_label=self.fig_property['secondary_y_label'],
401                    axis_label_text_font_size=self.
402                    fig_property['axis_label_size']), 'right')
403        # plot formatting
404        self.plot.legend.location = 'top_right'
405        self.plot.legend.click_policy = 'hide'
406        self.plot.title.text_font_size = self.fig_property['title_size']
407
408        if output_file is not None:
409            self.save_figure(output_file, save_json)
410        return self.plot
411
412    def load_from_json(self, file_path):
413        with open(file_path, 'r') as json_file:
414            fig_dict = json.load(json_file)
415        self.fig_property = fig_dict['fig_property']
416        self.figure_data = fig_dict['figure_data']
417
418    def _save_figure_json(self, output_file):
419        """Function to save a json format of a figure"""
420        figure_dict = collections.OrderedDict(fig_property=self.fig_property,
421                                              figure_data=self.figure_data)
422        output_file = output_file.replace('.html', '_plot_data.json')
423        with open(output_file, 'w') as outfile:
424            json.dump(figure_dict, outfile, indent=4)
425
426    def save_figure(self, output_file, save_json=True):
427        """Function to save BokehFigure.
428
429        Args:
430            output_file: string specifying output file path
431            save_json: flag controlling json outputs
432        """
433        if save_json:
434            self._save_figure_json(output_file)
435        bokeh.io.output_file(output_file)
436        bokeh.io.save(self.plot)
437
438    @staticmethod
439    def save_figures(figure_array, output_file_path, save_json=True):
440        """Function to save list of BokehFigures in one file.
441
442        Args:
443            figure_array: list of BokehFigure object to be plotted
444            output_file: string specifying output file path
445        """
446        for idx, figure in enumerate(figure_array):
447            figure.generate_figure()
448            if save_json:
449                json_file_path = output_file_path.replace(
450                    '.html', '{}-plot_data.json'.format(idx))
451                figure._save_figure_json(json_file_path)
452        plot_array = [figure.plot for figure in figure_array]
453        all_plots = bokeh.layouts.column(children=plot_array,
454                                         sizing_mode='scale_width')
455        bokeh.plotting.output_file(output_file_path)
456        bokeh.plotting.save(all_plots)
457
458
459# Ping utilities
460class PingResult(object):
461    """An object that contains the results of running ping command.
462
463    Attributes:
464        connected: True if a connection was made. False otherwise.
465        packet_loss_percentage: The total percentage of packets lost.
466        transmission_times: The list of PingTransmissionTimes containing the
467            timestamps gathered for transmitted packets.
468        rtts: An list-like object enumerating all round-trip-times of
469            transmitted packets.
470        timestamps: A list-like object enumerating the beginning timestamps of
471            each packet transmission.
472        ping_interarrivals: A list-like object enumerating the amount of time
473            between the beginning of each subsequent transmission.
474    """
475    def __init__(self, ping_output):
476        self.packet_loss_percentage = 100
477        self.transmission_times = []
478
479        self.rtts = _ListWrap(self.transmission_times, lambda entry: entry.rtt)
480        self.timestamps = _ListWrap(self.transmission_times,
481                                    lambda entry: entry.timestamp)
482        self.ping_interarrivals = _PingInterarrivals(self.transmission_times)
483
484        self.start_time = 0
485        for line in ping_output:
486            if 'loss' in line:
487                match = re.search(LOSS_REGEX, line)
488                self.packet_loss_percentage = float(match.group('loss'))
489            if 'time=' in line:
490                match = re.search(RTT_REGEX, line)
491                if self.start_time == 0:
492                    self.start_time = float(match.group('timestamp'))
493                self.transmission_times.append(
494                    PingTransmissionTimes(
495                        float(match.group('timestamp')) - self.start_time,
496                        float(match.group('rtt'))))
497        self.connected = len(
498            ping_output) > 1 and self.packet_loss_percentage < 100
499
500    def __getitem__(self, item):
501        if item == 'rtt':
502            return self.rtts
503        if item == 'connected':
504            return self.connected
505        if item == 'packet_loss_percentage':
506            return self.packet_loss_percentage
507        raise ValueError('Invalid key. Please use an attribute instead.')
508
509    def as_dict(self):
510        return {
511            'connected': 1 if self.connected else 0,
512            'rtt': list(self.rtts),
513            'time_stamp': list(self.timestamps),
514            'ping_interarrivals': list(self.ping_interarrivals),
515            'packet_loss_percentage': self.packet_loss_percentage
516        }
517
518
519class PingTransmissionTimes(object):
520    """A class that holds the timestamps for a packet sent via the ping command.
521
522    Attributes:
523        rtt: The round trip time for the packet sent.
524        timestamp: The timestamp the packet started its trip.
525    """
526    def __init__(self, timestamp, rtt):
527        self.rtt = rtt
528        self.timestamp = timestamp
529
530
531class _ListWrap(object):
532    """A convenient helper class for treating list iterators as native lists."""
533    def __init__(self, wrapped_list, func):
534        self.__wrapped_list = wrapped_list
535        self.__func = func
536
537    def __getitem__(self, key):
538        return self.__func(self.__wrapped_list[key])
539
540    def __iter__(self):
541        for item in self.__wrapped_list:
542            yield self.__func(item)
543
544    def __len__(self):
545        return len(self.__wrapped_list)
546
547
548class _PingInterarrivals(object):
549    """A helper class for treating ping interarrivals as a native list."""
550    def __init__(self, ping_entries):
551        self.__ping_entries = ping_entries
552
553    def __getitem__(self, key):
554        return (self.__ping_entries[key + 1].timestamp -
555                self.__ping_entries[key].timestamp)
556
557    def __iter__(self):
558        for index in range(len(self.__ping_entries) - 1):
559            yield self[index]
560
561    def __len__(self):
562        return max(0, len(self.__ping_entries) - 1)
563
564
565def get_ping_stats(src_device, dest_address, ping_duration, ping_interval,
566                   ping_size):
567    """Run ping to or from the DUT.
568
569    The function computes either pings the DUT or pings a remote ip from
570    DUT.
571
572    Args:
573        src_device: object representing device to ping from
574        dest_address: ip address to ping
575        ping_duration: timeout to set on the the ping process (in seconds)
576        ping_interval: time between pings (in seconds)
577        ping_size: size of ping packet payload
578    Returns:
579        ping_result: dict containing ping results and other meta data
580    """
581    ping_count = int(ping_duration / ping_interval)
582    ping_deadline = int(ping_count * ping_interval) + 1
583    ping_cmd_linux = 'ping -c {} -w {} -i {} -s {} -D'.format(
584        ping_count,
585        ping_deadline,
586        ping_interval,
587        ping_size,
588    )
589
590    ping_cmd_macos = 'ping -c {} -t {} -i {} -s {}'.format(
591        ping_count,
592        ping_deadline,
593        ping_interval,
594        ping_size,
595    )
596
597    if isinstance(src_device, AndroidDevice):
598        ping_cmd = '{} {}'.format(ping_cmd_linux, dest_address)
599        ping_output = src_device.adb.shell(ping_cmd,
600                                           timeout=ping_deadline + SHORT_SLEEP,
601                                           ignore_status=True)
602    elif isinstance(src_device, ssh.connection.SshConnection):
603        platform = src_device.run('uname').stdout
604        if 'linux' in platform.lower():
605            ping_cmd = 'sudo {} {}'.format(ping_cmd_linux, dest_address)
606        elif 'darwin' in platform.lower():
607            ping_cmd = "sudo {} {}| while IFS= read -r line; do printf '[%s] %s\n' \"$(gdate '+%s.%N')\" \"$line\"; done".format(
608                ping_cmd_macos, dest_address)
609        ping_output = src_device.run(ping_cmd,
610                                     timeout=ping_deadline + SHORT_SLEEP,
611                                     ignore_status=True).stdout
612    else:
613        raise TypeError('Unable to ping using src_device of type %s.' %
614                        type(src_device))
615    return PingResult(ping_output.splitlines())
616
617
618@nonblocking
619def get_ping_stats_nb(src_device, dest_address, ping_duration, ping_interval,
620                      ping_size):
621    return get_ping_stats(src_device, dest_address, ping_duration,
622                          ping_interval, ping_size)
623
624
625# Iperf utilities
626@nonblocking
627def start_iperf_client_nb(iperf_client, iperf_server_address, iperf_args, tag,
628                          timeout):
629    return iperf_client.start(iperf_server_address, iperf_args, tag, timeout)
630
631
632def get_iperf_arg_string(duration,
633                         reverse_direction,
634                         interval=1,
635                         traffic_type='TCP',
636                         socket_size=None,
637                         num_processes=1,
638                         udp_throughput='1000M',
639                         ipv6=False):
640    """Function to format iperf client arguments.
641
642    This function takes in iperf client parameters and returns a properly
643    formatter iperf arg string to be used in throughput tests.
644
645    Args:
646        duration: iperf duration in seconds
647        reverse_direction: boolean controlling the -R flag for iperf clients
648        interval: iperf print interval
649        traffic_type: string specifying TCP or UDP traffic
650        socket_size: string specifying TCP window or socket buffer, e.g., 2M
651        num_processes: int specifying number of iperf processes
652        udp_throughput: string specifying TX throughput in UDP tests, e.g. 100M
653        ipv6: boolean controlling the use of IP V6
654    Returns:
655        iperf_args: string of formatted iperf args
656    """
657    iperf_args = '-i {} -t {} -J '.format(interval, duration)
658    if ipv6:
659        iperf_args = iperf_args + '-6 '
660    if traffic_type.upper() == 'UDP':
661        iperf_args = iperf_args + '-u -b {} -l 1470 -P {} '.format(
662            udp_throughput, num_processes)
663    elif traffic_type.upper() == 'TCP':
664        iperf_args = iperf_args + '-P {} '.format(num_processes)
665    if socket_size:
666        iperf_args = iperf_args + '-w {} '.format(socket_size)
667    if reverse_direction:
668        iperf_args = iperf_args + ' -R'
669    return iperf_args
670
671
672# Attenuator Utilities
673def atten_by_label(atten_list, path_label, atten_level):
674    """Attenuate signals according to their path label.
675
676    Args:
677        atten_list: list of attenuators to iterate over
678        path_label: path label on which to set desired attenuation
679        atten_level: attenuation desired on path
680    """
681    for atten in atten_list:
682        if path_label in atten.path:
683            atten.set_atten(atten_level)
684
685
686def get_atten_for_target_rssi(target_rssi, attenuators, dut, ping_server):
687    """Function to estimate attenuation to hit a target RSSI.
688
689    This function estimates a constant attenuation setting on all atennuation
690    ports to hit a target RSSI. The estimate is not meant to be exact or
691    guaranteed.
692
693    Args:
694        target_rssi: rssi of interest
695        attenuators: list of attenuator ports
696        dut: android device object assumed connected to a wifi network.
697        ping_server: ssh connection object to ping server
698    Returns:
699        target_atten: attenuation setting to achieve target_rssi
700    """
701    logging.info('Searching attenuation for RSSI = {}dB'.format(target_rssi))
702    # Set attenuator to 0 dB
703    for atten in attenuators:
704        atten.set_atten(0, strict=False)
705    # Start ping traffic
706    dut_ip = dut.droid.connectivityGetIPv4Addresses('wlan0')[0]
707    # Measure starting RSSI
708    ping_future = get_ping_stats_nb(src_device=ping_server,
709                                    dest_address=dut_ip,
710                                    ping_duration=1.5,
711                                    ping_interval=0.02,
712                                    ping_size=64)
713    current_rssi = get_connected_rssi(dut,
714                                      num_measurements=4,
715                                      polling_frequency=0.25,
716                                      first_measurement_delay=0.5,
717                                      disconnect_warning=1,
718                                      ignore_samples=1)
719    current_rssi = current_rssi['signal_poll_rssi']['mean']
720    ping_future.result()
721    target_atten = 0
722    logging.debug('RSSI @ {0:.2f}dB attenuation = {1:.2f}'.format(
723        target_atten, current_rssi))
724    within_range = 0
725    for idx in range(20):
726        atten_delta = max(min(current_rssi - target_rssi, 20), -20)
727        target_atten = int((target_atten + atten_delta) * 4) / 4
728        if target_atten < 0:
729            return 0
730        if target_atten > attenuators[0].get_max_atten():
731            return attenuators[0].get_max_atten()
732        for atten in attenuators:
733            atten.set_atten(target_atten, strict=False)
734        ping_future = get_ping_stats_nb(src_device=ping_server,
735                                        dest_address=dut_ip,
736                                        ping_duration=1.5,
737                                        ping_interval=0.02,
738                                        ping_size=64)
739        current_rssi = get_connected_rssi(dut,
740                                          num_measurements=4,
741                                          polling_frequency=0.25,
742                                          first_measurement_delay=0.5,
743                                          disconnect_warning=1,
744                                          ignore_samples=1)
745        current_rssi = current_rssi['signal_poll_rssi']['mean']
746        ping_future.result()
747        logging.info('RSSI @ {0:.2f}dB attenuation = {1:.2f}'.format(
748            target_atten, current_rssi))
749        if abs(current_rssi - target_rssi) < 1:
750            if within_range:
751                logging.info(
752                    'Reached RSSI: {0:.2f}. Target RSSI: {1:.2f}.'
753                    'Attenuation: {2:.2f}, Iterations = {3:.2f}'.format(
754                        current_rssi, target_rssi, target_atten, idx))
755                return target_atten
756            else:
757                within_range = True
758        else:
759            within_range = False
760    return target_atten
761
762
763def get_current_atten_dut_chain_map(attenuators,
764                                    dut,
765                                    ping_server,
766                                    ping_from_dut=False):
767    """Function to detect mapping between attenuator ports and DUT chains.
768
769    This function detects the mapping between attenuator ports and DUT chains
770    in cases where DUT chains are connected to only one attenuator port. The
771    function assumes the DUT is already connected to a wifi network. The
772    function starts by measuring per chain RSSI at 0 attenuation, then
773    attenuates one port at a time looking for the chain that reports a lower
774    RSSI.
775
776    Args:
777        attenuators: list of attenuator ports
778        dut: android device object assumed connected to a wifi network.
779        ping_server: ssh connection object to ping server
780        ping_from_dut: boolean controlling whether to ping from or to dut
781    Returns:
782        chain_map: list of dut chains, one entry per attenuator port
783    """
784    # Set attenuator to 0 dB
785    for atten in attenuators:
786        atten.set_atten(0, strict=False)
787    # Start ping traffic
788    dut_ip = dut.droid.connectivityGetIPv4Addresses('wlan0')[0]
789    if ping_from_dut:
790        ping_future = get_ping_stats_nb(dut, ping_server._settings.hostname,
791                                        11, 0.02, 64)
792    else:
793        ping_future = get_ping_stats_nb(ping_server, dut_ip, 11, 0.02, 64)
794    # Measure starting RSSI
795    base_rssi = get_connected_rssi(dut, 4, 0.25, 1)
796    chain0_base_rssi = base_rssi['chain_0_rssi']['mean']
797    chain1_base_rssi = base_rssi['chain_1_rssi']['mean']
798    if chain0_base_rssi < -70 or chain1_base_rssi < -70:
799        logging.warning('RSSI might be too low to get reliable chain map.')
800    # Compile chain map by attenuating one path at a time and seeing which
801    # chain's RSSI degrades
802    chain_map = []
803    for test_atten in attenuators:
804        # Set one attenuator to 30 dB down
805        test_atten.set_atten(30, strict=False)
806        # Get new RSSI
807        test_rssi = get_connected_rssi(dut, 4, 0.25, 1)
808        # Assign attenuator to path that has lower RSSI
809        if chain0_base_rssi > -70 and chain0_base_rssi - test_rssi[
810                'chain_0_rssi']['mean'] > 10:
811            chain_map.append('DUT-Chain-0')
812        elif chain1_base_rssi > -70 and chain1_base_rssi - test_rssi[
813                'chain_1_rssi']['mean'] > 10:
814            chain_map.append('DUT-Chain-1')
815        else:
816            chain_map.append(None)
817        # Reset attenuator to 0
818        test_atten.set_atten(0, strict=False)
819    ping_future.result()
820    logging.debug('Chain Map: {}'.format(chain_map))
821    return chain_map
822
823
824def get_full_rf_connection_map(attenuators,
825                               dut,
826                               ping_server,
827                               networks,
828                               ping_from_dut=False):
829    """Function to detect per-network connections between attenuator and DUT.
830
831    This function detects the mapping between attenuator ports and DUT chains
832    on all networks in its arguments. The function connects the DUT to each
833    network then calls get_current_atten_dut_chain_map to get the connection
834    map on the current network. The function outputs the results in two formats
835    to enable easy access when users are interested in indexing by network or
836    attenuator port.
837
838    Args:
839        attenuators: list of attenuator ports
840        dut: android device object assumed connected to a wifi network.
841        ping_server: ssh connection object to ping server
842        networks: dict of network IDs and configs
843    Returns:
844        rf_map_by_network: dict of RF connections indexed by network.
845        rf_map_by_atten: list of RF connections indexed by attenuator
846    """
847    for atten in attenuators:
848        atten.set_atten(0, strict=False)
849
850    rf_map_by_network = collections.OrderedDict()
851    rf_map_by_atten = [[] for atten in attenuators]
852    for net_id, net_config in networks.items():
853        wutils.reset_wifi(dut)
854        wutils.wifi_connect(dut,
855                            net_config,
856                            num_of_tries=1,
857                            assert_on_fail=False,
858                            check_connectivity=False)
859        rf_map_by_network[net_id] = get_current_atten_dut_chain_map(
860            attenuators, dut, ping_server, ping_from_dut)
861        for idx, chain in enumerate(rf_map_by_network[net_id]):
862            if chain:
863                rf_map_by_atten[idx].append({
864                    'network': net_id,
865                    'dut_chain': chain
866                })
867    logging.debug('RF Map (by Network): {}'.format(rf_map_by_network))
868    logging.debug('RF Map (by Atten): {}'.format(rf_map_by_atten))
869
870    return rf_map_by_network, rf_map_by_atten
871
872
873# Generic device utils
874def get_dut_temperature(dut):
875    """Function to get dut temperature.
876
877    The function fetches and returns the reading from the temperature sensor
878    used for skin temperature and thermal throttling.
879
880    Args:
881        dut: AndroidDevice of interest
882    Returns:
883        temperature: device temperature. 0 if temperature could not be read
884    """
885    candidate_zones = [
886        '/sys/devices/virtual/thermal/tz-by-name/skin-therm/temp',
887        '/sys/devices/virtual/thermal/tz-by-name/sdm-therm-monitor/temp',
888        '/sys/devices/virtual/thermal/tz-by-name/sdm-therm-adc/temp',
889        '/sys/devices/virtual/thermal/tz-by-name/back_therm/temp',
890        '/dev/thermal/tz-by-name/quiet_therm/temp'
891    ]
892    for zone in candidate_zones:
893        try:
894            temperature = int(dut.adb.shell('cat {}'.format(zone)))
895            break
896        except:
897            temperature = 0
898    if temperature == 0:
899        logging.debug('Could not check DUT temperature.')
900    elif temperature > 100:
901        temperature = temperature / 1000
902    return temperature
903
904
905def wait_for_dut_cooldown(dut, target_temp=50, timeout=300):
906    """Function to wait for a DUT to cool down.
907
908    Args:
909        dut: AndroidDevice of interest
910        target_temp: target cooldown temperature
911        timeout: maxt time to wait for cooldown
912    """
913    start_time = time.time()
914    while time.time() - start_time < timeout:
915        temperature = get_dut_temperature(dut)
916        if temperature < target_temp:
917            break
918        time.sleep(SHORT_SLEEP)
919    elapsed_time = time.time() - start_time
920    logging.debug('DUT Final Temperature: {}C. Cooldown duration: {}'.format(
921        temperature, elapsed_time))
922
923
924def health_check(dut, batt_thresh=5, temp_threshold=53, cooldown=1):
925    """Function to check health status of a DUT.
926
927    The function checks both battery levels and temperature to avoid DUT
928    powering off during the test.
929
930    Args:
931        dut: AndroidDevice of interest
932        batt_thresh: battery level threshold
933        temp_threshold: temperature threshold
934        cooldown: flag to wait for DUT to cool down when overheating
935    Returns:
936        health_check: boolean confirming device is healthy
937    """
938    health_check = True
939    battery_level = utils.get_battery_level(dut)
940    if battery_level < batt_thresh:
941        logging.warning('Battery level low ({}%)'.format(battery_level))
942        health_check = False
943    else:
944        logging.debug('Battery level = {}%'.format(battery_level))
945
946    temperature = get_dut_temperature(dut)
947    if temperature > temp_threshold:
948        if cooldown:
949            logging.warning(
950                'Waiting for DUT to cooldown. ({} C)'.format(temperature))
951            wait_for_dut_cooldown(dut, target_temp=temp_threshold - 5)
952        else:
953            logging.warning('DUT Overheating ({} C)'.format(temperature))
954            health_check = False
955    else:
956        logging.debug('DUT Temperature = {} C'.format(temperature))
957    return health_check
958
959
960# Wifi Device utils
961def detect_wifi_platform(dut):
962    ini_check = len(dut.get_file_names('/vendor/firmware/wlan/qca_cld/'))
963    if ini_check:
964        wifi_platform = 'qcom'
965    else:
966        wifi_platform = 'brcm'
967    return wifi_platform
968
969
970def detect_wifi_decorator(f):
971    def wrap(*args, **kwargs):
972        if 'dut' in kwargs:
973            dut = kwargs['dut']
974        else:
975            dut = next(arg for arg in args if type(arg) == AndroidDevice)
976        f_decorated = '{}_{}'.format(f.__name__, detect_wifi_platform(dut))
977        f_decorated = globals()[f_decorated]
978        return (f_decorated(*args, **kwargs))
979
980    return wrap
981
982
983# Rssi Utilities
984def empty_rssi_result():
985    return collections.OrderedDict([('data', []), ('mean', None),
986                                    ('stdev', None)])
987
988
989@detect_wifi_decorator
990def get_connected_rssi(dut,
991                       num_measurements=1,
992                       polling_frequency=SHORT_SLEEP,
993                       first_measurement_delay=0,
994                       disconnect_warning=True,
995                       ignore_samples=0,
996                       interface=None):
997    """Gets all RSSI values reported for the connected access point/BSSID.
998
999    Args:
1000        dut: android device object from which to get RSSI
1001        num_measurements: number of scans done, and RSSIs collected
1002        polling_frequency: time to wait between RSSI measurements
1003        disconnect_warning: boolean controlling disconnection logging messages
1004        ignore_samples: number of leading samples to ignore
1005    Returns:
1006        connected_rssi: dict containing the measurements results for
1007        all reported RSSI values (signal_poll, per chain, etc.) and their
1008        statistics
1009    """
1010    pass
1011
1012
1013@nonblocking
1014def get_connected_rssi_nb(dut,
1015                          num_measurements=1,
1016                          polling_frequency=SHORT_SLEEP,
1017                          first_measurement_delay=0,
1018                          disconnect_warning=True,
1019                          ignore_samples=0,
1020                          interface=None):
1021    return get_connected_rssi(dut, num_measurements, polling_frequency,
1022                              first_measurement_delay, disconnect_warning,
1023                              ignore_samples, interface)
1024
1025
1026def get_connected_rssi_qcom(dut,
1027                            num_measurements=1,
1028                            polling_frequency=SHORT_SLEEP,
1029                            first_measurement_delay=0,
1030                            disconnect_warning=True,
1031                            ignore_samples=0,
1032                            interface=None):
1033    # yapf: disable
1034    connected_rssi = collections.OrderedDict(
1035        [('time_stamp', []),
1036         ('bssid', []), ('ssid', []), ('frequency', []),
1037         ('signal_poll_rssi', empty_rssi_result()),
1038         ('signal_poll_avg_rssi', empty_rssi_result()),
1039         ('chain_0_rssi', empty_rssi_result()),
1040         ('chain_1_rssi', empty_rssi_result())])
1041    # yapf: enable
1042    previous_bssid = 'disconnected'
1043    t0 = time.time()
1044    time.sleep(first_measurement_delay)
1045    for idx in range(num_measurements):
1046        measurement_start_time = time.time()
1047        connected_rssi['time_stamp'].append(measurement_start_time - t0)
1048        # Get signal poll RSSI
1049        try:
1050            if interface is None:
1051                status_output = dut.adb.shell(WPA_CLI_STATUS)
1052            else:
1053                status_output = dut.adb.shell(
1054                    'wpa_cli -i {} status'.format(interface))
1055        except:
1056            status_output = ''
1057        match = re.search('bssid=.*', status_output)
1058        if match:
1059            current_bssid = match.group(0).split('=')[1]
1060            connected_rssi['bssid'].append(current_bssid)
1061        else:
1062            current_bssid = 'disconnected'
1063            connected_rssi['bssid'].append(current_bssid)
1064            if disconnect_warning and previous_bssid != 'disconnected':
1065                logging.warning('WIFI DISCONNECT DETECTED!')
1066        previous_bssid = current_bssid
1067        match = re.search('\s+ssid=.*', status_output)
1068        if match:
1069            ssid = match.group(0).split('=')[1]
1070            connected_rssi['ssid'].append(ssid)
1071        else:
1072            connected_rssi['ssid'].append('disconnected')
1073        try:
1074            if interface is None:
1075                signal_poll_output = dut.adb.shell(SIGNAL_POLL)
1076            else:
1077                signal_poll_output = dut.adb.shell(
1078                    'wpa_cli -i {} signal_poll'.format(interface))
1079        except:
1080            signal_poll_output = ''
1081        match = re.search('FREQUENCY=.*', signal_poll_output)
1082        if match:
1083            frequency = int(match.group(0).split('=')[1])
1084            connected_rssi['frequency'].append(frequency)
1085        else:
1086            connected_rssi['frequency'].append(RSSI_ERROR_VAL)
1087        match = re.search('RSSI=.*', signal_poll_output)
1088        if match:
1089            temp_rssi = int(match.group(0).split('=')[1])
1090            if temp_rssi == -9999 or temp_rssi == 0:
1091                connected_rssi['signal_poll_rssi']['data'].append(
1092                    RSSI_ERROR_VAL)
1093            else:
1094                connected_rssi['signal_poll_rssi']['data'].append(temp_rssi)
1095        else:
1096            connected_rssi['signal_poll_rssi']['data'].append(RSSI_ERROR_VAL)
1097        match = re.search('AVG_RSSI=.*', signal_poll_output)
1098        if match:
1099            connected_rssi['signal_poll_avg_rssi']['data'].append(
1100                int(match.group(0).split('=')[1]))
1101        else:
1102            connected_rssi['signal_poll_avg_rssi']['data'].append(
1103                RSSI_ERROR_VAL)
1104
1105        # Get per chain RSSI
1106        try:
1107            if interface is None:
1108                per_chain_rssi = dut.adb.shell(STATION_DUMP)
1109            else:
1110                per_chain_rssi = ''
1111        except:
1112            per_chain_rssi = ''
1113        match = re.search('.*signal avg:.*', per_chain_rssi)
1114        if match:
1115            per_chain_rssi = per_chain_rssi[per_chain_rssi.find('[') +
1116                                            1:per_chain_rssi.find(']')]
1117            per_chain_rssi = per_chain_rssi.split(', ')
1118            connected_rssi['chain_0_rssi']['data'].append(
1119                int(per_chain_rssi[0]))
1120            connected_rssi['chain_1_rssi']['data'].append(
1121                int(per_chain_rssi[1]))
1122        else:
1123            connected_rssi['chain_0_rssi']['data'].append(RSSI_ERROR_VAL)
1124            connected_rssi['chain_1_rssi']['data'].append(RSSI_ERROR_VAL)
1125        measurement_elapsed_time = time.time() - measurement_start_time
1126        time.sleep(max(0, polling_frequency - measurement_elapsed_time))
1127
1128    # Compute mean RSSIs. Only average valid readings.
1129    # Output RSSI_ERROR_VAL if no valid connected readings found.
1130    for key, val in connected_rssi.copy().items():
1131        if 'data' not in val:
1132            continue
1133        filtered_rssi_values = [x for x in val['data'] if not math.isnan(x)]
1134        if len(filtered_rssi_values) > ignore_samples:
1135            filtered_rssi_values = filtered_rssi_values[ignore_samples:]
1136        if filtered_rssi_values:
1137            connected_rssi[key]['mean'] = statistics.mean(filtered_rssi_values)
1138            if len(filtered_rssi_values) > 1:
1139                connected_rssi[key]['stdev'] = statistics.stdev(
1140                    filtered_rssi_values)
1141            else:
1142                connected_rssi[key]['stdev'] = 0
1143        else:
1144            connected_rssi[key]['mean'] = RSSI_ERROR_VAL
1145            connected_rssi[key]['stdev'] = RSSI_ERROR_VAL
1146    return connected_rssi
1147
1148
1149def get_connected_rssi_brcm(dut,
1150                            num_measurements=1,
1151                            polling_frequency=SHORT_SLEEP,
1152                            first_measurement_delay=0,
1153                            disconnect_warning=True,
1154                            ignore_samples=0,
1155                            interface=None):
1156    # yapf: disable
1157    connected_rssi = collections.OrderedDict(
1158        [('time_stamp', []),
1159         ('bssid', []), ('ssid', []), ('frequency', []),
1160         ('signal_poll_rssi', empty_rssi_result()),
1161         ('signal_poll_avg_rssi', empty_rssi_result()),
1162         ('chain_0_rssi', empty_rssi_result()),
1163         ('chain_1_rssi', empty_rssi_result())])
1164
1165    # yapf: enable
1166    previous_bssid = 'disconnected'
1167    t0 = time.time()
1168    time.sleep(first_measurement_delay)
1169    for idx in range(num_measurements):
1170        measurement_start_time = time.time()
1171        connected_rssi['time_stamp'].append(measurement_start_time - t0)
1172        # Get signal poll RSSI
1173        status_output = dut.adb.shell('wl assoc')
1174        match = re.search('BSSID:.*', status_output)
1175
1176        if match:
1177            current_bssid = match.group(0).split('\t')[0]
1178            current_bssid = current_bssid.split(' ')[1]
1179            connected_rssi['bssid'].append(current_bssid)
1180
1181        else:
1182            current_bssid = 'disconnected'
1183            connected_rssi['bssid'].append(current_bssid)
1184            if disconnect_warning and previous_bssid != 'disconnected':
1185                logging.warning('WIFI DISCONNECT DETECTED!')
1186
1187        previous_bssid = current_bssid
1188        match = re.search('SSID:.*', status_output)
1189        if match:
1190            ssid = match.group(0).split(': ')[1]
1191            connected_rssi['ssid'].append(ssid)
1192        else:
1193            connected_rssi['ssid'].append('disconnected')
1194
1195        #TODO: SEARCH MAP ; PICK CENTER CHANNEL
1196        match = re.search('Primary channel:.*', status_output)
1197        if match:
1198            frequency = int(match.group(0).split(':')[1])
1199            connected_rssi['frequency'].append(frequency)
1200        else:
1201            connected_rssi['frequency'].append(RSSI_ERROR_VAL)
1202
1203        try:
1204            per_chain_rssi = dut.adb.shell('wl phy_rssi_ant')
1205        except:
1206            per_chain_rssi = DISCONNECTION_MESSAGE_BRCM
1207        if DISCONNECTION_MESSAGE_BRCM not in per_chain_rssi:
1208            per_chain_rssi = per_chain_rssi.split(' ')
1209            chain_0_rssi = int(per_chain_rssi[1])
1210            chain_1_rssi = int(per_chain_rssi[4])
1211            connected_rssi['chain_0_rssi']['data'].append(chain_0_rssi)
1212            connected_rssi['chain_1_rssi']['data'].append(chain_1_rssi)
1213            combined_rssi = math.pow(10, chain_0_rssi / 10) + math.pow(
1214                10, chain_1_rssi / 10)
1215            combined_rssi = 10 * math.log10(combined_rssi)
1216            connected_rssi['signal_poll_rssi']['data'].append(combined_rssi)
1217            connected_rssi['signal_poll_avg_rssi']['data'].append(
1218                combined_rssi)
1219        else:
1220            connected_rssi['chain_0_rssi']['data'].append(RSSI_ERROR_VAL)
1221            connected_rssi['chain_1_rssi']['data'].append(RSSI_ERROR_VAL)
1222            connected_rssi['signal_poll_rssi']['data'].append(RSSI_ERROR_VAL)
1223            connected_rssi['signal_poll_avg_rssi']['data'].append(
1224                RSSI_ERROR_VAL)
1225        measurement_elapsed_time = time.time() - measurement_start_time
1226        time.sleep(max(0, polling_frequency - measurement_elapsed_time))
1227
1228    # Statistics, Statistics
1229    for key, val in connected_rssi.copy().items():
1230        if 'data' not in val:
1231            continue
1232        filtered_rssi_values = [x for x in val['data'] if not math.isnan(x)]
1233        if len(filtered_rssi_values) > ignore_samples:
1234            filtered_rssi_values = filtered_rssi_values[ignore_samples:]
1235        if filtered_rssi_values:
1236            connected_rssi[key]['mean'] = statistics.mean(filtered_rssi_values)
1237            if len(filtered_rssi_values) > 1:
1238                connected_rssi[key]['stdev'] = statistics.stdev(
1239                    filtered_rssi_values)
1240            else:
1241                connected_rssi[key]['stdev'] = 0
1242        else:
1243            connected_rssi[key]['mean'] = RSSI_ERROR_VAL
1244            connected_rssi[key]['stdev'] = RSSI_ERROR_VAL
1245
1246    return connected_rssi
1247
1248
1249@detect_wifi_decorator
1250def get_scan_rssi(dut, tracked_bssids, num_measurements=1):
1251    """Gets scan RSSI for specified BSSIDs.
1252
1253    Args:
1254        dut: android device object from which to get RSSI
1255        tracked_bssids: array of BSSIDs to gather RSSI data for
1256        num_measurements: number of scans done, and RSSIs collected
1257    Returns:
1258        scan_rssi: dict containing the measurement results as well as the
1259        statistics of the scan RSSI for all BSSIDs in tracked_bssids
1260    """
1261    pass
1262
1263
1264@nonblocking
1265def get_scan_rssi_nb(dut, tracked_bssids, num_measurements=1):
1266    return get_scan_rssi(dut, tracked_bssids, num_measurements)
1267
1268
1269def get_scan_rssi_qcom(dut, tracked_bssids, num_measurements=1):
1270    scan_rssi = collections.OrderedDict()
1271    for bssid in tracked_bssids:
1272        scan_rssi[bssid] = empty_rssi_result()
1273    for idx in range(num_measurements):
1274        scan_output = dut.adb.shell(SCAN)
1275        time.sleep(MED_SLEEP)
1276        scan_output = dut.adb.shell(SCAN_RESULTS)
1277        for bssid in tracked_bssids:
1278            bssid_result = re.search(bssid + '.*',
1279                                     scan_output,
1280                                     flags=re.IGNORECASE)
1281            if bssid_result:
1282                bssid_result = bssid_result.group(0).split('\t')
1283                scan_rssi[bssid]['data'].append(int(bssid_result[2]))
1284            else:
1285                scan_rssi[bssid]['data'].append(RSSI_ERROR_VAL)
1286    # Compute mean RSSIs. Only average valid readings.
1287    # Output RSSI_ERROR_VAL if no readings found.
1288    for key, val in scan_rssi.items():
1289        filtered_rssi_values = [x for x in val['data'] if not math.isnan(x)]
1290        if filtered_rssi_values:
1291            scan_rssi[key]['mean'] = statistics.mean(filtered_rssi_values)
1292            if len(filtered_rssi_values) > 1:
1293                scan_rssi[key]['stdev'] = statistics.stdev(
1294                    filtered_rssi_values)
1295            else:
1296                scan_rssi[key]['stdev'] = 0
1297        else:
1298            scan_rssi[key]['mean'] = RSSI_ERROR_VAL
1299            scan_rssi[key]['stdev'] = RSSI_ERROR_VAL
1300    return scan_rssi
1301
1302
1303def get_scan_rssi_brcm(dut, tracked_bssids, num_measurements=1):
1304    scan_rssi = collections.OrderedDict()
1305    for bssid in tracked_bssids:
1306        scan_rssi[bssid] = empty_rssi_result()
1307    for idx in range(num_measurements):
1308        scan_output = dut.adb.shell('cmd wifi start-scan')
1309        time.sleep(MED_SLEEP)
1310        scan_output = dut.adb.shell('cmd wifi list-scan-results')
1311        for bssid in tracked_bssids:
1312            bssid_result = re.search(bssid + '.*',
1313                                     scan_output,
1314                                     flags=re.IGNORECASE)
1315            if bssid_result:
1316                bssid_result = bssid_result.group(0).split()
1317                print(bssid_result)
1318                scan_rssi[bssid]['data'].append(int(bssid_result[2]))
1319            else:
1320                scan_rssi[bssid]['data'].append(RSSI_ERROR_VAL)
1321    # Compute mean RSSIs. Only average valid readings.
1322    # Output RSSI_ERROR_VAL if no readings found.
1323    for key, val in scan_rssi.items():
1324        filtered_rssi_values = [x for x in val['data'] if not math.isnan(x)]
1325        if filtered_rssi_values:
1326            scan_rssi[key]['mean'] = statistics.mean(filtered_rssi_values)
1327            if len(filtered_rssi_values) > 1:
1328                scan_rssi[key]['stdev'] = statistics.stdev(
1329                    filtered_rssi_values)
1330            else:
1331                scan_rssi[key]['stdev'] = 0
1332        else:
1333            scan_rssi[key]['mean'] = RSSI_ERROR_VAL
1334            scan_rssi[key]['stdev'] = RSSI_ERROR_VAL
1335    return scan_rssi
1336
1337
1338@detect_wifi_decorator
1339def get_sw_signature(dut):
1340    """Function that checks the signature for wifi firmware and config files.
1341
1342    Returns:
1343        bdf_signature: signature consisting of last three digits of bdf cksums
1344        fw_signature: floating point firmware version, i.e., major.minor
1345    """
1346    pass
1347
1348
1349def get_sw_signature_qcom(dut):
1350    bdf_output = dut.adb.shell('cksum /vendor/firmware/bdwlan*')
1351    logging.debug('BDF Checksum output: {}'.format(bdf_output))
1352    bdf_signature = sum(
1353        [int(line.split(' ')[0]) for line in bdf_output.splitlines()]) % 1000
1354
1355    fw_output = dut.adb.shell('halutil -logger -get fw')
1356    logging.debug('Firmware version output: {}'.format(fw_output))
1357    fw_version = re.search(FW_REGEX, fw_output).group('firmware')
1358    fw_signature = fw_version.split('.')[-3:-1]
1359    fw_signature = float('.'.join(fw_signature))
1360    serial_hash = int(hashlib.md5(dut.serial.encode()).hexdigest(), 16) % 1000
1361    return {
1362        'config_signature': bdf_signature,
1363        'fw_signature': fw_signature,
1364        'serial_hash': serial_hash
1365    }
1366
1367
1368def get_sw_signature_brcm(dut):
1369    bdf_output = dut.adb.shell('cksum /vendor/etc/wifi/bcmdhd*')
1370    logging.debug('BDF Checksum output: {}'.format(bdf_output))
1371    bdf_signature = sum(
1372        [int(line.split(' ')[0]) for line in bdf_output.splitlines()]) % 1000
1373
1374    fw_output = dut.adb.shell('getprop vendor.wlan.firmware.version')
1375    logging.debug('Firmware version output: {}'.format(fw_output))
1376    fw_version = fw_output.split('.')[-1]
1377    driver_output = dut.adb.shell('getprop vendor.wlan.driver.version')
1378    driver_version = driver_output.split('.')[-1]
1379    fw_signature = float('{}.{}'.format(fw_version, driver_version))
1380    serial_hash = int(hashlib.md5(dut.serial.encode()).hexdigest(), 16) % 1000
1381    return {
1382        'config_signature': bdf_signature,
1383        'fw_signature': fw_signature,
1384        'serial_hash': serial_hash
1385    }
1386
1387
1388@detect_wifi_decorator
1389def push_config(dut, config_file):
1390    """Function to push Wifi BDF files
1391
1392    This function checks for existing wifi bdf files and over writes them all,
1393    for simplicity, with the bdf file provided in the arguments. The dut is
1394    rebooted for the bdf file to take effect
1395
1396    Args:
1397        dut: dut to push bdf file to
1398        config_file: path to bdf_file to push
1399    """
1400    pass
1401
1402
1403def push_config_qcom(dut, config_file):
1404    config_files_list = dut.adb.shell(
1405        'ls /vendor/firmware/bdwlan*').splitlines()
1406    for dst_file in config_files_list:
1407        dut.push_system_file(config_file, dst_file)
1408    dut.reboot()
1409
1410
1411def push_config_brcm(dut, config_file):
1412    config_files_list = dut.adb.shell('ls /vendor/etc/*.cal').splitlines()
1413    for dst_file in config_files_list:
1414        dut.push_system_file(config_file, dst_file)
1415    dut.reboot()
1416
1417
1418def push_firmware(dut, firmware_files):
1419    """Function to push Wifi firmware files
1420
1421    Args:
1422        dut: dut to push bdf file to
1423        firmware_files: path to wlanmdsp.mbn file
1424        datamsc_file: path to Data.msc file
1425    """
1426    for file in firmware_files:
1427        dut.push_system_file(file, '/vendor/firmware/')
1428    dut.reboot()
1429
1430
1431@detect_wifi_decorator
1432def start_wifi_logging(dut):
1433    """Function to start collecting wifi-related logs"""
1434    pass
1435
1436
1437def start_wifi_logging_qcom(dut):
1438    dut.droid.wifiEnableVerboseLogging(1)
1439    msg = "Failed to enable WiFi verbose logging."
1440    asserts.assert_equal(dut.droid.wifiGetVerboseLoggingLevel(), 1, msg)
1441    logging.info('Starting CNSS logs')
1442    dut.adb.shell("find /data/vendor/wifi/wlan_logs/ -type f -delete",
1443                  ignore_status=True)
1444    dut.adb.shell_nb('cnss_diag -f -s')
1445
1446
1447def start_wifi_logging_brcm(dut):
1448    pass
1449
1450
1451@detect_wifi_decorator
1452def stop_wifi_logging(dut):
1453    """Function to start collecting wifi-related logs"""
1454    pass
1455
1456
1457def stop_wifi_logging_qcom(dut):
1458    logging.info('Stopping CNSS logs')
1459    dut.adb.shell('killall cnss_diag')
1460    logs = dut.get_file_names("/data/vendor/wifi/wlan_logs/")
1461    if logs:
1462        dut.log.info("Pulling cnss_diag logs %s", logs)
1463        log_path = os.path.join(dut.device_log_path,
1464                                "CNSS_DIAG_%s" % dut.serial)
1465        os.makedirs(log_path, exist_ok=True)
1466        dut.pull_files(logs, log_path)
1467
1468
1469def stop_wifi_logging_brcm(dut):
1470    pass
1471
1472
1473def _set_ini_fields(ini_file_path, ini_field_dict):
1474    template_regex = r'^{}=[0-9,.x-]+'
1475    with open(ini_file_path, 'r') as f:
1476        ini_lines = f.read().splitlines()
1477        for idx, line in enumerate(ini_lines):
1478            for field_name, field_value in ini_field_dict.items():
1479                line_regex = re.compile(template_regex.format(field_name))
1480                if re.match(line_regex, line):
1481                    ini_lines[idx] = '{}={}'.format(field_name, field_value)
1482                    print(ini_lines[idx])
1483    with open(ini_file_path, 'w') as f:
1484        f.write('\n'.join(ini_lines) + '\n')
1485
1486
1487def _edit_dut_ini(dut, ini_fields):
1488    """Function to edit Wifi ini files."""
1489    dut_ini_path = '/vendor/firmware/wlan/qca_cld/WCNSS_qcom_cfg.ini'
1490    local_ini_path = os.path.expanduser('~/WCNSS_qcom_cfg.ini')
1491    dut.pull_files(dut_ini_path, local_ini_path)
1492
1493    _set_ini_fields(local_ini_path, ini_fields)
1494
1495    dut.push_system_file(local_ini_path, dut_ini_path)
1496    dut.reboot()
1497
1498
1499def set_ini_single_chain_mode(dut, chain):
1500    ini_fields = {
1501        'gEnable2x2': 0,
1502        'gSetTxChainmask1x1': chain + 1,
1503        'gSetRxChainmask1x1': chain + 1,
1504        'gDualMacFeatureDisable': 1,
1505        'gDot11Mode': 0
1506    }
1507    _edit_dut_ini(dut, ini_fields)
1508
1509
1510def set_ini_two_chain_mode(dut):
1511    ini_fields = {
1512        'gEnable2x2': 2,
1513        'gSetTxChainmask1x1': 1,
1514        'gSetRxChainmask1x1': 1,
1515        'gDualMacFeatureDisable': 6,
1516        'gDot11Mode': 0
1517    }
1518    _edit_dut_ini(dut, ini_fields)
1519
1520
1521def set_ini_tx_mode(dut, mode):
1522    TX_MODE_DICT = {
1523        'Auto': 0,
1524        '11n': 4,
1525        '11ac': 9,
1526        '11abg': 1,
1527        '11b': 2,
1528        '11': 3,
1529        '11g only': 5,
1530        '11n only': 6,
1531        '11b only': 7,
1532        '11ac only': 8
1533    }
1534
1535    ini_fields = {
1536        'gEnable2x2': 2,
1537        'gSetTxChainmask1x1': 1,
1538        'gSetRxChainmask1x1': 1,
1539        'gDualMacFeatureDisable': 6,
1540        'gDot11Mode': TX_MODE_DICT[mode]
1541    }
1542    _edit_dut_ini(dut, ini_fields)
1543
1544
1545# Link layer stats utilities
1546class LinkLayerStats():
1547    def __new__(self, dut, llstats_enabled=True):
1548        if detect_wifi_platform(dut) == 'qcom':
1549            return LinkLayerStatsQcom(dut, llstats_enabled=True)
1550        else:
1551            return LinkLayerStatsBrcm(dut, llstats_enabled=True)
1552
1553
1554class LinkLayerStatsQcom():
1555
1556    LLSTATS_CMD = 'cat /d/wlan0/ll_stats'
1557    PEER_REGEX = 'LL_STATS_PEER_ALL'
1558    MCS_REGEX = re.compile(
1559        r'preamble: (?P<mode>\S+), nss: (?P<num_streams>\S+), bw: (?P<bw>\S+), '
1560        'mcs: (?P<mcs>\S+), bitrate: (?P<rate>\S+), txmpdu: (?P<txmpdu>\S+), '
1561        'rxmpdu: (?P<rxmpdu>\S+), mpdu_lost: (?P<mpdu_lost>\S+), '
1562        'retries: (?P<retries>\S+), retries_short: (?P<retries_short>\S+), '
1563        'retries_long: (?P<retries_long>\S+)')
1564    MCS_ID = collections.namedtuple(
1565        'mcs_id', ['mode', 'num_streams', 'bandwidth', 'mcs', 'rate'])
1566    MODE_MAP = {'0': '11a/g', '1': '11b', '2': '11n', '3': '11ac'}
1567    BW_MAP = {'0': 20, '1': 40, '2': 80}
1568
1569    def __init__(self, dut, llstats_enabled=True):
1570        self.dut = dut
1571        self.llstats_enabled = llstats_enabled
1572        self.llstats_cumulative = self._empty_llstats()
1573        self.llstats_incremental = self._empty_llstats()
1574
1575    def update_stats(self):
1576        if self.llstats_enabled:
1577            try:
1578                llstats_output = self.dut.adb.shell(self.LLSTATS_CMD,
1579                                                    timeout=0.1)
1580            except:
1581                llstats_output = ''
1582        else:
1583            llstats_output = ''
1584        self._update_stats(llstats_output)
1585
1586    def reset_stats(self):
1587        self.llstats_cumulative = self._empty_llstats()
1588        self.llstats_incremental = self._empty_llstats()
1589
1590    def _empty_llstats(self):
1591        return collections.OrderedDict(mcs_stats=collections.OrderedDict(),
1592                                       summary=collections.OrderedDict())
1593
1594    def _empty_mcs_stat(self):
1595        return collections.OrderedDict(txmpdu=0,
1596                                       rxmpdu=0,
1597                                       mpdu_lost=0,
1598                                       retries=0,
1599                                       retries_short=0,
1600                                       retries_long=0)
1601
1602    def _mcs_id_to_string(self, mcs_id):
1603        mcs_string = '{} {}MHz Nss{} MCS{} {}Mbps'.format(
1604            mcs_id.mode, mcs_id.bandwidth, mcs_id.num_streams, mcs_id.mcs,
1605            mcs_id.rate)
1606        return mcs_string
1607
1608    def _parse_mcs_stats(self, llstats_output):
1609        llstats_dict = {}
1610        # Look for per-peer stats
1611        match = re.search(self.PEER_REGEX, llstats_output)
1612        if not match:
1613            self.reset_stats()
1614            return collections.OrderedDict()
1615        # Find and process all matches for per stream stats
1616        match_iter = re.finditer(self.MCS_REGEX, llstats_output)
1617        for match in match_iter:
1618            current_mcs = self.MCS_ID(self.MODE_MAP[match.group('mode')],
1619                                      int(match.group('num_streams')) + 1,
1620                                      self.BW_MAP[match.group('bw')],
1621                                      int(match.group('mcs')),
1622                                      int(match.group('rate'), 16) / 1000)
1623            current_stats = collections.OrderedDict(
1624                txmpdu=int(match.group('txmpdu')),
1625                rxmpdu=int(match.group('rxmpdu')),
1626                mpdu_lost=int(match.group('mpdu_lost')),
1627                retries=int(match.group('retries')),
1628                retries_short=int(match.group('retries_short')),
1629                retries_long=int(match.group('retries_long')))
1630            llstats_dict[self._mcs_id_to_string(current_mcs)] = current_stats
1631        return llstats_dict
1632
1633    def _diff_mcs_stats(self, new_stats, old_stats):
1634        stats_diff = collections.OrderedDict()
1635        for stat_key in new_stats.keys():
1636            stats_diff[stat_key] = new_stats[stat_key] - old_stats[stat_key]
1637        return stats_diff
1638
1639    def _generate_stats_summary(self, llstats_dict):
1640        llstats_summary = collections.OrderedDict(common_tx_mcs=None,
1641                                                  common_tx_mcs_count=0,
1642                                                  common_tx_mcs_freq=0,
1643                                                  common_rx_mcs=None,
1644                                                  common_rx_mcs_count=0,
1645                                                  common_rx_mcs_freq=0)
1646        txmpdu_count = 0
1647        rxmpdu_count = 0
1648        for mcs_id, mcs_stats in llstats_dict['mcs_stats'].items():
1649            if mcs_stats['txmpdu'] > llstats_summary['common_tx_mcs_count']:
1650                llstats_summary['common_tx_mcs'] = mcs_id
1651                llstats_summary['common_tx_mcs_count'] = mcs_stats['txmpdu']
1652            if mcs_stats['rxmpdu'] > llstats_summary['common_rx_mcs_count']:
1653                llstats_summary['common_rx_mcs'] = mcs_id
1654                llstats_summary['common_rx_mcs_count'] = mcs_stats['rxmpdu']
1655            txmpdu_count += mcs_stats['txmpdu']
1656            rxmpdu_count += mcs_stats['rxmpdu']
1657        if txmpdu_count:
1658            llstats_summary['common_tx_mcs_freq'] = (
1659                llstats_summary['common_tx_mcs_count'] / txmpdu_count)
1660        if rxmpdu_count:
1661            llstats_summary['common_rx_mcs_freq'] = (
1662                llstats_summary['common_rx_mcs_count'] / rxmpdu_count)
1663        return llstats_summary
1664
1665    def _update_stats(self, llstats_output):
1666        # Parse stats
1667        new_llstats = self._empty_llstats()
1668        new_llstats['mcs_stats'] = self._parse_mcs_stats(llstats_output)
1669        # Save old stats and set new cumulative stats
1670        old_llstats = self.llstats_cumulative.copy()
1671        self.llstats_cumulative = new_llstats.copy()
1672        # Compute difference between new and old stats
1673        self.llstats_incremental = self._empty_llstats()
1674        for mcs_id, new_mcs_stats in new_llstats['mcs_stats'].items():
1675            old_mcs_stats = old_llstats['mcs_stats'].get(
1676                mcs_id, self._empty_mcs_stat())
1677            self.llstats_incremental['mcs_stats'][
1678                mcs_id] = self._diff_mcs_stats(new_mcs_stats, old_mcs_stats)
1679        # Generate llstats summary
1680        self.llstats_incremental['summary'] = self._generate_stats_summary(
1681            self.llstats_incremental)
1682        self.llstats_cumulative['summary'] = self._generate_stats_summary(
1683            self.llstats_cumulative)
1684
1685
1686class LinkLayerStatsBrcm():
1687    def __init__(self, dut, llstats_enabled=True):
1688        self.dut = dut
1689        self.llstats_enabled = llstats_enabled
1690        self.llstats_incremental = self._empty_llstats()
1691        self.llstats_cumulative = self.llstats_incremental
1692
1693    def _empty_llstats(self):
1694        return collections.OrderedDict(mcs_stats=collections.OrderedDict(),
1695                                       summary=collections.OrderedDict())
1696
1697    def update_stats(self):
1698        self.llstats_incremental = self._empty_llstats()
1699        self.llstats_incremental['summary'] = collections.OrderedDict(
1700            common_tx_mcs=None,
1701            common_tx_mcs_count=1,
1702            common_tx_mcs_freq=1,
1703            common_rx_mcs=None,
1704            common_rx_mcs_count=1,
1705            common_rx_mcs_freq=1)
1706        if self.llstats_enabled:
1707            try:
1708                rate_info = self.dut.adb.shell('wl rate_info', timeout=0.1)
1709                self.llstats_incremental['summary'][
1710                    'common_tx_mcs'] = '{} Mbps'.format(
1711                        re.findall('\[Tx\]:'
1712                                   ' (\d+[.]*\d* Mbps)', rate_info))
1713                self.llstats_incremental['summary'][
1714                    'common_rx_mcs'] = '{} Mbps'.format(
1715                        re.findall('\[Rx\]:'
1716                                   ' (\d+[.]*\d* Mbps)', rate_info))
1717            except:
1718                pass
1719