1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/c_api_internal.h"
17 
18 #include <algorithm>
19 #include <unordered_map>
20 #include <unordered_set>
21 
22 #include "tensorflow/core/framework/attr_value_util.h"
23 #include "tensorflow/core/framework/function.pb.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/node_def_util.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/graph/graph.h"
28 #include "tensorflow/core/lib/strings/base64.h"
29 #include "tensorflow/core/lib/strings/strcat.h"
30 
31 using tensorflow::errors::InvalidArgument;
32 
33 namespace tensorflow {
34 namespace {
35 
36 // Class that maintains a one-to-one original node name -> new node name
37 // mapping. We normalize the names used as input and output arguments to match
38 // regexp "[a-z][a-z0-9_]*" specified in definition of ArgDef.name.
39 // Once we rename them, we risk creating a name collision with the other
40 // node names, so if necessary we add a suffix to make
41 // names unique. If we have an input named "A" and a node in the function
42 // body named "a", they will be renamed to "a" and "a_0".
43 class NodeNameMapping {
44  public:
45   NodeNameMapping() = default;
46 
47   // Normalize the input name and make it unique. This is the same as the
48   // function for output, expect that it adds a name mapping for the name.
49   string GetInputName(const string& name);
50 
51   // Normalize the output name and make it unique.
52   string GetOutputName(const string& name);
53 
54   // Make the node name unique.
55   string Uniquify(const string& name);
56 
57   // Records name as a used name. If this name is already used,
58   // returns an error status.
59   Status UseOutputName(const string& name);
60 
61   // Look up how a node name was previously normalized/uniquified.
62   // Returns empty if name was never seen.
63   string Lookup(const string& name) const;
64 
65  private:
66   string UniquifyHelper(const string& name) const;
67   static string Normalize(string name);
68 
69   // The normalized/uniquified names already used as
70   // input names (in signature), output names (in signature), and node names
71   // (in node_def).
72   // This is a superset of values in name_mapping_.
73   std::unordered_set<string> used_names_;
74   // Mapping from original node name from the graph to the normalized
75   // and uniquified version of it.
76   std::unordered_map<string, string> name_mapping_;
77 };
78 
Normalize(string name)79 string NodeNameMapping::Normalize(string name) {
80   // Convert letters to lowercase and non-alphanumeric characters to '_'.
81   if (name.empty()) return "unknown";
82   const int n = name.size();
83   for (int i = 0; i < n; ++i) {
84     char c = name[i];
85     if (isalnum(c)) {
86       if (isupper(c)) {
87         name[i] = tolower(c);
88       }
89     } else {
90       name[i] = '_';
91     }
92   }
93 
94   // Find the first letter and start with it.
95   int i = 0;
96   for (; i < n; ++i) {
97     if (isalpha(name[i])) break;
98   }
99 
100   // Return "unknown" if none of the name's chars were letters.
101   return i == n ? "unknown" : name.substr(i);
102 }
103 
UniquifyHelper(const string & name) const104 string NodeNameMapping::UniquifyHelper(const string& name) const {
105   // If the name hasn't been used yet, use it as-is.
106   if (used_names_.find(name) == used_names_.end()) return name;
107   // Add a suffix to name to make it unique.
108   for (int i = 0;; ++i) {
109     const string candidate = strings::StrCat(name, "_", i);
110     if (used_names_.find(candidate) == used_names_.end()) return candidate;
111   }
112 }
113 
GetInputName(const string & name)114 string NodeNameMapping::GetInputName(const string& name) {
115   const string& input_name = GetOutputName(name);
116   name_mapping_[name] = input_name;
117   return input_name;
118 }
119 
GetOutputName(const string & name)120 string NodeNameMapping::GetOutputName(const string& name) {
121   const string& input_name = UniquifyHelper(Normalize(name));
122   // Record that we used this name, but don't add it to name_mapping_
123   // since this name is not for a node.
124   used_names_.insert(input_name);
125   return input_name;
126 }
127 
Uniquify(const string & name)128 string NodeNameMapping::Uniquify(const string& name) {
129   const string uniqued = UniquifyHelper(name);
130   name_mapping_[name] = uniqued;
131   used_names_.insert(uniqued);
132   return uniqued;
133 }
134 
UseOutputName(const string & name)135 Status NodeNameMapping::UseOutputName(const string& name) {
136   const auto& iter = used_names_.find(name);
137   if (iter != used_names_.end()) {
138     return InvalidArgument("Cannot have duplicate output names. Name '", name,
139                            "' appears more than once in 'output_names' array.");
140   }
141   used_names_.insert(iter, name);
142   return Status::OK();
143 }
144 
Lookup(const string & name) const145 string NodeNameMapping::Lookup(const string& name) const {
146   const auto iter = name_mapping_.find(name);
147   if (iter == name_mapping_.end()) return string();
148   return iter->second;
149 }
150 
ValidateNonRefOutput(const Node * node,int idx)151 Status ValidateNonRefOutput(const Node* node, int idx) {
152   const DataType& dt = node->output_type(idx);
153   return IsRefType(dt)
154              ? InvalidArgument("Output ", idx, " of node '", node->name(),
155                                "' has a reference type ", DataTypeString(dt))
156              : Status::OK();
157 }
158 
FillFunctionBody(const string & fn_name,const NodeNameMapping & node_names,const std::vector<const Node * > & body_nodes,const std::unordered_map<string,string> & tensor_renaming,FunctionDef * fdef)159 Status FillFunctionBody(
160     const string& fn_name, const NodeNameMapping& node_names,
161     const std::vector<const Node*>& body_nodes,
162     const std::unordered_map<string, string>& tensor_renaming,
163     FunctionDef* fdef) {
164   std::vector<const Edge*> in_edges;
165   std::vector<const Edge*> control_edges;
166   for (const Node* node : body_nodes) {
167     NodeDef* node_def = fdef->add_node_def();
168     // First, copy the node_def as is. We will patch it next.
169     *node_def = node->def();
170     if (!node->assigned_device_name().empty()) {
171       node_def->set_device(node->assigned_device_name());
172     }
173     node_def->set_name(node_names.Lookup(node->name()));
174 
175     // Input names must be set based on nested names in tensor_renaming.
176     // Clear the flat input names we got from the original node_def
177     // from the graph.
178     node_def->clear_input();
179 
180     // Collect regular and control inputs. Regular inputs are indexed
181     // by the index at which they come into the `node`. Control inputs
182     // don't follow any order.
183     in_edges.clear();
184     in_edges.resize(node->num_inputs(), nullptr);
185     control_edges.clear();
186     for (const Edge* edge : node->in_edges()) {
187       if (edge->src()->IsSource()) continue;
188       if (edge->IsControlEdge()) {
189         control_edges.push_back(edge);
190       } else {
191         in_edges[edge->dst_input()] = edge;
192       }
193     }
194 
195     // Add regular inputs.
196     for (size_t i = 0; i < in_edges.size(); ++i) {
197       const Edge* edge = in_edges[i];
198       string original_input_name;
199       if (edge == nullptr) {
200         // A backedge might not appear as a regular Edge, but be only present
201         // in the node_def. Such edges are referred to as requested_inputs().
202         if (i >= node->requested_inputs().size()) {
203           return InvalidArgument(
204               "Graph to be converted to function appears to be malformed. ",
205               "Node ", node->name(), " is missing input edge ", i);
206         }
207         original_input_name =
208             ParseTensorName(node->requested_inputs()[i]).ToString();
209       } else {
210         original_input_name =
211             strings::StrCat(edge->src()->name(), ":", edge->src_output());
212       }
213 
214       const auto iter = tensor_renaming.find(original_input_name);
215       if (iter == tensor_renaming.end()) {
216         return InvalidArgument(
217             "Input ", i, ", '", original_input_name, "', of node '",
218             node->name(), "' in function '", fn_name,
219             "' is not available. You might need to include it in inputs "
220             "or include its source node in the body");
221       }
222       node_def->add_input(iter->second);
223     }
224 
225     // Add control inputs.
226     for (const Edge* edge : control_edges) {
227       // Add this control input only if the src node is in the body or a part of
228       // the inputs.
229       const string normalized = node_names.Lookup(edge->src()->name());
230       // If we did not find a name for the source of control edge, this
231       // source must be outside of the body, and not an input. Raise an error.
232       if (normalized.empty()) {
233         return InvalidArgument(
234             "The source of control edge ", edge->DebugString(),
235             " is not in the body. Encountered while creating function '",
236             fn_name, "'");
237       }
238       node_def->add_input(strings::StrCat("^", normalized));
239     }
240 
241     // A function is stateful if any of its nodes are stateful.
242     if (node->op_def().is_stateful()) {
243       fdef->mutable_signature()->set_is_stateful(true);
244     }
245   }
246   return Status::OK();
247 }
248 
249 // Graph to FunctionDef conversion. This code is closely modeled on the Python
250 // code in tensorflow/python/framework/function.py.
GraphToFunctionDef(const Graph & fn_body,const string & fn_name,bool append_hash_to_fn_name,const std::vector<const Node * > & body_nodes,const std::vector<OutputTensor> & inputs,const std::vector<OutputTensor> & outputs,const std::vector<string> & output_names,const char * description,FunctionDef * fdef)251 Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
252                           bool append_hash_to_fn_name,
253                           const std::vector<const Node*>& body_nodes,
254                           const std::vector<OutputTensor>& inputs,
255                           const std::vector<OutputTensor>& outputs,
256                           const std::vector<string>& output_names,
257                           const char* description, FunctionDef* fdef) {
258   if (!output_names.empty()) {
259     DCHECK_EQ(output_names.size(), outputs.size());
260   }
261 
262   if (description != nullptr) {
263     fdef->mutable_signature()->set_description(description);
264   }
265 
266   // Keep track of names we used and how we normalized them.
267   NodeNameMapping node_names;
268 
269   // Mapping from original names of tensors (i.e. "<node_name>:<idx>") to the
270   // name we used in the function:
271   //  - For input tensors:
272   //    {flat_tensor_name -> normalized_name_of_src_node}
273   //    e.g. {In:3 -> in}
274   //  - For tensors produced by nodes in function's body:
275   //    {flat_tensor_name -> nested_tensor_name}
276   //    e.g. {Add:3 -> add_0:z:1}
277   std::unordered_map<string, string> tensor_renaming;
278 
279   // Fill outputs in function's signature.
280   // We fill the outputs first to prevent output_names from colliding
281   // with the input names we pick below. With this order, no names are used in
282   // node_names yet, and output_names won't collide with anything (except
283   // potentially with themselves).
284   for (size_t i = 0; i < outputs.size(); ++i) {
285     const Node* node = outputs[i].node;
286     int idx = outputs[i].index;
287     OpDef::ArgDef* argdef = fdef->mutable_signature()->add_output_arg();
288     argdef->set_type(node->output_type(idx));
289     if (!output_names.empty()) {
290       TF_RETURN_IF_ERROR(node_names.UseOutputName(output_names[i]));
291       argdef->set_name(output_names[i]);
292     } else {
293       argdef->set_name(node_names.GetOutputName(node->name()));
294     }
295   }
296 
297   // Fill inputs in function's signature.
298   for (size_t i = 0; i < inputs.size(); ++i) {
299     const Node* node = inputs[i].node;
300     int idx = inputs[i].index;
301     OpDef::ArgDef* argdef = fdef->mutable_signature()->add_input_arg();
302     argdef->set_type(node->output_type(idx));
303     const string& input_name = node_names.GetInputName(node->name());
304     argdef->set_name(input_name);
305     tensor_renaming[strings::StrCat(node->name(), ":", idx)] = input_name;
306   }
307 
308   // Populate tensor_renaming and node_names.
309   // Generate the new output names for every node in the function.
310   // The NodeDefs in FunctionDefs use a different naming scheme for
311   // their inputs than the NodeDefs in a graph (see the comment for
312   // FunctionDef.node_def in function.proto). We do the
313   // graph tensor name -> function tensor name conversion for every
314   // possible input (i.e. every node's outputs) and store the result
315   // in tensor_renaming.
316   for (const Node* node : body_nodes) {
317     // Make sure node_name does not collide with an input or output name.
318     const string& node_name = node_names.Uniquify(node->name());
319     // For each output_arg in the op_def, the output_ranges
320     // map will have [start, end] range of indices that this arg produces
321     // among all the output tensors of this op.
322     NameRangeMap output_ranges;
323     TF_RETURN_IF_ERROR(
324         NameRangesForNode(*node, node->op_def(), nullptr, &output_ranges));
325     for (const auto& output : output_ranges) {
326       const StringPiece& output_name = output.first;
327       int index_start = output.second.first;
328       int index_end = output.second.second;
329       for (int i = index_start; i < index_end; ++i) {
330         const string& original_name = strings::StrCat(node->name(), ":", i);
331         const string& new_name =
332             strings::StrCat(node_name, ":", output_name, ":", i - index_start);
333         // Record the mapping if this tensor is not already mapped.
334         // Tensor can be already mapped if it is used as an input.
335         if (tensor_renaming.find(original_name) == tensor_renaming.end()) {
336           tensor_renaming[original_name] = new_name;
337         }
338       }
339     }
340   }
341 
342   TF_RETURN_IF_ERROR(
343       FillFunctionBody(fn_name, node_names, body_nodes, tensor_renaming, fdef));
344 
345   // Remap return values.
346   for (int r = 0; r < fdef->signature().output_arg_size(); ++r) {
347     const string& ret_name = fdef->signature().output_arg(r).name();
348     // We convert this flat tensor name to the nested value
349     // (e.g. `add:z:1`) that we stored in tensor_renaming.
350     const string& return_value =
351         strings::StrCat(outputs[r].node->name(), ":", outputs[r].index);
352     const auto iter = tensor_renaming.find(return_value);
353     if (iter == tensor_renaming.end()) {
354       return InvalidArgument(
355           "TF_Output ", return_value, " is neither in the function body ",
356           "nor among function inputs. Encountered while creating function '",
357           fn_name, "'");
358     }
359     (*fdef->mutable_ret())[ret_name] = iter->second;
360   }
361 
362   if (append_hash_to_fn_name) {
363     const uint64 hash = FunctionDefHash(*fdef);
364     string encoded;
365     TF_RETURN_IF_ERROR(Base64Encode(
366         StringPiece(reinterpret_cast<const char*>(&hash), sizeof(hash)),
367         &encoded));
368     // Besides letters and digits our Base64 encoding uses '_' and '-'.
369     // Dash is invalid in operation names and multiple underscores in random
370     // places look strange. Since we never need to decode the hash back,
371     // replace these chars with with 'a' and 'A'. Replacing with different
372     // letters keeps more entropy.
373     std::replace(encoded.begin(), encoded.end(), '-', 'a');
374     std::replace(encoded.begin(), encoded.end(), '_', 'A');
375     fdef->mutable_signature()->set_name(strings::StrCat(fn_name, "_", encoded));
376   } else {
377     fdef->mutable_signature()->set_name(fn_name);
378   }
379 
380   return Status::OK();
381 }
382 
383 // Converts `ninputs` and `inputs` into `inputs_tensors` and `input_nodes` and
384 // does various checks while doing so. `input_nodes` will contain the same
385 // information as input_tensors just in a different structure to make
386 // following processing easier. TODO(iga): Simplify this nested structure.
ProcessInputs(const TF_Graph * fn_body,const char * fn_name,int ninputs,const TF_Output * inputs,std::vector<OutputTensor> * input_tensors,std::unordered_map<const Node *,std::vector<int>> * input_nodes)387 Status ProcessInputs(
388     const TF_Graph* fn_body, const char* fn_name, int ninputs,
389     const TF_Output* inputs, std::vector<OutputTensor>* input_tensors,
390     std::unordered_map<const Node*, std::vector<int>>* input_nodes)
391     EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
392   input_tensors->reserve(ninputs);
393   for (int i = 0; i < ninputs; ++i) {
394     const Node& node = inputs[i].oper->node;
395     int idx = inputs[i].index;
396 
397     TF_RETURN_WITH_CONTEXT_IF_ERROR(
398         fn_body->graph.IsValidOutputTensor(&node, idx),
399         "Encountered while processing input ", i, " into function '", fn_name,
400         "'");
401     TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(&node, idx),
402                                     "Encountered while processing input ", i,
403                                     " into function '", fn_name, "'");
404 
405     input_tensors->emplace_back(&node, idx);
406 
407     const auto& iter = input_nodes->find(&node);
408     if (iter == input_nodes->end()) {
409       input_nodes->insert({&node, {idx}});
410     } else {
411       auto& indices = iter->second;
412       if (std::find(indices.begin(), indices.end(), idx) != indices.end()) {
413         return InvalidArgument("TF_Output ", node.name(), ":", idx,
414                                " appears more than once in the input list");
415       }
416       indices.push_back(idx);
417     }
418   }
419   return Status::OK();
420 }
421 
422 // Converts `noutputs` and `outputs` into `outputs_tensors` and does various
423 // checks while doing so.
ProcessOutputs(const TF_Graph * fn_body,const char * fn_name,int noutputs,const TF_Output * outputs,std::vector<OutputTensor> * output_tensors)424 Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
425                       int noutputs, const TF_Output* outputs,
426                       std::vector<OutputTensor>* output_tensors)
427     EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
428   output_tensors->reserve(noutputs);
429   for (int i = 0; i < noutputs; ++i) {
430     const Node& node = outputs[i].oper->node;
431     int idx = outputs[i].index;
432     TF_RETURN_WITH_CONTEXT_IF_ERROR(
433         fn_body->graph.IsValidOutputTensor(&node, idx),
434         "Encountered while processing output ", i, " from function '", fn_name,
435         "'");
436     TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(&node, idx),
437                                     "Encountered while creating function '",
438                                     fn_name, "'");
439     output_tensors->emplace_back(&node, idx);
440   }
441   return Status::OK();
442 }
443 
444 // Populates `body_nodes` with the nodes that will become function's body.
445 // Performs various checks.
ComputeBodyNodes(const TF_Graph * fn_body,const char * fn_name,int num_opers,const TF_Operation * const * opers,const std::unordered_map<const Node *,std::vector<int>> & input_nodes,std::vector<const Node * > * body_nodes)446 Status ComputeBodyNodes(
447     const TF_Graph* fn_body, const char* fn_name, int num_opers,
448     const TF_Operation* const* opers,
449     const std::unordered_map<const Node*, std::vector<int>>& input_nodes,
450     std::vector<const Node*>* body_nodes)
451     EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
452   if (num_opers == -1) {
453     for (const Node* node : fn_body->graph.op_nodes()) {
454       const auto& iter = input_nodes.find(node);
455       if (iter == input_nodes.end()) {
456         // This node is not referenced in inputs. Add it to the body.
457         body_nodes->push_back(node);
458       } else {
459         // This node is referenced in inputs. Currently, we place an
460         // artificial restriction and require that when num_opers=-1, such
461         // nodes must have a single output.
462         if (node->num_outputs() != 1) {
463           return InvalidArgument(
464               "When `num_opers` is set to -1, nodes referenced in `inputs` "
465               "must have a single output. Node ",
466               node->name(), " has ", node->num_outputs(),
467               " outputs. Encountered while creating function '", fn_name, "'");
468         }
469       }
470     }
471   } else {
472     body_nodes->reserve(num_opers);
473     for (int i = 0; i < num_opers; ++i) {
474       const Node* node = &opers[i]->node;
475       body_nodes->push_back(node);
476     }
477   }
478   return Status::OK();
479 }
480 
481 }  // namespace
482 }  // namespace tensorflow
483 
484 using tensorflow::Node;
485 using tensorflow::string;
486 
TF_GraphToFunction(const TF_Graph * fn_body,const char * fn_name,unsigned char append_hash_to_fn_name,int num_opers,const TF_Operation * const * opers,int ninputs,const TF_Output * inputs,int noutputs,const TF_Output * outputs,const char * const * output_names,const TF_FunctionOptions * opts,const char * description,TF_Status * status)487 TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name,
488                                 unsigned char append_hash_to_fn_name,
489                                 int num_opers, const TF_Operation* const* opers,
490                                 int ninputs, const TF_Output* inputs,
491                                 int noutputs, const TF_Output* outputs,
492                                 const char* const* output_names,
493                                 const TF_FunctionOptions* opts,
494                                 const char* description, TF_Status* status) {
495   tensorflow::mutex_lock l(*const_cast<tensorflow::mutex*>(&fn_body->mu));
496 
497   // Process inputs.
498   std::vector<tensorflow::OutputTensor> input_tensors;
499   std::unordered_map<const Node*, std::vector<int>> input_nodes;
500   status->status = tensorflow::ProcessInputs(fn_body, fn_name, ninputs, inputs,
501                                              &input_tensors, &input_nodes);
502   if (!status->status.ok()) return nullptr;
503 
504   // Process outputs.
505   std::vector<tensorflow::OutputTensor> output_tensors;
506   status->status = tensorflow::ProcessOutputs(fn_body, fn_name, noutputs,
507                                               outputs, &output_tensors);
508   if (!status->status.ok()) return nullptr;
509 
510   // Process output names.
511   std::vector<string> output_names_vec;
512   if (output_names) {
513     output_names_vec.reserve(noutputs);
514     for (int i = 0; i < noutputs; ++i) {
515       output_names_vec.push_back(string(output_names[i]));
516     }
517   }
518 
519   // Compute body nodes.
520   std::vector<const Node*> body_nodes;
521   status->status = tensorflow::ComputeBodyNodes(
522       fn_body, fn_name, num_opers, opers, input_nodes, &body_nodes);
523   if (!status->status.ok()) return nullptr;
524 
525   // Do the actual function creation.
526   TF_Function* tf_function = new TF_Function();
527   DCHECK(append_hash_to_fn_name <= 1);
528   status->status = tensorflow::GraphToFunctionDef(
529       fn_body->graph, fn_name, append_hash_to_fn_name != 0, body_nodes,
530       input_tensors, output_tensors, output_names_vec, description,
531       &tf_function->fdef);
532   if (!status->status.ok()) {
533     TF_DeleteFunction(tf_function);
534     return nullptr;
535   }
536   return tf_function;
537 }
538 
TF_GraphCopyFunction(TF_Graph * g,const TF_Function * func,const TF_Function * grad,TF_Status * status)539 void TF_GraphCopyFunction(TF_Graph* g, const TF_Function* func,
540                           const TF_Function* grad, TF_Status* status) {
541   if (func == nullptr) {
542     status->status = InvalidArgument(
543         "'func' argument to TF_GraphCopyFunction cannot be null");
544     return;
545   }
546 
547   // TODO(iga): Add AddFunctionDef() and AddGradientDef() methods to graph
548   // to avoid the extra copy here.
549   tensorflow::FunctionDefLibrary fdef_lib;
550   *fdef_lib.add_function() = func->fdef;
551   if (grad) {
552     *fdef_lib.add_function() = grad->fdef;
553     tensorflow::GradientDef* gdef = fdef_lib.add_gradient();
554     gdef->set_function_name(func->fdef.signature().name());
555     gdef->set_gradient_func(grad->fdef.signature().name());
556   }
557 
558   tensorflow::mutex_lock l(g->mu);
559   status->status = g->graph.AddFunctionLibrary(fdef_lib);
560 }
561 
TF_GraphNumFunctions(TF_Graph * g)562 int TF_GraphNumFunctions(TF_Graph* g) {
563   tensorflow::mutex_lock l(g->mu);
564   return g->graph.flib_def().num_functions();
565 }
566 
TF_GraphGetFunctions(TF_Graph * g,TF_Function ** funcs,int max_func,TF_Status * status)567 int TF_GraphGetFunctions(TF_Graph* g, TF_Function** funcs, int max_func,
568                          TF_Status* status) {
569   tensorflow::FunctionDefLibrary lib;
570   {
571     tensorflow::mutex_lock l(g->mu);
572     lib = g->graph.flib_def().ToProto();
573   }
574   const auto len = std::min(max_func, static_cast<int>(lib.function_size()));
575   for (int i = 0; i < len; ++i) {
576     TF_Function* func = new TF_Function();
577     func->fdef = lib.function(i);
578     funcs[i] = func;
579   }
580   status->status = tensorflow::Status::OK();
581   return len;
582 }
583 
TF_FunctionToFunctionDef(TF_Function * func,TF_Buffer * output_func_def,TF_Status * status)584 void TF_FunctionToFunctionDef(TF_Function* func, TF_Buffer* output_func_def,
585                               TF_Status* status) {
586   status->status = MessageToBuffer(func->fdef, output_func_def);
587 }
588 
TF_FunctionImportFunctionDef(const void * proto,size_t proto_len,TF_Status * status)589 TF_Function* TF_FunctionImportFunctionDef(const void* proto, size_t proto_len,
590                                           TF_Status* status) {
591   TF_Function* func = new TF_Function();
592   if (!func->fdef.ParseFromArray(proto, proto_len)) {
593     status->status = InvalidArgument(
594         "Invalid FunctionDef given to TF_FunctionImportFunctionDef");
595     TF_DeleteFunction(func);
596     return nullptr;
597   }
598   status->status = tensorflow::Status::OK();
599   return func;
600 }
601 
TF_FunctionSetAttrValueProto(TF_Function * func,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)602 void TF_FunctionSetAttrValueProto(TF_Function* func, const char* attr_name,
603                                   const void* proto, size_t proto_len,
604                                   TF_Status* status) {
605   tensorflow::AttrValue attr_value;
606   if (!attr_value.ParseFromArray(proto, proto_len)) {
607     status->status = InvalidArgument(
608         "Unparseable AttrValue proto passed to "
609         "TF_FunctionSetAttrValueProto");
610     return;
611   }
612   (*func->fdef.mutable_attr())[string(attr_name)] = attr_value;
613   status->status = tensorflow::Status::OK();
614 }
615 
TF_FunctionGetAttrValueProto(TF_Function * func,const char * attr_name,TF_Buffer * output_attr_value,TF_Status * status)616 void TF_FunctionGetAttrValueProto(TF_Function* func, const char* attr_name,
617                                   TF_Buffer* output_attr_value,
618                                   TF_Status* status) {
619   const auto& it = func->fdef.attr().find(attr_name);
620   if (it == func->fdef.attr().end()) {
621     status->status =
622         InvalidArgument("Function '", func->fdef.signature().name(),
623                         "' has no attr named '", attr_name, "'.");
624     return;
625   }
626   status->status = MessageToBuffer(it->second, output_attr_value);
627 }
628 
TF_DeleteFunction(TF_Function * func)629 void TF_DeleteFunction(TF_Function* func) { delete func; }
630