1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2tensorrt/segment/segment.h"
17 
18 #include <queue>
19 #include <set>
20 #include <unordered_map>
21 #include <vector>
22 
23 #include "absl/strings/str_cat.h"
24 #include "tensorflow/compiler/tf2tensorrt/segment/union_find.h"
25 #include "tensorflow/core/graph/algorithm.h"
26 #include "tensorflow/core/graph/graph.h"
27 #include "tensorflow/core/graph/graph_constructor.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/lib/strings/strcat.h"
31 #include "tensorflow/core/platform/types.h"
32 
33 #if GOOGLE_CUDA
34 #if GOOGLE_TENSORRT
35 
36 namespace tensorflow {
37 namespace tensorrt {
38 namespace segment {
39 using absl::StrAppend;
40 using absl::StrCat;
41 
42 // A simple graph representation to mirror Graph. This structure
43 // helps saving memory since segmenter modifies the graph in place, preventing
44 // the need to create a copy of the graph. It is composed of edges and nodes.
45 // Nodes keep pointers to original TF nodes.
46 class SimpleNode;
47 class SimpleGraph;
48 class SimpleEdge {
49  public:
SimpleEdge(int id,SimpleNode * src,int src_port,SimpleNode * dst,int dst_port,bool is_control=false)50   SimpleEdge(int id, SimpleNode* src, int src_port, SimpleNode* dst,
51              int dst_port, bool is_control = false)
52       : id_(id),
53         src_(src),
54         src_port_(src_port),
55         dst_(dst),
56         dst_port_(dst_port),
57         control_(is_control) {}
~SimpleEdge()58   ~SimpleEdge() {}
59 
src() const60   SimpleNode* src() const { return src_; }
dst() const61   SimpleNode* dst() const { return dst_; }
src_output() const62   int src_output() const { return src_port_; }
dst_input() const63   int dst_input() const { return dst_port_; }
id() const64   int id() const { return id_; }
IsControlEdge() const65   bool IsControlEdge() const { return control_; }
66 
67  private:
68   int id_;
69   SimpleNode* src_;
70   int src_port_;
71   SimpleNode* dst_;
72   int dst_port_;
73   bool control_;
74 };
75 
76 class SimpleNode {
77  public:
78   SimpleNode(const Node* node, const int id);
79 
in_edges() const80   const std::vector<SimpleEdge*>& in_edges() const { return in_edges_; }
out_edges() const81   const std::vector<SimpleEdge*>& out_edges() const { return out_edges_; }
82 
in_nodes() const83   std::vector<SimpleNode*> in_nodes() const {
84     std::vector<SimpleNode*> res;
85     res.reserve(in_edges_.size());
86     for (const auto e : in_edges_) {
87       if (e) res.push_back(e->src());
88     }
89     return res;
90   }
91 
out_nodes() const92   std::vector<SimpleNode*> out_nodes() const {
93     std::vector<SimpleNode*> res;
94     res.reserve(out_edges_.size());
95     for (const auto e : out_edges_) {
96       if (e) res.push_back(e->dst());
97     }
98     return res;
99   }
100 
name() const101   const string& name() const { return node_->name(); }
tf_node() const102   const Node* tf_node() const { return node_; }
id() const103   int id() const { return id_; }
104 
105  private:
106   const Node* node_;
107   std::vector<SimpleEdge*> in_edges_;
108   std::vector<SimpleEdge*> out_edges_;
109   int id_;
110 
111   friend class SimpleGraph;
112 };
113 
114 class SimpleGraph {
115  public:
116   explicit SimpleGraph(const Graph* g);
117   ~SimpleGraph();
118 
119   void AddControlEdge(SimpleNode* src, SimpleNode* dst);
120   void AddEdge(SimpleNode* src, int out_port, SimpleNode* dst, int in_port);
121   void RemoveEdge(const SimpleEdge*);
FindNodeId(int node_id)122   SimpleNode* FindNodeId(int node_id) {
123     if (node_id < 0 || node_id > static_cast<int>(nodes_.size())) {
124       return nullptr;
125     }
126     return nodes_[node_id];
127   }
num_node_ids() const128   int num_node_ids() const { return nodes_.size(); }
source_node() const129   const SimpleNode* source_node() const { return nodes_[Graph::kSourceId]; }
sink_node() const130   const SimpleNode* sink_node() const { return nodes_[Graph::kSinkId]; }
131 
132  private:
133   const Graph* g_;
134   std::vector<SimpleNode*> nodes_;
135   std::vector<SimpleEdge*> edges_;
136   // free_edge_ids_ and free_node_ids_ contain freed indices.
137   std::set<int> free_edge_ids_;
138   std::set<int> free_node_ids_;
139 };
140 
SimpleNode(const Node * node,const int id)141 SimpleNode::SimpleNode(const Node* node, const int id) : node_(node), id_(id) {
142   if (node_) {
143     in_edges_.reserve(node_->in_edges().size());
144     out_edges_.reserve(node_->out_edges().size());
145   }
146 }
147 
SimpleGraph(const Graph * g)148 SimpleGraph::SimpleGraph(const Graph* g) : g_(g) {
149   int n_nodes = g_->num_node_ids();
150   nodes_.resize(n_nodes, nullptr);
151   nodes_[g->kSourceId] = new SimpleNode(g->source_node(), g->kSourceId);
152   nodes_[g->kSinkId] = new SimpleNode(g->sink_node(), g->kSinkId);
153   int n_edges = g->num_edge_ids();
154   edges_.resize(n_edges, nullptr);
155   for (int i = 2; i < n_nodes; i++) {
156     const auto n = g->FindNodeId(i);
157     if (n) {
158       nodes_[i] = new SimpleNode(n, i);
159     } else {
160       free_node_ids_.insert(i);
161     }
162   }
163   for (int i = 0; i < n_edges; i++) {
164     const auto e = g->FindEdgeId(i);
165     if (e) {
166       const auto tfsrc = e->src();
167       const auto tfdst = e->dst();
168       bool is_control = e->IsControlEdge();
169       auto src = nodes_[tfsrc->id()];
170       auto dst = nodes_[tfdst->id()];
171       auto edge = new SimpleEdge(i, src, e->src_output(), dst, e->dst_input(),
172                                  is_control);
173       edges_[i] = edge;
174       src->out_edges_.push_back(edge);
175       dst->in_edges_.push_back(edge);
176     } else {
177       free_edge_ids_.insert(i);
178     }
179   }
180 }
181 
AddEdge(SimpleNode * src,int out_port,SimpleNode * dst,int in_port)182 void SimpleGraph::AddEdge(SimpleNode* src, int out_port, SimpleNode* dst,
183                           int in_port) {
184   int i = edges_.size();
185   if (!free_edge_ids_.empty()) {
186     auto it = free_edge_ids_.begin();
187     i = *it;
188     free_edge_ids_.erase(it);
189   } else {
190     edges_.push_back(nullptr);
191   }
192   bool is_control = (out_port == Graph::kControlSlot);
193   is_control |= (in_port == Graph::kControlSlot);
194   auto edge = new SimpleEdge(i, src, out_port, dst, in_port, is_control);
195   edges_[i] = edge;
196   src->out_edges_.push_back(edge);
197   dst->in_edges_.push_back(edge);
198 }
199 
AddControlEdge(SimpleNode * src,SimpleNode * dst)200 void SimpleGraph::AddControlEdge(SimpleNode* src, SimpleNode* dst) {
201   AddEdge(src, Graph::kControlSlot, dst, Graph::kControlSlot);
202 }
203 
RemoveEdge(const SimpleEdge * edge)204 void SimpleGraph::RemoveEdge(const SimpleEdge* edge) {
205   auto src = edge->src();
206   auto dst = edge->dst();
207   for (auto it = src->out_edges_.begin(); it != src->out_edges_.end(); ++it) {
208     if (*it == edge) {
209       src->out_edges_.erase(it);
210       break;
211     }
212   }
213   for (auto it = dst->in_edges_.begin(); it != dst->in_edges_.end(); ++it) {
214     if (*it == edge) {
215       dst->in_edges_.erase(it);
216       break;
217     }
218   }
219 }
220 
~SimpleGraph()221 SimpleGraph::~SimpleGraph() {
222   for (auto x : nodes_) delete x;
223   for (auto x : edges_) delete x;
224 }
225 
226 // Define comparison functions for std::set with pointer keys so that behavior
227 // is deterministic. When using std::set with pointer key types, the items are
228 // sorted by pointer address which is non-deterministic. This can cause issues
229 // for INT8 mode because the graph is converted twice and non-determinism may
230 // cause a mismatch between the calibration tables of the conversions.
231 struct SimpleEdgePtrCompare {
operator ()tensorflow::tensorrt::segment::SimpleEdgePtrCompare232   bool operator()(const SimpleEdge* lhs, const SimpleEdge* rhs) const {
233     return lhs->id() < rhs->id();
234   }
235 };
236 
237 struct NodePtrCompare {
operator ()tensorflow::tensorrt::segment::NodePtrCompare238   bool operator()(const Node* lhs, const Node* rhs) const {
239     return lhs->name() < rhs->name();
240   }
241 };
242 
243 namespace {
244 
245 // Copied from TF ReverseDFS, which only works for Graph.
StableDFS(const SimpleGraph & g,bool reverse,const std::vector<const SimpleNode * > & start,const std::function<bool (const SimpleNode *)> & enter,const std::function<bool (const SimpleNode *)> & leave)246 void StableDFS(const SimpleGraph& g, bool reverse,
247                const std::vector<const SimpleNode*>& start,
248                const std::function<bool(const SimpleNode*)>& enter,
249                const std::function<bool(const SimpleNode*)>& leave) {
250   // Stack of work to do.
251   struct Work {
252     const SimpleNode* node;
253     bool leave;  // Are we entering or leaving n?
254   };
255   std::vector<Work> stack(start.size());
256   for (int i = 0; i < start.size(); ++i) {
257     stack[i] = Work{start[i], false};
258   }
259 
260   auto get_nodes = reverse ? [](const SimpleNode* n) { return n->in_nodes(); }
261                            : [](const SimpleNode* n) { return n->out_nodes(); };
262   std::vector<bool> visited(g.num_node_ids(), false);
263   while (!stack.empty()) {
264     Work w = stack.back();
265     stack.pop_back();
266 
267     auto n = w.node;
268     if (w.leave) {
269       if (leave && !leave(n)) return;
270       continue;
271     }
272 
273     if (visited[n->id()]) continue;
274     visited[n->id()] = true;
275     if (enter && !enter(n)) return;
276 
277     // Arrange to call leave(n) when all done with descendants.
278     if (leave) stack.push_back(Work{n, true});
279 
280     auto nodes = get_nodes(n);
281     std::vector<const SimpleNode*> nodes_sorted(nodes.begin(), nodes.end());
282     std::sort(nodes_sorted.begin(), nodes_sorted.end(),
283               [](const SimpleNode* lhs, const SimpleNode* rhs) {
284                 return lhs->name() < rhs->name();
285               });
286     for (const SimpleNode* node : nodes_sorted) {
287       if (!visited[node->id()]) {
288         stack.push_back(Work{node, false});
289       }
290     }
291   }
292 }
293 
CanContractEdge(const SimpleEdge * edge,const std::unique_ptr<SimpleGraph> & graph)294 bool CanContractEdge(const SimpleEdge* edge,
295                      const std::unique_ptr<SimpleGraph>& graph) {
296   const auto src = edge->src();
297   const auto dst = edge->dst();
298 
299   // Can't contract edge if doing so would cause a cycle in the
300   // graph. So, if there is a directed path from 'src' to 'dst', other
301   // than 'edge' (or any other direct edge from 'src' to 'dst'), then
302   // combining 'src' and 'dst' will cause a cycle along that path.
303   //
304   // In practice, to avoid modifying the graph and to take advantage
305   // of existing graph functions, we perform an equivalent.
306   //   1. Get all nodes incoming to 'dst', excluding 'src'
307   //   2. Reverse DFS from those nodes
308   //   3. If reverse DFS reaches 'src' then we have a cycle
309   //
310   // TODO(aaroey): there are several problems with the current approach:
311   // 1. src->dst->src, this is not detected but it should be;
312   // 2. src->dst->...(any node sequence that doesn't contain src)...->dst, this
313   //    is detected but it should not be.
314   //
315   // Note that it's fine that dst connects back to src indirectly (i.e. through
316   // a path with length > 1 that consists of intermedia nodes other than src).
317   // While loops is one example.
318   //
319   // The goal is to make sure that the trt subgraph:
320   // 1. has no loops (i.e. is a DAG), and
321   // 2. if there is a path in the subgraph from X to Y (X and Y are both nodes
322   //    in the subgraph), then all paths from X to Y are in the subgraph.
323   //
324   // To achieve this goal, the correct way seems to be:
325   // 1. remove any direct edge from src->dst;
326   // 2. detect if src can reach dst, if so they cannot be merged.
327   std::vector<const SimpleNode*> dfs_start_nodes;
328   for (const SimpleNode* node : dst->in_nodes()) {
329     if (node != src) {
330       dfs_start_nodes.push_back(node);
331     }
332   }
333   bool has_cycle = false;
334   StableDFS(*graph, /*reverse=*/true, dfs_start_nodes, /*enter=*/nullptr,
335             [&has_cycle, src](const SimpleNode* n) {
336               if (n == src) {
337                 has_cycle = true;
338                 return false;
339               }
340               return true;
341             });
342   return !has_cycle;
343 }
344 }  // namespace
345 
ContractEdge(SimpleEdge * edge,SimpleGraph * graph,std::vector<const SimpleEdge * > * remove_edges)346 void ContractEdge(SimpleEdge* edge, SimpleGraph* graph,
347                   std::vector<const SimpleEdge*>* remove_edges) {
348   // Transfer all inputs and outputs of 'dst' to 'src' except edges
349   // connecting the two.
350   auto src = edge->src();
351   auto dst = edge->dst();
352 
353   // We can use '0' for input/output index because we don't need them
354   // to be accurate for the way we are using the graph.
355   std::vector<const SimpleEdge*> in_edges(dst->in_edges().begin(),
356                                           dst->in_edges().end());
357   for (const SimpleEdge* in_edge : in_edges) {
358     if (in_edge->IsControlEdge()) {
359       if (in_edge->src() != src) {
360         SimpleEdge* e = const_cast<SimpleEdge*>(in_edge);
361         graph->AddControlEdge(e->src(), src);
362       }
363     } else {
364       if (in_edge->src() != src) {
365         SimpleEdge* e = const_cast<SimpleEdge*>(in_edge);
366         if (e->src() == graph->source_node()) {
367           graph->AddEdge(e->src(), e->src_output(), src, Graph::kControlSlot);
368         } else {
369           graph->AddEdge(e->src(), e->src_output(), src, 0 /* input index */);
370         }
371       }
372     }
373   }
374 
375   std::vector<const SimpleEdge*> out_edges(dst->out_edges().begin(),
376                                            dst->out_edges().end());
377   for (const SimpleEdge* out_edge : out_edges) {
378     if (out_edge->IsControlEdge()) {
379       SimpleEdge* e = const_cast<SimpleEdge*>(out_edge);
380       graph->AddControlEdge(src, e->dst());
381     } else {
382       SimpleEdge* e = const_cast<SimpleEdge*>(out_edge);
383       if (e->dst() == graph->sink_node()) {
384         VLOG(1) << " edge to sink node " << src->name() << " -> "
385                 << e->dst()->name();
386         graph->AddEdge(src, Graph::kControlSlot, e->dst(), e->dst_input());
387       } else {
388         graph->AddEdge(src, 0 /* output index */, e->dst(), e->dst_input());
389       }
390     }
391   }
392 
393   // Return the edges that must be removed to disconnect 'dst' from
394   // the graph. We don't actually remove 'dst' since the caller holds
395   // references to all the nodes.
396   for (const auto& in_edge : dst->in_edges()) {
397     remove_edges->push_back(in_edge);
398   }
399   for (const auto& out_edge : dst->out_edges()) {
400     remove_edges->push_back(out_edge);
401   }
402 }
403 
SegmentGraph(const Graph * tf_graph,const std::function<Status (const Node *)> & candidate_fn,const std::function<bool (const Edge *)> & input_candidate_fn,const std::function<bool (const Edge *)> & output_candidate_fn,const SegmentOptions & options,SegmentNodesVector * segments)404 Status SegmentGraph(const Graph* tf_graph,
405                     const std::function<Status(const Node*)>& candidate_fn,
406                     const std::function<bool(const Edge*)>& input_candidate_fn,
407                     const std::function<bool(const Edge*)>& output_candidate_fn,
408                     const SegmentOptions& options,
409                     SegmentNodesVector* segments) {
410   // Steps:
411   // 1. run the segmentation algorithm to find all the segments, which uses
412   //    candidate_fn to determine the candidates segment nodes;
413   // 2. for each segments, remove the nodes that are inputs/outputs of the
414   //    segment but are not eligible, using input/output_candidate_fn to
415   //    determine the eligibilities;
416   // 3. convert the segment into expected return format and return the result.
417 
418   // --------------------------------- Step 1 ---------------------------------
419   auto graph = std::unique_ptr<SimpleGraph>(new SimpleGraph(tf_graph));
420   // Use a union-find to collect the nodes that belong to the same
421   // segment. A node value of nullptr indicates that the node is not a candidate
422   // for TRT.
423   std::unordered_set<string> unsupported_ops;
424   int num_unsupported_ops = 0;
425   std::vector<UnionFind<SimpleNode*>> node_segments;
426   for (int i = 0; i < graph->num_node_ids(); ++i) {
427     SimpleNode* node = graph->FindNodeId(i);
428     if (options.exclude_node_list.count(node->name()) != 0) {
429       VLOG(1) << "Not a TF-TRT candidate, "
430               << "(Op type: " << node->tf_node()->type_string() << "), "
431               << "(Op name: " << node->name() << "), "
432               << "(Reason: excluded by segmenter option)";
433       unsupported_ops.emplace(node->tf_node()->type_string());
434       num_unsupported_ops++;
435       node = nullptr;
436     } else {
437       const Status status = candidate_fn(node->tf_node());
438       if (!status.ok()) {
439         VLOG(1) << "Not a TF-TRT candidate, "
440                 << "(Op type: " << node->tf_node()->type_string() << "), "
441                 << "(Op name: " << node->name() << "), "
442                 << "(Reason: " << status << ")";
443         unsupported_ops.emplace(node->tf_node()->type_string());
444         num_unsupported_ops++;
445         node = nullptr;
446       }
447     }
448     node_segments.emplace_back(node);
449   }
450   string msg = StrCat(
451       "There are ", num_unsupported_ops, " ops of ", unsupported_ops.size(),
452       " different types in the graph that", " are not converted to TensorRT: ");
453   for (const auto& elem : unsupported_ops) {
454     StrAppend(&msg, elem, ", ");
455   }
456   LOG(INFO) << msg << "(For more information see "
457             << "https://docs.nvidia.com/deeplearning"
458             << "/dgx/integrate-tf-trt/index.html#support-ops).";
459 
460   // The segmentation algorithm below visits nodes in reverse topological order
461   // and attempts to merge nodes along output edges. That means that subgraphs
462   // grow from the output-side of the network towards the inputs.
463   //
464   // In general this is not guaranteed to produce a globally optimal
465   // segmentation. For exaample, consider graph with node {A, B, C, D} and edges
466   // {A->B, A->C, B->D, C->D), where A, B, D are trt compatible but C is not, so
467   // in theory we can choose to contract either A, B or B, D but not both, but
468   // here it always choose to contract B, D.
469   //
470   // In the future if we have a measure of how beneficial it is to include a
471   // given node in a TRT subgraph then we can revisit this algorithm to take
472   // advantage of that information.
473   std::vector<const SimpleNode*> order;
474   order.reserve(graph->num_node_ids());
475   StableDFS(*graph, /*reverse=*/false, {graph->source_node()},
476             /*enter=*/nullptr, [&order](const SimpleNode* n) {
477               order.push_back(n);
478               return true;
479             });
480   for (const SimpleNode* node : order) {
481     // All output nodes of 'node' have been visited...
482     VLOG(3) << "Trying node " << node->name() << " id=" << node->id();
483     // 'node' must be a TRT candidate...
484     if (node_segments[node->id()].Value() == nullptr) {
485       VLOG(3) << "... not a TRT candidate";
486       continue;
487     }
488     // Contract output edges to combine 'node' with output
489     // nodes. Iterate since combining two nodes may unblock other
490     // combining.
491     while (true) {
492       std::set<const SimpleEdge*, SimpleEdgePtrCompare> contract_edges;
493       for (const SimpleEdge* out_edge : node->out_edges()) {
494         VLOG(3) << "... out node " << out_edge->dst()->name() << " ( "
495                 << out_edge->dst()->id() << " <- " << node->id() << " )";
496         if (out_edge->IsControlEdge()) {
497           VLOG(3) << "... ... Control Edge, Skipping";
498           continue;
499         }
500         // Out node must be TRT candidate...
501         if (node_segments[out_edge->dst()->id()].Value() == nullptr) {
502           VLOG(3) << "... ... not a TRT candidate";
503           continue;
504         }
505         if (CanContractEdge(out_edge, graph)) {
506           VLOG(3) << "... ... can contract";
507           contract_edges.insert(out_edge);
508         } else {
509           VLOG(3) << "... ... cannot contract, would form cycle";
510         }
511       }
512       if (contract_edges.empty()) {
513         break;
514       }
515       // Contract edges and collect the adjacent nodes into the same
516       // segment/subgraph.
517       while (!contract_edges.empty()) {
518         const SimpleEdge* contract_edge = *contract_edges.begin();
519         const SimpleNode* src = contract_edge->src();
520         const SimpleNode* dst = contract_edge->dst();
521 
522         VLOG(3) << "Merge " << src->name() << " <- " << dst->name() << " ("
523                 << src->id() << " <- " << dst->id();
524         node_segments[src->id()].Merge(&node_segments[dst->id()]);
525 
526         // Contracting the edge leaves disconnected graph edges.
527         // Remove these from the graph and from 'contract_edges' so we
528         // don't visit them again.
529         SimpleEdge* e = const_cast<SimpleEdge*>(contract_edge);
530         std::vector<const SimpleEdge*> remove_edges;
531         ContractEdge(e, graph.get(), &remove_edges);
532 
533         for (const SimpleEdge* r : remove_edges) {
534           contract_edges.erase(r);
535           graph->RemoveEdge(r);
536         }
537       }
538     }
539   }
540 
541   // Collect the segments/subgraphs. Each subgraph is represented by a
542   // set of the names of the nodes in that subgraph.
543 
544   // A map from the segment identifier (currently the name of the root node of
545   // the segment tree) to the segment nodes set.
546   std::map<string, std::set<const Node*, NodePtrCompare>> sg_map;
547 
548   // A map from the segment identifier (currently the name of the root node of
549   // the segment tree) to the device names that the nodes in the segment are
550   // assigned to.
551   //
552   // TODO(aaroey): nodes assigned to different devices should not be merged,
553   // fix this.
554   std::unordered_map<string, std::set<string>> device_maps;
555 
556   for (auto& u : node_segments) {
557     if ((u.Value() != nullptr) && (u.ParentValue() != nullptr)) {
558       sg_map[u.ParentValue()->name()].insert(u.Value()->tf_node());
559       auto tf_node = u.Value()->tf_node();
560       // has_assigned_device_name() is expected to return true
561       // when called from optimization pass. However, since graph
562       // is converted back and forth between graph and graphdef,
563       // assigned devices demoted to requested devices. If the graph
564       // is passed directly to this module, assigned devices will be set.
565       if (tf_node->has_assigned_device_name()) {
566         device_maps[u.ParentValue()->name()].insert(
567             tf_node->assigned_device_name());
568       } else if (!tf_node->requested_device().empty()) {
569         device_maps[u.ParentValue()->name()].insert(
570             tf_node->requested_device());
571       } else {
572         VLOG(2) << "Node " << tf_node->name()
573                 << " has no device assigned requested device is: "
574                 << tf_node->requested_device();
575       }
576     }
577   }
578 
579   // --------------------------------- Step 2 ---------------------------------
580   // Remove ineligible input/output nodes.
581   for (auto& itr : sg_map) {
582     std::set<const Node*, NodePtrCompare>& segment_nodes = itr.second;
583     VLOG(1) << "Segment original size: " << segment_nodes.size();
584     while (true) {
585       std::deque<const Node*> in_nodes_que, out_nodes_que;
586       // Find an input node that is not eligible and add it to the queue.
587       // Nodes that has no incoming edges should not be treated as "input",
588       // as there are really no inputs to them. Similar for output nodes.
589       for (auto node : segment_nodes) {
590         bool added = false;
591         for (const Edge* edge : node->in_edges()) {
592           if (!edge->IsControlEdge() && !edge->src()->IsSource() &&
593               !segment_nodes.count(edge->src())) {  // 'node' is an input node.
594             if (!input_candidate_fn(edge)) {
595               in_nodes_que.push_back(node);
596               added = true;
597               break;
598             }
599           }
600         }
601         if (added) continue;  // Only adding the node once to either queue.
602         for (const Edge* edge : node->out_edges()) {
603           if (!edge->dst()->IsSink() && !edge->IsControlEdge() &&
604               !segment_nodes.count(edge->dst())) {  // 'node' is an output node.
605             if (!output_candidate_fn(edge)) {
606               out_nodes_que.push_back(node);
607               break;
608             }
609           }
610         }
611       }
612       if (in_nodes_que.empty() && out_nodes_que.empty()) {
613         // No more ineligible input/output nodes.
614         break;
615       }
616       // Now for each ineligible node, remove all of its inputs or outputs from
617       // the subgraph.
618       //
619       // It can be proven that, if the original subgraph:
620       // 1. is a DAG, and
621       // 2. all paths between two nodes in the subgraph are all inside the
622       //    subgraph
623       // then after doing this operation the resulting subgraph will keep the
624       // same properties 1 and 2.
625       //
626       // For simplicity we use heuristics: for input and const output nodes
627       // remove all their inputs, and for non-const output nodes remove all
628       // their outputs. In this way, for common cases the number of removed
629       // nodes should be minimum.
630       auto remove_nodes = [&segment_nodes](bool is_input_nodes,
631                                            std::deque<const Node*>* que) {
632         // Run a BFS on the queue to find all the input/output nodes.
633         std::set<const Node*, NodePtrCompare> visited;
634         std::set<const Node*, NodePtrCompare> logged(que->begin(), que->end());
635         while (!que->empty()) {
636           auto node = que->front();
637           que->pop_front();
638           if (!visited.insert(node).second) continue;
639           segment_nodes.erase(node);
640           for (auto in : (is_input_nodes || node->type_string() == "Const")
641                              ? node->in_nodes()
642                              : node->out_nodes()) {
643             if (segment_nodes.count(in)) {
644               que->push_back(in);
645               if (VLOG_IS_ON(2)) {
646                 if (!logged.count(in)) {
647                   VLOG(2) << "----> Need to remove node " << in->name()
648                           << " because one of its "
649                           << (is_input_nodes ? "output" : "input")
650                           << " nodes in the graph was removed: "
651                           << node->name();
652                   logged.insert(in);
653                 }
654               }
655             }
656           }
657         }
658       };
659       remove_nodes(true, &in_nodes_que);
660       remove_nodes(false, &out_nodes_que);
661     }
662     VLOG(1) << "Segment new size: " << segment_nodes.size();
663   }
664 
665   // --------------------------------- Step 3 ---------------------------------
666   // Convert the segments into the expected return format
667   for (const auto& itr : sg_map) {
668     const string& segment_root = itr.first;
669     // Return format does not require set comparator.
670     std::set<const Node*> segment_nodes(itr.second.begin(), itr.second.end());
671     if (VLOG_IS_ON(1) && !segment_nodes.empty()) {
672       string s;
673       for (auto node : segment_nodes) {
674         StrAppend(&s, "\n[Op type: ", node->type_string(), "] ", node->name());
675       }
676       VLOG(1) << "Nodes in segment " << segments->size()
677               << " with parent=" << segment_root << ":" << s;
678     }
679 
680     // Don't use small segments.
681     if (static_cast<int>(segment_nodes.size()) < options.minimum_segment_size) {
682       VLOG(1) << "Segment " << segments->size() << " has only "
683               << segment_nodes.size() << " nodes, dropping";
684       continue;
685     }
686 
687     // TODO(sami): Make segmenter placement aware once trtscopes are in place
688     const auto& dev_itr = device_maps.find(segment_root);
689     if (dev_itr == device_maps.end() || dev_itr->second.empty()) {
690       VLOG(1) << "No device assigned to segment " << segments->size();
691       segments->emplace_back(std::make_pair(segment_nodes, string()));
692     } else if (dev_itr->second.size() > 1) {
693       string s("Segment ");
694       StrAppend(&s, segments->size(), " has multiple devices attached: ");
695       for (const auto& dev : dev_itr->second) {
696         StrAppend(&s, dev, ", ");
697       }
698       LOG(WARNING) << s << " choosing " << *(dev_itr->second.begin());
699       segments->emplace_back(
700           std::make_pair(segment_nodes, *(dev_itr->second.begin())));
701     } else {
702       segments->emplace_back(
703           std::make_pair(segment_nodes, *(dev_itr->second.begin())));
704     }
705   }
706   if (VLOG_IS_ON(1)) {
707     for (const auto& d : device_maps) {
708       string s("Segment ");
709       StrAppend(&s, ": '", d.first, "' ");
710       for (const auto& dd : d.second) {
711         StrAppend(&s, dd, ", ");
712       }
713       VLOG(1) << "Devices " << s;
714     }
715   }
716   return Status::OK();
717 }
718 
719 }  // namespace segment
720 }  // namespace tensorrt
721 }  // namespace tensorflow
722 
723 #endif  // GOOGLE_TENSORRT
724 #endif  // GOOGLE_CUDA
725