1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/tools/graph_transforms/transform_utils.h"
17 
18 #include "tensorflow/core/framework/node_def_util.h"
19 #include "tensorflow/core/framework/op.h"
20 #include "tensorflow/core/lib/hash/hash.h"
21 #include "tensorflow/core/lib/strings/numbers.h"
22 #include "tensorflow/core/lib/strings/str_util.h"
23 
24 namespace tensorflow {
25 namespace graph_transforms {
26 
27 namespace {
IsMerge(const NodeDef & node_def)28 inline bool IsMerge(const NodeDef& node_def) {
29   return node_def.op() == "Merge" || node_def.op() == "RefMerge" ||
30          node_def.op() == "_XlaMerge";
31 }
32 
RecordMatchedNodes(const NodeMatch & match,std::set<string> * matched_nodes)33 void RecordMatchedNodes(const NodeMatch& match,
34                         std::set<string>* matched_nodes) {
35   matched_nodes->insert(match.node.name());
36   for (const NodeMatch& input_match : match.inputs) {
37     RecordMatchedNodes(input_match, matched_nodes);
38   }
39 }
40 
Hash64String(const string & input)41 inline uint64 Hash64String(const string& input) {
42   return Hash64(input.data(), input.size());
43 }
44 }  // namespace
45 
MatchedNodesAsArray(const NodeMatch & match,std::vector<NodeDef> * result)46 void MatchedNodesAsArray(const NodeMatch& match, std::vector<NodeDef>* result) {
47   std::set<string> found_nodes;
48   std::vector<NodeMatch> current_matches = {match};
49   while (!current_matches.empty()) {
50     std::vector<NodeMatch> next_matches;
51     for (const NodeMatch& current_match : current_matches) {
52       if (found_nodes.count(current_match.node.name())) {
53         continue;
54       }
55       found_nodes.insert(current_match.node.name());
56       result->push_back(current_match.node);
57       for (const NodeMatch& input_match : current_match.inputs) {
58         next_matches.push_back(input_match);
59       }
60     }
61     current_matches = next_matches;
62   }
63 }
64 
MapNamesToNodes(const GraphDef & graph_def,std::map<string,const NodeDef * > * result)65 void MapNamesToNodes(const GraphDef& graph_def,
66                      std::map<string, const NodeDef*>* result) {
67   for (const NodeDef& node : graph_def.node()) {
68     (*result)[node.name()] = &node;
69   }
70 }
71 
MapNodesToOutputs(const GraphDef & graph_def,std::map<string,std::vector<const NodeDef * >> * result)72 void MapNodesToOutputs(const GraphDef& graph_def,
73                        std::map<string, std::vector<const NodeDef*>>* result) {
74   std::map<string, const NodeDef*> node_map;
75   MapNamesToNodes(graph_def, &node_map);
76   for (const NodeDef& node : graph_def.node()) {
77     for (const string& input : node.input()) {
78       string input_node_name = NodeNameFromInput(input);
79       (*result)[input_node_name].push_back(&node);
80     }
81   }
82 }
83 
NodeNamePartsFromInput(const string & input_name,string * prefix,string * node_name,string * suffix)84 void NodeNamePartsFromInput(const string& input_name, string* prefix,
85                             string* node_name, string* suffix) {
86   std::vector<string> input_parts = str_util::Split(input_name, ':');
87   if (input_parts.size() < 2) {
88     *suffix = "";
89   } else {
90     *suffix = ":" + input_parts[1];
91   }
92   StringPiece node_name_piece(input_parts[0]);
93   if (absl::ConsumePrefix(&node_name_piece, "^")) {
94     *prefix = "^";
95   } else {
96     *prefix = "";
97   }
98   *node_name = string(node_name_piece);
99 }
100 
NodeNameFromInput(const string & input_name)101 string NodeNameFromInput(const string& input_name) {
102   string prefix;
103   string node_name;
104   string suffix;
105   NodeNamePartsFromInput(input_name, &prefix, &node_name, &suffix);
106   return node_name;
107 }
108 
CanonicalInputName(const string & input_name)109 string CanonicalInputName(const string& input_name) {
110   string prefix;
111   string node_name;
112   string suffix;
113   NodeNamePartsFromInput(input_name, &prefix, &node_name, &suffix);
114   if (suffix.empty()) {
115     suffix = ":0";
116   }
117   return prefix + node_name + suffix;
118 }
119 
HashNodeDef(const NodeDef & node)120 uint64 HashNodeDef(const NodeDef& node) {
121   uint64 hash = Hash64String(node.op());
122   hash = Hash64Combine(hash, Hash64String(node.name()));
123   for (const string& input : node.input()) {
124     hash = Hash64Combine(hash, Hash64String(CanonicalInputName(input)));
125   }
126   hash = Hash64Combine(hash, Hash64String(node.device()));
127   std::vector<string> attr_names;
128   attr_names.reserve(node.attr().size());
129   for (const auto& attr : node.attr()) {
130     attr_names.push_back(attr.first);
131   }
132   std::sort(attr_names.begin(), attr_names.end());
133   string attr_serialized;
134   for (const string& attr_name : attr_names) {
135     auto attr = node.attr().at(attr_name);
136     attr.SerializeToString(&attr_serialized);
137     hash = Hash64Combine(hash, Hash64String(attr_serialized));
138   }
139   return hash;
140 }
141 
AddNodeInput(const string & input_name,NodeDef * node)142 void AddNodeInput(const string& input_name, NodeDef* node) {
143   *(node->mutable_input()->Add()) = input_name;
144 }
145 
CopyNodeAttr(const NodeDef & source,const string & source_key,const string & dest_key,NodeDef * dest)146 void CopyNodeAttr(const NodeDef& source, const string& source_key,
147                   const string& dest_key, NodeDef* dest) {
148   CHECK_NE(0, source.attr().count(source_key))
149       << "No key '" << source_key << "' found in " << source.DebugString();
150   (*(dest->mutable_attr()))[dest_key] = source.attr().at(source_key);
151 }
152 
GetNodeTensorAttr(const NodeDef & node,const string & key)153 Tensor GetNodeTensorAttr(const NodeDef& node, const string& key) {
154   TensorProto tensor_proto = node.attr().at(key).tensor();
155   Tensor tensor;
156   CHECK(tensor.FromProto(tensor_proto));
157   return tensor;
158 }
159 
FilterGraphDef(const GraphDef & input_graph_def,std::function<bool (const NodeDef &)> selector,GraphDef * output_graph_def)160 void FilterGraphDef(const GraphDef& input_graph_def,
161                     std::function<bool(const NodeDef&)> selector,
162                     GraphDef* output_graph_def) {
163   output_graph_def->mutable_node()->Clear();
164   for (const NodeDef& node : input_graph_def.node()) {
165     if (selector(node)) {
166       *output_graph_def->mutable_node()->Add() = node;
167     }
168   }
169 }
170 
RemoveAttributes(const GraphDef & input_graph_def,const std::vector<string> & attributes,GraphDef * output_graph_def)171 void RemoveAttributes(const GraphDef& input_graph_def,
172                       const std::vector<string>& attributes,
173                       GraphDef* output_graph_def) {
174   output_graph_def->mutable_node()->Clear();
175   for (const NodeDef& node : input_graph_def.node()) {
176     NodeDef* new_node = output_graph_def->mutable_node()->Add();
177     *new_node = node;
178     for (const string& attribute : attributes) {
179       new_node->mutable_attr()->erase(attribute);
180     }
181   }
182 }
183 
SortByExecutionOrder(const GraphDef & input_graph_def,GraphDef * output_graph_def)184 Status SortByExecutionOrder(const GraphDef& input_graph_def,
185                             GraphDef* output_graph_def) {
186   const int num_nodes = input_graph_def.node_size();
187   std::vector<int> ready;
188   std::vector<int> pending_count;
189   pending_count.reserve(num_nodes);
190   std::vector<gtl::InlinedVector<int, 4>> outputs(num_nodes);
191 
192   std::map<string, int> name_index;
193   for (int i = 0; i < input_graph_def.node_size(); ++i) {
194     const NodeDef& node(input_graph_def.node(i));
195     name_index[node.name()] = i;
196   }
197 
198   // Parse the inputs for each node.
199   for (int n = 0; n < num_nodes; ++n) {
200     const NodeDef& node_def(input_graph_def.node(n));
201     if (IsMerge(node_def)) {
202       // for merge only wait for one non-control input.
203       int32 num_control_edges = 0;
204       for (int i = 0; i < node_def.input_size(); ++i) {
205         if (absl::StartsWith(node_def.input(i), "^")) {
206           num_control_edges++;
207         }
208       }
209       pending_count.push_back(num_control_edges + 1);
210     } else {
211       pending_count.push_back(node_def.input_size());
212     }
213     if (node_def.input_size() == 0) {
214       ready.push_back(n);
215       continue;
216     }
217     for (int i = 0; i < node_def.input_size(); ++i) {
218       const string& input_name = node_def.input(i);
219       const string& input_node_name = NodeNameFromInput(input_name);
220       if (!name_index.count(input_node_name)) {
221         return errors::InvalidArgument("Node '", node_def.name(),
222                                        "': Unknown input node '",
223                                        node_def.input(i), "'");
224       }
225       outputs[name_index[input_node_name]].push_back(n);
226     }
227   }
228 
229   int processed = 0;
230   output_graph_def->Clear();
231   // Process the NodeDefs in topological order.
232   // Code above sets this up by filling in ready_ with nodes that have no
233   // inputs, pending_counts_ with the number of inputs for each node and
234   // outputs_ with the outputs of each node.
235   while (!ready.empty()) {
236     int o = ready.back();
237     ready.pop_back();
238     ++processed;
239     const NodeDef& node_def(input_graph_def.node(o));
240     *output_graph_def->mutable_node()->Add() = node_def;
241 
242     // Update pending_count for outputs.
243     for (size_t i = 0; i < outputs[o].size(); ++i) {
244       const int output = outputs[o][i];
245       pending_count[output]--;
246       if (pending_count[output] == 0) {
247         ready.push_back(output);
248       }
249     }
250   }
251 
252   if (processed < num_nodes) {
253     LOG(WARNING) << "IN " << __func__ << (num_nodes - processed)
254                  << " NODES IN A CYCLE";
255     for (int64 i = 0; i < num_nodes; i++) {
256       if (pending_count[i] != 0) {
257         LOG(WARNING) << "PENDING: " << SummarizeNodeDef(input_graph_def.node(i))
258                      << "WITH PENDING COUNT = " << pending_count[i];
259       }
260     }
261     return errors::InvalidArgument(num_nodes - processed, " nodes in a cycle");
262   }
263   return Status::OK();
264 }
265 
DebugString() const266 string OpTypePattern::DebugString() const {
267   string result = "{" + op + ", {";
268   for (const OpTypePattern& input : inputs) {
269     result += input.DebugString() + ",";
270   }
271   result += "}}";
272   return result;
273 }
274 
DebugString() const275 string NodeMatch::DebugString() const {
276   string result = "{";
277   result += node.DebugString();
278   result += ", {";
279   for (const NodeMatch& input : inputs) {
280     result += input.DebugString() + ",";
281   }
282   result += "}}";
283   return result;
284 }
285 
GraphMatcher(const GraphDef & graph_def)286 GraphMatcher::GraphMatcher(const GraphDef& graph_def) {
287   SortByExecutionOrder(graph_def, &graph_def_).IgnoreError();
288   MapNamesToNodes(graph_def_, &node_map_);
289 }
290 
GetOpTypeMatches(const OpTypePattern & pattern,std::vector<NodeMatch> * matches)291 Status GraphMatcher::GetOpTypeMatches(const OpTypePattern& pattern,
292                                       std::vector<NodeMatch>* matches) {
293   std::set<string> matched_nodes;
294   for (const NodeDef& node : graph_def_.node()) {
295     // Skip any nodes that are already part of a match.
296     if (matched_nodes.count(node.name())) {
297       continue;
298     }
299     NodeMatch match;
300     if (DoesOpTypeMatch(node, pattern, matched_nodes, &match)) {
301       RecordMatchedNodes(match, &matched_nodes);
302       matches->push_back(match);
303     }
304   }
305   return Status::OK();
306 }
307 
DoesOpTypeMatch(const NodeDef & node,const OpTypePattern & pattern,const std::set<string> & previously_matched_nodes,NodeMatch * match)308 bool GraphMatcher::DoesOpTypeMatch(
309     const NodeDef& node, const OpTypePattern& pattern,
310     const std::set<string>& previously_matched_nodes, NodeMatch* match) {
311   VLOG(1) << "Looking at node " << node.DebugString();
312   VLOG(1) << "pattern=" << pattern.DebugString();
313   VLOG(1) << "match=" << match->DebugString();
314   if (previously_matched_nodes.count(node.name())) {
315     VLOG(1) << "node " << node.name() << " has been previously matched";
316     return false;
317   }
318   bool pattern_matched = false;
319   if (pattern.op == "*") {
320     pattern_matched = true;
321   } else {
322     std::vector<string> pattern_ops = str_util::Split(pattern.op, '|');
323     for (const string& pattern_op : pattern_ops) {
324       if (node.op() == pattern_op) {
325         pattern_matched = true;
326       }
327     }
328   }
329   if (!pattern_matched) {
330     VLOG(1) << "node.op() != pattern.op()";
331     return false;
332   }
333   match->node = node;
334   // Ignore any control inputs for pattern-matching purposes
335   std::vector<string> non_control_inputs;
336   for (const string& input : node.input()) {
337     if (!input.empty() && (input[0] != '^')) {
338       non_control_inputs.push_back(input);
339     }
340   }
341   if (pattern.inputs.empty()) {
342     // If there are no inputs, assume that's the end of the pattern.
343     return true;
344   }
345   if (non_control_inputs.size() != pattern.inputs.size()) {
346     VLOG(1) << "non_control_inputs.size() != pattern.inputs.size()";
347     return false;
348   }
349   for (int i = 0; i < pattern.inputs.size(); ++i) {
350     const string& input_node_name = NodeNameFromInput(non_control_inputs[i]);
351     const NodeDef& input_node = *(node_map_[input_node_name]);
352     const OpTypePattern& input_pattern = pattern.inputs[i];
353     match->inputs.push_back(NodeMatch());
354     NodeMatch* input_match = &(match->inputs.back());
355     if (!DoesOpTypeMatch(input_node, input_pattern, previously_matched_nodes,
356                          input_match)) {
357       return false;
358     }
359   }
360   return true;
361 }
362 
ReplaceMatchingOpTypes(const GraphDef & input_graph_def,const OpTypePattern & pattern,const std::function<Status (const NodeMatch &,const std::set<string> &,const std::set<string> &,std::vector<NodeDef> *)> & node_generator,const ReplaceMatchingOpTypesOptions & options,GraphDef * output_graph_def)363 Status ReplaceMatchingOpTypes(
364     const GraphDef& input_graph_def, const OpTypePattern& pattern,
365     const std::function<Status(const NodeMatch&, const std::set<string>&,
366                                const std::set<string>&, std::vector<NodeDef>*)>&
367         node_generator,
368     const ReplaceMatchingOpTypesOptions& options, GraphDef* output_graph_def) {
369   // Start off by retrieving all the matching subgraphs.
370   GraphMatcher matcher(input_graph_def);
371   std::vector<NodeMatch> matches;
372   TF_RETURN_IF_ERROR(matcher.GetOpTypeMatches(pattern, &matches));
373 
374   // Do some housekeeping so we can easily look up the resulting matches given
375   // a node name.
376   std::set<string> matched_nodes;
377   std::map<string, const NodeMatch*> matches_by_head_name;
378   for (const NodeMatch& match : matches) {
379     matches_by_head_name[match.node.name()] = &match;
380     RecordMatchedNodes(match, &matched_nodes);
381   }
382   std::map<string, std::vector<const NodeDef*>> outputs_map;
383   MapNodesToOutputs(input_graph_def, &outputs_map);
384 
385   // Go through all the nodes in the input graph, see if they are part of a
386   // match or if they can be left untouched.
387   output_graph_def->Clear();
388   for (const NodeDef& input_node : input_graph_def.node()) {
389     if (matches_by_head_name.count(input_node.name())) {
390       // This node is the beginning of a match, so call the replacement function
391       // after setting up some information it will need.
392       const NodeMatch* match = matches_by_head_name[input_node.name()];
393       std::vector<NodeDef> matched_nodes_array;
394       MatchedNodesAsArray(*match, &matched_nodes_array);
395       // This tells us whether a node is part of the current match.
396       std::set<string> matched_nodes_lookup;
397       for (const NodeDef& matched_node : matched_nodes_array) {
398         matched_nodes_lookup.insert(matched_node.name());
399       }
400       // These are helper arrays that the replacement function can use to tell
401       // whether it can safely remove an internal node (because nothing outside
402       // of the match uses it) or whether external nodes depend on it.
403       std::set<string> input_nodes;
404       std::set<string> output_nodes;
405       for (const NodeDef& matched_node : matched_nodes_array) {
406         // Look through all of this node's inputs, and if any of them come from
407         // outside the match, then this should be noted as one of the external
408         // inputs of the subgraph.
409         for (const string& input_name : matched_node.input()) {
410           string input_node_name = NodeNameFromInput(input_name);
411           if (!matched_nodes_lookup.count(input_node_name)) {
412             input_nodes.insert(matched_node.name());
413           }
414         }
415         // Do a reverse input lookup, to see which other nodes use the current
416         // one as an input. If any of those nodes are outside the match
417         // subgraph, then the current node is marked as an output node that
418         // shouldn't be removed.
419         if (outputs_map.count(matched_node.name())) {
420           for (const NodeDef* dependent_node :
421                outputs_map[matched_node.name()]) {
422             if (!matched_nodes_lookup.count(dependent_node->name())) {
423               output_nodes.insert(matched_node.name());
424             }
425           }
426         }
427       }
428       // Call the generator function and add all the returned nodes to the
429       // graph.
430       std::vector<NodeDef> new_nodes;
431       TF_RETURN_IF_ERROR(
432           node_generator(*match, input_nodes, output_nodes, &new_nodes));
433       std::set<string> new_node_names;
434       for (const NodeDef& new_node : new_nodes) {
435         new_node_names.insert(new_node.name());
436       }
437       // Check to make sure the generator function preserved all of the nodes
438       // that are used elsewhere in the graph, and add them back in if not.
439       bool abort_replacement = false;
440       if (!options.allow_inconsistencies) {
441         for (const string& expected_output : output_nodes) {
442           if (!new_node_names.count(expected_output)) {
443             LOG(WARNING) << "Expected " << expected_output
444                          << " to be preserved.";
445             abort_replacement = true;
446           }
447         }
448       }
449       if (abort_replacement) {
450         LOG(WARNING) << "Generator function didn't preserve needed nodes, "
451                      << "copying old replacements back in instead.";
452         std::vector<NodeDef> old_nodes;
453         MatchedNodesAsArray(*match, &old_nodes);
454         for (const NodeDef& old_node : old_nodes) {
455           NodeDef* added_node = output_graph_def->mutable_node()->Add();
456           *added_node = old_node;
457         }
458       } else {
459         for (const NodeDef& new_node : new_nodes) {
460           NodeDef* added_node = output_graph_def->mutable_node()->Add();
461           *added_node = new_node;
462         }
463       }
464     } else if (!matched_nodes.count(input_node.name())) {
465       // This node isn't part of any match, so just copy it over.
466       NodeDef* added_node = output_graph_def->mutable_node()->Add();
467       *added_node = input_node;
468     } else {
469       // Do nothing, because this is an internal part of a matching subgraph,
470       // and so will have been replaced by a new replacement subgraph.
471     }
472   }
473 
474   return Status::OK();
475 }
476 
RenameNodeInputs(const GraphDef & input_graph_def,const std::map<string,string> & inputs_to_rename,const std::unordered_set<string> & nodes_to_ignore,GraphDef * output_graph_def)477 Status RenameNodeInputs(const GraphDef& input_graph_def,
478                         const std::map<string, string>& inputs_to_rename,
479                         const std::unordered_set<string>& nodes_to_ignore,
480                         GraphDef* output_graph_def) {
481   std::map<string, std::vector<std::pair<string, string>>>
482       canonical_inputs_to_rename;
483   for (const auto& input_to_rename : inputs_to_rename) {
484     canonical_inputs_to_rename[NodeNameFromInput(input_to_rename.first)]
485         .push_back({input_to_rename.first, input_to_rename.second});
486   }
487 
488   output_graph_def->Clear();
489   for (const NodeDef& node : input_graph_def.node()) {
490     NodeDef* new_node = output_graph_def->mutable_node()->Add();
491     *new_node = node;
492     new_node->mutable_input()->Clear();
493     for (const string& input_name : node.input()) {
494       std::set<string> already_visited;
495       string new_input_name = input_name;
496       while (
497           canonical_inputs_to_rename.count(NodeNameFromInput(new_input_name))) {
498         string input_node_name = NodeNameFromInput(new_input_name);
499         if (already_visited.count(input_node_name)) {
500           return errors::InvalidArgument(
501               "RenameNodeInputs argument contains a cycle for ",
502               input_node_name);
503         }
504         already_visited.insert(input_node_name);
505         if (nodes_to_ignore.count(node.name())) {
506           break;
507         }
508         bool any_match_found = false;
509         for (const std::pair<string, string>& input_to_rename :
510              canonical_inputs_to_rename.at(input_node_name)) {
511           const string& source_name = input_to_rename.first;
512           const string& dest_name = input_to_rename.second;
513           bool is_match;
514           string match_name;
515           if (str_util::EndsWith(source_name, ":*")) {
516             is_match = true;
517             string prefix;
518             string unused_node_name;
519             string suffix;
520             NodeNamePartsFromInput(new_input_name, &prefix, &unused_node_name,
521                                    &suffix);
522             match_name = prefix + dest_name + suffix;
523           } else {
524             is_match = (CanonicalInputName(source_name) ==
525                         CanonicalInputName(new_input_name));
526             match_name = dest_name;
527           }
528           if (is_match) {
529             new_input_name = match_name;
530             any_match_found = true;
531           }
532         }
533         if (!any_match_found) {
534           break;
535         }
536       }
537       *(new_node->mutable_input()->Add()) = new_input_name;
538     }
539   }
540   return Status::OK();
541 }
542 
CopyOriginalMatch(const NodeMatch & match,std::vector<NodeDef> * new_nodes)543 void CopyOriginalMatch(const NodeMatch& match,
544                        std::vector<NodeDef>* new_nodes) {
545   std::vector<NodeDef> old_nodes;
546   MatchedNodesAsArray(match, &old_nodes);
547   for (const NodeDef& old_node : old_nodes) {
548     new_nodes->push_back(old_node);
549   }
550 }
551 
GetTransformRegistry()552 TransformRegistry* GetTransformRegistry() {
553   static TransformRegistry transform_registry;
554   return &transform_registry;
555 }
556 
FindInvalidInputs(const GraphDef & graph_def,std::vector<std::pair<string,string>> * invalid_inputs)557 void FindInvalidInputs(const GraphDef& graph_def,
558                        std::vector<std::pair<string, string>>* invalid_inputs) {
559   std::map<string, const NodeDef*> node_map;
560   MapNamesToNodes(graph_def, &node_map);
561 
562   for (const NodeDef& node : graph_def.node()) {
563     for (const string& input : node.input()) {
564       string input_node = NodeNameFromInput(input);
565       if (!node_map.count(input_node)) {
566         invalid_inputs->push_back({node.name(), input_node});
567       }
568     }
569   }
570 }
571 
IsGraphValid(const GraphDef & graph_def)572 Status IsGraphValid(const GraphDef& graph_def) {
573   std::vector<std::pair<string, string>> invalid_inputs;
574   FindInvalidInputs(graph_def, &invalid_inputs);
575   if (!invalid_inputs.empty()) {
576     std::map<string, const NodeDef*> node_map;
577     MapNamesToNodes(graph_def, &node_map);
578     for (const std::pair<string, string>& invalid_input : invalid_inputs) {
579       LOG(ERROR) << "Invalid input " << invalid_input.second << " for node "
580                  << invalid_input.first << " - "
581                  << node_map[invalid_input.first]->DebugString();
582     }
583     return errors::Internal(
584         "Invalid graph with inputs referring to nonexistent nodes");
585   }
586   return Status::OK();
587 }
588 
GetInOutTypes(const NodeDef & node_def,DataTypeVector * inputs,DataTypeVector * outputs)589 Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs,
590                      DataTypeVector* outputs) {
591   const OpDef* op_def;
592   TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def));
593   TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, *op_def, inputs, outputs));
594   return Status::OK();
595 }
596 
TensorShapeFromString(const string & shape_string,TensorShape * result)597 Status TensorShapeFromString(const string& shape_string, TensorShape* result) {
598   if (shape_string.empty()) {
599     return errors::InvalidArgument("Specified shape is empty.");
600   }
601   std::vector<string> dims_as_str = str_util::Split(shape_string, ",");
602   std::vector<int64> dims;
603   for (const string& dim : dims_as_str) {
604     int64 tmp;
605     if (strings::safe_strto64(dim, &tmp)) {
606       dims.push_back(tmp);
607     } else {
608       return errors::InvalidArgument("Could parse as shape: '", shape_string,
609                                      "'");
610     }
611   }
612   *result = TensorShape(dims);
613   return Status::OK();
614 }
615 
CountParameters(const string & name) const616 int TransformFuncContext::CountParameters(const string& name) const {
617   if (params.count(name)) {
618     return params.at(name).size();
619   } else {
620     return 0;
621   }
622 }
623 
GetOneStringParameter(const string & name,const string & default_value,string * result) const624 Status TransformFuncContext::GetOneStringParameter(const string& name,
625                                                    const string& default_value,
626                                                    string* result) const {
627   const int params_count = CountParameters(name);
628   if (params_count == 0) {
629     *result = default_value;
630     return Status::OK();
631   } else if (params_count == 1) {
632     *result = params.at(name).at(0);
633     return Status::OK();
634   } else {
635     return errors::InvalidArgument("Expected a single '", name,
636                                    "' parameter, but found ", params_count,
637                                    " occurrences");
638   }
639 }
640 
GetOneInt32Parameter(const string & name,int32 default_value,int32 * result) const641 Status TransformFuncContext::GetOneInt32Parameter(const string& name,
642                                                   int32 default_value,
643                                                   int32* result) const {
644   const int params_count = CountParameters(name);
645   if (params_count == 0) {
646     *result = default_value;
647     return Status::OK();
648   }
649   string string_value;
650   TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value));
651   if (!strings::safe_strto32(StringPiece(string_value), result)) {
652     return errors::InvalidArgument("Couldn't interpret the ", name,
653                                    " argument as a number:", string_value);
654   }
655   return Status::OK();
656 }
657 
GetOneInt64Parameter(const string & name,int64 default_value,int64 * result) const658 Status TransformFuncContext::GetOneInt64Parameter(const string& name,
659                                                   int64 default_value,
660                                                   int64* result) const {
661   const int params_count = CountParameters(name);
662   if (params_count == 0) {
663     *result = default_value;
664     return Status::OK();
665   }
666   string string_value;
667   TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value));
668   if (!strings::safe_strto64(StringPiece(string_value), result)) {
669     return errors::InvalidArgument("Couldn't interpret the ", name,
670                                    " argument as a number:", string_value);
671   }
672   return Status::OK();
673 }
674 
GetOneFloatParameter(const string & name,float default_value,float * result) const675 Status TransformFuncContext::GetOneFloatParameter(const string& name,
676                                                   float default_value,
677                                                   float* result) const {
678   const int params_count = CountParameters(name);
679   if (params_count == 0) {
680     *result = default_value;
681     return Status::OK();
682   }
683   string string_value;
684   TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value));
685   if (!strings::safe_strtof(string_value.c_str(), result)) {
686     return errors::InvalidArgument(
687         "Couldn't interpret the ", name,
688         " argument as a float number:", string_value);
689   }
690   return Status::OK();
691 }
692 
GetOneBoolParameter(const string & name,bool default_value,bool * result) const693 Status TransformFuncContext::GetOneBoolParameter(const string& name,
694                                                  bool default_value,
695                                                  bool* result) const {
696   const int params_count = CountParameters(name);
697   if (params_count == 0) {
698     *result = default_value;
699     return Status::OK();
700   }
701   string string_value;
702   TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value));
703   if (string_value == "true" || string_value == "1") {
704     *result = true;
705   } else if (string_value == "false" || string_value == "0") {
706     *result = false;
707   } else {
708     return errors::InvalidArgument("Couldn't interpret the ", name,
709                                    " argument as a boolean:", string_value,
710                                    " (expected true, false, 0 or 1)");
711   }
712   return Status::OK();
713 }
714 
715 }  // namespace graph_transforms
716 }  // namespace tensorflow
717