1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/costs/virtual_scheduler.h"
17 
18 #include "absl/strings/str_format.h"
19 #include "absl/strings/str_replace.h"
20 #include "tensorflow/core/framework/allocation_description.pb.h"
21 #include "tensorflow/core/framework/attr_value.pb.h"
22 #include "tensorflow/core/framework/node_def.pb.h"
23 #include "tensorflow/core/framework/tensor.pb.h"
24 #include "tensorflow/core/framework/tensor_description.pb.h"
25 #include "tensorflow/core/framework/tensor_shape.pb.h"
26 #include "tensorflow/core/grappler/clusters/utils.h"
27 #include "tensorflow/core/grappler/costs/utils.h"
28 #include "tensorflow/core/grappler/op_types.h"
29 #include "tensorflow/core/grappler/utils.h"
30 #include "tensorflow/core/grappler/utils/transitive_fanin.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/strings/numbers.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/util/device_name_utils.h"
35 
36 namespace tensorflow {
37 namespace grappler {
38 
39 const char kAttrInputSrc[] = "input_source_";
40 const char kAttrSrcDevice[] = "send_device";
41 const char kAttrDstDevice[] = "recv_device";
42 const char kAttrTensorName[] = "tensor_name";
43 const char kChannelDevice[] = "Channel";
44 const char kStreaming[] = "_streaming";
45 
46 namespace {
47 
48 using ::tensorflow::strings::HumanReadableNumBytes;
49 
Round2(const float x)50 float Round2(const float x) {
51   // Not using std::round from <cmath> here because not all platforms seem to
52   // support that (specifically Android).
53   return ::round(100.0 * x) / 100.0;
54 }
55 
FindOrCreateZero(const string & op_name,std::map<string,Costs> * op_cost)56 Costs& FindOrCreateZero(const string& op_name,
57                         std::map<string, Costs>* op_cost) {
58   auto it = op_cost->find(op_name);
59   if (it == op_cost->end()) {
60     // Note that default constructor of Costs sets some memory related fields
61     // to unknown values so we should explicitly initialize it with ZeroCosts.
62     it = op_cost->emplace(op_name, Costs::ZeroCosts()).first;
63   }
64   return it->second;
65 }
66 
67 // Key to the cached _Recv ops map, and its hash and predicate structures.
68 struct RecvNodeDescriptor {
69   const NodeDef* node;
70   const int port_num;
71   const string device;
72 
RecvNodeDescriptortensorflow::grappler::__anon277c9aad0111::RecvNodeDescriptor73   RecvNodeDescriptor(const NodeDef* node_, const int port_num_,
74                      const string& device_)
75       : node(node_), port_num(port_num_), device(device_) {}
76 };
77 
78 struct RecvNodeDescriptorHash {
operator ()tensorflow::grappler::__anon277c9aad0111::RecvNodeDescriptorHash79   std::size_t operator()(const RecvNodeDescriptor& recv_node) const {
80     return std::hash<const NodeDef*>()(recv_node.node) ^
81            std::hash<int>()(recv_node.port_num) ^
82            std::hash<string>()(recv_node.device);
83   }
84 };
85 
86 struct RecvNodeDescriptorEqual {
operator ()tensorflow::grappler::__anon277c9aad0111::RecvNodeDescriptorEqual87   bool operator()(const RecvNodeDescriptor& a,
88                   const RecvNodeDescriptor& b) const {
89     return a.node == b.node && a.port_num == b.port_num && a.device == b.device;
90   }
91 };
92 
UpdateDeviceAnnotationState(const NodeDef * node,const NodeState & node_state,DeviceState * device)93 void UpdateDeviceAnnotationState(const NodeDef* node,
94                                  const NodeState& node_state,
95                                  DeviceState* device) {
96   if (node->attr().count(kOutputShapes) == 0) return;
97 
98   int64 execution_count = node->attr().count(kExecutionCount) == 0
99                               ? 1
100                               : node->attr().at(kExecutionCount).i();
101 
102   auto& shape_annotation_stats = device->shape_annotation_stats;
103   shape_annotation_stats.num_ops_annotated += 1;
104   shape_annotation_stats.num_ops_executed += execution_count;
105   shape_annotation_stats.num_ops_executed_more_than_once +=
106       execution_count > 1 ? 1 : 0;
107   shape_annotation_stats.num_ops_with_incompatible_shapes +=
108       node_state.shape_incompatible ? 1 : 0;
109   shape_annotation_stats.num_ops_with_dynamic_shapes +=
110       (execution_count > 1 && node->attr().count(kOutputSame) == 0) ? 1 : 0;
111 }
112 
IsStreamingPort(const NodeDef & node,const int port)113 bool IsStreamingPort(const NodeDef& node, const int port) {
114   if (!node.attr().contains(kStreaming)) return false;
115 
116   auto& attr_list = node.attr().at(kStreaming).list();
117   bool is_streaming_port = false;
118   if (port >= 0 && port < attr_list.b().size()) {
119     is_streaming_port = attr_list.b(port);
120   }
121   return is_streaming_port;
122 }
123 
124 }  // namespace
125 
AddNode(const NodeDef * node)126 void LIFOManager::AddNode(const NodeDef* node) {
127   // Merge nodes are scheduled with the lowest priority in LIFO manager; virtual
128   // scheduler may run multiple input nodes of Merge (when we don't have
129   // annotation, which is quite common); simply scheduling Merge after one of
130   // its input may break scheduling constraints; some inputs of Merge may be
131   // scheduled after the Merge. So, we place Merge at the beginning of the queue
132   // to guarantee all the inputs of Merge are scheduled before the Merge.
133   if (IsMerge(*node)) {
134     nodes_.push_front(node);
135   } else {
136     nodes_.push_back(node);
137   }
138 }
139 
GetCurrNode()140 const NodeDef* LIFOManager::GetCurrNode() {
141   CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node";
142   if (curr_pos_ == nodes_.end()) {
143     curr_pos_ = --(nodes_.rbegin().base());  // Last one in the list.
144   }
145   // Once curr_pos_ is set to a valid entry in the list, we keep using the
146   // cached curr_pos_ until RemoveCurrNode() is called. AddNode() will not
147   // change the GetCurrNode() return value.
148   return *curr_pos_;
149 }
150 
RemoveCurrNode()151 void LIFOManager::RemoveCurrNode() {
152   // Make sure we have curr_pos_ ready to be removed.
153   GetCurrNode();
154   // Note curr_pos_ may not be pointing the last element if some nodes are
155   // added.
156   nodes_.erase(curr_pos_);
157 
158   curr_pos_ = nodes_.end();  // Reset curr_pos_.
159 }
160 
HeapReadyManager()161 HeapReadyManager::HeapReadyManager() : ReadyNodeManager() {
162   std::make_heap(nodes_.begin(), nodes_.end());
163 }
164 
Init(const std::unordered_map<const NodeDef *,NodeState> * node_map)165 Status HeapReadyManager::Init(
166     const std::unordered_map<const NodeDef*, NodeState>* node_map) {
167   // Resets the node state since different instances of the scheduler can reuse
168   // the same node_manager.
169   node_map_ = node_map;
170   nodes_.clear();
171   curr_node_ = nullptr;
172 
173   // Sets up the comparator for the heap.
174   greater_ = Greater();
175 
176   return Status::OK();
177 }
178 
AddNode(const NodeDef * node)179 void HeapReadyManager::AddNode(const NodeDef* node) {
180   // push_heap in AddNode and pop_heap in RemoveCurrNode() guarantees that the
181   // first element is the node with minimum time_ready.
182   nodes_.push_back(node);
183   std::push_heap(nodes_.begin(), nodes_.end(), greater_);
184 }
185 
GetCurrNode()186 const NodeDef* HeapReadyManager::GetCurrNode() {
187   if (curr_node_) return curr_node_;
188   if (nodes_.empty()) {
189     CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node";
190   }
191   const std::string node_name = nodes_.front()->name();
192   // Next time we call GetCurrNode(), it just returns the cached copy
193   // curr_node_, until we call the RemoveCurrNode().
194   curr_node_ = nodes_.front();
195   // Remove current node from the heap immediately. Because if we wait until
196   // later, the heap could have gotten re-organized if AddNode is called. The
197   // current node is anyways cached, incase GetCurrNode() is called again.
198   std::pop_heap(nodes_.begin(), nodes_.end(), greater_);
199   nodes_.pop_back();
200   return curr_node_;
201 }
202 
RemoveCurrNode()203 void HeapReadyManager::RemoveCurrNode() {
204   if (curr_node_) {
205     // If cached copy exists, remove that.
206     // Reset curr_node_ so that GetCurrNode() finds another node.
207     curr_node_ = nullptr;
208   } else {
209     // If cached copy not present, then remove entry from the heap queue.
210     std::pop_heap(nodes_.begin(), nodes_.end(), greater_);
211     nodes_.pop_back();
212   }
213 }
214 
Empty() const215 bool HeapReadyManager::Empty() const {
216   return nodes_.empty() && curr_node_ == nullptr;
217 }
218 
FirstReadyCmp(const std::unordered_map<const NodeDef *,NodeState> * node_map,const NodeDef * a,const NodeDef * b)219 bool FirstReadyCmp(
220     const std::unordered_map<const NodeDef*, NodeState>* node_map,
221     const NodeDef* a, const NodeDef* b) {
222   if (node_map->at(a).time_ready == node_map->at(b).time_ready) {
223     // Use Node name as tie-breaker for deterministic node scheduling.
224     return a->name().compare(b->name()) > 0;
225   } else {
226     // Note: we need a node with minimum time_ready, not maximum; hence, using
227     // a > b for comparison function.
228     return node_map->at(a).time_ready > node_map->at(b).time_ready;
229   }
230 }
231 
232 std::function<bool(const NodeDef*, const NodeDef*)>
Greater()233 FirstReadyManager::Greater() {
234   auto greater = [this](const NodeDef* a, const NodeDef* b) -> bool {
235     return FirstReadyCmp(node_map_, a, b);
236   };
237   return greater;
238 }
239 
240 std::function<bool(const NodeDef*, const NodeDef*)>
Greater()241 PriorityReadyManager::Greater() {
242   auto greater = [this](const NodeDef* a, const NodeDef* b) -> bool {
243     auto pri_a = node_priority_.at(a->name());
244     auto pri_b = node_priority_.at(b->name());
245     if (pri_a == pri_b) {
246       // Fallback to default (FirstReady) behaviour.
247       return FirstReadyCmp(node_map_, a, b);
248     }
249     return pri_a > pri_b;
250   };
251   return greater;
252 }
253 
AddNode(const NodeDef * node)254 void PriorityReadyManager::AddNode(const NodeDef* node) {
255   if (node_priority_.count(node->name()) == 0) {
256     VLOG(3) << "Priority of node " << node->name() << " not found.";
257     node_priority_[node->name()] = 0;
258   }
259   HeapReadyManager::AddNode(node);
260 }
261 
SetPriority(const std::unordered_map<string,int> & node_priority)262 Status PriorityReadyManager::SetPriority(
263     const std::unordered_map<string, int>& node_priority) {
264   node_priority_ = node_priority;
265   return Status::OK();
266 }
267 
CompositeNodeManager()268 CompositeNodeManager::CompositeNodeManager()
269     : ReadyNodeManager(), send_manager_(), recv_manager_() {}
270 
Init(const std::unordered_map<const NodeDef *,NodeState> * node_map)271 Status CompositeNodeManager::Init(
272     const std::unordered_map<const NodeDef*, NodeState>* node_map) {
273   node_map_ = node_map;
274   TF_RETURN_IF_ERROR(send_manager_.Init(node_map));
275   TF_RETURN_IF_ERROR(recv_manager_.Init(node_map));
276   curr_node_ = nullptr;
277   return Status::OK();
278 }
279 
AddNode(const NodeDef * node)280 void CompositeNodeManager::AddNode(const NodeDef* node) {
281   if (IsSend(*node)) {
282     send_manager_.AddNode(node);
283   } else if (IsRecv(*node)) {
284     recv_manager_.AddNode(node);
285   } else {
286     const auto& device = node_map_->at(node).device_name;
287     ops_lifo_map_[device].AddNode(node);
288   }
289 }
290 
GetCurrNode()291 const NodeDef* CompositeNodeManager::GetCurrNode() {
292   if (curr_node_) return curr_node_;
293 
294   // Per-device LIFO for normal ops (not _Send / _Recv),
295   // FirstReady for _Send and _Recv (separately),
296   // Globally (among the LIFO-selected ops from each device and _Send and
297   // _Recv) FirstReady,
298   // Priority order: _Send, _Recv, and then the rest, if time_ready is equal.
299   std::vector<std::pair<const NodeDef*, Costs::Duration>> candidates;
300   for (auto& ops_lifo : ops_lifo_map_) {
301     if (!ops_lifo.second.Empty()) {
302       const auto* op = ops_lifo.second.GetCurrNode();
303       candidates.emplace_back(op, node_map_->at(op).time_ready);
304     }
305   }
306   if (!send_manager_.Empty()) {
307     const auto* send = send_manager_.GetCurrNode();
308     candidates.emplace_back(send, node_map_->at(send).time_ready);
309   }
310   if (!recv_manager_.Empty()) {
311     const auto* recv = recv_manager_.GetCurrNode();
312     candidates.emplace_back(recv, node_map_->at(recv).time_ready);
313   }
314   CHECK(!candidates.empty());
315   auto first_ready = std::min_element(
316       candidates.begin(), candidates.end(),
317       [](const std::pair<const NodeDef*, Costs::Duration>& a,
318          const std::pair<const NodeDef*, Costs::Duration>& b) {
319         if (a.second == b.second) {
320           // Note that there can be only 1 Send and only 1 Recv in candidates,
321           // at most; hence, score is 2 for Send, 1 for Recv, and 0 for a
322           // normap op, and a_score and b_score are equal only if both are
323           // normal ops.
324           int a_score = 2 * IsSend(*a.first) + IsRecv(*a.first);
325           int b_score = 2 * IsSend(*b.first) + IsRecv(*b.first);
326           if (a_score == b_score) {
327             // Both are normal ops; use node name as tie breaker.
328             return a.first->name().compare(b.first->name()) < 0;
329           } else {
330             // Prioritize by op type: _Send, _Recv, and normap ops.
331             return a_score > b_score;
332           }
333         } else {
334           return a.second < b.second;
335         }
336       });
337   // Next time we call GetCurrNode(), it just returns the cached one,
338   // curr_node_ until we call RemovCurrNode().
339   curr_node_ = first_ready->first;
340 
341   return curr_node_;
342 }
343 
RemoveCurrNode()344 void CompositeNodeManager::RemoveCurrNode() {
345   const auto* node = GetCurrNode();
346   if (IsSend(*node)) {
347     send_manager_.RemoveCurrNode();
348   } else if (IsRecv(*node)) {
349     recv_manager_.RemoveCurrNode();
350   } else {
351     const auto device = node_map_->at(node).device_name;
352     ops_lifo_map_[device].RemoveCurrNode();
353   }
354   // Reset curr_node_ so that GetCurrNode() finds another node.
355   curr_node_ = nullptr;
356 }
357 
Empty() const358 bool CompositeNodeManager::Empty() const {
359   // Empty if all the ready managers are empty.
360   bool empty = true;
361   for (const auto& ops_lifo : ops_lifo_map_) {
362     empty &= ops_lifo.second.Empty();
363   }
364   return empty && send_manager_.Empty() && recv_manager_.Empty();
365 }
366 
ReadyNodeManagerFactory(const string & ready_node_manager)367 std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory(
368     const string& ready_node_manager) {
369   if (ready_node_manager == "FIFO") {
370     return absl::make_unique<FIFOManager>();
371   } else if (ready_node_manager == "LIFO") {
372     return absl::make_unique<LIFOManager>();
373   } else if (ready_node_manager == "FirstReady") {
374     return absl::make_unique<FirstReadyManager>();
375   } else if (ready_node_manager == "Composite") {
376     return absl::make_unique<CompositeNodeManager>();
377   }
378   LOG(FATAL) << "Not a valid ready node manager: " << ready_node_manager;
379   return nullptr;
380 }
381 
~SchedulerState()382 SchedulerState::~SchedulerState() {}
383 
SchedulerState(const bool use_static_shapes,const bool use_aggressive_shape_inference,Cluster * cluster,std::unique_ptr<VirtualPlacer> placer)384 SchedulerState::SchedulerState(const bool use_static_shapes,
385                                const bool use_aggressive_shape_inference,
386                                Cluster* cluster,
387                                std::unique_ptr<VirtualPlacer> placer)
388     : graph_costs_(Costs::ZeroCosts()),
389       cluster_(cluster),
390       use_static_shapes_(use_static_shapes),
391       use_aggressive_shape_inference_(use_aggressive_shape_inference),
392       placer_(std::move(placer)) {
393   DCHECK(placer_);  // check if the pointer is valid.
394   graph_costs_.num_ops_total = 0;
395   initialized_ = false;
396   track_mem_usage_snapshot_ = VLOG_IS_ON(1);
397 }
398 
Init(const GrapplerItem * item,std::vector<const NodeDef * > * initial_nodes,bool create_explicit_channel_device)399 Status SchedulerState::Init(const GrapplerItem* item,
400                             std::vector<const NodeDef*>* initial_nodes,
401                             bool create_explicit_channel_device) {
402   initialized_ = false;
403 
404   // Clear all internal states so that the SchedulerState is reusable for
405   // different GrapplerItems
406   node_map_.clear();
407   device_.clear();
408   additional_nodes_.clear();
409 
410   graph_costs_ = Costs::ZeroCosts();
411   graph_costs_.num_ops_total = 0;
412   op_to_cost_.clear();
413 
414   op_counts_.clear();
415   op_costs_.clear();
416 
417   initial_nodes->clear();
418 
419   // Constructs graph properties and performs shape inference.
420   graph_properties_ = absl::make_unique<GraphProperties>(*item);
421   // TODO(safeen,dyoon): Will we ever use InferDynamically? If not we may want
422   // to get rid of use_static_shapes_ and cluster_.
423   if (use_static_shapes_) {
424     TF_RETURN_IF_ERROR(graph_properties_->InferStatically(
425         true, use_aggressive_shape_inference_, true));
426   } else {
427     TF_RETURN_IF_ERROR(graph_properties_->InferDynamically(cluster_));
428   }
429 
430   grappler_item_ = item;
431   const auto& graph = grappler_item_->graph;
432   const auto& fetch_nodes = grappler_item_->fetch;
433   std::set<string> feed_nodes;
434 
435   for (const auto& f : grappler_item_->feed) {
436     auto iter_and_inserted_flag = feed_nodes.insert(f.first);
437     QCHECK(iter_and_inserted_flag.second)
438         << "Duplicate feed node found: " << f.first;
439   }
440 
441   // Get the nodes that would run to output fetch_nodes.
442   std::unordered_map<string, const NodeDef*> name_to_node;
443   std::vector<const NodeDef*> fetch_fanin_nodes;
444   TF_RETURN_IF_ERROR(ComputeTransitiveFanin(graph, fetch_nodes, &name_to_node,
445                                             &fetch_fanin_nodes));
446 
447   // Once ComputeTransitiveFanin is complete, only the nodes that can be reached
448   // from the fetch nodes are scheduled. So the scheduled nodes should be
449   // exactly the same as those executed for real. One possible discrepancy could
450   // be the control flow nodes, where tf only executes one path.
451 
452   // Traverses the graph to record _Send nodes.
453   // TODO(dyoon): Instead of identifying _Send node here manually, add _Send
454   // to _Recv as control dependency when creating GrapplerItem.
455   std::unordered_map<string, const NodeDef*> name_to_send;
456   for (const auto& node : graph.node()) {
457     if (IsSend(node)) {
458       const auto& attr = node.attr();
459       name_to_send[attr.at("tensor_name").s()] = &node;
460     }
461   }
462 
463   // To reuse _Recv ops.
464   std::unordered_map<RecvNodeDescriptor, const NodeDef*, RecvNodeDescriptorHash,
465                      RecvNodeDescriptorEqual>
466       cached_recv_nodes;
467 
468   // Build node_map; for each node, create its NodeState and connect its inputs
469   // and outputs.
470   for (const auto* curr_node : fetch_fanin_nodes) {
471     auto& curr_node_state = GetNodeStateOrCreateIt(curr_node);
472     const string curr_node_device = DeviceName(curr_node);
473     std::vector<string> inputs;
474     if (IsRecv(*curr_node)) {
475       const auto& attr = curr_node->attr();
476       if (attr.count("tensor_name")) {
477         const auto& send_node_name = attr.at("tensor_name").s();
478         auto it = name_to_send.find(send_node_name);
479         // If there is a _Send associated with the curr_node (_Recv), add it as
480         // input.
481         if (it != name_to_send.end()) {
482           const NodeDef* send = it->second;
483           inputs = {send->name()};
484         }
485       }
486     } else {
487       for (const string& input : curr_node->input()) {
488         inputs.push_back(input);
489       }
490     }
491     for (const string& input_node_name : inputs) {
492       // Note that input_node_name may be in <prefix><node_name>:<port_num>
493       // format, where <prefix> (e.g., "^" for control dependency) and
494       // ":<port_num>" may be omitted. NodeName() extracts only the node_name.
495       const NodeDef* input_node = name_to_node[NodeName(input_node_name)];
496 
497       CHECK(input_node);
498       const string in_device = DeviceName(input_node);
499       const auto input_node_port_num = NodePosition(input_node_name);
500 
501       // Control dependencies should be treated as high priority. Current
502       // Channel device doesn't model a separate virual channel for control v/s
503       // data transfers. So in the interim, it may be okay to let control
504       // dependencies magically flow across devices bypassing the channel
505       // device.
506       if (curr_node_device == in_device || IsControlInput(input_node_name)) {
507         // Same device: connect input_node and curr_node directly.
508         curr_node_state.inputs.push_back(
509             std::make_pair(input_node, input_node_port_num));
510         auto& input_node_state = GetNodeStateOrCreateIt(input_node);
511         input_node_state.outputs[input_node_port_num].push_back(curr_node);
512       } else {
513         RecvNodeDescriptor recv_node(input_node, input_node_port_num,
514                                      curr_node_device);
515         auto it = cached_recv_nodes.find(recv_node);
516         if (it != cached_recv_nodes.end()) {
517           // Different device, but found an already-cached copy (a _Recv op);
518           // connect the _Recv to curr_node.
519           const NodeDef* recv_op = it->second;
520           // recv_op's output port is hard-coded to zero.
521           curr_node_state.inputs.push_back(std::make_pair(recv_op, 0));
522           auto& input_node_state = node_map_.at(recv_op);
523           input_node_state.outputs[0].push_back(curr_node);
524         } else {
525           // Different device, no cached copy; transfer input_node to the
526           // curr_node's device.
527           auto send_and_recv =
528               CreateSendRecv(input_node, curr_node, input_node, input_node_name,
529                              create_explicit_channel_device);
530           // Note that CreateSendRecv() already connected input/output between
531           // _Send and _Recv ops.
532           const auto* send = send_and_recv.first;
533           const auto* recv = send_and_recv.second;
534           // recv_op's output port is hard-coded to zero.
535           curr_node_state.inputs.push_back(std::make_pair(recv, 0));
536           auto& input_node_state = GetNodeStateOrCreateIt(input_node);
537           input_node_state.outputs[input_node_port_num].push_back(send);
538 
539           // Cache the _Recv op for future use.
540           cached_recv_nodes[recv_node] = recv;
541         }
542       }
543     }
544 
545     // Special case: given feed nodes are ready at time 0.
546     const bool given_as_feed =
547         feed_nodes.find(curr_node->name()) != feed_nodes.end();
548 
549     // Default case: node without inputs are ready at time 0.
550     // Note that we check inputs vector which may be different to
551     // curr_node->input(); e.g., we add Send as input to Recv.
552     const bool has_no_inputs = inputs.empty();
553 
554     if (given_as_feed || has_no_inputs) {
555       curr_node_state.time_ready = Costs::Duration();
556       initial_nodes->push_back(curr_node);
557       VLOG(3) << "Added ready node: " << curr_node->name();
558     }
559     feed_nodes.erase(curr_node->name());
560 
561     if (IsPersistent(*curr_node)) {
562       auto& device_state = device_[curr_node_device];
563       for (int port_num = 0,
564                port_num_end = curr_node_state.output_properties.size();
565            port_num < port_num_end; ++port_num) {
566         device_state.persistent_nodes.insert(
567             std::make_pair(curr_node, port_num));
568       }
569     }
570   }
571 
572   if (initial_nodes->empty()) {
573     return errors::InvalidArgument("No ready nodes in the graph.");
574   }
575 
576   if (!feed_nodes.empty()) {
577     // This isn't always a bug: when the caller hasn't specified the exact list
578     // of feed and fetch nodes, by default we consider all placeholders as feed
579     // nodes, but some of them may not be needed for the default fetch node.
580     VLOG(1) << "Some feed nodes were not consumed by the fetch fanin: "
581             << absl::StrJoin(feed_nodes, ",");
582   }
583 
584   initialized_ = true;
585   return Status::OK();
586 }
587 
MaybeUpdateInputOutput(const NodeDef * node)588 void SchedulerState::MaybeUpdateInputOutput(const NodeDef* node) {
589   CHECK(!initialized_) << "MaybeUpdateInputOutput is called after Init().";
590   // This method is called when NodeState is created and adds input and output
591   // properties for a few exceptional cases that GraphProperties cannot provide
592   // input/output properties.
593   if ((IsSend(*node) || IsRecv(*node)) && node->attr().count(kAttrInputSrc)) {
594     // _Send and _Recv ops created from SchedulerState have kAttrInputSrc
595     // attr; normal _Send and _Recv ops (from the input graph) do not have that
596     // attr.
597     auto& node_state = node_map_[node];
598     auto& inputs = node_state.input_properties;
599     auto& outputs = node_state.output_properties;
600 
601     // _Send and _Recv ops are created from SchedulerState, so
602     // there should be no inputs TensorProperties.
603     CHECK(inputs.empty());
604     CHECK(outputs.empty());
605     const auto& attr = node->attr();
606     // This is the original input source to the _Send and _Recv, and this
607     // string includes "^" if it was control dependency, and output port
608     /// (e.g., ":2") if the input source had multiple outputs.
609     const auto& input_source_name = attr.at(kAttrInputSrc).s();
610     if (IsControlInput(input_source_name)) {
611       // Control dependency; regardless of the input source tensor size,
612       // send 4B.
613       OpInfo::TensorProperties control_message;
614       control_message.set_dtype(DT_FLOAT);
615       control_message.mutable_shape()->add_dim()->set_size(1);
616       auto* value = control_message.mutable_value();
617       value->add_float_val(1);
618       inputs.push_back(control_message);
619       outputs.push_back(control_message);
620     } else {
621       const auto& output_properties =
622           graph_properties_->GetOutputProperties(NodeName(input_source_name));
623       // Like with HasInputProperties, if a node does not have output
624       // properties, it's likely it was pruned during the shape inference run.
625       if (!output_properties.empty()) {
626         const auto input_node_port_num = NodePosition(input_source_name);
627         // Use the input source's output property as _Send and _Recv's input
628         // property.
629         CHECK_GT(output_properties.size(), input_node_port_num);
630         inputs.push_back(output_properties[input_node_port_num]);
631         outputs.push_back(output_properties[input_node_port_num]);
632       }
633     }
634   }
635 }
636 
DeviceName(const NodeDef * node) const637 string SchedulerState::DeviceName(const NodeDef* node) const {
638   return placer_->get_canonical_device_name(*node);
639 }
640 
SanitizedDeviceName(const NodeDef * node) const641 string SchedulerState::SanitizedDeviceName(const NodeDef* node) const {
642   // Replace the ":" characters that may be present in the device name with "_".
643   // This makes it possible to then use the resulting string in a node name.
644   return absl::StrReplaceAll(placer_->get_canonical_device_name(*node),
645                              {{":", "_"}});
646 }
647 
ChannelDeviceName(const NodeDef * from,const NodeDef * to) const648 string SchedulerState::ChannelDeviceName(const NodeDef* from,
649                                          const NodeDef* to) const {
650   CHECK(!initialized_) << "ChannelDeviceName is called after Init().";
651   return absl::StrCat(kChannelDevice, "_from_", SanitizedDeviceName(from),
652                       "_to_", SanitizedDeviceName(to));
653 }
654 
CreateSendRecv(const NodeDef * from,const NodeDef * to,const NodeDef * input_node,const string & input_name,bool create_channel_device)655 std::pair<const NodeDef*, const NodeDef*> SchedulerState::CreateSendRecv(
656     const NodeDef* from, const NodeDef* to, const NodeDef* input_node,
657     const string& input_name, bool create_channel_device) {
658   CHECK(!initialized_) << "CreateSendRecv is called after Init().";
659 
660   // Connect "from" node to "to" node with _Send and _Recv such that
661   // from -> _Send -> _Recv -> to.
662   // _Send is placed on "Channel" device, and _Recv is on the same device
663   // as "to" node.
664   // input_node_name is the string from the "to" node to identify which output
665   // we get from the "from" node.
666 
667   // Note that we use NodeState for scheduling, so _Send and _Recv
668   // NodeDefs created here need not be correct: in terms of name,
669   // input names, attrs, etc.
670 
671   auto input_node_port_num = NodePosition(input_name);
672   string src_name;
673   bool control_input = false;
674   if (input_node_port_num >= 0) {
675     src_name = absl::StrCat(from->name(), "_", input_node_port_num);
676   } else {
677     src_name = absl::StrCat(from->name(), "_minus1");
678     control_input = true;
679   }
680 
681   // _Send op.
682   auto* send = new NodeDef();
683   send->set_name("Send_" + src_name + "_from_" + SanitizedDeviceName(from) +
684                  "_to_" + SanitizedDeviceName(to));
685   send->set_op("_Send");
686   send->add_input(from->name());
687   auto send_device =
688       create_channel_device ? ChannelDeviceName(from, to) : DeviceName(from);
689   send->set_device(send_device);
690   auto& send_attr = *(send->mutable_attr());
691   send_attr[kAttrInputSrc].set_s(input_name);
692   send_attr[kAttrSrcDevice].set_s(DeviceName(from));
693   send_attr[kAttrDstDevice].set_s(DeviceName(to));
694   // GraphDef generated by AutoGrappler has tensor_name field when removing
695   // _Send/_Recv nodes.
696   if (input_node->attr().count(kAttrTensorName)) {
697     send_attr[kAttrTensorName].set_s(
698         input_node->attr().at(kAttrTensorName).s());
699   }
700 
701   // _Recv op.
702   auto* recv = new NodeDef();
703   recv->set_name("Recv_" + src_name + "_on_" + SanitizedDeviceName(to));
704   recv->set_op("_Recv");
705   recv->add_input(send->name());
706   recv->set_device(DeviceName(to));
707   auto& recv_attr = *(recv->mutable_attr());
708   recv_attr[kAttrInputSrc].set_s(input_name);
709   if (input_node->attr().count(kAttrTensorName)) {
710     recv_attr[kAttrTensorName].set_s(
711         input_node->attr().at(kAttrTensorName).s());
712   }
713 
714   // Propagate the streaming attribute to the send/recv nodes.
715   if (from->attr().contains(kStreaming) && !control_input) {
716     if (input_node_port_num >= from->attr().at(kStreaming).list().b_size()) {
717       LOG(ERROR)
718           << from->name()
719           << " port index larger than length of _streaming attribute list.";
720     } else if (from->attr().at(kStreaming).list().b(input_node_port_num)) {
721       send_attr[kStreaming].mutable_list()->add_b(true);
722       recv_attr[kStreaming].mutable_list()->add_b(true);
723     }
724   }
725 
726   // NodeState for _Send op.
727   auto& send_node_state = GetNodeStateOrCreateIt(send);
728   send_node_state.device_name = send->device();  // Set Channel device.
729   send_node_state.inputs.push_back(std::make_pair(from, input_node_port_num));
730   send_node_state.outputs[0].push_back(recv);
731 
732   // NodeState for _Recv op.
733   auto& recv_node_state = GetNodeStateOrCreateIt(recv);
734   recv_node_state.inputs.push_back(std::make_pair(send, 0));
735   recv_node_state.outputs[0].push_back(to);
736 
737   // Keep the created nodes.
738   additional_nodes_.emplace_back(std::unique_ptr<NodeDef>(send));
739   additional_nodes_.emplace_back(std::unique_ptr<NodeDef>(recv));
740 
741   // Return _Send and _Recv.
742   return std::make_pair(send, recv);
743 }
744 
CreateOpContext(const NodeDef * node) const745 OpContext SchedulerState::CreateOpContext(const NodeDef* node) const {
746   // Get the device from the placer.
747   DeviceProperties device;
748   device = placer_->get_device(*node);
749 
750   // Special case for _Send op.
751   if (IsSend(*node)) {
752     device.set_type(kChannelDevice);
753   }
754 
755   // Construct OpContext.
756   OpContext op_context;
757   const auto& node_state = node_map_.at(node);
758   op_context.name = node->name();
759   op_context.device_name = node_state.device_name;
760   auto& op_info = op_context.op_info;
761   op_info.set_op(node->op());
762   *op_info.mutable_attr() = node->attr();
763   for (auto& input : node_state.input_properties) {
764     *op_info.add_inputs() = input;
765   }
766   for (auto& output : node_state.output_properties) {
767     *op_info.add_outputs() = output;
768   }
769   op_info.mutable_device()->Swap(&device);
770 
771   if (grappler_item_->graph.has_library()) {
772     op_context.function_library = &grappler_item_->graph.library();
773   }
774   return op_context;
775 }
776 
GetNodeStateOrCreateIt(const NodeDef * node)777 NodeState& SchedulerState::GetNodeStateOrCreateIt(const NodeDef* node) {
778   CHECK(!initialized_) << "GetNodeStateOrCreateIt is called after Init().";
779 
780   auto it = node_map_.find(node);
781   if (it != node_map_.end()) {
782     return it->second;
783   }
784 
785   // Not found; create a NodeState for this node.
786   it = node_map_.emplace(node, NodeState()).first;
787   auto& node_state = it->second;
788   node_state.input_properties =
789       graph_properties_->GetInputProperties(node->name());
790   node_state.output_properties =
791       graph_properties_->GetOutputProperties(node->name());
792   node_state.shape_incompatible =
793       graph_properties_->CheckShapeIncompatible(node->name());
794 
795   // Some ops may need further processing to the input / output properties:
796   // _Send and _Recv.
797   MaybeUpdateInputOutput(node);
798 
799   if (!IsSend(*node)) {
800     node_state.device_name = DeviceName(node);
801     // For _Send op, device_name will be set to Channel in CreateSendRecv().
802   }
803 
804   // Initialize output port related data:
805   // Assume the size of OutputProperties represents the number of output ports
806   // of this node.
807   for (size_t i = 0; i < node_state.output_properties.size(); ++i) {
808     node_state.time_no_references[i] = Costs::Duration::max();
809     node_state.num_outputs_executed[i] = 0;
810     // Populate an empty vector for each port. The caller will add nodes
811     // that use this port as input.
812     node_state.outputs[i] = {};
813   }
814   // Port_num -1 is for control dependency.
815   node_state.time_no_references[-1] = Costs::Duration::max();
816   node_state.num_outputs_executed[-1] = 0;
817   node_state.outputs[-1] = {};
818 
819   // Initialize time_scheduled to infinity, so we know whether it has been
820   // assigned a non-default value later.
821   node_state.time_scheduled = Costs::Duration().infinity();
822 
823   return it->second;
824 }
825 
GetOutputNodes(const NodeDef * node,const Costs::Duration & curr_time,std::vector<const NodeDef * > * output_nodes)826 void SchedulerState::GetOutputNodes(const NodeDef* node,
827                                     const Costs::Duration& curr_time,
828                                     std::vector<const NodeDef*>* output_nodes) {
829   // Checks whether the Switch's output slots change over iterations.
830   int slot = -1;
831   if (IsSwitch(*node) && node->attr().count(kOutputSlots) > 0 &&
832       node->attr().at(kOutputSlots).list().i_size() > 0) {
833     slot = node->attr().at(kOutputSlots).list().i(0);
834     for (int i = 1; i < node->attr().at(kOutputSlots).list().i_size(); ++i) {
835       if (slot != node->attr().at(kOutputSlots).list().i(i)) {
836         slot = -1;
837         break;
838       }
839     }
840   }
841   // Increment num_inputs_ready of the output nodes and maybe add to ready
842   // nodes.
843   auto& node_state = node_map_[node];
844   for (const auto& port_num_output_pair : node_state.outputs) {
845     // If Switch is annotated and its output slots are always the same, we only
846     // schedule the slot that was executed. Otherwise, scheduler both slots.
847     if (slot >= 0 && port_num_output_pair.first != slot) continue;
848 
849     for (auto* output_node : port_num_output_pair.second) {
850       auto& output_state = node_map_[output_node];
851       output_state.num_inputs_ready++;
852       // Execute a node as soon as all its inputs are ready. Merge nodes are
853       // special since they run as soon as one of their inputs becomes
854       // available.
855       int output_state_inputs_size = output_state.inputs.size();
856       if (output_state.num_inputs_ready == output_state_inputs_size ||
857           IsMerge(*output_node)) {
858         // This output node is now ready.
859         output_state.time_ready = curr_time;
860         output_nodes->push_back(output_node);
861         VLOG(3) << "  Add output: " << output_node->name();
862       }
863     }
864   }
865 }
866 
MarkNodeExecuted(const NodeDef * node,const Costs & node_costs,const OpContext & op_context)867 std::vector<const NodeDef*> SchedulerState::MarkNodeExecuted(
868     const NodeDef* node, const Costs& node_costs, const OpContext& op_context) {
869   auto& node_state = node_map_[node];
870   // TODO(dyoon, andiryxu): Consider to revisit node execution w.r.t. Switch and
871   // Merge -- it can create a loop which may include loop-carried dependency,
872   // diverge-merge, and other complex execution patterns.
873   bool previously_executed_merge =
874       IsMerge(*node) && (node_state.time_finished != Costs::Duration::max());
875 
876   // If there is annotation in the graph about execution times, we use that
877   // number, otherwise, we assume the node is executed once.
878   node_state.execution_count = node->attr().count(kExecutionCount) == 0
879                                    ? 1
880                                    : node->attr().at(kExecutionCount).i();
881 
882   node_state.node_costs = node_costs;
883   // TotalNodeCosts() Should be called after node_costs and execution_count.
884   Costs total_node_costs = node_state.TotalNodeCosts();
885 
886   graph_costs_ = CombineCosts(graph_costs_, total_node_costs);
887   const string& op_name = node->op();
888 
889   auto& op_cost = FindOrCreateZero(op_name, &op_to_cost_);
890   op_cost = CombineCosts(op_cost, total_node_costs);
891 
892   if (VLOG_IS_ON(2)) {
893     // Also keep track of op counts and costs per op (with their shapes).
894     string node_description = GetOpDescription(op_context.op_info);
895     op_counts_[node_description] += 1;
896     op_costs_[node_description] =
897         std::make_pair(total_node_costs.execution_time.asMicroSeconds().count(),
898                        !node_costs.inaccurate);
899   }
900 
901   // Update node and device states.
902   auto& device = device_[node_state.device_name];
903   device.nodes_executed.push_back(node);
904   // Node is scheduled when the device is available AND all the inputs are
905   // ready; hence, time_scheduled is time_ready if time_ready > device curr
906   // time.
907   // NodeState times are assigned infinity at initialization. If they are
908   // still infinity here, we need to assign them. If not, it has been assigned
909   // already, so skip. This latter case may occur when a scheduler in-lines
910   // function calls, and thus schedules only function sub-nodes.
911   if (node_state.time_scheduled == Costs::Duration().infinity()) {
912     node_state.time_scheduled =
913         std::max(device.GetCurrTime(), node_state.time_ready);
914     // Override device curr time with the time_scheduled.
915     device.device_costs.execution_time = node_state.time_scheduled;
916   }
917   device.device_costs = CombineCosts(device.device_costs, total_node_costs);
918   auto curr_time = device.GetCurrTime();
919   node_state.time_finished = curr_time;
920 
921   // Update shape annotation states.
922   UpdateDeviceAnnotationState(node, node_state, &device);
923 
924   // Update device memory usage.
925   if (!IsPersistent(*node)) {
926     for (const auto& port_num_output_pair : node_state.outputs) {
927       int port_num = port_num_output_pair.first;
928 
929       // There's a chance that a specific output is not used at all.
930       if (node_state.outputs[port_num].empty()) {
931         node_state.time_no_references[port_num] = curr_time;
932       } else {
933         // Streaming outputs do not allocate memory, they are directly consumed
934         // by the target node.
935         if (!IsStreamingPort(*node, port_num)) {
936           device.memory_usage +=
937               CalculateOutputSize(node_state.output_properties, port_num) *
938               node_state.execution_count;
939         }
940         device.nodes_in_memory.insert(std::make_pair(node, port_num));
941       }
942     }
943   }
944 
945   // Update device's per-op cost.
946   auto& device_op_cost = FindOrCreateZero(op_name, &device.op_to_cost);
947   device_op_cost = CombineCosts(device_op_cost, total_node_costs);
948 
949   VLOG(3) << "Op scheduled -- name: " << node->name() << ", op: " << node->op()
950           << ", device: " << node->device()
951           << ", execution_count: " << node_state.execution_count
952           << ", ready: " << node_state.time_ready.count()
953           << ", scheduled: " << node_state.time_scheduled.count()
954           << ", finished: " << node_state.time_finished.count();
955   std::vector<const NodeDef*> new_nodes;
956   if (previously_executed_merge) {
957     // Skip AddOutputNodesToReadyQueue; this is due to Switch-Merge.
958     VLOG(1) << "node [ " << node->name() << ", " << node->op() << " ] "
959             << "is executed more than once. "
960             << "Skip scheduling its output nodes.";
961   } else {
962     // Checks outputs, and adds ready nodes to queue.
963     GetOutputNodes(node, curr_time, &new_nodes);
964   }
965 
966   // When op is scheduled, both input and output tensors must be allocated in
967   // memory. Now that output memory is added, check max memory usage.
968   if (!IsPersistent(*node)) {
969     if (device.memory_usage > device.max_memory_usage) {
970       device.max_memory_usage = device.memory_usage;
971 
972       if (track_mem_usage_snapshot_) {
973         device.mem_usage_snapshot_at_peak = device.nodes_in_memory;
974       }
975     }
976   }
977 
978   // Increment num_outputs_executed of the input nodes and maybe update memory.
979   for (const auto& input_port : node_state.inputs) {
980     auto* input = input_port.first;
981     auto port = input_port.second;
982 
983     auto& input_state = node_map_[input];
984     input_state.num_outputs_executed[port]++;
985     int input_state_outputs_size_ = input_state.outputs[port].size();
986     if (input_state.num_outputs_executed[port] == input_state_outputs_size_ &&
987         !IsPersistent(*input)) {
988       // All the outputs are executed; no reference to this output port of
989       // input node.
990       input_state.time_no_references[port] = curr_time;
991       auto& input_device = device_[input_state.device_name];
992       // If the node input is marked as streaming, then it wasn't allocated
993       // in memory. A streaming input is still reference counted, but it doesn't
994       // de-allocate memory.
995       if (!IsStreamingPort(*input, port)) {
996         input_device.memory_usage -=
997             CalculateOutputSize(input_state.output_properties, port) *
998             node_state.execution_count;
999       }
1000 
1001       input_device.nodes_in_memory.erase(std::make_pair(input, port));
1002     }
1003   }
1004 
1005   return new_nodes;
1006 }
1007 
Summary() const1008 Costs SchedulerState::Summary() const {
1009   // Overall statement about accuracy
1010   VLOG(1) << graph_costs_.num_ops_total << " ops processed in total, with "
1011           << graph_costs_.num_ops_with_unknown_shapes
1012           << " having unknown shapes";
1013 
1014   // Print out basic execution summary.
1015   VLOG(1) << "Expected execution time: " << graph_costs_.execution_time.count();
1016   VLOG(1) << "Expected compute time: " << graph_costs_.compute_time.count();
1017   VLOG(1) << "Expected memory time: " << graph_costs_.memory_time.count();
1018   VLOG(1) << "Expected intermediate memory time: "
1019           << graph_costs_.intermediate_memory_time.count();
1020   VLOG(1) << "Expected max memory: " << graph_costs_.max_memory;
1021   VLOG(1) << "Expected max per-op buffers: " << graph_costs_.max_per_op_buffers;
1022   VLOG(1) << "Expected max per-op streaming buffers: "
1023           << graph_costs_.max_per_op_streaming;
1024 
1025   VLOG(1) << "Per-op execution time / compute time / memory time"
1026           << " / intermediate memory time:";
1027   for (const auto& op_cost_pair : op_to_cost_) {
1028     const auto& op = op_cost_pair.first;
1029     const auto& cost = op_cost_pair.second.execution_time.count();
1030     const auto& compute_cost = op_cost_pair.second.compute_time.count();
1031     const auto& memory_cost = op_cost_pair.second.memory_time.count();
1032     const auto& intermediate_memory_cost =
1033         op_cost_pair.second.intermediate_memory_time.count();
1034     const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate;
1035     if (cost) {  // Skip printing out zero-cost ops.
1036       VLOG(1) << absl::StrFormat(" + %30s : %c %10d / %10d / %10d / %10d", op,
1037                                  (is_op_cost_accurate ? ' ' : '~'), cost,
1038                                  compute_cost, memory_cost,
1039                                  intermediate_memory_cost);
1040     }
1041   }
1042 
1043   // Print per device summary
1044   VLOG(1) << "Devices:";
1045   Costs critical_path_costs = Costs::ZeroCosts();
1046   std::vector<string> device_names;
1047   device_names.reserve(device_.size());
1048   for (auto& it : device_) {
1049     device_names.push_back(it.first);
1050   }
1051   std::sort(device_names.begin(), device_names.end());
1052 
1053   for (const auto& name : device_names) {
1054     const auto& state = device_.at(name);
1055 
1056     std::map<string, int64> op_to_memory;
1057     // First profile only persistent memory usage.
1058     int64 persistent_memory_usage = 0;
1059     std::set<string> persistent_ops;
1060     for (const auto& node_port : state.persistent_nodes) {
1061       const auto* node = node_port.first;
1062       const auto port = node_port.second;
1063       auto output_size = 0;
1064       // Check if the node is in the node_map. It may be that the node executed
1065       // on this device was executed by a different Scheduler.
1066       if (node_map_.find(node) != node_map_.end()) {
1067         output_size =
1068             CalculateOutputSize(node_map_.at(node).output_properties, port);
1069       }
1070       persistent_memory_usage += output_size;
1071       op_to_memory[node->op()] += output_size;
1072       persistent_ops.insert(node->op());
1073     }
1074     int64 max_memory_usage = persistent_memory_usage + state.max_memory_usage;
1075     critical_path_costs.estimated_max_memory_per_device[name] =
1076         max_memory_usage;
1077 
1078     const Costs::NanoSeconds wall_time_ns = state.GetCurrTime();
1079     VLOG(1) << "Device = " << name
1080             << ", num_nodes = " << state.nodes_executed.size()
1081             << ", wall_time_ns = " << wall_time_ns.count() << ", memory usage: "
1082             << "persistent = " << HumanReadableNumBytes(persistent_memory_usage)
1083             << ", peak = " << HumanReadableNumBytes(state.max_memory_usage)
1084             << ", total = " << HumanReadableNumBytes(max_memory_usage)
1085             << ", at the end: " << HumanReadableNumBytes(state.memory_usage);
1086 
1087     // Overall statement about accuracy
1088     VLOG(1) << state.device_costs.num_ops_total
1089             << " ops processed in total, with "
1090             << state.device_costs.num_ops_with_unknown_shapes
1091             << " having unknown shapes";
1092 
1093     // Device shape annotation statistics.
1094     const auto& device_annotation_stats = state.shape_annotation_stats;
1095     if (device_annotation_stats.num_ops_annotated > 0) {
1096       VLOG(1) << device_annotation_stats.num_ops_annotated
1097               << " ops with shape annotation, with "
1098               << device_annotation_stats.num_ops_executed_more_than_once
1099               << " executed more than once, "
1100               << device_annotation_stats.num_ops_with_dynamic_shapes
1101               << " with dynamic shapes, "
1102               << device_annotation_stats.num_ops_with_incompatible_shapes
1103               << " with incompatible shapes, "
1104               << device_annotation_stats.num_ops_executed
1105               << " ops executed in total.";
1106     }
1107 
1108     VLOG(1) << "Per-op execution time / compute time / memory time "
1109             << " / intermediate memory time"
1110             << " (and memory usage at peak memory usage):";
1111 
1112     // Profile non-persistent op memory usage.
1113     for (const auto& node_port : state.mem_usage_snapshot_at_peak) {
1114       const auto* node = node_port.first;
1115       const auto port = node_port.second;
1116       // Check if the node is in the node_map. It may be that the node executed
1117       // on this device was executed by a different Scheduler.
1118       if (node_map_.find(node) != node_map_.end()) {
1119         op_to_memory[node->op()] +=
1120             CalculateOutputSize(node_map_.at(node).output_properties, port);
1121       }
1122     }
1123     Costs::NanoSeconds total_compute_time_ns;
1124     bool is_total_cost_accurate = true;
1125     for (const auto& op_cost_pair : state.op_to_cost) {
1126       const auto& op = op_cost_pair.first;
1127       const auto& cost = op_cost_pair.second.execution_time.count();
1128       const auto& compute_cost = op_cost_pair.second.compute_time.count();
1129       const auto& memory_cost = op_cost_pair.second.memory_time.count();
1130       const auto& intermediate_memory_cost =
1131           op_cost_pair.second.intermediate_memory_time.count();
1132       total_compute_time_ns += op_cost_pair.second.execution_time;
1133       const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate;
1134       if (!is_op_cost_accurate) {
1135         is_total_cost_accurate = false;
1136       }
1137 
1138       int64 op_mem_usage = 0;
1139       auto it = op_to_memory.find(op);
1140       if (it != op_to_memory.end()) {
1141         op_mem_usage = it->second;
1142       }
1143 
1144       const float mem_usage_percent =
1145           max_memory_usage > 0 ? Round2(100.0 * op_mem_usage / max_memory_usage)
1146                                : 0.0;
1147       if (cost || mem_usage_percent > 1.0) {
1148         // Print out only non-zero cost ops or ops with > 1% memory usage.
1149         VLOG(1) << absl::StrFormat(
1150                        " + %30s : %c %10d / %10d / %10d / %10d", op.c_str(),
1151                        (is_op_cost_accurate ? ' ' : '~'), cost, compute_cost,
1152                        memory_cost, intermediate_memory_cost)
1153                 << " (" << HumanReadableNumBytes(op_mem_usage) << " ["
1154                 << mem_usage_percent << "%] "
1155                 << (persistent_ops.count(op) > 0 ? ": persistent op)" : ")");
1156       }
1157     }
1158 
1159     int utilization = 0;
1160     if (wall_time_ns.count() > 0) {
1161       utilization = total_compute_time_ns.count() * 100 / wall_time_ns.count();
1162     }
1163     VLOG(1) << "Device = " << name << ", total_compute_time_ns = "
1164             << (is_total_cost_accurate ? "" : "~")
1165             << total_compute_time_ns.count()
1166             << ", utilization = " << utilization << "%";
1167 
1168     if (critical_path_costs.execution_time <= state.GetCurrTime()) {
1169       critical_path_costs = state.device_costs;
1170     }
1171   }
1172 
1173   if (VLOG_IS_ON(2)) {
1174     // Also log the op description and their corresponding counts.
1175     VLOG(2) << "Node description, counts, cost:";
1176     for (const auto& item : op_counts_) {
1177       int cost;
1178       bool is_cost_accurate;
1179       std::tie(cost, is_cost_accurate) = op_costs_.at(item.first);
1180       VLOG(2) << "Node: " << item.first << ", Count: " << item.second
1181               << ", Individual Cost: " << (is_cost_accurate ? "" : "~") << cost
1182               << " us";
1183     }
1184   }
1185 
1186   VLOG(1) << "Critical path execution time: "
1187           << critical_path_costs.execution_time.count();
1188   return critical_path_costs;
1189 }
1190 
Summary(RunMetadata * metadata)1191 Costs SchedulerState::Summary(RunMetadata* metadata) {
1192   if (metadata) GenerateRunMetadata(metadata);
1193   return Summary();
1194 }
1195 
GenerateRunMetadata(RunMetadata * metadata)1196 void SchedulerState::GenerateRunMetadata(RunMetadata* metadata) {
1197   // Fill RunMetadata's step_stats and partition_graphs fields.
1198   StepStats* stepstats = metadata->mutable_step_stats();
1199   for (const auto& device : device_) {
1200     GraphDef* device_partition_graph = metadata->add_partition_graphs();
1201     DeviceStepStats* device_stepstats = stepstats->add_dev_stats();
1202     device_stepstats->set_device(device.first);
1203     for (const auto& node_def : device.second.nodes_executed) {
1204       // Only proceed if the node is in the node_map. This is to cover the case
1205       // where a device has executed a node that is not in the node_map of
1206       // this scheduler.
1207       if (node_map_.find(node_def) == node_map_.end()) {
1208         continue;
1209       }
1210       const NodeState& nodestate = node_map_.at(node_def);
1211       NodeExecStats* node_stats = device_stepstats->add_node_stats();
1212       uint64 total_output_size = 0;
1213       for (int slot = 0, slot_end = nodestate.output_properties.size();
1214            slot < slot_end; slot++) {
1215         const auto& properties = nodestate.output_properties[slot];
1216         NodeOutput* no = node_stats->add_output();
1217         no->set_slot(slot);
1218         TensorDescription* tensor_descr = no->mutable_tensor_description();
1219         tensor_descr->set_dtype(properties.dtype());
1220         *tensor_descr->mutable_shape() = properties.shape();
1221         // Optional allocation description.
1222         const auto tensor_size =
1223             CalculateOutputSize(nodestate.output_properties, slot);
1224         total_output_size += tensor_size;
1225         tensor_descr->mutable_allocation_description()->set_requested_bytes(
1226             tensor_size);
1227         tensor_descr->mutable_allocation_description()->set_allocated_bytes(
1228             tensor_size);
1229       }
1230       if (node_def->op() != "HloGenericOp") {
1231         node_stats->set_timeline_label(node_def->op());
1232       } else {
1233         // For HloGenericOp, display hlo_opcode as timeline label.
1234         string timeline_label;
1235         if (node_def->attr().count("hlo_opcode") > 0) {
1236           absl::StrAppend(&timeline_label,
1237                           node_def->attr().at("hlo_opcode").s());
1238         }
1239         if (node_def->attr().count("_hlo_metadata_op_type") > 0) {
1240           absl::StrAppend(&timeline_label, "/",
1241                           node_def->attr().at("_hlo_metadata_op_type").s());
1242         }
1243         node_stats->set_timeline_label(timeline_label);
1244       }
1245       node_stats->set_node_name(node_def->name());
1246       // Timestamps in microseconds (can be used by timeline_server).
1247       node_stats->set_op_start_rel_micros(0);
1248       node_stats->set_all_start_micros(
1249           nodestate.time_scheduled.asMicroSeconds().count());
1250       node_stats->set_op_end_rel_micros(
1251           nodestate.time_finished.asMicroSeconds().count() -
1252           nodestate.time_scheduled.asMicroSeconds().count());
1253       node_stats->set_all_end_rel_micros(
1254           nodestate.time_finished.asMicroSeconds().count() -
1255           nodestate.time_scheduled.asMicroSeconds().count());
1256       // Timestamps in nanoseconds (can be used by xprof trace).
1257       node_stats->set_op_start_rel_nanos(0);
1258       node_stats->set_all_start_nanos(nodestate.time_scheduled.count());
1259       node_stats->set_op_end_rel_nanos(nodestate.time_finished.count() -
1260                                        nodestate.time_scheduled.count());
1261       node_stats->set_all_end_rel_nanos(nodestate.time_finished.count() -
1262                                         nodestate.time_scheduled.count());
1263 
1264       auto* mem_stats = node_stats->mutable_memory_stats();
1265       // SchedulerState does not specify scratch pad memory usage.
1266       mem_stats->set_temp_memory_size(0);
1267       int64 persistent_memory_size = 0;
1268       if (IsPersistent(*node_def)) {
1269         persistent_memory_size = total_output_size;
1270       }
1271       mem_stats->set_persistent_memory_size(persistent_memory_size);
1272       *device_partition_graph->add_node() = *node_def;
1273     }
1274   }
1275 }
1276 
GetPeakMemoryUsage() const1277 const std::unordered_map<string, int64> SchedulerState::GetPeakMemoryUsage()
1278     const {
1279   std::unordered_map<string, int64> result;
1280   for (const auto& device : device_) {
1281     const string& name = device.first;
1282     const DeviceState& state = device.second;
1283     result[name] = state.max_memory_usage;
1284   }
1285   return result;
1286 }
1287 
1288 const std::unordered_map<string, int64>
GetPersistentMemoryUsage() const1289 SchedulerState::GetPersistentMemoryUsage() const {
1290   std::unordered_map<string, int64> result;
1291   for (const auto& device : device_) {
1292     const string& name = device.first;
1293     const DeviceState& state = device.second;
1294     int64 persistent_memory_usage = 0;
1295     for (const auto& node_port : state.persistent_nodes) {
1296       const auto* node = node_port.first;
1297       const auto port = node_port.second;
1298       const auto output_size =
1299           CalculateOutputSize(node_map_.at(node).output_properties, port);
1300       persistent_memory_usage += output_size;
1301     }
1302     result[name] = persistent_memory_usage;
1303   }
1304   return result;
1305 }
1306 
SetNodeStateTimeScheduled(const NodeDef * node)1307 void SchedulerState::SetNodeStateTimeScheduled(const NodeDef* node) {
1308   auto& node_state = node_map_.at(node);
1309   auto& device = device_[node_state.device_name];
1310   node_state.time_scheduled = device.GetCurrTime();
1311 }
1312 
~VirtualScheduler()1313 VirtualScheduler::~VirtualScheduler() {}
1314 
VirtualScheduler(const bool use_static_shapes,const bool use_aggressive_shape_inference,Cluster * cluster,ReadyNodeManager * ready_nodes,std::unique_ptr<VirtualPlacer> placer)1315 VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
1316                                    const bool use_aggressive_shape_inference,
1317                                    Cluster* cluster,
1318                                    ReadyNodeManager* ready_nodes,
1319                                    std::unique_ptr<VirtualPlacer> placer)
1320     : scheduler_state_(absl::make_unique<SchedulerState>(
1321           use_static_shapes, use_aggressive_shape_inference, cluster,
1322           std::move(placer))),
1323       ready_nodes_(ready_nodes) {}
1324 
VirtualScheduler(ReadyNodeManager * ready_nodes,std::unique_ptr<SchedulerState> scheduler_state)1325 VirtualScheduler::VirtualScheduler(
1326     ReadyNodeManager* ready_nodes,
1327     std::unique_ptr<SchedulerState> scheduler_state)
1328     : scheduler_state_(std::move(scheduler_state)), ready_nodes_(ready_nodes) {}
1329 
Init(const GrapplerItem * item)1330 Status VirtualScheduler::Init(const GrapplerItem* item) {
1331   // SchedulerState::Init() preprocesses the input grappler_item and
1332   // graph_properties to extract necessary information for emulating tensorflow
1333   // op scheduling and construct internal data structures (NodeState and
1334   // DeviceState) for virtual scheduling.
1335   TF_RETURN_IF_ERROR(ready_nodes_->Init(GetNodeStates()));
1336   std::vector<const NodeDef*> initial_nodes;
1337   auto status = scheduler_state_->Init(item, &initial_nodes);
1338   if (status.ok()) {
1339     // Add the set of initial nodes to ready_nodes_
1340     for (auto node : initial_nodes) {
1341       ready_nodes_->AddNode(node);
1342     }
1343   }
1344   return status;
1345 }
1346 
GetCurrNode()1347 OpContext VirtualScheduler::GetCurrNode() {
1348   const NodeDef* node = ready_nodes_->GetCurrNode();
1349   return scheduler_state_->CreateOpContext(node);
1350 }
1351 
MarkCurrNodeExecuted(const Costs & node_costs)1352 bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
1353   // Update graph_costs_ and per-op costs.
1354   const NodeDef* node = ready_nodes_->GetCurrNode();
1355   auto new_nodes = scheduler_state_->MarkNodeExecuted(
1356       node, node_costs,
1357       scheduler_state_->CreateOpContext(ready_nodes_->GetCurrNode()));
1358   // Add the set of new nodes obtained from MarkNodeExecuted() to ready_nodes_.
1359   for (auto node : new_nodes) {
1360     ready_nodes_->AddNode(node);
1361   }
1362   ready_nodes_->RemoveCurrNode();
1363   return !ready_nodes_->Empty();
1364 }
1365 
1366 }  // end namespace grappler
1367 }  // end namespace tensorflow
1368