1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/strength_reduction_pass.h"
16 
17 #include <algorithm>
18 #include <cstdio>
19 #include <cstring>
20 #include <memory>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <utility>
24 #include <vector>
25 
26 #include "source/opt/def_use_manager.h"
27 #include "source/opt/ir_context.h"
28 #include "source/opt/log.h"
29 #include "source/opt/reflect.h"
30 
31 namespace {
32 // Count the number of trailing zeros in the binary representation of
33 // |constVal|.
CountTrailingZeros(uint32_t constVal)34 uint32_t CountTrailingZeros(uint32_t constVal) {
35   // Faster if we use the hardware count trailing zeros instruction.
36   // If not available, we could create a table.
37   uint32_t shiftAmount = 0;
38   while ((constVal & 1) == 0) {
39     ++shiftAmount;
40     constVal = (constVal >> 1);
41   }
42   return shiftAmount;
43 }
44 
45 // Return true if |val| is a power of 2.
IsPowerOf2(uint32_t val)46 bool IsPowerOf2(uint32_t val) {
47   // The idea is that the & will clear out the least
48   // significant 1 bit.  If it is a power of 2, then
49   // there is exactly 1 bit set, and the value becomes 0.
50   if (val == 0) return false;
51   return ((val - 1) & val) == 0;
52 }
53 
54 }  // namespace
55 
56 namespace spvtools {
57 namespace opt {
58 
Process()59 Pass::Status StrengthReductionPass::Process() {
60   // Initialize the member variables on a per module basis.
61   bool modified = false;
62   int32_type_id_ = 0;
63   uint32_type_id_ = 0;
64   std::memset(constant_ids_, 0, sizeof(constant_ids_));
65 
66   FindIntTypesAndConstants();
67   modified = ScanFunctions();
68   return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
69 }
70 
ReplaceMultiplyByPowerOf2(BasicBlock::iterator * inst)71 bool StrengthReductionPass::ReplaceMultiplyByPowerOf2(
72     BasicBlock::iterator* inst) {
73   assert((*inst)->opcode() == SpvOp::SpvOpIMul &&
74          "Only works for multiplication of integers.");
75   bool modified = false;
76 
77   // Currently only works on 32-bit integers.
78   if ((*inst)->type_id() != int32_type_id_ &&
79       (*inst)->type_id() != uint32_type_id_) {
80     return modified;
81   }
82 
83   // Check the operands for a constant that is a power of 2.
84   for (int i = 0; i < 2; i++) {
85     uint32_t opId = (*inst)->GetSingleWordInOperand(i);
86     Instruction* opInst = get_def_use_mgr()->GetDef(opId);
87     if (opInst->opcode() == SpvOp::SpvOpConstant) {
88       // We found a constant operand.
89       uint32_t constVal = opInst->GetSingleWordOperand(2);
90 
91       if (IsPowerOf2(constVal)) {
92         modified = true;
93         uint32_t shiftAmount = CountTrailingZeros(constVal);
94         uint32_t shiftConstResultId = GetConstantId(shiftAmount);
95 
96         // Create the new instruction.
97         uint32_t newResultId = TakeNextId();
98         std::vector<Operand> newOperands;
99         newOperands.push_back((*inst)->GetInOperand(1 - i));
100         Operand shiftOperand(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
101                              {shiftConstResultId});
102         newOperands.push_back(shiftOperand);
103         std::unique_ptr<Instruction> newInstruction(
104             new Instruction(context(), SpvOp::SpvOpShiftLeftLogical,
105                             (*inst)->type_id(), newResultId, newOperands));
106 
107         // Insert the new instruction and update the data structures.
108         (*inst) = (*inst).InsertBefore(std::move(newInstruction));
109         get_def_use_mgr()->AnalyzeInstDefUse(&*(*inst));
110         ++(*inst);
111         context()->ReplaceAllUsesWith((*inst)->result_id(), newResultId);
112 
113         // Remove the old instruction.
114         Instruction* inst_to_delete = &*(*inst);
115         --(*inst);
116         context()->KillInst(inst_to_delete);
117 
118         // We do not want to replace the instruction twice if both operands
119         // are constants that are a power of 2.  So we break here.
120         break;
121       }
122     }
123   }
124 
125   return modified;
126 }
127 
FindIntTypesAndConstants()128 void StrengthReductionPass::FindIntTypesAndConstants() {
129   analysis::Integer int32(32, true);
130   int32_type_id_ = context()->get_type_mgr()->GetId(&int32);
131   analysis::Integer uint32(32, false);
132   uint32_type_id_ = context()->get_type_mgr()->GetId(&uint32);
133   for (auto iter = get_module()->types_values_begin();
134        iter != get_module()->types_values_end(); ++iter) {
135     switch (iter->opcode()) {
136       case SpvOp::SpvOpConstant:
137         if (iter->type_id() == uint32_type_id_) {
138           uint32_t value = iter->GetSingleWordOperand(2);
139           if (value <= 32) constant_ids_[value] = iter->result_id();
140         }
141         break;
142       default:
143         break;
144     }
145   }
146 }
147 
GetConstantId(uint32_t val)148 uint32_t StrengthReductionPass::GetConstantId(uint32_t val) {
149   assert(val <= 32 &&
150          "This function does not handle constants larger than 32.");
151 
152   if (constant_ids_[val] == 0) {
153     if (uint32_type_id_ == 0) {
154       analysis::Integer uint(32, false);
155       uint32_type_id_ = context()->get_type_mgr()->GetTypeInstruction(&uint);
156     }
157 
158     // Construct the constant.
159     uint32_t resultId = TakeNextId();
160     Operand constant(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
161                      {val});
162     std::unique_ptr<Instruction> newConstant(
163         new Instruction(context(), SpvOp::SpvOpConstant, uint32_type_id_,
164                         resultId, {constant}));
165     get_module()->AddGlobalValue(std::move(newConstant));
166 
167     // Notify the DefUseManager about this constant.
168     auto constantIter = --get_module()->types_values_end();
169     get_def_use_mgr()->AnalyzeInstDef(&*constantIter);
170 
171     // Store the result id for next time.
172     constant_ids_[val] = resultId;
173   }
174 
175   return constant_ids_[val];
176 }
177 
ScanFunctions()178 bool StrengthReductionPass::ScanFunctions() {
179   // I did not use |ForEachInst| in the module because the function that acts on
180   // the instruction gets a pointer to the instruction.  We cannot use that to
181   // insert a new instruction.  I want an iterator.
182   bool modified = false;
183   for (auto& func : *get_module()) {
184     for (auto& bb : func) {
185       for (auto inst = bb.begin(); inst != bb.end(); ++inst) {
186         switch (inst->opcode()) {
187           case SpvOp::SpvOpIMul:
188             if (ReplaceMultiplyByPowerOf2(&inst)) modified = true;
189             break;
190           default:
191             break;
192         }
193       }
194     }
195   }
196   return modified;
197 }
198 
199 }  // namespace opt
200 }  // namespace spvtools
201