1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2tensorrt/segment/segment.h"
17 
18 #include <queue>
19 #include <set>
20 #include <unordered_map>
21 #include <vector>
22 
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_format.h"
25 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
26 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
27 #include "tensorflow/core/common_runtime/graph_constructor.h"
28 #include "tensorflow/core/graph/algorithm.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/gtl/flatset.h"
33 #include "tensorflow/core/lib/strings/str_util.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/platform/types.h"
36 #include "tensorflow/core/util/env_var.h"
37 
38 #if GOOGLE_CUDA && GOOGLE_TENSORRT
39 
40 namespace tensorflow {
41 namespace tensorrt {
42 namespace segment {
43 namespace {
44 using absl::StrAppend;
45 using absl::StrAppendFormat;
46 using absl::StrCat;
47 using absl::StrJoin;
48 
49 // A simple graph representation to mirror Graph. This structure
50 // helps saving memory since segmenter modifies the graph in place, preventing
51 // the need to create a copy of the graph. It is composed of edges and nodes.
52 // Nodes keep pointers to original TF nodes.
53 class SimpleNode;
54 class SimpleGraph;
55 class SimpleEdge {
56  public:
SimpleEdge(int id,SimpleNode * src,int src_port,SimpleNode * dst,int dst_port,bool is_control=false)57   SimpleEdge(int id, SimpleNode* src, int src_port, SimpleNode* dst,
58              int dst_port, bool is_control = false)
59       : id_(id),
60         src_(src),
61         src_port_(src_port),
62         dst_(dst),
63         dst_port_(dst_port),
64         control_(is_control) {}
~SimpleEdge()65   ~SimpleEdge() {}
66 
src() const67   SimpleNode* src() const { return src_; }
dst() const68   SimpleNode* dst() const { return dst_; }
src_output() const69   int src_output() const { return src_port_; }
dst_input() const70   int dst_input() const { return dst_port_; }
id() const71   int id() const { return id_; }
IsControlEdge() const72   bool IsControlEdge() const { return control_; }
73 
74  private:
75   int id_;
76   SimpleNode* src_;
77   int src_port_;
78   SimpleNode* dst_;
79   int dst_port_;
80   bool control_;
81 };
82 
83 class SimpleNode {
84  public:
85   SimpleNode(const Node* node, const int id);
86 
in_edges() const87   const std::vector<SimpleEdge*>& in_edges() const { return in_edges_; }
out_edges() const88   const std::vector<SimpleEdge*>& out_edges() const { return out_edges_; }
89 
in_nodes() const90   std::vector<SimpleNode*> in_nodes() const {
91     std::vector<SimpleNode*> res;
92     res.reserve(in_edges_.size());
93     for (const auto e : in_edges_) {
94       if (e) res.push_back(e->src());
95     }
96     return res;
97   }
98 
out_nodes() const99   std::vector<SimpleNode*> out_nodes() const {
100     std::vector<SimpleNode*> res;
101     res.reserve(out_edges_.size());
102     for (const auto e : out_edges_) {
103       if (e) res.push_back(e->dst());
104     }
105     return res;
106   }
107 
name() const108   const string& name() const { return node_->name(); }
tf_node() const109   const Node* tf_node() const { return node_; }
id() const110   int id() const { return id_; }
111 
112  private:
113   const Node* node_;
114   std::vector<SimpleEdge*> in_edges_;
115   std::vector<SimpleEdge*> out_edges_;
116   int id_;
117 
118   friend class SimpleGraph;
119 };
120 
121 class SimpleGraph {
122  public:
123   explicit SimpleGraph(const Graph* g);
124   ~SimpleGraph();
125 
126   void AddControlEdge(SimpleNode* src, SimpleNode* dst);
127   void AddEdge(SimpleNode* src, int out_port, SimpleNode* dst, int in_port);
128   void RemoveEdge(const SimpleEdge*);
FindNodeId(int node_id)129   SimpleNode* FindNodeId(int node_id) {
130     if (node_id < 0 || node_id > static_cast<int>(nodes_.size())) {
131       return nullptr;
132     }
133     return nodes_[node_id];
134   }
num_node_ids() const135   int num_node_ids() const { return nodes_.size(); }
source_node() const136   const SimpleNode* source_node() const { return nodes_[Graph::kSourceId]; }
sink_node() const137   const SimpleNode* sink_node() const { return nodes_[Graph::kSinkId]; }
138 
139  private:
140   const Graph* g_;
141   std::vector<SimpleNode*> nodes_;
142   std::vector<SimpleEdge*> edges_;
143   // free_edge_ids_ and free_node_ids_ contain freed indices.
144   std::set<int> free_edge_ids_;
145   std::set<int> free_node_ids_;
146 };
147 
SimpleNode(const Node * node,const int id)148 SimpleNode::SimpleNode(const Node* node, const int id) : node_(node), id_(id) {
149   if (node_) {
150     in_edges_.reserve(node_->in_edges().size());
151     out_edges_.reserve(node_->out_edges().size());
152   }
153 }
154 
SimpleGraph(const Graph * g)155 SimpleGraph::SimpleGraph(const Graph* g) : g_(g) {
156   int n_nodes = g_->num_node_ids();
157   nodes_.resize(n_nodes, nullptr);
158   nodes_[g->kSourceId] = new SimpleNode(g->source_node(), g->kSourceId);
159   nodes_[g->kSinkId] = new SimpleNode(g->sink_node(), g->kSinkId);
160   int n_edges = g->num_edge_ids();
161   edges_.resize(n_edges, nullptr);
162   for (int i = 2; i < n_nodes; i++) {
163     const auto n = g->FindNodeId(i);
164     if (n) {
165       nodes_[i] = new SimpleNode(n, i);
166     } else {
167       free_node_ids_.insert(i);
168     }
169   }
170   for (int i = 0; i < n_edges; i++) {
171     const auto e = g->FindEdgeId(i);
172     if (e) {
173       const auto tfsrc = e->src();
174       const auto tfdst = e->dst();
175       bool is_control = e->IsControlEdge();
176       auto src = nodes_[tfsrc->id()];
177       auto dst = nodes_[tfdst->id()];
178       auto edge = new SimpleEdge(i, src, e->src_output(), dst, e->dst_input(),
179                                  is_control);
180       edges_[i] = edge;
181       src->out_edges_.push_back(edge);
182       dst->in_edges_.push_back(edge);
183     } else {
184       free_edge_ids_.insert(i);
185     }
186   }
187 }
188 
AddEdge(SimpleNode * src,int out_port,SimpleNode * dst,int in_port)189 void SimpleGraph::AddEdge(SimpleNode* src, int out_port, SimpleNode* dst,
190                           int in_port) {
191   int i = edges_.size();
192   if (!free_edge_ids_.empty()) {
193     auto it = free_edge_ids_.begin();
194     i = *it;
195     free_edge_ids_.erase(it);
196   } else {
197     edges_.push_back(nullptr);
198   }
199   bool is_control = (out_port == Graph::kControlSlot);
200   is_control |= (in_port == Graph::kControlSlot);
201   auto edge = new SimpleEdge(i, src, out_port, dst, in_port, is_control);
202   edges_[i] = edge;
203   src->out_edges_.push_back(edge);
204   dst->in_edges_.push_back(edge);
205 }
206 
AddControlEdge(SimpleNode * src,SimpleNode * dst)207 void SimpleGraph::AddControlEdge(SimpleNode* src, SimpleNode* dst) {
208   AddEdge(src, Graph::kControlSlot, dst, Graph::kControlSlot);
209 }
210 
RemoveEdge(const SimpleEdge * edge)211 void SimpleGraph::RemoveEdge(const SimpleEdge* edge) {
212   auto src = edge->src();
213   auto dst = edge->dst();
214   for (auto it = src->out_edges_.begin(); it != src->out_edges_.end(); ++it) {
215     if (*it == edge) {
216       src->out_edges_.erase(it);
217       break;
218     }
219   }
220   for (auto it = dst->in_edges_.begin(); it != dst->in_edges_.end(); ++it) {
221     if (*it == edge) {
222       dst->in_edges_.erase(it);
223       break;
224     }
225   }
226 }
227 
~SimpleGraph()228 SimpleGraph::~SimpleGraph() {
229   for (auto x : nodes_) delete x;
230   for (auto x : edges_) delete x;
231 }
232 
233 // Define comparison functions for std::set with pointer keys so that behavior
234 // is deterministic. When using std::set with pointer key types, the items are
235 // sorted by pointer address which is non-deterministic. This can cause issues
236 // for INT8 mode because the graph is converted twice and non-determinism may
237 // cause a mismatch between the calibration tables of the conversions.
238 struct SimpleEdgePtrCompare {
operator ()tensorflow::tensorrt::segment::__anon8f933a440111::SimpleEdgePtrCompare239   bool operator()(const SimpleEdge* lhs, const SimpleEdge* rhs) const {
240     return lhs->id() < rhs->id();
241   }
242 };
243 
244 // Copied from TF ReverseDFS, which only works for Graph.
StableDFS(const SimpleGraph & g,bool reverse,const std::vector<const SimpleNode * > & start,const std::function<bool (const SimpleNode *)> & enter,const std::function<bool (const SimpleNode *)> & leave)245 void StableDFS(const SimpleGraph& g, bool reverse,
246                const std::vector<const SimpleNode*>& start,
247                const std::function<bool(const SimpleNode*)>& enter,
248                const std::function<bool(const SimpleNode*)>& leave) {
249   // Stack of work to do.
250   struct Work {
251     const SimpleNode* node;
252     bool leave;  // Are we entering or leaving n?
253   };
254   std::vector<Work> stack(start.size());
255   for (int i = 0; i < start.size(); ++i) {
256     stack[i] = Work{start[i], false};
257   }
258 
259   auto get_nodes = reverse ? [](const SimpleNode* n) { return n->in_nodes(); }
260                            : [](const SimpleNode* n) { return n->out_nodes(); };
261   std::vector<bool> visited(g.num_node_ids(), false);
262   while (!stack.empty()) {
263     Work w = stack.back();
264     stack.pop_back();
265 
266     auto n = w.node;
267     if (w.leave) {
268       if (leave && !leave(n)) return;
269       continue;
270     }
271 
272     if (visited[n->id()]) continue;
273     visited[n->id()] = true;
274     if (enter && !enter(n)) return;
275 
276     // Arrange to call leave(n) when all done with descendants.
277     if (leave) stack.push_back(Work{n, true});
278 
279     auto nodes = get_nodes(n);
280     std::vector<const SimpleNode*> nodes_sorted(nodes.begin(), nodes.end());
281     std::sort(nodes_sorted.begin(), nodes_sorted.end(),
282               [](const SimpleNode* lhs, const SimpleNode* rhs) {
283                 return lhs->name() < rhs->name();
284               });
285     for (const SimpleNode* node : nodes_sorted) {
286       if (!visited[node->id()]) {
287         stack.push_back(Work{node, false});
288       }
289     }
290   }
291 }
292 
CanContractEdge(const SimpleEdge * edge,const std::unique_ptr<SimpleGraph> & graph)293 bool CanContractEdge(const SimpleEdge* edge,
294                      const std::unique_ptr<SimpleGraph>& graph) {
295   const auto src = edge->src();
296   const auto dst = edge->dst();
297 
298   // Can't contract edge if doing so would cause a cycle in the
299   // graph. So, if there is a directed path from 'src' to 'dst', other
300   // than 'edge' (or any other direct edge from 'src' to 'dst'), then
301   // combining 'src' and 'dst' will cause a cycle along that path.
302   //
303   // In practice, to avoid modifying the graph and to take advantage
304   // of existing graph functions, we perform an equivalent.
305   //   1. Get all nodes incoming to 'dst', excluding 'src'
306   //   2. Reverse DFS from those nodes
307   //   3. If reverse DFS reaches 'src' then we have a cycle
308   //
309   // TODO(aaroey): there are several problems with the current approach:
310   // 1. src->dst->src, this is not detected but it should be;
311   // 2. src->dst->...(any node sequence that doesn't contain src)...->dst, this
312   //    is detected but it should not be.
313   //
314   // Note that it's fine that dst connects back to src indirectly (i.e. through
315   // a path with length > 1 that consists of intermedia nodes other than src).
316   // While loops is one example.
317   //
318   // The goal is to make sure that the trt subgraph:
319   // 1. has no loops (i.e. is a DAG), and
320   // 2. if there is a path in the subgraph from X to Y (X and Y are both nodes
321   //    in the subgraph), then all paths from X to Y are in the subgraph.
322   //
323   // To achieve this goal, the correct way seems to be:
324   // 1. remove any direct edge from src->dst;
325   // 2. detect if src can reach dst, if so they cannot be merged.
326   std::vector<const SimpleNode*> dfs_start_nodes;
327   for (const SimpleNode* node : dst->in_nodes()) {
328     if (node != src) {
329       dfs_start_nodes.push_back(node);
330     }
331   }
332   bool has_cycle = false;
333   StableDFS(*graph, /*reverse=*/true, dfs_start_nodes, /*enter=*/nullptr,
334             [&has_cycle, src](const SimpleNode* n) {
335               if (n == src) {
336                 has_cycle = true;
337                 return false;
338               }
339               return true;
340             });
341   return !has_cycle;
342 }
343 
344 // TODO(bixia): put this to a common utility file.
TensorPropertiesToString(const OpInfo::TensorProperties & prop)345 string TensorPropertiesToString(const OpInfo::TensorProperties& prop) {
346   string s = StrCat(DataTypeString(prop.dtype()), ": ");
347   StrAppend(&s, "[");
348   if (prop.shape().unknown_rank()) {
349     StrAppend(&s, "?");
350   } else {
351     StrAppend(&s, StrJoin(prop.shape().dim(), ",",
352                           [](string* out, const TensorShapeProto_Dim& d) {
353                             StrAppendFormat(out, "%d", d.size());
354                           }));
355   }
356   StrAppend(&s, "]");
357   return s;
358 }
359 
TensorPropertiesToString(const std::vector<OpInfo::TensorProperties> & properties)360 string TensorPropertiesToString(
361     const std::vector<OpInfo::TensorProperties>& properties) {
362   return StrJoin(properties, "; ",
363                  [](string* out, const OpInfo::TensorProperties& prop) {
364                    StrAppend(out, TensorPropertiesToString(prop));
365                  });
366 }
367 
368 // From the given list of input properties, returns the leading shape, which is
369 // the shape that determines the batch size of the operation. The leading shape
370 // is selected from the group of input shapes with the highest rank as follows:
371 //  . If all of those shapes have non-negative values for the batch dimension,
372 //    the leading shape is the one with the largest value for the batch
373 //    dimension.
374 //  . If some or all of those shapes have negative values for the batch
375 //    dimension, and the rest of those shapes have 1 for the batch dimension,
376 //    the leading shape is the first of those shapes with a negative value for
377 //    the batch dimension.
378 //  . Otherwise, we can't determine the leading shape for the operation and
379 //    have to exclude the operation from TRT.
380 //
381 // Examples:
382 //    case-1: a[1,3,4] + b[2,3,4] => leading shape [2,3,4]
383 //    case-2: a[2,3,4] + b[scalar] => leading shape [2,3,4]
384 //    case-3: a[-1,3,4] + b[1,3,4] => leading shape [-1,3,4]
385 //    case-4: a[-1,3,4] + b[2,3,4] => no leading shape
386 //
387 // We have to return "no leading shape" for case-4 to exclude such operation
388 // from being translated for this reason:
389 //   The actually input for "a" have to be in the shape of [2,3,4] for the
390 //   operation to be valid. On the other hand, if we translate the operation
391 //   to implicit batch mode, it will becomes a[3,4]+b[3,4] which is valid for
392 //   any input shape of "a".
393 //
394 // This routine assumes the input program is valid. For example, we shouldn't
395 // see invalid operation like a[2,3,4] + b[3,3,4]. It also assumes the input
396 // properties is not empty and all input have known shapes.
397 //
398 // TODO(bixia): find a way to share this knowledge with the converter.
399 // TODO(bixia): investigate the use of symbolic shape analysis to improve
400 //   segmentation, such as by requiring the dynamic dimensions to have the same
401 //   negative value.
FindLeadingShape(absl::Span<const OpInfo::TensorProperties> properties)402 absl::optional<const TensorShapeProto*> FindLeadingShape(
403     absl::Span<const OpInfo::TensorProperties> properties) {
404   DCHECK(!properties.empty());
405   const TensorShapeProto* result;
406   int max_batch_dim_value;
407   auto choose_shape_with_higher_rank = [&](const TensorShapeProto* s) {
408     result = s;
409     max_batch_dim_value = s->dim_size() < 1 ? 1 : s->dim(0).size();
410   };
411 
412   DCHECK(!properties[0].shape().unknown_rank());
413   choose_shape_with_higher_rank(&properties[0].shape());
414 
415   for (const OpInfo::TensorProperties& p : properties.subspan(1)) {
416     DCHECK(!p.shape().unknown_rank());
417     if (p.shape().dim_size() < result->dim_size()) continue;
418 
419     if (p.shape().dim_size() > result->dim_size()) {
420       choose_shape_with_higher_rank(&p.shape());
421       continue;
422     }
423 
424     // Among the shapes with the same rank, choose the one with a dynamic batch
425     // size. If no shapes have a dynamic batch size, choose the one with the
426     // largest size.
427     if (result->dim_size() < 1) continue;
428 
429     if (p.shape().dim(0).size() < 0 || result->dim(0).size() < 0) {
430       if (p.shape().dim(0).size() < 0 && result->dim(0).size() >= 0) {
431         result = &p.shape();
432       } else {
433         max_batch_dim_value =
434             std::max<int>(max_batch_dim_value, p.shape().dim(0).size());
435       }
436 
437       continue;
438     }
439 
440     if (p.shape().dim(0).size() > result->dim(0).size()) {
441       result = &p.shape();
442       max_batch_dim_value = result->dim(0).size();
443     }
444   }
445 
446   if (result->dim_size() > 0 && result->dim(0).size() < 0) {
447     // dynamic batch size
448     if (max_batch_dim_value <= 1) {
449       return result;
450     } else {
451       return absl::nullopt;
452     }
453   }
454 
455   return result;
456 }
457 
458 // Returns the inputs that are relevant to determinate the batch size of the
459 // operation. This routine handles the following cases:
460 //   . Operations that support implicit boradcasting, such as operation mul.
461 //     In this case, we need to inspect all the inputs in order to determine the
462 //     batch size of the operation.
463 //   . Special cases. Such as "Conv2DBackpropInput", "Conv3DBackpropInputV2".
464 //   . The batch size of a operation is determined by the first input of the
465 //     operation.
GetInputsToDeterminateBatchSize(const Node * node,const std::vector<OpInfo::TensorProperties> & all_inputs)466 absl::Span<const OpInfo::TensorProperties> GetInputsToDeterminateBatchSize(
467     const Node* node, const std::vector<OpInfo::TensorProperties>& all_inputs) {
468   // TODO(bixia): Find a way to share this knowledge with the converter.
469   static std::set<string> broadcast_supporting_ops = {
470       // ops corresponding to ConvertBinary in the converter
471       "Add",
472       "AddV2",
473       "Mul",
474       "Sub",
475       "Div",
476       "FloorDiv",
477       "RealDiv",
478       "Minimum",
479       "Maximum",
480       "Pow",
481       // other ops that need to need GetTrtBroadcastShape to convert
482       "BiasAdd",
483       "SquaredDifference",
484       "BatchMatMul",
485       "BatchMatMulV2",
486   };
487   const string& op = node->def().op();
488 
489   if (op == "Conv2DBackpropInput" || op == "Conv3DBackpropInputV2") {
490     DCHECK_EQ(all_inputs.size(), 3);
491     return absl::MakeSpan(all_inputs).subspan(2, 1);
492   }
493 
494   if (broadcast_supporting_ops.count(op)) {
495     return absl::MakeSpan(all_inputs);
496   }
497 
498   // This is the common case for the operations that don't support implicit
499   // broadcasting: the first operand determines its batch size. All otherwise
500   // cases are handled before reaching here.
501   return absl::MakeSpan(all_inputs).subspan(0, 1);
502 }
503 
504 // Returns true if the operation we can remove the implicit batch of the
505 // operation.
506 //
507 // In particular, if the input shape has dynamic rank or the input shape rank
508 // is less than 2, we can't remove the implicit batch dimension and generate
509 // a new operation for TRT translation.
OperationCanBeTranslatedToImplicitBatch(const grappler::GraphProperties * graph_properties,const Node * node)510 bool OperationCanBeTranslatedToImplicitBatch(
511     const grappler::GraphProperties* graph_properties, const Node* node) {
512   VLOG(3) << "process node " << node->name();
513   if (node->num_inputs() == 0) return true;
514   if (!graph_properties || !graph_properties->HasInputProperties(node->name()))
515     return false;
516 
517   VLOG(3) << "input shapes "
518           << TensorPropertiesToString(
519                  graph_properties->GetInputProperties(node->name()));
520 
521   const std::vector<OpInfo::TensorProperties>& all_input_properties =
522       graph_properties->GetInputProperties(node->name());
523   absl::Span<const OpInfo::TensorProperties> input_properties =
524       GetInputsToDeterminateBatchSize(node, all_input_properties);
525   if (absl::c_any_of(input_properties, [](const OpInfo::TensorProperties& p) {
526         return p.shape().unknown_rank();
527       })) {
528     return false;
529   }
530 
531   absl::optional<const TensorShapeProto*> leading_shape =
532       FindLeadingShape(input_properties);
533   return leading_shape.has_value() && leading_shape.value()->dim_size() >= 2;
534 }
535 
536 // Returns true if we can't be sure that the operand with the given properties
537 // won't have negative values for non-batch dimensions.
538 //
HasDynamicNonBatchDimension(const OpInfo::TensorProperties & prop)539 bool HasDynamicNonBatchDimension(const OpInfo::TensorProperties& prop) {
540   const TensorShapeProto& shape = prop.shape();
541   if (shape.unknown_rank()) return true;
542 
543   // Scalar is a well specified shape, and TRT supports implicit broadcasting
544   // from scalar to other shapes.
545   if (shape.dim_size() == 0) return false;
546   for (int i = 1; i < shape.dim_size(); ++i) {
547     // The value of a dynamic dimension can be other negative values besides
548     // -1, representing the symbolic group of the dimension.
549     if (shape.dim(i).size() <= -1) {
550       return true;
551     }
552   }
553   return false;
554 }
555 
556 // Returns true if we can't be sure that the operation won't have dynamic
557 // non-batch dimension involved. We only check the shape of the first output
558 // assuming shape inference already propagates the shapes.
OperationHasDynamicNonBatchDimension(const grappler::GraphProperties * graph_properties,const Node * node)559 bool OperationHasDynamicNonBatchDimension(
560     const grappler::GraphProperties* graph_properties, const Node* node) {
561   VLOG(3) << "process node " << node->name();
562   // If the node doesn't have any input or output, not computation is involved.
563   if (node->num_inputs() == 0 || node->num_outputs() == 0) return false;
564 
565   // If the node doesn't have output properties, return true to be conservative.
566   if (!graph_properties->HasOutputProperties(node->name())) return true;
567   VLOG(3) << "output shapes "
568           << TensorPropertiesToString(
569                  graph_properties->GetOutputProperties(node->name()));
570   return HasDynamicNonBatchDimension(
571       graph_properties->GetOutputProperties(node->name()).at(0));
572 }
573 
ContractEdge(SimpleEdge * edge,SimpleGraph * graph,std::vector<const SimpleEdge * > * remove_edges)574 void ContractEdge(SimpleEdge* edge, SimpleGraph* graph,
575                   std::vector<const SimpleEdge*>* remove_edges) {
576   // Transfer all inputs and outputs of 'dst' to 'src' except edges
577   // connecting the two.
578   auto src = edge->src();
579   auto dst = edge->dst();
580 
581   // We can use '0' for input/output index because we don't need them
582   // to be accurate for the way we are using the graph.
583   std::vector<const SimpleEdge*> in_edges(dst->in_edges().begin(),
584                                           dst->in_edges().end());
585   for (const SimpleEdge* in_edge : in_edges) {
586     if (in_edge->IsControlEdge()) {
587       if (in_edge->src() != src) {
588         SimpleEdge* e = const_cast<SimpleEdge*>(in_edge);
589         graph->AddControlEdge(e->src(), src);
590       }
591     } else {
592       if (in_edge->src() != src) {
593         SimpleEdge* e = const_cast<SimpleEdge*>(in_edge);
594         if (e->src() == graph->source_node()) {
595           graph->AddEdge(e->src(), e->src_output(), src, Graph::kControlSlot);
596         } else {
597           graph->AddEdge(e->src(), e->src_output(), src, 0 /* input index */);
598         }
599       }
600     }
601   }
602 
603   std::vector<const SimpleEdge*> out_edges(dst->out_edges().begin(),
604                                            dst->out_edges().end());
605   for (const SimpleEdge* out_edge : out_edges) {
606     if (out_edge->IsControlEdge()) {
607       SimpleEdge* e = const_cast<SimpleEdge*>(out_edge);
608       graph->AddControlEdge(src, e->dst());
609     } else {
610       SimpleEdge* e = const_cast<SimpleEdge*>(out_edge);
611       if (e->dst() == graph->sink_node()) {
612         VLOG(1) << " edge to sink node " << src->name() << " -> "
613                 << e->dst()->name();
614         graph->AddEdge(src, Graph::kControlSlot, e->dst(), e->dst_input());
615       } else {
616         graph->AddEdge(src, 0 /* output index */, e->dst(), e->dst_input());
617       }
618     }
619   }
620 
621   // Return the edges that must be removed to disconnect 'dst' from
622   // the graph. We don't actually remove 'dst' since the caller holds
623   // references to all the nodes.
624   for (const auto& in_edge : dst->in_edges()) {
625     remove_edges->push_back(in_edge);
626   }
627   for (const auto& out_edge : dst->out_edges()) {
628     remove_edges->push_back(out_edge);
629   }
630 }
631 
632 // Returns a batch size representation for a segment that only contains the
633 // given node.
GetClusterBatchSizeForNode(const grappler::GraphProperties * graph_properties,const Node * node,bool use_implicit_batch)634 ClusterBatchSize GetClusterBatchSizeForNode(
635     const grappler::GraphProperties* graph_properties, const Node* node,
636     bool use_implicit_batch) {
637   ClusterBatchSize cluster_batch_size;
638   if (!use_implicit_batch || !node || node->num_inputs() == 0) {
639     return cluster_batch_size;
640   }
641 
642   const NodeDef& node_def = node->def();
643   if (node_def.attr().count(kTftrtOpMaxBatchSizeAttr)) {
644     cluster_batch_size.SetMaxBatchSize(
645         node_def.attr().at(kTftrtOpMaxBatchSizeAttr).i());
646   }
647 
648   // As shape inference cannot provide any useful information about the batch
649   // size, we keep it as missing.
650   if (!graph_properties ||
651       !graph_properties->HasInputProperties(node->name())) {
652     VLOG(3) << "doesn't have input property";
653     return cluster_batch_size;
654   }
655 
656   const std::vector<OpInfo::TensorProperties>& input_properties =
657       graph_properties->GetInputProperties(node->name());
658   absl::optional<const TensorShapeProto*> optional_leading_shape =
659       FindLeadingShape(GetInputsToDeterminateBatchSize(node, input_properties));
660   DCHECK(optional_leading_shape.has_value());
661   const TensorShapeProto* leading_shape = optional_leading_shape.value();
662   DCHECK(!leading_shape->unknown_rank() && leading_shape->dim_size() >= 2);
663   VLOG(3) << "set batch size as " << leading_shape->dim(0).size();
664   return cluster_batch_size.SetBatchSize(leading_shape->dim(0).size());
665 }
666 
AddSegmentForNode(const grappler::GraphProperties * graph_properties,std::vector<UnionFind<SimpleNode * >> * segments,SimpleNode * node,const DeviceNameUtils::ParsedName & device_name,bool use_implicit_batch)667 void AddSegmentForNode(const grappler::GraphProperties* graph_properties,
668                        std::vector<UnionFind<SimpleNode*>>* segments,
669                        SimpleNode* node,
670                        const DeviceNameUtils::ParsedName& device_name,
671                        bool use_implicit_batch) {
672   ClusterProperty property(
673       GetClusterBatchSizeForNode(graph_properties,
674                                  node == nullptr ? nullptr : node->tf_node(),
675                                  use_implicit_batch),
676       device_name);
677   segments->emplace_back(node, std::move(property));
678 }
679 
680 }  // namespace
681 
SegmentGraph(const Graph * tf_graph,const grappler::GraphProperties * graph_properties,const std::function<Status (const Node *)> & candidate_fn,const std::function<bool (const Edge *)> & input_candidate_fn,const std::function<bool (const Edge *)> & output_candidate_fn,const SegmentOptions & options,SegmentVector * segments)682 Status SegmentGraph(const Graph* tf_graph,
683                     const grappler::GraphProperties* graph_properties,
684                     const std::function<Status(const Node*)>& candidate_fn,
685                     const std::function<bool(const Edge*)>& input_candidate_fn,
686                     const std::function<bool(const Edge*)>& output_candidate_fn,
687                     const SegmentOptions& options, SegmentVector* segments) {
688   if (!options.use_implicit_batch && !options.allow_dynamic_non_batch_dim) {
689     return errors::Internal(
690         "Explicit batch mode should allow dynamic non-batch dimensions");
691   }
692 
693   if (options.use_implicit_batch && !options.maximum_batch_size.has_value()) {
694     return errors::Internal("Implicit batch mode requires maximum_batch_size");
695   }
696 
697   if (!options.allow_dynamic_non_batch_dim && !graph_properties) {
698     return errors::Internal(
699         "Need graph propertities to disallow dynamic non-batch dimensions");
700   }
701 
702   // Steps:
703   // 1. run the segmentation algorithm to find all the segments, which uses
704   //    candidate_fn to determine the candidates segment nodes;
705   // 2. for each segments, remove the nodes that are inputs/outputs of the
706   //    segment but are not eligible, using input/output_candidate_fn to
707   //    determine the eligibilities;
708   // 3. convert the segment into expected return format and return the result.
709 
710   // --------------------------------- Step 1 ---------------------------------
711   auto graph = std::unique_ptr<SimpleGraph>(new SimpleGraph(tf_graph));
712   // Use a union-find to collect the nodes that belong to the same
713   // segment. A node value of nullptr indicates that the node is not a candidate
714   // for TRT.
715   std::unordered_set<string> unsupported_ops;
716   int num_unsupported_ops = 0;
717 
718   // Getting the operations denylisted for conversion
719   string tftrt_op_denylist_str;
720   TF_CHECK_OK(
721       ReadStringFromEnvVar("TF_TRT_OP_DENYLIST", "", &tftrt_op_denylist_str));
722 
723   auto tftrt_op_denylist = gtl::FlatSet<string>{};  // non-absl ok
724 
725   for (const auto& x : str_util::Split(tftrt_op_denylist_str, ",")) {
726     tftrt_op_denylist.insert(x);
727   }
728 
729   // Parsing each node of the graph
730   std::vector<UnionFind<SimpleNode*>> node_segments;
731   for (int i = 0; i < graph->num_node_ids(); ++i) {
732     SimpleNode* node = graph->FindNodeId(i);
733     if (!node) {
734       VLOG(3) << "Node " << i << " doesn't exist in the graph";
735       continue;
736     }
737     auto exclude_node = [&](absl::string_view reason) {
738       VLOG(1) << "Not a TF-TRT candidate, "
739               << "(Op type: " << node->tf_node()->type_string() << "), "
740               << "(Op name: " << node->name() << "), "
741               << "(Reason: " << reason << ")";
742       unsupported_ops.emplace(node->tf_node()->type_string());
743       num_unsupported_ops++;
744       node = nullptr;
745     };
746     absl::optional<DeviceNameUtils::ParsedName> device_name =
747         GetDeviceParsedName(node->tf_node());
748     // GetDeviceParseName capitalizes the device type.
749     if (!device_name.has_value() ||
750         (device_name->has_type && device_name->type != "GPU")) {
751       exclude_node("node can't be placed on GPU");
752     } else if (options.exclude_node_list.count(node->name()) != 0) {
753       exclude_node("excluded by segmenter option");
754     } else if (options.use_implicit_batch &&
755                !OperationCanBeTranslatedToImplicitBatch(graph_properties,
756                                                         node->tf_node())) {
757       exclude_node(
758           "implicit batch mode requires input shape with at least two "
759           "dimensions");
760     } else if (!options.allow_dynamic_non_batch_dim &&
761                OperationHasDynamicNonBatchDimension(graph_properties,
762                                                     node->tf_node())) {
763       exclude_node("dynamic non-batch dimensions not allowed");
764     } else {
765       const Status status = candidate_fn(node->tf_node());
766       if (!status.ok()) {
767         exclude_node(status.error_message());
768       } else if (tftrt_op_denylist.count(node->tf_node()->type_string())) {
769         // WARNING verbosity since the user explicitly requests this behavior.
770         LOG_WARNING_WITH_PREFIX
771             << "Denylisted as TF-TRT candidate, "
772             << "(Op type: " << node->tf_node()->type_string() << "), "
773             << "(Op name: " << node->name() << ")";
774         exclude_node("Denylisted with the env var TF_TRT_OP_DENYLIST");
775       } else {
776         VLOG(2) << "Accepted as a TF-TRT candidate, "
777                 << "(Op type: " << node->tf_node()->type_string() << "), "
778                 << "(Op name: " << node->name();
779       }
780     }
781     AddSegmentForNode(graph_properties, &node_segments, node, *device_name,
782                       options.use_implicit_batch);
783   }
784   string msg = StrCat(
785       "There are ", num_unsupported_ops, " ops of ", unsupported_ops.size(),
786       " different types in the graph that", " are not converted to TensorRT: ");
787   for (const auto& elem : unsupported_ops) {
788     StrAppend(&msg, elem, ", ");
789   }
790   LOG(INFO) << msg << "(For more information see "
791             << "https://docs.nvidia.com/deeplearning"
792             << "/frameworks/tf-trt-user-guide/index.html#supported-ops).";
793 
794   // The segmentation algorithm below visits nodes in reverse topological order
795   // and attempts to merge nodes along output edges. That means that subgraphs
796   // grow from the output-side of the network towards the inputs.
797   //
798   // In general this is not guaranteed to produce a globally optimal
799   // segmentation. For example, consider graph with node {A, B, C, D} and edges
800   // {A->B, A->C, B->D, C->D), where A, B, D are trt compatible but C is not, so
801   // in theory we can choose to contract either A, B or B, D but not both, but
802   // here it always choose to contract B, D.
803   //
804   // In the future if we have a measure of how beneficial it is to include a
805   // given node in a TRT subgraph then we can revisit this algorithm to take
806   // advantage of that information.
807   std::vector<const SimpleNode*> order;
808   order.reserve(graph->num_node_ids());
809   StableDFS(*graph, /*reverse=*/false, {graph->source_node()},
810             /*enter=*/nullptr, [&order](const SimpleNode* n) {
811               order.push_back(n);
812               return true;
813             });
814   for (const SimpleNode* node : order) {
815     // All output nodes of 'node' have been visited.
816     VLOG(3) << "Trying node " << node->name() << " id=" << node->id();
817     // 'node' must be a TRT candidate.
818     if (node_segments[node->id()].Value() == nullptr) {
819       VLOG(3) << "... not a TRT candidate";
820       continue;
821     }
822     // Contract output edges to combine 'node' with output nodes. Repeat this
823     // step until no output edges can be further contracted. This is because
824     // contracting an output edge may unblock new edges for contracting.
825     ClusterBatchSize expected_batch_size =
826         node_segments[node->id()].Property().BatchSize();
827     DeviceNameUtils::ParsedName expected_device_name =
828         node_segments[node->id()].Property().DeviceName();
829     VLOG(3) << "batch size " << expected_batch_size;
830     while (true) {
831       std::set<const SimpleEdge*, SimpleEdgePtrCompare> contract_edges;
832       // TODO(bixia): consider merging the loop to find the edges and the loop
833       // to contract the edges.
834       for (const SimpleEdge* out_edge : node->out_edges()) {
835         VLOG(3) << "... out node " << out_edge->dst()->name() << " ( "
836                 << out_edge->dst()->id() << " <- " << node->id() << " )";
837         if (out_edge->IsControlEdge()) {
838           VLOG(3) << "... ... Control Edge, Skipping";
839           continue;
840         }
841         UnionFind<SimpleNode*>* out_cluster =
842             &node_segments[out_edge->dst()->id()];
843         // Out node must be a TRT candidate.
844         if (out_cluster->Value() == nullptr) {
845           VLOG(3) << "... ... not a TRT candidate";
846           continue;
847         }
848         // Out node must have compatible batch size.
849         ClusterBatchSize out_batch_size = out_cluster->Property().BatchSize();
850         ClusterBatchSize merged_batch_size = expected_batch_size;
851         if (!merged_batch_size.MergeIfCompatible(out_batch_size)) {
852           VLOG(3) << "... ... incompatible batch sizes "
853                   << expected_batch_size.ToString() << " "
854                   << out_batch_size.ToString();
855           continue;
856         }
857 
858         const DeviceNameUtils::ParsedName& out_device_name =
859             out_cluster->Property().DeviceName();
860         absl::optional<DeviceNameUtils::ParsedName> merged_device_name =
861             MergeIfCompatible(expected_device_name, out_device_name);
862         if (!merged_device_name.has_value()) {
863           VLOG(3) << "... ... incompatible device names "
864                   << expected_device_name << " " << out_device_name;
865           continue;
866         }
867 
868         if (CanContractEdge(out_edge, graph)) {
869           VLOG(3) << "... ... can contract. new batch size "
870                   << merged_batch_size.ToString();
871           contract_edges.insert(out_edge);
872           expected_batch_size = merged_batch_size;
873           expected_device_name = *merged_device_name;
874         } else {
875           VLOG(3) << "... ... cannot contract, would form cycle";
876         }
877       }
878       if (contract_edges.empty()) {
879         break;
880       }
881       // Contract edges and collect the adjacent nodes into the same
882       // segment/subgraph.
883       while (!contract_edges.empty()) {
884         const SimpleEdge* contract_edge = *contract_edges.begin();
885         const SimpleNode* src = contract_edge->src();
886         const SimpleNode* dst = contract_edge->dst();
887 
888         VLOG(3) << "Merge " << src->name() << " <- " << dst->name() << " ("
889                 << src->id() << " <- " << dst->id();
890         TF_RETURN_IF_ERROR(
891             node_segments[src->id()].Merge(&node_segments[dst->id()]));
892 
893         // Contracting the edge leaves disconnected graph edges.
894         // Remove these from the graph and from 'contract_edges' so we
895         // don't visit them again.
896         SimpleEdge* e = const_cast<SimpleEdge*>(contract_edge);
897         std::vector<const SimpleEdge*> remove_edges;
898         ContractEdge(e, graph.get(), &remove_edges);
899 
900         for (const SimpleEdge* r : remove_edges) {
901           contract_edges.erase(r);
902           graph->RemoveEdge(r);
903         }
904       }
905       if (expected_batch_size !=
906           node_segments[node->id()].Property().BatchSize()) {
907         return errors::Internal(
908             "expected batch size is not the same as the actual batch size");
909       }
910       if (expected_device_name !=
911           node_segments[node->id()].Property().DeviceName()) {
912         return errors::Internal(
913             "expected device name is not the same as the actual device name");
914       }
915     }
916   }
917 
918   // Collect the segments/subgraphs. Each subgraph is represented by a
919   // set of the names of the nodes in that subgraph.
920 
921   // A map from the segment identifier (currently the name of the root node of
922   // the segment tree) to the segment nodes set.
923   std::map<string, Segment> sg_map;
924 
925   for (auto& u : node_segments) {
926     if ((u.Value() != nullptr) && (u.ParentValue() != nullptr)) {
927       sg_map[u.ParentValue()->name()].nodes.insert(u.Value()->tf_node());
928     }
929     if ((u.Value() != nullptr) && (u.ParentValue() == u.Value())) {
930       sg_map[u.Value()->name()].property = u.Property();
931     }
932   }
933 
934   // --------------------------------- Step 2 ---------------------------------
935   // Remove ineligible input/output nodes.
936   for (auto& itr : sg_map) {
937     std::set<const Node*, NodePtrCompare>& segment_nodes = itr.second.nodes;
938     VLOG(1) << "Segment original size: " << segment_nodes.size();
939     while (true) {
940       std::deque<const Node*> in_nodes_que, out_nodes_que;
941       // Find an input node that is not eligible and add it to the queue.
942       // Nodes that has no incoming edges should not be treated as "input",
943       // as there are really no inputs to them. Similar for output nodes.
944       for (auto node : segment_nodes) {
945         bool added = false;
946         for (const Edge* edge : node->in_edges()) {
947           if (!edge->IsControlEdge() && !edge->src()->IsSource() &&
948               !segment_nodes.count(edge->src())) {  // 'node' is an input node.
949             if (!input_candidate_fn(edge)) {
950               in_nodes_que.push_back(node);
951               added = true;
952               break;
953             }
954           }
955         }
956         if (added) continue;  // Only adding the node once to either queue.
957         for (const Edge* edge : node->out_edges()) {
958           if (!edge->dst()->IsSink() && !edge->IsControlEdge() &&
959               !segment_nodes.count(edge->dst())) {  // 'node' is an output node.
960             if (!output_candidate_fn(edge)) {
961               out_nodes_que.push_back(node);
962               break;
963             }
964           }
965         }
966       }
967       if (in_nodes_que.empty() && out_nodes_que.empty()) {
968         // No more ineligible input/output nodes.
969         break;
970       }
971       // Now for each ineligible node, remove all of its inputs or outputs from
972       // the subgraph.
973       //
974       // It can be proven that, if the original subgraph:
975       // 1. is a DAG, and
976       // 2. all paths between two nodes in the subgraph are all inside the
977       //    subgraph
978       // then after doing this operation the resulting subgraph will keep the
979       // same properties 1 and 2.
980       //
981       // For simplicity we use heuristics: for input and const output nodes
982       // remove all their inputs, and for non-const output nodes remove all
983       // their outputs. In this way, for common cases the number of removed
984       // nodes should be minimum.
985       auto remove_nodes = [&segment_nodes](bool is_input_nodes,
986                                            std::deque<const Node*>* que) {
987         // Run a BFS on the queue to find all the input/output nodes.
988         std::set<const Node*, NodePtrCompare> visited;
989         std::set<const Node*, NodePtrCompare> logged(que->begin(), que->end());
990         while (!que->empty()) {
991           auto node = que->front();
992           que->pop_front();
993           if (!visited.insert(node).second) continue;
994           segment_nodes.erase(node);
995           for (auto in : (is_input_nodes || node->type_string() == "Const")
996                              ? node->in_nodes()
997                              : node->out_nodes()) {
998             if (segment_nodes.count(in)) {
999               que->push_back(in);
1000               if (VLOG_IS_ON(2)) {
1001                 if (!logged.count(in)) {
1002                   VLOG(2) << "----> Need to remove node " << in->name()
1003                           << " because one of its "
1004                           << (is_input_nodes ? "output" : "input")
1005                           << " nodes in the graph was removed: "
1006                           << node->name();
1007                   logged.insert(in);
1008                 }
1009               }
1010             }
1011           }
1012         }
1013       };
1014       remove_nodes(true, &in_nodes_que);
1015       remove_nodes(false, &out_nodes_que);
1016     }
1017     VLOG(1) << "Segment new size: " << segment_nodes.size();
1018   }
1019 
1020   // --------------------------------- Step 3 ---------------------------------
1021   // Convert the segments into the expected return format
1022   for (const auto& itr : sg_map) {
1023     const string& segment_root = itr.first;
1024     // Return format does not require set comparator.
1025     std::set<const Node*, NodePtrCompare> segment_nodes(
1026         itr.second.nodes.begin(), itr.second.nodes.end());
1027     if (VLOG_IS_ON(1) && !segment_nodes.empty()) {
1028       string s;
1029       for (auto node : segment_nodes) {
1030         StrAppend(&s, "\n[Op type: ", node->type_string(), "] ", node->name());
1031       }
1032       VLOG(1) << "Nodes in segment " << segments->size()
1033               << " with parent=" << segment_root << ":" << s;
1034     }
1035 
1036     const int num_effective_nodes = std::count_if(
1037         segment_nodes.begin(), segment_nodes.end(), [](const Node* node) {
1038           static auto noops =
1039               new std::set<string>{"Identity", "Snapshot", "StopGradient"};
1040           return noops->count(node->type_string()) == 0;
1041         });
1042 
1043     // Don't use segments whose number of effective nodes is small.
1044     if (num_effective_nodes == 0 ||
1045         num_effective_nodes < options.minimum_segment_size) {
1046       VLOG(1) << "Segment " << segments->size() << " has only "
1047               << num_effective_nodes << " effective nodes, dropping";
1048       continue;
1049     }
1050     segments->emplace_back(itr.second.property, segment_nodes);
1051   }
1052 
1053   return Status::OK();
1054 }
1055 
1056 }  // namespace segment
1057 }  // namespace tensorrt
1058 }  // namespace tensorflow
1059 
1060 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
1061