1 // Copyright (c) 2018 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/const_folding_rules.h"
16
17 #include "source/opt/ir_context.h"
18
19 namespace spvtools {
20 namespace opt {
21 namespace {
22
23 const uint32_t kExtractCompositeIdInIdx = 0;
24
25 // Returns true if |type| is Float or a vector of Float.
HasFloatingPoint(const analysis::Type * type)26 bool HasFloatingPoint(const analysis::Type* type) {
27 if (type->AsFloat()) {
28 return true;
29 } else if (const analysis::Vector* vec_type = type->AsVector()) {
30 return vec_type->element_type()->AsFloat() != nullptr;
31 }
32
33 return false;
34 }
35
36 // Folds an OpcompositeExtract where input is a composite constant.
FoldExtractWithConstants()37 ConstantFoldingRule FoldExtractWithConstants() {
38 return [](IRContext* context, Instruction* inst,
39 const std::vector<const analysis::Constant*>& constants)
40 -> const analysis::Constant* {
41 const analysis::Constant* c = constants[kExtractCompositeIdInIdx];
42 if (c == nullptr) {
43 return nullptr;
44 }
45
46 for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
47 uint32_t element_index = inst->GetSingleWordInOperand(i);
48 if (c->AsNullConstant()) {
49 // Return Null for the return type.
50 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
51 analysis::TypeManager* type_mgr = context->get_type_mgr();
52 return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {});
53 }
54
55 auto cc = c->AsCompositeConstant();
56 assert(cc != nullptr);
57 auto components = cc->GetComponents();
58 c = components[element_index];
59 }
60 return c;
61 };
62 }
63
FoldVectorShuffleWithConstants()64 ConstantFoldingRule FoldVectorShuffleWithConstants() {
65 return [](IRContext* context, Instruction* inst,
66 const std::vector<const analysis::Constant*>& constants)
67 -> const analysis::Constant* {
68 assert(inst->opcode() == SpvOpVectorShuffle);
69 const analysis::Constant* c1 = constants[0];
70 const analysis::Constant* c2 = constants[1];
71 if (c1 == nullptr || c2 == nullptr) {
72 return nullptr;
73 }
74
75 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
76 const analysis::Type* element_type = c1->type()->AsVector()->element_type();
77
78 std::vector<const analysis::Constant*> c1_components;
79 if (const analysis::VectorConstant* vec_const = c1->AsVectorConstant()) {
80 c1_components = vec_const->GetComponents();
81 } else {
82 assert(c1->AsNullConstant());
83 const analysis::Constant* element =
84 const_mgr->GetConstant(element_type, {});
85 c1_components.resize(c1->type()->AsVector()->element_count(), element);
86 }
87 std::vector<const analysis::Constant*> c2_components;
88 if (const analysis::VectorConstant* vec_const = c2->AsVectorConstant()) {
89 c2_components = vec_const->GetComponents();
90 } else {
91 assert(c2->AsNullConstant());
92 const analysis::Constant* element =
93 const_mgr->GetConstant(element_type, {});
94 c2_components.resize(c2->type()->AsVector()->element_count(), element);
95 }
96
97 std::vector<uint32_t> ids;
98 const uint32_t undef_literal_value = 0xffffffff;
99 for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
100 uint32_t index = inst->GetSingleWordInOperand(i);
101 if (index == undef_literal_value) {
102 // Don't fold shuffle with undef literal value.
103 return nullptr;
104 } else if (index < c1_components.size()) {
105 Instruction* member_inst =
106 const_mgr->GetDefiningInstruction(c1_components[index]);
107 ids.push_back(member_inst->result_id());
108 } else {
109 Instruction* member_inst = const_mgr->GetDefiningInstruction(
110 c2_components[index - c1_components.size()]);
111 ids.push_back(member_inst->result_id());
112 }
113 }
114
115 analysis::TypeManager* type_mgr = context->get_type_mgr();
116 return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
117 };
118 }
119
FoldVectorTimesScalar()120 ConstantFoldingRule FoldVectorTimesScalar() {
121 return [](IRContext* context, Instruction* inst,
122 const std::vector<const analysis::Constant*>& constants)
123 -> const analysis::Constant* {
124 assert(inst->opcode() == SpvOpVectorTimesScalar);
125 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
126 analysis::TypeManager* type_mgr = context->get_type_mgr();
127
128 if (!inst->IsFloatingPointFoldingAllowed()) {
129 if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
130 return nullptr;
131 }
132 }
133
134 const analysis::Constant* c1 = constants[0];
135 const analysis::Constant* c2 = constants[1];
136
137 if (c1 && c1->IsZero()) {
138 return c1;
139 }
140
141 if (c2 && c2->IsZero()) {
142 // Get or create the NullConstant for this type.
143 std::vector<uint32_t> ids;
144 return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
145 }
146
147 if (c1 == nullptr || c2 == nullptr) {
148 return nullptr;
149 }
150
151 // Check result type.
152 const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
153 const analysis::Vector* vector_type = result_type->AsVector();
154 assert(vector_type != nullptr);
155 const analysis::Type* element_type = vector_type->element_type();
156 assert(element_type != nullptr);
157 const analysis::Float* float_type = element_type->AsFloat();
158 assert(float_type != nullptr);
159
160 // Check types of c1 and c2.
161 assert(c1->type()->AsVector() == vector_type);
162 assert(c1->type()->AsVector()->element_type() == element_type &&
163 c2->type() == element_type);
164
165 // Get a float vector that is the result of vector-times-scalar.
166 std::vector<const analysis::Constant*> c1_components =
167 c1->GetVectorComponents(const_mgr);
168 std::vector<uint32_t> ids;
169 if (float_type->width() == 32) {
170 float scalar = c2->GetFloat();
171 for (uint32_t i = 0; i < c1_components.size(); ++i) {
172 utils::FloatProxy<float> result(c1_components[i]->GetFloat() * scalar);
173 std::vector<uint32_t> words = result.GetWords();
174 const analysis::Constant* new_elem =
175 const_mgr->GetConstant(float_type, words);
176 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
177 }
178 return const_mgr->GetConstant(vector_type, ids);
179 } else if (float_type->width() == 64) {
180 double scalar = c2->GetDouble();
181 for (uint32_t i = 0; i < c1_components.size(); ++i) {
182 utils::FloatProxy<double> result(c1_components[i]->GetDouble() *
183 scalar);
184 std::vector<uint32_t> words = result.GetWords();
185 const analysis::Constant* new_elem =
186 const_mgr->GetConstant(float_type, words);
187 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
188 }
189 return const_mgr->GetConstant(vector_type, ids);
190 }
191 return nullptr;
192 };
193 }
194
FoldCompositeWithConstants()195 ConstantFoldingRule FoldCompositeWithConstants() {
196 // Folds an OpCompositeConstruct where all of the inputs are constants to a
197 // constant. A new constant is created if necessary.
198 return [](IRContext* context, Instruction* inst,
199 const std::vector<const analysis::Constant*>& constants)
200 -> const analysis::Constant* {
201 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
202 analysis::TypeManager* type_mgr = context->get_type_mgr();
203 const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
204 Instruction* type_inst =
205 context->get_def_use_mgr()->GetDef(inst->type_id());
206
207 std::vector<uint32_t> ids;
208 for (uint32_t i = 0; i < constants.size(); ++i) {
209 const analysis::Constant* element_const = constants[i];
210 if (element_const == nullptr) {
211 return nullptr;
212 }
213
214 uint32_t component_type_id = 0;
215 if (type_inst->opcode() == SpvOpTypeStruct) {
216 component_type_id = type_inst->GetSingleWordInOperand(i);
217 } else if (type_inst->opcode() == SpvOpTypeArray) {
218 component_type_id = type_inst->GetSingleWordInOperand(0);
219 }
220
221 uint32_t element_id =
222 const_mgr->FindDeclaredConstant(element_const, component_type_id);
223 if (element_id == 0) {
224 return nullptr;
225 }
226 ids.push_back(element_id);
227 }
228 return const_mgr->GetConstant(new_type, ids);
229 };
230 }
231
232 // The interface for a function that returns the result of applying a scalar
233 // floating-point binary operation on |a| and |b|. The type of the return value
234 // will be |type|. The input constants must also be of type |type|.
235 using UnaryScalarFoldingRule = std::function<const analysis::Constant*(
236 const analysis::Type* result_type, const analysis::Constant* a,
237 analysis::ConstantManager*)>;
238
239 // The interface for a function that returns the result of applying a scalar
240 // floating-point binary operation on |a| and |b|. The type of the return value
241 // will be |type|. The input constants must also be of type |type|.
242 using BinaryScalarFoldingRule = std::function<const analysis::Constant*(
243 const analysis::Type* result_type, const analysis::Constant* a,
244 const analysis::Constant* b, analysis::ConstantManager*)>;
245
246 // Returns a |ConstantFoldingRule| that folds unary floating point scalar ops
247 // using |scalar_rule| and unary float point vectors ops by applying
248 // |scalar_rule| to the elements of the vector. The |ConstantFoldingRule|
249 // that is returned assumes that |constants| contains 1 entry. If they are
250 // not |nullptr|, then their type is either |Float| or |Integer| or a |Vector|
251 // whose element type is |Float| or |Integer|.
FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule)252 ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
253 return [scalar_rule](IRContext* context, Instruction* inst,
254 const std::vector<const analysis::Constant*>& constants)
255 -> const analysis::Constant* {
256 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
257 analysis::TypeManager* type_mgr = context->get_type_mgr();
258 const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
259 const analysis::Vector* vector_type = result_type->AsVector();
260
261 if (!inst->IsFloatingPointFoldingAllowed()) {
262 return nullptr;
263 }
264
265 if (constants[0] == nullptr) {
266 return nullptr;
267 }
268
269 if (vector_type != nullptr) {
270 std::vector<const analysis::Constant*> a_components;
271 std::vector<const analysis::Constant*> results_components;
272
273 a_components = constants[0]->GetVectorComponents(const_mgr);
274
275 // Fold each component of the vector.
276 for (uint32_t i = 0; i < a_components.size(); ++i) {
277 results_components.push_back(scalar_rule(vector_type->element_type(),
278 a_components[i], const_mgr));
279 if (results_components[i] == nullptr) {
280 return nullptr;
281 }
282 }
283
284 // Build the constant object and return it.
285 std::vector<uint32_t> ids;
286 for (const analysis::Constant* member : results_components) {
287 ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
288 }
289 return const_mgr->GetConstant(vector_type, ids);
290 } else {
291 return scalar_rule(result_type, constants[0], const_mgr);
292 }
293 };
294 }
295
296 // Returns a |ConstantFoldingRule| that folds floating point scalars using
297 // |scalar_rule| and vectors of floating point by applying |scalar_rule| to the
298 // elements of the vector. The |ConstantFoldingRule| that is returned assumes
299 // that |constants| contains 2 entries. If they are not |nullptr|, then their
300 // type is either |Float| or a |Vector| whose element type is |Float|.
FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule)301 ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) {
302 return [scalar_rule](IRContext* context, Instruction* inst,
303 const std::vector<const analysis::Constant*>& constants)
304 -> const analysis::Constant* {
305 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
306 analysis::TypeManager* type_mgr = context->get_type_mgr();
307 const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
308 const analysis::Vector* vector_type = result_type->AsVector();
309
310 if (!inst->IsFloatingPointFoldingAllowed()) {
311 return nullptr;
312 }
313
314 if (constants[0] == nullptr || constants[1] == nullptr) {
315 return nullptr;
316 }
317
318 if (vector_type != nullptr) {
319 std::vector<const analysis::Constant*> a_components;
320 std::vector<const analysis::Constant*> b_components;
321 std::vector<const analysis::Constant*> results_components;
322
323 a_components = constants[0]->GetVectorComponents(const_mgr);
324 b_components = constants[1]->GetVectorComponents(const_mgr);
325
326 // Fold each component of the vector.
327 for (uint32_t i = 0; i < a_components.size(); ++i) {
328 results_components.push_back(scalar_rule(vector_type->element_type(),
329 a_components[i],
330 b_components[i], const_mgr));
331 if (results_components[i] == nullptr) {
332 return nullptr;
333 }
334 }
335
336 // Build the constant object and return it.
337 std::vector<uint32_t> ids;
338 for (const analysis::Constant* member : results_components) {
339 ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
340 }
341 return const_mgr->GetConstant(vector_type, ids);
342 } else {
343 return scalar_rule(result_type, constants[0], constants[1], const_mgr);
344 }
345 };
346 }
347
348 // This macro defines a |UnaryScalarFoldingRule| that performs float to
349 // integer conversion.
350 // TODO(greg-lunarg): Support for 64-bit integer types.
FoldFToIOp()351 UnaryScalarFoldingRule FoldFToIOp() {
352 return [](const analysis::Type* result_type, const analysis::Constant* a,
353 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
354 assert(result_type != nullptr && a != nullptr);
355 const analysis::Integer* integer_type = result_type->AsInteger();
356 const analysis::Float* float_type = a->type()->AsFloat();
357 assert(float_type != nullptr);
358 assert(integer_type != nullptr);
359 if (integer_type->width() != 32) return nullptr;
360 if (float_type->width() == 32) {
361 float fa = a->GetFloat();
362 uint32_t result = integer_type->IsSigned()
363 ? static_cast<uint32_t>(static_cast<int32_t>(fa))
364 : static_cast<uint32_t>(fa);
365 std::vector<uint32_t> words = {result};
366 return const_mgr->GetConstant(result_type, words);
367 } else if (float_type->width() == 64) {
368 double fa = a->GetDouble();
369 uint32_t result = integer_type->IsSigned()
370 ? static_cast<uint32_t>(static_cast<int32_t>(fa))
371 : static_cast<uint32_t>(fa);
372 std::vector<uint32_t> words = {result};
373 return const_mgr->GetConstant(result_type, words);
374 }
375 return nullptr;
376 };
377 }
378
379 // This function defines a |UnaryScalarFoldingRule| that performs integer to
380 // float conversion.
381 // TODO(greg-lunarg): Support for 64-bit integer types.
FoldIToFOp()382 UnaryScalarFoldingRule FoldIToFOp() {
383 return [](const analysis::Type* result_type, const analysis::Constant* a,
384 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
385 assert(result_type != nullptr && a != nullptr);
386 const analysis::Integer* integer_type = a->type()->AsInteger();
387 const analysis::Float* float_type = result_type->AsFloat();
388 assert(float_type != nullptr);
389 assert(integer_type != nullptr);
390 if (integer_type->width() != 32) return nullptr;
391 uint32_t ua = a->GetU32();
392 if (float_type->width() == 32) {
393 float result_val = integer_type->IsSigned()
394 ? static_cast<float>(static_cast<int32_t>(ua))
395 : static_cast<float>(ua);
396 utils::FloatProxy<float> result(result_val);
397 std::vector<uint32_t> words = {result.data()};
398 return const_mgr->GetConstant(result_type, words);
399 } else if (float_type->width() == 64) {
400 double result_val = integer_type->IsSigned()
401 ? static_cast<double>(static_cast<int32_t>(ua))
402 : static_cast<double>(ua);
403 utils::FloatProxy<double> result(result_val);
404 std::vector<uint32_t> words = result.GetWords();
405 return const_mgr->GetConstant(result_type, words);
406 }
407 return nullptr;
408 };
409 }
410
411 // This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
412 // operator |op| must work for both float and double, and use syntax "f1 op f2".
413 #define FOLD_FPARITH_OP(op) \
414 [](const analysis::Type* result_type, const analysis::Constant* a, \
415 const analysis::Constant* b, \
416 analysis::ConstantManager* const_mgr_in_macro) \
417 -> const analysis::Constant* { \
418 assert(result_type != nullptr && a != nullptr && b != nullptr); \
419 assert(result_type == a->type() && result_type == b->type()); \
420 const analysis::Float* float_type_in_macro = result_type->AsFloat(); \
421 assert(float_type_in_macro != nullptr); \
422 if (float_type_in_macro->width() == 32) { \
423 float fa = a->GetFloat(); \
424 float fb = b->GetFloat(); \
425 utils::FloatProxy<float> result_in_macro(fa op fb); \
426 std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
427 return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \
428 } else if (float_type_in_macro->width() == 64) { \
429 double fa = a->GetDouble(); \
430 double fb = b->GetDouble(); \
431 utils::FloatProxy<double> result_in_macro(fa op fb); \
432 std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
433 return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \
434 } \
435 return nullptr; \
436 }
437
438 // Define the folding rule for conversion between floating point and integer
FoldFToI()439 ConstantFoldingRule FoldFToI() { return FoldFPUnaryOp(FoldFToIOp()); }
FoldIToF()440 ConstantFoldingRule FoldIToF() { return FoldFPUnaryOp(FoldIToFOp()); }
441
442 // Define the folding rules for subtraction, addition, multiplication, and
443 // division for floating point values.
FoldFSub()444 ConstantFoldingRule FoldFSub() { return FoldFPBinaryOp(FOLD_FPARITH_OP(-)); }
FoldFAdd()445 ConstantFoldingRule FoldFAdd() { return FoldFPBinaryOp(FOLD_FPARITH_OP(+)); }
FoldFMul()446 ConstantFoldingRule FoldFMul() { return FoldFPBinaryOp(FOLD_FPARITH_OP(*)); }
FoldFDiv()447 ConstantFoldingRule FoldFDiv() { return FoldFPBinaryOp(FOLD_FPARITH_OP(/)); }
448
CompareFloatingPoint(bool op_result,bool op_unordered,bool need_ordered)449 bool CompareFloatingPoint(bool op_result, bool op_unordered,
450 bool need_ordered) {
451 if (need_ordered) {
452 // operands are ordered and Operand 1 is |op| Operand 2
453 return !op_unordered && op_result;
454 } else {
455 // operands are unordered or Operand 1 is |op| Operand 2
456 return op_unordered || op_result;
457 }
458 }
459
460 // This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
461 // operator |op| must work for both float and double, and use syntax "f1 op f2".
462 #define FOLD_FPCMP_OP(op, ord) \
463 [](const analysis::Type* result_type, const analysis::Constant* a, \
464 const analysis::Constant* b, \
465 analysis::ConstantManager* const_mgr) -> const analysis::Constant* { \
466 assert(result_type != nullptr && a != nullptr && b != nullptr); \
467 assert(result_type->AsBool()); \
468 assert(a->type() == b->type()); \
469 const analysis::Float* float_type = a->type()->AsFloat(); \
470 assert(float_type != nullptr); \
471 if (float_type->width() == 32) { \
472 float fa = a->GetFloat(); \
473 float fb = b->GetFloat(); \
474 bool result = CompareFloatingPoint( \
475 fa op fb, std::isnan(fa) || std::isnan(fb), ord); \
476 std::vector<uint32_t> words = {uint32_t(result)}; \
477 return const_mgr->GetConstant(result_type, words); \
478 } else if (float_type->width() == 64) { \
479 double fa = a->GetDouble(); \
480 double fb = b->GetDouble(); \
481 bool result = CompareFloatingPoint( \
482 fa op fb, std::isnan(fa) || std::isnan(fb), ord); \
483 std::vector<uint32_t> words = {uint32_t(result)}; \
484 return const_mgr->GetConstant(result_type, words); \
485 } \
486 return nullptr; \
487 }
488
489 // Define the folding rules for ordered and unordered comparison for floating
490 // point values.
FoldFOrdEqual()491 ConstantFoldingRule FoldFOrdEqual() {
492 return FoldFPBinaryOp(FOLD_FPCMP_OP(==, true));
493 }
FoldFUnordEqual()494 ConstantFoldingRule FoldFUnordEqual() {
495 return FoldFPBinaryOp(FOLD_FPCMP_OP(==, false));
496 }
FoldFOrdNotEqual()497 ConstantFoldingRule FoldFOrdNotEqual() {
498 return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, true));
499 }
FoldFUnordNotEqual()500 ConstantFoldingRule FoldFUnordNotEqual() {
501 return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, false));
502 }
FoldFOrdLessThan()503 ConstantFoldingRule FoldFOrdLessThan() {
504 return FoldFPBinaryOp(FOLD_FPCMP_OP(<, true));
505 }
FoldFUnordLessThan()506 ConstantFoldingRule FoldFUnordLessThan() {
507 return FoldFPBinaryOp(FOLD_FPCMP_OP(<, false));
508 }
FoldFOrdGreaterThan()509 ConstantFoldingRule FoldFOrdGreaterThan() {
510 return FoldFPBinaryOp(FOLD_FPCMP_OP(>, true));
511 }
FoldFUnordGreaterThan()512 ConstantFoldingRule FoldFUnordGreaterThan() {
513 return FoldFPBinaryOp(FOLD_FPCMP_OP(>, false));
514 }
FoldFOrdLessThanEqual()515 ConstantFoldingRule FoldFOrdLessThanEqual() {
516 return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, true));
517 }
FoldFUnordLessThanEqual()518 ConstantFoldingRule FoldFUnordLessThanEqual() {
519 return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, false));
520 }
FoldFOrdGreaterThanEqual()521 ConstantFoldingRule FoldFOrdGreaterThanEqual() {
522 return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, true));
523 }
FoldFUnordGreaterThanEqual()524 ConstantFoldingRule FoldFUnordGreaterThanEqual() {
525 return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, false));
526 }
527
528 // Folds an OpDot where all of the inputs are constants to a
529 // constant. A new constant is created if necessary.
FoldOpDotWithConstants()530 ConstantFoldingRule FoldOpDotWithConstants() {
531 return [](IRContext* context, Instruction* inst,
532 const std::vector<const analysis::Constant*>& constants)
533 -> const analysis::Constant* {
534 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
535 analysis::TypeManager* type_mgr = context->get_type_mgr();
536 const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
537 assert(new_type->AsFloat() && "OpDot should have a float return type.");
538 const analysis::Float* float_type = new_type->AsFloat();
539
540 if (!inst->IsFloatingPointFoldingAllowed()) {
541 return nullptr;
542 }
543
544 // If one of the operands is 0, then the result is 0.
545 bool has_zero_operand = false;
546
547 for (int i = 0; i < 2; ++i) {
548 if (constants[i]) {
549 if (constants[i]->AsNullConstant() ||
550 constants[i]->AsVectorConstant()->IsZero()) {
551 has_zero_operand = true;
552 break;
553 }
554 }
555 }
556
557 if (has_zero_operand) {
558 if (float_type->width() == 32) {
559 utils::FloatProxy<float> result(0.0f);
560 std::vector<uint32_t> words = result.GetWords();
561 return const_mgr->GetConstant(float_type, words);
562 }
563 if (float_type->width() == 64) {
564 utils::FloatProxy<double> result(0.0);
565 std::vector<uint32_t> words = result.GetWords();
566 return const_mgr->GetConstant(float_type, words);
567 }
568 return nullptr;
569 }
570
571 if (constants[0] == nullptr || constants[1] == nullptr) {
572 return nullptr;
573 }
574
575 std::vector<const analysis::Constant*> a_components;
576 std::vector<const analysis::Constant*> b_components;
577
578 a_components = constants[0]->GetVectorComponents(const_mgr);
579 b_components = constants[1]->GetVectorComponents(const_mgr);
580
581 utils::FloatProxy<double> result(0.0);
582 std::vector<uint32_t> words = result.GetWords();
583 const analysis::Constant* result_const =
584 const_mgr->GetConstant(float_type, words);
585 for (uint32_t i = 0; i < a_components.size(); ++i) {
586 if (a_components[i] == nullptr || b_components[i] == nullptr) {
587 return nullptr;
588 }
589
590 const analysis::Constant* component = FOLD_FPARITH_OP(*)(
591 new_type, a_components[i], b_components[i], const_mgr);
592 result_const =
593 FOLD_FPARITH_OP(+)(new_type, result_const, component, const_mgr);
594 }
595 return result_const;
596 };
597 }
598
599 // This function defines a |UnaryScalarFoldingRule| that subtracts the constant
600 // from zero.
FoldFNegateOp()601 UnaryScalarFoldingRule FoldFNegateOp() {
602 return [](const analysis::Type* result_type, const analysis::Constant* a,
603 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
604 assert(result_type != nullptr && a != nullptr);
605 assert(result_type == a->type());
606 const analysis::Float* float_type = result_type->AsFloat();
607 assert(float_type != nullptr);
608 if (float_type->width() == 32) {
609 float fa = a->GetFloat();
610 utils::FloatProxy<float> result(-fa);
611 std::vector<uint32_t> words = result.GetWords();
612 return const_mgr->GetConstant(result_type, words);
613 } else if (float_type->width() == 64) {
614 double da = a->GetDouble();
615 utils::FloatProxy<double> result(-da);
616 std::vector<uint32_t> words = result.GetWords();
617 return const_mgr->GetConstant(result_type, words);
618 }
619 return nullptr;
620 };
621 }
622
FoldFNegate()623 ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(FoldFNegateOp()); }
624
FoldFClampFeedingCompare(uint32_t cmp_opcode)625 ConstantFoldingRule FoldFClampFeedingCompare(uint32_t cmp_opcode) {
626 return [cmp_opcode](IRContext* context, Instruction* inst,
627 const std::vector<const analysis::Constant*>& constants)
628 -> const analysis::Constant* {
629 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
630 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
631
632 if (!inst->IsFloatingPointFoldingAllowed()) {
633 return nullptr;
634 }
635
636 uint32_t non_const_idx = (constants[0] ? 1 : 0);
637 uint32_t operand_id = inst->GetSingleWordInOperand(non_const_idx);
638 Instruction* operand_inst = def_use_mgr->GetDef(operand_id);
639
640 analysis::TypeManager* type_mgr = context->get_type_mgr();
641 const analysis::Type* operand_type =
642 type_mgr->GetType(operand_inst->type_id());
643
644 if (!operand_type->AsFloat()) {
645 return nullptr;
646 }
647
648 if (operand_type->AsFloat()->width() != 32 &&
649 operand_type->AsFloat()->width() != 64) {
650 return nullptr;
651 }
652
653 if (operand_inst->opcode() != SpvOpExtInst) {
654 return nullptr;
655 }
656
657 if (operand_inst->GetSingleWordInOperand(1) != GLSLstd450FClamp) {
658 return nullptr;
659 }
660
661 if (constants[1] == nullptr && constants[0] == nullptr) {
662 return nullptr;
663 }
664
665 uint32_t max_id = operand_inst->GetSingleWordInOperand(4);
666 const analysis::Constant* max_const =
667 const_mgr->FindDeclaredConstant(max_id);
668
669 uint32_t min_id = operand_inst->GetSingleWordInOperand(3);
670 const analysis::Constant* min_const =
671 const_mgr->FindDeclaredConstant(min_id);
672
673 bool found_result = false;
674 bool result = false;
675
676 switch (cmp_opcode) {
677 case SpvOpFOrdLessThan:
678 case SpvOpFUnordLessThan:
679 case SpvOpFOrdGreaterThanEqual:
680 case SpvOpFUnordGreaterThanEqual:
681 if (constants[0]) {
682 if (min_const) {
683 if (constants[0]->GetValueAsDouble() <
684 min_const->GetValueAsDouble()) {
685 found_result = true;
686 result = (cmp_opcode == SpvOpFOrdLessThan ||
687 cmp_opcode == SpvOpFUnordLessThan);
688 }
689 }
690 if (max_const) {
691 if (constants[0]->GetValueAsDouble() >=
692 max_const->GetValueAsDouble()) {
693 found_result = true;
694 result = !(cmp_opcode == SpvOpFOrdLessThan ||
695 cmp_opcode == SpvOpFUnordLessThan);
696 }
697 }
698 }
699
700 if (constants[1]) {
701 if (max_const) {
702 if (max_const->GetValueAsDouble() <
703 constants[1]->GetValueAsDouble()) {
704 found_result = true;
705 result = (cmp_opcode == SpvOpFOrdLessThan ||
706 cmp_opcode == SpvOpFUnordLessThan);
707 }
708 }
709
710 if (min_const) {
711 if (min_const->GetValueAsDouble() >=
712 constants[1]->GetValueAsDouble()) {
713 found_result = true;
714 result = !(cmp_opcode == SpvOpFOrdLessThan ||
715 cmp_opcode == SpvOpFUnordLessThan);
716 }
717 }
718 }
719 break;
720 case SpvOpFOrdGreaterThan:
721 case SpvOpFUnordGreaterThan:
722 case SpvOpFOrdLessThanEqual:
723 case SpvOpFUnordLessThanEqual:
724 if (constants[0]) {
725 if (min_const) {
726 if (constants[0]->GetValueAsDouble() <=
727 min_const->GetValueAsDouble()) {
728 found_result = true;
729 result = (cmp_opcode == SpvOpFOrdLessThanEqual ||
730 cmp_opcode == SpvOpFUnordLessThanEqual);
731 }
732 }
733 if (max_const) {
734 if (constants[0]->GetValueAsDouble() >
735 max_const->GetValueAsDouble()) {
736 found_result = true;
737 result = !(cmp_opcode == SpvOpFOrdLessThanEqual ||
738 cmp_opcode == SpvOpFUnordLessThanEqual);
739 }
740 }
741 }
742
743 if (constants[1]) {
744 if (max_const) {
745 if (max_const->GetValueAsDouble() <=
746 constants[1]->GetValueAsDouble()) {
747 found_result = true;
748 result = (cmp_opcode == SpvOpFOrdLessThanEqual ||
749 cmp_opcode == SpvOpFUnordLessThanEqual);
750 }
751 }
752
753 if (min_const) {
754 if (min_const->GetValueAsDouble() >
755 constants[1]->GetValueAsDouble()) {
756 found_result = true;
757 result = !(cmp_opcode == SpvOpFOrdLessThanEqual ||
758 cmp_opcode == SpvOpFUnordLessThanEqual);
759 }
760 }
761 }
762 break;
763 default:
764 return nullptr;
765 }
766
767 if (!found_result) {
768 return nullptr;
769 }
770
771 const analysis::Type* bool_type =
772 context->get_type_mgr()->GetType(inst->type_id());
773 const analysis::Constant* result_const =
774 const_mgr->GetConstant(bool_type, {static_cast<uint32_t>(result)});
775 assert(result_const);
776 return result_const;
777 };
778 }
779
780 } // namespace
781
ConstantFoldingRules()782 ConstantFoldingRules::ConstantFoldingRules() {
783 // Add all folding rules to the list for the opcodes to which they apply.
784 // Note that the order in which rules are added to the list matters. If a rule
785 // applies to the instruction, the rest of the rules will not be attempted.
786 // Take that into consideration.
787
788 rules_[SpvOpCompositeConstruct].push_back(FoldCompositeWithConstants());
789
790 rules_[SpvOpCompositeExtract].push_back(FoldExtractWithConstants());
791
792 rules_[SpvOpConvertFToS].push_back(FoldFToI());
793 rules_[SpvOpConvertFToU].push_back(FoldFToI());
794 rules_[SpvOpConvertSToF].push_back(FoldIToF());
795 rules_[SpvOpConvertUToF].push_back(FoldIToF());
796
797 rules_[SpvOpDot].push_back(FoldOpDotWithConstants());
798 rules_[SpvOpFAdd].push_back(FoldFAdd());
799 rules_[SpvOpFDiv].push_back(FoldFDiv());
800 rules_[SpvOpFMul].push_back(FoldFMul());
801 rules_[SpvOpFSub].push_back(FoldFSub());
802
803 rules_[SpvOpFOrdEqual].push_back(FoldFOrdEqual());
804
805 rules_[SpvOpFUnordEqual].push_back(FoldFUnordEqual());
806
807 rules_[SpvOpFOrdNotEqual].push_back(FoldFOrdNotEqual());
808
809 rules_[SpvOpFUnordNotEqual].push_back(FoldFUnordNotEqual());
810
811 rules_[SpvOpFOrdLessThan].push_back(FoldFOrdLessThan());
812 rules_[SpvOpFOrdLessThan].push_back(
813 FoldFClampFeedingCompare(SpvOpFOrdLessThan));
814
815 rules_[SpvOpFUnordLessThan].push_back(FoldFUnordLessThan());
816 rules_[SpvOpFUnordLessThan].push_back(
817 FoldFClampFeedingCompare(SpvOpFUnordLessThan));
818
819 rules_[SpvOpFOrdGreaterThan].push_back(FoldFOrdGreaterThan());
820 rules_[SpvOpFOrdGreaterThan].push_back(
821 FoldFClampFeedingCompare(SpvOpFOrdGreaterThan));
822
823 rules_[SpvOpFUnordGreaterThan].push_back(FoldFUnordGreaterThan());
824 rules_[SpvOpFUnordGreaterThan].push_back(
825 FoldFClampFeedingCompare(SpvOpFUnordGreaterThan));
826
827 rules_[SpvOpFOrdLessThanEqual].push_back(FoldFOrdLessThanEqual());
828 rules_[SpvOpFOrdLessThanEqual].push_back(
829 FoldFClampFeedingCompare(SpvOpFOrdLessThanEqual));
830
831 rules_[SpvOpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual());
832 rules_[SpvOpFUnordLessThanEqual].push_back(
833 FoldFClampFeedingCompare(SpvOpFUnordLessThanEqual));
834
835 rules_[SpvOpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual());
836 rules_[SpvOpFOrdGreaterThanEqual].push_back(
837 FoldFClampFeedingCompare(SpvOpFOrdGreaterThanEqual));
838
839 rules_[SpvOpFUnordGreaterThanEqual].push_back(FoldFUnordGreaterThanEqual());
840 rules_[SpvOpFUnordGreaterThanEqual].push_back(
841 FoldFClampFeedingCompare(SpvOpFUnordGreaterThanEqual));
842
843 rules_[SpvOpVectorShuffle].push_back(FoldVectorShuffleWithConstants());
844 rules_[SpvOpVectorTimesScalar].push_back(FoldVectorTimesScalar());
845
846 rules_[SpvOpFNegate].push_back(FoldFNegate());
847 }
848 } // namespace opt
849 } // namespace spvtools
850