1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
17
18 #include <unistd.h>
19 #include <algorithm>
20 #include <atomic>
21 #include <deque>
22 #include <map>
23 #include <memory>
24 #include <queue>
25 #include <string>
26 #include <tuple>
27 #include <vector>
28
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/strings/match.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_format.h"
33 #include "absl/strings/str_join.h"
34 #include "absl/strings/str_replace.h"
35 #include "absl/types/optional.h"
36 #include "tensorflow/compiler/xla/layout_util.h"
37 #include "tensorflow/compiler/xla/literal.h"
38 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
39 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
40 #include "tensorflow/compiler/xla/service/hlo_module.h"
41 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
42 #include "tensorflow/compiler/xla/shape_util.h"
43 #include "tensorflow/compiler/xla/types.h"
44 #include "tensorflow/compiler/xla/util.h"
45 #include "tensorflow/compiler/xla/window_util.h"
46 #include "tensorflow/core/lib/core/status.h"
47 #include "tensorflow/core/lib/gtl/map_util.h"
48 #include "tensorflow/core/lib/io/path.h"
49 #include "tensorflow/core/lib/strings/numbers.h"
50 #include "tensorflow/core/platform/env.h"
51 #include "tensorflow/core/platform/mutex.h"
52 #include "tensorflow/core/platform/protobuf.h"
53 #include "tensorflow/core/platform/regexp.h"
54
55 namespace xla {
56 namespace {
57
58 using absl::nullopt;
59 using absl::optional;
60 using absl::StrAppend;
61 using absl::StrCat;
62 using absl::StrFormat;
63 using absl::StrJoin;
64 using tensorflow::Env;
65 using tensorflow::WriteStringToFile;
66 using tensorflow::io::JoinPath;
67
68 // Used to indicate how we should treat a given HLOInstruction in the graph.
69 // should we treat it like normal, hide it, and so on?
70 enum NodeFilterResult {
71 kNormalNode,
72 kHideNode,
73 // Make the node easy to find in the final graph.
74 kHighlightNode,
75 // "Gray out" the node to indicate that some of its operands have been
76 // omitted.
77 kSomeOperandsOmitted,
78 // Style the node the same as kSomeOperandsOmitted, but also don't connect it
79 // to its operands, even if they're present in the graph.
80 kOmitNodeOperands,
81 // Same style as kSomeOperandsOmitted, but used to indicate that some of the
82 // node's *users* have been omitted.
83 kSomeUsersOmitted,
84 };
85
86 // NodeFilter is essentially a map from HloInstruction*s to NodeFilterResult.
87 // It lets callers tell the graph-drawing routines which nodes they want to be
88 // shown, hidden, or highlighted.
89 class NodeFilter {
90 public:
__anon7c0932a70202(const HloInstruction*) 91 NodeFilter() : filter_([](const HloInstruction*) { return kNormalNode; }) {}
92
NodeFilter(std::function<NodeFilterResult (const HloInstruction * instr)> filter)93 explicit NodeFilter(
94 std::function<NodeFilterResult(const HloInstruction* instr)> filter)
95 : filter_(std::move(filter)) {}
96
Show(const HloInstruction * instr) const97 bool Show(const HloInstruction* instr) const {
98 return filter_(instr) != kHideNode;
99 }
Highlight(const HloInstruction * instr) const100 bool Highlight(const HloInstruction* instr) const {
101 return filter_(instr) == kHighlightNode;
102 }
OmitOperands(const HloInstruction * instr) const103 bool OmitOperands(const HloInstruction* instr) const {
104 return filter_(instr) == kOmitNodeOperands;
105 }
SomeOrAllOperandsOmitted(const HloInstruction * instr) const106 bool SomeOrAllOperandsOmitted(const HloInstruction* instr) const {
107 auto result = filter_(instr);
108 return result == kOmitNodeOperands || result == kSomeOperandsOmitted;
109 }
Deemphasized(const HloInstruction * instr) const110 bool Deemphasized(const HloInstruction* instr) const {
111 auto result = filter_(instr);
112 return result == kOmitNodeOperands || result == kSomeOperandsOmitted ||
113 result == kSomeUsersOmitted;
114 }
115
116 private:
117 std::function<NodeFilterResult(const HloInstruction* instr)> filter_;
118 };
119
120 // We arbitrarily set this as the boundary between "large" and "small"
121 // instructions.
IsSmall(const HloInstruction * instr)122 bool IsSmall(const HloInstruction* instr) {
123 if (ShapeUtil::HasPrimitiveType(instr->shape(), OPAQUE) ||
124 ShapeUtil::HasPrimitiveType(instr->shape(), TOKEN)) {
125 return true;
126 }
127 return ShapeUtil::ElementsInRecursive(instr->shape()) < 4096;
128 }
129
130 // Node color schemes, used by NodeColorAttributes.
131 enum ColorScheme {
132 kBlue,
133 kBrown,
134 kDarkBlue,
135 kDarkGreen,
136 kDarkOrange,
137 kDarkRed,
138 kGray,
139 kGreen,
140 kOrange,
141 kPurple,
142 kRed,
143 kWhite,
144 kYellow,
145
146 // Causes the node's border to be a dashed line, and its content to be gray
147 // text on a white background, suggesting that this is an "unimportant" node.
148 kDashedBorder,
149 };
150
151 // Graphviz attributes/colors that make up a color scheme.
152 struct NodeColors {
153 const char* style;
154 const char* fill_color;
155 const char* stroke_color;
156 const char* font_color;
157 };
158
NodeColorsForScheme(ColorScheme color)159 NodeColors NodeColorsForScheme(ColorScheme color) {
160 switch (color) {
161 case kBlue:
162 return NodeColors{"filled", "#bbdefb", "#8aacc8", "black"};
163 case kBrown:
164 return NodeColors{"filled", "#bcaaa4", "#8c7b75", "black"};
165 case kDarkBlue:
166 return NodeColors{"filled", "#1565c0", "#003c8f", "white"};
167 case kDarkGreen:
168 return NodeColors{"filled", "#2e7d32", "#005005", "white"};
169 case kDarkOrange:
170 // This is more of a "medium" orange, made to look close to kOrange;
171 // there's probably room for a darker weight if desired.
172 return NodeColors{"filled", "#ffb74d", "#c88719", "black"};
173 case kDarkRed:
174 return NodeColors{"filled", "#b71c1c", "#7f0000", "white"};
175 case kGray:
176 return NodeColors{"filled", "#cfd8dc", "#9ea7aa", "black"};
177 case kGreen:
178 return NodeColors{"filled", "#c8e6c9", "#97b498", "black"};
179 case kOrange:
180 return NodeColors{"filled", "#ffe0b2", "#cbae82", "black"};
181 case kPurple:
182 return NodeColors{"filled", "#e1bee7", "#af8eb5", "black"};
183 case kRed:
184 return NodeColors{"filled", "#ffcdd2", "#cb9ca1", "black"};
185 case kWhite:
186 return NodeColors{"filled", "white", "black", "black"};
187 case kYellow:
188 return NodeColors{"filled", "#fff9c4", "#cbc693", "black"};
189 case kDashedBorder:
190 // "filled,dashed" looks the same as "dashed", since we have a white
191 // background. But we use "filled,dashed" so that when you hover over
192 // any part of the node (not just the text inside the node), our css
193 // :hover rule is triggered.
194 return NodeColors{"filled,dashed", "white", "#757575", "#757575"};
195 }
196 }
197
198 // Given a ColorScheme, returns an attribute string for a node of that color.
199 // Sets the node's style and fill/stroke/text colors.
200 //
201 // Colors are from https://material.io/color.
NodeColorAttributes(ColorScheme color)202 string NodeColorAttributes(ColorScheme color) {
203 NodeColors node_colors = NodeColorsForScheme(color);
204
205 return StrFormat(R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")",
206 node_colors.style, node_colors.font_color,
207 node_colors.stroke_color, node_colors.fill_color);
208 }
209
210 // Replaces <> with <>, so that this string is safe(er) for use in a
211 // graphviz HTML-like string.
HtmlLikeStringSanitize(absl::string_view s)212 string HtmlLikeStringSanitize(absl::string_view s) {
213 return absl::StrReplaceAll(s, {{"<", "<"}, {">", ">"}});
214 }
215
216 // Tries to generates a human-readable one-word description of the given
217 // computation.
218 //
219 // Currently we support:
220 //
221 // "return param0 + param1;" --> "add"
222 // "return param0 * param1;" --> "multiply"
223 // "return min(param0, param1);" --> "min"
224 // "return max(param0, param1);" --> "max"
225 // "return param0 <= param1;" --> "less-or-equal"
226 // "return param0 >= param1;" --> "greater-or-equal"
227 // "return param0 > param1;" --> "greater-than"
228 // "return param0 < param1;" --> "less-than"
229 // "return param0 == param1;" --> "equal-to"
230 // "return param0 != param1;" --> "not-equal-to"
231 //
232 // where param0 and param1 are effective scalars. For the ops that are
233 // commutative, we also support them with param0 and param1 swapped.
234 //
235 // This is useful primarily for reduce and map nodes. These take a
236 // subcomputation which is almost always one of the above, and pattern matching
237 // it to a short string lets us tell the user what the subcomputation is without
238 // drawing it as a graph.
MatchTrivialComputation(const HloComputation * computation)239 optional<string> MatchTrivialComputation(const HloComputation* computation) {
240 namespace m = match;
241
242 if (computation->instruction_count() != 3) {
243 return nullopt;
244 }
245 HloInstruction* root = computation->root_instruction();
246 const HloInstruction *param0, *param1;
247 if (!Match(root, m::Op()
248 .WithNumOperands(2)
249 .WithShape(m::Shape().IsEffectiveScalar())
250 .WithBinaryOperandsAnyOrder(
251 m::Parameter(¶m0, 0)
252 .WithShape(m::Shape().IsEffectiveScalar()),
253 m::Parameter(¶m1, 1)
254 .WithShape(m::Shape().IsEffectiveScalar())))) {
255 return nullopt;
256 }
257
258 // If the params are reversed (i.e. operand0 is param1 and operand1 is
259 // param0), check that the operation being performed is commutative.
260 if (root->operand(0) == param1) {
261 CHECK_EQ(root->operand(1), param0);
262 if (root->opcode() == HloOpcode()) {
263 switch (root->comparison_direction()) {
264 case ComparisonDirection::kLe:
265 case ComparisonDirection::kGe:
266 case ComparisonDirection::kGt:
267 case ComparisonDirection::kLt:
268 return nullopt;
269 default:
270 break;
271 }
272 }
273 }
274
275 // If we recognize the root's opcode, we've successfully pattern-matched!
276 switch (root->opcode()) {
277 case HloOpcode::kAdd:
278 return "add";
279 case HloOpcode::kMultiply:
280 return "multiply";
281 case HloOpcode::kMinimum:
282 return "min";
283 case HloOpcode::kMaximum:
284 return "max";
285 case HloOpcode::kCompare: {
286 switch (root->comparison_direction()) {
287 case ComparisonDirection::kLe:
288 return "less-or-equal";
289 case ComparisonDirection::kGe:
290 return "greater-or-equal";
291 case ComparisonDirection::kGt:
292 return "greater-than";
293 case ComparisonDirection::kLt:
294 return "less-than";
295 case ComparisonDirection::kEq:
296 return "equal-to";
297 case ComparisonDirection::kNe:
298 return "not-equal-to";
299 }
300 }
301 default:
302 return nullopt;
303 }
304 }
305
306 // Encapsulates logic for dumping an HLO module to DOT (i.e. graphviz syntax).
307 class HloDotDumper {
308 public:
HloDotDumper(const HloComputation * computation,absl::string_view label,const DebugOptions & debug_options,bool show_backend_config,const HloExecutionProfile * profile,NodeFilter filter)309 HloDotDumper(const HloComputation* computation, absl::string_view label,
310 const DebugOptions& debug_options, bool show_backend_config,
311 const HloExecutionProfile* profile, NodeFilter filter)
312 : computation_(computation),
313 label_(label),
314 debug_options_(debug_options),
315 show_backend_config_(show_backend_config),
316 profile_(profile),
317 filter_(std::move(filter)) {}
318
319 string Dump();
320
321 private:
322 // Returns the dot graph identifier for the given instruction.
InstructionId(const HloInstruction * instruction)323 string InstructionId(const HloInstruction* instruction) {
324 return StrCat(reinterpret_cast<uint64>(instruction));
325 }
326
327 // Returns the dot graph identifier for the given computation.
SubcomputationId(const HloComputation * computation)328 string SubcomputationId(const HloComputation* computation) {
329 return StrCat("cluster_", reinterpret_cast<uint64>(computation));
330 }
331
332 // Generates graph header/footer. These should be called *after* dumping all
333 // of the instructions and subcomputations for the graph, as they both use
334 // data generated while dumping the graph.
335 string Header();
336 string Footer();
337
338 bool ShouldShowSubcomputation(const HloComputation* subcomp);
339 bool ShouldShowFusionSubcomputation(const HloInstruction* instr);
340
341 // We omit some nodes from the graph, instead drawing them inlined into the
342 // nodes that use them.
343 bool ShouldMergeIntoUsers(const HloInstruction* instr) const;
344
345 string DumpSubcomputation(const HloComputation* subcomp,
346 const HloInstruction* parent_instr);
347 string DumpComputation(const HloComputation* comp);
348 string DumpRootTag();
349 string DumpInstruction(const HloInstruction* instr);
350 ColorScheme GetInstructionColor(const HloInstruction* instr);
351 string GetInstructionNodeShape(const HloInstruction* instr);
352 string GetInstructionNodeLabel(const HloInstruction* instr);
353 string GetInstructionNodeMetadata(const HloInstruction* instr);
354 string GetInstructionNodeBackendConfig(const HloInstruction* instr);
355 string GetInstructionNodeExtraInfo(const HloInstruction* instr);
356 string GetInstructionNodeInlinedOperands(const HloInstruction* instr);
357 void AddInstructionIncomingEdges(const HloInstruction* instr);
358
359 // For most instructions, GetNodeForEdge(instr) returns instr.
360 //
361 // The exception is fusion nodes. For these, we walk up the chain of nested
362 // fusion nodes starting at instr until we reach a node that either (a) isn't
363 // a fusion node, or (b) is a fusion node for which
364 // ShouldShowFusionSubcomputation is false.
365 //
366 // We do this because fusion nodes are expanded inline -- if
367 // ShouldShowFusionSubcomputation is true, the fusion node won't be present in
368 // the graph.
369 //
370 // In general when you want to draw an edge from A to B, you should actually
371 // draw an edge from GetNodeForEdge(A) to GetNodeForEdge(B).
372 const HloInstruction* GetNodeForEdge(const HloInstruction* instr);
373
374 // If instr has just one computation and it's trivial (e.g. "return param0 +
375 // param1"), returns a string you can put into the node's body that names the
376 // subcomputation, e.g. "Subcomputation: <b>add</b>".
377 string GetInstructionTrivialComputationStr(const HloInstruction* instr);
378
379 const HloComputation* computation_; // never null
380 const string label_; // overall name for the graph
381 const DebugOptions& debug_options_;
382 const bool show_backend_config_;
383 const HloExecutionProfile* profile_; // may be null
384 const NodeFilter filter_;
385
386 // Each HloInstruction dumped gets a monotically-increasing node ID. This
387 // must start at 1, because that's where graphviz's accounting starts.
388 int64 next_node_id_ = 1;
389 absl::flat_hash_map<const HloInstruction*, int64> node_ids_;
390
391 // The "root" tag doesn't have an associated HloInstruction pointer, so we
392 // need to store it outside the map.
393 int64 root_node_id_;
394
395 // Each (from, to) edge gets a monotonically-increasing ID. This is a
396 // multimap because it's possible for the same edge to appear multiple times
397 // in the graph (e.g. x^2 may be represented as mul(x, x)).
398 int64 next_edge_id_ = 1;
399 std::unordered_multimap<
400 std::pair<const HloInstruction*, const HloInstruction*>, int64,
401 tensorflow::hash<std::pair<const HloInstruction*, const HloInstruction*>>>
402 edge_ids_;
403
404 // Each HloComputation that's emitted gets a monotonically-increasing ID.
405 int64 next_cluster_id_ = 1;
406 absl::flat_hash_map<const HloComputation*, int64> cluster_ids_;
407
408 // Edges to print from Footer(). Edges come at the end because graphviz is
409 // unhappy if an edge from a subcomputation to a node in the outer computation
410 // appears before both the inner computation and the destination node are
411 // defined.
412 std::vector<string> edges_;
413
414 // When coloring by sharding information, we track the sharding string
415 // representation to color association, by round-robin the color schemes.
416 absl::flat_hash_map<HloSharding, ColorScheme, HloSharding::Hasher>
417 sharding_colors_;
418 int64 next_shard_color_ = 0;
419 };
420
Dump()421 string HloDotDumper::Dump() {
422 string body;
423 StrAppend(&body, DumpComputation(computation_));
424 StrAppend(&body, DumpRootTag());
425
426 // By contract, Header() and Footer() have to be called after we've dumped all
427 // our instructions, because they use state generated during that process.
428 string g = Header();
429 StrAppend(&g, body);
430 StrAppend(&g, Footer());
431 return g;
432 }
433
Header()434 string HloDotDumper::Header() {
435 constexpr char fmt[] = R"(digraph G {
436 rankdir = TB;
437 compound = true;
438 label = <<b>%s</b>>;
439 labelloc = t;
440 // Disable the tooltip. Interestingly, "" doesn't work!
441 tooltip = " ";
442 // DOT graphs accept a stylesheet as a URI. So naturally, an inline
443 // stylesheet is a data URI!
444 stylesheet=<
445 data:text/css,
446 @import url(https://fonts.googleapis.com/css?family=Roboto:400,700);
447 svg text {
448 font-family: 'Roboto';
449 font-size: 12px;
450 }
451
452 %s
453 >
454
455 )";
456
457 VLOG(3) << "Generating Header";
458
459 string graph_label =
460 StrCat(label_, "<br/>Computation ", computation_->name());
461 if (computation_->IsFusionComputation()) {
462 StrAppend(&graph_label, " (in fusion instruction ",
463 computation_->FusionInstruction()->name(), ")");
464 }
465 if (profile_ != nullptr) {
466 auto cycles = profile_->total_cycles_executed(*computation_);
467 absl::StrAppendFormat(&graph_label, "<br/>total cycles = %d (%s)", cycles,
468 tensorflow::strings::HumanReadableNum(cycles));
469 }
470
471 // Create CSS rules that say, when you hover over the given node or cluster,
472 // turn the given edge the given color.
473 //
474 // We rely on a few properties of how graphviz generates SVGs:
475 //
476 // - Nodes are named "nodeN", where N corresponds to the 1-based index of
477 // the node in our DOT (i.e. the first node in the DOT is "node1", etc.).
478 // Edges are similarly named "edgeN", and clusters are named "clustN".
479 // - Nodes come before their in- and out-edges in the SVG. We need this
480 // because the "X ~ Y" CSS selector finds a sibling of X that *comes
481 // after X in the DOM* and matches Y.
482 std::vector<string> edge_css_rules;
483 const char* kBlue = "#1976d2";
484 const char* kRed = "#d32f2f";
485 for (const auto& kv : edge_ids_) {
486 const HloInstruction* from_node = kv.first.first;
487 const HloInstruction* to_node = kv.first.second;
488 int64 edge_id = kv.second;
489
490 auto add_hover_css_rule = [&](string elem_type, int64 elem_id,
491 const char* color) {
492 // One could imagine other ways of writing this CSS rule that involve
493 // less duplication, but this way seems to be relatively performant.
494 edge_css_rules.push_back(
495 StrFormat(" #%s%d:hover ~ #edge%d text { fill: %s; }\n"
496 " #%s%d:hover ~ #edge%d path { "
497 "stroke: %s; stroke-width: .2em; }\n"
498 " #%s%d:hover ~ #edge%d polygon { "
499 "fill: %s; stroke: %s; stroke-width: .2em; }\n",
500 elem_type, elem_id, edge_id, color, //
501 elem_type, elem_id, edge_id, color, //
502 elem_type, elem_id, edge_id, color, color));
503 };
504
505 // The "to_node" value may be a NULL, indicating that this points to the
506 // "root" tag rather than a normal node.
507 int64 from_node_id =
508 tensorflow::gtl::FindWithDefault(node_ids_, from_node, -1);
509 if (from_node_id == -1) {
510 LOG(FATAL) << from_node->name() << " was added to edges but not to nodes";
511 }
512 int64 to_node_id =
513 to_node ? tensorflow::gtl::FindWithDefault(node_ids_, to_node, -1)
514 : root_node_id_;
515 if (to_node != nullptr && to_node_id == -1) {
516 LOG(FATAL) << to_node->name() << " was added to edges but not to nodes";
517 }
518
519 add_hover_css_rule("node", from_node_id, kBlue);
520 add_hover_css_rule("node", to_node_id, kRed);
521
522 if (to_node) {
523 VLOG(3) << "Adding css for edge " << edge_id << " from node "
524 << from_node->name() << " to node " << to_node->name();
525 } else {
526 VLOG(3) << "Adding css for edge " << edge_id << " from node "
527 << from_node->name() << " to root tag";
528 }
529
530 // If this edge crosses a fusion cluster boundary, highlight it when the
531 // cluster is hovered over.
532 if (to_node) {
533 if (from_node->IsFused() &&
534 from_node->parent()->root_instruction() == from_node) {
535 int64 cluster_id = cluster_ids_.at(from_node->parent());
536 add_hover_css_rule("clust", cluster_id, kBlue);
537 }
538 if (to_node->IsFused() && to_node->opcode() == HloOpcode::kParameter) {
539 int64 cluster_id = cluster_ids_.at(to_node->parent());
540 add_hover_css_rule("clust", cluster_id, kRed);
541 }
542 }
543 }
544
545 // Browsers require that we URI-encode the contents of our data URI. (It
546 // seems this was a relatively recent change?) In practice, this means that we
547 // need to escape '#'.
548 return StrFormat(
549 fmt, graph_label,
550 absl::StrReplaceAll(StrJoin(edge_css_rules, "\n"), {{"#", "%23"}}));
551 }
552
Footer()553 string HloDotDumper::Footer() { return StrCat(StrJoin(edges_, "\n"), "\n}"); }
554
ShouldShowFusionSubcomputation(const HloInstruction * instr)555 bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) {
556 CHECK_EQ(instr->opcode(), HloOpcode::kFusion);
557 return ShouldShowSubcomputation(instr->fused_instructions_computation());
558 }
559
ShouldShowSubcomputation(const HloComputation * subcomp)560 bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) {
561 if (subcomp->IsFusionComputation()) {
562 const HloInstruction* fusion = subcomp->FusionInstruction();
563 if (!filter_.Show(fusion) || filter_.SomeOrAllOperandsOmitted(fusion)) {
564 return false;
565 }
566 }
567
568 // Don't show trivial subcomputations on non-fusion nodes -- these are inlined
569 // into the graph.
570 if (!subcomp->IsFusionComputation() && MatchTrivialComputation(subcomp)) {
571 return false;
572 }
573
574 // Show the subcomputation if we're showing any of its members.
575 return absl::c_any_of(
576 subcomp->instructions(),
577 [&](const HloInstruction* instr) { return filter_.Show(instr); });
578 }
579
DumpSubcomputation(const HloComputation * subcomp,const HloInstruction * parent_instr)580 string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp,
581 const HloInstruction* parent_instr) {
582 VLOG(2) << "Dumping subcomputation " << subcomp->name();
583 // Add an edge from the subcomputation to its parent node. If subcomp
584 // belongs to a fusion node, it's drawn in place of the fusion instruction,
585 // so there's no need to link those.
586 if (parent_instr->opcode() != HloOpcode::kFusion) {
587 const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction());
588 VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name()
589 << " as " << next_edge_id_;
590 edge_ids_.insert({{from, parent_instr}, next_edge_id_++});
591 constexpr char edge_fmt[] =
592 R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)";
593 edges_.push_back(StrFormat(
594 edge_fmt, InstructionId(from), InstructionId(parent_instr),
595 SubcomputationId(subcomp), subcomp->name(), parent_instr->name()));
596 }
597
598 // Have we already dumped this subcomputation? If so, generating the edge
599 // linking it and parent_instr is all we want to do in this function.
600 if (cluster_ids_.find(subcomp) != cluster_ids_.end()) {
601 return "";
602 }
603
604 cluster_ids_[subcomp] = next_cluster_id_++;
605
606 string id = SubcomputationId(subcomp);
607
608 string subcomp_label, style;
609 if (parent_instr->opcode() == HloOpcode::kFusion) {
610 subcomp_label =
611 StrFormat("Fused expression for <b>%s</b><br/>%s",
612 HtmlLikeStringSanitize(parent_instr->name()),
613 HtmlLikeStringSanitize(parent_instr->ToCategory()));
614 string extra_info = GetInstructionNodeExtraInfo(parent_instr);
615 if (!extra_info.empty()) {
616 StrAppend(&subcomp_label, "<br/>", extra_info);
617 }
618 string node_backend_config = GetInstructionNodeBackendConfig(parent_instr);
619 if (!node_backend_config.empty()) {
620 StrAppend(&subcomp_label, "<br/>", node_backend_config);
621 }
622
623 bool highlight = filter_.Highlight(parent_instr);
624 const char* fillcolor;
625 const char* strokecolor;
626 if (debug_options_.xla_hlo_graph_sharding_color() && !highlight) {
627 // Use the sharding color, if the node isn't highlighted.
628 NodeColors node_colors =
629 NodeColorsForScheme(GetInstructionColor(parent_instr));
630 fillcolor = node_colors.fill_color;
631 strokecolor = node_colors.stroke_color;
632 } else {
633 // Subcomputation's fill/stroke color is light/dark red/gray, depending on
634 // whether or not the subcomputation's fusion node is highlighted.
635 fillcolor = highlight ? "#ffcdd2" : "#f5f5f5";
636 strokecolor = highlight ? "#b71c1c" : "#c2c2c2";
637 }
638 style =
639 StrFormat(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s;")",
640 fillcolor, strokecolor);
641 } else {
642 subcomp_label = StrFormat("Subcomputation for <b>%s</b><br/>%s",
643 HtmlLikeStringSanitize(parent_instr->name()),
644 HtmlLikeStringSanitize(subcomp->name()));
645 style = "style=rounded; color=black;";
646 }
647
648 string comp_body = DumpComputation(subcomp);
649
650 constexpr char computation_fmt[] = R"(subgraph %s {
651 %s
652 label = <%s>;
653 labelloc = t;
654 tooltip = " ";
655 %s
656 } // %s
657
658 )";
659 return StrFormat(computation_fmt, id, style, subcomp_label, comp_body, id);
660 }
661
DumpComputation(const HloComputation * comp)662 string HloDotDumper::DumpComputation(const HloComputation* comp) {
663 string g;
664 for (const auto* instr : comp->instructions()) {
665 if (!filter_.Show(instr)) {
666 continue;
667 }
668
669 // Dump subcomputations within instr.
670 for (const HloComputation* subcomp : instr->called_computations()) {
671 if (ShouldShowSubcomputation(subcomp)) {
672 StrAppend(&g, DumpSubcomputation(subcomp, instr));
673 }
674 }
675
676 StrAppend(&g, DumpInstruction(instr));
677 }
678 return g;
679 }
680
DumpRootTag()681 string HloDotDumper::DumpRootTag() {
682 const HloInstruction* from = GetNodeForEdge(computation_->root_instruction());
683
684 // We didn't display constants as separate nodes; so if the root is a
685 // constant, we don't add root tag or edge for it.
686 if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant) {
687 return "";
688 }
689
690 auto from_id = InstructionId(from);
691
692 // The ID of the root computation is otherwise unused, so it makes a good ID
693 // to use for the root-tag node. However, the edge_ids_ map requires a
694 // HloInstruction* pointer for the 'to' value, so we use a NULL value there
695 // (rather than a pointer type-cast) to make it obvious if it is erroneously
696 // dereferenced.
697 HloInstruction* to = nullptr;
698 auto to_id = SubcomputationId(computation_);
699
700 string node_body = "ROOT";
701 string node_shape = "circle";
702 ColorScheme color = kBrown;
703
704 VLOG(2) << "Adding root tag as node " << next_node_id_;
705 root_node_id_ = next_node_id_++;
706
707 VLOG(2) << "Adding edge from " << from->name() << " to root tag as "
708 << next_edge_id_;
709 edge_ids_.insert({{from, to}, next_edge_id_++});
710 edges_.push_back(StrFormat(R"(%s -> %s [tooltip=" "];)", from_id, to_id));
711
712 return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)"
713 "\n",
714 to_id, node_body, node_shape, NodeColorAttributes(color));
715 }
716
TryGetFusionParameterConstant(const HloInstruction * instr)717 static const HloConstantInstruction* TryGetFusionParameterConstant(
718 const HloInstruction* instr) {
719 if (instr->opcode() != HloOpcode::kParameter || !instr->IsFused()) {
720 return nullptr;
721 }
722 const HloInstruction* fusion = instr->parent()->FusionInstruction();
723 const HloInstruction* operand = fusion->operand(instr->parameter_number());
724 return DynCast<HloConstantInstruction>(operand);
725 }
726
ShouldMergeIntoUsers(const HloInstruction * instr) const727 bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const {
728 // If a node:
729 //
730 // - is a parameter of a fusion node which is bound to a constant,
731 //
732 // or
733 //
734 // - is a tuple-shaped parameter, and
735 // - is not a parameter to a fusion node, and
736 // - has at least kMinUsersToOmit users shown, and
737 // - all of the shown users are get-tuple-elements,
738 //
739 // then we omit it from the graph, merging it with its users.
740 //
741 // This helps us handle the common case where a while loop body has one big
742 // tuple-shaped parameter.
743 if (TryGetFusionParameterConstant(instr) != nullptr) {
744 return true;
745 }
746 const int kMinUsersToOmit = 3;
747 return instr->opcode() == HloOpcode::kParameter && instr->shape().IsTuple() &&
748 !instr->IsFused() &&
749 absl::c_count_if(instr->users(),
750 [&](const HloInstruction* user) {
751 return filter_.Show(user);
752 }) > kMinUsersToOmit &&
753 absl::c_all_of(instr->users(), [&](const HloInstruction* user) {
754 return !filter_.Show(user) ||
755 user->opcode() == HloOpcode::kGetTupleElement;
756 });
757 }
758
DumpInstruction(const HloInstruction * instr)759 string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
760 // We don't display constants as separate nodes; they're merged into their
761 // users.
762 if (instr->opcode() == HloOpcode::kConstant) {
763 return "";
764 }
765 // Skip this node if it's merged into its users.
766 if (ShouldMergeIntoUsers(instr)) {
767 return "";
768 }
769 // Omit the fusion node if its subcomputation is drawn, since the
770 // subcomputation will be drawn inline.
771 if (instr->opcode() == HloOpcode::kFusion &&
772 ShouldShowFusionSubcomputation(instr)) {
773 return "";
774 }
775
776 VLOG(2) << "Adding node " << instr->name() << " as " << next_node_id_;
777 node_ids_[instr] = next_node_id_++;
778
779 ColorScheme color = GetInstructionColor(instr);
780 string node_shape = GetInstructionNodeShape(instr);
781 string node_label = GetInstructionNodeLabel(instr);
782 string node_metadata = GetInstructionNodeMetadata(instr);
783 string node_backend_config = GetInstructionNodeBackendConfig(instr);
784 string extra_info = GetInstructionNodeExtraInfo(instr);
785 string inlined_constants = GetInstructionNodeInlinedOperands(instr);
786 string trivial_subcomputation = GetInstructionTrivialComputationStr(instr);
787 AddInstructionIncomingEdges(instr);
788
789 if (!debug_options_.xla_hlo_graph_sharding_color()) {
790 // Override the node's styling if it should be (de-)emphasized.
791 if (filter_.Deemphasized(instr)) {
792 color = kDashedBorder;
793 }
794 if (filter_.Highlight(instr)) {
795 node_shape = "diamond";
796 color = kDarkRed;
797 }
798 }
799 // Build the text that will be displayed inside the node.
800 string node_body = node_label;
801 for (const string& s : {trivial_subcomputation, node_backend_config,
802 extra_info, inlined_constants}) {
803 if (!s.empty()) {
804 StrAppend(&node_body, "<br/>", s);
805 }
806 }
807
808 return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)"
809 "\n",
810 InstructionId(instr), node_body, node_shape, node_metadata,
811 NodeColorAttributes(color));
812 }
813
GetInstructionNodeInlinedOperands(const HloInstruction * instr)814 string HloDotDumper::GetInstructionNodeInlinedOperands(
815 const HloInstruction* instr) {
816 auto stringify_constant = [](const HloConstantInstruction* constant) {
817 const auto& shape = constant->shape();
818
819 // If the shape has a dimension of size zero, print it as e.g.
820 // "{} (f32[42, 0, 10])". The alternative, calling Literal::ToString(),
821 // enumerates all of its empty dimensions (e.g. "{ { {}, {} }, ..."), which
822 // is just noise.
823 if (ShapeUtil::IsZeroElementArray(shape)) {
824 return StrFormat("{} (%s)", ShapeUtil::HumanString(constant->shape()));
825 }
826
827 // Print the literal value of constants with <= K elements.
828 optional<int64> elem_count;
829 if (shape.IsArray()) {
830 elem_count = 1;
831 for (int64 dim : shape.dimensions()) {
832 *elem_count *= dim;
833 }
834 }
835 // Allow HloDotDumper to print HloInstruction reconstructed from HloProto
836 // collected from profiling tools. Those constants may not have a valid
837 // literal.
838 if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) {
839 return constant->literal().ToString();
840 }
841
842 // Otherwise, print e.g. "%constant.42 (s32[100])".
843 string constant_name;
844 if (absl::StartsWith(constant->name(), "constant")) {
845 constant_name = constant->name();
846 } else {
847 constant_name = StrCat("constant ", constant->name());
848 }
849 return StrFormat("%s %s", constant_name,
850 ShapeUtil::HumanString(constant->shape()));
851 };
852
853 std::vector<string> lines;
854 for (int64 i = 0; i < instr->operand_count(); ++i) {
855 const HloInstruction* operand = instr->operand(i);
856 const auto* constant_operand = DynCast<HloConstantInstruction>(operand);
857 optional<string> operand_str;
858 if (constant_operand != nullptr) {
859 operand_str = stringify_constant(constant_operand);
860 } else if (ShouldMergeIntoUsers(operand)) {
861 // Special case: If the operand is a parameter to a fusion node and it
862 // always has a constant value, display it like a regular constant.
863 //
864 // For other parameters, use the parameter number rather than the proper
865 // name, because that's generally how people think of the node.
866 if (operand->opcode() == HloOpcode::kParameter) {
867 if (const HloConstantInstruction* constant =
868 TryGetFusionParameterConstant(operand)) {
869 operand_str = stringify_constant(constant);
870 } else {
871 operand_str = StrFormat("Parameter %d", operand->parameter_number());
872 }
873 } else {
874 operand_str = operand->name();
875 }
876 }
877
878 if (operand_str) {
879 if (instr->operand_count() > 1) {
880 lines.push_back(StrFormat("<b>operand %d</b> = %s", i, *operand_str));
881 } else {
882 lines.push_back(StrFormat("<b>operand</b> = %s", *operand_str));
883 }
884 }
885 }
886 return StrJoin(lines, "<br/>");
887 }
888
GetInstructionColor(const HloInstruction * instr)889 ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
890 if (debug_options_.xla_hlo_graph_sharding_color()) {
891 if (!instr->has_sharding()) {
892 return kDashedBorder;
893 }
894 auto it = sharding_colors_.find(instr->sharding());
895 if (it != sharding_colors_.end()) {
896 return it->second;
897 }
898 ColorScheme color = static_cast<ColorScheme>(
899 kBlue + (next_shard_color_++ % (kDashedBorder - kBlue)));
900 sharding_colors_.emplace(instr->sharding(), color);
901 return color;
902 }
903
904 // Choose different weights of orange for small vs large parameters. This
905 // distinction is often important, especially in fusion nodes.
906 auto parameter_color = IsSmall(instr) ? kOrange : kDarkOrange;
907
908 // Special case: If this instruction has a parameter merged into it, paint it
909 // the same color as a parameter. Unless the merged-in parameter is a
910 // parameter to a fusion node that is bound to a constant -- these aren't
911 // "real" parameters from the user's perspective.
912 if (absl::c_any_of(instr->operands(), [&](const HloInstruction* operand) {
913 return operand->opcode() == HloOpcode::kParameter &&
914 ShouldMergeIntoUsers(operand) &&
915 TryGetFusionParameterConstant(operand) == nullptr;
916 })) {
917 return parameter_color;
918 }
919
920 // Pick different colors or shapes for instructions which are particularly
921 // expensive (eg, dot) and those which are unusual in some way or unique
922 // (eg, parameter).
923 switch (instr->opcode()) {
924 case HloOpcode::kAbs:
925 case HloOpcode::kAdd:
926 case HloOpcode::kAnd:
927 case HloOpcode::kAtan2:
928 case HloOpcode::kBitcastConvert:
929 case HloOpcode::kCeil:
930 case HloOpcode::kClamp:
931 case HloOpcode::kClz:
932 case HloOpcode::kCompare:
933 case HloOpcode::kComplex:
934 case HloOpcode::kConvert:
935 case HloOpcode::kCos:
936 case HloOpcode::kDivide:
937 case HloOpcode::kExp:
938 case HloOpcode::kExpm1:
939 case HloOpcode::kFloor:
940 case HloOpcode::kImag:
941 case HloOpcode::kIota:
942 case HloOpcode::kIsFinite:
943 case HloOpcode::kLog:
944 case HloOpcode::kLog1p:
945 case HloOpcode::kMaximum:
946 case HloOpcode::kMinimum:
947 case HloOpcode::kMultiply:
948 case HloOpcode::kNegate:
949 case HloOpcode::kNot:
950 case HloOpcode::kOr:
951 case HloOpcode::kXor:
952 case HloOpcode::kPower:
953 case HloOpcode::kReal:
954 case HloOpcode::kRemainder:
955 case HloOpcode::kRng:
956 case HloOpcode::kRoundNearestAfz:
957 case HloOpcode::kRsqrt:
958 case HloOpcode::kSelect:
959 case HloOpcode::kShiftLeft:
960 case HloOpcode::kShiftRightArithmetic:
961 case HloOpcode::kShiftRightLogical:
962 case HloOpcode::kSign:
963 case HloOpcode::kSin:
964 case HloOpcode::kSlice:
965 case HloOpcode::kSort:
966 case HloOpcode::kSqrt:
967 case HloOpcode::kSubtract:
968 case HloOpcode::kTanh:
969 // De-emphasize scalar-shaped elementwise ops -- they're generally
970 // uninteresting.
971 if (ShapeUtil::IsEffectiveScalar(instr->shape())) {
972 return kWhite;
973 }
974 return kYellow;
975 case HloOpcode::kBitcast:
976 case HloOpcode::kGetTupleElement:
977 case HloOpcode::kTrace:
978 case HloOpcode::kAfterAll:
979 case HloOpcode::kAddDependency:
980 case HloOpcode::kTuple:
981 return kWhite;
982 case HloOpcode::kBroadcast:
983 // De-emphasize nodes which broadcast a scalar within a fusion node --
984 // these are essentially free.
985 if (instr->IsFused() &&
986 ShapeUtil::IsEffectiveScalar(instr->operand(0)->shape())) {
987 return kWhite;
988 }
989 return kGreen;
990 case HloOpcode::kConcatenate:
991 case HloOpcode::kDynamicSlice:
992 case HloOpcode::kGather:
993 case HloOpcode::kPad:
994 case HloOpcode::kReshape:
995 case HloOpcode::kReverse:
996 case HloOpcode::kTupleSelect:
997 case HloOpcode::kTranspose:
998 // De-emphasize scalar-shaped data movement ops and all data movement ops
999 // inside fusion nodes, both of which are essentially free.
1000 if (ShapeUtil::IsEffectiveScalar(instr->shape()) || instr->IsFused()) {
1001 return kWhite;
1002 }
1003 return kGreen;
1004 case HloOpcode::kDynamicUpdateSlice:
1005 // Unlike the data-movement ops above, dynamic-update-slice is not ~free
1006 // inside of fusion nodes, so we de-emphasize it only if it's
1007 // scalar-shaped.
1008 if (ShapeUtil::IsEffectiveScalar(instr->shape())) {
1009 return kWhite;
1010 }
1011 return kGreen;
1012 case HloOpcode::kScatter:
1013 // Do not de-emphasize Scatter, since it involves significant work.
1014 case HloOpcode::kCopy:
1015 // Emphasize copy nodes, which are either physical transposes (and thus
1016 // significant), or copies of read-only buffers (and thus dead weight).
1017 return kGreen;
1018 case HloOpcode::kConvolution:
1019 case HloOpcode::kDot:
1020 case HloOpcode::kFft:
1021 case HloOpcode::kTriangularSolve:
1022 case HloOpcode::kCholesky:
1023 return kDarkBlue;
1024 case HloOpcode::kReducePrecision:
1025 return kRed;
1026 case HloOpcode::kParameter:
1027 return parameter_color;
1028 case HloOpcode::kBatchNormGrad:
1029 case HloOpcode::kBatchNormInference:
1030 case HloOpcode::kBatchNormTraining:
1031 case HloOpcode::kReduce:
1032 case HloOpcode::kReduceWindow:
1033 case HloOpcode::kSelectAndScatter:
1034 return kPurple;
1035 case HloOpcode::kDomain:
1036 case HloOpcode::kFusion:
1037 case HloOpcode::kMap:
1038 case HloOpcode::kGetDimensionSize:
1039 return kGray;
1040 case HloOpcode::kAllReduce:
1041 case HloOpcode::kAllToAll:
1042 case HloOpcode::kCollectivePermute:
1043 case HloOpcode::kInfeed:
1044 case HloOpcode::kOutfeed:
1045 case HloOpcode::kRecv:
1046 case HloOpcode::kRecvDone:
1047 case HloOpcode::kSend:
1048 case HloOpcode::kSendDone:
1049 case HloOpcode::kReplicaId:
1050 return kBrown;
1051 case HloOpcode::kCall:
1052 case HloOpcode::kConditional:
1053 case HloOpcode::kCustomCall:
1054 case HloOpcode::kWhile:
1055 return kDarkGreen;
1056 case HloOpcode::kConstant:
1057 LOG(FATAL) << "Constants don't get their own nodes in the graph.";
1058 }
1059 }
1060
GetInstructionNodeShape(const HloInstruction * instr)1061 string HloDotDumper::GetInstructionNodeShape(const HloInstruction* instr) {
1062 // Give while loops a different shape so they're easier to pick out.
1063 switch (instr->opcode()) {
1064 case HloOpcode::kWhile:
1065 return "ellipse";
1066 default:
1067 return "rect";
1068 }
1069 }
1070
GetInstructionNodeLabel(const HloInstruction * instr)1071 string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) {
1072 // If we have a parameter, put the param number in the name.
1073 if (instr->opcode() == HloOpcode::kParameter) {
1074 return StrFormat("<b>Parameter %d</b>", instr->parameter_number());
1075 }
1076
1077 // The HLO instruction name contains usually the opcode, e.g. "%add.42" is
1078 // an add instruction. In this case we render just the name.
1079 if (absl::StartsWith(instr->name(), HloOpcodeString(instr->opcode()))) {
1080 return StrFormat("<b>%s</b>", HtmlLikeStringSanitize(instr->name()));
1081 }
1082 string extended_opcode =
1083 StrCat(HloOpcodeString(instr->opcode()),
1084 instr->opcode() != HloOpcode::kFusion
1085 ? ""
1086 : StrCat(":", xla::ToString(instr->fusion_kind())));
1087 // If the name does not contain the opcode, render both.
1088 return StrFormat("<b>%s</b><br/>%s", HtmlLikeStringSanitize(extended_opcode),
1089 HtmlLikeStringSanitize(instr->name()));
1090 }
1091
GetInstructionNodeMetadata(const HloInstruction * instr)1092 string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
1093 std::vector<string> lines;
1094 if (!instr->metadata().op_name().empty()) {
1095 lines.push_back(HtmlLikeStringSanitize(instr->metadata().op_name()));
1096 }
1097 if (!instr->metadata().op_type().empty()) {
1098 lines.push_back(StrFormat(
1099 "op_type: %s", HtmlLikeStringSanitize(instr->metadata().op_type())));
1100 }
1101 if (!instr->metadata().source_file().empty() &&
1102 instr->metadata().source_line() != 0) {
1103 lines.push_back(StrFormat("op_type: %s:%d", instr->metadata().source_file(),
1104 instr->metadata().source_line()));
1105 }
1106
1107 return StrJoin(lines, "\n");
1108 }
1109
GetInstructionNodeBackendConfig(const HloInstruction * instr)1110 string HloDotDumper::GetInstructionNodeBackendConfig(
1111 const HloInstruction* instr) {
1112 if (!show_backend_config_ || instr->raw_backend_config_string().empty()) {
1113 return "";
1114 }
1115
1116 return StrCat("backend_config=\"", instr->raw_backend_config_string(), "\"");
1117 }
1118
GetInstructionNodeExtraInfo(const HloInstruction * instr)1119 string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
1120 std::vector<string> lines;
1121
1122 // Get the instruction's extra attributes excluding the names of its
1123 // subcomputations, since those are drawn explicitly in the graph.
1124 for (const auto& line : instr->ExtraAttributesToString(
1125 HloPrintOptions().set_print_subcomputation_mode(
1126 HloPrintOptions::PrintSubcomputationMode::kOff))) {
1127 lines.push_back(HtmlLikeStringSanitize(line));
1128 }
1129
1130 // Show the shape and layout of the instruction, unless it's an inlined fusion
1131 // node -- there the shape and layout is present in the output node.
1132 if (instr->opcode() != HloOpcode::kFusion ||
1133 !ShouldShowFusionSubcomputation(instr)) {
1134 // Show layout of instructions with more than one dimension. Don't show
1135 // layout on tuples or tensors with just one dimension (which only have one
1136 // possible layout) to avoid visual noise.
1137 bool shape_is_multidim = false;
1138 ShapeUtil::ForEachSubshape(instr->shape(),
1139 [&](const Shape& s, const ShapeIndex&) {
1140 shape_is_multidim |= s.dimensions_size() > 1;
1141 });
1142 string instr_shape;
1143 if (instr->opcode() != HloOpcode::kTuple && shape_is_multidim) {
1144 instr_shape = ShapeUtil::HumanStringWithLayout(instr->shape());
1145 } else {
1146 instr_shape = ShapeUtil::HumanString(instr->shape());
1147 }
1148
1149 // Some instructions have giant tuples as their shapes, so truncate the
1150 // HLO's shape to kMaxShapeLen characters.
1151 constexpr int kMaxShapeLen = 64;
1152 if (instr_shape.length() > kMaxShapeLen) {
1153 instr_shape = StrCat(
1154 absl::string_view(instr_shape).substr(0, kMaxShapeLen - 3), "...");
1155 }
1156 lines.push_back(instr_shape);
1157 }
1158 if (debug_options_.xla_hlo_graph_addresses()) {
1159 lines.push_back(StrFormat("[%p]", instr));
1160 }
1161 if (profile_ != nullptr) {
1162 double hlo_cycles_executed = profile_->GetCyclesTakenBy(*instr);
1163 double total_cycles_executed =
1164 profile_->total_cycles_executed(*instr->parent());
1165 if (hlo_cycles_executed > 0 && total_cycles_executed > 0) {
1166 lines.push_back(
1167 StrFormat("%% of cycles executed=%.2f",
1168 100 * hlo_cycles_executed / total_cycles_executed));
1169 }
1170 }
1171 return StrJoin(lines, "<br/>");
1172 }
1173
AddInstructionIncomingEdges(const HloInstruction * instr)1174 void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
1175 auto add_edge = [&](const HloInstruction* from, const HloInstruction* to,
1176 int64 operand_num, bool control_edge = false) {
1177 from = GetNodeForEdge(from);
1178
1179 if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant ||
1180 ShouldMergeIntoUsers(from)) {
1181 return;
1182 }
1183 VLOG(2) << "Adding edge from " << from->name() << " to " << to->name()
1184 << " as " << next_edge_id_;
1185 edge_ids_.insert({{from, to}, next_edge_id_++});
1186
1187 string edge_label;
1188 if (instr->operand_count() > 1 && !control_edge) {
1189 edge_label =
1190 StrFormat(R"( headlabel="%d", labeldistance=2)", operand_num);
1191 } else if (control_edge) {
1192 edge_label = "style=\"dotted\" color=\"gray\" label=\"ctrl\"";
1193 }
1194
1195 // We print "small" arrays using a hollow arrowhead and "large" arrays using
1196 // a filled arrowhead.
1197 constexpr char kEdgeFmt[] =
1198 R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)";
1199 edges_.push_back(StrFormat(kEdgeFmt, InstructionId(from), InstructionId(to),
1200 (IsSmall(from) ? "empty" : "normal"),
1201 from->name(), to->name(), edge_label));
1202 };
1203
1204 // Add edges from instr's operands to instr. Parameters within fusion
1205 // expressions are handled specially -- we draw an edge from the corresponding
1206 // operand on the fusion node itself to the parameter.
1207 if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) {
1208 // Only add the edge if this is not the outermost computation; otherwise it
1209 // will lead from a node we're not drawing.
1210 if (instr->parent() != computation_) {
1211 const HloInstruction* fusion = instr->parent()->FusionInstruction();
1212 add_edge(fusion->operand(instr->parameter_number()), instr,
1213 /*operand_num=*/0);
1214 }
1215 } else {
1216 for (int64 i = 0; i < instr->operand_count(); ++i) {
1217 add_edge(instr->operand(i), instr, i);
1218 }
1219 for (const HloInstruction* pred : instr->control_predecessors()) {
1220 add_edge(pred, instr, /*operand_num=*/0, /*control_edge=*/true);
1221 }
1222 }
1223 }
1224
GetInstructionTrivialComputationStr(const HloInstruction * instr)1225 string HloDotDumper::GetInstructionTrivialComputationStr(
1226 const HloInstruction* instr) {
1227 // called_computations() on a fusion node "inherits" any called computations
1228 // of the fused root, which isn't what we want. Just ignore fusion nodes
1229 // here; they're handled separately.
1230 if (instr->opcode() == HloOpcode::kFusion) {
1231 return "";
1232 }
1233
1234 std::vector<string> lines;
1235 for (int64 i = 0; i < instr->called_computations().size(); ++i) {
1236 optional<string> computation_type =
1237 MatchTrivialComputation(instr->called_computations()[i]);
1238 if (!computation_type) {
1239 continue;
1240 }
1241 if (instr->called_computations().size() == 1) {
1242 lines.push_back(StrFormat("Subcomputation: <b>%s</b>",
1243 HtmlLikeStringSanitize(*computation_type)));
1244 } else {
1245 lines.push_back(StrFormat("Subcomputation %d: <b>%s</b>", i,
1246 HtmlLikeStringSanitize(*computation_type)));
1247 }
1248 }
1249 return StrJoin(lines, "<br/>");
1250 }
1251
GetNodeForEdge(const HloInstruction * instr)1252 const HloInstruction* HloDotDumper::GetNodeForEdge(
1253 const HloInstruction* instr) {
1254 while (instr->opcode() == HloOpcode::kFusion &&
1255 ShouldShowFusionSubcomputation(instr)) {
1256 instr = instr->fused_expression_root();
1257 }
1258 return instr;
1259 }
1260
1261 // Gets a NodeFilter that includes roughly all instructions whose distance from
1262 // root is <= radius.
MakeNodeRadiusAroundFilter(const HloInstruction * root,int64 radius,const absl::flat_hash_set<const HloInstruction * > & boundary)1263 NodeFilter MakeNodeRadiusAroundFilter(
1264 const HloInstruction* root, int64 radius,
1265 const absl::flat_hash_set<const HloInstruction*>& boundary) {
1266 // First, find the neighborhood of nodes with distance from root <= radius.
1267 // These nodes are our initial set of "normal" nodes.
1268 absl::flat_hash_map<const HloInstruction*, NodeFilterResult> nodes;
1269 std::deque<std::pair<const HloInstruction*, /*depth*/ int64>> worklist;
1270 worklist.push_back({root, 0});
1271 while (!worklist.empty()) {
1272 const HloInstruction* instr;
1273 int64 depth;
1274 std::tie(instr, depth) = worklist.front();
1275 worklist.pop_front();
1276
1277 nodes[instr] = kNormalNode;
1278 if (depth == radius) {
1279 continue;
1280 }
1281 if (boundary.contains(instr)) {
1282 continue;
1283 }
1284
1285 // Traverse into instr's operands.
1286 //
1287 // Don't traverse into tuples' operands unless the tuple is the root.
1288 // Usually a tuple is the bottommost node in the graph, and so its operands
1289 // are not interesting to the graph at hand.
1290 if (instr == root || instr->opcode() != HloOpcode::kTuple) {
1291 for (const HloInstruction* operand : instr->operands()) {
1292 if (!nodes.contains(operand)) {
1293 worklist.push_back({operand, depth + 1});
1294 }
1295 }
1296 }
1297
1298 // Traverse into instr's nested computations.
1299 for (const HloComputation* computation : instr->called_computations()) {
1300 worklist.push_back({computation->root_instruction(), depth + 1});
1301 }
1302
1303 // Traverse into instr's users, unless:
1304 //
1305 // - there are a ton of them, in which case they're probably not
1306 // interesting (and anyway, rendering them all would make the graph
1307 // unreadable), or
1308 // - instr is a constant, in which case its users are probably not
1309 // interesting.
1310 if (instr->opcode() == HloOpcode::kConstant) {
1311 continue;
1312 }
1313 constexpr int kMaxUsersToRender = 16;
1314 if (instr->user_count() > kMaxUsersToRender) {
1315 // If we're going to skip this node's users, style it as such.
1316 nodes[instr] = kSomeUsersOmitted;
1317 continue;
1318 }
1319 for (const HloInstruction* user : instr->users()) {
1320 if (!nodes.contains(user)) {
1321 worklist.push_back({user, depth + 1});
1322 }
1323 }
1324 }
1325
1326 auto is_displayed = [&](const HloInstruction* instr) {
1327 // Constants are displayed inline with their users; they're never omitted.
1328 // Nodes in subcomputations are always shown.
1329 return nodes.contains(instr) || instr->opcode() == HloOpcode::kConstant ||
1330 instr->parent() != root->parent();
1331 };
1332
1333 // Make a second pass over 'nodes' to fix up the NodeFilterResults now that we
1334 // know which nodes will be included in the graph.
1335 for (auto& kv : nodes) {
1336 const HloInstruction* instr = kv.first;
1337 NodeFilterResult& filter_result = kv.second;
1338 const auto& operands = instr->operands();
1339
1340 if (absl::c_any_of(operands, is_displayed) &&
1341 !absl::c_all_of(operands, is_displayed)) {
1342 // Mark nodes with some operands omitted appropriately.
1343 filter_result = kSomeOperandsOmitted;
1344 } else if (!operands.empty() && absl::c_none_of(operands, is_displayed)) {
1345 // Mark nodes with *all* operands omitted appropriately.
1346 filter_result = kOmitNodeOperands;
1347 }
1348
1349 // Promote nodes with type kSomeUsersOmitted to kNormalNode if all of their
1350 // users made it into the graph.
1351 if (filter_result == kSomeUsersOmitted &&
1352 absl::c_all_of(instr->users(), is_displayed)) {
1353 filter_result = kNormalNode;
1354 }
1355 }
1356
1357 // Highlight the root node.
1358 nodes[root] = kHighlightNode;
1359
1360 return NodeFilter([=](const HloInstruction* instr) {
1361 auto it = nodes.find(instr);
1362 if (it != nodes.end()) {
1363 return it->second;
1364 }
1365 // Show all nodes in subcomputations.
1366 if (instr->parent() != root->parent()) {
1367 return kNormalNode;
1368 }
1369 return kHideNode;
1370 });
1371 }
1372
1373 // Gets a node filter that includes nodes on all paths from `from` to `to`. If
1374 // the all-paths set contains more than max_nodes elements, includes the nodes
1375 // on the shortest paths and sets hit_limit to true.
MakeNodeFromToFilter(const HloInstruction * from,const HloInstruction * to,int64 max_nodes,bool * hit_limit)1376 NodeFilter MakeNodeFromToFilter(const HloInstruction* from,
1377 const HloInstruction* to, int64 max_nodes,
1378 bool* hit_limit) {
1379 *hit_limit = false;
1380
1381 // Elements in the queue are paths through the graph.
1382 std::deque<std::vector<const HloInstruction*>> queue;
1383 queue.push_front({from});
1384
1385 // Compute the set of nodes we want to show using a slightly-modified
1386 // Djikstra's algorithm. The only real difference is, rather than stopping
1387 // when we find a (shortest) path, we continue until we've found max_nodes
1388 // nodes on some path.
1389 std::unordered_set<const HloInstruction*> visited;
1390 std::unordered_set<const HloInstruction*> to_display = {from, to};
1391 while (!queue.empty() && to_display.size() < max_nodes) {
1392 std::vector<const HloInstruction*> path = std::move(queue.front());
1393 queue.pop_front();
1394 if (!visited.insert(path.back()).second) {
1395 continue;
1396 }
1397
1398 for (const auto* user : path.back()->users()) {
1399 if (user == to) {
1400 auto it = path.begin();
1401 for (; it != path.end() && to_display.size() < max_nodes; ++it) {
1402 to_display.insert(*it);
1403 }
1404 if (it != path.end()) {
1405 *hit_limit = true;
1406 }
1407 } else if (!visited.count(user)) {
1408 auto new_path = path;
1409 new_path.push_back(user);
1410 queue.push_back(std::move(new_path));
1411 }
1412 }
1413 }
1414
1415 return NodeFilter([=](const HloInstruction* instr) {
1416 if (instr == from || instr == to) {
1417 return kHighlightNode;
1418 }
1419 return to_display.count(instr) ? kNormalNode : kHideNode;
1420 });
1421 }
1422
WrapDotInHtml(absl::string_view dot)1423 string WrapDotInHtml(absl::string_view dot) {
1424 static const char html_prefix[] = R"html(
1425 <!DOCTYPE html>
1426 <html>
1427 <head>
1428 <meta charset="utf-8">
1429 <style type="text/css">
1430 body {
1431 height: 100vh;
1432 margin: 0;
1433 }
1434 </style>
1435 </head>
1436 <body>
1437 <!-- Integrity hash is generated by https://www.srihash.org/ -->
1438 <script src="https://cdn.jsdelivr.net/npm/viz.js@2.1.1/viz.js"
1439 integrity="sha384-aD1MJYb0WKIUT+CtwJp5LTuV3U4pLAS6B/nUxL7ECimC2pN9N8vjlMr/yQCAkzxE"
1440 crossorigin="anonymous"></script>
1441 <script src="https://cdn.jsdelivr.net/npm/viz.js@2.1.1/full.render.js"
1442 integrity="sha384-bAixY275aIpCj6Te19y0MILZ4V+VEC8CVFujFEH+Lf7W+4XYYeYLwW5IBI6yQmMT"
1443 crossorigin="anonymous"></script>
1444 <script src="https://cdn.jsdelivr.net/npm/svg-pan-zoom@3.6.0/dist/svg-pan-zoom.min.js"
1445 integrity="sha384-3008WpYB2pOBvE7lwkrKf+qTmbTPGGPYxA9C1YVhvbPukns4ZFj7E98QPLkNW9dS"
1446 crossorigin="anonymous"></script>
1447 <div id="container" style="height:95vh; border:1px solid black; "></div>
1448 <script>
1449 var data = `
1450 )html";
1451
1452 static const char html_suffix[] = R"html(
1453 `;
1454 var cssregex = new RegExp('stylesheet=<([^]*)\n>\n', 'gm');
1455 var results = cssregex.exec(data)
1456 // graphviz has problem dealing with large stylesheets.
1457 // https://github.com/tensorflow/tensorflow/issues/17220#issuecomment-369228492
1458 // In order to avoid the problem, remove the stylesheet from the dot and
1459 // insert it directly info the rendered SVG.
1460 var dot_data = data;
1461 var css_data = ''
1462 if (results !== null) {
1463 css_data = results[1].replace(/\s*data:.*\s*,/,''); // Strip content-type field.
1464 // CSS inside DOT is URL-escaped, so we must unescape it
1465 // before we can insert it into SVG.
1466 css_data = unescape(css_data);
1467 dot_data = data.replace(cssregex, ''); // Remove the stylesheet
1468 }
1469
1470 var render_start = performance.now()
1471 function add_controls(svg) {
1472 var htmlblob = new Blob([document.documentElement.innerHTML],
1473 {type: 'text/html'});
1474 var savehtml = document.createElement('a');
1475 savehtml.setAttribute('href', URL.createObjectURL(htmlblob));
1476 savehtml.setAttribute('download', 'graph.html');
1477 savehtml.innerHTML = " [Save HTML+SVG] ";
1478 document.body.append(savehtml);
1479 var svgblob = new Blob([svg.outerHTML], {type: 'image/svg'});
1480 var savesvg = document.createElement('a');
1481 savesvg.setAttribute('href', URL.createObjectURL(svgblob));
1482 savesvg.setAttribute('download', 'graph.svg');
1483 savesvg.innerHTML = " [Save SVG] ";
1484 document.body.append(savesvg);
1485 var dotblob = new Blob([data], {type: 'text/dot'});
1486 var savedot = document.createElement('a');
1487 savedot.setAttribute('href', URL.createObjectURL(dotblob));
1488 savedot.setAttribute('download', 'graph.dot');
1489 savedot.innerHTML = " [Save DOT] ";
1490 document.body.append(savedot);
1491 // Will get called after embed element was loaded
1492 var panzoom = svgPanZoom(svg, {
1493 zoomEnabled: true,
1494 controlIconsEnabled: true,
1495 });
1496 document.getElementsByTagName("BODY")[0].onresize = function() {
1497 panzoom.resize();
1498 panzoom.fit();
1499 panzoom.center();
1500 };
1501 var render_end = performance.now();
1502 var render_note = document.createElement('div')
1503 render_note.innerHTML = 'Rendering took '
1504 + (render_end - render_start).toFixed(2) + "ms."
1505 document.body.append(render_note);
1506 }
1507 var svg = document.getElementById('graph')
1508 if (svg == null) {
1509 // Need to render SVG first.
1510 var viz = new Viz();
1511 viz.renderSVGElement(dot_data)
1512 .then(function(svg){
1513 var container = document.getElementById('container')
1514 var style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
1515 var node = document.createTextNode(css_data);
1516 style.appendChild(node);
1517 svg.setAttribute('width', '100%');
1518 svg.setAttribute('height', '100%');
1519 svg.setAttribute('id', 'graph');
1520 svg.appendChild(style);
1521 container.appendChild(svg);
1522 add_controls(svg);
1523 })
1524 } else {
1525 // HTML already has rendered SVG embedded, so we just need to add
1526 // controls.
1527 add_controls(svg);
1528 }
1529 </script>
1530 </body>
1531 </html>
1532 )html";
1533
1534 return absl::StrCat(html_prefix, dot, html_suffix);
1535 }
1536
1537 tensorflow::mutex url_renderer_mu(tensorflow::LINKER_INITIALIZED);
1538 std::function<StatusOr<string>(absl::string_view)>* url_renderer
1539 GUARDED_BY(url_renderer_mu) = nullptr;
1540
1541 // Precondition: url_renderer != nullptr.
1542 //
1543 // (We specify this as a precondition rather than checking it in here and
1544 // returning an error because we want to fail quickly when there's no URL
1545 // renderer available, and this function runs only after we've done all the work
1546 // of producing dot for the graph.)
1547 StatusOr<string> WrapDotInFormat(absl::string_view dot,
1548 RenderedGraphFormat format)
1549 EXCLUSIVE_LOCKS_REQUIRED(url_renderer_mu) {
1550 switch (format) {
1551 case RenderedGraphFormat::kUrl:
1552 CHECK(url_renderer != nullptr)
1553 << "Should have checked url_renderer != null before calling.";
1554 return (*url_renderer)(dot);
1555 case RenderedGraphFormat::kHtml:
1556 return WrapDotInHtml(dot);
1557 case RenderedGraphFormat::kDot:
1558 return string(dot);
1559 }
1560 }
1561
1562 } // namespace
1563
1564 void RegisterGraphToURLRenderer(
1565 std::function<StatusOr<string>(absl::string_view)> renderer) {
1566 tensorflow::mutex_lock lock(url_renderer_mu);
1567 if (url_renderer != nullptr) {
1568 LOG(WARNING) << "Multiple calls to RegisterGraphToURLRenderer. Last call "
1569 "wins, but because order of initialization in C++ is "
1570 "nondeterministic, this may not be what you want.";
1571 }
1572 delete url_renderer;
1573 url_renderer = new std::function<StatusOr<string>(absl::string_view)>(
1574 std::move(renderer));
1575 }
1576
1577 StatusOr<string> RenderGraph(const HloComputation& computation,
1578 absl::string_view label,
1579 const DebugOptions& debug_options,
1580 RenderedGraphFormat format,
1581 const HloExecutionProfile* hlo_execution_profile,
1582 bool show_backend_config) {
1583 tensorflow::mutex_lock lock(url_renderer_mu);
1584 if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
1585 return Unavailable("Can't render as URL; no URL renderer was registered.");
1586 }
1587
1588 string rendered_dot =
1589 HloDotDumper(&computation, label, debug_options, show_backend_config,
1590 hlo_execution_profile, NodeFilter())
1591 .Dump();
1592 return WrapDotInFormat(rendered_dot, format);
1593 }
1594
1595 StatusOr<string> RenderNeighborhoodAround(
1596 const HloInstruction& node, int radius, RenderedGraphFormat format,
1597 bool show_backend_config,
1598 const absl::flat_hash_set<const HloInstruction*>& boundary) {
1599 tensorflow::mutex_lock lock(url_renderer_mu);
1600 if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
1601 return FailedPrecondition(
1602 "Can't render as URL; no URL renderer was registered.");
1603 }
1604
1605 string label =
1606 StrCat("Neighborhood of ", radius, " nodes around ", node.name());
1607 string rendered_dot =
1608 HloDotDumper(node.parent(), label,
1609 node.GetModule()->config().debug_options(),
1610 show_backend_config, /*profile=*/nullptr,
1611 MakeNodeRadiusAroundFilter(&node, radius, boundary))
1612 .Dump();
1613 return WrapDotInFormat(rendered_dot, format);
1614 }
1615
1616 StatusOr<string> RenderAllPathsFromTo(const HloInstruction& from,
1617 const HloInstruction& to, int64 max_nodes,
1618 RenderedGraphFormat format,
1619 bool show_backend_config) {
1620 tensorflow::mutex_lock lock(url_renderer_mu);
1621 if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
1622 return FailedPrecondition(
1623 "Can't render as URL; no URL renderer was registered.");
1624 }
1625
1626 CHECK_EQ(from.parent(), to.parent()) << "Nodes must be in same computation!";
1627 auto debug_options = from.GetModule()->config().debug_options();
1628
1629 bool hit_limit = false;
1630 NodeFilter filter = MakeNodeFromToFilter(&from, &to, max_nodes, &hit_limit);
1631 string label;
1632 if (!hit_limit) {
1633 label = StrCat("All paths from ", from.name(), " to ", to.name());
1634 } else {
1635 label = StrCat(max_nodes, " nodes on the shortest paths from ", from.name(),
1636 " to ", to.name(),
1637 "<br/><br/>***SHOWING ONLY A SUBSET OF ALL PATHS BETWEEN "
1638 "NODES***<br/><br/>");
1639 }
1640 string rendered_dot =
1641 HloDotDumper(from.parent(), label, debug_options, show_backend_config,
1642 /*profile=*/nullptr, filter)
1643 .Dump();
1644 return WrapDotInFormat(rendered_dot, format);
1645 }
1646
1647 } // namespace xla
1648