1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/fold.h"
16
17 #include <cassert>
18 #include <cstdint>
19 #include <vector>
20
21 #include "source/opt/const_folding_rules.h"
22 #include "source/opt/def_use_manager.h"
23 #include "source/opt/folding_rules.h"
24 #include "source/opt/ir_builder.h"
25 #include "source/opt/ir_context.h"
26
27 namespace spvtools {
28 namespace opt {
29 namespace {
30
31 #ifndef INT32_MIN
32 #define INT32_MIN (-2147483648)
33 #endif
34
35 #ifndef INT32_MAX
36 #define INT32_MAX 2147483647
37 #endif
38
39 #ifndef UINT32_MAX
40 #define UINT32_MAX 0xffffffff /* 4294967295U */
41 #endif
42
43 } // namespace
44
UnaryOperate(SpvOp opcode,uint32_t operand) const45 uint32_t InstructionFolder::UnaryOperate(SpvOp opcode, uint32_t operand) const {
46 switch (opcode) {
47 // Arthimetics
48 case SpvOp::SpvOpSNegate:
49 return -static_cast<int32_t>(operand);
50 case SpvOp::SpvOpNot:
51 return ~operand;
52 case SpvOp::SpvOpLogicalNot:
53 return !static_cast<bool>(operand);
54 default:
55 assert(false &&
56 "Unsupported unary operation for OpSpecConstantOp instruction");
57 return 0u;
58 }
59 }
60
BinaryOperate(SpvOp opcode,uint32_t a,uint32_t b) const61 uint32_t InstructionFolder::BinaryOperate(SpvOp opcode, uint32_t a,
62 uint32_t b) const {
63 switch (opcode) {
64 // Arthimetics
65 case SpvOp::SpvOpIAdd:
66 return a + b;
67 case SpvOp::SpvOpISub:
68 return a - b;
69 case SpvOp::SpvOpIMul:
70 return a * b;
71 case SpvOp::SpvOpUDiv:
72 if (b != 0) {
73 return a / b;
74 } else {
75 // Dividing by 0 is undefined, so we will just pick 0.
76 return 0;
77 }
78 case SpvOp::SpvOpSDiv:
79 if (b != 0u) {
80 return (static_cast<int32_t>(a)) / (static_cast<int32_t>(b));
81 } else {
82 // Dividing by 0 is undefined, so we will just pick 0.
83 return 0;
84 }
85 case SpvOp::SpvOpSRem: {
86 // The sign of non-zero result comes from the first operand: a. This is
87 // guaranteed by C++11 rules for integer division operator. The division
88 // result is rounded toward zero, so the result of '%' has the sign of
89 // the first operand.
90 if (b != 0u) {
91 return static_cast<int32_t>(a) % static_cast<int32_t>(b);
92 } else {
93 // Remainder when dividing with 0 is undefined, so we will just pick 0.
94 return 0;
95 }
96 }
97 case SpvOp::SpvOpSMod: {
98 // The sign of non-zero result comes from the second operand: b
99 if (b != 0u) {
100 int32_t rem = BinaryOperate(SpvOp::SpvOpSRem, a, b);
101 int32_t b_prim = static_cast<int32_t>(b);
102 return (rem + b_prim) % b_prim;
103 } else {
104 // Mod with 0 is undefined, so we will just pick 0.
105 return 0;
106 }
107 }
108 case SpvOp::SpvOpUMod:
109 if (b != 0u) {
110 return (a % b);
111 } else {
112 // Mod with 0 is undefined, so we will just pick 0.
113 return 0;
114 }
115
116 // Shifting
117 case SpvOp::SpvOpShiftRightLogical:
118 if (b > 32) {
119 // This is undefined behaviour. Choose 0 for consistency.
120 return 0;
121 }
122 return a >> b;
123 case SpvOp::SpvOpShiftRightArithmetic:
124 if (b > 32) {
125 // This is undefined behaviour. Choose 0 for consistency.
126 return 0;
127 }
128 return (static_cast<int32_t>(a)) >> b;
129 case SpvOp::SpvOpShiftLeftLogical:
130 if (b > 32) {
131 // This is undefined behaviour. Choose 0 for consistency.
132 return 0;
133 }
134 return a << b;
135
136 // Bitwise operations
137 case SpvOp::SpvOpBitwiseOr:
138 return a | b;
139 case SpvOp::SpvOpBitwiseAnd:
140 return a & b;
141 case SpvOp::SpvOpBitwiseXor:
142 return a ^ b;
143
144 // Logical
145 case SpvOp::SpvOpLogicalEqual:
146 return (static_cast<bool>(a)) == (static_cast<bool>(b));
147 case SpvOp::SpvOpLogicalNotEqual:
148 return (static_cast<bool>(a)) != (static_cast<bool>(b));
149 case SpvOp::SpvOpLogicalOr:
150 return (static_cast<bool>(a)) || (static_cast<bool>(b));
151 case SpvOp::SpvOpLogicalAnd:
152 return (static_cast<bool>(a)) && (static_cast<bool>(b));
153
154 // Comparison
155 case SpvOp::SpvOpIEqual:
156 return a == b;
157 case SpvOp::SpvOpINotEqual:
158 return a != b;
159 case SpvOp::SpvOpULessThan:
160 return a < b;
161 case SpvOp::SpvOpSLessThan:
162 return (static_cast<int32_t>(a)) < (static_cast<int32_t>(b));
163 case SpvOp::SpvOpUGreaterThan:
164 return a > b;
165 case SpvOp::SpvOpSGreaterThan:
166 return (static_cast<int32_t>(a)) > (static_cast<int32_t>(b));
167 case SpvOp::SpvOpULessThanEqual:
168 return a <= b;
169 case SpvOp::SpvOpSLessThanEqual:
170 return (static_cast<int32_t>(a)) <= (static_cast<int32_t>(b));
171 case SpvOp::SpvOpUGreaterThanEqual:
172 return a >= b;
173 case SpvOp::SpvOpSGreaterThanEqual:
174 return (static_cast<int32_t>(a)) >= (static_cast<int32_t>(b));
175 default:
176 assert(false &&
177 "Unsupported binary operation for OpSpecConstantOp instruction");
178 return 0u;
179 }
180 }
181
TernaryOperate(SpvOp opcode,uint32_t a,uint32_t b,uint32_t c) const182 uint32_t InstructionFolder::TernaryOperate(SpvOp opcode, uint32_t a, uint32_t b,
183 uint32_t c) const {
184 switch (opcode) {
185 case SpvOp::SpvOpSelect:
186 return (static_cast<bool>(a)) ? b : c;
187 default:
188 assert(false &&
189 "Unsupported ternary operation for OpSpecConstantOp instruction");
190 return 0u;
191 }
192 }
193
OperateWords(SpvOp opcode,const std::vector<uint32_t> & operand_words) const194 uint32_t InstructionFolder::OperateWords(
195 SpvOp opcode, const std::vector<uint32_t>& operand_words) const {
196 switch (operand_words.size()) {
197 case 1:
198 return UnaryOperate(opcode, operand_words.front());
199 case 2:
200 return BinaryOperate(opcode, operand_words.front(), operand_words.back());
201 case 3:
202 return TernaryOperate(opcode, operand_words[0], operand_words[1],
203 operand_words[2]);
204 default:
205 assert(false && "Invalid number of operands");
206 return 0;
207 }
208 }
209
FoldInstructionInternal(Instruction * inst) const210 bool InstructionFolder::FoldInstructionInternal(Instruction* inst) const {
211 auto identity_map = [](uint32_t id) { return id; };
212 Instruction* folded_inst = FoldInstructionToConstant(inst, identity_map);
213 if (folded_inst != nullptr) {
214 inst->SetOpcode(SpvOpCopyObject);
215 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {folded_inst->result_id()}}});
216 return true;
217 }
218
219 SpvOp opcode = inst->opcode();
220 analysis::ConstantManager* const_manager = context_->get_constant_mgr();
221
222 std::vector<const analysis::Constant*> constants =
223 const_manager->GetOperandConstants(inst);
224
225 for (const FoldingRule& rule : GetFoldingRules().GetRulesForOpcode(opcode)) {
226 if (rule(context_, inst, constants)) {
227 return true;
228 }
229 }
230 return false;
231 }
232
233 // Returns the result of performing an operation on scalar constant operands.
234 // This function extracts the operand values as 32 bit words and returns the
235 // result in 32 bit word. Scalar constants with longer than 32-bit width are
236 // not accepted in this function.
FoldScalars(SpvOp opcode,const std::vector<const analysis::Constant * > & operands) const237 uint32_t InstructionFolder::FoldScalars(
238 SpvOp opcode,
239 const std::vector<const analysis::Constant*>& operands) const {
240 assert(IsFoldableOpcode(opcode) &&
241 "Unhandled instruction opcode in FoldScalars");
242 std::vector<uint32_t> operand_values_in_raw_words;
243 for (const auto& operand : operands) {
244 if (const analysis::ScalarConstant* scalar = operand->AsScalarConstant()) {
245 const auto& scalar_words = scalar->words();
246 assert(scalar_words.size() == 1 &&
247 "Scalar constants with longer than 32-bit width are not allowed "
248 "in FoldScalars()");
249 operand_values_in_raw_words.push_back(scalar_words.front());
250 } else if (operand->AsNullConstant()) {
251 operand_values_in_raw_words.push_back(0u);
252 } else {
253 assert(false &&
254 "FoldScalars() only accepts ScalarConst or NullConst type of "
255 "constant");
256 }
257 }
258 return OperateWords(opcode, operand_values_in_raw_words);
259 }
260
FoldBinaryIntegerOpToConstant(Instruction * inst,const std::function<uint32_t (uint32_t)> & id_map,uint32_t * result) const261 bool InstructionFolder::FoldBinaryIntegerOpToConstant(
262 Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map,
263 uint32_t* result) const {
264 SpvOp opcode = inst->opcode();
265 analysis::ConstantManager* const_manger = context_->get_constant_mgr();
266
267 uint32_t ids[2];
268 const analysis::IntConstant* constants[2];
269 for (uint32_t i = 0; i < 2; i++) {
270 const Operand* operand = &inst->GetInOperand(i);
271 if (operand->type != SPV_OPERAND_TYPE_ID) {
272 return false;
273 }
274 ids[i] = id_map(operand->words[0]);
275 const analysis::Constant* constant =
276 const_manger->FindDeclaredConstant(ids[i]);
277 constants[i] = (constant != nullptr ? constant->AsIntConstant() : nullptr);
278 }
279
280 switch (opcode) {
281 // Arthimetics
282 case SpvOp::SpvOpIMul:
283 for (uint32_t i = 0; i < 2; i++) {
284 if (constants[i] != nullptr && constants[i]->IsZero()) {
285 *result = 0;
286 return true;
287 }
288 }
289 break;
290 case SpvOp::SpvOpUDiv:
291 case SpvOp::SpvOpSDiv:
292 case SpvOp::SpvOpSRem:
293 case SpvOp::SpvOpSMod:
294 case SpvOp::SpvOpUMod:
295 // This changes undefined behaviour (ie divide by 0) into a 0.
296 for (uint32_t i = 0; i < 2; i++) {
297 if (constants[i] != nullptr && constants[i]->IsZero()) {
298 *result = 0;
299 return true;
300 }
301 }
302 break;
303
304 // Shifting
305 case SpvOp::SpvOpShiftRightLogical:
306 case SpvOp::SpvOpShiftLeftLogical:
307 if (constants[1] != nullptr) {
308 // When shifting by a value larger than the size of the result, the
309 // result is undefined. We are setting the undefined behaviour to a
310 // result of 0.
311 uint32_t shift_amount = constants[1]->GetU32BitValue();
312 if (shift_amount >= 32) {
313 *result = 0;
314 return true;
315 }
316 }
317 break;
318
319 // Bitwise operations
320 case SpvOp::SpvOpBitwiseOr:
321 for (uint32_t i = 0; i < 2; i++) {
322 if (constants[i] != nullptr) {
323 // TODO: Change the mask against a value based on the bit width of the
324 // instruction result type. This way we can handle say 16-bit values
325 // as well.
326 uint32_t mask = constants[i]->GetU32BitValue();
327 if (mask == 0xFFFFFFFF) {
328 *result = 0xFFFFFFFF;
329 return true;
330 }
331 }
332 }
333 break;
334 case SpvOp::SpvOpBitwiseAnd:
335 for (uint32_t i = 0; i < 2; i++) {
336 if (constants[i] != nullptr) {
337 if (constants[i]->IsZero()) {
338 *result = 0;
339 return true;
340 }
341 }
342 }
343 break;
344
345 // Comparison
346 case SpvOp::SpvOpULessThan:
347 if (constants[0] != nullptr &&
348 constants[0]->GetU32BitValue() == UINT32_MAX) {
349 *result = false;
350 return true;
351 }
352 if (constants[1] != nullptr && constants[1]->GetU32BitValue() == 0) {
353 *result = false;
354 return true;
355 }
356 break;
357 case SpvOp::SpvOpSLessThan:
358 if (constants[0] != nullptr &&
359 constants[0]->GetS32BitValue() == INT32_MAX) {
360 *result = false;
361 return true;
362 }
363 if (constants[1] != nullptr &&
364 constants[1]->GetS32BitValue() == INT32_MIN) {
365 *result = false;
366 return true;
367 }
368 break;
369 case SpvOp::SpvOpUGreaterThan:
370 if (constants[0] != nullptr && constants[0]->IsZero()) {
371 *result = false;
372 return true;
373 }
374 if (constants[1] != nullptr &&
375 constants[1]->GetU32BitValue() == UINT32_MAX) {
376 *result = false;
377 return true;
378 }
379 break;
380 case SpvOp::SpvOpSGreaterThan:
381 if (constants[0] != nullptr &&
382 constants[0]->GetS32BitValue() == INT32_MIN) {
383 *result = false;
384 return true;
385 }
386 if (constants[1] != nullptr &&
387 constants[1]->GetS32BitValue() == INT32_MAX) {
388 *result = false;
389 return true;
390 }
391 break;
392 case SpvOp::SpvOpULessThanEqual:
393 if (constants[0] != nullptr && constants[0]->IsZero()) {
394 *result = true;
395 return true;
396 }
397 if (constants[1] != nullptr &&
398 constants[1]->GetU32BitValue() == UINT32_MAX) {
399 *result = true;
400 return true;
401 }
402 break;
403 case SpvOp::SpvOpSLessThanEqual:
404 if (constants[0] != nullptr &&
405 constants[0]->GetS32BitValue() == INT32_MIN) {
406 *result = true;
407 return true;
408 }
409 if (constants[1] != nullptr &&
410 constants[1]->GetS32BitValue() == INT32_MAX) {
411 *result = true;
412 return true;
413 }
414 break;
415 case SpvOp::SpvOpUGreaterThanEqual:
416 if (constants[0] != nullptr &&
417 constants[0]->GetU32BitValue() == UINT32_MAX) {
418 *result = true;
419 return true;
420 }
421 if (constants[1] != nullptr && constants[1]->GetU32BitValue() == 0) {
422 *result = true;
423 return true;
424 }
425 break;
426 case SpvOp::SpvOpSGreaterThanEqual:
427 if (constants[0] != nullptr &&
428 constants[0]->GetS32BitValue() == INT32_MAX) {
429 *result = true;
430 return true;
431 }
432 if (constants[1] != nullptr &&
433 constants[1]->GetS32BitValue() == INT32_MIN) {
434 *result = true;
435 return true;
436 }
437 break;
438 default:
439 break;
440 }
441 return false;
442 }
443
FoldBinaryBooleanOpToConstant(Instruction * inst,const std::function<uint32_t (uint32_t)> & id_map,uint32_t * result) const444 bool InstructionFolder::FoldBinaryBooleanOpToConstant(
445 Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map,
446 uint32_t* result) const {
447 SpvOp opcode = inst->opcode();
448 analysis::ConstantManager* const_manger = context_->get_constant_mgr();
449
450 uint32_t ids[2];
451 const analysis::BoolConstant* constants[2];
452 for (uint32_t i = 0; i < 2; i++) {
453 const Operand* operand = &inst->GetInOperand(i);
454 if (operand->type != SPV_OPERAND_TYPE_ID) {
455 return false;
456 }
457 ids[i] = id_map(operand->words[0]);
458 const analysis::Constant* constant =
459 const_manger->FindDeclaredConstant(ids[i]);
460 constants[i] = (constant != nullptr ? constant->AsBoolConstant() : nullptr);
461 }
462
463 switch (opcode) {
464 // Logical
465 case SpvOp::SpvOpLogicalOr:
466 for (uint32_t i = 0; i < 2; i++) {
467 if (constants[i] != nullptr) {
468 if (constants[i]->value()) {
469 *result = true;
470 return true;
471 }
472 }
473 }
474 break;
475 case SpvOp::SpvOpLogicalAnd:
476 for (uint32_t i = 0; i < 2; i++) {
477 if (constants[i] != nullptr) {
478 if (!constants[i]->value()) {
479 *result = false;
480 return true;
481 }
482 }
483 }
484 break;
485
486 default:
487 break;
488 }
489 return false;
490 }
491
FoldIntegerOpToConstant(Instruction * inst,const std::function<uint32_t (uint32_t)> & id_map,uint32_t * result) const492 bool InstructionFolder::FoldIntegerOpToConstant(
493 Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map,
494 uint32_t* result) const {
495 assert(IsFoldableOpcode(inst->opcode()) &&
496 "Unhandled instruction opcode in FoldScalars");
497 switch (inst->NumInOperands()) {
498 case 2:
499 return FoldBinaryIntegerOpToConstant(inst, id_map, result) ||
500 FoldBinaryBooleanOpToConstant(inst, id_map, result);
501 default:
502 return false;
503 }
504 }
505
FoldVectors(SpvOp opcode,uint32_t num_dims,const std::vector<const analysis::Constant * > & operands) const506 std::vector<uint32_t> InstructionFolder::FoldVectors(
507 SpvOp opcode, uint32_t num_dims,
508 const std::vector<const analysis::Constant*>& operands) const {
509 assert(IsFoldableOpcode(opcode) &&
510 "Unhandled instruction opcode in FoldVectors");
511 std::vector<uint32_t> result;
512 for (uint32_t d = 0; d < num_dims; d++) {
513 std::vector<uint32_t> operand_values_for_one_dimension;
514 for (const auto& operand : operands) {
515 if (const analysis::VectorConstant* vector_operand =
516 operand->AsVectorConstant()) {
517 // Extract the raw value of the scalar component constants
518 // in 32-bit words here. The reason of not using FoldScalars() here
519 // is that we do not create temporary null constants as components
520 // when the vector operand is a NullConstant because Constant creation
521 // may need extra checks for the validity and that is not manageed in
522 // here.
523 if (const analysis::ScalarConstant* scalar_component =
524 vector_operand->GetComponents().at(d)->AsScalarConstant()) {
525 const auto& scalar_words = scalar_component->words();
526 assert(
527 scalar_words.size() == 1 &&
528 "Vector components with longer than 32-bit width are not allowed "
529 "in FoldVectors()");
530 operand_values_for_one_dimension.push_back(scalar_words.front());
531 } else if (operand->AsNullConstant()) {
532 operand_values_for_one_dimension.push_back(0u);
533 } else {
534 assert(false &&
535 "VectorConst should only has ScalarConst or NullConst as "
536 "components");
537 }
538 } else if (operand->AsNullConstant()) {
539 operand_values_for_one_dimension.push_back(0u);
540 } else {
541 assert(false &&
542 "FoldVectors() only accepts VectorConst or NullConst type of "
543 "constant");
544 }
545 }
546 result.push_back(OperateWords(opcode, operand_values_for_one_dimension));
547 }
548 return result;
549 }
550
IsFoldableOpcode(SpvOp opcode) const551 bool InstructionFolder::IsFoldableOpcode(SpvOp opcode) const {
552 // NOTE: Extend to more opcodes as new cases are handled in the folder
553 // functions.
554 switch (opcode) {
555 case SpvOp::SpvOpBitwiseAnd:
556 case SpvOp::SpvOpBitwiseOr:
557 case SpvOp::SpvOpBitwiseXor:
558 case SpvOp::SpvOpIAdd:
559 case SpvOp::SpvOpIEqual:
560 case SpvOp::SpvOpIMul:
561 case SpvOp::SpvOpINotEqual:
562 case SpvOp::SpvOpISub:
563 case SpvOp::SpvOpLogicalAnd:
564 case SpvOp::SpvOpLogicalEqual:
565 case SpvOp::SpvOpLogicalNot:
566 case SpvOp::SpvOpLogicalNotEqual:
567 case SpvOp::SpvOpLogicalOr:
568 case SpvOp::SpvOpNot:
569 case SpvOp::SpvOpSDiv:
570 case SpvOp::SpvOpSelect:
571 case SpvOp::SpvOpSGreaterThan:
572 case SpvOp::SpvOpSGreaterThanEqual:
573 case SpvOp::SpvOpShiftLeftLogical:
574 case SpvOp::SpvOpShiftRightArithmetic:
575 case SpvOp::SpvOpShiftRightLogical:
576 case SpvOp::SpvOpSLessThan:
577 case SpvOp::SpvOpSLessThanEqual:
578 case SpvOp::SpvOpSMod:
579 case SpvOp::SpvOpSNegate:
580 case SpvOp::SpvOpSRem:
581 case SpvOp::SpvOpUDiv:
582 case SpvOp::SpvOpUGreaterThan:
583 case SpvOp::SpvOpUGreaterThanEqual:
584 case SpvOp::SpvOpULessThan:
585 case SpvOp::SpvOpULessThanEqual:
586 case SpvOp::SpvOpUMod:
587 return true;
588 default:
589 return false;
590 }
591 }
592
IsFoldableConstant(const analysis::Constant * cst) const593 bool InstructionFolder::IsFoldableConstant(
594 const analysis::Constant* cst) const {
595 // Currently supported constants are 32-bit values or null constants.
596 if (const analysis::ScalarConstant* scalar = cst->AsScalarConstant())
597 return scalar->words().size() == 1;
598 else
599 return cst->AsNullConstant() != nullptr;
600 }
601
FoldInstructionToConstant(Instruction * inst,std::function<uint32_t (uint32_t)> id_map) const602 Instruction* InstructionFolder::FoldInstructionToConstant(
603 Instruction* inst, std::function<uint32_t(uint32_t)> id_map) const {
604 analysis::ConstantManager* const_mgr = context_->get_constant_mgr();
605
606 if (!inst->IsFoldableByFoldScalar() &&
607 !GetConstantFoldingRules().HasFoldingRule(inst->opcode())) {
608 return nullptr;
609 }
610 // Collect the values of the constant parameters.
611 std::vector<const analysis::Constant*> constants;
612 bool missing_constants = false;
613 inst->ForEachInId([&constants, &missing_constants, const_mgr,
614 &id_map](uint32_t* op_id) {
615 uint32_t id = id_map(*op_id);
616 const analysis::Constant* const_op = const_mgr->FindDeclaredConstant(id);
617 if (!const_op) {
618 constants.push_back(nullptr);
619 missing_constants = true;
620 } else {
621 constants.push_back(const_op);
622 }
623 });
624
625 if (GetConstantFoldingRules().HasFoldingRule(inst->opcode())) {
626 const analysis::Constant* folded_const = nullptr;
627 for (auto rule :
628 GetConstantFoldingRules().GetRulesForOpcode(inst->opcode())) {
629 folded_const = rule(context_, inst, constants);
630 if (folded_const != nullptr) {
631 Instruction* const_inst =
632 const_mgr->GetDefiningInstruction(folded_const, inst->type_id());
633 assert(const_inst->type_id() == inst->type_id());
634 // May be a new instruction that needs to be analysed.
635 context_->UpdateDefUse(const_inst);
636 return const_inst;
637 }
638 }
639 }
640
641 uint32_t result_val = 0;
642 bool successful = false;
643 // If all parameters are constant, fold the instruction to a constant.
644 if (!missing_constants && inst->IsFoldableByFoldScalar()) {
645 result_val = FoldScalars(inst->opcode(), constants);
646 successful = true;
647 }
648
649 if (!successful && inst->IsFoldableByFoldScalar()) {
650 successful = FoldIntegerOpToConstant(inst, id_map, &result_val);
651 }
652
653 if (successful) {
654 const analysis::Constant* result_const =
655 const_mgr->GetConstant(const_mgr->GetType(inst), {result_val});
656 Instruction* folded_inst =
657 const_mgr->GetDefiningInstruction(result_const, inst->type_id());
658 return folded_inst;
659 }
660 return nullptr;
661 }
662
IsFoldableType(Instruction * type_inst) const663 bool InstructionFolder::IsFoldableType(Instruction* type_inst) const {
664 // Support 32-bit integers.
665 if (type_inst->opcode() == SpvOpTypeInt) {
666 return type_inst->GetSingleWordInOperand(0) == 32;
667 }
668 // Support booleans.
669 if (type_inst->opcode() == SpvOpTypeBool) {
670 return true;
671 }
672 // Nothing else yet.
673 return false;
674 }
675
FoldInstruction(Instruction * inst) const676 bool InstructionFolder::FoldInstruction(Instruction* inst) const {
677 bool modified = false;
678 Instruction* folded_inst(inst);
679 while (folded_inst->opcode() != SpvOpCopyObject &&
680 FoldInstructionInternal(&*folded_inst)) {
681 modified = true;
682 }
683 return modified;
684 }
685
686 } // namespace opt
687 } // namespace spvtools
688