1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
17 
18 #include "tensorflow/cc/ops/array_ops.h"
19 #include "tensorflow/cc/ops/array_ops_internal.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/function_testlib.h"
22 #include "tensorflow/core/framework/node_def.pb.h"
23 #include "tensorflow/core/framework/tensor_testutil.h"
24 #include "tensorflow/core/framework/types.pb.h"
25 #include "tensorflow/core/grappler/grappler_item.h"
26 #include "tensorflow/core/grappler/utils.h"
27 #include "tensorflow/core/grappler/utils/grappler_test.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/lib/strings/strcat.h"
31 #include "tensorflow/core/platform/tensor_coding.h"
32 
33 namespace tensorflow {
34 namespace grappler {
35 namespace {
36 
37 class ConstantFoldingTest : public GrapplerTest {
38  protected:
39   template <DataType DTYPE>
SimpleNeutralElementTest()40   void SimpleNeutralElementTest() {
41     for (bool use_snapshot : {false, true}) {
42       typedef typename EnumToDataType<DTYPE>::Type T;
43       tensorflow::Scope s = tensorflow::Scope::NewRootScope();
44       Output x = ops::Placeholder(s.WithOpName("x"), DTYPE,
45                                   ops::Placeholder::Shape(TensorShape({2, 2})));
46       Output v = ops::Variable(s.WithOpName("v"), {2, 2}, DTYPE);
47       Tensor zeros_t(DTYPE, TensorShape({2, 2}));
48       Tensor ones_t(DTYPE, TensorShape({2, 2}));
49       Tensor x_t(DTYPE, TensorShape({2, 2}));
50       for (int i = 0; i < 4; ++i) {
51         zeros_t.flat<T>()(i) = T(0);
52         ones_t.flat<T>()(i) = T(1);
53         x_t.flat<T>()(i) = T(i + 1);
54       }
55       Output zeros = ops::Const(s.WithOpName("zeros"), zeros_t);
56       Output ones = ops::Const(s.WithOpName("ones"), ones_t);
57       Output mul1;
58       Output mul2;
59       Output add1;
60       Output add2;
61       if (DTYPE == DT_BOOL) {
62         mul1 = ops::LogicalAnd(s.WithOpName("mul1"), x, zeros);
63         mul2 = ops::LogicalAnd(s.WithOpName("mul2"), x, ones);
64         add1 = ops::LogicalOr(s.WithOpName("add1"), x, zeros);
65         add2 = ops::LogicalOr(s.WithOpName("add2"), x, ones);
66       } else {
67         if (DTYPE == DT_FLOAT) {
68           mul1 = ops::MulNoNan(s.WithOpName("mul1"), x, zeros);
69         } else {
70           mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
71         }
72         mul2 = ops::Mul(s.WithOpName("mul2"), x, ones);
73         add1 = ops::Add(s.WithOpName("add1"), x, zeros);
74         add1 = ops::Add(s.WithOpName("add2"), x, ones);
75       }
76       if (use_snapshot) {
77         // Add an op with ref input to prevent Snapshot from being
78         // turned into Identity.
79         ops::Assign(s.WithOpName("assign"), v, ones);
80       }
81       GrapplerItem item;
82       TF_CHECK_OK(s.ToGraphDef(&item.graph));
83       item.fetch = {"mul1", "mul2", "add1", "add2"};
84       ConstantFolding optimizer(/*cpu_device=*/nullptr);
85       GraphDef output;
86       Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
87       TF_EXPECT_OK(status);
88 
89       EXPECT_EQ(7, output.node_size());
90       const string snapshot_or_identity =
91           use_snapshot ? "Snapshot" : "Identity";
92       for (int i = 0; i < output.node_size(); ++i) {
93         const NodeDef& node = output.node(i);
94         const string& name = node.name();
95         if (name == "mul1") {
96           EXPECT_EQ("Const", node.op());
97           EXPECT_EQ("^x", node.input(0));
98           EXPECT_EQ("^zeros", node.input(1));
99         } else if (name == "mul2") {
100           EXPECT_EQ(snapshot_or_identity, node.op());
101           EXPECT_EQ("x", node.input(0));
102           EXPECT_EQ("^ones", node.input(1));
103         } else if (name == "add1") {
104           EXPECT_EQ(snapshot_or_identity, node.op());
105           EXPECT_EQ("x", node.input(0));
106           EXPECT_EQ("^zeros", node.input(1));
107         } else if (name == "add2") {
108           if (DTYPE == DT_BOOL) {
109             EXPECT_EQ("Const", node.op());
110             EXPECT_EQ("^x", node.input(0));
111             EXPECT_EQ("^ones", node.input(1));
112           } else {
113             EXPECT_EQ("Add", node.op());
114             EXPECT_EQ("x", node.input(0));
115             EXPECT_EQ("ones", node.input(1));
116           }
117         }
118       }
119       auto tensors_expected =
120           EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
121       auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
122       EXPECT_EQ(4, tensors_expected.size());
123       EXPECT_EQ(4, tensors.size());
124       for (int i = 0; i < item.fetch.size(); ++i) {
125         test::ExpectTensorEqual<T>(tensors_expected[i], tensors[i]);
126       }
127     }
128   }
129 
MulConvPushDownTest(const TensorShape & input_shape,const TensorShape & filter_shape,const TensorShape & mul_const_input_shape,const bool use_3d_conv,const char * padding,const char * data_format,const bool expect_folded)130   void MulConvPushDownTest(const TensorShape& input_shape,
131                            const TensorShape& filter_shape,
132                            const TensorShape& mul_const_input_shape,
133                            const bool use_3d_conv, const char* padding,
134                            const char* data_format, const bool expect_folded) {
135     // Tests if the following rewrite is performed:
136     //
137     //         *                       Conv2D
138     //        / \                       / \
139     //       c  Conv2D        -->      x  (c * filter)
140     //           / \
141     //          x  filter
142     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
143 
144     Tensor filter_values(DT_FLOAT, filter_shape);
145     for (int i = 0; i < filter_values.NumElements(); ++i) {
146       filter_values.flat<float>()(i) = std::sqrt(static_cast<float>(i));
147     }
148     Output filter =
149         ops::Const(s.WithOpName("filter"), Input::Initializer(filter_values));
150 
151     Output input = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
152                                     ops::Placeholder::Shape(input_shape));
153 
154     Output conv;
155     if (use_3d_conv) {
156       conv = ops::Conv3D(s.WithOpName("conv"), input, filter, {1, 1, 1, 1, 1},
157                          padding, ops::Conv3D::DataFormat(data_format));
158     } else {
159       conv = ops::Conv2D(s.WithOpName("conv"), input, filter, {1, 1, 1, 1},
160                          padding, ops::Conv2D::DataFormat(data_format));
161     }
162     Tensor mul_const_input(DT_FLOAT, mul_const_input_shape);
163     for (int i = 0; i < mul_const_input.NumElements(); ++i) {
164       mul_const_input.flat<float>()(i) = static_cast<float>(i + 3);
165     }
166     Output c =
167         ops::Const(s.WithOpName("c"), Input::Initializer(mul_const_input));
168     Output mul = ops::Mul(s.WithOpName("mul"), c, conv);
169 
170     GrapplerItem item;
171     TF_CHECK_OK(s.ToGraphDef(&item.graph));
172 
173     ConstantFolding optimizer(/*cpu_device=*/nullptr);
174     GraphDef output;
175     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
176     TF_EXPECT_OK(status);
177 
178     EXPECT_EQ(5, output.node_size());
179     int found = 0;
180     if (expect_folded) {
181       for (const auto& node : output.node()) {
182         if (node.name() == "mul") {
183           found++;
184           EXPECT_EQ(use_3d_conv ? "Conv3D" : "Conv2D", node.op());
185           EXPECT_EQ(2, node.input_size());
186           EXPECT_EQ("x", node.input(0));
187           EXPECT_EQ("conv/merged_input", node.input(1));
188         } else if (node.name() == "conv/merged_input") {
189           found++;
190           EXPECT_EQ("Const", node.op());
191           EXPECT_EQ(0, node.input_size());
192         }
193       }
194     } else {
195       for (const auto& node : output.node()) {
196         if (node.name() == "mul") {
197           found++;
198           EXPECT_EQ("Mul", node.op());
199           EXPECT_EQ(2, node.input_size());
200           EXPECT_EQ("c", node.input(0));
201           EXPECT_EQ("conv", node.input(1));
202         } else if (node.name() == "conv") {
203           found++;
204           EXPECT_EQ(use_3d_conv ? "Conv3D" : "Conv2D", node.op());
205           EXPECT_EQ(2, node.input_size());
206           EXPECT_EQ("x", node.input(0));
207           EXPECT_EQ("filter", node.input(1));
208         }
209       }
210     }
211     EXPECT_EQ(2, found);
212 
213     // Check that const folded multiplication node has the expected value.
214     std::vector<string> fetch = {"mul"};
215     Tensor value(DT_FLOAT, input_shape);
216     for (int i = 0; i < value.NumElements(); ++i) {
217       value.flat<float>()(i) = i;
218     }
219     auto actual = EvaluateNodes(output, fetch, {{"x", value}});
220     auto expected = EvaluateNodes(item.graph, fetch, {{"x", value}});
221     test::ExpectTensorEqual<float>(expected[0], actual[0]);
222   }
223 
224   template <typename T>
PaddingWithZeroSize()225   void PaddingWithZeroSize() {
226     tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
227 
228     auto in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_INT32);
229     auto in2 = ops::Variable(scope.WithOpName("in2"), {2, 2}, DT_INT32);
230     auto paddings1 =
231         ops::Const<T>(scope.WithOpName("paddings1"), {0, 0, 0, 0}, {2, 2});
232     auto paddings2 =
233         ops::Const<T>(scope.WithOpName("paddings2"), {1, 1, 2, 2}, {2, 2});
234     auto c1 = ops::Const(scope.WithOpName("c1"), 1);
235     auto c2 = ops::Const(scope.WithOpName("c2"), 1);
236 
237     ops::PadV2 p1(scope.WithOpName("p1"), in1, paddings1, c1);
238     ops::PadV2 p2(scope.WithOpName("p2"), in2, paddings2, c2);
239 
240     ops::Add out(scope.WithOpName("out"), p1, p2);
241 
242     GrapplerItem item;
243     item.fetch = {"out"};
244     TF_CHECK_OK(scope.ToGraphDef(&item.graph));
245 
246     ConstantFolding optimizer(/*cpu_device=*/nullptr);
247     GraphDef got;
248     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
249     TF_EXPECT_OK(status);
250 
251     GraphDef want;
252     AddNode("in1", "VariableV2", {}, {}, &want);
253     AddNode("in2", "VariableV2", {}, {}, &want);
254     AddNode("paddings1", "Const", {}, {}, &want);
255     AddNode("paddings2", "Const", {}, {}, &want);
256     AddNode("c1", "Const", {}, {}, &want);
257     AddNode("c2", "Const", {}, {}, &want);
258     AddNode(
259         "p1", "Identity",
260         {"in1", AsControlDependency("paddings1"), AsControlDependency("c1")},
261         {}, &want);
262     AddNode("p2", "PadV2", {"in2", "paddings2", "c2"}, {}, &want);
263     AddNode("out", "Add", {"p1", "p2"}, {}, &want);
264 
265     CompareGraphs(want, got);
266 
267     auto in1_t = GenerateRandomTensor<DT_INT32>(TensorShape({4, 6}));
268     auto in2_t = GenerateRandomTensor<DT_INT32>(TensorShape({2, 2}));
269     auto tensors_expected =
270         EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
271     EXPECT_EQ(1, tensors_expected.size());
272     auto tensors =
273         EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
274     EXPECT_EQ(1, tensors.size());
275     test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
276   }
277 };
278 
TEST_F(ConstantFoldingTest,SimpleFolding)279 TEST_F(ConstantFoldingTest, SimpleFolding) {
280   // Build a simple graph with a few trivially prunable ops.
281   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
282 
283   Output a = ops::Const(s.WithOpName("a"), 1.0f, {1});
284   Output b = ops::Const(s.WithOpName("b"), 2.0f, {1});
285   Output c = ops::AddN(s.WithOpName("c").WithDevice("/CPU:0"), {a, b});
286   Output d = ops::AddN(s.WithOpName("d"), {b, c});
287 
288   GrapplerItem item;
289   item.fetch.push_back("d");
290   TF_CHECK_OK(s.ToGraphDef(&item.graph));
291 
292   ConstantFolding optimizer(/*cpu_device=*/nullptr);
293   GraphDef output;
294   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
295   TF_EXPECT_OK(status);
296 
297   EXPECT_EQ(1, output.node_size());
298 
299   const NodeDef& node_d = output.node(0);
300   EXPECT_EQ("d", node_d.name());
301   EXPECT_EQ("Const", node_d.op());
302 
303   std::vector<string> fetch = {"d"};
304   auto tensors_expected = EvaluateNodes(item.graph, fetch);
305   auto tensors = EvaluateNodes(output, fetch);
306   EXPECT_EQ(1, tensors_expected.size());
307   EXPECT_EQ(1, tensors.size());
308   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
309 }
310 
TEST_F(ConstantFoldingTest,AddTree)311 TEST_F(ConstantFoldingTest, AddTree) {
312   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
313 
314   Output c1 = ops::Const(s.WithOpName("c1"), 1.0f, {1});
315   Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2});
316   Output c3 = ops::Const(s.WithOpName("c3"), 3.0f, {2});
317   Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
318                               ops::Placeholder::Shape(TensorShape({2, 2})));
319   Output add_child = ops::Add(s.WithOpName("add_child"), c2, x);
320   Output add_parent = ops::Add(s.WithOpName("add_parent"), c1, add_child);
321 
322   Output c4 = ops::Const(s.WithOpName("c4"), 4.0f, {2});
323   Output c5 = ops::Const(s.WithOpName("c5"), 5.0f, {2});
324   Output c20 = ops::Const(s.WithOpName("c20"), 20.0f, {2});
325   Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
326                               ops::Placeholder::Shape(TensorShape({2, 2})));
327   Output mul_child = ops::Mul(s.WithOpName("mul_child"), c4, y);
328   Output mul_parent = ops::Mul(s.WithOpName("mul_parent"), c5, mul_child);
329   Output addmul_child = ops::Add(s.WithOpName("addmul_child"), c4, x);
330   Output addmul_parent =
331       ops::Mul(s.WithOpName("addmul_parent"), c5, addmul_child);
332 
333   GrapplerItem item;
334   item.fetch = {"add_parent", "mul_parent", "addmul_parent"};
335   TF_CHECK_OK(s.ToGraphDef(&item.graph));
336 
337   ConstantFolding optimizer(/*cpu_device=*/nullptr);
338   GraphDef output;
339   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
340   TF_EXPECT_OK(status);
341 
342   // We expect the following rewrite(s) to occur:
343   //
344   //    +                +             +
345   //   / \              / \           / \
346   // 1.0  +     -->    x   +    -->  x  3.0
347   //     / \              / \
348   //   2.0  x           1.0 2.0
349   //
350   //    *                *             *
351   //   / \              / \           / \
352   // 4.0  *     -->    y   *    -->  y  20.0
353   //     / \              / \
354   //   5.0  y           4.0 5.0
355 
356   EXPECT_EQ(10, output.node_size());
357   for (const auto& node : output.node()) {
358     if (node.name() == "add_child") {
359       EXPECT_EQ("Const", node.op());
360       TensorProto t = node.attr().at("value").tensor();
361       ASSERT_EQ(1, t.tensor_shape().dim_size());
362       EXPECT_EQ(2, t.tensor_shape().dim(0).size());
363     } else if (node.name() == "add_parent") {
364       EXPECT_EQ("Add", node.op());
365       ASSERT_EQ(2, node.input_size());
366       EXPECT_EQ("x", node.input(0));
367       EXPECT_EQ("add_child", node.input(1));
368     } else if (node.name() == "mul_child") {
369       EXPECT_EQ("Const", node.op());
370       TensorProto t = node.attr().at("value").tensor();
371       EXPECT_EQ(1, t.tensor_shape().dim_size());
372       EXPECT_EQ(2, t.tensor_shape().dim(0).size());
373     } else if (node.name() == "mul_parent") {
374       EXPECT_EQ("Mul", node.op());
375       ASSERT_EQ(2, node.input_size());
376       EXPECT_EQ("y", node.input(0));
377       EXPECT_EQ("mul_child", node.input(1));
378     } else if (node.name() == "addmul_child") {
379       // Unchanged.
380       EXPECT_EQ("Add", node.op());
381       ASSERT_EQ(2, node.input_size());
382       EXPECT_EQ("c4", node.input(0));
383       EXPECT_EQ("x", node.input(1));
384     }
385   }
386 
387   // Check that the result nodes have the expected value.
388   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
389   auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
390 
391   std::vector<string> fetch = {"add_parent", "mul_parent"};
392   auto tensor_expected =
393       EvaluateNodes(item.graph, fetch, {{"x", x_t}, {"y", y_t}});
394   ASSERT_EQ(fetch.size(), tensor_expected.size());
395   fetch = {"add_parent", "mul_parent"};
396   auto tensors = EvaluateNodes(output, fetch, {{"x", x_t}, {"y", y_t}});
397   ASSERT_EQ(fetch.size(), tensors.size());
398   for (int i = 0; i < fetch.size(); i++) {
399     test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]);
400   }
401 }
402 
TEST_F(ConstantFoldingTest,AddSubtactTree)403 TEST_F(ConstantFoldingTest, AddSubtactTree) {
404   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
405 
406   Output c1 = ops::Const(s.WithOpName("c1"), 1.0f, {1});
407   Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
408                               ops::Placeholder::Shape(TensorShape({2, 2})));
409   Output sub_child = ops::Sub(s.WithOpName("sub_child"), x, x);
410   Output add_parent = ops::Add(s.WithOpName("add_parent"), sub_child, c1);
411 
412   GrapplerItem item;
413   item.fetch = {"add_parent"};
414   TF_CHECK_OK(s.ToGraphDef(&item.graph));
415 
416   ConstantFolding optimizer(/*cpu_device=*/nullptr);
417   GraphDef output;
418   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
419   TF_EXPECT_OK(status);
420 
421   // We expect the following rewrite(s) to occur:
422   //
423   //     +                +
424   //    / \              / \
425   //   -   1     -->    -   x
426   //  / \              / \
427   // x   x            1   x
428 
429   EXPECT_EQ(4, output.node_size());
430   for (const auto& node : output.node()) {
431     if (node.name() == "sub_child") {
432       EXPECT_EQ("Sub", node.op());
433       ASSERT_EQ(2, node.input_size());
434       EXPECT_EQ("c1", node.input(0));
435       EXPECT_EQ("x", node.input(1));
436     } else if (node.name() == "add_parent") {
437       EXPECT_EQ("Add", node.op());
438       ASSERT_EQ(2, node.input_size());
439       EXPECT_EQ("x", node.input(0));
440       EXPECT_EQ("sub_child", node.input(1));
441     }
442   }
443 
444   // Check that the result nodes have the expected value.
445   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
446 
447   std::vector<string> fetch = {"add_parent"};
448   auto tensor_expected = EvaluateNodes(item.graph, fetch, {{"x", x_t}});
449   ASSERT_EQ(fetch.size(), tensor_expected.size());
450   fetch = {"add_parent"};
451   auto tensors = EvaluateNodes(output, fetch, {{"x", x_t}});
452   ASSERT_EQ(fetch.size(), tensors.size());
453   for (int i = 0; i < fetch.size(); i++) {
454     test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]);
455   }
456 }
457 
TEST_F(ConstantFoldingTest,ConstantPushDown)458 TEST_F(ConstantFoldingTest, ConstantPushDown) {
459   for (int is_add : {true, false}) {
460     for (int is_parent_commutative : {true, false}) {
461       for (int is_child_commutative : {true, false}) {
462         for (int is_left_child_const : {true, false}) {
463           for (int is_left_leaf_const : {true, false}) {
464             tensorflow::Scope s = tensorflow::Scope::NewRootScope();
465             Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2});
466             Output c3 = ops::Const(s.WithOpName("c3"), 3.0f, {2});
467             Output x =
468                 ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
469                                  ops::Placeholder::Shape(TensorShape({2, 2})));
470 
471             auto get_op = [&](bool is_commutative, bool is_left_arg_const,
472                               const string& name, const Output& const_arg,
473                               const Output non_const_arg) -> Output {
474               if (is_add) {
475                 if (is_commutative) {
476                   return ops::Add(
477                       s.WithOpName(name),
478                       is_left_arg_const ? const_arg : non_const_arg,
479                       is_left_arg_const ? non_const_arg : const_arg);
480                 } else {
481                   return ops::Sub(
482                       s.WithOpName(name),
483                       is_left_arg_const ? const_arg : non_const_arg,
484                       is_left_arg_const ? non_const_arg : const_arg);
485                 }
486               } else {
487                 if (is_commutative) {
488                   return ops::Mul(
489                       s.WithOpName(name),
490                       is_left_arg_const ? const_arg : non_const_arg,
491                       is_left_arg_const ? non_const_arg : const_arg);
492                 } else {
493                   return ops::Div(
494                       s.WithOpName(name),
495                       is_left_arg_const ? const_arg : non_const_arg,
496                       is_left_arg_const ? non_const_arg : const_arg);
497                 }
498               }
499             };
500 
501             Output child = get_op(is_child_commutative, is_left_leaf_const,
502                                   "child", c2, x);
503             Output parent = get_op(is_parent_commutative, is_left_child_const,
504                                    "parent", c3, child);
505             GrapplerItem item;
506             item.fetch = {"parent"};
507             TF_CHECK_OK(s.ToGraphDef(&item.graph));
508 
509             ConstantFolding optimizer(/*cpu_device=*/nullptr);
510             GraphDef output;
511             Status status =
512                 optimizer.Optimize(/*cluster=*/nullptr, item, &output);
513             TF_EXPECT_OK(status);
514 
515             // Check that the result nodes have the expected value.
516             auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
517             std::vector<string> fetch = {"parent"};
518             auto tensor_expected =
519                 EvaluateNodes(item.graph, fetch, {{"x", x_t}});
520             ASSERT_EQ(fetch.size(), tensor_expected.size());
521             fetch = {"parent"};
522             auto tensors = EvaluateNodes(output, fetch, {{"x", x_t}});
523             ASSERT_EQ(fetch.size(), tensors.size());
524             for (int i = 0; i < fetch.size(); i++) {
525               test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]);
526             }
527           }
528         }
529       }
530     }
531   }
532 }
533 
TEST_F(ConstantFoldingTest,ConstantPushDownBiasAdd)534 TEST_F(ConstantFoldingTest, ConstantPushDownBiasAdd) {
535   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
536   Output c_mat = ops::Const(s.WithOpName("c_mat"), 2.0f, {2, 2});
537   Output c_vec = ops::Const(s.WithOpName("c_vec"), 3.0f, {2});
538   Output x_mat = ops::Placeholder(s.WithOpName("x_mat"), DT_FLOAT,
539                                   ops::Placeholder::Shape(TensorShape({2, 2})));
540   Output x_vec = ops::Placeholder(s.WithOpName("x_vec"), DT_FLOAT,
541                                   ops::Placeholder::Shape(TensorShape({2})));
542   // Rewrite expected for cases 1 through 3 and their symmetric equivalents,
543   // and case 4.
544   Output child1 = ops::BiasAdd(s.WithOpName("child1"), c_mat, x_vec);
545   Output parent1 = ops::Add(s.WithOpName("parent1"), child1, c_vec);
546   Output child1a = ops::BiasAdd(s.WithOpName("child1a"), c_mat, x_vec);
547   Output parent1a = ops::Add(s.WithOpName("parent1a"), c_vec, child1a);
548 
549   Output child2 = ops::BiasAdd(s.WithOpName("child2"), x_mat, c_vec);
550   Output parent2 = ops::Add(s.WithOpName("parent2"), child2, c_mat);
551   Output child2a = ops::BiasAdd(s.WithOpName("child2a"), x_mat, c_vec);
552   Output parent2a = ops::Add(s.WithOpName("parent2a"), c_mat, child2a);
553 
554   Output child3 = ops::Add(s.WithOpName("child3"), c_mat, x_vec);
555   Output parent3 = ops::BiasAdd(s.WithOpName("parent3"), child3, c_vec);
556   Output child3a = ops::Add(s.WithOpName("child3a"), x_vec, c_mat);
557   Output parent3a = ops::BiasAdd(s.WithOpName("parent3a"), child3a, c_vec);
558 
559   Output child4 = ops::BiasAdd(s.WithOpName("child4"), c_mat, x_vec);
560   Output parent4 = ops::BiasAdd(s.WithOpName("parent4"), child4, c_vec);
561 
562   // No rewrite expected.
563   Output child5 = ops::Add(s.WithOpName("child5"), x_vec, x_vec);
564   Output parent5 = ops::BiasAdd(s.WithOpName("parent5"), c_mat, child5);
565   Output child6 = ops::Add(s.WithOpName("child6"), x_vec, c_vec);
566   Output parent6 = ops::BiasAdd(s.WithOpName("parent6"), c_mat, child6);
567   Output child7 = ops::Add(s.WithOpName("child7"), x_mat, c_vec);
568   Output parent7 = ops::BiasAdd(s.WithOpName("parent7"), child7, c_vec);
569 
570   GrapplerItem item;
571   item.fetch = {"parent1",  "parent2", "parent3", "parent1a", "parent2a",
572                 "parent3a", "parent4", "parent5", "parent6",  "parent7"};
573   TF_CHECK_OK(s.ToGraphDef(&item.graph));
574 
575   ConstantFolding optimizer(/*cpu_device=*/nullptr);
576   GraphDef output;
577   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
578   TF_EXPECT_OK(status);
579 
580   EXPECT_EQ(24, output.node_size());
581   for (const auto& node : output.node()) {
582     if (node.name() == "child1" || node.name() == "child1a" ||
583         node.name() == "child2" || node.name() == "child2a" ||
584         node.name() == "child3" || node.name() == "child3a" ||
585         node.name() == "child4") {
586       EXPECT_EQ(node.op(), "Const") << " node: " << node.name();
587     } else if (node.name() != "c_mat" && node.name() != "c_vec") {
588       EXPECT_NE(node.op(), "Const") << " node: " << node.name();
589     }
590   }
591   // Check that the result nodes have the expected value.
592   auto x_mat_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
593   auto x_vec_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2}));
594   std::vector<string> fetch = item.fetch;
595   auto tensor_expected = EvaluateNodes(
596       item.graph, fetch, {{"x_vec", x_vec_t}, {"x_mat", x_mat_t}});
597   ASSERT_EQ(fetch.size(), tensor_expected.size());
598   auto tensors =
599       EvaluateNodes(output, fetch, {{"x_vec", x_vec_t}, {"x_mat", x_mat_t}});
600   ASSERT_EQ(fetch.size(), tensors.size());
601   for (int i = 0; i < fetch.size(); i++) {
602     test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]);
603   }
604 }
605 
606 // This test fails on ROCm platform (see commit message for details)
607 #ifndef TENSORFLOW_USE_ROCM
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_ScalarConst)608 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_ScalarConst) {
609   for (string data_format : {
610          "NHWC",
611 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
612              "NCHW"
613 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
614        }) {
615     MulConvPushDownTest(
616         /*input_shape=*/data_format == "NHWC" ? TensorShape{4, 10, 10, 3}
617                                               : TensorShape{4, 3, 10, 10},
618         /*filter_shape=*/{2, 2, 3, 5},
619         /*mul_const_input_shape=*/{},
620         /*use_3d_conv=*/false,
621         /*padding=*/"VALID", data_format.c_str(),
622         /*expect_folded=*/true);
623   }
624 }
625 #endif
626 
627 // This test fails on ROCm platform (see commit message for details)
628 #ifndef TENSORFLOW_USE_ROCM
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_SingletonConst)629 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_SingletonConst) {
630   for (string data_format : {
631          "NHWC",
632 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
633              "NCHW"
634 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
635        }) {
636     for (auto mul_const_input_shape :
637          {TensorShape{1}, TensorShape{1, 1, 1, 1}}) {
638       MulConvPushDownTest(
639           /*input_shape=*/data_format == "NHWC" ? TensorShape{4, 10, 10, 3}
640                                                 : TensorShape{4, 3, 10, 10},
641           /*filter_shape=*/{2, 2, 3, 5}, mul_const_input_shape,
642           /*use_3d_conv=*/false,
643           /*padding=*/"VALID", data_format.c_str(),
644           /*expect_folded=*/true);
645     }
646   }
647 }
648 #endif
649 
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_SingletonConst_ShapeMismatch)650 TEST_F(ConstantFoldingTest,
651        MulConvPushDownTest_Conv2D_SingletonConst_ShapeMismatch) {
652   for (string data_format : {
653          "NHWC",
654 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
655              "NCHW"
656 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
657        }) {
658     MulConvPushDownTest(
659         /*input_shape=*/data_format == "NHWC" ? TensorShape{4, 10, 10, 3}
660                                               : TensorShape{4, 3, 10, 10},
661         /*filter_shape=*/{2, 2, 3, 5},
662         /*mul_const_input_shape=*/{1, 1, 1, 1, 1},
663         /*use_3d_conv=*/false,
664         /*padding=*/"VALID", data_format.c_str(),
665         /*expect_folded=*/false);
666   }
667 }
668 
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_3x1x3Const)669 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_3x1x3Const) {
670   for (auto data_format : {
671          "NHWC",
672 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
673              "NCHW"
674 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
675        }) {
676     MulConvPushDownTest(
677         /*input_shape=*/{3, 3, 3, 3},
678         /*filter_shape=*/{3, 3, 3, 3},
679         /*mul_const_input_shape=*/{3, 1, 3},
680         /*use_3d_conv=*/false,
681         /*padding=*/"SAME", data_format,
682         /*expect_folded=*/false);
683   }
684 }
685 
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_NHWC_VectorLikeConst)686 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_NHWC_VectorLikeConst) {
687   for (auto mul_const_input_shape :
688        {TensorShape{3}, TensorShape{1, 3}, TensorShape{1, 1, 1, 3}}) {
689     MulConvPushDownTest(
690         /*input_shape=*/{3, 3, 3, 3},
691         /*filter_shape=*/{3, 3, 3, 3}, mul_const_input_shape,
692         /*use_3d_conv=*/false,
693         /*padding=*/"SAME",
694         /*data_format=*/"NHWC",
695         /*expect_folded=*/true);
696   }
697 }
698 
699 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_NCHW_VectorLikeConst)700 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_NCHW_VectorLikeConst) {
701   for (auto mul_const_input_shape :
702        {TensorShape{3}, TensorShape{3, 1, 1}, TensorShape{1, 3, 1, 1}}) {
703     MulConvPushDownTest(
704         /*input_shape=*/{3, 3, 3, 3},
705         /*filter_shape=*/{3, 3, 3, 3}, mul_const_input_shape,
706         /*use_3d_conv=*/false,
707         /*padding=*/"SAME",
708         /*data_format=*/"NCHW",
709         // TODO(laigd): optimization should happen in this case.
710         /*expect_folded=*/false);
711   }
712 }
713 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
714 
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_3x1Const)715 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_3x1Const) {
716   for (auto data_format : {
717          "NHWC",
718 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
719              "NCHW"
720 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
721        }) {
722     MulConvPushDownTest(
723         /*input_shape=*/{3, 3, 3, 3},
724         /*filter_shape=*/{3, 3, 3, 3},
725         /*mul_const_input_shape=*/{3, 1},
726         /*use_3d_conv=*/false,
727         /*padding=*/"SAME", data_format,
728         /*expect_folded=*/false);
729   }
730 }
731 
732 // This test fails on ROCm platform (see commit message for details)
733 #ifndef TENSORFLOW_USE_ROCM
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv3D_NDHWC_1x1x3Const)734 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv3D_NDHWC_1x1x3Const) {
735   MulConvPushDownTest(
736       /*input_shape=*/{3, 3, 3, 3, 3},
737       /*filter_shape=*/{3, 3, 3, 3, 3},
738       /*mul_const_input_shape=*/{1, 1, 3},
739       /*use_3d_conv=*/true,
740       /*padding=*/"SAME",
741       /*data_format=*/"NDHWC",
742       /*expect_folded=*/true);
743 }
744 #endif
745 
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv3D_NCDHW_3x1x1x1Const)746 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv3D_NCDHW_3x1x1x1Const) {
747   MulConvPushDownTest(
748       /*input_shape=*/{3, 3, 3, 3, 3},
749       /*filter_shape=*/{3, 3, 3, 3, 3},
750       /*mul_const_input_shape=*/{3, 1, 1, 1},
751       /*use_3d_conv=*/true,
752       /*padding=*/"SAME",
753       /*data_format=*/"NDHWC",
754       // TODO(laigd): optimization should happen in this case.
755       /*expect_folded=*/false);
756 }
757 
TEST_F(ConstantFoldingTest,NeutralElement)758 TEST_F(ConstantFoldingTest, NeutralElement) {
759   int kConst = 0;
760   int kLike = 1;
761   int kFill = 2;
762   for (int const_type : {kConst, kLike, kFill}) {
763     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
764     Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
765                                 ops::Placeholder::Shape(TensorShape({2, 2})));
766     Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
767                                 ops::Placeholder::Shape(TensorShape({2, 2})));
768     Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
769                                 ops::Placeholder::Shape(TensorShape({3, 2})));
770     Output b = ops::Placeholder(s.WithOpName("b"), DT_FLOAT,
771                                 ops::Placeholder::Shape(TensorShape({2, 3})));
772     Output bias = ops::Placeholder(s.WithOpName("bias"), DT_FLOAT,
773                                    ops::Placeholder::Shape(TensorShape({2})));
774     Output zeros_1d = ops::Const(s.WithOpName("zeros_1d"), 0.0f, {2});
775     Output zeros_const = ops::Const(s.WithOpName("zeros_const"), 0.0f, {2, 2});
776     Output zeros_const_bcast =
777         ops::Const(s.WithOpName("zeros_const_bcast"), 0.0f, {2, 2, 2});
778     Output zeros_like = ops::ZerosLike(s.WithOpName("zeros_like"), x);
779     Output zeros_fill = ops::Fill(s.WithOpName("zeros_fill"), {2, 2}, 0.0f);
780     Output zeros = const_type == kConst
781                        ? zeros_const
782                        : (const_type == kLike ? zeros_like : zeros_fill);
783     Output ones_const = ops::Const(s.WithOpName("ones_const"), 1.0f, {2, 2});
784     Output ones_const_bcast =
785         ops::Const(s.WithOpName("ones_const_bcast"), 1.0f, {2, 2, 2});
786     Output ones_like = ops::OnesLike(s.WithOpName("ones_like"), x);
787     Output ones_fill = ops::Fill(s.WithOpName("ones_fill"), {2, 2}, 1.0f);
788     Output ones = const_type == kConst
789                       ? ones_const
790                       : (const_type == kLike ? ones_like : ones_fill);
791     Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
792     Output mul2 = ops::Mul(s.WithOpName("mul2"), zeros, y);
793     Output mul1_bcast =
794         ops::Mul(s.WithOpName("mul1_bcast"), x, ones_const_bcast);
795     Output mul2_bcast =
796         ops::Mul(s.WithOpName("mul2_bcast"), ones_const_bcast, y);
797     Output mul3 = ops::Mul(s.WithOpName("mul3"), x, ones);
798     Output mul4 = ops::Mul(s.WithOpName("mul4"), ones, y);
799     Output mul5 = ops::MulNoNan(s.WithOpName("mul5"), x, zeros_1d);
800     Output mul6 = ops::MulNoNan(s.WithOpName("mul6"), zeros_1d, y);
801     Output div1 = ops::Div(s.WithOpName("div1"), x, ones);
802     Output div2 = ops::Div(s.WithOpName("div2"), ones, y);
803     Output matmul1 = ops::MatMul(s.WithOpName("matmul1"), x, zeros);
804     Output matmul2 = ops::MatMul(s.WithOpName("matmul2"), zeros, y);
805     Output matmul3 = ops::MatMul(s.WithOpName("matmul3"), a, zeros);
806     Output matmul4 = ops::MatMul(s.WithOpName("matmul4"), zeros, b);
807     Output add1 = ops::Add(s.WithOpName("add1"), x, zeros);
808     Output add2 = ops::Add(s.WithOpName("add2"), zeros, y);
809     Output add1_bcast =
810         ops::Add(s.WithOpName("add1_bcast"), x, zeros_const_bcast);
811     Output add2_bcast =
812         ops::Add(s.WithOpName("add2_bcast"), zeros_const_bcast, y);
813     Output bias_add1 = ops::BiasAdd(s.WithOpName("bias_add1"), x, zeros_1d);
814     Output bias_add2 = ops::BiasAdd(s.WithOpName("bias_add2"), zeros, bias);
815     Output sub1 = ops::Sub(s.WithOpName("sub1"), x, zeros);
816     Output sub2 = ops::Sub(s.WithOpName("sub2"), zeros, y);
817     Output concat =
818         ops::Stack(s.WithOpName("stack"),
819                    {mul1, mul2, mul3, mul4, mul5, mul6, div1, div2, matmul1,
820                     matmul2, add1, add2, bias_add1, bias_add2, sub1, sub2});
821     GrapplerItem item;
822     TF_CHECK_OK(s.ToGraphDef(&item.graph));
823     item.fetch = {"stack",      "matmul3",    "matmul4",   "mul1_bcast",
824                   "mul2_bcast", "add1_bcast", "add2_bcast"};
825 
826     ConstantFolding optimizer(/*cpu_device=*/nullptr);
827     GraphDef output;
828     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
829     TF_EXPECT_OK(status);
830 
831     const string suffix =
832         (const_type == kConst ? "_const"
833                               : (const_type == kLike ? "_like" : "_fill"));
834     const string zeros_name = strings::StrCat("zeros", suffix);
835     const string ones_name = strings::StrCat("ones", suffix);
836     const string ctrl_zeros_name = strings::StrCat("^zeros", suffix);
837     const string ctrl_ones_name = strings::StrCat("^ones", suffix);
838 
839     EXPECT_EQ(const_type == kFill ? 42 : 38, output.node_size());
840     for (int i = 0; i < output.node_size(); ++i) {
841       const NodeDef& node = output.node(i);
842       const string& name = node.name();
843       if (name == "mul1") {
844         EXPECT_EQ("Const", node.op());
845         EXPECT_EQ("^x", node.input(0));
846         EXPECT_EQ(ctrl_zeros_name, node.input(1));
847       } else if (name == "mul2") {
848         EXPECT_EQ("Const", node.op());
849         EXPECT_EQ(ctrl_zeros_name, node.input(0));
850         EXPECT_EQ("^y", node.input(1));
851       } else if (name == "mul1_bcast") {
852         EXPECT_EQ("BroadcastTo", node.op());
853         EXPECT_EQ("x", node.input(0));
854         EXPECT_EQ("^ones_const_bcast", node.input(2));
855       } else if (name == "mul2_bcast") {
856         EXPECT_EQ("BroadcastTo", node.op());
857         EXPECT_EQ("y", node.input(0));
858         EXPECT_EQ("^ones_const_bcast", node.input(2));
859       } else if (name == "mul3") {
860         EXPECT_EQ("Identity", node.op());
861         EXPECT_EQ("x", node.input(0));
862         EXPECT_EQ(ctrl_ones_name, node.input(1));
863       } else if (name == "mul4") {
864         EXPECT_EQ("Identity", node.op());
865         EXPECT_EQ("y", node.input(0));
866         EXPECT_EQ(ctrl_ones_name, node.input(1));
867       } else if (name == "mul5") {
868         EXPECT_EQ("Const", node.op());
869         EXPECT_EQ("^x", node.input(0));
870         EXPECT_EQ("^zeros_1d", node.input(1));
871       } else if (name == "mul6") {
872         EXPECT_EQ("Const", node.op());
873         EXPECT_EQ("^zeros_1d", node.input(0));
874         EXPECT_EQ("^y", node.input(1));
875       } else if (name == "div1") {
876         EXPECT_EQ("Identity", node.op());
877         EXPECT_EQ("x", node.input(0));
878         EXPECT_EQ(ctrl_ones_name, node.input(1));
879       } else if (name == "div2") {
880         EXPECT_EQ("Reciprocal", node.op());
881         EXPECT_EQ("y", node.input(0));
882         EXPECT_EQ(ctrl_ones_name, node.input(1));
883       } else if (name == "matmul1") {
884         EXPECT_EQ("Const", node.op());
885         EXPECT_EQ("^x", node.input(0));
886         EXPECT_EQ(ctrl_zeros_name, node.input(1));
887       } else if (name == "matmul2") {
888         EXPECT_EQ("Const", node.op());
889         EXPECT_EQ(ctrl_zeros_name, node.input(0));
890         EXPECT_EQ("^y", node.input(1));
891       } else if (name == "matmul3") {
892         EXPECT_EQ("Const", node.op());
893         EXPECT_EQ("^a", node.input(0));
894         EXPECT_EQ(ctrl_zeros_name, node.input(1));
895         TensorProto t = node.attr().at("value").tensor();
896         EXPECT_EQ(1, t.float_val_size());
897         EXPECT_EQ(0, t.float_val(0));
898         EXPECT_EQ(2, t.tensor_shape().dim_size());
899         EXPECT_EQ(3, t.tensor_shape().dim(0).size());
900         EXPECT_EQ(2, t.tensor_shape().dim(1).size());
901       } else if (name == "matmul4") {
902         EXPECT_EQ("Const", node.op());
903         EXPECT_EQ(ctrl_zeros_name, node.input(0));
904         EXPECT_EQ("^b", node.input(1));
905         TensorProto t = node.attr().at("value").tensor();
906         EXPECT_EQ(1, t.float_val_size());
907         EXPECT_EQ(0, t.float_val(0));
908         EXPECT_EQ(2, t.tensor_shape().dim_size());
909         EXPECT_EQ(2, t.tensor_shape().dim(0).size());
910         EXPECT_EQ(3, t.tensor_shape().dim(1).size());
911       } else if (name == "add1") {
912         EXPECT_EQ("Identity", node.op());
913         EXPECT_EQ("x", node.input(0));
914         EXPECT_EQ(ctrl_zeros_name, node.input(1));
915       } else if (name == "add2") {
916         EXPECT_EQ("Identity", node.op());
917         EXPECT_EQ("y", node.input(0));
918         EXPECT_EQ(ctrl_zeros_name, node.input(1));
919       } else if (name == "add1_bcast") {
920         EXPECT_EQ("BroadcastTo", node.op());
921         EXPECT_EQ("x", node.input(0));
922         EXPECT_EQ("^zeros_const_bcast", node.input(2));
923       } else if (name == "add2_bcast") {
924         EXPECT_EQ("BroadcastTo", node.op());
925         EXPECT_EQ("y", node.input(0));
926         EXPECT_EQ("^zeros_const_bcast", node.input(2));
927       } else if (name == "bias_add1") {
928         EXPECT_EQ("Identity", node.op());
929         EXPECT_EQ("x", node.input(0));
930         EXPECT_EQ("^zeros_1d", node.input(1));
931       } else if (name == "bias_add2") {
932         EXPECT_EQ("BroadcastTo", node.op());
933         EXPECT_EQ("bias", node.input(0));
934         EXPECT_EQ("ConstantFolding/bias_add2-broadcastto_shape-1",
935                   node.input(1));
936         EXPECT_EQ(ctrl_zeros_name, node.input(2));
937       } else if (name == "ConstantFolding/bias_add2-broadcastto_shape-1") {
938         EXPECT_EQ("Const", node.op());
939         EXPECT_EQ(ctrl_zeros_name, node.input(0));
940         EXPECT_EQ(node.attr().at("dtype").type(), DT_INT32);
941         TensorProto t = node.attr().at("value").tensor();
942         EXPECT_EQ(DT_INT32, t.dtype());
943         EXPECT_EQ(1, t.tensor_shape().dim_size());
944         EXPECT_EQ(2, t.tensor_shape().dim(0).size());
945       } else if (name == "sub1") {
946         EXPECT_EQ("Identity", node.op());
947         EXPECT_EQ("x", node.input(0));
948         EXPECT_EQ(ctrl_zeros_name, node.input(1));
949       } else if (name == "sub2") {
950         EXPECT_EQ("Neg", node.op());
951         EXPECT_EQ("y", node.input(0));
952         EXPECT_EQ(ctrl_zeros_name, node.input(1));
953       }
954       const std::set<string> square_zero_const{"mul1", "mul2",    "mul5",
955                                                "mul6", "matmul1", "matmul2"};
956       if (square_zero_const.count(name) > 0) {
957         TensorProto t = node.attr().at("value").tensor();
958         EXPECT_EQ(1, t.float_val_size());
959         EXPECT_EQ(0, t.float_val(0));
960         EXPECT_EQ(2, t.tensor_shape().dim_size());
961         EXPECT_EQ(2, t.tensor_shape().dim(0).size());
962         EXPECT_EQ(2, t.tensor_shape().dim(1).size());
963       }
964     }
965     auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 2}));
966     auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 3}));
967     auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
968     auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
969     auto bias_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2}));
970 
971     auto tensors_expected = EvaluateNodes(
972         item.graph, item.fetch,
973         {{"x", x_t}, {"y", y_t}, {"a", a_t}, {"b", b_t}, {"bias", bias_t}});
974     EXPECT_EQ(item.fetch.size(), tensors_expected.size());
975     auto tensors = EvaluateNodes(
976         output, item.fetch,
977         {{"x", x_t}, {"y", y_t}, {"a", a_t}, {"b", b_t}, {"bias", bias_t}});
978     EXPECT_EQ(item.fetch.size(), tensors.size());
979     for (int i = 0; i < item.fetch.size(); ++i) {
980       test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
981     }
982   }
983 }
984 
TEST_F(ConstantFoldingTest,NeutralElement_ShortFloats)985 TEST_F(ConstantFoldingTest, NeutralElement_ShortFloats) {
986   SimpleNeutralElementTest<DT_BOOL>();
987   SimpleNeutralElementTest<DT_HALF>();
988   SimpleNeutralElementTest<DT_BFLOAT16>();
989 }
990 
TEST_F(ConstantFoldingTest,StrengthReduce_Reciprocal)991 TEST_F(ConstantFoldingTest, StrengthReduce_Reciprocal) {
992   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
993   Output cf_half = ops::Const(s.WithOpName("cf_half"), 0.5f, {1});
994   Output xf = ops::Placeholder(s.WithOpName("xf"), DT_FLOAT,
995                                ops::Placeholder::Shape(TensorShape({2, 2})));
996   Output xi = ops::Placeholder(s.WithOpName("xi"), DT_INT32,
997                                ops::Placeholder::Shape(TensorShape({2, 2})));
998   Output ci = ops::Const(s.WithOpName("ci"), 2, {1});
999   Output cf = ops::Const(s.WithOpName("cf"), 2.0f, {1});
1000   Output div_i = ops::Div(s.WithOpName("div_i"), xi, ci);
1001   Output div_f = ops::Div(s.WithOpName("div_f"), xf, cf);
1002   Output realdiv = ops::RealDiv(s.WithOpName("realdiv"), xf, cf);
1003 
1004   GrapplerItem item;
1005   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1006   item.fetch = {"div_f", "div_i", "realdiv"};
1007   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1008   GraphDef output;
1009   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1010   TF_EXPECT_OK(status);
1011 
1012   EXPECT_EQ(8, output.node_size());
1013   for (int i = 0; i < output.node_size(); ++i) {
1014     const NodeDef& node = output.node(i);
1015     const string& name = node.name();
1016     if (name == "div_i") {
1017       // Integer division is unchanged.
1018       EXPECT_EQ("Div", node.op());
1019       EXPECT_EQ("xi", node.input(0));
1020       EXPECT_EQ("ci", node.input(1));
1021     } else if (name == "div_f") {
1022       EXPECT_EQ("Mul", node.op());
1023       EXPECT_EQ("xf", node.input(0));
1024       EXPECT_EQ("ConstantFolding/div_f_recip", node.input(1));
1025     } else if (name == "realdiv") {
1026       EXPECT_EQ("Mul", node.op());
1027       EXPECT_EQ("xf", node.input(0));
1028       EXPECT_EQ("ConstantFolding/realdiv_recip", node.input(1));
1029     } else if (name == "ConstantFolding/div_f_recip") {
1030       EXPECT_EQ("Const", node.op());
1031       EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
1032       TensorProto t = node.attr().at("value").tensor();
1033       EXPECT_EQ(DT_FLOAT, t.dtype());
1034       EXPECT_EQ(1, t.tensor_shape().dim_size());
1035       EXPECT_EQ(1, t.tensor_shape().dim(0).size());
1036     } else if (name == "ConstantFolding/realdiv_recip") {
1037       EXPECT_EQ("Const", node.op());
1038       EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
1039       TensorProto t = node.attr().at("value").tensor();
1040       EXPECT_EQ(DT_FLOAT, t.dtype());
1041       EXPECT_EQ(1, t.tensor_shape().dim_size());
1042       EXPECT_EQ(1, t.tensor_shape().dim(0).size());
1043     }
1044   }
1045 
1046   // Check that the reciprocals have the expected value.
1047   std::vector<string> fetch = {"cf_half"};
1048   auto tensor_expected = EvaluateNodes(item.graph, fetch);
1049   EXPECT_EQ(fetch.size(), tensor_expected.size());
1050   fetch = {"ConstantFolding/div_f_recip", "ConstantFolding/realdiv_recip"};
1051   auto tensors = EvaluateNodes(output, fetch);
1052   EXPECT_EQ(fetch.size(), tensors.size());
1053   for (int i = 0; i < fetch.size(); i++) {
1054     test::ExpectTensorEqual<float>(tensor_expected[0], tensors[i]);
1055   }
1056 }
1057 
TEST_F(ConstantFoldingTest,NeutralElement_PartialShape_UnknownOutputShape)1058 TEST_F(ConstantFoldingTest, NeutralElement_PartialShape_UnknownOutputShape) {
1059   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1060   Output x_known =
1061       ops::Placeholder(s.WithOpName("x_known"), DT_FLOAT,
1062                        ops::Placeholder::Shape(TensorShape({2, 2})));
1063   Output x_partially_known =
1064       ops::Placeholder(s.WithOpName("x_partially_unknown"), DT_FLOAT,
1065                        ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
1066   Output x_unknown = ops::Placeholder(s.WithOpName("x_unknown"), DT_FLOAT);
1067   Output zeros_known = ops::ZerosLike(s.WithOpName("zeros_known"), x_known);
1068   Output zeros_partially_known =
1069       ops::ZerosLike(s.WithOpName("zeros_partially_known"), x_partially_known);
1070   Output zeros_unknown =
1071       ops::ZerosLike(s.WithOpName("zeros_unknown"), x_unknown);
1072 
1073   // Multiplies without any additional ops to supply the output shape.
1074   int count = 0;
1075   std::vector<Output> muls;
1076   std::unordered_set<string> not_converted;
1077   std::unordered_set<string> to_const;
1078   std::unordered_set<string> to_identity;
1079   for (const auto* x : {&x_known, &x_partially_known, &x_unknown}) {
1080     for (const auto* zeros :
1081          {&zeros_known, &zeros_partially_known, &zeros_unknown}) {
1082       const string name = strings::StrCat("mul_", count++);
1083       muls.push_back(ops::Mul(s.WithOpName(name), *x, *zeros));
1084       if (x == &x_partially_known && zeros == &zeros_partially_known) {
1085         to_identity.insert(name);
1086       } else if (x == &x_unknown || zeros == &zeros_unknown) {
1087         not_converted.insert(name);
1088       } else {
1089         to_const.insert(name);
1090       }
1091     }
1092   }
1093 
1094   GrapplerItem item;
1095   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1096 
1097   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1098   GraphDef output;
1099   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1100   TF_EXPECT_OK(status);
1101 
1102   EXPECT_EQ(15, output.node_size());
1103   for (int i = 0; i < output.node_size(); ++i) {
1104     const NodeDef& node = output.node(i);
1105     const string& name = node.name();
1106     if (to_const.count(name) > 0) {
1107       EXPECT_EQ("Const", node.op()) << node.name();
1108     } else if (to_identity.count(name) > 0) {
1109       EXPECT_EQ("Identity", node.op()) << node.name();
1110     } else if (not_converted.count(name) > 0) {
1111       EXPECT_EQ("Mul", node.op()) << node.name();
1112     }
1113   }
1114 
1115   const std::vector<string> fetch = {"mul_0", "mul_4", "mul_8"};
1116   auto x_known_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1117   auto x_partially_unknown_t =
1118       GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
1119   auto x_unknown_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 7}));
1120   auto expected_tensors =
1121       EvaluateNodes(item.graph, fetch,
1122                     {{"x_known", x_known_t},
1123                      {"x_partially_unknown", x_partially_unknown_t},
1124                      {"x_unknown", x_unknown_t}});
1125   EXPECT_EQ(fetch.size(), expected_tensors.size());
1126   auto tensors = EvaluateNodes(output, fetch,
1127                                {{"x_known", x_known_t},
1128                                 {"x_partially_unknown", x_partially_unknown_t},
1129                                 {"x_unknown", x_unknown_t}});
1130   EXPECT_EQ(fetch.size(), tensors.size());
1131   for (int i = 0; i < tensors.size(); i++)
1132     test::ExpectTensorNear<float>(expected_tensors[i], tensors[i], 1e-5);
1133 }
1134 
TEST_F(ConstantFoldingTest,NeutralElement_PartialShape_KnownOutputShape)1135 TEST_F(ConstantFoldingTest, NeutralElement_PartialShape_KnownOutputShape) {
1136   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1137   Output known_shape = ops::Const(s.WithOpName("known_shape"), 0.0f, {2, 2});
1138   Output x_partially_known =
1139       ops::Placeholder(s.WithOpName("x_partially_unknown"), DT_FLOAT,
1140                        ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
1141   Output x_unknown = ops::Placeholder(s.WithOpName("x_unknown"), DT_FLOAT);
1142   Output zeros_partially_known =
1143       ops::ZerosLike(s.WithOpName("zeros_partially_known"), x_partially_known);
1144   Output zeros_unknown =
1145       ops::ZerosLike(s.WithOpName("zeros_unknown"), x_unknown);
1146 
1147   // If at least one of the inputs to AddN has a known shape, shape inference
1148   // will propagate the shape back to the inputs of AddN, making the
1149   // output shapes of all its inputs known
1150   std::vector<Output> muls_deduced_output_shape;
1151   std::unordered_set<string> to_const;
1152   int count = 0;
1153   for (const auto& x : {x_partially_known, x_unknown}) {
1154     for (const auto& zeros : {zeros_partially_known, zeros_unknown}) {
1155       const string name = strings::StrCat("mul_", count++);
1156       muls_deduced_output_shape.push_back(
1157           ops::Mul(s.WithOpName(name), x, zeros));
1158       to_const.insert(name);
1159     }
1160   }
1161   // We add a known shape as input to AddN to propagate it back to the
1162   // multiplies above, which means they can all be turned into Const nodes.
1163   muls_deduced_output_shape.push_back(known_shape);
1164   Output addn1 = ops::AddN(s.WithOpName("addn1"), muls_deduced_output_shape);
1165 
1166   GrapplerItem item;
1167   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1168 
1169   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1170   GraphDef output;
1171   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1172   TF_EXPECT_OK(status);
1173 
1174   EXPECT_EQ(10, output.node_size());
1175   for (int i = 0; i < output.node_size(); ++i) {
1176     const NodeDef& node = output.node(i);
1177     const string& name = node.name();
1178     if (to_const.count(name) > 0) {
1179       EXPECT_EQ("Const", node.op()) << node.name();
1180       EXPECT_EQ(2, node.input_size());
1181       EXPECT_TRUE(IsControlInput(node.input(0)));
1182       EXPECT_TRUE(IsControlInput(node.input(1)));
1183     }
1184   }
1185   const std::vector<string> fetch = {"addn1"};
1186   auto x_partially_unknown_t =
1187       GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1188   auto x_unknown_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1189   auto expected_tensors =
1190       EvaluateNodes(item.graph, fetch,
1191                     {{"x_partially_unknown", x_partially_unknown_t},
1192                      {"x_unknown", x_unknown_t}});
1193   EXPECT_EQ(1, expected_tensors.size());
1194   auto tensors = EvaluateNodes(output, fetch,
1195                                {{"x_partially_unknown", x_partially_unknown_t},
1196                                 {"x_unknown", x_unknown_t}});
1197   EXPECT_EQ(1, tensors.size());
1198   test::ExpectTensorNear<float>(expected_tensors[0], tensors[0], 1e-5);
1199 }
1200 
TEST_F(ConstantFoldingTest,CreateConstNodes)1201 TEST_F(ConstantFoldingTest, CreateConstNodes) {
1202   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1203 
1204 #define MAKE_TEST_GRAPH(TYPE)                                               \
1205   Output TYPE##_const =                                                     \
1206       ops::Const(s.WithOpName(#TYPE "_const"), static_cast<TYPE>(10), {5}); \
1207   Output TYPE##_mul =                                                       \
1208       ops::Mul(s.WithOpName(#TYPE "_mul"), TYPE##_const, TYPE##_const);     \
1209   Output TYPE##_id = ops::Identity(s.WithOpName(#TYPE "_id"), TYPE##_mul)
1210 
1211   MAKE_TEST_GRAPH(float);
1212   MAKE_TEST_GRAPH(double);
1213   MAKE_TEST_GRAPH(int64);
1214   MAKE_TEST_GRAPH(int32);
1215   MAKE_TEST_GRAPH(int16);
1216   MAKE_TEST_GRAPH(int8);
1217   MAKE_TEST_GRAPH(uint8);
1218 #undef MAKE_TEST_GRAPH
1219 
1220   Output bool_const = ops::Const(s.WithOpName("bool_const"), true, {5});
1221   Output bool_and =
1222       ops::LogicalAnd(s.WithOpName("bool_and"), bool_const, bool_const);
1223   Output bool_id = ops::Identity(s.WithOpName("bool_id"), bool_and);
1224 
1225   GrapplerItem item;
1226   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1227   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1228   GraphDef output;
1229   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1230   TF_EXPECT_OK(status);
1231 
1232   EXPECT_EQ(24, output.node_size());
1233   for (const NodeDef& node : output.node()) {
1234 #define CHECK_RESULT(TYPE, FIELD)                                             \
1235   if (node.name() == #TYPE "_mul") {                                          \
1236     EXPECT_EQ(5,                                                              \
1237               node.attr().at("value").tensor().tensor_shape().dim(0).size()); \
1238     EXPECT_EQ(1, node.attr().at("value").tensor().FIELD##_val_size());        \
1239     EXPECT_EQ(10 * 10, node.attr().at("value").tensor().FIELD##_val(0));      \
1240   }
1241 
1242     CHECK_RESULT(float, float);
1243     CHECK_RESULT(double, double);
1244     CHECK_RESULT(int64, int64);
1245     CHECK_RESULT(int32, int);
1246     CHECK_RESULT(int16, int);
1247     CHECK_RESULT(int8, int);
1248     CHECK_RESULT(uint8, int);
1249 #undef CHECK_RESULT
1250 
1251     if (node.name() == "bool_and") {
1252       EXPECT_EQ(5,
1253                 node.attr().at("value").tensor().tensor_shape().dim(0).size());
1254       EXPECT_EQ(1, node.attr().at("value").tensor().bool_val_size());
1255       EXPECT_EQ(true && true, node.attr().at("value").tensor().bool_val(0));
1256     }
1257   }
1258 }
1259 
TEST_F(ConstantFoldingTest,FoldingNodeWithTwoOutputs)1260 TEST_F(ConstantFoldingTest, FoldingNodeWithTwoOutputs) {
1261   // Build a simple graph with a few trivially prunable ops.
1262   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1263 
1264   Output a = ops::Const(s.WithOpName("a"), 10, {5});
1265   auto b = ops::Unique(s.WithOpName("b"), {a});
1266   Output c = ops::Identity(s.WithOpName("c"), {b.y});
1267   Output d = ops::Identity(s.WithOpName("d"), {b.idx});
1268   Output e = ops::Identity(s.WithOpName("e"), {c});
1269   Output f = ops::Identity(s.WithOpName("f"), {d});
1270 
1271   GrapplerItem item;
1272   item.fetch.push_back("e");
1273   item.fetch.push_back("f");
1274   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1275 
1276   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1277   GraphDef output;
1278   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1279   TF_EXPECT_OK(status);
1280 
1281   EXPECT_EQ(2, output.node_size());
1282 
1283   const NodeDef& new_c = output.node(0);
1284   EXPECT_EQ("e", new_c.name());
1285   EXPECT_EQ("Const", new_c.op());
1286 
1287   const NodeDef& new_d = output.node(1);
1288   EXPECT_EQ("f", new_d.name());
1289   EXPECT_EQ("Const", new_d.op());
1290 
1291   std::vector<string> fetch = {"e", "f"};
1292   auto tensors_expected = EvaluateNodes(item.graph, fetch);
1293   auto tensors = EvaluateNodes(output, fetch);
1294   EXPECT_EQ(fetch.size(), tensors_expected.size());
1295   EXPECT_EQ(fetch.size(), tensors.size());
1296   for (int i = 0; i < fetch.size(); i++) {
1297     test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
1298   }
1299 }
1300 
TEST_F(ConstantFoldingTest,ControlDependencies)1301 TEST_F(ConstantFoldingTest, ControlDependencies) {
1302   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1303   Output dflt = ops::Const(scope.WithOpName("dflt"), 3.14f, {1});
1304   Output p1 = ops::PlaceholderWithDefault(scope.WithOpName("p1"), dflt, {1});
1305   Output p2 = ops::PlaceholderWithDefault(scope.WithOpName("p2"), dflt, {1});
1306   Output c =
1307       ops::Const(scope.WithOpName("c").WithControlDependencies(p1), 10, {3});
1308   Output i1 = ops::Identity(scope.WithOpName("i1"), {c});
1309   Output i2 =
1310       ops::Identity(scope.WithOpName("i2").WithControlDependencies(p2), {i1});
1311   Output i3 = ops::Identity(scope.WithOpName("i3"), {i2});
1312 
1313   GrapplerItem item;
1314   item.fetch.push_back("i3");
1315   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1316 
1317   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1318   GraphDef output;
1319   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1320   TF_EXPECT_OK(status);
1321 
1322   std::vector<string> expected_nodes = {"dflt", "p1", "p2", "i3"};
1323   EXPECT_EQ(output.node_size(), expected_nodes.size());
1324   int i = 0;
1325   int found = 0;
1326   for (const auto& node : output.node()) {
1327     EXPECT_EQ(expected_nodes[i], output.node(i).name());
1328     i++;
1329     if (node.name() == "i3") {
1330       EXPECT_EQ("Const", node.op());
1331       ++found;
1332       auto folded = EvaluateNodes(output, {"i3"});
1333       auto expected = EvaluateNodes(item.graph, {"i3"});
1334       EXPECT_EQ(1, expected.size());
1335       EXPECT_EQ(1, folded.size());
1336       test::ExpectTensorEqual<int>(folded[0], expected[0]);
1337       EXPECT_EQ(2, node.input_size());
1338       EXPECT_EQ("^p1", node.input(0));
1339       EXPECT_EQ("^p2", node.input(1));
1340     }
1341   }
1342   EXPECT_EQ(1, found);
1343 }
1344 
TEST_F(ConstantFoldingTest,ControlDependenciesEmptyFetch)1345 TEST_F(ConstantFoldingTest, ControlDependenciesEmptyFetch) {
1346   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1347   Output dflt = ops::Const(scope.WithOpName("dflt"), 3.14f, {1});
1348   Output p1 = ops::PlaceholderWithDefault(scope.WithOpName("p1"), dflt, {1});
1349   Output p2 = ops::PlaceholderWithDefault(scope.WithOpName("p2"), dflt, {1});
1350   Output c =
1351       ops::Const(scope.WithOpName("c").WithControlDependencies(p1), 10, {3});
1352   Output i1 = ops::Identity(scope.WithOpName("i1"), {c});
1353   Output i2 =
1354       ops::Identity(scope.WithOpName("i2").WithControlDependencies(p2), {i1});
1355   Output i3 = ops::Identity(scope.WithOpName("e"), {i2});
1356 
1357   GrapplerItem item;
1358   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1359 
1360   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1361   GraphDef output;
1362   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1363   TF_EXPECT_OK(status);
1364 
1365   std::vector<string> expected_nodes = {"dflt", "p1", "p2", "c",
1366                                         "i1",   "i2", "e"};
1367   EXPECT_EQ(output.node_size(), expected_nodes.size());
1368   int i = 0;
1369   int found = 0;
1370   for (const auto& node : output.node()) {
1371     EXPECT_EQ(expected_nodes[i], output.node(i).name());
1372     i++;
1373     if (node.name() == "i1") {
1374       EXPECT_EQ("Const", node.op());
1375       ++found;
1376       auto folded = EvaluateNodes(output, {"i1"});
1377       auto expected = EvaluateNodes(item.graph, {"i1"});
1378       EXPECT_EQ(1, expected.size());
1379       EXPECT_EQ(1, folded.size());
1380       test::ExpectTensorEqual<int>(folded[0], expected[0]);
1381       EXPECT_EQ(1, node.input_size());
1382       EXPECT_EQ("^p1", node.input(0));
1383     }
1384     if (node.name() == "i2") {
1385       EXPECT_EQ("Const", node.op());
1386       ++found;
1387       auto folded = EvaluateNodes(output, {"i2"});
1388       auto expected = EvaluateNodes(item.graph, {"i2"});
1389       EXPECT_EQ(1, expected.size());
1390       EXPECT_EQ(1, folded.size());
1391       test::ExpectTensorEqual<int>(folded[0], expected[0]);
1392       EXPECT_EQ(2, node.input_size());
1393       EXPECT_EQ("^p1", node.input(0));
1394       EXPECT_EQ("^p2", node.input(1));
1395     }
1396   }
1397   EXPECT_EQ(2, found);
1398 }
1399 
TEST_F(ConstantFoldingTest,ControlDependenciesDeduplicate)1400 TEST_F(ConstantFoldingTest, ControlDependenciesDeduplicate) {
1401   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1402   Output dflt = ops::Const(scope.WithOpName("dflt"), 3.14f, {1});
1403   Output p1 = ops::PlaceholderWithDefault(scope.WithOpName("p1"), dflt, {1});
1404   Output p2 = ops::PlaceholderWithDefault(scope.WithOpName("p2"), dflt, {1});
1405   Output c =
1406       ops::Const(scope.WithOpName("c").WithControlDependencies(p1), 10, {3});
1407   Output i1 = ops::Identity(scope.WithOpName("i1")
1408                                 .WithControlDependencies(p2)
1409                                 .WithControlDependencies(p1),
1410                             {c});
1411   Output i2 = ops::Identity(scope.WithOpName("i2"), {i1});
1412 
1413   GrapplerItem item;
1414   item.fetch.push_back("i2");
1415   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1416   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1417   EXPECT_EQ(1, tensors_expected.size());
1418   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1419   GraphDef output;
1420   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1421   TF_EXPECT_OK(status);
1422 
1423   std::vector<string> expected_nodes = {"dflt", "p1", "p2", "i2"};
1424   EXPECT_EQ(output.node_size(), expected_nodes.size());
1425   int i = 0;
1426   for (const auto& node : output.node()) {
1427     EXPECT_EQ(expected_nodes[i], output.node(i).name());
1428     i++;
1429     if (node.name() == "i2") {
1430       EXPECT_EQ("Const", node.op());
1431       EXPECT_EQ(2, node.input_size());
1432       EXPECT_EQ("^p1", node.input(0));
1433       EXPECT_EQ("^p2", node.input(1));
1434     }
1435   }
1436   auto tensors = EvaluateNodes(output, item.fetch);
1437   EXPECT_EQ(1, tensors.size());
1438   test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1439 }
1440 
TEST_F(ConstantFoldingTest,VariableNumberOfOutputs)1441 TEST_F(ConstantFoldingTest, VariableNumberOfOutputs) {
1442   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1443   // Add a DynamicPartition node to the graph
1444   Output input = ops::Const(scope.WithOpName("in0"), 314, {3, 4, 5});
1445   Output indices = ops::Const(scope.WithOpName("indices"), 1, {3, 4});
1446   int num_partitions = 4;
1447   ops::DynamicPartition part(scope.WithOpName("partition"), input, indices,
1448                              num_partitions);
1449 
1450   std::vector<string> outputs;
1451   for (int i = 0; i < num_partitions; ++i) {
1452     string part_out_name = strings::StrCat("part_out", i);
1453     ops::Identity partition_out(scope.WithOpName(part_out_name),
1454                                 {part.outputs[i]});
1455     outputs.push_back(part_out_name);
1456   }
1457 
1458   GrapplerItem item;
1459   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1460 
1461   // Add a ConcatOffset node to the graph
1462   Tensor initial_val(DT_INT32, TensorShape({3}));
1463   test::FillIota<int>(&initial_val, 7);
1464   for (int i = 1; i < 5; ++i) {
1465     TF_CHECK_OK(NodeDefBuilder(strings::StrCat("in", i), "Const")
1466                     .Attr("dtype", DT_INT32)
1467                     .Attr("value", initial_val)
1468                     .Finalize(item.graph.add_node()));
1469   }
1470   Tensor concat_dim(DT_INT32, TensorShape({}));
1471   test::FillIota<int>(&concat_dim, 0);
1472   TF_CHECK_OK(NodeDefBuilder("concat_dim", "Const")
1473                   .Attr("dtype", DT_INT32)
1474                   .Attr("value", concat_dim)
1475                   .Finalize(item.graph.add_node()));
1476 
1477   TF_CHECK_OK(NodeDefBuilder("concat_offsets", "ConcatOffset")
1478                   .Input("concat_dim", 0, DT_INT32)
1479                   .Input({NodeDefBuilder::NodeOut("in1", 0, DT_INT32),
1480                           NodeDefBuilder::NodeOut("in2", 0, DT_INT32),
1481                           NodeDefBuilder::NodeOut("in3", 0, DT_INT32),
1482                           NodeDefBuilder::NodeOut("in4", 0, DT_INT32)})
1483                   .Finalize(item.graph.add_node()));
1484 
1485   for (int i = 0; i < 4; ++i) {
1486     string concat_offset_out_name = strings::StrCat("concat_offset_out", i);
1487     TF_CHECK_OK(NodeDefBuilder(concat_offset_out_name, "Identity")
1488                     .Attr("T", DT_INT32)
1489                     .Input("concat_offsets", i, DT_INT32)
1490                     .Finalize(item.graph.add_node()));
1491     outputs.push_back(concat_offset_out_name);
1492   }
1493 
1494   item.fetch = outputs;
1495   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1496   GraphDef output;
1497   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1498   TF_EXPECT_OK(status);
1499 
1500   int constant_folded = 0;
1501   for (const auto& node : output.node()) {
1502     if (node.name().find("part_out") != string::npos ||
1503         node.name().find("concat_offset_out") != string::npos) {
1504       ++constant_folded;
1505       EXPECT_EQ("Const", node.op());
1506     }
1507   }
1508   EXPECT_EQ(8, constant_folded);
1509 
1510   auto expected = EvaluateNodes(item.graph, outputs);
1511   auto optimized = EvaluateNodes(output, outputs);
1512   ASSERT_EQ(expected.size(), optimized.size());
1513   for (int i = 0; i < expected.size(); ++i) {
1514     test::ExpectTensorEqual<int>(expected[i], optimized[i]);
1515   }
1516 }
1517 
TEST_F(ConstantFoldingTest,ShapeMaterialization)1518 TEST_F(ConstantFoldingTest, ShapeMaterialization) {
1519   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1520   Output v1 = ops::Variable(scope.WithOpName("v1"), {3}, DT_FLOAT);
1521   Output v2 = ops::Variable(scope.WithOpName("v2"), {5, 7}, DT_FLOAT);
1522   Output v3 = ops::Variable(scope.WithOpName("v3"), {11, 13}, DT_FLOAT);
1523   Output rank = ops::Rank(scope.WithOpName("rank"), v1);
1524   Output shape = ops::Shape(scope.WithOpName("shape"), v2);
1525   Output size = ops::Size(scope.WithOpName("size"), v3);
1526   Output p1 = ops::Multiply(scope.WithOpName("p1"), size, rank);
1527   Output p2 = ops::Multiply(scope.WithOpName("p2"), p1, shape);
1528 
1529   GrapplerItem item;
1530   item.fetch.push_back("p2");
1531   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1532 
1533   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1534   GraphDef output;
1535   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1536   TF_EXPECT_OK(status);
1537 
1538   int found = 0;
1539   for (const auto& node : output.node()) {
1540     if (node.name() == "p2") {
1541       ++found;
1542       EXPECT_EQ("Const", node.op());
1543       EXPECT_EQ(3, node.input_size());
1544       EXPECT_EQ("^v3", node.input(0));
1545       EXPECT_EQ("^v1", node.input(1));
1546       EXPECT_EQ("^v2", node.input(2));
1547       Tensor value;
1548       CHECK(value.FromProto(node.attr().at("value").tensor()));
1549       // rank = 1, shape = (5, 7), size = 143 = 11*13
1550       // p2 = (715, 1001) = (5*143, 7*143)
1551       EXPECT_EQ(715, value.flat<int>()(0));
1552       EXPECT_EQ(1001, value.flat<int>()(1));
1553     }
1554   }
1555   EXPECT_EQ(1, found);
1556   auto v1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3}));
1557   auto v2_t = GenerateRandomTensor<DT_FLOAT>({5, 7});
1558   auto v3_t = GenerateRandomTensor<DT_FLOAT>({11, 13});
1559 
1560   auto tensors_expected = EvaluateNodes(
1561       item.graph, item.fetch, {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1562   EXPECT_EQ(1, item.fetch.size());
1563   auto tensors = EvaluateNodes(output, item.fetch,
1564                                {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1565   EXPECT_EQ(1, item.fetch.size());
1566   test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1567 }
1568 
TEST_F(ConstantFoldingTest,ShapeMaterializationEmptyFetch)1569 TEST_F(ConstantFoldingTest, ShapeMaterializationEmptyFetch) {
1570   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1571   Output v1 = ops::Variable(scope.WithOpName("v1"), {3}, DT_FLOAT);
1572   Output v2 = ops::Variable(scope.WithOpName("v2"), {5, 7}, DT_FLOAT);
1573   Output v3 = ops::Variable(scope.WithOpName("v3"), {11, 13}, DT_FLOAT);
1574   Output rank = ops::Rank(scope.WithOpName("rank"), v1);
1575   Output shape = ops::Shape(scope.WithOpName("shape"), v2);
1576   Output size = ops::Size(scope.WithOpName("size"), v3);
1577   Output p1 = ops::Multiply(scope.WithOpName("p1"), size, rank);
1578   Output p2 = ops::Multiply(scope.WithOpName("p2"), p1, shape);
1579 
1580   GrapplerItem item;
1581   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1582 
1583   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1584   GraphDef output;
1585   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1586   TF_EXPECT_OK(status);
1587 
1588   int found = 0;
1589   for (const auto& node : output.node()) {
1590     if (node.name() == "size") {
1591       ++found;
1592       EXPECT_EQ("Const", node.op());
1593       EXPECT_EQ(1, node.input_size());
1594       EXPECT_EQ("^v3", node.input(0));
1595       Tensor value;
1596       CHECK(value.FromProto(node.attr().at("value").tensor()));
1597       EXPECT_EQ(11 * 13, value.flat<int>()(0));
1598     } else if (node.name() == "rank") {
1599       ++found;
1600       EXPECT_EQ("Const", node.op());
1601       EXPECT_EQ(1, node.input_size());
1602       EXPECT_EQ("^v1", node.input(0));
1603       Tensor value;
1604       CHECK(value.FromProto(node.attr().at("value").tensor()));
1605       EXPECT_EQ(1, value.flat<int>()(0));
1606     } else if (node.name() == "shape") {
1607       ++found;
1608       EXPECT_EQ("Const", node.op());
1609       EXPECT_EQ(1, node.input_size());
1610       EXPECT_EQ("^v2", node.input(0));
1611       Tensor value;
1612       CHECK(value.FromProto(node.attr().at("value").tensor()));
1613       EXPECT_EQ(5, value.flat<int>()(0));
1614       EXPECT_EQ(7, value.flat<int>()(1));
1615     }
1616   }
1617   EXPECT_EQ(3, found);
1618 
1619   auto v1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3}));
1620   auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 7}));
1621   auto v3_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({11, 13}));
1622   std::vector<string> fetch_nodes = {"p2"};
1623   auto tensors_expected = EvaluateNodes(
1624       item.graph, fetch_nodes, {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1625   EXPECT_EQ(1, tensors_expected.size());
1626   auto tensors = EvaluateNodes(output, fetch_nodes,
1627                                {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1628   EXPECT_EQ(1, tensors.size());
1629   test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1630 }
1631 
TEST_F(ConstantFoldingTest,ShapeMaterializationShapeN)1632 TEST_F(ConstantFoldingTest, ShapeMaterializationShapeN) {
1633   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1634   Output v1 = ops::Variable(scope.WithOpName("v1"), {3, -1}, DT_FLOAT);
1635   Output v2 = ops::Variable(scope.WithOpName("v2"), {}, DT_FLOAT);
1636   Output v3 = ops::Variable(scope.WithOpName("v3"), {4, 6}, DT_FLOAT);
1637   auto s = ops::ShapeN(scope.WithOpName("s"), {v1, v2, v3});
1638   Output i1a = ops::Identity(scope.WithOpName("i1a"), s[0]);
1639   Output i1b = ops::Identity(scope.WithOpName("i1b"), s[0]);
1640   Output i2a = ops::Identity(scope.WithOpName("i2a"), s[1]);
1641   Output i2b = ops::Identity(scope.WithOpName("i2b"), s[1]);
1642   Output i2c = ops::Identity(scope.WithOpName("i2c"), s[1]);
1643   Output i3a = ops::Identity(scope.WithOpName("i3a"), s[2]);
1644   Output i3b = ops::Identity(scope.WithOpName("i3b"), s[2]);
1645 
1646   GrapplerItem item;
1647   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1648 
1649   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1650   GraphDef output;
1651   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1652   TF_EXPECT_OK(status);
1653   int found = 0;
1654   for (const auto& node : output.node()) {
1655     EXPECT_NE(AddPrefixToNodeName("s-matshapes-0", kConstantFoldingConst),
1656               node.name());
1657     EXPECT_NE(AddPrefixToNodeName("s-matshapes-1", kConstantFoldingConst),
1658               node.name());
1659     if (node.name() == "i1a" || node.name() == "i1b") {
1660       ++found;
1661       EXPECT_EQ("s", node.input(0));
1662     }
1663     if (node.name() == "i2a" || node.name() == "i2b" || node.name() == "i2c") {
1664       ++found;
1665       EXPECT_EQ("s:1", node.input(0));
1666     }
1667     if (node.name() == "i3a" || node.name() == "i3b") {
1668       ++found;
1669       EXPECT_EQ(AddPrefixToNodeName("s-matshapes-2", kConstantFoldingConst),
1670                 node.input(0));
1671     }
1672     if (node.name() == "s") {
1673       ++found;
1674       EXPECT_EQ("ShapeN", node.op());
1675       EXPECT_EQ("v1", node.input(0));
1676       EXPECT_EQ("v2", node.input(1));
1677       EXPECT_EQ("v3", node.input(2));
1678     }
1679     if (node.name() ==
1680         AddPrefixToNodeName("s-matshapes-2", kConstantFoldingConst)) {
1681       ++found;
1682       EXPECT_EQ("Const", node.op());
1683       EXPECT_EQ("^s", node.input(0));
1684       Tensor value;
1685       CHECK(value.FromProto(node.attr().at("value").tensor()));
1686       EXPECT_EQ(4, value.flat<int>()(0));
1687       EXPECT_EQ(6, value.flat<int>()(1));
1688     }
1689   }
1690   EXPECT_EQ(9, found);
1691 
1692   auto v1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
1693   auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 6}));
1694   auto v3_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 6}));
1695   const std::vector<string> fetch_nodes = {"i1a", "i1b", "i2a", "i2b",
1696                                            "i2c", "i3a", "i3b"};
1697   auto tensors_expected = EvaluateNodes(
1698       item.graph, fetch_nodes, {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1699   EXPECT_EQ(fetch_nodes.size(), tensors_expected.size());
1700   auto tensors = EvaluateNodes(output, fetch_nodes,
1701                                {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1702   EXPECT_EQ(fetch_nodes.size(), tensors.size());
1703   for (int i = 0; i < fetch_nodes.size(); i++)
1704     test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
1705 }
1706 
TEST_F(ConstantFoldingTest,ShapeMaterializationShapeN_MultipleOutputs)1707 TEST_F(ConstantFoldingTest, ShapeMaterializationShapeN_MultipleOutputs) {
1708   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1709   Output v1 = ops::Variable(scope.WithOpName("v1"), {3, -1}, DT_FLOAT);
1710   Output v2 = ops::Variable(scope.WithOpName("v2"), {4, 6}, DT_FLOAT);
1711   auto s = ops::ShapeN(scope.WithOpName("s"), {v1, v2});
1712   auto id_n = ops::IdentityN(scope.WithOpName("id_n"), {s[0], s[1]});
1713   Output ia = ops::Identity(scope.WithOpName("ia"), id_n[0]);
1714   Output ib = ops::Identity(scope.WithOpName("ib"), id_n[1]);
1715 
1716   GrapplerItem item;
1717   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1718   item.fetch.push_back("ia");
1719   item.fetch.push_back("ib");
1720 
1721   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1722   GraphDef output;
1723   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1724   TF_EXPECT_OK(status);
1725 
1726   int found = 0;
1727   for (const auto& node : output.node()) {
1728     EXPECT_NE(AddPrefixToNodeName("s-matshapes-0", kConstantFoldingConst),
1729               node.name());
1730     if (node.name() == "s") {
1731       ++found;
1732       EXPECT_EQ("ShapeN", node.op());
1733       EXPECT_EQ("v1", node.input(0));
1734       EXPECT_EQ("v2", node.input(1));
1735     }
1736     if (node.name() == "id_n") {
1737       ++found;
1738       EXPECT_EQ("IdentityN", node.op());
1739       EXPECT_EQ("s", node.input(0));
1740       EXPECT_EQ(AddPrefixToNodeName("s-matshapes-1", kConstantFoldingConst),
1741                 node.input(1));
1742     }
1743     if (node.name() == "ia") {
1744       ++found;
1745       EXPECT_EQ("id_n", node.input(0));
1746     }
1747     if (node.name() == "ib") {
1748       ++found;
1749       EXPECT_EQ("Const", node.op());
1750       EXPECT_EQ("^s", node.input(0));
1751       EXPECT_EQ("^id_n", node.input(1));
1752     }
1753   }
1754   EXPECT_EQ(4, found);
1755 
1756   auto v1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
1757   auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 6}));
1758   auto tensors_expected =
1759       EvaluateNodes(item.graph, item.fetch, {{"v1", v1_t}, {"v2", v2_t}});
1760   EXPECT_EQ(2, tensors_expected.size());
1761   auto tensors =
1762       EvaluateNodes(output, item.fetch, {{"v1", v1_t}, {"v2", v2_t}});
1763   EXPECT_EQ(2, tensors.size());
1764   for (int i = 0; i < tensors.size(); i++)
1765     test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
1766 }
1767 
TEST_F(ConstantFoldingTest,SwitchNodesEmptyFetch)1768 TEST_F(ConstantFoldingTest, SwitchNodesEmptyFetch) {
1769   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1770   ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT);
1771   ops::Variable v_ctrl(scope.WithOpName("v_ctrl"), {}, DT_BOOL);
1772   ops::Switch s1(scope.WithOpName("switch"), v_in, v_ctrl);
1773   ops::Rank rank(scope.WithOpName("rank"), s1.output_false);
1774   ops::Identity i(scope.WithOpName("i"), s1.output_true);
1775   ops::Size size(scope.WithOpName("size"), i);
1776   ops::Square p1(scope.WithOpName("p1"), rank);
1777   ops::Square p2(scope.WithOpName("p2"), size);
1778   ops::Merge m(scope.WithOpName("m"), {p1.y, p2.y});
1779 
1780   Output predicate =
1781       ops::Const(scope.WithOpName("false"), false, TensorShape({}));
1782   Output constant =
1783       ops::Const(scope.WithOpName("constant"), 1.0f, TensorShape({1}));
1784   ops::Switch s2(scope.WithOpName("switch2"), constant, predicate);
1785   ops::Identity statically_known(scope.WithOpName("i2"), s2.output_false);
1786   ops::Identity never_generated(scope.WithOpName("i3"), s2.output_true);
1787   ops::Merge m2(scope.WithOpName("m2"),
1788                 {statically_known.output, never_generated.output});
1789 
1790   GrapplerItem item;
1791   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1792 
1793   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1794   GraphDef output;
1795   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1796   TF_EXPECT_OK(status);
1797 
1798   std::set<string> present_nodes = {"v_in",     "v_ctrl",
1799                                     "switch",   "i",
1800                                     "p1",       "p2",
1801                                     "m",        "false",
1802                                     "constant", "switch2",
1803                                     "i2",       "i3",
1804                                     "m2",       "ConstantFoldingCtrl/switch_0",
1805                                     "rank",     "size"};
1806   std::set<string> not_present_nodes = {"ConstantFolding/switch2-0"};
1807   EXPECT_EQ(present_nodes.size(), output.node_size());
1808   int found = 0;
1809   for (const auto& node : output.node()) {
1810     EXPECT_TRUE(present_nodes.find(node.name()) != present_nodes.end())
1811         << node.name();
1812     EXPECT_TRUE(not_present_nodes.find(node.name()) == not_present_nodes.end())
1813         << node.name();
1814     present_nodes.erase(node.name());
1815     not_present_nodes.erase(node.name());
1816     if (node.name() == "rank") {
1817       ++found;
1818       EXPECT_EQ("Const", node.op());
1819       EXPECT_EQ(1, node.input_size());
1820       EXPECT_EQ("^ConstantFoldingCtrl/switch_0", node.input(0));
1821     }
1822     if (node.name() == "size") {
1823       ++found;
1824       EXPECT_EQ("Const", node.op());
1825       EXPECT_EQ(1, node.input_size());
1826       EXPECT_EQ("^i", node.input(0));
1827     }
1828     if (node.name() == "i2") {
1829       ++found;
1830       EXPECT_EQ("Const", node.op());
1831       EXPECT_EQ(0, node.input_size());
1832     }
1833     if (node.name() == "i3") {
1834       ++found;
1835       EXPECT_EQ("Identity", node.op());
1836       EXPECT_EQ(1, node.input_size());
1837       EXPECT_EQ("switch2:1", node.input(0));
1838     }
1839   }
1840   EXPECT_EQ(4, found);
1841 
1842   auto v_in_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3}));
1843   Tensor v_ctrl_t(DT_BOOL, TensorShape({}));
1844 
1845   v_ctrl_t.flat<bool>()(0) = true;
1846   std::vector<string> fetch_nodes = {"m", "m2"};
1847   auto tensors_expected = EvaluateNodes(
1848       item.graph, fetch_nodes, {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1849   EXPECT_EQ(2, tensors_expected.size());
1850   auto tensors = EvaluateNodes(output, fetch_nodes,
1851                                {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1852   EXPECT_EQ(2, tensors.size());
1853   test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1854   test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-5);
1855 
1856   v_ctrl_t.flat<bool>()(0) = false;
1857   tensors_expected = EvaluateNodes(item.graph, fetch_nodes,
1858                                    {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1859   EXPECT_EQ(2, tensors_expected.size());
1860   tensors = EvaluateNodes(output, fetch_nodes,
1861                           {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1862   EXPECT_EQ(2, tensors.size());
1863   test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1864   test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-5);
1865 }
1866 
TEST_F(ConstantFoldingTest,SwitchNodes)1867 TEST_F(ConstantFoldingTest, SwitchNodes) {
1868   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1869   ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT);
1870   ops::Variable v_ctrl(scope.WithOpName("v_ctrl"), {}, DT_BOOL);
1871   ops::Switch s1(scope.WithOpName("switch"), v_in, v_ctrl);
1872   ops::Rank rank(scope.WithOpName("rank"), s1.output_false);
1873   ops::Identity i(scope.WithOpName("i"), s1.output_true);
1874   ops::Size size(scope.WithOpName("size"), i);
1875   ops::Square p1(scope.WithOpName("p1"), rank);
1876   ops::Square p2(scope.WithOpName("p2"), size);
1877   ops::Merge m(scope.WithOpName("m"), {p1.y, p2.y});
1878 
1879   Output predicate =
1880       ops::Const(scope.WithOpName("false"), false, TensorShape({}));
1881   Output constant =
1882       ops::Const(scope.WithOpName("constant"), 1.0f, TensorShape({1}));
1883   ops::Switch s2(scope.WithOpName("switch2"), constant, predicate);
1884   ops::Identity statically_known(scope.WithOpName("i2"), s2.output_false);
1885   ops::Identity never_generated(scope.WithOpName("i3"), s2.output_true);
1886   ops::Merge m2(scope.WithOpName("m2"),
1887                 {statically_known.output, never_generated.output});
1888 
1889   GrapplerItem item;
1890   item.fetch.push_back("m");
1891   item.fetch.push_back("m2");
1892 
1893   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1894 
1895   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1896   GraphDef output;
1897   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1898   TF_EXPECT_OK(status);
1899   std::set<string> present_nodes = {"v_in",     "v_ctrl",
1900                                     "switch",   "i",
1901                                     "p1",       "p2",
1902                                     "m",        "false",
1903                                     "constant", "switch2",
1904                                     "i2",       "i3",
1905                                     "m2",       "ConstantFoldingCtrl/switch_0"};
1906   std::set<string> not_present_nodes = {"rank", "size",
1907                                         "ConstantFolding/switch2-0"};
1908   EXPECT_EQ(present_nodes.size(), output.node_size());
1909 
1910   int found = 0;
1911   for (const auto& node : output.node()) {
1912     EXPECT_TRUE(present_nodes.find(node.name()) != present_nodes.end());
1913     EXPECT_TRUE(not_present_nodes.find(node.name()) == not_present_nodes.end());
1914     present_nodes.erase(node.name());
1915     not_present_nodes.erase(node.name());
1916     if (node.name() == "i2") {
1917       ++found;
1918       EXPECT_EQ("Const", node.op());
1919       EXPECT_EQ(0, node.input_size());
1920     }
1921     if (node.name() == "i3") {
1922       ++found;
1923       EXPECT_EQ("Identity", node.op());
1924       EXPECT_EQ(1, node.input_size());
1925       EXPECT_EQ("switch2:1", node.input(0));
1926     }
1927   }
1928   EXPECT_EQ(2, found);
1929 
1930   auto v_in_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3}));
1931   Tensor v_ctrl_t(DT_BOOL, TensorShape({}));
1932   v_ctrl_t.flat<bool>()(0) = true;
1933   auto tensors_expected = EvaluateNodes(
1934       item.graph, item.fetch, {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1935   EXPECT_EQ(2, tensors_expected.size());
1936   auto tensors = EvaluateNodes(output, item.fetch,
1937                                {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1938   EXPECT_EQ(2, tensors.size());
1939   test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1940   test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-5);
1941 
1942   v_ctrl_t.flat<bool>()(0) = false;
1943   tensors_expected = EvaluateNodes(item.graph, item.fetch,
1944                                    {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1945   EXPECT_EQ(2, tensors_expected.size());
1946   tensors = EvaluateNodes(output, item.fetch,
1947                           {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1948   EXPECT_EQ(2, tensors.size());
1949   test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1950   test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-5);
1951 }
1952 
TEST_F(ConstantFoldingTest,MergeNodes)1953 TEST_F(ConstantFoldingTest, MergeNodes) {
1954   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1955 
1956   Output x =
1957       ops::RandomNormal(scope.WithOpName("x"), {3, 5}, DataType::DT_FLOAT);
1958   Output y =
1959       ops::RandomNormal(scope.WithOpName("y"), {3, 5}, DataType::DT_FLOAT);
1960   Output const1 =
1961       ops::Const(scope.WithOpName("const1").WithControlDependencies(x), 2.7f,
1962                  TensorShape({3, 5}));
1963   Output const2 =
1964       ops::Const(scope.WithOpName("const2"), 3.14f, TensorShape({3, 5}));
1965   Output const3 =
1966       ops::Const(scope.WithOpName("const3").WithControlDependencies(x), 3.14f,
1967                  TensorShape({3, 5}));
1968 
1969   // Create 3 merge nodes: m1 is foldable, m2 and m3 aren't.
1970   ops::Merge m1(scope.WithOpName("m1"), {x, const1, const2});
1971   ops::Merge m2(scope.WithOpName("m2"), {const1, const3});
1972   ops::Merge m3(scope.WithOpName("m3"), {x, y});
1973   // m4 is not foldable because the only constant input
1974   // has a control input, so we cannot know if it will be
1975   // triggered.
1976   ops::Merge m4(scope.WithOpName("m4"), {x, const1});
1977 
1978   ops::Identity out1(scope.WithOpName("out1"), m1.output);
1979   ops::Identity idx1(scope.WithOpName("idx1"), m1.value_index);
1980   ops::Identity out2(scope.WithOpName("out2"), m2.output);
1981   ops::Identity idx2(scope.WithOpName("idx2"), m2.value_index);
1982   ops::Identity out3(scope.WithOpName("out3"), m3.output);
1983   ops::Identity idx3(scope.WithOpName("idx3"), m3.value_index);
1984   ops::Identity out4(scope.WithOpName("out4"), m4.output);
1985   ops::Identity idx4(scope.WithOpName("idx4"), m4.value_index);
1986 
1987   GrapplerItem item;
1988   item.fetch = {"out1", "idx1", "out2", "idx2", "out3", "idx3", "out4", "idx4"};
1989   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1990 
1991   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1992   GraphDef output;
1993   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1994   TF_EXPECT_OK(status);
1995 
1996   EXPECT_EQ(19, output.node_size());
1997   int found_nodes = 0;
1998   for (const auto& node : output.node()) {
1999     if (node.name() == "out1") {
2000       EXPECT_EQ(1, node.input_size());
2001       EXPECT_EQ("^m1", node.input(0));
2002       ++found_nodes;
2003     } else if (node.name() == "idx1") {
2004       EXPECT_EQ(1, node.input_size());
2005       EXPECT_EQ("^m1", node.input(0));
2006       ++found_nodes;
2007     } else if (node.name() == "ConstantFolding/m1") {
2008       EXPECT_EQ("Const", node.op());
2009       EXPECT_EQ(1, node.input_size());
2010       EXPECT_EQ("^m1", node.input(0));
2011       ++found_nodes;
2012     } else if (node.name() == "ConstantFolding/m1_index") {
2013       EXPECT_EQ("Const", node.op());
2014       EXPECT_EQ(1, node.input_size());
2015       EXPECT_EQ("^m1", node.input(0));
2016       ++found_nodes;
2017     } else if (node.name() == "out2") {
2018       EXPECT_EQ(1, node.input_size());
2019       EXPECT_EQ("m2", node.input(0));
2020       ++found_nodes;
2021     } else if (node.name() == "idx2") {
2022       EXPECT_EQ(1, node.input_size());
2023       EXPECT_EQ("m2:1", node.input(0));
2024       ++found_nodes;
2025     } else if (node.name() == "out3") {
2026       EXPECT_EQ(1, node.input_size());
2027       EXPECT_EQ("m3", node.input(0));
2028       ++found_nodes;
2029     } else if (node.name() == "idx3") {
2030       EXPECT_EQ(1, node.input_size());
2031       EXPECT_EQ("m3:1", node.input(0));
2032       ++found_nodes;
2033     } else if (node.name() == "out4") {
2034       EXPECT_EQ(1, node.input_size());
2035       EXPECT_EQ("m4", node.input(0));
2036       ++found_nodes;
2037     } else if (node.name() == "idx4") {
2038       EXPECT_EQ(1, node.input_size());
2039       EXPECT_EQ("m4:1", node.input(0));
2040       ++found_nodes;
2041     }
2042   }
2043   // Make sure the graph contains all the nodes we're expecting.
2044   EXPECT_EQ(8, found_nodes);
2045 
2046   std::vector<string> fetch = {"out1", "idx1"};
2047   auto tensors = EvaluateNodes(output, fetch);
2048   EXPECT_EQ(2, tensors.size());
2049   const Tensor& out_value = tensors[0];
2050   EXPECT_EQ(3 * 5, out_value.NumElements());
2051   for (int i = 0; i < 3 * 5; ++i) {
2052     EXPECT_EQ(3.14f, out_value.flat<float>()(i));
2053   }
2054   const Tensor& out_idx = tensors[1];
2055   EXPECT_EQ(1, out_idx.NumElements());
2056   EXPECT_EQ(2, out_idx.flat<int32>()(0));
2057 }
2058 
TEST_F(ConstantFoldingTest,SplitRemoval)2059 TEST_F(ConstantFoldingTest, SplitRemoval) {
2060   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2061 
2062   Output in1 =
2063       ops::Variable(scope.WithOpName("in1"), TensorShape({2}), DT_FLOAT);
2064   Output in2 =
2065       ops::Variable(scope.WithOpName("in2"), TensorShape({4}), DT_FLOAT);
2066   auto split_dim = ops::Const(scope.WithOpName("split_dim"), {0}, {});
2067   ops::Split s1(scope.WithOpName("s1"), split_dim, in1, 1);
2068   ops::Split s2(scope.WithOpName("s2"), split_dim, in2, 2);
2069 
2070   ops::Add out(scope.WithOpName("out"), s1[0], s2[0]);
2071 
2072   GrapplerItem item;
2073   item.fetch = {"out"};
2074   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2075 
2076   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2077   GraphDef got;
2078   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2079   TF_EXPECT_OK(status);
2080 
2081   GraphDef want;
2082   AddNode("in1", "VariableV2", {}, {}, &want);
2083   AddNode("in2", "VariableV2", {}, {}, &want);
2084   AddNode("split_dim", "Const", {}, {}, &want);
2085   AddNode("s1", "Identity", {"in1", AsControlDependency("split_dim")}, {},
2086           &want);
2087   AddNode("s2", "Split", {"split_dim", "in2"}, {}, &want);
2088   AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2089 
2090   CompareGraphs(want, got);
2091 
2092   auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2}));
2093   auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4}));
2094   auto tensors_expected =
2095       EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2096   EXPECT_EQ(1, tensors_expected.size());
2097   auto tensors =
2098       EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2099   EXPECT_EQ(1, tensors.size());
2100   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2101 }
2102 
TEST_F(ConstantFoldingTest,SplitVRemoval)2103 TEST_F(ConstantFoldingTest, SplitVRemoval) {
2104   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2105 
2106   Output in1 =
2107       ops::Variable(scope.WithOpName("in1"), TensorShape({2}), DT_FLOAT);
2108   Output in2 =
2109       ops::Variable(scope.WithOpName("in2"), TensorShape({5}), DT_FLOAT);
2110   auto split_dim = ops::Const(scope.WithOpName("split_dim"), {0}, {});
2111   auto size_splits1 = ops::Const(scope.WithOpName("size_splits1"), {2}, {1});
2112   auto size_splits2 = ops::Const(scope.WithOpName("size_splits2"), {2, 3}, {2});
2113   ops::SplitV s1(scope.WithOpName("s1"), in1, size_splits1, split_dim, 1);
2114   ops::SplitV s2(scope.WithOpName("s2"), in2, size_splits2, split_dim, 2);
2115 
2116   ops::Add out(scope.WithOpName("out"), s1[0], s2[0]);
2117 
2118   GrapplerItem item;
2119   item.fetch = {"out"};
2120   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2121 
2122   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2123   GraphDef got;
2124   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2125   TF_EXPECT_OK(status);
2126 
2127   GraphDef want;
2128   AddNode("in1", "VariableV2", {}, {}, &want);
2129   AddNode("in2", "VariableV2", {}, {}, &want);
2130   AddNode("split_dim", "Const", {}, {}, &want);
2131   AddNode("size_splits1", "Const", {}, {}, &want);
2132   AddNode("size_splits2", "Const", {}, {}, &want);
2133   AddNode("s1", "Identity",
2134           {"in1", AsControlDependency("size_splits1"),
2135            AsControlDependency("split_dim")},
2136           {}, &want);
2137   AddNode("s2", "SplitV", {"in2", "size_splits2", "split_dim"}, {}, &want);
2138   AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2139 
2140   CompareGraphs(want, got);
2141 
2142   auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2}));
2143   auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5}));
2144   auto tensors_expected =
2145       EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2146   EXPECT_EQ(1, tensors_expected.size());
2147   auto tensors =
2148       EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2149   EXPECT_EQ(1, tensors.size());
2150   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2151 }
2152 
TEST_F(ConstantFoldingTest,TransposeOnSize1DimsRemoval)2153 TEST_F(ConstantFoldingTest, TransposeOnSize1DimsRemoval) {
2154   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2155 
2156   Output in1 = ops::Variable(scope.WithOpName("in1"), TensorShape({1, 2, 4, 1}),
2157                              DT_FLOAT);
2158   Output p1 = ops::Const(scope.WithOpName("p1"), {3, 2, 1, 0}, {4});
2159   Output in2 = ops::Variable(scope.WithOpName("in2"), TensorShape({1, 4, 2, 1}),
2160                              DT_FLOAT);
2161   Output p2 = ops::Const(scope.WithOpName("p2"), {3, 1, 2, 0}, {4});
2162   ops::Transpose t1(scope.WithOpName("t1"), in1, p1);
2163   ops::Transpose t2(scope.WithOpName("t2").WithControlDependencies({in1}), in2,
2164                     p2);
2165 
2166   ops::Add out1(scope.WithOpName("out1"), t1, t2);
2167 
2168   GrapplerItem item;
2169   item.fetch = {"out1"};
2170   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2171 
2172   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2173   GraphDef got;
2174   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2175   TF_EXPECT_OK(status);
2176 
2177   GraphDef want;
2178   AddNode("in1", "VariableV2", {}, {}, &want);
2179   AddNode("in2", "VariableV2", {}, {}, &want);
2180   AddNode("p1", "Const", {}, {}, &want);
2181   AddNode("p2", "Const", {}, {}, &want);
2182   AddNode("t1", "Transpose", {"in1", "p1"}, {}, &want);
2183   AddNode("t2", "Identity",
2184           {"in2", AsControlDependency("in1"), AsControlDependency("p2")}, {},
2185           &want);
2186   AddNode("out1", "Add", {"t1", "t2"}, {}, &want);
2187 
2188   CompareGraphs(want, got);
2189 }
2190 
TEST_F(ConstantFoldingTest,RandomShuffleOnScalarRemoval)2191 TEST_F(ConstantFoldingTest, RandomShuffleOnScalarRemoval) {
2192   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2193 
2194   Output in1 =
2195       ops::Variable(scope.WithOpName("in1"), TensorShape({}), DT_FLOAT);
2196   Output in2 =
2197       ops::Variable(scope.WithOpName("in2"), TensorShape({}), DT_FLOAT);
2198   ops::RandomShuffle s1(scope.WithOpName("s1"), in1);
2199   ops::RandomShuffle s2(scope.WithOpName("s2").WithControlDependencies({in1}),
2200                         in2);
2201 
2202   ops::Add out1(scope.WithOpName("out1"), s1, s2);
2203   ops::Identity out2(scope.WithOpName("out2"), s2);
2204 
2205   GrapplerItem item;
2206   item.fetch = {"out1", "out2"};
2207   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2208 
2209   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2210   GraphDef got;
2211   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2212   TF_EXPECT_OK(status);
2213 
2214   GraphDef want;
2215   AddNode("in1", "VariableV2", {}, {}, &want);
2216   AddNode("in2", "VariableV2", {}, {}, &want);
2217   AddNode("s1", "Identity", {"in1"}, {}, &want);
2218   AddNode("s2", "Identity", {"in2", AsControlDependency("in1")}, {}, &want);
2219   AddNode("out1", "Add", {"s1", "s2"}, {}, &want);
2220   AddNode("out2", "Identity", {"s2"}, {}, &want);
2221 
2222   CompareGraphs(want, got);
2223 
2224   auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({}));
2225   auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({}));
2226   auto tensors_expected =
2227       EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2228   EXPECT_EQ(2, tensors_expected.size());
2229   auto tensors =
2230       EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2231   EXPECT_EQ(2, tensors.size());
2232   for (int i = 0; i < tensors.size(); i++)
2233     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
2234 }
2235 
TEST_F(ConstantFoldingTest,ReverseOnSize1DimsRemoval)2236 TEST_F(ConstantFoldingTest, ReverseOnSize1DimsRemoval) {
2237   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2238 
2239   Output in1 = ops::Variable(scope.WithOpName("in1"), TensorShape({1, 2, 4, 1}),
2240                              DT_FLOAT);
2241   Output a1 = ops::Const(scope.WithOpName("a1"), {3, 2, 1, 0}, {4});
2242   Output in2 = ops::Variable(scope.WithOpName("in2"), TensorShape({1, 2, 4, 1}),
2243                              DT_FLOAT);
2244   Output a2 = ops::Const(scope.WithOpName("a2"), {0, 3}, {2});
2245   ops::Reverse r1(scope.WithOpName("r1"), in1, a1);
2246   ops::Reverse r2(scope.WithOpName("r2").WithControlDependencies({in1}), in2,
2247                   a2);
2248 
2249   ops::Add out1(scope.WithOpName("out1"), r1, r2);
2250 
2251   GrapplerItem item;
2252   item.fetch = {"out1"};
2253   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2254 
2255   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2256   GraphDef got;
2257   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2258   TF_EXPECT_OK(status);
2259 
2260   GraphDef want;
2261   AddNode("in1", "VariableV2", {}, {}, &want);
2262   AddNode("in2", "VariableV2", {}, {}, &want);
2263   AddNode("a1", "Const", {}, {}, &want);
2264   AddNode("a2", "Const", {}, {}, &want);
2265   AddNode("r1", "ReverseV2", {"in1", "a1"}, {}, &want);
2266   AddNode("r2", "Identity",
2267           {"in2", AsControlDependency("in1"), AsControlDependency("a2")}, {},
2268           &want);
2269   AddNode("out1", "Add", {"r1", "r2"}, {}, &want);
2270 
2271   CompareGraphs(want, got);
2272 }
2273 
TEST_F(ConstantFoldingTest,SliceWithSameDimensionRemoval)2274 TEST_F(ConstantFoldingTest, SliceWithSameDimensionRemoval) {
2275   {  // size = {3, 5}
2276     tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2277 
2278     auto in1 = ops::Variable(scope.WithOpName("in1"), {3, 5}, DT_FLOAT);
2279     auto begin = ops::Const(scope.WithOpName("begin"), {0, 0}, {2});
2280     auto size = ops::Const(scope.WithOpName("size"), {3, 5}, {2});
2281     Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6}, DT_FLOAT);
2282     ops::Slice s1(scope.WithOpName("s1"), in1, begin, size);
2283     ops::Slice s2(scope.WithOpName("s2"), in2, begin, size);
2284 
2285     ops::Add out(scope.WithOpName("out"), s1, s2);
2286 
2287     GrapplerItem item;
2288     item.fetch = {"out"};
2289     TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2290 
2291     ConstantFolding optimizer(/*cpu_device=*/nullptr);
2292     GraphDef got;
2293     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2294     TF_EXPECT_OK(status);
2295 
2296     GraphDef want;
2297     AddNode("in1", "VariableV2", {}, {}, &want);
2298     AddNode("in2", "VariableV2", {}, {}, &want);
2299     AddNode("begin", "Const", {}, {}, &want);
2300     AddNode("size", "Const", {}, {}, &want);
2301     AddNode("s1", "Identity",
2302             {"in1", AsControlDependency("begin"), AsControlDependency("size")},
2303             {}, &want);
2304     AddNode("s2", "Slice", {"in2", "begin", "size"}, {}, &want);
2305     AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2306 
2307     CompareGraphs(want, got);
2308 
2309     auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5}));
2310     auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 6}));
2311     auto tensors_expected =
2312         EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2313     EXPECT_EQ(1, tensors_expected.size());
2314     auto tensors =
2315         EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2316     EXPECT_EQ(1, tensors.size());
2317     test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2318   }
2319   {  // size = {-1, -1}
2320     tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2321 
2322     auto in1 =
2323         ops::Variable(scope.WithOpName("in1"), {3, 5}, DataType::DT_FLOAT);
2324     auto begin1 = ops::Const(scope.WithOpName("begin1"), {0, 0}, {2});
2325     auto begin2 = ops::Const(scope.WithOpName("begin2"), {1, 1}, {2});
2326     auto size = ops::Const(scope.WithOpName("size"), {-1, -1}, {2});
2327     Output in2 =
2328         ops::Variable(scope.WithOpName("in2"), {4, 6}, DataType::DT_FLOAT);
2329     ops::Slice s1(scope.WithOpName("s1"), in1, begin1, size);
2330     ops::Slice s2(scope.WithOpName("s2"), in2, begin2, size);
2331 
2332     ops::Add out(scope.WithOpName("out"), s1, s2);
2333 
2334     GrapplerItem item;
2335     item.fetch = {"out"};
2336     TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2337 
2338     ConstantFolding optimizer(/*cpu_device=*/nullptr);
2339     GraphDef got;
2340     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2341     TF_EXPECT_OK(status);
2342 
2343     GraphDef want;
2344     AddNode("in1", "VariableV2", {}, {}, &want);
2345     AddNode("in2", "VariableV2", {}, {}, &want);
2346     AddNode("begin1", "Const", {}, {}, &want);
2347     AddNode("begin2", "Const", {}, {}, &want);
2348     AddNode("size", "Const", {}, {}, &want);
2349     AddNode("s1", "Identity",
2350             {"in1", AsControlDependency("begin1"), AsControlDependency("size")},
2351             {}, &want);
2352     AddNode("s2", "Slice", {"in2", "begin2", "size"}, {}, &want);
2353     AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2354 
2355     CompareGraphs(want, got);
2356 
2357     auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5}));
2358     auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 6}));
2359     auto tensors_expected =
2360         EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2361     EXPECT_EQ(1, tensors_expected.size());
2362     auto tensors =
2363         EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2364     EXPECT_EQ(1, tensors.size());
2365     test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2366   }
2367 }
2368 
TEST_F(ConstantFoldingTest,StridedSliceWithSameDimensionRemoval)2369 TEST_F(ConstantFoldingTest, StridedSliceWithSameDimensionRemoval) {
2370   {  // no mask
2371     tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2372 
2373     auto in1 = ops::Variable(scope.WithOpName("in1"), {3, 5, 2}, DT_FLOAT);
2374     auto begin = ops::Const(scope.WithOpName("begin"), {0, 0}, {2});
2375     auto end = ops::Const(scope.WithOpName("end"), {3, 5}, {2});
2376     auto strides = ops::Const(scope.WithOpName("strides"), {1, 1}, {2});
2377     Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6, 2}, DT_FLOAT);
2378     ops::StridedSlice s1(scope.WithOpName("s1"), in1, begin, end, strides);
2379     ops::StridedSlice s2(scope.WithOpName("s2"), in2, begin, end, strides);
2380 
2381     ops::Add out(scope.WithOpName("out"), s1, s2);
2382 
2383     GrapplerItem item;
2384     item.fetch = {"out"};
2385     TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2386 
2387     ConstantFolding optimizer(/*cpu_device=*/nullptr);
2388     GraphDef got;
2389     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2390     TF_EXPECT_OK(status);
2391 
2392     GraphDef want;
2393     AddNode("in1", "VariableV2", {}, {}, &want);
2394     AddNode("in2", "VariableV2", {}, {}, &want);
2395     AddNode("begin", "Const", {}, {}, &want);
2396     AddNode("end", "Const", {}, {}, &want);
2397     AddNode("strides", "Const", {}, {}, &want);
2398     AddNode("s1", "Identity",
2399             {"in1", AsControlDependency("begin"), AsControlDependency("end"),
2400              AsControlDependency("strides")},
2401             {}, &want);
2402     AddNode("s2", "StridedSlice", {"in2", "begin", "end", "strides"}, {},
2403             &want);
2404     AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2405 
2406     CompareGraphs(want, got);
2407 
2408     auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 2}));
2409     auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 6, 2}));
2410     auto tensors_expected =
2411         EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2412     EXPECT_EQ(1, tensors_expected.size());
2413     auto tensors =
2414         EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2415     EXPECT_EQ(1, tensors.size());
2416     test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2417   }
2418   {  // with begin/end/ellipsis mask
2419     tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2420 
2421     // s1 = in1[:, ..., 0:5, 0:6]
2422     auto in1 =
2423         ops::Variable(scope.WithOpName("in1"), {2, 3, 4, 5, 6}, DT_FLOAT);
2424     auto begin1 = ops::Const(scope.WithOpName("begin1"), {0, 0, 0}, {3});
2425     auto end1 = ops::Const(scope.WithOpName("end1"), {0, 5, 6}, {3});
2426     auto strides1 = ops::Const(scope.WithOpName("strides1"), {1, 1, 1}, {3});
2427     ops::StridedSlice s1(
2428         scope.WithOpName("s1"), in1, begin1, end1, strides1,
2429         ops::StridedSlice::Attrs().BeginMask(1).EndMask(1).EllipsisMask(2));
2430 
2431     Output in2 =
2432         ops::Variable(scope.WithOpName("in2"), {5, 8, 5, 6, 9}, DT_FLOAT);
2433     auto begin2 = ops::Const(scope.WithOpName("begin2"), {0, 0, 0, 0, 0}, {5});
2434     auto end2 = ops::Const(scope.WithOpName("end2"), {2, 3, 4, 5, 6}, {5});
2435     auto strides2 =
2436         ops::Const(scope.WithOpName("strides2"), {1, 1, 1, 1, 1}, {5});
2437     ops::StridedSlice s2(scope.WithOpName("s2"), in2, begin2, end2, strides2);
2438 
2439     ops::Add out(scope.WithOpName("out"), s1, s2);
2440 
2441     GrapplerItem item;
2442     item.fetch = {"out"};
2443     TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2444 
2445     ConstantFolding optimizer(/*cpu_device=*/nullptr);
2446     GraphDef got;
2447     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2448     TF_EXPECT_OK(status);
2449 
2450     GraphDef want;
2451     AddNode("in1", "VariableV2", {}, {}, &want);
2452     AddNode("in2", "VariableV2", {}, {}, &want);
2453     AddNode("begin1", "Const", {}, {}, &want);
2454     AddNode("end1", "Const", {}, {}, &want);
2455     AddNode("strides1", "Const", {}, {}, &want);
2456     AddNode("s1", "Identity",
2457             {"in1", AsControlDependency("begin1"), AsControlDependency("end1"),
2458              AsControlDependency("strides1")},
2459             {}, &want);
2460     AddNode("begin2", "Const", {}, {}, &want);
2461     AddNode("end2", "Const", {}, {}, &want);
2462     AddNode("strides2", "Const", {}, {}, &want);
2463     AddNode("s2", "StridedSlice", {"in2", "begin2", "end2", "strides2"}, {},
2464             &want);
2465     AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2466 
2467     CompareGraphs(want, got);
2468 
2469     auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 3, 4, 5, 6}));
2470     auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 8, 5, 6, 9}));
2471     auto tensors_expected =
2472         EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2473     EXPECT_EQ(1, tensors_expected.size());
2474     auto tensors =
2475         EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2476     EXPECT_EQ(1, tensors.size());
2477     test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2478   }
2479 }
2480 
TEST_F(ConstantFoldingTest,TileWithMultipliesBeingOne)2481 TEST_F(ConstantFoldingTest, TileWithMultipliesBeingOne) {
2482   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2483 
2484   auto in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_FLOAT);
2485   auto in2 = ops::Variable(scope.WithOpName("in2"), {4, 3}, DT_FLOAT);
2486   auto multiplies1 = ops::Const(scope.WithOpName("multiplies1"), {1, 1}, {2});
2487   auto multiplies2 = ops::Const(scope.WithOpName("multiplies2"), {1, 2}, {2});
2488 
2489   ops::Tile t1(scope.WithOpName("t1"), in1, multiplies1);
2490   ops::Tile t2(scope.WithOpName("t2"), in2, multiplies2);
2491 
2492   ops::Add out(scope.WithOpName("out"), t1, t2);
2493 
2494   GrapplerItem item;
2495   item.fetch = {"out"};
2496   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2497 
2498   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2499   GraphDef got;
2500   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2501   TF_EXPECT_OK(status);
2502 
2503   GraphDef want;
2504   AddNode("in1", "VariableV2", {}, {}, &want);
2505   AddNode("in2", "VariableV2", {}, {}, &want);
2506   AddNode("multiplies1", "Const", {}, {}, &want);
2507   AddNode("multiplies2", "Const", {}, {}, &want);
2508   AddNode("t1", "Identity", {"in1", AsControlDependency("multiplies1")}, {},
2509           &want);
2510   AddNode("t2", "Tile", {"in2", "multiplies2"}, {}, &want);
2511   AddNode("out", "Add", {"t1", "t2"}, {}, &want);
2512 
2513   CompareGraphs(want, got);
2514 }
2515 
TEST_F(ConstantFoldingTest,MergeConcat)2516 TEST_F(ConstantFoldingTest, MergeConcat) {
2517   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2518 
2519   Output in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_FLOAT);
2520   Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6}, DT_FLOAT);
2521   Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
2522   Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
2523 
2524   ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis);
2525   ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis);
2526 
2527   GrapplerItem item;
2528   item.fetch = {"c2"};
2529   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2530 
2531   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2532   GraphDef got;
2533   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2534   TF_EXPECT_OK(status);
2535 
2536   GraphDef want;
2537   AddNode("in1", "VariableV2", {}, {}, &want);
2538   AddNode("in2", "VariableV2", {}, {}, &want);
2539   AddNode("in3", "VariableV2", {}, {}, &want);
2540   AddNode("axis", "Const", {}, {}, &want);
2541   AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "axis"}, {}, &want);
2542 
2543   CompareGraphs(want, got);
2544 }
2545 
TEST_F(ConstantFoldingTest,MergeConcat_SameInput)2546 TEST_F(ConstantFoldingTest, MergeConcat_SameInput) {
2547   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2548 
2549   Output in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_FLOAT);
2550   Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6}, DT_FLOAT);
2551   Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
2552   Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
2553 
2554   ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis);
2555   ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3, Output(c1)}, axis);
2556 
2557   GrapplerItem item;
2558   item.fetch = {"c2"};
2559   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2560 
2561   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2562   GraphDef got;
2563   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2564   TF_EXPECT_OK(status);
2565 
2566   GraphDef want;
2567   AddNode("in1", "VariableV2", {}, {}, &want);
2568   AddNode("in2", "VariableV2", {}, {}, &want);
2569   AddNode("in3", "VariableV2", {}, {}, &want);
2570   AddNode("axis", "Const", {}, {}, &want);
2571   AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "in1", "in2", "axis"}, {},
2572           &want);
2573 
2574   CompareGraphs(want, got);
2575 }
2576 
TEST_F(ConstantFoldingTest,MergeConcat_ConcatWithConst)2577 TEST_F(ConstantFoldingTest, MergeConcat_ConcatWithConst) {
2578   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2579 
2580   Output in1 = ops::Variable(scope.WithOpName("in1"), {2, 6}, DT_FLOAT);
2581   Output in2 = ops::Variable(scope.WithOpName("in2"), {}, DT_FLOAT);
2582   Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
2583   Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
2584 
2585   ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis);
2586   ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis);
2587 
2588   GrapplerItem item;
2589   item.fetch = {"c2"};
2590   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2591 
2592   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2593   GraphDef got;
2594   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2595   TF_EXPECT_OK(status);
2596 
2597   GraphDef want;
2598   AddNode("in1", "VariableV2", {}, {}, &want);
2599   AddNode("in2", "VariableV2", {}, {}, &want);
2600   AddNode("in3", "VariableV2", {}, {}, &want);
2601   AddNode("axis", "Const", {}, {}, &want);
2602   AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "axis"}, {}, &want);
2603 
2604   CompareGraphs(want, got);
2605 }
2606 
TEST_F(ConstantFoldingTest,MergeConcat_AxisMismatch)2607 TEST_F(ConstantFoldingTest, MergeConcat_AxisMismatch) {
2608   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2609 
2610   Output in1 = ops::Variable(scope.WithOpName("in1"), {2, 5}, DT_FLOAT);
2611   Output in2 = ops::Variable(scope.WithOpName("in2"), {}, DT_FLOAT);
2612   Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
2613   Output axis1 = ops::Const(scope.WithOpName("axis1"), 0, {});
2614   Output axis2 = ops::Const(scope.WithOpName("axis2"), 1, {});
2615 
2616   ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis2);
2617   ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis1);
2618 
2619   GrapplerItem item;
2620   item.fetch = {"c2"};
2621   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2622 
2623   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2624   GraphDef got;
2625   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2626   TF_EXPECT_OK(status);
2627 
2628   GraphDef want;
2629   AddNode("in1", "VariableV2", {}, {}, &want);
2630   AddNode("in2", "VariableV2", {}, {}, &want);
2631   AddNode("in3", "VariableV2", {}, {}, &want);
2632   AddNode("axis1", "Const", {}, {}, &want);
2633   AddNode("axis2", "Const", {}, {}, &want);
2634   AddNode("c1", "ConcatV2", {"in1", "in2", "axis2"}, {}, &want);
2635   AddNode("c2", "ConcatV2", {"c1", "in3", "axis1"}, {}, &want);
2636 
2637   CompareGraphs(want, got);
2638 }
2639 
TEST_F(ConstantFoldingTest,MergeConcat_PartialFolding)2640 TEST_F(ConstantFoldingTest, MergeConcat_PartialFolding) {
2641   Scope scope = Scope::NewRootScope();
2642   Output c1 = ops::Const(scope.WithOpName("c1"), 1.0f, {2, 2});
2643   Output c2 = ops::Const(scope.WithOpName("c2"), 2.0f, {2, 2});
2644   Output c3 = ops::Const(scope.WithOpName("c3"), 3.0f, {2, 2});
2645   Output c4 = ops::Const(scope.WithOpName("c4"), 4.0f, {2, 2});
2646   Output ph = ops::Placeholder(scope.WithOpName("ph"), DT_FLOAT,
2647                                ops::Placeholder::Shape(TensorShape({2, 2})));
2648   Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
2649 
2650   ops::Concat concat1(scope.WithOpName("concat1"), {c1, c2, ph}, axis);
2651   ops::Concat concat2(scope.WithOpName("concat2"), {c3, c4, Output(concat1)},
2652                       axis);
2653 
2654   GrapplerItem item;
2655   item.fetch = {"concat2"};
2656   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2657 
2658   ConstantFolding optimizer(nullptr);
2659   GraphDef got;
2660   Status status = optimizer.Optimize(nullptr, item, &got);
2661   TF_EXPECT_OK(status);
2662 
2663   GraphDef want;
2664   AddNode("ConstantFolding/concat2_partial_split_0", "Const", {}, {}, &want);
2665   AddNode("axis", "Const", {}, {}, &want);
2666   AddNode("ph", "Placeholder", {}, {}, &want);
2667   AddNode("concat2", "ConcatV2",
2668           {"ConstantFolding/concat2_partial_split_0", "ph", "axis"}, {}, &want);
2669 
2670   CompareGraphs(want, got);
2671 }
2672 
TEST_F(ConstantFoldingTest,PaddingWithZeroSize)2673 TEST_F(ConstantFoldingTest, PaddingWithZeroSize) {
2674   PaddingWithZeroSize<int32>();
2675   PaddingWithZeroSize<int64>();
2676 }
2677 
TEST_F(ConstantFoldingTest,SqueezeWithAllDimensionsGreaterThanOne)2678 TEST_F(ConstantFoldingTest, SqueezeWithAllDimensionsGreaterThanOne) {
2679   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2680 
2681   auto in1 = ops::Variable(scope.WithOpName("in1"), {2, 3}, DT_INT32);
2682   auto in2 = ops::Variable(scope.WithOpName("in2"), {1, 2, 3, 1}, DT_INT32);
2683 
2684   ops::Squeeze s1(scope.WithOpName("s1"), in1);
2685   ops::Squeeze s2(scope.WithOpName("s2"), in2);
2686 
2687   ops::Add out(scope.WithOpName("out"), s1, s2);
2688 
2689   GrapplerItem item;
2690   item.fetch = {"out"};
2691   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2692 
2693   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2694   GraphDef got;
2695   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2696   TF_EXPECT_OK(status);
2697 
2698   GraphDef want;
2699   AddNode("in1", "VariableV2", {}, {}, &want);
2700   AddNode("in2", "VariableV2", {}, {}, &want);
2701   AddNode("s1", "Identity", {"in1"}, {}, &want);
2702   AddNode("s2", "Squeeze", {"in2"}, {}, &want);
2703   AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2704 
2705   CompareGraphs(want, got);
2706 
2707   auto in1_t = GenerateRandomTensor<DT_INT32>(TensorShape({2, 3}));
2708   auto in2_t = GenerateRandomTensor<DT_INT32>(TensorShape({1, 2, 3, 1}));
2709   auto tensors_expected =
2710       EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2711   EXPECT_EQ(1, tensors_expected.size());
2712   auto tensors =
2713       EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2714   EXPECT_EQ(1, tensors.size());
2715   test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
2716 }
2717 
TEST_F(ConstantFoldingTest,NoOpReduction)2718 TEST_F(ConstantFoldingTest, NoOpReduction) {
2719   // Build a simple graph with reductions that can be reduced to the
2720   // identity.
2721   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2722 
2723   Output v = ops::Variable(scope.WithOpName("v"), {3, 5, 7}, DT_FLOAT);
2724   Output c =
2725       ops::Const(scope.WithOpName("c").WithControlDependencies(v), 0, {0});
2726   Output i = ops::Identity(scope.WithOpName("i"), c);
2727   Output p = ops::Prod(scope.WithOpName("p"), v, i);
2728   Output s = ops::Square(scope.WithOpName("s"), p);
2729 
2730   Output v2 = ops::Variable(scope.WithOpName("v2"), {3, 5, 1}, DT_FLOAT);
2731   Output c2 =
2732       ops::Const(scope.WithOpName("c2").WithControlDependencies(v), 2, {1});
2733   ops::Prod::Attrs attr;
2734   attr = attr.KeepDims(true);
2735   Output p2 = ops::Prod(scope.WithOpName("p2"), v2, c2, attr);
2736 
2737   // Test with unknown input shape.
2738   Output a = ops::Placeholder(scope.WithOpName("a"), DT_FLOAT);
2739   Output p3 = ops::Prod(scope.WithOpName("p3"), a, i, attr);
2740 
2741   GrapplerItem item;
2742   item.fetch = {"s", "p2", "p3"};
2743   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2744 
2745   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2746   GraphDef output;
2747   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2748   TF_EXPECT_OK(status);
2749 
2750   int found = 0;
2751   for (const auto& node : output.node()) {
2752     if (node.name() == "p") {
2753       found++;
2754       EXPECT_EQ("Identity", node.op());
2755       EXPECT_EQ(2, node.input_size());
2756       EXPECT_EQ("v", node.input(0));
2757       EXPECT_EQ("^i", node.input(1));
2758     } else if (node.name() == "p2") {
2759       found++;
2760       EXPECT_EQ("Identity", node.op());
2761       EXPECT_EQ(2, node.input_size());
2762       EXPECT_EQ("v2", node.input(0));
2763       EXPECT_EQ("^c2", node.input(1));
2764     } else if (node.name() == "p3") {
2765       found++;
2766       EXPECT_EQ("Identity", node.op());
2767       EXPECT_EQ(2, node.input_size());
2768       EXPECT_EQ("a", node.input(0));
2769       EXPECT_EQ("^i", node.input(1));
2770     }
2771   }
2772   EXPECT_EQ(3, found);
2773 
2774   auto v_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 7}));
2775   auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 1}));
2776   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 7}));
2777   auto tensors_expected = EvaluateNodes(item.graph, item.fetch,
2778                                         {{"v", v_t}, {"v2", v2_t}, {"a", a_t}});
2779   EXPECT_EQ(3, tensors_expected.size());
2780   auto tensors =
2781       EvaluateNodes(output, item.fetch, {{"v", v_t}, {"v2", v2_t}, {"a", a_t}});
2782   EXPECT_EQ(3, tensors.size());
2783   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2784   test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-5);
2785   test::ExpectTensorNear<float>(tensors_expected[2], tensors[2], 1e-5);
2786 }
2787 
TEST_F(ConstantFoldingTest,SingleElementEmptyAxisReduction)2788 TEST_F(ConstantFoldingTest, SingleElementEmptyAxisReduction) {
2789   // Build a simple graph with reductions that involve single-element input and
2790   // no axes to reduce along.
2791   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2792 
2793   Output input_var_three_dim = ops::Variable(
2794       scope.WithOpName("input_var_three_dim"), {1, 1, 1}, DT_FLOAT);
2795   Output input_var_one_dim =
2796       ops::Variable(scope.WithOpName("input_var_one_dim"), {1}, DT_FLOAT);
2797   Output one_axis = ops::Const(scope.WithOpName("one_axis"), {0}, {1});
2798   Output multiple_axes =
2799       ops::Const(scope.WithOpName("multiple_axes"), {1, 0}, {2});
2800   Output variable_axis =
2801       ops::Variable(scope.WithOpName("input_var_axis"), {1}, DT_INT32);
2802   ops::Mean::Attrs attr;
2803   attr = attr.KeepDims(false);
2804   // Should be optimized to Reshape.
2805   Output mean_1 = ops::Mean(scope.WithOpName("mean_1"), input_var_three_dim,
2806                             one_axis, attr.KeepDims(false));
2807   Output mean_2 = ops::Mean(scope.WithOpName("mean_2"), input_var_three_dim,
2808                             multiple_axes, attr.KeepDims(false));
2809   // Should remain as-is, since OutputProperties will not be known this node.
2810   Output mean_3 = ops::Mean(scope.WithOpName("mean_3"), input_var_one_dim,
2811                             one_axis, attr.KeepDims(false));
2812   // Should remain as-is.
2813   Output mean_4 = ops::Mean(scope.WithOpName("mean_4"), input_var_three_dim,
2814                             variable_axis, attr.KeepDims(false));
2815   // Should be optimized to Identity, since KeepDims=true.
2816   Output mean_5 = ops::Mean(scope.WithOpName("mean_5"), input_var_three_dim,
2817                             multiple_axes, attr.KeepDims(true));
2818 
2819   GrapplerItem item;
2820   item.fetch = {"mean_1", "mean_2", "mean_3", "mean_4", "mean_5"};
2821   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2822 
2823   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2824   GraphDef output;
2825   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2826   TF_EXPECT_OK(status);
2827 
2828   // Ensure Mean node is optimized to Reshape.
2829   int found = 0;
2830   for (const auto& node : output.node()) {
2831     if (node.name() == "mean_1" || node.name() == "mean_2") {
2832       found++;
2833       EXPECT_EQ("Reshape", node.op());
2834       EXPECT_EQ(2, node.input_size());
2835       EXPECT_EQ("input_var_three_dim", node.input(0));
2836     } else if (node.name() == "mean_3") {
2837       found++;
2838       EXPECT_EQ("Mean", node.op());
2839       EXPECT_EQ(2, node.input_size());
2840       EXPECT_EQ("input_var_one_dim", node.input(0));
2841     } else if (node.name() == "mean_4") {
2842       found++;
2843       EXPECT_EQ("Mean", node.op());
2844       EXPECT_EQ(2, node.input_size());
2845       EXPECT_EQ("input_var_three_dim", node.input(0));
2846     } else if (node.name() == "mean_5") {
2847       found++;
2848       EXPECT_EQ("Identity", node.op());
2849       EXPECT_EQ(2, node.input_size());
2850       EXPECT_EQ("^multiple_axes", node.input(1));
2851     }
2852   }
2853   EXPECT_EQ(5, found);
2854 
2855   // Ensure resultant values from Mean and Reshape are the same.
2856   auto input_var_three_dim_t =
2857       GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 1, 1}));
2858   auto input_var_one_dim_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1}));
2859   Tensor input_var_axis_t(DT_INT32, TensorShape({1}));
2860   input_var_axis_t.flat<int32>()(0) = 0;
2861   auto tensors_expected =
2862       EvaluateNodes(item.graph, item.fetch,
2863                     {{"input_var_three_dim", input_var_three_dim_t},
2864                      {"input_var_one_dim", input_var_one_dim_t},
2865                      {"input_var_axis", input_var_axis_t}});
2866   EXPECT_EQ(5, tensors_expected.size());
2867   auto tensors = EvaluateNodes(output, item.fetch,
2868                                {{"input_var_three_dim", input_var_three_dim_t},
2869                                 {"input_var_one_dim", input_var_one_dim_t},
2870                                 {"input_var_axis", input_var_axis_t}});
2871   EXPECT_EQ(5, tensors.size());
2872   for (int i = 0; i < 5; ++i) {
2873     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
2874   }
2875 }
2876 
TEST_F(ConstantFoldingTest,NoOpReshape)2877 TEST_F(ConstantFoldingTest, NoOpReshape) {
2878   // Build a simple graph with a reshape that can be reduced to the identity.
2879   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2880 
2881   // A reshape than can be optimized
2882   Output d1 = ops::Const(scope.WithOpName("d1"), 3.14f, {17});
2883   Output v1 = ops::Variable(scope.WithOpName("v1"), {17}, DT_FLOAT);
2884   Output c1 =
2885       ops::Const(scope.WithOpName("c1").WithControlDependencies(v1), 17, {1});
2886   Output i1 = ops::Identity(scope.WithOpName("i1"), c1);
2887   Output r1 =
2888       ops::Reshape(scope.WithOpName("r1").WithControlDependencies(d1), v1, i1);
2889   Output s1 = ops::Square(scope.WithOpName("s1"), r1);
2890 
2891   // A multi dimensional reshape than can be optimized
2892   Output v3 = ops::Variable(scope.WithOpName("v3"), {5, 5, 5}, DT_FLOAT);
2893   Output c3 =
2894       ops::Const(scope.WithOpName("c3").WithControlDependencies(v3), 5, {3});
2895   Output i3 = ops::Identity(scope.WithOpName("i3"), c3);
2896   Output r3 = ops::Reshape(scope.WithOpName("r3"), v3, i3);
2897   Output s3 = ops::Square(scope.WithOpName("s3"), r3);
2898 
2899   // A multi dimensional partially defined reshape than can be optimized
2900   Output v4 = ops::Variable(scope.WithOpName("v4"), {5, 5, 5}, DT_FLOAT);
2901   Output c4 = ops::Const(scope.WithOpName("c4").WithControlDependencies(v4),
2902                          {5, -1, 5}, {3});
2903   Output i4 = ops::Identity(scope.WithOpName("i4"), c4);
2904   Output r4 = ops::Reshape(scope.WithOpName("r4"), v4, i4);
2905   Output s4 = ops::Square(scope.WithOpName("s4"), r4);
2906 
2907   // A reshape that can't be optimized
2908   Output v2 = ops::Variable(scope.WithOpName("v2"), {17, 1}, DT_FLOAT);
2909   Output c2 =
2910       ops::Const(scope.WithOpName("c2").WithControlDependencies(v2), 17, {1});
2911   Output r2 = ops::Reshape(scope.WithOpName("r2"), v2, c2);
2912   Output s2 = ops::Square(scope.WithOpName("s2"), r2);
2913 
2914   GrapplerItem item;
2915   item.fetch = {"s1", "s2", "s3", "s4"};
2916   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2917 
2918   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2919   GraphDef output;
2920   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2921   TF_EXPECT_OK(status);
2922 
2923   int found = 0;
2924   for (const auto& node : output.node()) {
2925     if (node.name() == "r1") {
2926       ++found;
2927       EXPECT_EQ("Identity", node.op());
2928       ASSERT_EQ(3, node.input_size());
2929       EXPECT_EQ("v1", node.input(0));
2930       EXPECT_EQ("^i1", node.input(1));
2931       EXPECT_EQ("^d1", node.input(2));
2932     } else if (node.name() == "r3") {
2933       ++found;
2934       EXPECT_EQ("Identity", node.op());
2935       ASSERT_EQ(2, node.input_size());
2936       EXPECT_EQ("v3", node.input(0));
2937       EXPECT_EQ("^i3", node.input(1));
2938     } else if (node.name() == "r4") {
2939       ++found;
2940       EXPECT_EQ("Identity", node.op());
2941       ASSERT_EQ(2, node.input_size());
2942       EXPECT_EQ("v4", node.input(0));
2943       EXPECT_EQ("^i4", node.input(1));
2944     } else if (node.name() == "r2") {
2945       ++found;
2946       EXPECT_EQ("Reshape", node.op());
2947       ASSERT_EQ(2, node.input_size());
2948       EXPECT_EQ("v2", node.input(0));
2949       EXPECT_EQ("c2", node.input(1));
2950     }
2951   }
2952   EXPECT_EQ(4, found);
2953 
2954   auto v1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({17}));
2955   auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({17, 1}));
2956   auto v3_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 5, 5}));
2957   auto v4_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 5, 5}));
2958   auto tensors_expected =
2959       EvaluateNodes(item.graph, item.fetch,
2960                     {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}, {"v4", v4_t}});
2961   EXPECT_EQ(4, tensors_expected.size());
2962   auto tensors =
2963       EvaluateNodes(output, item.fetch,
2964                     {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}, {"v4", v4_t}});
2965   EXPECT_EQ(4, tensors.size());
2966   for (int i = 0; i < tensors.size(); i++)
2967     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
2968 }
2969 
TEST_F(ConstantFoldingTest,Packing)2970 TEST_F(ConstantFoldingTest, Packing) {
2971   // Build a simple graph with a large constant that can be folded.
2972   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2973   Output c = ops::Const(scope.WithOpName("c"), 3.14f, {1000});
2974   Output i1 = ops::Identity(scope.WithOpName("i1"), c);
2975   Output i2 = ops::Identity(scope.WithOpName("i2"), c);
2976 
2977   GrapplerItem item;
2978   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2979 
2980   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2981   GraphDef output;
2982   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2983   TF_EXPECT_OK(status);
2984 
2985   const std::vector<string> fetch_nodes = {"i1", "i2"};
2986   auto tensors_expected = EvaluateNodes(item.graph, fetch_nodes);
2987   EXPECT_EQ(fetch_nodes.size(), tensors_expected.size());
2988   auto tensors = EvaluateNodes(output, fetch_nodes);
2989   EXPECT_EQ(fetch_nodes.size(), tensors.size());
2990   for (int i = 0; i < fetch_nodes.size(); i++)
2991     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
2992 
2993   // Make sure that the representation of the folded constant is space
2994   // efficient: in particular, the whole message should be smaller than 8k
2995   // (the size needed to naively encode 1000 floats folded twice).
2996   EXPECT_GT(8000, output.ByteSizeLong());
2997 }
2998 
TEST_F(ConstantFoldingTest,LargeConstantNoSizeIncrease)2999 TEST_F(ConstantFoldingTest, LargeConstantNoSizeIncrease) {
3000   // Build a simple graph with a large constant with size greater than
3001   // kMaxConstantSize that can be folded because the resulting size does not
3002   // increase.
3003   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3004   const int64 large_constant_size = kMaxConstantSize + 1;
3005   Output a = ops::Variable(scope.WithOpName("a"), {1, 1}, DT_FLOAT);
3006   Output b_const =
3007       ops::Const(scope.WithOpName("b_const"), 3.14f, {1, large_constant_size});
3008   Output b = ops::Identity(scope.WithOpName("b"), b_const);
3009   Output matmul = ops::MatMul(scope.WithOpName("matmul"), a, b);
3010 
3011   GrapplerItem item;
3012   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3013 
3014   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3015   GraphDef output;
3016   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3017   TF_EXPECT_OK(status);
3018 
3019   item.graph.Swap(&output);
3020   status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3021   TF_EXPECT_OK(status);
3022 
3023   for (const auto& node : output.node()) {
3024     if (node.name() == "b") {
3025       EXPECT_EQ("Const", node.op());
3026     }
3027   }
3028   EXPECT_EQ(4, output.node_size());
3029   EXPECT_LT(output.ByteSizeLong(), sizeof(float) * large_constant_size + 500);
3030 }
3031 
TEST_F(ConstantFoldingTest,MaterializeBroadcastGradientArgs)3032 TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs) {
3033   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3034   Output a =
3035       ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
3036                        ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
3037   Output b = ops::Square(s.WithOpName("b"), a);
3038   Output c = ops::Mul(s.WithOpName("c"), a, b);
3039   Output d = ops::Shape(s.WithOpName("d"), a);
3040   Output e = ops::Shape(s.WithOpName("e"), b);
3041 
3042   auto f = ops::internal::BroadcastGradientArgs(s.WithOpName("f"), d, e);
3043   Output o1 = ops::Identity(s.WithOpName("o1"), f.r0);
3044   Output o2 = ops::Identity(s.WithOpName("o2"), f.r1);
3045 
3046   Output g = ops::Placeholder(s.WithOpName("g"), DT_FLOAT,
3047                               ops::Placeholder::Shape(PartialTensorShape({1})));
3048   Output h = ops::Shape(s.WithOpName("h"), g);
3049   auto i = ops::internal::BroadcastGradientArgs(s.WithOpName("i"), d, h);
3050   Output p1 = ops::Identity(s.WithOpName("p1"), i.r0);
3051   Output p2 = ops::Identity(s.WithOpName("p2"), i.r1);
3052 
3053   GrapplerItem item;
3054   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3055 
3056   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3057   GraphDef output;
3058   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3059   TF_EXPECT_OK(status);
3060 
3061   std::vector<string> fetch_nodes = {"o1", "o2", "p1", "p2"};
3062   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 5}));
3063   auto g_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1}));
3064   auto tensors_expected =
3065       EvaluateNodes(item.graph, fetch_nodes, {{"a", a_t}, {"g", g_t}});
3066   EXPECT_EQ(fetch_nodes.size(), tensors_expected.size());
3067 
3068   // Run a second time to make sure the optimization is idempotent.
3069   item.graph.Swap(&output);
3070   status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3071   TF_EXPECT_OK(status);
3072 
3073   int found = 0;
3074   for (const auto& node : output.node()) {
3075     if (node.name() == "o1") {
3076       ++found;
3077       EXPECT_EQ(1, node.input_size());
3078       EXPECT_EQ("ConstantFolding/f-bcastargs-0", node.input(0));
3079     } else if (node.name() == "o2") {
3080       ++found;
3081       EXPECT_EQ(1, node.input_size());
3082       EXPECT_EQ("ConstantFolding/f-bcastargs-1", node.input(0));
3083     } else if (node.name() == "ConstantFolding/f-bcastargs-0") {
3084       ++found;
3085       EXPECT_EQ("Const", node.op());
3086       EXPECT_EQ(1, node.input_size());
3087       EXPECT_EQ("^f", node.input(0));
3088       EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
3089                        .num_elements());
3090     } else if (node.name() == "ConstantFolding/f-bcastargs-1") {
3091       ++found;
3092       EXPECT_EQ("Const", node.op());
3093       EXPECT_EQ(1, node.input_size());
3094       EXPECT_EQ("^f", node.input(0));
3095       EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
3096                        .num_elements());
3097     } else if (node.name() == "p1") {
3098       ++found;
3099       EXPECT_EQ(1, node.input_size());
3100       EXPECT_EQ("i", node.input(0));
3101     } else if (node.name() == "p2") {
3102       ++found;
3103       EXPECT_EQ(1, node.input_size());
3104       EXPECT_EQ("i:1", node.input(0));
3105     }
3106   }
3107   EXPECT_EQ(6, found);
3108 
3109   auto tensors = EvaluateNodes(output, fetch_nodes, {{"a", a_t}, {"g", g_t}});
3110   EXPECT_EQ(fetch_nodes.size(), tensors.size());
3111   for (int i = 0; i < fetch_nodes.size(); i++)
3112     test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
3113 }
3114 
TEST_F(ConstantFoldingTest,MaterializeBroadcastGradientArgs_InfiniteLoop)3115 TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs_InfiniteLoop) {
3116   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3117   Output a =
3118       ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
3119                        ops::Placeholder::Shape(PartialTensorShape({2, 2})));
3120   Output b = ops::Square(s.WithOpName("b"), a);
3121   Output c = ops::Mul(s.WithOpName("c"), a, b);
3122   Output d = ops::Shape(s.WithOpName("d"), a);
3123   Output e = ops::Shape(s.WithOpName("e"), b);
3124 
3125   auto f = ops::internal::BroadcastGradientArgs(s.WithOpName("f"), d, e);
3126   Output o1 = ops::Identity(s.WithOpName("o1"), f.r0);
3127   Output o2 = ops::Identity(s.WithOpName("o2"), f.r1);
3128 
3129   GrapplerItem item;
3130   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3131 
3132   std::vector<string> fetch_nodes = {"o1", "o2"};
3133   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
3134   auto tensors_expected = EvaluateNodes(item.graph, fetch_nodes, {{"a", a_t}});
3135   EXPECT_EQ(fetch_nodes.size(), tensors_expected.size());
3136 
3137   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3138   GraphDef output;
3139   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3140   TF_EXPECT_OK(status);
3141 
3142   // Run a second time to make sure the optimization is idempotent.
3143   item.graph.Swap(&output);
3144   status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3145   TF_EXPECT_OK(status);
3146 
3147   EXPECT_EQ(11, output.node_size());
3148   int found = 0;
3149   for (const auto& node : output.node()) {
3150     if (node.name() == "ConstantFolding/f-folded-1") {
3151       ++found;
3152       EXPECT_EQ("Const", node.op());
3153       EXPECT_EQ(2, node.input_size());
3154       EXPECT_EQ("^a", node.input(0));
3155       EXPECT_EQ("^b", node.input(1));
3156     } else if (node.name() == "d") {
3157       ++found;
3158       EXPECT_EQ("Const", node.op());
3159       EXPECT_EQ(1, node.input_size());
3160       EXPECT_EQ("^a", node.input(0));
3161     } else if (node.name() == "e") {
3162       ++found;
3163       EXPECT_EQ("Const", node.op());
3164       EXPECT_EQ(1, node.input_size());
3165       EXPECT_EQ("^b", node.input(0));
3166     } else if (node.name() == "o1") {
3167       ++found;
3168       EXPECT_EQ(1, node.input_size());
3169       EXPECT_EQ("ConstantFolding/f-bcastargs-0", node.input(0));
3170     } else if (node.name() == "o2") {
3171       ++found;
3172       EXPECT_EQ(1, node.input_size());
3173       EXPECT_EQ("ConstantFolding/f-bcastargs-1", node.input(0));
3174     } else if (node.name() == "ConstantFolding/f-bcastargs-0") {
3175       ++found;
3176       EXPECT_EQ("Const", node.op());
3177       EXPECT_EQ(1, node.input_size());
3178       EXPECT_EQ("^ConstantFolding/f-folded-1", node.input(0));
3179       EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
3180                        .num_elements());
3181     } else if (node.name() == "ConstantFolding/f-bcastargs-1") {
3182       ++found;
3183       EXPECT_EQ("Const", node.op());
3184       EXPECT_EQ(1, node.input_size());
3185       EXPECT_EQ("^ConstantFolding/f-folded-1", node.input(0));
3186       EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
3187                        .num_elements());
3188     }
3189   }
3190   EXPECT_EQ(7, found);
3191   auto tensors = EvaluateNodes(output, fetch_nodes, {{"a", a_t}});
3192   EXPECT_EQ(fetch_nodes.size(), tensors.size());
3193   for (int i = 0; i < fetch_nodes.size(); i++)
3194     test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
3195 }
3196 
TEST_F(ConstantFoldingTest,MaterializeReductionIndices)3197 TEST_F(ConstantFoldingTest, MaterializeReductionIndices) {
3198   for (bool use_reshape : {true, false}) {
3199     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3200     Output input =
3201         ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
3202                          ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
3203     // If use_reshape is false, we need to now the number of indices to apply
3204     // the rewrite.
3205     Output indices = ops::Placeholder(
3206         s.WithOpName("indices"), DT_INT32,
3207         ops::Placeholder::Shape(PartialTensorShape({use_reshape ? -1 : 2})));
3208     Output sum = ops::Sum(s.WithOpName("sum"), input, indices);
3209     if (use_reshape) {
3210       Output size = ops::Const(s.WithOpName("size"), 1, {1});
3211       Output reshape = ops::Reshape(s.WithOpName("reshape"), sum, size);
3212     }
3213 
3214     GrapplerItem item;
3215     TF_CHECK_OK(s.ToGraphDef(&item.graph));
3216     item.fetch.push_back(use_reshape ? "reshape" : "sum");
3217 
3218     auto input_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
3219     Tensor indices_t(DT_INT32, TensorShape({2}));
3220     indices_t.flat<int>()(0) = 0;
3221     indices_t.flat<int>()(1) = 1;
3222     auto tensors_expected = EvaluateNodes(
3223         item.graph, item.fetch, {{"input", input_t}, {"indices", indices_t}});
3224     EXPECT_EQ(1, tensors_expected.size());
3225 
3226     // Use aggressive mode to force the shape inference to propagate placeholder
3227     // shapes.
3228     ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
3229                               /*cpu_device=*/nullptr);
3230     GraphDef output;
3231     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3232     TF_EXPECT_OK(status);
3233 
3234     // Run a second time to make sure the optimization is idempotent.
3235     item.graph.Swap(&output);
3236     status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3237     TF_EXPECT_OK(status);
3238 
3239     int found = 0;
3240     for (const auto& node : output.node()) {
3241       if (node.name() == "ConstantFolding/sum-reduction_indices") {
3242         ++found;
3243         EXPECT_EQ("Const", node.op());
3244         EXPECT_EQ("^indices", node.input(0));
3245         EXPECT_EQ(2,
3246                   TensorShape(node.attr().at("value").tensor().tensor_shape())
3247                       .num_elements());
3248       } else if (node.name() == "sum") {
3249         ++found;
3250         EXPECT_EQ("ConstantFolding/sum-reduction_indices", node.input(1));
3251       } else if (node.name() == "indices") {
3252         ++found;
3253       }
3254     }
3255     EXPECT_EQ(3, found);
3256 
3257     auto tensors = EvaluateNodes(output, item.fetch,
3258                                  {{"input", input_t}, {"indices", indices_t}});
3259     EXPECT_EQ(1, tensors.size());
3260     test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
3261   }
3262 }
3263 
TEST_F(ConstantFoldingTest,MaterializeReductionIndices_NotFullReduction)3264 TEST_F(ConstantFoldingTest, MaterializeReductionIndices_NotFullReduction) {
3265   for (bool input_rank_known : {true, false}) {
3266     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3267     Output input =
3268         (input_rank_known ? ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
3269                                              ops::Placeholder::Shape(
3270                                                  PartialTensorShape({-1, -1})))
3271                           : ops::Placeholder(s.WithOpName("input"), DT_FLOAT));
3272     Output indices =
3273         ops::Placeholder(s.WithOpName("indices"), DT_INT32,
3274                          ops::Placeholder::Shape(
3275                              PartialTensorShape({input_rank_known ? 1 : 2})));
3276     Output sum = ops::Sum(s.WithOpName("sum"), input, indices);
3277 
3278     GrapplerItem item;
3279     TF_CHECK_OK(s.ToGraphDef(&item.graph));
3280     item.fetch.push_back("sum");
3281 
3282     // Use aggressive mode to force the shape inference to propagate placeholder
3283     // shapes.
3284     ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
3285                               /*cpu_device=*/nullptr);
3286     GraphDef output;
3287     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3288     TF_EXPECT_OK(status);
3289 
3290     CompareGraphs(item.graph, output);
3291   }
3292 }
3293 
TEST_F(ConstantFoldingTest,LargeConstant)3294 TEST_F(ConstantFoldingTest, LargeConstant) {
3295   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3296   // Generate a 4k by 4k constant, non-compressible matrix.
3297   Output mat_diag =
3298       ops::Const(scope.WithOpName("mat_diag"), 3.14f, TensorShape({1024 * 4}));
3299   Output mat = ops::Diag(scope.WithOpName("mat"), mat_diag);
3300   Output out = ops::Identity(scope.WithOpName("out"), mat);
3301 
3302   GrapplerItem item;
3303   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3304   item.fetch.push_back("out");
3305 
3306   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3307   GraphDef output;
3308   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3309   TF_EXPECT_OK(status);
3310 
3311   // Make sure the diag node hasn't been folded, since it would use too much
3312   // memory to encode the corresponding constant.
3313   int found = 0;
3314   for (const NodeDef& node : output.node()) {
3315     if (node.name() == "out") {
3316       EXPECT_EQ(node.op(), "Identity");
3317       ASSERT_EQ(node.input_size(), 1);
3318       EXPECT_EQ(node.input(0), "mat");
3319       ++found;
3320     } else if (node.name() == "mat") {
3321       EXPECT_EQ(node.op(), "Diag");
3322       ASSERT_EQ(node.input_size(), 1);
3323       EXPECT_EQ(node.input(0), "mat_diag");
3324       ++found;
3325     }
3326   }
3327   EXPECT_EQ(found, 2);
3328   // output should be no longer than the size of the constant "mat_diag"
3329   // plus a small constant amount for the remaining nodes.
3330   EXPECT_LT(output.ByteSizeLong(), sizeof(int) * 4 * 1024 + 500);
3331 
3332   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3333   ASSERT_EQ(tensors_expected.size(), 1);
3334   auto tensors = EvaluateNodes(output, item.fetch);
3335   ASSERT_EQ(tensors.size(), 1);
3336   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
3337 }
3338 
TEST_F(ConstantFoldingTest,SwitchIdenticalInputs)3339 TEST_F(ConstantFoldingTest, SwitchIdenticalInputs) {
3340   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3341   Output x = ops::Placeholder(s.WithOpName("x"), DT_BOOL,
3342                               ops::Placeholder::Shape(TensorShape({})));
3343   ops::Switch sw = ops::Switch(s.WithOpName("switch"), x, x);
3344   Output id_false = ops::LogicalNot(s.WithOpName("id_false"), sw.output_false);
3345   Output id_true = ops::LogicalNot(s.WithOpName("id_true"), sw.output_true);
3346 
3347   GrapplerItem item;
3348   item.fetch.push_back("id_false");
3349   item.fetch.push_back("id_true");
3350   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3351 
3352   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3353   GraphDef output;
3354   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3355   TF_EXPECT_OK(status);
3356 
3357   EXPECT_EQ(6, output.node_size());
3358   int found = 0;
3359   for (const auto& node : output.node()) {
3360     if (node.name() == "switch" || node.name() == "x") {
3361       ++found;
3362     }
3363     if (node.name() == "id_false") {
3364       EXPECT_EQ("Const", node.op());
3365       EXPECT_EQ(1, node.input_size());
3366       EXPECT_EQ("^ConstantFoldingCtrl/switch_0", node.input(0));
3367       ++found;
3368     }
3369     if (node.name() == "id_true") {
3370       EXPECT_EQ("Const", node.op());
3371       EXPECT_EQ(1, node.input_size());
3372       EXPECT_EQ("^ConstantFoldingCtrl/switch_1", node.input(0));
3373       ++found;
3374     }
3375     if (node.name() == "ConstantFoldingCtrl/switch_0") {
3376       EXPECT_EQ("Identity", node.op());
3377       EXPECT_EQ(1, node.input_size());
3378       EXPECT_EQ("switch", node.input(0));
3379       ++found;
3380     }
3381     if (node.name() == "ConstantFoldingCtrl/switch_1") {
3382       EXPECT_EQ("Identity", node.op());
3383       EXPECT_EQ(1, node.input_size());
3384       EXPECT_EQ("switch:1", node.input(0));
3385       ++found;
3386     }
3387   }
3388   EXPECT_EQ(6, found);
3389 
3390   // Evaluate id_true when input tensor x is true.
3391   Tensor x_t(DT_BOOL, TensorShape({}));
3392   x_t.flat<bool>()(0) = true;
3393   auto tensors_expected = EvaluateNodes(item.graph, {"id_true"}, {{"x", x_t}});
3394   EXPECT_EQ(1, tensors_expected.size());
3395   auto tensors = EvaluateNodes(output, {"id_true"}, {{"x", x_t}});
3396   EXPECT_EQ(1, tensors.size());
3397   test::ExpectTensorEqual<bool>(tensors_expected[0], tensors[0]);
3398 
3399   // Evaluate id_false when input tensor is false.
3400   x_t.flat<bool>()(0) = false;
3401   tensors_expected = EvaluateNodes(item.graph, {"id_false"}, {{"x", x_t}});
3402   EXPECT_EQ(1, tensors_expected.size());
3403   tensors = EvaluateNodes(output, {"id_false"}, {{"x", x_t}});
3404   EXPECT_EQ(1, tensors.size());
3405   test::ExpectTensorEqual<bool>(tensors_expected[0], tensors[0]);
3406 }
3407 
TEST_F(ConstantFoldingTest,PartialFolding_AssociativeAndCommutative)3408 TEST_F(ConstantFoldingTest, PartialFolding_AssociativeAndCommutative) {
3409   std::function<Output(const Scope&, InputList)> addn_fun =
3410       [](const Scope& scope, InputList inputs) {
3411         return ops::AddN(scope, inputs);
3412       };
3413   std::function<Output(const Scope&, InputList)> accumulate_fun =
3414       [](const Scope& scope, InputList inputs) {
3415         return ops::AccumulateNV2(scope, inputs, TensorShape({2, 2}));
3416       };
3417   for (bool use_add_n : {true, false}) {
3418     auto fun = use_add_n ? addn_fun : accumulate_fun;
3419     const string op_name = use_add_n ? "AddN" : "AccumulateNV2";
3420     Scope s = Scope::NewRootScope();
3421     Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
3422                                 ops::Placeholder::Shape(TensorShape({2, 2})));
3423     Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
3424                                 ops::Placeholder::Shape(TensorShape({2, 2})));
3425     Output z = ops::Placeholder(s.WithOpName("z"), DT_FLOAT,
3426                                 ops::Placeholder::Shape(TensorShape({2, 2})));
3427     Output c1 = ops::Const(s.WithOpName("c1"), 1.0f, {2, 2});
3428     Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2, 2});
3429     Output c3 = ops::Const(s.WithOpName("c3"), 3.0f, {2, 2});
3430     Output acc0 = fun(s.WithOpName("acc0"), {c1, c2, c3});
3431     Output acc1 = fun(s.WithOpName("acc1"), {x, y, z});
3432     Output acc2 = fun(s.WithOpName("acc2"), {c1, x, y});
3433     Output acc3 = fun(s.WithOpName("acc3"), {c1, c2, z});
3434     Output acc4 = fun(s.WithOpName("acc4"), {c1, y, c2});
3435     Output acc5 = fun(s.WithOpName("acc5"), {x, c1, c2});
3436     Output acc6 = fun(s.WithOpName("acc6"), {x, c1, y, c2});
3437     Output stack = ops::Stack(s.WithOpName("stack"),
3438                               {acc0, acc1, acc2, acc3, acc4, acc5, acc6});
3439 
3440     GrapplerItem item;
3441     TF_CHECK_OK(s.ToGraphDef(&item.graph));
3442     item.fetch = {"stack"};
3443 
3444     ConstantFolding optimizer(/*cpu_device=*/nullptr);
3445     GraphDef output;
3446     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3447     TF_EXPECT_OK(status);
3448 
3449     EXPECT_EQ(16, output.node_size());
3450     for (const NodeDef& node : output.node()) {
3451       if (node.name() == "acc0") {
3452         EXPECT_EQ("Const", node.op());
3453       }
3454       if (node.name() == "acc1") {
3455         EXPECT_EQ(op_name, node.op());
3456         EXPECT_EQ(3, node.input_size());
3457         EXPECT_EQ("x", node.input(0));
3458         EXPECT_EQ("y", node.input(1));
3459         EXPECT_EQ("z", node.input(2));
3460       }
3461       if (node.name() == "acc2") {
3462         EXPECT_EQ(op_name, node.op());
3463         EXPECT_EQ(3, node.input_size());
3464         EXPECT_EQ("c1", node.input(0));
3465         EXPECT_EQ("x", node.input(1));
3466         EXPECT_EQ("y", node.input(2));
3467       }
3468       if (node.name() == "acc3") {
3469         EXPECT_EQ(op_name, node.op());
3470         EXPECT_EQ(2, node.input_size());
3471         EXPECT_EQ("ConstantFolding/acc3_partial_split_2", node.input(0));
3472         EXPECT_EQ("z", node.input(1));
3473       }
3474       if (node.name() == "acc4") {
3475         EXPECT_EQ(op_name, node.op());
3476         EXPECT_EQ(2, node.input_size());
3477         EXPECT_EQ("ConstantFolding/acc4_partial_split_2", node.input(0));
3478         EXPECT_EQ("y", node.input(1));
3479       }
3480       if (node.name() == "acc5") {
3481         EXPECT_EQ(op_name, node.op());
3482         EXPECT_EQ(2, node.input_size());
3483         EXPECT_EQ("x", node.input(0));
3484         EXPECT_EQ("ConstantFolding/acc5_partial_split_2", node.input(1));
3485       }
3486       if (node.name() == "acc6") {
3487         EXPECT_EQ(op_name, node.op());
3488         EXPECT_EQ(3, node.input_size());
3489         EXPECT_EQ("x", node.input(0));
3490         EXPECT_EQ("ConstantFolding/acc6_partial_split_2", node.input(1));
3491         EXPECT_EQ("y", node.input(2));
3492       }
3493       if (absl::StartsWith(node.name(), "ConstantFolding/")) {
3494         EXPECT_EQ("Const", node.op());
3495       }
3496     }
3497 
3498     std::vector<string> fetch = {"acc0"};
3499     auto tensors_expected = EvaluateNodes(item.graph, fetch);
3500     auto tensors = EvaluateNodes(output, fetch);
3501     EXPECT_EQ(1, tensors_expected.size());
3502     EXPECT_EQ(1, tensors.size());
3503     test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
3504   }
3505 }
3506 
TEST_F(ConstantFoldingTest,PartialFolding_Concat)3507 TEST_F(ConstantFoldingTest, PartialFolding_Concat) {
3508   Scope s = Scope::NewRootScope();
3509   Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
3510                               ops::Placeholder::Shape(TensorShape({2, 2})));
3511   Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
3512                               ops::Placeholder::Shape(TensorShape({2, 2})));
3513   Output z = ops::Placeholder(s.WithOpName("z"), DT_FLOAT,
3514                               ops::Placeholder::Shape(TensorShape({2, 2})));
3515   Output axis = ops::Const(s.WithOpName("axis"), 0, {});
3516   Output c1 = ops::Const(s.WithOpName("c1"), 1.0f, {2, 2});
3517   Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2, 2});
3518   Output concat0 = ops::Concat(s.WithOpName("concat0"), {c1, c2, c1}, axis);
3519   Output concat1 = ops::Concat(s.WithOpName("concat1"), {x, y, z}, axis);
3520   Output concat2 = ops::Concat(s.WithOpName("concat2"), {c1, x, y}, axis);
3521   Output concat3 = ops::Concat(s.WithOpName("concat3"), {c1, c2, z}, axis);
3522   Output concat4 = ops::Concat(s.WithOpName("concat4"), {c1, y, c2}, axis);
3523   Output concat5 = ops::Concat(s.WithOpName("concat5"), {x, c1, c2}, axis);
3524   Output concat6 = ops::Concat(s.WithOpName("concat6"), {x, c1, y, c2}, axis);
3525   Output concat7 = ops::Concat(s.WithOpName("concat7"), {x, y, c1, c2}, axis);
3526   Output concat8 = ops::Concat(s.WithOpName("concat8"), {x, c1, c2, y}, axis);
3527   Output concat9 = ops::Concat(s.WithOpName("concat9"), {c1, c2, x, y}, axis);
3528 
3529   GrapplerItem item;
3530   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3531   item.fetch = {"concat0", "concat1", "concat2", "concat3", "concat4",
3532                 "concat5", "concat6", "concat7", "concat8", "concat9"};
3533 
3534   auto tensors_expected = EvaluateNodes(item.graph, {"concat0"});
3535   EXPECT_EQ(1, tensors_expected.size());
3536   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3537   GraphDef output;
3538   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3539   TF_EXPECT_OK(status);
3540   // Run the optimizer twice to make sure the rewrite is idempotent.
3541   item.graph.Swap(&output);
3542   status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3543   TF_EXPECT_OK(status);
3544 
3545   EXPECT_EQ(21, output.node_size());
3546   for (int i = 0; i < output.node_size(); ++i) {
3547     const NodeDef& node = output.node(i);
3548     if (node.name() == "concat0") {
3549       EXPECT_EQ("Const", node.op());
3550     } else if (node.name() == "concat3") {
3551       EXPECT_EQ(3, node.input_size());
3552       EXPECT_EQ("ConstantFolding/concat3_partial_split_0", node.input(0));
3553       EXPECT_EQ("z", node.input(1));
3554       EXPECT_EQ("axis", node.input(2));
3555     } else if (node.name() == "concat5") {
3556       EXPECT_EQ(3, node.input_size());
3557       EXPECT_EQ("x", node.input(0));
3558       EXPECT_EQ("ConstantFolding/concat5_partial_split_1", node.input(1));
3559       EXPECT_EQ("axis", node.input(2));
3560     } else if (node.name() == "concat7") {
3561       EXPECT_EQ(4, node.input_size());
3562       EXPECT_EQ("x", node.input(0));
3563       EXPECT_EQ("y", node.input(1));
3564       EXPECT_EQ("ConstantFolding/concat7_partial_split_2", node.input(2));
3565       EXPECT_EQ("axis", node.input(3));
3566     } else if (node.name() == "concat8") {
3567       EXPECT_EQ(4, node.input_size());
3568       EXPECT_EQ("x", node.input(0));
3569       EXPECT_EQ("ConstantFolding/concat8_partial_split_1", node.input(1));
3570       EXPECT_EQ("y", node.input(2));
3571       EXPECT_EQ("axis", node.input(3));
3572     } else if (node.name() == "concat9") {
3573       EXPECT_EQ(4, node.input_size());
3574       EXPECT_EQ("ConstantFolding/concat9_partial_split_0", node.input(0));
3575       EXPECT_EQ("x", node.input(1));
3576       EXPECT_EQ("y", node.input(2));
3577       EXPECT_EQ("axis", node.input(3));
3578     } else if (absl::StartsWith(node.name(), "ConstantFolding/")) {
3579       EXPECT_EQ("Const", node.op());
3580     } else {
3581       EXPECT_EQ(item.graph.node(i).DebugString(), node.DebugString());
3582     }
3583   }
3584 
3585   auto tensors = EvaluateNodes(output, {"concat0"});
3586   EXPECT_EQ(1, tensors.size());
3587   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
3588 }
3589 
TEST_F(ConstantFoldingTest,PartialFolding_IdentityN)3590 TEST_F(ConstantFoldingTest, PartialFolding_IdentityN) {
3591   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3592   Output x = ops::Placeholder(scope.WithOpName("x"), DT_FLOAT,
3593                               ops::Placeholder::Shape(TensorShape({})));
3594   Output c1 = ops::Const(scope.WithOpName("c1"), 1.0f, {2, 2});
3595   Output c2 = ops::Const(scope.WithOpName("c2"), 2.0f, {2, 2});
3596   auto id_n = ops::IdentityN(scope.WithOpName("id_n"), {c1, x, c2});
3597   auto id0 = ops::Identity(scope.WithOpName("id0"), id_n[0]);
3598   auto id1 = ops::Identity(scope.WithOpName("id1"), id_n[1]);
3599   auto add0 = ops::Add(scope.WithOpName("add0"), id_n[0], id_n[1]);
3600   auto add1 = ops::Add(scope.WithOpName("add1"), id_n[0], id_n[2]);
3601 
3602   GrapplerItem item;
3603   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3604   item.fetch.push_back("id0");
3605   item.fetch.push_back("id1");
3606   item.fetch.push_back("add0");
3607   item.fetch.push_back("add1");
3608 
3609   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3610   GraphDef output;
3611   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3612   TF_EXPECT_OK(status);
3613   EXPECT_EQ(8, output.node_size());
3614   for (const auto& node : output.node()) {
3615     // id_n should remain unchanged.
3616     if (node.name() == "id_n") {
3617       EXPECT_EQ(3, node.input_size());
3618       EXPECT_EQ("c1", node.input(0));
3619       EXPECT_EQ("x", node.input(1));
3620       EXPECT_EQ("c2", node.input(2));
3621     }
3622     // id0 should be constant folded, and a control dependency from id_n.
3623     if (node.name() == "id0") {
3624       EXPECT_EQ("Const", node.op());
3625       EXPECT_EQ(1, node.input_size());
3626       EXPECT_EQ("^id_n", node.input(0));
3627     }
3628     // id1 is unchanged.
3629     if ("id1" == node.name()) {
3630       EXPECT_EQ(1, node.input_size());
3631       EXPECT_EQ("id_n:1", node.input(0));
3632     }
3633 
3634     if ("add0" == node.name()) {
3635       EXPECT_EQ(2, node.input_size());
3636       EXPECT_EQ("c1", node.input(0));
3637       EXPECT_EQ("id_n:1", node.input(1));
3638     }
3639     // add1 should bo constant folded and have a control dependency from id_n.
3640     if ("add1" == node.name()) {
3641       EXPECT_EQ("Const", node.op());
3642       EXPECT_EQ(1, node.input_size());
3643       EXPECT_EQ("^id_n", node.input(0));
3644     }
3645   }
3646 
3647   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({}));
3648   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
3649   EXPECT_EQ(4, tensors_expected.size());
3650   auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
3651   EXPECT_EQ(4, tensors.size());
3652   for (int i = 0; i < tensors.size(); i++) {
3653     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
3654   }
3655 }
3656 
TEST_F(ConstantFoldingTest,TrivialPack)3657 TEST_F(ConstantFoldingTest, TrivialPack) {
3658   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3659   Output x =
3660       ops::RandomNormal(scope.WithOpName("x"), {2, 2}, DataType::DT_FLOAT);
3661   Output y = ops::Const(scope.WithOpName("y"), {2.0f}, {});
3662   auto stack =
3663       ops::Stack(scope.WithOpName("stack").WithControlDependencies({y}), {x},
3664                  ops::Stack::Axis(1));
3665   auto stack_no_axis = ops::Stack(scope.WithOpName("stack_no_axis"), {x});
3666 
3667   GrapplerItem item;
3668   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3669   item.fetch = {"stack", "stack_no_axis"};
3670 
3671   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3672   GraphDef output;
3673   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3674   TF_EXPECT_OK(status);
3675   EXPECT_EQ(7, output.node_size());
3676   int found = 0;
3677   for (const auto& node : output.node()) {
3678     if (node.name() == "stack") {
3679       EXPECT_EQ("ExpandDims", node.op());
3680       EXPECT_EQ(3, node.input_size());
3681       EXPECT_EQ("x", node.input(0));
3682       EXPECT_EQ("ConstantFolding/stack_const_axis", node.input(1));
3683       EXPECT_EQ("^y", node.input(2));
3684       ++found;
3685     } else if (node.name() == "stack_no_axis") {
3686       EXPECT_EQ("ExpandDims", node.op());
3687       EXPECT_EQ(2, node.input_size());
3688       EXPECT_EQ("x", node.input(0));
3689       EXPECT_EQ("ConstantFolding/stack_no_axis_const_axis", node.input(1));
3690       ++found;
3691     } else if (node.name() == "ConstantFolding/stack_const_axis") {
3692       EXPECT_EQ("Const", node.op());
3693       EXPECT_EQ(1, node.input_size());
3694       EXPECT_EQ("^x", node.input(0));
3695       ++found;
3696     }
3697   }
3698   EXPECT_EQ(found, 3);
3699 
3700   std::vector<string> fetch = {"stack", "stack_no_axis"};
3701   auto tensors_expected = EvaluateNodes(item.graph, fetch);
3702   auto tensors = EvaluateNodes(output, fetch);
3703   EXPECT_EQ(2, tensors_expected.size());
3704   EXPECT_EQ(2, tensors.size());
3705   EXPECT_EQ(tensors_expected[0].shape(), tensors[0].shape());
3706   EXPECT_EQ(tensors_expected[1].shape(), tensors[1].shape());
3707 }
3708 
3709 // The test does not evalaute the optimized and original graphs to check if
3710 // their outputs are the same. See b/78233179.
TEST_F(ConstantFoldingTest,Enter)3711 TEST_F(ConstantFoldingTest, Enter) {
3712   GrapplerItem item;
3713   AttrValue frame_name;
3714   frame_name.set_s("foo");
3715   AttrValue is_constant_true;
3716   is_constant_true.set_b(true);
3717   AttrValue is_constant_false;
3718   is_constant_false.set_b(false);
3719   AttrValue type;
3720   type.set_type(DT_FLOAT);
3721   AttrValue value;
3722   Tensor value_tensor(DT_FLOAT, TensorShape({}));
3723   value_tensor.flat<float>()(0) = 1;
3724   value_tensor.AsProtoTensorContent(value.mutable_tensor());
3725 
3726   GraphDef& graph = item.graph;
3727   AddNode("x", "Placeholder", {}, {{"dtype", type}}, &graph);
3728   AddNode("c1", "Const", {"^x"}, {{"value", value}, {"dtype", type}}, &graph);
3729   AddNode("enter1", "Enter", {"x"},
3730           {{"T", type},
3731            {"frame_name", frame_name},
3732            {"is_constant", is_constant_true}},
3733           &graph);
3734   AddNode("enter2", "Enter", {"c1"},
3735           {{"T", type},
3736            {"frame_name", frame_name},
3737            {"is_constant", is_constant_true}},
3738           &graph);
3739   AddNode("enter3", "Enter", {"c1"},
3740           {{"T", type},
3741            {"frame_name", frame_name},
3742            {"is_constant", is_constant_false}},
3743           &graph);
3744   AddNode("id1", "Identity", {"enter1"}, {{"T", type}}, &graph);
3745   AddNode("id2", "Identity", {"enter2"}, {{"T", type}}, &graph);
3746   AddNode("id3", "Identity", {"enter2"}, {{"T", type}}, &graph);
3747   AddNode("id4", "Identity", {"enter3"}, {{"T", type}}, &graph);
3748   item.fetch.push_back("id1");
3749   item.fetch.push_back("id2");
3750   item.fetch.push_back("id3");
3751   item.fetch.push_back("id4");
3752 
3753   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3754   GraphDef output;
3755   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3756   TF_EXPECT_OK(status);
3757   // Run the optimizer twice to make sure the rewrite is idempotent.
3758   item.graph.Swap(&output);
3759   status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3760   TF_EXPECT_OK(status);
3761 
3762   EXPECT_EQ(9, output.node_size());
3763   for (const NodeDef& node : output.node()) {
3764     if (node.name() == "id1") {
3765       EXPECT_EQ("Identity", node.op());
3766       EXPECT_EQ(1, node.input_size());
3767       EXPECT_EQ("enter1", node.input(0));
3768     }
3769     if (node.name() == "id2" || node.name() == "id3") {
3770       EXPECT_EQ("Const", node.op());
3771       EXPECT_EQ(1, node.input_size());
3772       EXPECT_EQ("^enter2", node.input(0));
3773     }
3774     if (node.name() == "id4") {
3775       EXPECT_EQ("Identity", node.op());
3776       EXPECT_EQ(1, node.input_size());
3777       EXPECT_EQ("enter3", node.input(0));
3778     }
3779   }
3780 }
3781 
TEST_F(ConstantFoldingTest,TensorArraySize)3782 TEST_F(ConstantFoldingTest, TensorArraySize) {
3783   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3784   Output size = ops::Const(scope.WithOpName("size"), 5, TensorShape({}));
3785   Output placeholder =
3786       ops::Placeholder(scope.WithOpName("placeholder"), DT_RESOURCE,
3787                        ops::Placeholder::Shape(TensorShape({2})));
3788   Output foo = ops::Const(scope.WithOpName("foo"), 5.0f, TensorShape({}));
3789   auto dynamic_array =
3790       ops::TensorArray(scope.WithOpName("dynamic"), size, DT_FLOAT,
3791                        ops::TensorArray::DynamicSize(true));
3792   auto static_array =
3793       ops::TensorArray(scope.WithOpName("static"), size, DT_FLOAT,
3794                        ops::TensorArray::DynamicSize(false));
3795   auto dynamic_sz = ops::TensorArraySize(
3796       scope.WithOpName("dynamic_sz"), dynamic_array.handle, dynamic_array.flow);
3797   auto static_sz = ops::TensorArraySize(scope.WithOpName("static_sz"),
3798                                         static_array.handle, static_array.flow);
3799   auto placeholder_sz = ops::TensorArraySize(scope.WithOpName("placeholder_sz"),
3800                                              placeholder, foo);
3801 
3802   GrapplerItem item;
3803   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3804 
3805   auto tensors_expected =
3806       EvaluateNodes(item.graph, {"dynamic_sz", "static_sz"});
3807 
3808   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3809   GraphDef output;
3810   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3811   TF_EXPECT_OK(status);
3812   // Run the optimizer twice to make sure the rewrite is idempotent.
3813   item.graph.Swap(&output);
3814   status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3815   TF_EXPECT_OK(status);
3816 
3817   EXPECT_EQ(8, output.node_size());
3818   EXPECT_EQ("dynamic_sz", output.node(5).name());
3819   EXPECT_EQ("TensorArraySizeV3", output.node(5).op());
3820   EXPECT_EQ("static_sz", output.node(6).name());
3821   EXPECT_EQ("Const", output.node(6).op());
3822   EXPECT_EQ("placeholder_sz", output.node(7).name());
3823   EXPECT_EQ("TensorArraySizeV3", output.node(7).op());
3824 
3825   auto tensors_actual = EvaluateNodes(output, {"dynamic_sz", "static_sz"});
3826   EXPECT_EQ(2, tensors_expected.size());
3827   EXPECT_EQ(2, tensors_actual.size());
3828   test::ExpectTensorEqual<int32>(tensors_expected[0], tensors_actual[0]);
3829   test::ExpectTensorEqual<int32>(tensors_expected[1], tensors_actual[1]);
3830 }
3831 
TEST_F(ConstantFoldingTest,FoldingPreservesDenormalFlushing)3832 TEST_F(ConstantFoldingTest, FoldingPreservesDenormalFlushing) {
3833   // Multiplying min() with 0.1 gives a denormal without FTZ and zero with FTZ.
3834   // Make sure constant folding behaves the same way as TensorFlow.
3835   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3836 
3837   Output a =
3838       ops::Const(s.WithOpName("a"), std::numeric_limits<float>::min(), {1});
3839   Output b = ops::Const(s.WithOpName("b"), 0.1f, {1});
3840   Output c = ops::Mul(s.WithOpName("c"), a, b);
3841 
3842   GrapplerItem item;
3843   item.fetch.push_back("c");
3844   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3845 
3846   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3847   GraphDef output;
3848   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3849   TF_EXPECT_OK(status);
3850 
3851   EXPECT_EQ(1, output.node_size());
3852 
3853   const NodeDef& node_d = output.node(0);
3854   EXPECT_EQ("c", node_d.name());
3855   EXPECT_EQ("Const", node_d.op());
3856 
3857   std::vector<string> fetch = {"c"};
3858   auto tensors_expected = EvaluateNodes(item.graph, fetch);
3859   auto tensors = EvaluateNodes(output, fetch);
3860   EXPECT_EQ(1, tensors_expected.size());
3861   EXPECT_EQ(1, tensors.size());
3862   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
3863 }
3864 
TEST_F(ConstantFoldingTest,EvaluatingLargeConstantNoFoldingMergingLoop)3865 TEST_F(ConstantFoldingTest, EvaluatingLargeConstantNoFoldingMergingLoop) {
3866   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3867 
3868   int size = 10 * 1024 * 1024 / 4 / 2;
3869   Output nonconst =
3870       ops::RandomUniform(s.WithOpName("nonconst"), {size, 1}, DT_FLOAT);
3871   Output const1 = ops::Const(s.WithOpName("const1"), 0.0f, {size, 1});
3872   Output const2 = ops::Const(s.WithOpName("const2"), 1.0f, {size, 1});
3873   Output axis = ops::Const(s.WithOpName("axis"), -1, {});
3874   Output concat1 =
3875       ops::Concat(s.WithOpName("concat1"), {nonconst, const1}, axis);
3876   Output result = ops::Concat(s.WithOpName("result"), {concat1, const2}, axis);
3877 
3878   GrapplerItem item;
3879   item.fetch.push_back("result");
3880   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3881 
3882   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3883   GraphDef output;
3884   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3885   TF_EXPECT_OK(status);
3886 
3887   std::vector<string> fetch = {"result"};
3888   auto tensors_expected = EvaluateNodes(item.graph, fetch);
3889   auto tensors = EvaluateNodes(output, fetch);
3890   EXPECT_EQ(1, tensors_expected.size());
3891   EXPECT_EQ(1, tensors.size());
3892   EXPECT_EQ(tensors_expected[0].shape(), tensors[0].shape());
3893 }
3894 
3895 class ConstantFoldingCastConstTest : public GrapplerTest {
3896  protected:
ConstantFoldingCastConst(bool fetch_const,bool fetch_cast,bool fetch_const_child,bool fetch_cast_child)3897   void ConstantFoldingCastConst(bool fetch_const, bool fetch_cast,
3898                                 bool fetch_const_child, bool fetch_cast_child) {
3899     if (!fetch_const && !fetch_cast && !fetch_const_child &&
3900         !fetch_cast_child) {
3901       return;
3902     }
3903 
3904     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3905     CreateCastConstGraph(s);
3906     GrapplerItem item;
3907     int expected_output_size = SetFetch(&item, fetch_const, fetch_cast,
3908                                         fetch_const_child, fetch_cast_child);
3909     TF_CHECK_OK(s.ToGraphDef(&item.graph));
3910 
3911     GraphDef output = ConstantFoldingOptimize(item);
3912     EXPECT_EQ(expected_output_size, output.node_size());
3913 
3914     EvaluateAndCompareUnoptimized(item.graph, output, item.fetch);
3915   }
3916 
3917  private:
CreateCastConstGraph(const tensorflow::Scope & s)3918   void CreateCastConstGraph(const tensorflow::Scope& s) {
3919     Output const1 = ops::Const(s.WithOpName("const1"), 2, {5, 5});
3920     Output cast = ops::Cast(s.WithOpName("cast"), const1, DT_FLOAT);
3921     Output const1_child = ops::Identity(s.WithOpName("const1_child"), const1);
3922     Output cast_child = ops::Identity(s.WithOpName("cast_child"), cast);
3923   }
3924 
SetFetch(GrapplerItem * item,bool fetch_const,bool fetch_cast,bool fetch_const_child,bool fetch_cast_child)3925   int SetFetch(GrapplerItem* item, bool fetch_const, bool fetch_cast,
3926                bool fetch_const_child, bool fetch_cast_child) {
3927     int expected_output_size = 0;
3928     if (fetch_const) {
3929       item->fetch.push_back("const1");
3930       expected_output_size++;
3931     }
3932     if (fetch_cast) {
3933       item->fetch.push_back("cast");
3934       expected_output_size++;
3935     }
3936     if (fetch_const_child) {
3937       item->fetch.push_back("const1_child");
3938       expected_output_size++;
3939     }
3940     if (fetch_cast_child) {
3941       item->fetch.push_back("cast_child");
3942       expected_output_size++;
3943     }
3944     return expected_output_size;
3945   }
3946 
ConstantFoldingOptimize(const GrapplerItem & item)3947   GraphDef ConstantFoldingOptimize(const GrapplerItem& item) {
3948     ConstantFolding optimizer(/*cpu_device=*/nullptr);
3949     GraphDef output;
3950     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3951     TF_EXPECT_OK(status);
3952     return output;
3953   }
3954 
EvaluateAndCompareUnoptimized(const GraphDef & unoptimized_graph,const GraphDef & optimized_graph,const std::vector<string> & fetch_nodes)3955   void EvaluateAndCompareUnoptimized(const GraphDef& unoptimized_graph,
3956                                      const GraphDef& optimized_graph,
3957                                      const std::vector<string>& fetch_nodes) {
3958     auto tensors_expected = EvaluateNodes(unoptimized_graph, fetch_nodes);
3959     auto tensors = EvaluateNodes(optimized_graph, fetch_nodes);
3960     ASSERT_EQ(fetch_nodes.size(), tensors_expected.size());
3961     ASSERT_EQ(fetch_nodes.size(), tensors.size());
3962     for (int i = 0; i < fetch_nodes.size(); i++) {
3963       if (fetch_nodes[i] == "const1" || fetch_nodes[i] == "const1_child") {
3964         test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
3965       } else {
3966         test::ExpectTensorEqual<float>(tensors_expected[i], tensors[i]);
3967       }
3968     }
3969   }
3970 };
3971 
TEST_F(ConstantFoldingCastConstTest,CastConstFolding)3972 TEST_F(ConstantFoldingCastConstTest, CastConstFolding) {
3973   for (bool fetch_const : {false, true}) {
3974     for (bool fetch_cast : {false, true}) {
3975       for (bool fetch_const_child : {false, true}) {
3976         for (bool fetch_cast_child : {false, true}) {
3977           ConstantFoldingCastConst(fetch_const, fetch_cast, fetch_const_child,
3978                                    fetch_cast_child);
3979         }
3980       }
3981     }
3982   }
3983 }
3984 
TEST_F(ConstantFoldingTest,MaterializeConstantValuedNode)3985 TEST_F(ConstantFoldingTest, MaterializeConstantValuedNode) {
3986   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3987 
3988   Output x =
3989       ops::Placeholder(scope.WithOpName("x"), DT_FLOAT,
3990                        ops::Placeholder::Shape(TensorShape({1, 2, 3, 4})));
3991   Output ones_like = ops::OnesLike(scope.WithOpName("ones_like"), x);
3992   Output zeros_like = ops::ZerosLike(scope.WithOpName("zeros_like"), x);
3993   Output fill = ops::Fill(scope.WithOpName("fill"), {4, 3, 2, 1}, 42);
3994 
3995   GrapplerItem item;
3996   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3997   item.fetch = {"ones_like", "zeros_like", "fill"};
3998   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 2, 3, 4}));
3999   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
4000 
4001   ConstantFolding optimizer(/*cpu_device=*/nullptr);
4002   GraphDef output;
4003   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
4004   TF_EXPECT_OK(status);
4005 
4006   EXPECT_EQ(output.node_size(), 6);
4007   for (const auto& node : output.node()) {
4008     if (node.name() != "x") {
4009       EXPECT_EQ(node.op(), "Const");
4010     }
4011     if (node.name() == "ones_like" || node.name() == "zeros_like") {
4012       ASSERT_EQ(node.input_size(), 1);
4013       EXPECT_EQ(node.input(0), "^x");
4014     }
4015     if (node.name() == "fill") {
4016       ASSERT_EQ(node.input_size(), 2);
4017       EXPECT_EQ(node.input(0)[0], '^');
4018       EXPECT_EQ(node.input(1)[0], '^');
4019     }
4020   }
4021   auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
4022   ASSERT_EQ(item.fetch.size(), tensors.size());
4023   ASSERT_EQ(tensors_expected.size(), tensors.size());
4024   for (int i = 0; i < tensors.size(); i++) {
4025     if (item.fetch[i] == "fill") {
4026       test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
4027     } else {
4028       test::ExpectTensorEqual<float>(tensors_expected[i], tensors[i]);
4029     }
4030   }
4031 }
4032 
TEST_F(ConstantFoldingTest,MaterializeConstantValuedNodeDisableCompression)4033 TEST_F(ConstantFoldingTest, MaterializeConstantValuedNodeDisableCompression) {
4034   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
4035 
4036   Output x =
4037       ops::Placeholder(scope.WithOpName("x"), DT_FLOAT,
4038                        ops::Placeholder::Shape(TensorShape({1, 2, 3, 4})));
4039   Output ones_like = ops::OnesLike(scope.WithOpName("ones_like"), x);
4040   Output zeros_like = ops::ZerosLike(scope.WithOpName("zeros_like"), x);
4041   Output fill = ops::Fill(scope.WithOpName("fill"), {4, 3, 2, 1}, 42);
4042 
4043   GrapplerItem item;
4044   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
4045   item.fetch = {"ones_like", "zeros_like", "fill"};
4046   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 2, 3, 4}));
4047   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
4048 
4049   ConstantFolding optimizer(/*cpu_device=*/nullptr, true);
4050   GraphDef output;
4051   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
4052   TF_EXPECT_OK(status);
4053 
4054   EXPECT_EQ(output.node_size(), 6);
4055   for (const auto& node : output.node()) {
4056     if (node.name() == "ones_like") {
4057       EXPECT_EQ(node.op(), "OnesLike");
4058       ASSERT_EQ(node.input_size(), 1);
4059       EXPECT_EQ(node.input(0), "x");
4060     }
4061     if (node.name() == "zeros_like") {
4062       EXPECT_EQ(node.op(), "ZerosLike");
4063       ASSERT_EQ(node.input_size(), 1);
4064       EXPECT_EQ(node.input(0), "x");
4065     }
4066     if (node.name() == "fill") {
4067       EXPECT_EQ(node.op(), "Fill");
4068       ASSERT_EQ(node.input_size(), 2);
4069       EXPECT_EQ(node.input(0), "Const/Const");
4070       EXPECT_EQ(node.input(1), "Const_1/Const");
4071     }
4072   }
4073   auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
4074   ASSERT_EQ(item.fetch.size(), tensors.size());
4075   ASSERT_EQ(tensors_expected.size(), tensors.size());
4076   for (int i = 0; i < tensors.size(); i++) {
4077     if (item.fetch[i] == "fill") {
4078       test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
4079     } else {
4080       test::ExpectTensorEqual<float>(tensors_expected[i], tensors[i]);
4081     }
4082   }
4083 }
4084 
TEST_F(ConstantFoldingTest,MaterializeConstantValuedNodeHugeFill)4085 TEST_F(ConstantFoldingTest, MaterializeConstantValuedNodeHugeFill) {
4086   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
4087   Output value = ops::Const(scope.WithOpName("value"), 42, {});
4088   Output shape_const = ops::Const(scope.WithOpName("shape"),
4089                                   {1024, 1024, 1024, 1024, 1024}, {5});
4090   Output fill_huge =
4091       ops::Fill(scope.WithOpName("fill_huge"), shape_const, value);
4092 
4093   GrapplerItem item;
4094   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
4095   // Manually convert the input value format to tensor_content to test this
4096   // case.
4097   NodeDef* node = item.graph.mutable_node(0);
4098   ASSERT_EQ(node->name(), "value");
4099   TensorProto* t = (*node->mutable_attr())["value"].mutable_tensor();
4100   t->clear_int_val();
4101   int val = 42;
4102   port::CopyFromArray(t->mutable_tensor_content(),
4103                       reinterpret_cast<const char*>(&val), sizeof(int));
4104   item.fetch = {"fill_huge"};
4105   ConstantFolding optimizer(/*cpu_device=*/nullptr);
4106   GraphDef output;
4107   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
4108   TF_EXPECT_OK(status);
4109 
4110   EXPECT_EQ(output.node_size(), 3);
4111   for (const auto& node : output.node()) {
4112     EXPECT_EQ(node.op(), "Const");
4113     if (node.name() == "fill_huge") {
4114       ASSERT_EQ(node.input_size(), 2);
4115       EXPECT_EQ(node.input(0), "^shape");
4116       EXPECT_EQ(node.input(1), "^value");
4117     }
4118   }
4119 }
4120 
TEST_F(ConstantFoldingTest,BitcastDenormalFloats)4121 TEST_F(ConstantFoldingTest, BitcastDenormalFloats) {
4122   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
4123 
4124   Tensor x_t(DT_INT64, TensorShape({2, 2}));
4125   x_t.flat<int64>()(0) = 9223372036854775807L;
4126   x_t.flat<int64>()(1) = 1L;
4127   x_t.flat<int64>()(2) = 9223372036854775807L;
4128   x_t.flat<int64>()(3) = 1L;
4129   Output x = ops::Const(scope.WithOpName("x"), x_t);
4130   Output y = ops::Bitcast(scope.WithOpName("y"), x, DT_FLOAT);
4131   Output z = ops::Bitcast(scope.WithOpName("z"), y, DT_INT64);
4132 
4133   GrapplerItem item;
4134   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
4135   item.fetch = {"z"};
4136   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {});
4137 
4138   ConstantFolding optimizer(/*cpu_device=*/nullptr);
4139   GraphDef output;
4140   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
4141   TF_EXPECT_OK(status);
4142 
4143   ASSERT_EQ(output.node_size(), 1);
4144   const NodeDef& node = output.node(0);
4145   EXPECT_EQ(node.name(), "z");
4146   EXPECT_EQ(node.op(), "Const");
4147 
4148   auto tensors = EvaluateNodes(output, item.fetch, {});
4149   ASSERT_EQ(tensors.size(), 1);
4150   ASSERT_EQ(tensors_expected.size(), 1);
4151   test::ExpectTensorEqual<int64>(tensors[0], tensors_expected[0]);
4152 }
4153 
TEST_F(ConstantFoldingTest,SimplifyCase)4154 TEST_F(ConstantFoldingTest, SimplifyCase) {
4155   using test::function::NDef;
4156 
4157   for (int index = 0; index < 2; ++index) {
4158     // Build a graph to compute y = Case(index, x, XTimesTwo(x), NonZero(x))
4159     GrapplerItem item;
4160     constexpr char kDevice[] = "/job:localhost/replica:0/task:0/device:CPU:0";
4161     AttrValue branches;
4162     auto* f = branches.mutable_list()->add_func();
4163     f->set_name("XTimesTwo");
4164     (*f->mutable_attr())["T"].set_type(DT_FLOAT);
4165     auto* g = branches.mutable_list()->add_func();
4166     *g = *f;
4167     g->set_name("NonZero");
4168 
4169     // Add a pair of somewhat arbitrary output shapes to
4170     // test that they are correctly propagates to the _output_shapes
4171     // attribute.
4172     AttrValue output_shapes;
4173     // The first shape is a scalar.
4174     output_shapes.mutable_list()->add_shape();
4175     // The second shape is unknown.
4176     TensorShapeProto* g_shape = output_shapes.mutable_list()->add_shape();
4177     g_shape->set_unknown_rank(true);
4178 
4179     const Tensor kZero = test::AsScalar<int32>(0);
4180     const Tensor kOne = test::AsScalar<int32>(1);
4181     item.graph = test::function::GDef(
4182         {NDef("one", "Const", {},
4183               {{"value", index == 0 ? kZero : kOne}, {"dtype", DT_INT32}},
4184               kDevice),
4185          NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
4186          NDef("case", "Case", {"one", "x"},
4187               {{"Tin", DataTypeSlice{DT_FLOAT}},
4188                {"Tout", DataTypeSlice{DT_FLOAT}},
4189                {"branches", branches},
4190                {"output_shapes", output_shapes}},
4191               kDevice),
4192          NDef("y", "Identity", {"case"}, {{"T", DT_FLOAT}}, kDevice)},
4193         // FunctionLib
4194         {
4195             test::function::XTimesTwo(),
4196             test::function::NonZero(),
4197         });
4198     VLOG(1) << "Before: " << item.graph.DebugString();
4199 
4200     item.fetch = {"y"};
4201     const Tensor kTwo = test::AsScalar<float>(2.0f);
4202     auto tensors_expected =
4203         EvaluateNodes(item.graph, item.fetch, {{"x", kTwo}});
4204 
4205     ConstantFolding optimizer(/*cpu_device=*/nullptr);
4206     GraphDef optimized_graph;
4207     TF_ASSERT_OK(
4208         optimizer.Optimize(/*cluster=*/nullptr, item, &optimized_graph));
4209     VLOG(1) << "After: " << optimized_graph.DebugString();
4210 
4211     int pco_count = 0;
4212     for (const auto& node : optimized_graph.node()) {
4213       EXPECT_NE(node.op(), "Case");
4214       if (node.op() == "PartitionedCall") {
4215         ++pco_count;
4216         const auto& shape_list = node.attr().at("_output_shapes").list();
4217         ASSERT_EQ(shape_list.shape_size(), 1);
4218         EXPECT_EQ(shape_list.shape(0).dim_size(), 0);
4219         if (index == 0) {
4220           EXPECT_EQ(node.attr().at("f").func().name(), "XTimesTwo");
4221           EXPECT_EQ(shape_list.shape(0).unknown_rank(), false);
4222         } else {
4223           EXPECT_EQ(node.attr().at("f").func().name(), "NonZero");
4224           EXPECT_EQ(shape_list.shape(0).unknown_rank(), true);
4225         }
4226       }
4227     }
4228     EXPECT_EQ(pco_count, 1);
4229 
4230     auto tensors = EvaluateNodes(optimized_graph, item.fetch, {{"x", kTwo}});
4231     ASSERT_EQ(tensors.size(), tensors_expected.size());
4232     test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
4233   }
4234 }
4235 
TEST_F(ConstantFoldingTest,SimplifySelect)4236 TEST_F(ConstantFoldingTest, SimplifySelect) {
4237   for (bool scalar_pred : {true, false}) {
4238     for (bool pred_val : {true, false}) {
4239       tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
4240       std::unique_ptr<Tensor> if_t;
4241       if (scalar_pred) {
4242         if_t.reset(new Tensor(DT_BOOL, TensorShape()));
4243       } else {
4244         if_t.reset(new Tensor(DT_BOOL, TensorShape({2, 2})));
4245       }
4246       for (int i = 0; i < (scalar_pred ? 1 : 4); ++i) {
4247         if_t->flat<bool>()(i) = pred_val;
4248       }
4249       Output if_ = ops::Const(scope.WithOpName("if"), *if_t);
4250       Output then_ =
4251           ops::Placeholder(scope.WithOpName("then"), DT_FLOAT,
4252                            ops::Placeholder::Shape(TensorShape({2, 2})));
4253       Output else_ =
4254           ops::Placeholder(scope.WithOpName("else"), DT_FLOAT,
4255                            ops::Placeholder::Shape(TensorShape({2, 2})));
4256       Output select =
4257           ops::SelectV2(scope.WithOpName("select"), if_, then_, else_);
4258       Output id = ops::Identity(scope.WithOpName("id"), select);
4259 
4260       GrapplerItem item;
4261       TF_CHECK_OK(scope.ToGraphDef(&item.graph));
4262       item.fetch = {"id"};
4263 
4264       const Tensor kOne =
4265           test::AsTensor<float>({1.0f, 1.0f, 1.0f, 1.0f}, TensorShape({2, 2}));
4266       const Tensor kTwo =
4267           test::AsTensor<float>({2.0f, 2.0f, 2.0f, 2.0f}, TensorShape({2, 2}));
4268       auto tensors_expected = EvaluateNodes(item.graph, item.fetch,
4269                                             {{"then", kOne}, {"else", kTwo}});
4270 
4271       // Use aggressive mode to force the shape inference to propagate
4272       // placeholder shapes.
4273       ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
4274                                 /*cpu_device=*/nullptr);
4275       GraphDef optimized_graph;
4276       TF_EXPECT_OK(
4277           optimizer.Optimize(/*cluster=*/nullptr, item, &optimized_graph));
4278 
4279       ASSERT_EQ(optimized_graph.node_size(), 5);
4280       bool found = false;
4281       for (const auto& node : optimized_graph.node()) {
4282         if (node.name() == "select") {
4283           found = true;
4284           EXPECT_EQ(node.op(), "Identity");
4285           ASSERT_EQ(node.input_size(), 3);
4286           EXPECT_EQ(node.input(0), pred_val ? "then" : "else");
4287           EXPECT_EQ(node.input(1), pred_val ? "^if" : "^then");
4288           EXPECT_EQ(node.input(2), pred_val ? "^else" : "^if");
4289         }
4290       }
4291       EXPECT_TRUE(found);
4292 
4293       auto tensors = EvaluateNodes(optimized_graph, item.fetch,
4294                                    {{"then", kOne}, {"else", kTwo}});
4295       ASSERT_EQ(tensors.size(), 1);
4296       ASSERT_EQ(tensors_expected.size(), 1);
4297       test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
4298     }
4299   }
4300 }
4301 
TEST_F(ConstantFoldingTest,SimplifySelect_BroadcastTo)4302 TEST_F(ConstantFoldingTest, SimplifySelect_BroadcastTo) {
4303   for (TensorShape pred_shape : {TensorShape{2, 1}, TensorShape{2, 2, 1}}) {
4304     for (bool pred_val : {true, false}) {
4305       tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
4306       std::unique_ptr<Tensor> if_t;
4307       if_t.reset(new Tensor(DT_BOOL, pred_shape));
4308       for (int i = 0; i < pred_shape.num_elements(); ++i) {
4309         if_t->flat<bool>()(i) = pred_val;
4310       }
4311       Output if_ = ops::Const(scope.WithOpName("if"), *if_t);
4312       Output then_ =
4313           ops::Placeholder(scope.WithOpName("then"), DT_FLOAT,
4314                            ops::Placeholder::Shape(TensorShape({2, 1})));
4315       Output else_ =
4316           ops::Placeholder(scope.WithOpName("else"), DT_FLOAT,
4317                            ops::Placeholder::Shape(TensorShape({2, 4})));
4318       Output select =
4319           ops::SelectV2(scope.WithOpName("select"), if_, then_, else_);
4320       Output id = ops::Identity(scope.WithOpName("id"), select);
4321 
4322       GrapplerItem item;
4323       TF_CHECK_OK(scope.ToGraphDef(&item.graph));
4324       item.fetch = {"id"};
4325 
4326       const Tensor kOne =
4327           test::AsTensor<float>({1.0f, 1.0f}, TensorShape({2, 1}));
4328       const Tensor kTwo = test::AsTensor<float>(
4329           {2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f},
4330           TensorShape({2, 4}));
4331       auto tensors_expected = EvaluateNodes(item.graph, item.fetch,
4332                                             {{"then", kOne}, {"else", kTwo}});
4333 
4334       // Use aggressive mode to force the shape inference to propagate
4335       // placeholder shapes.
4336       ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
4337                                 /*cpu_device=*/nullptr);
4338       GraphDef optimized_graph;
4339       TF_EXPECT_OK(
4340           optimizer.Optimize(/*cluster=*/nullptr, item, &optimized_graph));
4341 
4342       ASSERT_EQ(optimized_graph.node_size(), 6);
4343       bool found = false;
4344       for (const auto& node : optimized_graph.node()) {
4345         if (node.name() == "select") {
4346           found = true;
4347           EXPECT_EQ(node.op(), "BroadcastTo");
4348           ASSERT_EQ(node.input_size(), 4);
4349           EXPECT_EQ(node.input(0), pred_val ? "then" : "else");
4350           EXPECT_EQ(node.input(1),
4351                     strings::StrCat("ConstantFolding/select-broadcastto_shape-",
4352                                     pred_val ? 1 : 2));
4353           EXPECT_EQ(node.input(2), pred_val ? "^else" : "^if");
4354           EXPECT_EQ(node.input(3), pred_val ? "^if" : "^then");
4355         }
4356       }
4357       EXPECT_TRUE(found);
4358 
4359       auto tensors = EvaluateNodes(optimized_graph, item.fetch,
4360                                    {{"then", kOne}, {"else", kTwo}});
4361       ASSERT_EQ(tensors.size(), 1);
4362       ASSERT_EQ(tensors_expected.size(), 1);
4363       ASSERT_EQ(tensors[0].shape(), pred_shape.num_elements() == 2
4364                                         ? TensorShape({2, 4})
4365                                         : TensorShape({2, 2, 4}));
4366       test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
4367     }
4368   }
4369 }
4370 
TEST_F(ConstantFoldingTest,QuantizationEmulation)4371 TEST_F(ConstantFoldingTest, QuantizationEmulation) {
4372   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
4373   Output x = ops::Const(scope.WithOpName("x"), {0.0f, 1.0f, 2.0f, 3.0f}, {4});
4374   Output min_range = ops::Const(scope.WithOpName("min_range"), 0.0f, {});
4375   Output max_range = ops::Const(scope.WithOpName("max_range"), 3.0f, {});
4376   Output y = ops::QuantizeAndDequantizeV2(scope.WithOpName("y"), x, min_range,
4377                                           max_range);
4378   Output id = ops::Identity(scope.WithOpName("id"), y);
4379 
4380   GrapplerItem item;
4381   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
4382   item.fetch = {"id"};
4383 
4384   std::vector<Tensor> expected_tensors = EvaluateNodes(item.graph, item.fetch);
4385 
4386   for (const bool fold_quantization_emulation : {false, true}) {
4387     ConstantFolding optimizer(/*cpu_device=*/nullptr,
4388                               /*disable_compressed_tensor_optimization=*/false,
4389                               fold_quantization_emulation);
4390     GraphDef output;
4391     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
4392     int num_quantization_emulation_ops = 0;
4393     for (const NodeDef& node : output.node()) {
4394       if (node.op() == "QuantizeAndDequantizeV2") {
4395         num_quantization_emulation_ops++;
4396       }
4397     }
4398     EXPECT_EQ(fold_quantization_emulation ? 0 : 1,
4399               num_quantization_emulation_ops);
4400 
4401     std::vector<Tensor> actual_tensors = EvaluateNodes(output, item.fetch);
4402     for (int i = 0; i < item.fetch.size(); ++i) {
4403       test::ExpectTensorEqual<float>(expected_tensors[i], actual_tensors[i]);
4404     }
4405   }
4406 }
4407 
4408 }  // namespace
4409 }  // namespace grappler
4410 }  // namespace tensorflow
4411