1# Copyright 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""A simple module for declaring C-like structures.
16
17Example usage:
18
19>>> # Declare a struct type by specifying name, field formats and field names.
20... # Field formats are the same as those used in the struct module, except:
21... # - S: Nested Struct.
22... # - A: NULL-padded ASCII string. Like s, but printing ignores contiguous
23... #      trailing NULL blocks at the end.
24... import cstruct
25>>> NLMsgHdr = cstruct.Struct("NLMsgHdr", "=LHHLL", "length type flags seq pid")
26>>>
27>>>
28>>> # Create instances from a tuple of values, raw bytes, zero-initialized, or
29>>> # using keywords.
30... n1 = NLMsgHdr((44, 32, 0x2, 0, 491))
31>>> print(n1)
32NLMsgHdr(length=44, type=32, flags=2, seq=0, pid=491)
33>>>
34>>> n2 = NLMsgHdr("\x2c\x00\x00\x00\x21\x00\x02\x00"
35...               "\x00\x00\x00\x00\xfe\x01\x00\x00" + "junk at end")
36>>> print(n2)
37NLMsgHdr(length=44, type=33, flags=2, seq=0, pid=510)
38>>>
39>>> n3 = netlink.NLMsgHdr() # Zero-initialized
40>>> print(n3)
41NLMsgHdr(length=0, type=0, flags=0, seq=0, pid=0)
42>>>
43>>> n4 = netlink.NLMsgHdr(length=44, type=33) # Other fields zero-initialized
44>>> print(n4)
45NLMsgHdr(length=44, type=33, flags=0, seq=0, pid=0)
46>>>
47>>> # Serialize to raw bytes.
48... print(n1.Pack().encode("hex"))
492c0000002000020000000000eb010000
50>>>
51>>> # Parse the beginning of a byte stream as a struct, and return the struct
52... # and the remainder of the stream for further reading.
53... data = ("\x2c\x00\x00\x00\x21\x00\x02\x00"
54...         "\x00\x00\x00\x00\xfe\x01\x00\x00"
55...         "more data")
56>>> cstruct.Read(data, NLMsgHdr)
57(NLMsgHdr(length=44, type=33, flags=2, seq=0, pid=510), 'more data')
58>>>
59>>> # Structs can contain one or more nested structs. The nested struct types
60... # are specified in a list as an optional last argument. Nested structs may
61... # contain nested structs.
62... S = cstruct.Struct("S", "=BI", "byte1 int2")
63>>> N = cstruct.Struct("N", "!BSiS", "byte1 s2 int3 s2", [S, S])
64>>> NN = cstruct.Struct("NN", "SHS", "s1 word2 n3", [S, N])
65>>> nn = NN((S((1, 25000)), -29876, N((55, S((5, 6)), 1111, S((7, 8))))))
66>>> nn.n3.s2.int2 = 5
67>>>
68"""
69
70import ctypes
71import string
72import struct
73import re
74
75
76def CalcSize(fmt):
77  if "A" in fmt:
78    fmt = fmt.replace("A", "s")
79  # Remove the last digital since it will cause error in python3.
80  fmt = (re.split('\d+$', fmt)[0])
81  return struct.calcsize(fmt)
82
83def CalcNumElements(fmt):
84  prevlen = len(fmt)
85  fmt = fmt.replace("S", "")
86  numstructs = prevlen - len(fmt)
87  size = CalcSize(fmt)
88  elements = struct.unpack(fmt, b"\x00" * size)
89  return len(elements) + numstructs
90
91
92def Struct(name, fmt, fieldnames, substructs={}):
93  """Function that returns struct classes."""
94
95  class Meta(type):
96
97    def __len__(cls):
98      return cls._length
99
100    def __init__(cls, unused_name, unused_bases, namespace):
101      # Make the class object have the name that's passed in.
102      type.__init__(cls, namespace["_name"], unused_bases, namespace)
103
104  class CStruct(object):
105    """Class representing a C-like structure."""
106
107    __metaclass__ = Meta
108
109    # Name of the struct.
110    _name = name
111    # List of field names.
112    _fieldnames = fieldnames
113    # Dict mapping field indices to nested struct classes.
114    _nested = {}
115    # List of string fields that are ASCII strings.
116    _asciiz = set()
117
118    _fieldnames = _fieldnames.split(" ")
119
120    # Parse fmt into _format, converting any S format characters to "XXs",
121    # where XX is the length of the struct type's packed representation.
122    _format = ""
123    laststructindex = 0
124    for i in range(len(fmt)):
125      if fmt[i] == "S":
126        # Nested struct. Record the index in our struct it should go into.
127        index = CalcNumElements(fmt[:i])
128        _nested[index] = substructs[laststructindex]
129        laststructindex += 1
130        _format += "%ds" % len(_nested[index])
131      elif fmt[i] == "A":
132        # Null-terminated ASCII string.
133        index = CalcNumElements(fmt[:i])
134        _asciiz.add(index)
135        _format += "s"
136      else:
137        # Standard struct format character.
138        _format += fmt[i]
139
140    _length = CalcSize(_format)
141
142    offset_list = [0]
143    last_offset = 0
144    for i in range(len(_format)):
145      offset = CalcSize(_format[:i])
146      if offset > last_offset:
147        last_offset = offset
148        offset_list.append(offset)
149
150    # A dictionary that maps field names to their offsets in the struct.
151    _offsets = dict(list(zip(_fieldnames, offset_list)))
152
153    # Check that the number of field names matches the number of fields.
154    numfields = len(struct.unpack(_format, b"\x00" * _length))
155    if len(_fieldnames) != numfields:
156      raise ValueError("Invalid cstruct: \"%s\" has %d elements, \"%s\" has %d."
157                       % (fmt, numfields, fieldnames, len(_fieldnames)))
158
159    def _SetValues(self, values):
160      # Replace self._values with the given list. We can't do direct assignment
161      # because of the __setattr__ overload on this class.
162      super(CStruct, self).__setattr__("_values", list(values))
163
164    def _Parse(self, data):
165      data = data[:self._length]
166      values = list(struct.unpack(self._format, data))
167      for index, value in enumerate(values):
168        if isinstance(value, str) and index in self._nested:
169          values[index] = self._nested[index](value)
170      self._SetValues(values)
171
172    def __init__(self, tuple_or_bytes=None, **kwargs):
173      """Construct an instance of this Struct.
174
175      1. With no args, the whole struct is zero-initialized.
176      2. With keyword args, the matching fields are populated; rest are zeroed.
177      3. With one tuple as the arg, the fields are assigned based on position.
178      4. With one string arg, the Struct is parsed from bytes.
179      """
180      if tuple_or_bytes and kwargs:
181        raise TypeError(
182            "%s: cannot specify both a tuple and keyword args" % self._name)
183
184      if tuple_or_bytes is None:
185        # Default construct from null bytes.
186        self._Parse("\x00" * len(self))
187        # If any keywords were supplied, set those fields.
188        for k, v in kwargs.items():
189          setattr(self, k, v)
190      elif isinstance(tuple_or_bytes, str):
191        # Initializing from a string.
192        if len(tuple_or_bytes) < self._length:
193          raise TypeError("%s requires string of length %d, got %d" %
194                          (self._name, self._length, len(tuple_or_bytes)))
195        self._Parse(tuple_or_bytes)
196      else:
197        # Initializing from a tuple.
198        if len(tuple_or_bytes) != len(self._fieldnames):
199          raise TypeError("%s has exactly %d fieldnames (%d given)" %
200                          (self._name, len(self._fieldnames),
201                           len(tuple_or_bytes)))
202        self._SetValues(tuple_or_bytes)
203
204    def _FieldIndex(self, attr):
205      try:
206        return self._fieldnames.index(attr)
207      except ValueError:
208        raise AttributeError("'%s' has no attribute '%s'" %
209                             (self._name, attr))
210
211    def __getattr__(self, name):
212      return self._values[self._FieldIndex(name)]
213
214    def __setattr__(self, name, value):
215      # TODO: check value type against self._format and throw here, or else
216      # callers get an unhelpful exception when they call Pack().
217      self._values[self._FieldIndex(name)] = value
218
219    def offset(self, name):
220      if "." in name:
221        raise NotImplementedError("offset() on nested field")
222      return self._offsets[name]
223
224    @classmethod
225    def __len__(cls):
226      return cls._length
227
228    def __ne__(self, other):
229      return not self.__eq__(other)
230
231    def __eq__(self, other):
232      return (isinstance(other, self.__class__) and
233              self._name == other._name and
234              self._fieldnames == other._fieldnames and
235              self._values == other._values)
236
237    @staticmethod
238    def _MaybePackStruct(value):
239      if hasattr(value, "__metaclass__"):# and value.__metaclass__ == Meta:
240        return value.Pack()
241      else:
242        return value
243
244    def Pack(self):
245      values = [self._MaybePackStruct(v) for v in self._values]
246      return struct.pack(self._format, *values)
247
248    def __str__(self):
249      def FieldDesc(index, name, value):
250        if isinstance(value, str):
251          if index in self._asciiz:
252            value = value.rstrip("\x00")
253          elif any(c not in string.printable for c in value):
254            value = value.encode("hex")
255        return "%s=%s" % (name, value)
256
257      descriptions = [
258          FieldDesc(i, n, v) for i, (n, v) in
259          enumerate(zip(self._fieldnames, self._values))]
260
261      return "%s(%s)" % (self._name, ", ".join(descriptions))
262
263    def __repr__(self):
264      return str(self)
265
266    def CPointer(self):
267      """Returns a C pointer to the serialized structure."""
268      buf = ctypes.create_string_buffer(self.Pack())
269      # Store the C buffer in the object so it doesn't get garbage collected.
270      super(CStruct, self).__setattr__("_buffer", buf)
271      return ctypes.addressof(self._buffer)
272
273  return CStruct
274
275
276def Read(data, struct_type):
277  length = len(struct_type)
278  return struct_type(data), data[length:]
279