1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_H_
17 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_H_
18 
19 #include <algorithm>
20 #include <cstdint>
21 #include <map>
22 #include <memory>
23 #include <string>
24 #include <vector>
25 
26 #include "absl/memory/memory.h"
27 #include "absl/types/any.h"
28 #include "absl/types/optional.h"
29 #include "tensorflow/lite/delegates/gpu/common/shape.h"
30 #include "tensorflow/lite/delegates/gpu/common/status.h"
31 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
32 
33 namespace tflite {
34 namespace gpu {
35 
36 // There is yet another representation of CNN graph. The primary purpose of this
37 // representation is to simplify graph manipulation.
38 
39 using ValueId = uint32_t;
40 
41 using NodeId = uint32_t;
42 
43 // Used to emulate quantized behavior.
44 struct QuantizationParams {
45   float min = 0;
46   float max = 0;
47   float scale = 0;
48 };
49 
50 // Connects tensor's producer and operation that depends on this tensor.
51 struct Value {
52   const ValueId id;
53   TensorRef<BHWC> tensor;
54   absl::optional<QuantizationParams> quant_params;
55 };
56 
57 struct Operation {
58   std::string type;
59   absl::any attributes;
60 };
61 
62 struct Node {
63   const NodeId id;
64   Operation operation;
65 };
66 
67 // A DAG that consists of nodes and values. Each value may have a single
68 // producer node and multiple consumer nodes. Therefore, each node may have
69 // multiple input and output values.
70 //
71 // Value that does not have a producer is a graph's input. Value that does not
72 // have a consumer is a graph's output.
73 //
74 // It keeps values and nodes referenced by their index in a vector. Therefore,
75 // nodes and values are never deleted, but rather erased, where corresponding
76 // index remains.
77 //
78 // It is possible to re-use removed indices, but it is not implemented yet.
79 class GraphFloat32 {
80  public:
81   // @return a collection of nodes in this graph.
82   std::vector<Node*> nodes() const;
83 
84   // @return a collection of values in this graph.
85   std::vector<Value*> values() const;
86 
87   // @return graph inputs, that are values without producers.
88   std::vector<Value*> inputs() const;
89 
90   // @return graph outputs, that are values without consumers.
91   std::vector<Value*> outputs() const;
92 
93   // @return values updated in place with a previously defined tensor reference.
94   std::vector<Value*> variable_inputs() const;
95 
96   // @return inputs into the given node. Returns empty vector for deleted node.
97   std::vector<Value*> FindInputs(NodeId id) const;
98 
99   // @return outputs from the given node. Returns empty vector for deleted node.
100   std::vector<Value*> FindOutputs(NodeId id) const;
101 
102   bool IsGraphInput(ValueId id) const;
103 
104   bool IsGraphOutput(ValueId id) const;
105 
106   // @return producer of the given value. Returns nullptr for deleted value.
107   Node* FindProducer(ValueId id) const;
108 
109   // @return consumers of the given value. Returns empty vector for deleted
110   // value.
111   std::vector<Node*> FindConsumers(ValueId id) const;
112 
113   // @return a node or nullptr if node with the given id is not present.
114   Node* GetNode(NodeId id) const;
115 
116   // @return a value or nullptr if value with the given id is not present.
117   Value* GetValue(ValueId id) const;
118 
119   //////////////////////////////////////////////////////////////////////////////
120   // Graph manipulation functions are below
121   //////////////////////////////////////////////////////////////////////////////
122 
123   // @return new node created in this graph
124   // NOTE: nodes should be created in the topological order, e.g. node A that
125   // depends on a value from node B should be created after node B.
126   Node* NewNode();
127 
128   // Insert Node after another in the execution plan.
129   absl::Status InsertNodeAfter(NodeId id, Node** new_node);
130 
131   // @return new value created in this graph
132   Value* NewValue();
133 
134   // Sets a producer for the given value. There could be a single producer
135   // for a value. If a value had another producer, it will reassign producer
136   // appropriately. If a value didn't have a producer, it will be removed
137   // from a graph's input.
138   absl::Status SetProducer(NodeId producer, ValueId value);
139 
140   // Removes a producer for the given value. Value becomes producer-less and
141   // therefore becomes graph's input.
142   absl::Status RemoveProducer(ValueId value);
143 
144   // Sets a consumer for the given value. There could be multiple consumers
145   // for a value.
146   absl::Status AddConsumer(NodeId consumer, ValueId value);
147 
148   // Replace input value for given node.
149   absl::Status ReplaceInput(NodeId node, ValueId old_value, ValueId new_value);
150 
151   // Removes a consumer for the given value. If value does not have any
152   // consumers it becomes graph's output.
153   absl::Status RemoveConsumer(NodeId consumer, ValueId value);
154 
155   // Removes node from this graph. For all input values this node will be
156   // removed from consumers and for all output values a producer will be
157   // removed.
158   absl::Status DeleteNode(NodeId id);
159 
160   // Removes value from this graph. It will be removed from inputs for all
161   // dependent nodes. A node that was a producer of this value will loose its
162   // output.
163   absl::Status DeleteValue(ValueId id);
164 
165   absl::Status MakeExactCopy(GraphFloat32* model) const;
166 
167  private:
168   struct NodeDef {
169     std::vector<Value*> inputs;
170     std::vector<Value*> outputs;
171     std::unique_ptr<Node> node;
172   };
173 
174   struct ValueDef {
175     Node* producer = nullptr;
176     std::vector<Node*> consumers;
177     std::unique_ptr<Value> value;
178   };
179 
180   bool IsInput(NodeId node, ValueId value);
181 
182   template <typename T>
Erase(std::vector<T> * values,T value)183   static void Erase(std::vector<T>* values, T value) {
184     values->erase(std::find(values->begin(), values->end(), value));
185   }
186 
187   // @return non-nullptr NodeDef that has valid Node or an error
188   absl::Status LookupNode(NodeId id, NodeDef** node_def);
189 
190   // @return non-nullptr ValueDef that has valid Value or an error
191   absl::Status LookupValue(ValueId id, ValueDef** value_def);
192 
193   template <typename Pred>
FilterValues(const Pred & predicate)194   std::vector<Value*> FilterValues(const Pred& predicate) const {
195     std::vector<Value*> values;
196     values.reserve(values_.size());
197     for (auto& v : values_) {
198       if (v.value != nullptr && predicate(v)) {
199         values.push_back(v.value.get());
200       }
201     }
202     return values;
203   }
204 
205   template <typename Pred>
FilterNodes(const Pred & predicate)206   std::vector<Node*> FilterNodes(const Pred& predicate) const {
207     std::vector<Node*> nodes;
208     nodes.reserve(nodes_.size());
209     for (const auto id : execution_plan_) {
210       auto& n = nodes_.at(id);
211       if (n.node != nullptr && predicate(n)) {
212         nodes.push_back(n.node.get());
213       }
214     }
215     return nodes;
216   }
217 
218   // There are two approaches possible: wrap entire NodeDef and ValueDef into
219   // unique_ptr and store it in values_ and nodes_ or store it by value.
220   // We store it by value here to make introspection calls cheaper.
221   std::vector<ValueDef> values_;
222 
223   std::map<NodeId, NodeDef> nodes_;
224   // Node Ids in order of execution.
225   std::vector<NodeId> execution_plan_;
226 };
227 
228 // Removes to_remove node that precedes to_keep node only if to_remove has
229 // outputs that are consumed only by to_keep. In such case to_keep inherits all
230 // to_remove inputs.
231 absl::Status RemovePrecedingNode(GraphFloat32* graph, const Node* to_remove,
232                                  const Node* to_keep);
233 
234 // Removes to_remove node that follows to_keep node only if to_remove has inputs
235 // that are produced by to_keep. to_keep inherits all to_remove inputs.
236 absl::Status RemoveFollowingNode(GraphFloat32* graph, const Node* to_remove,
237                                  const Node* to_keep);
238 
239 // Removes simple_node and its output value from the graph. Node is considered
240 // simple if it has only one input and one output value. Input value is kept.
241 absl::Status RemoveSimpleNodeKeepInput(GraphFloat32* graph,
242                                        const Node* simple_node);
243 
244 // Removes simple_node and its input value from the graph. Node is considered
245 // simple if it has only one input and one output value. Output value is kept.
246 // simple_node should be an exclusive consumer of its input value.
247 absl::Status RemoveSimpleNodeKeepOutput(GraphFloat32* graph,
248                                         const Node* simple_node);
249 
250 absl::Status AddOutput(GraphFloat32* graph, const Node* from_node,
251                        Value** output);
252 
253 // Makes a direct connection between from_node and to_node. All input parameters
254 // except output are expected to be initialized before passing to the function.
255 // If from_node already has an output value, which is not yet consumed by
256 // to_node, it may be passed as output parameter.
257 absl::Status ConnectTwoNodes(GraphFloat32* graph, const Node* from_node,
258                              const Node* to_node, Value** output);
259 
260 // @return true if all tensors have same batch value or if model has no values.
261 bool IsBatchMatchesForAllValues(const GraphFloat32& model);
262 
263 }  // namespace gpu
264 }  // namespace tflite
265 
266 #endif  // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_H_
267