1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/private_to_local_pass.h"
16
17 #include <memory>
18 #include <utility>
19 #include <vector>
20
21 #include "source/opt/ir_context.h"
22
23 namespace spvtools {
24 namespace opt {
25 namespace {
26
27 const uint32_t kVariableStorageClassInIdx = 0;
28 const uint32_t kSpvTypePointerTypeIdInIdx = 1;
29
30 } // namespace
31
Process()32 Pass::Status PrivateToLocalPass::Process() {
33 bool modified = false;
34
35 // Private variables require the shader capability. If this is not a shader,
36 // there is no work to do.
37 if (context()->get_feature_mgr()->HasCapability(SpvCapabilityAddresses))
38 return Status::SuccessWithoutChange;
39
40 std::vector<std::pair<Instruction*, Function*>> variables_to_move;
41 for (auto& inst : context()->types_values()) {
42 if (inst.opcode() != SpvOpVariable) {
43 continue;
44 }
45
46 if (inst.GetSingleWordInOperand(kVariableStorageClassInIdx) !=
47 SpvStorageClassPrivate) {
48 continue;
49 }
50
51 Function* target_function = FindLocalFunction(inst);
52 if (target_function != nullptr) {
53 variables_to_move.push_back({&inst, target_function});
54 }
55 }
56
57 modified = !variables_to_move.empty();
58 for (auto p : variables_to_move) {
59 MoveVariable(p.first, p.second);
60 }
61
62 return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
63 }
64
FindLocalFunction(const Instruction & inst) const65 Function* PrivateToLocalPass::FindLocalFunction(const Instruction& inst) const {
66 bool found_first_use = false;
67 Function* target_function = nullptr;
68 context()->get_def_use_mgr()->ForEachUser(
69 inst.result_id(),
70 [&target_function, &found_first_use, this](Instruction* use) {
71 BasicBlock* current_block = context()->get_instr_block(use);
72 if (current_block == nullptr) {
73 return;
74 }
75
76 if (!IsValidUse(use)) {
77 found_first_use = true;
78 target_function = nullptr;
79 return;
80 }
81 Function* current_function = current_block->GetParent();
82 if (!found_first_use) {
83 found_first_use = true;
84 target_function = current_function;
85 } else if (target_function != current_function) {
86 target_function = nullptr;
87 }
88 });
89 return target_function;
90 } // namespace opt
91
MoveVariable(Instruction * variable,Function * function)92 void PrivateToLocalPass::MoveVariable(Instruction* variable,
93 Function* function) {
94 // The variable needs to be removed from the global section, and placed in the
95 // header of the function. First step remove from the global list.
96 variable->RemoveFromList();
97 std::unique_ptr<Instruction> var(variable); // Take ownership.
98 context()->ForgetUses(variable);
99
100 // Update the storage class of the variable.
101 variable->SetInOperand(kVariableStorageClassInIdx, {SpvStorageClassFunction});
102
103 // Update the type as well.
104 uint32_t new_type_id = GetNewType(variable->type_id());
105 variable->SetResultType(new_type_id);
106
107 // Place the variable at the start of the first basic block.
108 context()->AnalyzeUses(variable);
109 context()->set_instr_block(variable, &*function->begin());
110 function->begin()->begin()->InsertBefore(move(var));
111
112 // Update uses where the type may have changed.
113 UpdateUses(variable->result_id());
114 }
115
GetNewType(uint32_t old_type_id)116 uint32_t PrivateToLocalPass::GetNewType(uint32_t old_type_id) {
117 auto type_mgr = context()->get_type_mgr();
118 Instruction* old_type_inst = get_def_use_mgr()->GetDef(old_type_id);
119 uint32_t pointee_type_id =
120 old_type_inst->GetSingleWordInOperand(kSpvTypePointerTypeIdInIdx);
121 uint32_t new_type_id =
122 type_mgr->FindPointerToType(pointee_type_id, SpvStorageClassFunction);
123 context()->UpdateDefUse(context()->get_def_use_mgr()->GetDef(new_type_id));
124 return new_type_id;
125 }
126
IsValidUse(const Instruction * inst) const127 bool PrivateToLocalPass::IsValidUse(const Instruction* inst) const {
128 // The cases in this switch have to match the cases in |UpdateUse|.
129 // If we don't know how to update it, it is not valid.
130 switch (inst->opcode()) {
131 case SpvOpLoad:
132 case SpvOpStore:
133 case SpvOpImageTexelPointer: // Treat like a load
134 return true;
135 case SpvOpAccessChain:
136 return context()->get_def_use_mgr()->WhileEachUser(
137 inst, [this](const Instruction* user) {
138 if (!IsValidUse(user)) return false;
139 return true;
140 });
141 case SpvOpName:
142 return true;
143 default:
144 return spvOpcodeIsDecoration(inst->opcode());
145 }
146 }
147
UpdateUse(Instruction * inst)148 void PrivateToLocalPass::UpdateUse(Instruction* inst) {
149 // The cases in this switch have to match the cases in |IsValidUse|. If we
150 // don't think it is valid, the optimization will not view the variable as a
151 // candidate, and therefore the use will not be updated.
152 switch (inst->opcode()) {
153 case SpvOpLoad:
154 case SpvOpStore:
155 case SpvOpImageTexelPointer: // Treat like a load
156 // The type is fine because it is the type pointed to, and that does not
157 // change.
158 break;
159 case SpvOpAccessChain:
160 context()->ForgetUses(inst);
161 inst->SetResultType(GetNewType(inst->type_id()));
162 context()->AnalyzeUses(inst);
163
164 // Update uses where the type may have changed.
165 UpdateUses(inst->result_id());
166 break;
167 case SpvOpName:
168 break;
169 default:
170 assert(spvOpcodeIsDecoration(inst->opcode()) &&
171 "Do not know how to update the type for this instruction.");
172 break;
173 }
174 }
UpdateUses(uint32_t id)175 void PrivateToLocalPass::UpdateUses(uint32_t id) {
176 std::vector<Instruction*> uses;
177 context()->get_def_use_mgr()->ForEachUser(
178 id, [&uses](Instruction* use) { uses.push_back(use); });
179
180 for (Instruction* use : uses) {
181 UpdateUse(use);
182 }
183 }
184
185 } // namespace opt
186 } // namespace spvtools
187