1# Copyright 2015 The Chromium OS Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4"""
5All of the MBIM messages are created using the MBIMControlMessageMeta metaclass.
6The metaclass supports a hierarchy of message definitions so that each message
7definition extends the structure of the base class it inherits.
8
9(mbim_message.py)
10MBIMControlMessage|         (mbim_message_request.py)
11                  |>MBIMControlMessageRequest |
12                  |                           |>MBIMOpen
13                  |                           |>MBIMClose
14                  |                           |>MBIMCommand    |
15                  |                           |                |>MBIMSetConnect
16                  |                           |                |>...
17                  |                           |
18                  |                           |>MBIMHostError
19                  |
20                  |         (mbim_message_response.py)
21                  |>MBIMControlMessageResponse|
22                                              |>MBIMOpenDone
23                                              |>MBIMCloseDone
24                                              |>MBIMCommandDone|
25                                              |                |>MBIMConnectInfo
26                                              |                |>...
27                                              |
28                                              |>MBIMHostError
29"""
30import array
31import logging
32import struct
33import sys
34from collections import namedtuple
35
36from autotest_lib.client.cros.cellular.mbim_compliance import mbim_errors
37
38
39# Type of message classes. The values of each field in the message is stored
40# as an attribute of the object created.
41# Request message classes accepts values for the attributes of the object.
42MESSAGE_TYPE_REQUEST = 1
43# Response message classes accepts raw_data which is parsed into attributes of
44# the object.
45MESSAGE_TYPE_RESPONSE = 2
46
47# Message field types.
48# Just a normal field type. No special properties.
49FIELD_TYPE_NORMAL = 1
50# Identify the payload ID for a message. This is used in  parsing of
51# response messages to help in identifying the child message class.
52FIELD_TYPE_PAYLOAD_ID = 2
53# Total length of the message including any payload_buffer it may contain.
54FIELD_TYPE_TOTAL_LEN = 3
55# Length of the payload contained in the payload_buffer.
56FIELD_TYPE_PAYLOAD_LEN = 4
57# Number of fragments of this message.
58FIELD_TYPE_NUM_FRAGMENTS = 5
59# Transaction ID of this message
60FIELD_TYPE_TRANSACTION_ID = 6
61
62
63def message_class_new(cls, **kwargs):
64    """
65    Creates a message instance with either the given field name/value
66    pairs or raw data buffer.
67
68    The total_length and transaction_id fields are automatically calculated
69    if not explicitly provided in the message args.
70
71    @param kwargs: Dictionary of (field_name, field_value) pairs or
72                    raw_data=Packed binary array.
73    @returns New message object created.
74
75    """
76    if 'raw_data' in kwargs and kwargs['raw_data']:
77        # We unpack the raw data received into the appropriate fields
78        # for this class. If there is some additional data present in
79        # |raw_data| that does not fit the format of the structure,
80        # they're stored in the variable sized |payload_buffer| field.
81        raw_data = kwargs['raw_data']
82        data_format = cls.get_field_format_string(get_all=True)
83        unpack_length = cls.get_struct_len(get_all=True)
84        data_length = len(raw_data)
85        if data_length < unpack_length:
86            mbim_errors.log_and_raise(
87                    mbim_errors.MBIMComplianceControlMessageError,
88                    'Length of Data (%d) to be parsed less than message'
89                    ' structure length (%d)' %
90                    (data_length, unpack_length))
91        obj = super(cls, cls).__new__(cls, *struct.unpack_from(data_format,
92                                                               raw_data))
93        if data_length > unpack_length:
94            setattr(obj, 'payload_buffer', raw_data[unpack_length:])
95        else:
96            setattr(obj, 'payload_buffer', None)
97        return obj
98    else:
99        # Check if all the fields have been populated for this message
100        # except for transaction ID and message length since these
101        # are generated during init.
102        field_values = []
103        fields = cls.get_fields(get_all=True)
104        defaults = cls.get_defaults(get_all=True)
105        for _, field_name, field_type in fields:
106            if field_name not in kwargs:
107                if field_type == FIELD_TYPE_TOTAL_LEN:
108                    field_value = cls.get_struct_len(get_all=True)
109                    if 'payload_buffer' in kwargs:
110                        field_value += len(kwargs.get('payload_buffer'))
111                elif field_type == FIELD_TYPE_TRANSACTION_ID:
112                    field_value = cls.get_next_transaction_id()
113                else:
114                    field_value = defaults.get(field_name, None)
115                if field_value is None:
116                    mbim_errors.log_and_raise(
117                            mbim_errors.MBIMComplianceControlMessageError,
118                            'Missing field value (%s) in %s' % (
119                                    field_name, cls.__name__))
120                field_values.append(field_value)
121            else:
122                field_values.append(kwargs.pop(field_name))
123        obj = super(cls, cls).__new__(cls, *field_values)
124        # We need to account for optional variable sized payload_buffer
125        # in some messages which are not explicitly mentioned in the
126        # |cls._FIELDS| attribute.
127        if 'payload_buffer' in kwargs:
128            setattr(obj, 'payload_buffer', kwargs.pop('payload_buffer'))
129        else:
130            setattr(obj, 'payload_buffer', None)
131        if kwargs:
132            mbim_errors.log_and_raise(
133                    mbim_errors.MBIMComplianceControlMessageError,
134                    'Unexpected fields (%s) in %s' % (
135                            kwargs.keys(), cls.__name__))
136        return obj
137
138
139class MBIMControlMessageMeta(type):
140    """
141    Metaclass for all the control message parsing/generation.
142
143    The metaclass creates each class by concatenating all the message fields
144    from it's base classes to create a hierarchy of messages.
145    Thus the payload class of each message class becomes the subclass of that
146    message.
147
148    Message definition attributes->
149    _FIELDS(optional): Used to define structure elements. The fields of a
150                       message is the concatenation of the _FIELDS attribute
151                       along with all the _FIELDS attribute from it's parent
152                       classes.
153    _DEFAULTS(optional): Field name/value pairs to be assigned to some
154                         of the fields if they are fixed for a message type.
155                         These are generally used to assign values to fields in
156                         the parent class.
157    _IDENTIFIERS(optional): Field name/value pairs to be used to idenitfy this
158                            message during parsing from raw_data.
159    _SECONDARY_FRAGMENTS(optional): Used to identify if this class can be
160                                    fragmented and name of secondary class
161                                    definition.
162    MESSAGE_TYPE: Used to identify request/repsonse classes.
163
164    Message internal attributes->
165    _CONSOLIDATED_FIELDS: Consolidated list of all the fields defining this
166                          message.
167    _CONSOLIDATED_DEFAULTS: Consolidated list of all the default field
168                            name/value pairs for this  message.
169
170    """
171    def __new__(mcs, name, bases, attrs):
172        # The MBIMControlMessage base class, which inherits from 'object',
173        # is merely used to establish the class hierarchy and is never
174        # constructed on it's own.
175        if object in bases:
176            return super(MBIMControlMessageMeta, mcs).__new__(
177                    mcs, name, bases, attrs)
178
179        # Append the current class fields, defaults to any base parent class
180        # fields.
181        fields = []
182        defaults = {}
183        for base_class in bases:
184            if hasattr(base_class, '_CONSOLIDATED_FIELDS'):
185                fields = getattr(base_class, '_CONSOLIDATED_FIELDS')
186            if hasattr(base_class, '_CONSOLIDATED_DEFAULTS'):
187                defaults = getattr(base_class, '_CONSOLIDATED_DEFAULTS').copy()
188        if '_FIELDS' in attrs:
189            fields = fields + map(list, attrs['_FIELDS'])
190        if '_DEFAULTS' in attrs:
191            defaults.update(attrs['_DEFAULTS'])
192        attrs['_CONSOLIDATED_FIELDS'] = fields
193        attrs['_CONSOLIDATED_DEFAULTS'] = defaults
194
195        if not fields:
196            mbim_errors.log_and_raise(
197                    mbim_errors.MBIMComplianceControlMessageError,
198                    '%s message must have some fields defined' % name)
199
200        attrs['__new__'] = message_class_new
201        _, field_names, _ = zip(*fields)
202        message_class = namedtuple(name, field_names)
203        # Prepend the class created via namedtuple to |bases| in order to
204        # correctly resolve the __new__ method while preserving the class
205        # hierarchy.
206        cls = super(MBIMControlMessageMeta, mcs).__new__(
207                mcs, name, (message_class,) + bases, attrs)
208        return cls
209
210
211class MBIMControlMessage(object):
212    """
213    MBIMControlMessage base class.
214
215    This class should not be instantiated or used directly.
216
217    """
218    __metaclass__ = MBIMControlMessageMeta
219
220    _NEXT_TRANSACTION_ID = 0X00000000
221
222
223    @classmethod
224    def _find_subclasses(cls):
225        """
226        Helper function to find all the derived payload classes of this
227        class.
228
229        """
230        return [c for c in cls.__subclasses__()]
231
232
233    @classmethod
234    def get_fields(cls, get_all=False):
235        """
236        Helper function to find all the fields of this class.
237
238        Returns either the total message fields or only the current
239        substructure fields in the nested message.
240
241        @param get_all: Whether to return the total struct fields or sub struct
242                         fields.
243        @returns Fields of the structure.
244
245        """
246        if get_all:
247            return cls._CONSOLIDATED_FIELDS
248        else:
249            return cls._FIELDS
250
251
252    @classmethod
253    def get_defaults(cls, get_all=False):
254        """
255        Helper function to find all the default field values of this class.
256
257        Returns either the total message default field name/value pairs or only
258        the current substructure defaults in the nested message.
259
260        @param get_all: Whether to return the total struct defaults or sub
261                         struct defaults.
262        @returns Defaults of the structure.
263
264        """
265        if get_all:
266            return cls._CONSOLIDATED_DEFAULTS
267        else:
268            return cls._DEFAULTS
269
270
271    @classmethod
272    def _get_identifiers(cls):
273        """
274        Helper function to find all the identifier field name/value pairs of
275        this class.
276
277        @returns All the idenitifiers of this class.
278
279        """
280        return getattr(cls, '_IDENTIFIERS', None)
281
282
283    @classmethod
284    def _find_field_names_of_type(cls, find_type, get_all=False):
285        """
286        Helper function to find all the field names which matches the field_type
287        specified.
288
289        params find_type: One of the FIELD_TYPE_* enum values specified above.
290        @returns Corresponding field names if found, else None.
291        """
292        fields = cls.get_fields(get_all=get_all)
293        field_names = []
294        for _, field_name, field_type in fields:
295            if field_type == find_type:
296                field_names.append(field_name)
297        return field_names
298
299
300    @classmethod
301    def get_secondary_fragment(cls):
302        """
303        Helper function to retrieve the associated secondary fragment class.
304
305        @returns |_SECONDARY_FRAGMENT| attribute of the class
306
307        """
308        return getattr(cls, '_SECONDARY_FRAGMENT', None)
309
310
311    @classmethod
312    def get_field_names(cls, get_all=True):
313        """
314        Helper function to return the field names of the message.
315
316        @returns The field names of the message structure.
317
318        """
319        _, field_names, _ = zip(*cls.get_fields(get_all=get_all))
320        return field_names
321
322
323    @classmethod
324    def get_field_formats(cls, get_all=True):
325        """
326        Helper function to return the field formats of the message.
327
328        @returns The format of fields of the message structure.
329
330        """
331        field_formats, _, _ = zip(*cls.get_fields(get_all=get_all))
332        return field_formats
333
334
335    @classmethod
336    def get_field_format_string(cls, get_all=True):
337        """
338        Helper function to return the field format string of the message.
339
340        @returns The format string of the message structure.
341
342        """
343        format_string = '<' + ''.join(cls.get_field_formats(get_all=get_all))
344        return format_string
345
346
347    @classmethod
348    def get_struct_len(cls, get_all=False):
349        """
350        Returns the length of the structure representing the message.
351
352        Returns the length of either the total message or only the current
353        substructure in the nested message.
354
355        @param get_all: Whether to return the total struct length or sub struct
356                length.
357        @returns Length of the structure.
358
359        """
360        return struct.calcsize(cls.get_field_format_string(get_all=get_all))
361
362
363    @classmethod
364    def find_primary_parent_fragment(cls):
365        """
366        Traverses up the message tree to find the primary fragment class
367        at the same tree level as the secondary frag class associated with this
368        message class. This should only be called on primary fragment derived
369        classes!
370
371        @returns Primary frag class associated with the message.
372
373        """
374        secondary_frag_cls = cls.get_secondary_fragment()
375        secondary_frag_parent_cls = secondary_frag_cls.__bases__[1]
376        message_cls = cls
377        message_parent_cls = message_cls.__bases__[1]
378        while message_parent_cls != secondary_frag_parent_cls:
379            message_cls = message_parent_cls
380            message_parent_cls = message_cls.__bases__[1]
381        return message_cls
382
383
384    @classmethod
385    def get_next_transaction_id(cls):
386        """
387        Returns incrementing transaction ids on successive calls.
388
389        @returns The tracsaction id for control message delivery.
390
391        """
392        if MBIMControlMessage._NEXT_TRANSACTION_ID > (sys.maxint - 2):
393            MBIMControlMessage._NEXT_TRANSACTION_ID = 0x00000000
394        MBIMControlMessage._NEXT_TRANSACTION_ID += 1
395        return MBIMControlMessage._NEXT_TRANSACTION_ID
396
397
398    def _get_fields_of_type(self, field_type, get_all=False):
399        """
400        Helper function to find all the field name/value of the specified type
401        in the given object.
402
403        @returns Corresponding map of field name/value pairs extracted from the
404                object.
405
406        """
407        field_names = self.__class__._find_field_names_of_type(field_type,
408                                                               get_all=get_all)
409        return {f: getattr(self, f) for f in field_names}
410
411
412    def _get_payload_id_fields(self):
413        """
414        Helper function to find all the payload id field name/value in the given
415        object.
416
417        @returns Corresponding field name/value pairs extracted from the object.
418
419        """
420        return self._get_fields_of_type(FIELD_TYPE_PAYLOAD_ID)
421
422
423    def get_payload_len(self):
424        """
425        Helper function to find the payload len field value in the given
426        object.
427
428        @returns Corresponding field value extracted from the object.
429
430        """
431        payload_len_fields = self._get_fields_of_type(FIELD_TYPE_PAYLOAD_LEN)
432        if ((not payload_len_fields) or (len(payload_len_fields) > 1)):
433            mbim_errors.log_and_raise(
434                    mbim_errors.MBIMComplianceControlMessageError,
435                    "Erorr in finding payload len field in message: %s" %
436                    self.__class__.__name__)
437        return payload_len_fields.values()[0]
438
439
440    def get_total_len(self):
441        """
442        Helper function to find the total len field value in the given
443        object.
444
445        @returns Corresponding field value extracted from the object.
446
447        """
448        total_len_fields = self._get_fields_of_type(FIELD_TYPE_TOTAL_LEN,
449                                                    get_all=True)
450        if ((not total_len_fields) or (len(total_len_fields) > 1)):
451            mbim_errors.log_and_raise(
452                    mbim_errors.MBIMComplianceControlMessageError,
453                    "Erorr in finding total len field in message: %s" %
454                    self.__class__.__name__)
455        return total_len_fields.values()[0]
456
457
458    def get_num_fragments(self):
459        """
460        Helper function to find the fragment num field value in the given
461        object.
462
463        @returns Corresponding field value extracted from the object.
464
465        """
466        num_fragment_fields = self._get_fields_of_type(FIELD_TYPE_NUM_FRAGMENTS)
467        if ((not num_fragment_fields) or (len(num_fragment_fields) > 1)):
468            mbim_errors.log_and_raise(
469                    mbim_errors.MBIMComplianceControlMessageError,
470                    "Erorr in finding num fragments field in message: %s" %
471                    self.__class__.__name__)
472        return num_fragment_fields.values()[0]
473
474
475    def find_payload_class(self):
476        """
477        Helper function to find the derived class which has the default
478        |payload_id| fields matching the current message contents.
479
480        @returns Corresponding class if found, else None.
481
482        """
483        cls = self.__class__
484        for payload_cls in cls._find_subclasses():
485            message_ids = self._get_payload_id_fields()
486            subclass_ids = payload_cls._get_identifiers()
487            if message_ids == subclass_ids:
488                return payload_cls
489        return None
490
491
492    def calculate_total_len(self):
493        """
494        Helper function to calculate the total len of a given message
495        object.
496
497        @returns Total length of the message.
498
499        """
500        message_class = self.__class__
501        total_len = message_class.get_struct_len(get_all=True)
502        if self.payload_buffer:
503            total_len += len(self.payload_buffer)
504        return total_len
505
506
507    def pack(self, format_string, field_names):
508        """
509        Packs a list of fields based on their formats.
510
511        @param format_string: The concatenated formats for the fields given in
512                |field_names|.
513        @param field_names: The name of the fields to be packed.
514        @returns The packet in binary array form.
515
516        """
517        field_values = [getattr(self, name) for name in field_names]
518        return array.array('B', struct.pack(format_string, *field_values))
519
520
521    def print_all_fields(self):
522        """Prints all the field name, value pair of this message."""
523        logging.debug('Class Name: %s', self.__class__.__name__)
524        for field_name in self.__class__.get_field_names(get_all=True):
525            logging.debug('Field Name: %s, Field Value: %s',
526                           field_name, str(getattr(self, field_name)))
527        if self.payload_buffer:
528            logging.debug('Payload: %s', str(getattr(self, 'payload_buffer')))
529
530
531    def create_raw_data(self):
532        """
533        Creates the raw binary data corresponding to the message struct.
534
535        @param payload_buffer: Variable sized paylaod buffer to attach at the
536                end of the msg.
537        @returns Packed byte array of the message.
538
539        """
540        message = self
541        message_class = message.__class__
542        format_string = message_class.get_field_format_string()
543        field_names = message_class.get_field_names()
544        packet = message.pack(format_string, field_names)
545        if self.payload_buffer:
546            packet.extend(self.payload_buffer)
547        return packet
548
549
550    def copy(self, **fields_to_alter):
551        """
552        Replaces the message tuple with updated field values.
553
554        @param fields_to_alter: Field name/value pairs to be changed.
555        @returns Updated message with the field values updated.
556
557        """
558        message = self._replace(**fields_to_alter)
559        # Copy the associated payload_buffer field to the new tuple.
560        message.payload_buffer = self.payload_buffer
561        return message
562