1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/graph/testlib.h"
17
18 #include <vector>
19 #include "tensorflow/core/framework/common_shape_fns.h"
20 #include "tensorflow/core/framework/graph.pb.h"
21 #include "tensorflow/core/framework/node_def_builder.h"
22 #include "tensorflow/core/framework/node_def_util.h"
23 #include "tensorflow/core/framework/op.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/framework/types.pb.h"
26 #include "tensorflow/core/graph/graph.h"
27 #include "tensorflow/core/graph/node_builder.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/platform/logging.h"
30
31 namespace tensorflow {
32 namespace test {
33 namespace graph {
34
Send(Graph * g,Node * input,const string & tensor,const string & sender,const uint64 sender_incarnation,const string & receiver)35 Node* Send(Graph* g, Node* input, const string& tensor, const string& sender,
36 const uint64 sender_incarnation, const string& receiver) {
37 Node* ret;
38 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Send")
39 .Input(input, 0)
40 .Attr("tensor_name", tensor)
41 .Attr("send_device", sender)
42 .Attr("send_device_incarnation",
43 static_cast<int64>(sender_incarnation))
44 .Attr("recv_device", receiver)
45 .Finalize(g, &ret));
46 return ret;
47 }
48
Recv(Graph * g,const string & tensor,const string & type,const string & sender,const uint64 sender_incarnation,const string & receiver)49 Node* Recv(Graph* g, const string& tensor, const string& type,
50 const string& sender, const uint64 sender_incarnation,
51 const string& receiver) {
52 Node* ret;
53 DataType dtype;
54 CHECK(DataTypeFromString(type, &dtype));
55 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Recv")
56 .Attr("tensor_type", dtype)
57 .Attr("tensor_name", tensor)
58 .Attr("send_device", sender)
59 .Attr("send_device_incarnation",
60 static_cast<int64>(sender_incarnation))
61 .Attr("recv_device", receiver)
62 .Finalize(g, &ret));
63 return ret;
64 }
65
Constant(Graph * g,const Tensor & tensor)66 Node* Constant(Graph* g, const Tensor& tensor) {
67 Node* ret;
68 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Const")
69 .Attr("dtype", tensor.dtype())
70 .Attr("value", tensor)
71 .Finalize(g, &ret));
72 return ret;
73 }
74
Constant(Graph * g,const Tensor & tensor,const string & name)75 Node* Constant(Graph* g, const Tensor& tensor, const string& name) {
76 Node* ret;
77 TF_CHECK_OK(NodeBuilder(name, "Const")
78 .Attr("dtype", tensor.dtype())
79 .Attr("value", tensor)
80 .Finalize(g, &ret));
81 return ret;
82 }
83
HostConstant(Graph * g,const Tensor & tensor)84 Node* HostConstant(Graph* g, const Tensor& tensor) {
85 return HostConstant(g, tensor, g->NewName("n"));
86 }
87
HostConstant(Graph * g,const Tensor & tensor,const string & name)88 Node* HostConstant(Graph* g, const Tensor& tensor, const string& name) {
89 Node* ret;
90 TF_CHECK_OK(NodeBuilder(name, "HostConst")
91 .Attr("dtype", tensor.dtype())
92 .Attr("value", tensor)
93 .Finalize(g, &ret));
94 return ret;
95 }
96
Var(Graph * g,const DataType dtype,const TensorShape & shape)97 Node* Var(Graph* g, const DataType dtype, const TensorShape& shape) {
98 Node* ret;
99 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Variable")
100 .Attr("dtype", dtype)
101 .Attr("shape", shape)
102 .Finalize(g, &ret));
103 return ret;
104 }
105
Var(Graph * g,const DataType dtype,const TensorShape & shape,const string & name)106 Node* Var(Graph* g, const DataType dtype, const TensorShape& shape,
107 const string& name) {
108 Node* ret;
109 TF_CHECK_OK(NodeBuilder(name, "Variable")
110 .Attr("dtype", dtype)
111 .Attr("shape", shape)
112 .Finalize(g, &ret));
113 return ret;
114 }
115
Assign(Graph * g,Node * var,Node * val)116 Node* Assign(Graph* g, Node* var, Node* val) {
117 Node* ret;
118 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Assign")
119 .Input(var)
120 .Input(val)
121 .Attr("use_locking", true)
122 .Finalize(g, &ret));
123 return ret;
124 }
125
Cumsum(Graph * g,Node * data,Node * axes,bool exclusive,bool reverse)126 Node* Cumsum(Graph* g, Node* data, Node* axes, bool exclusive, bool reverse) {
127 Node* ret;
128 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Cumsum")
129 .Input(data)
130 .Input(axes)
131 .Attr("exclusive", exclusive)
132 .Attr("reverse", reverse)
133 .Finalize(g, &ret));
134 return ret;
135 }
136
Reduce(Graph * g,const string & reduce,Node * data,Node * axes,bool keep_dims)137 Node* Reduce(Graph* g, const string& reduce, Node* data, Node* axes,
138 bool keep_dims) {
139 Node* ret;
140 TF_CHECK_OK(NodeBuilder(g->NewName("n"), reduce, g->op_registry())
141 .Input(data)
142 .Input(axes)
143 .Attr("keep_dims", keep_dims)
144 .Finalize(g, &ret));
145 return ret;
146 }
147
QuantizeToUINT8(Graph * g,Node * data)148 Node* QuantizeToUINT8(Graph* g, Node* data) {
149 Node* ret;
150 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Quantize")
151 .Input(data)
152 .Attr("T", DT_QUINT8)
153 .Attr("max_range", 1.0f)
154 .Attr("min_range", -1.0f)
155 .Finalize(g, &ret));
156 return ret;
157 }
158
Matmul(Graph * g,Node * in0,Node * in1,bool transpose_a,bool transpose_b)159 Node* Matmul(Graph* g, Node* in0, Node* in1, bool transpose_a,
160 bool transpose_b) {
161 Node* ret;
162 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "MatMul")
163 .Input(in0)
164 .Input(in1)
165 .Attr("transpose_a", transpose_a)
166 .Attr("transpose_b", transpose_b)
167 .Finalize(g, &ret));
168 return ret;
169 }
170
BatchMatmul(Graph * g,Node * in0,Node * in1,bool adj_x,bool adj_y)171 Node* BatchMatmul(Graph* g, Node* in0, Node* in1, bool adj_x, bool adj_y) {
172 Node* ret;
173 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BatchMatMul")
174 .Input(in0)
175 .Input(in1)
176 .Attr("adj_x", adj_x)
177 .Attr("adj_y", adj_y)
178 .Finalize(g, &ret));
179 return ret;
180 }
181
RandomNumberGenerator(const string & op,Graph * g,Node * input,DataType dtype)182 Node* RandomNumberGenerator(const string& op, Graph* g, Node* input,
183 DataType dtype) {
184 Node* ret;
185 TF_CHECK_OK(NodeBuilder(g->NewName("n"), op, g->op_registry())
186 .Input(input)
187 .Attr("dtype", dtype)
188 .Attr("seed", 0)
189 .Finalize(g, &ret));
190 return ret;
191 }
192
RandomUniform(Graph * g,Node * input,DataType dtype)193 Node* RandomUniform(Graph* g, Node* input, DataType dtype) {
194 return RandomNumberGenerator("RandomUniform", g, input, dtype);
195 }
196
RandomGaussian(Graph * g,Node * input,DataType dtype)197 Node* RandomGaussian(Graph* g, Node* input, DataType dtype) {
198 return RandomNumberGenerator("RandomStandardNormal", g, input, dtype);
199 }
200
TruncatedNormal(Graph * g,Node * input,DataType dtype)201 Node* TruncatedNormal(Graph* g, Node* input, DataType dtype) {
202 return RandomNumberGenerator("TruncatedNormal", g, input, dtype);
203 }
204
RandomGamma(Graph * g,Node * shape,Node * alpha)205 Node* RandomGamma(Graph* g, Node* shape, Node* alpha) {
206 Node* ret;
207 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "RandomGamma")
208 .Input(shape)
209 .Input(alpha)
210 .Attr("seed", 0)
211 .Finalize(g, &ret));
212 return ret;
213 }
214
RandomPoisson(Graph * g,Node * shape,Node * lam)215 Node* RandomPoisson(Graph* g, Node* shape, Node* lam) {
216 Node* ret;
217 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "RandomPoisson")
218 .Input(shape)
219 .Input(lam)
220 .Attr("seed", 0)
221 .Finalize(g, &ret));
222 return ret;
223 }
224
Unary(Graph * g,const string & func,Node * input,int index)225 Node* Unary(Graph* g, const string& func, Node* input, int index) {
226 Node* ret;
227 TF_CHECK_OK(NodeBuilder(g->NewName("n"), func, g->op_registry())
228 .Input(input, index)
229 .Finalize(g, &ret));
230 return ret;
231 }
232
Binary(Graph * g,const string & func,Node * in0,Node * in1)233 Node* Binary(Graph* g, const string& func, Node* in0, Node* in1) {
234 Node* ret;
235 TF_CHECK_OK(NodeBuilder(g->NewName("n"), func, g->op_registry())
236 .Input(in0)
237 .Input(in1)
238 .Finalize(g, &ret));
239 return ret;
240 }
241
Multi(Graph * g,const string & func,gtl::ArraySlice<Node * > ins)242 Node* Multi(Graph* g, const string& func, gtl::ArraySlice<Node*> ins) {
243 Node* ret;
244 auto b = NodeBuilder(g->NewName("n"), func, g->op_registry());
245 for (Node* n : ins) b = b.Input(n);
246 TF_CHECK_OK(b.Finalize(g, &ret));
247 return ret;
248 }
249
Identity(Graph * g,Node * input,int index)250 Node* Identity(Graph* g, Node* input, int index) {
251 Node* ret;
252 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Identity")
253 .Input(input, index)
254 .Finalize(g, &ret));
255 return ret;
256 }
257
Add(Graph * g,Node * in0,Node * in1)258 Node* Add(Graph* g, Node* in0, Node* in1) { return Binary(g, "Add", in0, in1); }
259
Reverse(Graph * g,Node * tensor,Node * axis)260 Node* Reverse(Graph* g, Node* tensor, Node* axis) {
261 return Binary(g, "ReverseV2", tensor, axis);
262 }
263
Roll(Graph * g,Node * input,Node * shift,Node * axis)264 Node* Roll(Graph* g, Node* input, Node* shift, Node* axis) {
265 Node* ret;
266 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Roll", g->op_registry())
267 .Input(input)
268 .Input(shift)
269 .Input(axis)
270 .Finalize(g, &ret));
271 return ret;
272 }
273
Error(Graph * g,Node * input,const string & errmsg)274 Node* Error(Graph* g, Node* input, const string& errmsg) {
275 Node* ret;
276 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Error")
277 .Input(input)
278 .Attr("message", errmsg)
279 .Finalize(g, &ret));
280 return ret;
281 }
282
InvalidRefType(Graph * g,DataType out_type,DataType invalid_type)283 Node* InvalidRefType(Graph* g, DataType out_type, DataType invalid_type) {
284 DCHECK(out_type != invalid_type);
285 Node* ret;
286 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "InvalidRefType")
287 .Attr("TIn", out_type)
288 .Attr("TOut", invalid_type)
289 .Finalize(g, &ret));
290 return ret;
291 }
292
Delay(Graph * g,Node * input,Microseconds delay_micros)293 Node* Delay(Graph* g, Node* input, Microseconds delay_micros) {
294 Node* ret;
295 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Delay")
296 .Input(input)
297 .Attr("micros", delay_micros.value())
298 .Finalize(g, &ret));
299 return ret;
300 }
301
NoOp(Graph * g,const std::vector<Node * > & control_inputs)302 Node* NoOp(Graph* g, const std::vector<Node*>& control_inputs) {
303 Node* ret;
304 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "NoOp")
305 .ControlInputs(control_inputs)
306 .Finalize(g, &ret));
307 return ret;
308 }
309
Switch(Graph * g,Node * in0,Node * in1)310 Node* Switch(Graph* g, Node* in0, Node* in1) {
311 Node* ret;
312 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Switch")
313 .Input(in0)
314 .Input(in1)
315 .Finalize(g, &ret));
316 return ret;
317 }
318
Enter(Graph * g,Node * input,const string & frame_name)319 Node* Enter(Graph* g, Node* input, const string& frame_name) {
320 Node* ret;
321 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Enter")
322 .Input(input)
323 .Attr("frame_name", frame_name)
324 .Finalize(g, &ret));
325 return ret;
326 }
327
Exit(Graph * g,Node * input)328 Node* Exit(Graph* g, Node* input) {
329 Node* ret;
330 TF_CHECK_OK(
331 NodeBuilder(g->NewName("n"), "Exit").Input(input).Finalize(g, &ret));
332 return ret;
333 }
334
Merge(Graph * g,Node * in0,Node * in1)335 Node* Merge(Graph* g, Node* in0, Node* in1) {
336 Node* ret;
337 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Merge")
338 .Input({in0, in1})
339 .Finalize(g, &ret));
340 return ret;
341 }
342
Merge(Graph * g,Node * in0,gtl::ArraySlice<string> remaining_in)343 Node* Merge(Graph* g, Node* in0, gtl::ArraySlice<string> remaining_in) {
344 std::vector<NodeBuilder::NodeOut> inputs;
345 inputs.reserve(remaining_in.size() + 1);
346 inputs.emplace_back(in0);
347 for (const string& in_name : remaining_in) {
348 inputs.emplace_back(in_name, 0, inputs[0].dt);
349 }
350
351 Node* ret;
352 TF_CHECK_OK(
353 NodeBuilder(g->NewName("n"), "Merge").Input(inputs).Finalize(g, &ret));
354 return ret;
355 }
356
Concat(Graph * g,Node * concat_dim,gtl::ArraySlice<Node * > tensors)357 Node* Concat(Graph* g, Node* concat_dim, gtl::ArraySlice<Node*> tensors) {
358 std::vector<NodeBuilder::NodeOut> nodeouts;
359 nodeouts.reserve(tensors.size());
360 for (auto const t : tensors) {
361 nodeouts.emplace_back(t);
362 }
363 Node* ret;
364 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Concat")
365 .Input(concat_dim)
366 .Input(nodeouts)
367 .Finalize(g, &ret));
368 return ret;
369 }
370
ConcatV2(Graph * g,gtl::ArraySlice<Node * > tensors,Node * concat_dim)371 Node* ConcatV2(Graph* g, gtl::ArraySlice<Node*> tensors, Node* concat_dim) {
372 std::vector<NodeBuilder::NodeOut> nodeouts;
373 nodeouts.reserve(tensors.size());
374 for (auto const t : tensors) {
375 nodeouts.emplace_back(t);
376 }
377 Node* ret;
378 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "ConcatV2")
379 .Input(nodeouts)
380 .Input(concat_dim)
381 .Finalize(g, &ret));
382 return ret;
383 }
384
Next(Graph * g,const string & name,Node * input)385 Node* Next(Graph* g, const string& name, Node* input) {
386 Node* ret;
387 TF_CHECK_OK(
388 NodeBuilder(name, "NextIteration").Input(input).Finalize(g, &ret));
389 return ret;
390 }
391
LoopCond(Graph * g,Node * input)392 Node* LoopCond(Graph* g, Node* input) {
393 Node* ret;
394 TF_CHECK_OK(
395 NodeBuilder(g->NewName("n"), "LoopCond").Input(input).Finalize(g, &ret));
396 return ret;
397 }
398
Less(Graph * g,Node * in0,Node * in1)399 Node* Less(Graph* g, Node* in0, Node* in1) {
400 return Binary(g, "Less", in0, in1);
401 }
402
Select(Graph * g,Node * c,Node * inx,Node * iny)403 Node* Select(Graph* g, Node* c, Node* inx, Node* iny) {
404 Node* ret;
405 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Select")
406 .Input(c)
407 .Input(inx)
408 .Input(iny)
409 .Finalize(g, &ret));
410 return ret;
411 }
412
Cast(Graph * g,Node * in,DataType dst)413 Node* Cast(Graph* g, Node* in, DataType dst) {
414 Node* ret;
415 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Cast")
416 .Input(in)
417 .Attr("DstT", dst)
418 .Finalize(g, &ret));
419 return ret;
420 }
421
Gather(Graph * g,Node * in0,Node * in1,Node * axis)422 Node* Gather(Graph* g, Node* in0, Node* in1, Node* axis) {
423 Node* ret;
424 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "GatherV2")
425 .Input(in0)
426 .Input(in1)
427 .Input(axis)
428 .Finalize(g, &ret));
429 return ret;
430 }
431
GetSessionTensor(Graph * g,Node * in)432 Node* GetSessionTensor(Graph* g, Node* in) {
433 Node* ret;
434 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "GetSessionTensor")
435 .Input(in, 0)
436 .Attr("dtype", DT_FLOAT)
437 .Finalize(g, &ret));
438 return ret;
439 }
440
Relu(Graph * g,Node * in)441 Node* Relu(Graph* g, Node* in) {
442 Node* ret;
443 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Relu")
444 .Input(in, 0)
445 .Attr("T", DT_FLOAT)
446 .Finalize(g, &ret));
447 return ret;
448 }
449
Relu6(Graph * g,Node * in)450 Node* Relu6(Graph* g, Node* in) {
451 Node* ret;
452 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Relu6")
453 .Input(in, 0)
454 .Attr("T", DT_FLOAT)
455 .Finalize(g, &ret));
456 return ret;
457 }
458
BiasAdd(Graph * g,Node * value,Node * bias)459 Node* BiasAdd(Graph* g, Node* value, Node* bias) {
460 Node* ret;
461 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BiasAdd")
462 .Input(value)
463 .Input(bias)
464 .Attr("T", DT_FLOAT)
465 .Finalize(g, &ret));
466 return ret;
467 }
468
Conv2D(Graph * g,Node * in0,Node * in1)469 Node* Conv2D(Graph* g, Node* in0, Node* in1) {
470 Node* ret;
471 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Conv2D")
472 .Input(in0)
473 .Input(in1)
474 .Attr("T", DT_FLOAT)
475 .Attr("strides", {1, 1, 1, 1})
476 .Attr("padding", "SAME")
477 .Finalize(g, &ret));
478 return ret;
479 }
480
Diag(Graph * g,Node * in,DataType type)481 Node* Diag(Graph* g, Node* in, DataType type) {
482 Node* ret;
483 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Diag")
484 .Input(in)
485 .Attr("T", type)
486 .Finalize(g, &ret));
487 return ret;
488 }
489
DiagPart(Graph * g,Node * in,DataType type)490 Node* DiagPart(Graph* g, Node* in, DataType type) {
491 Node* ret;
492 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "DiagPart")
493 .Input(in)
494 .Attr("T", type)
495 .Finalize(g, &ret));
496 return ret;
497 }
498
CheckNumerics(Graph * g,Node * in,const string & message)499 Node* CheckNumerics(Graph* g, Node* in, const string& message) {
500 Node* ret;
501 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "CheckNumerics")
502 .Input(in)
503 .Attr("message", message)
504 .Finalize(g, &ret));
505 return ret;
506 }
507
Arg(Graph * g,int64 index,DataType type)508 Node* Arg(Graph* g, int64 index, DataType type) {
509 Node* ret;
510 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Arg")
511 .Attr("T", type)
512 .Attr("index", index)
513 .Finalize(g, &ret));
514 return ret;
515 }
516
Retval(Graph * g,int64 index,Node * in)517 Node* Retval(Graph* g, int64 index, Node* in) {
518 Node* ret;
519 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Retval")
520 .Input(in)
521 .Attr("index", index)
522 .Finalize(g, &ret));
523 return ret;
524 }
525
ToGraphDef(Graph * g,GraphDef * gdef)526 void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); }
527
528 } // end namespace graph
529 } // end namespace test
530 } // end namespace tensorflow
531