1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/function_testlib.h"
17 
18 #include "tensorflow/core/framework/function.h"
19 #include "tensorflow/core/framework/node_def.pb.h"
20 #include "tensorflow/core/framework/tensor_testutil.h"
21 #include "tensorflow/core/framework/versions.pb.h"
22 #include "tensorflow/core/lib/core/threadpool.h"
23 #include "tensorflow/core/public/version.h"
24 
25 namespace tensorflow {
26 namespace test {
27 namespace function {
28 
29 typedef FunctionDefHelper FDH;
30 
GDef(gtl::ArraySlice<NodeDef> nodes,gtl::ArraySlice<FunctionDef> funcs)31 GraphDef GDef(gtl::ArraySlice<NodeDef> nodes,
32               gtl::ArraySlice<FunctionDef> funcs) {
33   GraphDef g;
34   VersionDef* versions = g.mutable_versions();
35   versions->set_producer(TF_GRAPH_DEF_VERSION);
36   versions->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
37   for (const auto& n : nodes) {
38     *(g.add_node()) = n;
39   }
40   auto lib = g.mutable_library();
41   for (const auto& f : funcs) {
42     *(lib->add_function()) = f;
43   }
44   return g;
45 }
46 
47 // Helper to construct a NodeDef.
NDef(StringPiece name,StringPiece op,gtl::ArraySlice<string> inputs,gtl::ArraySlice<std::pair<string,FDH::AttrValueWrapper>> attrs,const string & device)48 NodeDef NDef(StringPiece name, StringPiece op, gtl::ArraySlice<string> inputs,
49              gtl::ArraySlice<std::pair<string, FDH::AttrValueWrapper>> attrs,
50              const string& device) {
51   NodeDef n;
52   n.set_name(string(name));
53   n.set_op(string(op));
54   for (const auto& in : inputs) n.add_input(in);
55   n.set_device(device);
56   for (const auto& na : attrs)
57     n.mutable_attr()->insert({na.first, na.second.proto});
58   return n;
59 }
60 
NonZero()61 FunctionDef NonZero() {
62   return FDH::Define(
63       // Name
64       "NonZero",
65       // Args
66       {"x:T"},
67       // Return values
68       {"y:T"},
69       // Attr def
70       {"T:{float, double, int32, int64, string}"},
71       // Nodes
72       {
73           {{"y"}, "Identity", {"x"}, {{"T", "$T"}}},
74       });
75 }
76 
IsZero()77 FunctionDef IsZero() {
78   const Tensor kZero = test::AsScalar<int64>(0);
79   return FDH::Define(
80       // Name
81       "IsZero",
82       // Args
83       {"x: T"},
84       // Return values
85       {"equal: bool"},
86       // Attr def
87       {"T:{float, double, int32, int64, string}"},
88       {
89           {{"zero"}, "Const", {}, {{"value", kZero}, {"dtype", DT_INT64}}},
90           {{"cast"}, "Cast", {"zero"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
91           {{"equal"}, "Equal", {"x", "cast"}, {{"T", "$T"}}},
92       });
93 }
94 
RandomUniform()95 FunctionDef RandomUniform() {
96   const Tensor kZero = test::AsScalar<int64>(0);
97 
98   return FDH::Define(
99       // Name
100       "RandomUniform",
101       // Args
102       {"x: T"},
103       // Return values
104       {"random_uniform: int64"},
105       // Attr def
106       {"T:{float, double, int32, int64, string}"},
107       {{{"random_uniform/shape"},
108         "Const",
109         {},
110         {{"value", kZero}, {"dtype", DT_INT64}}},
111        {{"random_uniform"},
112         "RandomUniform",
113         {"random_uniform/shape"},
114         {{"T", DT_INT32},
115          {"Tout", DT_FLOAT},
116          {"seed", 87654321},
117          {"seed2", 42}}}});
118 }
119 
XTimesTwo()120 FunctionDef XTimesTwo() {
121   const Tensor kTwo = test::AsScalar<int64>(2);
122   return FDH::Define(
123       // Name
124       "XTimesTwo",
125       // Args
126       {"x: T"},
127       // Return values
128       {"y: T"},
129       // Attr def
130       {"T: {float, double, int32, int64}"},
131       // Nodes
132       {
133           {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}},
134           {{"scale"}, "Cast", {"two"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
135           {{"y"}, "Mul", {"x", "scale"}, {{"T", "$T"}}},
136       });
137 }
138 
TwoDeviceMult()139 FunctionDef TwoDeviceMult() {
140   const Tensor kTwo = test::AsScalar<int64>(2);
141   const Tensor kThree = test::AsScalar<int64>(3);
142   return FDH::Create(
143       // Name
144       "TwoDeviceMult",
145       // Args
146       {"x: T"},
147       // Return values
148       {"y_cpu: T", "y_gpu: T"},
149       // Attr def
150       {"T: {float, double, int32, int64}"},
151       // Nodes
152       {
153           {{"num_2"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}},
154           {{"num_3"}, "Const", {}, {{"value", kThree}, {"dtype", DT_INT64}}},
155           {{"factor_2"},
156            "Cast",
157            {"num_2:output:0"},
158            {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
159           {{"factor_3"},
160            "Cast",
161            {"num_3:output:0"},
162            {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
163           {{"y_cpu"},
164            "Mul",
165            {"x", "factor_2:y:0"},
166            {{"T", "$T"}},
167            {},
168            "/device:CPU:0"},
169           {{"y_gpu"},
170            "Mul",
171            {"x", "factor_3:y:0"},
172            {{"T", "$T"}},
173            {},
174            "/device:GPU:0"},
175       },
176       {{"y_cpu", "y_cpu:z:0"}, {"y_gpu", "y_gpu:z:0"}});
177 }
178 
TwoDeviceInputOutput()179 FunctionDef TwoDeviceInputOutput() {
180   const Tensor kTwo = test::AsScalar<float>(2);
181   const Tensor kThree = test::AsScalar<float>(3);
182   return FDH::Create(
183       // Name
184       "TwoDeviceInputOutput",
185       // Args
186       {"x1: T", "x2: T"},
187       // Return values
188       {"y_cpu: T", "y_gpu: T"},
189       // Attr def
190       {"T: {float}"},
191       // Nodes
192       {
193           {{"num_2"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_FLOAT}}},
194           {{"num_3"}, "Const", {}, {{"value", kThree}, {"dtype", DT_FLOAT}}},
195           {{"y_cpu"},
196            "Mul",
197            {"x1", "num_2:output:0"},
198            {{"T", "$T"}},
199            {},
200            "/device:CPU:0"},
201           {{"y_gpu"},
202            "Mul",
203            {"x2", "num_3:output:0"},
204            {{"T", "$T"}},
205            {},
206            "/device:GPU:0"},
207       },
208       {{"y_cpu", "y_cpu:z:0"}, {"y_gpu", "y_gpu:z:0"}});
209 }
210 
FuncWithListInput()211 FunctionDef FuncWithListInput() {
212   const Tensor kTwo = test::AsScalar<float>(2);
213   return FDH::Create(
214       // Name
215       "FuncWithListInput",
216       // Args
217       {"x1: N * T"},
218       // Return values
219       {},
220       // Attr def
221       {"T: {float}", "N: int >= 1"},
222       // Nodes
223       {
224           {{"num_2"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_FLOAT}}},
225       },
226       {});
227 }
228 
FuncWithListOutput()229 FunctionDef FuncWithListOutput() {
230   const Tensor kTwo = test::AsScalar<float>(2);
231   return FDH::Create(
232       // Name
233       "FuncWithListOutput",
234       // Args
235       {},
236       // Return values
237       {"y: N * T"},
238       // Attr def
239       {"T: {float}", "N: int >= 1"},
240       // Nodes
241       {
242           {{"num_2"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_FLOAT}}},
243       },
244       {{"y", "num_2:output:0"}});
245 }
246 
XAddX()247 FunctionDef XAddX() {
248   return FDH::Define(
249       // Name
250       "XAddX",
251       // Args
252       {"x: T"},
253       // Return values
254       {"y: T"},
255       // Attr def
256       {"T: {float, double, int32, int64}"},
257       // Nodes
258       {
259           {{"y"}, "Add", {"x", "x"}, {{"T", "$T"}}},
260       });
261 }
262 
XAddY()263 FunctionDef XAddY() {
264   return FDH::Define(
265       // Name
266       "XAddY",
267       // Args
268       {"x: T", "y: T"},
269       // Return values
270       {"z: T"},
271       // Attr def
272       {"T: {float, double, int32, int64}"},
273       // Nodes
274       {
275           {{"z"}, "Add", {"x", "y"}, {{"T", "$T"}}},
276       });
277 }
278 
XTimesTwoInt32()279 FunctionDef XTimesTwoInt32() {
280   const Tensor kTwo = test::AsScalar<int64>(2);
281   return FDH::Define(
282       // Name
283       "XTimesTwoInt32",
284       // Args
285       {"x: int32"},
286       // Return values
287       {"y: int32"}, {},
288       // Nodes
289       {
290           {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}},
291           {{"scale"},
292            "Cast",
293            {"two"},
294            {{"SrcT", DT_INT64}, {"DstT", DT_INT32}}},
295           {{"y"}, "Mul", {"x", "scale"}, {{"T", DT_INT32}}},
296       });
297 }
298 
XTimesFour()299 FunctionDef XTimesFour() {
300   return FDH::Create(
301       // Name
302       "XTimesFour",
303       // Args
304       {"x: T"},
305       // Return values
306       {"y: T"},
307       // Attr def
308       {"T: {float, double, int32, int64}"},
309       // Nodes
310       {
311           {{"x2"}, "XTimesTwo", {"x"}, {{"T", "$T"}}},
312           {{"y"}, "XTimesTwo", {"x2:y:0"}, {{"T", "$T"}}},
313       },
314       {{"y", "y:y:0"}});
315 }
316 
XTimes16()317 FunctionDef XTimes16() {
318   return FDH::Create(
319       // Name
320       "XTimes16",
321       // Args
322       {"x: T"},
323       // Return values
324       {"y: T"},
325       // Attr def
326       {"T: {float, double, int32, int64}"},
327       // Nodes
328       {
329           {{"x4"}, "XTimesFour", {"x"}, {{"T", "$T"}}},
330           {{"y"}, "XTimesFour", {"x4:y:0"}, {{"T", "$T"}}},
331       },
332       {{"y", "y:y:0"}});
333 }
334 
WXPlusB()335 FunctionDef WXPlusB() {
336   return FDH::Define(
337       // Name
338       "WXPlusB",
339       // Args
340       {"w: T", "x: T", "b: T"},
341       // Return values
342       {"y: T"},
343       // Attr def
344       {"T: {float, double}"},
345       // Nodes
346       {{{"mm"},
347         "MatMul",
348         {"w", "x"},
349         {{"T", "$T"}, {"transpose_a", false}, {"transpose_b", false}}},
350        {{"y"}, "Add", {"mm", "b"}, {{"T", "$T"}}}});
351 }
352 
Swap()353 FunctionDef Swap() {
354   return FDH::Define(
355       // Name
356       "Swap",
357       // Args
358       {"i0: T", "i1: T"},
359       // Return values
360       {"o0: T", "o1: T"},
361       // Attr def
362       {"T: {float, double, resource}"},
363       // Nodes
364       {{{"o0"}, "Identity", {"i1"}, {{"T", "$T"}}},
365        {{"o1"}, "Identity", {"i0"}, {{"T", "$T"}}}});
366 }
367 
EmptyBodySwap()368 FunctionDef EmptyBodySwap() {
369   return FDH::Create(
370       // Name
371       "EmptyBodySwap",
372       // Args
373       {"i0: T", "i1: T"},
374       // Return values
375       {"o0: T", "o1: T"},
376       // Attr def
377       {"T: {float, double, resource}"},
378       // Nodes
379       {},
380       // Output mapping
381       {{"o0", "i1"}, {"o1", "i0"}});
382 }
383 
ResourceOutput()384 FunctionDef ResourceOutput() {
385   const Tensor kTwo = test::AsScalar<float>(2);
386   return FDH::Create(
387       // Name
388       "ResourceOutput",
389       // Args
390       {"x: float", "y: resource"},
391       // Return values
392       {"y_out: resource", "two_x: float"},
393       // Attr def
394       {},
395       // Nodes
396       {
397           {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_FLOAT}}},
398           {{"mul"}, "Mul", {"x", "two:output:0"}, {{"T", DT_FLOAT}}, {}},
399       },
400       {{"y_out", "y"}, {"two_x", "mul:z:0"}});
401 }
402 
ResourceIdentity()403 FunctionDef ResourceIdentity() {
404   return FDH::Create(
405       // Name
406       "ResourceIdentity",
407       // Args
408       {"x: resource"},
409       // Return values
410       {"y: resource"},
411       // Attr def
412       {},
413       // Nodes
414       {},
415       // Output mapping
416       {{"y", "x"}});
417 }
418 
ReadResourceVariable()419 FunctionDef ReadResourceVariable() {
420   return FDH::Create(
421       // Name
422       "ReadResourceVariable",
423       // Args
424       {"x: resource"},
425       // Return values
426       {"y: float"},
427       // Attr def
428       {},
429       // Nodes
430       {
431           {{"read"}, "ReadVariableOp", {"x"}, {{"dtype", DT_FLOAT}}, {}},
432       },
433       {{"y", "read:value:0"}});
434 }
435 
InvalidControlFlow()436 FunctionDef InvalidControlFlow() {
437   return FDH::Create(
438       // Name
439       "InvalidControlFlow",
440       // Args
441       {"i: int32"},
442       // Return values
443       {"o: int32"},
444       // Attr def
445       {},
446       // Nodes
447       {{{"enter"}, "Enter", {"i"}, {{"T", DT_INT32}, {"frame_name", "while"}}},
448        {{"add"}, "Add", {"enter:output", "i"}, {{"T", DT_INT32}}}},
449       // Output mapping
450       {{"o", "add:z"}});
451 }
452 
LessThanOrEqualToN(int64 N)453 FunctionDef LessThanOrEqualToN(int64 N) {
454   const Tensor kN = test::AsScalar<int64>(N);
455   return FDH::Define(
456       // Name
457       "LessThanOrEqualToN",
458       // Args
459       {"x: T"},
460       // Return values
461       {"z: bool"},
462       // Attr def
463       {"T: {float, double, int32, int64}"},
464       // Nodes
465       {
466           {{"N"}, "Const", {}, {{"value", kN}, {"dtype", DT_INT64}}},
467           {{"y"}, "Cast", {"N"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
468           {{"z"}, "LessEqual", {"x", "y"}, {{"T", "$T"}}},
469       });
470 }
471 
XPlusOneXTimesY()472 FunctionDef XPlusOneXTimesY() {
473   const Tensor kOne = test::AsScalar<int64>(1);
474   return FDH::Define(
475       // Name
476       "XPlusOneXTimesY",
477       // Args
478       {"x: T", "y: T"},
479       // Return values
480       {"s: T", "t: T"},
481       // Attr def
482       {"T: {float, double, int32, int64}"},
483       // Nodes
484       {{{"one"}, "Const", {}, {{"value", kOne}, {"dtype", DT_INT64}}},
485        {{"increment"}, "Cast", {"one"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
486        {{"s"}, "Add", {"x", "increment"}, {{"T", "$T"}}},
487        {{"t"}, "Mul", {"x", "y"}, {{"T", "$T"}}}});
488 }
489 
XYXLessThanOrEqualToN(int64 N)490 FunctionDef XYXLessThanOrEqualToN(int64 N) {
491   const Tensor kN = test::AsScalar<int64>(N);
492   return FDH::Define(
493       // Name
494       "XYXLessThanOrEqualToN",
495       // Args
496       {"x: T", "y: T"},
497       // Return values
498       {"z: bool"},
499       // Attr def
500       {"T: {float, double, int32, int64}"},
501       // Nodes
502       {
503           {{"N"}, "Const", {}, {{"value", kN}, {"dtype", DT_INT64}}},
504           {{"N1"}, "Cast", {"N"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
505           {{"z"}, "LessEqual", {"x", "N1"}, {{"T", "$T"}}},
506       });
507 }
508 
RandomUniformLess()509 FunctionDef RandomUniformLess() {
510   const Tensor kZero = test::AsScalar<int32>(0);
511   const Tensor kOne = test::AsScalar<int32>(1);
512   const Tensor k005 = test::AsScalar<float>(0.05);
513 
514   return FDH::Define(
515       // Name
516       "RandomUniformLess",
517       // Args
518       {"arg0: int64"},
519       // Return values
520       {"strided_slice: bool"},
521       // Attr def
522       {"T:{float, double, int32, int64, string}"},
523       {{{"random_uniform/shape"},
524         "Const",
525         {},
526         {{"value", kZero}, {"dtype", DT_INT32}}},
527 
528        {{"random_uniform/RandomUniform"},
529         "RandomUniform",
530         {"random_uniform/shape"},
531         {{"T", DT_INT32}, {"Tout", DT_FLOAT}, {"seed", 0}, {"seed2", 0}}},
532 
533        {{"Less/y"}, "Const", {}, {{"value", k005}, {"dtype", DT_FLOAT}}},
534 
535        {{"Less"},
536         "Less",
537         {"random_uniform/RandomUniform", "Less/y"},
538         {{"T", DT_FLOAT}}},
539 
540        {{"strided_slice/stack"},
541         "Const",
542         {},
543         {{"value", kZero}, {"dtype", DT_INT32}}},
544 
545        {{"strided_slice/stack_1"},
546         "Const",
547         {},
548         {{"value", kOne}, {"dtype", DT_INT32}}},
549 
550        {{"strided_slice/stack_2"},
551         "Const",
552         {},
553         {{"value", kOne}, {"dtype", DT_INT32}}},
554 
555        {{"strided_slice"},
556         "StridedSlice",
557         {"Less", "strided_slice/stack", "strided_slice/stack_1",
558          "strided_slice/stack_2"},
559         {{"Index", DT_INT32},
560          {"T", DT_BOOL},
561          {"begin_mask", 0},
562          {"ellipsis_mask", 0},
563          {"end_mask", 0},
564          {"new_axis_mask", 0},
565          {"shrink_axis_mask", 0}}}});
566 }
567 
MakeRangeDataset()568 FunctionDef MakeRangeDataset() {
569   return FDH::Define(
570       /*name=*/"MakeRangeDataset",
571       /*arg_def=*/{"start: int64", "stop: int64", "step: int64"},
572       /*ret_def=*/{"y:variant"},
573       /*attr_def=*/
574       {"output_types: list(type) >= 1", "output_shapes: list(shape) >= 1"},
575       /*node_def=*/
576       {{/*ret=*/{"y"},
577         /*op=*/"RangeDataset",
578         /*arg=*/{"start", "stop", "step"},
579         /*attr=*/
580         {{"output_types", "$output_types"},
581          {"output_shapes", "$output_shapes"}}}});
582 }
583 
MakeBatchDataset()584 FunctionDef MakeBatchDataset() {
585   return FDH::Define(
586       /*name=*/"MakeBatchDataset",
587       /*arg_def=*/
588       {"input_dataset: variant", "batch_size: int64", "drop_remainder: bool"},
589       /*ret_def=*/{"y: variant"},
590       /*attr_def=*/
591       {"parallel_copy: bool = false", "output_types: list(type) >= 1",
592        "output_shapes: list(shape) >= 1"},
593       /*node_def=*/
594       {{/*ret=*/{"y"},
595         /*op=*/"BatchDatasetV2",
596         /*arg=*/{"input_dataset", "batch_size", "drop_remainder"},
597         /*attr=*/
598         {{"parallel_copy", "$parallel_copy"},
599          {"output_types", "$output_types"},
600          {"output_shapes", "$output_shapes"}}}});
601 }
602 
MakeMapDataset(bool has_other_args)603 FunctionDef MakeMapDataset(bool has_other_args) {
604   std::vector<string> args = {"input_dataset: variant"};
605   std::vector<string> inputs = {"input_dataset"};
606   if (has_other_args) {
607     args.emplace_back("other_arguments: Targuments");
608     inputs.emplace_back("other_arguments");
609   }
610 
611   return FDH::Define(
612       /*name=*/"MakeMapDataset",
613       /*arg_def=*/args,
614       /*ret_def=*/
615       {"y: variant"},
616       /*attr_def=*/
617       {"f: func", "Targuments: list(type) >= 0",
618        "output_types: list(type) >= 1", "output_shapes: list(shape) >= 1",
619        "use_inter_op_parallelism: bool = true",
620        "preserve_cardinality: bool = false"},
621       /*node_def=*/
622       {{/*ret=*/{"y"},
623         /*op=*/"MapDataset",
624         /*arg=*/inputs,
625         /*attr=*/
626         {{"f", "$f"},
627          {"Targuments", "$Targuments"},
628          {"output_types", "$output_types"},
629          {"output_shapes", "$output_shapes"},
630          {"use_inter_op_parallelism", "$use_inter_op_parallelism"},
631          {"preserve_cardinality", "$preserve_cardinality"}}}});
632 }
633 
MakeTakeDataset()634 FunctionDef MakeTakeDataset() {
635   return FDH::Define(
636       // Name
637       "TakeDataset",
638       // Args
639       {"input_dataset: variant", "count: int64"},
640       // Return values
641       {"y:variant"},
642       // Attr def
643       {"output_types: list(type) >= 1", "output_shapes: list(shape) >= 1"},
644       // Nodes
645       {{{"y"},
646         "TakeDataset",
647         {"input_dataset", "count"},
648         {{"output_types", "$output_types"},
649          {"output_shapes", "$output_shapes"}}}});
650 }
651 
MakeTensorSliceDataset()652 FunctionDef MakeTensorSliceDataset() {
653   return FDH::Define(
654       // Name
655       "MakeTensorSliceDataset",
656       // Args
657       {"x: Toutput_types"},
658       // Return values
659       {"y: variant"},
660       // Attr def
661       {"Toutput_types: list(type) >= 1", "output_shapes: list(shape) >= 1"},
662       // Nodes
663       {{{"y"},
664         "TensorSliceDataset",
665         {"x"},
666         {{"Toutput_types", "$Toutput_types"},
667          {"output_shapes", "$output_shapes"}}}});
668 }
669 
Unique()670 FunctionDef Unique() {
671   return FDH::Create(
672       // Name
673       "GetUnique",
674       // Args
675       {"x:T"},
676       // Return values
677       {"y:T", "idx: out_idx"},
678       // Attr def
679       {"T: type", "out_idx: {int32, int64} = DT_INT32"},
680       // Nodes
681       {
682           {{"result"}, "Unique", {"x"}, {{"T", "$T"}, {"out_idx", "$out_idx"}}},
683       },
684       {{"y", "result:y:0"}, {"idx", "result:idx:0"}});
685 }
686 
FunctionTestSchedClosure(std::function<void ()> fn)687 void FunctionTestSchedClosure(std::function<void()> fn) {
688   static thread::ThreadPool* w =
689       new thread::ThreadPool(Env::Default(), "Test", 8);
690   w->Schedule(std::move(fn));
691 }
692 
693 }  // end namespace function
694 }  // end namespace test
695 }  // end namespace tensorflow
696