1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/graph_def_builder.h"
17 
18 #include <utility>
19 
20 #include "tensorflow/core/graph/tensor_id.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 
23 namespace tensorflow {
24 
Options(Graph * graph,Status * status)25 GraphDefBuilder::Options::Options(Graph* graph, Status* status)
26     : graph_(graph), status_(status) {}
~Options()27 GraphDefBuilder::Options::~Options() {}
28 
WithName(StringPiece name) const29 GraphDefBuilder::Options GraphDefBuilder::Options::WithName(
30     StringPiece name) const {
31   return Options(*this).WithNameImpl(name);
32 }
WithDevice(StringPiece device) const33 GraphDefBuilder::Options GraphDefBuilder::Options::WithDevice(
34     StringPiece device) const {
35   return Options(*this).WithDeviceImpl(device);
36 }
WithControlInput(Node * control_input) const37 GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInput(
38     Node* control_input) const {
39   return Options(*this).WithControlInputImpl(control_input);
40 }
WithControlInputs(gtl::ArraySlice<Node * > control_inputs) const41 GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputs(
42     gtl::ArraySlice<Node*> control_inputs) const {
43   return Options(*this).WithControlInputsImpl(control_inputs);
44 }
WithNameImpl(StringPiece name)45 GraphDefBuilder::Options GraphDefBuilder::Options::WithNameImpl(
46     StringPiece name) {
47   name_ = string(name);
48   return *this;
49 }
WithDeviceImpl(StringPiece device)50 GraphDefBuilder::Options GraphDefBuilder::Options::WithDeviceImpl(
51     StringPiece device) {
52   device_ = string(device);
53   return *this;
54 }
WithControlInputImpl(Node * control_input)55 GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputImpl(
56     Node* control_input) {
57   control_inputs_.push_back(control_input);
58   return *this;
59 }
WithControlInputsImpl(gtl::ArraySlice<Node * > control_inputs)60 GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputsImpl(
61     gtl::ArraySlice<Node*> control_inputs) {
62   control_inputs_.insert(control_inputs_.end(), control_inputs.begin(),
63                          control_inputs.end());
64   return *this;
65 }
66 
ToGraphDef(GraphDef * graph_def) const67 Status GraphDefBuilder::ToGraphDef(GraphDef* graph_def) const {
68   if (status_.ok()) {
69     graph_.ToGraphDef(graph_def);
70   }
71   return status_;
72 }
73 
GetNameForOp(StringPiece op) const74 string GraphDefBuilder::Options::GetNameForOp(StringPiece op) const {
75   if (name_.empty()) return graph_->NewName(op);
76   return name_;
77 }
78 
FinalizeBuilder(NodeBuilder * builder) const79 Node* GraphDefBuilder::Options::FinalizeBuilder(NodeBuilder* builder) const {
80   builder->ControlInputs(control_inputs_);
81   if (!device_.empty()) builder->Device(device_);
82   for (const auto& attr : attrs_) {
83     builder->Attr(attr.first, attr.second);
84   }
85 
86   Node* returned_node;
87   UpdateStatus(builder->Finalize(graph_, &returned_node));
88   return returned_node;
89 }
90 
UpdateStatus(const Status & status) const91 void GraphDefBuilder::Options::UpdateStatus(const Status& status) const {
92   if (status_ == nullptr) {
93     TF_CHECK_OK(status);
94   } else {
95     status_->Update(status);
96   }
97 }
98 
99 namespace ops {
100 
SourceOp(const string & op_name,const GraphDefBuilder::Options & opts)101 Node* SourceOp(const string& op_name, const GraphDefBuilder::Options& opts) {
102   if (opts.HaveError()) return nullptr;
103   NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name,
104                            opts.op_registry());
105   return opts.FinalizeBuilder(&node_builder);
106 }
107 
UnaryOp(const string & op_name,NodeOut input,const GraphDefBuilder::Options & opts)108 Node* UnaryOp(const string& op_name, NodeOut input,
109               const GraphDefBuilder::Options& opts) {
110   if (opts.HaveError()) return nullptr;
111   NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name,
112                            opts.op_registry());
113   node_builder.Input(std::move(input));
114   return opts.FinalizeBuilder(&node_builder);
115 }
116 
BinaryOp(const string & op_name,NodeOut a,NodeOut b,const GraphDefBuilder::Options & opts)117 Node* BinaryOp(const string& op_name, NodeOut a, NodeOut b,
118                const GraphDefBuilder::Options& opts) {
119   if (opts.HaveError()) return nullptr;
120   NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name,
121                            opts.op_registry());
122   node_builder.Input(std::move(a)).Input(std::move(b));
123   return opts.FinalizeBuilder(&node_builder);
124 }
125 
126 }  // end namespace ops
127 }  // end namespace tensorflow
128