1 /*
2  * Copyright 2017, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MODULE_H
18 #define MODULE_H
19 
20 #include <iostream>
21 #include <map>
22 #include <vector>
23 
24 #include "core_defs.h"
25 #include "entity.h"
26 #include "instructions.h"
27 #include "stl_util.h"
28 #include "types_generated.h"
29 #include "visitor.h"
30 
31 namespace android {
32 namespace spirit {
33 
34 class Builder;
35 class AnnotationSection;
36 class CapabilityInst;
37 class DebugInfoSection;
38 class ExtensionInst;
39 class ExtInstImportInst;
40 class EntryPointInst;
41 class ExecutionModeInst;
42 class EntryPointDefinition;
43 class FunctionDeclaration;
44 class FunctionDefinition;
45 class GlobalSection;
46 class InputWordStream;
47 class Instruction;
48 class MemoryModelInst;
49 
50 union VersionNumber {
51   struct {
52     uint8_t mLowZero;
53     uint8_t mMinorNumber;
54     uint8_t mMajorNumber;
55     uint8_t mHighZero;
56   } mMajorMinor;
57   uint8_t mBytes[4];
58   uint32_t mWord;
59 };
60 
61 class Module : public Entity {
62 public:
63   static Module *getCurrentModule();
nextId()64   uint32_t nextId() { return mNextId++; }
65 
66   Module();
67 
68   Module(Builder *b);
69 
~Module()70   virtual ~Module() {}
71 
72   bool DeserializeInternal(InputWordStream &IS) override;
73 
74   void Serialize(OutputWordStream &OS) const override;
75 
76   void SerializeHeader(OutputWordStream &OS) const;
77 
registerId(uint32_t id,Instruction * inst)78   void registerId(uint32_t id, Instruction *inst) {
79     mIdTable.insert(std::make_pair(id, inst));
80   }
81 
82   void initialize();
83 
84   bool resolveIds();
85 
accept(IVisitor * v)86   void accept(IVisitor *v) override {
87     for (auto cap : mCapabilities) {
88       v->visit(cap);
89     }
90     for (auto ext : mExtensions) {
91       v->visit(ext);
92     }
93     for (auto imp : mExtInstImports) {
94       v->visit(imp);
95     }
96 
97     v->visit(mMemoryModel.get());
98 
99     for (auto entry : mEntryPoints) {
100       v->visit(entry);
101     }
102 
103     for (auto mode : mExecutionModes) {
104       v->visit(mode);
105     }
106 
107     v->visit(mDebugInfo.get());
108     if (mAnnotations) {
109       v->visit(mAnnotations.get());
110     }
111     if (mGlobals) {
112       v->visit(mGlobals.get());
113     }
114 
115     for (auto def : mFunctionDefinitions) {
116       v->visit(def);
117     }
118   }
119 
errs()120   static std::ostream &errs() { return std::cerr; }
121 
122   Module *addCapability(Capability cap);
123   Module *setMemoryModel(AddressingModel am, MemoryModel mm);
124   Module *addExtInstImport(const char *extName);
125   Module *addSource(SourceLanguage lang, int version);
126   Module *addSourceExtension(const char *ext);
127   Module *addString(const char *ext);
128   Module *addEntryPoint(EntryPointDefinition *entry);
129 
getGLExt()130   ExtInstImportInst *getGLExt() const { return mGLExt; }
131 
132   GlobalSection *getGlobalSection();
133 
134   Instruction *lookupByName(const char *) const;
135   FunctionDefinition *
136   getFunctionDefinitionFromInstruction(FunctionInst *) const;
137   FunctionDefinition *lookupFunctionDefinitionByName(const char *name) const;
138 
139   // Find the name of the instruction, e.g., the name of a function (OpFunction
140   // instruction).
141   // The returned string is owned by the OpName instruction, whose first operand
142   // is the instruction being queried on.
143   const char *lookupNameByInstruction(const Instruction *) const;
144 
145   VariableInst *getInvocationId();
146   VariableInst *getNumWorkgroups();
147 
148   // Adds a struct type built somewhere else.
149   Module *addStructType(TypeStructInst *structType);
150   Module *addVariable(VariableInst *var);
151 
152   // Methods to look up types. Create them if not found.
153   TypeVoidInst *getVoidType();
154   TypeIntInst *getIntType(int bits, bool isSigned = true);
155   TypeIntInst *getUnsignedIntType(int bits);
156   TypeFloatInst *getFloatType(int bits);
157   TypeVectorInst *getVectorType(Instruction *componentType, int width);
158   TypePointerInst *getPointerType(StorageClass storage,
159                                   Instruction *pointeeType);
160   TypeRuntimeArrayInst *getRuntimeArrayType(Instruction *elementType);
161 
162   // This implies that struct types are strictly structural equivalent, i.e.,
163   // two structs are equivalent i.f.f. their fields are equivalent, recursively.
164   TypeStructInst *getStructType(Instruction *fieldType[], int numField);
165   TypeStructInst *getStructType(const std::vector<Instruction *> &fieldType);
166   TypeStructInst *getStructType(Instruction *field0Type);
167   TypeStructInst *getStructType(Instruction *field0Type,
168                                 Instruction *field1Type);
169   TypeStructInst *getStructType(Instruction *field0Type,
170                                 Instruction *field1Type,
171                                 Instruction *field2Type);
172 
173   // TODO: Can function types of different decorations be considered the same?
174   TypeFunctionInst *getFunctionType(Instruction *retType,
175                                     Instruction *const argType[],
176                                     size_t numArg);
177   TypeFunctionInst *getFunctionType(Instruction *retType,
178                                     const std::vector<Instruction *> &argTypes);
179 
180   size_t getSize(TypeVoidInst *voidTy);
181   size_t getSize(TypeIntInst *intTy);
182   size_t getSize(TypeFloatInst *fpTy);
183   size_t getSize(TypeVectorInst *vTy);
184   size_t getSize(TypePointerInst *ptrTy);
185   size_t getSize(TypeStructInst *structTy);
186   size_t getSize(TypeFunctionInst *funcTy);
187   size_t getSize(Instruction *inst);
188 
189   ConstantInst *getConstant(TypeIntInst *type, int32_t value);
190   ConstantInst *getConstant(TypeIntInst *type, uint32_t value);
191   ConstantInst *getConstant(TypeFloatInst *type, float value);
192 
193   ConstantCompositeInst *getConstantComposite(TypeVectorInst *type,
194                                               ConstantInst *components[],
195                                               size_t width);
196   ConstantCompositeInst *
197   getConstantComposite(Instruction *type,
198                        const std::vector<ConstantInst *> &components);
199   ConstantCompositeInst *getConstantComposite(Instruction *type,
200                                               ConstantInst *comp0,
201                                               ConstantInst *comp1);
202   ConstantCompositeInst *getConstantComposite(TypeVectorInst *type,
203                                               ConstantInst *comp0,
204                                               ConstantInst *comp1,
205                                               ConstantInst *comp2);
206   ConstantCompositeInst *getConstantComposite(TypeVectorInst *type,
207                                               ConstantInst *comp0,
208                                               ConstantInst *comp1,
209                                               ConstantInst *comp2,
210                                               ConstantInst *comp3);
211 
212   Module *addFunctionDefinition(FunctionDefinition *func);
213 
214   void consolidateAnnotations();
215 
216 private:
217   static Module *mInstance;
218   uint32_t mNextId;
219   std::map<uint32_t, Instruction *> mIdTable;
220 
221   uint32_t mMagicNumber;
222   VersionNumber mVersion;
223   uint32_t mGeneratorMagicNumber;
224   uint32_t mBound;
225   uint32_t mReserved;
226 
227   std::vector<CapabilityInst *> mCapabilities;
228   std::vector<ExtensionInst *> mExtensions;
229   std::vector<ExtInstImportInst *> mExtInstImports;
230   std::unique_ptr<MemoryModelInst> mMemoryModel;
231   std::vector<EntryPointInst *> mEntryPointInsts;
232   std::vector<ExecutionModeInst *> mExecutionModes;
233   std::vector<EntryPointDefinition *> mEntryPoints;
234   std::unique_ptr<DebugInfoSection> mDebugInfo;
235   std::unique_ptr<AnnotationSection> mAnnotations;
236   std::unique_ptr<GlobalSection> mGlobals;
237   std::vector<FunctionDefinition *> mFunctionDefinitions;
238 
239   ExtInstImportInst *mGLExt;
240 
241   ContainerDeleter<std::vector<CapabilityInst *>> mCapabilitiesDeleter;
242   ContainerDeleter<std::vector<ExtensionInst *>> mExtensionsDeleter;
243   ContainerDeleter<std::vector<ExtInstImportInst *>> mExtInstImportsDeleter;
244   ContainerDeleter<std::vector<EntryPointInst *>> mEntryPointInstsDeleter;
245   ContainerDeleter<std::vector<ExecutionModeInst *>> mExecutionModesDeleter;
246   ContainerDeleter<std::vector<EntryPointDefinition *>> mEntryPointsDeleter;
247   ContainerDeleter<std::vector<FunctionDefinition *>>
248       mFunctionDefinitionsDeleter;
249 };
250 
251 struct Extent3D {
252   uint32_t mWidth;
253   uint32_t mHeight;
254   uint32_t mDepth;
255 };
256 
257 class EntryPointDefinition : public Entity {
258 public:
EntryPointDefinition()259   EntryPointDefinition() {}
260   EntryPointDefinition(Builder *builder, ExecutionModel execModel,
261                        FunctionDefinition *func, const char *name);
262 
~EntryPointDefinition()263   virtual ~EntryPointDefinition() {
264     // Nothing to do here since ~Module() will delete entities referenced here
265   }
266 
accept(IVisitor * visitor)267   void accept(IVisitor *visitor) override {
268     visitor->visit(mEntryPointInst);
269     // Do not visit the ExecutionMode instructions here. They are linked here
270     // for convinience, and for convinience only. They are all grouped, stored,
271     // and serialized directly in the module in a section right after all
272     // EntryPoint instructions. Visit them from there.
273   }
274 
275   bool DeserializeInternal(InputWordStream &IS) override;
276 
277   EntryPointDefinition *addToInterface(VariableInst *var);
addExecutionMode(ExecutionModeInst * mode)278   EntryPointDefinition *addExecutionMode(ExecutionModeInst *mode) {
279     mExecutionModeInsts.push_back(mode);
280     return this;
281   }
getExecutionModes()282   const std::vector<ExecutionModeInst *> &getExecutionModes() const {
283     return mExecutionModeInsts;
284   }
285 
286   EntryPointDefinition *setLocalSize(uint32_t width, uint32_t height,
287                                      uint32_t depth);
288 
289   EntryPointDefinition *applyExecutionMode(ExecutionModeInst *mode);
290 
getInstruction()291   EntryPointInst *getInstruction() const { return mEntryPointInst; }
292 
293 private:
294   const char *mName;
295   FunctionInst *mFunction;
296   ExecutionModel mExecutionModel;
297   std::vector<VariableInst *> mInterface;
298   Extent3D mLocalSize;
299 
300   EntryPointInst *mEntryPointInst;
301   std::vector<ExecutionModeInst *> mExecutionModeInsts;
302 };
303 
304 class DebugInfoSection : public Entity {
305 public:
DebugInfoSection()306   DebugInfoSection() : mSourcesDeleter(mSources), mNamesDeleter(mNames) {}
DebugInfoSection(Builder * b)307   DebugInfoSection(Builder *b)
308       : Entity(b), mSourcesDeleter(mSources), mNamesDeleter(mNames) {}
309 
~DebugInfoSection()310   virtual ~DebugInfoSection() {}
311 
312   bool DeserializeInternal(InputWordStream &IS) override;
313 
314   DebugInfoSection *addSource(SourceLanguage lang, int version);
315   DebugInfoSection *addSourceExtension(const char *ext);
316   DebugInfoSection *addString(const char *str);
317 
318   Instruction *lookupByName(const char *name) const;
319   const char *lookupNameByInstruction(const Instruction *) const;
320 
accept(IVisitor * v)321   void accept(IVisitor *v) override {
322     for (auto source : mSources) {
323       v->visit(source);
324     }
325     for (auto name : mNames) {
326       v->visit(name);
327     }
328   }
329 
330 private:
331   // (OpString|OpSource|OpSourceExtension|OpSourceContinued)*
332   std::vector<Instruction *> mSources;
333   // (OpName|OpMemberName)*
334   std::vector<Instruction *> mNames;
335 
336   ContainerDeleter<std::vector<Instruction *>> mSourcesDeleter;
337   ContainerDeleter<std::vector<Instruction *>> mNamesDeleter;
338 };
339 
340 class AnnotationSection : public Entity {
341 public:
342   AnnotationSection();
343   AnnotationSection(Builder *b);
344 
~AnnotationSection()345   virtual ~AnnotationSection() {}
346 
347   bool DeserializeInternal(InputWordStream &IS) override;
348 
accept(IVisitor * v)349   void accept(IVisitor *v) override {
350     for (auto inst : mAnnotations) {
351       v->visit(inst);
352     }
353   }
354 
addAnnotations(T begin,T end)355   template <typename T> void addAnnotations(T begin, T end) {
356     mAnnotations.insert<T>(std::end(mAnnotations), begin, end);
357   }
358 
begin()359   std::vector<Instruction *>::const_iterator begin() const {
360     return mAnnotations.begin();
361   }
362 
end()363   std::vector<Instruction *>::const_iterator end() const {
364     return mAnnotations.end();
365   }
366 
clear()367   void clear() { mAnnotations.clear(); }
368 
369 private:
370   std::vector<Instruction *> mAnnotations; // OpDecorate, etc.
371 
372   ContainerDeleter<std::vector<Instruction *>> mAnnotationsDeleter;
373 };
374 
375 // Types, constants, and globals
376 class GlobalSection : public Entity {
377 public:
378   GlobalSection();
379   GlobalSection(Builder *builder);
380 
~GlobalSection()381   virtual ~GlobalSection() {}
382 
383   bool DeserializeInternal(InputWordStream &IS) override;
384 
accept(IVisitor * v)385   void accept(IVisitor *v) override {
386     for (auto inst : mGlobalDefs) {
387       v->visit(inst);
388     }
389 
390     if (mInvocationId) {
391       v->visit(mInvocationId.get());
392     }
393 
394     if (mNumWorkgroups) {
395       v->visit(mNumWorkgroups.get());
396     }
397   }
398 
399   ConstantInst *getConstant(TypeIntInst *type, int32_t value);
400   ConstantInst *getConstant(TypeIntInst *type, uint32_t value);
401   ConstantInst *getConstant(TypeFloatInst *type, float value);
402   ConstantCompositeInst *getConstantComposite(TypeVectorInst *type,
403                                               ConstantInst *components[],
404                                               size_t width);
405 
406   // Methods to look up types. Create them if not found.
407   TypeVoidInst *getVoidType();
408   TypeIntInst *getIntType(int bits, bool isSigned = true);
409   TypeFloatInst *getFloatType(int bits);
410   TypeVectorInst *getVectorType(Instruction *componentType, int width);
411   TypePointerInst *getPointerType(StorageClass storage,
412                                   Instruction *pointeeType);
413   TypeRuntimeArrayInst *getRuntimeArrayType(Instruction *elementType);
414 
415   // This implies that struct types are strictly structural equivalent, i.e.,
416   // two structs are equivalent i.f.f. their fields are equivalent, recursively.
417   TypeStructInst *getStructType(Instruction *fieldType[], int numField);
418   // TypeStructInst *getStructType(const std::vector<Instruction *>
419   // &fieldTypes);
420 
421   // TODO: Can function types of different decorations be considered the same?
422   TypeFunctionInst *getFunctionType(Instruction *retType,
423                                     Instruction *const argType[],
424                                     size_t numArg);
425   // TypeStructInst *addStructType(Instruction *fieldType[], int numField);
426   GlobalSection *addStructType(TypeStructInst *structType);
427   GlobalSection *addVariable(VariableInst *var);
428 
429   VariableInst *getInvocationId();
430   VariableInst *getNumWorkgroups();
431 
432 private:
433   // TODO: Add structure to this.
434   // Separate types, constants, variables, etc.
435   std::vector<Instruction *> mGlobalDefs;
436   std::unique_ptr<VariableInst> mInvocationId;
437   std::unique_ptr<VariableInst> mNumWorkgroups;
438 
439   ContainerDeleter<std::vector<Instruction *>> mGlobalDefsDeleter;
440 };
441 
442 class FunctionDeclaration : public Entity {
443 public:
~FunctionDeclaration()444   virtual ~FunctionDeclaration() {}
445 
446   bool DeserializeInternal(InputWordStream &IS) override;
447 
accept(IVisitor * v)448   void accept(IVisitor *v) override {
449     v->visit(mFunc);
450     for (auto param : mParams) {
451       v->visit(param);
452     }
453     v->visit(mFuncEnd);
454   }
455 
456 private:
457   FunctionInst *mFunc;
458   std::vector<FunctionParameterInst *> mParams;
459   FunctionEndInst *mFuncEnd;
460 };
461 
462 class Block : public Entity {
463 public:
Block()464   Block() {}
Block(Builder * b)465   Block(Builder *b) : Entity(b) {}
466 
~Block()467   virtual ~Block() {}
468 
469   bool DeserializeInternal(InputWordStream &IS) override;
470 
accept(IVisitor * v)471   void accept(IVisitor *v) override {
472     for (auto inst : mInsts) {
473       v->visit(inst);
474     }
475   }
476 
addInstruction(Instruction * inst)477   Block *addInstruction(Instruction *inst) {
478     mInsts.push_back(inst);
479     return this;
480   }
481 
482 private:
483   std::vector<Instruction *> mInsts;
484 };
485 
486 class FunctionDefinition : public Entity {
487 public:
488   FunctionDefinition();
489   FunctionDefinition(Builder *builder, FunctionInst *func,
490                      FunctionEndInst *end);
491 
~FunctionDefinition()492   virtual ~FunctionDefinition() {}
493 
494   bool DeserializeInternal(InputWordStream &IS) override;
495 
accept(IVisitor * v)496   void accept(IVisitor *v) override {
497     v->visit(mFunc.get());
498     for (auto param : mParams) {
499       v->visit(param);
500     }
501     for (auto block : mBlocks) {
502       v->visit(block);
503     }
504     v->visit(mFuncEnd.get());
505   }
506 
addBlock(Block * b)507   FunctionDefinition *addBlock(Block *b) {
508     mBlocks.push_back(b);
509     return this;
510   }
511 
getInstruction()512   FunctionInst *getInstruction() const { return mFunc.get(); }
getParameter(uint32_t i)513   FunctionParameterInst *getParameter(uint32_t i) const { return mParams[i]; }
514 
515   Instruction *getReturnType() const;
516 
517 private:
518   std::unique_ptr<FunctionInst> mFunc;
519   std::vector<FunctionParameterInst *> mParams;
520   std::vector<Block *> mBlocks;
521   std::unique_ptr<FunctionEndInst> mFuncEnd;
522 
523   ContainerDeleter<std::vector<FunctionParameterInst *>> mParamsDeleter;
524   ContainerDeleter<std::vector<Block *>> mBlocksDeleter;
525 };
526 
527 } // namespace spirit
528 } // namespace android
529 
530 #endif // MODULE_H
531