1#!/usr/bin/python3
2#
3# Copyright (c) 2019 Collabora, Ltd.
4#
5# SPDX-License-Identifier: Apache-2.0
6#
7# Author(s):    Ryan Pavlik <ryan.pavlik@collabora.com>
8#
9# Purpose:      This script checks some "business logic" in the XML registry.
10
11import argparse
12import re
13import sys
14from pathlib import Path
15from collections import defaultdict, deque, namedtuple
16
17from check_spec_links import VulkanEntityDatabase as OrigEntityDatabase
18from reg import Registry
19from spec_tools.consistency_tools import XMLChecker
20from spec_tools.util import findNamedElem, getElemName, getElemType
21from apiconventions import APIConventions
22from parse_dependency import dependencyNames
23
24# These are extensions which do not follow the usual naming conventions,
25# specifying the alternate convention they follow
26EXTENSION_ENUM_NAME_SPELLING_CHANGE = {
27    'VK_EXT_swapchain_colorspace': 'VK_EXT_SWAPCHAIN_COLOR_SPACE',
28}
29
30# These are extensions whose names *look* like they end in version numbers,
31# but do not
32EXTENSION_NAME_VERSION_EXCEPTIONS = (
33    'VK_AMD_gpu_shader_int16',
34    'VK_EXT_index_type_uint8',
35    'VK_EXT_shader_image_atomic_int64',
36    'VK_KHR_video_decode_h264',
37    'VK_KHR_video_decode_h265',
38    'VK_EXT_video_encode_h264',
39    'VK_EXT_video_encode_h265',
40    'VK_KHR_external_fence_win32',
41    'VK_KHR_external_memory_win32',
42    'VK_KHR_external_semaphore_win32',
43    'VK_KHR_shader_atomic_int64',
44    'VK_KHR_shader_float16_int8',
45    'VK_KHR_spirv_1_4',
46    'VK_NV_external_memory_win32',
47    'VK_RESERVED_do_not_use_146',
48    'VK_RESERVED_do_not_use_94',
49)
50
51# These are APIs which can be required by an extension despite not having
52# suffixes matching the vendor ID of that extension.
53# Most are external types.
54# We could make this an (extension name, api name) set to be more specific.
55EXTENSION_API_NAME_EXCEPTIONS = {
56    'AHardwareBuffer',
57    'ANativeWindow',
58    'CAMetalLayer',
59    'IOSurfaceRef',
60    'MTLBuffer_id',
61    'MTLCommandQueue_id',
62    'MTLDevice_id',
63    'MTLSharedEvent_id',
64    'MTLTexture_id',
65    'VK_SAMPLER_ADDRESS_MODE_MIRROR_CLAMP_TO_EDGE',
66    'VkFlags64',
67    'VkPipelineCacheCreateFlagBits',
68    'VkPipelineColorBlendStateCreateFlagBits',
69    'VkPipelineDepthStencilStateCreateFlagBits',
70    'VkPipelineLayoutCreateFlagBits',
71}
72
73# These are APIs which contain _RESERVED_ intentionally
74EXTENSION_NAME_RESERVED_EXCEPTIONS = {
75    'VK_STRUCTURE_TYPE_PRIVATE_VENDOR_INFO_RESERVED_OFFSET_0_NV'
76}
77
78# Exceptions to pointer parameter naming rules
79# Keyed by (entity name, type, name).
80CHECK_PARAM_POINTER_NAME_EXCEPTIONS = {
81    ('vkGetDrmDisplayEXT', 'VkDisplayKHR', 'display') : None,
82}
83
84# Exceptions to pNext member requiring an optional attribute
85CHECK_MEMBER_PNEXT_OPTIONAL_EXCEPTIONS = (
86    'VkVideoEncodeInfoKHR',
87    'VkVideoEncodeRateControlLayerInfoKHR',
88)
89
90# Exceptions to VK_INCOMPLETE being required for, and only applicable to, array
91# enumeration functions
92CHECK_ARRAY_ENUMERATION_RETURN_CODE_EXCEPTIONS = (
93    'vkGetDeviceFaultInfoEXT',
94    'vkEnumerateDeviceLayerProperties',
95)
96
97# Exceptions to unknown structure type constants.
98# This is most likely an error in this script, not the XML.
99# It does not understand Vulkan SC (alternate 'api') types.
100CHECK_TYPE_STYPE_EXCEPTIONS = (
101    'VK_STRUCTURE_TYPE_PERFORMANCE_QUERY_RESERVATION_INFO_KHR',
102    'VK_STRUCTURE_TYPE_PIPELINE_POOL_SIZE',
103    'VK_STRUCTURE_TYPE_FAULT_DATA',
104    'VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_SC_1_0_FEATURES',
105    'VK_STRUCTURE_TYPE_DEVICE_OBJECT_RESERVATION_CREATE_INFO',
106    'VK_STRUCTURE_TYPE_PIPELINE_OFFLINE_CREATE_INFO',
107    'VK_STRUCTURE_TYPE_FAULT_CALLBACK_INFO',
108    'VK_STRUCTURE_TYPE_COMMAND_POOL_MEMORY_RESERVATION_CREATE_INFO',
109    'VK_STRUCTURE_TYPE_DEVICE_SEMAPHORE_SCI_SYNC_POOL_RESERVATION_CREATE_INFO_NV',
110    'VK_STRUCTURE_TYPE_COMMAND_POOL_MEMORY_CONSUMPTION',
111    'VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_SC_1_0_PROPERTIES',
112)
113
114def get_extension_commands(reg):
115    extension_cmds = set()
116    for ext in reg.extensions:
117        for cmd in ext.findall('./require/command[@name]'):
118            extension_cmds.add(cmd.get('name'))
119    return extension_cmds
120
121
122def get_enum_value_names(reg, enum_type):
123    names = set()
124    result_elem = reg.groupdict[enum_type].elem
125    for val in result_elem.findall('./enum[@name]'):
126        names.add(val.get('name'))
127    return names
128
129
130# Regular expression matching an extension name ending in a (possible) version number
131EXTNAME_RE = re.compile(r'(?P<base>(\w+[A-Za-z]))(?P<version>\d+)')
132
133DESTROY_PREFIX = 'vkDestroy'
134TYPEENUM = 'VkStructureType'
135
136SPECIFICATION_DIR = Path(__file__).parent.parent
137REVISION_RE = re.compile(r' *[*] Revision (?P<num>[1-9][0-9]*),.*')
138
139
140def get_extension_source(extname):
141    fn = f'{extname}.adoc'
142    return str(SPECIFICATION_DIR / 'appendices' / fn)
143
144
145class EntityDatabase(OrigEntityDatabase):
146
147    def __init__(self, args):
148        """Retain command-line arguments for later use in makeRegistry"""
149        self.args = args
150
151        super().__init__()
152
153    # Override base class method to not exclude 'disabled' extensions
154    def getExclusionSet(self):
155        """Return a set of "support=" attribute strings that should not be included in the database.
156
157        Called only during construction."""
158
159        return set(())
160
161    def makeRegistry(self):
162        try:
163            import lxml.etree as etree
164            HAS_LXML = True
165        except ImportError:
166            HAS_LXML = False
167        if not HAS_LXML:
168            return super().makeRegistry()
169
170        if len(self.args.files) > 0:
171            registryFile = self.args.files[0]
172        else:
173            registryFile = str(SPECIFICATION_DIR / 'xml/vk.xml')
174
175        registry = Registry()
176        registry.filename = registryFile
177        registry.loadElementTree(etree.parse(registryFile))
178        return registry
179
180
181class Checker(XMLChecker):
182    def __init__(self, args):
183        manual_types_to_codes = {
184            # These are hard-coded "manual" return codes:
185            # the codes of the value (string, list, or tuple)
186            # are available for a command if-and-only-if
187            # the key type is passed as an input.
188            'VkFormat': 'VK_ERROR_FORMAT_NOT_SUPPORTED'
189        }
190        forward_only = {
191            # Like the above, but these are only valid in the
192            # "type implies return code" direction
193        }
194        reverse_only = {
195            # like the above, but these are only valid in the
196            # "return code implies type or its descendant" direction
197            # "XrDuration": "XR_TIMEOUT_EXPIRED"
198        }
199        # Some return codes are related in that only one of a set
200        # may be returned by a command
201        # (eg. XR_ERROR_SESSION_RUNNING and XR_ERROR_SESSION_NOT_RUNNING)
202        self.exclusive_return_code_sets = tuple(
203            # set(("XR_ERROR_SESSION_NOT_RUNNING", "XR_ERROR_SESSION_RUNNING")),
204        )
205
206        # This is used to report collisions.
207        conventions = APIConventions()
208        db = EntityDatabase(args)
209
210        self.extension_cmds = get_extension_commands(db.registry)
211        self.return_codes = get_enum_value_names(db.registry, 'VkResult')
212        self.structure_types = get_enum_value_names(db.registry, TYPEENUM)
213
214        # Dict of entity name to a list of messages to suppress. (Exclude any context data and "Warning:"/"Error:")
215        # Keys are entity names, values are tuples or lists of message text to suppress.
216        suppressions = {}
217
218        # Structures explicitly allowed to have 'limittype' attributes
219        self.allowedStructs = set((
220            'VkFormatProperties',
221            'VkFormatProperties2',
222            'VkPhysicalDeviceProperties',
223            'VkPhysicalDeviceProperties2',
224            'VkPhysicalDeviceLimits',
225            'VkQueueFamilyProperties',
226            'VkQueueFamilyProperties2',
227            'VkSparseImageFormatProperties',
228            'VkSparseImageFormatProperties2',
229        ))
230
231        # Substructures of allowed structures. This can be found by looking
232        # at tags, but there are so few cases that it is hardwired for now.
233        self.nestedStructs = set((
234            'VkPhysicalDeviceLimits',
235            'VkPhysicalDeviceSparseProperties',
236            'VkPhysicalDeviceProperties',
237            'VkQueueFamilyProperties',
238            'VkSparseImageFormatProperties',
239        ))
240
241        # Structures all of whose (non pNext/sType) members are required to
242        # have 'limittype' attributes, as are their descendants
243        self.requiredStructs = set((
244            'VkPhysicalDeviceProperties',
245            'VkPhysicalDeviceProperties2',
246            'VkPhysicalDeviceLimits',
247            'VkSparseImageFormatProperties',
248            'VkSparseImageFormatProperties2',
249        ))
250
251        # Structures which have already have their limittype attributes validated
252        self.validatedLimittype = set()
253
254        # Initialize superclass
255        super().__init__(entity_db=db, conventions=conventions,
256                         manual_types_to_codes=manual_types_to_codes,
257                         forward_only_types_to_codes=forward_only,
258                         reverse_only_types_to_codes=reverse_only,
259                         suppressions=suppressions,
260                         display_warnings=args.warn)
261
262    def check(self):
263        """Extends base class behavior with additional checks"""
264
265        # This test is not run on a per-structure basis, but loops over
266        # specific structures
267        self.check_type_required_limittype()
268
269        super().check()
270
271    def check_command(self, name, info):
272        """Extends base class behavior with additional checks"""
273
274        if name[0:5] == 'vkCmd':
275            if info.elem.get('tasks') is None:
276                self.record_error(f'{name} is a vkCmd* command, but is missing a "tasks" attribute')
277
278        super().check_command(name, info)
279
280    def check_command_return_codes_basic(self, name, info,
281                                         successcodes, errorcodes):
282        """Check a command's return codes for consistency.
283
284        Called on every command."""
285        # Check that all extension commands can return the code associated
286        # with trying to use an extension that was not enabled.
287        # if name in self.extension_cmds and UNSUPPORTED not in errorcodes:
288        #     self.record_error('Missing expected return code',
289        #                       UNSUPPORTED,
290        #                       'implied due to being an extension command')
291
292        codes = successcodes.union(errorcodes)
293
294        # Check that all return codes are recognized.
295        unrecognized = codes - self.return_codes
296        if unrecognized:
297            self.record_error('Unrecognized return code(s):',
298                              unrecognized)
299
300        elem = info.elem
301        params = [(getElemName(elt), elt) for elt in elem.findall('param')]
302
303        def is_count_output(name, elt):
304            # Must end with Count or Size,
305            # not be const,
306            # and be a pointer (detected by naming convention)
307            return (name.endswith('Count') or name.endswith('Size')) \
308                and (elt.tail is None or 'const' not in elt.tail) \
309                and (name.startswith('p'))
310
311        countParams = [elt
312                       for name, elt in params
313                       if is_count_output(name, elt)]
314        if countParams:
315            assert(len(countParams) == 1)
316            if 'VK_INCOMPLETE' not in successcodes:
317                message = "Apparent enumeration of an array without VK_INCOMPLETE in successcodes for command {}.".format(name)
318                if name in CHECK_ARRAY_ENUMERATION_RETURN_CODE_EXCEPTIONS:
319                    self.record_warning('(Allowed exception)', message)
320                else:
321                    self.record_error(message)
322
323        elif 'VK_INCOMPLETE' in successcodes:
324            message = "VK_INCOMPLETE in successcodes of command {} that is apparently not an array enumeration.".format(name)
325            if name in CHECK_ARRAY_ENUMERATION_RETURN_CODE_EXCEPTIONS:
326                self.record_warning('(Allowed exception)', message)
327            else:
328                self.record_error(message)
329
330    def check_param(self, param):
331        """Check a member of a struct or a param of a function.
332
333        Called from check_params."""
334        super().check_param(param)
335
336        if not self.is_api_type(param):
337            return
338
339        param_text = ''.join(param.itertext())
340        param_name = getElemName(param)
341
342        # Make sure the number of leading 'p' matches the pointer count.
343        pointercount = param.find('type').tail
344        if pointercount:
345            pointercount = pointercount.count('*')
346        if pointercount:
347            prefix = 'p' * pointercount
348            if not param_name.startswith(prefix):
349                param_type = param.find('type').text
350                message = "Apparently incorrect pointer-related name prefix for {} - expected it to start with '{}'".format(
351                    param_text, prefix)
352                if (self.entity, param_type, param_name) in CHECK_PARAM_POINTER_NAME_EXCEPTIONS:
353                    self.record_warning('(Allowed exception)', message, elem=param)
354                else:
355                    self.record_error(message, elem=param)
356
357        # Make sure no members have optional="false" attributes
358        optional = param.get('optional')
359        if optional == 'false':
360            message = f'{self.entity}.{param_name} member has disallowed \'optional="false"\' attribute (remove this attribute)'
361            self.record_error(message, elem=param)
362
363        # Make sure pNext members have optional="true" attributes
364        if param_name == self.conventions.nextpointer_member_name:
365            optional = param.get('optional')
366            if optional is None or optional != 'true':
367                message = f'{self.entity}.pNext member is missing \'optional="true"\' attribute'
368                if self.entity in CHECK_MEMBER_PNEXT_OPTIONAL_EXCEPTIONS:
369                    self.record_warning('(Allowed exception)', message, elem=param)
370                else:
371                    self.record_error(message, elem=param)
372
373    def check_type_stype(self, name, info, type_elts):
374        """Check a struct type's sType member"""
375        if len(type_elts) > 1:
376            self.record_error(
377                'Have more than one member of type', TYPEENUM)
378        else:
379            type_elt = type_elts[0]
380            val = type_elt.get('values')
381            if val and val not in self.structure_types:
382                message = f'{self.entity} has unknown structure type constant {val}'
383                if val in CHECK_TYPE_STYPE_EXCEPTIONS:
384                    self.record_warning('(Allowed exception)', message)
385                else:
386                    self.record_error(message)
387
388    def check_type_pnext(self, name, info):
389        """Check a struct type's pNext member, if present"""
390
391        next_name = self.conventions.nextpointer_member_name
392        next_member = findNamedElem(info.elem.findall('member'), next_name)
393        if next_member is not None:
394            # Ensure that the 'optional' attribute is set to 'true'
395            optional = next_member.get('optional')
396            if optional is None or optional != 'true':
397                message = f'{name}.{next_name} member is missing \'optional="true"\' attribute'
398                if name in CHECK_MEMBER_PNEXT_OPTIONAL_EXCEPTIONS:
399                    self.record_warning('(Allowed exception)', message)
400                else:
401                    self.record_error(message)
402
403    def __isLimittypeStruct(self, name, info, allowedStructs):
404        """Check if a type element is a structure allowed to have 'limittype' attributes
405           name - name of a structure
406           info - corresponding TypeInfo object
407           allowedStructs - set of struct names explicitly allowed"""
408
409        # Is this an explicitly allowed struct?
410        if name in allowedStructs:
411            return True
412
413        # Is this a struct extending an explicitly allowed struct?
414        extends = info.elem.get('structextends')
415        if extends is not None:
416            # See if any name in the structextends attribute is an allowed
417            # struct
418            if len(set(extends.split(',')) & allowedStructs) > 0:
419                return True
420
421        return False
422
423    def __validateStructLimittypes(self, name, info, requiredLimittype):
424        """Validate 'limittype' attributes for a single struct.
425           info - TypeInfo for a struct <type>
426           requiredLimittype - True if members *must* have a limittype"""
427
428        # Do not re-check structures
429        if name in self.validatedLimittype:
430            return {}
431        self.validatedLimittype.add(name)
432
433        limittypeDiags = namedtuple('limittypeDiags', ['missing', 'invalid'])
434        badFields = defaultdict(lambda : limittypeDiags(missing=[], invalid=[]))
435        validLimittypes = { 'min', 'max', 'pot', 'mul', 'bits', 'bitmask', 'range', 'struct', 'exact', 'noauto' }
436        for member in info.getMembers():
437            memberName = member.findtext('name')
438            if memberName in ['sType', 'pNext']:
439                continue
440            limittype = member.get('limittype')
441            if limittype is None:
442                # Do not tag this as missing if it is not required
443                if requiredLimittype:
444                    badFields[info.elem.get('name')].missing.append(memberName)
445            elif limittype == 'struct':
446                typeName = member.findtext('type')
447                memberType = self.reg.typedict[typeName]
448                badFields.update(self.__validateStructLimittypes(typeName, memberType, requiredLimittype))
449            else:
450                for value in limittype.split(','):
451                    if value not in validLimittypes:
452                        badFields[info.elem.get('name')].invalid.append(memberName)
453
454        return badFields
455
456    def check_type_disallowed_limittype(self, name, info):
457        """Check if a struct type's members cannot have the 'limittype' attribute"""
458
459        # If not allowed to have limittypes, verify this for each member
460        if not self.__isLimittypeStruct(name, info, self.allowedStructs.union(self.nestedStructs)):
461            for member in info.getMembers():
462                if member.get('limittype') is not None:
463                    memname = member.findtext('name')
464                    self.record_error(f'{name} member {memname} has disallowed limittype attribute')
465
466    def check_type_optional_value(self, name, info):
467        """Check if a struct type's members have disallowed 'optional' attribute values"""
468
469        for member in info.getMembers():
470            # Make sure no members have optional="false" attributes
471            optional = member.get('optional')
472            if optional == 'false':
473                memname = member.findtext('name')
474                message = f'{name} member {memname} has disallowed \'optional="false"\' attribute (remove this attribute)'
475                self.record_error(message, elem=member)
476
477    def check_type_required_limittype(self):
478        """Check struct type members which must have the 'limittype' attribute
479
480        Called from check."""
481
482        for name in self.allowedStructs:
483            # Assume that only extending structs of structs explicitly
484            # requiring limittypes also require them
485            requiredLimittype = (name in self.requiredStructs)
486
487            info = self.reg.typedict[name]
488
489            self.set_error_context(entity=name, elem=info.elem)
490
491            badFields = self.__validateStructLimittypes(name, info, requiredLimittype)
492            for extendingStructName in self.reg.validextensionstructs[name]:
493                extendingStruct = self.reg.typedict[extendingStructName]
494                badFields.update(self.__validateStructLimittypes(extendingStructName, extendingStruct, requiredLimittype))
495
496            if badFields:
497                for key in sorted(badFields.keys()):
498                    diags = badFields[key]
499                    if diags.missing:
500                        self.record_error(f'{name} missing limittype for members {", ".join(badFields[key].missing)}')
501                    if diags.invalid:
502                        self.record_error(f'{name} has invalid limittype for members {", ".join(badFields[key].invalid)}')
503
504    def check_type(self, name, info, category):
505        """Check a type's XML data for consistency.
506
507        Called from check."""
508
509        if category == 'struct':
510            type_elts = [elt
511                         for elt in info.elem.findall('member')
512                         if getElemType(elt) == TYPEENUM]
513
514            if type_elts:
515                self.check_type_stype(name, info, type_elts)
516                self.check_type_pnext(name, info)
517
518            # Check for disallowed limittypes on all structures
519            self.check_type_disallowed_limittype(name, info)
520
521            # Check for disallowed 'optional' values
522            self.check_type_optional_value(name, info)
523        elif category == 'bitmask':
524            if 'Flags' in name:
525                expected_require = name.replace('Flags', 'FlagBits')
526                require = info.elem.get('require')
527                if require is not None and expected_require != require:
528                    self.record_error('Unexpected require attribute value:',
529                                      'got', require,
530                                      'but expected', expected_require)
531        super().check_type(name, info, category)
532
533    def check_suffixes(self, name, info, supported, name_exceptions):
534        """Check suffixes of new APIs required by an extension, which should
535           match the author ID of the extension.
536
537           Called from check_extension.
538
539           name - extension name
540           info - extdict entry for name
541           supported - True if extension supported by API being checked
542           name_exceptions - set of API names to not check, in addition to
543                             the global exception list."""
544
545        def has_suffix(apiname, author):
546            return apiname[-len(author):] == author
547
548        def has_any_suffixes(apiname, authors):
549            for author in authors:
550                if has_suffix(apiname, author):
551                    return True
552            return False
553
554        def check_names(elems, author, alt_authors, name_exceptions):
555            """Check names in a list of elements for consistency
556
557               elems - list of elements to check
558               author - author ID of the <extension> tag
559               alt_authors - set of other allowed author IDs
560               name_exceptions - additional set of allowed exceptions"""
561
562            for elem in elems:
563                apiname = elem.get('name', 'NO NAME ATTRIBUTE')
564                suffix = apiname[-len(author):]
565
566                if (not has_suffix(apiname, author) and
567                    apiname not in EXTENSION_API_NAME_EXCEPTIONS and
568                    apiname not in name_exceptions):
569
570                    msg = f'Extension {name} <{elem.tag}> {apiname} does not have expected suffix {author}'
571
572                    # Explicit 'aliased' deprecations not matching the
573                    # naming rules are allowed, but warned.
574                    if has_any_suffixes(apiname, alt_authors):
575                        self.record_warning('Allowed alternate author ID:', msg)
576                    elif not supported:
577                        self.record_warning('Allowed inconsistency for disabled extension:', msg)
578                    elif elem.get('deprecated') == 'aliased':
579                        self.record_warning('Allowed aliasing deprecation:', msg)
580                    else:
581                        msg += '\n\
582This may be due to an extension interaction not having the correct <require depends="">\n\
583Other exceptions can be added to xml_consistency.py:EXTENSION_API_NAME_EXCEPTIONS'
584                        self.record_error(msg)
585
586        elem = info.elem
587
588        self.set_error_context(entity=name, elem=elem)
589
590        # Extract author ID from the extension name.
591        author = name.split('_')[1]
592
593        # Loop over each <require> tag checking the API name suffixes in
594        # that tag for consistency.
595        # Names in tags whose 'depends' attribute includes extensions with
596        # different author IDs may be suffixed with those IDs.
597        for req_elem in elem.findall('./require'):
598            depends = req_elem.get('depends', '')
599            alt_authors = set()
600            if len(depends) > 0:
601                for name in dependencyNames(depends):
602                    # Skip core versions
603                    if name[0:11] != 'VK_VERSION_':
604                        # Extract author ID from extension name
605                        id = name.split('_')[1]
606                        alt_authors.add(id)
607
608            check_names(req_elem.findall('./command'), author, alt_authors, name_exceptions)
609            check_names(req_elem.findall('./type'), author, alt_authors, name_exceptions)
610            check_names(req_elem.findall('./enum'), author, alt_authors, name_exceptions)
611
612    def check_extension(self, name, info, supported):
613        """Check an extension's XML data for consistency.
614
615        Called from check."""
616
617        elem = info.elem
618        enums = elem.findall('./require/enum[@name]')
619
620        # If extension name is not on the exception list and matches the
621        # versioned-extension pattern, map the extension name to the version
622        # name with the version as a separate word. Otherwise just map it to
623        # the upper-case version of the extension name.
624
625        matches = EXTNAME_RE.fullmatch(name)
626        ext_versioned_name = False
627        if name in EXTENSION_ENUM_NAME_SPELLING_CHANGE:
628            ext_enum_name = EXTENSION_ENUM_NAME_SPELLING_CHANGE.get(name)
629        elif matches is None or name in EXTENSION_NAME_VERSION_EXCEPTIONS:
630            # This is the usual case, either a name that does not look
631            # versioned, or one that does but is on the exception list.
632            ext_enum_name = name.upper()
633        else:
634            # This is a versioned extension name.
635            # Treat the version number as a separate word.
636            base = matches.group('base')
637            version = matches.group('version')
638            ext_enum_name = base.upper() + '_' + version
639            # Keep track of this case
640            ext_versioned_name = True
641
642        # Look for the expected SPEC_VERSION token name
643        version_name = f'{ext_enum_name}_SPEC_VERSION'
644        version_elem = findNamedElem(enums, version_name)
645
646        if version_elem is None:
647            # Did not find a SPEC_VERSION enum matching the extension name
648            if ext_versioned_name:
649                suffix = '\n\
650    Make sure that trailing version numbers in extension names are treated\n\
651    as separate words in extension enumerant names. If this is an extension\n\
652    whose name ends in a number which is not a version, such as "...h264"\n\
653    or "...int16", add it to EXTENSION_NAME_VERSION_EXCEPTIONS in\n\
654    scripts/xml_consistency.py.'
655            else:
656                suffix = ''
657            self.record_error(f'Missing version enum {version_name}{suffix}')
658        elif supported:
659            # Skip unsupported / disabled extensions for these checks
660
661            fn = get_extension_source(name)
662            revisions = []
663            with open(fn, 'r', encoding='utf-8') as fp:
664                for line in fp:
665                    line = line.rstrip()
666                    match = REVISION_RE.match(line)
667                    if match:
668                        revisions.append(int(match.group('num')))
669            ver_from_xml = version_elem.get('value')
670            if revisions:
671                ver_from_text = str(max(revisions))
672                if ver_from_xml != ver_from_text:
673                    self.record_error('Version enum mismatch: spec text indicates', ver_from_text,
674                                      'but XML says', ver_from_xml)
675            else:
676                if ver_from_xml == '1':
677                    self.record_warning(
678                        "Cannot find version history in spec text - make sure it has lines starting exactly like '  * Revision 1, ....'",
679                        filename=fn)
680                else:
681                    self.record_warning("Cannot find version history in spec text, but XML reports a non-1 version number", ver_from_xml,
682                                        " - make sure the spec text has lines starting exactly like '  * Revision 1, ....'",
683                                        filename=fn)
684
685            for enum in enums:
686                enum_name = enum.get('name')
687                if '_RESERVED_' in enum_name and enum_name not in EXTENSION_NAME_RESERVED_EXCEPTIONS:
688                    self.record_error(enum_name, 'should not contain _RESERVED_ for a supported extension.\n\
689If this is intentional, add it to EXTENSION_NAME_RESERVED_EXCEPTIONS in scripts/xml_consistency.py.')
690
691        name_define = f'{ext_enum_name}_EXTENSION_NAME'
692        name_elem = findNamedElem(enums, name_define)
693        if name_elem is None:
694            self.record_error('Missing name enum', name_define)
695        else:
696            # Note: etree handles the XML entities here and turns &quot; back into "
697            expected_name = f'"{name}"'
698            name_val = name_elem.get('value')
699            if name_val != expected_name:
700                self.record_error('Incorrect name enum: expected', expected_name,
701                                  'got', name_val)
702
703        self.check_suffixes(name, info, supported, { version_name, name_define })
704
705        # More general checks
706        super().check_extension(name, info, supported)
707
708    def check_format(self):
709        """Check an extension's XML data for consistency.
710
711        Called from check."""
712
713        astc3d_formats = [
714                'VK_FORMAT_ASTC_3x3x3_UNORM_BLOCK_EXT',
715                'VK_FORMAT_ASTC_3x3x3_SRGB_BLOCK_EXT',
716                'VK_FORMAT_ASTC_3x3x3_SFLOAT_BLOCK_EXT',
717                'VK_FORMAT_ASTC_4x3x3_UNORM_BLOCK_EXT',
718                'VK_FORMAT_ASTC_4x3x3_SRGB_BLOCK_EXT',
719                'VK_FORMAT_ASTC_4x3x3_SFLOAT_BLOCK_EXT',
720                'VK_FORMAT_ASTC_4x4x3_UNORM_BLOCK_EXT',
721                'VK_FORMAT_ASTC_4x4x3_SRGB_BLOCK_EXT',
722                'VK_FORMAT_ASTC_4x4x3_SFLOAT_BLOCK_EXT',
723                'VK_FORMAT_ASTC_4x4x4_UNORM_BLOCK_EXT',
724                'VK_FORMAT_ASTC_4x4x4_SRGB_BLOCK_EXT',
725                'VK_FORMAT_ASTC_4x4x4_SFLOAT_BLOCK_EXT',
726                'VK_FORMAT_ASTC_5x4x4_UNORM_BLOCK_EXT',
727                'VK_FORMAT_ASTC_5x4x4_SRGB_BLOCK_EXT',
728                'VK_FORMAT_ASTC_5x4x4_SFLOAT_BLOCK_EXT',
729                'VK_FORMAT_ASTC_5x5x4_UNORM_BLOCK_EXT',
730                'VK_FORMAT_ASTC_5x5x4_SRGB_BLOCK_EXT',
731                'VK_FORMAT_ASTC_5x5x4_SFLOAT_BLOCK_EXT',
732                'VK_FORMAT_ASTC_5x5x5_UNORM_BLOCK_EXT',
733                'VK_FORMAT_ASTC_5x5x5_SRGB_BLOCK_EXT',
734                'VK_FORMAT_ASTC_5x5x5_SFLOAT_BLOCK_EXT',
735                'VK_FORMAT_ASTC_6x5x5_UNORM_BLOCK_EXT',
736                'VK_FORMAT_ASTC_6x5x5_SRGB_BLOCK_EXT',
737                'VK_FORMAT_ASTC_6x5x5_SFLOAT_BLOCK_EXT',
738                'VK_FORMAT_ASTC_6x6x5_UNORM_BLOCK_EXT',
739                'VK_FORMAT_ASTC_6x6x5_SRGB_BLOCK_EXT',
740                'VK_FORMAT_ASTC_6x6x5_SFLOAT_BLOCK_EXT',
741                'VK_FORMAT_ASTC_6x6x6_UNORM_BLOCK_EXT',
742                'VK_FORMAT_ASTC_6x6x6_SRGB_BLOCK_EXT',
743                'VK_FORMAT_ASTC_6x6x6_SFLOAT_BLOCK_EXT'
744        ]
745
746        # Need to build list of formats from rest of <enums>
747        enum_formats = []
748        for enum in self.reg.groupdict['VkFormat'].elem:
749            if enum.get('alias') is None and enum.get('name') != 'VK_FORMAT_UNDEFINED':
750                enum_formats.append(enum.get('name'))
751
752        found_formats = []
753        for name, info in self.reg.formatsdict.items():
754            found_formats.append(name)
755            self.set_error_context(entity=name, elem=info.elem)
756
757            if name not in enum_formats:
758                self.record_error('The <format> has no matching <enum> for', name)
759
760            # Check never just 1 plane
761            plane_elems = info.elem.findall('plane')
762            if len(plane_elems) == 1:
763                self.record_error('The <format> has only 1 <plane> for', name)
764
765            valid_chroma = ['420', '422', '444']
766            if info.elem.get('chroma') and info.elem.get('chroma') not in valid_chroma:
767                self.record_error('The <format> has chroma is not a valid value for', name)
768
769            # The formatsgenerator.py assumes only 1 <spirvimageformat> tag.
770            # If this changes in the future, remove this warning and update generator script
771            spirv_image_format = info.elem.findall('spirvimageformat')
772            if len(spirv_image_format) > 1:
773                self.record_error('More than 1 <spirvimageformat> but formatsgenerator.py is not updated, for format', name)
774
775        # Re-loop to check the other way if the <format> is missing
776        for enum in self.reg.groupdict['VkFormat'].elem:
777            name = enum.get('name')
778            if enum.get('alias') is None and name != 'VK_FORMAT_UNDEFINED':
779                if name not in found_formats and name not in astc3d_formats:
780                    self.set_error_context(entity=name, elem=enum)
781                    self.record_error('The <enum> has no matching <format> for ', name)
782
783        super().check_format()
784
785        # This should be called from check() but as a first pass, do it here
786        # Check for invalid version names in e.g.
787        #    <enable version="VK_VERSION_1_2"/>
788        # Could also consistency check struct / extension tags here
789        for capname in self.reg.spirvcapdict:
790            for elem in self.reg.spirvcapdict[capname].elem.findall('enable'):
791                version = elem.get('version')
792                if version is not None and version not in self.reg.apidict:
793                    self.set_error_context(entity=capname, elem=elem)
794                    self.record_error(f'<spirvcapability> {capname} enabled by a nonexistent version {version}')
795
796if __name__ == '__main__':
797
798    parser = argparse.ArgumentParser()
799    parser.add_argument('-warn', action='store_true',
800                        help='Enable display of warning messages')
801    parser.add_argument('files', metavar='filename', nargs='*',
802                        help='XML filename to check')
803
804    args = parser.parse_args()
805
806    ckr = Checker(args)
807    ckr.check()
808
809    if ckr.fail:
810        sys.exit(1)
811