1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/dump_graphviz.h"
16 
17 #include <cmath>
18 #include <functional>
19 #include <memory>
20 #include <vector>
21 
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_replace.h"
24 #include "absl/strings/str_split.h"
25 #include "absl/strings/strip.h"
26 #include "re2/re2.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/lite/toco/model_flags.pb.h"
29 #include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
30 #include "tensorflow/lite/toco/toco_port.h"
31 #include "tensorflow/lite/toco/toco_types.h"
32 #include "tensorflow/lite/toco/tooling_util.h"
33 
34 using toco::port::AppendF;
35 using toco::port::StringF;
36 
37 namespace toco {
38 namespace {
39 
40 // 'nslimit' is a graphviz (dot) parameter that limits the iterations during
41 // the layout phase. Omitting it allows infinite iterations, causing some
42 // complex graphs to never finish. A value of 125 produces good graphs
43 // while allowing complex graphs to finish.
44 constexpr char kGraphFmt[] = R"CODE(digraph Computegraph { tooltip = "/"
45     nslimit=125 margin=36 ranksep = 2 labelloc="t" label=%s
46 )CODE";
47 // Note: tooltip's are only supported on SVGs in Chrome.
48 constexpr char kSubgraphFmt[] =
49     R"CODE(    subgraph "cluster_%s" { style=rounded bgcolor="%s" penwidth=0.0 label=%s
50 )CODE";
51 constexpr char kArrayNodeFmt[] =
52     R"CODE(        "%s" [label=%s tooltip="%s" shape=%s style=filled fillcolor="%s" fontcolor="%sDD"];
53 )CODE";
54 constexpr char kOpNodeFmt[] =
55     R"CODE(        %s [label=%s tooltip=" " shape=box margin=0 style=filled fillcolor="%s" fontcolor="%sDD"];
56 )CODE";
57 constexpr char kInputEdgeFmt[] =
58     R"CODE(        "%s"%s -> %s:i%d:n [penwidth=%f weight=%f];
59 )CODE";
60 constexpr char kOutputEdgeFmt[] =
61     R"CODE(        %s:o%d:s -> "%s"%s [penwidth=%f weight=%f];
62 )CODE";
63 constexpr char kRNNBackEdgeFmt[] =
64     R"CODE(        "%s":s -> "%s":n [color="#0F9D58" constraint=false];
65 )CODE";
66 constexpr char kUnicodeMult[] = "\u00D7";
67 constexpr char kUnicodeEllipsis[] = " \u2026 ";
68 
69 class Color {
70  public:
Color()71   Color() {}
Color(uint8 r,uint8 g,uint8 b)72   Color(uint8 r, uint8 g, uint8 b) : r_(r), g_(g), b_(b) {}
Color(uint32 word)73   explicit Color(uint32 word)
74       : r_((word & 0x00FF0000) >> 16),
75         g_((word & 0x0000FF00) >> 8),
76         b_((word & 0x000000FF) >> 0) {}
77 
78   // Returns the string serialization of this color in graphviz format,
79   // for use as 'fillcolor' in boxes.
AsHexString() const80   std::string AsHexString() const {
81     return StringF("#%.2X%.2X%.2X", r_, g_, b_);
82   }
83   // The color to use for this node; will be used as 'fillcolor'
84   // for its box. See Color::AsHexString. A suitable, different
85   // color will be chosen for the 'fontcolor' for the inside text
86   // label, see Color::TextColorString.
87   // Returns the serialization in graphviz format of a suitable color to use
88   // 'fontcolor' in the same boxes. It should black or white, whichever offers
89   // the better contrast from AsHexString().
TextColorString() const90   std::string TextColorString() const {
91     // https://en.wikipedia.org/wiki/Relative_luminance
92     const float luminance = 0.2126f * r_ + 0.7152f * g_ + 0.0722f * b_;
93     const uint8 l = luminance > 128.f ? 0 : 255;
94     return StringF("#%.2X%.2X%.2X", l, l, l);
95   }
96 
97  private:
98   uint8 r_ = 0, g_ = 0, b_ = 0;
99 };
100 
HashStringToColor(std::string s)101 Color HashStringToColor(std::string s) {
102   // Return a unique color for a name.
103   //
104   // This function removes Tensorflow anti-collision suffixes (eg "_2"), hashes
105   // the string to a uint_32, then twiddles some bits to get a light and subtle
106   // color. This seems to be a good heuristic for keeping enough of the name to
107   // hash to a unique color while still revealing structure through naming
108   // similarities.
109   //
110   // The regular expression "_\d+" matches any underscore followed by numbers,
111   // which we strip out. Examples:
112   //
113   //     "Conv"      -> "Conv"
114   //     "Conv_2"    -> "Conv"
115   //     "Conv_72"   -> "Conv"
116   //     "Pad_1_bias -> "Pad_bias"
117   //     "Conv_abc"  -> "Conv_abc"
118 
119   RE2::GlobalReplace(&s, R"CODE(_\d+)CODE", "");
120   uint32 color_word = std::hash<std::string>{}(s);
121   color_word |= 0x00E0E0E0;
122   return Color(color_word);
123 }
124 
GetArrayColorAndShape(const Model & model,const std::string & array_name,Color * color,std::string * shape)125 void GetArrayColorAndShape(const Model& model, const std::string& array_name,
126                            Color* color, std::string* shape) {
127   // All colors in this file are from:
128   // https://material.io/guidelines/style/color.html
129   // Arrays involved in RNN back-edges have a different color
130   for (const auto& rnn_state : model.flags.rnn_states()) {
131     // RNN state, fed by a back-edge. Bold color.
132     if (array_name == rnn_state.state_array()) {
133       *color = Color(0x0F, 0x9D, 0x58);
134       *shape = "invhouse";
135       return;
136     }
137     // RNN back-edge source, feeding a RNN state.
138     // Light tone of the same color as RNN states.
139     if (array_name == rnn_state.back_edge_source_array()) {
140       *color = Color(0xB7, 0xE1, 0xCD);
141       *shape = "house";
142       return;
143     }
144   }
145   // Constant parameter arrays have their own bold color
146   if (model.GetArray(array_name).buffer) {
147     *color = Color(0x42, 0x85, 0xF4);
148     *shape = "cylinder";
149     return;
150   }
151   // Remaining arrays are activations.
152   // We use gray colors for them because they are the majority
153   // of arrays so we want to highlight other arrays instead of them.
154   // First, we use a bolder gray for input/output arrays:
155   if (IsInputArray(model, array_name)) {
156     *color = Color(0x9E, 0x9E, 0x9E);
157     *shape = "invhouse";
158     return;
159   }
160   if (IsOutputArray(model, array_name)) {
161     *color = Color(0x9E, 0x9E, 0x9E);
162     *shape = "house";
163     return;
164   }
165   // Remaining arrays are intermediate activation arrays.
166   // Lighter tone of the same grey as for input/output arrays:
167   // We want these to be very discrete.
168   *color = Color(0xF5, 0xF5, 0xF5);
169   *shape = "box";
170 }
171 
GetArrayCompassPt(const Model & model,const std::string & array_name)172 std::string GetArrayCompassPt(const Model& model,
173                               const std::string& array_name) {
174   // The "compass point" is the point on the node where edge connections are
175   // made. For most arrays we don't care, but input's and outputs look better
176   // connected at the tip of the "house" and "invhouse" shapes used. So we
177   // append ":n" and ":s" respectively for those.
178   for (const auto& rnn_state : model.flags.rnn_states()) {
179     // RNN state is essentially an input
180     if (array_name == rnn_state.state_array()) {
181       return ":s";
182     }
183     // RNN back-edge source is essentially an output
184     if (array_name == rnn_state.back_edge_source_array()) {
185       return ":n";
186     }
187   }
188   if (IsInputArray(model, array_name)) {
189     return ":s";
190   }
191   if (IsOutputArray(model, array_name)) {
192     return ":n";
193   }
194   return "";
195 }
196 
AppendArrayVal(std::string * string,Array const & array,int index)197 void AppendArrayVal(std::string* string, Array const& array, int index) {
198   if (array.buffer->type == ArrayDataType::kFloat) {
199     const auto& data = array.GetBuffer<ArrayDataType::kFloat>().data;
200     if (index >= data.size()) {
201       return;
202     }
203     AppendF(string, "%.3f", data[index]);
204   } else if (array.buffer->type == ArrayDataType::kUint8) {
205     const auto& data = array.GetBuffer<ArrayDataType::kUint8>().data;
206     if (index >= data.size()) {
207       return;
208     }
209     AppendF(string, "%d", data[index]);
210   } else if (array.buffer->type == ArrayDataType::kInt16) {
211     const auto& data = array.GetBuffer<ArrayDataType::kInt16>().data;
212     if (index >= data.size()) {
213       return;
214     }
215     AppendF(string, "%d", data[index]);
216   } else if (array.buffer->type == ArrayDataType::kInt32) {
217     const auto& data = array.GetBuffer<ArrayDataType::kInt32>().data;
218     if (index >= data.size()) {
219       return;
220     }
221     AppendF(string, "%d", data[index]);
222   } else if (array.buffer->type == ArrayDataType::kInt64) {
223     const auto& data = array.GetBuffer<ArrayDataType::kInt64>().data;
224     if (index >= data.size()) {
225       return;
226     }
227     AppendF(string, "%d", data[index]);
228   } else if (array.buffer->type == ArrayDataType::kBool) {
229     const auto& data = array.GetBuffer<ArrayDataType::kBool>().data;
230     if (index >= data.size()) {
231       return;
232     }
233     AppendF(string, "%d", data[index]);
234   }
235 }
236 
237 typedef std::map<std::string, std::string> Attributes;
238 
AttributesToHtml(Attributes attributes)239 std::string AttributesToHtml(Attributes attributes) {
240   std::string html;
241   for (const auto& attr : attributes) {
242     html += R"CODE(<TR><TD CELLPADDING="1" ALIGN="RIGHT">)CODE";
243     html += attr.first;
244     html += R"CODE(:</TD><TD CELLPADDING="1" ALIGN="LEFT">)CODE";
245     html += attr.second;
246     html += "</TD></TR>";
247   }
248   return html;
249 }
250 
GetArrayLabel(const Model & model,const std::string & array_id)251 std::string GetArrayLabel(const Model& model, const std::string& array_id) {
252   std::string html;
253 
254   // Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html)
255   html += "<";
256 
257   // Begin Table
258   html += R"CODE(<FONT POINT-SIZE="10" FACE="Courier">)CODE";
259   html += R"CODE(<TABLE BORDER="0" CELLSPACING="2" CELLPADDING="0">)CODE";
260 
261   auto& array = model.GetArray(array_id);
262   if (array.buffer) {
263     // "cylinder" shapes require some extra head room.
264     html += R"CODE(<TR><TD COLSPAN="2"> </TD></TR>)CODE";
265   }
266 
267   // "Primary" name of array (last non-slash delimited group of characters).
268   html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
269   html += R"CODE(<FONT POINT-SIZE="16" FACE="Helvetica"><I>)CODE";
270   AppendF(&html, R"CODE(%s)CODE",
271           std::vector<std::string>(absl::StrSplit(array_id, '/')).back());
272   html += R"CODE(</I></FONT>)CODE";
273   html += "</TD></TR>";
274 
275   // Array data type and dimensions
276   html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
277   html += R"CODE(<FONT POINT-SIZE="14" FACE="Courier"><B>)CODE";
278   // Type
279   html += ArrayDataTypeName(array.data_type);
280   // Shape
281   if (array.has_shape()) {
282     auto& array_shape = array.shape();
283     html += "[";
284     for (int dim = 0; dim < array_shape.dimensions_count(); dim++) {
285       AppendF(&html, "%d", array_shape.dims(dim));
286       if (dim + 1 < array_shape.dimensions_count()) {
287         html += kUnicodeMult;
288       }
289     }
290     html += "]";
291   }
292 
293   // Small buffer sample
294   int buffer_size = 0;
295   if (array.buffer) {
296     buffer_size = RequiredBufferSizeForShape(array.shape());
297   }
298   if ((buffer_size > 0) && (buffer_size <= 4)) {
299     html += " = ";
300     if (array.shape().dimensions_count() > 0) {
301       html += "{";
302     }
303     for (int i = 0; i < buffer_size; i++) {
304       AppendArrayVal(&html, array, i);
305       if (i + 1 < buffer_size) {
306         html += ", ";
307       }
308     }
309     if (array.shape().dimensions_count() > 0) {
310       html += "}";
311     }
312   }
313   html += R"CODE(</B></FONT>)CODE";
314   html += "</TD></TR>";
315 
316   // Large buffer samples get their own line
317   if (buffer_size > 4) {
318     html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER"> = {)CODE";
319     AppendArrayVal(&html, array, 0);
320     html += ", ";
321     AppendArrayVal(&html, array, 1);
322     html += kUnicodeEllipsis;
323     AppendArrayVal(&html, array, buffer_size - 2);
324     html += ", ";
325     AppendArrayVal(&html, array, buffer_size - 1);
326     html += "}</TD></TR>";
327   }
328 
329   // Other array properties
330   Attributes attrs;
331   if (array.minmax) {
332     attrs["minmax"] =
333         StringF("[%.7g, %.7g]", array.minmax->min, array.minmax->max);
334   }
335   if (array.quantization_params) {
336     attrs["quant"] = StringF("%7g\u00B7(x-%d)",  // Unicode "cdot"
337                              array.quantization_params->scale,
338                              array.quantization_params->zero_point);
339   }
340   if (array.alloc) {
341     attrs["alloc"] = StringF("[%d, %d)", array.alloc->start, array.alloc->end);
342   }
343   html += AttributesToHtml(attrs);
344 
345   // output array_id in ultra-small font so it can be searched and copied.
346   html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
347   html += R"CODE(<FONT POINT-SIZE="3" FACE="">)CODE";
348   AppendF(&html, R"CODE("%s")CODE", array_id);
349   html += R"CODE(</FONT>)CODE";
350   html += "</TD></TR>";
351 
352   // End Table and HTML-like label
353   html += R"CODE(</TABLE></FONT>)CODE";
354   html += ">";
355   return html;
356 }
357 
GetOpAttributes(const Model & model,const Operator & op)358 Attributes GetOpAttributes(const Model& model, const Operator& op) {
359   Attributes attrs;
360   switch (op.fused_activation_function) {
361     case FusedActivationFunctionType::kRelu:
362       attrs["func"] = "ReLU";
363       break;
364     case FusedActivationFunctionType::kRelu6:
365       attrs["func"] = "ReLU6";
366       break;
367     case FusedActivationFunctionType::kRelu1:
368       attrs["func"] = "ReLU1";
369       break;
370     default:
371       break;
372   }
373   // Output state of member vars on derived operators.
374   switch (op.type) {
375     case OperatorType::kConv: {
376       const auto& conv_op = static_cast<const ConvOperator&>(op);
377       std::string stride;
378       AppendF(&stride, "%d", conv_op.stride_width);
379       stride += kUnicodeMult;
380       AppendF(&stride, "%d", conv_op.stride_height);
381       attrs["stride"] = stride;
382       attrs["padding"] =
383           (conv_op.padding.type == PaddingType::kSame) ? "same" : "valid";
384       break;
385     }
386     case OperatorType::kDepthwiseConv: {
387       const auto& depthconv_op = static_cast<const ConvOperator&>(op);
388       std::string stride;
389       AppendF(&stride, "%d", depthconv_op.stride_width);
390       stride += kUnicodeMult;
391       AppendF(&stride, "%d", depthconv_op.stride_height);
392       attrs["stride"] = stride;
393       attrs["padding"] =
394           (depthconv_op.padding.type == PaddingType::kSame) ? "same" : "valid";
395       break;
396     }
397     case OperatorType::kFakeQuant: {
398       const auto& fakequant_op = static_cast<const FakeQuantOperator&>(op);
399       attrs["bits"] = StringF("%d", fakequant_op.num_bits);
400       if (fakequant_op.minmax) {
401         attrs["range"] = StringF("[%g,%g]", fakequant_op.minmax->min,
402                                  fakequant_op.minmax->max);
403       } else {
404         attrs["range"] = "[?,?]";
405       }
406       break;
407     }
408     default:
409       break;
410   }
411   int64 math_ops_count;
412   if (EstimateArithmeticOpsCount(model, op, &math_ops_count) &&
413       (math_ops_count != 0)) {
414     attrs["math"] = FormattedNumber(math_ops_count) + "ops";
415   }
416 
417   return attrs;
418 }
419 
GetOpColor(const Operator & op)420 Color GetOpColor(const Operator& op) {
421   if ((op.type == OperatorType::kDepthwiseConv) ||
422       (op.type == OperatorType::kConv) ||
423       (op.type == OperatorType::kFullyConnected) ||
424       (op.type == OperatorType::kFakeQuant)) {
425     // Give some ops a bolder red
426     return Color(0xC5, 0x39, 0x29);
427   } else {
428     return Color(0xDB, 0x44, 0x37);
429   }
430 }
431 
GetOpLabel(const Model & model,const Operator & op)432 std::string GetOpLabel(const Model& model, const Operator& op) {
433   // Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html)
434   std::string html;
435   html += "<";
436 
437   // Begin Table
438   html += R"CODE(<FONT POINT-SIZE="10" FACE="Courier">)CODE";
439   html +=
440       R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="0">)CODE";
441 
442   // Input Ports
443   if (!op.inputs.empty()) {
444     html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
445     // Distribute evenly using a sub-table
446     html += R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0">)CODE";
447     html += R"CODE(<TR>)CODE";
448     for (int i = 0; i < op.inputs.size(); i++) {
449       html += R"CODE(<TD PORT=")CODE";
450       AppendF(&html, "i%d", i);
451       html += R"CODE(">)CODE";
452       if (op.inputs.size() > 1) {
453         // Only number inputs when op has two or more inputs
454         AppendF(&html, "%d", i);
455       }
456       html += "</TD>";
457     }
458     html += "</TR>";
459     html += R"CODE(</TABLE></TD></TR>)CODE";
460   }
461 
462   // Name
463   html += R"CODE(<TR><TD COLSPAN="2" CELLPADDING="3" ALIGN="CENTER">)CODE";
464   html += R"CODE(<FONT POINT-SIZE="16" FACE="Helvetica"><B>)CODE";
465   if (op.type == OperatorType::kUnsupported) {
466     html += static_cast<const TensorFlowUnsupportedOperator&>(op).tensorflow_op;
467   } else {
468     html +=
469         std::string(absl::StripPrefix(OperatorTypeName(op.type), "TensorFlow"));
470   }
471   html += R"CODE(</B></FONT>)CODE";
472   html += "</TD></TR>";
473 
474   // Attributes
475   Attributes attrs = GetOpAttributes(model, op);
476   html += AttributesToHtml(attrs);
477 
478   // Output Ports
479   if (!op.outputs.empty()) {
480     html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
481     // Distribute evenly using a sub-table
482     html += R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0">)CODE";
483     html += R"CODE(<TR>)CODE";
484     for (int i = 0; i < op.outputs.size(); i++) {
485       html += R"CODE(<TD PORT=")CODE";
486       AppendF(&html, "o%d", i);
487       html += R"CODE(">)CODE";
488       if (op.outputs.size() > 1) {
489         // Only number outputs when op has two or more outputs
490         AppendF(&html, "%d", i);
491       }
492       html += "</TD>";
493     }
494     html += "</TR>";
495     html += R"CODE(</TABLE></TD></TR>)CODE";
496   }
497 
498   // End Table and HTML-like label
499   html += R"CODE(</TABLE></FONT>)CODE";
500   html += ">";
501 
502   return html;
503 }
504 
GetLog2BufferSize(const Model & model,const std::string & array_id)505 float GetLog2BufferSize(const Model& model, const std::string& array_id) {
506   auto& array = model.GetArray(array_id);
507   if (array.has_shape()) {
508     int buffer_size = 0;
509     if (IsNonEmpty(array.shape())) {
510       buffer_size = RequiredBufferSizeForShape(array.shape());
511       return std::log2(static_cast<float>(buffer_size));
512     }
513   }
514   return 0.0f;
515 }
516 
GetOpId(int op_index)517 std::string GetOpId(int op_index) { return StringF("op%05d", op_index); }
518 
DumpOperator(const Model & model,std::string * output_file,int op_index)519 void DumpOperator(const Model& model, std::string* output_file, int op_index) {
520   // Dump node for operator.
521   const Operator& op = *model.operators[op_index];
522   Color color = GetOpColor(op);
523   std::string label = GetOpLabel(model, op);
524   std::string op_id = GetOpId(op_index);
525   AppendF(output_file, kOpNodeFmt, op_id, label, color.AsHexString(),
526           color.TextColorString());
527 }
528 
DumpOperatorEdges(const Model & model,std::string * output_file,int op_index)529 void DumpOperatorEdges(const Model& model, std::string* output_file,
530                        int op_index) {
531   // Inputs
532   const Operator& op = *model.operators[op_index];
533   std::string op_id = GetOpId(op_index);
534   for (int i = 0; i < op.inputs.size(); i++) {
535     const auto& input = op.inputs[i];
536     if (!model.HasArray(input)) {
537       // Connected arrays should _always_ exist. Except, perhaps, during
538       // development.
539       continue;
540     }
541     float log2_buffer_size = GetLog2BufferSize(model, input);
542     // Draw lines that transport more data thicker (Otherwise, where would the
543     // data fit? right?).
544     float line_width = std::max(0.5f, log2_buffer_size / 3.0f);
545     // Keep edges that transport more data shorter than those with less.
546     float weight = std::max(1.0f, log2_buffer_size);
547     if (!IsInputArray(model, input) &&
548         GetOpWithOutput(model, input) == nullptr) {
549       // Give the main line of data flow a straighter path by penalizing edges
550       // to standalone buffers. Weights are generally very large buffers that
551       // would otherwise skew the layout.
552       weight = 1.0f;
553     }
554     std::string compass_pt = GetArrayCompassPt(model, input);
555     AppendF(output_file, kInputEdgeFmt, input, compass_pt, op_id, i, line_width,
556             weight);
557   }
558   // Outputs
559   for (int i = 0; i < op.outputs.size(); i++) {
560     const auto& output = op.outputs[i];
561     if (!model.HasArray(output)) {
562       continue;
563     }
564     float log2_buffer_size = GetLog2BufferSize(model, output);
565     // See comments above regarding weight and line_width calculations.
566     float line_width = std::max(0.5f, log2_buffer_size / 3.0f);
567     float weight = std::max(1.0f, log2_buffer_size);
568     if (!IsArrayConsumed(model, output)) {
569       weight = 1.0f;
570     }
571     std::string compass_pt = GetArrayCompassPt(model, output);
572     AppendF(output_file, kOutputEdgeFmt, op_id, i, output, compass_pt,
573             line_width, weight);
574   }
575 }
576 
577 struct Node {
Nodetoco::__anona87f18ca0111::Node578   Node() : math_ops(0) {}
579   // Name used as a key in the model's array map
580   std::string array_id;
581 
582   // Estimated number of math ops incurred by this node (the sum of the op
583   // with this array as 1st output, plus all children nodes).
584   int64 math_ops;
585 
586   // A map of child nodes keyed by name.
587   std::map<const std::string, std::unique_ptr<Node>> children;
588 };
589 
GetSubgraphLabel(Node const & node,const std::string & subgraph)590 std::string GetSubgraphLabel(Node const& node, const std::string& subgraph) {
591   // Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html)
592   std::string html;
593   html += "<";
594 
595   // Begin Table
596   html += R"CODE(<FONT POINT-SIZE="12" FACE="Courier">)CODE";
597   html +=
598       R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="0">)CODE";
599 
600   // Name
601   html += R"CODE(<TR><TD COLSPAN="2" CELLPADDING="3" ALIGN="CENTER">)CODE";
602   html += R"CODE(<FONT POINT-SIZE="18" FACE="Helvetica"><I>)CODE";
603   html += subgraph;
604   html += R"CODE(</I></FONT>)CODE";
605   html += "</TD></TR>";
606 
607   // Attributes
608   Attributes attrs;
609   if (node.math_ops > 0) {
610     attrs["math"] = FormattedNumber(node.math_ops) + "ops";
611   }
612   html += AttributesToHtml(attrs);
613 
614   // End Table and HTML-like label
615   html += R"CODE(</TABLE></FONT>)CODE";
616   html += ">";
617 
618   return html;
619 }
620 
DumpSubgraphHeader(std::string * output_file,Node const & node,const std::string & node_name)621 void DumpSubgraphHeader(std::string* output_file, Node const& node,
622                         const std::string& node_name) {
623   Color color = HashStringToColor(node_name);
624   std::string label = GetSubgraphLabel(node, node_name);
625   AppendF(output_file, kSubgraphFmt, node_name, color.AsHexString(), label);
626 }
627 
DumpArray(const Model & model,std::string * output_file,const std::string & array_id)628 void DumpArray(const Model& model, std::string* output_file,
629                const std::string& array_id) {
630   Color color;
631   std::string shape;
632   GetArrayColorAndShape(model, array_id, &color, &shape);
633   std::string label = GetArrayLabel(model, array_id);
634   AppendF(output_file, kArrayNodeFmt, array_id, label, array_id, shape,
635           color.AsHexString(), color.TextColorString());
636 
637   // Ops are placed in the same subgraph as their first output.
638   for (int op_index = 0; op_index < model.operators.size(); op_index++) {
639     const Operator& op = *model.operators[op_index];
640     if (!op.outputs.empty() && (op.outputs[0] == array_id)) {
641       DumpOperator(model, output_file, op_index);
642     }
643   }
644 }
645 
DumpNode(const Model & model,std::string * output_file,const std::string & node_name,Node const & node)646 void DumpNode(const Model& model, std::string* output_file,
647               const std::string& node_name, Node const& node) {
648   bool not_root = !node_name.empty();
649   if (not_root) {
650     DumpSubgraphHeader(output_file, node, node_name);
651   }
652 
653   for (const auto& child : node.children) {
654     if (!child.second->array_id.empty()) {
655       // Dump array if this node possesses one.
656       DumpArray(model, output_file, child.second->array_id);
657     }
658     // Note that it is always possible to have children. Unlike a filesystem,
659     // the existence of array "foo/bar" does _not_ prevent other arrays, such as
660     // and "foo/bar/baz", from being nested beneath it.
661     DumpNode(model, output_file, child.first, *child.second);
662   }
663 
664   if (not_root) {
665     // End subgraph
666     AppendF(output_file, "    }\n");
667   }
668 }
669 
GetArithmeticOpsCount(const Model & model,const std::string & array_id)670 int64 GetArithmeticOpsCount(const Model& model, const std::string& array_id) {
671   for (const auto& op : model.operators) {
672     if (!op->outputs.empty() && op->outputs[0] == array_id) {
673       int64 count;
674       if (EstimateArithmeticOpsCount(model, *op, &count)) {
675         return count;
676       } else {
677         return 0;
678       }
679     }
680   }
681   return 0;
682 }
683 
InsertNode(const Model & model,const std::string & array_id,Node * node,std::vector<std::string> prefixes,int64 * math_ops)684 void InsertNode(const Model& model, const std::string& array_id, Node* node,
685                 std::vector<std::string> prefixes, int64* math_ops) {
686   if (prefixes.empty()) {
687     // Base case: store array in this node.
688     node->array_id = array_id;
689     *math_ops = GetArithmeticOpsCount(model, array_id);
690   } else {
691     // Insert into the sub-tree for that prefix.
692     std::string prefix = prefixes.back();
693     prefixes.pop_back();
694     if (node->children.count(prefix) == 0) {
695       // Create a new node if this prefix is unseen.
696       node->children[prefix] = absl::make_unique<Node>();
697     }
698     InsertNode(model, array_id, node->children[prefix].get(), prefixes,
699                math_ops);
700   }
701   // Sum estimated math ops into all nodes.
702   node->math_ops += *math_ops;
703 }
704 
BuildArrayTree(const Model & model,Node * tree)705 void BuildArrayTree(const Model& model, Node* tree) {
706   // Delimit array names by path "/", then place into a tree based on this path.
707   for (const auto& array_id : model.GetArrayMap()) {
708     std::vector<std::string> prefixes = absl::StrSplit(array_id.first, '/');
709     std::reverse(prefixes.begin(), prefixes.end());
710     int64 math_ops;  // Temporary storage for math ops used during recursion.
711     InsertNode(model, array_id.first, tree, prefixes, &math_ops);
712   }
713 }
714 
GetGraphLabel(const Model & model,const std::string & graph_name)715 std::string GetGraphLabel(const Model& model, const std::string& graph_name) {
716   // Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html)
717   std::string html;
718   html += "<";
719 
720   // Begin Table
721   html += R"CODE(<FONT POINT-SIZE="36" FACE="Courier">)CODE";
722   html +=
723       R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="0">)CODE";
724 
725   // Name
726   html += R"CODE(<TR><TD COLSPAN="2" CELLPADDING="3" ALIGN="CENTER">)CODE";
727   html += R"CODE(<FONT POINT-SIZE="64" FACE="Helvetica"><B><I>)CODE";
728   html += graph_name;
729   html += R"CODE(</I></B></FONT>)CODE";
730   html += "</TD></TR>";
731 
732   // Attributes
733   Attributes attrs;
734   attrs["arrays"] = StringF("%d", model.GetArrayMap().size());
735   if (!model.optional_arrays.empty()) {
736     attrs["optional arrays"] = StringF("%d", model.optional_arrays.size());
737   }
738   attrs["operators"] = StringF("%d", model.operators.size());
739   int64 ops_count;
740   if (EstimateArithmeticOpsCount(model, &ops_count) && (ops_count > 0)) {
741     attrs["math"] = FormattedNumber(ops_count) + "ops";
742   }
743   if (model.transient_data_size > 0) {
744     attrs["transient data size"] =
745         StringF("%d KiB", model.transient_data_size / 1024);
746   }
747   if (model.transient_data_alignment > 0) {
748     attrs["transient data alignment"] =
749         StringF("%d bytes", model.transient_data_alignment);
750   }
751   html += AttributesToHtml(attrs);
752 
753   // End Table and HTML-like label
754   html += R"CODE(</TABLE></FONT>)CODE";
755   html += ">";
756 
757   return html;
758 }
759 }  // namespace
760 
DumpGraphviz(const Model & model,std::string * output_file,const std::string & graph_name)761 void DumpGraphviz(const Model& model, std::string* output_file,
762                   const std::string& graph_name) {
763   // Start graphviz format
764   AppendF(output_file, kGraphFmt, GetGraphLabel(model, graph_name));
765 
766   // Organize arrays into a tree for subgraphing
767   Node tree;
768   BuildArrayTree(model, &tree);
769   DumpNode(model, output_file, "", tree);
770 
771   // Dump edges outside all subgraphs (otherwise the referred-to nodes are
772   // implicitly included in that subgraph).
773   for (int op_index = 0; op_index < model.operators.size(); op_index++) {
774     DumpOperatorEdges(model, output_file, op_index);
775   }
776 
777   // Dump RNN Backedges
778   for (const auto& rnn_state : model.flags.rnn_states()) {
779     AppendF(output_file, kRNNBackEdgeFmt, rnn_state.back_edge_source_array(),
780             rnn_state.state_array());
781   }
782   // End graphviz format
783   AppendF(output_file, "}\n");
784 }
785 }  // namespace toco
786