1#  Copyright (C) 2024 The Android Open Source Project
2#
3#  Licensed under the Apache License, Version 2.0 (the "License");
4#  you may not use this file except in compliance with the License.
5#  You may obtain a copy of the License at
6#
7#       http://www.apache.org/licenses/LICENSE-2.0
8#
9#  Unless required by applicable law or agreed to in writing, software
10#  distributed under the License is distributed on an "AS IS" BASIS,
11#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12#  See the License for the specific language governing permissions and
13#  limitations under the License.
14
15"""Utils for handling Nearby Connection rpc."""
16
17import datetime
18import random
19import time
20
21from mobly import asserts
22from mobly import utils
23from mobly.controllers import android_device
24from mobly.controllers.android_device_lib import callback_handler_v2
25from mobly.controllers.android_device_lib import snippet_client_v2
26from mobly.snippet import callback_event
27
28from betocq import nc_constants
29
30# This number should be large enough to cover advertising interval, firmware
31# scheduling timing interval and user action delay
32ADVERTISING_TO_DISCOVERY_MAX_DELAY_SEC = 4
33
34
35class NearbyConnectionWrapper:
36  """Wrapper for Nearby Connection Snippet Client Operations."""
37
38  def __init__(
39      self,
40      advertiser: android_device.AndroidDevice,
41      discoverer: android_device.AndroidDevice,
42      advertiser_nearby: snippet_client_v2.SnippetClientV2,
43      discoverer_nearby: snippet_client_v2.SnippetClientV2,
44      advertising_discovery_medium: nc_constants.NearbyMedium = (
45          nc_constants.NearbyMedium.BLE_ONLY
46      ),
47      connection_medium: nc_constants.NearbyMedium = (
48          nc_constants.NearbyMedium.BT_ONLY
49      ),
50      upgrade_medium: nc_constants.NearbyMedium = (
51          nc_constants.NearbyMedium.BT_ONLY
52      ),
53  ):
54    self.advertiser = advertiser
55    self.discoverer = discoverer
56    self.service_id = utils.rand_ascii_str(8)
57    self.advertising_discovery_medium = advertising_discovery_medium
58    self.connection_medium = connection_medium
59    self.upgrade_medium = upgrade_medium
60    self.discoverer_nearby = discoverer_nearby
61    self.advertiser_nearby = advertiser_nearby
62    self.test_failure_reason = (
63        nc_constants.SingleTestFailureReason.UNINITIALIZED
64        )
65
66    self.connection_quality_info: nc_constants.ConnectionSetupQualityInfo = (
67        nc_constants.ConnectionSetupQualityInfo())
68
69    self._advertiser_connection_lifecycle_callback: (
70        callback_handler_v2.CallbackHandlerV2) = None
71    self._discoverer_endpoint_discovery_callback: (
72        callback_handler_v2.CallbackHandlerV2) = None
73    self._discoverer_connection_lifecycle_callback: (
74        callback_handler_v2.CallbackHandlerV2) = None
75    self._advertiser_payload_callback: (
76        callback_handler_v2.CallbackHandlerV2) = None
77    self._discoverer_payload_callback: (
78        callback_handler_v2.CallbackHandlerV2) = None
79    self._advertiser_endpoint_id: str = None
80    self._discoverer_endpoint_id: str = None
81
82  def start_advertising(self) -> None:
83    """Starts Nearby Connection advertising."""
84    advertiser_callback = self.advertiser_nearby.startAdvertising(
85        self.advertiser.serial,
86        self.service_id,
87        self.advertising_discovery_medium.value,
88        self.upgrade_medium.value,
89    )
90    self.advertiser.log.info(
91        f'Start advertising {self.advertising_discovery_medium.name}'
92    )
93    self._advertiser_connection_lifecycle_callback = advertiser_callback
94
95  def start_discovery(self, timeout: datetime.timedelta) -> None:
96    """Starts Nearby Connection discovery."""
97    self.discoverer.log.info(
98        f'Start discovery {self.advertising_discovery_medium.name}'
99    )
100    self._discoverer_endpoint_discovery_callback = (
101        self.discoverer_nearby.startDiscovery(
102            self.service_id, self.advertising_discovery_medium.value
103        )
104    )
105
106    endpoint_found_event = (
107        self._discoverer_endpoint_discovery_callback.waitAndGet(
108            'onEndpointFound', timeout=timeout.total_seconds()
109        )
110    )
111    endpoint_info = endpoint_found_event.data['discoveredEndpointInfo']
112    self.connection_quality_info.discovery_latency = datetime.timedelta(
113        microseconds=endpoint_found_event.data['discoveryTimeNs'] / 1_000
114    )
115    asserts.assert_equal(
116        endpoint_info['endpointName'], self.advertiser.serial,
117        'Received an unexpected endpoint during discovery: '
118        f'{endpoint_found_event}')
119
120    asserts.assert_equal(
121        endpoint_info['serviceId'], self.service_id,
122        f'Received an unexpected service id during discovery: '
123        f'{endpoint_found_event}')
124    self._advertiser_endpoint_id = endpoint_found_event.data['endpointId']
125
126  def stop_advertising(self) -> None:
127    """Stops Nearby Connection advertising."""
128    self.advertiser_nearby.stopAdvertising()
129    self.advertiser.log.info('Stop advertising')
130
131  def stop_discovery(self) -> None:
132    """Stops Nearby Connection discovery."""
133    self.discoverer_nearby.stopDiscovery()
134    self.discoverer.log.info('Stop discovery')
135
136  def request_connection(
137      self,
138      medium_upgrade_type: nc_constants.MediumUpgradeType,
139      timeout: datetime.timedelta,
140      keep_alive_timeout_ms: int = nc_constants.KEEP_ALIVE_TIMEOUT_BT_MS,
141      keep_alive_interval_ms: int = nc_constants.KEEP_ALIVE_INTERVAL_BT_MS,
142  ) -> None:
143    """Requests Nearby Connection."""
144
145    self.discoverer.log.info(
146        'Start connection request with keep_alive_timeout_ms'
147        f' {keep_alive_timeout_ms}'
148    )
149    self._discoverer_connection_lifecycle_callback = (
150        self.discoverer_nearby.requestConnection(
151            self.discoverer.serial,
152            self._advertiser_endpoint_id,
153            self.connection_medium.value,
154            self.upgrade_medium.value,
155            medium_upgrade_type.value,
156            keep_alive_timeout_ms,
157            keep_alive_interval_ms,
158        )
159    )
160
161    d_connection_init_event = (
162        self._discoverer_connection_lifecycle_callback.waitAndGet(
163            'onConnectionInitiated', timeout.total_seconds()
164        )
165    )
166    self.connection_quality_info.connection_latency = datetime.timedelta(
167        microseconds=d_connection_init_event.data['connectionTimeNs'] / 1_000
168    )
169
170    d_connection_info = d_connection_init_event.data['connectionInfo']
171    asserts.assert_false(
172        d_connection_info['isIncomingConnection'],
173        f'Received an incoming connection: {d_connection_init_event}'
174        'but expected an outgoing connection')
175
176    asserts.assert_equal(
177        d_connection_info['endpointName'],
178        self.advertiser.serial,
179        f'Received an unexpected endpoint: {d_connection_init_event}')
180
181    # wait for the advertiser connection initialized.
182    a_connection_init_event = (
183        self._advertiser_connection_lifecycle_callback.waitAndGet(
184            'onConnectionInitiated', timeout=timeout.total_seconds()
185        )
186    )
187    a_connection_info = a_connection_init_event.data['connectionInfo']
188    asserts.assert_true(
189        a_connection_info['isIncomingConnection'],
190        f'Received an outgoing connection: {d_connection_init_event}'
191        'but expected an incoming connection')
192
193    asserts.assert_equal(
194        a_connection_info['endpointName'],
195        self.discoverer.serial,
196        f'Received an unexpected endpoint: {a_connection_init_event}')
197
198    self._discoverer_endpoint_id = a_connection_init_event.data['endpointId']
199
200  def accept_connection(
201      self, timeout: datetime.timedelta
202  ) -> None:
203    """Accepts Nearby Connection."""
204    self._advertiser_payload_callback = (
205        self.advertiser_nearby.acceptConnection(
206            self._discoverer_endpoint_id
207        )
208    )
209    self.advertiser.log.info('Start connection accept')
210    self._discoverer_payload_callback = (
211        self.discoverer_nearby.acceptConnection(
212            self._advertiser_endpoint_id
213        )
214    )
215    self.discoverer.log.info('Start connection accept')
216
217    advertiser_connection_event = (
218        self._advertiser_connection_lifecycle_callback.waitAndGet(
219            'onConnectionResult', timeout=timeout.total_seconds()
220        )
221    )
222
223    asserts.assert_true(
224        advertiser_connection_event.data['isSuccess'],
225        f'Received an unsuccessful event: {advertiser_connection_event}')
226
227    asserts.assert_equal(
228        advertiser_connection_event.data['endpointId'],
229        self._discoverer_endpoint_id,
230        f'Received an unexpected endpoint: {advertiser_connection_event}')
231
232    discoverer_connection_event = (
233        self._discoverer_connection_lifecycle_callback.waitAndGet(
234            'onConnectionResult', timeout=timeout.total_seconds()
235        )
236    )
237    asserts.assert_true(
238        discoverer_connection_event.data['isSuccess'],
239        f'Received an unsuccessful event: {discoverer_connection_event}')
240
241    asserts.assert_equal(
242        discoverer_connection_event.data['endpointId'],
243        self._advertiser_endpoint_id,
244        f'Received an unexpected endpoint: {discoverer_connection_event}')
245
246    if nc_constants.is_high_quality_medium(self.upgrade_medium):
247      self.test_failure_reason = (
248          nc_constants.SingleTestFailureReason.WIFI_MEDIUM_UPGRADE
249      )
250      upgrade_start_time = datetime.datetime.now()
251      wait_high_quality = True
252      while wait_high_quality:
253        discoverer_medium_upgrade_event = self._discoverer_connection_lifecycle_callback.waitAndGet(
254            'onBandwidthChanged',
255            nc_constants.CONNECTION_BANDWIDTH_CHANGED_TIMEOUT.total_seconds(),
256        )
257        self.discoverer.log.info(
258            f'medium upgrade to {discoverer_medium_upgrade_event.data}'
259        )
260        if discoverer_medium_upgrade_event.data['isHighBwQuality']:
261          wait_high_quality = False
262          self.connection_quality_info.medium_upgrade_latency = (
263              datetime.datetime.now() - upgrade_start_time)
264          self.connection_quality_info.upgrade_medium = (
265              nc_constants.NearbyConnectionMedium(
266                  discoverer_medium_upgrade_event.data['medium']))
267          self.connection_quality_info.medium_upgrade_expected = True
268          self.discoverer.log.info(
269              f'upgraded to high quality medium: '
270              f'{self.connection_quality_info.upgrade_medium.name}')
271        else:
272          latency = datetime.datetime.now() - upgrade_start_time
273          if latency >= nc_constants.CONNECTION_BANDWIDTH_CHANGED_TIMEOUT:
274            raise TimeoutError('medium upgrade timeout')
275
276  def disconnect_endpoint(self) -> None:
277    """Disconnects Nearby Connection endpoint."""
278    if self:
279      self.discoverer_nearby.disconnectFromEndpoint(
280          self._advertiser_endpoint_id
281      )
282      self.discoverer.log.info(
283          f'Start disconnecting from endpoint: {self._advertiser_endpoint_id}'
284      )
285    else:
286      self.discoverer.log.info('no nearby connecty setup yet')
287      return nc_constants.OpResult(nc_constants.Result.SUCCESS)
288
289    if self._discoverer_connection_lifecycle_callback is not None:
290      disconnected_event = (
291          self._discoverer_connection_lifecycle_callback.waitAndGet(
292              'onDisconnected',
293              timeout=nc_constants.DISCONNECTION_TIMEOUT.total_seconds(),
294          )
295      )
296      asserts.assert_equal(
297          disconnected_event.data['endpointId'],
298          self._advertiser_endpoint_id,
299          f'Receive unexpected event on disconnect: {disconnected_event}')
300    self.discoverer.log.info(
301        f'disconnected with endpoint: {self._advertiser_endpoint_id}'
302    )
303
304  def start_nearby_connection(
305      self,
306      timeouts: nc_constants.ConnectionSetupTimeouts,
307      medium_upgrade_type: nc_constants.MediumUpgradeType = nc_constants.MediumUpgradeType.DEFAULT,
308      keep_alive_timeout_ms: int = 0,
309      keep_alive_interval_ms: int = 0,
310  ) -> None:
311    """Starts Nearby Connection between two Android devices."""
312    self.test_failure_reason = (
313        nc_constants.SingleTestFailureReason.TARGET_START_ADVERTISING)
314    # Start advertising.
315    self.start_advertising()
316    # Add a random delay between adversting and discovery
317    # to mimic the random delay between two devices' user action
318    time.sleep(ADVERTISING_TO_DISCOVERY_MAX_DELAY_SEC * random.random())
319
320    self.test_failure_reason = (
321        nc_constants.SingleTestFailureReason.SOURCE_START_DISCOVERY)
322    # Start discovery.
323    self.start_discovery(timeout=timeouts.discovery_timeout)
324
325    # Request connection.
326    self.test_failure_reason = (
327        nc_constants.SingleTestFailureReason.SOURCE_REQUEST_CONNECTION)
328    self.request_connection(
329        medium_upgrade_type=medium_upgrade_type,
330        timeout=timeouts.connection_init_timeout,
331        keep_alive_timeout_ms=keep_alive_timeout_ms,
332        keep_alive_interval_ms=keep_alive_interval_ms)
333
334    # Stop discovery.
335    self.stop_discovery()
336
337    # Accept connection.
338    self.test_failure_reason = (
339        nc_constants.SingleTestFailureReason.TARGET_ACCEPT_CONNECTION)
340    self.accept_connection(timeout=timeouts.connection_result_timeout)
341
342    # Stop advertising.
343    self.stop_advertising()
344    self.test_failure_reason = nc_constants.SingleTestFailureReason.SUCCESS
345
346  def transfer_file(
347      self,
348      file_size_kb: int,
349      timeout: datetime.timedelta,
350      payload_type: nc_constants.PayloadType,
351  ) -> float:
352    """Sends payloads and returns the transfer speed in kBS."""
353    try:
354      self.test_failure_reason = (
355          nc_constants.SingleTestFailureReason.FILE_TRANSFER_FAIL
356      )
357      transfer_speed_kbs = self._transfer_file(
358          file_size_kb, timeout, payload_type
359      )
360      self.test_failure_reason = nc_constants.SingleTestFailureReason.SUCCESS
361    finally:
362      # clean up
363      utils.concurrent_exec(
364          lambda nb: nb.transferFilesCleanup(),
365          param_list=[[self.discoverer_nearby], [self.advertiser_nearby]],
366          raise_on_exception=True)
367    return transfer_speed_kbs
368
369  def _transfer_file(
370      self, file_size_kb: int, timeout: datetime.timedelta,
371      payload_type: nc_constants.PayloadType
372  ) -> float:
373    """Sends payloads and returns the transfer speed in kBS."""
374    # Creates a file and send it to the advertiser.
375    file_name = utils.rand_ascii_str(8)
376    self.discoverer.log.info(
377        f'Start sending payloads with type: {payload_type.name}'
378    )
379    payload_id = self.discoverer_nearby.sendPayloadWithType(
380        self._advertiser_endpoint_id, file_name, file_size_kb, payload_type
381    )
382
383    # Waits for the advertiser received.
384    def on_receive(event: callback_event.CallbackEvent) -> bool:
385      return (
386          event.data['endpointId'] == self._discoverer_endpoint_id
387          and event.data['payload']['id'] == payload_id
388      )
389
390    asserts.assert_is_not_none(
391        self._advertiser_payload_callback,
392        'No nearby connection is set up, advertiser payload cb is none.')
393    asserts.assert_is_not_none(
394        self._discoverer_payload_callback,
395        'No nearby connection is set up, discoverer payload cb is none.')
396
397    self._advertiser_payload_callback.waitForEvent(
398        'onPayloadReceived',
399        predicate=on_receive,
400        timeout=timeout.total_seconds())
401
402    # Waits for complete transfer.
403    self._advertiser_payload_callback.waitForEvent(
404        'onPayloadTransferUpdate',
405        predicate=lambda event: event.data['update']['isSuccess'],
406        timeout=timeout.total_seconds())
407
408    payload_transfer_event = self._discoverer_payload_callback.waitForEvent(
409        'onPayloadTransferUpdate',
410        predicate=lambda event: event.data['update']['isSuccess'],
411        timeout=timeout.total_seconds(),
412    )
413    self.advertiser.log.info('payload received')
414
415    transfer_time = datetime.timedelta(
416        microseconds=payload_transfer_event.data['transferTimeNs'] / 1_000)
417    return round(file_size_kb/transfer_time.total_seconds())
418