1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
17 #include "tensorflow/cc/ops/array_ops.h"
18 #include "tensorflow/cc/ops/array_ops_internal.h"
19 #include "tensorflow/cc/ops/standard_ops.h"
20 #include "tensorflow/core/framework/node_def.pb.h"
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/framework/types.pb.h"
23 #include "tensorflow/core/grappler/grappler_item.h"
24 #include "tensorflow/core/grappler/utils.h"
25 #include "tensorflow/core/grappler/utils/grappler_test.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/lib/strings/str_util.h"
28 #include "tensorflow/core/lib/strings/strcat.h"
29 #include "tensorflow/core/platform/tensor_coding.h"
30
31 namespace tensorflow {
32 namespace grappler {
33 namespace {
34
35 class ConstantFoldingTest : public GrapplerTest {
36 protected:
37 template <DataType DTYPE>
SimpleNeutralElementTest()38 void SimpleNeutralElementTest() {
39 for (bool use_snapshot : {false, true}) {
40 typedef typename EnumToDataType<DTYPE>::Type T;
41 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
42 Output x = ops::Placeholder(s.WithOpName("x"), DTYPE,
43 ops::Placeholder::Shape(TensorShape({2, 2})));
44 Output v = ops::Variable(s.WithOpName("v"), {2, 2}, DTYPE);
45 Tensor zeros_t(DTYPE, TensorShape({2, 2}));
46 Tensor ones_t(DTYPE, TensorShape({2, 2}));
47 Tensor x_t(DTYPE, TensorShape({2, 2}));
48 for (int i = 0; i < 4; ++i) {
49 zeros_t.flat<T>()(i) = T(0);
50 ones_t.flat<T>()(i) = T(1);
51 x_t.flat<T>()(i) = T(i + 1);
52 }
53 Output zeros = ops::Const(s.WithOpName("zeros"), zeros_t);
54 Output ones = ops::Const(s.WithOpName("ones"), ones_t);
55 Output mul1;
56 Output mul2;
57 Output add1;
58 Output add2;
59 if (DTYPE == DT_BOOL) {
60 mul1 = ops::LogicalAnd(s.WithOpName("mul1"), x, zeros);
61 mul2 = ops::LogicalAnd(s.WithOpName("mul2"), x, ones);
62 add1 = ops::LogicalOr(s.WithOpName("add1"), x, zeros);
63 add2 = ops::LogicalOr(s.WithOpName("add2"), x, ones);
64 } else {
65 mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
66 mul2 = ops::Mul(s.WithOpName("mul2"), x, ones);
67 add1 = ops::Add(s.WithOpName("add1"), x, zeros);
68 add1 = ops::Add(s.WithOpName("add2"), x, ones);
69 }
70 if (use_snapshot) {
71 // Add an op with ref input to prevent Snapshot from being
72 // turned into Identity.
73 ops::Assign(s.WithOpName("assign"), v, ones);
74 }
75 GrapplerItem item;
76 TF_CHECK_OK(s.ToGraphDef(&item.graph));
77 item.fetch = {"mul1", "mul2", "add1", "add2"};
78 ConstantFolding optimizer(/*cpu_device=*/nullptr);
79 GraphDef output;
80 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
81 TF_EXPECT_OK(status);
82
83 EXPECT_EQ(7, output.node_size());
84 const string snapshot_or_identity =
85 use_snapshot ? "Snapshot" : "Identity";
86 for (int i = 0; i < output.node_size(); ++i) {
87 const NodeDef& node = output.node(i);
88 const string& name = node.name();
89 if (name == "mul1") {
90 EXPECT_EQ("Const", node.op());
91 EXPECT_EQ("^x", node.input(0));
92 EXPECT_EQ("^zeros", node.input(1));
93 } else if (name == "mul2") {
94 EXPECT_EQ(snapshot_or_identity, node.op());
95 EXPECT_EQ("x", node.input(0));
96 EXPECT_EQ("^ones", node.input(1));
97 } else if (name == "add1") {
98 EXPECT_EQ(snapshot_or_identity, node.op());
99 EXPECT_EQ("x", node.input(0));
100 EXPECT_EQ("^zeros", node.input(1));
101 } else if (name == "add2") {
102 if (DTYPE == DT_BOOL) {
103 EXPECT_EQ("Const", node.op());
104 EXPECT_EQ("^x", node.input(0));
105 EXPECT_EQ("^ones", node.input(1));
106 } else {
107 EXPECT_EQ("Add", node.op());
108 EXPECT_EQ("x", node.input(0));
109 EXPECT_EQ("ones", node.input(1));
110 }
111 }
112 }
113 auto tensors_expected =
114 EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
115 auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
116 EXPECT_EQ(4, tensors_expected.size());
117 EXPECT_EQ(4, tensors.size());
118 for (int i = 0; i < item.fetch.size(); ++i) {
119 test::ExpectTensorEqual<T>(tensors_expected[i], tensors[i]);
120 }
121 }
122 }
123
MulConvPushDownTest(const TensorShape & input_shape,const TensorShape & filter_shape,const TensorShape & mul_const_input_shape,const bool use_3d_conv,const char * padding,const char * data_format,const bool expect_folded)124 void MulConvPushDownTest(const TensorShape& input_shape,
125 const TensorShape& filter_shape,
126 const TensorShape& mul_const_input_shape,
127 const bool use_3d_conv, const char* padding,
128 const char* data_format, const bool expect_folded) {
129 // Tests if the following rewrite is performed:
130 //
131 // * Conv2D
132 // / \ / \
133 // c Conv2D --> x (c * filter)
134 // / \
135 // x filter
136 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
137
138 Tensor filter_values(DT_FLOAT, filter_shape);
139 for (int i = 0; i < filter_values.NumElements(); ++i) {
140 filter_values.flat<float>()(i) = std::sqrt(static_cast<float>(i));
141 }
142 Output filter =
143 ops::Const(s.WithOpName("filter"), Input::Initializer(filter_values));
144
145 Output input = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
146 ops::Placeholder::Shape(input_shape));
147
148 Output conv;
149 if (use_3d_conv) {
150 conv = ops::Conv3D(s.WithOpName("conv"), input, filter, {1, 1, 1, 1, 1},
151 padding, ops::Conv3D::DataFormat(data_format));
152 } else {
153 conv = ops::Conv2D(s.WithOpName("conv"), input, filter, {1, 1, 1, 1},
154 padding, ops::Conv2D::DataFormat(data_format));
155 }
156 Tensor mul_const_input(DT_FLOAT, mul_const_input_shape);
157 for (int i = 0; i < mul_const_input.NumElements(); ++i) {
158 mul_const_input.flat<float>()(i) = static_cast<float>(i + 3);
159 }
160 Output c =
161 ops::Const(s.WithOpName("c"), Input::Initializer(mul_const_input));
162 Output mul = ops::Mul(s.WithOpName("mul"), c, conv);
163
164 GrapplerItem item;
165 TF_CHECK_OK(s.ToGraphDef(&item.graph));
166
167 ConstantFolding optimizer(/*cpu_device=*/nullptr);
168 GraphDef output;
169 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
170 TF_EXPECT_OK(status);
171
172 EXPECT_EQ(5, output.node_size());
173 int found = 0;
174 if (expect_folded) {
175 for (const auto& node : output.node()) {
176 if (node.name() == "mul") {
177 found++;
178 EXPECT_EQ(use_3d_conv ? "Conv3D" : "Conv2D", node.op());
179 EXPECT_EQ(2, node.input_size());
180 EXPECT_EQ("x", node.input(0));
181 EXPECT_EQ("conv/merged_input", node.input(1));
182 } else if (node.name() == "conv/merged_input") {
183 found++;
184 EXPECT_EQ("Const", node.op());
185 EXPECT_EQ(0, node.input_size());
186 }
187 }
188 } else {
189 for (const auto& node : output.node()) {
190 if (node.name() == "mul") {
191 found++;
192 EXPECT_EQ("Mul", node.op());
193 EXPECT_EQ(2, node.input_size());
194 EXPECT_EQ("c", node.input(0));
195 EXPECT_EQ("conv", node.input(1));
196 } else if (node.name() == "conv") {
197 found++;
198 EXPECT_EQ(use_3d_conv ? "Conv3D" : "Conv2D", node.op());
199 EXPECT_EQ(2, node.input_size());
200 EXPECT_EQ("x", node.input(0));
201 EXPECT_EQ("filter", node.input(1));
202 }
203 }
204 }
205 EXPECT_EQ(2, found);
206
207 // Check that const folded multiplication node has the expected value.
208 std::vector<string> fetch = {"mul"};
209 Tensor value(DT_FLOAT, input_shape);
210 for (int i = 0; i < value.NumElements(); ++i) {
211 value.flat<float>()(i) = i;
212 }
213 auto actual = EvaluateNodes(output, fetch, {{"x", value}});
214 auto expected = EvaluateNodes(item.graph, fetch, {{"x", value}});
215 test::ExpectTensorEqual<float>(expected[0], actual[0]);
216 }
217 };
218
TEST_F(ConstantFoldingTest,SimpleFolding)219 TEST_F(ConstantFoldingTest, SimpleFolding) {
220 // Build a simple graph with a few trivially prunable ops.
221 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
222
223 Output a = ops::Const(s.WithOpName("a"), 1.0f, {1});
224 Output b = ops::Const(s.WithOpName("b"), 2.0f, {1});
225 Output c = ops::AddN(s.WithOpName("c").WithDevice("/CPU:0"), {a, b});
226 Output d = ops::AddN(s.WithOpName("d"), {b, c});
227
228 GrapplerItem item;
229 item.fetch.push_back("d");
230 TF_CHECK_OK(s.ToGraphDef(&item.graph));
231
232 ConstantFolding optimizer(/*cpu_device=*/nullptr);
233 GraphDef output;
234 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
235 TF_EXPECT_OK(status);
236
237 EXPECT_EQ(1, output.node_size());
238
239 const NodeDef& node_d = output.node(0);
240 EXPECT_EQ("d", node_d.name());
241 EXPECT_EQ("Const", node_d.op());
242
243 std::vector<string> fetch = {"d"};
244 auto tensors_expected = EvaluateNodes(item.graph, fetch);
245 auto tensors = EvaluateNodes(output, fetch);
246 EXPECT_EQ(1, tensors_expected.size());
247 EXPECT_EQ(1, tensors.size());
248 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
249 }
250
TEST_F(ConstantFoldingTest,AddTree)251 TEST_F(ConstantFoldingTest, AddTree) {
252 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
253
254 Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2});
255 Output c3 = ops::Const(s.WithOpName("c3"), 3.0f, {2});
256 Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
257 ops::Placeholder::Shape(TensorShape({2, 2})));
258 Output add_child = ops::Add(s.WithOpName("add_child"), c2, x);
259 Output c1 = ops::Const(s.WithOpName("c1").WithControlDependencies(add_child),
260 1.0f, {1});
261 Output add_parent = ops::Add(s.WithOpName("add_parent"), c1, add_child);
262
263 Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
264 ops::Placeholder::Shape(TensorShape({2, 2})));
265 Output c4 = ops::Const(s.WithOpName("c4"), 4.0f, {2});
266 Output c5 = ops::Const(s.WithOpName("c5"), 5.0f, {2});
267 Output c20 = ops::Const(s.WithOpName("c20"), 20.0f, {2});
268 Output mul_child = ops::Mul(s.WithOpName("mul_child"), c4, y);
269 Output mul_parent = ops::Mul(s.WithOpName("mul_parent"), c5, mul_child);
270 Output addmul_child = ops::Add(s.WithOpName("addmul_child"), c4, x);
271 Output addmul_parent =
272 ops::Mul(s.WithOpName("addmul_parent"), c5, addmul_child);
273
274 GrapplerItem item;
275 item.fetch = {"add_parent", "mul_parent", "addmul_parent"};
276 TF_CHECK_OK(s.ToGraphDef(&item.graph));
277
278 ConstantFolding optimizer(/*cpu_device=*/nullptr);
279 GraphDef output;
280 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
281 TF_EXPECT_OK(status);
282
283 // We expect the following rewrite(s) to occur:
284 //
285 // + + +
286 // / \ / \ / \
287 // 1.0 + --> x + --> x 3.0
288 // / \ / \
289 // 2.0 x 1.0 2.0
290 //
291 // * * *
292 // / \ / \ / \
293 // 4.0 * --> y * --> y 20.0
294 // / \ / \
295 // 5.0 y 4.0 5.0
296
297 EXPECT_EQ(11, output.node_size());
298 for (const auto& node : output.node()) {
299 if (node.name() == "add_child") {
300 EXPECT_EQ("Const", node.op());
301 TensorProto t = node.attr().at("value").tensor();
302 EXPECT_EQ(1, t.tensor_shape().dim_size());
303 EXPECT_EQ(2, t.tensor_shape().dim(0).size());
304 } else if (node.name() == "add_parent") {
305 EXPECT_EQ("Add", node.op());
306 EXPECT_EQ(2, node.input_size());
307 EXPECT_EQ("x", node.input(0));
308 EXPECT_EQ("add_child", node.input(1));
309 } else if (node.name() == "mul_child") {
310 EXPECT_EQ("Const", node.op());
311 TensorProto t = node.attr().at("value").tensor();
312 EXPECT_EQ(1, t.tensor_shape().dim_size());
313 EXPECT_EQ(2, t.tensor_shape().dim(0).size());
314 } else if (node.name() == "mul_parent") {
315 EXPECT_EQ("Mul", node.op());
316 EXPECT_EQ(2, node.input_size());
317 EXPECT_EQ("y", node.input(0));
318 EXPECT_EQ("mul_child", node.input(1));
319 } else if (node.name() == "addmul_child") {
320 // Unchanged.
321 EXPECT_EQ("Add", node.op());
322 EXPECT_EQ(2, node.input_size());
323 EXPECT_EQ("c4", node.input(0));
324 EXPECT_EQ("x", node.input(1));
325 }
326 }
327
328 // Check that the result nodes have the expected value.
329 std::vector<string> fetch = {"c3", "c20"};
330 auto tensor_expected = EvaluateNodes(item.graph, fetch);
331 EXPECT_EQ(fetch.size(), tensor_expected.size());
332 fetch = {"add_child", "mul_child"};
333 auto tensors = EvaluateNodes(output, fetch);
334 EXPECT_EQ(fetch.size(), tensors.size());
335 for (int i = 0; i < fetch.size(); i++) {
336 test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]);
337 }
338 }
339
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_ScalarConst)340 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_ScalarConst) {
341 for (string data_format : {
342 "NHWC",
343 #if GOOGLE_CUDA
344 "NCHW"
345 #endif // GOOGLE_CUDA
346 }) {
347 MulConvPushDownTest(
348 /*input_shape=*/data_format == "NHWC" ? TensorShape{4, 10, 10, 3}
349 : TensorShape{4, 3, 10, 10},
350 /*filter_shape=*/{2, 2, 3, 5},
351 /*mul_const_input_shape=*/{},
352 /*use_3d_conv=*/false,
353 /*padding=*/"VALID", data_format.c_str(),
354 /*expect_folded=*/true);
355 }
356 }
357
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_SingletonConst)358 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_SingletonConst) {
359 for (string data_format : {
360 "NHWC",
361 #if GOOGLE_CUDA
362 "NCHW"
363 #endif // GOOGLE_CUDA
364 }) {
365 for (auto mul_const_input_shape :
366 {TensorShape{1}, TensorShape{1, 1, 1, 1}}) {
367 MulConvPushDownTest(
368 /*input_shape=*/data_format == "NHWC" ? TensorShape{4, 10, 10, 3}
369 : TensorShape{4, 3, 10, 10},
370 /*filter_shape=*/{2, 2, 3, 5}, mul_const_input_shape,
371 /*use_3d_conv=*/false,
372 /*padding=*/"VALID", data_format.c_str(),
373 /*expect_folded=*/true);
374 }
375 }
376 }
377
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_SingletonConst_ShapeMismatch)378 TEST_F(ConstantFoldingTest,
379 MulConvPushDownTest_Conv2D_SingletonConst_ShapeMismatch) {
380 for (string data_format : {
381 "NHWC",
382 #if GOOGLE_CUDA
383 "NCHW"
384 #endif // GOOGLE_CUDA
385 }) {
386 MulConvPushDownTest(
387 /*input_shape=*/data_format == "NHWC" ? TensorShape{4, 10, 10, 3}
388 : TensorShape{4, 3, 10, 10},
389 /*filter_shape=*/{2, 2, 3, 5},
390 /*mul_const_input_shape=*/{1, 1, 1, 1, 1},
391 /*use_3d_conv=*/false,
392 /*padding=*/"VALID", data_format.c_str(),
393 /*expect_folded=*/false);
394 }
395 }
396
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_3x1x3Const)397 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_3x1x3Const) {
398 for (auto data_format : {
399 "NHWC",
400 #if GOOGLE_CUDA
401 "NCHW"
402 #endif // GOOGLE_CUDA
403 }) {
404 MulConvPushDownTest(
405 /*input_shape=*/{3, 3, 3, 3},
406 /*filter_shape=*/{3, 3, 3, 3},
407 /*mul_const_input_shape=*/{3, 1, 3},
408 /*use_3d_conv=*/false,
409 /*padding=*/"SAME", data_format,
410 /*expect_folded=*/false);
411 }
412 }
413
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_NHWC_VectorLikeConst)414 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_NHWC_VectorLikeConst) {
415 for (auto mul_const_input_shape :
416 {TensorShape{3}, TensorShape{1, 3}, TensorShape{1, 1, 1, 3}}) {
417 MulConvPushDownTest(
418 /*input_shape=*/{3, 3, 3, 3},
419 /*filter_shape=*/{3, 3, 3, 3}, mul_const_input_shape,
420 /*use_3d_conv=*/false,
421 /*padding=*/"SAME",
422 /*data_format=*/"NHWC",
423 /*expect_folded=*/true);
424 }
425 }
426
427 #if GOOGLE_CUDA
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_NCHW_VectorLikeConst)428 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_NCHW_VectorLikeConst) {
429 for (auto mul_const_input_shape :
430 {TensorShape{3}, TensorShape{3, 1, 1}, TensorShape{1, 3, 1, 1}}) {
431 MulConvPushDownTest(
432 /*input_shape=*/{3, 3, 3, 3},
433 /*filter_shape=*/{3, 3, 3, 3}, mul_const_input_shape,
434 /*use_3d_conv=*/false,
435 /*padding=*/"SAME",
436 /*data_format=*/"NCHW",
437 // TODO(laigd): optimization should happen in this case.
438 /*expect_folded=*/false);
439 }
440 }
441 #endif // GOOGLE_CUDA
442
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_3x1Const)443 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_3x1Const) {
444 for (auto data_format : {
445 "NHWC",
446 #if GOOGLE_CUDA
447 "NCHW"
448 #endif // GOOGLE_CUDA
449 }) {
450 MulConvPushDownTest(
451 /*input_shape=*/{3, 3, 3, 3},
452 /*filter_shape=*/{3, 3, 3, 3},
453 /*mul_const_input_shape=*/{3, 1},
454 /*use_3d_conv=*/false,
455 /*padding=*/"SAME", data_format,
456 /*expect_folded=*/false);
457 }
458 }
459
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv3D_NDHWC_1x1x3Const)460 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv3D_NDHWC_1x1x3Const) {
461 MulConvPushDownTest(
462 /*input_shape=*/{3, 3, 3, 3, 3},
463 /*filter_shape=*/{3, 3, 3, 3, 3},
464 /*mul_const_input_shape=*/{1, 1, 3},
465 /*use_3d_conv=*/true,
466 /*padding=*/"SAME",
467 /*data_format=*/"NDHWC",
468 /*expect_folded=*/true);
469 }
470
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv3D_NCDHW_3x1x1x1Const)471 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv3D_NCDHW_3x1x1x1Const) {
472 MulConvPushDownTest(
473 /*input_shape=*/{3, 3, 3, 3, 3},
474 /*filter_shape=*/{3, 3, 3, 3, 3},
475 /*mul_const_input_shape=*/{3, 1, 1, 1},
476 /*use_3d_conv=*/true,
477 /*padding=*/"SAME",
478 /*data_format=*/"NDHWC",
479 // TODO(laigd): optimization should happen in this case.
480 /*expect_folded=*/false);
481 }
482
TEST_F(ConstantFoldingTest,NeutralElement)483 TEST_F(ConstantFoldingTest, NeutralElement) {
484 int kConst = 0;
485 int kLike = 1;
486 int kFill = 2;
487 for (int const_type : {kConst, kLike, kFill}) {
488 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
489 Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
490 ops::Placeholder::Shape(TensorShape({2, 2})));
491 Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
492 ops::Placeholder::Shape(TensorShape({2, 2})));
493 Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
494 ops::Placeholder::Shape(TensorShape({3, 2})));
495 Output b = ops::Placeholder(s.WithOpName("b"), DT_FLOAT,
496 ops::Placeholder::Shape(TensorShape({2, 3})));
497 Output bias = ops::Placeholder(s.WithOpName("bias"), DT_FLOAT,
498 ops::Placeholder::Shape(TensorShape({2})));
499 Output zeros_1d = ops::Const(s.WithOpName("zeros_1d"), 0.0f, {2});
500 Output zeros_const = ops::Const(s.WithOpName("zeros_const"), 0.0f, {2, 2});
501 Output zeros_like = ops::ZerosLike(s.WithOpName("zeros_like"), x);
502 Output zeros_fill = ops::Fill(s.WithOpName("zeros_fill"), {2, 2}, 0.0f);
503 Output zeros = const_type == kConst
504 ? zeros_const
505 : (const_type == kLike ? zeros_like : zeros_fill);
506 Output ones_const = ops::Const(s.WithOpName("ones_const"), 1.0f, {2, 2});
507 Output ones_like = ops::OnesLike(s.WithOpName("ones_like"), x);
508 Output ones_fill = ops::Fill(s.WithOpName("ones_fill"), {2, 2}, 1.0f);
509 Output ones = const_type == kConst
510 ? ones_const
511 : (const_type == kLike ? ones_like : ones_fill);
512 Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
513 Output mul2 = ops::Mul(s.WithOpName("mul2"), zeros, y);
514 Output mul3 = ops::Mul(s.WithOpName("mul3"), x, ones);
515 Output mul4 = ops::Mul(s.WithOpName("mul4"), ones, y);
516 Output mul5 = ops::Mul(s.WithOpName("mul5"), x, zeros_1d);
517 Output mul6 = ops::Mul(s.WithOpName("mul6"), zeros_1d, y);
518 Output div1 = ops::Div(s.WithOpName("div1"), x, ones);
519 Output div2 = ops::Div(s.WithOpName("div2"), ones, y);
520 Output matmul1 = ops::MatMul(s.WithOpName("matmul1"), x, zeros);
521 Output matmul2 = ops::MatMul(s.WithOpName("matmul2"), zeros, y);
522 Output matmul3 = ops::MatMul(s.WithOpName("matmul3"), a, zeros);
523 Output matmul4 = ops::MatMul(s.WithOpName("matmul4"), zeros, b);
524 Output add1 = ops::Add(s.WithOpName("add1"), x, zeros);
525 Output add2 = ops::Add(s.WithOpName("add2"), zeros, y);
526 Output bias_add1 = ops::BiasAdd(s.WithOpName("bias_add1"), x, zeros_1d);
527 Output bias_add2 = ops::BiasAdd(s.WithOpName("bias_add2"), zeros, bias);
528 Output sub1 = ops::Sub(s.WithOpName("sub1"), x, zeros);
529 Output sub2 = ops::Sub(s.WithOpName("sub2"), zeros, y);
530 Output concat =
531 ops::Stack(s.WithOpName("stack"),
532 {mul1, mul2, mul3, mul4, mul5, mul6, div1, div2, matmul1,
533 matmul2, add1, add2, bias_add1, bias_add2, sub1, sub2});
534 GrapplerItem item;
535 TF_CHECK_OK(s.ToGraphDef(&item.graph));
536 item.fetch = {"stack", "matmul3", "matmul4"};
537
538 ConstantFolding optimizer(/*cpu_device=*/nullptr);
539 GraphDef output;
540 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
541 TF_EXPECT_OK(status);
542
543 const string suffix =
544 (const_type == kConst ? "_const"
545 : (const_type == kLike ? "_like" : "_fill"));
546 const string zeros_name = strings::StrCat("zeros", suffix);
547 const string ones_name = strings::StrCat("ones", suffix);
548 const string ctrl_zeros_name = strings::StrCat("^zeros", suffix);
549 const string ctrl_ones_name = strings::StrCat("^ones", suffix);
550 EXPECT_EQ(const_type == kFill ? 31 : 27, output.node_size());
551 for (int i = 0; i < output.node_size(); ++i) {
552 const NodeDef& node = output.node(i);
553 const string& name = node.name();
554 if (name == "mul1") {
555 EXPECT_EQ("Const", node.op());
556 EXPECT_EQ("^x", node.input(0));
557 EXPECT_EQ(ctrl_zeros_name, node.input(1));
558 } else if (name == "mul2") {
559 EXPECT_EQ("Const", node.op());
560 EXPECT_EQ(ctrl_zeros_name, node.input(0));
561 EXPECT_EQ("^y", node.input(1));
562 } else if (name == "mul3") {
563 EXPECT_EQ("Identity", node.op());
564 EXPECT_EQ("x", node.input(0));
565 EXPECT_EQ(ctrl_ones_name, node.input(1));
566 } else if (name == "mul4") {
567 EXPECT_EQ("Identity", node.op());
568 EXPECT_EQ("y", node.input(0));
569 EXPECT_EQ(ctrl_ones_name, node.input(1));
570 } else if (name == "mul5") {
571 EXPECT_EQ("Const", node.op());
572 EXPECT_EQ("^x", node.input(0));
573 EXPECT_EQ("^zeros_1d", node.input(1));
574 } else if (name == "mul6") {
575 EXPECT_EQ("Const", node.op());
576 EXPECT_EQ("^zeros_1d", node.input(0));
577 EXPECT_EQ("^y", node.input(1));
578 } else if (name == "div1") {
579 EXPECT_EQ("Identity", node.op());
580 EXPECT_EQ("x", node.input(0));
581 EXPECT_EQ(ctrl_ones_name, node.input(1));
582 } else if (name == "div2") {
583 EXPECT_EQ("Reciprocal", node.op());
584 EXPECT_EQ("y", node.input(0));
585 EXPECT_EQ(ctrl_ones_name, node.input(1));
586 } else if (name == "matmul1") {
587 EXPECT_EQ("Const", node.op());
588 EXPECT_EQ("^x", node.input(0));
589 EXPECT_EQ(ctrl_zeros_name, node.input(1));
590 } else if (name == "matmul2") {
591 EXPECT_EQ("Const", node.op());
592 EXPECT_EQ(ctrl_zeros_name, node.input(0));
593 EXPECT_EQ("^y", node.input(1));
594 } else if (name == "matmul3") {
595 EXPECT_EQ("Const", node.op());
596 EXPECT_EQ("^a", node.input(0));
597 EXPECT_EQ(ctrl_zeros_name, node.input(1));
598 TensorProto t = node.attr().at("value").tensor();
599 EXPECT_EQ(1, t.float_val_size());
600 EXPECT_EQ(0, t.float_val(0));
601 EXPECT_EQ(2, t.tensor_shape().dim_size());
602 EXPECT_EQ(3, t.tensor_shape().dim(0).size());
603 EXPECT_EQ(2, t.tensor_shape().dim(1).size());
604 } else if (name == "matmul4") {
605 EXPECT_EQ("Const", node.op());
606 EXPECT_EQ(ctrl_zeros_name, node.input(0));
607 EXPECT_EQ("^b", node.input(1));
608 TensorProto t = node.attr().at("value").tensor();
609 EXPECT_EQ(1, t.float_val_size());
610 EXPECT_EQ(0, t.float_val(0));
611 EXPECT_EQ(2, t.tensor_shape().dim_size());
612 EXPECT_EQ(2, t.tensor_shape().dim(0).size());
613 EXPECT_EQ(3, t.tensor_shape().dim(1).size());
614 } else if (name == "add1") {
615 EXPECT_EQ("Identity", node.op());
616 EXPECT_EQ("x", node.input(0));
617 EXPECT_EQ(ctrl_zeros_name, node.input(1));
618 } else if (name == "add2") {
619 EXPECT_EQ("Identity", node.op());
620 EXPECT_EQ("y", node.input(0));
621 EXPECT_EQ(ctrl_zeros_name, node.input(1));
622 } else if (name == "bias_add1") {
623 EXPECT_EQ("Identity", node.op());
624 EXPECT_EQ("x", node.input(0));
625 EXPECT_EQ("^zeros_1d", node.input(1));
626 } else if (name == "bias_add2") {
627 // We don't eliminate this one, because it requires broadcasting.
628 EXPECT_EQ("BiasAdd", node.op());
629 EXPECT_EQ(zeros_name, node.input(0));
630 EXPECT_EQ("bias", node.input(1));
631 } else if (name == "sub1") {
632 EXPECT_EQ("Identity", node.op());
633 EXPECT_EQ("x", node.input(0));
634 EXPECT_EQ(ctrl_zeros_name, node.input(1));
635 } else if (name == "sub2") {
636 EXPECT_EQ("Neg", node.op());
637 EXPECT_EQ("y", node.input(0));
638 EXPECT_EQ(ctrl_zeros_name, node.input(1));
639 }
640 const std::set<string> square_zero_const{"mul1", "mul2", "mul5",
641 "mul6", "matmul1", "matmul2"};
642 if (square_zero_const.count(name) > 0) {
643 TensorProto t = node.attr().at("value").tensor();
644 EXPECT_EQ(1, t.float_val_size());
645 EXPECT_EQ(0, t.float_val(0));
646 EXPECT_EQ(2, t.tensor_shape().dim_size());
647 EXPECT_EQ(2, t.tensor_shape().dim(0).size());
648 EXPECT_EQ(2, t.tensor_shape().dim(1).size());
649 }
650 }
651 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 2}));
652 auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 3}));
653 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
654 auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
655 auto bias_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2}));
656
657 auto tensors_expected = EvaluateNodes(
658 item.graph, item.fetch,
659 {{"x", x_t}, {"y", y_t}, {"a", a_t}, {"b", b_t}, {"bias", bias_t}});
660 EXPECT_EQ(item.fetch.size(), tensors_expected.size());
661 auto tensors = EvaluateNodes(
662 output, item.fetch,
663 {{"x", x_t}, {"y", y_t}, {"a", a_t}, {"b", b_t}, {"bias", bias_t}});
664 EXPECT_EQ(item.fetch.size(), tensors.size());
665 for (int i = 0; i < item.fetch.size(); ++i) {
666 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
667 }
668 }
669 }
670
TEST_F(ConstantFoldingTest,NeutralElement_ShortFloats)671 TEST_F(ConstantFoldingTest, NeutralElement_ShortFloats) {
672 SimpleNeutralElementTest<DT_BOOL>();
673 SimpleNeutralElementTest<DT_HALF>();
674 SimpleNeutralElementTest<DT_BFLOAT16>();
675 }
676
TEST_F(ConstantFoldingTest,StrengthReduce_Reciprocal)677 TEST_F(ConstantFoldingTest, StrengthReduce_Reciprocal) {
678 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
679 Output cf_half = ops::Const(s.WithOpName("cf_half"), 0.5f, {1});
680 Output xf = ops::Placeholder(s.WithOpName("xf"), DT_FLOAT,
681 ops::Placeholder::Shape(TensorShape({2, 2})));
682 Output xi = ops::Placeholder(s.WithOpName("xi"), DT_INT32,
683 ops::Placeholder::Shape(TensorShape({2, 2})));
684 Output ci = ops::Const(s.WithOpName("ci"), 2, {1});
685 Output cf = ops::Const(s.WithOpName("cf"), 2.0f, {1});
686 Output div_i = ops::Div(s.WithOpName("div_i"), xi, ci);
687 Output div_f = ops::Div(s.WithOpName("div_f"), xf, cf);
688 Output realdiv = ops::RealDiv(s.WithOpName("realdiv"), xf, cf);
689
690 GrapplerItem item;
691 TF_CHECK_OK(s.ToGraphDef(&item.graph));
692 item.fetch = {"div_f", "div_i", "realdiv"};
693 ConstantFolding optimizer(/*cpu_device=*/nullptr);
694 GraphDef output;
695 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
696 TF_EXPECT_OK(status);
697
698 EXPECT_EQ(8, output.node_size());
699 for (int i = 0; i < output.node_size(); ++i) {
700 const NodeDef& node = output.node(i);
701 const string& name = node.name();
702 if (name == "div_i") {
703 // Integer division is unchanged.
704 EXPECT_EQ("Div", node.op());
705 EXPECT_EQ("xi", node.input(0));
706 EXPECT_EQ("ci", node.input(1));
707 } else if (name == "div_f") {
708 EXPECT_EQ("Mul", node.op());
709 EXPECT_EQ("xf", node.input(0));
710 EXPECT_EQ("ConstantFolding/div_f_recip", node.input(1));
711 } else if (name == "realdiv") {
712 EXPECT_EQ("Mul", node.op());
713 EXPECT_EQ("xf", node.input(0));
714 EXPECT_EQ("ConstantFolding/realdiv_recip", node.input(1));
715 } else if (name == "ConstantFolding/div_f_recip") {
716 EXPECT_EQ("Const", node.op());
717 EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
718 TensorProto t = node.attr().at("value").tensor();
719 EXPECT_EQ(DT_FLOAT, t.dtype());
720 EXPECT_EQ(1, t.tensor_shape().dim_size());
721 EXPECT_EQ(1, t.tensor_shape().dim(0).size());
722 } else if (name == "ConstantFolding/realdiv_recip") {
723 EXPECT_EQ("Const", node.op());
724 EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
725 TensorProto t = node.attr().at("value").tensor();
726 EXPECT_EQ(DT_FLOAT, t.dtype());
727 EXPECT_EQ(1, t.tensor_shape().dim_size());
728 EXPECT_EQ(1, t.tensor_shape().dim(0).size());
729 }
730 }
731
732 // Check that the reciprocals have the expected value.
733 std::vector<string> fetch = {"cf_half"};
734 auto tensor_expected = EvaluateNodes(item.graph, fetch);
735 EXPECT_EQ(fetch.size(), tensor_expected.size());
736 fetch = {"ConstantFolding/div_f_recip", "ConstantFolding/realdiv_recip"};
737 auto tensors = EvaluateNodes(output, fetch);
738 EXPECT_EQ(fetch.size(), tensors.size());
739 for (int i = 0; i < fetch.size(); i++) {
740 test::ExpectTensorEqual<float>(tensor_expected[0], tensors[i]);
741 }
742 }
743
TEST_F(ConstantFoldingTest,NeutralElement_PartialShape_UnknownOutputShape)744 TEST_F(ConstantFoldingTest, NeutralElement_PartialShape_UnknownOutputShape) {
745 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
746 Output x_known =
747 ops::Placeholder(s.WithOpName("x_known"), DT_FLOAT,
748 ops::Placeholder::Shape(TensorShape({2, 2})));
749 Output x_partially_known =
750 ops::Placeholder(s.WithOpName("x_partially_unknown"), DT_FLOAT,
751 ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
752 Output x_unknown = ops::Placeholder(s.WithOpName("x_unknown"), DT_FLOAT);
753 Output zeros_known = ops::ZerosLike(s.WithOpName("zeros_known"), x_known);
754 Output zeros_partially_known =
755 ops::ZerosLike(s.WithOpName("zeros_partially_known"), x_partially_known);
756 Output zeros_unknown =
757 ops::ZerosLike(s.WithOpName("zeros_unknown"), x_unknown);
758
759 // Multiplies without any additional ops to supply the output shape.
760 int count = 0;
761 std::vector<Output> muls;
762 std::unordered_set<string> not_converted;
763 std::unordered_set<string> to_const;
764 std::unordered_set<string> to_identity;
765 for (const auto* x : {&x_known, &x_partially_known, &x_unknown}) {
766 for (const auto* zeros :
767 {&zeros_known, &zeros_partially_known, &zeros_unknown}) {
768 const string name = strings::StrCat("mul_", count++);
769 muls.push_back(ops::Mul(s.WithOpName(name), *x, *zeros));
770 if (x == &x_partially_known && zeros == &zeros_partially_known) {
771 to_identity.insert(name);
772 } else if (x == &x_unknown || zeros == &zeros_unknown) {
773 not_converted.insert(name);
774 } else {
775 to_const.insert(name);
776 }
777 }
778 }
779
780 GrapplerItem item;
781 TF_CHECK_OK(s.ToGraphDef(&item.graph));
782
783 ConstantFolding optimizer(/*cpu_device=*/nullptr);
784 GraphDef output;
785 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
786 TF_EXPECT_OK(status);
787
788 EXPECT_EQ(15, output.node_size());
789 for (int i = 0; i < output.node_size(); ++i) {
790 const NodeDef& node = output.node(i);
791 const string& name = node.name();
792 if (to_const.count(name) > 0) {
793 EXPECT_EQ("Const", node.op()) << node.name();
794 } else if (to_identity.count(name) > 0) {
795 EXPECT_EQ("Identity", node.op()) << node.name();
796 } else if (not_converted.count(name) > 0) {
797 EXPECT_EQ("Mul", node.op()) << node.name();
798 }
799 }
800
801 const std::vector<string> fetch = {"mul_0", "mul_4", "mul_8"};
802 auto x_known_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
803 auto x_partially_unknown_t =
804 GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
805 auto x_unknown_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 7}));
806 auto expected_tensors =
807 EvaluateNodes(item.graph, fetch,
808 {{"x_known", x_known_t},
809 {"x_partially_unknown", x_partially_unknown_t},
810 {"x_unknown", x_unknown_t}});
811 EXPECT_EQ(fetch.size(), expected_tensors.size());
812 auto tensors = EvaluateNodes(output, fetch,
813 {{"x_known", x_known_t},
814 {"x_partially_unknown", x_partially_unknown_t},
815 {"x_unknown", x_unknown_t}});
816 EXPECT_EQ(fetch.size(), tensors.size());
817 for (int i = 0; i < tensors.size(); i++)
818 test::ExpectTensorNear<float>(expected_tensors[i], tensors[i], 1e-5);
819 }
820
TEST_F(ConstantFoldingTest,NeutralElement_PartialShape_KnownOutputShape)821 TEST_F(ConstantFoldingTest, NeutralElement_PartialShape_KnownOutputShape) {
822 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
823 Output known_shape = ops::Const(s.WithOpName("known_shape"), 0.0f, {2, 2});
824 Output x_partially_known =
825 ops::Placeholder(s.WithOpName("x_partially_unknown"), DT_FLOAT,
826 ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
827 Output x_unknown = ops::Placeholder(s.WithOpName("x_unknown"), DT_FLOAT);
828 Output zeros_partially_known =
829 ops::ZerosLike(s.WithOpName("zeros_partially_known"), x_partially_known);
830 Output zeros_unknown =
831 ops::ZerosLike(s.WithOpName("zeros_unknown"), x_unknown);
832
833 // If at least one of the inputs to AddN has a known shape, shape inference
834 // will propagate the shape back to the inputs of AddN, making the
835 // output shapes of all its inputs known
836 std::vector<Output> muls_deduced_output_shape;
837 std::unordered_set<string> to_const;
838 int count = 0;
839 for (const auto& x : {x_partially_known, x_unknown}) {
840 for (const auto& zeros : {zeros_partially_known, zeros_unknown}) {
841 const string name = strings::StrCat("mul_", count++);
842 muls_deduced_output_shape.push_back(
843 ops::Mul(s.WithOpName(name), x, zeros));
844 to_const.insert(name);
845 }
846 }
847 // We add a known shape as input to AddN to propagate it back to the
848 // multiplies above, which means they can all be turned into Const nodes.
849 muls_deduced_output_shape.push_back(known_shape);
850 Output addn1 = ops::AddN(s.WithOpName("addn1"), muls_deduced_output_shape);
851
852 GrapplerItem item;
853 TF_CHECK_OK(s.ToGraphDef(&item.graph));
854
855 ConstantFolding optimizer(/*cpu_device=*/nullptr);
856 GraphDef output;
857 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
858 TF_EXPECT_OK(status);
859
860 EXPECT_EQ(10, output.node_size());
861 for (int i = 0; i < output.node_size(); ++i) {
862 const NodeDef& node = output.node(i);
863 const string& name = node.name();
864 if (to_const.count(name) > 0) {
865 EXPECT_EQ("Const", node.op()) << node.name();
866 EXPECT_EQ(2, node.input_size());
867 EXPECT_TRUE(IsControlInput(node.input(0)));
868 EXPECT_TRUE(IsControlInput(node.input(1)));
869 }
870 }
871 const std::vector<string> fetch = {"addn1"};
872 auto x_partially_unknown_t =
873 GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
874 auto x_unknown_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
875 auto expected_tensors =
876 EvaluateNodes(item.graph, fetch,
877 {{"x_partially_unknown", x_partially_unknown_t},
878 {"x_unknown", x_unknown_t}});
879 EXPECT_EQ(1, expected_tensors.size());
880 auto tensors = EvaluateNodes(output, fetch,
881 {{"x_partially_unknown", x_partially_unknown_t},
882 {"x_unknown", x_unknown_t}});
883 EXPECT_EQ(1, tensors.size());
884 test::ExpectTensorNear<float>(expected_tensors[0], tensors[0], 1e-5);
885 }
886
TEST_F(ConstantFoldingTest,CreateConstNodes)887 TEST_F(ConstantFoldingTest, CreateConstNodes) {
888 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
889
890 #define MAKE_TEST_GRAPH(TYPE) \
891 Output TYPE##_const = \
892 ops::Const(s.WithOpName(#TYPE "_const"), static_cast<TYPE>(10), {5}); \
893 Output TYPE##_mul = \
894 ops::Mul(s.WithOpName(#TYPE "_mul"), TYPE##_const, TYPE##_const); \
895 Output TYPE##_id = ops::Identity(s.WithOpName(#TYPE "_id"), TYPE##_mul)
896
897 MAKE_TEST_GRAPH(float);
898 MAKE_TEST_GRAPH(double);
899 MAKE_TEST_GRAPH(int64);
900 MAKE_TEST_GRAPH(int32);
901 MAKE_TEST_GRAPH(int16);
902 MAKE_TEST_GRAPH(int8);
903 MAKE_TEST_GRAPH(uint8);
904 #undef MAKE_TEST_GRAPH
905
906 Output bool_const = ops::Const(s.WithOpName("bool_const"), true, {5});
907 Output bool_and =
908 ops::LogicalAnd(s.WithOpName("bool_and"), bool_const, bool_const);
909 Output bool_id = ops::Identity(s.WithOpName("bool_id"), bool_and);
910
911 GrapplerItem item;
912 TF_CHECK_OK(s.ToGraphDef(&item.graph));
913 ConstantFolding optimizer(/*cpu_device=*/nullptr);
914 GraphDef output;
915 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
916 TF_EXPECT_OK(status);
917
918 EXPECT_EQ(24, output.node_size());
919 for (const NodeDef& node : output.node()) {
920 #define CHECK_RESULT(TYPE, FIELD) \
921 if (node.name() == #TYPE "_mul") { \
922 EXPECT_EQ(5, \
923 node.attr().at("value").tensor().tensor_shape().dim(0).size()); \
924 EXPECT_EQ(1, node.attr().at("value").tensor().FIELD##_val_size()); \
925 EXPECT_EQ(10 * 10, node.attr().at("value").tensor().FIELD##_val(0)); \
926 }
927
928 CHECK_RESULT(float, float);
929 CHECK_RESULT(double, double);
930 CHECK_RESULT(int64, int64);
931 CHECK_RESULT(int32, int);
932 CHECK_RESULT(int16, int);
933 CHECK_RESULT(int8, int);
934 CHECK_RESULT(uint8, int);
935 #undef CHECK_RESULT
936
937 if (node.name() == "bool_and") {
938 EXPECT_EQ(5,
939 node.attr().at("value").tensor().tensor_shape().dim(0).size());
940 EXPECT_EQ(1, node.attr().at("value").tensor().bool_val_size());
941 EXPECT_EQ(true && true, node.attr().at("value").tensor().bool_val(0));
942 }
943 }
944 }
945
TEST_F(ConstantFoldingTest,FoldingNodeWithTwoOutputs)946 TEST_F(ConstantFoldingTest, FoldingNodeWithTwoOutputs) {
947 // Build a simple graph with a few trivially prunable ops.
948 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
949
950 Output a = ops::Const(s.WithOpName("a"), 10, {5});
951 auto b = ops::Unique(s.WithOpName("b"), {a});
952 Output c = ops::Identity(s.WithOpName("c"), {b.y});
953 Output d = ops::Identity(s.WithOpName("d"), {b.idx});
954 Output e = ops::Identity(s.WithOpName("e"), {c});
955 Output f = ops::Identity(s.WithOpName("f"), {d});
956
957 GrapplerItem item;
958 item.fetch.push_back("e");
959 item.fetch.push_back("f");
960 TF_CHECK_OK(s.ToGraphDef(&item.graph));
961
962 ConstantFolding optimizer(/*cpu_device=*/nullptr);
963 GraphDef output;
964 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
965 TF_EXPECT_OK(status);
966
967 EXPECT_EQ(2, output.node_size());
968
969 const NodeDef& new_c = output.node(0);
970 EXPECT_EQ("e", new_c.name());
971 EXPECT_EQ("Const", new_c.op());
972
973 const NodeDef& new_d = output.node(1);
974 EXPECT_EQ("f", new_d.name());
975 EXPECT_EQ("Const", new_d.op());
976
977 std::vector<string> fetch = {"e", "f"};
978 auto tensors_expected = EvaluateNodes(item.graph, fetch);
979 auto tensors = EvaluateNodes(output, fetch);
980 EXPECT_EQ(fetch.size(), tensors_expected.size());
981 EXPECT_EQ(fetch.size(), tensors.size());
982 for (int i = 0; i < fetch.size(); i++) {
983 test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
984 }
985 }
986
TEST_F(ConstantFoldingTest,ControlDependencies)987 TEST_F(ConstantFoldingTest, ControlDependencies) {
988 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
989 Output dflt = ops::Const(scope.WithOpName("dflt"), 3.14f, {1});
990 Output p1 = ops::PlaceholderWithDefault(scope.WithOpName("p1"), dflt, {1});
991 Output p2 = ops::PlaceholderWithDefault(scope.WithOpName("p2"), dflt, {1});
992 Output c =
993 ops::Const(scope.WithOpName("c").WithControlDependencies(p1), 10, {3});
994 Output i1 = ops::Identity(scope.WithOpName("i1"), {c});
995 Output i2 =
996 ops::Identity(scope.WithOpName("i2").WithControlDependencies(p2), {i1});
997 Output i3 = ops::Identity(scope.WithOpName("e"), {i2});
998
999 GrapplerItem item;
1000 item.fetch.push_back("e");
1001 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1002
1003 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1004 GraphDef output;
1005 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1006 TF_EXPECT_OK(status);
1007
1008 std::vector<string> expected_nodes = {"dflt", "p1", "p2", "e"};
1009 EXPECT_EQ(output.node_size(), expected_nodes.size());
1010 int i = 0;
1011 int found = 0;
1012 for (const auto& node : output.node()) {
1013 EXPECT_EQ(expected_nodes[i], output.node(i).name());
1014 i++;
1015 if (node.name() == "e") {
1016 EXPECT_EQ("Const", node.op());
1017 ++found;
1018 auto folded = EvaluateNodes(output, {"e"});
1019 auto expected = EvaluateNodes(item.graph, {"e"});
1020 EXPECT_EQ(1, expected.size());
1021 EXPECT_EQ(1, folded.size());
1022 test::ExpectTensorEqual<int>(folded[0], expected[0]);
1023 EXPECT_EQ(2, node.input_size());
1024 EXPECT_EQ("^p1", node.input(0));
1025 EXPECT_EQ("^p2", node.input(1));
1026 }
1027 }
1028 EXPECT_EQ(1, found);
1029 }
1030
TEST_F(ConstantFoldingTest,ControlDependenciesEmptyFetch)1031 TEST_F(ConstantFoldingTest, ControlDependenciesEmptyFetch) {
1032 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1033 Output dflt = ops::Const(scope.WithOpName("dflt"), 3.14f, {1});
1034 Output p1 = ops::PlaceholderWithDefault(scope.WithOpName("p1"), dflt, {1});
1035 Output p2 = ops::PlaceholderWithDefault(scope.WithOpName("p2"), dflt, {1});
1036 Output c =
1037 ops::Const(scope.WithOpName("c").WithControlDependencies(p1), 10, {3});
1038 Output i1 = ops::Identity(scope.WithOpName("i1"), {c});
1039 Output i2 =
1040 ops::Identity(scope.WithOpName("i2").WithControlDependencies(p2), {i1});
1041 Output i3 = ops::Identity(scope.WithOpName("e"), {i2});
1042
1043 GrapplerItem item;
1044 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1045
1046 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1047 GraphDef output;
1048 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1049 TF_EXPECT_OK(status);
1050
1051 std::vector<string> expected_nodes = {"dflt", "p1", "p2", "c",
1052 "i1", "i2", "e"};
1053 EXPECT_EQ(output.node_size(), expected_nodes.size());
1054 int i = 0;
1055 int found = 0;
1056 for (const auto& node : output.node()) {
1057 EXPECT_EQ(expected_nodes[i], output.node(i).name());
1058 i++;
1059 if (node.name() == "i1") {
1060 EXPECT_EQ("Const", node.op());
1061 ++found;
1062 auto folded = EvaluateNodes(output, {"i1"});
1063 auto expected = EvaluateNodes(item.graph, {"i1"});
1064 EXPECT_EQ(1, expected.size());
1065 EXPECT_EQ(1, folded.size());
1066 test::ExpectTensorEqual<int>(folded[0], expected[0]);
1067 EXPECT_EQ(1, node.input_size());
1068 EXPECT_EQ("^p1", node.input(0));
1069 }
1070 if (node.name() == "i2") {
1071 EXPECT_EQ("Const", node.op());
1072 ++found;
1073 auto folded = EvaluateNodes(output, {"i2"});
1074 auto expected = EvaluateNodes(item.graph, {"i2"});
1075 EXPECT_EQ(1, expected.size());
1076 EXPECT_EQ(1, folded.size());
1077 test::ExpectTensorEqual<int>(folded[0], expected[0]);
1078 EXPECT_EQ(2, node.input_size());
1079 EXPECT_EQ("^p1", node.input(0));
1080 EXPECT_EQ("^p2", node.input(1));
1081 }
1082 }
1083 EXPECT_EQ(2, found);
1084 }
1085
TEST_F(ConstantFoldingTest,ControlDependenciesDeduplicate)1086 TEST_F(ConstantFoldingTest, ControlDependenciesDeduplicate) {
1087 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1088 Output dflt = ops::Const(scope.WithOpName("dflt"), 3.14f, {1});
1089 Output p1 = ops::PlaceholderWithDefault(scope.WithOpName("p1"), dflt, {1});
1090 Output p2 = ops::PlaceholderWithDefault(scope.WithOpName("p2"), dflt, {1});
1091 Output c =
1092 ops::Const(scope.WithOpName("c").WithControlDependencies(p1), 10, {3});
1093 Output i1 = ops::Identity(scope.WithOpName("i1")
1094 .WithControlDependencies(p2)
1095 .WithControlDependencies(p1),
1096 {c});
1097 Output i2 = ops::Identity(scope.WithOpName("i2"), {i1});
1098
1099 GrapplerItem item;
1100 item.fetch.push_back("i2");
1101 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1102 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1103 EXPECT_EQ(1, tensors_expected.size());
1104 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1105 GraphDef output;
1106 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1107 TF_EXPECT_OK(status);
1108
1109 std::vector<string> expected_nodes = {"dflt", "p1", "p2", "i2"};
1110 EXPECT_EQ(output.node_size(), expected_nodes.size());
1111 int i = 0;
1112 for (const auto& node : output.node()) {
1113 EXPECT_EQ(expected_nodes[i], output.node(i).name());
1114 i++;
1115 if (node.name() == "i2") {
1116 EXPECT_EQ("Const", node.op());
1117 EXPECT_EQ(2, node.input_size());
1118 EXPECT_EQ("^p1", node.input(0));
1119 EXPECT_EQ("^p2", node.input(1));
1120 }
1121 }
1122 auto tensors = EvaluateNodes(output, item.fetch);
1123 EXPECT_EQ(1, tensors.size());
1124 test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1125 }
1126
TEST_F(ConstantFoldingTest,VariableNumberOfOutputs)1127 TEST_F(ConstantFoldingTest, VariableNumberOfOutputs) {
1128 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1129 // Add a DynamicPartition node to the graph
1130 Output input = ops::Const(scope.WithOpName("in0"), 314, {3, 4, 5});
1131 Output indices = ops::Const(scope.WithOpName("indices"), 1, {3, 4});
1132 int num_partitions = 4;
1133 ops::DynamicPartition part(scope.WithOpName("partition"), input, indices,
1134 num_partitions);
1135
1136 std::vector<string> outputs;
1137 for (int i = 0; i < num_partitions; ++i) {
1138 string part_out_name = strings::StrCat("part_out", i);
1139 ops::Identity partition_out(scope.WithOpName(part_out_name),
1140 {part.outputs[i]});
1141 outputs.push_back(part_out_name);
1142 }
1143
1144 GrapplerItem item;
1145 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1146
1147 // Add a ConcatOffset node to the graph
1148 Tensor initial_val(DT_INT32, TensorShape({3}));
1149 test::FillIota<int>(&initial_val, 7);
1150 for (int i = 1; i < 5; ++i) {
1151 TF_CHECK_OK(NodeDefBuilder(strings::StrCat("in", i), "Const")
1152 .Attr("dtype", DT_INT32)
1153 .Attr("value", initial_val)
1154 .Finalize(item.graph.add_node()));
1155 }
1156 Tensor concat_dim(DT_INT32, TensorShape({}));
1157 test::FillIota<int>(&concat_dim, 0);
1158 TF_CHECK_OK(NodeDefBuilder("concat_dim", "Const")
1159 .Attr("dtype", DT_INT32)
1160 .Attr("value", concat_dim)
1161 .Finalize(item.graph.add_node()));
1162
1163 TF_CHECK_OK(NodeDefBuilder("concat_offsets", "ConcatOffset")
1164 .Input("concat_dim", 0, DT_INT32)
1165 .Input({NodeDefBuilder::NodeOut("in1", 0, DT_INT32),
1166 NodeDefBuilder::NodeOut("in2", 0, DT_INT32),
1167 NodeDefBuilder::NodeOut("in3", 0, DT_INT32),
1168 NodeDefBuilder::NodeOut("in4", 0, DT_INT32)})
1169 .Finalize(item.graph.add_node()));
1170
1171 for (int i = 0; i < 4; ++i) {
1172 string concat_offset_out_name = strings::StrCat("concat_offset_out", i);
1173 TF_CHECK_OK(NodeDefBuilder(concat_offset_out_name, "Identity")
1174 .Attr("T", DT_INT32)
1175 .Input("concat_offsets", i, DT_INT32)
1176 .Finalize(item.graph.add_node()));
1177 outputs.push_back(concat_offset_out_name);
1178 }
1179
1180 item.fetch = outputs;
1181 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1182 GraphDef output;
1183 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1184 TF_EXPECT_OK(status);
1185
1186 int constant_folded = 0;
1187 for (const auto& node : output.node()) {
1188 if (node.name().find("part_out") != string::npos ||
1189 node.name().find("concat_offset_out") != string::npos) {
1190 ++constant_folded;
1191 EXPECT_EQ("Const", node.op());
1192 }
1193 }
1194 EXPECT_EQ(8, constant_folded);
1195
1196 auto expected = EvaluateNodes(item.graph, outputs);
1197 auto optimized = EvaluateNodes(output, outputs);
1198 ASSERT_EQ(expected.size(), optimized.size());
1199 for (int i = 0; i < expected.size(); ++i) {
1200 test::ExpectTensorEqual<int>(expected[i], optimized[i]);
1201 }
1202 }
1203
TEST_F(ConstantFoldingTest,ShapeMaterialization)1204 TEST_F(ConstantFoldingTest, ShapeMaterialization) {
1205 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1206 Output v1 = ops::Variable(scope.WithOpName("v1"), {3}, DT_FLOAT);
1207 Output v2 = ops::Variable(scope.WithOpName("v2"), {5, 7}, DT_FLOAT);
1208 Output v3 = ops::Variable(scope.WithOpName("v3"), {11, 13}, DT_FLOAT);
1209 Output rank = ops::Rank(scope.WithOpName("rank"), v1);
1210 Output shape = ops::Shape(scope.WithOpName("shape"), v2);
1211 Output size = ops::Size(scope.WithOpName("size"), v3);
1212 Output p1 = ops::Multiply(scope.WithOpName("p1"), size, rank);
1213 Output p2 = ops::Multiply(scope.WithOpName("p2"), p1, shape);
1214
1215 GrapplerItem item;
1216 item.fetch.push_back("p2");
1217 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1218
1219 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1220 GraphDef output;
1221 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1222 TF_EXPECT_OK(status);
1223
1224 int found = 0;
1225 for (const auto& node : output.node()) {
1226 if (node.name() == "p2") {
1227 ++found;
1228 EXPECT_EQ("Const", node.op());
1229 EXPECT_EQ(3, node.input_size());
1230 EXPECT_EQ("^v3", node.input(0));
1231 EXPECT_EQ("^v1", node.input(1));
1232 EXPECT_EQ("^v2", node.input(2));
1233 Tensor value;
1234 CHECK(value.FromProto(node.attr().at("value").tensor()));
1235 // rank = 1, shape = (5, 7), size = 143 = 11*13
1236 // p2 = (715, 1001) = (5*143, 7*143)
1237 EXPECT_EQ(715, value.flat<int>()(0));
1238 EXPECT_EQ(1001, value.flat<int>()(1));
1239 }
1240 }
1241 EXPECT_EQ(1, found);
1242 auto v1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3}));
1243 auto v2_t = GenerateRandomTensor<DT_FLOAT>({5, 7});
1244 auto v3_t = GenerateRandomTensor<DT_FLOAT>({11, 13});
1245
1246 auto tensors_expected = EvaluateNodes(
1247 item.graph, item.fetch, {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1248 EXPECT_EQ(1, item.fetch.size());
1249 auto tensors = EvaluateNodes(output, item.fetch,
1250 {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1251 EXPECT_EQ(1, item.fetch.size());
1252 test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1253 }
1254
TEST_F(ConstantFoldingTest,ShapeMaterializationEmptyFetch)1255 TEST_F(ConstantFoldingTest, ShapeMaterializationEmptyFetch) {
1256 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1257 Output v1 = ops::Variable(scope.WithOpName("v1"), {3}, DT_FLOAT);
1258 Output v2 = ops::Variable(scope.WithOpName("v2"), {5, 7}, DT_FLOAT);
1259 Output v3 = ops::Variable(scope.WithOpName("v3"), {11, 13}, DT_FLOAT);
1260 Output rank = ops::Rank(scope.WithOpName("rank"), v1);
1261 Output shape = ops::Shape(scope.WithOpName("shape"), v2);
1262 Output size = ops::Size(scope.WithOpName("size"), v3);
1263 Output p1 = ops::Multiply(scope.WithOpName("p1"), size, rank);
1264 Output p2 = ops::Multiply(scope.WithOpName("p2"), p1, shape);
1265
1266 GrapplerItem item;
1267 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1268
1269 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1270 GraphDef output;
1271 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1272 TF_EXPECT_OK(status);
1273
1274 int found = 0;
1275 for (const auto& node : output.node()) {
1276 if (node.name() == "size") {
1277 ++found;
1278 EXPECT_EQ("Const", node.op());
1279 EXPECT_EQ(1, node.input_size());
1280 EXPECT_EQ("^v3", node.input(0));
1281 Tensor value;
1282 CHECK(value.FromProto(node.attr().at("value").tensor()));
1283 EXPECT_EQ(11 * 13, value.flat<int>()(0));
1284 } else if (node.name() == "rank") {
1285 ++found;
1286 EXPECT_EQ("Const", node.op());
1287 EXPECT_EQ(1, node.input_size());
1288 EXPECT_EQ("^v1", node.input(0));
1289 Tensor value;
1290 CHECK(value.FromProto(node.attr().at("value").tensor()));
1291 EXPECT_EQ(1, value.flat<int>()(0));
1292 } else if (node.name() == "shape") {
1293 ++found;
1294 EXPECT_EQ("Const", node.op());
1295 EXPECT_EQ(1, node.input_size());
1296 EXPECT_EQ("^v2", node.input(0));
1297 Tensor value;
1298 CHECK(value.FromProto(node.attr().at("value").tensor()));
1299 EXPECT_EQ(5, value.flat<int>()(0));
1300 EXPECT_EQ(7, value.flat<int>()(1));
1301 }
1302 }
1303 EXPECT_EQ(3, found);
1304
1305 auto v1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3}));
1306 auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 7}));
1307 auto v3_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({11, 13}));
1308 std::vector<string> fetch_nodes = {"p2"};
1309 auto tensors_expected = EvaluateNodes(
1310 item.graph, fetch_nodes, {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1311 EXPECT_EQ(1, tensors_expected.size());
1312 auto tensors = EvaluateNodes(output, fetch_nodes,
1313 {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1314 EXPECT_EQ(1, tensors.size());
1315 test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1316 }
1317
TEST_F(ConstantFoldingTest,ShapeMaterializationShapeN)1318 TEST_F(ConstantFoldingTest, ShapeMaterializationShapeN) {
1319 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1320 Output v1 = ops::Variable(scope.WithOpName("v1"), {3, -1}, DT_FLOAT);
1321 Output v2 = ops::Variable(scope.WithOpName("v2"), {}, DT_FLOAT);
1322 Output v3 = ops::Variable(scope.WithOpName("v3"), {4, 6}, DT_FLOAT);
1323 auto s = ops::ShapeN(scope.WithOpName("s"), {v1, v2, v3});
1324 Output i1a = ops::Identity(scope.WithOpName("i1a"), s[0]);
1325 Output i1b = ops::Identity(scope.WithOpName("i1b"), s[0]);
1326 Output i2a = ops::Identity(scope.WithOpName("i2a"), s[1]);
1327 Output i2b = ops::Identity(scope.WithOpName("i2b"), s[1]);
1328 Output i2c = ops::Identity(scope.WithOpName("i2c"), s[1]);
1329 Output i3a = ops::Identity(scope.WithOpName("i3a"), s[2]);
1330 Output i3b = ops::Identity(scope.WithOpName("i3b"), s[2]);
1331
1332 GrapplerItem item;
1333 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1334
1335 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1336 GraphDef output;
1337 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1338 TF_EXPECT_OK(status);
1339 int found = 0;
1340 for (const auto& node : output.node()) {
1341 EXPECT_NE(AddPrefixToNodeName("s-matshapes-0", kConstantFoldingConst),
1342 node.name());
1343 EXPECT_NE(AddPrefixToNodeName("s-matshapes-1", kConstantFoldingConst),
1344 node.name());
1345 if (node.name() == "i1a" || node.name() == "i1b") {
1346 ++found;
1347 EXPECT_EQ("s", node.input(0));
1348 }
1349 if (node.name() == "i2a" || node.name() == "i2b" || node.name() == "i2c") {
1350 ++found;
1351 EXPECT_EQ("s:1", node.input(0));
1352 }
1353 if (node.name() == "i3a" || node.name() == "i3b") {
1354 ++found;
1355 EXPECT_EQ(AddPrefixToNodeName("s-matshapes-2", kConstantFoldingConst),
1356 node.input(0));
1357 }
1358 if (node.name() == "s") {
1359 ++found;
1360 EXPECT_EQ("ShapeN", node.op());
1361 EXPECT_EQ("v1", node.input(0));
1362 EXPECT_EQ("v2", node.input(1));
1363 EXPECT_EQ("v3", node.input(2));
1364 }
1365 if (node.name() ==
1366 AddPrefixToNodeName("s-matshapes-2", kConstantFoldingConst)) {
1367 ++found;
1368 EXPECT_EQ("Const", node.op());
1369 EXPECT_EQ("^s", node.input(0));
1370 Tensor value;
1371 CHECK(value.FromProto(node.attr().at("value").tensor()));
1372 EXPECT_EQ(4, value.flat<int>()(0));
1373 EXPECT_EQ(6, value.flat<int>()(1));
1374 }
1375 }
1376 EXPECT_EQ(9, found);
1377
1378 auto v1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
1379 auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 6}));
1380 auto v3_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 6}));
1381 const std::vector<string> fetch_nodes = {"i1a", "i1b", "i2a", "i2b",
1382 "i2c", "i3a", "i3b"};
1383 auto tensors_expected = EvaluateNodes(
1384 item.graph, fetch_nodes, {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1385 EXPECT_EQ(fetch_nodes.size(), tensors_expected.size());
1386 auto tensors = EvaluateNodes(output, fetch_nodes,
1387 {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1388 EXPECT_EQ(fetch_nodes.size(), tensors.size());
1389 for (int i = 0; i < fetch_nodes.size(); i++)
1390 test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
1391 }
1392
TEST_F(ConstantFoldingTest,ShapeMaterializationShapeN_MultipleOutputs)1393 TEST_F(ConstantFoldingTest, ShapeMaterializationShapeN_MultipleOutputs) {
1394 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1395 Output v1 = ops::Variable(scope.WithOpName("v1"), {3, -1}, DT_FLOAT);
1396 Output v2 = ops::Variable(scope.WithOpName("v2"), {4, 6}, DT_FLOAT);
1397 auto s = ops::ShapeN(scope.WithOpName("s"), {v1, v2});
1398 auto id_n = ops::IdentityN(scope.WithOpName("id_n"), {s[0], s[1]});
1399 Output ia = ops::Identity(scope.WithOpName("ia"), id_n[0]);
1400 Output ib = ops::Identity(scope.WithOpName("ib"), id_n[1]);
1401
1402 GrapplerItem item;
1403 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1404 item.fetch.push_back("ia");
1405 item.fetch.push_back("ib");
1406
1407 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1408 GraphDef output;
1409 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1410 TF_EXPECT_OK(status);
1411
1412 int found = 0;
1413 for (const auto& node : output.node()) {
1414 EXPECT_NE(AddPrefixToNodeName("s-matshapes-0", kConstantFoldingConst),
1415 node.name());
1416 if (node.name() == "s") {
1417 ++found;
1418 EXPECT_EQ("ShapeN", node.op());
1419 EXPECT_EQ("v1", node.input(0));
1420 EXPECT_EQ("v2", node.input(1));
1421 }
1422 if (node.name() == "id_n") {
1423 ++found;
1424 EXPECT_EQ("IdentityN", node.op());
1425 EXPECT_EQ("s", node.input(0));
1426 EXPECT_EQ(AddPrefixToNodeName("s-matshapes-1", kConstantFoldingConst),
1427 node.input(1));
1428 }
1429 if (node.name() == "ia") {
1430 ++found;
1431 EXPECT_EQ("id_n", node.input(0));
1432 }
1433 if (node.name() == "ib") {
1434 ++found;
1435 EXPECT_EQ("Const", node.op());
1436 EXPECT_EQ("^s", node.input(0));
1437 EXPECT_EQ("^id_n", node.input(1));
1438 }
1439 }
1440 EXPECT_EQ(4, found);
1441
1442 auto v1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
1443 auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 6}));
1444 auto tensors_expected =
1445 EvaluateNodes(item.graph, item.fetch, {{"v1", v1_t}, {"v2", v2_t}});
1446 EXPECT_EQ(2, tensors_expected.size());
1447 auto tensors =
1448 EvaluateNodes(output, item.fetch, {{"v1", v1_t}, {"v2", v2_t}});
1449 EXPECT_EQ(2, tensors.size());
1450 for (int i = 0; i < tensors.size(); i++)
1451 test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
1452 }
1453
TEST_F(ConstantFoldingTest,SwitchNodesEmptyFetch)1454 TEST_F(ConstantFoldingTest, SwitchNodesEmptyFetch) {
1455 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1456 ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT);
1457 ops::Variable v_ctrl(scope.WithOpName("v_ctrl"), {}, DT_BOOL);
1458 ops::Switch s1(scope.WithOpName("switch"), v_in, v_ctrl);
1459 ops::Rank rank(scope.WithOpName("rank"), s1.output_false);
1460 ops::Identity i(scope.WithOpName("i"), s1.output_true);
1461 ops::Size size(scope.WithOpName("size"), i);
1462 ops::Square p1(scope.WithOpName("p1"), rank);
1463 ops::Square p2(scope.WithOpName("p2"), size);
1464 ops::Merge m(scope.WithOpName("m"), {p1.y, p2.y});
1465
1466 Output predicate =
1467 ops::Const(scope.WithOpName("false"), false, TensorShape({}));
1468 Output constant =
1469 ops::Const(scope.WithOpName("constant"), 1.0f, TensorShape({1}));
1470 ops::Switch s2(scope.WithOpName("switch2"), constant, predicate);
1471 ops::Identity statically_known(scope.WithOpName("i2"), s2.output_false);
1472 ops::Identity never_generated(scope.WithOpName("i3"), s2.output_true);
1473 ops::Merge m2(scope.WithOpName("m2"),
1474 {statically_known.output, never_generated.output});
1475
1476 GrapplerItem item;
1477 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1478
1479 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1480 GraphDef output;
1481 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1482 TF_EXPECT_OK(status);
1483
1484 std::set<string> present_nodes = {"v_in", "v_ctrl",
1485 "switch", "i",
1486 "p1", "p2",
1487 "m", "false",
1488 "constant", "switch2",
1489 "i2", "i3",
1490 "m2", "ConstantFoldingCtrl/switch_0",
1491 "rank", "size"};
1492 std::set<string> not_present_nodes = {"ConstantFolding/switch2-0"};
1493 EXPECT_EQ(present_nodes.size(), output.node_size());
1494 int found = 0;
1495 for (const auto& node : output.node()) {
1496 EXPECT_TRUE(present_nodes.find(node.name()) != present_nodes.end())
1497 << node.name();
1498 EXPECT_TRUE(not_present_nodes.find(node.name()) == not_present_nodes.end())
1499 << node.name();
1500 present_nodes.erase(node.name());
1501 not_present_nodes.erase(node.name());
1502 if (node.name() == "rank") {
1503 ++found;
1504 EXPECT_EQ("Const", node.op());
1505 EXPECT_EQ(1, node.input_size());
1506 EXPECT_EQ("^ConstantFoldingCtrl/switch_0", node.input(0));
1507 }
1508 if (node.name() == "size") {
1509 ++found;
1510 EXPECT_EQ("Const", node.op());
1511 EXPECT_EQ(1, node.input_size());
1512 EXPECT_EQ("^i", node.input(0));
1513 }
1514 if (node.name() == "i2") {
1515 ++found;
1516 EXPECT_EQ("Const", node.op());
1517 EXPECT_EQ(0, node.input_size());
1518 }
1519 if (node.name() == "i3") {
1520 ++found;
1521 EXPECT_EQ("Identity", node.op());
1522 EXPECT_EQ(1, node.input_size());
1523 EXPECT_EQ("switch2:1", node.input(0));
1524 }
1525 }
1526 EXPECT_EQ(4, found);
1527
1528 auto v_in_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3}));
1529 Tensor v_ctrl_t(DT_BOOL, TensorShape({}));
1530
1531 v_ctrl_t.flat<bool>()(0) = true;
1532 std::vector<string> fetch_nodes = {"m", "m2"};
1533 auto tensors_expected = EvaluateNodes(
1534 item.graph, fetch_nodes, {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1535 EXPECT_EQ(2, tensors_expected.size());
1536 auto tensors = EvaluateNodes(output, fetch_nodes,
1537 {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1538 EXPECT_EQ(2, tensors.size());
1539 test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1540 test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-5);
1541
1542 v_ctrl_t.flat<bool>()(0) = false;
1543 tensors_expected = EvaluateNodes(item.graph, fetch_nodes,
1544 {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1545 EXPECT_EQ(2, tensors_expected.size());
1546 tensors = EvaluateNodes(output, fetch_nodes,
1547 {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1548 EXPECT_EQ(2, tensors.size());
1549 test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1550 test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-5);
1551 }
1552
TEST_F(ConstantFoldingTest,SwitchNodes)1553 TEST_F(ConstantFoldingTest, SwitchNodes) {
1554 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1555 ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT);
1556 ops::Variable v_ctrl(scope.WithOpName("v_ctrl"), {}, DT_BOOL);
1557 ops::Switch s1(scope.WithOpName("switch"), v_in, v_ctrl);
1558 ops::Rank rank(scope.WithOpName("rank"), s1.output_false);
1559 ops::Identity i(scope.WithOpName("i"), s1.output_true);
1560 ops::Size size(scope.WithOpName("size"), i);
1561 ops::Square p1(scope.WithOpName("p1"), rank);
1562 ops::Square p2(scope.WithOpName("p2"), size);
1563 ops::Merge m(scope.WithOpName("m"), {p1.y, p2.y});
1564
1565 Output predicate =
1566 ops::Const(scope.WithOpName("false"), false, TensorShape({}));
1567 Output constant =
1568 ops::Const(scope.WithOpName("constant"), 1.0f, TensorShape({1}));
1569 ops::Switch s2(scope.WithOpName("switch2"), constant, predicate);
1570 ops::Identity statically_known(scope.WithOpName("i2"), s2.output_false);
1571 ops::Identity never_generated(scope.WithOpName("i3"), s2.output_true);
1572 ops::Merge m2(scope.WithOpName("m2"),
1573 {statically_known.output, never_generated.output});
1574
1575 GrapplerItem item;
1576 item.fetch.push_back("m");
1577 item.fetch.push_back("m2");
1578
1579 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1580
1581 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1582 GraphDef output;
1583 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1584 TF_EXPECT_OK(status);
1585 std::set<string> present_nodes = {"v_in", "v_ctrl",
1586 "switch", "i",
1587 "p1", "p2",
1588 "m", "false",
1589 "constant", "switch2",
1590 "i2", "i3",
1591 "m2", "ConstantFoldingCtrl/switch_0"};
1592 std::set<string> not_present_nodes = {"rank", "size",
1593 "ConstantFolding/switch2-0"};
1594 EXPECT_EQ(present_nodes.size(), output.node_size());
1595
1596 int found = 0;
1597 for (const auto& node : output.node()) {
1598 EXPECT_TRUE(present_nodes.find(node.name()) != present_nodes.end());
1599 EXPECT_TRUE(not_present_nodes.find(node.name()) == not_present_nodes.end());
1600 present_nodes.erase(node.name());
1601 not_present_nodes.erase(node.name());
1602 if (node.name() == "i2") {
1603 ++found;
1604 EXPECT_EQ("Const", node.op());
1605 EXPECT_EQ(0, node.input_size());
1606 }
1607 if (node.name() == "i3") {
1608 ++found;
1609 EXPECT_EQ("Identity", node.op());
1610 EXPECT_EQ(1, node.input_size());
1611 EXPECT_EQ("switch2:1", node.input(0));
1612 }
1613 }
1614 EXPECT_EQ(2, found);
1615
1616 auto v_in_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3}));
1617 Tensor v_ctrl_t(DT_BOOL, TensorShape({}));
1618 v_ctrl_t.flat<bool>()(0) = true;
1619 auto tensors_expected = EvaluateNodes(
1620 item.graph, item.fetch, {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1621 EXPECT_EQ(2, tensors_expected.size());
1622 auto tensors = EvaluateNodes(output, item.fetch,
1623 {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1624 EXPECT_EQ(2, tensors.size());
1625 test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1626 test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-5);
1627
1628 v_ctrl_t.flat<bool>()(0) = false;
1629 tensors_expected = EvaluateNodes(item.graph, item.fetch,
1630 {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1631 EXPECT_EQ(2, tensors_expected.size());
1632 tensors = EvaluateNodes(output, item.fetch,
1633 {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1634 EXPECT_EQ(2, tensors.size());
1635 test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1636 test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-5);
1637 }
1638
TEST_F(ConstantFoldingTest,MergeNodes)1639 TEST_F(ConstantFoldingTest, MergeNodes) {
1640 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1641
1642 Output x =
1643 ops::RandomNormal(scope.WithOpName("x"), {3, 5}, DataType::DT_FLOAT);
1644 Output y =
1645 ops::RandomNormal(scope.WithOpName("y"), {3, 5}, DataType::DT_FLOAT);
1646 Output const1 =
1647 ops::Const(scope.WithOpName("const1").WithControlDependencies(x), 2.7f,
1648 TensorShape({3, 5}));
1649 Output const2 =
1650 ops::Const(scope.WithOpName("const2"), 3.14f, TensorShape({3, 5}));
1651 Output const3 =
1652 ops::Const(scope.WithOpName("const3").WithControlDependencies(x), 3.14f,
1653 TensorShape({3, 5}));
1654
1655 // Create 3 merge nodes: m1 is foldable, m2 and m3 aren't.
1656 ops::Merge m1(scope.WithOpName("m1"), {x, const1, const2});
1657 ops::Merge m2(scope.WithOpName("m2"), {const1, const3});
1658 ops::Merge m3(scope.WithOpName("m3"), {x, y});
1659 // m4 is not foldable because the only constant input
1660 // has a control input, so we cannot know if it will be
1661 // triggered.
1662 ops::Merge m4(scope.WithOpName("m4"), {x, const1});
1663
1664 ops::Identity out1(scope.WithOpName("out1"), m1.output);
1665 ops::Identity idx1(scope.WithOpName("idx1"), m1.value_index);
1666 ops::Identity out2(scope.WithOpName("out2"), m2.output);
1667 ops::Identity idx2(scope.WithOpName("idx2"), m2.value_index);
1668 ops::Identity out3(scope.WithOpName("out3"), m3.output);
1669 ops::Identity idx3(scope.WithOpName("idx3"), m3.value_index);
1670 ops::Identity out4(scope.WithOpName("out4"), m4.output);
1671 ops::Identity idx4(scope.WithOpName("idx4"), m4.value_index);
1672
1673 GrapplerItem item;
1674 item.fetch = {"out1", "idx1", "out2", "idx2", "out3", "idx3", "out4", "idx4"};
1675 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1676
1677 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1678 GraphDef output;
1679 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1680 TF_EXPECT_OK(status);
1681
1682 EXPECT_EQ(19, output.node_size());
1683 int found_nodes = 0;
1684 for (const auto& node : output.node()) {
1685 if (node.name() == "out1") {
1686 EXPECT_EQ(1, node.input_size());
1687 EXPECT_EQ("^m1", node.input(0));
1688 ++found_nodes;
1689 } else if (node.name() == "idx1") {
1690 EXPECT_EQ(1, node.input_size());
1691 EXPECT_EQ("^m1", node.input(0));
1692 ++found_nodes;
1693 } else if (node.name() == "ConstantFolding/m1") {
1694 EXPECT_EQ("Const", node.op());
1695 EXPECT_EQ(1, node.input_size());
1696 EXPECT_EQ("^m1", node.input(0));
1697 ++found_nodes;
1698 } else if (node.name() == "ConstantFolding/m1_index") {
1699 EXPECT_EQ("Const", node.op());
1700 EXPECT_EQ(1, node.input_size());
1701 EXPECT_EQ("^m1", node.input(0));
1702 ++found_nodes;
1703 } else if (node.name() == "out2") {
1704 EXPECT_EQ(1, node.input_size());
1705 EXPECT_EQ("m2", node.input(0));
1706 ++found_nodes;
1707 } else if (node.name() == "idx2") {
1708 EXPECT_EQ(1, node.input_size());
1709 EXPECT_EQ("m2:1", node.input(0));
1710 ++found_nodes;
1711 } else if (node.name() == "out3") {
1712 EXPECT_EQ(1, node.input_size());
1713 EXPECT_EQ("m3", node.input(0));
1714 ++found_nodes;
1715 } else if (node.name() == "idx3") {
1716 EXPECT_EQ(1, node.input_size());
1717 EXPECT_EQ("m3:1", node.input(0));
1718 ++found_nodes;
1719 } else if (node.name() == "out4") {
1720 EXPECT_EQ(1, node.input_size());
1721 EXPECT_EQ("m4", node.input(0));
1722 ++found_nodes;
1723 } else if (node.name() == "idx4") {
1724 EXPECT_EQ(1, node.input_size());
1725 EXPECT_EQ("m4:1", node.input(0));
1726 ++found_nodes;
1727 }
1728 }
1729 // Make sure the graph contains all the nodes we're expecting.
1730 EXPECT_EQ(8, found_nodes);
1731
1732 std::vector<string> fetch = {"out1", "idx1"};
1733 auto tensors = EvaluateNodes(output, fetch);
1734 EXPECT_EQ(2, tensors.size());
1735 const Tensor& out_value = tensors[0];
1736 EXPECT_EQ(3 * 5, out_value.NumElements());
1737 for (int i = 0; i < 3 * 5; ++i) {
1738 EXPECT_EQ(3.14f, out_value.flat<float>()(i));
1739 }
1740 const Tensor& out_idx = tensors[1];
1741 EXPECT_EQ(1, out_idx.NumElements());
1742 EXPECT_EQ(2, out_idx.flat<int32>()(0));
1743 }
1744
TEST_F(ConstantFoldingTest,SplitRemoval)1745 TEST_F(ConstantFoldingTest, SplitRemoval) {
1746 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1747
1748 Output in1 =
1749 ops::Variable(scope.WithOpName("in1"), TensorShape({2}), DT_FLOAT);
1750 Output in2 =
1751 ops::Variable(scope.WithOpName("in2"), TensorShape({4}), DT_FLOAT);
1752 auto split_dim = ops::Const(scope.WithOpName("split_dim"), {0}, {});
1753 ops::Split s1(scope.WithOpName("s1"), split_dim, in1, 1);
1754 ops::Split s2(scope.WithOpName("s2"), split_dim, in2, 2);
1755
1756 ops::Add out(scope.WithOpName("out"), s1[0], s2[0]);
1757
1758 GrapplerItem item;
1759 item.fetch = {"out"};
1760 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1761
1762 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1763 GraphDef got;
1764 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
1765 TF_EXPECT_OK(status);
1766
1767 GraphDef want;
1768 AddNode("in1", "VariableV2", {}, {}, &want);
1769 AddNode("in2", "VariableV2", {}, {}, &want);
1770 AddNode("split_dim", "Const", {}, {}, &want);
1771 AddNode("s1", "Identity", {"in1", AsControlDependency("split_dim")}, {},
1772 &want);
1773 AddNode("s2", "Split", {"split_dim", "in2"}, {}, &want);
1774 AddNode("out", "Add", {"s1", "s2"}, {}, &want);
1775
1776 CompareGraphs(want, got);
1777
1778 auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2}));
1779 auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4}));
1780 auto tensors_expected =
1781 EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
1782 EXPECT_EQ(1, tensors_expected.size());
1783 auto tensors =
1784 EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
1785 EXPECT_EQ(1, tensors.size());
1786 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
1787 }
1788
TEST_F(ConstantFoldingTest,SplitVRemoval)1789 TEST_F(ConstantFoldingTest, SplitVRemoval) {
1790 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1791
1792 Output in1 =
1793 ops::Variable(scope.WithOpName("in1"), TensorShape({2}), DT_FLOAT);
1794 Output in2 =
1795 ops::Variable(scope.WithOpName("in2"), TensorShape({5}), DT_FLOAT);
1796 auto split_dim = ops::Const(scope.WithOpName("split_dim"), {0}, {});
1797 auto size_splits1 = ops::Const(scope.WithOpName("size_splits1"), {2}, {1});
1798 auto size_splits2 = ops::Const(scope.WithOpName("size_splits2"), {2, 3}, {2});
1799 ops::SplitV s1(scope.WithOpName("s1"), in1, size_splits1, split_dim, 1);
1800 ops::SplitV s2(scope.WithOpName("s2"), in2, size_splits2, split_dim, 2);
1801
1802 ops::Add out(scope.WithOpName("out"), s1[0], s2[0]);
1803
1804 GrapplerItem item;
1805 item.fetch = {"out"};
1806 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1807
1808 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1809 GraphDef got;
1810 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
1811 TF_EXPECT_OK(status);
1812
1813 GraphDef want;
1814 AddNode("in1", "VariableV2", {}, {}, &want);
1815 AddNode("in2", "VariableV2", {}, {}, &want);
1816 AddNode("split_dim", "Const", {}, {}, &want);
1817 AddNode("size_splits1", "Const", {}, {}, &want);
1818 AddNode("size_splits2", "Const", {}, {}, &want);
1819 AddNode("s1", "Identity",
1820 {"in1", AsControlDependency("size_splits1"),
1821 AsControlDependency("split_dim")},
1822 {}, &want);
1823 AddNode("s2", "SplitV", {"in2", "size_splits2", "split_dim"}, {}, &want);
1824 AddNode("out", "Add", {"s1", "s2"}, {}, &want);
1825
1826 CompareGraphs(want, got);
1827
1828 auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2}));
1829 auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5}));
1830 auto tensors_expected =
1831 EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
1832 EXPECT_EQ(1, tensors_expected.size());
1833 auto tensors =
1834 EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
1835 EXPECT_EQ(1, tensors.size());
1836 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
1837 }
1838
TEST_F(ConstantFoldingTest,TransposeOnSize1DimsRemoval)1839 TEST_F(ConstantFoldingTest, TransposeOnSize1DimsRemoval) {
1840 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1841
1842 Output in1 = ops::Variable(scope.WithOpName("in1"), TensorShape({1, 2, 4, 1}),
1843 DT_FLOAT);
1844 Output p1 = ops::Const(scope.WithOpName("p1"), {3, 2, 1, 0}, {4});
1845 Output in2 = ops::Variable(scope.WithOpName("in2"), TensorShape({1, 4, 2, 1}),
1846 DT_FLOAT);
1847 Output p2 = ops::Const(scope.WithOpName("p2"), {3, 1, 2, 0}, {4});
1848 ops::Transpose t1(scope.WithOpName("t1"), in1, p1);
1849 ops::Transpose t2(scope.WithOpName("t2").WithControlDependencies({in1}), in2,
1850 p2);
1851
1852 ops::Add out1(scope.WithOpName("out1"), t1, t2);
1853
1854 GrapplerItem item;
1855 item.fetch = {"out1"};
1856 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1857
1858 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1859 GraphDef got;
1860 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
1861 TF_EXPECT_OK(status);
1862
1863 GraphDef want;
1864 AddNode("in1", "VariableV2", {}, {}, &want);
1865 AddNode("in2", "VariableV2", {}, {}, &want);
1866 AddNode("p1", "Const", {}, {}, &want);
1867 AddNode("p2", "Const", {}, {}, &want);
1868 AddNode("t1", "Transpose", {"in1", "p1"}, {}, &want);
1869 AddNode("t2", "Identity",
1870 {"in2", AsControlDependency("in1"), AsControlDependency("p2")}, {},
1871 &want);
1872 AddNode("out1", "Add", {"t1", "t2"}, {}, &want);
1873
1874 CompareGraphs(want, got);
1875 }
1876
TEST_F(ConstantFoldingTest,RandomShuffleOnScalarRemoval)1877 TEST_F(ConstantFoldingTest, RandomShuffleOnScalarRemoval) {
1878 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1879
1880 Output in1 =
1881 ops::Variable(scope.WithOpName("in1"), TensorShape({}), DT_FLOAT);
1882 Output in2 =
1883 ops::Variable(scope.WithOpName("in2"), TensorShape({}), DT_FLOAT);
1884 ops::RandomShuffle s1(scope.WithOpName("s1"), in1);
1885 ops::RandomShuffle s2(scope.WithOpName("s2").WithControlDependencies({in1}),
1886 in2);
1887
1888 ops::Add out1(scope.WithOpName("out1"), s1, s2);
1889 ops::Identity out2(scope.WithOpName("out2"), s2);
1890
1891 GrapplerItem item;
1892 item.fetch = {"out1", "out2"};
1893 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1894
1895 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1896 GraphDef got;
1897 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
1898 TF_EXPECT_OK(status);
1899
1900 GraphDef want;
1901 AddNode("in1", "VariableV2", {}, {}, &want);
1902 AddNode("in2", "VariableV2", {}, {}, &want);
1903 AddNode("s1", "Identity", {"in1"}, {}, &want);
1904 AddNode("s2", "Identity", {"in2", AsControlDependency("in1")}, {}, &want);
1905 AddNode("out1", "Add", {"s1", "s2"}, {}, &want);
1906 AddNode("out2", "Identity", {"s2"}, {}, &want);
1907
1908 CompareGraphs(want, got);
1909
1910 auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({}));
1911 auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({}));
1912 auto tensors_expected =
1913 EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
1914 EXPECT_EQ(2, tensors_expected.size());
1915 auto tensors =
1916 EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
1917 EXPECT_EQ(2, tensors.size());
1918 for (int i = 0; i < tensors.size(); i++)
1919 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
1920 }
1921
TEST_F(ConstantFoldingTest,ReverseOnSize1DimsRemoval)1922 TEST_F(ConstantFoldingTest, ReverseOnSize1DimsRemoval) {
1923 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1924
1925 Output in1 = ops::Variable(scope.WithOpName("in1"), TensorShape({1, 2, 4, 1}),
1926 DT_FLOAT);
1927 Output a1 = ops::Const(scope.WithOpName("a1"), {3, 2, 1, 0}, {4});
1928 Output in2 = ops::Variable(scope.WithOpName("in2"), TensorShape({1, 2, 4, 1}),
1929 DT_FLOAT);
1930 Output a2 = ops::Const(scope.WithOpName("a2"), {0, 3}, {2});
1931 ops::Reverse r1(scope.WithOpName("r1"), in1, a1);
1932 ops::Reverse r2(scope.WithOpName("r2").WithControlDependencies({in1}), in2,
1933 a2);
1934
1935 ops::Add out1(scope.WithOpName("out1"), r1, r2);
1936
1937 GrapplerItem item;
1938 item.fetch = {"out1"};
1939 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1940
1941 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1942 GraphDef got;
1943 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
1944 TF_EXPECT_OK(status);
1945
1946 GraphDef want;
1947 AddNode("in1", "VariableV2", {}, {}, &want);
1948 AddNode("in2", "VariableV2", {}, {}, &want);
1949 AddNode("a1", "Const", {}, {}, &want);
1950 AddNode("a2", "Const", {}, {}, &want);
1951 AddNode("r1", "ReverseV2", {"in1", "a1"}, {}, &want);
1952 AddNode("r2", "Identity",
1953 {"in2", AsControlDependency("in1"), AsControlDependency("a2")}, {},
1954 &want);
1955 AddNode("out1", "Add", {"r1", "r2"}, {}, &want);
1956
1957 CompareGraphs(want, got);
1958 }
1959
TEST_F(ConstantFoldingTest,SliceWithSameDimensionRemoval)1960 TEST_F(ConstantFoldingTest, SliceWithSameDimensionRemoval) {
1961 { // size = {3, 5}
1962 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1963
1964 auto in1 = ops::Variable(scope.WithOpName("in1"), {3, 5}, DT_FLOAT);
1965 auto begin = ops::Const(scope.WithOpName("begin"), {0, 0}, {2});
1966 auto size = ops::Const(scope.WithOpName("size"), {3, 5}, {2});
1967 Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6}, DT_FLOAT);
1968 ops::Slice s1(scope.WithOpName("s1"), in1, begin, size);
1969 ops::Slice s2(scope.WithOpName("s2"), in2, begin, size);
1970
1971 ops::Add out(scope.WithOpName("out"), s1, s2);
1972
1973 GrapplerItem item;
1974 item.fetch = {"out"};
1975 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1976
1977 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1978 GraphDef got;
1979 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
1980 TF_EXPECT_OK(status);
1981
1982 GraphDef want;
1983 AddNode("in1", "VariableV2", {}, {}, &want);
1984 AddNode("in2", "VariableV2", {}, {}, &want);
1985 AddNode("begin", "Const", {}, {}, &want);
1986 AddNode("size", "Const", {}, {}, &want);
1987 AddNode("s1", "Identity",
1988 {"in1", AsControlDependency("begin"), AsControlDependency("size")},
1989 {}, &want);
1990 AddNode("s2", "Slice", {"in2", "begin", "size"}, {}, &want);
1991 AddNode("out", "Add", {"s1", "s2"}, {}, &want);
1992
1993 CompareGraphs(want, got);
1994
1995 auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5}));
1996 auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 6}));
1997 auto tensors_expected =
1998 EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
1999 EXPECT_EQ(1, tensors_expected.size());
2000 auto tensors =
2001 EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2002 EXPECT_EQ(1, tensors.size());
2003 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2004 }
2005 { // size = {-1, -1}
2006 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2007
2008 auto in1 =
2009 ops::Variable(scope.WithOpName("in1"), {3, 5}, DataType::DT_FLOAT);
2010 auto begin1 = ops::Const(scope.WithOpName("begin1"), {0, 0}, {2});
2011 auto begin2 = ops::Const(scope.WithOpName("begin2"), {1, 1}, {2});
2012 auto size = ops::Const(scope.WithOpName("size"), {-1, -1}, {2});
2013 Output in2 =
2014 ops::Variable(scope.WithOpName("in2"), {4, 6}, DataType::DT_FLOAT);
2015 ops::Slice s1(scope.WithOpName("s1"), in1, begin1, size);
2016 ops::Slice s2(scope.WithOpName("s2"), in2, begin2, size);
2017
2018 ops::Add out(scope.WithOpName("out"), s1, s2);
2019
2020 GrapplerItem item;
2021 item.fetch = {"out"};
2022 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2023
2024 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2025 GraphDef got;
2026 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2027 TF_EXPECT_OK(status);
2028
2029 GraphDef want;
2030 AddNode("in1", "VariableV2", {}, {}, &want);
2031 AddNode("in2", "VariableV2", {}, {}, &want);
2032 AddNode("begin1", "Const", {}, {}, &want);
2033 AddNode("begin2", "Const", {}, {}, &want);
2034 AddNode("size", "Const", {}, {}, &want);
2035 AddNode("s1", "Identity",
2036 {"in1", AsControlDependency("begin1"), AsControlDependency("size")},
2037 {}, &want);
2038 AddNode("s2", "Slice", {"in2", "begin2", "size"}, {}, &want);
2039 AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2040
2041 CompareGraphs(want, got);
2042
2043 auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5}));
2044 auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 6}));
2045 auto tensors_expected =
2046 EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2047 EXPECT_EQ(1, tensors_expected.size());
2048 auto tensors =
2049 EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2050 EXPECT_EQ(1, tensors.size());
2051 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2052 }
2053 }
2054
TEST_F(ConstantFoldingTest,StridedSliceWithSameDimensionRemoval)2055 TEST_F(ConstantFoldingTest, StridedSliceWithSameDimensionRemoval) {
2056 { // no mask
2057 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2058
2059 auto in1 = ops::Variable(scope.WithOpName("in1"), {3, 5, 2}, DT_FLOAT);
2060 auto begin = ops::Const(scope.WithOpName("begin"), {0, 0}, {2});
2061 auto end = ops::Const(scope.WithOpName("end"), {3, 5}, {2});
2062 auto strides = ops::Const(scope.WithOpName("strides"), {1, 1}, {2});
2063 Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6, 2}, DT_FLOAT);
2064 ops::StridedSlice s1(scope.WithOpName("s1"), in1, begin, end, strides);
2065 ops::StridedSlice s2(scope.WithOpName("s2"), in2, begin, end, strides);
2066
2067 ops::Add out(scope.WithOpName("out"), s1, s2);
2068
2069 GrapplerItem item;
2070 item.fetch = {"out"};
2071 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2072
2073 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2074 GraphDef got;
2075 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2076 TF_EXPECT_OK(status);
2077
2078 GraphDef want;
2079 AddNode("in1", "VariableV2", {}, {}, &want);
2080 AddNode("in2", "VariableV2", {}, {}, &want);
2081 AddNode("begin", "Const", {}, {}, &want);
2082 AddNode("end", "Const", {}, {}, &want);
2083 AddNode("strides", "Const", {}, {}, &want);
2084 AddNode("s1", "Identity",
2085 {"in1", AsControlDependency("begin"), AsControlDependency("end"),
2086 AsControlDependency("strides")},
2087 {}, &want);
2088 AddNode("s2", "StridedSlice", {"in2", "begin", "end", "strides"}, {},
2089 &want);
2090 AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2091
2092 CompareGraphs(want, got);
2093
2094 auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 2}));
2095 auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 6, 2}));
2096 auto tensors_expected =
2097 EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2098 EXPECT_EQ(1, tensors_expected.size());
2099 auto tensors =
2100 EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2101 EXPECT_EQ(1, tensors.size());
2102 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2103 }
2104 { // with begin/end/ellipsis mask
2105 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2106
2107 // s1 = in1[:, ..., 0:5, 0:6]
2108 auto in1 =
2109 ops::Variable(scope.WithOpName("in1"), {2, 3, 4, 5, 6}, DT_FLOAT);
2110 auto begin1 = ops::Const(scope.WithOpName("begin1"), {0, 0, 0}, {3});
2111 auto end1 = ops::Const(scope.WithOpName("end1"), {0, 5, 6}, {3});
2112 auto strides1 = ops::Const(scope.WithOpName("strides1"), {1, 1, 1}, {3});
2113 ops::StridedSlice s1(
2114 scope.WithOpName("s1"), in1, begin1, end1, strides1,
2115 ops::StridedSlice::Attrs().BeginMask(1).EndMask(1).EllipsisMask(2));
2116
2117 Output in2 =
2118 ops::Variable(scope.WithOpName("in2"), {5, 8, 5, 6, 9}, DT_FLOAT);
2119 auto begin2 = ops::Const(scope.WithOpName("begin2"), {0, 0, 0, 0, 0}, {5});
2120 auto end2 = ops::Const(scope.WithOpName("end2"), {2, 3, 4, 5, 6}, {5});
2121 auto strides2 =
2122 ops::Const(scope.WithOpName("strides2"), {1, 1, 1, 1, 1}, {5});
2123 ops::StridedSlice s2(scope.WithOpName("s2"), in2, begin2, end2, strides2);
2124
2125 ops::Add out(scope.WithOpName("out"), s1, s2);
2126
2127 GrapplerItem item;
2128 item.fetch = {"out"};
2129 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2130
2131 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2132 GraphDef got;
2133 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2134 TF_EXPECT_OK(status);
2135
2136 GraphDef want;
2137 AddNode("in1", "VariableV2", {}, {}, &want);
2138 AddNode("in2", "VariableV2", {}, {}, &want);
2139 AddNode("begin1", "Const", {}, {}, &want);
2140 AddNode("end1", "Const", {}, {}, &want);
2141 AddNode("strides1", "Const", {}, {}, &want);
2142 AddNode("s1", "Identity",
2143 {"in1", AsControlDependency("begin1"), AsControlDependency("end1"),
2144 AsControlDependency("strides1")},
2145 {}, &want);
2146 AddNode("begin2", "Const", {}, {}, &want);
2147 AddNode("end2", "Const", {}, {}, &want);
2148 AddNode("strides2", "Const", {}, {}, &want);
2149 AddNode("s2", "StridedSlice", {"in2", "begin2", "end2", "strides2"}, {},
2150 &want);
2151 AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2152
2153 CompareGraphs(want, got);
2154
2155 auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 3, 4, 5, 6}));
2156 auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 8, 5, 6, 9}));
2157 auto tensors_expected =
2158 EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2159 EXPECT_EQ(1, tensors_expected.size());
2160 auto tensors =
2161 EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2162 EXPECT_EQ(1, tensors.size());
2163 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2164 }
2165 }
2166
TEST_F(ConstantFoldingTest,TileWithMultipliesBeingOne)2167 TEST_F(ConstantFoldingTest, TileWithMultipliesBeingOne) {
2168 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2169
2170 auto in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_FLOAT);
2171 auto in2 = ops::Variable(scope.WithOpName("in2"), {4, 3}, DT_FLOAT);
2172 auto multiplies1 = ops::Const(scope.WithOpName("multiplies1"), {1, 1}, {2});
2173 auto multiplies2 = ops::Const(scope.WithOpName("multiplies2"), {1, 2}, {2});
2174
2175 ops::Tile t1(scope.WithOpName("t1"), in1, multiplies1);
2176 ops::Tile t2(scope.WithOpName("t2"), in2, multiplies2);
2177
2178 ops::Add out(scope.WithOpName("out"), t1, t2);
2179
2180 GrapplerItem item;
2181 item.fetch = {"out"};
2182 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2183
2184 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2185 GraphDef got;
2186 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2187 TF_EXPECT_OK(status);
2188
2189 GraphDef want;
2190 AddNode("in1", "VariableV2", {}, {}, &want);
2191 AddNode("in2", "VariableV2", {}, {}, &want);
2192 AddNode("multiplies1", "Const", {}, {}, &want);
2193 AddNode("multiplies2", "Const", {}, {}, &want);
2194 AddNode("t1", "Identity", {"in1", AsControlDependency("multiplies1")}, {},
2195 &want);
2196 AddNode("t2", "Tile", {"in2", "multiplies2"}, {}, &want);
2197 AddNode("out", "Add", {"t1", "t2"}, {}, &want);
2198
2199 CompareGraphs(want, got);
2200 }
2201
TEST_F(ConstantFoldingTest,MergeConcat)2202 TEST_F(ConstantFoldingTest, MergeConcat) {
2203 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2204
2205 Output in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_FLOAT);
2206 Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6}, DT_FLOAT);
2207 Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
2208 Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
2209
2210 ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis);
2211 ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis);
2212
2213 GrapplerItem item;
2214 item.fetch = {"c2"};
2215 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2216
2217 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2218 GraphDef got;
2219 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2220 TF_EXPECT_OK(status);
2221
2222 GraphDef want;
2223 AddNode("in1", "VariableV2", {}, {}, &want);
2224 AddNode("in2", "VariableV2", {}, {}, &want);
2225 AddNode("in3", "VariableV2", {}, {}, &want);
2226 AddNode("axis", "Const", {}, {}, &want);
2227 AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "axis"}, {}, &want);
2228
2229 CompareGraphs(want, got);
2230 }
2231
TEST_F(ConstantFoldingTest,MergeConcat_SameInput)2232 TEST_F(ConstantFoldingTest, MergeConcat_SameInput) {
2233 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2234
2235 Output in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_FLOAT);
2236 Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6}, DT_FLOAT);
2237 Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
2238 Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
2239
2240 ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis);
2241 ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3, Output(c1)}, axis);
2242
2243 GrapplerItem item;
2244 item.fetch = {"c2"};
2245 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2246
2247 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2248 GraphDef got;
2249 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2250 TF_EXPECT_OK(status);
2251
2252 GraphDef want;
2253 AddNode("in1", "VariableV2", {}, {}, &want);
2254 AddNode("in2", "VariableV2", {}, {}, &want);
2255 AddNode("in3", "VariableV2", {}, {}, &want);
2256 AddNode("axis", "Const", {}, {}, &want);
2257 AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "in1", "in2", "axis"}, {},
2258 &want);
2259
2260 CompareGraphs(want, got);
2261 }
2262
TEST_F(ConstantFoldingTest,MergeConcat_ConcatWithConst)2263 TEST_F(ConstantFoldingTest, MergeConcat_ConcatWithConst) {
2264 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2265
2266 Output in1 = ops::Variable(scope.WithOpName("in1"), {2, 6}, DT_FLOAT);
2267 Output in2 = ops::Variable(scope.WithOpName("in2"), {}, DT_FLOAT);
2268 Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
2269 Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
2270
2271 ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis);
2272 ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis);
2273
2274 GrapplerItem item;
2275 item.fetch = {"c2"};
2276 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2277
2278 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2279 GraphDef got;
2280 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2281 TF_EXPECT_OK(status);
2282
2283 GraphDef want;
2284 AddNode("in1", "VariableV2", {}, {}, &want);
2285 AddNode("in2", "VariableV2", {}, {}, &want);
2286 AddNode("in3", "VariableV2", {}, {}, &want);
2287 AddNode("axis", "Const", {}, {}, &want);
2288 AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "axis"}, {}, &want);
2289
2290 CompareGraphs(want, got);
2291 }
2292
TEST_F(ConstantFoldingTest,MergeConcat_AxisMismatch)2293 TEST_F(ConstantFoldingTest, MergeConcat_AxisMismatch) {
2294 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2295
2296 Output in1 = ops::Variable(scope.WithOpName("in1"), {2, 5}, DT_FLOAT);
2297 Output in2 = ops::Variable(scope.WithOpName("in2"), {}, DT_FLOAT);
2298 Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
2299 Output axis1 = ops::Const(scope.WithOpName("axis1"), 0, {});
2300 Output axis2 = ops::Const(scope.WithOpName("axis2"), 1, {});
2301
2302 ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis2);
2303 ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis1);
2304
2305 GrapplerItem item;
2306 item.fetch = {"c2"};
2307 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2308
2309 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2310 GraphDef got;
2311 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2312 TF_EXPECT_OK(status);
2313
2314 GraphDef want;
2315 AddNode("in1", "VariableV2", {}, {}, &want);
2316 AddNode("in2", "VariableV2", {}, {}, &want);
2317 AddNode("in3", "VariableV2", {}, {}, &want);
2318 AddNode("axis1", "Const", {}, {}, &want);
2319 AddNode("axis2", "Const", {}, {}, &want);
2320 AddNode("c1", "ConcatV2", {"in1", "in2", "axis2"}, {}, &want);
2321 AddNode("c2", "ConcatV2", {"c1", "in3", "axis1"}, {}, &want);
2322
2323 CompareGraphs(want, got);
2324 }
2325
TEST_F(ConstantFoldingTest,PaddingWithZeroSize)2326 TEST_F(ConstantFoldingTest, PaddingWithZeroSize) {
2327 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2328
2329 auto in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_INT32);
2330 auto in2 = ops::Variable(scope.WithOpName("in2"), {2, 2}, DT_INT32);
2331 auto paddings1 =
2332 ops::Const(scope.WithOpName("paddings1"), {0, 0, 0, 0}, {2, 2});
2333 auto paddings2 =
2334 ops::Const(scope.WithOpName("paddings2"), {1, 1, 2, 2}, {2, 2});
2335 auto c1 = ops::Const(scope.WithOpName("c1"), 1);
2336 auto c2 = ops::Const(scope.WithOpName("c2"), 1);
2337
2338 ops::PadV2 p1(scope.WithOpName("p1"), in1, paddings1, c1);
2339 ops::PadV2 p2(scope.WithOpName("p2"), in2, paddings2, c2);
2340
2341 ops::Add out(scope.WithOpName("out"), p1, p2);
2342
2343 GrapplerItem item;
2344 item.fetch = {"out"};
2345 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2346
2347 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2348 GraphDef got;
2349 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2350 TF_EXPECT_OK(status);
2351
2352 GraphDef want;
2353 AddNode("in1", "VariableV2", {}, {}, &want);
2354 AddNode("in2", "VariableV2", {}, {}, &want);
2355 AddNode("paddings1", "Const", {}, {}, &want);
2356 AddNode("paddings2", "Const", {}, {}, &want);
2357 AddNode("c1", "Const", {}, {}, &want);
2358 AddNode("c2", "Const", {}, {}, &want);
2359 AddNode("p1", "Identity",
2360 {"in1", AsControlDependency("paddings1"), AsControlDependency("c1")},
2361 {}, &want);
2362 AddNode("p2", "PadV2", {"in2", "paddings2", "c2"}, {}, &want);
2363 AddNode("out", "Add", {"p1", "p2"}, {}, &want);
2364
2365 CompareGraphs(want, got);
2366
2367 auto in1_t = GenerateRandomTensor<DT_INT32>(TensorShape({4, 6}));
2368 auto in2_t = GenerateRandomTensor<DT_INT32>(TensorShape({2, 2}));
2369 auto tensors_expected =
2370 EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2371 EXPECT_EQ(1, tensors_expected.size());
2372 auto tensors =
2373 EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2374 EXPECT_EQ(1, tensors.size());
2375 test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
2376 }
2377
TEST_F(ConstantFoldingTest,SqueezeWithAllDimesionsGreaterThanOne)2378 TEST_F(ConstantFoldingTest, SqueezeWithAllDimesionsGreaterThanOne) {
2379 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2380
2381 auto in1 = ops::Variable(scope.WithOpName("in1"), {2, 3}, DT_INT32);
2382 auto in2 = ops::Variable(scope.WithOpName("in2"), {1, 2, 3, 1}, DT_INT32);
2383
2384 ops::Squeeze s1(scope.WithOpName("s1"), in1);
2385 ops::Squeeze s2(scope.WithOpName("s2"), in2);
2386
2387 ops::Add out(scope.WithOpName("out"), s1, s2);
2388
2389 GrapplerItem item;
2390 item.fetch = {"out"};
2391 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2392
2393 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2394 GraphDef got;
2395 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2396 TF_EXPECT_OK(status);
2397
2398 GraphDef want;
2399 AddNode("in1", "VariableV2", {}, {}, &want);
2400 AddNode("in2", "VariableV2", {}, {}, &want);
2401 AddNode("s1", "Identity", {"in1"}, {}, &want);
2402 AddNode("s2", "Squeeze", {"in2"}, {}, &want);
2403 AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2404
2405 CompareGraphs(want, got);
2406
2407 auto in1_t = GenerateRandomTensor<DT_INT32>(TensorShape({2, 3}));
2408 auto in2_t = GenerateRandomTensor<DT_INT32>(TensorShape({1, 2, 3, 1}));
2409 auto tensors_expected =
2410 EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2411 EXPECT_EQ(1, tensors_expected.size());
2412 auto tensors =
2413 EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2414 EXPECT_EQ(1, tensors.size());
2415 test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
2416 }
2417
TEST_F(ConstantFoldingTest,NoOpReduction)2418 TEST_F(ConstantFoldingTest, NoOpReduction) {
2419 // Build a simple graph with reductions that can be reduced to the
2420 // identity.
2421 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2422
2423 Output v = ops::Variable(scope.WithOpName("v"), {3, 5, 7}, DT_FLOAT);
2424 Output c =
2425 ops::Const(scope.WithOpName("c").WithControlDependencies(v), 0, {0});
2426 Output i = ops::Identity(scope.WithOpName("i"), c);
2427 Output p = ops::Prod(scope.WithOpName("p"), v, i);
2428 Output s = ops::Square(scope.WithOpName("s"), p);
2429
2430 Output v2 = ops::Variable(scope.WithOpName("v2"), {3, 5, 1}, DT_FLOAT);
2431 Output c2 =
2432 ops::Const(scope.WithOpName("c2").WithControlDependencies(v), 2, {1});
2433 ops::Prod::Attrs attr;
2434 attr = attr.KeepDims(true);
2435 Output p2 = ops::Prod(scope.WithOpName("p2"), v2, c2, attr);
2436
2437 GrapplerItem item;
2438 item.fetch = {"s", "p2"};
2439 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2440
2441 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2442 GraphDef output;
2443 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2444 TF_EXPECT_OK(status);
2445
2446 int found = 0;
2447 for (const auto& node : output.node()) {
2448 if (node.name() == "p") {
2449 found++;
2450 EXPECT_EQ("Identity", node.op());
2451 EXPECT_EQ(2, node.input_size());
2452 EXPECT_EQ("v", node.input(0));
2453 EXPECT_EQ("^i", node.input(1));
2454 } else if (node.name() == "p2") {
2455 found++;
2456 EXPECT_EQ("Identity", node.op());
2457 EXPECT_EQ(2, node.input_size());
2458 EXPECT_EQ("v2", node.input(0));
2459 EXPECT_EQ("^c2", node.input(1));
2460 }
2461 }
2462 EXPECT_EQ(2, found);
2463
2464 auto v_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 7}));
2465 auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 1}));
2466 auto tensors_expected =
2467 EvaluateNodes(item.graph, item.fetch, {{"v", v_t}, {"v2", v2_t}});
2468 EXPECT_EQ(2, tensors_expected.size());
2469 auto tensors = EvaluateNodes(output, item.fetch, {{"v", v_t}, {"v2", v2_t}});
2470 EXPECT_EQ(2, tensors.size());
2471 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2472 test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-5);
2473 }
2474
TEST_F(ConstantFoldingTest,SingleElementEmptyAxisReduction)2475 TEST_F(ConstantFoldingTest, SingleElementEmptyAxisReduction) {
2476 // Build a simple graph with reductions that involve single-element input and
2477 // no axes to reduce along.
2478 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2479
2480 Output input_var_three_dim = ops::Variable(
2481 scope.WithOpName("input_var_three_dim"), {1, 1, 1}, DT_FLOAT);
2482 Output input_var_one_dim =
2483 ops::Variable(scope.WithOpName("input_var_one_dim"), {1}, DT_FLOAT);
2484 Output one_axis = ops::Const(scope.WithOpName("one_axis"), {0}, {1});
2485 Output multiple_axes =
2486 ops::Const(scope.WithOpName("multiple_axes"), {1, 0}, {2});
2487 Output variable_axis =
2488 ops::Variable(scope.WithOpName("input_var_axis"), {1}, DT_INT32);
2489 ops::Mean::Attrs attr;
2490 attr = attr.KeepDims(false);
2491 // Should be optimized to Reshape.
2492 Output mean_1 = ops::Mean(scope.WithOpName("mean_1"), input_var_three_dim,
2493 one_axis, attr.KeepDims(false));
2494 Output mean_2 = ops::Mean(scope.WithOpName("mean_2"), input_var_three_dim,
2495 multiple_axes, attr.KeepDims(false));
2496 // Should remain as-is, since OutputProperties will not be known this node.
2497 Output mean_3 = ops::Mean(scope.WithOpName("mean_3"), input_var_one_dim,
2498 one_axis, attr.KeepDims(false));
2499 // Should remain as-is.
2500 Output mean_4 = ops::Mean(scope.WithOpName("mean_4"), input_var_three_dim,
2501 variable_axis, attr.KeepDims(false));
2502 // Should be optimized to Identity, since KeepDims=true.
2503 Output mean_5 = ops::Mean(scope.WithOpName("mean_5"), input_var_three_dim,
2504 multiple_axes, attr.KeepDims(true));
2505
2506 GrapplerItem item;
2507 item.fetch = {"mean_1", "mean_2", "mean_3", "mean_4", "mean_5"};
2508 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2509
2510 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2511 GraphDef output;
2512 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2513 TF_EXPECT_OK(status);
2514
2515 // Ensure Mean node is optimized to Reshape.
2516 int found = 0;
2517 for (const auto& node : output.node()) {
2518 if (node.name() == "mean_1" || node.name() == "mean_2") {
2519 found++;
2520 EXPECT_EQ("Reshape", node.op());
2521 EXPECT_EQ(2, node.input_size());
2522 EXPECT_EQ("input_var_three_dim", node.input(0));
2523 } else if (node.name() == "mean_3") {
2524 found++;
2525 EXPECT_EQ("Mean", node.op());
2526 EXPECT_EQ(2, node.input_size());
2527 EXPECT_EQ("input_var_one_dim", node.input(0));
2528 } else if (node.name() == "mean_4") {
2529 found++;
2530 EXPECT_EQ("Mean", node.op());
2531 EXPECT_EQ(2, node.input_size());
2532 EXPECT_EQ("input_var_three_dim", node.input(0));
2533 } else if (node.name() == "mean_5") {
2534 found++;
2535 EXPECT_EQ("Identity", node.op());
2536 EXPECT_EQ(2, node.input_size());
2537 EXPECT_EQ("^multiple_axes", node.input(1));
2538 }
2539 }
2540 EXPECT_EQ(5, found);
2541
2542 // Ensure resultant values from Mean and Reshape are the same.
2543 auto input_var_three_dim_t =
2544 GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 1, 1}));
2545 auto input_var_one_dim_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1}));
2546 Tensor input_var_axis_t(DT_INT32, TensorShape({1}));
2547 input_var_axis_t.flat<int32>()(0) = 0;
2548 auto tensors_expected =
2549 EvaluateNodes(item.graph, item.fetch,
2550 {{"input_var_three_dim", input_var_three_dim_t},
2551 {"input_var_one_dim", input_var_one_dim_t},
2552 {"input_var_axis", input_var_axis_t}});
2553 EXPECT_EQ(5, tensors_expected.size());
2554 auto tensors = EvaluateNodes(output, item.fetch,
2555 {{"input_var_three_dim", input_var_three_dim_t},
2556 {"input_var_one_dim", input_var_one_dim_t},
2557 {"input_var_axis", input_var_axis_t}});
2558 EXPECT_EQ(5, tensors.size());
2559 for (int i = 0; i < 5; ++i) {
2560 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
2561 }
2562 }
2563
TEST_F(ConstantFoldingTest,NoOpReshape)2564 TEST_F(ConstantFoldingTest, NoOpReshape) {
2565 // Build a simple graph with a reshape that can be reduced to the identity.
2566 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2567
2568 // A reshape than can be optimized
2569 Output d1 = ops::Const(scope.WithOpName("d1"), 3.14f, {17});
2570 Output v1 = ops::Variable(scope.WithOpName("v1"), {17}, DT_FLOAT);
2571 Output c1 =
2572 ops::Const(scope.WithOpName("c1").WithControlDependencies(v1), 17, {1});
2573 Output i1 = ops::Identity(scope.WithOpName("i1"), c1);
2574 Output r1 =
2575 ops::Reshape(scope.WithOpName("r1").WithControlDependencies(d1), v1, i1);
2576 Output s1 = ops::Square(scope.WithOpName("s1"), r1);
2577
2578 // A multi dimensional reshape than can be optimized
2579 Output v3 = ops::Variable(scope.WithOpName("v3"), {5, 5, 5}, DT_FLOAT);
2580 Output c3 =
2581 ops::Const(scope.WithOpName("c3").WithControlDependencies(v3), 5, {3});
2582 Output i3 = ops::Identity(scope.WithOpName("i3"), c3);
2583 Output r3 = ops::Reshape(scope.WithOpName("r3"), v3, i3);
2584 Output s3 = ops::Square(scope.WithOpName("s3"), r3);
2585
2586 // A multi dimensional partially defined reshape than can be optimized
2587 Output v4 = ops::Variable(scope.WithOpName("v4"), {5, 5, 5}, DT_FLOAT);
2588 Output c4 = ops::Const(scope.WithOpName("c4").WithControlDependencies(v4),
2589 {5, -1, 5}, {3});
2590 Output i4 = ops::Identity(scope.WithOpName("i4"), c4);
2591 Output r4 = ops::Reshape(scope.WithOpName("r4"), v4, i4);
2592 Output s4 = ops::Square(scope.WithOpName("s4"), r4);
2593
2594 // A reshape that can't be optimized
2595 Output v2 = ops::Variable(scope.WithOpName("v2"), {17, 1}, DT_FLOAT);
2596 Output c2 =
2597 ops::Const(scope.WithOpName("c2").WithControlDependencies(v2), 17, {1});
2598 Output r2 = ops::Reshape(scope.WithOpName("r2"), v2, c2);
2599 Output s2 = ops::Square(scope.WithOpName("s2"), r2);
2600
2601 GrapplerItem item;
2602 item.fetch = {"s1", "s2", "s3", "s4"};
2603 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2604
2605 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2606 GraphDef output;
2607 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2608 TF_EXPECT_OK(status);
2609
2610 int found = 0;
2611 for (const auto& node : output.node()) {
2612 if (node.name() == "r1") {
2613 ++found;
2614 EXPECT_EQ("Identity", node.op());
2615 ASSERT_EQ(3, node.input_size());
2616 EXPECT_EQ("v1", node.input(0));
2617 EXPECT_EQ("^i1", node.input(1));
2618 EXPECT_EQ("^d1", node.input(2));
2619 } else if (node.name() == "r3") {
2620 ++found;
2621 EXPECT_EQ("Identity", node.op());
2622 ASSERT_EQ(2, node.input_size());
2623 EXPECT_EQ("v3", node.input(0));
2624 EXPECT_EQ("^i3", node.input(1));
2625 } else if (node.name() == "r4") {
2626 ++found;
2627 EXPECT_EQ("Identity", node.op());
2628 ASSERT_EQ(2, node.input_size());
2629 EXPECT_EQ("v4", node.input(0));
2630 EXPECT_EQ("^i4", node.input(1));
2631 } else if (node.name() == "r2") {
2632 ++found;
2633 EXPECT_EQ("Reshape", node.op());
2634 ASSERT_EQ(2, node.input_size());
2635 EXPECT_EQ("v2", node.input(0));
2636 EXPECT_EQ("c2", node.input(1));
2637 }
2638 }
2639 EXPECT_EQ(4, found);
2640
2641 auto v1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({17}));
2642 auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({17, 1}));
2643 auto v3_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 5, 5}));
2644 auto v4_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 5, 5}));
2645 auto tensors_expected =
2646 EvaluateNodes(item.graph, item.fetch,
2647 {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}, {"v4", v4_t}});
2648 EXPECT_EQ(4, tensors_expected.size());
2649 auto tensors =
2650 EvaluateNodes(output, item.fetch,
2651 {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}, {"v4", v4_t}});
2652 EXPECT_EQ(4, tensors.size());
2653 for (int i = 0; i < tensors.size(); i++)
2654 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
2655 }
2656
TEST_F(ConstantFoldingTest,Packing)2657 TEST_F(ConstantFoldingTest, Packing) {
2658 // Build a simple graph with a large constant that can be folded.
2659 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2660 Output c = ops::Const(scope.WithOpName("c"), 3.14f, {1000});
2661 Output i1 = ops::Identity(scope.WithOpName("i1"), c);
2662 Output i2 = ops::Identity(scope.WithOpName("i2"), c);
2663
2664 GrapplerItem item;
2665 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2666
2667 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2668 GraphDef output;
2669 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2670 TF_EXPECT_OK(status);
2671
2672 const std::vector<string> fetch_nodes = {"i1", "i2"};
2673 auto tensors_expected = EvaluateNodes(item.graph, fetch_nodes);
2674 EXPECT_EQ(fetch_nodes.size(), tensors_expected.size());
2675 auto tensors = EvaluateNodes(output, fetch_nodes);
2676 EXPECT_EQ(fetch_nodes.size(), tensors.size());
2677 for (int i = 0; i < fetch_nodes.size(); i++)
2678 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
2679
2680 // Make sure that the representation of the folded constant is space
2681 // efficient: in particular, the whole message should be smaller than 8k
2682 // (the size needed to naively encode 1000 floats folded twice).
2683 EXPECT_GT(8000, output.ByteSizeLong());
2684 }
2685
TEST_F(ConstantFoldingTest,MaterializeBroadcastGradientArgs)2686 TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs) {
2687 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2688 Output a =
2689 ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
2690 ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
2691 Output b = ops::Square(s.WithOpName("b"), a);
2692 Output c = ops::Mul(s.WithOpName("c"), a, b);
2693 Output d = ops::Shape(s.WithOpName("d"), a);
2694 Output e = ops::Shape(s.WithOpName("e"), b);
2695
2696 auto f = ops::internal::BroadcastGradientArgs(s.WithOpName("f"), d, e);
2697 Output o1 = ops::Identity(s.WithOpName("o1"), f.r0);
2698 Output o2 = ops::Identity(s.WithOpName("o2"), f.r1);
2699
2700 Output g = ops::Placeholder(s.WithOpName("g"), DT_FLOAT,
2701 ops::Placeholder::Shape(PartialTensorShape({1})));
2702 Output h = ops::Shape(s.WithOpName("h"), g);
2703 auto i = ops::internal::BroadcastGradientArgs(s.WithOpName("i"), d, h);
2704 Output p1 = ops::Identity(s.WithOpName("p1"), i.r0);
2705 Output p2 = ops::Identity(s.WithOpName("p2"), i.r1);
2706
2707 GrapplerItem item;
2708 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2709
2710 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2711 GraphDef output;
2712 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2713 TF_EXPECT_OK(status);
2714
2715 std::vector<string> fetch_nodes = {"o1", "o2", "p1", "p2"};
2716 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 5}));
2717 auto g_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1}));
2718 auto tensors_expected =
2719 EvaluateNodes(item.graph, fetch_nodes, {{"a", a_t}, {"g", g_t}});
2720 EXPECT_EQ(fetch_nodes.size(), tensors_expected.size());
2721
2722 // Run a second time to make sure the optimization is idempotent.
2723 item.graph.Swap(&output);
2724 status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2725 TF_EXPECT_OK(status);
2726
2727 int found = 0;
2728 for (const auto& node : output.node()) {
2729 if (node.name() == "o1") {
2730 ++found;
2731 EXPECT_EQ(1, node.input_size());
2732 EXPECT_EQ("ConstantFolding/f-bcastargs-0", node.input(0));
2733 } else if (node.name() == "o2") {
2734 ++found;
2735 EXPECT_EQ(1, node.input_size());
2736 EXPECT_EQ("ConstantFolding/f-bcastargs-1", node.input(0));
2737 } else if (node.name() == "ConstantFolding/f-bcastargs-0") {
2738 ++found;
2739 EXPECT_EQ("Const", node.op());
2740 EXPECT_EQ(1, node.input_size());
2741 EXPECT_EQ("^f", node.input(0));
2742 EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
2743 .num_elements());
2744 } else if (node.name() == "ConstantFolding/f-bcastargs-1") {
2745 ++found;
2746 EXPECT_EQ("Const", node.op());
2747 EXPECT_EQ(1, node.input_size());
2748 EXPECT_EQ("^f", node.input(0));
2749 EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
2750 .num_elements());
2751 } else if (node.name() == "p1") {
2752 ++found;
2753 EXPECT_EQ(1, node.input_size());
2754 EXPECT_EQ("i", node.input(0));
2755 } else if (node.name() == "p2") {
2756 ++found;
2757 EXPECT_EQ(1, node.input_size());
2758 EXPECT_EQ("i:1", node.input(0));
2759 }
2760 }
2761 EXPECT_EQ(6, found);
2762
2763 auto tensors = EvaluateNodes(output, fetch_nodes, {{"a", a_t}, {"g", g_t}});
2764 EXPECT_EQ(fetch_nodes.size(), tensors.size());
2765 for (int i = 0; i < fetch_nodes.size(); i++)
2766 test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
2767 }
2768
TEST_F(ConstantFoldingTest,MaterializeBroadcastGradientArgs_InfiniteLoop)2769 TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs_InfiniteLoop) {
2770 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2771 Output a =
2772 ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
2773 ops::Placeholder::Shape(PartialTensorShape({2, 2})));
2774 Output b = ops::Square(s.WithOpName("b"), a);
2775 Output c = ops::Mul(s.WithOpName("c"), a, b);
2776 Output d = ops::Shape(s.WithOpName("d"), a);
2777 Output e = ops::Shape(s.WithOpName("e"), b);
2778
2779 auto f = ops::internal::BroadcastGradientArgs(s.WithOpName("f"), d, e);
2780 Output o1 = ops::Identity(s.WithOpName("o1"), f.r0);
2781 Output o2 = ops::Identity(s.WithOpName("o2"), f.r1);
2782
2783 GrapplerItem item;
2784 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2785
2786 std::vector<string> fetch_nodes = {"o1", "o2"};
2787 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2788 auto tensors_expected = EvaluateNodes(item.graph, fetch_nodes, {{"a", a_t}});
2789 EXPECT_EQ(fetch_nodes.size(), tensors_expected.size());
2790
2791 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2792 GraphDef output;
2793 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2794 TF_EXPECT_OK(status);
2795
2796 // Run a second time to make sure the optimization is idempotent.
2797 item.graph.Swap(&output);
2798 status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2799 TF_EXPECT_OK(status);
2800
2801 EXPECT_EQ(11, output.node_size());
2802 int found = 0;
2803 for (const auto& node : output.node()) {
2804 if (node.name() == "ConstantFolding/f-folded-1") {
2805 ++found;
2806 EXPECT_EQ("Const", node.op());
2807 EXPECT_EQ(2, node.input_size());
2808 EXPECT_EQ("^a", node.input(0));
2809 EXPECT_EQ("^b", node.input(1));
2810 } else if (node.name() == "d") {
2811 ++found;
2812 EXPECT_EQ("Const", node.op());
2813 EXPECT_EQ(1, node.input_size());
2814 EXPECT_EQ("^a", node.input(0));
2815 } else if (node.name() == "e") {
2816 ++found;
2817 EXPECT_EQ("Const", node.op());
2818 EXPECT_EQ(1, node.input_size());
2819 EXPECT_EQ("^b", node.input(0));
2820 } else if (node.name() == "o1") {
2821 ++found;
2822 EXPECT_EQ(1, node.input_size());
2823 EXPECT_EQ("ConstantFolding/f-bcastargs-0", node.input(0));
2824 } else if (node.name() == "o2") {
2825 ++found;
2826 EXPECT_EQ(1, node.input_size());
2827 EXPECT_EQ("ConstantFolding/f-bcastargs-1", node.input(0));
2828 } else if (node.name() == "ConstantFolding/f-bcastargs-0") {
2829 ++found;
2830 EXPECT_EQ("Const", node.op());
2831 EXPECT_EQ(1, node.input_size());
2832 EXPECT_EQ("^ConstantFolding/f-folded-1", node.input(0));
2833 EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
2834 .num_elements());
2835 } else if (node.name() == "ConstantFolding/f-bcastargs-1") {
2836 ++found;
2837 EXPECT_EQ("Const", node.op());
2838 EXPECT_EQ(1, node.input_size());
2839 EXPECT_EQ("^ConstantFolding/f-folded-1", node.input(0));
2840 EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
2841 .num_elements());
2842 }
2843 }
2844 EXPECT_EQ(7, found);
2845 auto tensors = EvaluateNodes(output, fetch_nodes, {{"a", a_t}});
2846 EXPECT_EQ(fetch_nodes.size(), tensors.size());
2847 for (int i = 0; i < fetch_nodes.size(); i++)
2848 test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
2849 }
2850
TEST_F(ConstantFoldingTest,MaterializeReductionIndices)2851 TEST_F(ConstantFoldingTest, MaterializeReductionIndices) {
2852 for (bool use_reshape : {true, false}) {
2853 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2854 Output input =
2855 ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
2856 ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
2857 // If use_reshape is false, we need to now the number of indices to apply
2858 // the rewrite.
2859 Output indices = ops::Placeholder(
2860 s.WithOpName("indices"), DT_INT32,
2861 ops::Placeholder::Shape(PartialTensorShape({use_reshape ? -1 : 2})));
2862 Output sum = ops::Sum(s.WithOpName("sum"), input, indices);
2863 if (use_reshape) {
2864 Output size = ops::Const(s.WithOpName("size"), 1, {1});
2865 Output reshape = ops::Reshape(s.WithOpName("reshape"), sum, size);
2866 }
2867
2868 GrapplerItem item;
2869 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2870 item.fetch.push_back(use_reshape ? "reshape" : "sum");
2871
2872 auto input_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
2873 Tensor indices_t(DT_INT32, TensorShape({2}));
2874 indices_t.flat<int>()(0) = 0;
2875 indices_t.flat<int>()(1) = 1;
2876 auto tensors_expected = EvaluateNodes(
2877 item.graph, item.fetch, {{"input", input_t}, {"indices", indices_t}});
2878 EXPECT_EQ(1, tensors_expected.size());
2879
2880 // Use aggressive mode to force the shape inference to propagate placeholder
2881 // shapes.
2882 ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
2883 /*cpu_device=*/nullptr);
2884 GraphDef output;
2885 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2886 TF_EXPECT_OK(status);
2887
2888 // Run a second time to make sure the optimization is idempotent.
2889 item.graph.Swap(&output);
2890 status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2891 TF_EXPECT_OK(status);
2892
2893 int found = 0;
2894 for (const auto& node : output.node()) {
2895 if (node.name() == "ConstantFolding/sum-reduction_indices") {
2896 ++found;
2897 EXPECT_EQ("Const", node.op());
2898 EXPECT_EQ("^indices", node.input(0));
2899 EXPECT_EQ(2,
2900 TensorShape(node.attr().at("value").tensor().tensor_shape())
2901 .num_elements());
2902 } else if (node.name() == "sum") {
2903 ++found;
2904 EXPECT_EQ("ConstantFolding/sum-reduction_indices", node.input(1));
2905 } else if (node.name() == "indices") {
2906 ++found;
2907 }
2908 }
2909 EXPECT_EQ(3, found);
2910
2911 auto tensors = EvaluateNodes(output, item.fetch,
2912 {{"input", input_t}, {"indices", indices_t}});
2913 EXPECT_EQ(1, tensors.size());
2914 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2915 }
2916 }
2917
TEST_F(ConstantFoldingTest,MaterializeReductionIndices_NotFullReduction)2918 TEST_F(ConstantFoldingTest, MaterializeReductionIndices_NotFullReduction) {
2919 for (bool input_rank_known : {true, false}) {
2920 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2921 Output input =
2922 (input_rank_known ? ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
2923 ops::Placeholder::Shape(
2924 PartialTensorShape({-1, -1})))
2925 : ops::Placeholder(s.WithOpName("input"), DT_FLOAT));
2926 Output indices =
2927 ops::Placeholder(s.WithOpName("indices"), DT_INT32,
2928 ops::Placeholder::Shape(
2929 PartialTensorShape({input_rank_known ? 1 : 2})));
2930 Output sum = ops::Sum(s.WithOpName("sum"), input, indices);
2931
2932 GrapplerItem item;
2933 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2934 item.fetch.push_back("sum");
2935
2936 // Use aggressive mode to force the shape inference to propagate placeholder
2937 // shapes.
2938 ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
2939 /*cpu_device=*/nullptr);
2940 GraphDef output;
2941 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2942 TF_EXPECT_OK(status);
2943
2944 CompareGraphs(item.graph, output);
2945 }
2946 }
2947
TEST_F(ConstantFoldingTest,LargeConstant)2948 TEST_F(ConstantFoldingTest, LargeConstant) {
2949 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2950 // Generate a 4k by 4k constant matrix.
2951 Output mat_diag =
2952 ops::Const(scope.WithOpName("mat_diag"), 3.14f, TensorShape({1024 * 4}));
2953 Output mat = ops::Diag(scope.WithOpName("mat"), mat_diag);
2954 Output out = ops::Identity(scope.WithOpName("out"), mat);
2955
2956 GrapplerItem item;
2957 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2958 item.fetch.push_back("out");
2959
2960 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2961 GraphDef output;
2962 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2963 TF_EXPECT_OK(status);
2964
2965 // Make sure the diag node hasn't been folded, since it would use too much
2966 // memory to encode the corresponding constant.
2967 int found = 0;
2968 for (const NodeDef& node : output.node()) {
2969 if (node.name() == "out") {
2970 EXPECT_EQ("Identity", node.op());
2971 EXPECT_EQ(1, node.input_size());
2972 EXPECT_EQ("mat", node.input(0));
2973 ++found;
2974 } else if (node.name() == "mat") {
2975 EXPECT_EQ("Diag", node.op());
2976 EXPECT_EQ(1, node.input_size());
2977 EXPECT_EQ("mat_diag", node.input(0));
2978 ++found;
2979 }
2980 }
2981 EXPECT_EQ(2, found);
2982
2983 EXPECT_GT(1024 * 1024, output.ByteSizeLong());
2984
2985 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2986 EXPECT_EQ(1, tensors_expected.size());
2987 auto tensors = EvaluateNodes(output, item.fetch);
2988 EXPECT_EQ(1, tensors.size());
2989 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
2990 }
2991
TEST_F(ConstantFoldingTest,SwitchIdenticalInputs)2992 TEST_F(ConstantFoldingTest, SwitchIdenticalInputs) {
2993 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2994 Output x = ops::Placeholder(s.WithOpName("x"), DT_BOOL,
2995 ops::Placeholder::Shape(TensorShape({})));
2996 ops::Switch sw = ops::Switch(s.WithOpName("switch"), x, x);
2997 Output id_false = ops::LogicalNot(s.WithOpName("id_false"), sw.output_false);
2998 Output id_true = ops::LogicalNot(s.WithOpName("id_true"), sw.output_true);
2999
3000 GrapplerItem item;
3001 item.fetch.push_back("id_false");
3002 item.fetch.push_back("id_true");
3003 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3004
3005 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3006 GraphDef output;
3007 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3008 TF_EXPECT_OK(status);
3009
3010 EXPECT_EQ(6, output.node_size());
3011 int found = 0;
3012 for (const auto& node : output.node()) {
3013 if (node.name() == "switch" || node.name() == "x") {
3014 ++found;
3015 }
3016 if (node.name() == "id_false") {
3017 EXPECT_EQ("Const", node.op());
3018 EXPECT_EQ(1, node.input_size());
3019 EXPECT_EQ("^ConstantFoldingCtrl/switch_0", node.input(0));
3020 ++found;
3021 }
3022 if (node.name() == "id_true") {
3023 EXPECT_EQ("Const", node.op());
3024 EXPECT_EQ(1, node.input_size());
3025 EXPECT_EQ("^ConstantFoldingCtrl/switch_1", node.input(0));
3026 ++found;
3027 }
3028 if (node.name() == "ConstantFoldingCtrl/switch_0") {
3029 EXPECT_EQ("Identity", node.op());
3030 EXPECT_EQ(1, node.input_size());
3031 EXPECT_EQ("switch", node.input(0));
3032 ++found;
3033 }
3034 if (node.name() == "ConstantFoldingCtrl/switch_1") {
3035 EXPECT_EQ("Identity", node.op());
3036 EXPECT_EQ(1, node.input_size());
3037 EXPECT_EQ("switch:1", node.input(0));
3038 ++found;
3039 }
3040 }
3041 EXPECT_EQ(6, found);
3042
3043 // Evaluate id_true when input tensor x is true.
3044 Tensor x_t(DT_BOOL, TensorShape({}));
3045 x_t.flat<bool>()(0) = true;
3046 auto tensors_expected = EvaluateNodes(item.graph, {"id_true"}, {{"x", x_t}});
3047 EXPECT_EQ(1, tensors_expected.size());
3048 auto tensors = EvaluateNodes(output, {"id_true"}, {{"x", x_t}});
3049 EXPECT_EQ(1, tensors.size());
3050 test::ExpectTensorEqual<bool>(tensors_expected[0], tensors[0]);
3051
3052 // Evalute id_false when input tensor is false.
3053 x_t.flat<bool>()(0) = false;
3054 tensors_expected = EvaluateNodes(item.graph, {"id_false"}, {{"x", x_t}});
3055 EXPECT_EQ(1, tensors_expected.size());
3056 tensors = EvaluateNodes(output, {"id_false"}, {{"x", x_t}});
3057 EXPECT_EQ(1, tensors.size());
3058 test::ExpectTensorEqual<bool>(tensors_expected[0], tensors[0]);
3059 }
3060
TEST_F(ConstantFoldingTest,PartialFolding_AssociativeAndCommutative)3061 TEST_F(ConstantFoldingTest, PartialFolding_AssociativeAndCommutative) {
3062 std::function<Output(const Scope&, InputList)> addn_fun =
3063 [](const Scope& scope, InputList inputs) {
3064 return ops::AddN(scope, inputs);
3065 };
3066 std::function<Output(const Scope&, InputList)> accumulate_fun =
3067 [](const Scope& scope, InputList inputs) {
3068 return ops::AccumulateNV2(scope, inputs, TensorShape({2, 2}));
3069 };
3070 for (bool use_add_n : {true, false}) {
3071 auto fun = use_add_n ? addn_fun : accumulate_fun;
3072 const string op_name = use_add_n ? "AddN" : "AccumulateNV2";
3073 Scope s = Scope::NewRootScope();
3074 Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
3075 ops::Placeholder::Shape(TensorShape({2, 2})));
3076 Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
3077 ops::Placeholder::Shape(TensorShape({2, 2})));
3078 Output z = ops::Placeholder(s.WithOpName("z"), DT_FLOAT,
3079 ops::Placeholder::Shape(TensorShape({2, 2})));
3080 Output c1 = ops::Const(s.WithOpName("c1"), 1.0f, {2, 2});
3081 Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2, 2});
3082 Output c3 = ops::Const(s.WithOpName("c3"), 3.0f, {2, 2});
3083 Output acc0 = fun(s.WithOpName("acc0"), {c1, c2, c3});
3084 Output acc1 = fun(s.WithOpName("acc1"), {x, y, z});
3085 Output acc2 = fun(s.WithOpName("acc2"), {c1, x, y});
3086 Output acc3 = fun(s.WithOpName("acc3"), {c1, c2, z});
3087 Output acc4 = fun(s.WithOpName("acc4"), {c1, y, c2});
3088 Output acc5 = fun(s.WithOpName("acc5"), {x, c1, c2});
3089 Output acc6 = fun(s.WithOpName("acc6"), {x, c1, y, c2});
3090 Output stack = ops::Stack(s.WithOpName("stack"),
3091 {acc0, acc1, acc2, acc3, acc4, acc5, acc6});
3092
3093 GrapplerItem item;
3094 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3095 item.fetch = {"stack"};
3096
3097 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3098 GraphDef output;
3099 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3100 TF_EXPECT_OK(status);
3101
3102 EXPECT_EQ(16, output.node_size());
3103 for (const NodeDef& node : output.node()) {
3104 if (node.name() == "acc0") {
3105 EXPECT_EQ("Const", node.op());
3106 }
3107 if (node.name() == "acc1") {
3108 EXPECT_EQ(op_name, node.op());
3109 EXPECT_EQ(3, node.input_size());
3110 EXPECT_EQ("x", node.input(0));
3111 EXPECT_EQ("y", node.input(1));
3112 EXPECT_EQ("z", node.input(2));
3113 }
3114 if (node.name() == "acc2") {
3115 EXPECT_EQ(op_name, node.op());
3116 EXPECT_EQ(3, node.input_size());
3117 EXPECT_EQ("c1", node.input(0));
3118 EXPECT_EQ("x", node.input(1));
3119 EXPECT_EQ("y", node.input(2));
3120 }
3121 if (node.name() == "acc3") {
3122 EXPECT_EQ(op_name, node.op());
3123 EXPECT_EQ(2, node.input_size());
3124 EXPECT_EQ("ConstantFolding/acc3_partial_split_2", node.input(0));
3125 EXPECT_EQ("z", node.input(1));
3126 }
3127 if (node.name() == "acc4") {
3128 EXPECT_EQ(op_name, node.op());
3129 EXPECT_EQ(2, node.input_size());
3130 EXPECT_EQ("ConstantFolding/acc4_partial_split_2", node.input(0));
3131 EXPECT_EQ("y", node.input(1));
3132 }
3133 if (node.name() == "acc5") {
3134 EXPECT_EQ(op_name, node.op());
3135 EXPECT_EQ(2, node.input_size());
3136 EXPECT_EQ("x", node.input(0));
3137 EXPECT_EQ("ConstantFolding/acc5_partial_split_2", node.input(1));
3138 }
3139 if (node.name() == "acc6") {
3140 EXPECT_EQ(op_name, node.op());
3141 EXPECT_EQ(3, node.input_size());
3142 EXPECT_EQ("x", node.input(0));
3143 EXPECT_EQ("ConstantFolding/acc6_partial_split_2", node.input(1));
3144 EXPECT_EQ("y", node.input(2));
3145 }
3146 if (str_util::StartsWith(node.name(), "ConstantFolding/")) {
3147 EXPECT_EQ("Const", node.op());
3148 }
3149 }
3150
3151 std::vector<string> fetch = {"acc0"};
3152 auto tensors_expected = EvaluateNodes(item.graph, fetch);
3153 auto tensors = EvaluateNodes(output, fetch);
3154 EXPECT_EQ(1, tensors_expected.size());
3155 EXPECT_EQ(1, tensors.size());
3156 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
3157 }
3158 }
3159
TEST_F(ConstantFoldingTest,PartialFolding_Concat)3160 TEST_F(ConstantFoldingTest, PartialFolding_Concat) {
3161 Scope s = Scope::NewRootScope();
3162 Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
3163 ops::Placeholder::Shape(TensorShape({2, 2})));
3164 Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
3165 ops::Placeholder::Shape(TensorShape({2, 2})));
3166 Output z = ops::Placeholder(s.WithOpName("z"), DT_FLOAT,
3167 ops::Placeholder::Shape(TensorShape({2, 2})));
3168 Output axis = ops::Const(s.WithOpName("axis"), 0, {});
3169 Output c1 = ops::Const(s.WithOpName("c1"), 1.0f, {2, 2});
3170 Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2, 2});
3171 Output concat0 = ops::Concat(s.WithOpName("concat0"), {c1, c2, c1}, axis);
3172 Output concat1 = ops::Concat(s.WithOpName("concat1"), {x, y, z}, axis);
3173 Output concat2 = ops::Concat(s.WithOpName("concat2"), {c1, x, y}, axis);
3174 Output concat3 = ops::Concat(s.WithOpName("concat3"), {c1, c2, z}, axis);
3175 Output concat4 = ops::Concat(s.WithOpName("concat4"), {c1, y, c2}, axis);
3176 Output concat5 = ops::Concat(s.WithOpName("concat5"), {x, c1, c2}, axis);
3177 Output concat6 = ops::Concat(s.WithOpName("concat6"), {x, c1, y, c2}, axis);
3178 Output concat7 = ops::Concat(s.WithOpName("concat7"), {x, y, c1, c2}, axis);
3179 Output concat8 = ops::Concat(s.WithOpName("concat8"), {x, c1, c2, y}, axis);
3180 Output concat9 = ops::Concat(s.WithOpName("concat9"), {c1, c2, x, y}, axis);
3181
3182 GrapplerItem item;
3183 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3184 item.fetch = {"concat0", "concat1", "concat2", "concat3", "concat4",
3185 "concat5", "concat6", "concat7", "concat8", "concat9"};
3186
3187 auto tensors_expected = EvaluateNodes(item.graph, {"concat0"});
3188 EXPECT_EQ(1, tensors_expected.size());
3189 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3190 GraphDef output;
3191 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3192 TF_EXPECT_OK(status);
3193 // Run the optimizer twice to make sure the rewrite is idempotent.
3194 item.graph.Swap(&output);
3195 status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3196 TF_EXPECT_OK(status);
3197
3198 EXPECT_EQ(21, output.node_size());
3199 for (int i = 0; i < output.node_size(); ++i) {
3200 const NodeDef& node = output.node(i);
3201 if (node.name() == "concat0") {
3202 EXPECT_EQ("Const", node.op());
3203 } else if (node.name() == "concat3") {
3204 EXPECT_EQ(3, node.input_size());
3205 EXPECT_EQ("ConstantFolding/concat3_partial_split_0", node.input(0));
3206 EXPECT_EQ("z", node.input(1));
3207 EXPECT_EQ("axis", node.input(2));
3208 } else if (node.name() == "concat5") {
3209 EXPECT_EQ(3, node.input_size());
3210 EXPECT_EQ("x", node.input(0));
3211 EXPECT_EQ("ConstantFolding/concat5_partial_split_1", node.input(1));
3212 EXPECT_EQ("axis", node.input(2));
3213 } else if (node.name() == "concat7") {
3214 EXPECT_EQ(4, node.input_size());
3215 EXPECT_EQ("x", node.input(0));
3216 EXPECT_EQ("y", node.input(1));
3217 EXPECT_EQ("ConstantFolding/concat7_partial_split_2", node.input(2));
3218 EXPECT_EQ("axis", node.input(3));
3219 } else if (node.name() == "concat8") {
3220 EXPECT_EQ(4, node.input_size());
3221 EXPECT_EQ("x", node.input(0));
3222 EXPECT_EQ("ConstantFolding/concat8_partial_split_1", node.input(1));
3223 EXPECT_EQ("y", node.input(2));
3224 EXPECT_EQ("axis", node.input(3));
3225 } else if (node.name() == "concat9") {
3226 EXPECT_EQ(4, node.input_size());
3227 EXPECT_EQ("ConstantFolding/concat9_partial_split_0", node.input(0));
3228 EXPECT_EQ("x", node.input(1));
3229 EXPECT_EQ("y", node.input(2));
3230 EXPECT_EQ("axis", node.input(3));
3231 } else if (str_util::StartsWith(node.name(), "ConstantFolding/")) {
3232 EXPECT_EQ("Const", node.op());
3233 } else {
3234 EXPECT_EQ(item.graph.node(i).DebugString(), node.DebugString());
3235 }
3236 }
3237
3238 auto tensors = EvaluateNodes(output, {"concat0"});
3239 EXPECT_EQ(1, tensors.size());
3240 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
3241 }
3242
TEST_F(ConstantFoldingTest,PartialFolding_IdentityN)3243 TEST_F(ConstantFoldingTest, PartialFolding_IdentityN) {
3244 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3245 Output x = ops::Placeholder(scope.WithOpName("x"), DT_FLOAT,
3246 ops::Placeholder::Shape(TensorShape({})));
3247 Output c1 = ops::Const(scope.WithOpName("c1"), 1.0f, {2, 2});
3248 Output c2 = ops::Const(scope.WithOpName("c2"), 2.0f, {2, 2});
3249 auto id_n = ops::IdentityN(scope.WithOpName("id_n"), {c1, x, c2});
3250 auto id0 = ops::Identity(scope.WithOpName("id0"), id_n[0]);
3251 auto id1 = ops::Identity(scope.WithOpName("id1"), id_n[1]);
3252 auto add0 = ops::Add(scope.WithOpName("add0"), id_n[0], id_n[1]);
3253 auto add1 = ops::Add(scope.WithOpName("add1"), id_n[0], id_n[2]);
3254
3255 GrapplerItem item;
3256 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3257 item.fetch.push_back("id0");
3258 item.fetch.push_back("id1");
3259 item.fetch.push_back("add0");
3260 item.fetch.push_back("add1");
3261
3262 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3263 GraphDef output;
3264 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3265 TF_EXPECT_OK(status);
3266 EXPECT_EQ(8, output.node_size());
3267 for (const auto& node : output.node()) {
3268 // id_n should remain unchanged.
3269 if (node.name() == "id_n") {
3270 EXPECT_EQ(3, node.input_size());
3271 EXPECT_EQ("c1", node.input(0));
3272 EXPECT_EQ("x", node.input(1));
3273 EXPECT_EQ("c2", node.input(2));
3274 }
3275 // id0 should be constant folded, and a control dependency from id_n.
3276 if (node.name() == "id0") {
3277 EXPECT_EQ("Const", node.op());
3278 EXPECT_EQ(1, node.input_size());
3279 EXPECT_EQ("^id_n", node.input(0));
3280 }
3281 // id1 is unchanged.
3282 if ("id1" == node.name()) {
3283 EXPECT_EQ(1, node.input_size());
3284 EXPECT_EQ("id_n:1", node.input(0));
3285 }
3286
3287 if ("add0" == node.name()) {
3288 EXPECT_EQ(2, node.input_size());
3289 EXPECT_EQ("c1", node.input(0));
3290 EXPECT_EQ("id_n:1", node.input(1));
3291 }
3292 // add1 should bo constant folded and have a control dependency from id_n.
3293 if ("add1" == node.name()) {
3294 EXPECT_EQ("Const", node.op());
3295 EXPECT_EQ(1, node.input_size());
3296 EXPECT_EQ("^id_n", node.input(0));
3297 }
3298 }
3299
3300 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({}));
3301 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
3302 EXPECT_EQ(4, tensors_expected.size());
3303 auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
3304 EXPECT_EQ(4, tensors.size());
3305 for (int i = 0; i < tensors.size(); i++) {
3306 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
3307 }
3308 }
3309
TEST_F(ConstantFoldingTest,TrivialPack)3310 TEST_F(ConstantFoldingTest, TrivialPack) {
3311 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3312 Output x =
3313 ops::RandomNormal(scope.WithOpName("x"), {2, 2}, DataType::DT_FLOAT);
3314 Output y = ops::Const(scope.WithOpName("y"), {2.0f}, {});
3315 auto stack =
3316 ops::Stack(scope.WithOpName("stack").WithControlDependencies({y}), {x},
3317 ops::Stack::Axis(1));
3318 auto stack_no_axis = ops::Stack(scope.WithOpName("stack_no_axis"), {x});
3319
3320 GrapplerItem item;
3321 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3322 item.fetch = {"stack", "stack_no_axis"};
3323
3324 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3325 GraphDef output;
3326 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3327 TF_EXPECT_OK(status);
3328 EXPECT_EQ(7, output.node_size());
3329 int found = 0;
3330 for (const auto& node : output.node()) {
3331 if (node.name() == "stack") {
3332 EXPECT_EQ("ExpandDims", node.op());
3333 EXPECT_EQ(3, node.input_size());
3334 EXPECT_EQ("x", node.input(0));
3335 EXPECT_EQ("ConstantFolding/stack_const_axis", node.input(1));
3336 EXPECT_EQ("^y", node.input(2));
3337 ++found;
3338 } else if (node.name() == "stack_no_axis") {
3339 EXPECT_EQ("ExpandDims", node.op());
3340 EXPECT_EQ(2, node.input_size());
3341 EXPECT_EQ("x", node.input(0));
3342 EXPECT_EQ("ConstantFolding/stack_no_axis_const_axis", node.input(1));
3343 ++found;
3344 } else if (node.name() == "ConstantFolding/stack_const_axis") {
3345 EXPECT_EQ("Const", node.op());
3346 EXPECT_EQ(1, node.input_size());
3347 EXPECT_EQ("^x", node.input(0));
3348 ++found;
3349 }
3350 }
3351 EXPECT_EQ(found, 3);
3352
3353 std::vector<string> fetch = {"stack", "stack_no_axis"};
3354 auto tensors_expected = EvaluateNodes(item.graph, fetch);
3355 auto tensors = EvaluateNodes(output, fetch);
3356 EXPECT_EQ(2, tensors_expected.size());
3357 EXPECT_EQ(2, tensors.size());
3358 EXPECT_EQ(tensors_expected[0].shape(), tensors[0].shape());
3359 EXPECT_EQ(tensors_expected[1].shape(), tensors[1].shape());
3360 }
3361
3362 // The test does not evalute the optimized and original graphs to check if their
3363 // outputs are the same. See b/78233179.
TEST_F(ConstantFoldingTest,Enter)3364 TEST_F(ConstantFoldingTest, Enter) {
3365 GrapplerItem item;
3366 AttrValue frame_name;
3367 frame_name.set_s("foo");
3368 AttrValue is_constant_true;
3369 is_constant_true.set_b(true);
3370 AttrValue is_constant_false;
3371 is_constant_false.set_b(false);
3372 AttrValue type;
3373 type.set_type(DT_FLOAT);
3374 AttrValue value;
3375 Tensor value_tensor(DT_FLOAT, TensorShape({}));
3376 value_tensor.flat<float>()(0) = 1;
3377 value_tensor.AsProtoTensorContent(value.mutable_tensor());
3378
3379 GraphDef& graph = item.graph;
3380 AddNode("x", "Placeholder", {}, {{"dtype", type}}, &graph);
3381 AddNode("c1", "Const", {"^x"}, {{"value", value}, {"dtype", type}}, &graph);
3382 AddNode("enter1", "Enter", {"x"},
3383 {{"T", type},
3384 {"frame_name", frame_name},
3385 {"is_constant", is_constant_true}},
3386 &graph);
3387 AddNode("enter2", "Enter", {"c1"},
3388 {{"T", type},
3389 {"frame_name", frame_name},
3390 {"is_constant", is_constant_true}},
3391 &graph);
3392 AddNode("enter3", "Enter", {"c1"},
3393 {{"T", type},
3394 {"frame_name", frame_name},
3395 {"is_constant", is_constant_false}},
3396 &graph);
3397 AddNode("id1", "Identity", {"enter1"}, {{"T", type}}, &graph);
3398 AddNode("id2", "Identity", {"enter2"}, {{"T", type}}, &graph);
3399 AddNode("id3", "Identity", {"enter2"}, {{"T", type}}, &graph);
3400 AddNode("id4", "Identity", {"enter3"}, {{"T", type}}, &graph);
3401 item.fetch.push_back("id1");
3402 item.fetch.push_back("id2");
3403 item.fetch.push_back("id3");
3404 item.fetch.push_back("id4");
3405
3406 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3407 GraphDef output;
3408 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3409 TF_EXPECT_OK(status);
3410 // Run the optimizer twice to make sure the rewrite is idempotent.
3411 item.graph.Swap(&output);
3412 status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3413 TF_EXPECT_OK(status);
3414
3415 EXPECT_EQ(9, output.node_size());
3416 for (const NodeDef& node : output.node()) {
3417 if (node.name() == "id1") {
3418 EXPECT_EQ("Identity", node.op());
3419 EXPECT_EQ(1, node.input_size());
3420 EXPECT_EQ("enter1", node.input(0));
3421 }
3422 if (node.name() == "id2" || node.name() == "id3") {
3423 EXPECT_EQ("Const", node.op());
3424 EXPECT_EQ(1, node.input_size());
3425 EXPECT_EQ("^enter2", node.input(0));
3426 }
3427 if (node.name() == "id4") {
3428 EXPECT_EQ("Identity", node.op());
3429 EXPECT_EQ(1, node.input_size());
3430 EXPECT_EQ("enter3", node.input(0));
3431 }
3432 }
3433 }
3434
TEST_F(ConstantFoldingTest,TensorArraySize)3435 TEST_F(ConstantFoldingTest, TensorArraySize) {
3436 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3437 Output size = ops::Const(scope.WithOpName("size"), 5, TensorShape({}));
3438 Output placeholder =
3439 ops::Placeholder(scope.WithOpName("placeholder"), DT_RESOURCE,
3440 ops::Placeholder::Shape(TensorShape({2})));
3441 Output foo = ops::Const(scope.WithOpName("foo"), 5.0f, TensorShape({}));
3442 auto dynamic_array =
3443 ops::TensorArray(scope.WithOpName("dynamic"), size, DT_FLOAT,
3444 ops::TensorArray::DynamicSize(true));
3445 auto static_array =
3446 ops::TensorArray(scope.WithOpName("static"), size, DT_FLOAT,
3447 ops::TensorArray::DynamicSize(false));
3448 auto dynamic_sz = ops::TensorArraySize(
3449 scope.WithOpName("dynamic_sz"), dynamic_array.handle, dynamic_array.flow);
3450 auto static_sz = ops::TensorArraySize(scope.WithOpName("static_sz"),
3451 static_array.handle, static_array.flow);
3452 auto placeholder_sz = ops::TensorArraySize(scope.WithOpName("placeholder_sz"),
3453 placeholder, foo);
3454
3455 GrapplerItem item;
3456 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3457
3458 auto tensors_expected =
3459 EvaluateNodes(item.graph, {"dynamic_sz", "static_sz"});
3460
3461 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3462 GraphDef output;
3463 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3464 TF_EXPECT_OK(status);
3465 // Run the optimizer twice to make sure the rewrite is idempotent.
3466 item.graph.Swap(&output);
3467 status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3468 TF_EXPECT_OK(status);
3469
3470 EXPECT_EQ(8, output.node_size());
3471 EXPECT_EQ("dynamic_sz", output.node(5).name());
3472 EXPECT_EQ("TensorArraySizeV3", output.node(5).op());
3473 EXPECT_EQ("static_sz", output.node(6).name());
3474 EXPECT_EQ("Const", output.node(6).op());
3475 EXPECT_EQ("placeholder_sz", output.node(7).name());
3476 EXPECT_EQ("TensorArraySizeV3", output.node(7).op());
3477
3478 auto tensors_actual = EvaluateNodes(output, {"dynamic_sz", "static_sz"});
3479 EXPECT_EQ(2, tensors_expected.size());
3480 EXPECT_EQ(2, tensors_actual.size());
3481 test::ExpectTensorEqual<int32>(tensors_expected[0], tensors_actual[0]);
3482 test::ExpectTensorEqual<int32>(tensors_expected[1], tensors_actual[1]);
3483 }
3484
TEST_F(ConstantFoldingTest,FoldingPreservesDenormalFlushing)3485 TEST_F(ConstantFoldingTest, FoldingPreservesDenormalFlushing) {
3486 // Multiplying min() with 0.1 gives a denormal without FTZ and zero with FTZ.
3487 // Make sure constant folding behaves the same way as TensorFlow.
3488 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3489
3490 Output a =
3491 ops::Const(s.WithOpName("a"), std::numeric_limits<float>::min(), {1});
3492 Output b = ops::Const(s.WithOpName("b"), 0.1f, {1});
3493 Output c = ops::Mul(s.WithOpName("c"), a, b);
3494
3495 GrapplerItem item;
3496 item.fetch.push_back("c");
3497 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3498
3499 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3500 GraphDef output;
3501 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3502 TF_EXPECT_OK(status);
3503
3504 EXPECT_EQ(1, output.node_size());
3505
3506 const NodeDef& node_d = output.node(0);
3507 EXPECT_EQ("c", node_d.name());
3508 EXPECT_EQ("Const", node_d.op());
3509
3510 std::vector<string> fetch = {"c"};
3511 auto tensors_expected = EvaluateNodes(item.graph, fetch);
3512 auto tensors = EvaluateNodes(output, fetch);
3513 EXPECT_EQ(1, tensors_expected.size());
3514 EXPECT_EQ(1, tensors.size());
3515 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
3516 }
3517
TEST_F(ConstantFoldingTest,EvaluatingLargeConstantNoFoldingMergingLoop)3518 TEST_F(ConstantFoldingTest, EvaluatingLargeConstantNoFoldingMergingLoop) {
3519 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3520
3521 int size = 10 * 1024 * 1024 / 4 / 2;
3522 Output nonconst =
3523 ops::RandomUniform(s.WithOpName("nonconst"), {size, 1}, DT_FLOAT);
3524 Output const1 = ops::Const(s.WithOpName("const1"), 0.0f, {size, 1});
3525 Output const2 = ops::Const(s.WithOpName("const2"), 1.0f, {size, 1});
3526 Output axis = ops::Const(s.WithOpName("axis"), -1, {});
3527 Output concat1 =
3528 ops::Concat(s.WithOpName("concat1"), {nonconst, const1}, axis);
3529 Output result = ops::Concat(s.WithOpName("result"), {concat1, const2}, axis);
3530
3531 GrapplerItem item;
3532 item.fetch.push_back("result");
3533 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3534
3535 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3536 GraphDef output;
3537 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3538 TF_EXPECT_OK(status);
3539
3540 std::vector<string> fetch = {"result"};
3541 auto tensors_expected = EvaluateNodes(item.graph, fetch);
3542 auto tensors = EvaluateNodes(output, fetch);
3543 EXPECT_EQ(1, tensors_expected.size());
3544 EXPECT_EQ(1, tensors.size());
3545 EXPECT_EQ(tensors_expected[0].shape(), tensors[0].shape());
3546 }
3547
3548 class ConstantFoldingCastConstTest : public GrapplerTest {
3549 protected:
ConstantFoldingCastConst(bool fetch_const,bool fetch_cast,bool fetch_const_child,bool fetch_cast_child)3550 void ConstantFoldingCastConst(bool fetch_const, bool fetch_cast,
3551 bool fetch_const_child, bool fetch_cast_child) {
3552 if (!fetch_const && !fetch_cast && !fetch_const_child &&
3553 !fetch_cast_child) {
3554 return;
3555 }
3556
3557 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3558 CreateCastConstGraph(s);
3559 GrapplerItem item;
3560 int expected_output_size = SetFetch(&item, fetch_const, fetch_cast,
3561 fetch_const_child, fetch_cast_child);
3562 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3563
3564 GraphDef output = ConstantFoldingOptimize(item);
3565 EXPECT_EQ(expected_output_size, output.node_size());
3566
3567 EvaluateAndCompareUnoptimized(item.graph, output, item.fetch);
3568 }
3569
3570 private:
CreateCastConstGraph(const tensorflow::Scope & s)3571 void CreateCastConstGraph(const tensorflow::Scope& s) {
3572 Output const1 = ops::Const(s.WithOpName("const1"), 2, {5, 5});
3573 Output cast = ops::Cast(s.WithOpName("cast"), const1, DT_FLOAT);
3574 Output const1_child = ops::Identity(s.WithOpName("const1_child"), const1);
3575 Output cast_child = ops::Identity(s.WithOpName("cast_child"), cast);
3576 }
3577
SetFetch(GrapplerItem * item,bool fetch_const,bool fetch_cast,bool fetch_const_child,bool fetch_cast_child)3578 int SetFetch(GrapplerItem* item, bool fetch_const, bool fetch_cast,
3579 bool fetch_const_child, bool fetch_cast_child) {
3580 int expected_output_size = 0;
3581 if (fetch_const) {
3582 item->fetch.push_back("const1");
3583 expected_output_size++;
3584 }
3585 if (fetch_cast) {
3586 item->fetch.push_back("cast");
3587 expected_output_size++;
3588 }
3589 if (fetch_const_child) {
3590 item->fetch.push_back("const1_child");
3591 expected_output_size++;
3592 }
3593 if (fetch_cast_child) {
3594 item->fetch.push_back("cast_child");
3595 expected_output_size++;
3596 }
3597 return expected_output_size;
3598 }
3599
ConstantFoldingOptimize(const GrapplerItem & item)3600 GraphDef ConstantFoldingOptimize(const GrapplerItem& item) {
3601 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3602 GraphDef output;
3603 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3604 TF_EXPECT_OK(status);
3605 return output;
3606 }
3607
EvaluateAndCompareUnoptimized(const GraphDef & unoptimized_graph,const GraphDef & optimized_graph,const std::vector<string> & fetch_nodes)3608 void EvaluateAndCompareUnoptimized(const GraphDef& unoptimized_graph,
3609 const GraphDef& optimized_graph,
3610 const std::vector<string>& fetch_nodes) {
3611 auto tensors_expected = EvaluateNodes(unoptimized_graph, fetch_nodes);
3612 auto tensors = EvaluateNodes(optimized_graph, fetch_nodes);
3613 ASSERT_EQ(fetch_nodes.size(), tensors_expected.size());
3614 ASSERT_EQ(fetch_nodes.size(), tensors.size());
3615 for (int i = 0; i < fetch_nodes.size(); i++) {
3616 if (fetch_nodes[i] == "const1" || fetch_nodes[i] == "const1_child") {
3617 test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
3618 } else {
3619 test::ExpectTensorEqual<float>(tensors_expected[i], tensors[i]);
3620 }
3621 }
3622 }
3623 };
3624
TEST_F(ConstantFoldingCastConstTest,CastConstFolding)3625 TEST_F(ConstantFoldingCastConstTest, CastConstFolding) {
3626 for (bool fetch_const : {false, true}) {
3627 for (bool fetch_cast : {false, true}) {
3628 for (bool fetch_const_child : {false, true}) {
3629 for (bool fetch_cast_child : {false, true}) {
3630 ConstantFoldingCastConst(fetch_const, fetch_cast, fetch_const_child,
3631 fetch_cast_child);
3632 }
3633 }
3634 }
3635 }
3636 }
3637
TEST_F(ConstantFoldingTest,MaterializeConstantValuedNode)3638 TEST_F(ConstantFoldingTest, MaterializeConstantValuedNode) {
3639 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3640
3641 Output x =
3642 ops::Placeholder(scope.WithOpName("x"), DT_FLOAT,
3643 ops::Placeholder::Shape(TensorShape({1, 2, 3, 4})));
3644 Output ones_like = ops::OnesLike(scope.WithOpName("ones_like"), x);
3645 Output zeros_like = ops::ZerosLike(scope.WithOpName("zeros_like"), x);
3646 Output fill = ops::Fill(scope.WithOpName("fill"), {4, 3, 2, 1}, 42);
3647
3648 GrapplerItem item;
3649 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3650 item.fetch = {"ones_like", "zeros_like", "fill"};
3651 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 2, 3, 4}));
3652 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
3653
3654 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3655 GraphDef output;
3656 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3657 TF_EXPECT_OK(status);
3658
3659 EXPECT_EQ(output.node_size(), 6);
3660 for (const auto& node : output.node()) {
3661 if (node.name() != "x") {
3662 EXPECT_EQ(node.op(), "Const");
3663 }
3664 if (node.name() == "ones_like" || node.name() == "zeros_like") {
3665 ASSERT_EQ(node.input_size(), 1);
3666 EXPECT_EQ(node.input(0), "^x");
3667 }
3668 if (node.name() == "fill") {
3669 ASSERT_EQ(node.input_size(), 2);
3670 EXPECT_EQ(node.input(0)[0], '^');
3671 EXPECT_EQ(node.input(1)[0], '^');
3672 }
3673 }
3674 auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
3675 ASSERT_EQ(item.fetch.size(), tensors.size());
3676 ASSERT_EQ(tensors_expected.size(), tensors.size());
3677 for (int i = 0; i < tensors.size(); i++) {
3678 if (item.fetch[i] == "fill") {
3679 test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
3680 } else {
3681 test::ExpectTensorEqual<float>(tensors_expected[i], tensors[i]);
3682 }
3683 }
3684 }
3685
TEST_F(ConstantFoldingTest,MaterializeConstantValuedNodeHugeFill)3686 TEST_F(ConstantFoldingTest, MaterializeConstantValuedNodeHugeFill) {
3687 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3688 Output value = ops::Const(scope.WithOpName("value"), 42, {});
3689 Output fill_huge = ops::Fill(scope.WithOpName("fill_huge"),
3690 {1024, 1024, 1024, 1024, 1024}, value);
3691
3692 GrapplerItem item;
3693 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3694 // Manually convert the input value format to tensor_content to test this
3695 // case.
3696 NodeDef* node = item.graph.mutable_node(0);
3697 ASSERT_EQ(node->name(), "value");
3698 TensorProto* t = (*node->mutable_attr())["value"].mutable_tensor();
3699 t->clear_int_val();
3700 int val = 42;
3701 port::CopyFromArray(t->mutable_tensor_content(),
3702 reinterpret_cast<const char*>(&val), sizeof(int));
3703 item.fetch = {"fill_huge"};
3704 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3705 GraphDef output;
3706 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3707 TF_EXPECT_OK(status);
3708
3709 EXPECT_EQ(output.node_size(), 3);
3710 for (const auto& node : output.node()) {
3711 EXPECT_EQ(node.op(), "Const");
3712 if (node.name() == "fill_huge") {
3713 ASSERT_EQ(node.input_size(), 2);
3714 EXPECT_EQ(node.input(0)[0], '^');
3715 EXPECT_EQ(node.input(1)[0], '^');
3716 }
3717 }
3718 }
3719
TEST_F(ConstantFoldingTest,BitcastDenormalFloats)3720 TEST_F(ConstantFoldingTest, BitcastDenormalFloats) {
3721 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3722
3723 Tensor x_t(DT_INT64, TensorShape({2, 2}));
3724 x_t.flat<int64>()(0) = 9223372036854775807L;
3725 x_t.flat<int64>()(1) = 1L;
3726 x_t.flat<int64>()(2) = 9223372036854775807L;
3727 x_t.flat<int64>()(3) = 1L;
3728 Output x = ops::Const(scope.WithOpName("x"), x_t);
3729 Output y = ops::Bitcast(scope.WithOpName("y"), x, DT_FLOAT);
3730 Output z = ops::Bitcast(scope.WithOpName("z"), y, DT_INT64);
3731
3732 GrapplerItem item;
3733 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3734 item.fetch = {"z"};
3735 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {});
3736
3737 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3738 GraphDef output;
3739 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3740 TF_EXPECT_OK(status);
3741
3742 ASSERT_EQ(output.node_size(), 1);
3743 const NodeDef& node = output.node(0);
3744 EXPECT_EQ(node.name(), "z");
3745 EXPECT_EQ(node.op(), "Const");
3746
3747 auto tensors = EvaluateNodes(output, item.fetch, {});
3748 ASSERT_EQ(tensors.size(), 1);
3749 ASSERT_EQ(tensors_expected.size(), 1);
3750 test::ExpectTensorEqual<int64>(tensors[0], tensors_expected[0]);
3751 }
3752
TEST_F(ConstantFoldingTest,CompressConstants)3753 TEST_F(ConstantFoldingTest, CompressConstants) {
3754 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3755 Tensor zeros_t(DT_FLOAT, TensorShape({64}));
3756 Tensor ones_t(DT_FLOAT, TensorShape({64}));
3757 for (int i = 0; i < 64; ++i) {
3758 zeros_t.flat<float>()(i) = 0.0f;
3759 ones_t.flat<float>()(i) = 1.0f;
3760 }
3761 Output zeros = ops::Const(scope.WithOpName("zeros"), zeros_t);
3762 Output host_ones = ops::Const(scope.WithOpName("host_ones"), ones_t);
3763 GrapplerItem item;
3764 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3765 ASSERT_EQ(item.graph.node(1).name(), "host_ones");
3766 // There is not C++ api for HostConst, so we manually change the node type
3767 // here.
3768 item.graph.mutable_node(1)->set_op("HostConst");
3769 item.fetch = {"zeros", "host_ones"};
3770 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {});
3771
3772 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3773 GraphDef output;
3774 TF_EXPECT_OK(optimizer.Optimize(/*cluster=*/nullptr, item, &output));
3775
3776 {
3777 ASSERT_EQ(output.node_size(), 2);
3778 const NodeDef& node = output.node(0);
3779 EXPECT_EQ(node.name(), "zeros");
3780 EXPECT_EQ(node.op(), "Const");
3781 const TensorProto& zeroes_t = node.attr().at("value").tensor();
3782 EXPECT_EQ(zeroes_t.float_val_size(), 1);
3783 EXPECT_EQ(zeroes_t.float_val(0), 0.0f);
3784 }
3785 {
3786 const NodeDef& node = output.node(1);
3787 EXPECT_EQ(node.name(), "host_ones");
3788 EXPECT_EQ(node.op(), "HostConst");
3789 const TensorProto& ones_t = node.attr().at("value").tensor();
3790 EXPECT_EQ(ones_t.float_val_size(), 1);
3791 EXPECT_EQ(ones_t.float_val(0), 1.0f);
3792 }
3793
3794 auto tensors = EvaluateNodes(output, item.fetch, {});
3795 ASSERT_EQ(tensors.size(), 2);
3796 ASSERT_EQ(tensors_expected.size(), 2);
3797 for (int i = 0; i < 2; ++i) {
3798 test::ExpectTensorEqual<float>(tensors[i], tensors_expected[i]);
3799 }
3800 }
3801
3802 } // namespace
3803 } // namespace grappler
3804 } // namespace tensorflow
3805