1 // Copyright (c) 2016 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef SOURCE_OPT_CONSTANTS_H_
16 #define SOURCE_OPT_CONSTANTS_H_
17 
18 #include <cinttypes>
19 #include <map>
20 #include <memory>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <utility>
24 #include <vector>
25 
26 #include "source/opt/module.h"
27 #include "source/opt/type_manager.h"
28 #include "source/opt/types.h"
29 #include "source/util/hex_float.h"
30 #include "source/util/make_unique.h"
31 
32 namespace spvtools {
33 namespace opt {
34 
35 class IRContext;
36 
37 namespace analysis {
38 
39 // Class hierarchy to represent the normal constants defined through
40 // OpConstantTrue, OpConstantFalse, OpConstant, OpConstantNull and
41 // OpConstantComposite instructions.
42 // TODO(qining): Add class for constants defined with OpConstantSampler.
43 class Constant;
44 class ScalarConstant;
45 class IntConstant;
46 class FloatConstant;
47 class BoolConstant;
48 class CompositeConstant;
49 class StructConstant;
50 class VectorConstant;
51 class MatrixConstant;
52 class ArrayConstant;
53 class NullConstant;
54 class ConstantManager;
55 
56 // Abstract class for a SPIR-V constant. It has a bunch of As<subclass> methods,
57 // which is used as a way to probe the actual <subclass>
58 class Constant {
59  public:
60   Constant() = delete;
~Constant()61   virtual ~Constant() {}
62 
63   // Make a deep copy of this constant.
64   virtual std::unique_ptr<Constant> Copy() const = 0;
65 
66   // reflections
AsScalarConstant()67   virtual ScalarConstant* AsScalarConstant() { return nullptr; }
AsIntConstant()68   virtual IntConstant* AsIntConstant() { return nullptr; }
AsFloatConstant()69   virtual FloatConstant* AsFloatConstant() { return nullptr; }
AsBoolConstant()70   virtual BoolConstant* AsBoolConstant() { return nullptr; }
AsCompositeConstant()71   virtual CompositeConstant* AsCompositeConstant() { return nullptr; }
AsStructConstant()72   virtual StructConstant* AsStructConstant() { return nullptr; }
AsVectorConstant()73   virtual VectorConstant* AsVectorConstant() { return nullptr; }
AsMatrixConstant()74   virtual MatrixConstant* AsMatrixConstant() { return nullptr; }
AsArrayConstant()75   virtual ArrayConstant* AsArrayConstant() { return nullptr; }
AsNullConstant()76   virtual NullConstant* AsNullConstant() { return nullptr; }
77 
AsScalarConstant()78   virtual const ScalarConstant* AsScalarConstant() const { return nullptr; }
AsIntConstant()79   virtual const IntConstant* AsIntConstant() const { return nullptr; }
AsFloatConstant()80   virtual const FloatConstant* AsFloatConstant() const { return nullptr; }
AsBoolConstant()81   virtual const BoolConstant* AsBoolConstant() const { return nullptr; }
AsCompositeConstant()82   virtual const CompositeConstant* AsCompositeConstant() const {
83     return nullptr;
84   }
AsStructConstant()85   virtual const StructConstant* AsStructConstant() const { return nullptr; }
AsVectorConstant()86   virtual const VectorConstant* AsVectorConstant() const { return nullptr; }
AsMatrixConstant()87   virtual const MatrixConstant* AsMatrixConstant() const { return nullptr; }
AsArrayConstant()88   virtual const ArrayConstant* AsArrayConstant() const { return nullptr; }
AsNullConstant()89   virtual const NullConstant* AsNullConstant() const { return nullptr; }
90 
91   // Returns the float representation of the constant. Must be a 32 bit
92   // Float type.
93   float GetFloat() const;
94 
95   // Returns the double representation of the constant. Must be a 64 bit
96   // Float type.
97   double GetDouble() const;
98 
99   // Returns the double representation of the constant. Must be a 32-bit or
100   // 64-bit Float type.
101   double GetValueAsDouble() const;
102 
103   // Returns uint32_t representation of the constant. Must be a 32 bit
104   // Integer type.
105   uint32_t GetU32() const;
106 
107   // Returns uint64_t representation of the constant. Must be a 64 bit
108   // Integer type.
109   uint64_t GetU64() const;
110 
111   // Returns int32_t representation of the constant. Must be a 32 bit
112   // Integer type.
113   int32_t GetS32() const;
114 
115   // Returns int64_t representation of the constant. Must be a 64 bit
116   // Integer type.
117   int64_t GetS64() const;
118 
119   // Returns true if the constant is a zero or a composite containing 0s.
IsZero()120   virtual bool IsZero() const { return false; }
121 
type()122   const Type* type() const { return type_; }
123 
124   // Returns an std::vector containing the elements of |constant|.  The type of
125   // |constant| must be |Vector|.
126   std::vector<const Constant*> GetVectorComponents(
127       ConstantManager* const_mgr) const;
128 
129  protected:
Constant(const Type * ty)130   Constant(const Type* ty) : type_(ty) {}
131 
132   // The type of this constant.
133   const Type* type_;
134 };
135 
136 // Abstract class for scalar type constants.
137 class ScalarConstant : public Constant {
138  public:
139   ScalarConstant() = delete;
AsScalarConstant()140   ScalarConstant* AsScalarConstant() override { return this; }
AsScalarConstant()141   const ScalarConstant* AsScalarConstant() const override { return this; }
142 
143   // Returns a const reference of the value of this constant in 32-bit words.
words()144   virtual const std::vector<uint32_t>& words() const { return words_; }
145 
146   // Returns true if the value is zero.
IsZero()147   bool IsZero() const override {
148     bool is_zero = true;
149     for (uint32_t v : words()) {
150       if (v != 0) {
151         is_zero = false;
152         break;
153       }
154     }
155     return is_zero;
156   }
157 
158  protected:
ScalarConstant(const Type * ty,const std::vector<uint32_t> & w)159   ScalarConstant(const Type* ty, const std::vector<uint32_t>& w)
160       : Constant(ty), words_(w) {}
ScalarConstant(const Type * ty,std::vector<uint32_t> && w)161   ScalarConstant(const Type* ty, std::vector<uint32_t>&& w)
162       : Constant(ty), words_(std::move(w)) {}
163   std::vector<uint32_t> words_;
164 };
165 
166 // Integer type constant.
167 class IntConstant : public ScalarConstant {
168  public:
IntConstant(const Integer * ty,const std::vector<uint32_t> & w)169   IntConstant(const Integer* ty, const std::vector<uint32_t>& w)
170       : ScalarConstant(ty, w) {}
IntConstant(const Integer * ty,std::vector<uint32_t> && w)171   IntConstant(const Integer* ty, std::vector<uint32_t>&& w)
172       : ScalarConstant(ty, std::move(w)) {}
173 
AsIntConstant()174   IntConstant* AsIntConstant() override { return this; }
AsIntConstant()175   const IntConstant* AsIntConstant() const override { return this; }
176 
GetS32BitValue()177   int32_t GetS32BitValue() const {
178     // Relies on signed values smaller than 32-bit being sign extended.  See
179     // section 2.2.1 of the SPIR-V spec.
180     assert(words().size() == 1);
181     return words()[0];
182   }
183 
GetU32BitValue()184   uint32_t GetU32BitValue() const {
185     // Relies on unsigned values smaller than 32-bit being zero extended.  See
186     // section 2.2.1 of the SPIR-V spec.
187     assert(words().size() == 1);
188     return words()[0];
189   }
190 
GetS64BitValue()191   int64_t GetS64BitValue() const {
192     // Relies on unsigned values smaller than 64-bit being sign extended.  See
193     // section 2.2.1 of the SPIR-V spec.
194     assert(words().size() == 2);
195     return static_cast<uint64_t>(words()[1]) << 32 |
196            static_cast<uint64_t>(words()[0]);
197   }
198 
GetU64BitValue()199   uint64_t GetU64BitValue() const {
200     // Relies on unsigned values smaller than 64-bit being zero extended.  See
201     // section 2.2.1 of the SPIR-V spec.
202     assert(words().size() == 2);
203     return static_cast<uint64_t>(words()[1]) << 32 |
204            static_cast<uint64_t>(words()[0]);
205   }
206 
207   // Make a copy of this IntConstant instance.
CopyIntConstant()208   std::unique_ptr<IntConstant> CopyIntConstant() const {
209     return MakeUnique<IntConstant>(type_->AsInteger(), words_);
210   }
Copy()211   std::unique_ptr<Constant> Copy() const override {
212     return std::unique_ptr<Constant>(CopyIntConstant().release());
213   }
214 };
215 
216 // Float type constant.
217 class FloatConstant : public ScalarConstant {
218  public:
FloatConstant(const Float * ty,const std::vector<uint32_t> & w)219   FloatConstant(const Float* ty, const std::vector<uint32_t>& w)
220       : ScalarConstant(ty, w) {}
FloatConstant(const Float * ty,std::vector<uint32_t> && w)221   FloatConstant(const Float* ty, std::vector<uint32_t>&& w)
222       : ScalarConstant(ty, std::move(w)) {}
223 
AsFloatConstant()224   FloatConstant* AsFloatConstant() override { return this; }
AsFloatConstant()225   const FloatConstant* AsFloatConstant() const override { return this; }
226 
227   // Make a copy of this FloatConstant instance.
CopyFloatConstant()228   std::unique_ptr<FloatConstant> CopyFloatConstant() const {
229     return MakeUnique<FloatConstant>(type_->AsFloat(), words_);
230   }
Copy()231   std::unique_ptr<Constant> Copy() const override {
232     return std::unique_ptr<Constant>(CopyFloatConstant().release());
233   }
234 
235   // Returns the float value of |this|.  The type of |this| must be |Float| with
236   // width of 32.
GetFloatValue()237   float GetFloatValue() const {
238     assert(type()->AsFloat()->width() == 32 &&
239            "Not a 32-bit floating point value.");
240     utils::FloatProxy<float> a(words()[0]);
241     return a.getAsFloat();
242   }
243 
244   // Returns the double value of |this|.  The type of |this| must be |Float|
245   // with width of 64.
GetDoubleValue()246   double GetDoubleValue() const {
247     assert(type()->AsFloat()->width() == 64 &&
248            "Not a 32-bit floating point value.");
249     uint64_t combined_words = words()[1];
250     combined_words = combined_words << 32;
251     combined_words |= words()[0];
252     utils::FloatProxy<double> a(combined_words);
253     return a.getAsFloat();
254   }
255 };
256 
257 // Bool type constant.
258 class BoolConstant : public ScalarConstant {
259  public:
BoolConstant(const Bool * ty,bool v)260   BoolConstant(const Bool* ty, bool v)
261       : ScalarConstant(ty, {static_cast<uint32_t>(v)}), value_(v) {}
262 
AsBoolConstant()263   BoolConstant* AsBoolConstant() override { return this; }
AsBoolConstant()264   const BoolConstant* AsBoolConstant() const override { return this; }
265 
266   // Make a copy of this BoolConstant instance.
CopyBoolConstant()267   std::unique_ptr<BoolConstant> CopyBoolConstant() const {
268     return MakeUnique<BoolConstant>(type_->AsBool(), value_);
269   }
Copy()270   std::unique_ptr<Constant> Copy() const override {
271     return std::unique_ptr<Constant>(CopyBoolConstant().release());
272   }
273 
value()274   bool value() const { return value_; }
275 
276  private:
277   bool value_;
278 };
279 
280 // Abstract class for composite constants.
281 class CompositeConstant : public Constant {
282  public:
283   CompositeConstant() = delete;
AsCompositeConstant()284   CompositeConstant* AsCompositeConstant() override { return this; }
AsCompositeConstant()285   const CompositeConstant* AsCompositeConstant() const override { return this; }
286 
287   // Returns a const reference of the components held in this composite
288   // constant.
GetComponents()289   virtual const std::vector<const Constant*>& GetComponents() const {
290     return components_;
291   }
292 
IsZero()293   bool IsZero() const override {
294     for (const Constant* c : GetComponents()) {
295       if (!c->IsZero()) {
296         return false;
297       }
298     }
299     return true;
300   }
301 
302  protected:
CompositeConstant(const Type * ty)303   CompositeConstant(const Type* ty) : Constant(ty), components_() {}
CompositeConstant(const Type * ty,const std::vector<const Constant * > & components)304   CompositeConstant(const Type* ty,
305                     const std::vector<const Constant*>& components)
306       : Constant(ty), components_(components) {}
CompositeConstant(const Type * ty,std::vector<const Constant * > && components)307   CompositeConstant(const Type* ty, std::vector<const Constant*>&& components)
308       : Constant(ty), components_(std::move(components)) {}
309   std::vector<const Constant*> components_;
310 };
311 
312 // Struct type constant.
313 class StructConstant : public CompositeConstant {
314  public:
StructConstant(const Struct * ty)315   StructConstant(const Struct* ty) : CompositeConstant(ty) {}
StructConstant(const Struct * ty,const std::vector<const Constant * > & components)316   StructConstant(const Struct* ty,
317                  const std::vector<const Constant*>& components)
318       : CompositeConstant(ty, components) {}
StructConstant(const Struct * ty,std::vector<const Constant * > && components)319   StructConstant(const Struct* ty, std::vector<const Constant*>&& components)
320       : CompositeConstant(ty, std::move(components)) {}
321 
AsStructConstant()322   StructConstant* AsStructConstant() override { return this; }
AsStructConstant()323   const StructConstant* AsStructConstant() const override { return this; }
324 
325   // Make a copy of this StructConstant instance.
CopyStructConstant()326   std::unique_ptr<StructConstant> CopyStructConstant() const {
327     return MakeUnique<StructConstant>(type_->AsStruct(), components_);
328   }
Copy()329   std::unique_ptr<Constant> Copy() const override {
330     return std::unique_ptr<Constant>(CopyStructConstant().release());
331   }
332 };
333 
334 // Vector type constant.
335 class VectorConstant : public CompositeConstant {
336  public:
VectorConstant(const Vector * ty)337   VectorConstant(const Vector* ty)
338       : CompositeConstant(ty), component_type_(ty->element_type()) {}
VectorConstant(const Vector * ty,const std::vector<const Constant * > & components)339   VectorConstant(const Vector* ty,
340                  const std::vector<const Constant*>& components)
341       : CompositeConstant(ty, components),
342         component_type_(ty->element_type()) {}
VectorConstant(const Vector * ty,std::vector<const Constant * > && components)343   VectorConstant(const Vector* ty, std::vector<const Constant*>&& components)
344       : CompositeConstant(ty, std::move(components)),
345         component_type_(ty->element_type()) {}
346 
AsVectorConstant()347   VectorConstant* AsVectorConstant() override { return this; }
AsVectorConstant()348   const VectorConstant* AsVectorConstant() const override { return this; }
349 
350   // Make a copy of this VectorConstant instance.
CopyVectorConstant()351   std::unique_ptr<VectorConstant> CopyVectorConstant() const {
352     auto another = MakeUnique<VectorConstant>(type_->AsVector());
353     another->components_.insert(another->components_.end(), components_.begin(),
354                                 components_.end());
355     return another;
356   }
Copy()357   std::unique_ptr<Constant> Copy() const override {
358     return std::unique_ptr<Constant>(CopyVectorConstant().release());
359   }
360 
component_type()361   const Type* component_type() const { return component_type_; }
362 
363  private:
364   const Type* component_type_;
365 };
366 
367 // Matrix type constant.
368 class MatrixConstant : public CompositeConstant {
369  public:
MatrixConstant(const Matrix * ty)370   MatrixConstant(const Matrix* ty)
371       : CompositeConstant(ty), component_type_(ty->element_type()) {}
MatrixConstant(const Matrix * ty,const std::vector<const Constant * > & components)372   MatrixConstant(const Matrix* ty,
373                  const std::vector<const Constant*>& components)
374       : CompositeConstant(ty, components),
375         component_type_(ty->element_type()) {}
MatrixConstant(const Vector * ty,std::vector<const Constant * > && components)376   MatrixConstant(const Vector* ty, std::vector<const Constant*>&& components)
377       : CompositeConstant(ty, std::move(components)),
378         component_type_(ty->element_type()) {}
379 
AsMatrixConstant()380   MatrixConstant* AsMatrixConstant() override { return this; }
AsMatrixConstant()381   const MatrixConstant* AsMatrixConstant() const override { return this; }
382 
383   // Make a copy of this MatrixConstant instance.
CopyMatrixConstant()384   std::unique_ptr<MatrixConstant> CopyMatrixConstant() const {
385     auto another = MakeUnique<MatrixConstant>(type_->AsMatrix());
386     another->components_.insert(another->components_.end(), components_.begin(),
387                                 components_.end());
388     return another;
389   }
Copy()390   std::unique_ptr<Constant> Copy() const override {
391     return std::unique_ptr<Constant>(CopyMatrixConstant().release());
392   }
393 
component_type()394   const Type* component_type() { return component_type_; }
395 
396  private:
397   const Type* component_type_;
398 };
399 
400 // Array type constant.
401 class ArrayConstant : public CompositeConstant {
402  public:
ArrayConstant(const Array * ty)403   ArrayConstant(const Array* ty) : CompositeConstant(ty) {}
ArrayConstant(const Array * ty,const std::vector<const Constant * > & components)404   ArrayConstant(const Array* ty, const std::vector<const Constant*>& components)
405       : CompositeConstant(ty, components) {}
ArrayConstant(const Array * ty,std::vector<const Constant * > && components)406   ArrayConstant(const Array* ty, std::vector<const Constant*>&& components)
407       : CompositeConstant(ty, std::move(components)) {}
408 
AsArrayConstant()409   ArrayConstant* AsArrayConstant() override { return this; }
AsArrayConstant()410   const ArrayConstant* AsArrayConstant() const override { return this; }
411 
412   // Make a copy of this ArrayConstant instance.
CopyArrayConstant()413   std::unique_ptr<ArrayConstant> CopyArrayConstant() const {
414     return MakeUnique<ArrayConstant>(type_->AsArray(), components_);
415   }
Copy()416   std::unique_ptr<Constant> Copy() const override {
417     return std::unique_ptr<Constant>(CopyArrayConstant().release());
418   }
419 };
420 
421 // Null type constant.
422 class NullConstant : public Constant {
423  public:
NullConstant(const Type * ty)424   NullConstant(const Type* ty) : Constant(ty) {}
AsNullConstant()425   NullConstant* AsNullConstant() override { return this; }
AsNullConstant()426   const NullConstant* AsNullConstant() const override { return this; }
427 
428   // Make a copy of this NullConstant instance.
CopyNullConstant()429   std::unique_ptr<NullConstant> CopyNullConstant() const {
430     return MakeUnique<NullConstant>(type_);
431   }
Copy()432   std::unique_ptr<Constant> Copy() const override {
433     return std::unique_ptr<Constant>(CopyNullConstant().release());
434   }
IsZero()435   bool IsZero() const override { return true; };
436 };
437 
438 // Hash function for Constant instances. Use the structure of the constant as
439 // the key.
440 struct ConstantHash {
add_pointerConstantHash441   void add_pointer(std::u32string* h, const void* p) const {
442     uint64_t ptr_val = reinterpret_cast<uint64_t>(p);
443     h->push_back(static_cast<uint32_t>(ptr_val >> 32));
444     h->push_back(static_cast<uint32_t>(ptr_val));
445   }
446 
operatorConstantHash447   size_t operator()(const Constant* const_val) const {
448     std::u32string h;
449     add_pointer(&h, const_val->type());
450     if (const auto scalar = const_val->AsScalarConstant()) {
451       for (const auto& w : scalar->words()) {
452         h.push_back(w);
453       }
454     } else if (const auto composite = const_val->AsCompositeConstant()) {
455       for (const auto& c : composite->GetComponents()) {
456         add_pointer(&h, c);
457       }
458     } else if (const_val->AsNullConstant()) {
459       h.push_back(0);
460     } else {
461       assert(
462           false &&
463           "Tried to compute the hash value of an invalid Constant instance.");
464     }
465 
466     return std::hash<std::u32string>()(h);
467   }
468 };
469 
470 // Equality comparison structure for two constants.
471 struct ConstantEqual {
operatorConstantEqual472   bool operator()(const Constant* c1, const Constant* c2) const {
473     if (c1->type() != c2->type()) {
474       return false;
475     }
476 
477     if (const auto& s1 = c1->AsScalarConstant()) {
478       const auto& s2 = c2->AsScalarConstant();
479       return s2 && s1->words() == s2->words();
480     } else if (const auto& composite1 = c1->AsCompositeConstant()) {
481       const auto& composite2 = c2->AsCompositeConstant();
482       return composite2 &&
483              composite1->GetComponents() == composite2->GetComponents();
484     } else if (c1->AsNullConstant()) {
485       return c2->AsNullConstant() != nullptr;
486     } else {
487       assert(false && "Tried to compare two invalid Constant instances.");
488     }
489     return false;
490   }
491 };
492 
493 // This class represents a pool of constants.
494 class ConstantManager {
495  public:
496   ConstantManager(IRContext* ctx);
497 
context()498   IRContext* context() const { return ctx_; }
499 
500   // Gets or creates a unique Constant instance of type |type| and a vector of
501   // constant defining words |words|. If a Constant instance existed already in
502   // the constant pool, it returns a pointer to it.  Otherwise, it creates one
503   // using CreateConstant. If a new Constant instance cannot be created, it
504   // returns nullptr.
505   const Constant* GetConstant(
506       const Type* type, const std::vector<uint32_t>& literal_words_or_ids);
507 
508   template <class C>
GetConstant(const Type * type,const C & literal_words_or_ids)509   const Constant* GetConstant(const Type* type, const C& literal_words_or_ids) {
510     return GetConstant(type, std::vector<uint32_t>(literal_words_or_ids.begin(),
511                                                    literal_words_or_ids.end()));
512   }
513 
514   // Gets or creates a Constant instance to hold the constant value of the given
515   // instruction. It returns a pointer to a Constant instance or nullptr if it
516   // could not create the constant.
517   const Constant* GetConstantFromInst(Instruction* inst);
518 
519   // Gets or creates a constant defining instruction for the given Constant |c|.
520   // If |c| had already been defined, it returns a pointer to the existing
521   // declaration. Otherwise, it calls BuildInstructionAndAddToModule. If the
522   // optional |pos| is given, it will insert any newly created instructions at
523   // the given instruction iterator position. Otherwise, it inserts the new
524   // instruction at the end of the current module's types section.
525   //
526   // |type_id| is an optional argument for disambiguating equivalent types. If
527   // |type_id| is specified, it is used as the type of the constant when a new
528   // instruction is created. Otherwise the type of the constant is derived by
529   // getting an id from the type manager for |c|.
530   //
531   // When |type_id| is not zero, the type of |c| must be the type returned by
532   // type manager when given |type_id|.
533   Instruction* GetDefiningInstruction(const Constant* c, uint32_t type_id = 0,
534                                       Module::inst_iterator* pos = nullptr);
535 
536   // Creates a constant defining instruction for the given Constant instance
537   // and inserts the instruction at the position specified by the given
538   // instruction iterator. Returns a pointer to the created instruction if
539   // succeeded, otherwise returns a null pointer. The instruction iterator
540   // points to the same instruction before and after the insertion. This is the
541   // only method that actually manages id creation/assignment and instruction
542   // creation/insertion for a new Constant instance.
543   //
544   // |type_id| is an optional argument for disambiguating equivalent types. If
545   // |type_id| is specified, it is used as the type of the constant. Otherwise
546   // the type of the constant is derived by getting an id from the type manager
547   // for |c|.
548   Instruction* BuildInstructionAndAddToModule(const Constant* c,
549                                               Module::inst_iterator* pos,
550                                               uint32_t type_id = 0);
551 
552   // A helper function to get the result type of the given instruction. Returns
553   // nullptr if the instruction does not have a type id (type id is 0).
554   Type* GetType(const Instruction* inst) const;
555 
556   // A helper function to get the collected normal constant with the given id.
557   // Returns the pointer to the Constant instance in case it is found.
558   // Otherwise, it returns a null pointer.
FindDeclaredConstant(uint32_t id)559   const Constant* FindDeclaredConstant(uint32_t id) const {
560     auto iter = id_to_const_val_.find(id);
561     return (iter != id_to_const_val_.end()) ? iter->second : nullptr;
562   }
563 
564   // A helper function to get the id of a collected constant with the pointer
565   // to the Constant instance. Returns 0 in case the constant is not found.
566   uint32_t FindDeclaredConstant(const Constant* c, uint32_t type_id) const;
567 
568   // Returns the canonical constant that has the same structure and value as the
569   // given Constant |cst|. If none is found, it returns nullptr.
570   //
571   // TODO: Should be able to give a type id to disambiguate types with the same
572   // structure.
FindConstant(const Constant * c)573   const Constant* FindConstant(const Constant* c) const {
574     auto it = const_pool_.find(c);
575     return (it != const_pool_.end()) ? *it : nullptr;
576   }
577 
578   // Registers a new constant |cst| in the constant pool. If the constant
579   // existed already, it returns a pointer to the previously existing Constant
580   // in the pool. Otherwise, it returns |cst|.
RegisterConstant(std::unique_ptr<Constant> cst)581   const Constant* RegisterConstant(std::unique_ptr<Constant> cst) {
582     auto ret = const_pool_.insert(cst.get());
583     if (ret.second) {
584       owned_constants_.emplace_back(std::move(cst));
585     }
586     return *ret.first;
587   }
588 
589   // A helper function to get a vector of Constant instances with the specified
590   // ids. If it can not find the Constant instance for any one of the ids,
591   // it returns an empty vector.
592   std::vector<const Constant*> GetConstantsFromIds(
593       const std::vector<uint32_t>& ids) const;
594 
595   // Returns a vector of constants representing each in operand. If an operand
596   // is not constant its entry is nullptr.
597   std::vector<const Constant*> GetOperandConstants(Instruction* inst) const;
598 
599   // Records a mapping between |inst| and the constant value generated by it.
600   // It returns true if a new Constant was successfully mapped, false if |inst|
601   // generates no constant values.
MapInst(Instruction * inst)602   bool MapInst(Instruction* inst) {
603     if (auto cst = GetConstantFromInst(inst)) {
604       MapConstantToInst(cst, inst);
605       return true;
606     }
607     return false;
608   }
609 
RemoveId(uint32_t id)610   void RemoveId(uint32_t id) {
611     auto it = id_to_const_val_.find(id);
612     if (it != id_to_const_val_.end()) {
613       const_val_to_id_.erase(it->second);
614       id_to_const_val_.erase(it);
615     }
616   }
617 
618   // Records a new mapping between |inst| and |const_value|. This updates the
619   // two mappings |id_to_const_val_| and |const_val_to_id_|.
MapConstantToInst(const Constant * const_value,Instruction * inst)620   void MapConstantToInst(const Constant* const_value, Instruction* inst) {
621     if (id_to_const_val_.insert({inst->result_id(), const_value}).second) {
622       const_val_to_id_.insert({const_value, inst->result_id()});
623     }
624   }
625 
626  private:
627   // Creates a Constant instance with the given type and a vector of constant
628   // defining words. Returns a unique pointer to the created Constant instance
629   // if the Constant instance can be created successfully. To create scalar
630   // type constants, the vector should contain the constant value in 32 bit
631   // words and the given type must be of type Bool, Integer or Float. To create
632   // composite type constants, the vector should contain the component ids, and
633   // those component ids should have been recorded before as Normal Constants.
634   // And the given type must be of type Struct, Vector or Array. When creating
635   // VectorType Constant instance, the components must be scalars of the same
636   // type, either Bool, Integer or Float. If any of the rules above failed, the
637   // creation will fail and nullptr will be returned. If the vector is empty,
638   // a NullConstant instance will be created with the given type.
639   std::unique_ptr<Constant> CreateConstant(
640       const Type* type,
641       const std::vector<uint32_t>& literal_words_or_ids) const;
642 
643   // Creates an instruction with the given result id to declare a constant
644   // represented by the given Constant instance. Returns an unique pointer to
645   // the created instruction if the instruction can be created successfully.
646   // Otherwise, returns a null pointer.
647   //
648   // |type_id| is an optional argument for disambiguating equivalent types. If
649   // |type_id| is specified, it is used as the type of the constant. Otherwise
650   // the type of the constant is derived by getting an id from the type manager
651   // for |c|.
652   std::unique_ptr<Instruction> CreateInstruction(uint32_t result_id,
653                                                  const Constant* c,
654                                                  uint32_t type_id = 0) const;
655 
656   // Creates an OpConstantComposite instruction with the given result id and
657   // the CompositeConst instance which represents a composite constant. Returns
658   // an unique pointer to the created instruction if succeeded. Otherwise
659   // returns a null pointer.
660   //
661   // |type_id| is an optional argument for disambiguating equivalent types. If
662   // |type_id| is specified, it is used as the type of the constant. Otherwise
663   // the type of the constant is derived by getting an id from the type manager
664   // for |c|.
665   std::unique_ptr<Instruction> CreateCompositeInstruction(
666       uint32_t result_id, const CompositeConstant* cc,
667       uint32_t type_id = 0) const;
668 
669   // IR context that owns this constant manager.
670   IRContext* ctx_;
671 
672   // A mapping from the result ids of Normal Constants to their
673   // Constant instances. All Normal Constants in the module, either
674   // existing ones before optimization or the newly generated ones, should have
675   // their Constant instance stored and their result id registered in this map.
676   std::unordered_map<uint32_t, const Constant*> id_to_const_val_;
677 
678   // A mapping from the Constant instance of Normal Constants to their
679   // result id in the module. This is a mirror map of |id_to_const_val_|. All
680   // Normal Constants that defining instructions in the module should have
681   // their Constant and their result id registered here.
682   std::multimap<const Constant*, uint32_t> const_val_to_id_;
683 
684   // The constant pool.  All created constants are registered here.
685   std::unordered_set<const Constant*, ConstantHash, ConstantEqual> const_pool_;
686 
687   // The constant that are owned by the constant manager.  Every constant in
688   // |const_pool_| should be in |owned_constants_| as well.
689   std::vector<std::unique_ptr<Constant>> owned_constants_;
690 };
691 
692 }  // namespace analysis
693 }  // namespace opt
694 }  // namespace spvtools
695 
696 #endif  // SOURCE_OPT_CONSTANTS_H_
697