1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
17 
18 #include <unistd.h>
19 
20 #include <algorithm>
21 #include <atomic>
22 #include <deque>
23 #include <map>
24 #include <memory>
25 #include <queue>
26 #include <string>
27 #include <tuple>
28 #include <vector>
29 
30 #include "absl/container/flat_hash_map.h"
31 #include "absl/strings/match.h"
32 #include "absl/strings/str_cat.h"
33 #include "absl/strings/str_format.h"
34 #include "absl/strings/str_join.h"
35 #include "absl/strings/str_replace.h"
36 #include "absl/types/optional.h"
37 #include "tensorflow/compiler/xla/layout_util.h"
38 #include "tensorflow/compiler/xla/literal.h"
39 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
40 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
41 #include "tensorflow/compiler/xla/service/hlo_module.h"
42 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
43 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
44 #include "tensorflow/compiler/xla/shape_util.h"
45 #include "tensorflow/compiler/xla/types.h"
46 #include "tensorflow/compiler/xla/util.h"
47 #include "tensorflow/compiler/xla/window_util.h"
48 #include "tensorflow/core/lib/core/status.h"
49 #include "tensorflow/core/lib/gtl/map_util.h"
50 #include "tensorflow/core/lib/io/path.h"
51 #include "tensorflow/core/lib/strings/numbers.h"
52 #include "tensorflow/core/platform/env.h"
53 #include "tensorflow/core/platform/mutex.h"
54 #include "tensorflow/core/platform/protobuf.h"
55 #include "tensorflow/core/platform/regexp.h"
56 
57 namespace xla {
58 namespace {
59 
60 using absl::nullopt;
61 using absl::optional;
62 using absl::StrAppend;
63 using absl::StrCat;
64 using absl::StrFormat;
65 using absl::StrJoin;
66 
67 // Used to indicate how we should treat a given HLOInstruction in the graph.
68 // should we treat it like normal, hide it, and so on?
69 enum NodeFilterResult {
70   kNormalNode,
71   kHideNode,
72   // Make the node easy to find in the final graph.
73   kHighlightNode,
74   // "Gray out" the node to indicate that some of its operands have been
75   // omitted.
76   kSomeOperandsOmitted,
77   // Style the node the same as kSomeOperandsOmitted, but also don't connect it
78   // to its operands, even if they're present in the graph.
79   kOmitNodeOperands,
80   // Same style as kSomeOperandsOmitted, but used to indicate that some of the
81   // node's *users* have been omitted.
82   kSomeUsersOmitted,
83 };
84 
85 // NodeFilter is essentially a map from HloInstruction*s to NodeFilterResult.
86 // It lets callers tell the graph-drawing routines which nodes they want to be
87 // shown, hidden, or highlighted.
88 class NodeFilter {
89  public:
__anon7c0932a70202(const HloInstruction*) 90   NodeFilter() : filter_([](const HloInstruction*) { return kNormalNode; }) {}
91 
NodeFilter(std::function<NodeFilterResult (const HloInstruction * instr)> filter)92   explicit NodeFilter(
93       std::function<NodeFilterResult(const HloInstruction* instr)> filter)
94       : filter_(std::move(filter)) {}
95 
Show(const HloInstruction * instr) const96   bool Show(const HloInstruction* instr) const {
97     return filter_(instr) != kHideNode;
98   }
Highlight(const HloInstruction * instr) const99   bool Highlight(const HloInstruction* instr) const {
100     return filter_(instr) == kHighlightNode;
101   }
OmitOperands(const HloInstruction * instr) const102   bool OmitOperands(const HloInstruction* instr) const {
103     return filter_(instr) == kOmitNodeOperands;
104   }
SomeOrAllOperandsOmitted(const HloInstruction * instr) const105   bool SomeOrAllOperandsOmitted(const HloInstruction* instr) const {
106     auto result = filter_(instr);
107     return result == kOmitNodeOperands || result == kSomeOperandsOmitted;
108   }
Deemphasized(const HloInstruction * instr) const109   bool Deemphasized(const HloInstruction* instr) const {
110     auto result = filter_(instr);
111     return result == kOmitNodeOperands || result == kSomeOperandsOmitted ||
112            result == kSomeUsersOmitted;
113   }
114 
115  private:
116   std::function<NodeFilterResult(const HloInstruction* instr)> filter_;
117 };
118 
119 // We arbitrarily set this as the boundary between "large" and "small"
120 // instructions.
IsSmall(const HloInstruction * instr)121 bool IsSmall(const HloInstruction* instr) {
122   if (ShapeUtil::HasPrimitiveType(instr->shape(), OPAQUE_TYPE) ||
123       ShapeUtil::HasPrimitiveType(instr->shape(), TOKEN)) {
124     return true;
125   }
126   return ShapeUtil::ElementsInRecursive(instr->shape()) < 4096;
127 }
128 
129 // Node color schemes, used by NodeColorAttributes.
130 enum ColorScheme {
131   kBlue,
132   kBrown,
133   kDarkBlue,
134   kDarkGreen,
135   kDarkOrange,
136   kDarkRed,
137   kGray,
138   kGreen,
139   kOrange,
140   kPurple,
141   kRed,
142   kWhite,
143   kYellow,
144 
145   // Causes the node's border to be a dashed line, and its content to be gray
146   // text on a white background, suggesting that this is an "unimportant" node.
147   kDashedBorder,
148 };
149 
150 // Graphviz attributes/colors that make up a color scheme.
151 struct NodeColors {
152   const char* style;
153   const char* fill_color;
154   const char* stroke_color;
155   const char* font_color;
156 };
157 
NodeColorsForScheme(ColorScheme color)158 NodeColors NodeColorsForScheme(ColorScheme color) {
159   switch (color) {
160     case kBlue:
161       return NodeColors{"filled", "#bbdefb", "#8aacc8", "black"};
162     case kBrown:
163       return NodeColors{"filled", "#bcaaa4", "#8c7b75", "black"};
164     case kDarkBlue:
165       return NodeColors{"filled", "#1565c0", "#003c8f", "white"};
166     case kDarkGreen:
167       return NodeColors{"filled", "#2e7d32", "#005005", "white"};
168     case kDarkOrange:
169       // This is more of a "medium" orange, made to look close to kOrange;
170       // there's probably room for a darker weight if desired.
171       return NodeColors{"filled", "#ffb74d", "#c88719", "black"};
172     case kDarkRed:
173       return NodeColors{"filled", "#b71c1c", "#7f0000", "white"};
174     case kGray:
175       return NodeColors{"filled", "#cfd8dc", "#9ea7aa", "black"};
176     case kGreen:
177       return NodeColors{"filled", "#c8e6c9", "#97b498", "black"};
178     case kOrange:
179       return NodeColors{"filled", "#ffe0b2", "#cbae82", "black"};
180     case kPurple:
181       return NodeColors{"filled", "#e1bee7", "#af8eb5", "black"};
182     case kRed:
183       return NodeColors{"filled", "#ffcdd2", "#cb9ca1", "black"};
184     case kWhite:
185       return NodeColors{"filled", "white", "black", "black"};
186     case kYellow:
187       return NodeColors{"filled", "#fff9c4", "#cbc693", "black"};
188     case kDashedBorder:
189       // "filled,dashed" looks the same as "dashed", since we have a white
190       // background.  But we use "filled,dashed" so that when you hover over
191       // any part of the node (not just the text inside the node), our css
192       // :hover rule is triggered.
193       return NodeColors{"filled,dashed", "white", "#757575", "#757575"};
194   }
195 }
196 
197 // Given a ColorScheme, returns an attribute string for a node of that color.
198 // Sets the node's style and fill/stroke/text colors.
199 //
200 // Colors are from https://material.io/color.
NodeColorAttributes(ColorScheme color)201 string NodeColorAttributes(ColorScheme color) {
202   NodeColors node_colors = NodeColorsForScheme(color);
203 
204   return StrFormat(R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")",
205                    node_colors.style, node_colors.font_color,
206                    node_colors.stroke_color, node_colors.fill_color);
207 }
208 
209 // Replaces <> with &lt;&gt;, so that this string is safe(er) for use in a
210 // graphviz HTML-like string.
HtmlLikeStringSanitize(absl::string_view s)211 string HtmlLikeStringSanitize(absl::string_view s) {
212   return absl::StrReplaceAll(s, {{"<", "&lt;"}, {">", "&gt;"}});
213 }
214 
IsFusedBroadcastOfConstantEffectiveScalar(const HloInstruction * instr)215 bool IsFusedBroadcastOfConstantEffectiveScalar(const HloInstruction* instr) {
216   namespace m = match;
217   return instr->parent()->IsFusionComputation() &&
218          Match(instr, m::Broadcast(m::ConstantEffectiveScalar()));
219 }
220 
221 // Tries to generates a human-readable one-word description of the given
222 // computation.
223 //
224 // Currently we support:
225 //
226 //   "return param0 + param1;"      --> "add"
227 //   "return param0 * param1;"      --> "multiply"
228 //   "return min(param0, param1);"  --> "min"
229 //   "return max(param0, param1);"  --> "max"
230 //   "return param0 <= param1;"     --> "less-or-equal"
231 //   "return param0 >= param1;"     --> "greater-or-equal"
232 //   "return param0 >  param1;"     --> "greater-than"
233 //   "return param0 <  param1;"     --> "less-than"
234 //   "return param0 == param1;"     --> "equal-to"
235 //   "return param0 != param1;"     --> "not-equal-to"
236 //
237 // where param0 and param1 are effective scalars.  For the ops that are
238 // commutative, we also support them with param0 and param1 swapped.
239 //
240 // This is useful primarily for reduce and map nodes.  These take a
241 // subcomputation which is almost always one of the above, and pattern matching
242 // it to a short string lets us tell the user what the subcomputation is without
243 // drawing it as a graph.
MatchTrivialComputation(const HloComputation * computation)244 optional<string> MatchTrivialComputation(const HloComputation* computation) {
245   namespace m = match;
246 
247   if (computation->instruction_count() != 3) {
248     return nullopt;
249   }
250   HloInstruction* root = computation->root_instruction();
251   const HloInstruction *param0, *param1;
252   if (!Match(root, m::Op()
253                        .WithNumOperands(2)
254                        .WithShape(m::Shape().IsEffectiveScalar())
255                        .WithBinaryOperandsAnyOrder(
256                            m::Parameter(&param0, 0)
257                                .WithShape(m::Shape().IsEffectiveScalar()),
258                            m::Parameter(&param1, 1)
259                                .WithShape(m::Shape().IsEffectiveScalar())))) {
260     return nullopt;
261   }
262 
263   // If the params are reversed (i.e. operand0 is param1 and operand1 is
264   // param0), check that the operation being performed is commutative.
265   if (root->operand(0) == param1) {
266     CHECK_EQ(root->operand(1), param0);
267     if (root->opcode() == HloOpcode()) {
268       switch (root->comparison_direction()) {
269         case ComparisonDirection::kLe:
270         case ComparisonDirection::kGe:
271         case ComparisonDirection::kGt:
272         case ComparisonDirection::kLt:
273           return nullopt;
274         default:
275           break;
276       }
277     }
278   }
279 
280   // If we recognize the root's opcode, we've successfully pattern-matched!
281   switch (root->opcode()) {
282     case HloOpcode::kAdd:
283       return "add";
284     case HloOpcode::kMultiply:
285       return "multiply";
286     case HloOpcode::kMinimum:
287       return "min";
288     case HloOpcode::kMaximum:
289       return "max";
290     case HloOpcode::kCompare: {
291       switch (root->comparison_direction()) {
292         case ComparisonDirection::kLe:
293           return "less-or-equal";
294         case ComparisonDirection::kGe:
295           return "greater-or-equal";
296         case ComparisonDirection::kGt:
297           return "greater-than";
298         case ComparisonDirection::kLt:
299           return "less-than";
300         case ComparisonDirection::kEq:
301           return "equal-to";
302         case ComparisonDirection::kNe:
303           return "not-equal-to";
304       }
305     }
306     default:
307       return nullopt;
308   }
309 }
310 
311 // Encapsulates logic for dumping an HLO module to DOT (i.e. graphviz syntax).
312 class HloDotDumper {
313  public:
HloDotDumper(const HloComputation * computation,absl::string_view label,const DebugOptions & debug_options,HloRenderOptions hlo_render_options,const HloExecutionProfile * profile,NodeFilter filter)314   HloDotDumper(const HloComputation* computation, absl::string_view label,
315                const DebugOptions& debug_options,
316                HloRenderOptions hlo_render_options,
317                const HloExecutionProfile* profile, NodeFilter filter)
318       : computation_(computation),
319         label_(label),
320         debug_options_(debug_options),
321         hlo_render_options_(hlo_render_options),
322         profile_(profile),
323         filter_(std::move(filter)) {}
324 
325   string Dump();
326 
327  private:
328   // Returns the dot graph identifier for the given instruction.
InstructionId(const HloInstruction * instruction)329   string InstructionId(const HloInstruction* instruction) {
330     return StrCat(reinterpret_cast<uint64>(instruction));
331   }
332 
333   // Returns the dot graph identifier for the given computation.
SubcomputationId(const HloComputation * computation)334   string SubcomputationId(const HloComputation* computation) {
335     return StrCat("cluster_", reinterpret_cast<uint64>(computation));
336   }
337 
338   // Generates graph header/footer.  These should be called *after* dumping all
339   // of the instructions and subcomputations for the graph, as they both use
340   // data generated while dumping the graph.
341   string Header();
342   string Footer();
343 
344   bool ShouldShowSubcomputation(const HloComputation* subcomp);
345   bool ShouldShowFusionSubcomputation(const HloInstruction* instr);
346 
347   // We omit some nodes from the graph, instead drawing them inlined into the
348   // nodes that use them.
349   bool ShouldMergeIntoUsers(const HloInstruction* instr) const;
350 
351   string DumpSubcomputation(const HloComputation* subcomp,
352                             const HloInstruction* parent_instr);
353   string DumpComputation(const HloComputation* comp);
354   string DumpRootTag();
355   string DumpInstruction(const HloInstruction* instr);
356   ColorScheme GetInstructionColor(const HloInstruction* instr);
357   string GetInstructionNodeShape(const HloInstruction* instr);
358   string GetInstructionNodeLabel(const HloInstruction* instr);
359   string GetInstructionNodeMetadata(const HloInstruction* instr);
360   string GetInstructionNodeBackendConfig(const HloInstruction* instr);
361   string GetInstructionNodeExtraInfo(const HloInstruction* instr);
362   string GetInstructionNodeInlinedOperands(const HloInstruction* instr);
363   void AddInstructionIncomingEdges(const HloInstruction* instr);
364 
365   // For most instructions, GetNodeForEdge(instr) returns instr.
366   //
367   // The exception is fusion nodes.  For these, we walk up the chain of nested
368   // fusion nodes starting at instr until we reach a node that either (a) isn't
369   // a fusion node, or (b) is a fusion node for which
370   // ShouldShowFusionSubcomputation is false.
371   //
372   // We do this because fusion nodes are expanded inline -- if
373   // ShouldShowFusionSubcomputation is true, the fusion node won't be present in
374   // the graph.
375   //
376   // In general when you want to draw an edge from A to B, you should actually
377   // draw an edge from GetNodeForEdge(A) to GetNodeForEdge(B).
378   const HloInstruction* GetNodeForEdge(const HloInstruction* instr);
379 
380   // If instr has just one computation and it's trivial (e.g. "return param0 +
381   // param1"), returns a string you can put into the node's body that names the
382   // subcomputation, e.g. "Subcomputation: <b>add</b>".
383   string GetInstructionTrivialComputationStr(const HloInstruction* instr);
384 
385   const HloComputation* computation_;  // never null
386   const string label_;                 // overall name for the graph
387   const DebugOptions& debug_options_;
388   const HloRenderOptions hlo_render_options_;
389   const HloExecutionProfile* profile_;  // may be null
390   const NodeFilter filter_;
391 
392   // Each HloInstruction dumped gets a monotonically-increasing node ID.  This
393   // must start at 1, because that's where graphviz's accounting starts.
394   int64 next_node_id_ = 1;
395   absl::flat_hash_map<const HloInstruction*, int64> node_ids_;
396 
397   // The "root" tag doesn't have an associated HloInstruction pointer, so we
398   // need to store it outside the map.
399   int64 root_node_id_;
400 
401   // Each (from, to) edge gets a monotonically-increasing ID.  This is a
402   // multimap because it's possible for the same edge to appear multiple times
403   // in the graph (e.g. x^2 may be represented as mul(x, x)).
404   int64 next_edge_id_ = 1;
405   std::unordered_multimap<
406       std::pair<const HloInstruction*, const HloInstruction*>, int64,
407       tensorflow::hash<std::pair<const HloInstruction*, const HloInstruction*>>>
408       edge_ids_;
409 
410   // Each HloComputation that's emitted gets a monotonically-increasing ID.
411   int64 next_cluster_id_ = 1;
412   absl::flat_hash_map<const HloComputation*, int64> cluster_ids_;
413 
414   // Edges to print from Footer().  Edges come at the end because graphviz is
415   // unhappy if an edge from a subcomputation to a node in the outer computation
416   // appears before both the inner computation and the destination node are
417   // defined.
418   std::vector<string> edges_;
419 
420   // When coloring by sharding information, we track the sharding string
421   // representation to color association, by round-robin the color schemes.
422   absl::flat_hash_map<HloSharding, ColorScheme, HloSharding::Hasher>
423       sharding_colors_;
424   int64 next_shard_color_ = 0;
425 };
426 
Dump()427 string HloDotDumper::Dump() {
428   string body;
429   StrAppend(&body, DumpComputation(computation_));
430   StrAppend(&body, DumpRootTag());
431 
432   // By contract, Header() and Footer() have to be called after we've dumped all
433   // our instructions, because they use state generated during that process.
434   string g = Header();
435   StrAppend(&g, body);
436   StrAppend(&g, Footer());
437   return g;
438 }
439 
Header()440 string HloDotDumper::Header() {
441   constexpr char fmt[] = R"(digraph G {
442 rankdir = TB;
443 compound = true;
444 label = <<b>%s</b>>;
445 labelloc = t;
446 // Disable the tooltip.  Interestingly, "" doesn't work!
447 tooltip = " ";
448 // DOT graphs accept a stylesheet as a URI.  So naturally, an inline
449 // stylesheet is a data URI!
450 stylesheet=<
451   data:text/css,
452   @import url(https://fonts.googleapis.com/css?family=Roboto:400,700);
453   svg text {
454     font-family: 'Roboto';
455     font-size: 12px;
456   }
457 
458   %s
459 >
460 
461 )";
462 
463   VLOG(3) << "Generating Header";
464 
465   string graph_label =
466       StrCat(label_, "<br/>Computation ", computation_->name());
467   if (computation_->IsFusionComputation()) {
468     StrAppend(&graph_label, " (in fusion instruction ",
469               computation_->FusionInstruction()->name(), ")");
470   }
471   if (profile_ != nullptr) {
472     auto cycles = profile_->total_cycles_executed(*computation_);
473     absl::StrAppendFormat(&graph_label, "<br/>total cycles = %d (%s)", cycles,
474                           tensorflow::strings::HumanReadableNum(cycles));
475   }
476 
477   // Create CSS rules that say, when you hover over the given node or cluster,
478   // turn the given edge the given color.
479   //
480   // We rely on a few properties of how graphviz generates SVGs:
481   //
482   //  - Nodes are named "nodeN", where N corresponds to the 1-based index of
483   //    the node in our DOT (i.e. the first node in the DOT is "node1", etc.).
484   //    Edges are similarly named "edgeN", and clusters are named "clustN".
485   //  - Nodes come before their in- and out-edges in the SVG.  We need this
486   //    because the "X ~ Y" CSS selector finds a sibling of X that *comes
487   //    after X in the DOM* and matches Y.
488   std::vector<string> edge_css_rules;
489   const char* kBlue = "#1976d2";
490   const char* kRed = "#d32f2f";
491   for (const auto& kv : edge_ids_) {
492     const HloInstruction* from_node = kv.first.first;
493     const HloInstruction* to_node = kv.first.second;
494     int64 edge_id = kv.second;
495 
496     auto add_hover_css_rule = [&](string elem_type, int64 elem_id,
497                                   const char* color) {
498       // One could imagine other ways of writing this CSS rule that involve
499       // less duplication, but this way seems to be relatively performant.
500       edge_css_rules.push_back(
501           StrFormat("  #%s%d:hover ~ #edge%d text { fill: %s; }\n"
502                     "  #%s%d:hover ~ #edge%d path { "
503                     "stroke: %s; stroke-width: .2em; }\n"
504                     "  #%s%d:hover ~ #edge%d polygon { "
505                     "fill: %s; stroke: %s; stroke-width: .2em; }\n",
506                     elem_type, elem_id, edge_id, color,  //
507                     elem_type, elem_id, edge_id, color,  //
508                     elem_type, elem_id, edge_id, color, color));
509     };
510 
511     // The "to_node" value may be a NULL, indicating that this points to the
512     // "root" tag rather than a normal node.
513     int64 from_node_id =
514         tensorflow::gtl::FindWithDefault(node_ids_, from_node, -1);
515     if (from_node_id == -1) {
516       LOG(FATAL) << from_node->name() << " was added to edges but not to nodes";
517     }
518     int64 to_node_id =
519         to_node ? tensorflow::gtl::FindWithDefault(node_ids_, to_node, -1)
520                 : root_node_id_;
521     if (to_node != nullptr && to_node_id == -1) {
522       LOG(FATAL) << to_node->name() << " was added to edges but not to nodes";
523     }
524 
525     add_hover_css_rule("node", from_node_id, kBlue);
526     add_hover_css_rule("node", to_node_id, kRed);
527 
528     if (to_node) {
529       VLOG(3) << "Adding css for edge " << edge_id << " from node "
530               << from_node->name() << " to node " << to_node->name();
531     } else {
532       VLOG(3) << "Adding css for edge " << edge_id << " from node "
533               << from_node->name() << " to root tag";
534     }
535 
536     // If this edge crosses a fusion cluster boundary, highlight it when the
537     // cluster is hovered over.
538     if (to_node) {
539       if (from_node->IsFused() &&
540           from_node->parent()->root_instruction() == from_node) {
541         int64 cluster_id = cluster_ids_.at(from_node->parent());
542         add_hover_css_rule("clust", cluster_id, kBlue);
543       }
544       if (to_node->IsFused() && to_node->opcode() == HloOpcode::kParameter) {
545         int64 cluster_id = cluster_ids_.at(to_node->parent());
546         add_hover_css_rule("clust", cluster_id, kRed);
547       }
548     }
549   }
550 
551   // Browsers require that we URI-encode the contents of our data URI.  (It
552   // seems this was a relatively recent change?) In practice, this means that we
553   // need to escape '#'.
554   return StrFormat(
555       fmt, graph_label,
556       absl::StrReplaceAll(StrJoin(edge_css_rules, "\n"), {{"#", "%23"}}));
557 }
558 
Footer()559 string HloDotDumper::Footer() { return StrCat(StrJoin(edges_, "\n"), "\n}"); }
560 
ShouldShowFusionSubcomputation(const HloInstruction * instr)561 bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) {
562   CHECK_EQ(instr->opcode(), HloOpcode::kFusion);
563   return ShouldShowSubcomputation(instr->fused_instructions_computation());
564 }
565 
ShouldShowSubcomputation(const HloComputation * subcomp)566 bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) {
567   if (subcomp->IsFusionComputation()) {
568     const HloInstruction* fusion = subcomp->FusionInstruction();
569     if (!filter_.Show(fusion) || filter_.SomeOrAllOperandsOmitted(fusion) ||
570         !hlo_render_options_.show_fusion_subcomputations) {
571       return false;
572     }
573   }
574 
575   // Don't show trivial subcomputations on non-fusion nodes -- these are inlined
576   // into the graph.
577   if (!subcomp->IsFusionComputation() && MatchTrivialComputation(subcomp)) {
578     return false;
579   }
580 
581   // Show the subcomputation if we're showing any of its members.
582   return absl::c_any_of(
583       subcomp->instructions(),
584       [&](const HloInstruction* instr) { return filter_.Show(instr); });
585 }
586 
DumpSubcomputation(const HloComputation * subcomp,const HloInstruction * parent_instr)587 string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp,
588                                         const HloInstruction* parent_instr) {
589   VLOG(2) << "Dumping subcomputation " << subcomp->name();
590   // Add an edge from the subcomputation to its parent node.  If subcomp
591   // belongs to a fusion node, it's drawn in place of the fusion instruction,
592   // so there's no need to link those.
593   if (parent_instr->opcode() != HloOpcode::kFusion) {
594     const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction());
595     VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name()
596             << " as " << next_edge_id_;
597     edge_ids_.insert({{from, parent_instr}, next_edge_id_++});
598     constexpr char edge_fmt[] =
599         R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)";
600     edges_.push_back(StrFormat(
601         edge_fmt, InstructionId(from), InstructionId(parent_instr),
602         SubcomputationId(subcomp), subcomp->name(), parent_instr->name()));
603   }
604 
605   // Have we already dumped this subcomputation?  If so, generating the edge
606   // linking it and parent_instr is all we want to do in this function.
607   if (cluster_ids_.find(subcomp) != cluster_ids_.end()) {
608     return "";
609   }
610 
611   cluster_ids_[subcomp] = next_cluster_id_++;
612 
613   string id = SubcomputationId(subcomp);
614 
615   string subcomp_label, style;
616   if (parent_instr->opcode() == HloOpcode::kFusion) {
617     subcomp_label =
618         StrFormat("Fused expression for <b>%s</b><br/>%s",
619                   HtmlLikeStringSanitize(parent_instr->name()),
620                   HtmlLikeStringSanitize(parent_instr->ToCategory()));
621     string extra_info = GetInstructionNodeExtraInfo(parent_instr);
622     if (!extra_info.empty()) {
623       StrAppend(&subcomp_label, "<br/>", extra_info);
624     }
625     string node_backend_config = GetInstructionNodeBackendConfig(parent_instr);
626     if (!node_backend_config.empty()) {
627       StrAppend(&subcomp_label, "<br/>", node_backend_config);
628     }
629 
630     bool highlight = filter_.Highlight(parent_instr);
631     const char* fillcolor;
632     const char* strokecolor;
633     if (debug_options_.xla_hlo_graph_sharding_color() && !highlight) {
634       // Use the sharding color, if the node isn't highlighted.
635       NodeColors node_colors =
636           NodeColorsForScheme(GetInstructionColor(parent_instr));
637       fillcolor = node_colors.fill_color;
638       strokecolor = node_colors.stroke_color;
639     } else {
640       // Subcomputation's fill/stroke color is light/dark red/gray, depending on
641       // whether or not the subcomputation's fusion node is highlighted.
642       fillcolor = highlight ? "#ffcdd2" : "#f5f5f5";
643       strokecolor = highlight ? "#b71c1c" : "#c2c2c2";
644     }
645     style =
646         StrFormat(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s;")",
647                   fillcolor, strokecolor);
648   } else {
649     subcomp_label = StrFormat("Subcomputation for <b>%s</b><br/>%s",
650                               HtmlLikeStringSanitize(parent_instr->name()),
651                               HtmlLikeStringSanitize(subcomp->name()));
652     style = "style=rounded; color=black;";
653   }
654 
655   string comp_body = DumpComputation(subcomp);
656 
657   constexpr char computation_fmt[] = R"(subgraph %s {
658 %s
659 label = <%s>;
660 labelloc = t;
661 tooltip = " ";
662 %s
663 }  // %s
664 
665 )";
666   return StrFormat(computation_fmt, id, style, subcomp_label, comp_body, id);
667 }
668 
DumpComputation(const HloComputation * comp)669 string HloDotDumper::DumpComputation(const HloComputation* comp) {
670   string g;
671   for (const auto* instr : comp->instructions()) {
672     if (!filter_.Show(instr)) {
673       continue;
674     }
675 
676     // Dump subcomputations within instr.
677     for (const HloComputation* subcomp : instr->called_computations()) {
678       if (ShouldShowSubcomputation(subcomp)) {
679         StrAppend(&g, DumpSubcomputation(subcomp, instr));
680       }
681     }
682 
683     StrAppend(&g, DumpInstruction(instr));
684   }
685   return g;
686 }
687 
DumpRootTag()688 string HloDotDumper::DumpRootTag() {
689   const HloInstruction* from = GetNodeForEdge(computation_->root_instruction());
690 
691   // We didn't display constants or broadcasts of effective scalars within
692   // fusions as separate nodes; so if the root is a constant/broadcast of
693   // scalar, we don't add root tag or edge for it.
694   if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant ||
695       IsFusedBroadcastOfConstantEffectiveScalar(from)) {
696     return "";
697   }
698 
699   auto from_id = InstructionId(from);
700 
701   // The ID of the root computation is otherwise unused, so it makes a good ID
702   // to use for the root-tag node.  However, the edge_ids_ map requires a
703   // HloInstruction* pointer for the 'to' value, so we use a NULL value there
704   // (rather than a pointer type-cast) to make it obvious if it is erroneously
705   // dereferenced.
706   HloInstruction* to = nullptr;
707   auto to_id = SubcomputationId(computation_);
708 
709   string node_body = "ROOT";
710   string node_shape = "circle";
711   ColorScheme color = kBrown;
712 
713   VLOG(2) << "Adding root tag as node " << next_node_id_;
714   root_node_id_ = next_node_id_++;
715 
716   VLOG(2) << "Adding edge from " << from->name() << " to root tag as "
717           << next_edge_id_;
718   edge_ids_.insert({{from, to}, next_edge_id_++});
719   edges_.push_back(StrFormat(R"(%s -> %s [tooltip=" "];)", from_id, to_id));
720 
721   return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)"
722                    "\n",
723                    to_id, node_body, node_shape, NodeColorAttributes(color));
724 }
725 
TryGetFusionParameterConstant(const HloInstruction * instr)726 static const HloConstantInstruction* TryGetFusionParameterConstant(
727     const HloInstruction* instr) {
728   if (instr->opcode() != HloOpcode::kParameter || !instr->IsFused()) {
729     return nullptr;
730   }
731   const HloInstruction* fusion = instr->parent()->FusionInstruction();
732   const HloInstruction* operand = fusion->operand(instr->parameter_number());
733   return DynCast<HloConstantInstruction>(operand);
734 }
735 
ShouldMergeIntoUsers(const HloInstruction * instr) const736 bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const {
737   // If a node:
738   //
739   //  - is a parameter of a fusion node which is bound to a constant,
740   //
741   // or
742   //
743   //  - is a tuple-shaped parameter, and
744   //  - is not a parameter to a fusion node, and
745   //  - has at least kMinUsersToOmit users shown, and
746   //  - all of the shown users are get-tuple-elements,
747   //
748   // then we omit it from the graph, merging it with its users.
749   //
750   // This helps us handle the common case where a while loop body has one big
751   // tuple-shaped parameter.
752   if (TryGetFusionParameterConstant(instr) != nullptr) {
753     return true;
754   }
755   const int kMinUsersToOmit = 3;
756   return instr->opcode() == HloOpcode::kParameter && instr->shape().IsTuple() &&
757          !instr->IsFused() &&
758          absl::c_count_if(instr->users(),
759                           [&](const HloInstruction* user) {
760                             return filter_.Show(user);
761                           }) > kMinUsersToOmit &&
762          absl::c_all_of(instr->users(), [&](const HloInstruction* user) {
763            return !filter_.Show(user) ||
764                   user->opcode() == HloOpcode::kGetTupleElement;
765          });
766 }
767 
DumpInstruction(const HloInstruction * instr)768 string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
769   // We don't display constants or broadcasts of effective scalar constants
770   // within fusions as separate nodes; they're merged into their users.
771   if (instr->opcode() == HloOpcode::kConstant ||
772       IsFusedBroadcastOfConstantEffectiveScalar(instr)) {
773     return "";
774   }
775   // Skip this node if it's merged into its users.
776   if (ShouldMergeIntoUsers(instr)) {
777     return "";
778   }
779   // Omit the fusion node if its subcomputation is drawn, since the
780   // subcomputation will be drawn inline.
781   if (instr->opcode() == HloOpcode::kFusion &&
782       ShouldShowFusionSubcomputation(instr)) {
783     return "";
784   }
785 
786   VLOG(2) << "Adding node " << instr->name() << " as " << next_node_id_;
787   node_ids_[instr] = next_node_id_++;
788 
789   ColorScheme color = GetInstructionColor(instr);
790   string node_shape = GetInstructionNodeShape(instr);
791   string node_label = GetInstructionNodeLabel(instr);
792   string node_metadata = GetInstructionNodeMetadata(instr);
793   string node_backend_config = GetInstructionNodeBackendConfig(instr);
794   string extra_info = GetInstructionNodeExtraInfo(instr);
795   string inlined_constants = GetInstructionNodeInlinedOperands(instr);
796   string trivial_subcomputation = GetInstructionTrivialComputationStr(instr);
797   AddInstructionIncomingEdges(instr);
798 
799   if (!debug_options_.xla_hlo_graph_sharding_color()) {
800     // Override the node's styling if it should be (de-)emphasized.
801     if (filter_.Deemphasized(instr)) {
802       color = kDashedBorder;
803     }
804     if (filter_.Highlight(instr)) {
805       node_shape = "diamond";
806       color = kDarkRed;
807     }
808   }
809   // Build the text that will be displayed inside the node.
810   string node_body = node_label;
811   for (const string& s : {trivial_subcomputation, node_backend_config,
812                           extra_info, inlined_constants}) {
813     if (!s.empty()) {
814       StrAppend(&node_body, "<br/>", s);
815     }
816   }
817 
818   return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)"
819                    "\n",
820                    InstructionId(instr), node_body, node_shape, node_metadata,
821                    NodeColorAttributes(color));
822 }
823 
GetInstructionNodeInlinedOperands(const HloInstruction * instr)824 string HloDotDumper::GetInstructionNodeInlinedOperands(
825     const HloInstruction* instr) {
826   // The constant's shape is a parameter because, in the case of a broadcasted
827   // scalar constant, we want to show the broadcasted shape, not the constant's
828   // scalar shape.
829   auto stringify_constant = [](const HloConstantInstruction* constant,
830                                const Shape& shape) {
831     // If the shape has a dimension of size zero, print it as e.g.
832     // "{} (f32[42, 0, 10])".  The alternative, calling Literal::ToString(),
833     // enumerates all of its empty dimensions (e.g.  "{ { {}, {} }, ..."), which
834     // is just noise.
835     if (ShapeUtil::IsZeroElementArray(shape)) {
836       return StrFormat("{} (%s)", ShapeUtil::HumanString(constant->shape()));
837     }
838 
839     // Print the literal value of constants with <= K elements.  Note that we
840     // use `constant->shape()` rather than `shape`, because if `constant` is a
841     // scalar that's broadcasted into `shape`, we want to print the constant.
842     optional<int64> elem_count;
843     if (shape.IsArray()) {
844       elem_count = ShapeUtil::ElementsIn(constant->shape());
845     }
846     // Allow HloDotDumper to print HloInstruction reconstructed from HloProto
847     // collected from profiling tools. Those constants may not have a valid
848     // literal.
849     if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) {
850       return StrFormat("%s %s", shape.ToString(),
851                        constant->literal().ToStringWithoutShape());
852     }
853 
854     // Otherwise, print e.g. "%constant.42 (s32[100])".
855     string constant_name;
856     if (absl::StartsWith(constant->name(), "constant")) {
857       constant_name = constant->name();
858     } else {
859       constant_name = StrCat("constant ", constant->name());
860     }
861     return StrFormat("%s %s", constant_name, ShapeUtil::HumanString(shape));
862   };
863 
864   std::vector<string> lines;
865   for (int64 i = 0; i < instr->operand_count(); ++i) {
866     const HloInstruction* operand = instr->operand(i);
867     optional<string> operand_str;
868     if (const auto* constant_operand =
869             DynCast<HloConstantInstruction>(operand)) {
870       operand_str =
871           stringify_constant(constant_operand, constant_operand->shape());
872     } else if (IsFusedBroadcastOfConstantEffectiveScalar(operand)) {
873       operand_str = stringify_constant(
874           Cast<HloConstantInstruction>(operand->operand(0)), operand->shape());
875     } else if (ShouldMergeIntoUsers(operand)) {
876       // Special case: If the operand is a parameter to a fusion node and it
877       // always has a constant value, display it like a regular constant.
878       //
879       // For other parameters, use the parameter number rather than the proper
880       // name, because that's generally how people think of the node.
881       if (operand->opcode() == HloOpcode::kParameter) {
882         if (const HloConstantInstruction* constant =
883                 TryGetFusionParameterConstant(operand)) {
884           operand_str = stringify_constant(constant, constant->shape());
885         } else {
886           operand_str = StrFormat("Parameter %d", operand->parameter_number());
887         }
888       } else {
889         operand_str = operand->name();
890       }
891     }
892 
893     if (operand_str) {
894       if (instr->operand_count() > 1) {
895         lines.push_back(StrFormat("<b>operand %d</b> = %s", i, *operand_str));
896       } else {
897         lines.push_back(StrFormat("<b>operand</b> = %s", *operand_str));
898       }
899     }
900   }
901   return StrJoin(lines, "<br/>");
902 }
903 
GetInstructionColor(const HloInstruction * instr)904 ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
905   if (debug_options_.xla_hlo_graph_sharding_color()) {
906     if (!instr->has_sharding()) {
907       return kDashedBorder;
908     }
909     auto it = sharding_colors_.find(instr->sharding());
910     if (it != sharding_colors_.end()) {
911       return it->second;
912     }
913     ColorScheme color = static_cast<ColorScheme>(
914         kBlue + (next_shard_color_++ % (kDashedBorder - kBlue)));
915     sharding_colors_.emplace(instr->sharding(), color);
916     return color;
917   }
918 
919   // Choose different weights of orange for small vs large parameters.  This
920   // distinction is often important, especially in fusion nodes.
921   auto parameter_color = IsSmall(instr) ? kOrange : kDarkOrange;
922 
923   // Special case: If this instruction has a parameter merged into it, paint it
924   // the same color as a parameter.  Unless the merged-in parameter is a
925   // parameter to a fusion node that is bound to a constant -- these aren't
926   // "real" parameters from the user's perspective.
927   if (absl::c_any_of(instr->operands(), [&](const HloInstruction* operand) {
928         return operand->opcode() == HloOpcode::kParameter &&
929                ShouldMergeIntoUsers(operand) &&
930                TryGetFusionParameterConstant(operand) == nullptr;
931       })) {
932     return parameter_color;
933   }
934 
935   // Pick different colors or shapes for instructions which are particularly
936   // expensive (eg, dot) and those which are unusual in some way or unique
937   // (eg, parameter).
938   switch (instr->opcode()) {
939     case HloOpcode::kAbs:
940     case HloOpcode::kAdd:
941     case HloOpcode::kAnd:
942     case HloOpcode::kAtan2:
943     case HloOpcode::kBitcastConvert:
944     case HloOpcode::kCeil:
945     case HloOpcode::kClamp:
946     case HloOpcode::kClz:
947     case HloOpcode::kCompare:
948     case HloOpcode::kComplex:
949     case HloOpcode::kConvert:
950     case HloOpcode::kCos:
951     case HloOpcode::kDivide:
952     case HloOpcode::kExp:
953     case HloOpcode::kExpm1:
954     case HloOpcode::kFloor:
955     case HloOpcode::kImag:
956     case HloOpcode::kIota:
957     case HloOpcode::kIsFinite:
958     case HloOpcode::kLog:
959     case HloOpcode::kLog1p:
960     case HloOpcode::kMaximum:
961     case HloOpcode::kMinimum:
962     case HloOpcode::kMultiply:
963     case HloOpcode::kNegate:
964     case HloOpcode::kNot:
965     case HloOpcode::kPopulationCount:
966     case HloOpcode::kOr:
967     case HloOpcode::kXor:
968     case HloOpcode::kPower:
969     case HloOpcode::kReal:
970     case HloOpcode::kRemainder:
971     case HloOpcode::kRng:
972     case HloOpcode::kRngGetAndUpdateState:
973     case HloOpcode::kRngBitGenerator:
974     case HloOpcode::kRoundNearestAfz:
975     case HloOpcode::kRsqrt:
976     case HloOpcode::kSelect:
977     case HloOpcode::kShiftLeft:
978     case HloOpcode::kShiftRightArithmetic:
979     case HloOpcode::kShiftRightLogical:
980     case HloOpcode::kLogistic:
981     case HloOpcode::kSign:
982     case HloOpcode::kSin:
983     case HloOpcode::kSlice:
984     case HloOpcode::kSort:
985     case HloOpcode::kSqrt:
986     case HloOpcode::kCbrt:
987     case HloOpcode::kSubtract:
988     case HloOpcode::kTanh:
989       // De-emphasize scalar-shaped elementwise ops -- they're generally
990       // uninteresting.
991       if (ShapeUtil::IsEffectiveScalar(instr->shape())) {
992         return kWhite;
993       }
994       return kYellow;
995     case HloOpcode::kBitcast:
996     case HloOpcode::kGetTupleElement:
997     case HloOpcode::kTrace:
998     case HloOpcode::kAfterAll:
999     case HloOpcode::kAddDependency:
1000     case HloOpcode::kTuple:
1001       return kWhite;
1002     case HloOpcode::kBroadcast:
1003       // De-emphasize nodes which broadcast a scalar within a fusion node --
1004       // these are essentially free.
1005       if (instr->IsFused() &&
1006           ShapeUtil::IsEffectiveScalar(instr->operand(0)->shape())) {
1007         return kWhite;
1008       }
1009       return kGreen;
1010     case HloOpcode::kConcatenate:
1011     case HloOpcode::kDynamicSlice:
1012     case HloOpcode::kGather:
1013     case HloOpcode::kPad:
1014     case HloOpcode::kReshape:
1015     case HloOpcode::kDynamicReshape:
1016     case HloOpcode::kReverse:
1017     case HloOpcode::kTupleSelect:
1018     case HloOpcode::kTranspose:
1019       // De-emphasize scalar-shaped data movement ops and all data movement ops
1020       // inside fusion nodes, both of which are essentially free.
1021       if (ShapeUtil::IsEffectiveScalar(instr->shape()) || instr->IsFused()) {
1022         return kWhite;
1023       }
1024       return kGreen;
1025     case HloOpcode::kDynamicUpdateSlice:
1026       // Unlike the data-movement ops above, dynamic-update-slice is not ~free
1027       // inside of fusion nodes, so we de-emphasize it only if it's
1028       // scalar-shaped.
1029       if (ShapeUtil::IsEffectiveScalar(instr->shape())) {
1030         return kWhite;
1031       }
1032       return kGreen;
1033     case HloOpcode::kScatter:
1034       // Do not de-emphasize Scatter, since it involves significant work.
1035     case HloOpcode::kCopy:
1036     case HloOpcode::kCopyStart:
1037     case HloOpcode::kCopyDone:
1038       // Emphasize copy nodes, which are either physical transposes (and thus
1039       // significant), or copies of read-only buffers (and thus dead weight).
1040       return kGreen;
1041     case HloOpcode::kConvolution:
1042     case HloOpcode::kDot:
1043     case HloOpcode::kFft:
1044     case HloOpcode::kTriangularSolve:
1045     case HloOpcode::kCholesky:
1046       return kDarkBlue;
1047     case HloOpcode::kReducePrecision:
1048       return kRed;
1049     case HloOpcode::kParameter:
1050       return parameter_color;
1051     case HloOpcode::kBatchNormGrad:
1052     case HloOpcode::kBatchNormInference:
1053     case HloOpcode::kBatchNormTraining:
1054     case HloOpcode::kReduce:
1055     case HloOpcode::kReduceWindow:
1056     case HloOpcode::kSelectAndScatter:
1057       return kPurple;
1058     case HloOpcode::kDomain:
1059     case HloOpcode::kFusion:
1060     case HloOpcode::kMap:
1061     case HloOpcode::kGetDimensionSize:
1062     case HloOpcode::kSetDimensionSize:
1063       return kGray;
1064     case HloOpcode::kAllGather:
1065     case HloOpcode::kAllReduce:
1066     case HloOpcode::kAllToAll:
1067     case HloOpcode::kCollectivePermute:
1068     case HloOpcode::kCollectivePermuteStart:
1069     case HloOpcode::kCollectivePermuteDone:
1070     case HloOpcode::kInfeed:
1071     case HloOpcode::kOutfeed:
1072     case HloOpcode::kPartitionId:
1073     case HloOpcode::kRecv:
1074     case HloOpcode::kRecvDone:
1075     case HloOpcode::kSend:
1076     case HloOpcode::kSendDone:
1077     case HloOpcode::kReplicaId:
1078       return kBrown;
1079     case HloOpcode::kCall:
1080     case HloOpcode::kConditional:
1081     case HloOpcode::kCustomCall:
1082     case HloOpcode::kWhile:
1083       return kDarkGreen;
1084     case HloOpcode::kConstant:
1085       LOG(FATAL) << "Constants don't get their own nodes in the graph.";
1086   }
1087 }
1088 
GetInstructionNodeShape(const HloInstruction * instr)1089 string HloDotDumper::GetInstructionNodeShape(const HloInstruction* instr) {
1090   // Give while loops a different shape so they're easier to pick out.
1091   switch (instr->opcode()) {
1092     case HloOpcode::kWhile:
1093       return "ellipse";
1094     default:
1095       return "rect";
1096   }
1097 }
1098 
GetInstructionNodeLabel(const HloInstruction * instr)1099 string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) {
1100   // If we have a parameter, put the param number in the name.
1101   if (instr->opcode() == HloOpcode::kParameter) {
1102     return StrFormat("<b>Parameter %d</b>", instr->parameter_number());
1103   }
1104 
1105   // The HLO instruction name contains usually the opcode, e.g. "%add.42" is
1106   // an add instruction.  In this case we render just the name.
1107   if (absl::StartsWith(instr->name(), HloOpcodeString(instr->opcode()))) {
1108     return StrFormat("<b>%s</b>", HtmlLikeStringSanitize(instr->name()));
1109   }
1110   string extended_opcode =
1111       StrCat(HloOpcodeString(instr->opcode()),
1112              instr->opcode() != HloOpcode::kFusion
1113                  ? ""
1114                  : StrCat(":", xla::ToString(instr->fusion_kind())));
1115   // If the name does not contain the opcode, render both.
1116   return StrFormat("<b>%s</b><br/>%s", HtmlLikeStringSanitize(extended_opcode),
1117                    HtmlLikeStringSanitize(instr->name()));
1118 }
1119 
GetInstructionNodeMetadata(const HloInstruction * instr)1120 string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
1121   std::vector<string> lines;
1122   if (!instr->metadata().op_name().empty()) {
1123     lines.push_back(HtmlLikeStringSanitize(instr->metadata().op_name()));
1124   }
1125   if (!instr->metadata().op_type().empty()) {
1126     lines.push_back(StrFormat(
1127         "op_type: %s", HtmlLikeStringSanitize(instr->metadata().op_type())));
1128   }
1129   if (!instr->metadata().source_file().empty() &&
1130       instr->metadata().source_line() != 0) {
1131     lines.push_back(StrFormat("source: %s:%d", instr->metadata().source_file(),
1132                               instr->metadata().source_line()));
1133   }
1134 
1135   return StrJoin(lines, "\n");
1136 }
1137 
GetInstructionNodeBackendConfig(const HloInstruction * instr)1138 string HloDotDumper::GetInstructionNodeBackendConfig(
1139     const HloInstruction* instr) {
1140   if (!hlo_render_options_.show_backend_config ||
1141       instr->raw_backend_config_string().empty()) {
1142     return "";
1143   }
1144 
1145   return StrCat("backend_config=\"", instr->raw_backend_config_string(), "\"");
1146 }
1147 
GetInstructionNodeExtraInfo(const HloInstruction * instr)1148 string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
1149   std::vector<string> lines;
1150 
1151   // Get the instruction's extra attributes excluding the names of its
1152   // subcomputations, since those are drawn explicitly in the graph.
1153   for (const auto& line : instr->ExtraAttributesToString(
1154            HloPrintOptions().set_print_subcomputation_mode(
1155                HloPrintOptions::PrintSubcomputationMode::kOff))) {
1156     // Some instructions have giant replica group fields, so truncate the
1157     // replica group line length to 128.
1158     constexpr int kMaxReplicaGroupLen = 128;
1159     if (absl::StartsWith(line, "replica_groups=") &&
1160         line.length() > kMaxReplicaGroupLen) {
1161       lines.push_back(HtmlLikeStringSanitize(
1162           StrCat(line.substr(0, kMaxReplicaGroupLen - 3), "...")));
1163     } else {
1164       lines.push_back(HtmlLikeStringSanitize(line));
1165     }
1166   }
1167 
1168   // Show the shape and layout of the instruction, unless it's an inlined fusion
1169   // node -- there the shape and layout is present in the output node.
1170   if (instr->opcode() != HloOpcode::kFusion ||
1171       !ShouldShowFusionSubcomputation(instr)) {
1172     // Show layout of instructions with more than one dimension.  Don't show
1173     // layout on tuples or tensors with just one dimension (which only have one
1174     // possible layout) to avoid visual noise.
1175     bool shape_is_multidim = false;
1176     ShapeUtil::ForEachSubshape(instr->shape(),
1177                                [&](const Shape& s, const ShapeIndex&) {
1178                                  shape_is_multidim |= s.dimensions_size() > 1;
1179                                });
1180     string instr_shape;
1181     if (instr->opcode() != HloOpcode::kTuple && shape_is_multidim) {
1182       instr_shape = ShapeUtil::HumanStringWithLayout(instr->shape());
1183     } else {
1184       instr_shape = ShapeUtil::HumanString(instr->shape());
1185     }
1186 
1187     // Some instructions have giant tuples as their shapes, so truncate the
1188     // HLO's shape to kMaxShapeLen characters.
1189     constexpr int kMaxShapeLen = 64;
1190     if (instr_shape.length() > kMaxShapeLen) {
1191       instr_shape = StrCat(
1192           absl::string_view(instr_shape).substr(0, kMaxShapeLen - 3), "...");
1193     }
1194     lines.push_back(HtmlLikeStringSanitize(instr_shape));
1195   }
1196   if (debug_options_.xla_hlo_graph_addresses()) {
1197     lines.push_back(StrFormat("[%p]", instr));
1198   }
1199   if (profile_ != nullptr) {
1200     double hlo_cycles_executed = profile_->GetCyclesTakenBy(*instr);
1201     double total_cycles_executed =
1202         profile_->total_cycles_executed(*instr->parent());
1203     if (hlo_cycles_executed > 0 && total_cycles_executed > 0) {
1204       lines.push_back(
1205           StrFormat("%% of cycles executed=%.2f",
1206                     100 * hlo_cycles_executed / total_cycles_executed));
1207     }
1208   }
1209   return StrJoin(lines, "<br/>");
1210 }
1211 
AddInstructionIncomingEdges(const HloInstruction * instr)1212 void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
1213   auto add_edge = [&](const HloInstruction* from, const HloInstruction* to,
1214                       int64 operand_num, bool control_edge = false) {
1215     from = GetNodeForEdge(from);
1216 
1217     if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant ||
1218         IsFusedBroadcastOfConstantEffectiveScalar(from) ||
1219         ShouldMergeIntoUsers(from)) {
1220       return;
1221     }
1222     VLOG(2) << "Adding edge from " << from->name() << " to " << to->name()
1223             << " as " << next_edge_id_;
1224     edge_ids_.insert({{from, to}, next_edge_id_++});
1225 
1226     string edge_label;
1227     if (instr->operand_count() > 1 && !control_edge) {
1228       edge_label =
1229           StrFormat(R"( headlabel="%d", labeldistance=2)", operand_num);
1230     } else if (control_edge) {
1231       edge_label = "style=\"dotted\" color=\"gray\" label=\"ctrl\"";
1232     }
1233 
1234     // We print "small" arrays using a hollow arrowhead and "large" arrays using
1235     // a filled arrowhead.
1236     constexpr char kEdgeFmt[] =
1237         R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)";
1238     edges_.push_back(StrFormat(kEdgeFmt, InstructionId(from), InstructionId(to),
1239                                (IsSmall(from) ? "empty" : "normal"),
1240                                from->name(), to->name(), edge_label));
1241   };
1242 
1243   // Add edges from instr's operands to instr.  Parameters within fusion
1244   // expressions are handled specially -- we draw an edge from the corresponding
1245   // operand on the fusion node itself to the parameter.
1246   if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) {
1247     // Only add the edge if this is not the outermost computation; otherwise it
1248     // will lead from a node we're not drawing.
1249     if (instr->parent() != computation_) {
1250       const HloInstruction* fusion = instr->parent()->FusionInstruction();
1251       add_edge(fusion->operand(instr->parameter_number()), instr,
1252                /*operand_num=*/0);
1253     }
1254   } else {
1255     for (int64 i = 0; i < instr->operand_count(); ++i) {
1256       add_edge(instr->operand(i), instr, i);
1257     }
1258     for (const HloInstruction* pred : instr->control_predecessors()) {
1259       add_edge(pred, instr, /*operand_num=*/0, /*control_edge=*/true);
1260     }
1261   }
1262 }
1263 
GetInstructionTrivialComputationStr(const HloInstruction * instr)1264 string HloDotDumper::GetInstructionTrivialComputationStr(
1265     const HloInstruction* instr) {
1266   // called_computations() on a fusion node "inherits" any called computations
1267   // of the fused root, which isn't what we want.  Just ignore fusion nodes
1268   // here; they're handled separately.
1269   if (instr->opcode() == HloOpcode::kFusion) {
1270     return "";
1271   }
1272 
1273   std::vector<string> lines;
1274   for (int64 i = 0; i < instr->called_computations().size(); ++i) {
1275     optional<string> computation_type =
1276         MatchTrivialComputation(instr->called_computations()[i]);
1277     if (!computation_type) {
1278       continue;
1279     }
1280     if (instr->called_computations().size() == 1) {
1281       lines.push_back(StrFormat("Subcomputation: <b>%s</b>",
1282                                 HtmlLikeStringSanitize(*computation_type)));
1283     } else {
1284       lines.push_back(StrFormat("Subcomputation %d: <b>%s</b>", i,
1285                                 HtmlLikeStringSanitize(*computation_type)));
1286     }
1287   }
1288   return StrJoin(lines, "<br/>");
1289 }
1290 
GetNodeForEdge(const HloInstruction * instr)1291 const HloInstruction* HloDotDumper::GetNodeForEdge(
1292     const HloInstruction* instr) {
1293   while (instr->opcode() == HloOpcode::kFusion &&
1294          ShouldShowFusionSubcomputation(instr)) {
1295     instr = instr->fused_expression_root();
1296   }
1297   return instr;
1298 }
1299 
1300 // Gets a NodeFilter that includes roughly all instructions whose distance from
1301 // root is <= radius.
MakeNodeRadiusAroundFilter(const HloInstruction * root,int64 radius,const absl::flat_hash_set<const HloInstruction * > & boundary)1302 NodeFilter MakeNodeRadiusAroundFilter(
1303     const HloInstruction* root, int64 radius,
1304     const absl::flat_hash_set<const HloInstruction*>& boundary) {
1305   // First, find the neighborhood of nodes with distance from root <= radius.
1306   // These nodes are our initial set of "normal" nodes.
1307   absl::flat_hash_map<const HloInstruction*, NodeFilterResult> nodes;
1308   std::deque<std::pair<const HloInstruction*, /*depth*/ int64>> worklist;
1309   worklist.push_back({root, 0});
1310   while (!worklist.empty()) {
1311     const HloInstruction* instr;
1312     int64 depth;
1313     std::tie(instr, depth) = worklist.front();
1314     worklist.pop_front();
1315 
1316     nodes[instr] = kNormalNode;
1317     if (depth == radius) {
1318       continue;
1319     }
1320     if (boundary.contains(instr)) {
1321       continue;
1322     }
1323 
1324     // Traverse into instr's operands.
1325     //
1326     // Don't traverse into tuples' operands unless the tuple is the root.
1327     // Usually a tuple is the bottommost node in the graph, and so its operands
1328     // are not interesting to the graph at hand.
1329     if (instr == root || instr->opcode() != HloOpcode::kTuple) {
1330       for (const HloInstruction* operand : instr->operands()) {
1331         if (!nodes.contains(operand)) {
1332           worklist.push_back({operand, depth + 1});
1333         }
1334       }
1335     }
1336 
1337     // Traverse into instr's nested computations.
1338     for (const HloComputation* computation : instr->called_computations()) {
1339       worklist.push_back({computation->root_instruction(), depth + 1});
1340     }
1341 
1342     // Traverse into instr's users, unless:
1343     //
1344     //  - there are a ton of them, in which case they're probably not
1345     //    interesting (and anyway, rendering them all would make the graph
1346     //    unreadable), or
1347     //  - instr is a constant, in which case its users are probably not
1348     //    interesting.
1349     if (instr->opcode() == HloOpcode::kConstant) {
1350       continue;
1351     }
1352     constexpr int kMaxUsersToRender = 16;
1353     if (instr->user_count() > kMaxUsersToRender) {
1354       // If we're going to skip this node's users, style it as such.
1355       nodes[instr] = kSomeUsersOmitted;
1356       continue;
1357     }
1358     for (const HloInstruction* user : instr->users()) {
1359       if (!nodes.contains(user)) {
1360         worklist.push_back({user, depth + 1});
1361       }
1362     }
1363   }
1364 
1365   auto is_displayed = [&](const HloInstruction* instr) {
1366     // Constants are displayed inline with their users; they're never omitted.
1367     // Nodes in subcomputations are always shown.
1368     return nodes.contains(instr) || instr->opcode() == HloOpcode::kConstant ||
1369            instr->parent() != root->parent();
1370   };
1371 
1372   // Make a second pass over 'nodes' to fix up the NodeFilterResults now that we
1373   // know which nodes will be included in the graph.
1374   for (auto& kv : nodes) {
1375     const HloInstruction* instr = kv.first;
1376     NodeFilterResult& filter_result = kv.second;
1377     const auto& operands = instr->operands();
1378 
1379     if (absl::c_any_of(operands, is_displayed) &&
1380         !absl::c_all_of(operands, is_displayed)) {
1381       // Mark nodes with some operands omitted appropriately.
1382       filter_result = kSomeOperandsOmitted;
1383     } else if (!operands.empty() && absl::c_none_of(operands, is_displayed)) {
1384       // Mark nodes with *all* operands omitted appropriately.
1385       filter_result = kOmitNodeOperands;
1386     }
1387 
1388     // Promote nodes with type kSomeUsersOmitted to kNormalNode if all of their
1389     // users made it into the graph.
1390     if (filter_result == kSomeUsersOmitted &&
1391         absl::c_all_of(instr->users(), is_displayed)) {
1392       filter_result = kNormalNode;
1393     }
1394   }
1395 
1396   // Highlight the root node.
1397   nodes[root] = kHighlightNode;
1398 
1399   return NodeFilter([=](const HloInstruction* instr) {
1400     auto it = nodes.find(instr);
1401     if (it != nodes.end()) {
1402       return it->second;
1403     }
1404     // Show all nodes in subcomputations.
1405     if (instr->parent() != root->parent()) {
1406       return kNormalNode;
1407     }
1408     return kHideNode;
1409   });
1410 }
1411 
1412 // Gets a node filter that includes nodes on all paths from `from` to `to`.  If
1413 // the all-paths set contains more than max_nodes elements, includes the nodes
1414 // on the shortest paths and sets hit_limit to true.
MakeNodeFromToFilter(const HloInstruction * from,const HloInstruction * to,int64 max_nodes,bool * hit_limit)1415 NodeFilter MakeNodeFromToFilter(const HloInstruction* from,
1416                                 const HloInstruction* to, int64 max_nodes,
1417                                 bool* hit_limit) {
1418   *hit_limit = false;
1419 
1420   // Elements in the queue are paths through the graph.
1421   std::deque<std::vector<const HloInstruction*>> queue;
1422   queue.push_front({from});
1423 
1424   // Compute the set of nodes we want to show using a slightly-modified
1425   // Djikstra's algorithm.  The only real difference is, rather than stopping
1426   // when we find a (shortest) path, we continue until we've found max_nodes
1427   // nodes on some path.
1428   std::unordered_set<const HloInstruction*> visited;
1429   std::unordered_set<const HloInstruction*> to_display = {from, to};
1430   while (!queue.empty() && to_display.size() < max_nodes) {
1431     std::vector<const HloInstruction*> path = std::move(queue.front());
1432     queue.pop_front();
1433     if (!visited.insert(path.back()).second) {
1434       continue;
1435     }
1436 
1437     for (const auto* user : path.back()->users()) {
1438       if (user == to) {
1439         auto it = path.begin();
1440         for (; it != path.end() && to_display.size() < max_nodes; ++it) {
1441           to_display.insert(*it);
1442         }
1443         if (it != path.end()) {
1444           *hit_limit = true;
1445         }
1446       } else if (!visited.count(user)) {
1447         auto new_path = path;
1448         new_path.push_back(user);
1449         queue.push_back(std::move(new_path));
1450       }
1451     }
1452   }
1453 
1454   return NodeFilter([=](const HloInstruction* instr) {
1455     if (instr == from || instr == to) {
1456       return kHighlightNode;
1457     }
1458     return to_display.count(instr) ? kNormalNode : kHideNode;
1459   });
1460 }
1461 
WrapDotInHtml(absl::string_view dot)1462 string WrapDotInHtml(absl::string_view dot) {
1463   static const char html_prefix[] = R"html(
1464 <!DOCTYPE html>
1465 <html>
1466 <head>
1467   <meta charset="utf-8">
1468   <style type="text/css">
1469     body {
1470       height: 100vh;
1471       margin: 0;
1472     }
1473   </style>
1474 </head>
1475 <body>
1476   <!-- Integrity hash is generated by https://www.srihash.org/ -->
1477   <script src="https://cdn.jsdelivr.net/npm/viz.js@2.1.1/viz.js"
1478      integrity="sha384-aD1MJYb0WKIUT+CtwJp5LTuV3U4pLAS6B/nUxL7ECimC2pN9N8vjlMr/yQCAkzxE"
1479      crossorigin="anonymous"></script>
1480   <script src="https://cdn.jsdelivr.net/npm/viz.js@2.1.1/full.render.js"
1481      integrity="sha384-bAixY275aIpCj6Te19y0MILZ4V+VEC8CVFujFEH+Lf7W+4XYYeYLwW5IBI6yQmMT"
1482      crossorigin="anonymous"></script>
1483   <script src="https://cdn.jsdelivr.net/npm/svg-pan-zoom@3.6.0/dist/svg-pan-zoom.min.js"
1484      integrity="sha384-3008WpYB2pOBvE7lwkrKf+qTmbTPGGPYxA9C1YVhvbPukns4ZFj7E98QPLkNW9dS"
1485      crossorigin="anonymous"></script>
1486   <div id="container" style="height:95vh; border:1px solid black; "></div>
1487   <script>
1488     var data = `
1489 )html";
1490 
1491   static const char html_suffix[] = R"html(
1492 `;
1493     var cssregex = new RegExp('stylesheet=<([^]*)\n>\n', 'gm');
1494     var results = cssregex.exec(data)
1495     // graphviz has problem dealing with large stylesheets.
1496     // https://github.com/tensorflow/tensorflow/issues/17220#issuecomment-369228492
1497     // In order to avoid the problem, remove the stylesheet from the dot and
1498     // insert it directly info the rendered SVG.
1499     var dot_data = data;
1500     var css_data = ''
1501     if (results !== null) {
1502         css_data = results[1].replace(/\s*data:.*\s*,/,''); // Strip content-type field.
1503         // CSS inside DOT is URL-escaped, so we must unescape it
1504         // before we can insert it into SVG.
1505         css_data = unescape(css_data);
1506         dot_data = data.replace(cssregex, ''); // Remove the stylesheet
1507     }
1508 
1509     var render_start = performance.now()
1510     function add_controls(svg) {
1511         var htmlblob = new Blob([document.documentElement.innerHTML],
1512                                 {type: 'text/html'});
1513         var savehtml = document.createElement('a');
1514         savehtml.setAttribute('href', URL.createObjectURL(htmlblob));
1515         savehtml.setAttribute('download', 'graph.html');
1516         savehtml.innerHTML = " [Save HTML+SVG] ";
1517         document.body.append(savehtml);
1518         var svgblob = new Blob([svg.outerHTML], {type: 'image/svg'});
1519         var savesvg = document.createElement('a');
1520         savesvg.setAttribute('href', URL.createObjectURL(svgblob));
1521         savesvg.setAttribute('download', 'graph.svg');
1522         savesvg.innerHTML = " [Save SVG] ";
1523         document.body.append(savesvg);
1524         var dotblob =  new Blob([data], {type: 'text/dot'});
1525         var savedot = document.createElement('a');
1526         savedot.setAttribute('href', URL.createObjectURL(dotblob));
1527         savedot.setAttribute('download', 'graph.dot');
1528         savedot.innerHTML = " [Save DOT] ";
1529         document.body.append(savedot);
1530         // Will get called after embed element was loaded
1531         var panzoom = svgPanZoom(svg, {
1532             zoomEnabled: true,
1533             controlIconsEnabled: true,
1534         });
1535         document.getElementsByTagName("BODY")[0].onresize = function() {
1536             panzoom.resize();
1537             panzoom.fit();
1538             panzoom.center();
1539         };
1540         var render_end = performance.now();
1541         var render_note = document.createElement('div')
1542         render_note.innerHTML = 'Rendering took '
1543                                 + (render_end - render_start).toFixed(2) + "ms."
1544         document.body.append(render_note);
1545     }
1546     var svg = document.getElementById('graph')
1547     if (svg == null) {
1548         // Need to render SVG first.
1549         var viz = new Viz();
1550         viz.renderSVGElement(dot_data)
1551             .then(function(svg){
1552                 var container = document.getElementById('container')
1553                 var style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
1554                 var node = document.createTextNode(css_data);
1555                 style.appendChild(node);
1556                 svg.setAttribute('width', '100%');
1557                 svg.setAttribute('height', '100%');
1558                 svg.setAttribute('id', 'graph');
1559                 svg.appendChild(style);
1560                 container.appendChild(svg);
1561                 add_controls(svg);
1562             })
1563     } else {
1564         // HTML already has rendered SVG embedded, so we just need to add
1565         // controls.
1566         add_controls(svg);
1567     }
1568   </script>
1569 </body>
1570 </html>
1571 )html";
1572 
1573   return absl::StrCat(html_prefix, dot, html_suffix);
1574 }
1575 
1576 tensorflow::mutex url_renderer_mu(tensorflow::LINKER_INITIALIZED);
1577 std::function<StatusOr<string>(absl::string_view)>* url_renderer
1578     TF_GUARDED_BY(url_renderer_mu) = nullptr;
1579 
1580 // Storage for fusion visualization: (module_id, computation_id) -> sequence of
1581 // dot dumps.
1582 tensorflow::mutex fusion_visualizer_state_mu(tensorflow::LINKER_INITIALIZED);
1583 static auto& fusion_visualizer_state TF_GUARDED_BY(fusion_visualizer_state_mu) =
1584     *new absl::flat_hash_map<std::pair<int64, int64>,
1585                              std::vector<std::string>>();
1586 
1587 // Generates a key to the fusion visualizer state mapping.
1588 std::pair<int, int> FusionVisualizerStateKey(
1589     const HloComputation& computation) {
1590   return std::make_pair(computation.parent()->unique_id(),
1591                         computation.unique_id());
1592 }
1593 
1594 // Generates a fusion explorer for the given computation using the data in
1595 // fusion_visualizer_state and the URL renderer. Precondition: url_renderer !=
1596 // nullptr.
1597 StatusOr<std::string> WrapFusionExplorer(const HloComputation& computation)
1598     TF_EXCLUSIVE_LOCKS_REQUIRED(url_renderer_mu) {
1599   CHECK(url_renderer != nullptr);
1600   tensorflow::mutex_lock lock(fusion_visualizer_state_mu);
1601   const std::vector<std::string>& dot_graphs =
1602       fusion_visualizer_state[FusionVisualizerStateKey(computation)];
1603   std::vector<std::string> dot_urls;
1604   dot_urls.reserve(dot_graphs.size());
1605   for (const std::string& dot : dot_graphs) {
1606     TF_ASSIGN_OR_RETURN(std::string url, (*url_renderer)(dot));
1607     dot_urls.push_back(url);
1608   }
1609 
1610   return absl::StrReplaceAll(
1611       R"(
1612   <!doctype html>
1613   <style>
1614     html, body {height: 100%; text-align: center;}
1615     #display {height: 80%; width: 80%;}
1616   </style>
1617   <title>Fusion Explorer: $TITLE</title>
1618   <iframe id='display' width=80% height=80%></iframe>
1619   <p id='description'></p>
1620   <p>
1621     <a id='prev' href='#'>Prev Step</a>
1622     <a id='next' href='#'>Next Step</a>
1623   </p>
1624   <p>
1625     Use j/k for keyboard navigation.
1626   </p>
1627   <script>
1628   var currId = -1;
1629   var urls = [$URLS];
1630 
1631   var setIframe = function() {
1632     document.getElementById('display').src = urls[currId];
1633   };
1634 
1635   var update = function(delta)  {
1636     currId = (currId + delta + urls.length) % urls.length;
1637     document.getElementById('description').innerHTML = "Frame #"
1638       + (currId + 1) + " / " + urls.length;
1639     setIframe();
1640   };
1641 
1642   document.getElementById('prev').onclick = function() {
1643     update(-1);
1644     return false;
1645   };
1646 
1647   document.getElementById('next').onclick = function() {
1648     update(1);
1649     return false;
1650   };
1651 
1652   window.addEventListener("keydown", function (event) {
1653     if (event.defaultPrevented) {
1654       return;
1655     }
1656     if (event.key == "j") {
1657       update(1);
1658     } else if (event.key == "k") {
1659       update(-1);
1660     } else {
1661       return;
1662     }
1663     event.preventDefault();
1664   }, true);
1665 
1666   document.addEventListener("DOMContentLoaded", function() {
1667     update(1);
1668   });
1669 
1670   </script>
1671   )",
1672       {{"$URLS", absl::StrJoin(dot_urls, ", ",
1673                                [&](std::string* out, const std::string& url) {
1674                                  absl::StrAppend(out, "\"", url, "\"");
1675                                })},
1676        {"$TITLE",
1677         absl::StrCat(computation.parent()->name(), "_", computation.name())}});
1678 }
1679 
1680 // Precondition: (url_renderer != nullptr || (format != kUrl
1681 //   && format != kFusionVisualization)).
1682 //
1683 // (We specify this as a precondition rather than checking it in here and
1684 // returning an error because we want to fail quickly when there's no URL
1685 // renderer available, and this function runs only after we've done all the work
1686 // of producing dot for the graph.)
WrapDotInFormat(const HloComputation & computation,absl::string_view dot,RenderedGraphFormat format)1687 StatusOr<string> WrapDotInFormat(const HloComputation& computation,
1688                                  absl::string_view dot,
1689                                  RenderedGraphFormat format)
1690     TF_EXCLUSIVE_LOCKS_REQUIRED(url_renderer_mu) {
1691   switch (format) {
1692     case RenderedGraphFormat::kUrl:
1693       CHECK(url_renderer != nullptr)
1694           << "Should have checked url_renderer != null before calling.";
1695       return (*url_renderer)(dot);
1696     case RenderedGraphFormat::kHtml:
1697       return WrapDotInHtml(dot);
1698     case RenderedGraphFormat::kDot:
1699       return string(dot);
1700     case RenderedGraphFormat::kFusionVisualization:
1701       return WrapFusionExplorer(computation);
1702   }
1703 }
1704 
1705 }  // namespace
1706 
RegisterGraphToURLRenderer(std::function<StatusOr<string> (absl::string_view)> renderer)1707 void RegisterGraphToURLRenderer(
1708     std::function<StatusOr<string>(absl::string_view)> renderer) {
1709   tensorflow::mutex_lock lock(url_renderer_mu);
1710   if (url_renderer != nullptr) {
1711     LOG(WARNING) << "Multiple calls to RegisterGraphToURLRenderer.  Last call "
1712                     "wins, but because order of initialization in C++ is "
1713                     "nondeterministic, this may not be what you want.";
1714   }
1715   delete url_renderer;
1716   url_renderer = new std::function<StatusOr<string>(absl::string_view)>(
1717       std::move(renderer));
1718 }
1719 
RegisterFusionState(const HloComputation & computation,absl::string_view label)1720 Status RegisterFusionState(const HloComputation& computation,
1721                            absl::string_view label) {
1722   tensorflow::mutex_lock lock(fusion_visualizer_state_mu);
1723   TF_ASSIGN_OR_RETURN(
1724       string dot_graph,
1725       RenderGraph(computation,
1726                   absl::StrCat(computation.parent()->name(), ", ",
1727                                computation.name(), ", ", label),
1728                   /*debug_options=*/{}, xla::RenderedGraphFormat::kDot,
1729                   /*hlo_execution_profile=*/nullptr,
1730                   /*hlo_render_options=*/{}));
1731   std::vector<std::string>& fusion_states =
1732       fusion_visualizer_state[FusionVisualizerStateKey(computation)];
1733   if (fusion_states.empty() || fusion_states.back() != dot_graph) {
1734     fusion_states.push_back(dot_graph);
1735   }
1736   return Status::OK();
1737 }
1738 
RenderGraph(const HloComputation & computation,absl::string_view label,const DebugOptions & debug_options,RenderedGraphFormat format,const HloExecutionProfile * hlo_execution_profile,HloRenderOptions hlo_render_options)1739 StatusOr<string> RenderGraph(const HloComputation& computation,
1740                              absl::string_view label,
1741                              const DebugOptions& debug_options,
1742                              RenderedGraphFormat format,
1743                              const HloExecutionProfile* hlo_execution_profile,
1744                              HloRenderOptions hlo_render_options) {
1745   tensorflow::mutex_lock lock(url_renderer_mu);
1746   if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
1747     return Unavailable("Can't render as URL; no URL renderer was registered.");
1748   }
1749 
1750   string rendered_dot =
1751       HloDotDumper(&computation, label, debug_options, hlo_render_options,
1752                    hlo_execution_profile, NodeFilter())
1753           .Dump();
1754   return WrapDotInFormat(computation, rendered_dot, format);
1755 }
1756 
RenderNeighborhoodAround(const HloInstruction & node,int radius,RenderedGraphFormat format,HloRenderOptions hlo_render_options,const absl::flat_hash_set<const HloInstruction * > & boundary)1757 StatusOr<string> RenderNeighborhoodAround(
1758     const HloInstruction& node, int radius, RenderedGraphFormat format,
1759     HloRenderOptions hlo_render_options,
1760     const absl::flat_hash_set<const HloInstruction*>& boundary) {
1761   tensorflow::mutex_lock lock(url_renderer_mu);
1762   if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
1763     return FailedPrecondition(
1764         "Can't render as URL; no URL renderer was registered.");
1765   }
1766 
1767   string label =
1768       StrCat("Neighborhood of ", radius, " nodes around ", node.name());
1769   string rendered_dot =
1770       HloDotDumper(node.parent(), label,
1771                    node.GetModule()->config().debug_options(),
1772                    hlo_render_options, /*profile=*/nullptr,
1773                    MakeNodeRadiusAroundFilter(&node, radius, boundary))
1774           .Dump();
1775   return WrapDotInFormat(*node.parent(), rendered_dot, format);
1776 }
1777 
RenderAllPathsFromTo(const HloInstruction & from,const HloInstruction & to,int64 max_nodes,RenderedGraphFormat format,HloRenderOptions hlo_render_options)1778 StatusOr<string> RenderAllPathsFromTo(const HloInstruction& from,
1779                                       const HloInstruction& to, int64 max_nodes,
1780                                       RenderedGraphFormat format,
1781                                       HloRenderOptions hlo_render_options) {
1782   tensorflow::mutex_lock lock(url_renderer_mu);
1783   if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
1784     return FailedPrecondition(
1785         "Can't render as URL; no URL renderer was registered.");
1786   }
1787 
1788   CHECK_EQ(from.parent(), to.parent()) << "Nodes must be in same computation!";
1789   auto debug_options = from.GetModule()->config().debug_options();
1790 
1791   bool hit_limit = false;
1792   NodeFilter filter = MakeNodeFromToFilter(&from, &to, max_nodes, &hit_limit);
1793   string label;
1794   if (!hit_limit) {
1795     label = StrCat("All paths from ", from.name(), " to ", to.name());
1796   } else {
1797     label = StrCat(max_nodes, " nodes on the shortest paths from ", from.name(),
1798                    " to ", to.name(),
1799                    "<br/><br/>***SHOWING ONLY A SUBSET OF ALL PATHS BETWEEN "
1800                    "NODES***<br/><br/>");
1801   }
1802   string rendered_dot =
1803       HloDotDumper(from.parent(), label, debug_options, hlo_render_options,
1804                    /*profile=*/nullptr, filter)
1805           .Dump();
1806   return WrapDotInFormat(*from.parent(), rendered_dot, format);
1807 }
1808 
1809 }  // namespace xla
1810