1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
17
18 #include "tensorflow/cc/ops/array_ops.h"
19 #include "tensorflow/cc/ops/array_ops_internal.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/function_testlib.h"
22 #include "tensorflow/core/framework/node_def.pb.h"
23 #include "tensorflow/core/framework/tensor_testutil.h"
24 #include "tensorflow/core/framework/types.pb.h"
25 #include "tensorflow/core/grappler/grappler_item.h"
26 #include "tensorflow/core/grappler/utils.h"
27 #include "tensorflow/core/grappler/utils/grappler_test.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/lib/strings/strcat.h"
31 #include "tensorflow/core/platform/tensor_coding.h"
32
33 namespace tensorflow {
34 namespace grappler {
35 namespace {
36
37 class ConstantFoldingTest : public GrapplerTest {
38 protected:
39 template <DataType DTYPE>
SimpleNeutralElementTest()40 void SimpleNeutralElementTest() {
41 for (bool use_snapshot : {false, true}) {
42 typedef typename EnumToDataType<DTYPE>::Type T;
43 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
44 Output x = ops::Placeholder(s.WithOpName("x"), DTYPE,
45 ops::Placeholder::Shape(TensorShape({2, 2})));
46 Output v = ops::Variable(s.WithOpName("v"), {2, 2}, DTYPE);
47 Tensor zeros_t(DTYPE, TensorShape({2, 2}));
48 Tensor ones_t(DTYPE, TensorShape({2, 2}));
49 Tensor x_t(DTYPE, TensorShape({2, 2}));
50 for (int i = 0; i < 4; ++i) {
51 zeros_t.flat<T>()(i) = T(0);
52 ones_t.flat<T>()(i) = T(1);
53 x_t.flat<T>()(i) = T(i + 1);
54 }
55 Output zeros = ops::Const(s.WithOpName("zeros"), zeros_t);
56 Output ones = ops::Const(s.WithOpName("ones"), ones_t);
57 Output mul1;
58 Output mul2;
59 Output add1;
60 Output add2;
61 if (DTYPE == DT_BOOL) {
62 mul1 = ops::LogicalAnd(s.WithOpName("mul1"), x, zeros);
63 mul2 = ops::LogicalAnd(s.WithOpName("mul2"), x, ones);
64 add1 = ops::LogicalOr(s.WithOpName("add1"), x, zeros);
65 add2 = ops::LogicalOr(s.WithOpName("add2"), x, ones);
66 } else {
67 if (DTYPE == DT_FLOAT) {
68 mul1 = ops::MulNoNan(s.WithOpName("mul1"), x, zeros);
69 } else {
70 mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
71 }
72 mul2 = ops::Mul(s.WithOpName("mul2"), x, ones);
73 add1 = ops::Add(s.WithOpName("add1"), x, zeros);
74 add1 = ops::Add(s.WithOpName("add2"), x, ones);
75 }
76 if (use_snapshot) {
77 // Add an op with ref input to prevent Snapshot from being
78 // turned into Identity.
79 ops::Assign(s.WithOpName("assign"), v, ones);
80 }
81 GrapplerItem item;
82 TF_CHECK_OK(s.ToGraphDef(&item.graph));
83 item.fetch = {"mul1", "mul2", "add1", "add2"};
84 ConstantFolding optimizer(/*cpu_device=*/nullptr);
85 GraphDef output;
86 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
87 TF_EXPECT_OK(status);
88
89 EXPECT_EQ(7, output.node_size());
90 const string snapshot_or_identity =
91 use_snapshot ? "Snapshot" : "Identity";
92 for (int i = 0; i < output.node_size(); ++i) {
93 const NodeDef& node = output.node(i);
94 const string& name = node.name();
95 if (name == "mul1") {
96 EXPECT_EQ("Const", node.op());
97 EXPECT_EQ("^x", node.input(0));
98 EXPECT_EQ("^zeros", node.input(1));
99 } else if (name == "mul2") {
100 EXPECT_EQ(snapshot_or_identity, node.op());
101 EXPECT_EQ("x", node.input(0));
102 EXPECT_EQ("^ones", node.input(1));
103 } else if (name == "add1") {
104 EXPECT_EQ(snapshot_or_identity, node.op());
105 EXPECT_EQ("x", node.input(0));
106 EXPECT_EQ("^zeros", node.input(1));
107 } else if (name == "add2") {
108 if (DTYPE == DT_BOOL) {
109 EXPECT_EQ("Const", node.op());
110 EXPECT_EQ("^x", node.input(0));
111 EXPECT_EQ("^ones", node.input(1));
112 } else {
113 EXPECT_EQ("Add", node.op());
114 EXPECT_EQ("x", node.input(0));
115 EXPECT_EQ("ones", node.input(1));
116 }
117 }
118 }
119 auto tensors_expected =
120 EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
121 auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
122 EXPECT_EQ(4, tensors_expected.size());
123 EXPECT_EQ(4, tensors.size());
124 for (int i = 0; i < item.fetch.size(); ++i) {
125 test::ExpectTensorEqual<T>(tensors_expected[i], tensors[i]);
126 }
127 }
128 }
129
MulConvPushDownTest(const TensorShape & input_shape,const TensorShape & filter_shape,const TensorShape & mul_const_input_shape,const bool use_3d_conv,const char * padding,const char * data_format,const bool expect_folded)130 void MulConvPushDownTest(const TensorShape& input_shape,
131 const TensorShape& filter_shape,
132 const TensorShape& mul_const_input_shape,
133 const bool use_3d_conv, const char* padding,
134 const char* data_format, const bool expect_folded) {
135 // Tests if the following rewrite is performed:
136 //
137 // * Conv2D
138 // / \ / \
139 // c Conv2D --> x (c * filter)
140 // / \
141 // x filter
142 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
143
144 Tensor filter_values(DT_FLOAT, filter_shape);
145 for (int i = 0; i < filter_values.NumElements(); ++i) {
146 filter_values.flat<float>()(i) = std::sqrt(static_cast<float>(i));
147 }
148 Output filter =
149 ops::Const(s.WithOpName("filter"), Input::Initializer(filter_values));
150
151 Output input = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
152 ops::Placeholder::Shape(input_shape));
153
154 Output conv;
155 if (use_3d_conv) {
156 conv = ops::Conv3D(s.WithOpName("conv"), input, filter, {1, 1, 1, 1, 1},
157 padding, ops::Conv3D::DataFormat(data_format));
158 } else {
159 conv = ops::Conv2D(s.WithOpName("conv"), input, filter, {1, 1, 1, 1},
160 padding, ops::Conv2D::DataFormat(data_format));
161 }
162 Tensor mul_const_input(DT_FLOAT, mul_const_input_shape);
163 for (int i = 0; i < mul_const_input.NumElements(); ++i) {
164 mul_const_input.flat<float>()(i) = static_cast<float>(i + 3);
165 }
166 Output c =
167 ops::Const(s.WithOpName("c"), Input::Initializer(mul_const_input));
168 Output mul = ops::Mul(s.WithOpName("mul"), c, conv);
169
170 GrapplerItem item;
171 TF_CHECK_OK(s.ToGraphDef(&item.graph));
172
173 ConstantFolding optimizer(/*cpu_device=*/nullptr);
174 GraphDef output;
175 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
176 TF_EXPECT_OK(status);
177
178 EXPECT_EQ(5, output.node_size());
179 int found = 0;
180 if (expect_folded) {
181 for (const auto& node : output.node()) {
182 if (node.name() == "mul") {
183 found++;
184 EXPECT_EQ(use_3d_conv ? "Conv3D" : "Conv2D", node.op());
185 EXPECT_EQ(2, node.input_size());
186 EXPECT_EQ("x", node.input(0));
187 EXPECT_EQ("conv/merged_input", node.input(1));
188 } else if (node.name() == "conv/merged_input") {
189 found++;
190 EXPECT_EQ("Const", node.op());
191 EXPECT_EQ(0, node.input_size());
192 }
193 }
194 } else {
195 for (const auto& node : output.node()) {
196 if (node.name() == "mul") {
197 found++;
198 EXPECT_EQ("Mul", node.op());
199 EXPECT_EQ(2, node.input_size());
200 EXPECT_EQ("c", node.input(0));
201 EXPECT_EQ("conv", node.input(1));
202 } else if (node.name() == "conv") {
203 found++;
204 EXPECT_EQ(use_3d_conv ? "Conv3D" : "Conv2D", node.op());
205 EXPECT_EQ(2, node.input_size());
206 EXPECT_EQ("x", node.input(0));
207 EXPECT_EQ("filter", node.input(1));
208 }
209 }
210 }
211 EXPECT_EQ(2, found);
212
213 // Check that const folded multiplication node has the expected value.
214 std::vector<string> fetch = {"mul"};
215 Tensor value(DT_FLOAT, input_shape);
216 for (int i = 0; i < value.NumElements(); ++i) {
217 value.flat<float>()(i) = i;
218 }
219 auto actual = EvaluateNodes(output, fetch, {{"x", value}});
220 auto expected = EvaluateNodes(item.graph, fetch, {{"x", value}});
221 test::ExpectTensorEqual<float>(expected[0], actual[0]);
222 }
223
224 template <typename T>
PaddingWithZeroSize()225 void PaddingWithZeroSize() {
226 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
227
228 auto in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_INT32);
229 auto in2 = ops::Variable(scope.WithOpName("in2"), {2, 2}, DT_INT32);
230 auto paddings1 =
231 ops::Const<T>(scope.WithOpName("paddings1"), {0, 0, 0, 0}, {2, 2});
232 auto paddings2 =
233 ops::Const<T>(scope.WithOpName("paddings2"), {1, 1, 2, 2}, {2, 2});
234 auto c1 = ops::Const(scope.WithOpName("c1"), 1);
235 auto c2 = ops::Const(scope.WithOpName("c2"), 1);
236
237 ops::PadV2 p1(scope.WithOpName("p1"), in1, paddings1, c1);
238 ops::PadV2 p2(scope.WithOpName("p2"), in2, paddings2, c2);
239
240 ops::Add out(scope.WithOpName("out"), p1, p2);
241
242 GrapplerItem item;
243 item.fetch = {"out"};
244 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
245
246 ConstantFolding optimizer(/*cpu_device=*/nullptr);
247 GraphDef got;
248 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
249 TF_EXPECT_OK(status);
250
251 GraphDef want;
252 AddNode("in1", "VariableV2", {}, {}, &want);
253 AddNode("in2", "VariableV2", {}, {}, &want);
254 AddNode("paddings1", "Const", {}, {}, &want);
255 AddNode("paddings2", "Const", {}, {}, &want);
256 AddNode("c1", "Const", {}, {}, &want);
257 AddNode("c2", "Const", {}, {}, &want);
258 AddNode(
259 "p1", "Identity",
260 {"in1", AsControlDependency("paddings1"), AsControlDependency("c1")},
261 {}, &want);
262 AddNode("p2", "PadV2", {"in2", "paddings2", "c2"}, {}, &want);
263 AddNode("out", "Add", {"p1", "p2"}, {}, &want);
264
265 CompareGraphs(want, got);
266
267 auto in1_t = GenerateRandomTensor<DT_INT32>(TensorShape({4, 6}));
268 auto in2_t = GenerateRandomTensor<DT_INT32>(TensorShape({2, 2}));
269 auto tensors_expected =
270 EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
271 EXPECT_EQ(1, tensors_expected.size());
272 auto tensors =
273 EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
274 EXPECT_EQ(1, tensors.size());
275 test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
276 }
277 };
278
TEST_F(ConstantFoldingTest,SimpleFolding)279 TEST_F(ConstantFoldingTest, SimpleFolding) {
280 // Build a simple graph with a few trivially prunable ops.
281 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
282
283 Output a = ops::Const(s.WithOpName("a"), 1.0f, {1});
284 Output b = ops::Const(s.WithOpName("b"), 2.0f, {1});
285 Output c = ops::AddN(s.WithOpName("c").WithDevice("/CPU:0"), {a, b});
286 Output d = ops::AddN(s.WithOpName("d"), {b, c});
287
288 GrapplerItem item;
289 item.fetch.push_back("d");
290 TF_CHECK_OK(s.ToGraphDef(&item.graph));
291
292 ConstantFolding optimizer(/*cpu_device=*/nullptr);
293 GraphDef output;
294 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
295 TF_EXPECT_OK(status);
296
297 EXPECT_EQ(1, output.node_size());
298
299 const NodeDef& node_d = output.node(0);
300 EXPECT_EQ("d", node_d.name());
301 EXPECT_EQ("Const", node_d.op());
302
303 std::vector<string> fetch = {"d"};
304 auto tensors_expected = EvaluateNodes(item.graph, fetch);
305 auto tensors = EvaluateNodes(output, fetch);
306 EXPECT_EQ(1, tensors_expected.size());
307 EXPECT_EQ(1, tensors.size());
308 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
309 }
310
TEST_F(ConstantFoldingTest,AddTree)311 TEST_F(ConstantFoldingTest, AddTree) {
312 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
313
314 Output c1 = ops::Const(s.WithOpName("c1"), 1.0f, {1});
315 Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2});
316 Output c3 = ops::Const(s.WithOpName("c3"), 3.0f, {2});
317 Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
318 ops::Placeholder::Shape(TensorShape({2, 2})));
319 Output add_child = ops::Add(s.WithOpName("add_child"), c2, x);
320 Output add_parent = ops::Add(s.WithOpName("add_parent"), c1, add_child);
321
322 Output c4 = ops::Const(s.WithOpName("c4"), 4.0f, {2});
323 Output c5 = ops::Const(s.WithOpName("c5"), 5.0f, {2});
324 Output c20 = ops::Const(s.WithOpName("c20"), 20.0f, {2});
325 Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
326 ops::Placeholder::Shape(TensorShape({2, 2})));
327 Output mul_child = ops::Mul(s.WithOpName("mul_child"), c4, y);
328 Output mul_parent = ops::Mul(s.WithOpName("mul_parent"), c5, mul_child);
329 Output addmul_child = ops::Add(s.WithOpName("addmul_child"), c4, x);
330 Output addmul_parent =
331 ops::Mul(s.WithOpName("addmul_parent"), c5, addmul_child);
332
333 GrapplerItem item;
334 item.fetch = {"add_parent", "mul_parent", "addmul_parent"};
335 TF_CHECK_OK(s.ToGraphDef(&item.graph));
336
337 ConstantFolding optimizer(/*cpu_device=*/nullptr);
338 GraphDef output;
339 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
340 TF_EXPECT_OK(status);
341
342 // We expect the following rewrite(s) to occur:
343 //
344 // + + +
345 // / \ / \ / \
346 // 1.0 + --> x + --> x 3.0
347 // / \ / \
348 // 2.0 x 1.0 2.0
349 //
350 // * * *
351 // / \ / \ / \
352 // 4.0 * --> y * --> y 20.0
353 // / \ / \
354 // 5.0 y 4.0 5.0
355
356 EXPECT_EQ(10, output.node_size());
357 for (const auto& node : output.node()) {
358 if (node.name() == "add_child") {
359 EXPECT_EQ("Const", node.op());
360 TensorProto t = node.attr().at("value").tensor();
361 ASSERT_EQ(1, t.tensor_shape().dim_size());
362 EXPECT_EQ(2, t.tensor_shape().dim(0).size());
363 } else if (node.name() == "add_parent") {
364 EXPECT_EQ("Add", node.op());
365 ASSERT_EQ(2, node.input_size());
366 EXPECT_EQ("x", node.input(0));
367 EXPECT_EQ("add_child", node.input(1));
368 } else if (node.name() == "mul_child") {
369 EXPECT_EQ("Const", node.op());
370 TensorProto t = node.attr().at("value").tensor();
371 EXPECT_EQ(1, t.tensor_shape().dim_size());
372 EXPECT_EQ(2, t.tensor_shape().dim(0).size());
373 } else if (node.name() == "mul_parent") {
374 EXPECT_EQ("Mul", node.op());
375 ASSERT_EQ(2, node.input_size());
376 EXPECT_EQ("y", node.input(0));
377 EXPECT_EQ("mul_child", node.input(1));
378 } else if (node.name() == "addmul_child") {
379 // Unchanged.
380 EXPECT_EQ("Add", node.op());
381 ASSERT_EQ(2, node.input_size());
382 EXPECT_EQ("c4", node.input(0));
383 EXPECT_EQ("x", node.input(1));
384 }
385 }
386
387 // Check that the result nodes have the expected value.
388 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
389 auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
390
391 std::vector<string> fetch = {"add_parent", "mul_parent"};
392 auto tensor_expected =
393 EvaluateNodes(item.graph, fetch, {{"x", x_t}, {"y", y_t}});
394 ASSERT_EQ(fetch.size(), tensor_expected.size());
395 fetch = {"add_parent", "mul_parent"};
396 auto tensors = EvaluateNodes(output, fetch, {{"x", x_t}, {"y", y_t}});
397 ASSERT_EQ(fetch.size(), tensors.size());
398 for (int i = 0; i < fetch.size(); i++) {
399 test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]);
400 }
401 }
402
TEST_F(ConstantFoldingTest,AddSubtactTree)403 TEST_F(ConstantFoldingTest, AddSubtactTree) {
404 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
405
406 Output c1 = ops::Const(s.WithOpName("c1"), 1.0f, {1});
407 Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
408 ops::Placeholder::Shape(TensorShape({2, 2})));
409 Output sub_child = ops::Sub(s.WithOpName("sub_child"), x, x);
410 Output add_parent = ops::Add(s.WithOpName("add_parent"), sub_child, c1);
411
412 GrapplerItem item;
413 item.fetch = {"add_parent"};
414 TF_CHECK_OK(s.ToGraphDef(&item.graph));
415
416 ConstantFolding optimizer(/*cpu_device=*/nullptr);
417 GraphDef output;
418 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
419 TF_EXPECT_OK(status);
420
421 // We expect the following rewrite(s) to occur:
422 //
423 // + +
424 // / \ / \
425 // - 1 --> - x
426 // / \ / \
427 // x x 1 x
428
429 EXPECT_EQ(4, output.node_size());
430 for (const auto& node : output.node()) {
431 if (node.name() == "sub_child") {
432 EXPECT_EQ("Sub", node.op());
433 ASSERT_EQ(2, node.input_size());
434 EXPECT_EQ("c1", node.input(0));
435 EXPECT_EQ("x", node.input(1));
436 } else if (node.name() == "add_parent") {
437 EXPECT_EQ("Add", node.op());
438 ASSERT_EQ(2, node.input_size());
439 EXPECT_EQ("x", node.input(0));
440 EXPECT_EQ("sub_child", node.input(1));
441 }
442 }
443
444 // Check that the result nodes have the expected value.
445 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
446
447 std::vector<string> fetch = {"add_parent"};
448 auto tensor_expected = EvaluateNodes(item.graph, fetch, {{"x", x_t}});
449 ASSERT_EQ(fetch.size(), tensor_expected.size());
450 fetch = {"add_parent"};
451 auto tensors = EvaluateNodes(output, fetch, {{"x", x_t}});
452 ASSERT_EQ(fetch.size(), tensors.size());
453 for (int i = 0; i < fetch.size(); i++) {
454 test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]);
455 }
456 }
457
TEST_F(ConstantFoldingTest,ConstantPushDown)458 TEST_F(ConstantFoldingTest, ConstantPushDown) {
459 for (int is_add : {true, false}) {
460 for (int is_parent_commutative : {true, false}) {
461 for (int is_child_commutative : {true, false}) {
462 for (int is_left_child_const : {true, false}) {
463 for (int is_left_leaf_const : {true, false}) {
464 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
465 Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2});
466 Output c3 = ops::Const(s.WithOpName("c3"), 3.0f, {2});
467 Output x =
468 ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
469 ops::Placeholder::Shape(TensorShape({2, 2})));
470
471 auto get_op = [&](bool is_commutative, bool is_left_arg_const,
472 const string& name, const Output& const_arg,
473 const Output non_const_arg) -> Output {
474 if (is_add) {
475 if (is_commutative) {
476 return ops::Add(
477 s.WithOpName(name),
478 is_left_arg_const ? const_arg : non_const_arg,
479 is_left_arg_const ? non_const_arg : const_arg);
480 } else {
481 return ops::Sub(
482 s.WithOpName(name),
483 is_left_arg_const ? const_arg : non_const_arg,
484 is_left_arg_const ? non_const_arg : const_arg);
485 }
486 } else {
487 if (is_commutative) {
488 return ops::Mul(
489 s.WithOpName(name),
490 is_left_arg_const ? const_arg : non_const_arg,
491 is_left_arg_const ? non_const_arg : const_arg);
492 } else {
493 return ops::Div(
494 s.WithOpName(name),
495 is_left_arg_const ? const_arg : non_const_arg,
496 is_left_arg_const ? non_const_arg : const_arg);
497 }
498 }
499 };
500
501 Output child = get_op(is_child_commutative, is_left_leaf_const,
502 "child", c2, x);
503 Output parent = get_op(is_parent_commutative, is_left_child_const,
504 "parent", c3, child);
505 GrapplerItem item;
506 item.fetch = {"parent"};
507 TF_CHECK_OK(s.ToGraphDef(&item.graph));
508
509 ConstantFolding optimizer(/*cpu_device=*/nullptr);
510 GraphDef output;
511 Status status =
512 optimizer.Optimize(/*cluster=*/nullptr, item, &output);
513 TF_EXPECT_OK(status);
514
515 // Check that the result nodes have the expected value.
516 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
517 std::vector<string> fetch = {"parent"};
518 auto tensor_expected =
519 EvaluateNodes(item.graph, fetch, {{"x", x_t}});
520 ASSERT_EQ(fetch.size(), tensor_expected.size());
521 fetch = {"parent"};
522 auto tensors = EvaluateNodes(output, fetch, {{"x", x_t}});
523 ASSERT_EQ(fetch.size(), tensors.size());
524 for (int i = 0; i < fetch.size(); i++) {
525 test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]);
526 }
527 }
528 }
529 }
530 }
531 }
532 }
533
TEST_F(ConstantFoldingTest,ConstantPushDownBiasAdd)534 TEST_F(ConstantFoldingTest, ConstantPushDownBiasAdd) {
535 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
536 Output c_mat = ops::Const(s.WithOpName("c_mat"), 2.0f, {2, 2});
537 Output c_vec = ops::Const(s.WithOpName("c_vec"), 3.0f, {2});
538 Output x_mat = ops::Placeholder(s.WithOpName("x_mat"), DT_FLOAT,
539 ops::Placeholder::Shape(TensorShape({2, 2})));
540 Output x_vec = ops::Placeholder(s.WithOpName("x_vec"), DT_FLOAT,
541 ops::Placeholder::Shape(TensorShape({2})));
542 // Rewrite expected for cases 1 through 3 and their symmetric equivalents,
543 // and case 4.
544 Output child1 = ops::BiasAdd(s.WithOpName("child1"), c_mat, x_vec);
545 Output parent1 = ops::Add(s.WithOpName("parent1"), child1, c_vec);
546 Output child1a = ops::BiasAdd(s.WithOpName("child1a"), c_mat, x_vec);
547 Output parent1a = ops::Add(s.WithOpName("parent1a"), c_vec, child1a);
548
549 Output child2 = ops::BiasAdd(s.WithOpName("child2"), x_mat, c_vec);
550 Output parent2 = ops::Add(s.WithOpName("parent2"), child2, c_mat);
551 Output child2a = ops::BiasAdd(s.WithOpName("child2a"), x_mat, c_vec);
552 Output parent2a = ops::Add(s.WithOpName("parent2a"), c_mat, child2a);
553
554 Output child3 = ops::Add(s.WithOpName("child3"), c_mat, x_vec);
555 Output parent3 = ops::BiasAdd(s.WithOpName("parent3"), child3, c_vec);
556 Output child3a = ops::Add(s.WithOpName("child3a"), x_vec, c_mat);
557 Output parent3a = ops::BiasAdd(s.WithOpName("parent3a"), child3a, c_vec);
558
559 Output child4 = ops::BiasAdd(s.WithOpName("child4"), c_mat, x_vec);
560 Output parent4 = ops::BiasAdd(s.WithOpName("parent4"), child4, c_vec);
561
562 // No rewrite expected.
563 Output child5 = ops::Add(s.WithOpName("child5"), x_vec, x_vec);
564 Output parent5 = ops::BiasAdd(s.WithOpName("parent5"), c_mat, child5);
565 Output child6 = ops::Add(s.WithOpName("child6"), x_vec, c_vec);
566 Output parent6 = ops::BiasAdd(s.WithOpName("parent6"), c_mat, child6);
567 Output child7 = ops::Add(s.WithOpName("child7"), x_mat, c_vec);
568 Output parent7 = ops::BiasAdd(s.WithOpName("parent7"), child7, c_vec);
569
570 GrapplerItem item;
571 item.fetch = {"parent1", "parent2", "parent3", "parent1a", "parent2a",
572 "parent3a", "parent4", "parent5", "parent6", "parent7"};
573 TF_CHECK_OK(s.ToGraphDef(&item.graph));
574
575 ConstantFolding optimizer(/*cpu_device=*/nullptr);
576 GraphDef output;
577 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
578 TF_EXPECT_OK(status);
579
580 EXPECT_EQ(24, output.node_size());
581 for (const auto& node : output.node()) {
582 if (node.name() == "child1" || node.name() == "child1a" ||
583 node.name() == "child2" || node.name() == "child2a" ||
584 node.name() == "child3" || node.name() == "child3a" ||
585 node.name() == "child4") {
586 EXPECT_EQ(node.op(), "Const") << " node: " << node.name();
587 } else if (node.name() != "c_mat" && node.name() != "c_vec") {
588 EXPECT_NE(node.op(), "Const") << " node: " << node.name();
589 }
590 }
591 // Check that the result nodes have the expected value.
592 auto x_mat_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
593 auto x_vec_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2}));
594 std::vector<string> fetch = item.fetch;
595 auto tensor_expected = EvaluateNodes(
596 item.graph, fetch, {{"x_vec", x_vec_t}, {"x_mat", x_mat_t}});
597 ASSERT_EQ(fetch.size(), tensor_expected.size());
598 auto tensors =
599 EvaluateNodes(output, fetch, {{"x_vec", x_vec_t}, {"x_mat", x_mat_t}});
600 ASSERT_EQ(fetch.size(), tensors.size());
601 for (int i = 0; i < fetch.size(); i++) {
602 test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]);
603 }
604 }
605
606 // This test fails on ROCm platform (see commit message for details)
607 #ifndef TENSORFLOW_USE_ROCM
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_ScalarConst)608 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_ScalarConst) {
609 for (string data_format : {
610 "NHWC",
611 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
612 "NCHW"
613 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
614 }) {
615 MulConvPushDownTest(
616 /*input_shape=*/data_format == "NHWC" ? TensorShape{4, 10, 10, 3}
617 : TensorShape{4, 3, 10, 10},
618 /*filter_shape=*/{2, 2, 3, 5},
619 /*mul_const_input_shape=*/{},
620 /*use_3d_conv=*/false,
621 /*padding=*/"VALID", data_format.c_str(),
622 /*expect_folded=*/true);
623 }
624 }
625 #endif
626
627 // This test fails on ROCm platform (see commit message for details)
628 #ifndef TENSORFLOW_USE_ROCM
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_SingletonConst)629 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_SingletonConst) {
630 for (string data_format : {
631 "NHWC",
632 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
633 "NCHW"
634 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
635 }) {
636 for (auto mul_const_input_shape :
637 {TensorShape{1}, TensorShape{1, 1, 1, 1}}) {
638 MulConvPushDownTest(
639 /*input_shape=*/data_format == "NHWC" ? TensorShape{4, 10, 10, 3}
640 : TensorShape{4, 3, 10, 10},
641 /*filter_shape=*/{2, 2, 3, 5}, mul_const_input_shape,
642 /*use_3d_conv=*/false,
643 /*padding=*/"VALID", data_format.c_str(),
644 /*expect_folded=*/true);
645 }
646 }
647 }
648 #endif
649
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_SingletonConst_ShapeMismatch)650 TEST_F(ConstantFoldingTest,
651 MulConvPushDownTest_Conv2D_SingletonConst_ShapeMismatch) {
652 for (string data_format : {
653 "NHWC",
654 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
655 "NCHW"
656 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
657 }) {
658 MulConvPushDownTest(
659 /*input_shape=*/data_format == "NHWC" ? TensorShape{4, 10, 10, 3}
660 : TensorShape{4, 3, 10, 10},
661 /*filter_shape=*/{2, 2, 3, 5},
662 /*mul_const_input_shape=*/{1, 1, 1, 1, 1},
663 /*use_3d_conv=*/false,
664 /*padding=*/"VALID", data_format.c_str(),
665 /*expect_folded=*/false);
666 }
667 }
668
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_3x1x3Const)669 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_3x1x3Const) {
670 for (auto data_format : {
671 "NHWC",
672 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
673 "NCHW"
674 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
675 }) {
676 MulConvPushDownTest(
677 /*input_shape=*/{3, 3, 3, 3},
678 /*filter_shape=*/{3, 3, 3, 3},
679 /*mul_const_input_shape=*/{3, 1, 3},
680 /*use_3d_conv=*/false,
681 /*padding=*/"SAME", data_format,
682 /*expect_folded=*/false);
683 }
684 }
685
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_NHWC_VectorLikeConst)686 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_NHWC_VectorLikeConst) {
687 for (auto mul_const_input_shape :
688 {TensorShape{3}, TensorShape{1, 3}, TensorShape{1, 1, 1, 3}}) {
689 MulConvPushDownTest(
690 /*input_shape=*/{3, 3, 3, 3},
691 /*filter_shape=*/{3, 3, 3, 3}, mul_const_input_shape,
692 /*use_3d_conv=*/false,
693 /*padding=*/"SAME",
694 /*data_format=*/"NHWC",
695 /*expect_folded=*/true);
696 }
697 }
698
699 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_NCHW_VectorLikeConst)700 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_NCHW_VectorLikeConst) {
701 for (auto mul_const_input_shape :
702 {TensorShape{3}, TensorShape{3, 1, 1}, TensorShape{1, 3, 1, 1}}) {
703 MulConvPushDownTest(
704 /*input_shape=*/{3, 3, 3, 3},
705 /*filter_shape=*/{3, 3, 3, 3}, mul_const_input_shape,
706 /*use_3d_conv=*/false,
707 /*padding=*/"SAME",
708 /*data_format=*/"NCHW",
709 // TODO(laigd): optimization should happen in this case.
710 /*expect_folded=*/false);
711 }
712 }
713 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
714
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_3x1Const)715 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_3x1Const) {
716 for (auto data_format : {
717 "NHWC",
718 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
719 "NCHW"
720 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
721 }) {
722 MulConvPushDownTest(
723 /*input_shape=*/{3, 3, 3, 3},
724 /*filter_shape=*/{3, 3, 3, 3},
725 /*mul_const_input_shape=*/{3, 1},
726 /*use_3d_conv=*/false,
727 /*padding=*/"SAME", data_format,
728 /*expect_folded=*/false);
729 }
730 }
731
732 // This test fails on ROCm platform (see commit message for details)
733 #ifndef TENSORFLOW_USE_ROCM
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv3D_NDHWC_1x1x3Const)734 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv3D_NDHWC_1x1x3Const) {
735 MulConvPushDownTest(
736 /*input_shape=*/{3, 3, 3, 3, 3},
737 /*filter_shape=*/{3, 3, 3, 3, 3},
738 /*mul_const_input_shape=*/{1, 1, 3},
739 /*use_3d_conv=*/true,
740 /*padding=*/"SAME",
741 /*data_format=*/"NDHWC",
742 /*expect_folded=*/true);
743 }
744 #endif
745
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv3D_NCDHW_3x1x1x1Const)746 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv3D_NCDHW_3x1x1x1Const) {
747 MulConvPushDownTest(
748 /*input_shape=*/{3, 3, 3, 3, 3},
749 /*filter_shape=*/{3, 3, 3, 3, 3},
750 /*mul_const_input_shape=*/{3, 1, 1, 1},
751 /*use_3d_conv=*/true,
752 /*padding=*/"SAME",
753 /*data_format=*/"NDHWC",
754 // TODO(laigd): optimization should happen in this case.
755 /*expect_folded=*/false);
756 }
757
TEST_F(ConstantFoldingTest,NeutralElement)758 TEST_F(ConstantFoldingTest, NeutralElement) {
759 int kConst = 0;
760 int kLike = 1;
761 int kFill = 2;
762 for (int const_type : {kConst, kLike, kFill}) {
763 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
764 Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
765 ops::Placeholder::Shape(TensorShape({2, 2})));
766 Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
767 ops::Placeholder::Shape(TensorShape({2, 2})));
768 Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
769 ops::Placeholder::Shape(TensorShape({3, 2})));
770 Output b = ops::Placeholder(s.WithOpName("b"), DT_FLOAT,
771 ops::Placeholder::Shape(TensorShape({2, 3})));
772 Output bias = ops::Placeholder(s.WithOpName("bias"), DT_FLOAT,
773 ops::Placeholder::Shape(TensorShape({2})));
774 Output zeros_1d = ops::Const(s.WithOpName("zeros_1d"), 0.0f, {2});
775 Output zeros_const = ops::Const(s.WithOpName("zeros_const"), 0.0f, {2, 2});
776 Output zeros_const_bcast =
777 ops::Const(s.WithOpName("zeros_const_bcast"), 0.0f, {2, 2, 2});
778 Output zeros_like = ops::ZerosLike(s.WithOpName("zeros_like"), x);
779 Output zeros_fill = ops::Fill(s.WithOpName("zeros_fill"), {2, 2}, 0.0f);
780 Output zeros = const_type == kConst
781 ? zeros_const
782 : (const_type == kLike ? zeros_like : zeros_fill);
783 Output ones_const = ops::Const(s.WithOpName("ones_const"), 1.0f, {2, 2});
784 Output ones_const_bcast =
785 ops::Const(s.WithOpName("ones_const_bcast"), 1.0f, {2, 2, 2});
786 Output ones_like = ops::OnesLike(s.WithOpName("ones_like"), x);
787 Output ones_fill = ops::Fill(s.WithOpName("ones_fill"), {2, 2}, 1.0f);
788 Output ones = const_type == kConst
789 ? ones_const
790 : (const_type == kLike ? ones_like : ones_fill);
791 Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
792 Output mul2 = ops::Mul(s.WithOpName("mul2"), zeros, y);
793 Output mul1_bcast =
794 ops::Mul(s.WithOpName("mul1_bcast"), x, ones_const_bcast);
795 Output mul2_bcast =
796 ops::Mul(s.WithOpName("mul2_bcast"), ones_const_bcast, y);
797 Output mul3 = ops::Mul(s.WithOpName("mul3"), x, ones);
798 Output mul4 = ops::Mul(s.WithOpName("mul4"), ones, y);
799 Output mul5 = ops::MulNoNan(s.WithOpName("mul5"), x, zeros_1d);
800 Output mul6 = ops::MulNoNan(s.WithOpName("mul6"), zeros_1d, y);
801 Output div1 = ops::Div(s.WithOpName("div1"), x, ones);
802 Output div2 = ops::Div(s.WithOpName("div2"), ones, y);
803 Output matmul1 = ops::MatMul(s.WithOpName("matmul1"), x, zeros);
804 Output matmul2 = ops::MatMul(s.WithOpName("matmul2"), zeros, y);
805 Output matmul3 = ops::MatMul(s.WithOpName("matmul3"), a, zeros);
806 Output matmul4 = ops::MatMul(s.WithOpName("matmul4"), zeros, b);
807 Output add1 = ops::Add(s.WithOpName("add1"), x, zeros);
808 Output add2 = ops::Add(s.WithOpName("add2"), zeros, y);
809 Output add1_bcast =
810 ops::Add(s.WithOpName("add1_bcast"), x, zeros_const_bcast);
811 Output add2_bcast =
812 ops::Add(s.WithOpName("add2_bcast"), zeros_const_bcast, y);
813 Output bias_add1 = ops::BiasAdd(s.WithOpName("bias_add1"), x, zeros_1d);
814 Output bias_add2 = ops::BiasAdd(s.WithOpName("bias_add2"), zeros, bias);
815 Output sub1 = ops::Sub(s.WithOpName("sub1"), x, zeros);
816 Output sub2 = ops::Sub(s.WithOpName("sub2"), zeros, y);
817 Output concat =
818 ops::Stack(s.WithOpName("stack"),
819 {mul1, mul2, mul3, mul4, mul5, mul6, div1, div2, matmul1,
820 matmul2, add1, add2, bias_add1, bias_add2, sub1, sub2});
821 GrapplerItem item;
822 TF_CHECK_OK(s.ToGraphDef(&item.graph));
823 item.fetch = {"stack", "matmul3", "matmul4", "mul1_bcast",
824 "mul2_bcast", "add1_bcast", "add2_bcast"};
825
826 ConstantFolding optimizer(/*cpu_device=*/nullptr);
827 GraphDef output;
828 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
829 TF_EXPECT_OK(status);
830
831 const string suffix =
832 (const_type == kConst ? "_const"
833 : (const_type == kLike ? "_like" : "_fill"));
834 const string zeros_name = strings::StrCat("zeros", suffix);
835 const string ones_name = strings::StrCat("ones", suffix);
836 const string ctrl_zeros_name = strings::StrCat("^zeros", suffix);
837 const string ctrl_ones_name = strings::StrCat("^ones", suffix);
838
839 EXPECT_EQ(const_type == kFill ? 42 : 38, output.node_size());
840 for (int i = 0; i < output.node_size(); ++i) {
841 const NodeDef& node = output.node(i);
842 const string& name = node.name();
843 if (name == "mul1") {
844 EXPECT_EQ("Const", node.op());
845 EXPECT_EQ("^x", node.input(0));
846 EXPECT_EQ(ctrl_zeros_name, node.input(1));
847 } else if (name == "mul2") {
848 EXPECT_EQ("Const", node.op());
849 EXPECT_EQ(ctrl_zeros_name, node.input(0));
850 EXPECT_EQ("^y", node.input(1));
851 } else if (name == "mul1_bcast") {
852 EXPECT_EQ("BroadcastTo", node.op());
853 EXPECT_EQ("x", node.input(0));
854 EXPECT_EQ("^ones_const_bcast", node.input(2));
855 } else if (name == "mul2_bcast") {
856 EXPECT_EQ("BroadcastTo", node.op());
857 EXPECT_EQ("y", node.input(0));
858 EXPECT_EQ("^ones_const_bcast", node.input(2));
859 } else if (name == "mul3") {
860 EXPECT_EQ("Identity", node.op());
861 EXPECT_EQ("x", node.input(0));
862 EXPECT_EQ(ctrl_ones_name, node.input(1));
863 } else if (name == "mul4") {
864 EXPECT_EQ("Identity", node.op());
865 EXPECT_EQ("y", node.input(0));
866 EXPECT_EQ(ctrl_ones_name, node.input(1));
867 } else if (name == "mul5") {
868 EXPECT_EQ("Const", node.op());
869 EXPECT_EQ("^x", node.input(0));
870 EXPECT_EQ("^zeros_1d", node.input(1));
871 } else if (name == "mul6") {
872 EXPECT_EQ("Const", node.op());
873 EXPECT_EQ("^zeros_1d", node.input(0));
874 EXPECT_EQ("^y", node.input(1));
875 } else if (name == "div1") {
876 EXPECT_EQ("Identity", node.op());
877 EXPECT_EQ("x", node.input(0));
878 EXPECT_EQ(ctrl_ones_name, node.input(1));
879 } else if (name == "div2") {
880 EXPECT_EQ("Reciprocal", node.op());
881 EXPECT_EQ("y", node.input(0));
882 EXPECT_EQ(ctrl_ones_name, node.input(1));
883 } else if (name == "matmul1") {
884 EXPECT_EQ("Const", node.op());
885 EXPECT_EQ("^x", node.input(0));
886 EXPECT_EQ(ctrl_zeros_name, node.input(1));
887 } else if (name == "matmul2") {
888 EXPECT_EQ("Const", node.op());
889 EXPECT_EQ(ctrl_zeros_name, node.input(0));
890 EXPECT_EQ("^y", node.input(1));
891 } else if (name == "matmul3") {
892 EXPECT_EQ("Const", node.op());
893 EXPECT_EQ("^a", node.input(0));
894 EXPECT_EQ(ctrl_zeros_name, node.input(1));
895 TensorProto t = node.attr().at("value").tensor();
896 EXPECT_EQ(1, t.float_val_size());
897 EXPECT_EQ(0, t.float_val(0));
898 EXPECT_EQ(2, t.tensor_shape().dim_size());
899 EXPECT_EQ(3, t.tensor_shape().dim(0).size());
900 EXPECT_EQ(2, t.tensor_shape().dim(1).size());
901 } else if (name == "matmul4") {
902 EXPECT_EQ("Const", node.op());
903 EXPECT_EQ(ctrl_zeros_name, node.input(0));
904 EXPECT_EQ("^b", node.input(1));
905 TensorProto t = node.attr().at("value").tensor();
906 EXPECT_EQ(1, t.float_val_size());
907 EXPECT_EQ(0, t.float_val(0));
908 EXPECT_EQ(2, t.tensor_shape().dim_size());
909 EXPECT_EQ(2, t.tensor_shape().dim(0).size());
910 EXPECT_EQ(3, t.tensor_shape().dim(1).size());
911 } else if (name == "add1") {
912 EXPECT_EQ("Identity", node.op());
913 EXPECT_EQ("x", node.input(0));
914 EXPECT_EQ(ctrl_zeros_name, node.input(1));
915 } else if (name == "add2") {
916 EXPECT_EQ("Identity", node.op());
917 EXPECT_EQ("y", node.input(0));
918 EXPECT_EQ(ctrl_zeros_name, node.input(1));
919 } else if (name == "add1_bcast") {
920 EXPECT_EQ("BroadcastTo", node.op());
921 EXPECT_EQ("x", node.input(0));
922 EXPECT_EQ("^zeros_const_bcast", node.input(2));
923 } else if (name == "add2_bcast") {
924 EXPECT_EQ("BroadcastTo", node.op());
925 EXPECT_EQ("y", node.input(0));
926 EXPECT_EQ("^zeros_const_bcast", node.input(2));
927 } else if (name == "bias_add1") {
928 EXPECT_EQ("Identity", node.op());
929 EXPECT_EQ("x", node.input(0));
930 EXPECT_EQ("^zeros_1d", node.input(1));
931 } else if (name == "bias_add2") {
932 EXPECT_EQ("BroadcastTo", node.op());
933 EXPECT_EQ("bias", node.input(0));
934 EXPECT_EQ("ConstantFolding/bias_add2-broadcastto_shape-1",
935 node.input(1));
936 EXPECT_EQ(ctrl_zeros_name, node.input(2));
937 } else if (name == "ConstantFolding/bias_add2-broadcastto_shape-1") {
938 EXPECT_EQ("Const", node.op());
939 EXPECT_EQ(ctrl_zeros_name, node.input(0));
940 EXPECT_EQ(node.attr().at("dtype").type(), DT_INT32);
941 TensorProto t = node.attr().at("value").tensor();
942 EXPECT_EQ(DT_INT32, t.dtype());
943 EXPECT_EQ(1, t.tensor_shape().dim_size());
944 EXPECT_EQ(2, t.tensor_shape().dim(0).size());
945 } else if (name == "sub1") {
946 EXPECT_EQ("Identity", node.op());
947 EXPECT_EQ("x", node.input(0));
948 EXPECT_EQ(ctrl_zeros_name, node.input(1));
949 } else if (name == "sub2") {
950 EXPECT_EQ("Neg", node.op());
951 EXPECT_EQ("y", node.input(0));
952 EXPECT_EQ(ctrl_zeros_name, node.input(1));
953 }
954 const std::set<string> square_zero_const{"mul1", "mul2", "mul5",
955 "mul6", "matmul1", "matmul2"};
956 if (square_zero_const.count(name) > 0) {
957 TensorProto t = node.attr().at("value").tensor();
958 EXPECT_EQ(1, t.float_val_size());
959 EXPECT_EQ(0, t.float_val(0));
960 EXPECT_EQ(2, t.tensor_shape().dim_size());
961 EXPECT_EQ(2, t.tensor_shape().dim(0).size());
962 EXPECT_EQ(2, t.tensor_shape().dim(1).size());
963 }
964 }
965 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 2}));
966 auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 3}));
967 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
968 auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
969 auto bias_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2}));
970
971 auto tensors_expected = EvaluateNodes(
972 item.graph, item.fetch,
973 {{"x", x_t}, {"y", y_t}, {"a", a_t}, {"b", b_t}, {"bias", bias_t}});
974 EXPECT_EQ(item.fetch.size(), tensors_expected.size());
975 auto tensors = EvaluateNodes(
976 output, item.fetch,
977 {{"x", x_t}, {"y", y_t}, {"a", a_t}, {"b", b_t}, {"bias", bias_t}});
978 EXPECT_EQ(item.fetch.size(), tensors.size());
979 for (int i = 0; i < item.fetch.size(); ++i) {
980 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
981 }
982 }
983 }
984
TEST_F(ConstantFoldingTest,NeutralElement_ShortFloats)985 TEST_F(ConstantFoldingTest, NeutralElement_ShortFloats) {
986 SimpleNeutralElementTest<DT_BOOL>();
987 SimpleNeutralElementTest<DT_HALF>();
988 SimpleNeutralElementTest<DT_BFLOAT16>();
989 }
990
TEST_F(ConstantFoldingTest,StrengthReduce_Reciprocal)991 TEST_F(ConstantFoldingTest, StrengthReduce_Reciprocal) {
992 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
993 Output cf_half = ops::Const(s.WithOpName("cf_half"), 0.5f, {1});
994 Output xf = ops::Placeholder(s.WithOpName("xf"), DT_FLOAT,
995 ops::Placeholder::Shape(TensorShape({2, 2})));
996 Output xi = ops::Placeholder(s.WithOpName("xi"), DT_INT32,
997 ops::Placeholder::Shape(TensorShape({2, 2})));
998 Output ci = ops::Const(s.WithOpName("ci"), 2, {1});
999 Output cf = ops::Const(s.WithOpName("cf"), 2.0f, {1});
1000 Output div_i = ops::Div(s.WithOpName("div_i"), xi, ci);
1001 Output div_f = ops::Div(s.WithOpName("div_f"), xf, cf);
1002 Output realdiv = ops::RealDiv(s.WithOpName("realdiv"), xf, cf);
1003
1004 GrapplerItem item;
1005 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1006 item.fetch = {"div_f", "div_i", "realdiv"};
1007 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1008 GraphDef output;
1009 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1010 TF_EXPECT_OK(status);
1011
1012 EXPECT_EQ(8, output.node_size());
1013 for (int i = 0; i < output.node_size(); ++i) {
1014 const NodeDef& node = output.node(i);
1015 const string& name = node.name();
1016 if (name == "div_i") {
1017 // Integer division is unchanged.
1018 EXPECT_EQ("Div", node.op());
1019 EXPECT_EQ("xi", node.input(0));
1020 EXPECT_EQ("ci", node.input(1));
1021 } else if (name == "div_f") {
1022 EXPECT_EQ("Mul", node.op());
1023 EXPECT_EQ("xf", node.input(0));
1024 EXPECT_EQ("ConstantFolding/div_f_recip", node.input(1));
1025 } else if (name == "realdiv") {
1026 EXPECT_EQ("Mul", node.op());
1027 EXPECT_EQ("xf", node.input(0));
1028 EXPECT_EQ("ConstantFolding/realdiv_recip", node.input(1));
1029 } else if (name == "ConstantFolding/div_f_recip") {
1030 EXPECT_EQ("Const", node.op());
1031 EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
1032 TensorProto t = node.attr().at("value").tensor();
1033 EXPECT_EQ(DT_FLOAT, t.dtype());
1034 EXPECT_EQ(1, t.tensor_shape().dim_size());
1035 EXPECT_EQ(1, t.tensor_shape().dim(0).size());
1036 } else if (name == "ConstantFolding/realdiv_recip") {
1037 EXPECT_EQ("Const", node.op());
1038 EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
1039 TensorProto t = node.attr().at("value").tensor();
1040 EXPECT_EQ(DT_FLOAT, t.dtype());
1041 EXPECT_EQ(1, t.tensor_shape().dim_size());
1042 EXPECT_EQ(1, t.tensor_shape().dim(0).size());
1043 }
1044 }
1045
1046 // Check that the reciprocals have the expected value.
1047 std::vector<string> fetch = {"cf_half"};
1048 auto tensor_expected = EvaluateNodes(item.graph, fetch);
1049 EXPECT_EQ(fetch.size(), tensor_expected.size());
1050 fetch = {"ConstantFolding/div_f_recip", "ConstantFolding/realdiv_recip"};
1051 auto tensors = EvaluateNodes(output, fetch);
1052 EXPECT_EQ(fetch.size(), tensors.size());
1053 for (int i = 0; i < fetch.size(); i++) {
1054 test::ExpectTensorEqual<float>(tensor_expected[0], tensors[i]);
1055 }
1056 }
1057
TEST_F(ConstantFoldingTest,NeutralElement_PartialShape_UnknownOutputShape)1058 TEST_F(ConstantFoldingTest, NeutralElement_PartialShape_UnknownOutputShape) {
1059 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1060 Output x_known =
1061 ops::Placeholder(s.WithOpName("x_known"), DT_FLOAT,
1062 ops::Placeholder::Shape(TensorShape({2, 2})));
1063 Output x_partially_known =
1064 ops::Placeholder(s.WithOpName("x_partially_unknown"), DT_FLOAT,
1065 ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
1066 Output x_unknown = ops::Placeholder(s.WithOpName("x_unknown"), DT_FLOAT);
1067 Output zeros_known = ops::ZerosLike(s.WithOpName("zeros_known"), x_known);
1068 Output zeros_partially_known =
1069 ops::ZerosLike(s.WithOpName("zeros_partially_known"), x_partially_known);
1070 Output zeros_unknown =
1071 ops::ZerosLike(s.WithOpName("zeros_unknown"), x_unknown);
1072
1073 // Multiplies without any additional ops to supply the output shape.
1074 int count = 0;
1075 std::vector<Output> muls;
1076 std::unordered_set<string> not_converted;
1077 std::unordered_set<string> to_const;
1078 std::unordered_set<string> to_identity;
1079 for (const auto* x : {&x_known, &x_partially_known, &x_unknown}) {
1080 for (const auto* zeros :
1081 {&zeros_known, &zeros_partially_known, &zeros_unknown}) {
1082 const string name = strings::StrCat("mul_", count++);
1083 muls.push_back(ops::Mul(s.WithOpName(name), *x, *zeros));
1084 if (x == &x_partially_known && zeros == &zeros_partially_known) {
1085 to_identity.insert(name);
1086 } else if (x == &x_unknown || zeros == &zeros_unknown) {
1087 not_converted.insert(name);
1088 } else {
1089 to_const.insert(name);
1090 }
1091 }
1092 }
1093
1094 GrapplerItem item;
1095 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1096
1097 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1098 GraphDef output;
1099 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1100 TF_EXPECT_OK(status);
1101
1102 EXPECT_EQ(15, output.node_size());
1103 for (int i = 0; i < output.node_size(); ++i) {
1104 const NodeDef& node = output.node(i);
1105 const string& name = node.name();
1106 if (to_const.count(name) > 0) {
1107 EXPECT_EQ("Const", node.op()) << node.name();
1108 } else if (to_identity.count(name) > 0) {
1109 EXPECT_EQ("Identity", node.op()) << node.name();
1110 } else if (not_converted.count(name) > 0) {
1111 EXPECT_EQ("Mul", node.op()) << node.name();
1112 }
1113 }
1114
1115 const std::vector<string> fetch = {"mul_0", "mul_4", "mul_8"};
1116 auto x_known_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1117 auto x_partially_unknown_t =
1118 GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
1119 auto x_unknown_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 7}));
1120 auto expected_tensors =
1121 EvaluateNodes(item.graph, fetch,
1122 {{"x_known", x_known_t},
1123 {"x_partially_unknown", x_partially_unknown_t},
1124 {"x_unknown", x_unknown_t}});
1125 EXPECT_EQ(fetch.size(), expected_tensors.size());
1126 auto tensors = EvaluateNodes(output, fetch,
1127 {{"x_known", x_known_t},
1128 {"x_partially_unknown", x_partially_unknown_t},
1129 {"x_unknown", x_unknown_t}});
1130 EXPECT_EQ(fetch.size(), tensors.size());
1131 for (int i = 0; i < tensors.size(); i++)
1132 test::ExpectTensorNear<float>(expected_tensors[i], tensors[i], 1e-5);
1133 }
1134
TEST_F(ConstantFoldingTest,NeutralElement_PartialShape_KnownOutputShape)1135 TEST_F(ConstantFoldingTest, NeutralElement_PartialShape_KnownOutputShape) {
1136 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1137 Output known_shape = ops::Const(s.WithOpName("known_shape"), 0.0f, {2, 2});
1138 Output x_partially_known =
1139 ops::Placeholder(s.WithOpName("x_partially_unknown"), DT_FLOAT,
1140 ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
1141 Output x_unknown = ops::Placeholder(s.WithOpName("x_unknown"), DT_FLOAT);
1142 Output zeros_partially_known =
1143 ops::ZerosLike(s.WithOpName("zeros_partially_known"), x_partially_known);
1144 Output zeros_unknown =
1145 ops::ZerosLike(s.WithOpName("zeros_unknown"), x_unknown);
1146
1147 // If at least one of the inputs to AddN has a known shape, shape inference
1148 // will propagate the shape back to the inputs of AddN, making the
1149 // output shapes of all its inputs known
1150 std::vector<Output> muls_deduced_output_shape;
1151 std::unordered_set<string> to_const;
1152 int count = 0;
1153 for (const auto& x : {x_partially_known, x_unknown}) {
1154 for (const auto& zeros : {zeros_partially_known, zeros_unknown}) {
1155 const string name = strings::StrCat("mul_", count++);
1156 muls_deduced_output_shape.push_back(
1157 ops::Mul(s.WithOpName(name), x, zeros));
1158 to_const.insert(name);
1159 }
1160 }
1161 // We add a known shape as input to AddN to propagate it back to the
1162 // multiplies above, which means they can all be turned into Const nodes.
1163 muls_deduced_output_shape.push_back(known_shape);
1164 Output addn1 = ops::AddN(s.WithOpName("addn1"), muls_deduced_output_shape);
1165
1166 GrapplerItem item;
1167 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1168
1169 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1170 GraphDef output;
1171 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1172 TF_EXPECT_OK(status);
1173
1174 EXPECT_EQ(10, output.node_size());
1175 for (int i = 0; i < output.node_size(); ++i) {
1176 const NodeDef& node = output.node(i);
1177 const string& name = node.name();
1178 if (to_const.count(name) > 0) {
1179 EXPECT_EQ("Const", node.op()) << node.name();
1180 EXPECT_EQ(2, node.input_size());
1181 EXPECT_TRUE(IsControlInput(node.input(0)));
1182 EXPECT_TRUE(IsControlInput(node.input(1)));
1183 }
1184 }
1185 const std::vector<string> fetch = {"addn1"};
1186 auto x_partially_unknown_t =
1187 GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1188 auto x_unknown_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1189 auto expected_tensors =
1190 EvaluateNodes(item.graph, fetch,
1191 {{"x_partially_unknown", x_partially_unknown_t},
1192 {"x_unknown", x_unknown_t}});
1193 EXPECT_EQ(1, expected_tensors.size());
1194 auto tensors = EvaluateNodes(output, fetch,
1195 {{"x_partially_unknown", x_partially_unknown_t},
1196 {"x_unknown", x_unknown_t}});
1197 EXPECT_EQ(1, tensors.size());
1198 test::ExpectTensorNear<float>(expected_tensors[0], tensors[0], 1e-5);
1199 }
1200
TEST_F(ConstantFoldingTest,CreateConstNodes)1201 TEST_F(ConstantFoldingTest, CreateConstNodes) {
1202 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1203
1204 #define MAKE_TEST_GRAPH(TYPE) \
1205 Output TYPE##_const = \
1206 ops::Const(s.WithOpName(#TYPE "_const"), static_cast<TYPE>(10), {5}); \
1207 Output TYPE##_mul = \
1208 ops::Mul(s.WithOpName(#TYPE "_mul"), TYPE##_const, TYPE##_const); \
1209 Output TYPE##_id = ops::Identity(s.WithOpName(#TYPE "_id"), TYPE##_mul)
1210
1211 MAKE_TEST_GRAPH(float);
1212 MAKE_TEST_GRAPH(double);
1213 MAKE_TEST_GRAPH(int64);
1214 MAKE_TEST_GRAPH(int32);
1215 MAKE_TEST_GRAPH(int16);
1216 MAKE_TEST_GRAPH(int8);
1217 MAKE_TEST_GRAPH(uint8);
1218 #undef MAKE_TEST_GRAPH
1219
1220 Output bool_const = ops::Const(s.WithOpName("bool_const"), true, {5});
1221 Output bool_and =
1222 ops::LogicalAnd(s.WithOpName("bool_and"), bool_const, bool_const);
1223 Output bool_id = ops::Identity(s.WithOpName("bool_id"), bool_and);
1224
1225 GrapplerItem item;
1226 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1227 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1228 GraphDef output;
1229 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1230 TF_EXPECT_OK(status);
1231
1232 EXPECT_EQ(24, output.node_size());
1233 for (const NodeDef& node : output.node()) {
1234 #define CHECK_RESULT(TYPE, FIELD) \
1235 if (node.name() == #TYPE "_mul") { \
1236 EXPECT_EQ(5, \
1237 node.attr().at("value").tensor().tensor_shape().dim(0).size()); \
1238 EXPECT_EQ(1, node.attr().at("value").tensor().FIELD##_val_size()); \
1239 EXPECT_EQ(10 * 10, node.attr().at("value").tensor().FIELD##_val(0)); \
1240 }
1241
1242 CHECK_RESULT(float, float);
1243 CHECK_RESULT(double, double);
1244 CHECK_RESULT(int64, int64);
1245 CHECK_RESULT(int32, int);
1246 CHECK_RESULT(int16, int);
1247 CHECK_RESULT(int8, int);
1248 CHECK_RESULT(uint8, int);
1249 #undef CHECK_RESULT
1250
1251 if (node.name() == "bool_and") {
1252 EXPECT_EQ(5,
1253 node.attr().at("value").tensor().tensor_shape().dim(0).size());
1254 EXPECT_EQ(1, node.attr().at("value").tensor().bool_val_size());
1255 EXPECT_EQ(true && true, node.attr().at("value").tensor().bool_val(0));
1256 }
1257 }
1258 }
1259
TEST_F(ConstantFoldingTest,FoldingNodeWithTwoOutputs)1260 TEST_F(ConstantFoldingTest, FoldingNodeWithTwoOutputs) {
1261 // Build a simple graph with a few trivially prunable ops.
1262 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1263
1264 Output a = ops::Const(s.WithOpName("a"), 10, {5});
1265 auto b = ops::Unique(s.WithOpName("b"), {a});
1266 Output c = ops::Identity(s.WithOpName("c"), {b.y});
1267 Output d = ops::Identity(s.WithOpName("d"), {b.idx});
1268 Output e = ops::Identity(s.WithOpName("e"), {c});
1269 Output f = ops::Identity(s.WithOpName("f"), {d});
1270
1271 GrapplerItem item;
1272 item.fetch.push_back("e");
1273 item.fetch.push_back("f");
1274 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1275
1276 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1277 GraphDef output;
1278 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1279 TF_EXPECT_OK(status);
1280
1281 EXPECT_EQ(2, output.node_size());
1282
1283 const NodeDef& new_c = output.node(0);
1284 EXPECT_EQ("e", new_c.name());
1285 EXPECT_EQ("Const", new_c.op());
1286
1287 const NodeDef& new_d = output.node(1);
1288 EXPECT_EQ("f", new_d.name());
1289 EXPECT_EQ("Const", new_d.op());
1290
1291 std::vector<string> fetch = {"e", "f"};
1292 auto tensors_expected = EvaluateNodes(item.graph, fetch);
1293 auto tensors = EvaluateNodes(output, fetch);
1294 EXPECT_EQ(fetch.size(), tensors_expected.size());
1295 EXPECT_EQ(fetch.size(), tensors.size());
1296 for (int i = 0; i < fetch.size(); i++) {
1297 test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
1298 }
1299 }
1300
TEST_F(ConstantFoldingTest,ControlDependencies)1301 TEST_F(ConstantFoldingTest, ControlDependencies) {
1302 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1303 Output dflt = ops::Const(scope.WithOpName("dflt"), 3.14f, {1});
1304 Output p1 = ops::PlaceholderWithDefault(scope.WithOpName("p1"), dflt, {1});
1305 Output p2 = ops::PlaceholderWithDefault(scope.WithOpName("p2"), dflt, {1});
1306 Output c =
1307 ops::Const(scope.WithOpName("c").WithControlDependencies(p1), 10, {3});
1308 Output i1 = ops::Identity(scope.WithOpName("i1"), {c});
1309 Output i2 =
1310 ops::Identity(scope.WithOpName("i2").WithControlDependencies(p2), {i1});
1311 Output i3 = ops::Identity(scope.WithOpName("i3"), {i2});
1312
1313 GrapplerItem item;
1314 item.fetch.push_back("i3");
1315 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1316
1317 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1318 GraphDef output;
1319 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1320 TF_EXPECT_OK(status);
1321
1322 std::vector<string> expected_nodes = {"dflt", "p1", "p2", "i3"};
1323 EXPECT_EQ(output.node_size(), expected_nodes.size());
1324 int i = 0;
1325 int found = 0;
1326 for (const auto& node : output.node()) {
1327 EXPECT_EQ(expected_nodes[i], output.node(i).name());
1328 i++;
1329 if (node.name() == "i3") {
1330 EXPECT_EQ("Const", node.op());
1331 ++found;
1332 auto folded = EvaluateNodes(output, {"i3"});
1333 auto expected = EvaluateNodes(item.graph, {"i3"});
1334 EXPECT_EQ(1, expected.size());
1335 EXPECT_EQ(1, folded.size());
1336 test::ExpectTensorEqual<int>(folded[0], expected[0]);
1337 EXPECT_EQ(2, node.input_size());
1338 EXPECT_EQ("^p1", node.input(0));
1339 EXPECT_EQ("^p2", node.input(1));
1340 }
1341 }
1342 EXPECT_EQ(1, found);
1343 }
1344
TEST_F(ConstantFoldingTest,ControlDependenciesEmptyFetch)1345 TEST_F(ConstantFoldingTest, ControlDependenciesEmptyFetch) {
1346 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1347 Output dflt = ops::Const(scope.WithOpName("dflt"), 3.14f, {1});
1348 Output p1 = ops::PlaceholderWithDefault(scope.WithOpName("p1"), dflt, {1});
1349 Output p2 = ops::PlaceholderWithDefault(scope.WithOpName("p2"), dflt, {1});
1350 Output c =
1351 ops::Const(scope.WithOpName("c").WithControlDependencies(p1), 10, {3});
1352 Output i1 = ops::Identity(scope.WithOpName("i1"), {c});
1353 Output i2 =
1354 ops::Identity(scope.WithOpName("i2").WithControlDependencies(p2), {i1});
1355 Output i3 = ops::Identity(scope.WithOpName("e"), {i2});
1356
1357 GrapplerItem item;
1358 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1359
1360 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1361 GraphDef output;
1362 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1363 TF_EXPECT_OK(status);
1364
1365 std::vector<string> expected_nodes = {"dflt", "p1", "p2", "c",
1366 "i1", "i2", "e"};
1367 EXPECT_EQ(output.node_size(), expected_nodes.size());
1368 int i = 0;
1369 int found = 0;
1370 for (const auto& node : output.node()) {
1371 EXPECT_EQ(expected_nodes[i], output.node(i).name());
1372 i++;
1373 if (node.name() == "i1") {
1374 EXPECT_EQ("Const", node.op());
1375 ++found;
1376 auto folded = EvaluateNodes(output, {"i1"});
1377 auto expected = EvaluateNodes(item.graph, {"i1"});
1378 EXPECT_EQ(1, expected.size());
1379 EXPECT_EQ(1, folded.size());
1380 test::ExpectTensorEqual<int>(folded[0], expected[0]);
1381 EXPECT_EQ(1, node.input_size());
1382 EXPECT_EQ("^p1", node.input(0));
1383 }
1384 if (node.name() == "i2") {
1385 EXPECT_EQ("Const", node.op());
1386 ++found;
1387 auto folded = EvaluateNodes(output, {"i2"});
1388 auto expected = EvaluateNodes(item.graph, {"i2"});
1389 EXPECT_EQ(1, expected.size());
1390 EXPECT_EQ(1, folded.size());
1391 test::ExpectTensorEqual<int>(folded[0], expected[0]);
1392 EXPECT_EQ(2, node.input_size());
1393 EXPECT_EQ("^p1", node.input(0));
1394 EXPECT_EQ("^p2", node.input(1));
1395 }
1396 }
1397 EXPECT_EQ(2, found);
1398 }
1399
TEST_F(ConstantFoldingTest,ControlDependenciesDeduplicate)1400 TEST_F(ConstantFoldingTest, ControlDependenciesDeduplicate) {
1401 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1402 Output dflt = ops::Const(scope.WithOpName("dflt"), 3.14f, {1});
1403 Output p1 = ops::PlaceholderWithDefault(scope.WithOpName("p1"), dflt, {1});
1404 Output p2 = ops::PlaceholderWithDefault(scope.WithOpName("p2"), dflt, {1});
1405 Output c =
1406 ops::Const(scope.WithOpName("c").WithControlDependencies(p1), 10, {3});
1407 Output i1 = ops::Identity(scope.WithOpName("i1")
1408 .WithControlDependencies(p2)
1409 .WithControlDependencies(p1),
1410 {c});
1411 Output i2 = ops::Identity(scope.WithOpName("i2"), {i1});
1412
1413 GrapplerItem item;
1414 item.fetch.push_back("i2");
1415 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1416 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1417 EXPECT_EQ(1, tensors_expected.size());
1418 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1419 GraphDef output;
1420 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1421 TF_EXPECT_OK(status);
1422
1423 std::vector<string> expected_nodes = {"dflt", "p1", "p2", "i2"};
1424 EXPECT_EQ(output.node_size(), expected_nodes.size());
1425 int i = 0;
1426 for (const auto& node : output.node()) {
1427 EXPECT_EQ(expected_nodes[i], output.node(i).name());
1428 i++;
1429 if (node.name() == "i2") {
1430 EXPECT_EQ("Const", node.op());
1431 EXPECT_EQ(2, node.input_size());
1432 EXPECT_EQ("^p1", node.input(0));
1433 EXPECT_EQ("^p2", node.input(1));
1434 }
1435 }
1436 auto tensors = EvaluateNodes(output, item.fetch);
1437 EXPECT_EQ(1, tensors.size());
1438 test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1439 }
1440
TEST_F(ConstantFoldingTest,VariableNumberOfOutputs)1441 TEST_F(ConstantFoldingTest, VariableNumberOfOutputs) {
1442 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1443 // Add a DynamicPartition node to the graph
1444 Output input = ops::Const(scope.WithOpName("in0"), 314, {3, 4, 5});
1445 Output indices = ops::Const(scope.WithOpName("indices"), 1, {3, 4});
1446 int num_partitions = 4;
1447 ops::DynamicPartition part(scope.WithOpName("partition"), input, indices,
1448 num_partitions);
1449
1450 std::vector<string> outputs;
1451 for (int i = 0; i < num_partitions; ++i) {
1452 string part_out_name = strings::StrCat("part_out", i);
1453 ops::Identity partition_out(scope.WithOpName(part_out_name),
1454 {part.outputs[i]});
1455 outputs.push_back(part_out_name);
1456 }
1457
1458 GrapplerItem item;
1459 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1460
1461 // Add a ConcatOffset node to the graph
1462 Tensor initial_val(DT_INT32, TensorShape({3}));
1463 test::FillIota<int>(&initial_val, 7);
1464 for (int i = 1; i < 5; ++i) {
1465 TF_CHECK_OK(NodeDefBuilder(strings::StrCat("in", i), "Const")
1466 .Attr("dtype", DT_INT32)
1467 .Attr("value", initial_val)
1468 .Finalize(item.graph.add_node()));
1469 }
1470 Tensor concat_dim(DT_INT32, TensorShape({}));
1471 test::FillIota<int>(&concat_dim, 0);
1472 TF_CHECK_OK(NodeDefBuilder("concat_dim", "Const")
1473 .Attr("dtype", DT_INT32)
1474 .Attr("value", concat_dim)
1475 .Finalize(item.graph.add_node()));
1476
1477 TF_CHECK_OK(NodeDefBuilder("concat_offsets", "ConcatOffset")
1478 .Input("concat_dim", 0, DT_INT32)
1479 .Input({NodeDefBuilder::NodeOut("in1", 0, DT_INT32),
1480 NodeDefBuilder::NodeOut("in2", 0, DT_INT32),
1481 NodeDefBuilder::NodeOut("in3", 0, DT_INT32),
1482 NodeDefBuilder::NodeOut("in4", 0, DT_INT32)})
1483 .Finalize(item.graph.add_node()));
1484
1485 for (int i = 0; i < 4; ++i) {
1486 string concat_offset_out_name = strings::StrCat("concat_offset_out", i);
1487 TF_CHECK_OK(NodeDefBuilder(concat_offset_out_name, "Identity")
1488 .Attr("T", DT_INT32)
1489 .Input("concat_offsets", i, DT_INT32)
1490 .Finalize(item.graph.add_node()));
1491 outputs.push_back(concat_offset_out_name);
1492 }
1493
1494 item.fetch = outputs;
1495 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1496 GraphDef output;
1497 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1498 TF_EXPECT_OK(status);
1499
1500 int constant_folded = 0;
1501 for (const auto& node : output.node()) {
1502 if (node.name().find("part_out") != string::npos ||
1503 node.name().find("concat_offset_out") != string::npos) {
1504 ++constant_folded;
1505 EXPECT_EQ("Const", node.op());
1506 }
1507 }
1508 EXPECT_EQ(8, constant_folded);
1509
1510 auto expected = EvaluateNodes(item.graph, outputs);
1511 auto optimized = EvaluateNodes(output, outputs);
1512 ASSERT_EQ(expected.size(), optimized.size());
1513 for (int i = 0; i < expected.size(); ++i) {
1514 test::ExpectTensorEqual<int>(expected[i], optimized[i]);
1515 }
1516 }
1517
TEST_F(ConstantFoldingTest,ShapeMaterialization)1518 TEST_F(ConstantFoldingTest, ShapeMaterialization) {
1519 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1520 Output v1 = ops::Variable(scope.WithOpName("v1"), {3}, DT_FLOAT);
1521 Output v2 = ops::Variable(scope.WithOpName("v2"), {5, 7}, DT_FLOAT);
1522 Output v3 = ops::Variable(scope.WithOpName("v3"), {11, 13}, DT_FLOAT);
1523 Output rank = ops::Rank(scope.WithOpName("rank"), v1);
1524 Output shape = ops::Shape(scope.WithOpName("shape"), v2);
1525 Output size = ops::Size(scope.WithOpName("size"), v3);
1526 Output p1 = ops::Multiply(scope.WithOpName("p1"), size, rank);
1527 Output p2 = ops::Multiply(scope.WithOpName("p2"), p1, shape);
1528
1529 GrapplerItem item;
1530 item.fetch.push_back("p2");
1531 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1532
1533 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1534 GraphDef output;
1535 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1536 TF_EXPECT_OK(status);
1537
1538 int found = 0;
1539 for (const auto& node : output.node()) {
1540 if (node.name() == "p2") {
1541 ++found;
1542 EXPECT_EQ("Const", node.op());
1543 EXPECT_EQ(3, node.input_size());
1544 EXPECT_EQ("^v3", node.input(0));
1545 EXPECT_EQ("^v1", node.input(1));
1546 EXPECT_EQ("^v2", node.input(2));
1547 Tensor value;
1548 CHECK(value.FromProto(node.attr().at("value").tensor()));
1549 // rank = 1, shape = (5, 7), size = 143 = 11*13
1550 // p2 = (715, 1001) = (5*143, 7*143)
1551 EXPECT_EQ(715, value.flat<int>()(0));
1552 EXPECT_EQ(1001, value.flat<int>()(1));
1553 }
1554 }
1555 EXPECT_EQ(1, found);
1556 auto v1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3}));
1557 auto v2_t = GenerateRandomTensor<DT_FLOAT>({5, 7});
1558 auto v3_t = GenerateRandomTensor<DT_FLOAT>({11, 13});
1559
1560 auto tensors_expected = EvaluateNodes(
1561 item.graph, item.fetch, {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1562 EXPECT_EQ(1, item.fetch.size());
1563 auto tensors = EvaluateNodes(output, item.fetch,
1564 {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1565 EXPECT_EQ(1, item.fetch.size());
1566 test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1567 }
1568
TEST_F(ConstantFoldingTest,ShapeMaterializationEmptyFetch)1569 TEST_F(ConstantFoldingTest, ShapeMaterializationEmptyFetch) {
1570 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1571 Output v1 = ops::Variable(scope.WithOpName("v1"), {3}, DT_FLOAT);
1572 Output v2 = ops::Variable(scope.WithOpName("v2"), {5, 7}, DT_FLOAT);
1573 Output v3 = ops::Variable(scope.WithOpName("v3"), {11, 13}, DT_FLOAT);
1574 Output rank = ops::Rank(scope.WithOpName("rank"), v1);
1575 Output shape = ops::Shape(scope.WithOpName("shape"), v2);
1576 Output size = ops::Size(scope.WithOpName("size"), v3);
1577 Output p1 = ops::Multiply(scope.WithOpName("p1"), size, rank);
1578 Output p2 = ops::Multiply(scope.WithOpName("p2"), p1, shape);
1579
1580 GrapplerItem item;
1581 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1582
1583 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1584 GraphDef output;
1585 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1586 TF_EXPECT_OK(status);
1587
1588 int found = 0;
1589 for (const auto& node : output.node()) {
1590 if (node.name() == "size") {
1591 ++found;
1592 EXPECT_EQ("Const", node.op());
1593 EXPECT_EQ(1, node.input_size());
1594 EXPECT_EQ("^v3", node.input(0));
1595 Tensor value;
1596 CHECK(value.FromProto(node.attr().at("value").tensor()));
1597 EXPECT_EQ(11 * 13, value.flat<int>()(0));
1598 } else if (node.name() == "rank") {
1599 ++found;
1600 EXPECT_EQ("Const", node.op());
1601 EXPECT_EQ(1, node.input_size());
1602 EXPECT_EQ("^v1", node.input(0));
1603 Tensor value;
1604 CHECK(value.FromProto(node.attr().at("value").tensor()));
1605 EXPECT_EQ(1, value.flat<int>()(0));
1606 } else if (node.name() == "shape") {
1607 ++found;
1608 EXPECT_EQ("Const", node.op());
1609 EXPECT_EQ(1, node.input_size());
1610 EXPECT_EQ("^v2", node.input(0));
1611 Tensor value;
1612 CHECK(value.FromProto(node.attr().at("value").tensor()));
1613 EXPECT_EQ(5, value.flat<int>()(0));
1614 EXPECT_EQ(7, value.flat<int>()(1));
1615 }
1616 }
1617 EXPECT_EQ(3, found);
1618
1619 auto v1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3}));
1620 auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 7}));
1621 auto v3_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({11, 13}));
1622 std::vector<string> fetch_nodes = {"p2"};
1623 auto tensors_expected = EvaluateNodes(
1624 item.graph, fetch_nodes, {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1625 EXPECT_EQ(1, tensors_expected.size());
1626 auto tensors = EvaluateNodes(output, fetch_nodes,
1627 {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1628 EXPECT_EQ(1, tensors.size());
1629 test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1630 }
1631
TEST_F(ConstantFoldingTest,ShapeMaterializationShapeN)1632 TEST_F(ConstantFoldingTest, ShapeMaterializationShapeN) {
1633 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1634 Output v1 = ops::Variable(scope.WithOpName("v1"), {3, -1}, DT_FLOAT);
1635 Output v2 = ops::Variable(scope.WithOpName("v2"), {}, DT_FLOAT);
1636 Output v3 = ops::Variable(scope.WithOpName("v3"), {4, 6}, DT_FLOAT);
1637 auto s = ops::ShapeN(scope.WithOpName("s"), {v1, v2, v3});
1638 Output i1a = ops::Identity(scope.WithOpName("i1a"), s[0]);
1639 Output i1b = ops::Identity(scope.WithOpName("i1b"), s[0]);
1640 Output i2a = ops::Identity(scope.WithOpName("i2a"), s[1]);
1641 Output i2b = ops::Identity(scope.WithOpName("i2b"), s[1]);
1642 Output i2c = ops::Identity(scope.WithOpName("i2c"), s[1]);
1643 Output i3a = ops::Identity(scope.WithOpName("i3a"), s[2]);
1644 Output i3b = ops::Identity(scope.WithOpName("i3b"), s[2]);
1645
1646 GrapplerItem item;
1647 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1648
1649 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1650 GraphDef output;
1651 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1652 TF_EXPECT_OK(status);
1653 int found = 0;
1654 for (const auto& node : output.node()) {
1655 EXPECT_NE(AddPrefixToNodeName("s-matshapes-0", kConstantFoldingConst),
1656 node.name());
1657 EXPECT_NE(AddPrefixToNodeName("s-matshapes-1", kConstantFoldingConst),
1658 node.name());
1659 if (node.name() == "i1a" || node.name() == "i1b") {
1660 ++found;
1661 EXPECT_EQ("s", node.input(0));
1662 }
1663 if (node.name() == "i2a" || node.name() == "i2b" || node.name() == "i2c") {
1664 ++found;
1665 EXPECT_EQ("s:1", node.input(0));
1666 }
1667 if (node.name() == "i3a" || node.name() == "i3b") {
1668 ++found;
1669 EXPECT_EQ(AddPrefixToNodeName("s-matshapes-2", kConstantFoldingConst),
1670 node.input(0));
1671 }
1672 if (node.name() == "s") {
1673 ++found;
1674 EXPECT_EQ("ShapeN", node.op());
1675 EXPECT_EQ("v1", node.input(0));
1676 EXPECT_EQ("v2", node.input(1));
1677 EXPECT_EQ("v3", node.input(2));
1678 }
1679 if (node.name() ==
1680 AddPrefixToNodeName("s-matshapes-2", kConstantFoldingConst)) {
1681 ++found;
1682 EXPECT_EQ("Const", node.op());
1683 EXPECT_EQ("^s", node.input(0));
1684 Tensor value;
1685 CHECK(value.FromProto(node.attr().at("value").tensor()));
1686 EXPECT_EQ(4, value.flat<int>()(0));
1687 EXPECT_EQ(6, value.flat<int>()(1));
1688 }
1689 }
1690 EXPECT_EQ(9, found);
1691
1692 auto v1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
1693 auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 6}));
1694 auto v3_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 6}));
1695 const std::vector<string> fetch_nodes = {"i1a", "i1b", "i2a", "i2b",
1696 "i2c", "i3a", "i3b"};
1697 auto tensors_expected = EvaluateNodes(
1698 item.graph, fetch_nodes, {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1699 EXPECT_EQ(fetch_nodes.size(), tensors_expected.size());
1700 auto tensors = EvaluateNodes(output, fetch_nodes,
1701 {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1702 EXPECT_EQ(fetch_nodes.size(), tensors.size());
1703 for (int i = 0; i < fetch_nodes.size(); i++)
1704 test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
1705 }
1706
TEST_F(ConstantFoldingTest,ShapeMaterializationShapeN_MultipleOutputs)1707 TEST_F(ConstantFoldingTest, ShapeMaterializationShapeN_MultipleOutputs) {
1708 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1709 Output v1 = ops::Variable(scope.WithOpName("v1"), {3, -1}, DT_FLOAT);
1710 Output v2 = ops::Variable(scope.WithOpName("v2"), {4, 6}, DT_FLOAT);
1711 auto s = ops::ShapeN(scope.WithOpName("s"), {v1, v2});
1712 auto id_n = ops::IdentityN(scope.WithOpName("id_n"), {s[0], s[1]});
1713 Output ia = ops::Identity(scope.WithOpName("ia"), id_n[0]);
1714 Output ib = ops::Identity(scope.WithOpName("ib"), id_n[1]);
1715
1716 GrapplerItem item;
1717 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1718 item.fetch.push_back("ia");
1719 item.fetch.push_back("ib");
1720
1721 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1722 GraphDef output;
1723 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1724 TF_EXPECT_OK(status);
1725
1726 int found = 0;
1727 for (const auto& node : output.node()) {
1728 EXPECT_NE(AddPrefixToNodeName("s-matshapes-0", kConstantFoldingConst),
1729 node.name());
1730 if (node.name() == "s") {
1731 ++found;
1732 EXPECT_EQ("ShapeN", node.op());
1733 EXPECT_EQ("v1", node.input(0));
1734 EXPECT_EQ("v2", node.input(1));
1735 }
1736 if (node.name() == "id_n") {
1737 ++found;
1738 EXPECT_EQ("IdentityN", node.op());
1739 EXPECT_EQ("s", node.input(0));
1740 EXPECT_EQ(AddPrefixToNodeName("s-matshapes-1", kConstantFoldingConst),
1741 node.input(1));
1742 }
1743 if (node.name() == "ia") {
1744 ++found;
1745 EXPECT_EQ("id_n", node.input(0));
1746 }
1747 if (node.name() == "ib") {
1748 ++found;
1749 EXPECT_EQ("Const", node.op());
1750 EXPECT_EQ("^s", node.input(0));
1751 EXPECT_EQ("^id_n", node.input(1));
1752 }
1753 }
1754 EXPECT_EQ(4, found);
1755
1756 auto v1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
1757 auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 6}));
1758 auto tensors_expected =
1759 EvaluateNodes(item.graph, item.fetch, {{"v1", v1_t}, {"v2", v2_t}});
1760 EXPECT_EQ(2, tensors_expected.size());
1761 auto tensors =
1762 EvaluateNodes(output, item.fetch, {{"v1", v1_t}, {"v2", v2_t}});
1763 EXPECT_EQ(2, tensors.size());
1764 for (int i = 0; i < tensors.size(); i++)
1765 test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
1766 }
1767
TEST_F(ConstantFoldingTest,SwitchNodesEmptyFetch)1768 TEST_F(ConstantFoldingTest, SwitchNodesEmptyFetch) {
1769 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1770 ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT);
1771 ops::Variable v_ctrl(scope.WithOpName("v_ctrl"), {}, DT_BOOL);
1772 ops::Switch s1(scope.WithOpName("switch"), v_in, v_ctrl);
1773 ops::Rank rank(scope.WithOpName("rank"), s1.output_false);
1774 ops::Identity i(scope.WithOpName("i"), s1.output_true);
1775 ops::Size size(scope.WithOpName("size"), i);
1776 ops::Square p1(scope.WithOpName("p1"), rank);
1777 ops::Square p2(scope.WithOpName("p2"), size);
1778 ops::Merge m(scope.WithOpName("m"), {p1.y, p2.y});
1779
1780 Output predicate =
1781 ops::Const(scope.WithOpName("false"), false, TensorShape({}));
1782 Output constant =
1783 ops::Const(scope.WithOpName("constant"), 1.0f, TensorShape({1}));
1784 ops::Switch s2(scope.WithOpName("switch2"), constant, predicate);
1785 ops::Identity statically_known(scope.WithOpName("i2"), s2.output_false);
1786 ops::Identity never_generated(scope.WithOpName("i3"), s2.output_true);
1787 ops::Merge m2(scope.WithOpName("m2"),
1788 {statically_known.output, never_generated.output});
1789
1790 GrapplerItem item;
1791 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1792
1793 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1794 GraphDef output;
1795 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1796 TF_EXPECT_OK(status);
1797
1798 std::set<string> present_nodes = {"v_in", "v_ctrl",
1799 "switch", "i",
1800 "p1", "p2",
1801 "m", "false",
1802 "constant", "switch2",
1803 "i2", "i3",
1804 "m2", "ConstantFoldingCtrl/switch_0",
1805 "rank", "size"};
1806 std::set<string> not_present_nodes = {"ConstantFolding/switch2-0"};
1807 EXPECT_EQ(present_nodes.size(), output.node_size());
1808 int found = 0;
1809 for (const auto& node : output.node()) {
1810 EXPECT_TRUE(present_nodes.find(node.name()) != present_nodes.end())
1811 << node.name();
1812 EXPECT_TRUE(not_present_nodes.find(node.name()) == not_present_nodes.end())
1813 << node.name();
1814 present_nodes.erase(node.name());
1815 not_present_nodes.erase(node.name());
1816 if (node.name() == "rank") {
1817 ++found;
1818 EXPECT_EQ("Const", node.op());
1819 EXPECT_EQ(1, node.input_size());
1820 EXPECT_EQ("^ConstantFoldingCtrl/switch_0", node.input(0));
1821 }
1822 if (node.name() == "size") {
1823 ++found;
1824 EXPECT_EQ("Const", node.op());
1825 EXPECT_EQ(1, node.input_size());
1826 EXPECT_EQ("^i", node.input(0));
1827 }
1828 if (node.name() == "i2") {
1829 ++found;
1830 EXPECT_EQ("Const", node.op());
1831 EXPECT_EQ(0, node.input_size());
1832 }
1833 if (node.name() == "i3") {
1834 ++found;
1835 EXPECT_EQ("Identity", node.op());
1836 EXPECT_EQ(1, node.input_size());
1837 EXPECT_EQ("switch2:1", node.input(0));
1838 }
1839 }
1840 EXPECT_EQ(4, found);
1841
1842 auto v_in_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3}));
1843 Tensor v_ctrl_t(DT_BOOL, TensorShape({}));
1844
1845 v_ctrl_t.flat<bool>()(0) = true;
1846 std::vector<string> fetch_nodes = {"m", "m2"};
1847 auto tensors_expected = EvaluateNodes(
1848 item.graph, fetch_nodes, {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1849 EXPECT_EQ(2, tensors_expected.size());
1850 auto tensors = EvaluateNodes(output, fetch_nodes,
1851 {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1852 EXPECT_EQ(2, tensors.size());
1853 test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1854 test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-5);
1855
1856 v_ctrl_t.flat<bool>()(0) = false;
1857 tensors_expected = EvaluateNodes(item.graph, fetch_nodes,
1858 {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1859 EXPECT_EQ(2, tensors_expected.size());
1860 tensors = EvaluateNodes(output, fetch_nodes,
1861 {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1862 EXPECT_EQ(2, tensors.size());
1863 test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1864 test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-5);
1865 }
1866
TEST_F(ConstantFoldingTest,SwitchNodes)1867 TEST_F(ConstantFoldingTest, SwitchNodes) {
1868 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1869 ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT);
1870 ops::Variable v_ctrl(scope.WithOpName("v_ctrl"), {}, DT_BOOL);
1871 ops::Switch s1(scope.WithOpName("switch"), v_in, v_ctrl);
1872 ops::Rank rank(scope.WithOpName("rank"), s1.output_false);
1873 ops::Identity i(scope.WithOpName("i"), s1.output_true);
1874 ops::Size size(scope.WithOpName("size"), i);
1875 ops::Square p1(scope.WithOpName("p1"), rank);
1876 ops::Square p2(scope.WithOpName("p2"), size);
1877 ops::Merge m(scope.WithOpName("m"), {p1.y, p2.y});
1878
1879 Output predicate =
1880 ops::Const(scope.WithOpName("false"), false, TensorShape({}));
1881 Output constant =
1882 ops::Const(scope.WithOpName("constant"), 1.0f, TensorShape({1}));
1883 ops::Switch s2(scope.WithOpName("switch2"), constant, predicate);
1884 ops::Identity statically_known(scope.WithOpName("i2"), s2.output_false);
1885 ops::Identity never_generated(scope.WithOpName("i3"), s2.output_true);
1886 ops::Merge m2(scope.WithOpName("m2"),
1887 {statically_known.output, never_generated.output});
1888
1889 GrapplerItem item;
1890 item.fetch.push_back("m");
1891 item.fetch.push_back("m2");
1892
1893 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1894
1895 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1896 GraphDef output;
1897 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1898 TF_EXPECT_OK(status);
1899 std::set<string> present_nodes = {"v_in", "v_ctrl",
1900 "switch", "i",
1901 "p1", "p2",
1902 "m", "false",
1903 "constant", "switch2",
1904 "i2", "i3",
1905 "m2", "ConstantFoldingCtrl/switch_0"};
1906 std::set<string> not_present_nodes = {"rank", "size",
1907 "ConstantFolding/switch2-0"};
1908 EXPECT_EQ(present_nodes.size(), output.node_size());
1909
1910 int found = 0;
1911 for (const auto& node : output.node()) {
1912 EXPECT_TRUE(present_nodes.find(node.name()) != present_nodes.end());
1913 EXPECT_TRUE(not_present_nodes.find(node.name()) == not_present_nodes.end());
1914 present_nodes.erase(node.name());
1915 not_present_nodes.erase(node.name());
1916 if (node.name() == "i2") {
1917 ++found;
1918 EXPECT_EQ("Const", node.op());
1919 EXPECT_EQ(0, node.input_size());
1920 }
1921 if (node.name() == "i3") {
1922 ++found;
1923 EXPECT_EQ("Identity", node.op());
1924 EXPECT_EQ(1, node.input_size());
1925 EXPECT_EQ("switch2:1", node.input(0));
1926 }
1927 }
1928 EXPECT_EQ(2, found);
1929
1930 auto v_in_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3}));
1931 Tensor v_ctrl_t(DT_BOOL, TensorShape({}));
1932 v_ctrl_t.flat<bool>()(0) = true;
1933 auto tensors_expected = EvaluateNodes(
1934 item.graph, item.fetch, {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1935 EXPECT_EQ(2, tensors_expected.size());
1936 auto tensors = EvaluateNodes(output, item.fetch,
1937 {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1938 EXPECT_EQ(2, tensors.size());
1939 test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1940 test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-5);
1941
1942 v_ctrl_t.flat<bool>()(0) = false;
1943 tensors_expected = EvaluateNodes(item.graph, item.fetch,
1944 {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1945 EXPECT_EQ(2, tensors_expected.size());
1946 tensors = EvaluateNodes(output, item.fetch,
1947 {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1948 EXPECT_EQ(2, tensors.size());
1949 test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1950 test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-5);
1951 }
1952
TEST_F(ConstantFoldingTest,MergeNodes)1953 TEST_F(ConstantFoldingTest, MergeNodes) {
1954 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1955
1956 Output x =
1957 ops::RandomNormal(scope.WithOpName("x"), {3, 5}, DataType::DT_FLOAT);
1958 Output y =
1959 ops::RandomNormal(scope.WithOpName("y"), {3, 5}, DataType::DT_FLOAT);
1960 Output const1 =
1961 ops::Const(scope.WithOpName("const1").WithControlDependencies(x), 2.7f,
1962 TensorShape({3, 5}));
1963 Output const2 =
1964 ops::Const(scope.WithOpName("const2"), 3.14f, TensorShape({3, 5}));
1965 Output const3 =
1966 ops::Const(scope.WithOpName("const3").WithControlDependencies(x), 3.14f,
1967 TensorShape({3, 5}));
1968
1969 // Create 3 merge nodes: m1 is foldable, m2 and m3 aren't.
1970 ops::Merge m1(scope.WithOpName("m1"), {x, const1, const2});
1971 ops::Merge m2(scope.WithOpName("m2"), {const1, const3});
1972 ops::Merge m3(scope.WithOpName("m3"), {x, y});
1973 // m4 is not foldable because the only constant input
1974 // has a control input, so we cannot know if it will be
1975 // triggered.
1976 ops::Merge m4(scope.WithOpName("m4"), {x, const1});
1977
1978 ops::Identity out1(scope.WithOpName("out1"), m1.output);
1979 ops::Identity idx1(scope.WithOpName("idx1"), m1.value_index);
1980 ops::Identity out2(scope.WithOpName("out2"), m2.output);
1981 ops::Identity idx2(scope.WithOpName("idx2"), m2.value_index);
1982 ops::Identity out3(scope.WithOpName("out3"), m3.output);
1983 ops::Identity idx3(scope.WithOpName("idx3"), m3.value_index);
1984 ops::Identity out4(scope.WithOpName("out4"), m4.output);
1985 ops::Identity idx4(scope.WithOpName("idx4"), m4.value_index);
1986
1987 GrapplerItem item;
1988 item.fetch = {"out1", "idx1", "out2", "idx2", "out3", "idx3", "out4", "idx4"};
1989 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1990
1991 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1992 GraphDef output;
1993 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1994 TF_EXPECT_OK(status);
1995
1996 EXPECT_EQ(19, output.node_size());
1997 int found_nodes = 0;
1998 for (const auto& node : output.node()) {
1999 if (node.name() == "out1") {
2000 EXPECT_EQ(1, node.input_size());
2001 EXPECT_EQ("^m1", node.input(0));
2002 ++found_nodes;
2003 } else if (node.name() == "idx1") {
2004 EXPECT_EQ(1, node.input_size());
2005 EXPECT_EQ("^m1", node.input(0));
2006 ++found_nodes;
2007 } else if (node.name() == "ConstantFolding/m1") {
2008 EXPECT_EQ("Const", node.op());
2009 EXPECT_EQ(1, node.input_size());
2010 EXPECT_EQ("^m1", node.input(0));
2011 ++found_nodes;
2012 } else if (node.name() == "ConstantFolding/m1_index") {
2013 EXPECT_EQ("Const", node.op());
2014 EXPECT_EQ(1, node.input_size());
2015 EXPECT_EQ("^m1", node.input(0));
2016 ++found_nodes;
2017 } else if (node.name() == "out2") {
2018 EXPECT_EQ(1, node.input_size());
2019 EXPECT_EQ("m2", node.input(0));
2020 ++found_nodes;
2021 } else if (node.name() == "idx2") {
2022 EXPECT_EQ(1, node.input_size());
2023 EXPECT_EQ("m2:1", node.input(0));
2024 ++found_nodes;
2025 } else if (node.name() == "out3") {
2026 EXPECT_EQ(1, node.input_size());
2027 EXPECT_EQ("m3", node.input(0));
2028 ++found_nodes;
2029 } else if (node.name() == "idx3") {
2030 EXPECT_EQ(1, node.input_size());
2031 EXPECT_EQ("m3:1", node.input(0));
2032 ++found_nodes;
2033 } else if (node.name() == "out4") {
2034 EXPECT_EQ(1, node.input_size());
2035 EXPECT_EQ("m4", node.input(0));
2036 ++found_nodes;
2037 } else if (node.name() == "idx4") {
2038 EXPECT_EQ(1, node.input_size());
2039 EXPECT_EQ("m4:1", node.input(0));
2040 ++found_nodes;
2041 }
2042 }
2043 // Make sure the graph contains all the nodes we're expecting.
2044 EXPECT_EQ(8, found_nodes);
2045
2046 std::vector<string> fetch = {"out1", "idx1"};
2047 auto tensors = EvaluateNodes(output, fetch);
2048 EXPECT_EQ(2, tensors.size());
2049 const Tensor& out_value = tensors[0];
2050 EXPECT_EQ(3 * 5, out_value.NumElements());
2051 for (int i = 0; i < 3 * 5; ++i) {
2052 EXPECT_EQ(3.14f, out_value.flat<float>()(i));
2053 }
2054 const Tensor& out_idx = tensors[1];
2055 EXPECT_EQ(1, out_idx.NumElements());
2056 EXPECT_EQ(2, out_idx.flat<int32>()(0));
2057 }
2058
TEST_F(ConstantFoldingTest,SplitRemoval)2059 TEST_F(ConstantFoldingTest, SplitRemoval) {
2060 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2061
2062 Output in1 =
2063 ops::Variable(scope.WithOpName("in1"), TensorShape({2}), DT_FLOAT);
2064 Output in2 =
2065 ops::Variable(scope.WithOpName("in2"), TensorShape({4}), DT_FLOAT);
2066 auto split_dim = ops::Const(scope.WithOpName("split_dim"), {0}, {});
2067 ops::Split s1(scope.WithOpName("s1"), split_dim, in1, 1);
2068 ops::Split s2(scope.WithOpName("s2"), split_dim, in2, 2);
2069
2070 ops::Add out(scope.WithOpName("out"), s1[0], s2[0]);
2071
2072 GrapplerItem item;
2073 item.fetch = {"out"};
2074 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2075
2076 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2077 GraphDef got;
2078 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2079 TF_EXPECT_OK(status);
2080
2081 GraphDef want;
2082 AddNode("in1", "VariableV2", {}, {}, &want);
2083 AddNode("in2", "VariableV2", {}, {}, &want);
2084 AddNode("split_dim", "Const", {}, {}, &want);
2085 AddNode("s1", "Identity", {"in1", AsControlDependency("split_dim")}, {},
2086 &want);
2087 AddNode("s2", "Split", {"split_dim", "in2"}, {}, &want);
2088 AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2089
2090 CompareGraphs(want, got);
2091
2092 auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2}));
2093 auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4}));
2094 auto tensors_expected =
2095 EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2096 EXPECT_EQ(1, tensors_expected.size());
2097 auto tensors =
2098 EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2099 EXPECT_EQ(1, tensors.size());
2100 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2101 }
2102
TEST_F(ConstantFoldingTest,SplitVRemoval)2103 TEST_F(ConstantFoldingTest, SplitVRemoval) {
2104 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2105
2106 Output in1 =
2107 ops::Variable(scope.WithOpName("in1"), TensorShape({2}), DT_FLOAT);
2108 Output in2 =
2109 ops::Variable(scope.WithOpName("in2"), TensorShape({5}), DT_FLOAT);
2110 auto split_dim = ops::Const(scope.WithOpName("split_dim"), {0}, {});
2111 auto size_splits1 = ops::Const(scope.WithOpName("size_splits1"), {2}, {1});
2112 auto size_splits2 = ops::Const(scope.WithOpName("size_splits2"), {2, 3}, {2});
2113 ops::SplitV s1(scope.WithOpName("s1"), in1, size_splits1, split_dim, 1);
2114 ops::SplitV s2(scope.WithOpName("s2"), in2, size_splits2, split_dim, 2);
2115
2116 ops::Add out(scope.WithOpName("out"), s1[0], s2[0]);
2117
2118 GrapplerItem item;
2119 item.fetch = {"out"};
2120 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2121
2122 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2123 GraphDef got;
2124 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2125 TF_EXPECT_OK(status);
2126
2127 GraphDef want;
2128 AddNode("in1", "VariableV2", {}, {}, &want);
2129 AddNode("in2", "VariableV2", {}, {}, &want);
2130 AddNode("split_dim", "Const", {}, {}, &want);
2131 AddNode("size_splits1", "Const", {}, {}, &want);
2132 AddNode("size_splits2", "Const", {}, {}, &want);
2133 AddNode("s1", "Identity",
2134 {"in1", AsControlDependency("size_splits1"),
2135 AsControlDependency("split_dim")},
2136 {}, &want);
2137 AddNode("s2", "SplitV", {"in2", "size_splits2", "split_dim"}, {}, &want);
2138 AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2139
2140 CompareGraphs(want, got);
2141
2142 auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2}));
2143 auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5}));
2144 auto tensors_expected =
2145 EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2146 EXPECT_EQ(1, tensors_expected.size());
2147 auto tensors =
2148 EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2149 EXPECT_EQ(1, tensors.size());
2150 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2151 }
2152
TEST_F(ConstantFoldingTest,TransposeOnSize1DimsRemoval)2153 TEST_F(ConstantFoldingTest, TransposeOnSize1DimsRemoval) {
2154 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2155
2156 Output in1 = ops::Variable(scope.WithOpName("in1"), TensorShape({1, 2, 4, 1}),
2157 DT_FLOAT);
2158 Output p1 = ops::Const(scope.WithOpName("p1"), {3, 2, 1, 0}, {4});
2159 Output in2 = ops::Variable(scope.WithOpName("in2"), TensorShape({1, 4, 2, 1}),
2160 DT_FLOAT);
2161 Output p2 = ops::Const(scope.WithOpName("p2"), {3, 1, 2, 0}, {4});
2162 ops::Transpose t1(scope.WithOpName("t1"), in1, p1);
2163 ops::Transpose t2(scope.WithOpName("t2").WithControlDependencies({in1}), in2,
2164 p2);
2165
2166 ops::Add out1(scope.WithOpName("out1"), t1, t2);
2167
2168 GrapplerItem item;
2169 item.fetch = {"out1"};
2170 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2171
2172 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2173 GraphDef got;
2174 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2175 TF_EXPECT_OK(status);
2176
2177 GraphDef want;
2178 AddNode("in1", "VariableV2", {}, {}, &want);
2179 AddNode("in2", "VariableV2", {}, {}, &want);
2180 AddNode("p1", "Const", {}, {}, &want);
2181 AddNode("p2", "Const", {}, {}, &want);
2182 AddNode("t1", "Transpose", {"in1", "p1"}, {}, &want);
2183 AddNode("t2", "Identity",
2184 {"in2", AsControlDependency("in1"), AsControlDependency("p2")}, {},
2185 &want);
2186 AddNode("out1", "Add", {"t1", "t2"}, {}, &want);
2187
2188 CompareGraphs(want, got);
2189 }
2190
TEST_F(ConstantFoldingTest,RandomShuffleOnScalarRemoval)2191 TEST_F(ConstantFoldingTest, RandomShuffleOnScalarRemoval) {
2192 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2193
2194 Output in1 =
2195 ops::Variable(scope.WithOpName("in1"), TensorShape({}), DT_FLOAT);
2196 Output in2 =
2197 ops::Variable(scope.WithOpName("in2"), TensorShape({}), DT_FLOAT);
2198 ops::RandomShuffle s1(scope.WithOpName("s1"), in1);
2199 ops::RandomShuffle s2(scope.WithOpName("s2").WithControlDependencies({in1}),
2200 in2);
2201
2202 ops::Add out1(scope.WithOpName("out1"), s1, s2);
2203 ops::Identity out2(scope.WithOpName("out2"), s2);
2204
2205 GrapplerItem item;
2206 item.fetch = {"out1", "out2"};
2207 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2208
2209 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2210 GraphDef got;
2211 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2212 TF_EXPECT_OK(status);
2213
2214 GraphDef want;
2215 AddNode("in1", "VariableV2", {}, {}, &want);
2216 AddNode("in2", "VariableV2", {}, {}, &want);
2217 AddNode("s1", "Identity", {"in1"}, {}, &want);
2218 AddNode("s2", "Identity", {"in2", AsControlDependency("in1")}, {}, &want);
2219 AddNode("out1", "Add", {"s1", "s2"}, {}, &want);
2220 AddNode("out2", "Identity", {"s2"}, {}, &want);
2221
2222 CompareGraphs(want, got);
2223
2224 auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({}));
2225 auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({}));
2226 auto tensors_expected =
2227 EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2228 EXPECT_EQ(2, tensors_expected.size());
2229 auto tensors =
2230 EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2231 EXPECT_EQ(2, tensors.size());
2232 for (int i = 0; i < tensors.size(); i++)
2233 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
2234 }
2235
TEST_F(ConstantFoldingTest,ReverseOnSize1DimsRemoval)2236 TEST_F(ConstantFoldingTest, ReverseOnSize1DimsRemoval) {
2237 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2238
2239 Output in1 = ops::Variable(scope.WithOpName("in1"), TensorShape({1, 2, 4, 1}),
2240 DT_FLOAT);
2241 Output a1 = ops::Const(scope.WithOpName("a1"), {3, 2, 1, 0}, {4});
2242 Output in2 = ops::Variable(scope.WithOpName("in2"), TensorShape({1, 2, 4, 1}),
2243 DT_FLOAT);
2244 Output a2 = ops::Const(scope.WithOpName("a2"), {0, 3}, {2});
2245 ops::Reverse r1(scope.WithOpName("r1"), in1, a1);
2246 ops::Reverse r2(scope.WithOpName("r2").WithControlDependencies({in1}), in2,
2247 a2);
2248
2249 ops::Add out1(scope.WithOpName("out1"), r1, r2);
2250
2251 GrapplerItem item;
2252 item.fetch = {"out1"};
2253 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2254
2255 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2256 GraphDef got;
2257 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2258 TF_EXPECT_OK(status);
2259
2260 GraphDef want;
2261 AddNode("in1", "VariableV2", {}, {}, &want);
2262 AddNode("in2", "VariableV2", {}, {}, &want);
2263 AddNode("a1", "Const", {}, {}, &want);
2264 AddNode("a2", "Const", {}, {}, &want);
2265 AddNode("r1", "ReverseV2", {"in1", "a1"}, {}, &want);
2266 AddNode("r2", "Identity",
2267 {"in2", AsControlDependency("in1"), AsControlDependency("a2")}, {},
2268 &want);
2269 AddNode("out1", "Add", {"r1", "r2"}, {}, &want);
2270
2271 CompareGraphs(want, got);
2272 }
2273
TEST_F(ConstantFoldingTest,SliceWithSameDimensionRemoval)2274 TEST_F(ConstantFoldingTest, SliceWithSameDimensionRemoval) {
2275 { // size = {3, 5}
2276 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2277
2278 auto in1 = ops::Variable(scope.WithOpName("in1"), {3, 5}, DT_FLOAT);
2279 auto begin = ops::Const(scope.WithOpName("begin"), {0, 0}, {2});
2280 auto size = ops::Const(scope.WithOpName("size"), {3, 5}, {2});
2281 Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6}, DT_FLOAT);
2282 ops::Slice s1(scope.WithOpName("s1"), in1, begin, size);
2283 ops::Slice s2(scope.WithOpName("s2"), in2, begin, size);
2284
2285 ops::Add out(scope.WithOpName("out"), s1, s2);
2286
2287 GrapplerItem item;
2288 item.fetch = {"out"};
2289 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2290
2291 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2292 GraphDef got;
2293 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2294 TF_EXPECT_OK(status);
2295
2296 GraphDef want;
2297 AddNode("in1", "VariableV2", {}, {}, &want);
2298 AddNode("in2", "VariableV2", {}, {}, &want);
2299 AddNode("begin", "Const", {}, {}, &want);
2300 AddNode("size", "Const", {}, {}, &want);
2301 AddNode("s1", "Identity",
2302 {"in1", AsControlDependency("begin"), AsControlDependency("size")},
2303 {}, &want);
2304 AddNode("s2", "Slice", {"in2", "begin", "size"}, {}, &want);
2305 AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2306
2307 CompareGraphs(want, got);
2308
2309 auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5}));
2310 auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 6}));
2311 auto tensors_expected =
2312 EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2313 EXPECT_EQ(1, tensors_expected.size());
2314 auto tensors =
2315 EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2316 EXPECT_EQ(1, tensors.size());
2317 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2318 }
2319 { // size = {-1, -1}
2320 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2321
2322 auto in1 =
2323 ops::Variable(scope.WithOpName("in1"), {3, 5}, DataType::DT_FLOAT);
2324 auto begin1 = ops::Const(scope.WithOpName("begin1"), {0, 0}, {2});
2325 auto begin2 = ops::Const(scope.WithOpName("begin2"), {1, 1}, {2});
2326 auto size = ops::Const(scope.WithOpName("size"), {-1, -1}, {2});
2327 Output in2 =
2328 ops::Variable(scope.WithOpName("in2"), {4, 6}, DataType::DT_FLOAT);
2329 ops::Slice s1(scope.WithOpName("s1"), in1, begin1, size);
2330 ops::Slice s2(scope.WithOpName("s2"), in2, begin2, size);
2331
2332 ops::Add out(scope.WithOpName("out"), s1, s2);
2333
2334 GrapplerItem item;
2335 item.fetch = {"out"};
2336 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2337
2338 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2339 GraphDef got;
2340 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2341 TF_EXPECT_OK(status);
2342
2343 GraphDef want;
2344 AddNode("in1", "VariableV2", {}, {}, &want);
2345 AddNode("in2", "VariableV2", {}, {}, &want);
2346 AddNode("begin1", "Const", {}, {}, &want);
2347 AddNode("begin2", "Const", {}, {}, &want);
2348 AddNode("size", "Const", {}, {}, &want);
2349 AddNode("s1", "Identity",
2350 {"in1", AsControlDependency("begin1"), AsControlDependency("size")},
2351 {}, &want);
2352 AddNode("s2", "Slice", {"in2", "begin2", "size"}, {}, &want);
2353 AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2354
2355 CompareGraphs(want, got);
2356
2357 auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5}));
2358 auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 6}));
2359 auto tensors_expected =
2360 EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2361 EXPECT_EQ(1, tensors_expected.size());
2362 auto tensors =
2363 EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2364 EXPECT_EQ(1, tensors.size());
2365 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2366 }
2367 }
2368
TEST_F(ConstantFoldingTest,StridedSliceWithSameDimensionRemoval)2369 TEST_F(ConstantFoldingTest, StridedSliceWithSameDimensionRemoval) {
2370 { // no mask
2371 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2372
2373 auto in1 = ops::Variable(scope.WithOpName("in1"), {3, 5, 2}, DT_FLOAT);
2374 auto begin = ops::Const(scope.WithOpName("begin"), {0, 0}, {2});
2375 auto end = ops::Const(scope.WithOpName("end"), {3, 5}, {2});
2376 auto strides = ops::Const(scope.WithOpName("strides"), {1, 1}, {2});
2377 Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6, 2}, DT_FLOAT);
2378 ops::StridedSlice s1(scope.WithOpName("s1"), in1, begin, end, strides);
2379 ops::StridedSlice s2(scope.WithOpName("s2"), in2, begin, end, strides);
2380
2381 ops::Add out(scope.WithOpName("out"), s1, s2);
2382
2383 GrapplerItem item;
2384 item.fetch = {"out"};
2385 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2386
2387 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2388 GraphDef got;
2389 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2390 TF_EXPECT_OK(status);
2391
2392 GraphDef want;
2393 AddNode("in1", "VariableV2", {}, {}, &want);
2394 AddNode("in2", "VariableV2", {}, {}, &want);
2395 AddNode("begin", "Const", {}, {}, &want);
2396 AddNode("end", "Const", {}, {}, &want);
2397 AddNode("strides", "Const", {}, {}, &want);
2398 AddNode("s1", "Identity",
2399 {"in1", AsControlDependency("begin"), AsControlDependency("end"),
2400 AsControlDependency("strides")},
2401 {}, &want);
2402 AddNode("s2", "StridedSlice", {"in2", "begin", "end", "strides"}, {},
2403 &want);
2404 AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2405
2406 CompareGraphs(want, got);
2407
2408 auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 2}));
2409 auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 6, 2}));
2410 auto tensors_expected =
2411 EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2412 EXPECT_EQ(1, tensors_expected.size());
2413 auto tensors =
2414 EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2415 EXPECT_EQ(1, tensors.size());
2416 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2417 }
2418 { // with begin/end/ellipsis mask
2419 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2420
2421 // s1 = in1[:, ..., 0:5, 0:6]
2422 auto in1 =
2423 ops::Variable(scope.WithOpName("in1"), {2, 3, 4, 5, 6}, DT_FLOAT);
2424 auto begin1 = ops::Const(scope.WithOpName("begin1"), {0, 0, 0}, {3});
2425 auto end1 = ops::Const(scope.WithOpName("end1"), {0, 5, 6}, {3});
2426 auto strides1 = ops::Const(scope.WithOpName("strides1"), {1, 1, 1}, {3});
2427 ops::StridedSlice s1(
2428 scope.WithOpName("s1"), in1, begin1, end1, strides1,
2429 ops::StridedSlice::Attrs().BeginMask(1).EndMask(1).EllipsisMask(2));
2430
2431 Output in2 =
2432 ops::Variable(scope.WithOpName("in2"), {5, 8, 5, 6, 9}, DT_FLOAT);
2433 auto begin2 = ops::Const(scope.WithOpName("begin2"), {0, 0, 0, 0, 0}, {5});
2434 auto end2 = ops::Const(scope.WithOpName("end2"), {2, 3, 4, 5, 6}, {5});
2435 auto strides2 =
2436 ops::Const(scope.WithOpName("strides2"), {1, 1, 1, 1, 1}, {5});
2437 ops::StridedSlice s2(scope.WithOpName("s2"), in2, begin2, end2, strides2);
2438
2439 ops::Add out(scope.WithOpName("out"), s1, s2);
2440
2441 GrapplerItem item;
2442 item.fetch = {"out"};
2443 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2444
2445 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2446 GraphDef got;
2447 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2448 TF_EXPECT_OK(status);
2449
2450 GraphDef want;
2451 AddNode("in1", "VariableV2", {}, {}, &want);
2452 AddNode("in2", "VariableV2", {}, {}, &want);
2453 AddNode("begin1", "Const", {}, {}, &want);
2454 AddNode("end1", "Const", {}, {}, &want);
2455 AddNode("strides1", "Const", {}, {}, &want);
2456 AddNode("s1", "Identity",
2457 {"in1", AsControlDependency("begin1"), AsControlDependency("end1"),
2458 AsControlDependency("strides1")},
2459 {}, &want);
2460 AddNode("begin2", "Const", {}, {}, &want);
2461 AddNode("end2", "Const", {}, {}, &want);
2462 AddNode("strides2", "Const", {}, {}, &want);
2463 AddNode("s2", "StridedSlice", {"in2", "begin2", "end2", "strides2"}, {},
2464 &want);
2465 AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2466
2467 CompareGraphs(want, got);
2468
2469 auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 3, 4, 5, 6}));
2470 auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 8, 5, 6, 9}));
2471 auto tensors_expected =
2472 EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2473 EXPECT_EQ(1, tensors_expected.size());
2474 auto tensors =
2475 EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2476 EXPECT_EQ(1, tensors.size());
2477 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2478 }
2479 }
2480
TEST_F(ConstantFoldingTest,TileWithMultipliesBeingOne)2481 TEST_F(ConstantFoldingTest, TileWithMultipliesBeingOne) {
2482 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2483
2484 auto in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_FLOAT);
2485 auto in2 = ops::Variable(scope.WithOpName("in2"), {4, 3}, DT_FLOAT);
2486 auto multiplies1 = ops::Const(scope.WithOpName("multiplies1"), {1, 1}, {2});
2487 auto multiplies2 = ops::Const(scope.WithOpName("multiplies2"), {1, 2}, {2});
2488
2489 ops::Tile t1(scope.WithOpName("t1"), in1, multiplies1);
2490 ops::Tile t2(scope.WithOpName("t2"), in2, multiplies2);
2491
2492 ops::Add out(scope.WithOpName("out"), t1, t2);
2493
2494 GrapplerItem item;
2495 item.fetch = {"out"};
2496 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2497
2498 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2499 GraphDef got;
2500 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2501 TF_EXPECT_OK(status);
2502
2503 GraphDef want;
2504 AddNode("in1", "VariableV2", {}, {}, &want);
2505 AddNode("in2", "VariableV2", {}, {}, &want);
2506 AddNode("multiplies1", "Const", {}, {}, &want);
2507 AddNode("multiplies2", "Const", {}, {}, &want);
2508 AddNode("t1", "Identity", {"in1", AsControlDependency("multiplies1")}, {},
2509 &want);
2510 AddNode("t2", "Tile", {"in2", "multiplies2"}, {}, &want);
2511 AddNode("out", "Add", {"t1", "t2"}, {}, &want);
2512
2513 CompareGraphs(want, got);
2514 }
2515
TEST_F(ConstantFoldingTest,MergeConcat)2516 TEST_F(ConstantFoldingTest, MergeConcat) {
2517 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2518
2519 Output in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_FLOAT);
2520 Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6}, DT_FLOAT);
2521 Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
2522 Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
2523
2524 ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis);
2525 ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis);
2526
2527 GrapplerItem item;
2528 item.fetch = {"c2"};
2529 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2530
2531 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2532 GraphDef got;
2533 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2534 TF_EXPECT_OK(status);
2535
2536 GraphDef want;
2537 AddNode("in1", "VariableV2", {}, {}, &want);
2538 AddNode("in2", "VariableV2", {}, {}, &want);
2539 AddNode("in3", "VariableV2", {}, {}, &want);
2540 AddNode("axis", "Const", {}, {}, &want);
2541 AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "axis"}, {}, &want);
2542
2543 CompareGraphs(want, got);
2544 }
2545
TEST_F(ConstantFoldingTest,MergeConcat_SameInput)2546 TEST_F(ConstantFoldingTest, MergeConcat_SameInput) {
2547 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2548
2549 Output in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_FLOAT);
2550 Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6}, DT_FLOAT);
2551 Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
2552 Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
2553
2554 ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis);
2555 ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3, Output(c1)}, axis);
2556
2557 GrapplerItem item;
2558 item.fetch = {"c2"};
2559 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2560
2561 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2562 GraphDef got;
2563 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2564 TF_EXPECT_OK(status);
2565
2566 GraphDef want;
2567 AddNode("in1", "VariableV2", {}, {}, &want);
2568 AddNode("in2", "VariableV2", {}, {}, &want);
2569 AddNode("in3", "VariableV2", {}, {}, &want);
2570 AddNode("axis", "Const", {}, {}, &want);
2571 AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "in1", "in2", "axis"}, {},
2572 &want);
2573
2574 CompareGraphs(want, got);
2575 }
2576
TEST_F(ConstantFoldingTest,MergeConcat_ConcatWithConst)2577 TEST_F(ConstantFoldingTest, MergeConcat_ConcatWithConst) {
2578 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2579
2580 Output in1 = ops::Variable(scope.WithOpName("in1"), {2, 6}, DT_FLOAT);
2581 Output in2 = ops::Variable(scope.WithOpName("in2"), {}, DT_FLOAT);
2582 Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
2583 Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
2584
2585 ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis);
2586 ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis);
2587
2588 GrapplerItem item;
2589 item.fetch = {"c2"};
2590 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2591
2592 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2593 GraphDef got;
2594 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2595 TF_EXPECT_OK(status);
2596
2597 GraphDef want;
2598 AddNode("in1", "VariableV2", {}, {}, &want);
2599 AddNode("in2", "VariableV2", {}, {}, &want);
2600 AddNode("in3", "VariableV2", {}, {}, &want);
2601 AddNode("axis", "Const", {}, {}, &want);
2602 AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "axis"}, {}, &want);
2603
2604 CompareGraphs(want, got);
2605 }
2606
TEST_F(ConstantFoldingTest,MergeConcat_AxisMismatch)2607 TEST_F(ConstantFoldingTest, MergeConcat_AxisMismatch) {
2608 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2609
2610 Output in1 = ops::Variable(scope.WithOpName("in1"), {2, 5}, DT_FLOAT);
2611 Output in2 = ops::Variable(scope.WithOpName("in2"), {}, DT_FLOAT);
2612 Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
2613 Output axis1 = ops::Const(scope.WithOpName("axis1"), 0, {});
2614 Output axis2 = ops::Const(scope.WithOpName("axis2"), 1, {});
2615
2616 ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis2);
2617 ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis1);
2618
2619 GrapplerItem item;
2620 item.fetch = {"c2"};
2621 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2622
2623 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2624 GraphDef got;
2625 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2626 TF_EXPECT_OK(status);
2627
2628 GraphDef want;
2629 AddNode("in1", "VariableV2", {}, {}, &want);
2630 AddNode("in2", "VariableV2", {}, {}, &want);
2631 AddNode("in3", "VariableV2", {}, {}, &want);
2632 AddNode("axis1", "Const", {}, {}, &want);
2633 AddNode("axis2", "Const", {}, {}, &want);
2634 AddNode("c1", "ConcatV2", {"in1", "in2", "axis2"}, {}, &want);
2635 AddNode("c2", "ConcatV2", {"c1", "in3", "axis1"}, {}, &want);
2636
2637 CompareGraphs(want, got);
2638 }
2639
TEST_F(ConstantFoldingTest,MergeConcat_PartialFolding)2640 TEST_F(ConstantFoldingTest, MergeConcat_PartialFolding) {
2641 Scope scope = Scope::NewRootScope();
2642 Output c1 = ops::Const(scope.WithOpName("c1"), 1.0f, {2, 2});
2643 Output c2 = ops::Const(scope.WithOpName("c2"), 2.0f, {2, 2});
2644 Output c3 = ops::Const(scope.WithOpName("c3"), 3.0f, {2, 2});
2645 Output c4 = ops::Const(scope.WithOpName("c4"), 4.0f, {2, 2});
2646 Output ph = ops::Placeholder(scope.WithOpName("ph"), DT_FLOAT,
2647 ops::Placeholder::Shape(TensorShape({2, 2})));
2648 Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
2649
2650 ops::Concat concat1(scope.WithOpName("concat1"), {c1, c2, ph}, axis);
2651 ops::Concat concat2(scope.WithOpName("concat2"), {c3, c4, Output(concat1)},
2652 axis);
2653
2654 GrapplerItem item;
2655 item.fetch = {"concat2"};
2656 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2657
2658 ConstantFolding optimizer(nullptr);
2659 GraphDef got;
2660 Status status = optimizer.Optimize(nullptr, item, &got);
2661 TF_EXPECT_OK(status);
2662
2663 GraphDef want;
2664 AddNode("ConstantFolding/concat2_partial_split_0", "Const", {}, {}, &want);
2665 AddNode("axis", "Const", {}, {}, &want);
2666 AddNode("ph", "Placeholder", {}, {}, &want);
2667 AddNode("concat2", "ConcatV2",
2668 {"ConstantFolding/concat2_partial_split_0", "ph", "axis"}, {}, &want);
2669
2670 CompareGraphs(want, got);
2671 }
2672
TEST_F(ConstantFoldingTest,PaddingWithZeroSize)2673 TEST_F(ConstantFoldingTest, PaddingWithZeroSize) {
2674 PaddingWithZeroSize<int32>();
2675 PaddingWithZeroSize<int64>();
2676 }
2677
TEST_F(ConstantFoldingTest,SqueezeWithAllDimensionsGreaterThanOne)2678 TEST_F(ConstantFoldingTest, SqueezeWithAllDimensionsGreaterThanOne) {
2679 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2680
2681 auto in1 = ops::Variable(scope.WithOpName("in1"), {2, 3}, DT_INT32);
2682 auto in2 = ops::Variable(scope.WithOpName("in2"), {1, 2, 3, 1}, DT_INT32);
2683
2684 ops::Squeeze s1(scope.WithOpName("s1"), in1);
2685 ops::Squeeze s2(scope.WithOpName("s2"), in2);
2686
2687 ops::Add out(scope.WithOpName("out"), s1, s2);
2688
2689 GrapplerItem item;
2690 item.fetch = {"out"};
2691 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2692
2693 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2694 GraphDef got;
2695 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2696 TF_EXPECT_OK(status);
2697
2698 GraphDef want;
2699 AddNode("in1", "VariableV2", {}, {}, &want);
2700 AddNode("in2", "VariableV2", {}, {}, &want);
2701 AddNode("s1", "Identity", {"in1"}, {}, &want);
2702 AddNode("s2", "Squeeze", {"in2"}, {}, &want);
2703 AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2704
2705 CompareGraphs(want, got);
2706
2707 auto in1_t = GenerateRandomTensor<DT_INT32>(TensorShape({2, 3}));
2708 auto in2_t = GenerateRandomTensor<DT_INT32>(TensorShape({1, 2, 3, 1}));
2709 auto tensors_expected =
2710 EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2711 EXPECT_EQ(1, tensors_expected.size());
2712 auto tensors =
2713 EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2714 EXPECT_EQ(1, tensors.size());
2715 test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
2716 }
2717
TEST_F(ConstantFoldingTest,NoOpReduction)2718 TEST_F(ConstantFoldingTest, NoOpReduction) {
2719 // Build a simple graph with reductions that can be reduced to the
2720 // identity.
2721 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2722
2723 Output v = ops::Variable(scope.WithOpName("v"), {3, 5, 7}, DT_FLOAT);
2724 Output c =
2725 ops::Const(scope.WithOpName("c").WithControlDependencies(v), 0, {0});
2726 Output i = ops::Identity(scope.WithOpName("i"), c);
2727 Output p = ops::Prod(scope.WithOpName("p"), v, i);
2728 Output s = ops::Square(scope.WithOpName("s"), p);
2729
2730 Output v2 = ops::Variable(scope.WithOpName("v2"), {3, 5, 1}, DT_FLOAT);
2731 Output c2 =
2732 ops::Const(scope.WithOpName("c2").WithControlDependencies(v), 2, {1});
2733 ops::Prod::Attrs attr;
2734 attr = attr.KeepDims(true);
2735 Output p2 = ops::Prod(scope.WithOpName("p2"), v2, c2, attr);
2736
2737 // Test with unknown input shape.
2738 Output a = ops::Placeholder(scope.WithOpName("a"), DT_FLOAT);
2739 Output p3 = ops::Prod(scope.WithOpName("p3"), a, i, attr);
2740
2741 GrapplerItem item;
2742 item.fetch = {"s", "p2", "p3"};
2743 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2744
2745 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2746 GraphDef output;
2747 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2748 TF_EXPECT_OK(status);
2749
2750 int found = 0;
2751 for (const auto& node : output.node()) {
2752 if (node.name() == "p") {
2753 found++;
2754 EXPECT_EQ("Identity", node.op());
2755 EXPECT_EQ(2, node.input_size());
2756 EXPECT_EQ("v", node.input(0));
2757 EXPECT_EQ("^i", node.input(1));
2758 } else if (node.name() == "p2") {
2759 found++;
2760 EXPECT_EQ("Identity", node.op());
2761 EXPECT_EQ(2, node.input_size());
2762 EXPECT_EQ("v2", node.input(0));
2763 EXPECT_EQ("^c2", node.input(1));
2764 } else if (node.name() == "p3") {
2765 found++;
2766 EXPECT_EQ("Identity", node.op());
2767 EXPECT_EQ(2, node.input_size());
2768 EXPECT_EQ("a", node.input(0));
2769 EXPECT_EQ("^i", node.input(1));
2770 }
2771 }
2772 EXPECT_EQ(3, found);
2773
2774 auto v_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 7}));
2775 auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 1}));
2776 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 7}));
2777 auto tensors_expected = EvaluateNodes(item.graph, item.fetch,
2778 {{"v", v_t}, {"v2", v2_t}, {"a", a_t}});
2779 EXPECT_EQ(3, tensors_expected.size());
2780 auto tensors =
2781 EvaluateNodes(output, item.fetch, {{"v", v_t}, {"v2", v2_t}, {"a", a_t}});
2782 EXPECT_EQ(3, tensors.size());
2783 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2784 test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-5);
2785 test::ExpectTensorNear<float>(tensors_expected[2], tensors[2], 1e-5);
2786 }
2787
TEST_F(ConstantFoldingTest,SingleElementEmptyAxisReduction)2788 TEST_F(ConstantFoldingTest, SingleElementEmptyAxisReduction) {
2789 // Build a simple graph with reductions that involve single-element input and
2790 // no axes to reduce along.
2791 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2792
2793 Output input_var_three_dim = ops::Variable(
2794 scope.WithOpName("input_var_three_dim"), {1, 1, 1}, DT_FLOAT);
2795 Output input_var_one_dim =
2796 ops::Variable(scope.WithOpName("input_var_one_dim"), {1}, DT_FLOAT);
2797 Output one_axis = ops::Const(scope.WithOpName("one_axis"), {0}, {1});
2798 Output multiple_axes =
2799 ops::Const(scope.WithOpName("multiple_axes"), {1, 0}, {2});
2800 Output variable_axis =
2801 ops::Variable(scope.WithOpName("input_var_axis"), {1}, DT_INT32);
2802 ops::Mean::Attrs attr;
2803 attr = attr.KeepDims(false);
2804 // Should be optimized to Reshape.
2805 Output mean_1 = ops::Mean(scope.WithOpName("mean_1"), input_var_three_dim,
2806 one_axis, attr.KeepDims(false));
2807 Output mean_2 = ops::Mean(scope.WithOpName("mean_2"), input_var_three_dim,
2808 multiple_axes, attr.KeepDims(false));
2809 // Should remain as-is, since OutputProperties will not be known this node.
2810 Output mean_3 = ops::Mean(scope.WithOpName("mean_3"), input_var_one_dim,
2811 one_axis, attr.KeepDims(false));
2812 // Should remain as-is.
2813 Output mean_4 = ops::Mean(scope.WithOpName("mean_4"), input_var_three_dim,
2814 variable_axis, attr.KeepDims(false));
2815 // Should be optimized to Identity, since KeepDims=true.
2816 Output mean_5 = ops::Mean(scope.WithOpName("mean_5"), input_var_three_dim,
2817 multiple_axes, attr.KeepDims(true));
2818
2819 GrapplerItem item;
2820 item.fetch = {"mean_1", "mean_2", "mean_3", "mean_4", "mean_5"};
2821 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2822
2823 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2824 GraphDef output;
2825 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2826 TF_EXPECT_OK(status);
2827
2828 // Ensure Mean node is optimized to Reshape.
2829 int found = 0;
2830 for (const auto& node : output.node()) {
2831 if (node.name() == "mean_1" || node.name() == "mean_2") {
2832 found++;
2833 EXPECT_EQ("Reshape", node.op());
2834 EXPECT_EQ(2, node.input_size());
2835 EXPECT_EQ("input_var_three_dim", node.input(0));
2836 } else if (node.name() == "mean_3") {
2837 found++;
2838 EXPECT_EQ("Mean", node.op());
2839 EXPECT_EQ(2, node.input_size());
2840 EXPECT_EQ("input_var_one_dim", node.input(0));
2841 } else if (node.name() == "mean_4") {
2842 found++;
2843 EXPECT_EQ("Mean", node.op());
2844 EXPECT_EQ(2, node.input_size());
2845 EXPECT_EQ("input_var_three_dim", node.input(0));
2846 } else if (node.name() == "mean_5") {
2847 found++;
2848 EXPECT_EQ("Identity", node.op());
2849 EXPECT_EQ(2, node.input_size());
2850 EXPECT_EQ("^multiple_axes", node.input(1));
2851 }
2852 }
2853 EXPECT_EQ(5, found);
2854
2855 // Ensure resultant values from Mean and Reshape are the same.
2856 auto input_var_three_dim_t =
2857 GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 1, 1}));
2858 auto input_var_one_dim_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1}));
2859 Tensor input_var_axis_t(DT_INT32, TensorShape({1}));
2860 input_var_axis_t.flat<int32>()(0) = 0;
2861 auto tensors_expected =
2862 EvaluateNodes(item.graph, item.fetch,
2863 {{"input_var_three_dim", input_var_three_dim_t},
2864 {"input_var_one_dim", input_var_one_dim_t},
2865 {"input_var_axis", input_var_axis_t}});
2866 EXPECT_EQ(5, tensors_expected.size());
2867 auto tensors = EvaluateNodes(output, item.fetch,
2868 {{"input_var_three_dim", input_var_three_dim_t},
2869 {"input_var_one_dim", input_var_one_dim_t},
2870 {"input_var_axis", input_var_axis_t}});
2871 EXPECT_EQ(5, tensors.size());
2872 for (int i = 0; i < 5; ++i) {
2873 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
2874 }
2875 }
2876
TEST_F(ConstantFoldingTest,NoOpReshape)2877 TEST_F(ConstantFoldingTest, NoOpReshape) {
2878 // Build a simple graph with a reshape that can be reduced to the identity.
2879 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2880
2881 // A reshape than can be optimized
2882 Output d1 = ops::Const(scope.WithOpName("d1"), 3.14f, {17});
2883 Output v1 = ops::Variable(scope.WithOpName("v1"), {17}, DT_FLOAT);
2884 Output c1 =
2885 ops::Const(scope.WithOpName("c1").WithControlDependencies(v1), 17, {1});
2886 Output i1 = ops::Identity(scope.WithOpName("i1"), c1);
2887 Output r1 =
2888 ops::Reshape(scope.WithOpName("r1").WithControlDependencies(d1), v1, i1);
2889 Output s1 = ops::Square(scope.WithOpName("s1"), r1);
2890
2891 // A multi dimensional reshape than can be optimized
2892 Output v3 = ops::Variable(scope.WithOpName("v3"), {5, 5, 5}, DT_FLOAT);
2893 Output c3 =
2894 ops::Const(scope.WithOpName("c3").WithControlDependencies(v3), 5, {3});
2895 Output i3 = ops::Identity(scope.WithOpName("i3"), c3);
2896 Output r3 = ops::Reshape(scope.WithOpName("r3"), v3, i3);
2897 Output s3 = ops::Square(scope.WithOpName("s3"), r3);
2898
2899 // A multi dimensional partially defined reshape than can be optimized
2900 Output v4 = ops::Variable(scope.WithOpName("v4"), {5, 5, 5}, DT_FLOAT);
2901 Output c4 = ops::Const(scope.WithOpName("c4").WithControlDependencies(v4),
2902 {5, -1, 5}, {3});
2903 Output i4 = ops::Identity(scope.WithOpName("i4"), c4);
2904 Output r4 = ops::Reshape(scope.WithOpName("r4"), v4, i4);
2905 Output s4 = ops::Square(scope.WithOpName("s4"), r4);
2906
2907 // A reshape that can't be optimized
2908 Output v2 = ops::Variable(scope.WithOpName("v2"), {17, 1}, DT_FLOAT);
2909 Output c2 =
2910 ops::Const(scope.WithOpName("c2").WithControlDependencies(v2), 17, {1});
2911 Output r2 = ops::Reshape(scope.WithOpName("r2"), v2, c2);
2912 Output s2 = ops::Square(scope.WithOpName("s2"), r2);
2913
2914 GrapplerItem item;
2915 item.fetch = {"s1", "s2", "s3", "s4"};
2916 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2917
2918 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2919 GraphDef output;
2920 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2921 TF_EXPECT_OK(status);
2922
2923 int found = 0;
2924 for (const auto& node : output.node()) {
2925 if (node.name() == "r1") {
2926 ++found;
2927 EXPECT_EQ("Identity", node.op());
2928 ASSERT_EQ(3, node.input_size());
2929 EXPECT_EQ("v1", node.input(0));
2930 EXPECT_EQ("^i1", node.input(1));
2931 EXPECT_EQ("^d1", node.input(2));
2932 } else if (node.name() == "r3") {
2933 ++found;
2934 EXPECT_EQ("Identity", node.op());
2935 ASSERT_EQ(2, node.input_size());
2936 EXPECT_EQ("v3", node.input(0));
2937 EXPECT_EQ("^i3", node.input(1));
2938 } else if (node.name() == "r4") {
2939 ++found;
2940 EXPECT_EQ("Identity", node.op());
2941 ASSERT_EQ(2, node.input_size());
2942 EXPECT_EQ("v4", node.input(0));
2943 EXPECT_EQ("^i4", node.input(1));
2944 } else if (node.name() == "r2") {
2945 ++found;
2946 EXPECT_EQ("Reshape", node.op());
2947 ASSERT_EQ(2, node.input_size());
2948 EXPECT_EQ("v2", node.input(0));
2949 EXPECT_EQ("c2", node.input(1));
2950 }
2951 }
2952 EXPECT_EQ(4, found);
2953
2954 auto v1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({17}));
2955 auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({17, 1}));
2956 auto v3_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 5, 5}));
2957 auto v4_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 5, 5}));
2958 auto tensors_expected =
2959 EvaluateNodes(item.graph, item.fetch,
2960 {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}, {"v4", v4_t}});
2961 EXPECT_EQ(4, tensors_expected.size());
2962 auto tensors =
2963 EvaluateNodes(output, item.fetch,
2964 {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}, {"v4", v4_t}});
2965 EXPECT_EQ(4, tensors.size());
2966 for (int i = 0; i < tensors.size(); i++)
2967 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
2968 }
2969
TEST_F(ConstantFoldingTest,Packing)2970 TEST_F(ConstantFoldingTest, Packing) {
2971 // Build a simple graph with a large constant that can be folded.
2972 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2973 Output c = ops::Const(scope.WithOpName("c"), 3.14f, {1000});
2974 Output i1 = ops::Identity(scope.WithOpName("i1"), c);
2975 Output i2 = ops::Identity(scope.WithOpName("i2"), c);
2976
2977 GrapplerItem item;
2978 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2979
2980 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2981 GraphDef output;
2982 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2983 TF_EXPECT_OK(status);
2984
2985 const std::vector<string> fetch_nodes = {"i1", "i2"};
2986 auto tensors_expected = EvaluateNodes(item.graph, fetch_nodes);
2987 EXPECT_EQ(fetch_nodes.size(), tensors_expected.size());
2988 auto tensors = EvaluateNodes(output, fetch_nodes);
2989 EXPECT_EQ(fetch_nodes.size(), tensors.size());
2990 for (int i = 0; i < fetch_nodes.size(); i++)
2991 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
2992
2993 // Make sure that the representation of the folded constant is space
2994 // efficient: in particular, the whole message should be smaller than 8k
2995 // (the size needed to naively encode 1000 floats folded twice).
2996 EXPECT_GT(8000, output.ByteSizeLong());
2997 }
2998
TEST_F(ConstantFoldingTest,LargeConstantNoSizeIncrease)2999 TEST_F(ConstantFoldingTest, LargeConstantNoSizeIncrease) {
3000 // Build a simple graph with a large constant with size greater than
3001 // kMaxConstantSize that can be folded because the resulting size does not
3002 // increase.
3003 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3004 const int64 large_constant_size = kMaxConstantSize + 1;
3005 Output a = ops::Variable(scope.WithOpName("a"), {1, 1}, DT_FLOAT);
3006 Output b_const =
3007 ops::Const(scope.WithOpName("b_const"), 3.14f, {1, large_constant_size});
3008 Output b = ops::Identity(scope.WithOpName("b"), b_const);
3009 Output matmul = ops::MatMul(scope.WithOpName("matmul"), a, b);
3010
3011 GrapplerItem item;
3012 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3013
3014 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3015 GraphDef output;
3016 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3017 TF_EXPECT_OK(status);
3018
3019 item.graph.Swap(&output);
3020 status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3021 TF_EXPECT_OK(status);
3022
3023 for (const auto& node : output.node()) {
3024 if (node.name() == "b") {
3025 EXPECT_EQ("Const", node.op());
3026 }
3027 }
3028 EXPECT_EQ(4, output.node_size());
3029 EXPECT_LT(output.ByteSizeLong(), sizeof(float) * large_constant_size + 500);
3030 }
3031
TEST_F(ConstantFoldingTest,MaterializeBroadcastGradientArgs)3032 TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs) {
3033 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3034 Output a =
3035 ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
3036 ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
3037 Output b = ops::Square(s.WithOpName("b"), a);
3038 Output c = ops::Mul(s.WithOpName("c"), a, b);
3039 Output d = ops::Shape(s.WithOpName("d"), a);
3040 Output e = ops::Shape(s.WithOpName("e"), b);
3041
3042 auto f = ops::internal::BroadcastGradientArgs(s.WithOpName("f"), d, e);
3043 Output o1 = ops::Identity(s.WithOpName("o1"), f.r0);
3044 Output o2 = ops::Identity(s.WithOpName("o2"), f.r1);
3045
3046 Output g = ops::Placeholder(s.WithOpName("g"), DT_FLOAT,
3047 ops::Placeholder::Shape(PartialTensorShape({1})));
3048 Output h = ops::Shape(s.WithOpName("h"), g);
3049 auto i = ops::internal::BroadcastGradientArgs(s.WithOpName("i"), d, h);
3050 Output p1 = ops::Identity(s.WithOpName("p1"), i.r0);
3051 Output p2 = ops::Identity(s.WithOpName("p2"), i.r1);
3052
3053 GrapplerItem item;
3054 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3055
3056 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3057 GraphDef output;
3058 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3059 TF_EXPECT_OK(status);
3060
3061 std::vector<string> fetch_nodes = {"o1", "o2", "p1", "p2"};
3062 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 5}));
3063 auto g_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1}));
3064 auto tensors_expected =
3065 EvaluateNodes(item.graph, fetch_nodes, {{"a", a_t}, {"g", g_t}});
3066 EXPECT_EQ(fetch_nodes.size(), tensors_expected.size());
3067
3068 // Run a second time to make sure the optimization is idempotent.
3069 item.graph.Swap(&output);
3070 status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3071 TF_EXPECT_OK(status);
3072
3073 int found = 0;
3074 for (const auto& node : output.node()) {
3075 if (node.name() == "o1") {
3076 ++found;
3077 EXPECT_EQ(1, node.input_size());
3078 EXPECT_EQ("ConstantFolding/f-bcastargs-0", node.input(0));
3079 } else if (node.name() == "o2") {
3080 ++found;
3081 EXPECT_EQ(1, node.input_size());
3082 EXPECT_EQ("ConstantFolding/f-bcastargs-1", node.input(0));
3083 } else if (node.name() == "ConstantFolding/f-bcastargs-0") {
3084 ++found;
3085 EXPECT_EQ("Const", node.op());
3086 EXPECT_EQ(1, node.input_size());
3087 EXPECT_EQ("^f", node.input(0));
3088 EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
3089 .num_elements());
3090 } else if (node.name() == "ConstantFolding/f-bcastargs-1") {
3091 ++found;
3092 EXPECT_EQ("Const", node.op());
3093 EXPECT_EQ(1, node.input_size());
3094 EXPECT_EQ("^f", node.input(0));
3095 EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
3096 .num_elements());
3097 } else if (node.name() == "p1") {
3098 ++found;
3099 EXPECT_EQ(1, node.input_size());
3100 EXPECT_EQ("i", node.input(0));
3101 } else if (node.name() == "p2") {
3102 ++found;
3103 EXPECT_EQ(1, node.input_size());
3104 EXPECT_EQ("i:1", node.input(0));
3105 }
3106 }
3107 EXPECT_EQ(6, found);
3108
3109 auto tensors = EvaluateNodes(output, fetch_nodes, {{"a", a_t}, {"g", g_t}});
3110 EXPECT_EQ(fetch_nodes.size(), tensors.size());
3111 for (int i = 0; i < fetch_nodes.size(); i++)
3112 test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
3113 }
3114
TEST_F(ConstantFoldingTest,MaterializeBroadcastGradientArgs_InfiniteLoop)3115 TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs_InfiniteLoop) {
3116 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3117 Output a =
3118 ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
3119 ops::Placeholder::Shape(PartialTensorShape({2, 2})));
3120 Output b = ops::Square(s.WithOpName("b"), a);
3121 Output c = ops::Mul(s.WithOpName("c"), a, b);
3122 Output d = ops::Shape(s.WithOpName("d"), a);
3123 Output e = ops::Shape(s.WithOpName("e"), b);
3124
3125 auto f = ops::internal::BroadcastGradientArgs(s.WithOpName("f"), d, e);
3126 Output o1 = ops::Identity(s.WithOpName("o1"), f.r0);
3127 Output o2 = ops::Identity(s.WithOpName("o2"), f.r1);
3128
3129 GrapplerItem item;
3130 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3131
3132 std::vector<string> fetch_nodes = {"o1", "o2"};
3133 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
3134 auto tensors_expected = EvaluateNodes(item.graph, fetch_nodes, {{"a", a_t}});
3135 EXPECT_EQ(fetch_nodes.size(), tensors_expected.size());
3136
3137 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3138 GraphDef output;
3139 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3140 TF_EXPECT_OK(status);
3141
3142 // Run a second time to make sure the optimization is idempotent.
3143 item.graph.Swap(&output);
3144 status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3145 TF_EXPECT_OK(status);
3146
3147 EXPECT_EQ(11, output.node_size());
3148 int found = 0;
3149 for (const auto& node : output.node()) {
3150 if (node.name() == "ConstantFolding/f-folded-1") {
3151 ++found;
3152 EXPECT_EQ("Const", node.op());
3153 EXPECT_EQ(2, node.input_size());
3154 EXPECT_EQ("^a", node.input(0));
3155 EXPECT_EQ("^b", node.input(1));
3156 } else if (node.name() == "d") {
3157 ++found;
3158 EXPECT_EQ("Const", node.op());
3159 EXPECT_EQ(1, node.input_size());
3160 EXPECT_EQ("^a", node.input(0));
3161 } else if (node.name() == "e") {
3162 ++found;
3163 EXPECT_EQ("Const", node.op());
3164 EXPECT_EQ(1, node.input_size());
3165 EXPECT_EQ("^b", node.input(0));
3166 } else if (node.name() == "o1") {
3167 ++found;
3168 EXPECT_EQ(1, node.input_size());
3169 EXPECT_EQ("ConstantFolding/f-bcastargs-0", node.input(0));
3170 } else if (node.name() == "o2") {
3171 ++found;
3172 EXPECT_EQ(1, node.input_size());
3173 EXPECT_EQ("ConstantFolding/f-bcastargs-1", node.input(0));
3174 } else if (node.name() == "ConstantFolding/f-bcastargs-0") {
3175 ++found;
3176 EXPECT_EQ("Const", node.op());
3177 EXPECT_EQ(1, node.input_size());
3178 EXPECT_EQ("^ConstantFolding/f-folded-1", node.input(0));
3179 EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
3180 .num_elements());
3181 } else if (node.name() == "ConstantFolding/f-bcastargs-1") {
3182 ++found;
3183 EXPECT_EQ("Const", node.op());
3184 EXPECT_EQ(1, node.input_size());
3185 EXPECT_EQ("^ConstantFolding/f-folded-1", node.input(0));
3186 EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
3187 .num_elements());
3188 }
3189 }
3190 EXPECT_EQ(7, found);
3191 auto tensors = EvaluateNodes(output, fetch_nodes, {{"a", a_t}});
3192 EXPECT_EQ(fetch_nodes.size(), tensors.size());
3193 for (int i = 0; i < fetch_nodes.size(); i++)
3194 test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
3195 }
3196
TEST_F(ConstantFoldingTest,MaterializeReductionIndices)3197 TEST_F(ConstantFoldingTest, MaterializeReductionIndices) {
3198 for (bool use_reshape : {true, false}) {
3199 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3200 Output input =
3201 ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
3202 ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
3203 // If use_reshape is false, we need to now the number of indices to apply
3204 // the rewrite.
3205 Output indices = ops::Placeholder(
3206 s.WithOpName("indices"), DT_INT32,
3207 ops::Placeholder::Shape(PartialTensorShape({use_reshape ? -1 : 2})));
3208 Output sum = ops::Sum(s.WithOpName("sum"), input, indices);
3209 if (use_reshape) {
3210 Output size = ops::Const(s.WithOpName("size"), 1, {1});
3211 Output reshape = ops::Reshape(s.WithOpName("reshape"), sum, size);
3212 }
3213
3214 GrapplerItem item;
3215 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3216 item.fetch.push_back(use_reshape ? "reshape" : "sum");
3217
3218 auto input_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
3219 Tensor indices_t(DT_INT32, TensorShape({2}));
3220 indices_t.flat<int>()(0) = 0;
3221 indices_t.flat<int>()(1) = 1;
3222 auto tensors_expected = EvaluateNodes(
3223 item.graph, item.fetch, {{"input", input_t}, {"indices", indices_t}});
3224 EXPECT_EQ(1, tensors_expected.size());
3225
3226 // Use aggressive mode to force the shape inference to propagate placeholder
3227 // shapes.
3228 ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
3229 /*cpu_device=*/nullptr);
3230 GraphDef output;
3231 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3232 TF_EXPECT_OK(status);
3233
3234 // Run a second time to make sure the optimization is idempotent.
3235 item.graph.Swap(&output);
3236 status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3237 TF_EXPECT_OK(status);
3238
3239 int found = 0;
3240 for (const auto& node : output.node()) {
3241 if (node.name() == "ConstantFolding/sum-reduction_indices") {
3242 ++found;
3243 EXPECT_EQ("Const", node.op());
3244 EXPECT_EQ("^indices", node.input(0));
3245 EXPECT_EQ(2,
3246 TensorShape(node.attr().at("value").tensor().tensor_shape())
3247 .num_elements());
3248 } else if (node.name() == "sum") {
3249 ++found;
3250 EXPECT_EQ("ConstantFolding/sum-reduction_indices", node.input(1));
3251 } else if (node.name() == "indices") {
3252 ++found;
3253 }
3254 }
3255 EXPECT_EQ(3, found);
3256
3257 auto tensors = EvaluateNodes(output, item.fetch,
3258 {{"input", input_t}, {"indices", indices_t}});
3259 EXPECT_EQ(1, tensors.size());
3260 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
3261 }
3262 }
3263
TEST_F(ConstantFoldingTest,MaterializeReductionIndices_NotFullReduction)3264 TEST_F(ConstantFoldingTest, MaterializeReductionIndices_NotFullReduction) {
3265 for (bool input_rank_known : {true, false}) {
3266 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3267 Output input =
3268 (input_rank_known ? ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
3269 ops::Placeholder::Shape(
3270 PartialTensorShape({-1, -1})))
3271 : ops::Placeholder(s.WithOpName("input"), DT_FLOAT));
3272 Output indices =
3273 ops::Placeholder(s.WithOpName("indices"), DT_INT32,
3274 ops::Placeholder::Shape(
3275 PartialTensorShape({input_rank_known ? 1 : 2})));
3276 Output sum = ops::Sum(s.WithOpName("sum"), input, indices);
3277
3278 GrapplerItem item;
3279 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3280 item.fetch.push_back("sum");
3281
3282 // Use aggressive mode to force the shape inference to propagate placeholder
3283 // shapes.
3284 ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
3285 /*cpu_device=*/nullptr);
3286 GraphDef output;
3287 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3288 TF_EXPECT_OK(status);
3289
3290 CompareGraphs(item.graph, output);
3291 }
3292 }
3293
TEST_F(ConstantFoldingTest,LargeConstant)3294 TEST_F(ConstantFoldingTest, LargeConstant) {
3295 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3296 // Generate a 4k by 4k constant, non-compressible matrix.
3297 Output mat_diag =
3298 ops::Const(scope.WithOpName("mat_diag"), 3.14f, TensorShape({1024 * 4}));
3299 Output mat = ops::Diag(scope.WithOpName("mat"), mat_diag);
3300 Output out = ops::Identity(scope.WithOpName("out"), mat);
3301
3302 GrapplerItem item;
3303 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3304 item.fetch.push_back("out");
3305
3306 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3307 GraphDef output;
3308 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3309 TF_EXPECT_OK(status);
3310
3311 // Make sure the diag node hasn't been folded, since it would use too much
3312 // memory to encode the corresponding constant.
3313 int found = 0;
3314 for (const NodeDef& node : output.node()) {
3315 if (node.name() == "out") {
3316 EXPECT_EQ(node.op(), "Identity");
3317 ASSERT_EQ(node.input_size(), 1);
3318 EXPECT_EQ(node.input(0), "mat");
3319 ++found;
3320 } else if (node.name() == "mat") {
3321 EXPECT_EQ(node.op(), "Diag");
3322 ASSERT_EQ(node.input_size(), 1);
3323 EXPECT_EQ(node.input(0), "mat_diag");
3324 ++found;
3325 }
3326 }
3327 EXPECT_EQ(found, 2);
3328 // output should be no longer than the size of the constant "mat_diag"
3329 // plus a small constant amount for the remaining nodes.
3330 EXPECT_LT(output.ByteSizeLong(), sizeof(int) * 4 * 1024 + 500);
3331
3332 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3333 ASSERT_EQ(tensors_expected.size(), 1);
3334 auto tensors = EvaluateNodes(output, item.fetch);
3335 ASSERT_EQ(tensors.size(), 1);
3336 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
3337 }
3338
TEST_F(ConstantFoldingTest,SwitchIdenticalInputs)3339 TEST_F(ConstantFoldingTest, SwitchIdenticalInputs) {
3340 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3341 Output x = ops::Placeholder(s.WithOpName("x"), DT_BOOL,
3342 ops::Placeholder::Shape(TensorShape({})));
3343 ops::Switch sw = ops::Switch(s.WithOpName("switch"), x, x);
3344 Output id_false = ops::LogicalNot(s.WithOpName("id_false"), sw.output_false);
3345 Output id_true = ops::LogicalNot(s.WithOpName("id_true"), sw.output_true);
3346
3347 GrapplerItem item;
3348 item.fetch.push_back("id_false");
3349 item.fetch.push_back("id_true");
3350 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3351
3352 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3353 GraphDef output;
3354 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3355 TF_EXPECT_OK(status);
3356
3357 EXPECT_EQ(6, output.node_size());
3358 int found = 0;
3359 for (const auto& node : output.node()) {
3360 if (node.name() == "switch" || node.name() == "x") {
3361 ++found;
3362 }
3363 if (node.name() == "id_false") {
3364 EXPECT_EQ("Const", node.op());
3365 EXPECT_EQ(1, node.input_size());
3366 EXPECT_EQ("^ConstantFoldingCtrl/switch_0", node.input(0));
3367 ++found;
3368 }
3369 if (node.name() == "id_true") {
3370 EXPECT_EQ("Const", node.op());
3371 EXPECT_EQ(1, node.input_size());
3372 EXPECT_EQ("^ConstantFoldingCtrl/switch_1", node.input(0));
3373 ++found;
3374 }
3375 if (node.name() == "ConstantFoldingCtrl/switch_0") {
3376 EXPECT_EQ("Identity", node.op());
3377 EXPECT_EQ(1, node.input_size());
3378 EXPECT_EQ("switch", node.input(0));
3379 ++found;
3380 }
3381 if (node.name() == "ConstantFoldingCtrl/switch_1") {
3382 EXPECT_EQ("Identity", node.op());
3383 EXPECT_EQ(1, node.input_size());
3384 EXPECT_EQ("switch:1", node.input(0));
3385 ++found;
3386 }
3387 }
3388 EXPECT_EQ(6, found);
3389
3390 // Evaluate id_true when input tensor x is true.
3391 Tensor x_t(DT_BOOL, TensorShape({}));
3392 x_t.flat<bool>()(0) = true;
3393 auto tensors_expected = EvaluateNodes(item.graph, {"id_true"}, {{"x", x_t}});
3394 EXPECT_EQ(1, tensors_expected.size());
3395 auto tensors = EvaluateNodes(output, {"id_true"}, {{"x", x_t}});
3396 EXPECT_EQ(1, tensors.size());
3397 test::ExpectTensorEqual<bool>(tensors_expected[0], tensors[0]);
3398
3399 // Evaluate id_false when input tensor is false.
3400 x_t.flat<bool>()(0) = false;
3401 tensors_expected = EvaluateNodes(item.graph, {"id_false"}, {{"x", x_t}});
3402 EXPECT_EQ(1, tensors_expected.size());
3403 tensors = EvaluateNodes(output, {"id_false"}, {{"x", x_t}});
3404 EXPECT_EQ(1, tensors.size());
3405 test::ExpectTensorEqual<bool>(tensors_expected[0], tensors[0]);
3406 }
3407
TEST_F(ConstantFoldingTest,PartialFolding_AssociativeAndCommutative)3408 TEST_F(ConstantFoldingTest, PartialFolding_AssociativeAndCommutative) {
3409 std::function<Output(const Scope&, InputList)> addn_fun =
3410 [](const Scope& scope, InputList inputs) {
3411 return ops::AddN(scope, inputs);
3412 };
3413 std::function<Output(const Scope&, InputList)> accumulate_fun =
3414 [](const Scope& scope, InputList inputs) {
3415 return ops::AccumulateNV2(scope, inputs, TensorShape({2, 2}));
3416 };
3417 for (bool use_add_n : {true, false}) {
3418 auto fun = use_add_n ? addn_fun : accumulate_fun;
3419 const string op_name = use_add_n ? "AddN" : "AccumulateNV2";
3420 Scope s = Scope::NewRootScope();
3421 Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
3422 ops::Placeholder::Shape(TensorShape({2, 2})));
3423 Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
3424 ops::Placeholder::Shape(TensorShape({2, 2})));
3425 Output z = ops::Placeholder(s.WithOpName("z"), DT_FLOAT,
3426 ops::Placeholder::Shape(TensorShape({2, 2})));
3427 Output c1 = ops::Const(s.WithOpName("c1"), 1.0f, {2, 2});
3428 Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2, 2});
3429 Output c3 = ops::Const(s.WithOpName("c3"), 3.0f, {2, 2});
3430 Output acc0 = fun(s.WithOpName("acc0"), {c1, c2, c3});
3431 Output acc1 = fun(s.WithOpName("acc1"), {x, y, z});
3432 Output acc2 = fun(s.WithOpName("acc2"), {c1, x, y});
3433 Output acc3 = fun(s.WithOpName("acc3"), {c1, c2, z});
3434 Output acc4 = fun(s.WithOpName("acc4"), {c1, y, c2});
3435 Output acc5 = fun(s.WithOpName("acc5"), {x, c1, c2});
3436 Output acc6 = fun(s.WithOpName("acc6"), {x, c1, y, c2});
3437 Output stack = ops::Stack(s.WithOpName("stack"),
3438 {acc0, acc1, acc2, acc3, acc4, acc5, acc6});
3439
3440 GrapplerItem item;
3441 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3442 item.fetch = {"stack"};
3443
3444 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3445 GraphDef output;
3446 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3447 TF_EXPECT_OK(status);
3448
3449 EXPECT_EQ(16, output.node_size());
3450 for (const NodeDef& node : output.node()) {
3451 if (node.name() == "acc0") {
3452 EXPECT_EQ("Const", node.op());
3453 }
3454 if (node.name() == "acc1") {
3455 EXPECT_EQ(op_name, node.op());
3456 EXPECT_EQ(3, node.input_size());
3457 EXPECT_EQ("x", node.input(0));
3458 EXPECT_EQ("y", node.input(1));
3459 EXPECT_EQ("z", node.input(2));
3460 }
3461 if (node.name() == "acc2") {
3462 EXPECT_EQ(op_name, node.op());
3463 EXPECT_EQ(3, node.input_size());
3464 EXPECT_EQ("c1", node.input(0));
3465 EXPECT_EQ("x", node.input(1));
3466 EXPECT_EQ("y", node.input(2));
3467 }
3468 if (node.name() == "acc3") {
3469 EXPECT_EQ(op_name, node.op());
3470 EXPECT_EQ(2, node.input_size());
3471 EXPECT_EQ("ConstantFolding/acc3_partial_split_2", node.input(0));
3472 EXPECT_EQ("z", node.input(1));
3473 }
3474 if (node.name() == "acc4") {
3475 EXPECT_EQ(op_name, node.op());
3476 EXPECT_EQ(2, node.input_size());
3477 EXPECT_EQ("ConstantFolding/acc4_partial_split_2", node.input(0));
3478 EXPECT_EQ("y", node.input(1));
3479 }
3480 if (node.name() == "acc5") {
3481 EXPECT_EQ(op_name, node.op());
3482 EXPECT_EQ(2, node.input_size());
3483 EXPECT_EQ("x", node.input(0));
3484 EXPECT_EQ("ConstantFolding/acc5_partial_split_2", node.input(1));
3485 }
3486 if (node.name() == "acc6") {
3487 EXPECT_EQ(op_name, node.op());
3488 EXPECT_EQ(3, node.input_size());
3489 EXPECT_EQ("x", node.input(0));
3490 EXPECT_EQ("ConstantFolding/acc6_partial_split_2", node.input(1));
3491 EXPECT_EQ("y", node.input(2));
3492 }
3493 if (absl::StartsWith(node.name(), "ConstantFolding/")) {
3494 EXPECT_EQ("Const", node.op());
3495 }
3496 }
3497
3498 std::vector<string> fetch = {"acc0"};
3499 auto tensors_expected = EvaluateNodes(item.graph, fetch);
3500 auto tensors = EvaluateNodes(output, fetch);
3501 EXPECT_EQ(1, tensors_expected.size());
3502 EXPECT_EQ(1, tensors.size());
3503 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
3504 }
3505 }
3506
TEST_F(ConstantFoldingTest,PartialFolding_Concat)3507 TEST_F(ConstantFoldingTest, PartialFolding_Concat) {
3508 Scope s = Scope::NewRootScope();
3509 Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
3510 ops::Placeholder::Shape(TensorShape({2, 2})));
3511 Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
3512 ops::Placeholder::Shape(TensorShape({2, 2})));
3513 Output z = ops::Placeholder(s.WithOpName("z"), DT_FLOAT,
3514 ops::Placeholder::Shape(TensorShape({2, 2})));
3515 Output axis = ops::Const(s.WithOpName("axis"), 0, {});
3516 Output c1 = ops::Const(s.WithOpName("c1"), 1.0f, {2, 2});
3517 Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2, 2});
3518 Output concat0 = ops::Concat(s.WithOpName("concat0"), {c1, c2, c1}, axis);
3519 Output concat1 = ops::Concat(s.WithOpName("concat1"), {x, y, z}, axis);
3520 Output concat2 = ops::Concat(s.WithOpName("concat2"), {c1, x, y}, axis);
3521 Output concat3 = ops::Concat(s.WithOpName("concat3"), {c1, c2, z}, axis);
3522 Output concat4 = ops::Concat(s.WithOpName("concat4"), {c1, y, c2}, axis);
3523 Output concat5 = ops::Concat(s.WithOpName("concat5"), {x, c1, c2}, axis);
3524 Output concat6 = ops::Concat(s.WithOpName("concat6"), {x, c1, y, c2}, axis);
3525 Output concat7 = ops::Concat(s.WithOpName("concat7"), {x, y, c1, c2}, axis);
3526 Output concat8 = ops::Concat(s.WithOpName("concat8"), {x, c1, c2, y}, axis);
3527 Output concat9 = ops::Concat(s.WithOpName("concat9"), {c1, c2, x, y}, axis);
3528
3529 GrapplerItem item;
3530 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3531 item.fetch = {"concat0", "concat1", "concat2", "concat3", "concat4",
3532 "concat5", "concat6", "concat7", "concat8", "concat9"};
3533
3534 auto tensors_expected = EvaluateNodes(item.graph, {"concat0"});
3535 EXPECT_EQ(1, tensors_expected.size());
3536 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3537 GraphDef output;
3538 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3539 TF_EXPECT_OK(status);
3540 // Run the optimizer twice to make sure the rewrite is idempotent.
3541 item.graph.Swap(&output);
3542 status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3543 TF_EXPECT_OK(status);
3544
3545 EXPECT_EQ(21, output.node_size());
3546 for (int i = 0; i < output.node_size(); ++i) {
3547 const NodeDef& node = output.node(i);
3548 if (node.name() == "concat0") {
3549 EXPECT_EQ("Const", node.op());
3550 } else if (node.name() == "concat3") {
3551 EXPECT_EQ(3, node.input_size());
3552 EXPECT_EQ("ConstantFolding/concat3_partial_split_0", node.input(0));
3553 EXPECT_EQ("z", node.input(1));
3554 EXPECT_EQ("axis", node.input(2));
3555 } else if (node.name() == "concat5") {
3556 EXPECT_EQ(3, node.input_size());
3557 EXPECT_EQ("x", node.input(0));
3558 EXPECT_EQ("ConstantFolding/concat5_partial_split_1", node.input(1));
3559 EXPECT_EQ("axis", node.input(2));
3560 } else if (node.name() == "concat7") {
3561 EXPECT_EQ(4, node.input_size());
3562 EXPECT_EQ("x", node.input(0));
3563 EXPECT_EQ("y", node.input(1));
3564 EXPECT_EQ("ConstantFolding/concat7_partial_split_2", node.input(2));
3565 EXPECT_EQ("axis", node.input(3));
3566 } else if (node.name() == "concat8") {
3567 EXPECT_EQ(4, node.input_size());
3568 EXPECT_EQ("x", node.input(0));
3569 EXPECT_EQ("ConstantFolding/concat8_partial_split_1", node.input(1));
3570 EXPECT_EQ("y", node.input(2));
3571 EXPECT_EQ("axis", node.input(3));
3572 } else if (node.name() == "concat9") {
3573 EXPECT_EQ(4, node.input_size());
3574 EXPECT_EQ("ConstantFolding/concat9_partial_split_0", node.input(0));
3575 EXPECT_EQ("x", node.input(1));
3576 EXPECT_EQ("y", node.input(2));
3577 EXPECT_EQ("axis", node.input(3));
3578 } else if (absl::StartsWith(node.name(), "ConstantFolding/")) {
3579 EXPECT_EQ("Const", node.op());
3580 } else {
3581 EXPECT_EQ(item.graph.node(i).DebugString(), node.DebugString());
3582 }
3583 }
3584
3585 auto tensors = EvaluateNodes(output, {"concat0"});
3586 EXPECT_EQ(1, tensors.size());
3587 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
3588 }
3589
TEST_F(ConstantFoldingTest,PartialFolding_IdentityN)3590 TEST_F(ConstantFoldingTest, PartialFolding_IdentityN) {
3591 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3592 Output x = ops::Placeholder(scope.WithOpName("x"), DT_FLOAT,
3593 ops::Placeholder::Shape(TensorShape({})));
3594 Output c1 = ops::Const(scope.WithOpName("c1"), 1.0f, {2, 2});
3595 Output c2 = ops::Const(scope.WithOpName("c2"), 2.0f, {2, 2});
3596 auto id_n = ops::IdentityN(scope.WithOpName("id_n"), {c1, x, c2});
3597 auto id0 = ops::Identity(scope.WithOpName("id0"), id_n[0]);
3598 auto id1 = ops::Identity(scope.WithOpName("id1"), id_n[1]);
3599 auto add0 = ops::Add(scope.WithOpName("add0"), id_n[0], id_n[1]);
3600 auto add1 = ops::Add(scope.WithOpName("add1"), id_n[0], id_n[2]);
3601
3602 GrapplerItem item;
3603 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3604 item.fetch.push_back("id0");
3605 item.fetch.push_back("id1");
3606 item.fetch.push_back("add0");
3607 item.fetch.push_back("add1");
3608
3609 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3610 GraphDef output;
3611 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3612 TF_EXPECT_OK(status);
3613 EXPECT_EQ(8, output.node_size());
3614 for (const auto& node : output.node()) {
3615 // id_n should remain unchanged.
3616 if (node.name() == "id_n") {
3617 EXPECT_EQ(3, node.input_size());
3618 EXPECT_EQ("c1", node.input(0));
3619 EXPECT_EQ("x", node.input(1));
3620 EXPECT_EQ("c2", node.input(2));
3621 }
3622 // id0 should be constant folded, and a control dependency from id_n.
3623 if (node.name() == "id0") {
3624 EXPECT_EQ("Const", node.op());
3625 EXPECT_EQ(1, node.input_size());
3626 EXPECT_EQ("^id_n", node.input(0));
3627 }
3628 // id1 is unchanged.
3629 if ("id1" == node.name()) {
3630 EXPECT_EQ(1, node.input_size());
3631 EXPECT_EQ("id_n:1", node.input(0));
3632 }
3633
3634 if ("add0" == node.name()) {
3635 EXPECT_EQ(2, node.input_size());
3636 EXPECT_EQ("c1", node.input(0));
3637 EXPECT_EQ("id_n:1", node.input(1));
3638 }
3639 // add1 should bo constant folded and have a control dependency from id_n.
3640 if ("add1" == node.name()) {
3641 EXPECT_EQ("Const", node.op());
3642 EXPECT_EQ(1, node.input_size());
3643 EXPECT_EQ("^id_n", node.input(0));
3644 }
3645 }
3646
3647 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({}));
3648 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
3649 EXPECT_EQ(4, tensors_expected.size());
3650 auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
3651 EXPECT_EQ(4, tensors.size());
3652 for (int i = 0; i < tensors.size(); i++) {
3653 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
3654 }
3655 }
3656
TEST_F(ConstantFoldingTest,TrivialPack)3657 TEST_F(ConstantFoldingTest, TrivialPack) {
3658 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3659 Output x =
3660 ops::RandomNormal(scope.WithOpName("x"), {2, 2}, DataType::DT_FLOAT);
3661 Output y = ops::Const(scope.WithOpName("y"), {2.0f}, {});
3662 auto stack =
3663 ops::Stack(scope.WithOpName("stack").WithControlDependencies({y}), {x},
3664 ops::Stack::Axis(1));
3665 auto stack_no_axis = ops::Stack(scope.WithOpName("stack_no_axis"), {x});
3666
3667 GrapplerItem item;
3668 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3669 item.fetch = {"stack", "stack_no_axis"};
3670
3671 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3672 GraphDef output;
3673 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3674 TF_EXPECT_OK(status);
3675 EXPECT_EQ(7, output.node_size());
3676 int found = 0;
3677 for (const auto& node : output.node()) {
3678 if (node.name() == "stack") {
3679 EXPECT_EQ("ExpandDims", node.op());
3680 EXPECT_EQ(3, node.input_size());
3681 EXPECT_EQ("x", node.input(0));
3682 EXPECT_EQ("ConstantFolding/stack_const_axis", node.input(1));
3683 EXPECT_EQ("^y", node.input(2));
3684 ++found;
3685 } else if (node.name() == "stack_no_axis") {
3686 EXPECT_EQ("ExpandDims", node.op());
3687 EXPECT_EQ(2, node.input_size());
3688 EXPECT_EQ("x", node.input(0));
3689 EXPECT_EQ("ConstantFolding/stack_no_axis_const_axis", node.input(1));
3690 ++found;
3691 } else if (node.name() == "ConstantFolding/stack_const_axis") {
3692 EXPECT_EQ("Const", node.op());
3693 EXPECT_EQ(1, node.input_size());
3694 EXPECT_EQ("^x", node.input(0));
3695 ++found;
3696 }
3697 }
3698 EXPECT_EQ(found, 3);
3699
3700 std::vector<string> fetch = {"stack", "stack_no_axis"};
3701 auto tensors_expected = EvaluateNodes(item.graph, fetch);
3702 auto tensors = EvaluateNodes(output, fetch);
3703 EXPECT_EQ(2, tensors_expected.size());
3704 EXPECT_EQ(2, tensors.size());
3705 EXPECT_EQ(tensors_expected[0].shape(), tensors[0].shape());
3706 EXPECT_EQ(tensors_expected[1].shape(), tensors[1].shape());
3707 }
3708
3709 // The test does not evalaute the optimized and original graphs to check if
3710 // their outputs are the same. See b/78233179.
TEST_F(ConstantFoldingTest,Enter)3711 TEST_F(ConstantFoldingTest, Enter) {
3712 GrapplerItem item;
3713 AttrValue frame_name;
3714 frame_name.set_s("foo");
3715 AttrValue is_constant_true;
3716 is_constant_true.set_b(true);
3717 AttrValue is_constant_false;
3718 is_constant_false.set_b(false);
3719 AttrValue type;
3720 type.set_type(DT_FLOAT);
3721 AttrValue value;
3722 Tensor value_tensor(DT_FLOAT, TensorShape({}));
3723 value_tensor.flat<float>()(0) = 1;
3724 value_tensor.AsProtoTensorContent(value.mutable_tensor());
3725
3726 GraphDef& graph = item.graph;
3727 AddNode("x", "Placeholder", {}, {{"dtype", type}}, &graph);
3728 AddNode("c1", "Const", {"^x"}, {{"value", value}, {"dtype", type}}, &graph);
3729 AddNode("enter1", "Enter", {"x"},
3730 {{"T", type},
3731 {"frame_name", frame_name},
3732 {"is_constant", is_constant_true}},
3733 &graph);
3734 AddNode("enter2", "Enter", {"c1"},
3735 {{"T", type},
3736 {"frame_name", frame_name},
3737 {"is_constant", is_constant_true}},
3738 &graph);
3739 AddNode("enter3", "Enter", {"c1"},
3740 {{"T", type},
3741 {"frame_name", frame_name},
3742 {"is_constant", is_constant_false}},
3743 &graph);
3744 AddNode("id1", "Identity", {"enter1"}, {{"T", type}}, &graph);
3745 AddNode("id2", "Identity", {"enter2"}, {{"T", type}}, &graph);
3746 AddNode("id3", "Identity", {"enter2"}, {{"T", type}}, &graph);
3747 AddNode("id4", "Identity", {"enter3"}, {{"T", type}}, &graph);
3748 item.fetch.push_back("id1");
3749 item.fetch.push_back("id2");
3750 item.fetch.push_back("id3");
3751 item.fetch.push_back("id4");
3752
3753 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3754 GraphDef output;
3755 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3756 TF_EXPECT_OK(status);
3757 // Run the optimizer twice to make sure the rewrite is idempotent.
3758 item.graph.Swap(&output);
3759 status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3760 TF_EXPECT_OK(status);
3761
3762 EXPECT_EQ(9, output.node_size());
3763 for (const NodeDef& node : output.node()) {
3764 if (node.name() == "id1") {
3765 EXPECT_EQ("Identity", node.op());
3766 EXPECT_EQ(1, node.input_size());
3767 EXPECT_EQ("enter1", node.input(0));
3768 }
3769 if (node.name() == "id2" || node.name() == "id3") {
3770 EXPECT_EQ("Const", node.op());
3771 EXPECT_EQ(1, node.input_size());
3772 EXPECT_EQ("^enter2", node.input(0));
3773 }
3774 if (node.name() == "id4") {
3775 EXPECT_EQ("Identity", node.op());
3776 EXPECT_EQ(1, node.input_size());
3777 EXPECT_EQ("enter3", node.input(0));
3778 }
3779 }
3780 }
3781
TEST_F(ConstantFoldingTest,TensorArraySize)3782 TEST_F(ConstantFoldingTest, TensorArraySize) {
3783 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3784 Output size = ops::Const(scope.WithOpName("size"), 5, TensorShape({}));
3785 Output placeholder =
3786 ops::Placeholder(scope.WithOpName("placeholder"), DT_RESOURCE,
3787 ops::Placeholder::Shape(TensorShape({2})));
3788 Output foo = ops::Const(scope.WithOpName("foo"), 5.0f, TensorShape({}));
3789 auto dynamic_array =
3790 ops::TensorArray(scope.WithOpName("dynamic"), size, DT_FLOAT,
3791 ops::TensorArray::DynamicSize(true));
3792 auto static_array =
3793 ops::TensorArray(scope.WithOpName("static"), size, DT_FLOAT,
3794 ops::TensorArray::DynamicSize(false));
3795 auto dynamic_sz = ops::TensorArraySize(
3796 scope.WithOpName("dynamic_sz"), dynamic_array.handle, dynamic_array.flow);
3797 auto static_sz = ops::TensorArraySize(scope.WithOpName("static_sz"),
3798 static_array.handle, static_array.flow);
3799 auto placeholder_sz = ops::TensorArraySize(scope.WithOpName("placeholder_sz"),
3800 placeholder, foo);
3801
3802 GrapplerItem item;
3803 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3804
3805 auto tensors_expected =
3806 EvaluateNodes(item.graph, {"dynamic_sz", "static_sz"});
3807
3808 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3809 GraphDef output;
3810 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3811 TF_EXPECT_OK(status);
3812 // Run the optimizer twice to make sure the rewrite is idempotent.
3813 item.graph.Swap(&output);
3814 status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3815 TF_EXPECT_OK(status);
3816
3817 EXPECT_EQ(8, output.node_size());
3818 EXPECT_EQ("dynamic_sz", output.node(5).name());
3819 EXPECT_EQ("TensorArraySizeV3", output.node(5).op());
3820 EXPECT_EQ("static_sz", output.node(6).name());
3821 EXPECT_EQ("Const", output.node(6).op());
3822 EXPECT_EQ("placeholder_sz", output.node(7).name());
3823 EXPECT_EQ("TensorArraySizeV3", output.node(7).op());
3824
3825 auto tensors_actual = EvaluateNodes(output, {"dynamic_sz", "static_sz"});
3826 EXPECT_EQ(2, tensors_expected.size());
3827 EXPECT_EQ(2, tensors_actual.size());
3828 test::ExpectTensorEqual<int32>(tensors_expected[0], tensors_actual[0]);
3829 test::ExpectTensorEqual<int32>(tensors_expected[1], tensors_actual[1]);
3830 }
3831
TEST_F(ConstantFoldingTest,FoldingPreservesDenormalFlushing)3832 TEST_F(ConstantFoldingTest, FoldingPreservesDenormalFlushing) {
3833 // Multiplying min() with 0.1 gives a denormal without FTZ and zero with FTZ.
3834 // Make sure constant folding behaves the same way as TensorFlow.
3835 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3836
3837 Output a =
3838 ops::Const(s.WithOpName("a"), std::numeric_limits<float>::min(), {1});
3839 Output b = ops::Const(s.WithOpName("b"), 0.1f, {1});
3840 Output c = ops::Mul(s.WithOpName("c"), a, b);
3841
3842 GrapplerItem item;
3843 item.fetch.push_back("c");
3844 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3845
3846 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3847 GraphDef output;
3848 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3849 TF_EXPECT_OK(status);
3850
3851 EXPECT_EQ(1, output.node_size());
3852
3853 const NodeDef& node_d = output.node(0);
3854 EXPECT_EQ("c", node_d.name());
3855 EXPECT_EQ("Const", node_d.op());
3856
3857 std::vector<string> fetch = {"c"};
3858 auto tensors_expected = EvaluateNodes(item.graph, fetch);
3859 auto tensors = EvaluateNodes(output, fetch);
3860 EXPECT_EQ(1, tensors_expected.size());
3861 EXPECT_EQ(1, tensors.size());
3862 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
3863 }
3864
TEST_F(ConstantFoldingTest,EvaluatingLargeConstantNoFoldingMergingLoop)3865 TEST_F(ConstantFoldingTest, EvaluatingLargeConstantNoFoldingMergingLoop) {
3866 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3867
3868 int size = 10 * 1024 * 1024 / 4 / 2;
3869 Output nonconst =
3870 ops::RandomUniform(s.WithOpName("nonconst"), {size, 1}, DT_FLOAT);
3871 Output const1 = ops::Const(s.WithOpName("const1"), 0.0f, {size, 1});
3872 Output const2 = ops::Const(s.WithOpName("const2"), 1.0f, {size, 1});
3873 Output axis = ops::Const(s.WithOpName("axis"), -1, {});
3874 Output concat1 =
3875 ops::Concat(s.WithOpName("concat1"), {nonconst, const1}, axis);
3876 Output result = ops::Concat(s.WithOpName("result"), {concat1, const2}, axis);
3877
3878 GrapplerItem item;
3879 item.fetch.push_back("result");
3880 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3881
3882 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3883 GraphDef output;
3884 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3885 TF_EXPECT_OK(status);
3886
3887 std::vector<string> fetch = {"result"};
3888 auto tensors_expected = EvaluateNodes(item.graph, fetch);
3889 auto tensors = EvaluateNodes(output, fetch);
3890 EXPECT_EQ(1, tensors_expected.size());
3891 EXPECT_EQ(1, tensors.size());
3892 EXPECT_EQ(tensors_expected[0].shape(), tensors[0].shape());
3893 }
3894
3895 class ConstantFoldingCastConstTest : public GrapplerTest {
3896 protected:
ConstantFoldingCastConst(bool fetch_const,bool fetch_cast,bool fetch_const_child,bool fetch_cast_child)3897 void ConstantFoldingCastConst(bool fetch_const, bool fetch_cast,
3898 bool fetch_const_child, bool fetch_cast_child) {
3899 if (!fetch_const && !fetch_cast && !fetch_const_child &&
3900 !fetch_cast_child) {
3901 return;
3902 }
3903
3904 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3905 CreateCastConstGraph(s);
3906 GrapplerItem item;
3907 int expected_output_size = SetFetch(&item, fetch_const, fetch_cast,
3908 fetch_const_child, fetch_cast_child);
3909 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3910
3911 GraphDef output = ConstantFoldingOptimize(item);
3912 EXPECT_EQ(expected_output_size, output.node_size());
3913
3914 EvaluateAndCompareUnoptimized(item.graph, output, item.fetch);
3915 }
3916
3917 private:
CreateCastConstGraph(const tensorflow::Scope & s)3918 void CreateCastConstGraph(const tensorflow::Scope& s) {
3919 Output const1 = ops::Const(s.WithOpName("const1"), 2, {5, 5});
3920 Output cast = ops::Cast(s.WithOpName("cast"), const1, DT_FLOAT);
3921 Output const1_child = ops::Identity(s.WithOpName("const1_child"), const1);
3922 Output cast_child = ops::Identity(s.WithOpName("cast_child"), cast);
3923 }
3924
SetFetch(GrapplerItem * item,bool fetch_const,bool fetch_cast,bool fetch_const_child,bool fetch_cast_child)3925 int SetFetch(GrapplerItem* item, bool fetch_const, bool fetch_cast,
3926 bool fetch_const_child, bool fetch_cast_child) {
3927 int expected_output_size = 0;
3928 if (fetch_const) {
3929 item->fetch.push_back("const1");
3930 expected_output_size++;
3931 }
3932 if (fetch_cast) {
3933 item->fetch.push_back("cast");
3934 expected_output_size++;
3935 }
3936 if (fetch_const_child) {
3937 item->fetch.push_back("const1_child");
3938 expected_output_size++;
3939 }
3940 if (fetch_cast_child) {
3941 item->fetch.push_back("cast_child");
3942 expected_output_size++;
3943 }
3944 return expected_output_size;
3945 }
3946
ConstantFoldingOptimize(const GrapplerItem & item)3947 GraphDef ConstantFoldingOptimize(const GrapplerItem& item) {
3948 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3949 GraphDef output;
3950 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3951 TF_EXPECT_OK(status);
3952 return output;
3953 }
3954
EvaluateAndCompareUnoptimized(const GraphDef & unoptimized_graph,const GraphDef & optimized_graph,const std::vector<string> & fetch_nodes)3955 void EvaluateAndCompareUnoptimized(const GraphDef& unoptimized_graph,
3956 const GraphDef& optimized_graph,
3957 const std::vector<string>& fetch_nodes) {
3958 auto tensors_expected = EvaluateNodes(unoptimized_graph, fetch_nodes);
3959 auto tensors = EvaluateNodes(optimized_graph, fetch_nodes);
3960 ASSERT_EQ(fetch_nodes.size(), tensors_expected.size());
3961 ASSERT_EQ(fetch_nodes.size(), tensors.size());
3962 for (int i = 0; i < fetch_nodes.size(); i++) {
3963 if (fetch_nodes[i] == "const1" || fetch_nodes[i] == "const1_child") {
3964 test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
3965 } else {
3966 test::ExpectTensorEqual<float>(tensors_expected[i], tensors[i]);
3967 }
3968 }
3969 }
3970 };
3971
TEST_F(ConstantFoldingCastConstTest,CastConstFolding)3972 TEST_F(ConstantFoldingCastConstTest, CastConstFolding) {
3973 for (bool fetch_const : {false, true}) {
3974 for (bool fetch_cast : {false, true}) {
3975 for (bool fetch_const_child : {false, true}) {
3976 for (bool fetch_cast_child : {false, true}) {
3977 ConstantFoldingCastConst(fetch_const, fetch_cast, fetch_const_child,
3978 fetch_cast_child);
3979 }
3980 }
3981 }
3982 }
3983 }
3984
TEST_F(ConstantFoldingTest,MaterializeConstantValuedNode)3985 TEST_F(ConstantFoldingTest, MaterializeConstantValuedNode) {
3986 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3987
3988 Output x =
3989 ops::Placeholder(scope.WithOpName("x"), DT_FLOAT,
3990 ops::Placeholder::Shape(TensorShape({1, 2, 3, 4})));
3991 Output ones_like = ops::OnesLike(scope.WithOpName("ones_like"), x);
3992 Output zeros_like = ops::ZerosLike(scope.WithOpName("zeros_like"), x);
3993 Output fill = ops::Fill(scope.WithOpName("fill"), {4, 3, 2, 1}, 42);
3994
3995 GrapplerItem item;
3996 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3997 item.fetch = {"ones_like", "zeros_like", "fill"};
3998 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 2, 3, 4}));
3999 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
4000
4001 ConstantFolding optimizer(/*cpu_device=*/nullptr);
4002 GraphDef output;
4003 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
4004 TF_EXPECT_OK(status);
4005
4006 EXPECT_EQ(output.node_size(), 6);
4007 for (const auto& node : output.node()) {
4008 if (node.name() != "x") {
4009 EXPECT_EQ(node.op(), "Const");
4010 }
4011 if (node.name() == "ones_like" || node.name() == "zeros_like") {
4012 ASSERT_EQ(node.input_size(), 1);
4013 EXPECT_EQ(node.input(0), "^x");
4014 }
4015 if (node.name() == "fill") {
4016 ASSERT_EQ(node.input_size(), 2);
4017 EXPECT_EQ(node.input(0)[0], '^');
4018 EXPECT_EQ(node.input(1)[0], '^');
4019 }
4020 }
4021 auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
4022 ASSERT_EQ(item.fetch.size(), tensors.size());
4023 ASSERT_EQ(tensors_expected.size(), tensors.size());
4024 for (int i = 0; i < tensors.size(); i++) {
4025 if (item.fetch[i] == "fill") {
4026 test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
4027 } else {
4028 test::ExpectTensorEqual<float>(tensors_expected[i], tensors[i]);
4029 }
4030 }
4031 }
4032
TEST_F(ConstantFoldingTest,MaterializeConstantValuedNodeDisableCompression)4033 TEST_F(ConstantFoldingTest, MaterializeConstantValuedNodeDisableCompression) {
4034 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
4035
4036 Output x =
4037 ops::Placeholder(scope.WithOpName("x"), DT_FLOAT,
4038 ops::Placeholder::Shape(TensorShape({1, 2, 3, 4})));
4039 Output ones_like = ops::OnesLike(scope.WithOpName("ones_like"), x);
4040 Output zeros_like = ops::ZerosLike(scope.WithOpName("zeros_like"), x);
4041 Output fill = ops::Fill(scope.WithOpName("fill"), {4, 3, 2, 1}, 42);
4042
4043 GrapplerItem item;
4044 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
4045 item.fetch = {"ones_like", "zeros_like", "fill"};
4046 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 2, 3, 4}));
4047 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
4048
4049 ConstantFolding optimizer(/*cpu_device=*/nullptr, true);
4050 GraphDef output;
4051 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
4052 TF_EXPECT_OK(status);
4053
4054 EXPECT_EQ(output.node_size(), 6);
4055 for (const auto& node : output.node()) {
4056 if (node.name() == "ones_like") {
4057 EXPECT_EQ(node.op(), "OnesLike");
4058 ASSERT_EQ(node.input_size(), 1);
4059 EXPECT_EQ(node.input(0), "x");
4060 }
4061 if (node.name() == "zeros_like") {
4062 EXPECT_EQ(node.op(), "ZerosLike");
4063 ASSERT_EQ(node.input_size(), 1);
4064 EXPECT_EQ(node.input(0), "x");
4065 }
4066 if (node.name() == "fill") {
4067 EXPECT_EQ(node.op(), "Fill");
4068 ASSERT_EQ(node.input_size(), 2);
4069 EXPECT_EQ(node.input(0), "Const/Const");
4070 EXPECT_EQ(node.input(1), "Const_1/Const");
4071 }
4072 }
4073 auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
4074 ASSERT_EQ(item.fetch.size(), tensors.size());
4075 ASSERT_EQ(tensors_expected.size(), tensors.size());
4076 for (int i = 0; i < tensors.size(); i++) {
4077 if (item.fetch[i] == "fill") {
4078 test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
4079 } else {
4080 test::ExpectTensorEqual<float>(tensors_expected[i], tensors[i]);
4081 }
4082 }
4083 }
4084
TEST_F(ConstantFoldingTest,MaterializeConstantValuedNodeHugeFill)4085 TEST_F(ConstantFoldingTest, MaterializeConstantValuedNodeHugeFill) {
4086 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
4087 Output value = ops::Const(scope.WithOpName("value"), 42, {});
4088 Output shape_const = ops::Const(scope.WithOpName("shape"),
4089 {1024, 1024, 1024, 1024, 1024}, {5});
4090 Output fill_huge =
4091 ops::Fill(scope.WithOpName("fill_huge"), shape_const, value);
4092
4093 GrapplerItem item;
4094 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
4095 // Manually convert the input value format to tensor_content to test this
4096 // case.
4097 NodeDef* node = item.graph.mutable_node(0);
4098 ASSERT_EQ(node->name(), "value");
4099 TensorProto* t = (*node->mutable_attr())["value"].mutable_tensor();
4100 t->clear_int_val();
4101 int val = 42;
4102 port::CopyFromArray(t->mutable_tensor_content(),
4103 reinterpret_cast<const char*>(&val), sizeof(int));
4104 item.fetch = {"fill_huge"};
4105 ConstantFolding optimizer(/*cpu_device=*/nullptr);
4106 GraphDef output;
4107 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
4108 TF_EXPECT_OK(status);
4109
4110 EXPECT_EQ(output.node_size(), 3);
4111 for (const auto& node : output.node()) {
4112 EXPECT_EQ(node.op(), "Const");
4113 if (node.name() == "fill_huge") {
4114 ASSERT_EQ(node.input_size(), 2);
4115 EXPECT_EQ(node.input(0), "^shape");
4116 EXPECT_EQ(node.input(1), "^value");
4117 }
4118 }
4119 }
4120
TEST_F(ConstantFoldingTest,BitcastDenormalFloats)4121 TEST_F(ConstantFoldingTest, BitcastDenormalFloats) {
4122 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
4123
4124 Tensor x_t(DT_INT64, TensorShape({2, 2}));
4125 x_t.flat<int64>()(0) = 9223372036854775807L;
4126 x_t.flat<int64>()(1) = 1L;
4127 x_t.flat<int64>()(2) = 9223372036854775807L;
4128 x_t.flat<int64>()(3) = 1L;
4129 Output x = ops::Const(scope.WithOpName("x"), x_t);
4130 Output y = ops::Bitcast(scope.WithOpName("y"), x, DT_FLOAT);
4131 Output z = ops::Bitcast(scope.WithOpName("z"), y, DT_INT64);
4132
4133 GrapplerItem item;
4134 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
4135 item.fetch = {"z"};
4136 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {});
4137
4138 ConstantFolding optimizer(/*cpu_device=*/nullptr);
4139 GraphDef output;
4140 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
4141 TF_EXPECT_OK(status);
4142
4143 ASSERT_EQ(output.node_size(), 1);
4144 const NodeDef& node = output.node(0);
4145 EXPECT_EQ(node.name(), "z");
4146 EXPECT_EQ(node.op(), "Const");
4147
4148 auto tensors = EvaluateNodes(output, item.fetch, {});
4149 ASSERT_EQ(tensors.size(), 1);
4150 ASSERT_EQ(tensors_expected.size(), 1);
4151 test::ExpectTensorEqual<int64>(tensors[0], tensors_expected[0]);
4152 }
4153
TEST_F(ConstantFoldingTest,SimplifyCase)4154 TEST_F(ConstantFoldingTest, SimplifyCase) {
4155 using test::function::NDef;
4156
4157 for (int index = 0; index < 2; ++index) {
4158 // Build a graph to compute y = Case(index, x, XTimesTwo(x), NonZero(x))
4159 GrapplerItem item;
4160 constexpr char kDevice[] = "/job:localhost/replica:0/task:0/device:CPU:0";
4161 AttrValue branches;
4162 auto* f = branches.mutable_list()->add_func();
4163 f->set_name("XTimesTwo");
4164 (*f->mutable_attr())["T"].set_type(DT_FLOAT);
4165 auto* g = branches.mutable_list()->add_func();
4166 *g = *f;
4167 g->set_name("NonZero");
4168
4169 // Add a pair of somewhat arbitrary output shapes to
4170 // test that they are correctly propagates to the _output_shapes
4171 // attribute.
4172 AttrValue output_shapes;
4173 // The first shape is a scalar.
4174 output_shapes.mutable_list()->add_shape();
4175 // The second shape is unknown.
4176 TensorShapeProto* g_shape = output_shapes.mutable_list()->add_shape();
4177 g_shape->set_unknown_rank(true);
4178
4179 const Tensor kZero = test::AsScalar<int32>(0);
4180 const Tensor kOne = test::AsScalar<int32>(1);
4181 item.graph = test::function::GDef(
4182 {NDef("one", "Const", {},
4183 {{"value", index == 0 ? kZero : kOne}, {"dtype", DT_INT32}},
4184 kDevice),
4185 NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
4186 NDef("case", "Case", {"one", "x"},
4187 {{"Tin", DataTypeSlice{DT_FLOAT}},
4188 {"Tout", DataTypeSlice{DT_FLOAT}},
4189 {"branches", branches},
4190 {"output_shapes", output_shapes}},
4191 kDevice),
4192 NDef("y", "Identity", {"case"}, {{"T", DT_FLOAT}}, kDevice)},
4193 // FunctionLib
4194 {
4195 test::function::XTimesTwo(),
4196 test::function::NonZero(),
4197 });
4198 VLOG(1) << "Before: " << item.graph.DebugString();
4199
4200 item.fetch = {"y"};
4201 const Tensor kTwo = test::AsScalar<float>(2.0f);
4202 auto tensors_expected =
4203 EvaluateNodes(item.graph, item.fetch, {{"x", kTwo}});
4204
4205 ConstantFolding optimizer(/*cpu_device=*/nullptr);
4206 GraphDef optimized_graph;
4207 TF_ASSERT_OK(
4208 optimizer.Optimize(/*cluster=*/nullptr, item, &optimized_graph));
4209 VLOG(1) << "After: " << optimized_graph.DebugString();
4210
4211 int pco_count = 0;
4212 for (const auto& node : optimized_graph.node()) {
4213 EXPECT_NE(node.op(), "Case");
4214 if (node.op() == "PartitionedCall") {
4215 ++pco_count;
4216 const auto& shape_list = node.attr().at("_output_shapes").list();
4217 ASSERT_EQ(shape_list.shape_size(), 1);
4218 EXPECT_EQ(shape_list.shape(0).dim_size(), 0);
4219 if (index == 0) {
4220 EXPECT_EQ(node.attr().at("f").func().name(), "XTimesTwo");
4221 EXPECT_EQ(shape_list.shape(0).unknown_rank(), false);
4222 } else {
4223 EXPECT_EQ(node.attr().at("f").func().name(), "NonZero");
4224 EXPECT_EQ(shape_list.shape(0).unknown_rank(), true);
4225 }
4226 }
4227 }
4228 EXPECT_EQ(pco_count, 1);
4229
4230 auto tensors = EvaluateNodes(optimized_graph, item.fetch, {{"x", kTwo}});
4231 ASSERT_EQ(tensors.size(), tensors_expected.size());
4232 test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
4233 }
4234 }
4235
TEST_F(ConstantFoldingTest,SimplifySelect)4236 TEST_F(ConstantFoldingTest, SimplifySelect) {
4237 for (bool scalar_pred : {true, false}) {
4238 for (bool pred_val : {true, false}) {
4239 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
4240 std::unique_ptr<Tensor> if_t;
4241 if (scalar_pred) {
4242 if_t.reset(new Tensor(DT_BOOL, TensorShape()));
4243 } else {
4244 if_t.reset(new Tensor(DT_BOOL, TensorShape({2, 2})));
4245 }
4246 for (int i = 0; i < (scalar_pred ? 1 : 4); ++i) {
4247 if_t->flat<bool>()(i) = pred_val;
4248 }
4249 Output if_ = ops::Const(scope.WithOpName("if"), *if_t);
4250 Output then_ =
4251 ops::Placeholder(scope.WithOpName("then"), DT_FLOAT,
4252 ops::Placeholder::Shape(TensorShape({2, 2})));
4253 Output else_ =
4254 ops::Placeholder(scope.WithOpName("else"), DT_FLOAT,
4255 ops::Placeholder::Shape(TensorShape({2, 2})));
4256 Output select =
4257 ops::SelectV2(scope.WithOpName("select"), if_, then_, else_);
4258 Output id = ops::Identity(scope.WithOpName("id"), select);
4259
4260 GrapplerItem item;
4261 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
4262 item.fetch = {"id"};
4263
4264 const Tensor kOne =
4265 test::AsTensor<float>({1.0f, 1.0f, 1.0f, 1.0f}, TensorShape({2, 2}));
4266 const Tensor kTwo =
4267 test::AsTensor<float>({2.0f, 2.0f, 2.0f, 2.0f}, TensorShape({2, 2}));
4268 auto tensors_expected = EvaluateNodes(item.graph, item.fetch,
4269 {{"then", kOne}, {"else", kTwo}});
4270
4271 // Use aggressive mode to force the shape inference to propagate
4272 // placeholder shapes.
4273 ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
4274 /*cpu_device=*/nullptr);
4275 GraphDef optimized_graph;
4276 TF_EXPECT_OK(
4277 optimizer.Optimize(/*cluster=*/nullptr, item, &optimized_graph));
4278
4279 ASSERT_EQ(optimized_graph.node_size(), 5);
4280 bool found = false;
4281 for (const auto& node : optimized_graph.node()) {
4282 if (node.name() == "select") {
4283 found = true;
4284 EXPECT_EQ(node.op(), "Identity");
4285 ASSERT_EQ(node.input_size(), 3);
4286 EXPECT_EQ(node.input(0), pred_val ? "then" : "else");
4287 EXPECT_EQ(node.input(1), pred_val ? "^if" : "^then");
4288 EXPECT_EQ(node.input(2), pred_val ? "^else" : "^if");
4289 }
4290 }
4291 EXPECT_TRUE(found);
4292
4293 auto tensors = EvaluateNodes(optimized_graph, item.fetch,
4294 {{"then", kOne}, {"else", kTwo}});
4295 ASSERT_EQ(tensors.size(), 1);
4296 ASSERT_EQ(tensors_expected.size(), 1);
4297 test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
4298 }
4299 }
4300 }
4301
TEST_F(ConstantFoldingTest,SimplifySelect_BroadcastTo)4302 TEST_F(ConstantFoldingTest, SimplifySelect_BroadcastTo) {
4303 for (TensorShape pred_shape : {TensorShape{2, 1}, TensorShape{2, 2, 1}}) {
4304 for (bool pred_val : {true, false}) {
4305 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
4306 std::unique_ptr<Tensor> if_t;
4307 if_t.reset(new Tensor(DT_BOOL, pred_shape));
4308 for (int i = 0; i < pred_shape.num_elements(); ++i) {
4309 if_t->flat<bool>()(i) = pred_val;
4310 }
4311 Output if_ = ops::Const(scope.WithOpName("if"), *if_t);
4312 Output then_ =
4313 ops::Placeholder(scope.WithOpName("then"), DT_FLOAT,
4314 ops::Placeholder::Shape(TensorShape({2, 1})));
4315 Output else_ =
4316 ops::Placeholder(scope.WithOpName("else"), DT_FLOAT,
4317 ops::Placeholder::Shape(TensorShape({2, 4})));
4318 Output select =
4319 ops::SelectV2(scope.WithOpName("select"), if_, then_, else_);
4320 Output id = ops::Identity(scope.WithOpName("id"), select);
4321
4322 GrapplerItem item;
4323 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
4324 item.fetch = {"id"};
4325
4326 const Tensor kOne =
4327 test::AsTensor<float>({1.0f, 1.0f}, TensorShape({2, 1}));
4328 const Tensor kTwo = test::AsTensor<float>(
4329 {2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f},
4330 TensorShape({2, 4}));
4331 auto tensors_expected = EvaluateNodes(item.graph, item.fetch,
4332 {{"then", kOne}, {"else", kTwo}});
4333
4334 // Use aggressive mode to force the shape inference to propagate
4335 // placeholder shapes.
4336 ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
4337 /*cpu_device=*/nullptr);
4338 GraphDef optimized_graph;
4339 TF_EXPECT_OK(
4340 optimizer.Optimize(/*cluster=*/nullptr, item, &optimized_graph));
4341
4342 ASSERT_EQ(optimized_graph.node_size(), 6);
4343 bool found = false;
4344 for (const auto& node : optimized_graph.node()) {
4345 if (node.name() == "select") {
4346 found = true;
4347 EXPECT_EQ(node.op(), "BroadcastTo");
4348 ASSERT_EQ(node.input_size(), 4);
4349 EXPECT_EQ(node.input(0), pred_val ? "then" : "else");
4350 EXPECT_EQ(node.input(1),
4351 strings::StrCat("ConstantFolding/select-broadcastto_shape-",
4352 pred_val ? 1 : 2));
4353 EXPECT_EQ(node.input(2), pred_val ? "^else" : "^if");
4354 EXPECT_EQ(node.input(3), pred_val ? "^if" : "^then");
4355 }
4356 }
4357 EXPECT_TRUE(found);
4358
4359 auto tensors = EvaluateNodes(optimized_graph, item.fetch,
4360 {{"then", kOne}, {"else", kTwo}});
4361 ASSERT_EQ(tensors.size(), 1);
4362 ASSERT_EQ(tensors_expected.size(), 1);
4363 ASSERT_EQ(tensors[0].shape(), pred_shape.num_elements() == 2
4364 ? TensorShape({2, 4})
4365 : TensorShape({2, 2, 4}));
4366 test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
4367 }
4368 }
4369 }
4370
TEST_F(ConstantFoldingTest,QuantizationEmulation)4371 TEST_F(ConstantFoldingTest, QuantizationEmulation) {
4372 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
4373 Output x = ops::Const(scope.WithOpName("x"), {0.0f, 1.0f, 2.0f, 3.0f}, {4});
4374 Output min_range = ops::Const(scope.WithOpName("min_range"), 0.0f, {});
4375 Output max_range = ops::Const(scope.WithOpName("max_range"), 3.0f, {});
4376 Output y = ops::QuantizeAndDequantizeV2(scope.WithOpName("y"), x, min_range,
4377 max_range);
4378 Output id = ops::Identity(scope.WithOpName("id"), y);
4379
4380 GrapplerItem item;
4381 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
4382 item.fetch = {"id"};
4383
4384 std::vector<Tensor> expected_tensors = EvaluateNodes(item.graph, item.fetch);
4385
4386 for (const bool fold_quantization_emulation : {false, true}) {
4387 ConstantFolding optimizer(/*cpu_device=*/nullptr,
4388 /*disable_compressed_tensor_optimization=*/false,
4389 fold_quantization_emulation);
4390 GraphDef output;
4391 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
4392 int num_quantization_emulation_ops = 0;
4393 for (const NodeDef& node : output.node()) {
4394 if (node.op() == "QuantizeAndDequantizeV2") {
4395 num_quantization_emulation_ops++;
4396 }
4397 }
4398 EXPECT_EQ(fold_quantization_emulation ? 0 : 1,
4399 num_quantization_emulation_ops);
4400
4401 std::vector<Tensor> actual_tensors = EvaluateNodes(output, item.fetch);
4402 for (int i = 0; i < item.fetch.size(); ++i) {
4403 test::ExpectTensorEqual<float>(expected_tensors[i], actual_tensors[i]);
4404 }
4405 }
4406 }
4407
4408 } // namespace
4409 } // namespace grappler
4410 } // namespace tensorflow
4411