1#!/usr/bin/env python3.4
2#
3#   Copyright 2019 - The Android Open Source Project
4#
5#   Licensed under the Apache License, Version 2.0 (the 'License');
6#   you may not use this file except in compliance with the License.
7#   You may obtain a copy of the License at
8#
9#       http://www.apache.org/licenses/LICENSE-2.0
10#
11#   Unless required by applicable law or agreed to in writing, software
12#   distributed under the License is distributed on an 'AS IS' BASIS,
13#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14#   See the License for the specific language governing permissions and
15#   limitations under the License.
16
17import bokeh, bokeh.plotting
18import collections
19import itertools
20import json
21import logging
22import math
23import os
24import re
25import statistics
26import time
27from acts.controllers.android_device import AndroidDevice
28from acts.controllers.utils_lib import ssh
29from acts import utils
30from acts.test_utils.wifi import wifi_test_utils as wutils
31from concurrent.futures import ThreadPoolExecutor
32
33SHORT_SLEEP = 1
34MED_SLEEP = 6
35TEST_TIMEOUT = 10
36STATION_DUMP = 'iw wlan0 station dump'
37SCAN = 'wpa_cli scan'
38SCAN_RESULTS = 'wpa_cli scan_results'
39SIGNAL_POLL = 'wpa_cli signal_poll'
40WPA_CLI_STATUS = 'wpa_cli status'
41CONST_3dB = 3.01029995664
42RSSI_ERROR_VAL = float('nan')
43RTT_REGEX = re.compile(r'^\[(?P<timestamp>\S+)\] .*? time=(?P<rtt>\S+)')
44LOSS_REGEX = re.compile(r'(?P<loss>\S+)% packet loss')
45
46
47# Threading decorator
48def nonblocking(f):
49    """Creates a decorator transforming function calls to non-blocking"""
50    def wrap(*args, **kwargs):
51        executor = ThreadPoolExecutor(max_workers=1)
52        thread_future = executor.submit(f, *args, **kwargs)
53        # Ensure resources are freed up when executor ruturns or raises
54        executor.shutdown(wait=False)
55        return thread_future
56
57    return wrap
58
59
60# Link layer stats utilities
61class LinkLayerStats():
62
63    LLSTATS_CMD = 'cat /d/wlan0/ll_stats'
64    PEER_REGEX = 'LL_STATS_PEER_ALL'
65    MCS_REGEX = re.compile(
66        r'preamble: (?P<mode>\S+), nss: (?P<num_streams>\S+), bw: (?P<bw>\S+), '
67        'mcs: (?P<mcs>\S+), bitrate: (?P<rate>\S+), txmpdu: (?P<txmpdu>\S+), '
68        'rxmpdu: (?P<rxmpdu>\S+), mpdu_lost: (?P<mpdu_lost>\S+), '
69        'retries: (?P<retries>\S+), retries_short: (?P<retries_short>\S+), '
70        'retries_long: (?P<retries_long>\S+)')
71    MCS_ID = collections.namedtuple(
72        'mcs_id', ['mode', 'num_streams', 'bandwidth', 'mcs', 'rate'])
73    MODE_MAP = {'0': '11a/g', '1': '11b', '2': '11n', '3': '11ac'}
74    BW_MAP = {'0': 20, '1': 40, '2': 80}
75
76    def __init__(self, dut):
77        self.dut = dut
78        self.llstats_cumulative = self._empty_llstats()
79        self.llstats_incremental = self._empty_llstats()
80
81    def update_stats(self):
82        llstats_output = self.dut.adb.shell(self.LLSTATS_CMD)
83        self._update_stats(llstats_output)
84
85    def reset_stats(self):
86        self.llstats_cumulative = self._empty_llstats()
87        self.llstats_incremental = self._empty_llstats()
88
89    def _empty_llstats(self):
90        return collections.OrderedDict(mcs_stats=collections.OrderedDict(),
91                                       summary=collections.OrderedDict())
92
93    def _empty_mcs_stat(self):
94        return collections.OrderedDict(txmpdu=0,
95                                       rxmpdu=0,
96                                       mpdu_lost=0,
97                                       retries=0,
98                                       retries_short=0,
99                                       retries_long=0)
100
101    def _mcs_id_to_string(self, mcs_id):
102        mcs_string = '{} {}MHz Nss{} MCS{} {}Mbps'.format(
103            mcs_id.mode, mcs_id.bandwidth, mcs_id.num_streams, mcs_id.mcs,
104            mcs_id.rate)
105        return mcs_string
106
107    def _parse_mcs_stats(self, llstats_output):
108        llstats_dict = {}
109        # Look for per-peer stats
110        match = re.search(self.PEER_REGEX, llstats_output)
111        if not match:
112            self.reset_stats()
113            return collections.OrderedDict()
114        # Find and process all matches for per stream stats
115        match_iter = re.finditer(self.MCS_REGEX, llstats_output)
116        for match in match_iter:
117            current_mcs = self.MCS_ID(self.MODE_MAP[match.group('mode')],
118                                      int(match.group('num_streams')) + 1,
119                                      self.BW_MAP[match.group('bw')],
120                                      int(match.group('mcs')),
121                                      int(match.group('rate'), 16) / 1000)
122            current_stats = collections.OrderedDict(
123                txmpdu=int(match.group('txmpdu')),
124                rxmpdu=int(match.group('rxmpdu')),
125                mpdu_lost=int(match.group('mpdu_lost')),
126                retries=int(match.group('retries')),
127                retries_short=int(match.group('retries_short')),
128                retries_long=int(match.group('retries_long')))
129            llstats_dict[self._mcs_id_to_string(current_mcs)] = current_stats
130        return llstats_dict
131
132    def _diff_mcs_stats(self, new_stats, old_stats):
133        stats_diff = collections.OrderedDict()
134        for stat_key in new_stats.keys():
135            stats_diff[stat_key] = new_stats[stat_key] - old_stats[stat_key]
136        return stats_diff
137
138    def _generate_stats_summary(self, llstats_dict):
139        llstats_summary = collections.OrderedDict(common_tx_mcs=None,
140                                                  common_tx_mcs_count=0,
141                                                  common_tx_mcs_freq=0,
142                                                  common_rx_mcs=None,
143                                                  common_rx_mcs_count=0,
144                                                  common_rx_mcs_freq=0)
145        txmpdu_count = 0
146        rxmpdu_count = 0
147        for mcs_id, mcs_stats in llstats_dict['mcs_stats'].items():
148            if mcs_stats['txmpdu'] > llstats_summary['common_tx_mcs_count']:
149                llstats_summary['common_tx_mcs'] = mcs_id
150                llstats_summary['common_tx_mcs_count'] = mcs_stats['txmpdu']
151            if mcs_stats['rxmpdu'] > llstats_summary['common_rx_mcs_count']:
152                llstats_summary['common_rx_mcs'] = mcs_id
153                llstats_summary['common_rx_mcs_count'] = mcs_stats['rxmpdu']
154            txmpdu_count += mcs_stats['txmpdu']
155            rxmpdu_count += mcs_stats['rxmpdu']
156        if txmpdu_count:
157            llstats_summary['common_tx_mcs_freq'] = (
158                llstats_summary['common_tx_mcs_count'] / txmpdu_count)
159        if rxmpdu_count:
160            llstats_summary['common_rx_mcs_freq'] = (
161                llstats_summary['common_rx_mcs_count'] / rxmpdu_count)
162        return llstats_summary
163
164    def _update_stats(self, llstats_output):
165        # Parse stats
166        new_llstats = self._empty_llstats()
167        new_llstats['mcs_stats'] = self._parse_mcs_stats(llstats_output)
168        # Save old stats and set new cumulative stats
169        old_llstats = self.llstats_cumulative.copy()
170        self.llstats_cumulative = new_llstats.copy()
171        # Compute difference between new and old stats
172        self.llstats_incremental = self._empty_llstats()
173        for mcs_id, new_mcs_stats in new_llstats['mcs_stats'].items():
174            old_mcs_stats = old_llstats['mcs_stats'].get(
175                mcs_id, self._empty_mcs_stat())
176            self.llstats_incremental['mcs_stats'][
177                mcs_id] = self._diff_mcs_stats(new_mcs_stats, old_mcs_stats)
178        # Generate llstats summary
179        self.llstats_incremental['summary'] = self._generate_stats_summary(
180            self.llstats_incremental)
181        self.llstats_cumulative['summary'] = self._generate_stats_summary(
182            self.llstats_cumulative)
183
184
185# JSON serializer
186def serialize_dict(input_dict):
187    """Function to serialize dicts to enable JSON output"""
188    output_dict = collections.OrderedDict()
189    for key, value in input_dict.items():
190        output_dict[_serialize_value(key)] = _serialize_value(value)
191    return output_dict
192
193
194def _serialize_value(value):
195    """Function to recursively serialize dict entries to enable JSON output"""
196    if isinstance(value, tuple):
197        return str(value)
198    if isinstance(value, list):
199        return [_serialize_value(x) for x in value]
200    elif isinstance(value, dict):
201        return serialize_dict(value)
202    else:
203        return value
204
205
206# Plotting Utilities
207class BokehFigure():
208    """Class enabling  simplified Bokeh plotting."""
209
210    COLORS = [
211        'black',
212        'blue',
213        'blueviolet',
214        'brown',
215        'burlywood',
216        'cadetblue',
217        'cornflowerblue',
218        'crimson',
219        'cyan',
220        'darkblue',
221        'darkgreen',
222        'darkmagenta',
223        'darkorange',
224        'darkred',
225        'deepskyblue',
226        'goldenrod',
227        'green',
228        'grey',
229        'indigo',
230        'navy',
231        'olive',
232        'orange',
233        'red',
234        'salmon',
235        'teal',
236        'yellow',
237    ]
238    MARKERS = [
239        'asterisk', 'circle', 'circle_cross', 'circle_x', 'cross', 'diamond',
240        'diamond_cross', 'hex', 'inverted_triangle', 'square', 'square_x',
241        'square_cross', 'triangle', 'x'
242    ]
243
244    def __init__(self,
245                 title=None,
246                 x_label=None,
247                 primary_y_label=None,
248                 secondary_y_label=None,
249                 height=700,
250                 width=1100,
251                 title_size='15pt',
252                 axis_label_size='12pt'):
253        self.figure_data = []
254        self.fig_property = {
255            'title': title,
256            'x_label': x_label,
257            'primary_y_label': primary_y_label,
258            'secondary_y_label': secondary_y_label,
259            'num_lines': 0,
260            'height': height,
261            'width': width,
262            'title_size': title_size,
263            'axis_label_size': axis_label_size
264        }
265        self.TOOLS = (
266            'box_zoom,box_select,pan,crosshair,redo,undo,reset,hover,save')
267        self.TOOLTIPS = [
268            ('index', '$index'),
269            ('(x,y)', '($x, $y)'),
270            ('info', '@hover_text'),
271        ]
272
273    def init_plot(self):
274        self.plot = bokeh.plotting.figure(
275            sizing_mode='scale_width',
276            plot_width=self.fig_property['width'],
277            plot_height=self.fig_property['height'],
278            title=self.fig_property['title'],
279            tools=self.TOOLS,
280            output_backend='webgl')
281        self.plot.hover.tooltips = self.TOOLTIPS
282        self.plot.add_tools(
283            bokeh.models.tools.WheelZoomTool(dimensions='width'))
284        self.plot.add_tools(
285            bokeh.models.tools.WheelZoomTool(dimensions='height'))
286
287    def _filter_line(self, x_data, y_data, hover_text=None):
288        """Function to remove NaN points from bokeh plots."""
289        x_data_filtered = []
290        y_data_filtered = []
291        hover_text_filtered = []
292        for x, y, hover in itertools.zip_longest(x_data, y_data, hover_text):
293            if not math.isnan(y):
294                x_data_filtered.append(x)
295                y_data_filtered.append(y)
296                hover_text_filtered.append(hover)
297        return x_data_filtered, y_data_filtered, hover_text_filtered
298
299    def add_line(self,
300                 x_data,
301                 y_data,
302                 legend,
303                 hover_text=None,
304                 color=None,
305                 width=3,
306                 style='solid',
307                 marker=None,
308                 marker_size=10,
309                 shaded_region=None,
310                 y_axis='default'):
311        """Function to add line to existing BokehFigure.
312
313        Args:
314            x_data: list containing x-axis values for line
315            y_data: list containing y_axis values for line
316            legend: string containing line title
317            hover_text: text to display when hovering over lines
318            color: string describing line color
319            width: integer line width
320            style: string describing line style, e.g, solid or dashed
321            marker: string specifying line marker, e.g., cross
322            shaded region: data describing shaded region to plot
323            y_axis: identifier for y-axis to plot line against
324        """
325        if y_axis not in ['default', 'secondary']:
326            raise ValueError('y_axis must be default or secondary')
327        if color == None:
328            color = self.COLORS[self.fig_property['num_lines'] %
329                                len(self.COLORS)]
330        if style == 'dashed':
331            style = [5, 5]
332        if not hover_text:
333            hover_text = ['y={}'.format(y) for y in y_data]
334        x_data_filter, y_data_filter, hover_text_filter = self._filter_line(
335            x_data, y_data, hover_text)
336        self.figure_data.append({
337            'x_data': x_data_filter,
338            'y_data': y_data_filter,
339            'legend': legend,
340            'hover_text': hover_text_filter,
341            'color': color,
342            'width': width,
343            'style': style,
344            'marker': marker,
345            'marker_size': marker_size,
346            'shaded_region': shaded_region,
347            'y_axis': y_axis
348        })
349        self.fig_property['num_lines'] += 1
350
351    def add_scatter(self,
352                    x_data,
353                    y_data,
354                    legend,
355                    hover_text=None,
356                    color=None,
357                    marker=None,
358                    marker_size=10,
359                    y_axis='default'):
360        """Function to add line to existing BokehFigure.
361
362        Args:
363            x_data: list containing x-axis values for line
364            y_data: list containing y_axis values for line
365            legend: string containing line title
366            hover_text: text to display when hovering over lines
367            color: string describing line color
368            marker: string specifying marker, e.g., cross
369            y_axis: identifier for y-axis to plot line against
370        """
371        if y_axis not in ['default', 'secondary']:
372            raise ValueError('y_axis must be default or secondary')
373        if color == None:
374            color = self.COLORS[self.fig_property['num_lines'] %
375                                len(self.COLORS)]
376        if marker == None:
377            marker = self.MARKERS[self.fig_property['num_lines'] %
378                                  len(self.MARKERS)]
379        if not hover_text:
380            hover_text = ['y={}'.format(y) for y in y_data]
381        self.figure_data.append({
382            'x_data': x_data,
383            'y_data': y_data,
384            'legend': legend,
385            'hover_text': hover_text,
386            'color': color,
387            'width': 0,
388            'style': 'solid',
389            'marker': marker,
390            'marker_size': marker_size,
391            'shaded_region': None,
392            'y_axis': y_axis
393        })
394        self.fig_property['num_lines'] += 1
395
396    def generate_figure(self, output_file=None):
397        """Function to generate and save BokehFigure.
398
399        Args:
400            output_file: string specifying output file path
401        """
402        self.init_plot()
403        two_axes = False
404        for line in self.figure_data:
405            source = bokeh.models.ColumnDataSource(
406                data=dict(x=line['x_data'],
407                          y=line['y_data'],
408                          hover_text=line['hover_text']))
409            if line['width'] > 0:
410                self.plot.line(x='x',
411                               y='y',
412                               legend_label=line['legend'],
413                               line_width=line['width'],
414                               color=line['color'],
415                               line_dash=line['style'],
416                               name=line['y_axis'],
417                               y_range_name=line['y_axis'],
418                               source=source)
419            if line['shaded_region']:
420                band_x = line['shaded_region']['x_vector']
421                band_x.extend(line['shaded_region']['x_vector'][::-1])
422                band_y = line['shaded_region']['lower_limit']
423                band_y.extend(line['shaded_region']['upper_limit'][::-1])
424                self.plot.patch(band_x,
425                                band_y,
426                                color='#7570B3',
427                                line_alpha=0.1,
428                                fill_alpha=0.1)
429            if line['marker'] in self.MARKERS:
430                marker_func = getattr(self.plot, line['marker'])
431                marker_func(x='x',
432                            y='y',
433                            size=line['marker_size'],
434                            legend_label=line['legend'],
435                            line_color=line['color'],
436                            fill_color=line['color'],
437                            name=line['y_axis'],
438                            y_range_name=line['y_axis'],
439                            source=source)
440            if line['y_axis'] == 'secondary':
441                two_axes = True
442
443        #x-axis formatting
444        self.plot.xaxis.axis_label = self.fig_property['x_label']
445        self.plot.x_range.range_padding = 0
446        self.plot.xaxis[0].axis_label_text_font_size = self.fig_property[
447            'axis_label_size']
448        #y-axis formatting
449        self.plot.yaxis[0].axis_label = self.fig_property['primary_y_label']
450        self.plot.yaxis[0].axis_label_text_font_size = self.fig_property[
451            'axis_label_size']
452        self.plot.y_range = bokeh.models.DataRange1d(names=['default'])
453        if two_axes and 'secondary' not in self.plot.extra_y_ranges:
454            self.plot.extra_y_ranges = {
455                'secondary': bokeh.models.DataRange1d(names=['secondary'])
456            }
457            self.plot.add_layout(
458                bokeh.models.LinearAxis(
459                    y_range_name='secondary',
460                    axis_label=self.fig_property['secondary_y_label'],
461                    axis_label_text_font_size=self.
462                    fig_property['axis_label_size']), 'right')
463        # plot formatting
464        self.plot.legend.location = 'top_right'
465        self.plot.legend.click_policy = 'hide'
466        self.plot.title.text_font_size = self.fig_property['title_size']
467
468        if output_file is not None:
469            self.save_figure(output_file)
470        return self.plot
471
472    def _save_figure_json(self, output_file):
473        """Function to save a json format of a figure"""
474        figure_dict = collections.OrderedDict(fig_property=self.fig_property,
475                                              figure_data=self.figure_data,
476                                              tools=self.TOOLS,
477                                              tooltips=self.TOOLTIPS)
478        output_file = output_file.replace('.html', '_plot_data.json')
479        with open(output_file, 'w') as outfile:
480            json.dump(figure_dict, outfile, indent=4)
481
482    def save_figure(self, output_file):
483        """Function to save BokehFigure.
484
485        Args:
486            output_file: string specifying output file path
487        """
488        bokeh.plotting.output_file(output_file)
489        bokeh.plotting.save(self.plot)
490        self._save_figure_json(output_file)
491
492    @staticmethod
493    def save_figures(figure_array, output_file_path):
494        """Function to save list of BokehFigures in one file.
495
496        Args:
497            figure_array: list of BokehFigure object to be plotted
498            output_file: string specifying output file path
499        """
500        for idx, figure in enumerate(figure_array):
501            figure.generate_figure()
502            json_file_path = output_file_path.replace(
503                '.html', '{}-plot_data.json'.format(idx))
504            figure._save_figure_json(json_file_path)
505        plot_array = [figure.plot for figure in figure_array]
506        all_plots = bokeh.layouts.column(children=plot_array,
507                                         sizing_mode='scale_width')
508        bokeh.plotting.output_file(output_file_path)
509        bokeh.plotting.save(all_plots)
510
511
512# Ping utilities
513class PingResult(object):
514    """An object that contains the results of running ping command.
515
516    Attributes:
517        connected: True if a connection was made. False otherwise.
518        packet_loss_percentage: The total percentage of packets lost.
519        transmission_times: The list of PingTransmissionTimes containing the
520            timestamps gathered for transmitted packets.
521        rtts: An list-like object enumerating all round-trip-times of
522            transmitted packets.
523        timestamps: A list-like object enumerating the beginning timestamps of
524            each packet transmission.
525        ping_interarrivals: A list-like object enumerating the amount of time
526            between the beginning of each subsequent transmission.
527    """
528    def __init__(self, ping_output):
529        self.packet_loss_percentage = 100
530        self.transmission_times = []
531
532        self.rtts = _ListWrap(self.transmission_times, lambda entry: entry.rtt)
533        self.timestamps = _ListWrap(self.transmission_times,
534                                    lambda entry: entry.timestamp)
535        self.ping_interarrivals = _PingInterarrivals(self.transmission_times)
536
537        self.start_time = 0
538        for line in ping_output:
539            if 'loss' in line:
540                match = re.search(LOSS_REGEX, line)
541                self.packet_loss_percentage = float(match.group('loss'))
542            if 'time=' in line:
543                match = re.search(RTT_REGEX, line)
544                if self.start_time == 0:
545                    self.start_time = float(match.group('timestamp'))
546                self.transmission_times.append(
547                    PingTransmissionTimes(
548                        float(match.group('timestamp')) - self.start_time,
549                        float(match.group('rtt'))))
550        self.connected = len(
551            ping_output) > 1 and self.packet_loss_percentage < 100
552
553    def __getitem__(self, item):
554        if item == 'rtt':
555            return self.rtts
556        if item == 'connected':
557            return self.connected
558        if item == 'packet_loss_percentage':
559            return self.packet_loss_percentage
560        raise ValueError('Invalid key. Please use an attribute instead.')
561
562    def as_dict(self):
563        return {
564            'connected': 1 if self.connected else 0,
565            'rtt': list(self.rtts),
566            'time_stamp': list(self.timestamps),
567            'ping_interarrivals': list(self.ping_interarrivals),
568            'packet_loss_percentage': self.packet_loss_percentage
569        }
570
571
572class PingTransmissionTimes(object):
573    """A class that holds the timestamps for a packet sent via the ping command.
574
575    Attributes:
576        rtt: The round trip time for the packet sent.
577        timestamp: The timestamp the packet started its trip.
578    """
579    def __init__(self, timestamp, rtt):
580        self.rtt = rtt
581        self.timestamp = timestamp
582
583
584class _ListWrap(object):
585    """A convenient helper class for treating list iterators as native lists."""
586    def __init__(self, wrapped_list, func):
587        self.__wrapped_list = wrapped_list
588        self.__func = func
589
590    def __getitem__(self, key):
591        return self.__func(self.__wrapped_list[key])
592
593    def __iter__(self):
594        for item in self.__wrapped_list:
595            yield self.__func(item)
596
597    def __len__(self):
598        return len(self.__wrapped_list)
599
600
601class _PingInterarrivals(object):
602    """A helper class for treating ping interarrivals as a native list."""
603    def __init__(self, ping_entries):
604        self.__ping_entries = ping_entries
605
606    def __getitem__(self, key):
607        return (self.__ping_entries[key + 1].timestamp -
608                self.__ping_entries[key].timestamp)
609
610    def __iter__(self):
611        for index in range(len(self.__ping_entries) - 1):
612            yield self[index]
613
614    def __len__(self):
615        return max(0, len(self.__ping_entries) - 1)
616
617
618def get_ping_stats(src_device, dest_address, ping_duration, ping_interval,
619                   ping_size):
620    """Run ping to or from the DUT.
621
622    The function computes either pings the DUT or pings a remote ip from
623    DUT.
624
625    Args:
626        src_device: object representing device to ping from
627        dest_address: ip address to ping
628        ping_duration: timeout to set on the the ping process (in seconds)
629        ping_interval: time between pings (in seconds)
630        ping_size: size of ping packet payload
631    Returns:
632        ping_result: dict containing ping results and other meta data
633    """
634    ping_count = int(ping_duration / ping_interval)
635    ping_deadline = int(ping_count * ping_interval) + 1
636    ping_cmd = 'ping -c {} -w {} -i {} -s {} -D'.format(
637        ping_count,
638        ping_deadline,
639        ping_interval,
640        ping_size,
641    )
642    if isinstance(src_device, AndroidDevice):
643        ping_cmd = '{} {}'.format(ping_cmd, dest_address)
644        ping_output = src_device.adb.shell(ping_cmd,
645                                           timeout=ping_deadline + SHORT_SLEEP,
646                                           ignore_status=True)
647    elif isinstance(src_device, ssh.connection.SshConnection):
648        ping_cmd = 'sudo {} {}'.format(ping_cmd, dest_address)
649        ping_output = src_device.run(ping_cmd,
650                                     timeout=ping_deadline + SHORT_SLEEP,
651                                     ignore_status=True).stdout
652    else:
653        raise TypeError('Unable to ping using src_device of type %s.' %
654                        type(src_device))
655    return PingResult(ping_output.splitlines())
656
657
658@nonblocking
659def get_ping_stats_nb(src_device, dest_address, ping_duration, ping_interval,
660                      ping_size):
661    return get_ping_stats(src_device, dest_address, ping_duration,
662                          ping_interval, ping_size)
663
664
665@nonblocking
666def start_iperf_client_nb(iperf_client, iperf_server_address, iperf_args, tag,
667                          timeout):
668    return iperf_client.start(iperf_server_address, iperf_args, tag, timeout)
669
670
671# Rssi Utilities
672def empty_rssi_result():
673    return collections.OrderedDict([('data', []), ('mean', None),
674                                    ('stdev', None)])
675
676
677def get_connected_rssi(dut,
678                       num_measurements=1,
679                       polling_frequency=SHORT_SLEEP,
680                       first_measurement_delay=0,
681                       disconnect_warning=True):
682    """Gets all RSSI values reported for the connected access point/BSSID.
683
684    Args:
685        dut: android device object from which to get RSSI
686        num_measurements: number of scans done, and RSSIs collected
687        polling_frequency: time to wait between RSSI measurements
688        disconnect_warning: boolean controlling disconnection logging messages
689    Returns:
690        connected_rssi: dict containing the measurements results for
691        all reported RSSI values (signal_poll, per chain, etc.) and their
692        statistics
693    """
694    # yapf: disable
695    connected_rssi = collections.OrderedDict(
696        [('time_stamp', []),
697         ('bssid', []), ('frequency', []),
698         ('signal_poll_rssi', empty_rssi_result()),
699         ('signal_poll_avg_rssi', empty_rssi_result()),
700         ('chain_0_rssi', empty_rssi_result()),
701         ('chain_1_rssi', empty_rssi_result())])
702    # yapf: enable
703    previous_bssid = 'disconnected'
704    t0 = time.time()
705    time.sleep(first_measurement_delay)
706    for idx in range(num_measurements):
707        measurement_start_time = time.time()
708        connected_rssi['time_stamp'].append(measurement_start_time - t0)
709        # Get signal poll RSSI
710        status_output = dut.adb.shell(WPA_CLI_STATUS)
711        match = re.search('bssid=.*', status_output)
712        if match:
713            current_bssid = match.group(0).split('=')[1]
714            connected_rssi['bssid'].append(current_bssid)
715        else:
716            current_bssid = 'disconnected'
717            connected_rssi['bssid'].append(current_bssid)
718            if disconnect_warning and previous_bssid != 'disconnected':
719                logging.warning('WIFI DISCONNECT DETECTED!')
720        previous_bssid = current_bssid
721        signal_poll_output = dut.adb.shell(SIGNAL_POLL)
722        match = re.search('FREQUENCY=.*', signal_poll_output)
723        if match:
724            frequency = int(match.group(0).split('=')[1])
725            connected_rssi['frequency'].append(frequency)
726        else:
727            connected_rssi['frequency'].append(RSSI_ERROR_VAL)
728        match = re.search('RSSI=.*', signal_poll_output)
729        if match:
730            temp_rssi = int(match.group(0).split('=')[1])
731            if temp_rssi == -9999 or temp_rssi == 0:
732                connected_rssi['signal_poll_rssi']['data'].append(
733                    RSSI_ERROR_VAL)
734            else:
735                connected_rssi['signal_poll_rssi']['data'].append(temp_rssi)
736        else:
737            connected_rssi['signal_poll_rssi']['data'].append(RSSI_ERROR_VAL)
738        match = re.search('AVG_RSSI=.*', signal_poll_output)
739        if match:
740            connected_rssi['signal_poll_avg_rssi']['data'].append(
741                int(match.group(0).split('=')[1]))
742        else:
743            connected_rssi['signal_poll_avg_rssi']['data'].append(
744                RSSI_ERROR_VAL)
745        # Get per chain RSSI
746        per_chain_rssi = dut.adb.shell(STATION_DUMP)
747        match = re.search('.*signal avg:.*', per_chain_rssi)
748        if match:
749            per_chain_rssi = per_chain_rssi[per_chain_rssi.find('[') +
750                                            1:per_chain_rssi.find(']')]
751            per_chain_rssi = per_chain_rssi.split(', ')
752            connected_rssi['chain_0_rssi']['data'].append(
753                int(per_chain_rssi[0]))
754            connected_rssi['chain_1_rssi']['data'].append(
755                int(per_chain_rssi[1]))
756        else:
757            connected_rssi['chain_0_rssi']['data'].append(RSSI_ERROR_VAL)
758            connected_rssi['chain_1_rssi']['data'].append(RSSI_ERROR_VAL)
759        measurement_elapsed_time = time.time() - measurement_start_time
760        time.sleep(max(0, polling_frequency - measurement_elapsed_time))
761
762    # Compute mean RSSIs. Only average valid readings.
763    # Output RSSI_ERROR_VAL if no valid connected readings found.
764    for key, val in connected_rssi.copy().items():
765        if 'data' not in val:
766            continue
767        filtered_rssi_values = [x for x in val['data'] if not math.isnan(x)]
768        if filtered_rssi_values:
769            connected_rssi[key]['mean'] = statistics.mean(filtered_rssi_values)
770            if len(filtered_rssi_values) > 1:
771                connected_rssi[key]['stdev'] = statistics.stdev(
772                    filtered_rssi_values)
773            else:
774                connected_rssi[key]['stdev'] = 0
775        else:
776            connected_rssi[key]['mean'] = RSSI_ERROR_VAL
777            connected_rssi[key]['stdev'] = RSSI_ERROR_VAL
778    return connected_rssi
779
780
781@nonblocking
782def get_connected_rssi_nb(dut,
783                          num_measurements=1,
784                          polling_frequency=SHORT_SLEEP,
785                          first_measurement_delay=0,
786                          disconnect_warning=True):
787    return get_connected_rssi(dut, num_measurements, polling_frequency,
788                              first_measurement_delay)
789
790
791def get_scan_rssi(dut, tracked_bssids, num_measurements=1):
792    """Gets scan RSSI for specified BSSIDs.
793
794    Args:
795        dut: android device object from which to get RSSI
796        tracked_bssids: array of BSSIDs to gather RSSI data for
797        num_measurements: number of scans done, and RSSIs collected
798    Returns:
799        scan_rssi: dict containing the measurement results as well as the
800        statistics of the scan RSSI for all BSSIDs in tracked_bssids
801    """
802    scan_rssi = collections.OrderedDict()
803    for bssid in tracked_bssids:
804        scan_rssi[bssid] = empty_rssi_result()
805    for idx in range(num_measurements):
806        scan_output = dut.adb.shell(SCAN)
807        time.sleep(MED_SLEEP)
808        scan_output = dut.adb.shell(SCAN_RESULTS)
809        for bssid in tracked_bssids:
810            bssid_result = re.search(bssid + '.*',
811                                     scan_output,
812                                     flags=re.IGNORECASE)
813            if bssid_result:
814                bssid_result = bssid_result.group(0).split('\t')
815                scan_rssi[bssid]['data'].append(int(bssid_result[2]))
816            else:
817                scan_rssi[bssid]['data'].append(RSSI_ERROR_VAL)
818    # Compute mean RSSIs. Only average valid readings.
819    # Output RSSI_ERROR_VAL if no readings found.
820    for key, val in scan_rssi.items():
821        filtered_rssi_values = [x for x in val['data'] if not math.isnan(x)]
822        if filtered_rssi_values:
823            scan_rssi[key]['mean'] = statistics.mean(filtered_rssi_values)
824            if len(filtered_rssi_values) > 1:
825                scan_rssi[key]['stdev'] = statistics.stdev(
826                    filtered_rssi_values)
827            else:
828                scan_rssi[key]['stdev'] = 0
829        else:
830            scan_rssi[key]['mean'] = RSSI_ERROR_VAL
831            scan_rssi[key]['stdev'] = RSSI_ERROR_VAL
832    return scan_rssi
833
834
835@nonblocking
836def get_scan_rssi_nb(dut, tracked_bssids, num_measurements=1):
837    return get_scan_rssi(dut, tracked_bssids, num_measurements)
838
839
840# Attenuator Utilities
841def atten_by_label(atten_list, path_label, atten_level):
842    """Attenuate signals according to their path label.
843
844    Args:
845        atten_list: list of attenuators to iterate over
846        path_label: path label on which to set desired attenuation
847        atten_level: attenuation desired on path
848    """
849    for atten in atten_list:
850        if path_label in atten.path:
851            atten.set_atten(atten_level)
852
853
854def get_current_atten_dut_chain_map(attenuators, dut, ping_server):
855    """Function to detect mapping between attenuator ports and DUT chains.
856
857    This function detects the mapping between attenuator ports and DUT chains
858    in cases where DUT chains are connected to only one attenuator port. The
859    function assumes the DUT is already connected to a wifi network. The
860    function starts by measuring per chain RSSI at 0 attenuation, then
861    attenuates one port at a time looking for the chain that reports a lower
862    RSSI.
863
864    Args:
865        attenuators: list of attenuator ports
866        dut: android device object assumed connected to a wifi network.
867        ping_server: ssh connection object to ping server
868        ping_ip: ip to ping to keep connection alive and RSSI updated
869    Returns:
870        chain_map: list of dut chains, one entry per attenuator port
871    """
872    # Set attenuator to 0 dB
873    for atten in attenuators:
874        atten.set_atten(0, strict=False)
875    # Start ping traffic
876    dut_ip = dut.droid.connectivityGetIPv4Addresses('wlan0')[0]
877    ping_future = get_ping_stats_nb(ping_server, dut_ip, 11, 0.02, 64)
878    # Measure starting RSSI
879    base_rssi = get_connected_rssi(dut, 4, 0.25, 1)
880    chain0_base_rssi = base_rssi['chain_0_rssi']['mean']
881    chain1_base_rssi = base_rssi['chain_1_rssi']['mean']
882    if chain0_base_rssi < -70 or chain1_base_rssi < -70:
883        logging.warning('RSSI might be too low to get reliable chain map.')
884    # Compile chain map by attenuating one path at a time and seeing which
885    # chain's RSSI degrades
886    chain_map = []
887    for test_atten in attenuators:
888        # Set one attenuator to 30 dB down
889        test_atten.set_atten(30, strict=False)
890        # Get new RSSI
891        test_rssi = get_connected_rssi(dut, 4, 0.25, 1)
892        # Assign attenuator to path that has lower RSSI
893        if chain0_base_rssi > -70 and chain0_base_rssi - test_rssi[
894                'chain_0_rssi']['mean'] > 10:
895            chain_map.append('DUT-Chain-0')
896        elif chain1_base_rssi > -70 and chain1_base_rssi - test_rssi[
897                'chain_1_rssi']['mean'] > 10:
898            chain_map.append('DUT-Chain-1')
899        else:
900            chain_map.append(None)
901        # Reset attenuator to 0
902        test_atten.set_atten(0, strict=False)
903    ping_future.result()
904    logging.debug('Chain Map: {}'.format(chain_map))
905    return chain_map
906
907
908def get_full_rf_connection_map(attenuators, dut, ping_server, networks):
909    """Function to detect per-network connections between attenuator and DUT.
910
911    This function detects the mapping between attenuator ports and DUT chains
912    on all networks in its arguments. The function connects the DUT to each
913    network then calls get_current_atten_dut_chain_map to get the connection
914    map on the current network. The function outputs the results in two formats
915    to enable easy access when users are interested in indexing by network or
916    attenuator port.
917
918    Args:
919        attenuators: list of attenuator ports
920        dut: android device object assumed connected to a wifi network.
921        ping_server: ssh connection object to ping server
922        networks: dict of network IDs and configs
923    Returns:
924        rf_map_by_network: dict of RF connections indexed by network.
925        rf_map_by_atten: list of RF connections indexed by attenuator
926    """
927    for atten in attenuators:
928        atten.set_atten(0, strict=False)
929
930    rf_map_by_network = collections.OrderedDict()
931    rf_map_by_atten = [[] for atten in attenuators]
932    for net_id, net_config in networks.items():
933        wutils.reset_wifi(dut)
934        wutils.wifi_connect(dut,
935                            net_config,
936                            num_of_tries=1,
937                            assert_on_fail=False,
938                            check_connectivity=False)
939        rf_map_by_network[net_id] = get_current_atten_dut_chain_map(
940            attenuators, dut, ping_server)
941        for idx, chain in enumerate(rf_map_by_network[net_id]):
942            if chain:
943                rf_map_by_atten[idx].append({
944                    "network": net_id,
945                    "dut_chain": chain
946                })
947    logging.debug("RF Map (by Network): {}".format(rf_map_by_network))
948    logging.debug("RF Map (by Atten): {}".format(rf_map_by_atten))
949
950    return rf_map_by_network, rf_map_by_atten
951
952
953# Miscellaneous Wifi Utilities
954def validate_network(dut, ssid):
955    """Check that DUT has a valid internet connection through expected SSID
956
957    Args:
958        dut: android device of interest
959        ssid: expected ssid
960    """
961    current_network = dut.droid.wifiGetConnectionInfo()
962    try:
963        connected = wutils.validate_connection(dut) is not None
964    except:
965        connected = False
966    if connected and current_network['SSID'] == ssid:
967        return True
968    else:
969        return False
970
971
972def get_server_address(ssh_connection, dut_ip, subnet_mask):
973    """Get server address on a specific subnet,
974
975    This function retrieves the LAN IP of a remote machine used in testing,
976    i.e., it returns the server's IP belonging to the same LAN as the DUT.
977
978    Args:
979        ssh_connection: object representing server for which we want an ip
980        dut_ip: string in ip address format, i.e., xxx.xxx.xxx.xxx, specifying
981        the DUT LAN IP we wish to connect to
982        subnet_mask: string representing subnet mask
983    """
984    subnet_mask = subnet_mask.split('.')
985    dut_subnet = [
986        int(dut) & int(subnet)
987        for dut, subnet in zip(dut_ip.split('.'), subnet_mask)
988    ]
989    ifconfig_out = ssh_connection.run('ifconfig').stdout
990    ip_list = re.findall('inet (?:addr:)?(\d+.\d+.\d+.\d+)', ifconfig_out)
991    for current_ip in ip_list:
992        current_subnet = [
993            int(ip) & int(subnet)
994            for ip, subnet in zip(current_ip.split('.'), subnet_mask)
995        ]
996        if current_subnet == dut_subnet:
997            return current_ip
998    logging.error('No IP address found in requested subnet')
999
1000
1001def get_iperf_arg_string(duration,
1002                         reverse_direction,
1003                         interval=1,
1004                         traffic_type='TCP',
1005                         tcp_window=None,
1006                         tcp_processes=1,
1007                         udp_throughput='1000M'):
1008    """Function to format iperf client arguments.
1009
1010    This function takes in iperf client parameters and returns a properly
1011    formatter iperf arg string to be used in throughput tests.
1012
1013    Args:
1014        duration: iperf duration in seconds
1015        reverse_direction: boolean controlling the -R flag for iperf clients
1016        interval: iperf print interval
1017        traffic_type: string specifying TCP or UDP traffic
1018        tcp_window: string specifying TCP window, e.g., 2M
1019        tcp_processes: int specifying number of tcp processes
1020        udp_throughput: string specifying TX throughput in UDP tests, e.g. 100M
1021    Returns:
1022        iperf_args: string of formatted iperf args
1023    """
1024    iperf_args = '-i {} -t {} -J '.format(interval, duration)
1025    if traffic_type.upper() == 'UDP':
1026        iperf_args = iperf_args + '-u -b {} -l 1400'.format(udp_throughput)
1027    elif traffic_type.upper() == 'TCP':
1028        iperf_args = iperf_args + '-P {}'.format(tcp_processes)
1029        if tcp_window:
1030            iperf_args = iperf_args + '-w {}'.format(tcp_window)
1031    if reverse_direction:
1032        iperf_args = iperf_args + ' -R'
1033    return iperf_args
1034
1035
1036def get_dut_temperature(dut):
1037    """Function to get dut temperature.
1038
1039    The function fetches and returns the reading from the temperature sensor
1040    used for skin temperature and thermal throttling.
1041
1042    Args:
1043        dut: AndroidDevice of interest
1044    Returns:
1045        temperature: device temperature. 0 if temperature could not be read
1046    """
1047    candidate_zones = [
1048        'skin-therm', 'sdm-therm-monitor', 'sdm-therm-adc', 'back_therm'
1049    ]
1050    for zone in candidate_zones:
1051        try:
1052            temperature = int(
1053                dut.adb.shell(
1054                    'cat /sys/class/thermal/tz-by-name/{}/temp'.format(zone)))
1055            break
1056        except ValueError:
1057            temperature = 0
1058    if temperature == 0:
1059        logging.debug('Could not check DUT temperature.')
1060    elif temperature > 100:
1061        temperature = temperature / 1000
1062    return temperature
1063
1064
1065def wait_for_dut_cooldown(dut, target_temp=50, timeout=300):
1066    """Function to wait for a DUT to cool down.
1067
1068    Args:
1069        dut: AndroidDevice of interest
1070        target_temp: target cooldown temperature
1071        timeout: maxt time to wait for cooldown
1072    """
1073    start_time = time.time()
1074    while time.time() - start_time < timeout:
1075        temperature = get_dut_temperature(dut)
1076        if temperature < target_temp:
1077            break
1078        time.sleep(SHORT_SLEEP)
1079    elapsed_time = time.time() - start_time
1080    logging.debug("DUT Final Temperature: {}C. Cooldown duration: {}".format(
1081        temperature, elapsed_time))
1082
1083
1084def health_check(dut, batt_thresh=5, temp_threshold=53, cooldown=1):
1085    """Function to check health status of a DUT.
1086
1087    The function checks both battery levels and temperature to avoid DUT
1088    powering off during the test.
1089
1090    Args:
1091        dut: AndroidDevice of interest
1092        batt_thresh: battery level threshold
1093        temp_threshold: temperature threshold
1094        cooldown: flag to wait for DUT to cool down when overheating
1095    Returns:
1096        health_check: boolean confirming device is healthy
1097    """
1098    health_check = True
1099    battery_level = utils.get_battery_level(dut)
1100    if battery_level < batt_thresh:
1101        logging.warning("Battery level low ({}%)".format(battery_level))
1102        health_check = False
1103    else:
1104        logging.debug("Battery level = {}%".format(battery_level))
1105
1106    temperature = get_dut_temperature(dut)
1107    if temperature > temp_threshold:
1108        if cooldown:
1109            logging.warning(
1110                "Waiting for DUT to cooldown. ({} C)".format(temperature))
1111            wait_for_dut_cooldown(dut, target_temp=temp_threshold - 5)
1112        else:
1113            logging.warning("DUT Overheating ({} C)".format(temperature))
1114            health_check = False
1115    else:
1116        logging.debug("DUT Temperature = {} C".format(temperature))
1117    return health_check
1118
1119
1120def push_bdf(dut, bdf_file):
1121    """Function to push Wifi BDF files
1122
1123    This function checks for existing wifi bdf files and over writes them all,
1124    for simplicity, with the bdf file provided in the arguments. The dut is
1125    rebooted for the bdf file to take effect
1126
1127    Args:
1128        dut: dut to push bdf file to
1129        bdf_file: path to bdf_file to push
1130    """
1131    bdf_files_list = dut.adb.shell('ls /vendor/firmware/bdwlan*').splitlines()
1132    for dst_file in bdf_files_list:
1133        dut.push_system_file(bdf_file, dst_file)
1134    dut.reboot()
1135
1136
1137def push_firmware(dut, wlanmdsp_file, datamsc_file):
1138    """Function to push Wifi firmware files
1139
1140    Args:
1141        dut: dut to push bdf file to
1142        wlanmdsp_file: path to wlanmdsp.mbn file
1143        datamsc_file: path to Data.msc file
1144    """
1145    dut.push_system_file(wlanmdsp_file, '/vendor/firmware/wlanmdsp.mbn')
1146    dut.push_system_file(datamsc_file, '/vendor/firmware/Data.msc')
1147    dut.reboot()
1148
1149
1150def _set_ini_fields(ini_file_path, ini_field_dict):
1151    template_regex = r'^{}=[0-9,.x-]+'
1152    with open(ini_file_path, 'r') as f:
1153        ini_lines = f.read().splitlines()
1154        for idx, line in enumerate(ini_lines):
1155            for field_name, field_value in ini_field_dict.items():
1156                line_regex = re.compile(template_regex.format(field_name))
1157                if re.match(line_regex, line):
1158                    ini_lines[idx] = "{}={}".format(field_name, field_value)
1159                    print(ini_lines[idx])
1160    with open(ini_file_path, 'w') as f:
1161        f.write("\n".join(ini_lines) + "\n")
1162
1163
1164def _edit_dut_ini(dut, ini_fields):
1165    """Function to edit Wifi ini files."""
1166    dut_ini_path = '/vendor/firmware/wlan/qca_cld/WCNSS_qcom_cfg.ini'
1167    local_ini_path = os.path.expanduser('~/WCNSS_qcom_cfg.ini')
1168    dut.pull_files(dut_ini_path, local_ini_path)
1169
1170    _set_ini_fields(local_ini_path, ini_fields)
1171
1172    dut.push_system_file(local_ini_path, dut_ini_path)
1173    dut.reboot()
1174
1175
1176def set_ini_single_chain_mode(dut, chain):
1177    ini_fields = {
1178        'gEnable2x2': 0,
1179        'gSetTxChainmask1x1': chain + 1,
1180        'gSetRxChainmask1x1': chain + 1,
1181        'gDualMacFeatureDisable': 1,
1182        'gDot11Mode': 0
1183    }
1184    _edit_dut_ini(dut, ini_fields)
1185
1186
1187def set_ini_two_chain_mode(dut):
1188    ini_fields = {
1189        'gEnable2x2': 2,
1190        'gSetTxChainmask1x1': 1,
1191        'gSetRxChainmask1x1': 1,
1192        'gDualMacFeatureDisable': 6,
1193        'gDot11Mode': 0
1194    }
1195    _edit_dut_ini(dut, ini_fields)
1196
1197
1198def set_ini_tx_mode(dut, mode):
1199    TX_MODE_DICT = {
1200        "Auto": 0,
1201        "11n": 4,
1202        "11ac": 9,
1203        "11abg": 1,
1204        "11b": 2,
1205        "11g": 3,
1206        "11g only": 5,
1207        "11n only": 6,
1208        "11b only": 7,
1209        "11ac only": 8
1210    }
1211
1212    ini_fields = {
1213        'gEnable2x2': 2,
1214        'gSetTxChainmask1x1': 1,
1215        'gSetRxChainmask1x1': 1,
1216        'gDualMacFeatureDisable': 6,
1217        'gDot11Mode': TX_MODE_DICT[mode]
1218    }
1219    _edit_dut_ini(dut, ini_fields)
1220