1# Copyright (c) 2003-2016 CORE Security Technologies
2#
3# This software is provided under under a slightly modified version
4# of the Apache Software License. See the accompanying LICENSE file
5# for more information.
6#
7
8from struct import pack, unpack, calcsize
9
10class Structure:
11    """ sublcasses can define commonHdr and/or structure.
12        each of them is an tuple of either two: (fieldName, format) or three: (fieldName, ':', class) fields.
13        [it can't be a dictionary, because order is important]
14
15        where format specifies how the data in the field will be converted to/from bytes (string)
16        class is the class to use when unpacking ':' fields.
17
18        each field can only contain one value (or an array of values for *)
19           i.e. struct.pack('Hl',1,2) is valid, but format specifier 'Hl' is not (you must use 2 dfferent fields)
20
21        format specifiers:
22          specifiers from module pack can be used with the same format
23          see struct.__doc__ (pack/unpack is finally called)
24            x       [padding byte]
25            c       [character]
26            b       [signed byte]
27            B       [unsigned byte]
28            h       [signed short]
29            H       [unsigned short]
30            l       [signed long]
31            L       [unsigned long]
32            i       [signed integer]
33            I       [unsigned integer]
34            q       [signed long long (quad)]
35            Q       [unsigned long long (quad)]
36            s       [string (array of chars), must be preceded with length in format specifier, padded with zeros]
37            p       [pascal string (includes byte count), must be preceded with length in format specifier, padded with zeros]
38            f       [float]
39            d       [double]
40            =       [native byte ordering, size and alignment]
41            @       [native byte ordering, standard size and alignment]
42            !       [network byte ordering]
43            <       [little endian]
44            >       [big endian]
45
46          usual printf like specifiers can be used (if started with %)
47          [not recommeneded, there is no why to unpack this]
48
49            %08x    will output an 8 bytes hex
50            %s      will output a string
51            %s\\x00  will output a NUL terminated string
52            %d%d    will output 2 decimal digits (against the very same specification of Structure)
53            ...
54
55          some additional format specifiers:
56            :       just copy the bytes from the field into the output string (input may be string, other structure, or anything responding to __str__()) (for unpacking, all what's left is returned)
57            z       same as :, but adds a NUL byte at the end (asciiz) (for unpacking the first NUL byte is used as terminator)  [asciiz string]
58            u       same as z, but adds two NUL bytes at the end (after padding to an even size with NULs). (same for unpacking) [unicode string]
59            w       DCE-RPC/NDR string (it's a macro for [  '<L=(len(field)+1)/2','"\\x00\\x00\\x00\\x00','<L=(len(field)+1)/2',':' ]
60            ?-field length of field named 'field', formated as specified with ? ('?' may be '!H' for example). The input value overrides the real length
61            ?1*?2   array of elements. Each formated as '?2', the number of elements in the array is stored as specified by '?1' (?1 is optional, or can also be a constant (number), for unpacking)
62            'xxxx   literal xxxx (field's value doesn't change the output. quotes must not be closed or escaped)
63            "xxxx   literal xxxx (field's value doesn't change the output. quotes must not be closed or escaped)
64            _       will not pack the field. Accepts a third argument, which is an unpack code. See _Test_UnpackCode for an example
65            ?=packcode  will evaluate packcode in the context of the structure, and pack the result as specified by ?. Unpacking is made plain
66            ?&fieldname "Address of field fieldname".
67                        For packing it will simply pack the id() of fieldname. Or use 0 if fieldname doesn't exists.
68                        For unpacking, it's used to know weather fieldname has to be unpacked or not, i.e. by adding a & field you turn another field (fieldname) in an optional field.
69
70    """
71    commonHdr = ()
72    structure = ()
73    debug = 0
74
75    def __init__(self, data = None, alignment = 0):
76        if not hasattr(self, 'alignment'):
77            self.alignment = alignment
78
79        self.fields    = {}
80        self.rawData   = data
81        if data is not None:
82            self.fromString(data)
83        else:
84            self.data = None
85
86    @classmethod
87    def fromFile(self, file):
88        answer = self()
89        answer.fromString(file.read(len(answer)))
90        return answer
91
92    def setAlignment(self, alignment):
93        self.alignment = alignment
94
95    def setData(self, data):
96        self.data = data
97
98    def packField(self, fieldName, format = None):
99        if self.debug:
100            print "packField( %s | %s )" % (fieldName, format)
101
102        if format is None:
103            format = self.formatForField(fieldName)
104
105        if self.fields.has_key(fieldName):
106            ans = self.pack(format, self.fields[fieldName], field = fieldName)
107        else:
108            ans = self.pack(format, None, field = fieldName)
109
110        if self.debug:
111            print "\tanswer %r" % ans
112
113        return ans
114
115    def getData(self):
116        if self.data is not None:
117            return self.data
118        data = ''
119        for field in self.commonHdr+self.structure:
120            try:
121                data += self.packField(field[0], field[1])
122            except Exception, e:
123                if self.fields.has_key(field[0]):
124                    e.args += ("When packing field '%s | %s | %r' in %s" % (field[0], field[1], self[field[0]], self.__class__),)
125                else:
126                    e.args += ("When packing field '%s | %s' in %s" % (field[0], field[1], self.__class__),)
127                raise
128            if self.alignment:
129                if len(data) % self.alignment:
130                    data += ('\x00'*self.alignment)[:-(len(data) % self.alignment)]
131
132        #if len(data) % self.alignment: data += ('\x00'*self.alignment)[:-(len(data) % self.alignment)]
133        return data
134
135    def fromString(self, data):
136        self.rawData = data
137        for field in self.commonHdr+self.structure:
138            if self.debug:
139                print "fromString( %s | %s | %r )" % (field[0], field[1], data)
140            size = self.calcUnpackSize(field[1], data, field[0])
141            if self.debug:
142                print "  size = %d" % size
143            dataClassOrCode = str
144            if len(field) > 2:
145                dataClassOrCode = field[2]
146            try:
147                self[field[0]] = self.unpack(field[1], data[:size], dataClassOrCode = dataClassOrCode, field = field[0])
148            except Exception,e:
149                e.args += ("When unpacking field '%s | %s | %r[:%d]'" % (field[0], field[1], data, size),)
150                raise
151
152            size = self.calcPackSize(field[1], self[field[0]], field[0])
153            if self.alignment and size % self.alignment:
154                size += self.alignment - (size % self.alignment)
155            data = data[size:]
156
157        return self
158
159    def __setitem__(self, key, value):
160        self.fields[key] = value
161        self.data = None        # force recompute
162
163    def __getitem__(self, key):
164        return self.fields[key]
165
166    def __delitem__(self, key):
167        del self.fields[key]
168
169    def __str__(self):
170        return self.getData()
171
172    def __len__(self):
173        # XXX: improve
174        return len(self.getData())
175
176    def pack(self, format, data, field = None):
177        if self.debug:
178            print "  pack( %s | %r | %s)" %  (format, data, field)
179
180        if field:
181            addressField = self.findAddressFieldFor(field)
182            if (addressField is not None) and (data is None):
183                return ''
184
185        # void specifier
186        if format[:1] == '_':
187            return ''
188
189        # quote specifier
190        if format[:1] == "'" or format[:1] == '"':
191            return format[1:]
192
193        # code specifier
194        two = format.split('=')
195        if len(two) >= 2:
196            try:
197                return self.pack(two[0], data)
198            except:
199                fields = {'self':self}
200                fields.update(self.fields)
201                return self.pack(two[0], eval(two[1], {}, fields))
202
203        # address specifier
204        two = format.split('&')
205        if len(two) == 2:
206            try:
207                return self.pack(two[0], data)
208            except:
209                if (self.fields.has_key(two[1])) and (self[two[1]] is not None):
210                    return self.pack(two[0], id(self[two[1]]) & ((1<<(calcsize(two[0])*8))-1) )
211                else:
212                    return self.pack(two[0], 0)
213
214        # length specifier
215        two = format.split('-')
216        if len(two) == 2:
217            try:
218                return self.pack(two[0],data)
219            except:
220                return self.pack(two[0], self.calcPackFieldSize(two[1]))
221
222        # array specifier
223        two = format.split('*')
224        if len(two) == 2:
225            answer = ''
226            for each in data:
227                answer += self.pack(two[1], each)
228            if two[0]:
229                if two[0].isdigit():
230                    if int(two[0]) != len(data):
231                        raise Exception, "Array field has a constant size, and it doesn't match the actual value"
232                else:
233                    return self.pack(two[0], len(data))+answer
234            return answer
235
236        # "printf" string specifier
237        if format[:1] == '%':
238            # format string like specifier
239            return format % data
240
241        # asciiz specifier
242        if format[:1] == 'z':
243            return str(data)+'\0'
244
245        # unicode specifier
246        if format[:1] == 'u':
247            return str(data)+'\0\0' + (len(data) & 1 and '\0' or '')
248
249        # DCE-RPC/NDR string specifier
250        if format[:1] == 'w':
251            if len(data) == 0:
252                data = '\0\0'
253            elif len(data) % 2:
254                data += '\0'
255            l = pack('<L', len(data)/2)
256            return '%s\0\0\0\0%s%s' % (l,l,data)
257
258        if data is None:
259            raise Exception, "Trying to pack None"
260
261        # literal specifier
262        if format[:1] == ':':
263            return str(data)
264
265        # struct like specifier
266        return pack(format, data)
267
268    def unpack(self, format, data, dataClassOrCode = str, field = None):
269        if self.debug:
270            print "  unpack( %s | %r )" %  (format, data)
271
272        if field:
273            addressField = self.findAddressFieldFor(field)
274            if addressField is not None:
275                if not self[addressField]:
276                    return
277
278        # void specifier
279        if format[:1] == '_':
280            if dataClassOrCode != str:
281                fields = {'self':self, 'inputDataLeft':data}
282                fields.update(self.fields)
283                return eval(dataClassOrCode, {}, fields)
284            else:
285                return None
286
287        # quote specifier
288        if format[:1] == "'" or format[:1] == '"':
289            answer = format[1:]
290            if answer != data:
291                raise Exception, "Unpacked data doesn't match constant value '%r' should be '%r'" % (data, answer)
292            return answer
293
294        # address specifier
295        two = format.split('&')
296        if len(two) == 2:
297            return self.unpack(two[0],data)
298
299        # code specifier
300        two = format.split('=')
301        if len(two) >= 2:
302            return self.unpack(two[0],data)
303
304        # length specifier
305        two = format.split('-')
306        if len(two) == 2:
307            return self.unpack(two[0],data)
308
309        # array specifier
310        two = format.split('*')
311        if len(two) == 2:
312            answer = []
313            sofar = 0
314            if two[0].isdigit():
315                number = int(two[0])
316            elif two[0]:
317                sofar += self.calcUnpackSize(two[0], data)
318                number = self.unpack(two[0], data[:sofar])
319            else:
320                number = -1
321
322            while number and sofar < len(data):
323                nsofar = sofar + self.calcUnpackSize(two[1],data[sofar:])
324                answer.append(self.unpack(two[1], data[sofar:nsofar], dataClassOrCode))
325                number -= 1
326                sofar = nsofar
327            return answer
328
329        # "printf" string specifier
330        if format[:1] == '%':
331            # format string like specifier
332            return format % data
333
334        # asciiz specifier
335        if format == 'z':
336            if data[-1] != '\x00':
337                raise Exception, ("%s 'z' field is not NUL terminated: %r" % (field, data))
338            return data[:-1] # remove trailing NUL
339
340        # unicode specifier
341        if format == 'u':
342            if data[-2:] != '\x00\x00':
343                raise Exception, ("%s 'u' field is not NUL-NUL terminated: %r" % (field, data))
344            return data[:-2] # remove trailing NUL
345
346        # DCE-RPC/NDR string specifier
347        if format == 'w':
348            l = unpack('<L', data[:4])[0]
349            return data[12:12+l*2]
350
351        # literal specifier
352        if format == ':':
353            return dataClassOrCode(data)
354
355        # struct like specifier
356        return unpack(format, data)[0]
357
358    def calcPackSize(self, format, data, field = None):
359#        # print "  calcPackSize  %s:%r" %  (format, data)
360        if field:
361            addressField = self.findAddressFieldFor(field)
362            if addressField is not None:
363                if not self[addressField]:
364                    return 0
365
366        # void specifier
367        if format[:1] == '_':
368            return 0
369
370        # quote specifier
371        if format[:1] == "'" or format[:1] == '"':
372            return len(format)-1
373
374        # address specifier
375        two = format.split('&')
376        if len(two) == 2:
377            return self.calcPackSize(two[0], data)
378
379        # code specifier
380        two = format.split('=')
381        if len(two) >= 2:
382            return self.calcPackSize(two[0], data)
383
384        # length specifier
385        two = format.split('-')
386        if len(two) == 2:
387            return self.calcPackSize(two[0], data)
388
389        # array specifier
390        two = format.split('*')
391        if len(two) == 2:
392            answer = 0
393            if two[0].isdigit():
394                    if int(two[0]) != len(data):
395                        raise Exception, "Array field has a constant size, and it doesn't match the actual value"
396            elif two[0]:
397                answer += self.calcPackSize(two[0], len(data))
398
399            for each in data:
400                answer += self.calcPackSize(two[1], each)
401            return answer
402
403        # "printf" string specifier
404        if format[:1] == '%':
405            # format string like specifier
406            return len(format % data)
407
408        # asciiz specifier
409        if format[:1] == 'z':
410            return len(data)+1
411
412        # asciiz specifier
413        if format[:1] == 'u':
414            l = len(data)
415            return l + (l & 1 and 3 or 2)
416
417        # DCE-RPC/NDR string specifier
418        if format[:1] == 'w':
419            l = len(data)
420            return 12+l+l % 2
421
422        # literal specifier
423        if format[:1] == ':':
424            return len(data)
425
426        # struct like specifier
427        return calcsize(format)
428
429    def calcUnpackSize(self, format, data, field = None):
430        if self.debug:
431            print "  calcUnpackSize( %s | %s | %r)" %  (field, format, data)
432
433        # void specifier
434        if format[:1] == '_':
435            return 0
436
437        addressField = self.findAddressFieldFor(field)
438        if addressField is not None:
439            if not self[addressField]:
440                return 0
441
442        try:
443            lengthField = self.findLengthFieldFor(field)
444            return self[lengthField]
445        except:
446            pass
447
448        # XXX: Try to match to actual values, raise if no match
449
450        # quote specifier
451        if format[:1] == "'" or format[:1] == '"':
452            return len(format)-1
453
454        # address specifier
455        two = format.split('&')
456        if len(two) == 2:
457            return self.calcUnpackSize(two[0], data)
458
459        # code specifier
460        two = format.split('=')
461        if len(two) >= 2:
462            return self.calcUnpackSize(two[0], data)
463
464        # length specifier
465        two = format.split('-')
466        if len(two) == 2:
467            return self.calcUnpackSize(two[0], data)
468
469        # array specifier
470        two = format.split('*')
471        if len(two) == 2:
472            answer = 0
473            if two[0]:
474                if two[0].isdigit():
475                    number = int(two[0])
476                else:
477                    answer += self.calcUnpackSize(two[0], data)
478                    number = self.unpack(two[0], data[:answer])
479
480                while number:
481                    number -= 1
482                    answer += self.calcUnpackSize(two[1], data[answer:])
483            else:
484                while answer < len(data):
485                    answer += self.calcUnpackSize(two[1], data[answer:])
486            return answer
487
488        # "printf" string specifier
489        if format[:1] == '%':
490            raise Exception, "Can't guess the size of a printf like specifier for unpacking"
491
492        # asciiz specifier
493        if format[:1] == 'z':
494            return data.index('\x00')+1
495
496        # asciiz specifier
497        if format[:1] == 'u':
498            l = data.index('\x00\x00')
499            return l + (l & 1 and 3 or 2)
500
501        # DCE-RPC/NDR string specifier
502        if format[:1] == 'w':
503            l = unpack('<L', data[:4])[0]
504            return 12+l*2
505
506        # literal specifier
507        if format[:1] == ':':
508            return len(data)
509
510        # struct like specifier
511        return calcsize(format)
512
513    def calcPackFieldSize(self, fieldName, format = None):
514        if format is None:
515            format = self.formatForField(fieldName)
516
517        return self.calcPackSize(format, self[fieldName])
518
519    def formatForField(self, fieldName):
520        for field in self.commonHdr+self.structure:
521            if field[0] == fieldName:
522                return field[1]
523        raise Exception, ("Field %s not found" % fieldName)
524
525    def findAddressFieldFor(self, fieldName):
526        descriptor = '&%s' % fieldName
527        l = len(descriptor)
528        for field in self.commonHdr+self.structure:
529            if field[1][-l:] == descriptor:
530                return field[0]
531        return None
532
533    def findLengthFieldFor(self, fieldName):
534        descriptor = '-%s' % fieldName
535        l = len(descriptor)
536        for field in self.commonHdr+self.structure:
537            if field[1][-l:] == descriptor:
538                return field[0]
539        return None
540
541    def zeroValue(self, format):
542        two = format.split('*')
543        if len(two) == 2:
544            if two[0].isdigit():
545                return (self.zeroValue(two[1]),)*int(two[0])
546
547        if not format.find('*') == -1: return ()
548        if 's' in format: return ''
549        if format in ['z',':','u']: return ''
550        if format == 'w': return '\x00\x00'
551
552        return 0
553
554    def clear(self):
555        for field in self.commonHdr + self.structure:
556            self[field[0]] = self.zeroValue(field[1])
557
558    def dump(self, msg = None, indent = 0):
559        if msg is None: msg = self.__class__.__name__
560        ind = ' '*indent
561        print "\n%s" % msg
562        fixedFields = []
563        for field in self.commonHdr+self.structure:
564            i = field[0]
565            if i in self.fields:
566                fixedFields.append(i)
567                if isinstance(self[i], Structure):
568                    self[i].dump('%s%s:{' % (ind,i), indent = indent + 4)
569                    print "%s}" % ind
570                else:
571                    print "%s%s: {%r}" % (ind,i,self[i])
572        # Do we have remaining fields not defined in the structures? let's
573        # print them
574        remainingFields = list(set(self.fields) - set(fixedFields))
575        for i in remainingFields:
576            if isinstance(self[i], Structure):
577                self[i].dump('%s%s:{' % (ind,i), indent = indent + 4)
578                print "%s}" % ind
579            else:
580                print "%s%s: {%r}" % (ind,i,self[i])
581
582
583class _StructureTest:
584    alignment = 0
585    def create(self,data = None):
586        if data is not None:
587            return self.theClass(data, alignment = self.alignment)
588        else:
589            return self.theClass(alignment = self.alignment)
590
591    def run(self):
592        print
593        print "-"*70
594        testName = self.__class__.__name__
595        print "starting test: %s....." % testName
596        a = self.create()
597        self.populate(a)
598        a.dump("packing.....")
599        a_str = str(a)
600        print "packed: %r" % a_str
601        print "unpacking....."
602        b = self.create(a_str)
603        b.dump("unpacked.....")
604        print "repacking....."
605        b_str = str(b)
606        if b_str != a_str:
607            print "ERROR: original packed and repacked don't match"
608            print "packed: %r" % b_str
609
610class _Test_simple(_StructureTest):
611    class theClass(Structure):
612        commonHdr = ()
613        structure = (
614                ('int1', '!L'),
615                ('len1','!L-z1'),
616                ('arr1','B*<L'),
617                ('z1', 'z'),
618                ('u1','u'),
619                ('', '"COCA'),
620                ('len2','!H-:1'),
621                ('', '"COCA'),
622                (':1', ':'),
623                ('int3','>L'),
624                ('code1','>L=len(arr1)*2+0x1000'),
625                )
626
627    def populate(self, a):
628        a['default'] = 'hola'
629        a['int1'] = 0x3131
630        a['int3'] = 0x45444342
631        a['z1']   = 'hola'
632        a['u1']   = 'hola'.encode('utf_16_le')
633        a[':1']   = ':1234:'
634        a['arr1'] = (0x12341234,0x88990077,0x41414141)
635        # a['len1'] = 0x42424242
636
637class _Test_fixedLength(_Test_simple):
638    def populate(self, a):
639        _Test_simple.populate(self, a)
640        a['len1'] = 0x42424242
641
642class _Test_simple_aligned4(_Test_simple):
643    alignment = 4
644
645class _Test_nested(_StructureTest):
646    class theClass(Structure):
647        class _Inner(Structure):
648            structure = (('data', 'z'),)
649
650        structure = (
651            ('nest1', ':', _Inner),
652            ('nest2', ':', _Inner),
653            ('int', '<L'),
654        )
655
656    def populate(self, a):
657        a['nest1'] = _Test_nested.theClass._Inner()
658        a['nest2'] = _Test_nested.theClass._Inner()
659        a['nest1']['data'] = 'hola manola'
660        a['nest2']['data'] = 'chau loco'
661        a['int'] = 0x12345678
662
663class _Test_Optional(_StructureTest):
664    class theClass(Structure):
665        structure = (
666                ('pName','<L&Name'),
667                ('pList','<L&List'),
668                ('Name','w'),
669                ('List','<H*<L'),
670            )
671
672    def populate(self, a):
673        a['Name'] = 'Optional test'
674        a['List'] = (1,2,3,4)
675
676class _Test_Optional_sparse(_Test_Optional):
677    def populate(self, a):
678        _Test_Optional.populate(self, a)
679        del a['Name']
680
681class _Test_AsciiZArray(_StructureTest):
682    class theClass(Structure):
683        structure = (
684            ('head','<L'),
685            ('array','B*z'),
686            ('tail','<L'),
687        )
688
689    def populate(self, a):
690        a['head'] = 0x1234
691        a['tail'] = 0xabcd
692        a['array'] = ('hola','manola','te traje')
693
694class _Test_UnpackCode(_StructureTest):
695    class theClass(Structure):
696        structure = (
697            ('leni','<L=len(uno)*2'),
698            ('cuchi','_-uno','leni/2'),
699            ('uno',':'),
700            ('dos',':'),
701        )
702
703    def populate(self, a):
704        a['uno'] = 'soy un loco!'
705        a['dos'] = 'que haces fiera'
706
707class _Test_AAA(_StructureTest):
708    class theClass(Structure):
709        commonHdr = ()
710        structure = (
711          ('iv', '!L=((init_vector & 0xFFFFFF) << 8) | ((pad & 0x3f) << 2) | (keyid & 3)'),
712          ('init_vector',   '_','(iv >> 8)'),
713          ('pad',           '_','((iv >>2) & 0x3F)'),
714          ('keyid',         '_','( iv & 0x03 )'),
715          ('dataLen',       '_-data', 'len(inputDataLeft)-4'),
716          ('data',':'),
717          ('icv','>L'),
718        )
719
720    def populate(self, a):
721        a['init_vector']=0x01020304
722        #a['pad']=int('01010101',2)
723        a['pad']=int('010101',2)
724        a['keyid']=0x07
725        a['data']="\xA0\xA1\xA2\xA3\xA4\xA5\xA6\xA7\xA8\xA9"
726        a['icv'] = 0x05060708
727        #a['iv'] = 0x01020304
728
729if __name__ == '__main__':
730    _Test_simple().run()
731
732    try:
733        _Test_fixedLength().run()
734    except:
735        print "cannot repack because length is bogus"
736
737    _Test_simple_aligned4().run()
738    _Test_nested().run()
739    _Test_Optional().run()
740    _Test_Optional_sparse().run()
741    _Test_AsciiZArray().run()
742    _Test_UnpackCode().run()
743    _Test_AAA().run()
744