1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/model.h"
17 
18 #include <stdint.h>
19 
20 #include <algorithm>
21 #include <iterator>
22 #include <map>
23 #include <memory>
24 #include <string>
25 #include <utility>
26 #include <vector>
27 
28 #include "absl/status/status.h"
29 #include "absl/strings/str_cat.h"
30 #include "tensorflow/lite/delegates/gpu/common/shape.h"
31 #include "tensorflow/lite/delegates/gpu/common/status.h"
32 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
33 
34 namespace tflite {
35 namespace gpu {
36 
nodes() const37 std::vector<Node*> GraphFloat32::nodes() const {
38   return FilterNodes([](const NodeDef&) { return true; });
39 }
40 
values() const41 std::vector<Value*> GraphFloat32::values() const {
42   return FilterValues([](const ValueDef&) { return true; });
43 }
44 
inputs() const45 std::vector<Value*> GraphFloat32::inputs() const {
46   return FilterValues([](const ValueDef& v) { return v.producer == nullptr; });
47 }
48 
variable_inputs() const49 std::vector<Value*> GraphFloat32::variable_inputs() const {
50   return FilterValues(
51       [](const ValueDef& v) { return v.value->tensor.is_variable_input; });
52 }
53 
outputs() const54 std::vector<Value*> GraphFloat32::outputs() const {
55   return FilterValues([](const ValueDef& v) { return v.consumers.empty(); });
56 }
57 
FindInputs(NodeId id) const58 std::vector<Value*> GraphFloat32::FindInputs(NodeId id) const {
59   if (id >= nodes_.size()) {
60     return {};
61   }
62   return nodes_.at(id).inputs;
63 }
64 
FindOutputs(NodeId id) const65 std::vector<Value*> GraphFloat32::FindOutputs(NodeId id) const {
66   if (id >= nodes_.size()) {
67     return {};
68   }
69   return nodes_.at(id).outputs;
70 }
71 
IsGraphInput(ValueId id) const72 bool GraphFloat32::IsGraphInput(ValueId id) const {
73   if (id >= values_.size()) {
74     return false;
75   }
76   return values_[id].producer == nullptr;
77 }
78 
IsGraphOutput(ValueId id) const79 bool GraphFloat32::IsGraphOutput(ValueId id) const {
80   if (id >= values_.size()) {
81     return false;
82   }
83   return values_[id].consumers.empty();
84 }
85 
FindProducer(ValueId id) const86 Node* GraphFloat32::FindProducer(ValueId id) const {
87   if (id >= values_.size()) {
88     return nullptr;
89   }
90   return values_[id].producer;
91 }
92 
FindConsumers(ValueId id) const93 std::vector<Node*> GraphFloat32::FindConsumers(ValueId id) const {
94   if (id >= values_.size()) {
95     return {};
96   }
97   return values_[id].consumers;
98 }
99 
GetNode(NodeId id) const100 Node* GraphFloat32::GetNode(NodeId id) const {
101   if (id >= nodes_.size()) {
102     return {};
103   }
104   return nodes_.at(id).node.get();
105 }
106 
GetValue(ValueId id) const107 Value* GraphFloat32::GetValue(ValueId id) const {
108   if (id >= values_.size()) {
109     return nullptr;
110   }
111   return values_[id].value.get();
112 }
113 
NewNode()114 Node* GraphFloat32::NewNode() {
115   const NodeId new_id = nodes_.size();
116   NodeDef def;
117   def.node = absl::make_unique<Node>(Node{static_cast<NodeId>(new_id), {}});
118   Node* node = def.node.get();
119   nodes_[new_id] = std::move(def);
120   execution_plan_.push_back(new_id);
121   return node;
122 }
123 
InsertNodeAfter(NodeId id,Node ** new_node)124 absl::Status GraphFloat32::InsertNodeAfter(NodeId id, Node** new_node) {
125   if (id >= nodes_.size()) {
126     return absl::OutOfRangeError("NodeId is out of range");
127   }
128   int idx = 0;
129   while (idx < execution_plan_.size()) {
130     if (execution_plan_[idx] == id) break;
131     ++idx;
132   }
133   if (idx == execution_plan_.size()) {
134     return absl::OutOfRangeError("NodeId not in execution plan");
135   }
136 
137   const NodeId new_id = nodes_.size();
138   NodeDef def;
139   def.node = absl::make_unique<Node>(Node{static_cast<NodeId>(new_id), {}});
140   *new_node = def.node.get();
141   nodes_[new_id] = std::move(def);
142   execution_plan_.insert(execution_plan_.begin() + idx + 1, new_id);
143   return absl::OkStatus();
144 }
145 
NewValue()146 Value* GraphFloat32::NewValue() {
147   ValueDef def;
148   def.value =
149       absl::make_unique<Value>(Value{static_cast<ValueId>(values_.size()), {}});
150   Value* value = def.value.get();
151   values_.push_back(std::move(def));
152   return value;
153 }
154 
SetProducer(NodeId producer,ValueId value)155 absl::Status GraphFloat32::SetProducer(NodeId producer, ValueId value) {
156   ValueDef* v;
157   RETURN_IF_ERROR(LookupValue(value, &v));
158   Value* value_ptr = v->value.get();
159   NodeDef* n;
160   RETURN_IF_ERROR(LookupNode(producer, &n));
161   Node* node_ptr = n->node.get();
162 
163   // check if this value has the same producer already
164   if (node_ptr == v->producer) {
165     return absl::AlreadyExistsError(absl::StrCat(
166         "Node ", producer, " is already a producer of the value ", value));
167   }
168 
169   // Check if the node is a consumer of this value.
170   if (IsInput(producer, value)) {
171     return absl::InvalidArgumentError("Node is a consumer of the value");
172   }
173 
174   if (v->producer != nullptr) {
175     // value is no longer produced by it's previous producer.
176     Erase(&nodes_[v->producer->id].outputs, value_ptr);
177   }
178   v->producer = node_ptr;
179   n->outputs.push_back(value_ptr);
180   return absl::OkStatus();
181 }
182 
RemoveProducer(ValueId value)183 absl::Status GraphFloat32::RemoveProducer(ValueId value) {
184   ValueDef* v;
185   RETURN_IF_ERROR(LookupValue(value, &v));
186   Value* value_ptr = v->value.get();
187   if (v->producer == nullptr) {
188     return absl::InvalidArgumentError("Value does not have a producer");
189   }
190   Erase(&nodes_[v->producer->id].outputs, value_ptr);
191   v->producer = nullptr;
192   return absl::OkStatus();
193 }
194 
AddConsumer(NodeId consumer,ValueId value)195 absl::Status GraphFloat32::AddConsumer(NodeId consumer, ValueId value) {
196   ValueDef* v;
197   RETURN_IF_ERROR(LookupValue(value, &v));
198   Value* value_ptr = v->value.get();
199   NodeDef* n;
200   RETURN_IF_ERROR(LookupNode(consumer, &n));
201   Node* node_ptr = n->node.get();
202 
203   // check if this value has the same producer already
204   if (node_ptr == v->producer) {
205     return absl::InvalidArgumentError("Node is a producer of the value");
206   }
207 
208   // check if this value has the same consumer already
209   if (IsInput(consumer, value)) {
210     return absl::AlreadyExistsError(absl::StrCat(
211         "Node ", consumer, " is already a consumer of the value ", value));
212   }
213 
214   n->inputs.push_back(value_ptr);
215   v->consumers.push_back(node_ptr);
216   return absl::OkStatus();
217 }
218 
219 // Replace input value for given node.
ReplaceInput(NodeId node,ValueId old_value,ValueId new_value)220 absl::Status GraphFloat32::ReplaceInput(NodeId node, ValueId old_value,
221                                         ValueId new_value) {
222   ValueDef* v_old;
223   RETURN_IF_ERROR(LookupValue(old_value, &v_old));
224   Value* value_old_ptr = v_old->value.get();
225   ValueDef* v_new;
226   RETURN_IF_ERROR(LookupValue(new_value, &v_new));
227   Value* value_new_ptr = v_new->value.get();
228   NodeDef* n;
229   RETURN_IF_ERROR(LookupNode(node, &n));
230   Node* node_ptr = n->node.get();
231 
232   // Check if the node is a consumer of old_value.
233   if (!IsInput(node, old_value)) {
234     return absl::InvalidArgumentError("old_value must be input of node.");
235   }
236 
237   // Check if the node is not a consumer of new_value.
238   if (IsInput(node, new_value)) {
239     return absl::InvalidArgumentError("new_value can not be input of node.");
240   }
241 
242   // Check if this value has the same producer already
243   if (node_ptr == v_new->producer) {
244     return absl::InvalidArgumentError("new_value can not be output of node.");
245   }
246 
247   for (int i = 0; i < n->inputs.size(); ++i) {
248     if (n->inputs[i] == value_old_ptr) {
249       n->inputs[i] = value_new_ptr;
250       break;
251     }
252   }
253   v_new->consumers.push_back(node_ptr);
254   Erase(&v_old->consumers, node_ptr);
255   return absl::OkStatus();
256 }
257 
RemoveConsumer(NodeId consumer,ValueId value)258 absl::Status GraphFloat32::RemoveConsumer(NodeId consumer, ValueId value) {
259   ValueDef* v;
260   RETURN_IF_ERROR(LookupValue(value, &v));
261   Value* value_ptr = v->value.get();
262   NodeDef* n;
263   RETURN_IF_ERROR(LookupNode(consumer, &n));
264   Node* node_ptr = n->node.get();
265   if (!IsInput(consumer, value)) {
266     return absl::InvalidArgumentError("Node is not a consumer of the value");
267   }
268   Erase(&n->inputs, value_ptr);
269   Erase(&v->consumers, node_ptr);
270   return absl::OkStatus();
271 }
272 
DeleteNode(NodeId id)273 absl::Status GraphFloat32::DeleteNode(NodeId id) {
274   NodeDef* n;
275   RETURN_IF_ERROR(LookupNode(id, &n));
276   Node* node_ptr = n->node.get();
277   for (auto value : n->inputs) {
278     Erase(&values_[value->id].consumers, node_ptr);
279   }
280   for (auto value : n->outputs) {
281     values_[value->id].producer = nullptr;
282   }
283   n->inputs.clear();
284   n->outputs.clear();
285   n->node.reset();
286   return absl::OkStatus();
287 }
288 
DeleteValue(ValueId id)289 absl::Status GraphFloat32::DeleteValue(ValueId id) {
290   ValueDef* v;
291   RETURN_IF_ERROR(LookupValue(id, &v));
292   Value* value_ptr = v->value.get();
293   if (v->producer != nullptr) {
294     Erase(&nodes_[v->producer->id].outputs, value_ptr);
295   }
296   if (!v->consumers.empty()) {
297     for (auto node : v->consumers) {
298       Erase(&nodes_[node->id].inputs, value_ptr);
299     }
300   }
301   v->producer = nullptr;
302   v->consumers.clear();
303   v->value.reset();
304   return absl::OkStatus();
305 }
306 
MakeExactCopy(GraphFloat32 * model) const307 absl::Status GraphFloat32::MakeExactCopy(GraphFloat32* model) const {
308   model->nodes_.clear();
309   model->execution_plan_.clear();
310   model->values_.clear();
311   for (auto& value_def : values_) {
312     model->values_.push_back({});
313     if (value_def.value) {
314       model->values_.back().value = absl::make_unique<Value>(*value_def.value);
315     }
316   }
317   // Add all nodes first.
318   for (auto node_id : execution_plan_) {
319     model->execution_plan_.push_back(node_id);
320     model->nodes_[node_id] = {};
321     auto& node_def = nodes_.at(node_id);
322     if (node_def.node) {
323       model->nodes_[node_id].node = absl::make_unique<Node>(*node_def.node);
324     }
325   }
326   // Wire up dependencies between nodes.
327   for (auto node_id : execution_plan_) {
328     auto& node_def = nodes_.at(node_id);
329     if (node_def.node) {
330       for (auto output : node_def.outputs) {
331         RETURN_IF_ERROR(model->SetProducer(node_def.node->id, output->id));
332       }
333       for (auto input : node_def.inputs) {
334         RETURN_IF_ERROR(model->AddConsumer(node_def.node->id, input->id));
335       }
336     }
337   }
338   return absl::OkStatus();
339 }
340 
IsInput(NodeId node,ValueId value)341 bool GraphFloat32::IsInput(NodeId node, ValueId value) {
342   if (node >= nodes_.size() || value >= values_.size()) {
343     return false;
344   }
345   const NodeDef& n = nodes_[node];
346   const ValueDef& v = values_[value];
347   if (!n.node || !v.value) {
348     return false;
349   }
350   return std::find(n.inputs.begin(), n.inputs.end(), v.value.get()) !=
351          n.inputs.end();
352 }
353 
LookupNode(NodeId id,NodeDef ** node_def)354 absl::Status GraphFloat32::LookupNode(NodeId id, NodeDef** node_def) {
355   if (id >= nodes_.size()) {
356     return absl::OutOfRangeError("NodeId is out of range");
357   }
358   auto& n = nodes_[id];
359   if (!n.node) {
360     return absl::OutOfRangeError("Node is already deleted");
361   }
362   *node_def = &n;
363   return absl::OkStatus();
364 }
365 
LookupValue(ValueId id,ValueDef ** value_def)366 absl::Status GraphFloat32::LookupValue(ValueId id, ValueDef** value_def) {
367   if (id >= values_.size()) {
368     return absl::OutOfRangeError("ValueId is out of range");
369   }
370   auto& v = values_[id];
371   if (!v.value) {
372     return absl::OutOfRangeError("Value is already deleted");
373   }
374   *value_def = &v;
375   return absl::OkStatus();
376 }
377 
RemovePrecedingNode(GraphFloat32 * graph,const Node * to_remove,const Node * to_keep)378 absl::Status RemovePrecedingNode(GraphFloat32* graph, const Node* to_remove,
379                                  const Node* to_keep) {
380   // Make sure all outputs from to_remove are consumed by to_keep.
381   for (auto output : graph->FindOutputs(to_remove->id)) {
382     auto consumers = graph->FindConsumers(output->id);
383     if (consumers.size() > 1 ||
384         (consumers.size() == 1 && consumers[0] != to_keep)) {
385       return absl::InvalidArgumentError(
386           "Output from to_remove node has other consumers");
387     }
388   }
389 
390   // Update all references
391   for (auto input : graph->FindInputs(to_remove->id)) {
392     RETURN_IF_ERROR(graph->AddConsumer(to_keep->id, input->id));
393   }
394   for (auto output : graph->FindOutputs(to_remove->id)) {
395     RETURN_IF_ERROR(graph->DeleteValue(output->id));
396   }
397   return graph->DeleteNode(to_remove->id);
398 }
399 
RemoveFollowingNode(GraphFloat32 * graph,const Node * to_remove,const Node * to_keep)400 absl::Status RemoveFollowingNode(GraphFloat32* graph, const Node* to_remove,
401                                  const Node* to_keep) {
402   // Make sure all inputs to to_remove are produced by to_keep.
403   for (auto input : graph->FindInputs(to_remove->id)) {
404     Node* producer = graph->FindProducer(input->id);
405     if (producer->id != to_keep->id) {
406       return absl::InvalidArgumentError("To_remove node has other inputs");
407     }
408   }
409 
410   for (auto input : graph->FindInputs(to_remove->id)) {
411     RETURN_IF_ERROR(graph->DeleteValue(input->id));
412   }
413   for (auto output : graph->FindOutputs(to_remove->id)) {
414     RETURN_IF_ERROR(graph->SetProducer(to_keep->id, output->id));
415   }
416   return graph->DeleteNode(to_remove->id);
417 }
418 
RemoveSimpleNodeKeepInput(GraphFloat32 * graph,const Node * simple_node)419 absl::Status RemoveSimpleNodeKeepInput(GraphFloat32* graph,
420                                        const Node* simple_node) {
421   const auto inputs = graph->FindInputs(simple_node->id);
422   const auto outputs = graph->FindOutputs(simple_node->id);
423   if (inputs.size() != 1 || outputs.size() != 1) {
424     return absl::FailedPreconditionError(
425         "simple_node node must have 1 input and 1 output");
426   }
427   const auto input_id = inputs[0]->id;
428   const auto output_id = outputs[0]->id;
429   const Node* producer = graph->FindProducer(input_id);
430   const auto consumers = graph->FindConsumers(output_id);
431   RETURN_IF_ERROR(graph->DeleteNode(simple_node->id));
432   for (auto& consumer : consumers) {
433     RETURN_IF_ERROR(graph->ReplaceInput(consumer->id, output_id, input_id));
434   }
435   RETURN_IF_ERROR(graph->DeleteValue(output_id));
436   if (!producer && consumers.empty()) {
437     RETURN_IF_ERROR(graph->DeleteValue(input_id));
438   }
439   return absl::OkStatus();
440 }
441 
RemoveSimpleNodeKeepOutput(GraphFloat32 * graph,const Node * simple_node)442 absl::Status RemoveSimpleNodeKeepOutput(GraphFloat32* graph,
443                                         const Node* simple_node) {
444   const auto inputs = graph->FindInputs(simple_node->id);
445   const auto outputs = graph->FindOutputs(simple_node->id);
446   if (inputs.size() != 1 || outputs.size() != 1) {
447     return absl::FailedPreconditionError(
448         "simple_node must have 1 input and 1 output");
449   }
450   const auto input_id = inputs[0]->id;
451   const auto output_id = outputs[0]->id;
452   const Node* producer = graph->FindProducer(input_id);
453   const auto input_consumers = graph->FindConsumers(input_id);
454   if (input_consumers.size() != 1) {
455     return absl::FailedPreconditionError(
456         "simple_node should be the only consumer on the node.");
457   }
458 
459   RETURN_IF_ERROR(graph->DeleteNode(simple_node->id));
460   if (producer) {
461     RETURN_IF_ERROR(graph->RemoveProducer(input_id));
462     RETURN_IF_ERROR(graph->SetProducer(producer->id, output_id));
463   }
464 
465   RETURN_IF_ERROR(graph->DeleteValue(input_id));
466 
467   const auto output_consumers = graph->FindConsumers(output_id);
468   if (!producer && output_consumers.empty()) {
469     RETURN_IF_ERROR(graph->DeleteValue(output_id));
470   }
471   return absl::OkStatus();
472 }
473 
AddOutput(GraphFloat32 * graph,const Node * from_node,Value ** output)474 absl::Status AddOutput(GraphFloat32* graph, const Node* from_node,
475                        Value** output) {
476   auto link = graph->NewValue();
477   RETURN_IF_ERROR(graph->SetProducer(from_node->id, link->id));
478   *output = link;
479   return absl::OkStatus();
480 }
481 
ConnectTwoNodes(GraphFloat32 * graph,const Node * from_node,const Node * to_node,Value ** output)482 absl::Status ConnectTwoNodes(GraphFloat32* graph, const Node* from_node,
483                              const Node* to_node, Value** output) {
484   const Node* output_producer =
485       *output ? graph->FindProducer((*output)->id) : nullptr;
486   // Output is already initialized, but producer is not from_node.
487   if (*output && output_producer && output_producer->id != from_node->id) {
488     return absl::InvalidArgumentError("Wrong output is passed.");
489   }
490   // Output is already initialized, and producer is from_node.
491   if (*output) {
492     RETURN_IF_ERROR(graph->AddConsumer(to_node->id, (*output)->id));
493   } else {
494     // Output is not initialized.
495     Value* link;
496     RETURN_IF_ERROR(AddOutput(graph, from_node, &link));
497     RETURN_IF_ERROR(graph->AddConsumer(to_node->id, link->id));
498     *output = link;
499   }
500   return absl::OkStatus();
501 }
502 
IsBatchMatchesForAllValues(const GraphFloat32 & model)503 bool IsBatchMatchesForAllValues(const GraphFloat32& model) {
504   if (model.values().empty()) return true;
505   const int32_t b = model.values()[0]->tensor.shape.b;
506   for (auto value : model.values()) {
507     if (value->tensor.shape.b != b) {
508       return false;
509     }
510   }
511   return true;
512 }
513 
514 }  // namespace gpu
515 }  // namespace tflite
516