1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_H_ 17 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_H_ 18 19 #include <algorithm> 20 #include <cstdint> 21 #include <map> 22 #include <memory> 23 #include <string> 24 #include <vector> 25 26 #include "absl/memory/memory.h" 27 #include "absl/types/any.h" 28 #include "absl/types/optional.h" 29 #include "tensorflow/lite/delegates/gpu/common/shape.h" 30 #include "tensorflow/lite/delegates/gpu/common/status.h" 31 #include "tensorflow/lite/delegates/gpu/common/tensor.h" 32 33 namespace tflite { 34 namespace gpu { 35 36 // There is yet another representation of CNN graph. The primary purpose of this 37 // representation is to simplify graph manipulation. 38 39 using ValueId = uint32_t; 40 41 using NodeId = uint32_t; 42 43 // Used to emulate quantized behavior. 44 struct QuantizationParams { 45 float min = 0; 46 float max = 0; 47 float scale = 0; 48 }; 49 50 // Connects tensor's producer and operation that depends on this tensor. 51 struct Value { 52 const ValueId id; 53 TensorRef<BHWC> tensor; 54 absl::optional<QuantizationParams> quant_params; 55 }; 56 57 struct Operation { 58 std::string type; 59 absl::any attributes; 60 }; 61 62 struct Node { 63 const NodeId id; 64 Operation operation; 65 }; 66 67 // A DAG that consists of nodes and values. Each value may have a single 68 // producer node and multiple consumer nodes. Therefore, each node may have 69 // multiple input and output values. 70 // 71 // Value that does not have a producer is a graph's input. Value that does not 72 // have a consumer is a graph's output. 73 // 74 // It keeps values and nodes referenced by their index in a vector. Therefore, 75 // nodes and values are never deleted, but rather erased, where corresponding 76 // index remains. 77 // 78 // It is possible to re-use removed indices, but it is not implemented yet. 79 class GraphFloat32 { 80 public: 81 // @return a collection of nodes in this graph. 82 std::vector<Node*> nodes() const; 83 84 // @return a collection of values in this graph. 85 std::vector<Value*> values() const; 86 87 // @return graph inputs, that are values without producers. 88 std::vector<Value*> inputs() const; 89 90 // @return graph outputs, that are values without consumers. 91 std::vector<Value*> outputs() const; 92 93 // @return values updated in place with a previously defined tensor reference. 94 std::vector<Value*> variable_inputs() const; 95 96 // @return inputs into the given node. Returns empty vector for deleted node. 97 std::vector<Value*> FindInputs(NodeId id) const; 98 99 // @return outputs from the given node. Returns empty vector for deleted node. 100 std::vector<Value*> FindOutputs(NodeId id) const; 101 102 bool IsGraphInput(ValueId id) const; 103 104 bool IsGraphOutput(ValueId id) const; 105 106 // @return producer of the given value. Returns nullptr for deleted value. 107 Node* FindProducer(ValueId id) const; 108 109 // @return consumers of the given value. Returns empty vector for deleted 110 // value. 111 std::vector<Node*> FindConsumers(ValueId id) const; 112 113 // @return a node or nullptr if node with the given id is not present. 114 Node* GetNode(NodeId id) const; 115 116 // @return a value or nullptr if value with the given id is not present. 117 Value* GetValue(ValueId id) const; 118 119 ////////////////////////////////////////////////////////////////////////////// 120 // Graph manipulation functions are below 121 ////////////////////////////////////////////////////////////////////////////// 122 123 // @return new node created in this graph 124 // NOTE: nodes should be created in the topological order, e.g. node A that 125 // depends on a value from node B should be created after node B. 126 Node* NewNode(); 127 128 // Insert Node after another in the execution plan. 129 absl::Status InsertNodeAfter(NodeId id, Node** new_node); 130 131 // @return new value created in this graph 132 Value* NewValue(); 133 134 // Sets a producer for the given value. There could be a single producer 135 // for a value. If a value had another producer, it will reassign producer 136 // appropriately. If a value didn't have a producer, it will be removed 137 // from a graph's input. 138 absl::Status SetProducer(NodeId producer, ValueId value); 139 140 // Removes a producer for the given value. Value becomes producer-less and 141 // therefore becomes graph's input. 142 absl::Status RemoveProducer(ValueId value); 143 144 // Sets a consumer for the given value. There could be multiple consumers 145 // for a value. 146 absl::Status AddConsumer(NodeId consumer, ValueId value); 147 148 // Replace input value for given node. 149 absl::Status ReplaceInput(NodeId node, ValueId old_value, ValueId new_value); 150 151 // Removes a consumer for the given value. If value does not have any 152 // consumers it becomes graph's output. 153 absl::Status RemoveConsumer(NodeId consumer, ValueId value); 154 155 // Removes node from this graph. For all input values this node will be 156 // removed from consumers and for all output values a producer will be 157 // removed. 158 absl::Status DeleteNode(NodeId id); 159 160 // Removes value from this graph. It will be removed from inputs for all 161 // dependent nodes. A node that was a producer of this value will loose its 162 // output. 163 absl::Status DeleteValue(ValueId id); 164 165 absl::Status MakeExactCopy(GraphFloat32* model) const; 166 167 private: 168 struct NodeDef { 169 std::vector<Value*> inputs; 170 std::vector<Value*> outputs; 171 std::unique_ptr<Node> node; 172 }; 173 174 struct ValueDef { 175 Node* producer = nullptr; 176 std::vector<Node*> consumers; 177 std::unique_ptr<Value> value; 178 }; 179 180 bool IsInput(NodeId node, ValueId value); 181 182 template <typename T> Erase(std::vector<T> * values,T value)183 static void Erase(std::vector<T>* values, T value) { 184 values->erase(std::find(values->begin(), values->end(), value)); 185 } 186 187 // @return non-nullptr NodeDef that has valid Node or an error 188 absl::Status LookupNode(NodeId id, NodeDef** node_def); 189 190 // @return non-nullptr ValueDef that has valid Value or an error 191 absl::Status LookupValue(ValueId id, ValueDef** value_def); 192 193 template <typename Pred> FilterValues(const Pred & predicate)194 std::vector<Value*> FilterValues(const Pred& predicate) const { 195 std::vector<Value*> values; 196 values.reserve(values_.size()); 197 for (auto& v : values_) { 198 if (v.value != nullptr && predicate(v)) { 199 values.push_back(v.value.get()); 200 } 201 } 202 return values; 203 } 204 205 template <typename Pred> FilterNodes(const Pred & predicate)206 std::vector<Node*> FilterNodes(const Pred& predicate) const { 207 std::vector<Node*> nodes; 208 nodes.reserve(nodes_.size()); 209 for (const auto id : execution_plan_) { 210 auto& n = nodes_.at(id); 211 if (n.node != nullptr && predicate(n)) { 212 nodes.push_back(n.node.get()); 213 } 214 } 215 return nodes; 216 } 217 218 // There are two approaches possible: wrap entire NodeDef and ValueDef into 219 // unique_ptr and store it in values_ and nodes_ or store it by value. 220 // We store it by value here to make introspection calls cheaper. 221 std::vector<ValueDef> values_; 222 223 std::map<NodeId, NodeDef> nodes_; 224 // Node Ids in order of execution. 225 std::vector<NodeId> execution_plan_; 226 }; 227 228 // Removes to_remove node that precedes to_keep node only if to_remove has 229 // outputs that are consumed only by to_keep. In such case to_keep inherits all 230 // to_remove inputs. 231 absl::Status RemovePrecedingNode(GraphFloat32* graph, const Node* to_remove, 232 const Node* to_keep); 233 234 // Removes to_remove node that follows to_keep node only if to_remove has inputs 235 // that are produced by to_keep. to_keep inherits all to_remove inputs. 236 absl::Status RemoveFollowingNode(GraphFloat32* graph, const Node* to_remove, 237 const Node* to_keep); 238 239 // Removes simple_node and its output value from the graph. Node is considered 240 // simple if it has only one input and one output value. Input value is kept. 241 absl::Status RemoveSimpleNodeKeepInput(GraphFloat32* graph, 242 const Node* simple_node); 243 244 // Removes simple_node and its input value from the graph. Node is considered 245 // simple if it has only one input and one output value. Output value is kept. 246 // simple_node should be an exclusive consumer of its input value. 247 absl::Status RemoveSimpleNodeKeepOutput(GraphFloat32* graph, 248 const Node* simple_node); 249 250 absl::Status AddOutput(GraphFloat32* graph, const Node* from_node, 251 Value** output); 252 253 // Makes a direct connection between from_node and to_node. All input parameters 254 // except output are expected to be initialized before passing to the function. 255 // If from_node already has an output value, which is not yet consumed by 256 // to_node, it may be passed as output parameter. 257 absl::Status ConnectTwoNodes(GraphFloat32* graph, const Node* from_node, 258 const Node* to_node, Value** output); 259 260 // @return true if all tensors have same batch value or if model has no values. 261 bool IsBatchMatchesForAllValues(const GraphFloat32& model); 262 263 } // namespace gpu 264 } // namespace tflite 265 266 #endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_H_ 267