1# Copyright (c) 2015 The Chromium OS Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5import array
6import logging
7import mox
8import multiprocessing
9import struct
10import unittest
11
12import common
13from autotest_lib.client.cros.cellular.mbim_compliance import mbim_channel
14from autotest_lib.client.cros.cellular.mbim_compliance import mbim_errors
15
16
17class MBIMChannelTestCase(unittest.TestCase):
18    """ Test cases for the MBIMChannel class. """
19
20    def setUp(self):
21        # Arguments passed to MBIMChannel. Irrelevant for these tests, mostly.
22        self._device = None
23        self._interface_number = 0
24        self._interrupt_endpoint_address = 0x01
25        self._in_buffer_size = 100
26
27        self._setup_mock_subprocess()
28        self._mox = mox.Mox()
29
30        # Reach into |MBIMChannel| and mock out the request queue, so we can set
31        # expectations on it.
32        # |multiprocessing.Queue| is actually a function that returns some
33        # hidden |multiprocessing.queues.Queue| class. We'll grab the class from
34        # a temporary object so we can mock it.
35        some_queue = multiprocessing.Queue()
36        queue_class = some_queue.__class__
37        self._mock_request_queue = self._mox.CreateMock(queue_class)
38        self._channel._request_queue = self._mock_request_queue
39
40        # On the other hand, just grab the real response queue.
41        self._response_queue = self._channel._response_queue
42
43        # Decrease timeouts to small values to speed up tests.
44        self._channel.FRAGMENT_TIMEOUT_S = 0.2
45        self._channel.TRANSACTION_TIMEOUT_S = 0.5
46
47
48    def tearDown(self):
49        self._channel.close()
50        self._subprocess_mox.VerifyAll()
51
52
53    def _setup_mock_subprocess(self):
54        """
55        Setup long-term expectations on the mocked out subprocess.
56
57        These expectations are only met when |self._channel.close| is called in
58        |tearDown|.
59
60        """
61        self._subprocess_mox = mox.Mox()
62        mock_process = self._subprocess_mox.CreateMock(multiprocessing.Process)
63        mock_process(target=mox.IgnoreArg(),
64                     args=mox.IgnoreArg()).AndReturn(mock_process)
65        mock_process.start()
66
67        # Each API call into MBIMChannel results in an aliveness ping to the
68        # subprocess.
69        # Finally, when |self._channel| is destructed, it will attempt to
70        # terminate the |mock_process|, with increasingly drastic actions.
71        mock_process.is_alive().MultipleTimes().AndReturn(True)
72        mock_process.join(mox.IgnoreArg())
73        mock_process.is_alive().AndReturn(True)
74        mock_process.terminate()
75
76        self._subprocess_mox.ReplayAll()
77        self._channel = mbim_channel.MBIMChannel(
78                self._device,
79                self._interface_number,
80                self._interrupt_endpoint_address,
81                self._in_buffer_size,
82                mock_process)
83
84
85    def test_creation(self):
86        """ A trivial test that we mocked out the |Process| class correctly. """
87        pass
88
89
90    def test_unfragmented_packet_successful(self):
91        """ Test that we can synchronously send an unfragmented packet. """
92        packet = self._get_unfragmented_packet(1)
93        response_packet = self._get_unfragmented_packet(1)
94        self._expect_transaction([packet], [response_packet])
95        self._verify_transaction_successful([packet], [response_packet])
96
97
98    def test_unfragmented_packet_timeout(self):
99        """ Test the case when an unfragmented packet receives no response. """
100        packet = self._get_unfragmented_packet(1)
101        self._expect_transaction([packet])
102        self._verify_transaction_failed([packet])
103
104
105    def test_single_fragment_successful(self):
106        """ Test that we can synchronously send a fragmented packet. """
107        fragment = self._get_fragment(1, 1, 0)
108        response_fragment = self._get_fragment(1, 1, 0)
109        self._expect_transaction([fragment], [response_fragment])
110        self._verify_transaction_successful([fragment], [response_fragment])
111
112
113    def test_single_fragment_timeout(self):
114        """ Test the case when a fragmented packet receives no response. """
115        fragment = self._get_fragment(1, 1, 0)
116        self._expect_transaction([fragment])
117        self._verify_transaction_failed([fragment])
118
119
120    def test_single_fragment_corrupted_reply(self):
121        """ Test the case when the response has a corrupted fragment header. """
122        fragment = self._get_fragment(1, 1, 0)
123        response_fragment = self._get_fragment(1, 1, 0)
124        response_fragment = response_fragment[:len(response_fragment)-1]
125        self._expect_transaction([fragment], [response_fragment])
126        self._verify_transaction_failed([fragment])
127
128
129    def test_multiple_fragments_successful(self):
130        """ Test that we can send/recieve multi-fragment packets. """
131        fragment_0 = self._get_fragment(1, 2, 0)
132        fragment_1 = self._get_fragment(1, 2, 1)
133        response_fragment_0 = self._get_fragment(1, 2, 0)
134        response_fragment_1 = self._get_fragment(1, 2, 1)
135        self._expect_transaction([fragment_0, fragment_1],
136                                 [response_fragment_0, response_fragment_1])
137        self._verify_transaction_successful(
138                [fragment_0, fragment_1],
139                [response_fragment_0, response_fragment_1])
140
141
142    def test_multiple_fragments_incorrect_total_fragments(self):
143        """ Test the case when one of the fragment reports incorrect total. """
144        fragment = self._get_fragment(1, 1, 0)
145        response_fragment_0 = self._get_fragment(1, 2, 0)
146        # total_fragment should have been 2, but is 99.
147        response_fragment_1 = self._get_fragment(1, 99, 1)
148        self._expect_transaction([fragment],
149                                 [response_fragment_0, response_fragment_1])
150        self._verify_transaction_failed([fragment])
151
152
153    def test_multiple_fragments_reordered_reply_1(self):
154        """ Test the case when the first fragemnt reports incorrect index. """
155        fragment = self._get_fragment(1, 1, 0)
156        # Incorrect first fragment number.
157        response_fragment = self._get_fragment(1, 2, 1)
158        self._expect_transaction([fragment], [response_fragment])
159        self._verify_transaction_failed([fragment])
160
161
162    def test_multiple_fragments_reordered_reply_2(self):
163        """ Test the case when a follow up fragment reports incorrect index. """
164        fragment = self._get_fragment(1, 1, 0)
165        response_fragment_0 = self._get_fragment(1, 2, 0)
166        # Incorrect second fragment number.
167        response_fragment_1 = self._get_fragment(1, 2, 99)
168        self._expect_transaction([fragment],
169                                 [response_fragment_0, response_fragment_1])
170        self._verify_transaction_failed([fragment])
171
172
173    def test_multiple_fragments_insufficient_reply_timeout(self):
174        """ Test the case when we recieve only part of the response. """
175        fragment = self._get_fragment(1, 1, 0)
176        # The second fragment will never arrive.
177        response_fragment_0 = self._get_fragment(1, 2, 0)
178        self._expect_transaction([fragment], [response_fragment_0])
179        self._verify_transaction_successful([fragment], [response_fragment_0])
180
181
182    def test_unfragmented_packet_notification(self):
183        """ Test the case when a notification comes before the response. """
184        packet = self._get_unfragmented_packet(1)
185        response = self._get_unfragmented_packet(1)
186        notification = self._get_unfragmented_packet(0)
187        self._expect_transaction([packet], [notification, response])
188        self._verify_transaction_successful([packet], [response])
189        self.assertEqual([[notification]],
190                         self._channel.get_outstanding_packets())
191
192
193    def test_fragmented_notification(self):
194        """ Test the case when a fragmented notification preceeds response. """
195        packet_fragment_0 = self._get_fragment(1, 2, 0)
196        packet_fragment_1 = self._get_fragment(1, 2, 1)
197        response_fragment_0 = self._get_fragment(1, 2, 0)
198        response_fragment_1 = self._get_fragment(1, 2, 1)
199        notification_0_fragment_0 = self._get_fragment(0, 2, 0)
200        notification_0_fragment_1 = self._get_fragment(0, 2, 1)
201        notification_1_fragment_0 = self._get_fragment(99, 2, 0)
202        notification_1_fragment_1 = self._get_fragment(99, 2, 1)
203
204        self._expect_transaction(
205                [packet_fragment_0, packet_fragment_1],
206                [notification_0_fragment_0, notification_0_fragment_1,
207                 notification_1_fragment_0, notification_1_fragment_1,
208                 response_fragment_0, response_fragment_1])
209        self._verify_transaction_successful(
210                [packet_fragment_0, packet_fragment_1],
211                [response_fragment_0, response_fragment_1])
212        self.assertEqual(
213                [[notification_0_fragment_0, notification_0_fragment_1],
214                 [notification_1_fragment_0, notification_1_fragment_1]],
215                self._channel.get_outstanding_packets())
216
217
218    def test_multiple_packets_rollover_notification(self):
219        """
220        Test the case when we receive incomplete response, followed by
221        fragmented notifications.
222
223        We have to be smart enough to realize that the incorrect fragment
224        recieved at the end of the response belongs to the next notification
225        instead.
226
227        """
228        packet = self._get_fragment(1, 1, 0)
229        # The second fragment never comes, instead we get a notification
230        # fragment.
231        response_fragment_0 = self._get_fragment(1, 2, 0)
232        notification_0_fragment_0 = self._get_fragment(0, 2, 0)
233        notification_0_fragment_1 = self._get_fragment(0, 2, 1)
234        notification_1_fragment_0 = self._get_fragment(99, 2, 0)
235        notification_1_fragment_1 = self._get_fragment(99, 2, 1)
236
237        self._expect_transaction(
238                [packet],
239                [response_fragment_0,
240                 notification_0_fragment_0, notification_0_fragment_1,
241                 notification_1_fragment_0, notification_1_fragment_1])
242        self._verify_transaction_successful(
243                [packet],
244                [response_fragment_0])
245        self.assertEqual(
246                [[notification_0_fragment_0, notification_0_fragment_1],
247                 [notification_1_fragment_0, notification_1_fragment_1]],
248                self._channel.get_outstanding_packets())
249
250
251    def test_data(self):
252        """ Test that data is transferred transaperntly. """
253        packet = self._get_unfragmented_packet(1)
254        packet.fromlist([0xFF, 0xFF, 0xFF, 0xFF, 0xDD, 0xDD, 0xDD, 0xDD])
255        response_packet = self._get_unfragmented_packet(1)
256        response_packet.fromlist([0xAA, 0xAA, 0xBB, 0xBB])
257        self._expect_transaction([packet], [response_packet])
258        self._verify_transaction_successful([packet], [response_packet])
259
260
261    def test_flush_successful(self):
262        """ Test that flush clears all queues. """
263        packet = self._get_unfragmented_packet(1)
264        response = self._get_unfragmented_packet(1)
265        notification_1 = self._get_fragment(0, 1, 0)
266        self._response_queue.put_nowait(notification_1)
267        self._mock_request_queue.qsize().AndReturn(1)
268        self._mock_request_queue.empty().AndReturn(False)
269        self._mock_request_queue.empty().WithSideEffects(
270                self._response_queue.put_nowait(response)).AndReturn(True)
271        self._mox.ReplayAll()
272        self._channel.flush()
273        self._mox.VerifyAll()
274        self.assertEqual(0, self._response_queue.qsize())
275
276
277    def test_flush_failed(self):
278        """ Test the case when the request queue fails to empty out. """
279        packet = self._get_unfragmented_packet(1)
280        self._mock_request_queue.qsize().AndReturn(1)
281        self._mock_request_queue.empty().MultipleTimes().AndReturn(False)
282        self._mox.ReplayAll()
283        self.assertRaises(
284                mbim_errors.MBIMComplianceChannelError,
285                self._channel.flush)
286        self._mox.VerifyAll()
287
288
289    def _queue_responses(self, responses):
290        """ Helper method for |_expect_transaction|. Do not use directly. """
291        for response in responses:
292            self._response_queue.put_nowait(response)
293
294
295    def _expect_transaction(self, requests, responses=None):
296        """
297        Helper method to setup expectations on the queues.
298
299        @param requests: A list of packets to expect on the |_request_queue|.
300        @param respones: An optional list of packets to respond with after the
301                last request.
302
303        """
304
305        last_request = requests[len(requests) - 1]
306        earlier_requests = requests[:len(requests) - 1]
307        for request in earlier_requests:
308            self._mock_request_queue.put_nowait(request)
309        if responses:
310            self._mock_request_queue.put_nowait(last_request).WithSideEffects(
311                    lambda _: self._queue_responses(responses))
312        else:
313            self._mock_request_queue.put_nowait(last_request)
314
315
316    def _verify_transaction_successful(self, requests, responses):
317        """
318        Helper method to assert that the transaction was successful.
319
320        @param requests: List of packets sent.
321        @param responses: List of packets expected back.
322        """
323        self._mox.ReplayAll()
324        self.assertEqual(responses,
325                         self._channel.bidirectional_transaction(*requests))
326        self._mox.VerifyAll()
327
328
329    def _verify_transaction_failed(self, requests):
330        """
331        Helper method to assert that the transaction failed.
332
333        @param requests: List of packets sent.
334
335        """
336        self._mox.ReplayAll()
337        self.assertRaises(mbim_errors.MBIMComplianceChannelError,
338                          self._channel.bidirectional_transaction,
339                          *requests)
340        self._mox.VerifyAll()
341
342
343    def _get_unfragmented_packet(self, transaction_id):
344        """ Creates a packet that has no fragment header. """
345        packet_format = '<LLL' # This does not contain a fragment header.
346        packet = self._create_buffer(struct.calcsize(packet_format))
347        struct.pack_into(packet_format,
348                         packet,
349                         0,
350                         0x00000000,  # 0x0 does not need fragments.
351                         struct.calcsize(packet_format),
352                         transaction_id)
353        return packet
354
355
356    def _get_fragment(self, transaction_id, total_fragments, current_fragment):
357        """ Creates a fragment with the given fields. """
358        fragment_header_format = '<LLLLL'
359        message_type = 0x00000003  # MBIM_COMMAND_MSG has fragments.
360        fragment = self._create_buffer(struct.calcsize(fragment_header_format))
361        struct.pack_into(fragment_header_format,
362                         fragment,
363                         0,
364                         message_type,
365                         struct.calcsize(fragment_header_format),
366                         transaction_id,
367                         total_fragments,
368                         current_fragment)
369        return fragment
370
371
372    def _create_buffer(self, size):
373        """ Create an array of the give size initialized to 0x00. """
374        return array.array('B', '\x00' * size)
375
376
377if __name__ == '__main__':
378    logging.basicConfig(level=logging.DEBUG)
379    unittest.main()
380