1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2tensorrt/segment/segment.h"
17
18 #include <queue>
19 #include <set>
20 #include <unordered_map>
21 #include <vector>
22
23 #include "absl/strings/str_cat.h"
24 #include "tensorflow/compiler/tf2tensorrt/segment/union_find.h"
25 #include "tensorflow/core/graph/algorithm.h"
26 #include "tensorflow/core/graph/graph.h"
27 #include "tensorflow/core/graph/graph_constructor.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/lib/strings/strcat.h"
31 #include "tensorflow/core/platform/types.h"
32
33 #if GOOGLE_CUDA
34 #if GOOGLE_TENSORRT
35
36 namespace tensorflow {
37 namespace tensorrt {
38 namespace segment {
39 using absl::StrAppend;
40 using absl::StrCat;
41
42 // A simple graph representation to mirror Graph. This structure
43 // helps saving memory since segmenter modifies the graph in place, preventing
44 // the need to create a copy of the graph. It is composed of edges and nodes.
45 // Nodes keep pointers to original TF nodes.
46 class SimpleNode;
47 class SimpleGraph;
48 class SimpleEdge {
49 public:
SimpleEdge(int id,SimpleNode * src,int src_port,SimpleNode * dst,int dst_port,bool is_control=false)50 SimpleEdge(int id, SimpleNode* src, int src_port, SimpleNode* dst,
51 int dst_port, bool is_control = false)
52 : id_(id),
53 src_(src),
54 src_port_(src_port),
55 dst_(dst),
56 dst_port_(dst_port),
57 control_(is_control) {}
~SimpleEdge()58 ~SimpleEdge() {}
59
src() const60 SimpleNode* src() const { return src_; }
dst() const61 SimpleNode* dst() const { return dst_; }
src_output() const62 int src_output() const { return src_port_; }
dst_input() const63 int dst_input() const { return dst_port_; }
id() const64 int id() const { return id_; }
IsControlEdge() const65 bool IsControlEdge() const { return control_; }
66
67 private:
68 int id_;
69 SimpleNode* src_;
70 int src_port_;
71 SimpleNode* dst_;
72 int dst_port_;
73 bool control_;
74 };
75
76 class SimpleNode {
77 public:
78 SimpleNode(const Node* node, const int id);
79
in_edges() const80 const std::vector<SimpleEdge*>& in_edges() const { return in_edges_; }
out_edges() const81 const std::vector<SimpleEdge*>& out_edges() const { return out_edges_; }
82
in_nodes() const83 std::vector<SimpleNode*> in_nodes() const {
84 std::vector<SimpleNode*> res;
85 res.reserve(in_edges_.size());
86 for (const auto e : in_edges_) {
87 if (e) res.push_back(e->src());
88 }
89 return res;
90 }
91
out_nodes() const92 std::vector<SimpleNode*> out_nodes() const {
93 std::vector<SimpleNode*> res;
94 res.reserve(out_edges_.size());
95 for (const auto e : out_edges_) {
96 if (e) res.push_back(e->dst());
97 }
98 return res;
99 }
100
name() const101 const string& name() const { return node_->name(); }
tf_node() const102 const Node* tf_node() const { return node_; }
id() const103 int id() const { return id_; }
104
105 private:
106 const Node* node_;
107 std::vector<SimpleEdge*> in_edges_;
108 std::vector<SimpleEdge*> out_edges_;
109 int id_;
110
111 friend class SimpleGraph;
112 };
113
114 class SimpleGraph {
115 public:
116 explicit SimpleGraph(const Graph* g);
117 ~SimpleGraph();
118
119 void AddControlEdge(SimpleNode* src, SimpleNode* dst);
120 void AddEdge(SimpleNode* src, int out_port, SimpleNode* dst, int in_port);
121 void RemoveEdge(const SimpleEdge*);
FindNodeId(int node_id)122 SimpleNode* FindNodeId(int node_id) {
123 if (node_id < 0 || node_id > static_cast<int>(nodes_.size())) {
124 return nullptr;
125 }
126 return nodes_[node_id];
127 }
num_node_ids() const128 int num_node_ids() const { return nodes_.size(); }
source_node() const129 const SimpleNode* source_node() const { return nodes_[Graph::kSourceId]; }
sink_node() const130 const SimpleNode* sink_node() const { return nodes_[Graph::kSinkId]; }
131
132 private:
133 const Graph* g_;
134 std::vector<SimpleNode*> nodes_;
135 std::vector<SimpleEdge*> edges_;
136 // free_edge_ids_ and free_node_ids_ contain freed indices.
137 std::set<int> free_edge_ids_;
138 std::set<int> free_node_ids_;
139 };
140
SimpleNode(const Node * node,const int id)141 SimpleNode::SimpleNode(const Node* node, const int id) : node_(node), id_(id) {
142 if (node_) {
143 in_edges_.reserve(node_->in_edges().size());
144 out_edges_.reserve(node_->out_edges().size());
145 }
146 }
147
SimpleGraph(const Graph * g)148 SimpleGraph::SimpleGraph(const Graph* g) : g_(g) {
149 int n_nodes = g_->num_node_ids();
150 nodes_.resize(n_nodes, nullptr);
151 nodes_[g->kSourceId] = new SimpleNode(g->source_node(), g->kSourceId);
152 nodes_[g->kSinkId] = new SimpleNode(g->sink_node(), g->kSinkId);
153 int n_edges = g->num_edge_ids();
154 edges_.resize(n_edges, nullptr);
155 for (int i = 2; i < n_nodes; i++) {
156 const auto n = g->FindNodeId(i);
157 if (n) {
158 nodes_[i] = new SimpleNode(n, i);
159 } else {
160 free_node_ids_.insert(i);
161 }
162 }
163 for (int i = 0; i < n_edges; i++) {
164 const auto e = g->FindEdgeId(i);
165 if (e) {
166 const auto tfsrc = e->src();
167 const auto tfdst = e->dst();
168 bool is_control = e->IsControlEdge();
169 auto src = nodes_[tfsrc->id()];
170 auto dst = nodes_[tfdst->id()];
171 auto edge = new SimpleEdge(i, src, e->src_output(), dst, e->dst_input(),
172 is_control);
173 edges_[i] = edge;
174 src->out_edges_.push_back(edge);
175 dst->in_edges_.push_back(edge);
176 } else {
177 free_edge_ids_.insert(i);
178 }
179 }
180 }
181
AddEdge(SimpleNode * src,int out_port,SimpleNode * dst,int in_port)182 void SimpleGraph::AddEdge(SimpleNode* src, int out_port, SimpleNode* dst,
183 int in_port) {
184 int i = edges_.size();
185 if (!free_edge_ids_.empty()) {
186 auto it = free_edge_ids_.begin();
187 i = *it;
188 free_edge_ids_.erase(it);
189 } else {
190 edges_.push_back(nullptr);
191 }
192 bool is_control = (out_port == Graph::kControlSlot);
193 is_control |= (in_port == Graph::kControlSlot);
194 auto edge = new SimpleEdge(i, src, out_port, dst, in_port, is_control);
195 edges_[i] = edge;
196 src->out_edges_.push_back(edge);
197 dst->in_edges_.push_back(edge);
198 }
199
AddControlEdge(SimpleNode * src,SimpleNode * dst)200 void SimpleGraph::AddControlEdge(SimpleNode* src, SimpleNode* dst) {
201 AddEdge(src, Graph::kControlSlot, dst, Graph::kControlSlot);
202 }
203
RemoveEdge(const SimpleEdge * edge)204 void SimpleGraph::RemoveEdge(const SimpleEdge* edge) {
205 auto src = edge->src();
206 auto dst = edge->dst();
207 for (auto it = src->out_edges_.begin(); it != src->out_edges_.end(); ++it) {
208 if (*it == edge) {
209 src->out_edges_.erase(it);
210 break;
211 }
212 }
213 for (auto it = dst->in_edges_.begin(); it != dst->in_edges_.end(); ++it) {
214 if (*it == edge) {
215 dst->in_edges_.erase(it);
216 break;
217 }
218 }
219 }
220
~SimpleGraph()221 SimpleGraph::~SimpleGraph() {
222 for (auto x : nodes_) delete x;
223 for (auto x : edges_) delete x;
224 }
225
226 // Define comparison functions for std::set with pointer keys so that behavior
227 // is deterministic. When using std::set with pointer key types, the items are
228 // sorted by pointer address which is non-deterministic. This can cause issues
229 // for INT8 mode because the graph is converted twice and non-determinism may
230 // cause a mismatch between the calibration tables of the conversions.
231 struct SimpleEdgePtrCompare {
operator ()tensorflow::tensorrt::segment::SimpleEdgePtrCompare232 bool operator()(const SimpleEdge* lhs, const SimpleEdge* rhs) const {
233 return lhs->id() < rhs->id();
234 }
235 };
236
237 struct NodePtrCompare {
operator ()tensorflow::tensorrt::segment::NodePtrCompare238 bool operator()(const Node* lhs, const Node* rhs) const {
239 return lhs->name() < rhs->name();
240 }
241 };
242
243 namespace {
244
245 // Copied from TF ReverseDFS, which only works for Graph.
StableDFS(const SimpleGraph & g,bool reverse,const std::vector<const SimpleNode * > & start,const std::function<bool (const SimpleNode *)> & enter,const std::function<bool (const SimpleNode *)> & leave)246 void StableDFS(const SimpleGraph& g, bool reverse,
247 const std::vector<const SimpleNode*>& start,
248 const std::function<bool(const SimpleNode*)>& enter,
249 const std::function<bool(const SimpleNode*)>& leave) {
250 // Stack of work to do.
251 struct Work {
252 const SimpleNode* node;
253 bool leave; // Are we entering or leaving n?
254 };
255 std::vector<Work> stack(start.size());
256 for (int i = 0; i < start.size(); ++i) {
257 stack[i] = Work{start[i], false};
258 }
259
260 auto get_nodes = reverse ? [](const SimpleNode* n) { return n->in_nodes(); }
261 : [](const SimpleNode* n) { return n->out_nodes(); };
262 std::vector<bool> visited(g.num_node_ids(), false);
263 while (!stack.empty()) {
264 Work w = stack.back();
265 stack.pop_back();
266
267 auto n = w.node;
268 if (w.leave) {
269 if (leave && !leave(n)) return;
270 continue;
271 }
272
273 if (visited[n->id()]) continue;
274 visited[n->id()] = true;
275 if (enter && !enter(n)) return;
276
277 // Arrange to call leave(n) when all done with descendants.
278 if (leave) stack.push_back(Work{n, true});
279
280 auto nodes = get_nodes(n);
281 std::vector<const SimpleNode*> nodes_sorted(nodes.begin(), nodes.end());
282 std::sort(nodes_sorted.begin(), nodes_sorted.end(),
283 [](const SimpleNode* lhs, const SimpleNode* rhs) {
284 return lhs->name() < rhs->name();
285 });
286 for (const SimpleNode* node : nodes_sorted) {
287 if (!visited[node->id()]) {
288 stack.push_back(Work{node, false});
289 }
290 }
291 }
292 }
293
CanContractEdge(const SimpleEdge * edge,const std::unique_ptr<SimpleGraph> & graph)294 bool CanContractEdge(const SimpleEdge* edge,
295 const std::unique_ptr<SimpleGraph>& graph) {
296 const auto src = edge->src();
297 const auto dst = edge->dst();
298
299 // Can't contract edge if doing so would cause a cycle in the
300 // graph. So, if there is a directed path from 'src' to 'dst', other
301 // than 'edge' (or any other direct edge from 'src' to 'dst'), then
302 // combining 'src' and 'dst' will cause a cycle along that path.
303 //
304 // In practice, to avoid modifying the graph and to take advantage
305 // of existing graph functions, we perform an equivalent.
306 // 1. Get all nodes incoming to 'dst', excluding 'src'
307 // 2. Reverse DFS from those nodes
308 // 3. If reverse DFS reaches 'src' then we have a cycle
309 //
310 // TODO(aaroey): there are several problems with the current approach:
311 // 1. src->dst->src, this is not detected but it should be;
312 // 2. src->dst->...(any node sequence that doesn't contain src)...->dst, this
313 // is detected but it should not be.
314 //
315 // Note that it's fine that dst connects back to src indirectly (i.e. through
316 // a path with length > 1 that consists of intermedia nodes other than src).
317 // While loops is one example.
318 //
319 // The goal is to make sure that the trt subgraph:
320 // 1. has no loops (i.e. is a DAG), and
321 // 2. if there is a path in the subgraph from X to Y (X and Y are both nodes
322 // in the subgraph), then all paths from X to Y are in the subgraph.
323 //
324 // To achieve this goal, the correct way seems to be:
325 // 1. remove any direct edge from src->dst;
326 // 2. detect if src can reach dst, if so they cannot be merged.
327 std::vector<const SimpleNode*> dfs_start_nodes;
328 for (const SimpleNode* node : dst->in_nodes()) {
329 if (node != src) {
330 dfs_start_nodes.push_back(node);
331 }
332 }
333 bool has_cycle = false;
334 StableDFS(*graph, /*reverse=*/true, dfs_start_nodes, /*enter=*/nullptr,
335 [&has_cycle, src](const SimpleNode* n) {
336 if (n == src) {
337 has_cycle = true;
338 return false;
339 }
340 return true;
341 });
342 return !has_cycle;
343 }
344 } // namespace
345
ContractEdge(SimpleEdge * edge,SimpleGraph * graph,std::vector<const SimpleEdge * > * remove_edges)346 void ContractEdge(SimpleEdge* edge, SimpleGraph* graph,
347 std::vector<const SimpleEdge*>* remove_edges) {
348 // Transfer all inputs and outputs of 'dst' to 'src' except edges
349 // connecting the two.
350 auto src = edge->src();
351 auto dst = edge->dst();
352
353 // We can use '0' for input/output index because we don't need them
354 // to be accurate for the way we are using the graph.
355 std::vector<const SimpleEdge*> in_edges(dst->in_edges().begin(),
356 dst->in_edges().end());
357 for (const SimpleEdge* in_edge : in_edges) {
358 if (in_edge->IsControlEdge()) {
359 if (in_edge->src() != src) {
360 SimpleEdge* e = const_cast<SimpleEdge*>(in_edge);
361 graph->AddControlEdge(e->src(), src);
362 }
363 } else {
364 if (in_edge->src() != src) {
365 SimpleEdge* e = const_cast<SimpleEdge*>(in_edge);
366 if (e->src() == graph->source_node()) {
367 graph->AddEdge(e->src(), e->src_output(), src, Graph::kControlSlot);
368 } else {
369 graph->AddEdge(e->src(), e->src_output(), src, 0 /* input index */);
370 }
371 }
372 }
373 }
374
375 std::vector<const SimpleEdge*> out_edges(dst->out_edges().begin(),
376 dst->out_edges().end());
377 for (const SimpleEdge* out_edge : out_edges) {
378 if (out_edge->IsControlEdge()) {
379 SimpleEdge* e = const_cast<SimpleEdge*>(out_edge);
380 graph->AddControlEdge(src, e->dst());
381 } else {
382 SimpleEdge* e = const_cast<SimpleEdge*>(out_edge);
383 if (e->dst() == graph->sink_node()) {
384 VLOG(1) << " edge to sink node " << src->name() << " -> "
385 << e->dst()->name();
386 graph->AddEdge(src, Graph::kControlSlot, e->dst(), e->dst_input());
387 } else {
388 graph->AddEdge(src, 0 /* output index */, e->dst(), e->dst_input());
389 }
390 }
391 }
392
393 // Return the edges that must be removed to disconnect 'dst' from
394 // the graph. We don't actually remove 'dst' since the caller holds
395 // references to all the nodes.
396 for (const auto& in_edge : dst->in_edges()) {
397 remove_edges->push_back(in_edge);
398 }
399 for (const auto& out_edge : dst->out_edges()) {
400 remove_edges->push_back(out_edge);
401 }
402 }
403
SegmentGraph(const Graph * tf_graph,const std::function<Status (const Node *)> & candidate_fn,const std::function<bool (const Edge *)> & input_candidate_fn,const std::function<bool (const Edge *)> & output_candidate_fn,const SegmentOptions & options,SegmentNodesVector * segments)404 Status SegmentGraph(const Graph* tf_graph,
405 const std::function<Status(const Node*)>& candidate_fn,
406 const std::function<bool(const Edge*)>& input_candidate_fn,
407 const std::function<bool(const Edge*)>& output_candidate_fn,
408 const SegmentOptions& options,
409 SegmentNodesVector* segments) {
410 // Steps:
411 // 1. run the segmentation algorithm to find all the segments, which uses
412 // candidate_fn to determine the candidates segment nodes;
413 // 2. for each segments, remove the nodes that are inputs/outputs of the
414 // segment but are not eligible, using input/output_candidate_fn to
415 // determine the eligibilities;
416 // 3. convert the segment into expected return format and return the result.
417
418 // --------------------------------- Step 1 ---------------------------------
419 auto graph = std::unique_ptr<SimpleGraph>(new SimpleGraph(tf_graph));
420 // Use a union-find to collect the nodes that belong to the same
421 // segment. A node value of nullptr indicates that the node is not a candidate
422 // for TRT.
423 std::unordered_set<string> unsupported_ops;
424 int num_unsupported_ops = 0;
425 std::vector<UnionFind<SimpleNode*>> node_segments;
426 for (int i = 0; i < graph->num_node_ids(); ++i) {
427 SimpleNode* node = graph->FindNodeId(i);
428 if (options.exclude_node_list.count(node->name()) != 0) {
429 VLOG(1) << "Not a TF-TRT candidate, "
430 << "(Op type: " << node->tf_node()->type_string() << "), "
431 << "(Op name: " << node->name() << "), "
432 << "(Reason: excluded by segmenter option)";
433 unsupported_ops.emplace(node->tf_node()->type_string());
434 num_unsupported_ops++;
435 node = nullptr;
436 } else {
437 const Status status = candidate_fn(node->tf_node());
438 if (!status.ok()) {
439 VLOG(1) << "Not a TF-TRT candidate, "
440 << "(Op type: " << node->tf_node()->type_string() << "), "
441 << "(Op name: " << node->name() << "), "
442 << "(Reason: " << status << ")";
443 unsupported_ops.emplace(node->tf_node()->type_string());
444 num_unsupported_ops++;
445 node = nullptr;
446 }
447 }
448 node_segments.emplace_back(node);
449 }
450 string msg = StrCat(
451 "There are ", num_unsupported_ops, " ops of ", unsupported_ops.size(),
452 " different types in the graph that", " are not converted to TensorRT: ");
453 for (const auto& elem : unsupported_ops) {
454 StrAppend(&msg, elem, ", ");
455 }
456 LOG(INFO) << msg << "(For more information see "
457 << "https://docs.nvidia.com/deeplearning"
458 << "/dgx/integrate-tf-trt/index.html#support-ops).";
459
460 // The segmentation algorithm below visits nodes in reverse topological order
461 // and attempts to merge nodes along output edges. That means that subgraphs
462 // grow from the output-side of the network towards the inputs.
463 //
464 // In general this is not guaranteed to produce a globally optimal
465 // segmentation. For exaample, consider graph with node {A, B, C, D} and edges
466 // {A->B, A->C, B->D, C->D), where A, B, D are trt compatible but C is not, so
467 // in theory we can choose to contract either A, B or B, D but not both, but
468 // here it always choose to contract B, D.
469 //
470 // In the future if we have a measure of how beneficial it is to include a
471 // given node in a TRT subgraph then we can revisit this algorithm to take
472 // advantage of that information.
473 std::vector<const SimpleNode*> order;
474 order.reserve(graph->num_node_ids());
475 StableDFS(*graph, /*reverse=*/false, {graph->source_node()},
476 /*enter=*/nullptr, [&order](const SimpleNode* n) {
477 order.push_back(n);
478 return true;
479 });
480 for (const SimpleNode* node : order) {
481 // All output nodes of 'node' have been visited...
482 VLOG(3) << "Trying node " << node->name() << " id=" << node->id();
483 // 'node' must be a TRT candidate...
484 if (node_segments[node->id()].Value() == nullptr) {
485 VLOG(3) << "... not a TRT candidate";
486 continue;
487 }
488 // Contract output edges to combine 'node' with output
489 // nodes. Iterate since combining two nodes may unblock other
490 // combining.
491 while (true) {
492 std::set<const SimpleEdge*, SimpleEdgePtrCompare> contract_edges;
493 for (const SimpleEdge* out_edge : node->out_edges()) {
494 VLOG(3) << "... out node " << out_edge->dst()->name() << " ( "
495 << out_edge->dst()->id() << " <- " << node->id() << " )";
496 if (out_edge->IsControlEdge()) {
497 VLOG(3) << "... ... Control Edge, Skipping";
498 continue;
499 }
500 // Out node must be TRT candidate...
501 if (node_segments[out_edge->dst()->id()].Value() == nullptr) {
502 VLOG(3) << "... ... not a TRT candidate";
503 continue;
504 }
505 if (CanContractEdge(out_edge, graph)) {
506 VLOG(3) << "... ... can contract";
507 contract_edges.insert(out_edge);
508 } else {
509 VLOG(3) << "... ... cannot contract, would form cycle";
510 }
511 }
512 if (contract_edges.empty()) {
513 break;
514 }
515 // Contract edges and collect the adjacent nodes into the same
516 // segment/subgraph.
517 while (!contract_edges.empty()) {
518 const SimpleEdge* contract_edge = *contract_edges.begin();
519 const SimpleNode* src = contract_edge->src();
520 const SimpleNode* dst = contract_edge->dst();
521
522 VLOG(3) << "Merge " << src->name() << " <- " << dst->name() << " ("
523 << src->id() << " <- " << dst->id();
524 node_segments[src->id()].Merge(&node_segments[dst->id()]);
525
526 // Contracting the edge leaves disconnected graph edges.
527 // Remove these from the graph and from 'contract_edges' so we
528 // don't visit them again.
529 SimpleEdge* e = const_cast<SimpleEdge*>(contract_edge);
530 std::vector<const SimpleEdge*> remove_edges;
531 ContractEdge(e, graph.get(), &remove_edges);
532
533 for (const SimpleEdge* r : remove_edges) {
534 contract_edges.erase(r);
535 graph->RemoveEdge(r);
536 }
537 }
538 }
539 }
540
541 // Collect the segments/subgraphs. Each subgraph is represented by a
542 // set of the names of the nodes in that subgraph.
543
544 // A map from the segment identifier (currently the name of the root node of
545 // the segment tree) to the segment nodes set.
546 std::map<string, std::set<const Node*, NodePtrCompare>> sg_map;
547
548 // A map from the segment identifier (currently the name of the root node of
549 // the segment tree) to the device names that the nodes in the segment are
550 // assigned to.
551 //
552 // TODO(aaroey): nodes assigned to different devices should not be merged,
553 // fix this.
554 std::unordered_map<string, std::set<string>> device_maps;
555
556 for (auto& u : node_segments) {
557 if ((u.Value() != nullptr) && (u.ParentValue() != nullptr)) {
558 sg_map[u.ParentValue()->name()].insert(u.Value()->tf_node());
559 auto tf_node = u.Value()->tf_node();
560 // has_assigned_device_name() is expected to return true
561 // when called from optimization pass. However, since graph
562 // is converted back and forth between graph and graphdef,
563 // assigned devices demoted to requested devices. If the graph
564 // is passed directly to this module, assigned devices will be set.
565 if (tf_node->has_assigned_device_name()) {
566 device_maps[u.ParentValue()->name()].insert(
567 tf_node->assigned_device_name());
568 } else if (!tf_node->requested_device().empty()) {
569 device_maps[u.ParentValue()->name()].insert(
570 tf_node->requested_device());
571 } else {
572 VLOG(2) << "Node " << tf_node->name()
573 << " has no device assigned requested device is: "
574 << tf_node->requested_device();
575 }
576 }
577 }
578
579 // --------------------------------- Step 2 ---------------------------------
580 // Remove ineligible input/output nodes.
581 for (auto& itr : sg_map) {
582 std::set<const Node*, NodePtrCompare>& segment_nodes = itr.second;
583 VLOG(1) << "Segment original size: " << segment_nodes.size();
584 while (true) {
585 std::deque<const Node*> in_nodes_que, out_nodes_que;
586 // Find an input node that is not eligible and add it to the queue.
587 // Nodes that has no incoming edges should not be treated as "input",
588 // as there are really no inputs to them. Similar for output nodes.
589 for (auto node : segment_nodes) {
590 bool added = false;
591 for (const Edge* edge : node->in_edges()) {
592 if (!edge->IsControlEdge() && !edge->src()->IsSource() &&
593 !segment_nodes.count(edge->src())) { // 'node' is an input node.
594 if (!input_candidate_fn(edge)) {
595 in_nodes_que.push_back(node);
596 added = true;
597 break;
598 }
599 }
600 }
601 if (added) continue; // Only adding the node once to either queue.
602 for (const Edge* edge : node->out_edges()) {
603 if (!edge->dst()->IsSink() && !edge->IsControlEdge() &&
604 !segment_nodes.count(edge->dst())) { // 'node' is an output node.
605 if (!output_candidate_fn(edge)) {
606 out_nodes_que.push_back(node);
607 break;
608 }
609 }
610 }
611 }
612 if (in_nodes_que.empty() && out_nodes_que.empty()) {
613 // No more ineligible input/output nodes.
614 break;
615 }
616 // Now for each ineligible node, remove all of its inputs or outputs from
617 // the subgraph.
618 //
619 // It can be proven that, if the original subgraph:
620 // 1. is a DAG, and
621 // 2. all paths between two nodes in the subgraph are all inside the
622 // subgraph
623 // then after doing this operation the resulting subgraph will keep the
624 // same properties 1 and 2.
625 //
626 // For simplicity we use heuristics: for input and const output nodes
627 // remove all their inputs, and for non-const output nodes remove all
628 // their outputs. In this way, for common cases the number of removed
629 // nodes should be minimum.
630 auto remove_nodes = [&segment_nodes](bool is_input_nodes,
631 std::deque<const Node*>* que) {
632 // Run a BFS on the queue to find all the input/output nodes.
633 std::set<const Node*, NodePtrCompare> visited;
634 std::set<const Node*, NodePtrCompare> logged(que->begin(), que->end());
635 while (!que->empty()) {
636 auto node = que->front();
637 que->pop_front();
638 if (!visited.insert(node).second) continue;
639 segment_nodes.erase(node);
640 for (auto in : (is_input_nodes || node->type_string() == "Const")
641 ? node->in_nodes()
642 : node->out_nodes()) {
643 if (segment_nodes.count(in)) {
644 que->push_back(in);
645 if (VLOG_IS_ON(2)) {
646 if (!logged.count(in)) {
647 VLOG(2) << "----> Need to remove node " << in->name()
648 << " because one of its "
649 << (is_input_nodes ? "output" : "input")
650 << " nodes in the graph was removed: "
651 << node->name();
652 logged.insert(in);
653 }
654 }
655 }
656 }
657 }
658 };
659 remove_nodes(true, &in_nodes_que);
660 remove_nodes(false, &out_nodes_que);
661 }
662 VLOG(1) << "Segment new size: " << segment_nodes.size();
663 }
664
665 // --------------------------------- Step 3 ---------------------------------
666 // Convert the segments into the expected return format
667 for (const auto& itr : sg_map) {
668 const string& segment_root = itr.first;
669 // Return format does not require set comparator.
670 std::set<const Node*> segment_nodes(itr.second.begin(), itr.second.end());
671 if (VLOG_IS_ON(1) && !segment_nodes.empty()) {
672 string s;
673 for (auto node : segment_nodes) {
674 StrAppend(&s, "\n[Op type: ", node->type_string(), "] ", node->name());
675 }
676 VLOG(1) << "Nodes in segment " << segments->size()
677 << " with parent=" << segment_root << ":" << s;
678 }
679
680 // Don't use small segments.
681 if (static_cast<int>(segment_nodes.size()) < options.minimum_segment_size) {
682 VLOG(1) << "Segment " << segments->size() << " has only "
683 << segment_nodes.size() << " nodes, dropping";
684 continue;
685 }
686
687 // TODO(sami): Make segmenter placement aware once trtscopes are in place
688 const auto& dev_itr = device_maps.find(segment_root);
689 if (dev_itr == device_maps.end() || dev_itr->second.empty()) {
690 VLOG(1) << "No device assigned to segment " << segments->size();
691 segments->emplace_back(std::make_pair(segment_nodes, string()));
692 } else if (dev_itr->second.size() > 1) {
693 string s("Segment ");
694 StrAppend(&s, segments->size(), " has multiple devices attached: ");
695 for (const auto& dev : dev_itr->second) {
696 StrAppend(&s, dev, ", ");
697 }
698 LOG(WARNING) << s << " choosing " << *(dev_itr->second.begin());
699 segments->emplace_back(
700 std::make_pair(segment_nodes, *(dev_itr->second.begin())));
701 } else {
702 segments->emplace_back(
703 std::make_pair(segment_nodes, *(dev_itr->second.begin())));
704 }
705 }
706 if (VLOG_IS_ON(1)) {
707 for (const auto& d : device_maps) {
708 string s("Segment ");
709 StrAppend(&s, ": '", d.first, "' ");
710 for (const auto& dd : d.second) {
711 StrAppend(&s, dd, ", ");
712 }
713 VLOG(1) << "Devices " << s;
714 }
715 }
716 return Status::OK();
717 }
718
719 } // namespace segment
720 } // namespace tensorrt
721 } // namespace tensorflow
722
723 #endif // GOOGLE_TENSORRT
724 #endif // GOOGLE_CUDA
725