1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/dump_graphviz.h"
16
17 #include <cmath>
18 #include <functional>
19 #include <memory>
20 #include <vector>
21
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_replace.h"
24 #include "absl/strings/str_split.h"
25 #include "absl/strings/strip.h"
26 #include "re2/re2.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/lite/toco/model_flags.pb.h"
29 #include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
30 #include "tensorflow/lite/toco/toco_port.h"
31 #include "tensorflow/lite/toco/toco_types.h"
32 #include "tensorflow/lite/toco/tooling_util.h"
33
34 using toco::port::AppendF;
35 using toco::port::StringF;
36
37 namespace toco {
38 namespace {
39
40 // 'nslimit' is a graphviz (dot) parameter that limits the iterations during
41 // the layout phase. Omitting it allows infinite iterations, causing some
42 // complex graphs to never finish. A value of 125 produces good graphs
43 // while allowing complex graphs to finish.
44 constexpr char kGraphFmt[] = R"CODE(digraph Computegraph { tooltip = "/"
45 nslimit=125 margin=36 ranksep = 2 labelloc="t" label=%s
46 )CODE";
47 // Note: tooltip's are only supported on SVGs in Chrome.
48 constexpr char kSubgraphFmt[] =
49 R"CODE( subgraph "cluster_%s" { style=rounded bgcolor="%s" penwidth=0.0 label=%s
50 )CODE";
51 constexpr char kArrayNodeFmt[] =
52 R"CODE( "%s" [label=%s tooltip="%s" shape=%s style=filled fillcolor="%s" fontcolor="%sDD"];
53 )CODE";
54 constexpr char kOpNodeFmt[] =
55 R"CODE( %s [label=%s tooltip=" " shape=box margin=0 style=filled fillcolor="%s" fontcolor="%sDD"];
56 )CODE";
57 constexpr char kInputEdgeFmt[] =
58 R"CODE( "%s"%s -> %s:i%d:n [penwidth=%f weight=%f];
59 )CODE";
60 constexpr char kOutputEdgeFmt[] =
61 R"CODE( %s:o%d:s -> "%s"%s [penwidth=%f weight=%f];
62 )CODE";
63 constexpr char kRNNBackEdgeFmt[] =
64 R"CODE( "%s":s -> "%s":n [color="#0F9D58" constraint=false];
65 )CODE";
66 constexpr char kUnicodeMult[] = "\u00D7";
67 constexpr char kUnicodeEllipsis[] = " \u2026 ";
68
69 class Color {
70 public:
Color()71 Color() {}
Color(uint8 r,uint8 g,uint8 b)72 Color(uint8 r, uint8 g, uint8 b) : r_(r), g_(g), b_(b) {}
Color(uint32 word)73 explicit Color(uint32 word)
74 : r_((word & 0x00FF0000) >> 16),
75 g_((word & 0x0000FF00) >> 8),
76 b_((word & 0x000000FF) >> 0) {}
77
78 // Returns the string serialization of this color in graphviz format,
79 // for use as 'fillcolor' in boxes.
AsHexString() const80 std::string AsHexString() const {
81 return StringF("#%.2X%.2X%.2X", r_, g_, b_);
82 }
83 // The color to use for this node; will be used as 'fillcolor'
84 // for its box. See Color::AsHexString. A suitable, different
85 // color will be chosen for the 'fontcolor' for the inside text
86 // label, see Color::TextColorString.
87 // Returns the serialization in graphviz format of a suitable color to use
88 // 'fontcolor' in the same boxes. It should black or white, whichever offers
89 // the better contrast from AsHexString().
TextColorString() const90 std::string TextColorString() const {
91 // https://en.wikipedia.org/wiki/Relative_luminance
92 const float luminance = 0.2126f * r_ + 0.7152f * g_ + 0.0722f * b_;
93 const uint8 l = luminance > 128.f ? 0 : 255;
94 return StringF("#%.2X%.2X%.2X", l, l, l);
95 }
96
97 private:
98 uint8 r_ = 0, g_ = 0, b_ = 0;
99 };
100
HashStringToColor(std::string s)101 Color HashStringToColor(std::string s) {
102 // Return a unique color for a name.
103 //
104 // This function removes Tensorflow anti-collision suffixes (eg "_2"), hashes
105 // the string to a uint_32, then twiddles some bits to get a light and subtle
106 // color. This seems to be a good heuristic for keeping enough of the name to
107 // hash to a unique color while still revealing structure through naming
108 // similarities.
109 //
110 // The regular expression "_\d+" matches any underscore followed by numbers,
111 // which we strip out. Examples:
112 //
113 // "Conv" -> "Conv"
114 // "Conv_2" -> "Conv"
115 // "Conv_72" -> "Conv"
116 // "Pad_1_bias -> "Pad_bias"
117 // "Conv_abc" -> "Conv_abc"
118
119 RE2::GlobalReplace(&s, R"CODE(_\d+)CODE", "");
120 uint32 color_word = std::hash<std::string>{}(s);
121 color_word |= 0x00E0E0E0;
122 return Color(color_word);
123 }
124
GetArrayColorAndShape(const Model & model,const std::string & array_name,Color * color,std::string * shape)125 void GetArrayColorAndShape(const Model& model, const std::string& array_name,
126 Color* color, std::string* shape) {
127 // All colors in this file are from:
128 // https://material.io/guidelines/style/color.html
129 // Arrays involved in RNN back-edges have a different color
130 for (const auto& rnn_state : model.flags.rnn_states()) {
131 // RNN state, fed by a back-edge. Bold color.
132 if (array_name == rnn_state.state_array()) {
133 *color = Color(0x0F, 0x9D, 0x58);
134 *shape = "invhouse";
135 return;
136 }
137 // RNN back-edge source, feeding a RNN state.
138 // Light tone of the same color as RNN states.
139 if (array_name == rnn_state.back_edge_source_array()) {
140 *color = Color(0xB7, 0xE1, 0xCD);
141 *shape = "house";
142 return;
143 }
144 }
145 // Constant parameter arrays have their own bold color
146 if (model.GetArray(array_name).buffer) {
147 *color = Color(0x42, 0x85, 0xF4);
148 *shape = "cylinder";
149 return;
150 }
151 // Remaining arrays are activations.
152 // We use gray colors for them because they are the majority
153 // of arrays so we want to highlight other arrays instead of them.
154 // First, we use a bolder gray for input/output arrays:
155 if (IsInputArray(model, array_name)) {
156 *color = Color(0x9E, 0x9E, 0x9E);
157 *shape = "invhouse";
158 return;
159 }
160 if (IsOutputArray(model, array_name)) {
161 *color = Color(0x9E, 0x9E, 0x9E);
162 *shape = "house";
163 return;
164 }
165 // Remaining arrays are intermediate activation arrays.
166 // Lighter tone of the same grey as for input/output arrays:
167 // We want these to be very discrete.
168 *color = Color(0xF5, 0xF5, 0xF5);
169 *shape = "box";
170 }
171
GetArrayCompassPt(const Model & model,const std::string & array_name)172 std::string GetArrayCompassPt(const Model& model,
173 const std::string& array_name) {
174 // The "compass point" is the point on the node where edge connections are
175 // made. For most arrays we don't care, but input's and outputs look better
176 // connected at the tip of the "house" and "invhouse" shapes used. So we
177 // append ":n" and ":s" respectively for those.
178 for (const auto& rnn_state : model.flags.rnn_states()) {
179 // RNN state is essentially an input
180 if (array_name == rnn_state.state_array()) {
181 return ":s";
182 }
183 // RNN back-edge source is essentially an output
184 if (array_name == rnn_state.back_edge_source_array()) {
185 return ":n";
186 }
187 }
188 if (IsInputArray(model, array_name)) {
189 return ":s";
190 }
191 if (IsOutputArray(model, array_name)) {
192 return ":n";
193 }
194 return "";
195 }
196
AppendArrayVal(std::string * string,Array const & array,int index)197 void AppendArrayVal(std::string* string, Array const& array, int index) {
198 if (array.buffer->type == ArrayDataType::kFloat) {
199 const auto& data = array.GetBuffer<ArrayDataType::kFloat>().data;
200 if (index >= data.size()) {
201 return;
202 }
203 AppendF(string, "%.3f", data[index]);
204 } else if (array.buffer->type == ArrayDataType::kUint8) {
205 const auto& data = array.GetBuffer<ArrayDataType::kUint8>().data;
206 if (index >= data.size()) {
207 return;
208 }
209 AppendF(string, "%d", data[index]);
210 } else if (array.buffer->type == ArrayDataType::kInt16) {
211 const auto& data = array.GetBuffer<ArrayDataType::kInt16>().data;
212 if (index >= data.size()) {
213 return;
214 }
215 AppendF(string, "%d", data[index]);
216 } else if (array.buffer->type == ArrayDataType::kInt32) {
217 const auto& data = array.GetBuffer<ArrayDataType::kInt32>().data;
218 if (index >= data.size()) {
219 return;
220 }
221 AppendF(string, "%d", data[index]);
222 } else if (array.buffer->type == ArrayDataType::kInt64) {
223 const auto& data = array.GetBuffer<ArrayDataType::kInt64>().data;
224 if (index >= data.size()) {
225 return;
226 }
227 AppendF(string, "%d", data[index]);
228 } else if (array.buffer->type == ArrayDataType::kBool) {
229 const auto& data = array.GetBuffer<ArrayDataType::kBool>().data;
230 if (index >= data.size()) {
231 return;
232 }
233 AppendF(string, "%d", data[index]);
234 }
235 }
236
237 typedef std::map<std::string, std::string> Attributes;
238
AttributesToHtml(Attributes attributes)239 std::string AttributesToHtml(Attributes attributes) {
240 std::string html;
241 for (const auto& attr : attributes) {
242 html += R"CODE(<TR><TD CELLPADDING="1" ALIGN="RIGHT">)CODE";
243 html += attr.first;
244 html += R"CODE(:</TD><TD CELLPADDING="1" ALIGN="LEFT">)CODE";
245 html += attr.second;
246 html += "</TD></TR>";
247 }
248 return html;
249 }
250
GetArrayLabel(const Model & model,const std::string & array_id)251 std::string GetArrayLabel(const Model& model, const std::string& array_id) {
252 std::string html;
253
254 // Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html)
255 html += "<";
256
257 // Begin Table
258 html += R"CODE(<FONT POINT-SIZE="10" FACE="Courier">)CODE";
259 html += R"CODE(<TABLE BORDER="0" CELLSPACING="2" CELLPADDING="0">)CODE";
260
261 auto& array = model.GetArray(array_id);
262 if (array.buffer) {
263 // "cylinder" shapes require some extra head room.
264 html += R"CODE(<TR><TD COLSPAN="2"> </TD></TR>)CODE";
265 }
266
267 // "Primary" name of array (last non-slash delimited group of characters).
268 html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
269 html += R"CODE(<FONT POINT-SIZE="16" FACE="Helvetica"><I>)CODE";
270 AppendF(&html, R"CODE(%s)CODE",
271 std::vector<std::string>(absl::StrSplit(array_id, '/')).back());
272 html += R"CODE(</I></FONT>)CODE";
273 html += "</TD></TR>";
274
275 // Array data type and dimensions
276 html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
277 html += R"CODE(<FONT POINT-SIZE="14" FACE="Courier"><B>)CODE";
278 // Type
279 html += ArrayDataTypeName(array.data_type);
280 // Shape
281 if (array.has_shape()) {
282 auto& array_shape = array.shape();
283 html += "[";
284 for (int dim = 0; dim < array_shape.dimensions_count(); dim++) {
285 AppendF(&html, "%d", array_shape.dims(dim));
286 if (dim + 1 < array_shape.dimensions_count()) {
287 html += kUnicodeMult;
288 }
289 }
290 html += "]";
291 }
292
293 // Small buffer sample
294 int buffer_size = 0;
295 if (array.buffer) {
296 buffer_size = RequiredBufferSizeForShape(array.shape());
297 }
298 if ((buffer_size > 0) && (buffer_size <= 4)) {
299 html += " = ";
300 if (array.shape().dimensions_count() > 0) {
301 html += "{";
302 }
303 for (int i = 0; i < buffer_size; i++) {
304 AppendArrayVal(&html, array, i);
305 if (i + 1 < buffer_size) {
306 html += ", ";
307 }
308 }
309 if (array.shape().dimensions_count() > 0) {
310 html += "}";
311 }
312 }
313 html += R"CODE(</B></FONT>)CODE";
314 html += "</TD></TR>";
315
316 // Large buffer samples get their own line
317 if (buffer_size > 4) {
318 html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER"> = {)CODE";
319 AppendArrayVal(&html, array, 0);
320 html += ", ";
321 AppendArrayVal(&html, array, 1);
322 html += kUnicodeEllipsis;
323 AppendArrayVal(&html, array, buffer_size - 2);
324 html += ", ";
325 AppendArrayVal(&html, array, buffer_size - 1);
326 html += "}</TD></TR>";
327 }
328
329 // Other array properties
330 Attributes attrs;
331 if (array.minmax) {
332 attrs["minmax"] =
333 StringF("[%.7g, %.7g]", array.minmax->min, array.minmax->max);
334 }
335 if (array.quantization_params) {
336 attrs["quant"] = StringF("%7g\u00B7(x-%d)", // Unicode "cdot"
337 array.quantization_params->scale,
338 array.quantization_params->zero_point);
339 }
340 if (array.alloc) {
341 attrs["alloc"] = StringF("[%d, %d)", array.alloc->start, array.alloc->end);
342 }
343 html += AttributesToHtml(attrs);
344
345 // output array_id in ultra-small font so it can be searched and copied.
346 html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
347 html += R"CODE(<FONT POINT-SIZE="3" FACE="">)CODE";
348 AppendF(&html, R"CODE("%s")CODE", array_id);
349 html += R"CODE(</FONT>)CODE";
350 html += "</TD></TR>";
351
352 // End Table and HTML-like label
353 html += R"CODE(</TABLE></FONT>)CODE";
354 html += ">";
355 return html;
356 }
357
GetOpAttributes(const Model & model,const Operator & op)358 Attributes GetOpAttributes(const Model& model, const Operator& op) {
359 Attributes attrs;
360 switch (op.fused_activation_function) {
361 case FusedActivationFunctionType::kRelu:
362 attrs["func"] = "ReLU";
363 break;
364 case FusedActivationFunctionType::kRelu6:
365 attrs["func"] = "ReLU6";
366 break;
367 case FusedActivationFunctionType::kRelu1:
368 attrs["func"] = "ReLU1";
369 break;
370 default:
371 break;
372 }
373 // Output state of member vars on derived operators.
374 switch (op.type) {
375 case OperatorType::kConv: {
376 const auto& conv_op = static_cast<const ConvOperator&>(op);
377 std::string stride;
378 AppendF(&stride, "%d", conv_op.stride_width);
379 stride += kUnicodeMult;
380 AppendF(&stride, "%d", conv_op.stride_height);
381 attrs["stride"] = stride;
382 attrs["padding"] =
383 (conv_op.padding.type == PaddingType::kSame) ? "same" : "valid";
384 break;
385 }
386 case OperatorType::kDepthwiseConv: {
387 const auto& depthconv_op = static_cast<const ConvOperator&>(op);
388 std::string stride;
389 AppendF(&stride, "%d", depthconv_op.stride_width);
390 stride += kUnicodeMult;
391 AppendF(&stride, "%d", depthconv_op.stride_height);
392 attrs["stride"] = stride;
393 attrs["padding"] =
394 (depthconv_op.padding.type == PaddingType::kSame) ? "same" : "valid";
395 break;
396 }
397 case OperatorType::kFakeQuant: {
398 const auto& fakequant_op = static_cast<const FakeQuantOperator&>(op);
399 attrs["bits"] = StringF("%d", fakequant_op.num_bits);
400 if (fakequant_op.minmax) {
401 attrs["range"] = StringF("[%g,%g]", fakequant_op.minmax->min,
402 fakequant_op.minmax->max);
403 } else {
404 attrs["range"] = "[?,?]";
405 }
406 break;
407 }
408 default:
409 break;
410 }
411 int64 math_ops_count;
412 if (EstimateArithmeticOpsCount(model, op, &math_ops_count) &&
413 (math_ops_count != 0)) {
414 attrs["math"] = FormattedNumber(math_ops_count) + "ops";
415 }
416
417 return attrs;
418 }
419
GetOpColor(const Operator & op)420 Color GetOpColor(const Operator& op) {
421 if ((op.type == OperatorType::kDepthwiseConv) ||
422 (op.type == OperatorType::kConv) ||
423 (op.type == OperatorType::kFullyConnected) ||
424 (op.type == OperatorType::kFakeQuant)) {
425 // Give some ops a bolder red
426 return Color(0xC5, 0x39, 0x29);
427 } else {
428 return Color(0xDB, 0x44, 0x37);
429 }
430 }
431
GetOpLabel(const Model & model,const Operator & op)432 std::string GetOpLabel(const Model& model, const Operator& op) {
433 // Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html)
434 std::string html;
435 html += "<";
436
437 // Begin Table
438 html += R"CODE(<FONT POINT-SIZE="10" FACE="Courier">)CODE";
439 html +=
440 R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="0">)CODE";
441
442 // Input Ports
443 if (!op.inputs.empty()) {
444 html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
445 // Distribute evenly using a sub-table
446 html += R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0">)CODE";
447 html += R"CODE(<TR>)CODE";
448 for (int i = 0; i < op.inputs.size(); i++) {
449 html += R"CODE(<TD PORT=")CODE";
450 AppendF(&html, "i%d", i);
451 html += R"CODE(">)CODE";
452 if (op.inputs.size() > 1) {
453 // Only number inputs when op has two or more inputs
454 AppendF(&html, "%d", i);
455 }
456 html += "</TD>";
457 }
458 html += "</TR>";
459 html += R"CODE(</TABLE></TD></TR>)CODE";
460 }
461
462 // Name
463 html += R"CODE(<TR><TD COLSPAN="2" CELLPADDING="3" ALIGN="CENTER">)CODE";
464 html += R"CODE(<FONT POINT-SIZE="16" FACE="Helvetica"><B>)CODE";
465 if (op.type == OperatorType::kUnsupported) {
466 html += static_cast<const TensorFlowUnsupportedOperator&>(op).tensorflow_op;
467 } else {
468 html +=
469 std::string(absl::StripPrefix(OperatorTypeName(op.type), "TensorFlow"));
470 }
471 html += R"CODE(</B></FONT>)CODE";
472 html += "</TD></TR>";
473
474 // Attributes
475 Attributes attrs = GetOpAttributes(model, op);
476 html += AttributesToHtml(attrs);
477
478 // Output Ports
479 if (!op.outputs.empty()) {
480 html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
481 // Distribute evenly using a sub-table
482 html += R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0">)CODE";
483 html += R"CODE(<TR>)CODE";
484 for (int i = 0; i < op.outputs.size(); i++) {
485 html += R"CODE(<TD PORT=")CODE";
486 AppendF(&html, "o%d", i);
487 html += R"CODE(">)CODE";
488 if (op.outputs.size() > 1) {
489 // Only number outputs when op has two or more outputs
490 AppendF(&html, "%d", i);
491 }
492 html += "</TD>";
493 }
494 html += "</TR>";
495 html += R"CODE(</TABLE></TD></TR>)CODE";
496 }
497
498 // End Table and HTML-like label
499 html += R"CODE(</TABLE></FONT>)CODE";
500 html += ">";
501
502 return html;
503 }
504
GetLog2BufferSize(const Model & model,const std::string & array_id)505 float GetLog2BufferSize(const Model& model, const std::string& array_id) {
506 auto& array = model.GetArray(array_id);
507 if (array.has_shape()) {
508 int buffer_size = 0;
509 if (IsNonEmpty(array.shape())) {
510 buffer_size = RequiredBufferSizeForShape(array.shape());
511 return std::log2(static_cast<float>(buffer_size));
512 }
513 }
514 return 0.0f;
515 }
516
GetOpId(int op_index)517 std::string GetOpId(int op_index) { return StringF("op%05d", op_index); }
518
DumpOperator(const Model & model,std::string * output_file,int op_index)519 void DumpOperator(const Model& model, std::string* output_file, int op_index) {
520 // Dump node for operator.
521 const Operator& op = *model.operators[op_index];
522 Color color = GetOpColor(op);
523 std::string label = GetOpLabel(model, op);
524 std::string op_id = GetOpId(op_index);
525 AppendF(output_file, kOpNodeFmt, op_id, label, color.AsHexString(),
526 color.TextColorString());
527 }
528
DumpOperatorEdges(const Model & model,std::string * output_file,int op_index)529 void DumpOperatorEdges(const Model& model, std::string* output_file,
530 int op_index) {
531 // Inputs
532 const Operator& op = *model.operators[op_index];
533 std::string op_id = GetOpId(op_index);
534 for (int i = 0; i < op.inputs.size(); i++) {
535 const auto& input = op.inputs[i];
536 if (!model.HasArray(input)) {
537 // Connected arrays should _always_ exist. Except, perhaps, during
538 // development.
539 continue;
540 }
541 float log2_buffer_size = GetLog2BufferSize(model, input);
542 // Draw lines that transport more data thicker (Otherwise, where would the
543 // data fit? right?).
544 float line_width = std::max(0.5f, log2_buffer_size / 3.0f);
545 // Keep edges that transport more data shorter than those with less.
546 float weight = std::max(1.0f, log2_buffer_size);
547 if (!IsInputArray(model, input) &&
548 GetOpWithOutput(model, input) == nullptr) {
549 // Give the main line of data flow a straighter path by penalizing edges
550 // to standalone buffers. Weights are generally very large buffers that
551 // would otherwise skew the layout.
552 weight = 1.0f;
553 }
554 std::string compass_pt = GetArrayCompassPt(model, input);
555 AppendF(output_file, kInputEdgeFmt, input, compass_pt, op_id, i, line_width,
556 weight);
557 }
558 // Outputs
559 for (int i = 0; i < op.outputs.size(); i++) {
560 const auto& output = op.outputs[i];
561 if (!model.HasArray(output)) {
562 continue;
563 }
564 float log2_buffer_size = GetLog2BufferSize(model, output);
565 // See comments above regarding weight and line_width calculations.
566 float line_width = std::max(0.5f, log2_buffer_size / 3.0f);
567 float weight = std::max(1.0f, log2_buffer_size);
568 if (!IsArrayConsumed(model, output)) {
569 weight = 1.0f;
570 }
571 std::string compass_pt = GetArrayCompassPt(model, output);
572 AppendF(output_file, kOutputEdgeFmt, op_id, i, output, compass_pt,
573 line_width, weight);
574 }
575 }
576
577 struct Node {
Nodetoco::__anona87f18ca0111::Node578 Node() : math_ops(0) {}
579 // Name used as a key in the model's array map
580 std::string array_id;
581
582 // Estimated number of math ops incurred by this node (the sum of the op
583 // with this array as 1st output, plus all children nodes).
584 int64 math_ops;
585
586 // A map of child nodes keyed by name.
587 std::map<const std::string, std::unique_ptr<Node>> children;
588 };
589
GetSubgraphLabel(Node const & node,const std::string & subgraph)590 std::string GetSubgraphLabel(Node const& node, const std::string& subgraph) {
591 // Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html)
592 std::string html;
593 html += "<";
594
595 // Begin Table
596 html += R"CODE(<FONT POINT-SIZE="12" FACE="Courier">)CODE";
597 html +=
598 R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="0">)CODE";
599
600 // Name
601 html += R"CODE(<TR><TD COLSPAN="2" CELLPADDING="3" ALIGN="CENTER">)CODE";
602 html += R"CODE(<FONT POINT-SIZE="18" FACE="Helvetica"><I>)CODE";
603 html += subgraph;
604 html += R"CODE(</I></FONT>)CODE";
605 html += "</TD></TR>";
606
607 // Attributes
608 Attributes attrs;
609 if (node.math_ops > 0) {
610 attrs["math"] = FormattedNumber(node.math_ops) + "ops";
611 }
612 html += AttributesToHtml(attrs);
613
614 // End Table and HTML-like label
615 html += R"CODE(</TABLE></FONT>)CODE";
616 html += ">";
617
618 return html;
619 }
620
DumpSubgraphHeader(std::string * output_file,Node const & node,const std::string & node_name)621 void DumpSubgraphHeader(std::string* output_file, Node const& node,
622 const std::string& node_name) {
623 Color color = HashStringToColor(node_name);
624 std::string label = GetSubgraphLabel(node, node_name);
625 AppendF(output_file, kSubgraphFmt, node_name, color.AsHexString(), label);
626 }
627
DumpArray(const Model & model,std::string * output_file,const std::string & array_id)628 void DumpArray(const Model& model, std::string* output_file,
629 const std::string& array_id) {
630 Color color;
631 std::string shape;
632 GetArrayColorAndShape(model, array_id, &color, &shape);
633 std::string label = GetArrayLabel(model, array_id);
634 AppendF(output_file, kArrayNodeFmt, array_id, label, array_id, shape,
635 color.AsHexString(), color.TextColorString());
636
637 // Ops are placed in the same subgraph as their first output.
638 for (int op_index = 0; op_index < model.operators.size(); op_index++) {
639 const Operator& op = *model.operators[op_index];
640 if (!op.outputs.empty() && (op.outputs[0] == array_id)) {
641 DumpOperator(model, output_file, op_index);
642 }
643 }
644 }
645
DumpNode(const Model & model,std::string * output_file,const std::string & node_name,Node const & node)646 void DumpNode(const Model& model, std::string* output_file,
647 const std::string& node_name, Node const& node) {
648 bool not_root = !node_name.empty();
649 if (not_root) {
650 DumpSubgraphHeader(output_file, node, node_name);
651 }
652
653 for (const auto& child : node.children) {
654 if (!child.second->array_id.empty()) {
655 // Dump array if this node possesses one.
656 DumpArray(model, output_file, child.second->array_id);
657 }
658 // Note that it is always possible to have children. Unlike a filesystem,
659 // the existence of array "foo/bar" does _not_ prevent other arrays, such as
660 // and "foo/bar/baz", from being nested beneath it.
661 DumpNode(model, output_file, child.first, *child.second);
662 }
663
664 if (not_root) {
665 // End subgraph
666 AppendF(output_file, " }\n");
667 }
668 }
669
GetArithmeticOpsCount(const Model & model,const std::string & array_id)670 int64 GetArithmeticOpsCount(const Model& model, const std::string& array_id) {
671 for (const auto& op : model.operators) {
672 if (!op->outputs.empty() && op->outputs[0] == array_id) {
673 int64 count;
674 if (EstimateArithmeticOpsCount(model, *op, &count)) {
675 return count;
676 } else {
677 return 0;
678 }
679 }
680 }
681 return 0;
682 }
683
InsertNode(const Model & model,const std::string & array_id,Node * node,std::vector<std::string> prefixes,int64 * math_ops)684 void InsertNode(const Model& model, const std::string& array_id, Node* node,
685 std::vector<std::string> prefixes, int64* math_ops) {
686 if (prefixes.empty()) {
687 // Base case: store array in this node.
688 node->array_id = array_id;
689 *math_ops = GetArithmeticOpsCount(model, array_id);
690 } else {
691 // Insert into the sub-tree for that prefix.
692 std::string prefix = prefixes.back();
693 prefixes.pop_back();
694 if (node->children.count(prefix) == 0) {
695 // Create a new node if this prefix is unseen.
696 node->children[prefix] = absl::make_unique<Node>();
697 }
698 InsertNode(model, array_id, node->children[prefix].get(), prefixes,
699 math_ops);
700 }
701 // Sum estimated math ops into all nodes.
702 node->math_ops += *math_ops;
703 }
704
BuildArrayTree(const Model & model,Node * tree)705 void BuildArrayTree(const Model& model, Node* tree) {
706 // Delimit array names by path "/", then place into a tree based on this path.
707 for (const auto& array_id : model.GetArrayMap()) {
708 std::vector<std::string> prefixes = absl::StrSplit(array_id.first, '/');
709 std::reverse(prefixes.begin(), prefixes.end());
710 int64 math_ops; // Temporary storage for math ops used during recursion.
711 InsertNode(model, array_id.first, tree, prefixes, &math_ops);
712 }
713 }
714
GetGraphLabel(const Model & model,const std::string & graph_name)715 std::string GetGraphLabel(const Model& model, const std::string& graph_name) {
716 // Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html)
717 std::string html;
718 html += "<";
719
720 // Begin Table
721 html += R"CODE(<FONT POINT-SIZE="36" FACE="Courier">)CODE";
722 html +=
723 R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="0">)CODE";
724
725 // Name
726 html += R"CODE(<TR><TD COLSPAN="2" CELLPADDING="3" ALIGN="CENTER">)CODE";
727 html += R"CODE(<FONT POINT-SIZE="64" FACE="Helvetica"><B><I>)CODE";
728 html += graph_name;
729 html += R"CODE(</I></B></FONT>)CODE";
730 html += "</TD></TR>";
731
732 // Attributes
733 Attributes attrs;
734 attrs["arrays"] = StringF("%d", model.GetArrayMap().size());
735 if (!model.optional_arrays.empty()) {
736 attrs["optional arrays"] = StringF("%d", model.optional_arrays.size());
737 }
738 attrs["operators"] = StringF("%d", model.operators.size());
739 int64 ops_count;
740 if (EstimateArithmeticOpsCount(model, &ops_count) && (ops_count > 0)) {
741 attrs["math"] = FormattedNumber(ops_count) + "ops";
742 }
743 if (model.transient_data_size > 0) {
744 attrs["transient data size"] =
745 StringF("%d KiB", model.transient_data_size / 1024);
746 }
747 if (model.transient_data_alignment > 0) {
748 attrs["transient data alignment"] =
749 StringF("%d bytes", model.transient_data_alignment);
750 }
751 html += AttributesToHtml(attrs);
752
753 // End Table and HTML-like label
754 html += R"CODE(</TABLE></FONT>)CODE";
755 html += ">";
756
757 return html;
758 }
759 } // namespace
760
DumpGraphviz(const Model & model,std::string * output_file,const std::string & graph_name)761 void DumpGraphviz(const Model& model, std::string* output_file,
762 const std::string& graph_name) {
763 // Start graphviz format
764 AppendF(output_file, kGraphFmt, GetGraphLabel(model, graph_name));
765
766 // Organize arrays into a tree for subgraphing
767 Node tree;
768 BuildArrayTree(model, &tree);
769 DumpNode(model, output_file, "", tree);
770
771 // Dump edges outside all subgraphs (otherwise the referred-to nodes are
772 // implicitly included in that subgraph).
773 for (int op_index = 0; op_index < model.operators.size(); op_index++) {
774 DumpOperatorEdges(model, output_file, op_index);
775 }
776
777 // Dump RNN Backedges
778 for (const auto& rnn_state : model.flags.rnn_states()) {
779 AppendF(output_file, kRNNBackEdgeFmt, rnn_state.back_edge_source_array(),
780 rnn_state.state_array());
781 }
782 // End graphviz format
783 AppendF(output_file, "}\n");
784 }
785 } // namespace toco
786