1# Copyright 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""A simple module for declaring C-like structures.
16
17Example usage:
18
19>>> # Declare a struct type by specifying name, field formats and field names.
20... # Field formats are the same as those used in the struct module, except:
21... # - S: Nested Struct.
22... # - A: NULL-padded ASCII string. Like s, but printing ignores contiguous
23... #      trailing NULL blocks at the end.
24... import cstruct
25>>> NLMsgHdr = cstruct.Struct("NLMsgHdr", "=LHHLL", "length type flags seq pid")
26>>>
27>>>
28>>> # Create instances from tuples or raw bytes. Data past the end is ignored.
29... n1 = NLMsgHdr((44, 32, 0x2, 0, 491))
30>>> print n1
31NLMsgHdr(length=44, type=32, flags=2, seq=0, pid=491)
32>>>
33>>> n2 = NLMsgHdr("\x2c\x00\x00\x00\x21\x00\x02\x00"
34...               "\x00\x00\x00\x00\xfe\x01\x00\x00" + "junk at end")
35>>> print n2
36NLMsgHdr(length=44, type=33, flags=2, seq=0, pid=510)
37>>>
38>>> # Serialize to raw bytes.
39... print n1.Pack().encode("hex")
402c0000002000020000000000eb010000
41>>>
42>>> # Parse the beginning of a byte stream as a struct, and return the struct
43... # and the remainder of the stream for further reading.
44... data = ("\x2c\x00\x00\x00\x21\x00\x02\x00"
45...         "\x00\x00\x00\x00\xfe\x01\x00\x00"
46...         "more data")
47>>> cstruct.Read(data, NLMsgHdr)
48(NLMsgHdr(length=44, type=33, flags=2, seq=0, pid=510), 'more data')
49>>>
50>>> # Structs can contain one or more nested structs. The nested struct types
51... # are specified in a list as an optional last argument. Nested structs may
52... # contain nested structs.
53... S = cstruct.Struct("S", "=BI", "byte1 int2")
54>>> N = cstruct.Struct("N", "!BSiS", "byte1 s2 int3 s2", [S, S])
55>>> NN = cstruct.Struct("NN", "SHS", "s1 word2 n3", [S, N])
56>>> nn = NN((S((1, 25000)), -29876, N((55, S((5, 6)), 1111, S((7, 8))))))
57>>> nn.n3.s2.int2 = 5
58>>>
59"""
60
61import ctypes
62import string
63import struct
64
65
66def CalcSize(fmt):
67  if "A" in fmt:
68    fmt = fmt.replace("A", "s")
69  return struct.calcsize(fmt)
70
71def CalcNumElements(fmt):
72  prevlen = len(fmt)
73  fmt = fmt.replace("S", "")
74  numstructs = prevlen - len(fmt)
75  size = CalcSize(fmt)
76  elements = struct.unpack(fmt, "\x00" * size)
77  return len(elements) + numstructs
78
79
80def Struct(name, fmt, fieldnames, substructs={}):
81  """Function that returns struct classes."""
82
83  class Meta(type):
84
85    def __len__(cls):
86      return cls._length
87
88    def __init__(cls, unused_name, unused_bases, namespace):
89      # Make the class object have the name that's passed in.
90      type.__init__(cls, namespace["_name"], unused_bases, namespace)
91
92  class CStruct(object):
93    """Class representing a C-like structure."""
94
95    __metaclass__ = Meta
96
97    # Name of the struct.
98    _name = name
99    # List of field names.
100    _fieldnames = fieldnames
101    # Dict mapping field indices to nested struct classes.
102    _nested = {}
103    # List of string fields that are ASCII strings.
104    _asciiz = set()
105
106    if isinstance(_fieldnames, str):
107      _fieldnames = _fieldnames.split(" ")
108
109    # Parse fmt into _format, converting any S format characters to "XXs",
110    # where XX is the length of the struct type's packed representation.
111    _format = ""
112    laststructindex = 0
113    for i in xrange(len(fmt)):
114      if fmt[i] == "S":
115        # Nested struct. Record the index in our struct it should go into.
116        index = CalcNumElements(fmt[:i])
117        _nested[index] = substructs[laststructindex]
118        laststructindex += 1
119        _format += "%ds" % len(_nested[index])
120      elif fmt[i] == "A":
121        # Null-terminated ASCII string.
122        index = CalcNumElements(fmt[:i])
123        _asciiz.add(index)
124        _format += "s"
125      else:
126         # Standard struct format character.
127        _format += fmt[i]
128
129    _length = CalcSize(_format)
130
131    def _SetValues(self, values):
132      super(CStruct, self).__setattr__("_values", list(values))
133
134    def _Parse(self, data):
135      data = data[:self._length]
136      values = list(struct.unpack(self._format, data))
137      for index, value in enumerate(values):
138        if isinstance(value, str) and index in self._nested:
139          values[index] = self._nested[index](value)
140      self._SetValues(values)
141
142    def __init__(self, values):
143      # Initializing from a string.
144      if isinstance(values, str):
145        if len(values) < self._length:
146          raise TypeError("%s requires string of length %d, got %d" %
147                          (self._name, self._length, len(values)))
148        self._Parse(values)
149      else:
150        # Initializing from a tuple.
151        if len(values) != len(self._fieldnames):
152          raise TypeError("%s has exactly %d fieldnames (%d given)" %
153                          (self._name, len(self._fieldnames), len(values)))
154        self._SetValues(values)
155
156    def _FieldIndex(self, attr):
157      try:
158        return self._fieldnames.index(attr)
159      except ValueError:
160        raise AttributeError("'%s' has no attribute '%s'" %
161                             (self._name, attr))
162
163    def __getattr__(self, name):
164      return self._values[self._FieldIndex(name)]
165
166    def __setattr__(self, name, value):
167      self._values[self._FieldIndex(name)] = value
168
169    @classmethod
170    def __len__(cls):
171      return cls._length
172
173    def __ne__(self, other):
174      return not self.__eq__(other)
175
176    def __eq__(self, other):
177      return (isinstance(other, self.__class__) and
178              self._name == other._name and
179              self._fieldnames == other._fieldnames and
180              self._values == other._values)
181
182    @staticmethod
183    def _MaybePackStruct(value):
184      if hasattr(value, "__metaclass__"):# and value.__metaclass__ == Meta:
185        return value.Pack()
186      else:
187        return value
188
189    def Pack(self):
190      values = [self._MaybePackStruct(v) for v in self._values]
191      return struct.pack(self._format, *values)
192
193    def __str__(self):
194      def FieldDesc(index, name, value):
195        if isinstance(value, str):
196          if index in self._asciiz:
197            value = value.rstrip("\x00")
198          elif any(c not in string.printable for c in value):
199            value = value.encode("hex")
200        return "%s=%s" % (name, value)
201
202      descriptions = [
203          FieldDesc(i, n, v) for i, (n, v) in
204          enumerate(zip(self._fieldnames, self._values))]
205
206      return "%s(%s)" % (self._name, ", ".join(descriptions))
207
208    def __repr__(self):
209      return str(self)
210
211    def CPointer(self):
212      """Returns a C pointer to the serialized structure."""
213      buf = ctypes.create_string_buffer(self.Pack())
214      # Store the C buffer in the object so it doesn't get garbage collected.
215      super(CStruct, self).__setattr__("_buffer", buf)
216      return ctypes.addressof(self._buffer)
217
218  return CStruct
219
220
221def Read(data, struct_type):
222  length = len(struct_type)
223  return struct_type(data), data[length:]
224