1 // Copyright (c) 2018 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/const_folding_rules.h"
16 
17 #include "source/opt/ir_context.h"
18 
19 namespace spvtools {
20 namespace opt {
21 namespace {
22 
23 const uint32_t kExtractCompositeIdInIdx = 0;
24 
25 // Returns true if |type| is Float or a vector of Float.
HasFloatingPoint(const analysis::Type * type)26 bool HasFloatingPoint(const analysis::Type* type) {
27   if (type->AsFloat()) {
28     return true;
29   } else if (const analysis::Vector* vec_type = type->AsVector()) {
30     return vec_type->element_type()->AsFloat() != nullptr;
31   }
32 
33   return false;
34 }
35 
36 // Folds an OpcompositeExtract where input is a composite constant.
FoldExtractWithConstants()37 ConstantFoldingRule FoldExtractWithConstants() {
38   return [](IRContext* context, Instruction* inst,
39             const std::vector<const analysis::Constant*>& constants)
40              -> const analysis::Constant* {
41     const analysis::Constant* c = constants[kExtractCompositeIdInIdx];
42     if (c == nullptr) {
43       return nullptr;
44     }
45 
46     for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
47       uint32_t element_index = inst->GetSingleWordInOperand(i);
48       if (c->AsNullConstant()) {
49         // Return Null for the return type.
50         analysis::ConstantManager* const_mgr = context->get_constant_mgr();
51         analysis::TypeManager* type_mgr = context->get_type_mgr();
52         return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {});
53       }
54 
55       auto cc = c->AsCompositeConstant();
56       assert(cc != nullptr);
57       auto components = cc->GetComponents();
58       c = components[element_index];
59     }
60     return c;
61   };
62 }
63 
FoldVectorShuffleWithConstants()64 ConstantFoldingRule FoldVectorShuffleWithConstants() {
65   return [](IRContext* context, Instruction* inst,
66             const std::vector<const analysis::Constant*>& constants)
67              -> const analysis::Constant* {
68     assert(inst->opcode() == SpvOpVectorShuffle);
69     const analysis::Constant* c1 = constants[0];
70     const analysis::Constant* c2 = constants[1];
71     if (c1 == nullptr || c2 == nullptr) {
72       return nullptr;
73     }
74 
75     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
76     const analysis::Type* element_type = c1->type()->AsVector()->element_type();
77 
78     std::vector<const analysis::Constant*> c1_components;
79     if (const analysis::VectorConstant* vec_const = c1->AsVectorConstant()) {
80       c1_components = vec_const->GetComponents();
81     } else {
82       assert(c1->AsNullConstant());
83       const analysis::Constant* element =
84           const_mgr->GetConstant(element_type, {});
85       c1_components.resize(c1->type()->AsVector()->element_count(), element);
86     }
87     std::vector<const analysis::Constant*> c2_components;
88     if (const analysis::VectorConstant* vec_const = c2->AsVectorConstant()) {
89       c2_components = vec_const->GetComponents();
90     } else {
91       assert(c2->AsNullConstant());
92       const analysis::Constant* element =
93           const_mgr->GetConstant(element_type, {});
94       c2_components.resize(c2->type()->AsVector()->element_count(), element);
95     }
96 
97     std::vector<uint32_t> ids;
98     const uint32_t undef_literal_value = 0xffffffff;
99     for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
100       uint32_t index = inst->GetSingleWordInOperand(i);
101       if (index == undef_literal_value) {
102         // Don't fold shuffle with undef literal value.
103         return nullptr;
104       } else if (index < c1_components.size()) {
105         Instruction* member_inst =
106             const_mgr->GetDefiningInstruction(c1_components[index]);
107         ids.push_back(member_inst->result_id());
108       } else {
109         Instruction* member_inst = const_mgr->GetDefiningInstruction(
110             c2_components[index - c1_components.size()]);
111         ids.push_back(member_inst->result_id());
112       }
113     }
114 
115     analysis::TypeManager* type_mgr = context->get_type_mgr();
116     return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
117   };
118 }
119 
FoldVectorTimesScalar()120 ConstantFoldingRule FoldVectorTimesScalar() {
121   return [](IRContext* context, Instruction* inst,
122             const std::vector<const analysis::Constant*>& constants)
123              -> const analysis::Constant* {
124     assert(inst->opcode() == SpvOpVectorTimesScalar);
125     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
126     analysis::TypeManager* type_mgr = context->get_type_mgr();
127 
128     if (!inst->IsFloatingPointFoldingAllowed()) {
129       if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
130         return nullptr;
131       }
132     }
133 
134     const analysis::Constant* c1 = constants[0];
135     const analysis::Constant* c2 = constants[1];
136 
137     if (c1 && c1->IsZero()) {
138       return c1;
139     }
140 
141     if (c2 && c2->IsZero()) {
142       // Get or create the NullConstant for this type.
143       std::vector<uint32_t> ids;
144       return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
145     }
146 
147     if (c1 == nullptr || c2 == nullptr) {
148       return nullptr;
149     }
150 
151     // Check result type.
152     const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
153     const analysis::Vector* vector_type = result_type->AsVector();
154     assert(vector_type != nullptr);
155     const analysis::Type* element_type = vector_type->element_type();
156     assert(element_type != nullptr);
157     const analysis::Float* float_type = element_type->AsFloat();
158     assert(float_type != nullptr);
159 
160     // Check types of c1 and c2.
161     assert(c1->type()->AsVector() == vector_type);
162     assert(c1->type()->AsVector()->element_type() == element_type &&
163            c2->type() == element_type);
164 
165     // Get a float vector that is the result of vector-times-scalar.
166     std::vector<const analysis::Constant*> c1_components =
167         c1->GetVectorComponents(const_mgr);
168     std::vector<uint32_t> ids;
169     if (float_type->width() == 32) {
170       float scalar = c2->GetFloat();
171       for (uint32_t i = 0; i < c1_components.size(); ++i) {
172         utils::FloatProxy<float> result(c1_components[i]->GetFloat() * scalar);
173         std::vector<uint32_t> words = result.GetWords();
174         const analysis::Constant* new_elem =
175             const_mgr->GetConstant(float_type, words);
176         ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
177       }
178       return const_mgr->GetConstant(vector_type, ids);
179     } else if (float_type->width() == 64) {
180       double scalar = c2->GetDouble();
181       for (uint32_t i = 0; i < c1_components.size(); ++i) {
182         utils::FloatProxy<double> result(c1_components[i]->GetDouble() *
183                                          scalar);
184         std::vector<uint32_t> words = result.GetWords();
185         const analysis::Constant* new_elem =
186             const_mgr->GetConstant(float_type, words);
187         ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
188       }
189       return const_mgr->GetConstant(vector_type, ids);
190     }
191     return nullptr;
192   };
193 }
194 
FoldCompositeWithConstants()195 ConstantFoldingRule FoldCompositeWithConstants() {
196   // Folds an OpCompositeConstruct where all of the inputs are constants to a
197   // constant.  A new constant is created if necessary.
198   return [](IRContext* context, Instruction* inst,
199             const std::vector<const analysis::Constant*>& constants)
200              -> const analysis::Constant* {
201     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
202     analysis::TypeManager* type_mgr = context->get_type_mgr();
203     const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
204     Instruction* type_inst =
205         context->get_def_use_mgr()->GetDef(inst->type_id());
206 
207     std::vector<uint32_t> ids;
208     for (uint32_t i = 0; i < constants.size(); ++i) {
209       const analysis::Constant* element_const = constants[i];
210       if (element_const == nullptr) {
211         return nullptr;
212       }
213 
214       uint32_t component_type_id = 0;
215       if (type_inst->opcode() == SpvOpTypeStruct) {
216         component_type_id = type_inst->GetSingleWordInOperand(i);
217       } else if (type_inst->opcode() == SpvOpTypeArray) {
218         component_type_id = type_inst->GetSingleWordInOperand(0);
219       }
220 
221       uint32_t element_id =
222           const_mgr->FindDeclaredConstant(element_const, component_type_id);
223       if (element_id == 0) {
224         return nullptr;
225       }
226       ids.push_back(element_id);
227     }
228     return const_mgr->GetConstant(new_type, ids);
229   };
230 }
231 
232 // The interface for a function that returns the result of applying a scalar
233 // floating-point binary operation on |a| and |b|.  The type of the return value
234 // will be |type|.  The input constants must also be of type |type|.
235 using UnaryScalarFoldingRule = std::function<const analysis::Constant*(
236     const analysis::Type* result_type, const analysis::Constant* a,
237     analysis::ConstantManager*)>;
238 
239 // The interface for a function that returns the result of applying a scalar
240 // floating-point binary operation on |a| and |b|.  The type of the return value
241 // will be |type|.  The input constants must also be of type |type|.
242 using BinaryScalarFoldingRule = std::function<const analysis::Constant*(
243     const analysis::Type* result_type, const analysis::Constant* a,
244     const analysis::Constant* b, analysis::ConstantManager*)>;
245 
246 // Returns a |ConstantFoldingRule| that folds unary floating point scalar ops
247 // using |scalar_rule| and unary float point vectors ops by applying
248 // |scalar_rule| to the elements of the vector.  The |ConstantFoldingRule|
249 // that is returned assumes that |constants| contains 1 entry.  If they are
250 // not |nullptr|, then their type is either |Float| or |Integer| or a |Vector|
251 // whose element type is |Float| or |Integer|.
FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule)252 ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
253   return [scalar_rule](IRContext* context, Instruction* inst,
254                        const std::vector<const analysis::Constant*>& constants)
255              -> const analysis::Constant* {
256     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
257     analysis::TypeManager* type_mgr = context->get_type_mgr();
258     const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
259     const analysis::Vector* vector_type = result_type->AsVector();
260 
261     if (!inst->IsFloatingPointFoldingAllowed()) {
262       return nullptr;
263     }
264 
265     if (constants[0] == nullptr) {
266       return nullptr;
267     }
268 
269     if (vector_type != nullptr) {
270       std::vector<const analysis::Constant*> a_components;
271       std::vector<const analysis::Constant*> results_components;
272 
273       a_components = constants[0]->GetVectorComponents(const_mgr);
274 
275       // Fold each component of the vector.
276       for (uint32_t i = 0; i < a_components.size(); ++i) {
277         results_components.push_back(scalar_rule(vector_type->element_type(),
278                                                  a_components[i], const_mgr));
279         if (results_components[i] == nullptr) {
280           return nullptr;
281         }
282       }
283 
284       // Build the constant object and return it.
285       std::vector<uint32_t> ids;
286       for (const analysis::Constant* member : results_components) {
287         ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
288       }
289       return const_mgr->GetConstant(vector_type, ids);
290     } else {
291       return scalar_rule(result_type, constants[0], const_mgr);
292     }
293   };
294 }
295 
296 // Returns a |ConstantFoldingRule| that folds floating point scalars using
297 // |scalar_rule| and vectors of floating point by applying |scalar_rule| to the
298 // elements of the vector.  The |ConstantFoldingRule| that is returned assumes
299 // that |constants| contains 2 entries.  If they are not |nullptr|, then their
300 // type is either |Float| or a |Vector| whose element type is |Float|.
FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule)301 ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) {
302   return [scalar_rule](IRContext* context, Instruction* inst,
303                        const std::vector<const analysis::Constant*>& constants)
304              -> const analysis::Constant* {
305     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
306     analysis::TypeManager* type_mgr = context->get_type_mgr();
307     const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
308     const analysis::Vector* vector_type = result_type->AsVector();
309 
310     if (!inst->IsFloatingPointFoldingAllowed()) {
311       return nullptr;
312     }
313 
314     if (constants[0] == nullptr || constants[1] == nullptr) {
315       return nullptr;
316     }
317 
318     if (vector_type != nullptr) {
319       std::vector<const analysis::Constant*> a_components;
320       std::vector<const analysis::Constant*> b_components;
321       std::vector<const analysis::Constant*> results_components;
322 
323       a_components = constants[0]->GetVectorComponents(const_mgr);
324       b_components = constants[1]->GetVectorComponents(const_mgr);
325 
326       // Fold each component of the vector.
327       for (uint32_t i = 0; i < a_components.size(); ++i) {
328         results_components.push_back(scalar_rule(vector_type->element_type(),
329                                                  a_components[i],
330                                                  b_components[i], const_mgr));
331         if (results_components[i] == nullptr) {
332           return nullptr;
333         }
334       }
335 
336       // Build the constant object and return it.
337       std::vector<uint32_t> ids;
338       for (const analysis::Constant* member : results_components) {
339         ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
340       }
341       return const_mgr->GetConstant(vector_type, ids);
342     } else {
343       return scalar_rule(result_type, constants[0], constants[1], const_mgr);
344     }
345   };
346 }
347 
348 // This macro defines a |UnaryScalarFoldingRule| that performs float to
349 // integer conversion.
350 // TODO(greg-lunarg): Support for 64-bit integer types.
FoldFToIOp()351 UnaryScalarFoldingRule FoldFToIOp() {
352   return [](const analysis::Type* result_type, const analysis::Constant* a,
353             analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
354     assert(result_type != nullptr && a != nullptr);
355     const analysis::Integer* integer_type = result_type->AsInteger();
356     const analysis::Float* float_type = a->type()->AsFloat();
357     assert(float_type != nullptr);
358     assert(integer_type != nullptr);
359     if (integer_type->width() != 32) return nullptr;
360     if (float_type->width() == 32) {
361       float fa = a->GetFloat();
362       uint32_t result = integer_type->IsSigned()
363                             ? static_cast<uint32_t>(static_cast<int32_t>(fa))
364                             : static_cast<uint32_t>(fa);
365       std::vector<uint32_t> words = {result};
366       return const_mgr->GetConstant(result_type, words);
367     } else if (float_type->width() == 64) {
368       double fa = a->GetDouble();
369       uint32_t result = integer_type->IsSigned()
370                             ? static_cast<uint32_t>(static_cast<int32_t>(fa))
371                             : static_cast<uint32_t>(fa);
372       std::vector<uint32_t> words = {result};
373       return const_mgr->GetConstant(result_type, words);
374     }
375     return nullptr;
376   };
377 }
378 
379 // This function defines a |UnaryScalarFoldingRule| that performs integer to
380 // float conversion.
381 // TODO(greg-lunarg): Support for 64-bit integer types.
FoldIToFOp()382 UnaryScalarFoldingRule FoldIToFOp() {
383   return [](const analysis::Type* result_type, const analysis::Constant* a,
384             analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
385     assert(result_type != nullptr && a != nullptr);
386     const analysis::Integer* integer_type = a->type()->AsInteger();
387     const analysis::Float* float_type = result_type->AsFloat();
388     assert(float_type != nullptr);
389     assert(integer_type != nullptr);
390     if (integer_type->width() != 32) return nullptr;
391     uint32_t ua = a->GetU32();
392     if (float_type->width() == 32) {
393       float result_val = integer_type->IsSigned()
394                              ? static_cast<float>(static_cast<int32_t>(ua))
395                              : static_cast<float>(ua);
396       utils::FloatProxy<float> result(result_val);
397       std::vector<uint32_t> words = {result.data()};
398       return const_mgr->GetConstant(result_type, words);
399     } else if (float_type->width() == 64) {
400       double result_val = integer_type->IsSigned()
401                               ? static_cast<double>(static_cast<int32_t>(ua))
402                               : static_cast<double>(ua);
403       utils::FloatProxy<double> result(result_val);
404       std::vector<uint32_t> words = result.GetWords();
405       return const_mgr->GetConstant(result_type, words);
406     }
407     return nullptr;
408   };
409 }
410 
411 // This macro defines a |BinaryScalarFoldingRule| that applies |op|.  The
412 // operator |op| must work for both float and double, and use syntax "f1 op f2".
413 #define FOLD_FPARITH_OP(op)                                                \
414   [](const analysis::Type* result_type, const analysis::Constant* a,       \
415      const analysis::Constant* b,                                          \
416      analysis::ConstantManager* const_mgr_in_macro)                        \
417       -> const analysis::Constant* {                                       \
418     assert(result_type != nullptr && a != nullptr && b != nullptr);        \
419     assert(result_type == a->type() && result_type == b->type());          \
420     const analysis::Float* float_type_in_macro = result_type->AsFloat();   \
421     assert(float_type_in_macro != nullptr);                                \
422     if (float_type_in_macro->width() == 32) {                              \
423       float fa = a->GetFloat();                                            \
424       float fb = b->GetFloat();                                            \
425       utils::FloatProxy<float> result_in_macro(fa op fb);                  \
426       std::vector<uint32_t> words_in_macro = result_in_macro.GetWords();   \
427       return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \
428     } else if (float_type_in_macro->width() == 64) {                       \
429       double fa = a->GetDouble();                                          \
430       double fb = b->GetDouble();                                          \
431       utils::FloatProxy<double> result_in_macro(fa op fb);                 \
432       std::vector<uint32_t> words_in_macro = result_in_macro.GetWords();   \
433       return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \
434     }                                                                      \
435     return nullptr;                                                        \
436   }
437 
438 // Define the folding rule for conversion between floating point and integer
FoldFToI()439 ConstantFoldingRule FoldFToI() { return FoldFPUnaryOp(FoldFToIOp()); }
FoldIToF()440 ConstantFoldingRule FoldIToF() { return FoldFPUnaryOp(FoldIToFOp()); }
441 
442 // Define the folding rules for subtraction, addition, multiplication, and
443 // division for floating point values.
FoldFSub()444 ConstantFoldingRule FoldFSub() { return FoldFPBinaryOp(FOLD_FPARITH_OP(-)); }
FoldFAdd()445 ConstantFoldingRule FoldFAdd() { return FoldFPBinaryOp(FOLD_FPARITH_OP(+)); }
FoldFMul()446 ConstantFoldingRule FoldFMul() { return FoldFPBinaryOp(FOLD_FPARITH_OP(*)); }
FoldFDiv()447 ConstantFoldingRule FoldFDiv() { return FoldFPBinaryOp(FOLD_FPARITH_OP(/)); }
448 
CompareFloatingPoint(bool op_result,bool op_unordered,bool need_ordered)449 bool CompareFloatingPoint(bool op_result, bool op_unordered,
450                           bool need_ordered) {
451   if (need_ordered) {
452     // operands are ordered and Operand 1 is |op| Operand 2
453     return !op_unordered && op_result;
454   } else {
455     // operands are unordered or Operand 1 is |op| Operand 2
456     return op_unordered || op_result;
457   }
458 }
459 
460 // This macro defines a |BinaryScalarFoldingRule| that applies |op|.  The
461 // operator |op| must work for both float and double, and use syntax "f1 op f2".
462 #define FOLD_FPCMP_OP(op, ord)                                            \
463   [](const analysis::Type* result_type, const analysis::Constant* a,      \
464      const analysis::Constant* b,                                         \
465      analysis::ConstantManager* const_mgr) -> const analysis::Constant* { \
466     assert(result_type != nullptr && a != nullptr && b != nullptr);       \
467     assert(result_type->AsBool());                                        \
468     assert(a->type() == b->type());                                       \
469     const analysis::Float* float_type = a->type()->AsFloat();             \
470     assert(float_type != nullptr);                                        \
471     if (float_type->width() == 32) {                                      \
472       float fa = a->GetFloat();                                           \
473       float fb = b->GetFloat();                                           \
474       bool result = CompareFloatingPoint(                                 \
475           fa op fb, std::isnan(fa) || std::isnan(fb), ord);               \
476       std::vector<uint32_t> words = {uint32_t(result)};                   \
477       return const_mgr->GetConstant(result_type, words);                  \
478     } else if (float_type->width() == 64) {                               \
479       double fa = a->GetDouble();                                         \
480       double fb = b->GetDouble();                                         \
481       bool result = CompareFloatingPoint(                                 \
482           fa op fb, std::isnan(fa) || std::isnan(fb), ord);               \
483       std::vector<uint32_t> words = {uint32_t(result)};                   \
484       return const_mgr->GetConstant(result_type, words);                  \
485     }                                                                     \
486     return nullptr;                                                       \
487   }
488 
489 // Define the folding rules for ordered and unordered comparison for floating
490 // point values.
FoldFOrdEqual()491 ConstantFoldingRule FoldFOrdEqual() {
492   return FoldFPBinaryOp(FOLD_FPCMP_OP(==, true));
493 }
FoldFUnordEqual()494 ConstantFoldingRule FoldFUnordEqual() {
495   return FoldFPBinaryOp(FOLD_FPCMP_OP(==, false));
496 }
FoldFOrdNotEqual()497 ConstantFoldingRule FoldFOrdNotEqual() {
498   return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, true));
499 }
FoldFUnordNotEqual()500 ConstantFoldingRule FoldFUnordNotEqual() {
501   return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, false));
502 }
FoldFOrdLessThan()503 ConstantFoldingRule FoldFOrdLessThan() {
504   return FoldFPBinaryOp(FOLD_FPCMP_OP(<, true));
505 }
FoldFUnordLessThan()506 ConstantFoldingRule FoldFUnordLessThan() {
507   return FoldFPBinaryOp(FOLD_FPCMP_OP(<, false));
508 }
FoldFOrdGreaterThan()509 ConstantFoldingRule FoldFOrdGreaterThan() {
510   return FoldFPBinaryOp(FOLD_FPCMP_OP(>, true));
511 }
FoldFUnordGreaterThan()512 ConstantFoldingRule FoldFUnordGreaterThan() {
513   return FoldFPBinaryOp(FOLD_FPCMP_OP(>, false));
514 }
FoldFOrdLessThanEqual()515 ConstantFoldingRule FoldFOrdLessThanEqual() {
516   return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, true));
517 }
FoldFUnordLessThanEqual()518 ConstantFoldingRule FoldFUnordLessThanEqual() {
519   return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, false));
520 }
FoldFOrdGreaterThanEqual()521 ConstantFoldingRule FoldFOrdGreaterThanEqual() {
522   return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, true));
523 }
FoldFUnordGreaterThanEqual()524 ConstantFoldingRule FoldFUnordGreaterThanEqual() {
525   return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, false));
526 }
527 
528 // Folds an OpDot where all of the inputs are constants to a
529 // constant.  A new constant is created if necessary.
FoldOpDotWithConstants()530 ConstantFoldingRule FoldOpDotWithConstants() {
531   return [](IRContext* context, Instruction* inst,
532             const std::vector<const analysis::Constant*>& constants)
533              -> const analysis::Constant* {
534     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
535     analysis::TypeManager* type_mgr = context->get_type_mgr();
536     const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
537     assert(new_type->AsFloat() && "OpDot should have a float return type.");
538     const analysis::Float* float_type = new_type->AsFloat();
539 
540     if (!inst->IsFloatingPointFoldingAllowed()) {
541       return nullptr;
542     }
543 
544     // If one of the operands is 0, then the result is 0.
545     bool has_zero_operand = false;
546 
547     for (int i = 0; i < 2; ++i) {
548       if (constants[i]) {
549         if (constants[i]->AsNullConstant() ||
550             constants[i]->AsVectorConstant()->IsZero()) {
551           has_zero_operand = true;
552           break;
553         }
554       }
555     }
556 
557     if (has_zero_operand) {
558       if (float_type->width() == 32) {
559         utils::FloatProxy<float> result(0.0f);
560         std::vector<uint32_t> words = result.GetWords();
561         return const_mgr->GetConstant(float_type, words);
562       }
563       if (float_type->width() == 64) {
564         utils::FloatProxy<double> result(0.0);
565         std::vector<uint32_t> words = result.GetWords();
566         return const_mgr->GetConstant(float_type, words);
567       }
568       return nullptr;
569     }
570 
571     if (constants[0] == nullptr || constants[1] == nullptr) {
572       return nullptr;
573     }
574 
575     std::vector<const analysis::Constant*> a_components;
576     std::vector<const analysis::Constant*> b_components;
577 
578     a_components = constants[0]->GetVectorComponents(const_mgr);
579     b_components = constants[1]->GetVectorComponents(const_mgr);
580 
581     utils::FloatProxy<double> result(0.0);
582     std::vector<uint32_t> words = result.GetWords();
583     const analysis::Constant* result_const =
584         const_mgr->GetConstant(float_type, words);
585     for (uint32_t i = 0; i < a_components.size(); ++i) {
586       if (a_components[i] == nullptr || b_components[i] == nullptr) {
587         return nullptr;
588       }
589 
590       const analysis::Constant* component = FOLD_FPARITH_OP(*)(
591           new_type, a_components[i], b_components[i], const_mgr);
592       result_const =
593           FOLD_FPARITH_OP(+)(new_type, result_const, component, const_mgr);
594     }
595     return result_const;
596   };
597 }
598 
599 // This function defines a |UnaryScalarFoldingRule| that subtracts the constant
600 // from zero.
FoldFNegateOp()601 UnaryScalarFoldingRule FoldFNegateOp() {
602   return [](const analysis::Type* result_type, const analysis::Constant* a,
603             analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
604     assert(result_type != nullptr && a != nullptr);
605     assert(result_type == a->type());
606     const analysis::Float* float_type = result_type->AsFloat();
607     assert(float_type != nullptr);
608     if (float_type->width() == 32) {
609       float fa = a->GetFloat();
610       utils::FloatProxy<float> result(-fa);
611       std::vector<uint32_t> words = result.GetWords();
612       return const_mgr->GetConstant(result_type, words);
613     } else if (float_type->width() == 64) {
614       double da = a->GetDouble();
615       utils::FloatProxy<double> result(-da);
616       std::vector<uint32_t> words = result.GetWords();
617       return const_mgr->GetConstant(result_type, words);
618     }
619     return nullptr;
620   };
621 }
622 
FoldFNegate()623 ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(FoldFNegateOp()); }
624 
FoldFClampFeedingCompare(uint32_t cmp_opcode)625 ConstantFoldingRule FoldFClampFeedingCompare(uint32_t cmp_opcode) {
626   return [cmp_opcode](IRContext* context, Instruction* inst,
627                       const std::vector<const analysis::Constant*>& constants)
628              -> const analysis::Constant* {
629     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
630     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
631 
632     if (!inst->IsFloatingPointFoldingAllowed()) {
633       return nullptr;
634     }
635 
636     uint32_t non_const_idx = (constants[0] ? 1 : 0);
637     uint32_t operand_id = inst->GetSingleWordInOperand(non_const_idx);
638     Instruction* operand_inst = def_use_mgr->GetDef(operand_id);
639 
640     analysis::TypeManager* type_mgr = context->get_type_mgr();
641     const analysis::Type* operand_type =
642         type_mgr->GetType(operand_inst->type_id());
643 
644     if (!operand_type->AsFloat()) {
645       return nullptr;
646     }
647 
648     if (operand_type->AsFloat()->width() != 32 &&
649         operand_type->AsFloat()->width() != 64) {
650       return nullptr;
651     }
652 
653     if (operand_inst->opcode() != SpvOpExtInst) {
654       return nullptr;
655     }
656 
657     if (operand_inst->GetSingleWordInOperand(1) != GLSLstd450FClamp) {
658       return nullptr;
659     }
660 
661     if (constants[1] == nullptr && constants[0] == nullptr) {
662       return nullptr;
663     }
664 
665     uint32_t max_id = operand_inst->GetSingleWordInOperand(4);
666     const analysis::Constant* max_const =
667         const_mgr->FindDeclaredConstant(max_id);
668 
669     uint32_t min_id = operand_inst->GetSingleWordInOperand(3);
670     const analysis::Constant* min_const =
671         const_mgr->FindDeclaredConstant(min_id);
672 
673     bool found_result = false;
674     bool result = false;
675 
676     switch (cmp_opcode) {
677       case SpvOpFOrdLessThan:
678       case SpvOpFUnordLessThan:
679       case SpvOpFOrdGreaterThanEqual:
680       case SpvOpFUnordGreaterThanEqual:
681         if (constants[0]) {
682           if (min_const) {
683             if (constants[0]->GetValueAsDouble() <
684                 min_const->GetValueAsDouble()) {
685               found_result = true;
686               result = (cmp_opcode == SpvOpFOrdLessThan ||
687                         cmp_opcode == SpvOpFUnordLessThan);
688             }
689           }
690           if (max_const) {
691             if (constants[0]->GetValueAsDouble() >=
692                 max_const->GetValueAsDouble()) {
693               found_result = true;
694               result = !(cmp_opcode == SpvOpFOrdLessThan ||
695                          cmp_opcode == SpvOpFUnordLessThan);
696             }
697           }
698         }
699 
700         if (constants[1]) {
701           if (max_const) {
702             if (max_const->GetValueAsDouble() <
703                 constants[1]->GetValueAsDouble()) {
704               found_result = true;
705               result = (cmp_opcode == SpvOpFOrdLessThan ||
706                         cmp_opcode == SpvOpFUnordLessThan);
707             }
708           }
709 
710           if (min_const) {
711             if (min_const->GetValueAsDouble() >=
712                 constants[1]->GetValueAsDouble()) {
713               found_result = true;
714               result = !(cmp_opcode == SpvOpFOrdLessThan ||
715                          cmp_opcode == SpvOpFUnordLessThan);
716             }
717           }
718         }
719         break;
720       case SpvOpFOrdGreaterThan:
721       case SpvOpFUnordGreaterThan:
722       case SpvOpFOrdLessThanEqual:
723       case SpvOpFUnordLessThanEqual:
724         if (constants[0]) {
725           if (min_const) {
726             if (constants[0]->GetValueAsDouble() <=
727                 min_const->GetValueAsDouble()) {
728               found_result = true;
729               result = (cmp_opcode == SpvOpFOrdLessThanEqual ||
730                         cmp_opcode == SpvOpFUnordLessThanEqual);
731             }
732           }
733           if (max_const) {
734             if (constants[0]->GetValueAsDouble() >
735                 max_const->GetValueAsDouble()) {
736               found_result = true;
737               result = !(cmp_opcode == SpvOpFOrdLessThanEqual ||
738                          cmp_opcode == SpvOpFUnordLessThanEqual);
739             }
740           }
741         }
742 
743         if (constants[1]) {
744           if (max_const) {
745             if (max_const->GetValueAsDouble() <=
746                 constants[1]->GetValueAsDouble()) {
747               found_result = true;
748               result = (cmp_opcode == SpvOpFOrdLessThanEqual ||
749                         cmp_opcode == SpvOpFUnordLessThanEqual);
750             }
751           }
752 
753           if (min_const) {
754             if (min_const->GetValueAsDouble() >
755                 constants[1]->GetValueAsDouble()) {
756               found_result = true;
757               result = !(cmp_opcode == SpvOpFOrdLessThanEqual ||
758                          cmp_opcode == SpvOpFUnordLessThanEqual);
759             }
760           }
761         }
762         break;
763       default:
764         return nullptr;
765     }
766 
767     if (!found_result) {
768       return nullptr;
769     }
770 
771     const analysis::Type* bool_type =
772         context->get_type_mgr()->GetType(inst->type_id());
773     const analysis::Constant* result_const =
774         const_mgr->GetConstant(bool_type, {static_cast<uint32_t>(result)});
775     assert(result_const);
776     return result_const;
777   };
778 }
779 
780 }  // namespace
781 
ConstantFoldingRules()782 ConstantFoldingRules::ConstantFoldingRules() {
783   // Add all folding rules to the list for the opcodes to which they apply.
784   // Note that the order in which rules are added to the list matters. If a rule
785   // applies to the instruction, the rest of the rules will not be attempted.
786   // Take that into consideration.
787 
788   rules_[SpvOpCompositeConstruct].push_back(FoldCompositeWithConstants());
789 
790   rules_[SpvOpCompositeExtract].push_back(FoldExtractWithConstants());
791 
792   rules_[SpvOpConvertFToS].push_back(FoldFToI());
793   rules_[SpvOpConvertFToU].push_back(FoldFToI());
794   rules_[SpvOpConvertSToF].push_back(FoldIToF());
795   rules_[SpvOpConvertUToF].push_back(FoldIToF());
796 
797   rules_[SpvOpDot].push_back(FoldOpDotWithConstants());
798   rules_[SpvOpFAdd].push_back(FoldFAdd());
799   rules_[SpvOpFDiv].push_back(FoldFDiv());
800   rules_[SpvOpFMul].push_back(FoldFMul());
801   rules_[SpvOpFSub].push_back(FoldFSub());
802 
803   rules_[SpvOpFOrdEqual].push_back(FoldFOrdEqual());
804 
805   rules_[SpvOpFUnordEqual].push_back(FoldFUnordEqual());
806 
807   rules_[SpvOpFOrdNotEqual].push_back(FoldFOrdNotEqual());
808 
809   rules_[SpvOpFUnordNotEqual].push_back(FoldFUnordNotEqual());
810 
811   rules_[SpvOpFOrdLessThan].push_back(FoldFOrdLessThan());
812   rules_[SpvOpFOrdLessThan].push_back(
813       FoldFClampFeedingCompare(SpvOpFOrdLessThan));
814 
815   rules_[SpvOpFUnordLessThan].push_back(FoldFUnordLessThan());
816   rules_[SpvOpFUnordLessThan].push_back(
817       FoldFClampFeedingCompare(SpvOpFUnordLessThan));
818 
819   rules_[SpvOpFOrdGreaterThan].push_back(FoldFOrdGreaterThan());
820   rules_[SpvOpFOrdGreaterThan].push_back(
821       FoldFClampFeedingCompare(SpvOpFOrdGreaterThan));
822 
823   rules_[SpvOpFUnordGreaterThan].push_back(FoldFUnordGreaterThan());
824   rules_[SpvOpFUnordGreaterThan].push_back(
825       FoldFClampFeedingCompare(SpvOpFUnordGreaterThan));
826 
827   rules_[SpvOpFOrdLessThanEqual].push_back(FoldFOrdLessThanEqual());
828   rules_[SpvOpFOrdLessThanEqual].push_back(
829       FoldFClampFeedingCompare(SpvOpFOrdLessThanEqual));
830 
831   rules_[SpvOpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual());
832   rules_[SpvOpFUnordLessThanEqual].push_back(
833       FoldFClampFeedingCompare(SpvOpFUnordLessThanEqual));
834 
835   rules_[SpvOpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual());
836   rules_[SpvOpFOrdGreaterThanEqual].push_back(
837       FoldFClampFeedingCompare(SpvOpFOrdGreaterThanEqual));
838 
839   rules_[SpvOpFUnordGreaterThanEqual].push_back(FoldFUnordGreaterThanEqual());
840   rules_[SpvOpFUnordGreaterThanEqual].push_back(
841       FoldFClampFeedingCompare(SpvOpFUnordGreaterThanEqual));
842 
843   rules_[SpvOpVectorShuffle].push_back(FoldVectorShuffleWithConstants());
844   rules_[SpvOpVectorTimesScalar].push_back(FoldVectorTimesScalar());
845 
846   rules_[SpvOpFNegate].push_back(FoldFNegate());
847 }
848 }  // namespace opt
849 }  // namespace spvtools
850