1import types
2import weakref
3
4from .lock import allocate_lock
5from .error import CDefError, VerificationError, VerificationMissing
6
7# type qualifiers
8Q_CONST    = 0x01
9Q_RESTRICT = 0x02
10Q_VOLATILE = 0x04
11
12def qualify(quals, replace_with):
13    if quals & Q_CONST:
14        replace_with = ' const ' + replace_with.lstrip()
15    if quals & Q_VOLATILE:
16        replace_with = ' volatile ' + replace_with.lstrip()
17    if quals & Q_RESTRICT:
18        # It seems that __restrict is supported by gcc and msvc.
19        # If you hit some different compiler, add a #define in
20        # _cffi_include.h for it (and in its copies, documented there)
21        replace_with = ' __restrict ' + replace_with.lstrip()
22    return replace_with
23
24
25class BaseTypeByIdentity(object):
26    is_array_type = False
27    is_raw_function = False
28
29    def get_c_name(self, replace_with='', context='a C file', quals=0):
30        result = self.c_name_with_marker
31        assert result.count('&') == 1
32        # some logic duplication with ffi.getctype()... :-(
33        replace_with = replace_with.strip()
34        if replace_with:
35            if replace_with.startswith('*') and '&[' in result:
36                replace_with = '(%s)' % replace_with
37            elif not replace_with[0] in '[(':
38                replace_with = ' ' + replace_with
39        replace_with = qualify(quals, replace_with)
40        result = result.replace('&', replace_with)
41        if '$' in result:
42            raise VerificationError(
43                "cannot generate '%s' in %s: unknown type name"
44                % (self._get_c_name(), context))
45        return result
46
47    def _get_c_name(self):
48        return self.c_name_with_marker.replace('&', '')
49
50    def has_c_name(self):
51        return '$' not in self._get_c_name()
52
53    def is_integer_type(self):
54        return False
55
56    def get_cached_btype(self, ffi, finishlist, can_delay=False):
57        try:
58            BType = ffi._cached_btypes[self]
59        except KeyError:
60            BType = self.build_backend_type(ffi, finishlist)
61            BType2 = ffi._cached_btypes.setdefault(self, BType)
62            assert BType2 is BType
63        return BType
64
65    def __repr__(self):
66        return '<%s>' % (self._get_c_name(),)
67
68    def _get_items(self):
69        return [(name, getattr(self, name)) for name in self._attrs_]
70
71
72class BaseType(BaseTypeByIdentity):
73
74    def __eq__(self, other):
75        return (self.__class__ == other.__class__ and
76                self._get_items() == other._get_items())
77
78    def __ne__(self, other):
79        return not self == other
80
81    def __hash__(self):
82        return hash((self.__class__, tuple(self._get_items())))
83
84
85class VoidType(BaseType):
86    _attrs_ = ()
87
88    def __init__(self):
89        self.c_name_with_marker = 'void&'
90
91    def build_backend_type(self, ffi, finishlist):
92        return global_cache(self, ffi, 'new_void_type')
93
94void_type = VoidType()
95
96
97class BasePrimitiveType(BaseType):
98    def is_complex_type(self):
99        return False
100
101
102class PrimitiveType(BasePrimitiveType):
103    _attrs_ = ('name',)
104
105    ALL_PRIMITIVE_TYPES = {
106        'char':               'c',
107        'short':              'i',
108        'int':                'i',
109        'long':               'i',
110        'long long':          'i',
111        'signed char':        'i',
112        'unsigned char':      'i',
113        'unsigned short':     'i',
114        'unsigned int':       'i',
115        'unsigned long':      'i',
116        'unsigned long long': 'i',
117        'float':              'f',
118        'double':             'f',
119        'long double':        'f',
120        'float _Complex':     'j',
121        'double _Complex':    'j',
122        '_Bool':              'i',
123        # the following types are not primitive in the C sense
124        'wchar_t':            'c',
125        'char16_t':           'c',
126        'char32_t':           'c',
127        'int8_t':             'i',
128        'uint8_t':            'i',
129        'int16_t':            'i',
130        'uint16_t':           'i',
131        'int32_t':            'i',
132        'uint32_t':           'i',
133        'int64_t':            'i',
134        'uint64_t':           'i',
135        'int_least8_t':       'i',
136        'uint_least8_t':      'i',
137        'int_least16_t':      'i',
138        'uint_least16_t':     'i',
139        'int_least32_t':      'i',
140        'uint_least32_t':     'i',
141        'int_least64_t':      'i',
142        'uint_least64_t':     'i',
143        'int_fast8_t':        'i',
144        'uint_fast8_t':       'i',
145        'int_fast16_t':       'i',
146        'uint_fast16_t':      'i',
147        'int_fast32_t':       'i',
148        'uint_fast32_t':      'i',
149        'int_fast64_t':       'i',
150        'uint_fast64_t':      'i',
151        'intptr_t':           'i',
152        'uintptr_t':          'i',
153        'intmax_t':           'i',
154        'uintmax_t':          'i',
155        'ptrdiff_t':          'i',
156        'size_t':             'i',
157        'ssize_t':            'i',
158        }
159
160    def __init__(self, name):
161        assert name in self.ALL_PRIMITIVE_TYPES
162        self.name = name
163        self.c_name_with_marker = name + '&'
164
165    def is_char_type(self):
166        return self.ALL_PRIMITIVE_TYPES[self.name] == 'c'
167    def is_integer_type(self):
168        return self.ALL_PRIMITIVE_TYPES[self.name] == 'i'
169    def is_float_type(self):
170        return self.ALL_PRIMITIVE_TYPES[self.name] == 'f'
171    def is_complex_type(self):
172        return self.ALL_PRIMITIVE_TYPES[self.name] == 'j'
173
174    def build_backend_type(self, ffi, finishlist):
175        return global_cache(self, ffi, 'new_primitive_type', self.name)
176
177
178class UnknownIntegerType(BasePrimitiveType):
179    _attrs_ = ('name',)
180
181    def __init__(self, name):
182        self.name = name
183        self.c_name_with_marker = name + '&'
184
185    def is_integer_type(self):
186        return True
187
188    def build_backend_type(self, ffi, finishlist):
189        raise NotImplementedError("integer type '%s' can only be used after "
190                                  "compilation" % self.name)
191
192class UnknownFloatType(BasePrimitiveType):
193    _attrs_ = ('name', )
194
195    def __init__(self, name):
196        self.name = name
197        self.c_name_with_marker = name + '&'
198
199    def build_backend_type(self, ffi, finishlist):
200        raise NotImplementedError("float type '%s' can only be used after "
201                                  "compilation" % self.name)
202
203
204class BaseFunctionType(BaseType):
205    _attrs_ = ('args', 'result', 'ellipsis', 'abi')
206
207    def __init__(self, args, result, ellipsis, abi=None):
208        self.args = args
209        self.result = result
210        self.ellipsis = ellipsis
211        self.abi = abi
212        #
213        reprargs = [arg._get_c_name() for arg in self.args]
214        if self.ellipsis:
215            reprargs.append('...')
216        reprargs = reprargs or ['void']
217        replace_with = self._base_pattern % (', '.join(reprargs),)
218        if abi is not None:
219            replace_with = replace_with[:1] + abi + ' ' + replace_with[1:]
220        self.c_name_with_marker = (
221            self.result.c_name_with_marker.replace('&', replace_with))
222
223
224class RawFunctionType(BaseFunctionType):
225    # Corresponds to a C type like 'int(int)', which is the C type of
226    # a function, but not a pointer-to-function.  The backend has no
227    # notion of such a type; it's used temporarily by parsing.
228    _base_pattern = '(&)(%s)'
229    is_raw_function = True
230
231    def build_backend_type(self, ffi, finishlist):
232        raise CDefError("cannot render the type %r: it is a function "
233                        "type, not a pointer-to-function type" % (self,))
234
235    def as_function_pointer(self):
236        return FunctionPtrType(self.args, self.result, self.ellipsis, self.abi)
237
238
239class FunctionPtrType(BaseFunctionType):
240    _base_pattern = '(*&)(%s)'
241
242    def build_backend_type(self, ffi, finishlist):
243        result = self.result.get_cached_btype(ffi, finishlist)
244        args = []
245        for tp in self.args:
246            args.append(tp.get_cached_btype(ffi, finishlist))
247        abi_args = ()
248        if self.abi == "__stdcall":
249            if not self.ellipsis:    # __stdcall ignored for variadic funcs
250                try:
251                    abi_args = (ffi._backend.FFI_STDCALL,)
252                except AttributeError:
253                    pass
254        return global_cache(self, ffi, 'new_function_type',
255                            tuple(args), result, self.ellipsis, *abi_args)
256
257    def as_raw_function(self):
258        return RawFunctionType(self.args, self.result, self.ellipsis, self.abi)
259
260
261class PointerType(BaseType):
262    _attrs_ = ('totype', 'quals')
263
264    def __init__(self, totype, quals=0):
265        self.totype = totype
266        self.quals = quals
267        extra = qualify(quals, " *&")
268        if totype.is_array_type:
269            extra = "(%s)" % (extra.lstrip(),)
270        self.c_name_with_marker = totype.c_name_with_marker.replace('&', extra)
271
272    def build_backend_type(self, ffi, finishlist):
273        BItem = self.totype.get_cached_btype(ffi, finishlist, can_delay=True)
274        return global_cache(self, ffi, 'new_pointer_type', BItem)
275
276voidp_type = PointerType(void_type)
277
278def ConstPointerType(totype):
279    return PointerType(totype, Q_CONST)
280
281const_voidp_type = ConstPointerType(void_type)
282
283
284class NamedPointerType(PointerType):
285    _attrs_ = ('totype', 'name')
286
287    def __init__(self, totype, name, quals=0):
288        PointerType.__init__(self, totype, quals)
289        self.name = name
290        self.c_name_with_marker = name + '&'
291
292
293class ArrayType(BaseType):
294    _attrs_ = ('item', 'length')
295    is_array_type = True
296
297    def __init__(self, item, length):
298        self.item = item
299        self.length = length
300        #
301        if length is None:
302            brackets = '&[]'
303        elif length == '...':
304            brackets = '&[/*...*/]'
305        else:
306            brackets = '&[%s]' % length
307        self.c_name_with_marker = (
308            self.item.c_name_with_marker.replace('&', brackets))
309
310    def resolve_length(self, newlength):
311        return ArrayType(self.item, newlength)
312
313    def build_backend_type(self, ffi, finishlist):
314        if self.length == '...':
315            raise CDefError("cannot render the type %r: unknown length" %
316                            (self,))
317        self.item.get_cached_btype(ffi, finishlist)   # force the item BType
318        BPtrItem = PointerType(self.item).get_cached_btype(ffi, finishlist)
319        return global_cache(self, ffi, 'new_array_type', BPtrItem, self.length)
320
321char_array_type = ArrayType(PrimitiveType('char'), None)
322
323
324class StructOrUnionOrEnum(BaseTypeByIdentity):
325    _attrs_ = ('name',)
326    forcename = None
327
328    def build_c_name_with_marker(self):
329        name = self.forcename or '%s %s' % (self.kind, self.name)
330        self.c_name_with_marker = name + '&'
331
332    def force_the_name(self, forcename):
333        self.forcename = forcename
334        self.build_c_name_with_marker()
335
336    def get_official_name(self):
337        assert self.c_name_with_marker.endswith('&')
338        return self.c_name_with_marker[:-1]
339
340
341class StructOrUnion(StructOrUnionOrEnum):
342    fixedlayout = None
343    completed = 0
344    partial = False
345    packed = 0
346
347    def __init__(self, name, fldnames, fldtypes, fldbitsize, fldquals=None):
348        self.name = name
349        self.fldnames = fldnames
350        self.fldtypes = fldtypes
351        self.fldbitsize = fldbitsize
352        self.fldquals = fldquals
353        self.build_c_name_with_marker()
354
355    def anonymous_struct_fields(self):
356        if self.fldtypes is not None:
357            for name, type in zip(self.fldnames, self.fldtypes):
358                if name == '' and isinstance(type, StructOrUnion):
359                    yield type
360
361    def enumfields(self, expand_anonymous_struct_union=True):
362        fldquals = self.fldquals
363        if fldquals is None:
364            fldquals = (0,) * len(self.fldnames)
365        for name, type, bitsize, quals in zip(self.fldnames, self.fldtypes,
366                                              self.fldbitsize, fldquals):
367            if (name == '' and isinstance(type, StructOrUnion)
368                    and expand_anonymous_struct_union):
369                # nested anonymous struct/union
370                for result in type.enumfields():
371                    yield result
372            else:
373                yield (name, type, bitsize, quals)
374
375    def force_flatten(self):
376        # force the struct or union to have a declaration that lists
377        # directly all fields returned by enumfields(), flattening
378        # nested anonymous structs/unions.
379        names = []
380        types = []
381        bitsizes = []
382        fldquals = []
383        for name, type, bitsize, quals in self.enumfields():
384            names.append(name)
385            types.append(type)
386            bitsizes.append(bitsize)
387            fldquals.append(quals)
388        self.fldnames = tuple(names)
389        self.fldtypes = tuple(types)
390        self.fldbitsize = tuple(bitsizes)
391        self.fldquals = tuple(fldquals)
392
393    def get_cached_btype(self, ffi, finishlist, can_delay=False):
394        BType = StructOrUnionOrEnum.get_cached_btype(self, ffi, finishlist,
395                                                     can_delay)
396        if not can_delay:
397            self.finish_backend_type(ffi, finishlist)
398        return BType
399
400    def finish_backend_type(self, ffi, finishlist):
401        if self.completed:
402            if self.completed != 2:
403                raise NotImplementedError("recursive structure declaration "
404                                          "for '%s'" % (self.name,))
405            return
406        BType = ffi._cached_btypes[self]
407        #
408        self.completed = 1
409        #
410        if self.fldtypes is None:
411            pass    # not completing it: it's an opaque struct
412            #
413        elif self.fixedlayout is None:
414            fldtypes = [tp.get_cached_btype(ffi, finishlist)
415                        for tp in self.fldtypes]
416            lst = list(zip(self.fldnames, fldtypes, self.fldbitsize))
417            extra_flags = ()
418            if self.packed:
419                if self.packed == 1:
420                    extra_flags = (8,)    # SF_PACKED
421                else:
422                    extra_flags = (0, self.packed)
423            ffi._backend.complete_struct_or_union(BType, lst, self,
424                                                  -1, -1, *extra_flags)
425            #
426        else:
427            fldtypes = []
428            fieldofs, fieldsize, totalsize, totalalignment = self.fixedlayout
429            for i in range(len(self.fldnames)):
430                fsize = fieldsize[i]
431                ftype = self.fldtypes[i]
432                #
433                if isinstance(ftype, ArrayType) and ftype.length == '...':
434                    # fix the length to match the total size
435                    BItemType = ftype.item.get_cached_btype(ffi, finishlist)
436                    nlen, nrest = divmod(fsize, ffi.sizeof(BItemType))
437                    if nrest != 0:
438                        self._verification_error(
439                            "field '%s.%s' has a bogus size?" % (
440                            self.name, self.fldnames[i] or '{}'))
441                    ftype = ftype.resolve_length(nlen)
442                    self.fldtypes = (self.fldtypes[:i] + (ftype,) +
443                                     self.fldtypes[i+1:])
444                #
445                BFieldType = ftype.get_cached_btype(ffi, finishlist)
446                if isinstance(ftype, ArrayType) and ftype.length is None:
447                    assert fsize == 0
448                else:
449                    bitemsize = ffi.sizeof(BFieldType)
450                    if bitemsize != fsize:
451                        self._verification_error(
452                            "field '%s.%s' is declared as %d bytes, but is "
453                            "really %d bytes" % (self.name,
454                                                 self.fldnames[i] or '{}',
455                                                 bitemsize, fsize))
456                fldtypes.append(BFieldType)
457            #
458            lst = list(zip(self.fldnames, fldtypes, self.fldbitsize, fieldofs))
459            ffi._backend.complete_struct_or_union(BType, lst, self,
460                                                  totalsize, totalalignment)
461        self.completed = 2
462
463    def _verification_error(self, msg):
464        raise VerificationError(msg)
465
466    def check_not_partial(self):
467        if self.partial and self.fixedlayout is None:
468            raise VerificationMissing(self._get_c_name())
469
470    def build_backend_type(self, ffi, finishlist):
471        self.check_not_partial()
472        finishlist.append(self)
473        #
474        return global_cache(self, ffi, 'new_%s_type' % self.kind,
475                            self.get_official_name(), key=self)
476
477
478class StructType(StructOrUnion):
479    kind = 'struct'
480
481
482class UnionType(StructOrUnion):
483    kind = 'union'
484
485
486class EnumType(StructOrUnionOrEnum):
487    kind = 'enum'
488    partial = False
489    partial_resolved = False
490
491    def __init__(self, name, enumerators, enumvalues, baseinttype=None):
492        self.name = name
493        self.enumerators = enumerators
494        self.enumvalues = enumvalues
495        self.baseinttype = baseinttype
496        self.build_c_name_with_marker()
497
498    def force_the_name(self, forcename):
499        StructOrUnionOrEnum.force_the_name(self, forcename)
500        if self.forcename is None:
501            name = self.get_official_name()
502            self.forcename = '$' + name.replace(' ', '_')
503
504    def check_not_partial(self):
505        if self.partial and not self.partial_resolved:
506            raise VerificationMissing(self._get_c_name())
507
508    def build_backend_type(self, ffi, finishlist):
509        self.check_not_partial()
510        base_btype = self.build_baseinttype(ffi, finishlist)
511        return global_cache(self, ffi, 'new_enum_type',
512                            self.get_official_name(),
513                            self.enumerators, self.enumvalues,
514                            base_btype, key=self)
515
516    def build_baseinttype(self, ffi, finishlist):
517        if self.baseinttype is not None:
518            return self.baseinttype.get_cached_btype(ffi, finishlist)
519        #
520        if self.enumvalues:
521            smallest_value = min(self.enumvalues)
522            largest_value = max(self.enumvalues)
523        else:
524            import warnings
525            try:
526                # XXX!  The goal is to ensure that the warnings.warn()
527                # will not suppress the warning.  We want to get it
528                # several times if we reach this point several times.
529                __warningregistry__.clear()
530            except NameError:
531                pass
532            warnings.warn("%r has no values explicitly defined; "
533                          "guessing that it is equivalent to 'unsigned int'"
534                          % self._get_c_name())
535            smallest_value = largest_value = 0
536        if smallest_value < 0:   # needs a signed type
537            sign = 1
538            candidate1 = PrimitiveType("int")
539            candidate2 = PrimitiveType("long")
540        else:
541            sign = 0
542            candidate1 = PrimitiveType("unsigned int")
543            candidate2 = PrimitiveType("unsigned long")
544        btype1 = candidate1.get_cached_btype(ffi, finishlist)
545        btype2 = candidate2.get_cached_btype(ffi, finishlist)
546        size1 = ffi.sizeof(btype1)
547        size2 = ffi.sizeof(btype2)
548        if (smallest_value >= ((-1) << (8*size1-1)) and
549            largest_value < (1 << (8*size1-sign))):
550            return btype1
551        if (smallest_value >= ((-1) << (8*size2-1)) and
552            largest_value < (1 << (8*size2-sign))):
553            return btype2
554        raise CDefError("%s values don't all fit into either 'long' "
555                        "or 'unsigned long'" % self._get_c_name())
556
557def unknown_type(name, structname=None):
558    if structname is None:
559        structname = '$%s' % name
560    tp = StructType(structname, None, None, None)
561    tp.force_the_name(name)
562    tp.origin = "unknown_type"
563    return tp
564
565def unknown_ptr_type(name, structname=None):
566    if structname is None:
567        structname = '$$%s' % name
568    tp = StructType(structname, None, None, None)
569    return NamedPointerType(tp, name)
570
571
572global_lock = allocate_lock()
573_typecache_cffi_backend = weakref.WeakValueDictionary()
574
575def get_typecache(backend):
576    # returns _typecache_cffi_backend if backend is the _cffi_backend
577    # module, or type(backend).__typecache if backend is an instance of
578    # CTypesBackend (or some FakeBackend class during tests)
579    if isinstance(backend, types.ModuleType):
580        return _typecache_cffi_backend
581    with global_lock:
582        if not hasattr(type(backend), '__typecache'):
583            type(backend).__typecache = weakref.WeakValueDictionary()
584        return type(backend).__typecache
585
586def global_cache(srctype, ffi, funcname, *args, **kwds):
587    key = kwds.pop('key', (funcname, args))
588    assert not kwds
589    try:
590        return ffi._typecache[key]
591    except KeyError:
592        pass
593    try:
594        res = getattr(ffi._backend, funcname)(*args)
595    except NotImplementedError as e:
596        raise NotImplementedError("%s: %r: %s" % (funcname, srctype, e))
597    # note that setdefault() on WeakValueDictionary is not atomic
598    # and contains a rare bug (http://bugs.python.org/issue19542);
599    # we have to use a lock and do it ourselves
600    cache = ffi._typecache
601    with global_lock:
602        res1 = cache.get(key)
603        if res1 is None:
604            cache[key] = res
605            return res
606        else:
607            return res1
608
609def pointer_cache(ffi, BType):
610    return global_cache('?', ffi, 'new_pointer_type', BType)
611
612def attach_exception_info(e, name):
613    if e.args and type(e.args[0]) is str:
614        e.args = ('%s: %s' % (name, e.args[0]),) + e.args[1:]
615