1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/c_api_internal.h"
17
18 #include <algorithm>
19 #include <unordered_map>
20 #include <unordered_set>
21
22 #include "tensorflow/core/framework/attr_value_util.h"
23 #include "tensorflow/core/framework/function.pb.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/node_def_util.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/graph/graph.h"
28 #include "tensorflow/core/lib/strings/base64.h"
29 #include "tensorflow/core/lib/strings/strcat.h"
30
31 using tensorflow::errors::InvalidArgument;
32
33 namespace tensorflow {
34 namespace {
35
36 // Class that maintains a one-to-one original node name -> new node name
37 // mapping. We normalize the names used as input and output arguments to match
38 // regexp "[a-z][a-z0-9_]*" specified in definition of ArgDef.name.
39 // Once we rename them, we risk creating a name collision with the other
40 // node names, so if necessary we add a suffix to make
41 // names unique. If we have an input named "A" and a node in the function
42 // body named "a", they will be renamed to "a" and "a_0".
43 class NodeNameMapping {
44 public:
45 NodeNameMapping() = default;
46
47 // Normalize the input name and make it unique. This is the same as the
48 // function for output, expect that it adds a name mapping for the name.
49 string GetInputName(const string& name);
50
51 // Normalize the output name and make it unique.
52 string GetOutputName(const string& name);
53
54 // Make the node name unique.
55 string Uniquify(const string& name);
56
57 // Records name as a used name. If this name is already used,
58 // returns an error status.
59 Status UseOutputName(const string& name);
60
61 // Look up how a node name was previously normalized/uniquified.
62 // Returns empty if name was never seen.
63 string Lookup(const string& name) const;
64
65 private:
66 string UniquifyHelper(const string& name) const;
67 static string Normalize(string name);
68
69 // The normalized/uniquified names already used as
70 // input names (in signature), output names (in signature), and node names
71 // (in node_def).
72 // This is a superset of values in name_mapping_.
73 std::unordered_set<string> used_names_;
74 // Mapping from original node name from the graph to the normalized
75 // and uniquified version of it.
76 std::unordered_map<string, string> name_mapping_;
77 };
78
Normalize(string name)79 string NodeNameMapping::Normalize(string name) {
80 // Convert letters to lowercase and non-alphanumeric characters to '_'.
81 if (name.empty()) return "unknown";
82 const int n = name.size();
83 for (int i = 0; i < n; ++i) {
84 char c = name[i];
85 if (isalnum(c)) {
86 if (isupper(c)) {
87 name[i] = tolower(c);
88 }
89 } else {
90 name[i] = '_';
91 }
92 }
93
94 // Find the first letter and start with it.
95 int i = 0;
96 for (; i < n; ++i) {
97 if (isalpha(name[i])) break;
98 }
99
100 // Return "unknown" if none of the name's chars were letters.
101 return i == n ? "unknown" : name.substr(i);
102 }
103
UniquifyHelper(const string & name) const104 string NodeNameMapping::UniquifyHelper(const string& name) const {
105 // If the name hasn't been used yet, use it as-is.
106 if (used_names_.find(name) == used_names_.end()) return name;
107 // Add a suffix to name to make it unique.
108 for (int i = 0;; ++i) {
109 const string candidate = strings::StrCat(name, "_", i);
110 if (used_names_.find(candidate) == used_names_.end()) return candidate;
111 }
112 }
113
GetInputName(const string & name)114 string NodeNameMapping::GetInputName(const string& name) {
115 const string& input_name = GetOutputName(name);
116 name_mapping_[name] = input_name;
117 return input_name;
118 }
119
GetOutputName(const string & name)120 string NodeNameMapping::GetOutputName(const string& name) {
121 const string& input_name = UniquifyHelper(Normalize(name));
122 // Record that we used this name, but don't add it to name_mapping_
123 // since this name is not for a node.
124 used_names_.insert(input_name);
125 return input_name;
126 }
127
Uniquify(const string & name)128 string NodeNameMapping::Uniquify(const string& name) {
129 const string uniqued = UniquifyHelper(name);
130 name_mapping_[name] = uniqued;
131 used_names_.insert(uniqued);
132 return uniqued;
133 }
134
UseOutputName(const string & name)135 Status NodeNameMapping::UseOutputName(const string& name) {
136 const auto& iter = used_names_.find(name);
137 if (iter != used_names_.end()) {
138 return InvalidArgument("Cannot have duplicate output names. Name '", name,
139 "' appears more than once in 'output_names' array.");
140 }
141 used_names_.insert(iter, name);
142 return Status::OK();
143 }
144
Lookup(const string & name) const145 string NodeNameMapping::Lookup(const string& name) const {
146 const auto iter = name_mapping_.find(name);
147 if (iter == name_mapping_.end()) return string();
148 return iter->second;
149 }
150
ValidateNonRefOutput(const Node * node,int idx)151 Status ValidateNonRefOutput(const Node* node, int idx) {
152 const DataType& dt = node->output_type(idx);
153 return IsRefType(dt)
154 ? InvalidArgument("Output ", idx, " of node '", node->name(),
155 "' has a reference type ", DataTypeString(dt))
156 : Status::OK();
157 }
158
FillFunctionBody(const string & fn_name,const NodeNameMapping & node_names,const std::vector<const Node * > & body_nodes,const std::unordered_map<string,string> & tensor_renaming,FunctionDef * fdef)159 Status FillFunctionBody(
160 const string& fn_name, const NodeNameMapping& node_names,
161 const std::vector<const Node*>& body_nodes,
162 const std::unordered_map<string, string>& tensor_renaming,
163 FunctionDef* fdef) {
164 std::vector<const Edge*> in_edges;
165 std::vector<const Edge*> control_edges;
166 for (const Node* node : body_nodes) {
167 NodeDef* node_def = fdef->add_node_def();
168 // First, copy the node_def as is. We will patch it next.
169 *node_def = node->def();
170 if (!node->assigned_device_name().empty()) {
171 node_def->set_device(node->assigned_device_name());
172 }
173 node_def->set_name(node_names.Lookup(node->name()));
174
175 // Input names must be set based on nested names in tensor_renaming.
176 // Clear the flat input names we got from the original node_def
177 // from the graph.
178 node_def->clear_input();
179
180 // Collect regular and control inputs. Regular inputs are indexed
181 // by the index at which they come into the `node`. Control inputs
182 // don't follow any order.
183 in_edges.clear();
184 in_edges.resize(node->num_inputs(), nullptr);
185 control_edges.clear();
186 for (const Edge* edge : node->in_edges()) {
187 if (edge->src()->IsSource()) continue;
188 if (edge->IsControlEdge()) {
189 control_edges.push_back(edge);
190 } else {
191 in_edges[edge->dst_input()] = edge;
192 }
193 }
194
195 // Add regular inputs.
196 for (size_t i = 0; i < in_edges.size(); ++i) {
197 const Edge* edge = in_edges[i];
198 string original_input_name;
199 if (edge == nullptr) {
200 // A backedge might not appear as a regular Edge, but be only present
201 // in the node_def. Such edges are referred to as requested_inputs().
202 if (i >= node->requested_inputs().size()) {
203 return InvalidArgument(
204 "Graph to be converted to function appears to be malformed. ",
205 "Node ", node->name(), " is missing input edge ", i);
206 }
207 original_input_name =
208 ParseTensorName(node->requested_inputs()[i]).ToString();
209 } else {
210 original_input_name =
211 strings::StrCat(edge->src()->name(), ":", edge->src_output());
212 }
213
214 const auto iter = tensor_renaming.find(original_input_name);
215 if (iter == tensor_renaming.end()) {
216 return InvalidArgument(
217 "Input ", i, ", '", original_input_name, "', of node '",
218 node->name(), "' in function '", fn_name,
219 "' is not available. You might need to include it in inputs "
220 "or include its source node in the body");
221 }
222 node_def->add_input(iter->second);
223 }
224
225 // Add control inputs.
226 for (const Edge* edge : control_edges) {
227 // Add this control input only if the src node is in the body or a part of
228 // the inputs.
229 const string normalized = node_names.Lookup(edge->src()->name());
230 // If we did not find a name for the source of control edge, this
231 // source must be outside of the body, and not an input. Raise an error.
232 if (normalized.empty()) {
233 return InvalidArgument(
234 "The source of control edge ", edge->DebugString(),
235 " is not in the body. Encountered while creating function '",
236 fn_name, "'");
237 }
238 node_def->add_input(strings::StrCat("^", normalized));
239 }
240
241 // A function is stateful if any of its nodes are stateful.
242 if (node->op_def().is_stateful()) {
243 fdef->mutable_signature()->set_is_stateful(true);
244 }
245 }
246 return Status::OK();
247 }
248
249 // Graph to FunctionDef conversion. This code is closely modeled on the Python
250 // code in tensorflow/python/framework/function.py.
GraphToFunctionDef(const Graph & fn_body,const string & fn_name,bool append_hash_to_fn_name,const std::vector<const Node * > & body_nodes,const std::vector<OutputTensor> & inputs,const std::vector<OutputTensor> & outputs,const std::vector<string> & output_names,const char * description,FunctionDef * fdef)251 Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
252 bool append_hash_to_fn_name,
253 const std::vector<const Node*>& body_nodes,
254 const std::vector<OutputTensor>& inputs,
255 const std::vector<OutputTensor>& outputs,
256 const std::vector<string>& output_names,
257 const char* description, FunctionDef* fdef) {
258 if (!output_names.empty()) {
259 DCHECK_EQ(output_names.size(), outputs.size());
260 }
261
262 if (description != nullptr) {
263 fdef->mutable_signature()->set_description(description);
264 }
265
266 // Keep track of names we used and how we normalized them.
267 NodeNameMapping node_names;
268
269 // Mapping from original names of tensors (i.e. "<node_name>:<idx>") to the
270 // name we used in the function:
271 // - For input tensors:
272 // {flat_tensor_name -> normalized_name_of_src_node}
273 // e.g. {In:3 -> in}
274 // - For tensors produced by nodes in function's body:
275 // {flat_tensor_name -> nested_tensor_name}
276 // e.g. {Add:3 -> add_0:z:1}
277 std::unordered_map<string, string> tensor_renaming;
278
279 // Fill outputs in function's signature.
280 // We fill the outputs first to prevent output_names from colliding
281 // with the input names we pick below. With this order, no names are used in
282 // node_names yet, and output_names won't collide with anything (except
283 // potentially with themselves).
284 for (size_t i = 0; i < outputs.size(); ++i) {
285 const Node* node = outputs[i].node;
286 int idx = outputs[i].index;
287 OpDef::ArgDef* argdef = fdef->mutable_signature()->add_output_arg();
288 argdef->set_type(node->output_type(idx));
289 if (!output_names.empty()) {
290 TF_RETURN_IF_ERROR(node_names.UseOutputName(output_names[i]));
291 argdef->set_name(output_names[i]);
292 } else {
293 argdef->set_name(node_names.GetOutputName(node->name()));
294 }
295 }
296
297 // Fill inputs in function's signature.
298 for (size_t i = 0; i < inputs.size(); ++i) {
299 const Node* node = inputs[i].node;
300 int idx = inputs[i].index;
301 OpDef::ArgDef* argdef = fdef->mutable_signature()->add_input_arg();
302 argdef->set_type(node->output_type(idx));
303 const string& input_name = node_names.GetInputName(node->name());
304 argdef->set_name(input_name);
305 tensor_renaming[strings::StrCat(node->name(), ":", idx)] = input_name;
306 }
307
308 // Populate tensor_renaming and node_names.
309 // Generate the new output names for every node in the function.
310 // The NodeDefs in FunctionDefs use a different naming scheme for
311 // their inputs than the NodeDefs in a graph (see the comment for
312 // FunctionDef.node_def in function.proto). We do the
313 // graph tensor name -> function tensor name conversion for every
314 // possible input (i.e. every node's outputs) and store the result
315 // in tensor_renaming.
316 for (const Node* node : body_nodes) {
317 // Make sure node_name does not collide with an input or output name.
318 const string& node_name = node_names.Uniquify(node->name());
319 // For each output_arg in the op_def, the output_ranges
320 // map will have [start, end] range of indices that this arg produces
321 // among all the output tensors of this op.
322 NameRangeMap output_ranges;
323 TF_RETURN_IF_ERROR(
324 NameRangesForNode(*node, node->op_def(), nullptr, &output_ranges));
325 for (const auto& output : output_ranges) {
326 const StringPiece& output_name = output.first;
327 int index_start = output.second.first;
328 int index_end = output.second.second;
329 for (int i = index_start; i < index_end; ++i) {
330 const string& original_name = strings::StrCat(node->name(), ":", i);
331 const string& new_name =
332 strings::StrCat(node_name, ":", output_name, ":", i - index_start);
333 // Record the mapping if this tensor is not already mapped.
334 // Tensor can be already mapped if it is used as an input.
335 if (tensor_renaming.find(original_name) == tensor_renaming.end()) {
336 tensor_renaming[original_name] = new_name;
337 }
338 }
339 }
340 }
341
342 TF_RETURN_IF_ERROR(
343 FillFunctionBody(fn_name, node_names, body_nodes, tensor_renaming, fdef));
344
345 // Remap return values.
346 for (int r = 0; r < fdef->signature().output_arg_size(); ++r) {
347 const string& ret_name = fdef->signature().output_arg(r).name();
348 // We convert this flat tensor name to the nested value
349 // (e.g. `add:z:1`) that we stored in tensor_renaming.
350 const string& return_value =
351 strings::StrCat(outputs[r].node->name(), ":", outputs[r].index);
352 const auto iter = tensor_renaming.find(return_value);
353 if (iter == tensor_renaming.end()) {
354 return InvalidArgument(
355 "TF_Output ", return_value, " is neither in the function body ",
356 "nor among function inputs. Encountered while creating function '",
357 fn_name, "'");
358 }
359 (*fdef->mutable_ret())[ret_name] = iter->second;
360 }
361
362 if (append_hash_to_fn_name) {
363 const uint64 hash = FunctionDefHash(*fdef);
364 string encoded;
365 TF_RETURN_IF_ERROR(Base64Encode(
366 StringPiece(reinterpret_cast<const char*>(&hash), sizeof(hash)),
367 &encoded));
368 // Besides letters and digits our Base64 encoding uses '_' and '-'.
369 // Dash is invalid in operation names and multiple underscores in random
370 // places look strange. Since we never need to decode the hash back,
371 // replace these chars with with 'a' and 'A'. Replacing with different
372 // letters keeps more entropy.
373 std::replace(encoded.begin(), encoded.end(), '-', 'a');
374 std::replace(encoded.begin(), encoded.end(), '_', 'A');
375 fdef->mutable_signature()->set_name(strings::StrCat(fn_name, "_", encoded));
376 } else {
377 fdef->mutable_signature()->set_name(fn_name);
378 }
379
380 return Status::OK();
381 }
382
383 // Converts `ninputs` and `inputs` into `inputs_tensors` and `input_nodes` and
384 // does various checks while doing so. `input_nodes` will contain the same
385 // information as input_tensors just in a different structure to make
386 // following processing easier. TODO(iga): Simplify this nested structure.
ProcessInputs(const TF_Graph * fn_body,const char * fn_name,int ninputs,const TF_Output * inputs,std::vector<OutputTensor> * input_tensors,std::unordered_map<const Node *,std::vector<int>> * input_nodes)387 Status ProcessInputs(
388 const TF_Graph* fn_body, const char* fn_name, int ninputs,
389 const TF_Output* inputs, std::vector<OutputTensor>* input_tensors,
390 std::unordered_map<const Node*, std::vector<int>>* input_nodes)
391 EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
392 input_tensors->reserve(ninputs);
393 for (int i = 0; i < ninputs; ++i) {
394 const Node& node = inputs[i].oper->node;
395 int idx = inputs[i].index;
396
397 TF_RETURN_WITH_CONTEXT_IF_ERROR(
398 fn_body->graph.IsValidOutputTensor(&node, idx),
399 "Encountered while processing input ", i, " into function '", fn_name,
400 "'");
401 TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(&node, idx),
402 "Encountered while processing input ", i,
403 " into function '", fn_name, "'");
404
405 input_tensors->emplace_back(&node, idx);
406
407 const auto& iter = input_nodes->find(&node);
408 if (iter == input_nodes->end()) {
409 input_nodes->insert({&node, {idx}});
410 } else {
411 auto& indices = iter->second;
412 if (std::find(indices.begin(), indices.end(), idx) != indices.end()) {
413 return InvalidArgument("TF_Output ", node.name(), ":", idx,
414 " appears more than once in the input list");
415 }
416 indices.push_back(idx);
417 }
418 }
419 return Status::OK();
420 }
421
422 // Converts `noutputs` and `outputs` into `outputs_tensors` and does various
423 // checks while doing so.
ProcessOutputs(const TF_Graph * fn_body,const char * fn_name,int noutputs,const TF_Output * outputs,std::vector<OutputTensor> * output_tensors)424 Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
425 int noutputs, const TF_Output* outputs,
426 std::vector<OutputTensor>* output_tensors)
427 EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
428 output_tensors->reserve(noutputs);
429 for (int i = 0; i < noutputs; ++i) {
430 const Node& node = outputs[i].oper->node;
431 int idx = outputs[i].index;
432 TF_RETURN_WITH_CONTEXT_IF_ERROR(
433 fn_body->graph.IsValidOutputTensor(&node, idx),
434 "Encountered while processing output ", i, " from function '", fn_name,
435 "'");
436 TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(&node, idx),
437 "Encountered while creating function '",
438 fn_name, "'");
439 output_tensors->emplace_back(&node, idx);
440 }
441 return Status::OK();
442 }
443
444 // Populates `body_nodes` with the nodes that will become function's body.
445 // Performs various checks.
ComputeBodyNodes(const TF_Graph * fn_body,const char * fn_name,int num_opers,const TF_Operation * const * opers,const std::unordered_map<const Node *,std::vector<int>> & input_nodes,std::vector<const Node * > * body_nodes)446 Status ComputeBodyNodes(
447 const TF_Graph* fn_body, const char* fn_name, int num_opers,
448 const TF_Operation* const* opers,
449 const std::unordered_map<const Node*, std::vector<int>>& input_nodes,
450 std::vector<const Node*>* body_nodes)
451 EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
452 if (num_opers == -1) {
453 for (const Node* node : fn_body->graph.op_nodes()) {
454 const auto& iter = input_nodes.find(node);
455 if (iter == input_nodes.end()) {
456 // This node is not referenced in inputs. Add it to the body.
457 body_nodes->push_back(node);
458 } else {
459 // This node is referenced in inputs. Currently, we place an
460 // artificial restriction and require that when num_opers=-1, such
461 // nodes must have a single output.
462 if (node->num_outputs() != 1) {
463 return InvalidArgument(
464 "When `num_opers` is set to -1, nodes referenced in `inputs` "
465 "must have a single output. Node ",
466 node->name(), " has ", node->num_outputs(),
467 " outputs. Encountered while creating function '", fn_name, "'");
468 }
469 }
470 }
471 } else {
472 body_nodes->reserve(num_opers);
473 for (int i = 0; i < num_opers; ++i) {
474 const Node* node = &opers[i]->node;
475 body_nodes->push_back(node);
476 }
477 }
478 return Status::OK();
479 }
480
481 } // namespace
482 } // namespace tensorflow
483
484 using tensorflow::Node;
485 using tensorflow::string;
486
TF_GraphToFunction(const TF_Graph * fn_body,const char * fn_name,unsigned char append_hash_to_fn_name,int num_opers,const TF_Operation * const * opers,int ninputs,const TF_Output * inputs,int noutputs,const TF_Output * outputs,const char * const * output_names,const TF_FunctionOptions * opts,const char * description,TF_Status * status)487 TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name,
488 unsigned char append_hash_to_fn_name,
489 int num_opers, const TF_Operation* const* opers,
490 int ninputs, const TF_Output* inputs,
491 int noutputs, const TF_Output* outputs,
492 const char* const* output_names,
493 const TF_FunctionOptions* opts,
494 const char* description, TF_Status* status) {
495 tensorflow::mutex_lock l(*const_cast<tensorflow::mutex*>(&fn_body->mu));
496
497 // Process inputs.
498 std::vector<tensorflow::OutputTensor> input_tensors;
499 std::unordered_map<const Node*, std::vector<int>> input_nodes;
500 status->status = tensorflow::ProcessInputs(fn_body, fn_name, ninputs, inputs,
501 &input_tensors, &input_nodes);
502 if (!status->status.ok()) return nullptr;
503
504 // Process outputs.
505 std::vector<tensorflow::OutputTensor> output_tensors;
506 status->status = tensorflow::ProcessOutputs(fn_body, fn_name, noutputs,
507 outputs, &output_tensors);
508 if (!status->status.ok()) return nullptr;
509
510 // Process output names.
511 std::vector<string> output_names_vec;
512 if (output_names) {
513 output_names_vec.reserve(noutputs);
514 for (int i = 0; i < noutputs; ++i) {
515 output_names_vec.push_back(string(output_names[i]));
516 }
517 }
518
519 // Compute body nodes.
520 std::vector<const Node*> body_nodes;
521 status->status = tensorflow::ComputeBodyNodes(
522 fn_body, fn_name, num_opers, opers, input_nodes, &body_nodes);
523 if (!status->status.ok()) return nullptr;
524
525 // Do the actual function creation.
526 TF_Function* tf_function = new TF_Function();
527 DCHECK(append_hash_to_fn_name <= 1);
528 status->status = tensorflow::GraphToFunctionDef(
529 fn_body->graph, fn_name, append_hash_to_fn_name != 0, body_nodes,
530 input_tensors, output_tensors, output_names_vec, description,
531 &tf_function->fdef);
532 if (!status->status.ok()) {
533 TF_DeleteFunction(tf_function);
534 return nullptr;
535 }
536 return tf_function;
537 }
538
TF_GraphCopyFunction(TF_Graph * g,const TF_Function * func,const TF_Function * grad,TF_Status * status)539 void TF_GraphCopyFunction(TF_Graph* g, const TF_Function* func,
540 const TF_Function* grad, TF_Status* status) {
541 if (func == nullptr) {
542 status->status = InvalidArgument(
543 "'func' argument to TF_GraphCopyFunction cannot be null");
544 return;
545 }
546
547 // TODO(iga): Add AddFunctionDef() and AddGradientDef() methods to graph
548 // to avoid the extra copy here.
549 tensorflow::FunctionDefLibrary fdef_lib;
550 *fdef_lib.add_function() = func->fdef;
551 if (grad) {
552 *fdef_lib.add_function() = grad->fdef;
553 tensorflow::GradientDef* gdef = fdef_lib.add_gradient();
554 gdef->set_function_name(func->fdef.signature().name());
555 gdef->set_gradient_func(grad->fdef.signature().name());
556 }
557
558 tensorflow::mutex_lock l(g->mu);
559 status->status = g->graph.AddFunctionLibrary(fdef_lib);
560 }
561
TF_GraphNumFunctions(TF_Graph * g)562 int TF_GraphNumFunctions(TF_Graph* g) {
563 tensorflow::mutex_lock l(g->mu);
564 return g->graph.flib_def().num_functions();
565 }
566
TF_GraphGetFunctions(TF_Graph * g,TF_Function ** funcs,int max_func,TF_Status * status)567 int TF_GraphGetFunctions(TF_Graph* g, TF_Function** funcs, int max_func,
568 TF_Status* status) {
569 tensorflow::FunctionDefLibrary lib;
570 {
571 tensorflow::mutex_lock l(g->mu);
572 lib = g->graph.flib_def().ToProto();
573 }
574 const auto len = std::min(max_func, static_cast<int>(lib.function_size()));
575 for (int i = 0; i < len; ++i) {
576 TF_Function* func = new TF_Function();
577 func->fdef = lib.function(i);
578 funcs[i] = func;
579 }
580 status->status = tensorflow::Status::OK();
581 return len;
582 }
583
TF_FunctionToFunctionDef(TF_Function * func,TF_Buffer * output_func_def,TF_Status * status)584 void TF_FunctionToFunctionDef(TF_Function* func, TF_Buffer* output_func_def,
585 TF_Status* status) {
586 status->status = MessageToBuffer(func->fdef, output_func_def);
587 }
588
TF_FunctionImportFunctionDef(const void * proto,size_t proto_len,TF_Status * status)589 TF_Function* TF_FunctionImportFunctionDef(const void* proto, size_t proto_len,
590 TF_Status* status) {
591 TF_Function* func = new TF_Function();
592 if (!func->fdef.ParseFromArray(proto, proto_len)) {
593 status->status = InvalidArgument(
594 "Invalid FunctionDef given to TF_FunctionImportFunctionDef");
595 TF_DeleteFunction(func);
596 return nullptr;
597 }
598 status->status = tensorflow::Status::OK();
599 return func;
600 }
601
TF_FunctionSetAttrValueProto(TF_Function * func,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)602 void TF_FunctionSetAttrValueProto(TF_Function* func, const char* attr_name,
603 const void* proto, size_t proto_len,
604 TF_Status* status) {
605 tensorflow::AttrValue attr_value;
606 if (!attr_value.ParseFromArray(proto, proto_len)) {
607 status->status = InvalidArgument(
608 "Unparseable AttrValue proto passed to "
609 "TF_FunctionSetAttrValueProto");
610 return;
611 }
612 (*func->fdef.mutable_attr())[string(attr_name)] = attr_value;
613 status->status = tensorflow::Status::OK();
614 }
615
TF_FunctionGetAttrValueProto(TF_Function * func,const char * attr_name,TF_Buffer * output_attr_value,TF_Status * status)616 void TF_FunctionGetAttrValueProto(TF_Function* func, const char* attr_name,
617 TF_Buffer* output_attr_value,
618 TF_Status* status) {
619 const auto& it = func->fdef.attr().find(attr_name);
620 if (it == func->fdef.attr().end()) {
621 status->status =
622 InvalidArgument("Function '", func->fdef.signature().name(),
623 "' has no attr named '", attr_name, "'.");
624 return;
625 }
626 status->status = MessageToBuffer(it->second, output_attr_value);
627 }
628
TF_DeleteFunction(TF_Function * func)629 void TF_DeleteFunction(TF_Function* func) { delete func; }
630