1 // Copyright 2019 The Amber Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef SRC_TYPE_H_
16 #define SRC_TYPE_H_
17 
18 #include <cassert>
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "src/format_data.h"
24 #include "src/make_unique.h"
25 
26 namespace amber {
27 namespace type {
28 
29 class List;
30 class Number;
31 class Struct;
32 
33 class Type {
34  public:
35   Type();
36   virtual ~Type();
37 
IsSignedInt(FormatMode mode)38   static bool IsSignedInt(FormatMode mode) {
39     return mode == FormatMode::kSInt || mode == FormatMode::kSNorm ||
40            mode == FormatMode::kSScaled;
41   }
42 
IsUnsignedInt(FormatMode mode)43   static bool IsUnsignedInt(FormatMode mode) {
44     return mode == FormatMode::kUInt || mode == FormatMode::kUNorm ||
45            mode == FormatMode::kUScaled || mode == FormatMode::kSRGB;
46   }
47 
IsInt(FormatMode mode)48   static bool IsInt(FormatMode mode) {
49     return IsSignedInt(mode) || IsUnsignedInt(mode);
50   }
51 
IsFloat(FormatMode mode)52   static bool IsFloat(FormatMode mode) {
53     return mode == FormatMode::kSFloat || mode == FormatMode::kUFloat;
54   }
55 
IsInt8(FormatMode mode,uint32_t num_bits)56   static bool IsInt8(FormatMode mode, uint32_t num_bits) {
57     return IsSignedInt(mode) && num_bits == 8;
58   }
IsInt16(FormatMode mode,uint32_t num_bits)59   static bool IsInt16(FormatMode mode, uint32_t num_bits) {
60     return IsSignedInt(mode) && num_bits == 16;
61   }
IsInt32(FormatMode mode,uint32_t num_bits)62   static bool IsInt32(FormatMode mode, uint32_t num_bits) {
63     return IsSignedInt(mode) && num_bits == 32;
64   }
IsInt64(FormatMode mode,uint32_t num_bits)65   static bool IsInt64(FormatMode mode, uint32_t num_bits) {
66     return IsSignedInt(mode) && num_bits == 64;
67   }
68 
IsUint8(FormatMode mode,uint32_t num_bits)69   static bool IsUint8(FormatMode mode, uint32_t num_bits) {
70     return IsUnsignedInt(mode) && num_bits == 8;
71   }
IsUint16(FormatMode mode,uint32_t num_bits)72   static bool IsUint16(FormatMode mode, uint32_t num_bits) {
73     return IsUnsignedInt(mode) && num_bits == 16;
74   }
IsUint32(FormatMode mode,uint32_t num_bits)75   static bool IsUint32(FormatMode mode, uint32_t num_bits) {
76     return IsUnsignedInt(mode) && num_bits == 32;
77   }
IsUint64(FormatMode mode,uint32_t num_bits)78   static bool IsUint64(FormatMode mode, uint32_t num_bits) {
79     return IsUnsignedInt(mode) && num_bits == 64;
80   }
81 
IsFloat16(FormatMode mode,uint32_t num_bits)82   static bool IsFloat16(FormatMode mode, uint32_t num_bits) {
83     return IsFloat(mode) && num_bits == 16;
84   }
IsFloat32(FormatMode mode,uint32_t num_bits)85   static bool IsFloat32(FormatMode mode, uint32_t num_bits) {
86     return IsFloat(mode) && num_bits == 32;
87   }
IsFloat64(FormatMode mode,uint32_t num_bits)88   static bool IsFloat64(FormatMode mode, uint32_t num_bits) {
89     return IsFloat(mode) && num_bits == 64;
90   }
91 
92   // Returns the size in bytes of a single element of the type. This does not
93   // include space for arrays, vectors, etc.
94   virtual uint32_t SizeInBytes() const = 0;
95 
96   virtual bool Equal(const Type* b) const = 0;
97 
IsList()98   virtual bool IsList() const { return false; }
IsNumber()99   virtual bool IsNumber() const { return false; }
IsStruct()100   virtual bool IsStruct() const { return false; }
101 
102   List* AsList();
103   Number* AsNumber();
104   Struct* AsStruct();
105 
106   const List* AsList() const;
107   const Number* AsNumber() const;
108   const Struct* AsStruct() const;
109 
SetRowCount(uint32_t size)110   void SetRowCount(uint32_t size) { row_count_ = size; }
RowCount()111   uint32_t RowCount() const { return row_count_; }
112 
SetColumnCount(uint32_t size)113   void SetColumnCount(uint32_t size) { column_count_ = size; }
ColumnCount()114   uint32_t ColumnCount() const { return column_count_; }
115 
SetIsRuntimeArray()116   void SetIsRuntimeArray() { is_array_ = true; }
SetIsSizedArray(uint32_t size)117   void SetIsSizedArray(uint32_t size) {
118     is_array_ = true;
119     array_size_ = size;
120   }
IsArray()121   bool IsArray() const { return is_array_; }
IsSizedArray()122   bool IsSizedArray() const { return is_array_ && array_size_ > 0; }
IsRuntimeArray()123   bool IsRuntimeArray() const { return is_array_ && array_size_ == 0; }
ArraySize()124   uint32_t ArraySize() const { return array_size_; }
125 
IsVec()126   bool IsVec() const { return column_count_ == 1 && row_count_ > 1; }
127 
128   // Returns true if this type holds a vec3.
IsVec3()129   bool IsVec3() const { return column_count_ == 1 && row_count_ == 3; }
130 
131   // Returns true if this type holds a matrix.
IsMatrix()132   bool IsMatrix() const { return column_count_ > 1 && row_count_ > 1; }
133 
134  private:
135   uint32_t row_count_ = 1;
136   uint32_t column_count_ = 1;
137   uint32_t array_size_ = 0;
138   bool is_array_ = false;
139 };
140 
141 class Number : public Type {
142  public:
143   explicit Number(FormatMode mode);
144   Number(FormatMode mode, uint32_t bits);
145   ~Number() override;
146 
147   static std::unique_ptr<Number> Int(uint32_t bits);
148   static std::unique_ptr<Number> Uint(uint32_t bits);
149   static std::unique_ptr<Number> Float(uint32_t bits);
150 
IsNumber()151   bool IsNumber() const override { return true; }
152 
NumBits()153   uint32_t NumBits() const { return bits_; }
SizeInBytes()154   uint32_t SizeInBytes() const override { return bits_ / 8; }
155 
Equal(const Type * b)156   bool Equal(const Type* b) const override {
157     if (!b->IsNumber())
158       return false;
159 
160     auto n = b->AsNumber();
161     return format_mode_ == n->format_mode_ && bits_ == n->bits_;
162   }
163 
GetFormatMode()164   FormatMode GetFormatMode() const { return format_mode_; }
165 
166  private:
167   FormatMode format_mode_ = FormatMode::kSInt;
168   uint32_t bits_ = 32;
169 };
170 
171 // The list type only holds lists of scalar float and int values.
172 class List : public Type {
173  public:
174   struct Member {
MemberMember175     Member(FormatComponentType t, FormatMode m, uint32_t b)
176         : name(t), mode(m), num_bits(b) {}
177 
SizeInBytesMember178     uint32_t SizeInBytes() const { return num_bits / 8; }
179 
180     FormatComponentType name = FormatComponentType::kR;
181     FormatMode mode = FormatMode::kSInt;
182     uint32_t num_bits = 0;
183   };
184 
185   List();
186   ~List() override;
187 
IsList()188   bool IsList() const override { return true; }
189 
Equal(const Type * b)190   bool Equal(const Type* b) const override {
191     if (!b->IsList())
192       return false;
193 
194     auto l = b->AsList();
195     if (pack_size_in_bits_ != l->pack_size_in_bits_)
196       return false;
197     if (members_.size() != l->members_.size())
198       return false;
199 
200     auto& lm = l->Members();
201     for (size_t i = 0; i < members_.size(); ++i) {
202       if (members_[i].name != lm[i].name)
203         return false;
204       if (members_[i].mode != lm[i].mode)
205         return false;
206       if (members_[i].num_bits != lm[i].num_bits)
207         return false;
208     }
209     return true;
210   }
211 
SetPackSizeInBits(uint32_t size)212   void SetPackSizeInBits(uint32_t size) { pack_size_in_bits_ = size; }
PackSizeInBits()213   uint32_t PackSizeInBits() const { return pack_size_in_bits_; }
IsPacked()214   bool IsPacked() const { return pack_size_in_bits_ > 0; }
215 
AddMember(FormatComponentType name,FormatMode mode,uint32_t num_bits)216   void AddMember(FormatComponentType name, FormatMode mode, uint32_t num_bits) {
217     members_.push_back({name, mode, num_bits});
218   }
219 
Members()220   const std::vector<Member>& Members() const { return members_; }
Members()221   std::vector<Member>& Members() { return members_; }
222 
223   uint32_t SizeInBytes() const override;
224 
225  private:
226   std::vector<Member> members_;
227   uint32_t pack_size_in_bits_ = 0;
228 };
229 
230 class Struct : public Type {
231  public:
232   struct Member {
233     std::string name;
234     Type* type;
235     int32_t offset_in_bytes = -1;
236     int32_t array_stride_in_bytes = -1;
237     int32_t matrix_stride_in_bytes = -1;
238 
HasOffsetMember239     bool HasOffset() const { return offset_in_bytes >= 0; }
HasArrayStrideMember240     bool HasArrayStride() const { return array_stride_in_bytes > 0; }
HasMatrixStrideMember241     bool HasMatrixStride() const { return matrix_stride_in_bytes > 0; }
242   };
243 
244   Struct();
245   ~Struct() override;
246 
247   uint32_t SizeInBytes() const override;
IsStruct()248   bool IsStruct() const override { return true; }
249 
Equal(const Type * b)250   bool Equal(const Type* b) const override {
251     if (!b->IsStruct())
252       return false;
253 
254     auto s = b->AsStruct();
255     if (is_stride_specified_ != s->is_stride_specified_)
256       return false;
257     if (stride_in_bytes_ != s->stride_in_bytes_)
258       return false;
259     if (members_.size() != s->members_.size())
260       return false;
261 
262     auto& sm = s->Members();
263     for (size_t i = 0; i < members_.size(); ++i) {
264       if (members_[i].offset_in_bytes != sm[i].offset_in_bytes)
265         return false;
266       if (members_[i].array_stride_in_bytes != sm[i].array_stride_in_bytes)
267         return false;
268       if (members_[i].matrix_stride_in_bytes != sm[i].matrix_stride_in_bytes)
269         return false;
270       if (!members_[i].type->Equal(sm[i].type))
271         return false;
272     }
273     return true;
274   }
275 
HasStride()276   bool HasStride() const { return is_stride_specified_; }
StrideInBytes()277   uint32_t StrideInBytes() const { return stride_in_bytes_; }
SetStrideInBytes(uint32_t stride)278   void SetStrideInBytes(uint32_t stride) {
279     is_stride_specified_ = true;
280     stride_in_bytes_ = stride;
281   }
282 
AddMember(Type * type)283   Member* AddMember(Type* type) {
284     members_.push_back({});
285     members_.back().type = type;
286     return &members_.back();
287   }
288 
Members()289   const std::vector<Member>& Members() const { return members_; }
290 
291  private:
292   std::vector<Member> members_;
293   bool is_stride_specified_ = false;
294   uint32_t stride_in_bytes_ = 0;
295 };
296 
297 }  // namespace type
298 }  // namespace amber
299 
300 #endif  // SRC_TYPE_H_
301