1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_BUILDER_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_BUILDER_H_
18 
19 #include <functional>
20 #include <vector>
21 #include "tensorflow/core/framework/attr_value_util.h"
22 #include "tensorflow/core/framework/node_def.pb.h"
23 #include "tensorflow/core/framework/node_def_util.h"
24 #include "tensorflow/core/framework/op.h"
25 #include "tensorflow/core/framework/op_def.pb.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/graph/graph.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/lib/gtl/array_slice.h"
30 #include "tensorflow/core/lib/strings/strcat.h"
31 
32 namespace tensorflow {
33 
34 class NodeDefBuilder;
35 typedef std::function<Status(const OpDef&, int, const NodeDef&,
36                              NodeDefBuilder*)>
37     FakeInputFunctor;
38 
39 // This is a helper for creating a NodeDef.  Automatically sets attrs
40 // that can be inferred from the inputs, and uses default values
41 // (where they exist) for unspecified attrs.  Example usage:
42 //
43 //  NodeDef node_def;
44 //  Status status = NodeDefBuilder(node_name, op_name)
45 //                           .Input(...)
46 //                           .Attr(...)
47 //                           .Finalize(&node_def);
48 //  if (!status.ok()) return status;
49 //  // Use node_def here.
50 class NodeDefBuilder {
51  public:
52   // To specify an output to be consumed by one of the Input() methods below.
53   struct NodeOut {
54     NodeOut(StringPiece n, int i, DataType dt);
55     NodeOut();  // uninitialized, call Reset() before use.
56     void Reset(StringPiece n, int i, DataType dt);
57     string node;
58     int index;
59     DataType data_type;
60   };
61 
62   // Specify the name and the Op (either via an OpDef or the name of
63   // the Op plus a registry) for the NodeDef.  Other fields are
64   // specified by calling the methods below.
65   // REQUIRES: The OpDef must satisfy ValidateOpDef().
66   NodeDefBuilder(StringPiece name, StringPiece op_name,
67                  const OpRegistryInterface* op_registry = OpRegistry::Global(),
68                  const NodeDebugInfo* debug = nullptr);
69   NodeDefBuilder(StringPiece name, StringPiece op_name,
70                  const NodeDebugInfo& debug);
71   // REQUIRES: in addition, *op_def must outlive *this.
72   NodeDefBuilder(StringPiece name, const OpDef* op_def);
73 
74   // You must call one Input() function per input_arg in the Op,
75   // *and in the same order as the input_args appear in the OpDef.*
76 
77   // For inputs that take a single tensor.
78   NodeDefBuilder& Input(StringPiece src_node, int src_index, DataType dt);
79   NodeDefBuilder& Input(const NodeOut& src);
80 
81   // For inputs that take a list of tensors.
82   NodeDefBuilder& Input(gtl::ArraySlice<NodeOut> src_list);
83 
84   // To create inputs in tests, see fake_input.h.
85   NodeDefBuilder& Input(FakeInputFunctor fake_input);
86 
87   // Specify that this node must only run after src_node.
88   NodeDefBuilder& ControlInput(StringPiece src_node);
89 
90   // Constrains what devices this node may be scheduled on.
91   NodeDefBuilder& Device(StringPiece device_spec);
92 
93   // Sets the attr, if not already set.  If already set with a different
94   // value, an error will be returned from Finalize().
95   NodeDefBuilder& Attr(StringPiece name, const AttrValue& value);
96   NodeDefBuilder& Attr(StringPiece name, StringPiece value);
97   NodeDefBuilder& Attr(StringPiece name, const char* value);
98   NodeDefBuilder& Attr(StringPiece name, int32 value);
99   NodeDefBuilder& Attr(StringPiece name, int64 value);
100   NodeDefBuilder& Attr(StringPiece name, float value);
101   NodeDefBuilder& Attr(StringPiece name, double value);
102   NodeDefBuilder& Attr(StringPiece name, bool value);
103   NodeDefBuilder& Attr(StringPiece name, DataType value);
104   NodeDefBuilder& Attr(StringPiece name, const PartialTensorShape& value);
105   NodeDefBuilder& Attr(StringPiece name, const Tensor& value);
106   NodeDefBuilder& Attr(StringPiece name, const TensorProto& value);
107   NodeDefBuilder& Attr(StringPiece name, const NameAttrList& value);
108   NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<StringPiece> value);
109   NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<const char*> value);
110   NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<string> value);
111   NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<int32> value);
112   NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<int64> value);
113   NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<float> value);
114   NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<bool> value);
115   NodeDefBuilder& Attr(StringPiece name, const std::vector<bool>& value);
116   NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<DataType> value);
117   NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<TensorShape> value);
118   NodeDefBuilder& Attr(StringPiece name,
119                        gtl::ArraySlice<PartialTensorShape> value);
120   NodeDefBuilder& Attr(StringPiece name,
121                        gtl::ArraySlice<TensorShapeProto> value);
122   NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<Tensor> value);
123   NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<NameAttrList> value);
124 
125   template <class T>
Attr(StringPiece name,std::initializer_list<T> value)126   NodeDefBuilder& Attr(StringPiece name, std::initializer_list<T> value) {
127     return Attr(name, gtl::ArraySlice<T>(value));
128   }
129 
130   // Finish building the NodeDef, returning any errors or setting
131   // *node_def if none.
132   // WARNING: Not all problems are detected!  The resulting NodeDef may
133   // not be valid!  Call ValidateNodeDef() from node_def_utils to be sure.
134   Status Finalize(NodeDef* node_def) const;
135 
136   // Accessors for the values set in the constructor.
node_name()137   const string& node_name() const { return node_def_.name(); }
op_def()138   const OpDef& op_def() const { return *op_def_; }
139 
140  private:
141   // Called in the constructors.
142   void Initialize();
143 
144   // Get the current ArgDef and advance to the next one. Returns nullptr
145   // if no more inputs are available.
146   const OpDef::ArgDef* NextArgDef();
147 
148   // Returns true if there is still an input_arg available in *op_def_,
149   // otherwise adds to error_ and returns false.
150   bool NextArgAvailable();
151 
152   // These do the main work of the Input() methods.
153   void SingleInput(const OpDef::ArgDef* input_arg, StringPiece src_node,
154                    int src_index, DataType dt);
155   void ListInput(const OpDef::ArgDef* input_arg,
156                  gtl::ArraySlice<NodeOut> src_list);
157 
158   // Add "src_node:src_index" to the list of inputs in the node_def_.
159   void AddInput(StringPiece src_node, int src_index);
160 
161   // Generate an error if you can't pass dt when expected is expected.
162   void VerifyInputType(const OpDef::ArgDef* input_arg, DataType expected,
163                        DataType dt);
164 
165   // If input_arg->is_ref() is true, generate an error if dt is not a ref.
166   void VerifyInputRef(const OpDef::ArgDef* input_arg, DataType dt);
167 
168   // Makes dt a ref type if that is what the input_arg specifies.
MaybeAddRef(const OpDef::ArgDef * input_arg,DataType dt)169   DataType MaybeAddRef(const OpDef::ArgDef* input_arg, DataType dt) {
170     return input_arg->is_ref() ? MakeRefType(dt) : dt;
171   }
172 
173   const OpDef* op_def_;
174   NodeDef node_def_;
175   int inputs_specified_;
176   std::vector<string> control_inputs_;
177   std::vector<string> errors_;
178 };
179 
180 }  // namespace tensorflow
181 
182 #endif  // TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_BUILDER_H_
183