1# Copyright (c) 2018 The Android Open Source Project
2# Copyright (c) 2018 Google Inc.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16from .vulkantypes import VulkanType, VulkanTypeInfo, VulkanCompoundType, VulkanAPI
17from collections import OrderedDict
18from copy import copy
19from pathlib import Path, PurePosixPath
20
21import os
22import sys
23import shutil
24import subprocess
25
26# Class capturing a single file
27
28
29class SingleFileModule(object):
30    def __init__(self, suffix, directory, basename, customAbsDir=None, suppress=False):
31        self.directory = directory
32        self.basename = basename
33        self.customAbsDir = customAbsDir
34        self.suffix = suffix
35        self.file = None
36
37        self.preamble = ""
38        self.postamble = ""
39
40        self.suppress = suppress
41
42    def begin(self, globalDir):
43        if self.suppress:
44            return
45
46        # Create subdirectory, if needed
47        if self.customAbsDir:
48            absDir = self.customAbsDir
49        else:
50            absDir = os.path.join(globalDir, self.directory)
51
52        filename = os.path.join(absDir, self.basename)
53
54        self.file = open(filename + self.suffix, "w", encoding="utf-8")
55        self.file.write(self.preamble)
56
57    def append(self, toAppend):
58        if self.suppress:
59            return
60
61        self.file.write(toAppend)
62
63    def end(self):
64        if self.suppress:
65            return
66
67        self.file.write(self.postamble)
68        self.file.close()
69
70    def getMakefileSrcEntry(self):
71        return ""
72
73    def getCMakeSrcEntry(self):
74        return ""
75
76# Class capturing a .cpp file and a .h file (a "C++ module")
77
78
79class Module(object):
80
81    def __init__(
82            self, directory, basename, customAbsDir=None, suppress=False, implOnly=False,
83            headerOnly=False, suppressFeatureGuards=False):
84        self._headerFileModule = SingleFileModule(
85            ".h", directory, basename, customAbsDir, suppress or implOnly)
86        self._implFileModule = SingleFileModule(
87            ".cpp", directory, basename, customAbsDir, suppress or headerOnly)
88
89        self._headerOnly = headerOnly
90        self._implOnly = implOnly
91
92        self.directory = directory
93        self.basename = basename
94        self._customAbsDir = customAbsDir
95
96        self.suppressFeatureGuards = suppressFeatureGuards
97
98    @property
99    def suppress(self):
100        raise AttributeError("suppress is write only")
101
102    @suppress.setter
103    def suppress(self, value: bool):
104        self._headerFileModule.suppress = self._implOnly or value
105        self._implFileModule.suppress = self._headerOnly or value
106
107    @property
108    def headerPreamble(self) -> str:
109        return self._headerFileModule.preamble
110
111    @headerPreamble.setter
112    def headerPreamble(self, value: str):
113        self._headerFileModule.preamble = value
114
115    @property
116    def headerPostamble(self) -> str:
117        return self._headerFileModule.postamble
118
119    @headerPostamble.setter
120    def headerPostamble(self, value: str):
121        self._headerFileModule.postamble = value
122
123    @property
124    def implPreamble(self) -> str:
125        return self._implFileModule.preamble
126
127    @implPreamble.setter
128    def implPreamble(self, value: str):
129        self._implFileModule.preamble = value
130
131    @property
132    def implPostamble(self) -> str:
133        return self._implFileModule.postamble
134
135    @implPostamble.setter
136    def implPostamble(self, value: str):
137        self._implFileModule.postamble = value
138
139    def getMakefileSrcEntry(self):
140        if self._customAbsDir:
141            return self.basename + ".cpp \\\n"
142        dirName = self.directory
143        baseName = self.basename
144        joined = os.path.join(dirName, baseName)
145        return "    " + joined + ".cpp \\\n"
146
147    def getCMakeSrcEntry(self):
148        if self._customAbsDir:
149            return "\n" + self.basename + ".cpp "
150        dirName = Path(self.directory)
151        baseName = Path(self.basename)
152        joined = PurePosixPath(dirName / baseName)
153        return "\n    " + str(joined) + ".cpp "
154
155    def begin(self, globalDir):
156        self._headerFileModule.begin(globalDir)
157        self._implFileModule.begin(globalDir)
158
159    def appendHeader(self, toAppend):
160        self._headerFileModule.append(toAppend)
161
162    def appendImpl(self, toAppend):
163        self._implFileModule.append(toAppend)
164
165    def end(self):
166        self._headerFileModule.end()
167        self._implFileModule.end()
168
169        clang_format_command = shutil.which('clang-format')
170        assert (clang_format_command is not None)
171
172        def formatFile(filename: Path):
173            assert (subprocess.call([clang_format_command, "-i",
174                    "--style=file", str(filename.resolve())]) == 0)
175
176        if not self._headerFileModule.suppress:
177            formatFile(Path(self._headerFileModule.file.name))
178
179        if not self._implFileModule.suppress:
180            formatFile(Path(self._implFileModule.file.name))
181
182
183class PyScript(SingleFileModule):
184    def __init__(self, directory, basename, customAbsDir=None, suppress=False):
185        super().__init__(".py", directory, basename, customAbsDir, suppress)
186
187
188# Class capturing a .proto protobuf definition file
189class Proto(SingleFileModule):
190
191    def __init__(self, directory, basename, customAbsDir=None, suppress=False):
192        super().__init__(".proto", directory, basename, customAbsDir, suppress)
193
194    def getMakefileSrcEntry(self):
195        super().getMakefileSrcEntry()
196        if self.customAbsDir:
197            return self.basename + ".proto \\\n"
198        dirName = self.directory
199        baseName = self.basename
200        joined = os.path.join(dirName, baseName)
201        return "    " + joined + ".proto \\\n"
202
203    def getCMakeSrcEntry(self):
204        super().getCMakeSrcEntry()
205        if self.customAbsDir:
206            return "\n" + self.basename + ".proto "
207
208        dirName = self.directory
209        baseName = self.basename
210        joined = os.path.join(dirName, baseName)
211        return "\n    " + joined + ".proto "
212
213class CodeGen(object):
214
215    def __init__(self,):
216        self.code = ""
217        self.indentLevel = 0
218        self.gensymCounter = [-1]
219
220    def var(self, prefix="cgen_var"):
221        self.gensymCounter[-1] += 1
222        res = "%s_%s" % (prefix, '_'.join(str(i) for i in self.gensymCounter if i >= 0))
223        return res
224
225    def swapCode(self,):
226        res = "%s" % self.code
227        self.code = ""
228        return res
229
230    def indent(self,extra=0):
231        return "".join("    " * (self.indentLevel + extra))
232
233    def incrIndent(self,):
234        self.indentLevel += 1
235
236    def decrIndent(self,):
237        if self.indentLevel > 0:
238            self.indentLevel -= 1
239
240    def beginBlock(self, bracketPrint=True):
241        if bracketPrint:
242            self.code += self.indent() + "{\n"
243        self.indentLevel += 1
244        self.gensymCounter.append(-1)
245
246    def endBlock(self,bracketPrint=True):
247        self.indentLevel -= 1
248        if bracketPrint:
249            self.code += self.indent() + "}\n"
250        del self.gensymCounter[-1]
251
252    def beginIf(self, cond):
253        self.code += self.indent() + "if (" + cond + ")\n"
254        self.beginBlock()
255
256    def beginElse(self, cond = None):
257        if cond is not None:
258            self.code += \
259                self.indent() + \
260                "else if (" + cond + ")\n"
261        else:
262            self.code += self.indent() + "else\n"
263        self.beginBlock()
264
265    def endElse(self):
266        self.endBlock()
267
268    def endIf(self):
269        self.endBlock()
270
271    def beginSwitch(self, switchvar):
272        self.code += self.indent() + "switch (" + switchvar + ")\n"
273        self.beginBlock()
274
275    def switchCase(self, switchval, blocked = False):
276        self.code += self.indent() + "case %s:" % switchval
277        self.beginBlock(bracketPrint = blocked)
278
279    def switchCaseBreak(self, switchval, blocked = False):
280        self.code += self.indent() + "case %s:" % switchval
281        self.endBlock(bracketPrint = blocked)
282
283    def switchCaseDefault(self, blocked = False):
284        self.code += self.indent() + "default:" % switchval
285        self.beginBlock(bracketPrint = blocked)
286
287    def endSwitch(self):
288        self.endBlock()
289
290    def beginWhile(self, cond):
291        self.code += self.indent() + "while (" + cond + ")\n"
292        self.beginBlock()
293
294    def endWhile(self):
295        self.endBlock()
296
297    def beginFor(self, initial, condition, increment):
298        self.code += \
299            self.indent() + "for (" + \
300            "; ".join([initial, condition, increment]) + \
301            ")\n"
302        self.beginBlock()
303
304    def endFor(self):
305        self.endBlock()
306
307    def beginLoop(self, loopVarType, loopVar, loopInit, loopBound):
308        self.beginFor(
309            "%s %s = %s" % (loopVarType, loopVar, loopInit),
310            "%s < %s" % (loopVar, loopBound),
311            "++%s" % (loopVar))
312
313    def endLoop(self):
314        self.endBlock()
315
316    def stmt(self, code):
317        self.code += self.indent() + code + ";\n"
318
319    def line(self, code):
320        self.code += self.indent() + code + "\n"
321
322    def leftline(self, code):
323        self.code += code + "\n"
324
325    def makeCallExpr(self, funcName, parameters):
326        return funcName + "(%s)" % (", ".join(parameters))
327
328    def funcCall(self, lhs, funcName, parameters):
329        res = self.indent()
330
331        if lhs is not None:
332            res += lhs + " = "
333
334        res += self.makeCallExpr(funcName, parameters) + ";\n"
335        self.code += res
336
337    def funcCallRet(self, _lhs, funcName, parameters):
338        res = self.indent()
339        res += "return " + self.makeCallExpr(funcName, parameters) + ";\n"
340        self.code += res
341
342    # Given a VulkanType object, generate a C type declaration
343    # with optional parameter name:
344    # [const] [typename][*][const*] [paramName]
345    def makeCTypeDecl(self, vulkanType, useParamName=True):
346        constness = "const " if vulkanType.isConst else ""
347        typeName = vulkanType.typeName
348
349        if vulkanType.pointerIndirectionLevels == 0:
350            ptrSpec = ""
351        elif vulkanType.isPointerToConstPointer:
352            ptrSpec = "* const*" if vulkanType.isConst else "**"
353            if vulkanType.pointerIndirectionLevels > 2:
354                ptrSpec += "*" * (vulkanType.pointerIndirectionLevels - 2)
355        else:
356            ptrSpec = "*" * vulkanType.pointerIndirectionLevels
357
358        if useParamName and (vulkanType.paramName is not None):
359            paramStr = (" " + vulkanType.paramName)
360        else:
361            paramStr = ""
362
363        return "%s%s%s%s" % (constness, typeName, ptrSpec, paramStr)
364
365    def makeRichCTypeDecl(self, vulkanType, useParamName=True):
366        constness = "const " if vulkanType.isConst else ""
367        typeName = vulkanType.typeName
368
369        if vulkanType.pointerIndirectionLevels == 0:
370            ptrSpec = ""
371        elif vulkanType.isPointerToConstPointer:
372            ptrSpec = "* const*" if vulkanType.isConst else "**"
373            if vulkanType.pointerIndirectionLevels > 2:
374                ptrSpec += "*" * (vulkanType.pointerIndirectionLevels - 2)
375        else:
376            ptrSpec = "*" * vulkanType.pointerIndirectionLevels
377
378        if useParamName and (vulkanType.paramName is not None):
379            paramStr = (" " + vulkanType.paramName)
380        else:
381            paramStr = ""
382
383        if vulkanType.staticArrExpr:
384            staticArrInfo = "[%s]" % vulkanType.staticArrExpr
385        else:
386            staticArrInfo = ""
387
388        return "%s%s%s%s%s" % (constness, typeName, ptrSpec, paramStr, staticArrInfo)
389
390    # Given a VulkanAPI object, generate the C function protype:
391    # <returntype> <funcname>(<parameters>)
392    def makeFuncProto(self, vulkanApi, useParamName=True):
393
394        protoBegin = "%s %s" % (self.makeCTypeDecl(
395            vulkanApi.retType, useParamName=False), vulkanApi.name)
396
397        def getFuncArgDecl(param):
398            if param.staticArrExpr:
399                return self.makeCTypeDecl(param, useParamName=useParamName) + ("[%s]" % param.staticArrExpr)
400            else:
401                return self.makeCTypeDecl(param, useParamName=useParamName)
402
403        protoParams = "(\n    %s)" % ((",\n%s" % self.indent(1)).join(
404            list(map(
405                getFuncArgDecl,
406                vulkanApi.parameters))))
407
408        return protoBegin + protoParams
409
410    def makeFuncAlias(self, nameDst, nameSrc):
411        return "DEFINE_ALIAS_FUNCTION({}, {})\n\n".format(nameSrc, nameDst)
412
413    def makeFuncDecl(self, vulkanApi):
414        return self.makeFuncProto(vulkanApi) + ";\n\n"
415
416    def makeFuncImpl(self, vulkanApi, codegenFunc):
417        self.swapCode()
418
419        self.line(self.makeFuncProto(vulkanApi))
420        self.beginBlock()
421        codegenFunc(self)
422        self.endBlock()
423
424        return self.swapCode() + "\n"
425
426    def emitFuncImpl(self, vulkanApi, codegenFunc):
427        self.line(self.makeFuncProto(vulkanApi))
428        self.beginBlock()
429        codegenFunc(self)
430        self.endBlock()
431
432    def makeStructAccess(self,
433                         vulkanType,
434                         structVarName,
435                         asPtr=True,
436                         structAsPtr=True,
437                         accessIndex=None):
438
439        deref = "->" if structAsPtr else "."
440
441        indexExpr = (" + %s" % accessIndex) if accessIndex else ""
442
443        addrOfExpr = "" if vulkanType.accessibleAsPointer() or (
444            not asPtr) else "&"
445
446        return "%s%s%s%s%s" % (addrOfExpr, structVarName, deref,
447                               vulkanType.paramName, indexExpr)
448
449    def makeRawLengthAccess(self, vulkanType):
450        lenExpr = vulkanType.getLengthExpression()
451
452        if not lenExpr:
453            return None, None
454
455        if lenExpr == "null-terminated":
456            return "strlen(%s)" % vulkanType.paramName, None
457
458        return lenExpr, None
459
460    def makeLengthAccessFromStruct(self,
461                                   structInfo,
462                                   vulkanType,
463                                   structVarName,
464                                   asPtr=True):
465        # Handle special cases first
466        # Mostly when latexmath is involved
467        def handleSpecialCases(structInfo, vulkanType, structVarName, asPtr):
468            cases = [
469                {
470                    "structName": "VkShaderModuleCreateInfo",
471                    "field": "pCode",
472                    "lenExprMember": "codeSize",
473                    "postprocess": lambda expr: "(%s / 4)" % expr
474                },
475                {
476                    "structName": "VkPipelineMultisampleStateCreateInfo",
477                    "field": "pSampleMask",
478                    "lenExprMember": "rasterizationSamples",
479                    "postprocess": lambda expr: "(((%s) + 31) / 32)" % expr
480                },
481                {
482                    "structName": "VkAccelerationStructureVersionInfoKHR",
483                    "field": "pVersionData",
484                    "lenExprMember": "",
485                    "postprocess": lambda _: "2*VK_UUID_SIZE"
486                },
487            ]
488
489            for c in cases:
490                if (structInfo.name, vulkanType.paramName) == (c["structName"],
491                                                               c["field"]):
492                    deref = "->" if asPtr else "."
493                    expr = "%s%s%s" % (structVarName, deref,
494                                       c["lenExprMember"])
495                    lenAccessGuardExpr = "%s" % structVarName
496                    return c["postprocess"](expr), lenAccessGuardExpr
497
498            return None, None
499
500        specialCaseAccess = \
501            handleSpecialCases(
502                structInfo, vulkanType, structVarName, asPtr)
503
504        if specialCaseAccess != (None, None):
505            return specialCaseAccess
506
507        lenExpr = vulkanType.getLengthExpression()
508
509        if not lenExpr:
510            return None, None
511
512        deref = "->" if asPtr else "."
513        lenAccessGuardExpr = "%s" % (
514
515            structVarName) if deref else None
516        if lenExpr == "null-terminated":
517            return "strlen(%s%s%s)" % (structVarName, deref,
518                                       vulkanType.paramName), lenAccessGuardExpr
519
520        if not structInfo.getMember(lenExpr):
521            return self.makeRawLengthAccess(vulkanType)
522
523        return "%s%s%s" % (structVarName, deref, lenExpr), lenAccessGuardExpr
524
525    def makeLengthAccessFromApi(self, api, vulkanType):
526        # Handle special cases first
527        # Mostly when :: is involved
528        def handleSpecialCases(vulkanType):
529            lenExpr = vulkanType.getLengthExpression()
530
531            if lenExpr is None:
532                return None, None
533
534            if "::" in lenExpr:
535                structVarName, memberVarName = lenExpr.split("::")
536                lenAccessGuardExpr = "%s" % (structVarName)
537                return "%s->%s" % (structVarName, memberVarName), lenAccessGuardExpr
538            return None, None
539
540        specialCaseAccess = handleSpecialCases(vulkanType)
541
542        if specialCaseAccess != (None, None):
543            return specialCaseAccess
544
545        lenExpr = vulkanType.getLengthExpression()
546
547        if not lenExpr:
548            return None, None
549
550        lenExprInfo = api.getParameter(lenExpr)
551
552        if not lenExprInfo:
553            return self.makeRawLengthAccess(vulkanType)
554
555        if lenExpr == "null-terminated":
556            return "strlen(%s)" % vulkanType.paramName(), None
557        else:
558            deref = "*" if lenExprInfo.pointerIndirectionLevels > 0 else ""
559            lenAccessGuardExpr = "%s" % lenExpr if deref else None
560            return "(%s(%s))" % (deref, lenExpr), lenAccessGuardExpr
561
562    def accessParameter(self, param, asPtr=True):
563        if asPtr:
564            if param.pointerIndirectionLevels > 0:
565                return param.paramName
566            else:
567                return "&%s" % param.paramName
568        else:
569            return param.paramName
570
571    def sizeofExpr(self, vulkanType):
572        return "sizeof(%s)" % (
573            self.makeCTypeDecl(vulkanType, useParamName=False))
574
575    def generalAccess(self,
576                      vulkanType,
577                      parentVarName=None,
578                      asPtr=True,
579                      structAsPtr=True):
580        if vulkanType.parent is None:
581            if parentVarName is None:
582                return self.accessParameter(vulkanType, asPtr=asPtr)
583            else:
584                return self.accessParameter(vulkanType.withModifiedName(parentVarName), asPtr=asPtr)
585
586        if isinstance(vulkanType.parent, VulkanCompoundType):
587            return self.makeStructAccess(
588                vulkanType, parentVarName, asPtr=asPtr, structAsPtr=structAsPtr)
589
590        if isinstance(vulkanType.parent, VulkanAPI):
591            if parentVarName is None:
592                return self.accessParameter(vulkanType, asPtr=asPtr)
593            else:
594                return self.accessParameter(vulkanType.withModifiedName(parentVarName), asPtr=asPtr)
595
596        os.abort("Could not find a way to access Vulkan type %s" %
597                 vulkanType.name)
598
599    def makeLengthAccess(self, vulkanType, parentVarName="parent"):
600        if vulkanType.parent is None:
601            return self.makeRawLengthAccess(vulkanType)
602
603        if isinstance(vulkanType.parent, VulkanCompoundType):
604            return self.makeLengthAccessFromStruct(
605                vulkanType.parent, vulkanType, parentVarName, asPtr=True)
606
607        if isinstance(vulkanType.parent, VulkanAPI):
608            return self.makeLengthAccessFromApi(vulkanType.parent, vulkanType)
609
610        os.abort("Could not find a way to access length of Vulkan type %s" %
611                 vulkanType.name)
612
613    def generalLengthAccess(self, vulkanType, parentVarName="parent"):
614        return self.makeLengthAccess(vulkanType, parentVarName)[0]
615
616    def generalLengthAccessGuard(self, vulkanType, parentVarName="parent"):
617        return self.makeLengthAccess(vulkanType, parentVarName)[1]
618
619    def vkApiCall(self, api, customPrefix="", globalStatePrefix="", customParameters=None, checkForDeviceLost=False, checkForOutOfMemory=False):
620        callLhs = None
621
622        retTypeName = api.getRetTypeExpr()
623        retVar = None
624
625        if retTypeName != "void":
626            retVar = api.getRetVarExpr()
627            self.stmt("%s %s = (%s)0" % (retTypeName, retVar, retTypeName))
628            callLhs = retVar
629
630        if customParameters is None:
631            self.funcCall(
632            callLhs, customPrefix + api.name, [p.paramName for p in api.parameters])
633        else:
634            self.funcCall(
635                callLhs, customPrefix + api.name, customParameters)
636
637        if retTypeName == "VkResult" and checkForDeviceLost:
638            self.stmt("if ((%s) == VK_ERROR_DEVICE_LOST) %sDeviceLost()" % (callLhs, globalStatePrefix))
639
640        if retTypeName == "VkResult" and checkForOutOfMemory:
641            if api.name == "vkAllocateMemory":
642                self.stmt(
643                    "%sCheckOutOfMemory(%s, opcode, context, std::make_optional<uint64_t>(pAllocateInfo->allocationSize))"
644                    % (globalStatePrefix, callLhs))
645            else:
646                self.stmt(
647                    "%sCheckOutOfMemory(%s, opcode, context)"
648                    % (globalStatePrefix, callLhs))
649
650        return (retTypeName, retVar)
651
652    def makeCheckVkSuccess(self, expr):
653        return "((%s) == VK_SUCCESS)" % expr
654
655    def makeReinterpretCast(self, varName, typeName, const=True):
656        return "reinterpret_cast<%s%s*>(%s)" % \
657               ("const " if const else "", typeName, varName)
658
659    def validPrimitive(self, typeInfo, typeName):
660        size = typeInfo.getPrimitiveEncodingSize(typeName)
661        return size != None
662
663    def makePrimitiveStreamMethod(self, typeInfo, typeName, direction="write"):
664        if not self.validPrimitive(typeInfo, typeName):
665            return None
666
667        size = typeInfo.getPrimitiveEncodingSize(typeName)
668        prefix = "put" if direction == "write" else "get"
669        suffix = None
670        if size == 1:
671            suffix = "Byte"
672        elif size == 2:
673            suffix = "Be16"
674        elif size == 4:
675            suffix = "Be32"
676        elif size == 8:
677            suffix = "Be64"
678
679        if suffix:
680            return prefix + suffix
681
682        return None
683
684    def makePrimitiveStreamMethodInPlace(self, typeInfo, typeName, direction="write"):
685        if not self.validPrimitive(typeInfo, typeName):
686            return None
687
688        size = typeInfo.getPrimitiveEncodingSize(typeName)
689        prefix = "to" if direction == "write" else "from"
690        suffix = None
691        if size == 1:
692            suffix = "Byte"
693        elif size == 2:
694            suffix = "Be16"
695        elif size == 4:
696            suffix = "Be32"
697        elif size == 8:
698            suffix = "Be64"
699
700        if suffix:
701            return prefix + suffix
702
703        return None
704
705    def streamPrimitive(self, typeInfo, streamVar, accessExpr, accessType, direction="write"):
706        accessTypeName = accessType.typeName
707
708        if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName):
709            print("Tried to stream a non-primitive type: %s" % accessTypeName)
710            os.abort()
711
712        needPtrCast = False
713
714        if accessType.pointerIndirectionLevels > 0:
715            streamSize = 8
716            streamStorageVarType = "uint64_t"
717            needPtrCast = True
718            streamMethod = "putBe64" if direction == "write" else "getBe64"
719        else:
720            streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName)
721            if streamSize == 1:
722                streamStorageVarType = "uint8_t"
723            elif streamSize == 2:
724                streamStorageVarType = "uint16_t"
725            elif streamSize == 4:
726                streamStorageVarType = "uint32_t"
727            elif streamSize == 8:
728                streamStorageVarType = "uint64_t"
729            streamMethod = self.makePrimitiveStreamMethod(
730                typeInfo, accessTypeName, direction=direction)
731
732        streamStorageVar = self.var()
733
734        accessCast = self.makeRichCTypeDecl(accessType, useParamName=False)
735
736        ptrCast = "(uintptr_t)" if needPtrCast else ""
737
738        if direction == "read":
739            self.stmt("%s = (%s)%s%s->%s()" %
740                      (accessExpr,
741                       accessCast,
742                       ptrCast,
743                       streamVar,
744                       streamMethod))
745        else:
746            self.stmt("%s %s = (%s)%s%s" %
747                      (streamStorageVarType, streamStorageVar,
748                       streamStorageVarType, ptrCast, accessExpr))
749            self.stmt("%s->%s(%s)" %
750                      (streamVar, streamMethod, streamStorageVar))
751
752    def memcpyPrimitive(self, typeInfo, streamVar, accessExpr, accessType, variant, direction="write"):
753        accessTypeName = accessType.typeName
754
755        if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName):
756            print("Tried to stream a non-primitive type: %s" % accessTypeName)
757            os.abort()
758
759        needPtrCast = False
760
761        streamSize = 8
762
763        if accessType.pointerIndirectionLevels > 0:
764            streamSize = 8
765            streamStorageVarType = "uint64_t"
766            needPtrCast = True
767            streamMethod = "toBe64" if direction == "write" else "fromBe64"
768        else:
769            streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName)
770            if streamSize == 1:
771                streamStorageVarType = "uint8_t"
772            elif streamSize == 2:
773                streamStorageVarType = "uint16_t"
774            elif streamSize == 4:
775                streamStorageVarType = "uint32_t"
776            elif streamSize == 8:
777                streamStorageVarType = "uint64_t"
778            streamMethod = self.makePrimitiveStreamMethodInPlace(
779                typeInfo, accessTypeName, direction=direction)
780
781        streamStorageVar = self.var()
782
783        accessCast = self.makeRichCTypeDecl(accessType, useParamName=False)
784
785        if direction == "read":
786            accessCast = self.makeRichCTypeDecl(
787                accessType.getForNonConstAccess(), useParamName=False)
788
789        ptrCast = "(uintptr_t)" if needPtrCast else ""
790
791        streamNamespace = "gfxstream::guest" if variant == "guest" else "android::base"
792
793        if direction == "read":
794            self.stmt("memcpy((%s*)&%s, %s, %s)" %
795                      (accessCast,
796                       accessExpr,
797                       streamVar,
798                       str(streamSize)))
799            self.stmt("%s::Stream::%s((uint8_t*)&%s)" % (
800                streamNamespace,
801                streamMethod,
802                accessExpr))
803        else:
804            self.stmt("%s %s = (%s)%s%s" %
805                      (streamStorageVarType, streamStorageVar,
806                       streamStorageVarType, ptrCast, accessExpr))
807            self.stmt("memcpy(%s, &%s, %s)" %
808                      (streamVar, streamStorageVar, str(streamSize)))
809            self.stmt("%s::Stream::%s((uint8_t*)%s)" % (
810                streamNamespace,
811                streamMethod,
812                streamVar))
813
814    def countPrimitive(self, typeInfo, accessType):
815        accessTypeName = accessType.typeName
816
817        if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName):
818            print("Tried to count a non-primitive type: %s" % accessTypeName)
819            os.abort()
820
821        needPtrCast = False
822
823        if accessType.pointerIndirectionLevels > 0:
824            streamSize = 8
825        else:
826            streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName)
827
828        return streamSize
829
830# Class to wrap a Vulkan API call.
831#
832# The user gives a generic callback, |codegenDef|,
833# that takes a CodeGen object and a VulkanAPI object as arguments.
834# codegenDef uses CodeGen along with the VulkanAPI object
835# to generate the function body.
836class VulkanAPIWrapper(object):
837
838    def __init__(self,
839                 customApiPrefix,
840                 extraParameters=None,
841                 returnTypeOverride=None,
842                 codegenDef=None):
843        self.customApiPrefix = customApiPrefix
844        self.extraParameters = extraParameters
845        self.returnTypeOverride = returnTypeOverride
846
847        self.codegen = CodeGen()
848
849        self.definitionFunc = codegenDef
850
851        # Private function
852
853        def makeApiFunc(self, typeInfo, apiName):
854            customApi = copy(typeInfo.apis[apiName])
855            customApi.name = self.customApiPrefix + customApi.name
856            if self.extraParameters is not None:
857                if isinstance(self.extraParameters, list):
858                    customApi.parameters = \
859                        self.extraParameters + customApi.parameters
860                else:
861                    os.abort(
862                        "Type of extra parameters to custom API not valid. Expected list, got %s" % type(
863                            self.extraParameters))
864
865            if self.returnTypeOverride is not None:
866                customApi.retType = self.returnTypeOverride
867            return customApi
868
869        self.makeApi = makeApiFunc
870
871    def setCodegenDef(self, codegenDefFunc):
872        self.definitionFunc = codegenDefFunc
873
874    def makeDecl(self, typeInfo, apiName):
875        return self.codegen.makeFuncProto(
876            self.makeApi(self, typeInfo, apiName)) + ";\n\n"
877
878    def makeDefinition(self, typeInfo, apiName, isStatic=False):
879        vulkanApi = self.makeApi(self, typeInfo, apiName)
880
881        self.codegen.swapCode()
882        self.codegen.beginBlock()
883
884        if self.definitionFunc is None:
885            print("ERROR: No definition found for (%s, %s)" %
886                  (vulkanApi.name, self.customApiPrefix))
887            sys.exit(1)
888
889        self.definitionFunc(self.codegen, vulkanApi)
890
891        self.codegen.endBlock()
892
893        return ("static " if isStatic else "") + self.codegen.makeFuncProto(
894            vulkanApi) + "\n" + self.codegen.swapCode() + "\n"
895
896# Base class for wrapping all Vulkan API objects.  These work with Vulkan
897# Registry generators and have gen* triggers.  They tend to contain
898# VulkanAPIWrapper objects to make it easier to generate the code.
899class VulkanWrapperGenerator(object):
900
901    def __init__(self, module: Module, typeInfo: VulkanTypeInfo):
902        self.module: Module = module
903        self.typeInfo: VulkanTypeInfo = typeInfo
904        self.extensionStructTypes = OrderedDict()
905
906    def onBegin(self):
907        pass
908
909    def onEnd(self):
910        pass
911
912    def onBeginFeature(self, featureName, featureType):
913        pass
914
915    def onFeatureNewCmd(self, cmdName):
916        pass
917
918    def onEndFeature(self):
919        pass
920
921    def onGenType(self, typeInfo, name, alias):
922        category = self.typeInfo.categoryOf(name)
923        if category in ["struct", "union"] and not alias:
924            structInfo = self.typeInfo.structs[name]
925            if structInfo.structExtendsExpr:
926                self.extensionStructTypes[name] = structInfo
927        pass
928
929    def onGenStruct(self, typeInfo, name, alias):
930        pass
931
932    def onGenGroup(self, groupinfo, groupName, alias=None):
933        pass
934
935    def onGenEnum(self, enuminfo, name, alias):
936        pass
937
938    def onGenCmd(self, cmdinfo, name, alias):
939        pass
940
941    # Below Vulkan structure types may correspond to multiple Vulkan structs
942    # due to a conflict between different Vulkan registries. In order to get
943    # the correct Vulkan struct type, we need to check the type of its "root"
944    # struct as well.
945    ROOT_TYPE_MAPPING = {
946        "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FRAGMENT_DENSITY_MAP_FEATURES_EXT": {
947            "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT",
948            "VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT",
949            "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkImportColorBufferGOOGLE",
950            "default": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT",
951        },
952        "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FRAGMENT_DENSITY_MAP_PROPERTIES_EXT": {
953            "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2": "VkPhysicalDeviceFragmentDensityMapPropertiesEXT",
954            "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkCreateBlobGOOGLE",
955            "default": "VkPhysicalDeviceFragmentDensityMapPropertiesEXT",
956        },
957        "VK_STRUCTURE_TYPE_RENDER_PASS_FRAGMENT_DENSITY_MAP_CREATE_INFO_EXT": {
958            "VK_STRUCTURE_TYPE_RENDER_PASS_CREATE_INFO": "VkRenderPassFragmentDensityMapCreateInfoEXT",
959            "VK_STRUCTURE_TYPE_RENDER_PASS_CREATE_INFO_2": "VkRenderPassFragmentDensityMapCreateInfoEXT",
960            "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkImportBufferGOOGLE",
961            "default": "VkRenderPassFragmentDensityMapCreateInfoEXT",
962        },
963    }
964
965    def emitForEachStructExtension(self, cgen, retType, triggerVar, forEachFunc, autoBreak=True, defaultEmit=None, nullEmit=None, rootTypeVar=None):
966        def readStructType(structTypeName, structVarName, cgen):
967            cgen.stmt("uint32_t %s = (uint32_t)%s(%s)" % \
968                (structTypeName, "goldfish_vk_struct_type", structVarName))
969
970        def castAsStruct(varName, typeName, const=True):
971            return "reinterpret_cast<%s%s*>(%s)" % \
972                   ("const " if const else "", typeName, varName)
973
974        def doDefaultReturn(cgen):
975            if retType.typeName == "void":
976                cgen.stmt("return")
977            else:
978                cgen.stmt("return (%s)0" % retType.typeName)
979
980        cgen.beginIf("!%s" % triggerVar.paramName)
981        if nullEmit is None:
982            doDefaultReturn(cgen)
983        else:
984            nullEmit(cgen)
985        cgen.endIf()
986
987        readStructType("structType", triggerVar.paramName, cgen)
988
989        cgen.line("switch(structType)")
990        cgen.beginBlock()
991
992        currFeature = None
993
994        for ext in self.extensionStructTypes.values():
995            if not currFeature:
996                cgen.leftline("#ifdef %s" % ext.feature)
997                currFeature = ext.feature
998
999            if currFeature and ext.feature != currFeature:
1000                cgen.leftline("#endif")
1001                cgen.leftline("#ifdef %s" % ext.feature)
1002                currFeature = ext.feature
1003
1004            enum = ext.structEnumExpr
1005            protect = None
1006            if enum in self.typeInfo.enumElem:
1007                protect = self.typeInfo.enumElem[enum].get("protect", default=None)
1008                if protect is not None:
1009                    cgen.leftline("#ifdef %s" % protect)
1010
1011            cgen.line("case %s:" % enum)
1012            cgen.beginBlock()
1013
1014            if rootTypeVar is not None and enum in VulkanWrapperGenerator.ROOT_TYPE_MAPPING:
1015                cgen.line("switch(%s)" % rootTypeVar.paramName)
1016                cgen.beginBlock()
1017                kv = VulkanWrapperGenerator.ROOT_TYPE_MAPPING[enum]
1018                for k in kv:
1019                    v = self.extensionStructTypes[kv[k]]
1020                    if k == "default":
1021                        cgen.line("%s:" % k)
1022                    else:
1023                        cgen.line("case %s:" % k)
1024                    cgen.beginBlock()
1025                    castedAccess = castAsStruct(
1026                        triggerVar.paramName, v.name, const=triggerVar.isConst)
1027                    forEachFunc(v, castedAccess, cgen)
1028                    cgen.line("break;")
1029                    cgen.endBlock()
1030                cgen.endBlock()
1031            else:
1032                castedAccess = castAsStruct(
1033                    triggerVar.paramName, ext.name, const=triggerVar.isConst)
1034                forEachFunc(ext, castedAccess, cgen)
1035
1036            if autoBreak:
1037                cgen.stmt("break")
1038            cgen.endBlock()
1039
1040            if protect is not None:
1041                cgen.leftline("#endif // %s" % protect)
1042
1043        if currFeature:
1044            cgen.leftline("#endif")
1045
1046        cgen.line("default:")
1047        cgen.beginBlock()
1048        if defaultEmit is None:
1049            doDefaultReturn(cgen)
1050        else:
1051            defaultEmit(cgen)
1052        cgen.endBlock()
1053
1054        cgen.endBlock()
1055
1056    def emitForEachStructExtensionGeneral(self, cgen, forEachFunc, doFeatureIfdefs=False):
1057        currFeature = None
1058
1059        for (i, ext) in enumerate(self.extensionStructTypes.values()):
1060            if doFeatureIfdefs:
1061                if not currFeature:
1062                    cgen.leftline("#ifdef %s" % ext.feature)
1063                    currFeature = ext.feature
1064
1065                if currFeature and ext.feature != currFeature:
1066                    cgen.leftline("#endif")
1067                    cgen.leftline("#ifdef %s" % ext.feature)
1068                    currFeature = ext.feature
1069
1070            forEachFunc(i, ext, cgen)
1071
1072        if doFeatureIfdefs:
1073            if currFeature:
1074                cgen.leftline("#endif")
1075