1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/testlib.h"
17 
18 #include <vector>
19 #include "tensorflow/core/framework/common_shape_fns.h"
20 #include "tensorflow/core/framework/graph.pb.h"
21 #include "tensorflow/core/framework/node_def_builder.h"
22 #include "tensorflow/core/framework/node_def_util.h"
23 #include "tensorflow/core/framework/op.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/framework/types.pb.h"
26 #include "tensorflow/core/graph/graph.h"
27 #include "tensorflow/core/graph/node_builder.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/platform/logging.h"
30 
31 namespace tensorflow {
32 namespace test {
33 namespace graph {
34 
Send(Graph * g,Node * input,const string & tensor,const string & sender,const uint64 sender_incarnation,const string & receiver)35 Node* Send(Graph* g, Node* input, const string& tensor, const string& sender,
36            const uint64 sender_incarnation, const string& receiver) {
37   Node* ret;
38   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Send")
39                   .Input(input, 0)
40                   .Attr("tensor_name", tensor)
41                   .Attr("send_device", sender)
42                   .Attr("send_device_incarnation",
43                         static_cast<int64>(sender_incarnation))
44                   .Attr("recv_device", receiver)
45                   .Finalize(g, &ret));
46   return ret;
47 }
48 
Recv(Graph * g,const string & tensor,const string & type,const string & sender,const uint64 sender_incarnation,const string & receiver)49 Node* Recv(Graph* g, const string& tensor, const string& type,
50            const string& sender, const uint64 sender_incarnation,
51            const string& receiver) {
52   Node* ret;
53   DataType dtype;
54   CHECK(DataTypeFromString(type, &dtype));
55   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Recv")
56                   .Attr("tensor_type", dtype)
57                   .Attr("tensor_name", tensor)
58                   .Attr("send_device", sender)
59                   .Attr("send_device_incarnation",
60                         static_cast<int64>(sender_incarnation))
61                   .Attr("recv_device", receiver)
62                   .Finalize(g, &ret));
63   return ret;
64 }
65 
Constant(Graph * g,const Tensor & tensor)66 Node* Constant(Graph* g, const Tensor& tensor) {
67   Node* ret;
68   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Const")
69                   .Attr("dtype", tensor.dtype())
70                   .Attr("value", tensor)
71                   .Finalize(g, &ret));
72   return ret;
73 }
74 
Constant(Graph * g,const Tensor & tensor,const string & name)75 Node* Constant(Graph* g, const Tensor& tensor, const string& name) {
76   Node* ret;
77   TF_CHECK_OK(NodeBuilder(name, "Const")
78                   .Attr("dtype", tensor.dtype())
79                   .Attr("value", tensor)
80                   .Finalize(g, &ret));
81   return ret;
82 }
83 
HostConstant(Graph * g,const Tensor & tensor)84 Node* HostConstant(Graph* g, const Tensor& tensor) {
85   return HostConstant(g, tensor, g->NewName("n"));
86 }
87 
HostConstant(Graph * g,const Tensor & tensor,const string & name)88 Node* HostConstant(Graph* g, const Tensor& tensor, const string& name) {
89   Node* ret;
90   TF_CHECK_OK(NodeBuilder(name, "HostConst")
91                   .Attr("dtype", tensor.dtype())
92                   .Attr("value", tensor)
93                   .Finalize(g, &ret));
94   return ret;
95 }
96 
Var(Graph * g,const DataType dtype,const TensorShape & shape)97 Node* Var(Graph* g, const DataType dtype, const TensorShape& shape) {
98   Node* ret;
99   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Variable")
100                   .Attr("dtype", dtype)
101                   .Attr("shape", shape)
102                   .Finalize(g, &ret));
103   return ret;
104 }
105 
Var(Graph * g,const DataType dtype,const TensorShape & shape,const string & name)106 Node* Var(Graph* g, const DataType dtype, const TensorShape& shape,
107           const string& name) {
108   Node* ret;
109   TF_CHECK_OK(NodeBuilder(name, "Variable")
110                   .Attr("dtype", dtype)
111                   .Attr("shape", shape)
112                   .Finalize(g, &ret));
113   return ret;
114 }
115 
Assign(Graph * g,Node * var,Node * val)116 Node* Assign(Graph* g, Node* var, Node* val) {
117   Node* ret;
118   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Assign")
119                   .Input(var)
120                   .Input(val)
121                   .Attr("use_locking", true)
122                   .Finalize(g, &ret));
123   return ret;
124 }
125 
Cumsum(Graph * g,Node * data,Node * axes,bool exclusive,bool reverse)126 Node* Cumsum(Graph* g, Node* data, Node* axes, bool exclusive, bool reverse) {
127   Node* ret;
128   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Cumsum")
129                   .Input(data)
130                   .Input(axes)
131                   .Attr("exclusive", exclusive)
132                   .Attr("reverse", reverse)
133                   .Finalize(g, &ret));
134   return ret;
135 }
136 
Reduce(Graph * g,const string & reduce,Node * data,Node * axes,bool keep_dims)137 Node* Reduce(Graph* g, const string& reduce, Node* data, Node* axes,
138              bool keep_dims) {
139   Node* ret;
140   TF_CHECK_OK(NodeBuilder(g->NewName("n"), reduce, g->op_registry())
141                   .Input(data)
142                   .Input(axes)
143                   .Attr("keep_dims", keep_dims)
144                   .Finalize(g, &ret));
145   return ret;
146 }
147 
QuantizeToUINT8(Graph * g,Node * data)148 Node* QuantizeToUINT8(Graph* g, Node* data) {
149   Node* ret;
150   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Quantize")
151                   .Input(data)
152                   .Attr("T", DT_QUINT8)
153                   .Attr("max_range", 1.0f)
154                   .Attr("min_range", -1.0f)
155                   .Finalize(g, &ret));
156   return ret;
157 }
158 
Matmul(Graph * g,Node * in0,Node * in1,bool transpose_a,bool transpose_b)159 Node* Matmul(Graph* g, Node* in0, Node* in1, bool transpose_a,
160              bool transpose_b) {
161   Node* ret;
162   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "MatMul")
163                   .Input(in0)
164                   .Input(in1)
165                   .Attr("transpose_a", transpose_a)
166                   .Attr("transpose_b", transpose_b)
167                   .Finalize(g, &ret));
168   return ret;
169 }
170 
BatchMatmul(Graph * g,Node * in0,Node * in1,bool adj_x,bool adj_y)171 Node* BatchMatmul(Graph* g, Node* in0, Node* in1, bool adj_x, bool adj_y) {
172   Node* ret;
173   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BatchMatMul")
174                   .Input(in0)
175                   .Input(in1)
176                   .Attr("adj_x", adj_x)
177                   .Attr("adj_y", adj_y)
178                   .Finalize(g, &ret));
179   return ret;
180 }
181 
RandomNumberGenerator(const string & op,Graph * g,Node * input,DataType dtype)182 Node* RandomNumberGenerator(const string& op, Graph* g, Node* input,
183                             DataType dtype) {
184   Node* ret;
185   TF_CHECK_OK(NodeBuilder(g->NewName("n"), op, g->op_registry())
186                   .Input(input)
187                   .Attr("dtype", dtype)
188                   .Attr("seed", 0)
189                   .Finalize(g, &ret));
190   return ret;
191 }
192 
RandomUniform(Graph * g,Node * input,DataType dtype)193 Node* RandomUniform(Graph* g, Node* input, DataType dtype) {
194   return RandomNumberGenerator("RandomUniform", g, input, dtype);
195 }
196 
RandomGaussian(Graph * g,Node * input,DataType dtype)197 Node* RandomGaussian(Graph* g, Node* input, DataType dtype) {
198   return RandomNumberGenerator("RandomStandardNormal", g, input, dtype);
199 }
200 
TruncatedNormal(Graph * g,Node * input,DataType dtype)201 Node* TruncatedNormal(Graph* g, Node* input, DataType dtype) {
202   return RandomNumberGenerator("TruncatedNormal", g, input, dtype);
203 }
204 
RandomGamma(Graph * g,Node * shape,Node * alpha)205 Node* RandomGamma(Graph* g, Node* shape, Node* alpha) {
206   Node* ret;
207   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "RandomGamma")
208                   .Input(shape)
209                   .Input(alpha)
210                   .Attr("seed", 0)
211                   .Finalize(g, &ret));
212   return ret;
213 }
214 
RandomPoisson(Graph * g,Node * shape,Node * lam)215 Node* RandomPoisson(Graph* g, Node* shape, Node* lam) {
216   Node* ret;
217   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "RandomPoisson")
218                   .Input(shape)
219                   .Input(lam)
220                   .Attr("seed", 0)
221                   .Finalize(g, &ret));
222   return ret;
223 }
224 
Unary(Graph * g,const string & func,Node * input,int index)225 Node* Unary(Graph* g, const string& func, Node* input, int index) {
226   Node* ret;
227   TF_CHECK_OK(NodeBuilder(g->NewName("n"), func, g->op_registry())
228                   .Input(input, index)
229                   .Finalize(g, &ret));
230   return ret;
231 }
232 
Binary(Graph * g,const string & func,Node * in0,Node * in1)233 Node* Binary(Graph* g, const string& func, Node* in0, Node* in1) {
234   Node* ret;
235   TF_CHECK_OK(NodeBuilder(g->NewName("n"), func, g->op_registry())
236                   .Input(in0)
237                   .Input(in1)
238                   .Finalize(g, &ret));
239   return ret;
240 }
241 
Multi(Graph * g,const string & func,gtl::ArraySlice<Node * > ins)242 Node* Multi(Graph* g, const string& func, gtl::ArraySlice<Node*> ins) {
243   Node* ret;
244   auto b = NodeBuilder(g->NewName("n"), func, g->op_registry());
245   for (Node* n : ins) b = b.Input(n);
246   TF_CHECK_OK(b.Finalize(g, &ret));
247   return ret;
248 }
249 
Identity(Graph * g,Node * input,int index)250 Node* Identity(Graph* g, Node* input, int index) {
251   Node* ret;
252   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Identity")
253                   .Input(input, index)
254                   .Finalize(g, &ret));
255   return ret;
256 }
257 
Add(Graph * g,Node * in0,Node * in1)258 Node* Add(Graph* g, Node* in0, Node* in1) { return Binary(g, "Add", in0, in1); }
259 
Reverse(Graph * g,Node * tensor,Node * axis)260 Node* Reverse(Graph* g, Node* tensor, Node* axis) {
261   return Binary(g, "ReverseV2", tensor, axis);
262 }
263 
Roll(Graph * g,Node * input,Node * shift,Node * axis)264 Node* Roll(Graph* g, Node* input, Node* shift, Node* axis) {
265   Node* ret;
266   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Roll", g->op_registry())
267                   .Input(input)
268                   .Input(shift)
269                   .Input(axis)
270                   .Finalize(g, &ret));
271   return ret;
272 }
273 
Error(Graph * g,Node * input,const string & errmsg)274 Node* Error(Graph* g, Node* input, const string& errmsg) {
275   Node* ret;
276   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Error")
277                   .Input(input)
278                   .Attr("message", errmsg)
279                   .Finalize(g, &ret));
280   return ret;
281 }
282 
InvalidRefType(Graph * g,DataType out_type,DataType invalid_type)283 Node* InvalidRefType(Graph* g, DataType out_type, DataType invalid_type) {
284   DCHECK(out_type != invalid_type);
285   Node* ret;
286   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "InvalidRefType")
287                   .Attr("TIn", out_type)
288                   .Attr("TOut", invalid_type)
289                   .Finalize(g, &ret));
290   return ret;
291 }
292 
Delay(Graph * g,Node * input,Microseconds delay_micros)293 Node* Delay(Graph* g, Node* input, Microseconds delay_micros) {
294   Node* ret;
295   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Delay")
296                   .Input(input)
297                   .Attr("micros", delay_micros.value())
298                   .Finalize(g, &ret));
299   return ret;
300 }
301 
NoOp(Graph * g,const std::vector<Node * > & control_inputs)302 Node* NoOp(Graph* g, const std::vector<Node*>& control_inputs) {
303   Node* ret;
304   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "NoOp")
305                   .ControlInputs(control_inputs)
306                   .Finalize(g, &ret));
307   return ret;
308 }
309 
Switch(Graph * g,Node * in0,Node * in1)310 Node* Switch(Graph* g, Node* in0, Node* in1) {
311   Node* ret;
312   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Switch")
313                   .Input(in0)
314                   .Input(in1)
315                   .Finalize(g, &ret));
316   return ret;
317 }
318 
Enter(Graph * g,Node * input,const string & frame_name)319 Node* Enter(Graph* g, Node* input, const string& frame_name) {
320   Node* ret;
321   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Enter")
322                   .Input(input)
323                   .Attr("frame_name", frame_name)
324                   .Finalize(g, &ret));
325   return ret;
326 }
327 
Exit(Graph * g,Node * input)328 Node* Exit(Graph* g, Node* input) {
329   Node* ret;
330   TF_CHECK_OK(
331       NodeBuilder(g->NewName("n"), "Exit").Input(input).Finalize(g, &ret));
332   return ret;
333 }
334 
Merge(Graph * g,Node * in0,Node * in1)335 Node* Merge(Graph* g, Node* in0, Node* in1) {
336   Node* ret;
337   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Merge")
338                   .Input({in0, in1})
339                   .Finalize(g, &ret));
340   return ret;
341 }
342 
Merge(Graph * g,Node * in0,gtl::ArraySlice<string> remaining_in)343 Node* Merge(Graph* g, Node* in0, gtl::ArraySlice<string> remaining_in) {
344   std::vector<NodeBuilder::NodeOut> inputs;
345   inputs.reserve(remaining_in.size() + 1);
346   inputs.emplace_back(in0);
347   for (const string& in_name : remaining_in) {
348     inputs.emplace_back(in_name, 0, inputs[0].dt);
349   }
350 
351   Node* ret;
352   TF_CHECK_OK(
353       NodeBuilder(g->NewName("n"), "Merge").Input(inputs).Finalize(g, &ret));
354   return ret;
355 }
356 
Concat(Graph * g,Node * concat_dim,gtl::ArraySlice<Node * > tensors)357 Node* Concat(Graph* g, Node* concat_dim, gtl::ArraySlice<Node*> tensors) {
358   std::vector<NodeBuilder::NodeOut> nodeouts;
359   nodeouts.reserve(tensors.size());
360   for (auto const t : tensors) {
361     nodeouts.emplace_back(t);
362   }
363   Node* ret;
364   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Concat")
365                   .Input(concat_dim)
366                   .Input(nodeouts)
367                   .Finalize(g, &ret));
368   return ret;
369 }
370 
ConcatV2(Graph * g,gtl::ArraySlice<Node * > tensors,Node * concat_dim)371 Node* ConcatV2(Graph* g, gtl::ArraySlice<Node*> tensors, Node* concat_dim) {
372   std::vector<NodeBuilder::NodeOut> nodeouts;
373   nodeouts.reserve(tensors.size());
374   for (auto const t : tensors) {
375     nodeouts.emplace_back(t);
376   }
377   Node* ret;
378   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "ConcatV2")
379                   .Input(nodeouts)
380                   .Input(concat_dim)
381                   .Finalize(g, &ret));
382   return ret;
383 }
384 
Next(Graph * g,const string & name,Node * input)385 Node* Next(Graph* g, const string& name, Node* input) {
386   Node* ret;
387   TF_CHECK_OK(
388       NodeBuilder(name, "NextIteration").Input(input).Finalize(g, &ret));
389   return ret;
390 }
391 
LoopCond(Graph * g,Node * input)392 Node* LoopCond(Graph* g, Node* input) {
393   Node* ret;
394   TF_CHECK_OK(
395       NodeBuilder(g->NewName("n"), "LoopCond").Input(input).Finalize(g, &ret));
396   return ret;
397 }
398 
Less(Graph * g,Node * in0,Node * in1)399 Node* Less(Graph* g, Node* in0, Node* in1) {
400   return Binary(g, "Less", in0, in1);
401 }
402 
Select(Graph * g,Node * c,Node * inx,Node * iny)403 Node* Select(Graph* g, Node* c, Node* inx, Node* iny) {
404   Node* ret;
405   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Select")
406                   .Input(c)
407                   .Input(inx)
408                   .Input(iny)
409                   .Finalize(g, &ret));
410   return ret;
411 }
412 
Cast(Graph * g,Node * in,DataType dst)413 Node* Cast(Graph* g, Node* in, DataType dst) {
414   Node* ret;
415   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Cast")
416                   .Input(in)
417                   .Attr("DstT", dst)
418                   .Finalize(g, &ret));
419   return ret;
420 }
421 
Gather(Graph * g,Node * in0,Node * in1,Node * axis)422 Node* Gather(Graph* g, Node* in0, Node* in1, Node* axis) {
423   Node* ret;
424   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "GatherV2")
425                   .Input(in0)
426                   .Input(in1)
427                   .Input(axis)
428                   .Finalize(g, &ret));
429   return ret;
430 }
431 
GetSessionTensor(Graph * g,Node * in)432 Node* GetSessionTensor(Graph* g, Node* in) {
433   Node* ret;
434   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "GetSessionTensor")
435                   .Input(in, 0)
436                   .Attr("dtype", DT_FLOAT)
437                   .Finalize(g, &ret));
438   return ret;
439 }
440 
Relu(Graph * g,Node * in)441 Node* Relu(Graph* g, Node* in) {
442   Node* ret;
443   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Relu")
444                   .Input(in, 0)
445                   .Attr("T", DT_FLOAT)
446                   .Finalize(g, &ret));
447   return ret;
448 }
449 
Relu6(Graph * g,Node * in)450 Node* Relu6(Graph* g, Node* in) {
451   Node* ret;
452   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Relu6")
453                   .Input(in, 0)
454                   .Attr("T", DT_FLOAT)
455                   .Finalize(g, &ret));
456   return ret;
457 }
458 
BiasAdd(Graph * g,Node * value,Node * bias)459 Node* BiasAdd(Graph* g, Node* value, Node* bias) {
460   Node* ret;
461   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BiasAdd")
462                   .Input(value)
463                   .Input(bias)
464                   .Attr("T", DT_FLOAT)
465                   .Finalize(g, &ret));
466   return ret;
467 }
468 
Conv2D(Graph * g,Node * in0,Node * in1)469 Node* Conv2D(Graph* g, Node* in0, Node* in1) {
470   Node* ret;
471   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Conv2D")
472                   .Input(in0)
473                   .Input(in1)
474                   .Attr("T", DT_FLOAT)
475                   .Attr("strides", {1, 1, 1, 1})
476                   .Attr("padding", "SAME")
477                   .Finalize(g, &ret));
478   return ret;
479 }
480 
Diag(Graph * g,Node * in,DataType type)481 Node* Diag(Graph* g, Node* in, DataType type) {
482   Node* ret;
483   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Diag")
484                   .Input(in)
485                   .Attr("T", type)
486                   .Finalize(g, &ret));
487   return ret;
488 }
489 
DiagPart(Graph * g,Node * in,DataType type)490 Node* DiagPart(Graph* g, Node* in, DataType type) {
491   Node* ret;
492   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "DiagPart")
493                   .Input(in)
494                   .Attr("T", type)
495                   .Finalize(g, &ret));
496   return ret;
497 }
498 
CheckNumerics(Graph * g,Node * in,const string & message)499 Node* CheckNumerics(Graph* g, Node* in, const string& message) {
500   Node* ret;
501   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "CheckNumerics")
502                   .Input(in)
503                   .Attr("message", message)
504                   .Finalize(g, &ret));
505   return ret;
506 }
507 
Arg(Graph * g,int64 index,DataType type)508 Node* Arg(Graph* g, int64 index, DataType type) {
509   Node* ret;
510   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Arg")
511                   .Attr("T", type)
512                   .Attr("index", index)
513                   .Finalize(g, &ret));
514   return ret;
515 }
516 
Retval(Graph * g,int64 index,Node * in)517 Node* Retval(Graph* g, int64 index, Node* in) {
518   Node* ret;
519   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Retval")
520                   .Input(in)
521                   .Attr("index", index)
522                   .Finalize(g, &ret));
523   return ret;
524 }
525 
ToGraphDef(Graph * g,GraphDef * gdef)526 void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); }
527 
528 }  // end namespace graph
529 }  // end namespace test
530 }  // end namespace tensorflow
531