1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
17 
18 #include "tensorflow/core/common_runtime/constant_folding.h"
19 #include "tensorflow/core/graph/graph_constructor.h"
20 #include "tensorflow/core/graph/node_builder.h"
21 #include "tensorflow/core/graph/subgraph.h"
22 #include "tensorflow/core/lib/strings/str_util.h"
23 #include "tensorflow/core/platform/init_main.h"
24 #include "tensorflow/core/public/session.h"
25 #include "tensorflow/tools/graph_transforms/transform_utils.h"
26 
27 namespace tensorflow {
28 namespace graph_transforms {
29 
30 // Clears the device field of all ops in the graph.
InsertLogging(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)31 Status InsertLogging(const GraphDef& input_graph_def,
32                      const TransformFuncContext& context,
33                      GraphDef* output_graph_def) {
34   std::unordered_set<string> ops;
35   bool has_ops;
36   if (context.params.count("op")) {
37     has_ops = true;
38     for (const string& op : context.params.at("op")) {
39       ops.insert(op);
40     }
41   } else {
42     has_ops = false;
43   }
44 
45   std::unordered_set<string> prefixes;
46   bool has_prefixes;
47   if (context.params.count("prefix")) {
48     has_prefixes = true;
49     for (const string& prefix : context.params.at("prefix")) {
50       prefixes.insert(prefix);
51     }
52   } else {
53     has_prefixes = false;
54   }
55 
56   string message;
57   TF_RETURN_IF_ERROR(context.GetOneStringParameter("message", "", &message));
58 
59   bool show_name;
60   TF_RETURN_IF_ERROR(
61       context.GetOneBoolParameter("show_name", false, &show_name));
62 
63   bool show_op;
64   TF_RETURN_IF_ERROR(context.GetOneBoolParameter("show_op", false, &show_op));
65 
66   int32 first_n;
67   TF_RETURN_IF_ERROR(context.GetOneInt32Parameter("first_n", -1, &first_n));
68 
69   int32 summarize;
70   TF_RETURN_IF_ERROR(
71       context.GetOneInt32Parameter("summarize", 1024, &summarize));
72 
73   std::unordered_map<string, std::set<int>> node_outputs;
74   for (const NodeDef& node : input_graph_def.node()) {
75     for (const string& input : node.input()) {
76       const string canonical_input = CanonicalInputName(input);
77       string prefix;
78       string name;
79       string suffix;
80       NodeNamePartsFromInput(canonical_input, &prefix, &name, &suffix);
81       const string output_index_string = suffix.substr(1, suffix.size() - 1);
82       int32 output_index;
83       if (!strings::safe_strto32(output_index_string, &output_index)) {
84         return errors::InvalidArgument("Couldn't understand output number in ",
85                                        input);
86       }
87       node_outputs[name].insert(output_index);
88     }
89   }
90 
91   std::map<string, string> inputs_to_rename;
92   std::unordered_set<string> ignore_when_renaming;
93   GraphDef logged_graph_def;
94   for (const NodeDef& node : input_graph_def.node()) {
95     NodeDef* new_node = logged_graph_def.mutable_node()->Add();
96     *new_node = node;
97     if (node_outputs[node.name()].empty()) {
98       // There were no outputs found to this node, so skip it.
99       continue;
100     }
101     const bool op_matches = (ops.count(node.op()) > 0);
102     bool prefix_matches = false;
103     for (const string& prefix : prefixes) {
104       if (str_util::StartsWith(node.name(), prefix)) {
105         prefix_matches = true;
106       }
107     }
108     // If we're not looking for ops, or we found the right op, and if we're not
109     // looking for prefixes or we found the right prefix, then add logging here.
110     if ((!has_ops || op_matches) && (!has_prefixes || prefix_matches)) {
111       const string name_suffix = "__print__";
112       DataTypeVector input_types;
113       DataTypeVector output_types;
114       TF_RETURN_IF_ERROR(GetInOutTypes(node, &input_types, &output_types));
115       NodeDef* print_node = logged_graph_def.mutable_node()->Add();
116       print_node->set_op("Print");
117       print_node->set_name(strings::StrCat(node.name(), name_suffix));
118       string node_message;
119       if (show_op) {
120         node_message += ";" + node.op() + ";";
121       }
122       if (show_name) {
123         node_message += ";" + print_node->name() + ";";
124       }
125       node_message += message;
126       SetNodeAttr("message", node_message, print_node);
127       SetNodeAttr("first_n", first_n, print_node);
128       SetNodeAttr("summarize", summarize, print_node);
129       print_node->add_input(node.name() + ":0");
130       SetNodeAttr("T", output_types[0], print_node);
131       for (int output_index : node_outputs[node.name()]) {
132         print_node->add_input(strings::StrCat(node.name(), ":", output_index));
133       }
134       SetNodeAttr("U", output_types, print_node);
135       ignore_when_renaming.insert(print_node->name());
136       // Rewrite the graph so all references to the first input of the original
137       // op now pull from the print op instead, so it's executed.
138       inputs_to_rename[node.name() + ":0"] =
139           strings::StrCat(node.name(), name_suffix, ":0");
140     }
141   }
142 
143   output_graph_def->Clear();
144   return RenameNodeInputs(logged_graph_def, inputs_to_rename,
145                           ignore_when_renaming, output_graph_def);
146 }
147 
148 REGISTER_GRAPH_TRANSFORM("insert_logging", InsertLogging);
149 
150 }  // namespace graph_transforms
151 }  // namespace tensorflow
152