1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/private_to_local_pass.h"
16 
17 #include <memory>
18 #include <utility>
19 #include <vector>
20 
21 #include "source/opt/ir_context.h"
22 
23 namespace spvtools {
24 namespace opt {
25 namespace {
26 
27 const uint32_t kVariableStorageClassInIdx = 0;
28 const uint32_t kSpvTypePointerTypeIdInIdx = 1;
29 
30 }  // namespace
31 
Process()32 Pass::Status PrivateToLocalPass::Process() {
33   bool modified = false;
34 
35   // Private variables require the shader capability.  If this is not a shader,
36   // there is no work to do.
37   if (context()->get_feature_mgr()->HasCapability(SpvCapabilityAddresses))
38     return Status::SuccessWithoutChange;
39 
40   std::vector<std::pair<Instruction*, Function*>> variables_to_move;
41   for (auto& inst : context()->types_values()) {
42     if (inst.opcode() != SpvOpVariable) {
43       continue;
44     }
45 
46     if (inst.GetSingleWordInOperand(kVariableStorageClassInIdx) !=
47         SpvStorageClassPrivate) {
48       continue;
49     }
50 
51     Function* target_function = FindLocalFunction(inst);
52     if (target_function != nullptr) {
53       variables_to_move.push_back({&inst, target_function});
54     }
55   }
56 
57   modified = !variables_to_move.empty();
58   for (auto p : variables_to_move) {
59     MoveVariable(p.first, p.second);
60   }
61 
62   return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
63 }
64 
FindLocalFunction(const Instruction & inst) const65 Function* PrivateToLocalPass::FindLocalFunction(const Instruction& inst) const {
66   bool found_first_use = false;
67   Function* target_function = nullptr;
68   context()->get_def_use_mgr()->ForEachUser(
69       inst.result_id(),
70       [&target_function, &found_first_use, this](Instruction* use) {
71         BasicBlock* current_block = context()->get_instr_block(use);
72         if (current_block == nullptr) {
73           return;
74         }
75 
76         if (!IsValidUse(use)) {
77           found_first_use = true;
78           target_function = nullptr;
79           return;
80         }
81         Function* current_function = current_block->GetParent();
82         if (!found_first_use) {
83           found_first_use = true;
84           target_function = current_function;
85         } else if (target_function != current_function) {
86           target_function = nullptr;
87         }
88       });
89   return target_function;
90 }  // namespace opt
91 
MoveVariable(Instruction * variable,Function * function)92 void PrivateToLocalPass::MoveVariable(Instruction* variable,
93                                       Function* function) {
94   // The variable needs to be removed from the global section, and placed in the
95   // header of the function.  First step remove from the global list.
96   variable->RemoveFromList();
97   std::unique_ptr<Instruction> var(variable);  // Take ownership.
98   context()->ForgetUses(variable);
99 
100   // Update the storage class of the variable.
101   variable->SetInOperand(kVariableStorageClassInIdx, {SpvStorageClassFunction});
102 
103   // Update the type as well.
104   uint32_t new_type_id = GetNewType(variable->type_id());
105   variable->SetResultType(new_type_id);
106 
107   // Place the variable at the start of the first basic block.
108   context()->AnalyzeUses(variable);
109   context()->set_instr_block(variable, &*function->begin());
110   function->begin()->begin()->InsertBefore(move(var));
111 
112   // Update uses where the type may have changed.
113   UpdateUses(variable->result_id());
114 }
115 
GetNewType(uint32_t old_type_id)116 uint32_t PrivateToLocalPass::GetNewType(uint32_t old_type_id) {
117   auto type_mgr = context()->get_type_mgr();
118   Instruction* old_type_inst = get_def_use_mgr()->GetDef(old_type_id);
119   uint32_t pointee_type_id =
120       old_type_inst->GetSingleWordInOperand(kSpvTypePointerTypeIdInIdx);
121   uint32_t new_type_id =
122       type_mgr->FindPointerToType(pointee_type_id, SpvStorageClassFunction);
123   context()->UpdateDefUse(context()->get_def_use_mgr()->GetDef(new_type_id));
124   return new_type_id;
125 }
126 
IsValidUse(const Instruction * inst) const127 bool PrivateToLocalPass::IsValidUse(const Instruction* inst) const {
128   // The cases in this switch have to match the cases in |UpdateUse|.
129   // If we don't know how to update it, it is not valid.
130   switch (inst->opcode()) {
131     case SpvOpLoad:
132     case SpvOpStore:
133     case SpvOpImageTexelPointer:  // Treat like a load
134       return true;
135     case SpvOpAccessChain:
136       return context()->get_def_use_mgr()->WhileEachUser(
137           inst, [this](const Instruction* user) {
138             if (!IsValidUse(user)) return false;
139             return true;
140           });
141     case SpvOpName:
142       return true;
143     default:
144       return spvOpcodeIsDecoration(inst->opcode());
145   }
146 }
147 
UpdateUse(Instruction * inst)148 void PrivateToLocalPass::UpdateUse(Instruction* inst) {
149   // The cases in this switch have to match the cases in |IsValidUse|.  If we
150   // don't think it is valid, the optimization will not view the variable as a
151   // candidate, and therefore the use will not be updated.
152   switch (inst->opcode()) {
153     case SpvOpLoad:
154     case SpvOpStore:
155     case SpvOpImageTexelPointer:  // Treat like a load
156       // The type is fine because it is the type pointed to, and that does not
157       // change.
158       break;
159     case SpvOpAccessChain:
160       context()->ForgetUses(inst);
161       inst->SetResultType(GetNewType(inst->type_id()));
162       context()->AnalyzeUses(inst);
163 
164       // Update uses where the type may have changed.
165       UpdateUses(inst->result_id());
166       break;
167     case SpvOpName:
168       break;
169     default:
170       assert(spvOpcodeIsDecoration(inst->opcode()) &&
171              "Do not know how to update the type for this instruction.");
172       break;
173   }
174 }
UpdateUses(uint32_t id)175 void PrivateToLocalPass::UpdateUses(uint32_t id) {
176   std::vector<Instruction*> uses;
177   context()->get_def_use_mgr()->ForEachUser(
178       id, [&uses](Instruction* use) { uses.push_back(use); });
179 
180   for (Instruction* use : uses) {
181     UpdateUse(use);
182   }
183 }
184 
185 }  // namespace opt
186 }  // namespace spvtools
187