1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/function_testlib.h"
17
18 #include "tensorflow/core/framework/function.h"
19 #include "tensorflow/core/framework/node_def.pb.h"
20 #include "tensorflow/core/framework/tensor_testutil.h"
21 #include "tensorflow/core/framework/versions.pb.h"
22 #include "tensorflow/core/lib/core/threadpool.h"
23 #include "tensorflow/core/public/version.h"
24
25 namespace tensorflow {
26 namespace test {
27 namespace function {
28
29 typedef FunctionDefHelper FDH;
30
GDef(gtl::ArraySlice<NodeDef> nodes,gtl::ArraySlice<FunctionDef> funcs)31 GraphDef GDef(gtl::ArraySlice<NodeDef> nodes,
32 gtl::ArraySlice<FunctionDef> funcs) {
33 GraphDef g;
34 VersionDef* versions = g.mutable_versions();
35 versions->set_producer(TF_GRAPH_DEF_VERSION);
36 versions->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
37 for (const auto& n : nodes) {
38 *(g.add_node()) = n;
39 }
40 auto lib = g.mutable_library();
41 for (const auto& f : funcs) {
42 *(lib->add_function()) = f;
43 }
44 return g;
45 }
46
47 // Helper to construct a NodeDef.
NDef(StringPiece name,StringPiece op,gtl::ArraySlice<string> inputs,gtl::ArraySlice<std::pair<string,FDH::AttrValueWrapper>> attrs,const string & device)48 NodeDef NDef(StringPiece name, StringPiece op, gtl::ArraySlice<string> inputs,
49 gtl::ArraySlice<std::pair<string, FDH::AttrValueWrapper>> attrs,
50 const string& device) {
51 NodeDef n;
52 n.set_name(string(name));
53 n.set_op(string(op));
54 for (const auto& in : inputs) n.add_input(in);
55 n.set_device(device);
56 for (auto na : attrs) n.mutable_attr()->insert({na.first, na.second.proto});
57 return n;
58 }
59
NonZero()60 FunctionDef NonZero() {
61 return FDH::Define(
62 // Name
63 "NonZero",
64 // Args
65 {"x:T"},
66 // Return values
67 {"y:T"},
68 // Attr def
69 {"T:{float, double, int32, int64, string}"},
70 // Nodes
71 {
72 {{"y"}, "Identity", {"x"}, {{"T", "$T"}}},
73 });
74 }
75
IsZero()76 FunctionDef IsZero() {
77 const Tensor kZero = test::AsScalar<int64>(0);
78 return FDH::Define(
79 // Name
80 "IsZero",
81 // Args
82 {"x: T"},
83 // Return values
84 {"equal: T"},
85 // Attr def
86 {"T:{float, double, int32, int64, string}"},
87 {
88 {{"zero"}, "Const", {}, {{"value", kZero}, {"dtype", DT_INT64}}},
89 {{"cast"}, "Cast", {"zero"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
90 {{"equal"}, "Equal", {"x", "cast"}, {{"T", "$T"}}},
91 });
92 }
93
RandomUniform()94 FunctionDef RandomUniform() {
95 const Tensor kZero = test::AsScalar<int64>(0);
96
97 return FDH::Define(
98 // Name
99 "RandomUniform",
100 // Args
101 {"x: T"},
102 // Return values
103 {"random_uniform: int64"},
104 // Attr def
105 {"T:{float, double, int32, int64, string}"},
106 {{{"random_uniform/shape"},
107 "Const",
108 {},
109 {{"value", kZero}, {"dtype", DT_INT64}}},
110 {{"random_uniform"},
111 "RandomUniform",
112 {"random_uniform/shape"},
113 {{"T", DT_INT32},
114 {"Tout", DT_FLOAT},
115 {"seed", 87654321},
116 {"seed2", 42}}}});
117 }
118
XTimesTwo()119 FunctionDef XTimesTwo() {
120 const Tensor kTwo = test::AsScalar<int64>(2);
121 return FDH::Define(
122 // Name
123 "XTimesTwo",
124 // Args
125 {"x: T"},
126 // Return values
127 {"y: T"},
128 // Attr def
129 {"T: {float, double, int32, int64}"},
130 // Nodes
131 {
132 {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}},
133 {{"scale"}, "Cast", {"two"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
134 {{"y"}, "Mul", {"x", "scale"}, {{"T", "$T"}}},
135 });
136 }
137
TwoDeviceMult()138 FunctionDef TwoDeviceMult() {
139 const Tensor kTwo = test::AsScalar<int64>(2);
140 const Tensor kThree = test::AsScalar<int64>(3);
141 return FDH::Create(
142 // Name
143 "TwoDeviceMult",
144 // Args
145 {"x: T"},
146 // Return values
147 {"y_cpu: T", "y_gpu: T"},
148 // Attr def
149 {"T: {float, double, int32, int64}"},
150 // Nodes
151 {
152 {{"num_2"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}},
153 {{"num_3"}, "Const", {}, {{"value", kThree}, {"dtype", DT_INT64}}},
154 {{"factor_2"},
155 "Cast",
156 {"num_2:output:0"},
157 {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
158 {{"factor_3"},
159 "Cast",
160 {"num_3:output:0"},
161 {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
162 {{"y_cpu"},
163 "Mul",
164 {"x", "factor_2:y:0"},
165 {{"T", "$T"}},
166 {},
167 "/device:CPU:0"},
168 {{"y_gpu"},
169 "Mul",
170 {"x", "factor_3:y:0"},
171 {{"T", "$T"}},
172 {},
173 "/device:GPU:0"},
174 },
175 {{"y_cpu", "y_cpu:z:0"}, {"y_gpu", "y_gpu:z:0"}});
176 }
177
TwoDeviceInputOutput()178 FunctionDef TwoDeviceInputOutput() {
179 const Tensor kTwo = test::AsScalar<float>(2);
180 const Tensor kThree = test::AsScalar<float>(3);
181 return FDH::Create(
182 // Name
183 "TwoDeviceInputOutput",
184 // Args
185 {"x1: T", "x2: T"},
186 // Return values
187 {"y_cpu: T", "y_gpu: T"},
188 // Attr def
189 {"T: {float}"},
190 // Nodes
191 {
192 {{"num_2"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_FLOAT}}},
193 {{"num_3"}, "Const", {}, {{"value", kThree}, {"dtype", DT_FLOAT}}},
194 {{"y_cpu"},
195 "Mul",
196 {"x1", "num_2:output:0"},
197 {{"T", "$T"}},
198 {},
199 "/device:CPU:0"},
200 {{"y_gpu"},
201 "Mul",
202 {"x2", "num_3:output:0"},
203 {{"T", "$T"}},
204 {},
205 "/device:GPU:0"},
206 },
207 {{"y_cpu", "y_cpu:z:0"}, {"y_gpu", "y_gpu:z:0"}});
208 }
209
FuncWithListInput()210 FunctionDef FuncWithListInput() {
211 const Tensor kTwo = test::AsScalar<float>(2);
212 return FDH::Create(
213 // Name
214 "FuncWithListInput",
215 // Args
216 {"x1: N * T"},
217 // Return values
218 {},
219 // Attr def
220 {"T: {float}", "N: int >= 1"},
221 // Nodes
222 {
223 {{"num_2"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_FLOAT}}},
224 },
225 {});
226 }
227
FuncWithListOutput()228 FunctionDef FuncWithListOutput() {
229 const Tensor kTwo = test::AsScalar<float>(2);
230 return FDH::Create(
231 // Name
232 "FuncWithListOutput",
233 // Args
234 {},
235 // Return values
236 {"y: N * T"},
237 // Attr def
238 {"T: {float}", "N: int >= 1"},
239 // Nodes
240 {
241 {{"num_2"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_FLOAT}}},
242 },
243 {{"y", "num_2:output:0"}});
244 }
245
XAddX()246 FunctionDef XAddX() {
247 return FDH::Define(
248 // Name
249 "XAddX",
250 // Args
251 {"x: T"},
252 // Return values
253 {"y: T"},
254 // Attr def
255 {"T: {float, double, int32, int64}"},
256 // Nodes
257 {
258 {{"y"}, "Add", {"x", "x"}, {{"T", "$T"}}},
259 });
260 }
261
XTimesTwoInt32()262 FunctionDef XTimesTwoInt32() {
263 const Tensor kTwo = test::AsScalar<int64>(2);
264 return FDH::Define(
265 // Name
266 "XTimesTwoInt32",
267 // Args
268 {"x: int32"},
269 // Return values
270 {"y: int32"}, {},
271 // Nodes
272 {
273 {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}},
274 {{"scale"},
275 "Cast",
276 {"two"},
277 {{"SrcT", DT_INT64}, {"DstT", DT_INT32}}},
278 {{"y"}, "Mul", {"x", "scale"}, {{"T", DT_INT32}}},
279 });
280 }
281
XTimesFour()282 FunctionDef XTimesFour() {
283 return FDH::Create(
284 // Name
285 "XTimesFour",
286 // Args
287 {"x: T"},
288 // Return values
289 {"y: T"},
290 // Attr def
291 {"T: {float, double, int32, int64}"},
292 // Nodes
293 {
294 {{"x2"}, "XTimesTwo", {"x"}, {{"T", "$T"}}},
295 {{"y"}, "XTimesTwo", {"x2:y:0"}, {{"T", "$T"}}},
296 },
297 {{"y", "y:y:0"}});
298 }
299
XTimes16()300 FunctionDef XTimes16() {
301 return FDH::Create(
302 // Name
303 "XTimes16",
304 // Args
305 {"x: T"},
306 // Return values
307 {"y: T"},
308 // Attr def
309 {"T: {float, double, int32, int64}"},
310 // Nodes
311 {
312 {{"x4"}, "XTimesFour", {"x"}, {{"T", "$T"}}},
313 {{"y"}, "XTimesFour", {"x4:y:0"}, {{"T", "$T"}}},
314 },
315 {{"y", "y:y:0"}});
316 }
317
WXPlusB()318 FunctionDef WXPlusB() {
319 return FDH::Define(
320 // Name
321 "WXPlusB",
322 // Args
323 {"w: T", "x: T", "b: T"},
324 // Return values
325 {"y: T"},
326 // Attr def
327 {"T: {float, double}"},
328 // Nodes
329 {{{"mm"},
330 "MatMul",
331 {"w", "x"},
332 {{"T", "$T"},
333 {"transpose_a", false},
334 {"transpose_b", false},
335 {"_kernel", "eigen"}}},
336 {{"y"}, "Add", {"mm", "b"}, {{"T", "$T"}}}});
337 }
338
Swap()339 FunctionDef Swap() {
340 return FDH::Define(
341 // Name
342 "Swap",
343 // Args
344 {"i0: T", "i1: T"},
345 // Return values
346 {"o0: T", "o1: T"},
347 // Attr def
348 {"T: {float, double}"},
349 // Nodes
350 {{{"o0"}, "Identity", {"i1"}, {{"T", "$T"}}},
351 {{"o1"}, "Identity", {"i0"}, {{"T", "$T"}}}});
352 }
353
EmptyBodySwap()354 FunctionDef EmptyBodySwap() {
355 return FDH::Create(
356 // Name
357 "EmptyBodySwap",
358 // Args
359 {"i0: T", "i1: T"},
360 // Return values
361 {"o0: T", "o1: T"},
362 // Attr def
363 {"T: {float, double}"},
364 // Nodes
365 {},
366 // Output mapping
367 {{"o0", "i1"}, {"o1", "i0"}});
368 }
369
ResourceOutput()370 FunctionDef ResourceOutput() {
371 const Tensor kTwo = test::AsScalar<float>(2);
372 return FDH::Create(
373 // Name
374 "ResourceOutput",
375 // Args
376 {"x: float", "y: resource"},
377 // Return values
378 {"y_out: resource", "two_x: float"},
379 // Attr def
380 {},
381 // Nodes
382 {
383 {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_FLOAT}}},
384 {{"mul"}, "Mul", {"x", "two:output:0"}, {{"T", DT_FLOAT}}, {}},
385 },
386 {{"y_out", "y"}, {"two_x", "mul:z:0"}});
387 }
388
ReadResourceVariable()389 FunctionDef ReadResourceVariable() {
390 return FDH::Create(
391 // Name
392 "ReadResourceVariable",
393 // Args
394 {"x: resource"},
395 // Return values
396 {"y: float"},
397 // Attr def
398 {},
399 // Nodes
400 {
401 {{"read"}, "ReadVariableOp", {"x"}, {{"dtype", DT_FLOAT}}, {}},
402 },
403 {{"y", "read:value:0"}});
404 }
405
InvalidControlFlow()406 FunctionDef InvalidControlFlow() {
407 return FDH::Create(
408 // Name
409 "InvalidControlFlow",
410 // Args
411 {"i: int32"},
412 // Return values
413 {"o: int32"},
414 // Attr def
415 {},
416 // Nodes
417 {{{"enter"}, "Enter", {"i"}, {{"T", DT_INT32}, {"frame_name", "while"}}},
418 {{"add"}, "Add", {"enter:output", "i"}, {{"T", DT_INT32}}}},
419 // Output mapping
420 {{"o", "add:z"}});
421 }
422
LessThanOrEqualToN(int64 N)423 FunctionDef LessThanOrEqualToN(int64 N) {
424 const Tensor kN = test::AsScalar<int64>(N);
425 return FDH::Define(
426 // Name
427 "LessThanOrEqualToN",
428 // Args
429 {"x: T"},
430 // Return values
431 {"z: bool"},
432 // Attr def
433 {"T: {float, double, int32, int64}"},
434 // Nodes
435 {
436 {{"N"}, "Const", {}, {{"value", kN}, {"dtype", DT_INT64}}},
437 {{"y"}, "Cast", {"N"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
438 {{"z"}, "LessEqual", {"x", "y"}, {{"T", "$T"}}},
439 });
440 }
441
XPlusOneXTimesY()442 FunctionDef XPlusOneXTimesY() {
443 const Tensor kOne = test::AsScalar<int64>(1);
444 return FDH::Define(
445 // Name
446 "XPlusOneXTimesY",
447 // Args
448 {"x: T", "y: T"},
449 // Return values
450 {"s: T", "t: T"},
451 // Attr def
452 {"T: {float, double, int32, int64}"},
453 // Nodes
454 {{{"one"}, "Const", {}, {{"value", kOne}, {"dtype", DT_INT64}}},
455 {{"increment"}, "Cast", {"one"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
456 {{"s"}, "Add", {"x", "increment"}, {{"T", "$T"}}},
457 {{"t"}, "Mul", {"x", "y"}, {{"T", "$T"}}}});
458 }
459
XYXLessThanOrEqualToN(int64 N)460 FunctionDef XYXLessThanOrEqualToN(int64 N) {
461 const Tensor kN = test::AsScalar<int64>(N);
462 return FDH::Define(
463 // Name
464 "XYXLessThanOrEqualToN",
465 // Args
466 {"x: T", "y: T"},
467 // Return values
468 {"z: bool"},
469 // Attr def
470 {"T: {float, double, int32, int64}"},
471 // Nodes
472 {
473 {{"N"}, "Const", {}, {{"value", kN}, {"dtype", DT_INT64}}},
474 {{"N1"}, "Cast", {"N"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
475 {{"z"}, "LessEqual", {"x", "N1"}, {{"T", "$T"}}},
476 });
477 }
478
FunctionTestSchedClosure(std::function<void ()> fn)479 void FunctionTestSchedClosure(std::function<void()> fn) {
480 static thread::ThreadPool* w =
481 new thread::ThreadPool(Env::Default(), "Test", 8);
482 w->Schedule(std::move(fn));
483 }
484
485 } // end namespace function
486 } // end namespace test
487 } // end namespace tensorflow
488