1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "ModelUtils"
18 
19 #include "ModelUtils.h"
20 
21 #include <android-base/logging.h>
22 
23 #include <algorithm>
24 #include <numeric>
25 #include <unordered_set>
26 #include <utility>
27 #include <vector>
28 
29 #include "nnapi/TypeUtils.h"
30 #include "nnapi/Types.h"
31 #include "nnapi/Validation.h"
32 
33 namespace android::nn {
34 namespace {
35 
36 // Map each `true` value in `includes` with a unique integer. `false` values are ignored. E.g.:
37 //   includes = {false, true, true, false, true}
38 //   returned = {    X,    0,    1,     X,    2}
getMapping(const std::vector<bool> & includes)39 std::vector<uint32_t> getMapping(const std::vector<bool>& includes) {
40     std::vector<uint32_t> mapping;
41     mapping.reserve(includes.size());
42     std::transform_exclusive_scan(includes.begin(), includes.end(), std::back_inserter(mapping), 0u,
43                                   std::plus<>{}, [](bool included) { return included ? 1u : 0u; });
44     return mapping;
45 }
46 
47 // Remap indexes in `indexes` by the mapping `mapping`.
48 // Precondition: indexes != nullptr
remapIndexes(std::vector<uint32_t> * indexes,const std::vector<uint32_t> & mapping)49 void remapIndexes(std::vector<uint32_t>* indexes, const std::vector<uint32_t>& mapping) {
50     CHECK(indexes != nullptr);
51     for (uint32_t& index : (*indexes)) {
52         index = mapping.at(index);
53     }
54 }
55 
56 // Keep elements from `elements` specified by `elementsToKeep`, removing all other elements.
57 // Precondition: elements != nullptr
58 // Precondition: elements->size() == elementsToKeep.size()
59 template <typename Type>
keepSelectedElements(std::vector<Type> * elements,const std::vector<bool> & elementsToKeep)60 void keepSelectedElements(std::vector<Type>* elements, const std::vector<bool>& elementsToKeep) {
61     CHECK(elements != nullptr);
62     CHECK_EQ(elements->size(), elementsToKeep.size());
63 
64     size_t elementsCopied = 0;
65     for (size_t i = 0; i < elementsToKeep.size(); ++i) {
66         if (elementsToKeep[i]) {
67             if (elementsCopied != i) {
68                 (*elements)[elementsCopied] = std::move((*elements)[i]);
69             }
70             elementsCopied++;
71         }
72     }
73     elements->resize(elementsCopied);
74 }
75 
76 // Find which operands in model.main.operands are read or written by model.main.operations and
77 // model.main.inputIndexes.
78 // Postcondition: returned.size() == model.main.operands.size()
identifyUsedOperands(const Model & model)79 std::vector<bool> identifyUsedOperands(const Model& model) {
80     std::vector<bool> used(model.main.operands.size(), false);
81     auto markUsed = [&used](const std::vector<uint32_t>& indexes) {
82         std::for_each(indexes.begin(), indexes.end(),
83                       [&used](uint32_t index) { used.at(index) = true; });
84     };
85     for (const auto& operation : model.main.operations) {
86         markUsed(operation.inputs);
87         markUsed(operation.outputs);
88     }
89     markUsed(model.main.inputIndexes);
90     CHECK_EQ(used.size(), model.main.operands.size());
91     return used;
92 }
93 
94 // Forward declaration.
95 void identifyUsedSubgraphs(uint32_t current, const std::vector<Model::Subgraph>& subgraphs,
96                            std::vector<bool>* used);
97 
98 // Helper function to find which subgraphs are reachable by `operands`.
99 // Precondition: used != nullptr
100 // Precondition: subgraphs.size() == used->size()
identifyUsedSubgraphs(const std::vector<Operand> & operands,const std::vector<Model::Subgraph> & subgraphs,std::vector<bool> * used)101 void identifyUsedSubgraphs(const std::vector<Operand>& operands,
102                            const std::vector<Model::Subgraph>& subgraphs, std::vector<bool>* used) {
103     for (const auto& operand : operands) {
104         if (operand.lifetime == Operand::LifeTime::SUBGRAPH) {
105             identifyUsedSubgraphs(operand.location.offset, subgraphs, used);
106         }
107     }
108 }
109 
110 // Helper function to find which subgraphs are reachable by the subgraph at the `current` index, and
111 // store when a subgraph is used in `used`. `used` also acts as a cache, ensuring each subgraph is
112 // processed at most once.
113 // Precondition: used != nullptr
114 // Precondition: subgraphs.size() == used->size()
115 // Precondition: current < subgraphs.size()
identifyUsedSubgraphs(uint32_t current,const std::vector<Model::Subgraph> & subgraphs,std::vector<bool> * used)116 void identifyUsedSubgraphs(uint32_t current, const std::vector<Model::Subgraph>& subgraphs,
117                            std::vector<bool>* used) {
118     CHECK(used != nullptr);
119     CHECK_EQ(subgraphs.size(), used->size());
120     CHECK_LT(current, subgraphs.size());
121 
122     // If a subgraph was already marked as used, quickly return to avoid redundant processing.
123     if ((*used)[current]) {
124         return;
125     }
126 
127     // Mark the current subgraph as used, then process any subgraph it references recursively.
128     (*used)[current] = true;
129     identifyUsedSubgraphs(subgraphs[current].operands, subgraphs, used);
130 }
131 
132 // Find which subgraphs are reachable by the main operands of `model`.
133 // Postcondition: returned.size() == model.referenced.size()
identifyUsedSubgraphs(const Model & model)134 std::vector<bool> identifyUsedSubgraphs(const Model& model) {
135     std::vector<bool> used(model.referenced.size(), false);
136     identifyUsedSubgraphs(model.main.operands, model.referenced, &used);
137     CHECK_EQ(used.size(), model.referenced.size());
138     return used;
139 }
140 
141 // Helper function to find which pools are used by `subgraph`, and store when a pool is used in
142 // `used`.
143 // Precondition: used != nullptr
identifyUsedPools(const Model::Subgraph & subgraph,std::vector<bool> * used)144 void identifyUsedPools(const Model::Subgraph& subgraph, std::vector<bool>* used) {
145     CHECK(used != nullptr);
146     for (const auto& operand : subgraph.operands) {
147         if (operand.lifetime == Operand::LifeTime::CONSTANT_REFERENCE) {
148             used->at(operand.location.poolIndex) = true;
149         }
150     }
151 }
152 
153 // Find which pools are used by `model`.
154 // Postcondition: returned.size() == model.pools.size()
identifyUsedPools(const Model & model)155 std::vector<bool> identifyUsedPools(const Model& model) {
156     std::vector<bool> used(model.pools.size(), false);
157     identifyUsedPools(model.main, &used);
158     for (const auto& subgraph : model.referenced) {
159         identifyUsedPools(subgraph, &used);
160     }
161     CHECK_EQ(used.size(), model.pools.size());
162     return used;
163 }
164 
165 // Fix the DataLocation in `operand` by either remapping an index or by copying constant data.
166 // Precondition: operand != nullptr
167 // Precondition: newOperandValues != nullptr
fixOperandDataLocation(Operand * operand,Model::OperandValues * newOperandValues,const Model::OperandValues & oldOperandValues,const std::vector<uint32_t> & remappedPoolIndex,const std::vector<uint32_t> & remappedSubgraphIndex)168 void fixOperandDataLocation(Operand* operand, Model::OperandValues* newOperandValues,
169                             const Model::OperandValues& oldOperandValues,
170                             const std::vector<uint32_t>& remappedPoolIndex,
171                             const std::vector<uint32_t>& remappedSubgraphIndex) {
172     CHECK(operand != nullptr);
173     CHECK(newOperandValues != nullptr);
174 
175     switch (operand->lifetime) {
176         case Operand::LifeTime::CONSTANT_COPY: {
177             const uint8_t* data = oldOperandValues.data() + operand->location.offset;
178             const uint32_t length = operand->location.length;
179             operand->location = newOperandValues->append(data, length);
180             break;
181         }
182         case Operand::LifeTime::CONSTANT_REFERENCE:
183             operand->location.poolIndex = remappedPoolIndex.at(operand->location.poolIndex);
184             break;
185         case Operand::LifeTime::SUBGRAPH: {
186             uint32_t& subgraphIndex = operand->location.offset;
187             subgraphIndex = remappedSubgraphIndex.at(subgraphIndex);
188             break;
189         }
190         case Operand::LifeTime::TEMPORARY_VARIABLE:
191         case Operand::LifeTime::SUBGRAPH_INPUT:
192         case Operand::LifeTime::SUBGRAPH_OUTPUT:
193         case Operand::LifeTime::NO_VALUE:
194         case Operand::LifeTime::POINTER:
195             break;
196     }
197 }
198 
199 // Fix all DataLocations in `operands` by either remapping an index or by copying constant data.
200 // Precondition: operands != nullptr
201 // Precondition: newOperandValues != nullptr
fixOperandDataLocations(std::vector<Operand> * operands,Model::OperandValues * newOperandValues,const Model::OperandValues & oldOperandValues,const std::vector<uint32_t> & remappedPoolIndex,const std::vector<uint32_t> & remappedSubgraphIndex)202 void fixOperandDataLocations(std::vector<Operand>* operands, Model::OperandValues* newOperandValues,
203                              const Model::OperandValues& oldOperandValues,
204                              const std::vector<uint32_t>& remappedPoolIndex,
205                              const std::vector<uint32_t>& remappedSubgraphIndex) {
206     for (Operand& operand : (*operands)) {
207         fixOperandDataLocation(&operand, newOperandValues, oldOperandValues, remappedPoolIndex,
208                                remappedSubgraphIndex);
209     }
210 }
211 
212 // Fix all operands' DataLocations in `model` by either remapping an index or by copying constant
213 // data.
214 // Precondition: model != nullptr
fixOperandDataLocations(Model * model,const std::vector<uint32_t> & remappedPoolIndex,const std::vector<uint32_t> & remappedSubgraphIndex)215 void fixOperandDataLocations(Model* model, const std::vector<uint32_t>& remappedPoolIndex,
216                              const std::vector<uint32_t>& remappedSubgraphIndex) {
217     const auto operandValues = std::exchange(model->operandValues, Model::OperandValues{});
218     fixOperandDataLocations(&model->main.operands, &model->operandValues, operandValues,
219                             remappedPoolIndex, remappedSubgraphIndex);
220     for (auto& subgraph : model->referenced) {
221         fixOperandDataLocations(&subgraph.operands, &model->operandValues, operandValues,
222                                 remappedPoolIndex, remappedSubgraphIndex);
223     }
224 }
225 
226 // Find which extensions are used in `model`.
227 // Postcondition: returned.size() == model.extensionNameToPrefix.size()
identifyUsedExtensions(const Model & model)228 std::vector<bool> identifyUsedExtensions(const Model& model) {
229     std::unordered_set<uint16_t> prefixes;
230     const auto collectPrefix = [&prefixes](const auto& operandOrOperation) {
231         const auto prefix = getExtensionPrefix(static_cast<uint32_t>(operandOrOperation.type));
232         constexpr uint16_t kStandardPrefix = 0u;
233         if (prefix != kStandardPrefix) {
234             prefixes.insert(prefix);
235         }
236     };
237     const auto collectPrefixes = [collectPrefix](const Model::Subgraph& subgraph) {
238         std::for_each(subgraph.operands.begin(), subgraph.operands.end(), collectPrefix);
239         std::for_each(subgraph.operations.begin(), subgraph.operations.end(), collectPrefix);
240     };
241 
242     collectPrefixes(model.main);
243     for (const auto& subgraph : model.referenced) {
244         collectPrefixes(subgraph);
245     }
246 
247     std::vector<bool> used;
248     used.reserve(model.extensionNameToPrefix.size());
249     for (const auto& extension : model.extensionNameToPrefix) {
250         used.push_back(prefixes.count(extension.prefix) > 0);
251     }
252     CHECK_EQ(used.size(), model.extensionNameToPrefix.size());
253     return used;
254 }
255 
256 }  // anonymous namespace
257 
removeDeadOperands(Model * model)258 void removeDeadOperands(Model* model) {
259     CHECK(model != nullptr);
260 
261     // Keep only the operands which are used.
262     const auto operandsUsed = identifyUsedOperands(*model);
263     keepSelectedElements(&model->main.operands, operandsUsed);
264 
265     // Fix operand indexes.
266     const auto mappedOperandIndices = getMapping(operandsUsed);
267     for (auto& operation : model->main.operations) {
268         remapIndexes(&operation.inputs, mappedOperandIndices);
269         remapIndexes(&operation.outputs, mappedOperandIndices);
270     }
271     remapIndexes(&model->main.inputIndexes, mappedOperandIndices);
272     remapIndexes(&model->main.outputIndexes, mappedOperandIndices);
273 
274     // Keep only the subgraphs which are used.
275     const auto subgraphsUsed = identifyUsedSubgraphs(*model);
276     keepSelectedElements(&model->referenced, subgraphsUsed);
277 
278     // Keep only the pools which are used.
279     const auto poolsUsed = identifyUsedPools(*model);
280     keepSelectedElements(&model->pools, poolsUsed);
281 
282     // Fix operand locations.
283     const auto mappedPoolIndices = getMapping(poolsUsed);
284     const auto mappedSubgraphIndices = getMapping(subgraphsUsed);
285     fixOperandDataLocations(model, mappedPoolIndices, mappedSubgraphIndices);
286 
287     // Keep only the extensionNameToPrefixes which are used.
288     const auto extensionsUsed = identifyUsedExtensions(*model);
289     keepSelectedElements(&model->extensionNameToPrefix, extensionsUsed);
290 }
291 
292 }  // namespace android::nn
293