1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/costs/graph_properties.h"
17 #include "tensorflow/cc/framework/scope.h"
18 #include "tensorflow/cc/ops/standard_ops.h"
19 #include "tensorflow/core/framework/graph_def_util.h"
20 #include "tensorflow/core/framework/node_def_builder.h"
21 #include "tensorflow/core/framework/tensor.pb.h"  // NOLINT
22 #include "tensorflow/core/framework/tensor_shape.pb.h"
23 #include "tensorflow/core/framework/tensor_testutil.h"
24 #include "tensorflow/core/framework/types.pb.h"
25 #include "tensorflow/core/framework/versions.pb.h"
26 #include "tensorflow/core/grappler/clusters/single_machine.h"
27 #include "tensorflow/core/grappler/grappler_item.h"
28 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
29 #include "tensorflow/core/grappler/inputs/utils.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 #include "tensorflow/core/lib/io/path.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/platform/protobuf.h"
34 #include "tensorflow/core/platform/test.h"
35 
36 namespace tensorflow {
37 namespace grappler {
38 namespace {
39 
40 const char kTestDataPath[] = "core/grappler/costs/graph_properties_testdata";
41 
42 class GraphPropertiesTest : public ::testing::Test {
43  public:
SetUp()44   void SetUp() override {
45     // Provision a single machine with 3 cpu cores
46     cluster_.reset(new SingleMachine(5 * 60, 3, 0));
47     TF_CHECK_OK(cluster_->Provision());
48 
49     // This function is simply
50     // out = Fill(shape, value), but
51     // Fill requires values in the shape input, not just shape of it, to infer
52     // output shape.
53     auto f = FunctionDefHelper::Create(
54         // Name
55         "MyFillFunc",
56         // Inputs
57         {"shape: int32", "value: float"},
58         // Outputs
59         {"out: float"},
60         // Attrs
61         {},
62         // Nodes
63         {
64             {{"a"},
65              "Fill",
66              {"shape", "value"},
67              {{"T", DataType::DT_FLOAT}, {"index_type", DataType::DT_INT32}}},
68         },
69         // Returns
70         {{"out", "a:output:0"}});
71     function_lib_.add_function()->Swap(&f);
72   }
73 
TearDown()74   void TearDown() override {
75     TF_CHECK_OK(cluster_->Shutdown());
76     cluster_.reset();
77   }
78 
79  protected:
80   // Returns a string form of <p>, suitable for comparing type and shape.
81   // Example output for 4-d float tensor: "float: [10,2,30,4]"
PropToString(const OpInfo::TensorProperties & p)82   string PropToString(const OpInfo::TensorProperties& p) {
83     string s = strings::StrCat(DataTypeString(p.dtype()), ": ");
84     if (p.shape().unknown_rank()) {
85       strings::StrAppend(&s, "?");
86     } else {
87       strings::StrAppend(&s, "[");
88       for (int i = 0; i < p.shape().dim_size(); ++i) {
89         strings::StrAppend(&s, i == 0 ? "" : ",",
90                            std::max<int64>(p.shape().dim(i).size(), -1));
91       }
92       strings::StrAppend(&s, "]");
93     }
94     return s;
95   }
96 
97   // Compare values of integer (DT_INT32 or DT_INT64) tensor against expected
98   // ones.
ExpectTensorValues(const std::vector<int64> & expected,const TensorProto & tensor_proto_to_compare)99   void ExpectTensorValues(const std::vector<int64>& expected,
100                           const TensorProto& tensor_proto_to_compare) {
101     Tensor tensor;
102     EXPECT_TRUE(tensor.FromProto(tensor_proto_to_compare));
103     EXPECT_EQ(expected.size(), tensor.NumElements());
104     // We're interested in only integer tensors as only shapes are exported as
105     // graph properties values.
106     CHECK(tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64);
107     if (tensor.dtype() == DT_INT32) {
108       for (int i = 0; i < tensor.NumElements(); i++) {
109         EXPECT_EQ(expected[i], tensor.flat<int32>()(i));
110       }
111     } else {
112       for (int i = 0; i < tensor.NumElements(); i++) {
113         EXPECT_EQ(expected[i], tensor.flat<int64>()(i));
114       }
115     }
116   }
117 
118   std::unique_ptr<SingleMachine> cluster_;
119   FunctionDefLibrary function_lib_;
120 };
121 
TEST_F(GraphPropertiesTest,StaticProperties)122 TEST_F(GraphPropertiesTest, StaticProperties) {
123   TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
124                                           cluster_->GetDeviceNames());
125   GrapplerItem item;
126   CHECK(fake_input.NextItem(&item));
127 
128   GraphProperties properties(item);
129   Status s = properties.InferStatically(true);
130   TF_CHECK_OK(s);
131 
132   for (const auto& node : item.graph.node()) {
133     if (node.op() == "RandomStandardNormal") {
134       // The node has one input (the shape of the tensor to generate).
135       EXPECT_EQ(1, properties.GetInputProperties(node.name()).size());
136       // The const node has one output.
137       const auto props = properties.GetOutputProperties(node.name());
138       EXPECT_EQ(1, props.size());
139       const OpInfo::TensorProperties& prop = props[0];
140       EXPECT_EQ(DT_FLOAT, prop.dtype());
141       EXPECT_FALSE(prop.shape().unknown_rank());
142       EXPECT_EQ(2, prop.shape().dim_size());
143       EXPECT_EQ(10, prop.shape().dim(0).size());
144       EXPECT_EQ(1, prop.shape().dim(1).size());
145     } else if (node.op() == "AddN") {
146       const auto in_props = properties.GetInputProperties(node.name());
147       EXPECT_EQ(1, in_props.size());
148       const OpInfo::TensorProperties& in_prop = in_props[0];
149       EXPECT_EQ(DT_FLOAT, in_prop.dtype());
150       EXPECT_FALSE(in_prop.shape().unknown_rank());
151       EXPECT_EQ(2, in_prop.shape().dim_size());
152       EXPECT_EQ(10, in_prop.shape().dim(0).size());
153       EXPECT_EQ(1, in_prop.shape().dim(1).size());
154       const auto out_props = properties.GetOutputProperties(node.name());
155       EXPECT_EQ(1, out_props.size());
156       string in_prop_str;
157       ::tensorflow::protobuf::TextFormat::PrintToString(in_prop, &in_prop_str);
158       string out_prop_str;
159       ::tensorflow::protobuf::TextFormat::PrintToString(out_props[0],
160                                                         &out_prop_str);
161       EXPECT_EQ(in_prop_str, out_prop_str);
162     }
163   }
164 }
165 
TEST_F(GraphPropertiesTest,ClearProperties)166 TEST_F(GraphPropertiesTest, ClearProperties) {
167   TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
168                                           cluster_->GetDeviceNames());
169   GrapplerItem item;
170   CHECK(fake_input.NextItem(&item));
171 
172   GraphProperties properties(item);
173   Status s = properties.InferStatically(true);
174   TF_CHECK_OK(s);
175 
176   for (const auto& node : item.graph.node()) {
177     if (node.op() == "RandomStandardNormal") {
178       EXPECT_EQ(1, properties.GetInputProperties(node.name()).size());
179       const auto props = properties.GetOutputProperties(node.name());
180       properties.ClearOutputProperties(node.name());
181       const auto cleared_props = properties.GetOutputProperties(node.name());
182       EXPECT_TRUE(cleared_props.empty());
183     } else if (node.op() == "AddN") {
184       const auto in_props = properties.GetInputProperties(node.name());
185       EXPECT_EQ(1, in_props.size());
186       properties.ClearInputProperties(node.name());
187       const auto cleared_props = properties.GetInputProperties(node.name());
188       EXPECT_TRUE(cleared_props.empty());
189     }
190   }
191 }
192 
TEST_F(GraphPropertiesTest,DynamicProperties)193 TEST_F(GraphPropertiesTest, DynamicProperties) {
194   TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
195                                           cluster_->GetDeviceNames());
196   GrapplerItem item;
197   CHECK(fake_input.NextItem(&item));
198 
199   GraphProperties properties(item);
200   TF_CHECK_OK(cluster_->Initialize(item));
201   Status s = properties.InferDynamically(cluster_.get());
202   TF_CHECK_OK(s);
203 
204   for (const auto& node : item.graph.node()) {
205     if (node.op() == "RandomStandardNormal") {
206       // The random node is missing from the cost graph (why ?)
207       EXPECT_EQ(0, properties.GetInputProperties(node.name()).size());
208     } else if (node.op() == "AddN") {
209       // Since the random node is missing, we can't infer the input properties
210       // of the first AddN node. The other AddN nodes have the expected
211       // properties.
212       if (node.name() == "AddN") {
213         const auto props = properties.GetInputProperties(node.name());
214         EXPECT_EQ(1, props.size());
215         const OpInfo::TensorProperties& prop = props[0];
216         EXPECT_EQ(DT_INVALID, prop.dtype());
217         EXPECT_TRUE(prop.shape().unknown_rank());
218       } else {
219         const auto props = properties.GetInputProperties(node.name());
220         EXPECT_EQ(1, props.size());
221         const OpInfo::TensorProperties& prop = props[0];
222         EXPECT_EQ(DT_FLOAT, prop.dtype());
223         EXPECT_FALSE(prop.shape().unknown_rank());
224         EXPECT_EQ(2, prop.shape().dim_size());
225         EXPECT_EQ(10, prop.shape().dim(0).size());
226         EXPECT_EQ(1, prop.shape().dim(1).size());
227         const auto out_props = properties.GetOutputProperties(node.name());
228         EXPECT_EQ(1, out_props.size());
229         string prop_str;
230         ::tensorflow::protobuf::TextFormat::PrintToString(prop, &prop_str);
231         string out_prop_str;
232         ::tensorflow::protobuf::TextFormat::PrintToString(out_props[0],
233                                                           &out_prop_str);
234         EXPECT_EQ(prop_str, out_prop_str);
235       }
236     }
237   }
238 }
239 
TEST_F(GraphPropertiesTest,Variables)240 TEST_F(GraphPropertiesTest, Variables) {
241   GrapplerItem item;
242   TF_CHECK_OK(NodeDefBuilder("Var", "Variable")
243                   .Attr("dtype", DT_FLOAT)
244                   .Attr("shape", TensorShape({3, 7}))
245                   .Finalize(item.graph.add_node()));
246   item.fetch.push_back("Var");
247 
248   Tensor initial_val(DT_FLOAT, TensorShape({3, 7}));
249   test::FillIota<float>(&initial_val, 0);
250   TF_CHECK_OK(NodeDefBuilder("InitialVal", "Const")
251                   .Attr("dtype", DT_FLOAT)
252                   .Attr("value", initial_val)
253                   .Finalize(item.graph.add_node()));
254   TF_CHECK_OK(NodeDefBuilder("InitVar", "Assign")
255                   .Input("Var", 0, DT_FLOAT_REF)
256                   .Input("InitialVal", 0, DT_FLOAT)
257                   .Finalize(item.graph.add_node()));
258   item.init_ops.push_back("InitVar");
259 
260   {
261     GraphProperties static_properties(item);
262     TF_CHECK_OK(static_properties.InferStatically(false));
263 
264     const auto props = static_properties.GetOutputProperties("Var");
265     EXPECT_EQ(1, props.size());
266     const OpInfo::TensorProperties& prop = props[0];
267     EXPECT_EQ(DT_FLOAT_REF, prop.dtype());
268     EXPECT_FALSE(prop.shape().unknown_rank());
269     EXPECT_EQ(2, prop.shape().dim_size());
270     EXPECT_EQ(3, prop.shape().dim(0).size());
271     EXPECT_EQ(7, prop.shape().dim(1).size());
272   }
273   {
274     TF_CHECK_OK(cluster_->Initialize(item));
275     GraphProperties dynamic_properties(item);
276     TF_CHECK_OK(dynamic_properties.InferDynamically(cluster_.get()));
277 
278     const auto props = dynamic_properties.GetOutputProperties("Var");
279     EXPECT_EQ(1, props.size());
280     const OpInfo::TensorProperties& prop = props[0];
281     EXPECT_EQ(DT_FLOAT_REF, prop.dtype());
282     EXPECT_FALSE(prop.shape().unknown_rank());
283     EXPECT_EQ(2, prop.shape().dim_size());
284     EXPECT_EQ(3, prop.shape().dim(0).size());
285     EXPECT_EQ(7, prop.shape().dim(1).size());
286   }
287 }
288 
TEST_F(GraphPropertiesTest,ReadVariableOpAfterEnter)289 TEST_F(GraphPropertiesTest, ReadVariableOpAfterEnter) {
290   GrapplerItem item;
291   TF_CHECK_OK(NodeDefBuilder("Var", "VarHandleOp")
292                   .Attr("dtype", DT_FLOAT)
293                   .Attr("shape", TensorShape({3, 7}))
294                   .Finalize(item.graph.add_node()));
295   TF_CHECK_OK(NodeDefBuilder("Enter", "Enter")
296                   .Attr("T", DT_RESOURCE)
297                   .Attr("frame_name", "while_context")
298                   .Attr("is_constant", true)
299                   .Attr("parallel_iterations", 10)
300                   .Input("Var", 0, DT_RESOURCE)
301                   .Finalize(item.graph.add_node()));
302   TF_CHECK_OK(NodeDefBuilder("ReadVariableOpAfterEnter", "ReadVariableOp")
303                   .Attr("dtype", DT_FLOAT)
304                   .Input("Enter", 0, DT_RESOURCE)
305                   .Finalize(item.graph.add_node()));
306 
307   GraphProperties properties(item);
308   TF_CHECK_OK(properties.InferStatically(false));
309   const auto props = properties.GetOutputProperties("ReadVariableOpAfterEnter");
310   EXPECT_EQ(1, props.size());
311   const OpInfo::TensorProperties& prop = props[0];
312   EXPECT_EQ(DT_FLOAT, prop.dtype());
313   EXPECT_FALSE(prop.shape().unknown_rank());
314   EXPECT_EQ(2, prop.shape().dim_size());
315   EXPECT_EQ(3, prop.shape().dim(0).size());
316   EXPECT_EQ(7, prop.shape().dim(1).size());
317 }
318 
TEST_F(GraphPropertiesTest,VarHandles)319 TEST_F(GraphPropertiesTest, VarHandles) {
320   GrapplerItem item;
321   TF_CHECK_OK(NodeDefBuilder("Var", "VarHandleOp")
322                   .Attr("dtype", DT_FLOAT)
323                   .Attr("shape", TensorShape({3, 7}))
324                   .Finalize(item.graph.add_node()));
325 
326   TF_CHECK_OK(NodeDefBuilder("VarRead", "ReadVariableOp")
327                   .Attr("dtype", DT_FLOAT)
328                   .Input("Var", 0, DT_RESOURCE)
329                   .Finalize(item.graph.add_node()));
330 
331   GraphProperties properties(item);
332   TF_CHECK_OK(properties.InferStatically(false));
333 
334   const auto props = properties.GetOutputProperties("VarRead");
335   EXPECT_EQ(1, props.size());
336   const OpInfo::TensorProperties& prop = props[0];
337   EXPECT_EQ(DT_FLOAT, prop.dtype());
338   EXPECT_FALSE(prop.shape().unknown_rank());
339   EXPECT_EQ(2, prop.shape().dim_size());
340   EXPECT_EQ(3, prop.shape().dim(0).size());
341   EXPECT_EQ(7, prop.shape().dim(1).size());
342 }
343 
TEST_F(GraphPropertiesTest,WhileLoopWithVarHandleOpInput)344 TEST_F(GraphPropertiesTest, WhileLoopWithVarHandleOpInput) {
345   // Test graph is first generated in python using:
346   /*
347     i0 = tf.constant(0)
348     v = tf.get_variable(initializer=i0, name='loop_var', use_resource=True)
349     def cond(i, x):
350       return i < 3
351     def body(i, x):
352       return i + 1, x + x
353     v, y = tf.while_loop(cond, body, loop_vars=[v, tf.constant(1)])
354   */
355   // and then modified by hand such that the ReadVariableOp is inside the loop
356   // body instead of outside the while loop (which is the case when constructed
357   // using the python API), such that we have the following pattern: VarHandleOp
358   // -> Enter -> Switch -> ReadVariableOp -> other parts of loop body. Note
359   // DT_RESOURCE is passed all the way until ReadVariableOp.
360   GrapplerItem item;
361   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
362                                  "while_loop_var_handle_op.pbtxt");
363   TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
364   GraphProperties properties(item);
365   TF_CHECK_OK(properties.InferStatically(false));
366 
367   std::vector<string> resource_nodes{
368       "loop_var",       "while/Enter",         "while/Merge", "while/Switch",
369       "while/Identity", "while/NextIteration", "while/Exit"};
370   for (const string& node : resource_nodes) {
371     const auto props = properties.GetOutputProperties(node);
372     EXPECT_GE(props.size(), 1);  // Merge has 2 outputs.
373     EXPECT_EQ("resource: []", PropToString(props[0]));
374   }
375 
376   // After ReadVariableOp, the shape should be recovered.
377   const auto props = properties.GetOutputProperties("while/ReadVariableOp");
378   EXPECT_EQ(1, props.size());
379   EXPECT_EQ("int32: []", PropToString(props[0]));
380 }
381 
TEST_F(GraphPropertiesTest,QueueWithOnlyDequeue_NoShapeAttr)382 TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_NoShapeAttr) {
383   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
384   auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT});
385   auto dequeue1 =
386       ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
387 
388   GrapplerItem item;
389   TF_CHECK_OK(root.ToGraphDef(&item.graph));
390 
391   GraphProperties properties(item);
392   TF_CHECK_OK(properties.InferStatically(false));
393 
394   const auto props1 = properties.GetOutputProperties("Dequeue1");
395   ASSERT_EQ(1, props1.size());
396   EXPECT_EQ("float: ?", PropToString(props1[0]));
397 }
398 
TEST_F(GraphPropertiesTest,QueueWithOnlyDequeue_ShapeAttr)399 TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_ShapeAttr) {
400   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
401   auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT},
402                            ops::FIFOQueue::Attrs().Shapes({{3, 7, 1}}));
403   auto dequeue1 =
404       ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
405 
406   GrapplerItem item;
407   TF_CHECK_OK(root.ToGraphDef(&item.graph));
408 
409   GraphProperties properties(item);
410   TF_CHECK_OK(properties.InferStatically(false));
411 
412   const auto props1 = properties.GetOutputProperties("Dequeue1");
413   ASSERT_EQ(1, props1.size());
414   EXPECT_EQ("float: [3,7,1]", PropToString(props1[0]));
415 }
416 
TEST_F(GraphPropertiesTest,QueueWithOnlyDequeue_PartialShapeAttr)417 TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_PartialShapeAttr) {
418   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
419   auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT},
420                            ops::FIFOQueue::Attrs().Shapes({{3, 7, -1}}));
421   auto dequeue1 =
422       ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
423 
424   GrapplerItem item;
425   TF_CHECK_OK(root.ToGraphDef(&item.graph));
426 
427   GraphProperties properties(item);
428   TF_CHECK_OK(properties.InferStatically(false));
429 
430   const auto props1 = properties.GetOutputProperties("Dequeue1");
431   ASSERT_EQ(1, props1.size());
432   EXPECT_EQ("float: [3,7,-1]", PropToString(props1[0]));
433 }
434 
TEST_F(GraphPropertiesTest,Queues)435 TEST_F(GraphPropertiesTest, Queues) {
436   // Create a graph with known input shapes, and propagate the shapes through a
437   // couple of queues.
438   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
439 
440   auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT});
441   Output rnd =
442       ops::RandomNormal(root.WithOpName("rnd"), {3, 7}, DataType::DT_FLOAT);
443   Output square1 = ops::Square(root.WithOpName("Square1"), rnd);
444   auto enqueue1 = ops::QueueEnqueue(root.WithOpName("Enqueue1"), q1, {square1});
445   auto dequeue1 =
446       ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
447 
448   auto q2 =
449       ops::RandomShuffleQueue(root.WithOpName("Queue2"), {DataType::DT_FLOAT});
450   Output square2 = ops::Square(root.WithOpName("Square2"), dequeue1[0]);
451   auto enqueue2 = ops::QueueEnqueue(root.WithOpName("Enqueue2"), q2, {square2});
452   auto dequeue2 =
453       ops::QueueDequeue(root.WithOpName("Dequeue2"), q2, {DataType::DT_FLOAT});
454 
455   auto q4 =
456       ops::RandomShuffleQueue(root.WithOpName("Queue4"), {DataType::DT_FLOAT});
457   auto enqueue4 = ops::QueueEnqueue(root.WithOpName("Enqueue4"), q4, {square2});
458   auto enqueue4_2 =
459       ops::QueueEnqueue(root.WithOpName("Enqueue4_2"), q4, {dequeue2[0]});
460   auto dequeue4 =
461       ops::QueueDequeue(root.WithOpName("Dequeue4"), q4, {DataType::DT_FLOAT});
462 
463   // Create a queue that takes in three tensors.
464   auto q5 = ops::RandomShuffleQueue(
465       root.WithOpName("Queue5"),
466       {DataType::DT_FLOAT, DataType::DT_DOUBLE, DataType::DT_FLOAT});
467   Output rnd2 =
468       ops::RandomNormal(root.WithOpName("rnd2"), {10}, DataType::DT_DOUBLE);
469   Output rnd3 =
470       ops::RandomNormal(root.WithOpName("rnd3"), {1, 2, 3}, DataType::DT_FLOAT);
471   auto enqueue5 =
472       ops::QueueEnqueue(root.WithOpName("Enqueue5"), q5, {rnd, rnd2, rnd3});
473   auto dequeue5 = ops::QueueDequeue(
474       root.WithOpName("Dequeue5"), q5,
475       {DataType::DT_FLOAT, DataType::DT_DOUBLE, DataType::DT_FLOAT});
476 
477   GrapplerItem item;
478   TF_CHECK_OK(root.ToGraphDef(&item.graph));
479 
480   GraphProperties properties(item);
481   TF_CHECK_OK(properties.InferStatically(false));
482 
483   const auto props1 = properties.GetOutputProperties("Dequeue1");
484   ASSERT_EQ(1, props1.size());
485   EXPECT_EQ("float: [3,7]", PropToString(props1[0]));
486 
487   const auto props2 = properties.GetOutputProperties("Dequeue2");
488   ASSERT_EQ(1, props2.size());
489   EXPECT_EQ("float: [3,7]", PropToString(props2[0]));
490 
491   // The dequeue3 op shape is unknown. The square2 op shape is known. Verify
492   // that we merge the 2 properly to determine the shape of the data coming out
493   // of the queue.
494   const auto props4 = properties.GetOutputProperties("Dequeue4");
495   ASSERT_EQ(1, props4.size());
496   EXPECT_EQ("float: [3,7]", PropToString(props4[0]));
497 
498   // The dequeue5 op shape is known.
499   const auto props5 = properties.GetOutputProperties("Dequeue5");
500   ASSERT_EQ(3, props5.size());
501   EXPECT_EQ("float: [3,7]", PropToString(props5[0]));
502   EXPECT_EQ("double: [10]", PropToString(props5[1]));
503   EXPECT_EQ("float: [1,2,3]", PropToString(props5[2]));
504 }
505 
TEST_F(GraphPropertiesTest,MergeWithoutLoops)506 TEST_F(GraphPropertiesTest, MergeWithoutLoops) {
507   // Test graph produced in python using:
508   /*
509     with tf.Graph().as_default():
510       x = tf.constant(2)
511       y = tf.constant(5)
512       z = tf.ones([1,1,1])
513       def f1(): return tf.concat([z, z], axis=0)
514       def f2(): return tf.concat([z, z], axis=1)
515       r = tf.cond(tf.less(x, y), f1, f2)
516       tf.concat([r, r], axis=2)
517       with open('/tmp/graph.pbtxt', 'w') as f:
518         f.write(str(tf.get_default_graph().as_graph_def()))
519    */
520 
521   GrapplerItem item;
522   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
523                                  "merge_without_loops.pbtxt");
524   TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
525   GraphProperties properties(item);
526   TF_CHECK_OK(properties.InferStatically(false));
527 
528   std::vector<string> nodes{"cond/Merge", "cond/concat", "cond/concat_1"};
529   std::vector<string> expected_outputs{"float: [-1,-1,1]", "float: [2,1,1]",
530                                        "float: [1,2,1]"};
531   for (int i = 0; i < nodes.size(); i++) {
532     const auto props = properties.GetOutputProperties(nodes[i]);
533     const OpInfo::TensorProperties& prop = props[0];
534     EXPECT_EQ(DT_FLOAT, prop.dtype());
535     EXPECT_EQ(expected_outputs[i], PropToString(prop));
536   }
537 
538   // The "Less" node should be fed by 2 int32 scalar constant values.
539   const auto props = properties.GetInputProperties("Less");
540   EXPECT_EQ(2, props.size());
541   for (int i = 0; i < props.size(); ++i) {
542     EXPECT_EQ(DT_INT32, props[i].dtype());
543     EXPECT_TRUE(props[i].has_value());
544     EXPECT_EQ("int32: []", PropToString(props[i]));
545   }
546 }
547 
TEST_F(GraphPropertiesTest,WhileLoop)548 TEST_F(GraphPropertiesTest, WhileLoop) {
549   // Test graph produced in python using:
550   /*
551      with tf.Graph().as_default():
552        i0 = tf.constant(0)
553        m0 = tf.placeholder([-1, 2])
554        c = lambda i, m: i < 10
555        b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
556        r = tf.while_loop(
557               c, b, loop_vars=[i0, m0],
558               shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
559        with open('/tmp/graph.pbtxt', 'w') as f:
560          f.write(str(tf.get_default_graph().as_graph_def()))
561   */
562 
563   GrapplerItem item;
564   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
565                                  "while_loop.pbtxt");
566   TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
567   GraphProperties properties(item);
568   TF_CHECK_OK(properties.InferStatically(false));
569 
570   std::vector<string> nodes{"while/Merge_1", "while/NextIteration_1",
571                             "while/Exit_1"};
572   for (const string& node : nodes) {
573     const auto props = properties.GetOutputProperties(node);
574     const OpInfo::TensorProperties& prop = props[0];
575     EXPECT_EQ(DT_FLOAT, prop.dtype());
576     EXPECT_EQ("float: [-1,2]", PropToString(prop));
577   }
578 
579   // The loop outputs batch dim should be different from the input batch dim
580   // since we concatenated along the batch dim.
581   auto shape_in = properties.GetOutputProperties("ones").at(0).shape();
582   auto shape_out = properties.GetOutputProperties("while/Exit_1").at(0).shape();
583   EXPECT_GE(-2, shape_in.dim(0).size());
584   EXPECT_GE(-2, shape_out.dim(0).size());
585   EXPECT_NE(shape_in.dim(0).size(), shape_out.dim(0).size());
586 }
587 
TEST_F(GraphPropertiesTest,NestedLoop)588 TEST_F(GraphPropertiesTest, NestedLoop) {
589   // Test graph produced in python using:
590   /*
591     with tf.Graph().as_default():
592       i0 = tf.constant(0)
593 
594       def inner(j, y):
595         def inner_cond(j, y):
596           return j < 3
597 
598         def inner_body(j, y):
599           return j+1, tf.concat([y, y], axis=2)
600 
601         return tf.while_loop(inner_cond, inner_body, loop_vars=[j, y],
602                              shape_invariants=[i0.get_shape(),
603                                               tf.TensorShape([None, 1, None])])
604 
605       def outer_cond(i, x):
606         return i < 3
607 
608       def outer_body(i, x):
609         j, y = inner(0, x)
610         return i+1, tf.concat([x, x], axis=0)
611 
612       r = tf.while_loop(outer_cond, outer_body,
613                         loop_vars=[i0, tf.ones([1, 1, 1])],
614                         shape_invariants=[i0.get_shape(),
615                                           tf.TensorShape([None, 1, None])])
616 
617       with open('/tmp/graph.pbtxt', 'w') as f:
618         f.write(str(tf.get_default_graph().as_graph_def()))
619   */
620 
621   GrapplerItem item;
622   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
623                                  "nested_loop.pbtxt");
624   TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
625   GraphProperties properties(item);
626   TF_CHECK_OK(properties.InferStatically(false));
627 
628   std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
629                                   "while/Exit_1"};
630   std::vector<string> inner_nodes{"while/while/Merge_1",
631                                   "while/while/NextIteration_1",
632                                   "while/while/Exit_1"};
633   for (const string& node : outer_nodes) {
634     const auto props = properties.GetOutputProperties(node);
635     const OpInfo::TensorProperties& prop = props[0];
636     EXPECT_EQ(DT_FLOAT, prop.dtype());
637     EXPECT_EQ("float: [-1,1,1]", PropToString(prop));
638   }
639   for (const string& node : inner_nodes) {
640     const auto props = properties.GetOutputProperties(node);
641     const OpInfo::TensorProperties& prop = props[0];
642     EXPECT_EQ(DT_FLOAT, prop.dtype());
643     EXPECT_EQ("float: [-1,1,-1]", PropToString(prop));
644   }
645 }
646 
TEST_F(GraphPropertiesTest,LoopsAndQueues)647 TEST_F(GraphPropertiesTest, LoopsAndQueues) {
648   // Test graph produced in python using:
649   /*
650     with tf.Graph().as_default():
651       i0 = tf.constant(0)
652       q = tf.FIFOQueue(1, "float")
653 
654       def inner(j, y):
655         def inner_cond(j, y):
656           return j < 3
657 
658         def inner_body(j, y):
659           return j+1, tf.concat([y, y], axis=0)
660 
661         return tf.while_loop(inner_cond, inner_body,
662                              loop_vars=[j, y],
663                              shape_invariants=[i0.get_shape(),
664                                                tf.TensorShape(None)])
665 
666       def outer_cond(i, x):
667         return i < 3
668 
669       def outer_body(i, x):
670         q.enqueue(x)
671         y = tf.concat([x, x], axis=2)
672         inner(0, q.dequeue())
673         return i+1, y
674 
675       i, z = tf.while_loop(outer_cond, outer_body,
676                            loop_vars=[i0, tf.ones([1, 1, 1])],
677                            shape_invariants=[i0.get_shape(),
678                                              tf.TensorShape([None, 1, None])])
679 
680       with open('/tmp/graph.pbtxt', 'w') as f:
681         f.write(str(tf.get_default_graph().as_graph_def()))
682    */
683 
684   GrapplerItem item;
685   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
686                                  "loops_and_queues.pbtxt");
687   TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
688   GraphProperties properties(item);
689   TF_CHECK_OK(properties.InferStatically(false));
690 
691   std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
692                                   "while/Exit_1"};
693   std::vector<string> inner_nodes{"while/while/Merge_1",
694                                   "while/while/NextIteration_1",
695                                   "while/while/Exit_1"};
696   for (const string& node : outer_nodes) {
697     const auto props = properties.GetOutputProperties(node);
698     const OpInfo::TensorProperties& prop = props[0];
699     EXPECT_EQ(DT_FLOAT, prop.dtype());
700     EXPECT_EQ("float: [1,1,-1]", PropToString(prop));
701   }
702   for (const string& node : inner_nodes) {
703     const auto props = properties.GetOutputProperties(node);
704     const OpInfo::TensorProperties& prop = props[0];
705     EXPECT_EQ(DT_FLOAT, prop.dtype());
706     EXPECT_EQ("float: [-1,1,-1]", PropToString(prop));
707   }
708 }
709 
TEST_F(GraphPropertiesTest,LoopsAndResourceVars)710 TEST_F(GraphPropertiesTest, LoopsAndResourceVars) {
711   // Test graph produced in python using:
712   /*
713     with tf.Graph().as_default():
714       i0 = tf.constant(0)
715       with tf.variable_scope(VariableScope(reuse=None, use_resource=True)):
716         v = tf.get_variable(initializer=i0, name='loop_var')
717 
718       def inner(j, y):
719         def inner_cond(j, y):
720           return j < 3
721 
722         def inner_body(j, y):
723           return j + 1, y + y
724 
725         return tf.while_loop(inner_cond, inner_body, loop_vars=[j, y])
726 
727       def outer_cond(i, x):
728         return i < 3
729 
730       def outer_body(i, x):
731         y = x + x
732         inner(0, v)
733         return i + 1, y
734 
735       v, z = tf.while_loop(outer_cond, outer_body,
736                            loop_vars=[v, tf.constant(1)])
737 
738       with open('/tmp/graph.pbtxt', 'w') as f:
739         f.write(str(tf.get_default_graph().as_graph_def()))
740   */
741 
742   GrapplerItem item;
743   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
744                                  "loops_and_resource_vars.pbtxt");
745   TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
746   GraphProperties properties(item);
747   TF_CHECK_OK(properties.InferStatically(false));
748 
749   std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
750                                   "while/Exit_1"};
751   std::vector<string> inner_nodes{"while/while/Merge_1",
752                                   "while/while/NextIteration_1",
753                                   "while/while/Exit_1"};
754   for (const string& node : outer_nodes) {
755     const auto props = properties.GetOutputProperties(node);
756     const OpInfo::TensorProperties& prop = props[0];
757     EXPECT_EQ(DT_INT32, prop.dtype());
758     EXPECT_EQ("int32: []", PropToString(prop));
759   }
760   for (const string& node : inner_nodes) {
761     const auto props = properties.GetOutputProperties(node);
762     const OpInfo::TensorProperties& prop = props[0];
763     EXPECT_EQ(DT_INT32, prop.dtype());
764     EXPECT_EQ("int32: []", PropToString(prop));
765   }
766 }
767 
TEST_F(GraphPropertiesTest,QueuesAndLoops)768 TEST_F(GraphPropertiesTest, QueuesAndLoops) {
769   // Test graph produced in python using:
770   /*
771     with tf.Graph().as_default():
772       i0 = tf.constant(0)
773       q0 = tf.FIFOQueue(1, "float")
774       q0.enqueue(tf.ones([2, 2]))
775       q1 = tf.FIFOQueue(1, "float")
776 
777       def c(i, m):
778         return i < 10
779 
780       def b(i, m):
781         return i+1, tf.concat([m, m], axis=0)
782 
783       i, m = tf.while_loop(
784           c, b, loop_vars=[i0,  q0.dequeue()],
785           shape_invariants=[i0.get_shape(), tf.TensorShape(None)])
786 
787       q1.enqueue(m)
788       v = q1.dequeue();
789       tf.concat([v, v], axis=1)
790       with open('/tmp/graph.pbtxt', 'w') as f:
791         f.write(str(tf.get_default_graph().as_graph_def()))
792   */
793 
794   GrapplerItem item;
795   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
796                                  "queues_and_loops.pbtxt");
797   TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
798   GraphProperties properties(item);
799   TF_CHECK_OK(properties.InferStatically(false));
800 
801   std::vector<string> nodes{"while/Merge_1", "while/NextIteration_1",
802                             "while/Exit_1"};
803 
804   for (const string& node : nodes) {
805     const auto props = properties.GetOutputProperties(node);
806     const OpInfo::TensorProperties& prop = props[0];
807     EXPECT_EQ(DT_FLOAT, prop.dtype());
808     EXPECT_EQ("float: [-1,2]", PropToString(prop));
809   }
810 
811   const auto props = properties.GetOutputProperties("concat");
812   const OpInfo::TensorProperties& prop = props[0];
813   EXPECT_EQ(DT_FLOAT, prop.dtype());
814   EXPECT_EQ("float: [-1,4]", PropToString(prop));
815 }
816 
TEST_F(GraphPropertiesTest,InferRestoreOpShape)817 TEST_F(GraphPropertiesTest, InferRestoreOpShape) {
818   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
819   Output var = ops::Variable(s.WithOpName("var"), TensorShape({128, 256}),
820                              DataType::DT_FLOAT);
821   Output filename =
822       ops::Const(s.WithOpName("filename"), string("model"), TensorShape());
823   Output tensor_name =
824       ops::Const(s.WithOpName("tensorname"), string("a"), TensorShape());
825   Output restore = ops::Restore(s.WithOpName("restore"), filename, tensor_name,
826                                 DataType::DT_FLOAT);
827   Output init_restore = ops::Assign(s.WithOpName("init_restore"), var, restore);
828 
829   Output shape_and_slice = ops::Const(s.WithOpName("shape_and_slice"),
830                                       string("256 256 0,128:-"), TensorShape());
831   Output restore_slice =
832       ops::RestoreSlice(s.WithOpName("restore_slice"), filename, tensor_name,
833                         shape_and_slice, DataType::DT_FLOAT);
834   Output init_restore_slice =
835       ops::Assign(s.WithOpName("init_restore_slice"), var, restore_slice);
836 
837   Output restore_v2 =
838       ops::RestoreSlice(s.WithOpName("restore_v2"), filename, tensor_name,
839                         shape_and_slice, DataType::DT_FLOAT);
840   Output init_restore_v2 =
841       ops::Assign(s.WithOpName("init_restore_v2"), var, restore_v2);
842 
843   GrapplerItem item;
844   TF_CHECK_OK(s.ToGraphDef(&item.graph));
845   item.fetch.push_back("init_restore");
846 
847   GraphProperties properties(item);
848   TF_CHECK_OK(properties.InferStatically(false));
849 
850   const auto restore_props = properties.GetOutputProperties("restore");
851   const OpInfo::TensorProperties& restore_prop = restore_props[0];
852   EXPECT_EQ(DT_FLOAT, restore_prop.dtype());
853   EXPECT_EQ("float: [128,256]", PropToString(restore_prop));
854 
855   const auto restore_slice_props =
856       properties.GetOutputProperties("restore_slice");
857   const OpInfo::TensorProperties& restore_slice_prop = restore_slice_props[0];
858   EXPECT_EQ(DT_FLOAT, restore_slice_prop.dtype());
859   EXPECT_EQ("float: [128,256]", PropToString(restore_slice_prop));
860 
861   const auto restorev2_props = properties.GetOutputProperties("restore_v2");
862   const OpInfo::TensorProperties& restorev2_prop = restorev2_props[0];
863   EXPECT_EQ(DT_FLOAT, restorev2_prop.dtype());
864   EXPECT_EQ("float: [128,256]", PropToString(restorev2_prop));
865 
866   // Check input shapes of assign op are propagted correctly.
867   const auto input_props = properties.GetInputProperties("init_restore");
868   ASSERT_EQ(2, input_props.size());
869   const OpInfo::TensorProperties& input_prop = input_props[1];
870   EXPECT_EQ(DT_FLOAT, input_prop.dtype());
871   EXPECT_EQ("float: [128,256]", PropToString(input_prop));
872 }
873 
TEST_F(GraphPropertiesTest,InferRestoreOpShape_WithTwoNodesShareSameOutput)874 TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) {
875   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
876   Output var = ops::Variable(s.WithOpName("var"), PartialTensorShape(),
877                              DataType::DT_FLOAT);
878   Output var2 = ops::Variable(s.WithOpName("var2"), TensorShape({128, 256}),
879                               DataType::DT_FLOAT);
880   Output filename =
881       ops::Const(s.WithOpName("filename"), string("model"), TensorShape());
882   Output tensor_name =
883       ops::Const(s.WithOpName("tensorname"), string("a"), TensorShape());
884   Output restore = ops::Restore(s.WithOpName("restore"), filename, tensor_name,
885                                 DataType::DT_FLOAT);
886   Output init = ops::Assign(s.WithOpName("init"), var, restore);
887   Output init2 = ops::Assign(s.WithOpName("init2"), var2, restore);
888 
889   GrapplerItem item;
890   TF_CHECK_OK(s.ToGraphDef(&item.graph));
891   item.fetch.push_back("init");
892   item.fetch.push_back("init2");
893 
894   GraphProperties properties(item);
895   TF_CHECK_OK(properties.InferStatically(false));
896 
897   const auto props = properties.GetOutputProperties("restore");
898   const OpInfo::TensorProperties& prop = props[0];
899   EXPECT_EQ(DT_FLOAT, prop.dtype());
900   EXPECT_EQ("float: [128,256]", PropToString(prop));
901 }
902 
TEST_F(GraphPropertiesTest,TensorAsShapesPropagation)903 TEST_F(GraphPropertiesTest, TensorAsShapesPropagation) {
904   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
905   Output a = ops::Const(s.WithOpName("a"), {5, 7}, {2});
906   Output a1 = ops::Identity(s.WithOpName("a1"), a);
907   Output b = ops::Const(s.WithOpName("b"), 99, {});
908   Output b1 = ops::Identity(s.WithOpName("b1"), b);
909   Output c = ops::Const(s.WithOpName("c"), 1, {4, 4, 4});
910   Output c1 = ops::Identity(s.WithOpName("c1"), c);
911 
912   GrapplerItem item;
913   TF_CHECK_OK(s.ToGraphDef(&item.graph));
914   GraphProperties properties(item);
915   TF_CHECK_OK(properties.InferStatically(false));
916 
917   // Check output shapes.
918   EXPECT_EQ("int32: [2]", PropToString(properties.GetOutputProperties("a")[0]));
919   EXPECT_EQ("int32: [2]",
920             PropToString(properties.GetOutputProperties("a1")[0]));
921   EXPECT_EQ("int32: []", PropToString(properties.GetOutputProperties("b")[0]));
922   EXPECT_EQ("int32: []", PropToString(properties.GetOutputProperties("b1")[0]));
923   EXPECT_EQ("int32: [4,4,4]",
924             PropToString(properties.GetOutputProperties("c")[0]));
925   EXPECT_EQ("int32: [4,4,4]",
926             PropToString(properties.GetOutputProperties("c1")[0]));
927 
928   // Check has_value.
929   EXPECT_TRUE(properties.GetOutputProperties("a")[0].has_value());
930   EXPECT_TRUE(properties.GetInputProperties("a1")[0].has_value());
931   EXPECT_TRUE(properties.GetOutputProperties("a1")[0].has_value());
932   EXPECT_TRUE(properties.GetOutputProperties("b")[0].has_value());
933   EXPECT_TRUE(properties.GetInputProperties("b1")[0].has_value());
934   EXPECT_TRUE(properties.GetOutputProperties("b1")[0].has_value());
935   EXPECT_TRUE(properties.GetOutputProperties("c")[0].has_value());
936   EXPECT_TRUE(properties.GetInputProperties("c1")[0].has_value());
937   // Note that we propagate tensor value of only 1D vector and scalar.
938   EXPECT_TRUE(properties.GetOutputProperties("c1")[0].has_value());
939 
940   // Check values.
941   ExpectTensorValues({5, 7}, properties.GetOutputProperties("a")[0].value());
942   ExpectTensorValues({5, 7}, properties.GetInputProperties("a1")[0].value());
943   ExpectTensorValues({5, 7}, properties.GetOutputProperties("a1")[0].value());
944   ExpectTensorValues({99}, properties.GetOutputProperties("b")[0].value());
945   ExpectTensorValues({99}, properties.GetInputProperties("b1")[0].value());
946   ExpectTensorValues({99}, properties.GetOutputProperties("b1")[0].value());
947   std::vector<int64> c_values;
948   for (int i = 0; i < 4 * 4 * 4; i++) {
949     c_values.push_back(1);
950   }
951   ExpectTensorValues({c_values},
952                      properties.GetOutputProperties("c")[0].value());
953   ExpectTensorValues({c_values},
954                      properties.GetInputProperties("c1")[0].value());
955   ExpectTensorValues({c_values},
956                      properties.GetOutputProperties("c1")[0].value());
957 }
958 
TEST_F(GraphPropertiesTest,IdentityPassingShape)959 TEST_F(GraphPropertiesTest, IdentityPassingShape) {
960   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
961   Output a = ops::Const(s.WithOpName("a"), 5, {2});
962   Output b = ops::Identity(s.WithOpName("b"), a);
963   Output c = ops::Const(s.WithOpName("const"), 0.1f, {});
964   // Fill needs not only e's shape but also the value of e to figure out output
965   // shape; hence, Identity op (b) should pass a's value as
966   // output_tensors_as_shape.
967   Output d = ops::Fill(s.WithOpName("fill"), b, c);
968 
969   GrapplerItem item;
970   TF_CHECK_OK(s.ToGraphDef(&item.graph));
971   GraphProperties properties(item);
972   TF_CHECK_OK(properties.InferStatically(false));
973   const auto out_props = properties.GetOutputProperties("fill");
974   const OpInfo::TensorProperties out_prop0 = out_props[0];
975   EXPECT_EQ("float: [5,5]", PropToString(out_prop0));
976 }
977 
TEST_F(GraphPropertiesTest,SkippingValueInferenceForLargeTensors)978 TEST_F(GraphPropertiesTest, SkippingValueInferenceForLargeTensors) {
979   // When using aggressive_shape_inference, we run EvaluateNode() for
980   // whitelisted ops and small input / output tensors. For instance, Fill op is
981   // evaluated and produces output tensor value if output tensor size is smal
982   // (currently, fewer than 17 elements); otherwise we don't run EvalauteNode().
983   // This is to avoid wasting time and memory for producing huge tensors (e.g.,
984   // initializing a large table using Fill.
985   {
986     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
987     Output a = ops::Const(s.WithOpName("a"), 4, {2});  // 4x4
988     Output b = ops::Const(s.WithOpName("const"), 0.1f, {});
989     // Shape described by a is small; expect output values of Fill op.
990     Output c = ops::Fill(s.WithOpName("fill"), a, b);
991 
992     GrapplerItem item;
993     TF_CHECK_OK(s.ToGraphDef(&item.graph));
994     GraphProperties properties(item);
995     TF_CHECK_OK(properties.InferStatically(
996         /*assume_valid_feeds=*/false,
997         /*aggressive_shape_inference=*/true));
998     const auto out_props = properties.GetOutputProperties("fill");
999     const OpInfo::TensorProperties out_prop0 = out_props[0];
1000     EXPECT_EQ("float: [4,4]", PropToString(out_prop0));
1001     EXPECT_TRUE(out_prop0.has_value());
1002   }
1003   {
1004     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1005     Output a = ops::Const(s.WithOpName("a"), 1000, {4});  // 1000x1000x1000x1000
1006     Output b = ops::Const(s.WithOpName("const"), 0.1f, {});
1007     // Shape described by a is huge; in that case we skip value inference.
1008     // Otherwise, it'd be too much overhead.
1009     Output c = ops::Fill(s.WithOpName("fill"), a, b);
1010 
1011     GrapplerItem item;
1012     TF_CHECK_OK(s.ToGraphDef(&item.graph));
1013     GraphProperties properties(item);
1014     TF_CHECK_OK(properties.InferStatically(
1015         /*assume_valid_feeds=*/false,
1016         /*aggressive_shape_inference=*/true));
1017     const auto out_props = properties.GetOutputProperties("fill");
1018     const OpInfo::TensorProperties out_prop0 = out_props[0];
1019     EXPECT_EQ("float: [1000,1000,1000,1000]", PropToString(out_prop0));
1020     EXPECT_FALSE(out_prop0.has_value());
1021   }
1022 }
1023 
TEST_F(GraphPropertiesTest,PackWithConstInput)1024 TEST_F(GraphPropertiesTest, PackWithConstInput) {
1025   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1026   Output a = ops::Const(s.WithOpName("a"), 1, {});
1027   Output b = ops::Const(s.WithOpName("b"), 2, {});
1028   Output c = ops::Const(s.WithOpName("c"), 3, {});
1029   Output d = ops::Const(s.WithOpName("d"), 4, {});
1030   // Note ops::Stack instantiates Pack op.
1031   Output e = ops::Stack(s.WithOpName("pack"), {a, b, c, d});
1032   // e is rank 1 tensor: shape = {4}, and its value is {1, 2, 3, 4}
1033   Output f = ops::Const(s.WithOpName("const"), 0.1f, {});
1034   // Fill needs not only e's shape but also its value to figure out output
1035   // shape.
1036   Output g = ops::Fill(s.WithOpName("fill"), e, f);
1037 
1038   GrapplerItem item;
1039   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1040   GraphProperties properties(item);
1041   TF_CHECK_OK(properties.InferStatically(false));
1042   const auto out_props = properties.GetOutputProperties("fill");
1043   const OpInfo::TensorProperties out_prop0 = out_props[0];
1044   EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
1045 }
1046 
TEST_F(GraphPropertiesTest,RankOp)1047 TEST_F(GraphPropertiesTest, RankOp) {
1048   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1049   Output c = ops::Const(s.WithOpName("Const"), 1, {4, 4, 4});
1050   Output r = ops::Rank(s.WithOpName("Rank"), c);
1051   Output i = ops::Identity(s.WithOpName("Identity"), r);
1052 
1053   GrapplerItem item;
1054   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1055   GraphProperties properties(item);
1056   TF_CHECK_OK(properties.InferStatically(false));
1057   const auto rank_props = properties.GetOutputProperties("Rank");
1058   const OpInfo::TensorProperties rank_prop0 = rank_props[0];
1059   EXPECT_EQ("int32: []", PropToString(rank_prop0));
1060   EXPECT_TRUE(rank_prop0.has_value());
1061   ExpectTensorValues({3}, rank_prop0.value());
1062   const auto identity_props = properties.GetOutputProperties("Identity");
1063   const OpInfo::TensorProperties identity_props0 = identity_props[0];
1064   EXPECT_EQ("int32: []", PropToString(identity_props0));
1065   EXPECT_TRUE(identity_props0.has_value());
1066   ExpectTensorValues({3}, identity_props0.value());
1067 }
1068 
TEST_F(GraphPropertiesTest,SizeOp)1069 TEST_F(GraphPropertiesTest, SizeOp) {
1070   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1071   Output c = ops::Const(s.WithOpName("Const"), 1, {1, 2, 3, 4});
1072   Output r = ops::Size(s.WithOpName("Size"), c);
1073   Output i = ops::Identity(s.WithOpName("Identity"), r);
1074 
1075   GrapplerItem item;
1076   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1077   GraphProperties properties(item);
1078   TF_CHECK_OK(properties.InferStatically(false));
1079   const auto size_props = properties.GetOutputProperties("Size");
1080   const OpInfo::TensorProperties size_props0 = size_props[0];
1081   EXPECT_EQ("int32: []", PropToString(size_props0));
1082   EXPECT_TRUE(size_props0.has_value());
1083   ExpectTensorValues({24}, size_props0.value());
1084   const auto identity_props = properties.GetOutputProperties("Identity");
1085   const OpInfo::TensorProperties identity_props0 = identity_props[0];
1086   EXPECT_EQ("int32: []", PropToString(identity_props0));
1087   EXPECT_TRUE(identity_props0.has_value());
1088   ExpectTensorValues({24}, identity_props0.value());
1089 }
1090 
TEST_F(GraphPropertiesTest,PackWithIdentityInput)1091 TEST_F(GraphPropertiesTest, PackWithIdentityInput) {
1092   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1093   // Same to PackWithConstInput test case, but a, b, c, and d are Identity ops
1094   // from Const.
1095   // If output_tensors_as_shape is not not set for those Shape ops or Pack op
1096   // doesn't take input_tensors_as_shape, Fill op's input doesn't have value;
1097   // hence, its output shape becomes unknown.
1098   Output a0 = ops::Const(s.WithOpName("a0"), 1, {});
1099   Output b0 = ops::Const(s.WithOpName("b0"), 2, {});
1100   Output c0 = ops::Const(s.WithOpName("c0"), 3, {});
1101   Output d0 = ops::Const(s.WithOpName("d0"), 4, {});
1102   Output a = ops::Identity(s.WithOpName("a"), a0);
1103   Output b = ops::Identity(s.WithOpName("b"), b0);
1104   Output c = ops::Identity(s.WithOpName("c"), c0);
1105   Output d = ops::Identity(s.WithOpName("d"), d0);
1106   // Note ops::Stack instantiates Pack op.
1107   Output e = ops::Stack(s.WithOpName("pack"), {a, b, c, d});
1108   // e is rank 1 tensor: shape = {4}, and its value is {1, 2, 3, 4}
1109   Output f = ops::Const(s.WithOpName("const"), 0.1f, {});
1110   // Fill needs not only e's shape but also its value to figure out output
1111   // shape.
1112   Output g = ops::Fill(s.WithOpName("fill"), e, f);
1113 
1114   GrapplerItem item;
1115   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1116   GraphProperties properties(item);
1117   TF_CHECK_OK(properties.InferStatically(false));
1118   const auto out_props = properties.GetOutputProperties("fill");
1119   const OpInfo::TensorProperties out_prop0 = out_props[0];
1120   EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
1121 }
1122 
TEST_F(GraphPropertiesTest,FunctionWithConstInput)1123 TEST_F(GraphPropertiesTest, FunctionWithConstInput) {
1124   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1125   TF_CHECK_OK(s.graph()->AddFunctionLibrary(function_lib_));
1126   Output shape = ops::Const(s.WithOpName("shape"), {1, 2, 3, 4});
1127   Output value = ops::Const(s.WithOpName("value"), 0.1f, {});
1128   auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc",
1129                                          s.graph()->op_registry());
1130   tensorflow::Node* func_op;
1131   auto _shape = tensorflow::ops::AsNodeOut(s, shape);
1132   auto _value = tensorflow::ops::AsNodeOut(s, value);
1133   TF_CHECK_OK(
1134       builder.Input(_shape).Input(_value).Finalize(s.graph(), &func_op));
1135   GrapplerItem item;
1136   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1137 
1138   GraphProperties properties(item);
1139   TF_CHECK_OK(properties.InferStatically(false));
1140   const auto out_props = properties.GetOutputProperties("MyFillFunc");
1141   const OpInfo::TensorProperties out_prop0 = out_props[0];
1142   EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
1143 }
1144 
TEST_F(GraphPropertiesTest,FunctionWithIdentityOfConstInput)1145 TEST_F(GraphPropertiesTest, FunctionWithIdentityOfConstInput) {
1146   // Same to FunctionWithConstInput, but function inputs are Identity of Const,
1147   // so tensor shapes, not tensor value, should be used as Const input to
1148   // function.
1149   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1150   TF_CHECK_OK(s.graph()->AddFunctionLibrary(function_lib_));
1151   Output shape_ = ops::Const(s.WithOpName("shape_"), {1, 2, 3, 4});
1152   Output shape = ops::Identity(s.WithOpName("shape"), shape_);
1153   Output value = ops::Const(s.WithOpName("value"), 0.1f, {});
1154   auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc",
1155                                          s.graph()->op_registry());
1156   tensorflow::Node* func_op;
1157   auto _shape = tensorflow::ops::AsNodeOut(s, shape);
1158   auto _value = tensorflow::ops::AsNodeOut(s, value);
1159   TF_CHECK_OK(
1160       builder.Input(_shape).Input(_value).Finalize(s.graph(), &func_op));
1161   GrapplerItem item;
1162   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1163 
1164   GraphProperties properties(item);
1165   TF_CHECK_OK(properties.InferStatically(false));
1166   const auto out_props = properties.GetOutputProperties("MyFillFunc");
1167   const OpInfo::TensorProperties out_prop0 = out_props[0];
1168   EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
1169 }
1170 
TEST_F(GraphPropertiesTest,FunctionReturnTensorValue)1171 TEST_F(GraphPropertiesTest, FunctionReturnTensorValue) {
1172   FunctionDefLibrary library;
1173   *library.add_function() = FunctionDefHelper::Create(
1174       "MyFunc",                                                   // Name
1175       {"x: int32"},                                               // Inputs
1176       {"out: int32"},                                             // Outputs
1177       {},                                                         // Attrs
1178       {{{"a"}, "Identity", {"x"}, {{"T", DataType::DT_INT32}}}},  // Nodes
1179       {{"out", "a:output:0"}});                                   // Returns
1180   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1181   TF_CHECK_OK(s.graph()->AddFunctionLibrary(library));
1182 
1183   // MyFunc takes Const (shape) and passes it with Identity. Expect function
1184   // output has the same shape as well as value (output_tensors_as_shape) as
1185   // input Const tensor.
1186   Output shape = ops::Const(s.WithOpName("shape"), {5, 7}, {2});
1187   auto _shape = tensorflow::ops::AsNodeOut(s, shape);
1188   auto builder =
1189       tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
1190   tensorflow::Node* func_op;
1191   TF_CHECK_OK(builder.Input(_shape).Finalize(s.graph(), &func_op));
1192 
1193   GrapplerItem item;
1194   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1195 
1196   GraphProperties properties(item);
1197   TF_CHECK_OK(properties.InferStatically(true));
1198   const auto out_props = properties.GetOutputProperties("MyFunc");
1199   const OpInfo::TensorProperties out_prop0 = out_props[0];
1200   EXPECT_EQ("int32: [2]", PropToString(out_prop0));
1201   EXPECT_TRUE(out_prop0.has_value());
1202   ExpectTensorValues({5, 7}, out_prop0.value());
1203   ExpectTensorValues({5, 7},
1204                      properties.GetInputProperties("MyFunc")[0].value());
1205 }
1206 
TEST_F(GraphPropertiesTest,ArithmeticFunctionReturnTensorValue)1207 TEST_F(GraphPropertiesTest, ArithmeticFunctionReturnTensorValue) {
1208   FunctionDefLibrary library;
1209   // Function that adds two input values.
1210   *library.add_function() = FunctionDefHelper::Create(
1211       "MyFunc",                                                   // Name
1212       {"x: int32", "y: int32"},                                   // Inputs
1213       {"out: int32"},                                             // Outputs
1214       {},                                                         // Attrs
1215       {{{"a"}, "Add", {"x", "y"}, {{"T", DataType::DT_INT32}}}},  // Nodes
1216       {{"out", "a:z:0"}});                                        // Returns
1217   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1218   TF_CHECK_OK(s.graph()->AddFunctionLibrary(library));
1219 
1220   Output shape = ops::Const(s.WithOpName("shape"), {5, 7}, {2});
1221   auto _shape = tensorflow::ops::AsNodeOut(s, shape);
1222   auto builder =
1223       tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
1224   tensorflow::Node* func_op;
1225   TF_CHECK_OK(
1226       builder.Input(_shape).Input(_shape).Finalize(s.graph(), &func_op));
1227 
1228   GrapplerItem item;
1229   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1230   {
1231     GraphProperties properties(item);
1232     // Without aggressive_shape_inference, the internal function does not
1233     // evaluate output value.
1234     TF_CHECK_OK(properties.InferStatically(
1235         /*assume_valid_feeds=*/true,
1236         /*aggressive_shape_inference=*/false));
1237     const auto out_props = properties.GetOutputProperties("MyFunc");
1238     const OpInfo::TensorProperties out_prop0 = out_props[0];
1239     EXPECT_EQ("int32: [2]", PropToString(out_prop0));
1240     EXPECT_FALSE(out_prop0.has_value());
1241   }
1242 
1243   {
1244     GraphProperties properties(item);
1245     // With aggressive_shape_inference, output value is evaluated.
1246     TF_CHECK_OK(properties.InferStatically(
1247         /*assume_valid_feeds=*/true,
1248         /*aggressive_shape_inference=*/true));
1249     const auto out_props = properties.GetOutputProperties("MyFunc");
1250     const OpInfo::TensorProperties out_prop0 = out_props[0];
1251     EXPECT_EQ("int32: [2]", PropToString(out_prop0));
1252     EXPECT_TRUE(out_prop0.has_value());
1253 
1254     ExpectTensorValues({10, 14}, out_prop0.value());
1255     ExpectTensorValues({5, 7},
1256                        properties.GetInputProperties("MyFunc")[0].value());
1257     ExpectTensorValues({5, 7},
1258                        properties.GetInputProperties("MyFunc")[1].value());
1259   }
1260 }
1261 
TEST_F(GraphPropertiesTest,FunctionWithScalarInput)1262 TEST_F(GraphPropertiesTest, FunctionWithScalarInput) {
1263   // Create graph with a function that takes a scalar value so that we use
1264   // Placeholder with scalar as for input to the function shape inference.
1265   // Placeholder -> Identity -> MyFunc, where MyFunc simply takes Identity of
1266   // the input; all tensors are scalars.
1267   FunctionDefLibrary library;
1268   *library.add_function() = FunctionDefHelper::Create(
1269       "MyFunc",                                                   // Name
1270       {"x: float"},                                               // Inputs
1271       {"out: float"},                                             // Outputs
1272       {},                                                         // Attrs
1273       {{{"a"}, "Identity", {"x"}, {{"T", DataType::DT_FLOAT}}}},  // Nodes
1274       {{"out", "a:output:0"}});                                   // Returns
1275   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1276   TF_CHECK_OK(s.graph()->AddFunctionLibrary(library));
1277   Output placeholder =
1278       ops::Placeholder(s.WithOpName("Placeholder"), DataType::DT_FLOAT,
1279                        ops::Placeholder::Shape(TensorShape({})));
1280   Output identity = ops::Identity(s.WithOpName("Identity"), placeholder);
1281   auto _identity = tensorflow::ops::AsNodeOut(s, identity);
1282   auto builder =
1283       tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
1284   tensorflow::Node* func_op;
1285   TF_CHECK_OK(builder.Input(_identity).Finalize(s.graph(), &func_op));
1286   GrapplerItem item;
1287   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1288 
1289   // Tensorflow version < 21 infers output shape of Placeholder with empty shape
1290   // as unknown, instead of scalar.
1291   EXPECT_GT(item.graph.versions().producer(), 21);
1292 
1293   // MyFunc output shouldn't be unknown rank.
1294   GraphProperties properties(item);
1295   TF_CHECK_OK(properties.InferStatically(true));
1296   const auto out_props = properties.GetOutputProperties("MyFunc");
1297   const OpInfo::TensorProperties out_prop0 = out_props[0];
1298   EXPECT_EQ(DT_FLOAT, out_prop0.dtype());
1299   EXPECT_FALSE(out_prop0.shape().unknown_rank());
1300 }
1301 
TEST_F(GraphPropertiesTest,SimpleFunctionStaticShapeInference)1302 TEST_F(GraphPropertiesTest, SimpleFunctionStaticShapeInference) {
1303   // Test graph produced in python using:
1304   /*
1305     @function.Defun(*[tf.float32] * 2, noinline=True)
1306     def MyAdd(x, y):
1307       return tf.add(x,y)
1308 
1309     with tf.Graph().as_default():
1310       x = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1311       y = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1312       z = MyAdd(x, y)
1313       z = MyAdd(x, z)
1314   */
1315   GrapplerItem item;
1316   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1317                                  "simple_function.pbtxt");
1318   TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
1319   GraphProperties properties(item);
1320   TF_CHECK_OK(properties.InferStatically(false));
1321   const auto out_props = properties.GetOutputProperties("MyAdd_55e046a8");
1322   const OpInfo::TensorProperties& out_prop = out_props[0];
1323   EXPECT_EQ(DT_FLOAT, out_prop.dtype());
1324   EXPECT_FALSE(out_prop.shape().unknown_rank());
1325   EXPECT_EQ(2, out_prop.shape().dim_size());
1326   EXPECT_EQ(1, out_prop.shape().dim(0).size());
1327   EXPECT_EQ(2, out_prop.shape().dim(1).size());
1328 
1329   const auto in_props = properties.GetInputProperties("MyAdd_55e046a8");
1330   EXPECT_EQ(2, in_props.size());
1331 
1332   const OpInfo::TensorProperties& in_prop = in_props[0];
1333   EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1334 
1335   const OpInfo::TensorProperties& in_prop1 = in_props[1];
1336   EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
1337 }
1338 
TEST_F(GraphPropertiesTest,LargeFunctionStaticShapeInference)1339 TEST_F(GraphPropertiesTest, LargeFunctionStaticShapeInference) {
1340   GrapplerItem item;
1341   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1342                                  "large_function_graph.pbtxt");
1343   TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
1344   GraphProperties properties(item);
1345   TF_CHECK_OK(properties.InferStatically(false));
1346 
1347   const auto out_props = properties.GetOutputProperties("y0");
1348   EXPECT_EQ(2, out_props.size());
1349 
1350   const OpInfo::TensorProperties& out_prop0 = out_props[0];
1351   EXPECT_EQ("float: [128,112,112,64]", PropToString(out_prop0));
1352 
1353   const OpInfo::TensorProperties& out_prop1 = out_props[1];
1354   EXPECT_EQ("float: [128,112,112,24]", PropToString(out_prop1));
1355 
1356   const auto in_props = properties.GetInputProperties("y0");
1357   EXPECT_EQ(4, in_props.size());
1358 
1359   const OpInfo::TensorProperties& in_prop0 = in_props[0];
1360   EXPECT_EQ("float: [64]", PropToString(in_prop0));
1361 
1362   const OpInfo::TensorProperties& in_prop1 = in_props[1];
1363   EXPECT_EQ("float: [1,1,24,64]", PropToString(in_prop1));
1364 
1365   const OpInfo::TensorProperties& in_prop2 = in_props[2];
1366   EXPECT_EQ("float: [128,224,224,3]", PropToString(in_prop2));
1367 
1368   const OpInfo::TensorProperties& in_prop3 = in_props[3];
1369   EXPECT_EQ("float: [7,7,3,8]", PropToString(in_prop3));
1370 }
1371 
TEST_F(GraphPropertiesTest,LargeFunctionWithMultipleOutputs)1372 TEST_F(GraphPropertiesTest, LargeFunctionWithMultipleOutputs) {
1373   // Test graph produced in python using:
1374   /*
1375     @function.Defun(noinline=True)
1376     def MyFunc():
1377       @function.Defun(*[tf.float32] * 2)
1378       def Cond(n, unused_x):
1379         return n > 0
1380 
1381       @function.Defun(*[tf.float32] * 2)
1382       def Body(n, x):
1383         return n - 1, x + n
1384 
1385       i = tf.constant(10)
1386       return functional_ops.While([i, 0.], Cond, Body)
1387 
1388     with tf.Graph().as_default():
1389       z = MyFunc()
1390   */
1391   GrapplerItem item;
1392   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1393                                  "function_functional_while.pbtxt");
1394   TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
1395   GraphProperties properties(item);
1396   TF_CHECK_OK(properties.InferStatically(false));
1397 
1398   const auto out_props = properties.GetOutputProperties("MyFunc_AenMyWWx1Us");
1399   EXPECT_EQ(2, out_props.size());
1400 
1401   const OpInfo::TensorProperties& out_prop0 = out_props[0];
1402   EXPECT_EQ(DT_INT32, out_prop0.dtype());
1403   EXPECT_FALSE(out_prop0.shape().unknown_rank());
1404 
1405   const OpInfo::TensorProperties& out_prop1 = out_props[1];
1406   EXPECT_EQ(DT_FLOAT, out_prop1.dtype());
1407   EXPECT_FALSE(out_prop1.shape().unknown_rank());
1408 }
1409 
TEST_F(GraphPropertiesTest,FunctionWithErrorStaticShapeInference)1410 TEST_F(GraphPropertiesTest, FunctionWithErrorStaticShapeInference) {
1411   GrapplerItem item;
1412   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1413                                  "function_error.pbtxt");
1414   TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
1415   GraphProperties properties(item);
1416   TF_CHECK_OK(properties.InferStatically(false));
1417 
1418   const auto out_props = properties.GetOutputProperties("MyAdd_yabA4wXEdM4");
1419   EXPECT_EQ(1, out_props.size());
1420 
1421   const OpInfo::TensorProperties& out_prop = out_props[0];
1422   EXPECT_EQ(DT_FLOAT, out_prop.dtype());
1423   EXPECT_TRUE(out_prop.shape().unknown_rank());
1424 
1425   const auto in_props = properties.GetInputProperties("MyAdd_yabA4wXEdM4");
1426   EXPECT_EQ(2, in_props.size());
1427 
1428   const OpInfo::TensorProperties& in_prop = in_props[0];
1429   EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1430 
1431   const OpInfo::TensorProperties& in_prop1 = in_props[1];
1432   EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
1433 }
1434 
TEST_F(GraphPropertiesTest,FunctionSwitchStaticShapeInference)1435 TEST_F(GraphPropertiesTest, FunctionSwitchStaticShapeInference) {
1436   // Test graph produced in python using:
1437   /*
1438     @function.Defun(*[tf.float32] * 2, noinline=True)
1439     def MyAdd(x, y):
1440       return tf.add(x, y)
1441 
1442     with tf.Graph().as_default():
1443       x = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1444       y = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1445       z = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1446       z2 = MyAdd(tf.case([(tf.less(0, 1), x)], default=y), z)
1447   */
1448   GrapplerItem item;
1449   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1450                                  "function_switch.pbtxt");
1451   TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
1452   GraphProperties properties(item);
1453   TF_CHECK_OK(properties.InferStatically(false));
1454   const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o");
1455   const OpInfo::TensorProperties& out_prop = out_props[0];
1456   EXPECT_EQ(DT_FLOAT, out_prop.dtype());
1457   EXPECT_EQ("float: [1,2]", PropToString(out_prop));
1458 
1459   const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o");
1460   EXPECT_EQ(2, in_props.size());
1461 
1462   const OpInfo::TensorProperties& in_prop = in_props[0];
1463   EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1464 
1465   const OpInfo::TensorProperties& in_prop1 = in_props[1];
1466   EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
1467 }
1468 
TEST_F(GraphPropertiesTest,FunctionSwitch2StaticShapeInference)1469 TEST_F(GraphPropertiesTest, FunctionSwitch2StaticShapeInference) {
1470   // Test graph produced in python using:
1471   /*
1472     @function.Defun(*[tf.float32] * 2, noinline=True)
1473     def MyAdd(x, y):
1474       return tf.add(x, y)
1475 
1476     with tf.Graph().as_default():
1477       x = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1478       y = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1479       z = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1480       z2 = MyAdd(tf.case([(tf.less(1, 0), x)], default=y), z)
1481   */
1482   GrapplerItem item;
1483   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1484                                  "function_switch_2.pbtxt");
1485   TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
1486   GraphProperties properties(item);
1487   TF_CHECK_OK(properties.InferStatically(false));
1488   const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o");
1489   const OpInfo::TensorProperties& out_prop = out_props[0];
1490   EXPECT_EQ("float: [1,2]", PropToString(out_prop));
1491 
1492   const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o");
1493   EXPECT_EQ(2, in_props.size());
1494 
1495   const OpInfo::TensorProperties& in_prop = in_props[0];
1496   EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1497 
1498   const OpInfo::TensorProperties& in_prop1 = in_props[1];
1499   EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
1500 }
1501 
TEST_F(GraphPropertiesTest,FunctionSwitchShapesStaticShapeInference)1502 TEST_F(GraphPropertiesTest, FunctionSwitchShapesStaticShapeInference) {
1503   // Test graph produced in python using:
1504   /*
1505     @function.Defun(*[tf.float32] * 2, noinline=True)
1506     def MyAdd(x, y):
1507       a = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1508       b = tf.constant(2.0, shape=[1, 3], dtype=tf.float32)
1509       c = tf.add(x, a)
1510       d = tf.add(y, b)
1511       return c
1512 
1513     with tf.Graph().as_default():
1514       x = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1515       y = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1516       z = tf.constant(2.0, shape=[1, 3], dtype=tf.float32)
1517       z2 = MyAdd(tf.case([(tf.less(1, 0), x)], default=y), z)
1518   */
1519   GrapplerItem item;
1520   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1521                                  "function_switch_shapes.pbtxt");
1522   TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
1523   GraphProperties properties(item);
1524   TF_CHECK_OK(properties.InferStatically(false));
1525   const auto out_props = properties.GetOutputProperties("MyAdd_lEKAAnIwI5I");
1526   const OpInfo::TensorProperties& out_prop = out_props[0];
1527   EXPECT_EQ("float: [1,2]", PropToString(out_prop));
1528 
1529   const auto in_props = properties.GetInputProperties("MyAdd_lEKAAnIwI5I");
1530   EXPECT_EQ(2, in_props.size());
1531 
1532   const OpInfo::TensorProperties& in_prop = in_props[0];
1533   EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1534 
1535   const OpInfo::TensorProperties& in_prop1 = in_props[1];
1536   EXPECT_EQ("float: [1,3]", PropToString(in_prop1));
1537 }
1538 
TEST_F(GraphPropertiesTest,SymbolicShapes)1539 TEST_F(GraphPropertiesTest, SymbolicShapes) {
1540   // Build a simple graph with placeholders of unknown dimensions. These
1541   // dimensions will be encoded symbolically.
1542   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1543 
1544   Output a =
1545       ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
1546                        ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
1547   Output b =
1548       ops::Placeholder(s.WithOpName("b"), DT_FLOAT,
1549                        ops::Placeholder::Shape(PartialTensorShape({-1})));
1550   Output c = ops::Identity(s.WithOpName("c"), a);
1551   Output d = ops::Identity(s.WithOpName("d"), b);
1552   Output e = ops::Add(s.WithOpName("e"), c, d);
1553   Output f = ops::Add(s.WithOpName("f"), a, c);
1554 
1555   Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
1556   Output g = ops::Shape(s.WithOpName("g"), c);
1557   Output h = ops::Fill(s.WithOpName("h"), g, zero);
1558   Output zero_idx = ops::Const(s.WithOpName("zero_idx"), {0}, {1});
1559   Output j = ops::Sum(s.WithOpName("j"), a, zero_idx);
1560 
1561   GrapplerItem item;
1562   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1563 
1564   GraphProperties properties(item);
1565   TF_CHECK_OK(properties.InferStatically(false));
1566   const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
1567   const auto shape_c = properties.GetOutputProperties("c").at(0).shape();
1568   EXPECT_EQ(2, shape_a.dim_size());
1569   EXPECT_EQ(shape_a.dim_size(), shape_c.dim_size());
1570   EXPECT_GE(-2, shape_a.dim(0).size());
1571   EXPECT_EQ(shape_a.dim(0).size(), shape_c.dim(0).size());
1572   EXPECT_GE(-2, shape_a.dim(1).size());
1573   EXPECT_EQ(shape_a.dim(1).size(), shape_c.dim(1).size());
1574 
1575   PartialTensorShape shape(shape_a);
1576   EXPECT_FALSE(shape.IsFullyDefined());
1577   EXPECT_FALSE(shape.unknown_rank());
1578 
1579   const auto shape_b = properties.GetOutputProperties("b").at(0).shape();
1580   const auto shape_d = properties.GetOutputProperties("d").at(0).shape();
1581   EXPECT_EQ(1, shape_b.dim_size());
1582   EXPECT_EQ(shape_b.dim_size(), shape_d.dim_size());
1583   EXPECT_GE(-2, shape_b.dim(0).size());
1584   EXPECT_NE(shape_a.dim(0).size(), shape_b.dim(0).size());
1585   EXPECT_EQ(shape_b.dim(0).size(), shape_d.dim(0).size());
1586 
1587   const auto shape_e = properties.GetOutputProperties("e").at(0).shape();
1588   ASSERT_EQ(2, shape_e.dim_size());
1589   EXPECT_EQ(shape_e.dim(0).size(), shape_c.dim(0).size());
1590   EXPECT_NE(shape_e.dim(1).size(), shape_c.dim(1).size());
1591   EXPECT_NE(shape_e.dim(0).size(), shape_d.dim(0).size());
1592 
1593   const auto shape_f = properties.GetOutputProperties("f").at(0).shape();
1594   ASSERT_EQ(2, shape_f.dim_size());
1595   EXPECT_EQ(shape_f.dim(0).size(), shape_a.dim(0).size());
1596   EXPECT_EQ(shape_f.dim(1).size(), shape_a.dim(1).size());
1597 
1598   const auto shape_h = properties.GetOutputProperties("h").at(0).shape();
1599   ASSERT_EQ(2, shape_f.dim_size());
1600   EXPECT_EQ(shape_h.dim(0).size(), shape_c.dim(0).size());
1601   EXPECT_EQ(shape_h.dim(1).size(), shape_c.dim(1).size());
1602 
1603   const auto shape_j = properties.GetOutputProperties("j").at(0).shape();
1604   ASSERT_EQ(1, shape_j.dim_size());
1605   EXPECT_EQ(shape_j.dim(0).size(), shape_a.dim(1).size());
1606 }
1607 
TEST_F(GraphPropertiesTest,DoNotValidateColocationConstraints)1608 TEST_F(GraphPropertiesTest, DoNotValidateColocationConstraints) {
1609   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1610   Output a = ops::Const(s.WithOpName("a"), 1.0f, {1});
1611   Output b = ops::Const(s.WithOpName("b"), 2.0f, {1});
1612   Output c = ops::Const(s.WithOpName("c").ColocateWith(a), 3.0f, {1});
1613   GrapplerItem item;
1614   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1615   // Create a graph with node a removed (say by some graph optimization
1616   // pass), noting that node c is colocated with a. This is fine as it
1617   // is in the late stage of graph execution, the colocation constraints have
1618   // been validated previously and the device placement of nodes has completed.
1619   GraphDef optimized_graph;
1620   for (const auto& node : item.graph.node()) {
1621     if (node.name() != "a") {
1622       *optimized_graph.add_node() = node;
1623     }
1624   }
1625   item.graph.Swap(&optimized_graph);
1626   GraphProperties properties(item);
1627   // This function should return OK, since it doesn't validate the colocation
1628   // constraints internally.
1629   TF_EXPECT_OK(properties.InferStatically(false));
1630 }
1631 
TEST_F(GraphPropertiesTest,ShapeTracking)1632 TEST_F(GraphPropertiesTest, ShapeTracking) {
1633   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1634   Output a =
1635       ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
1636                        ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
1637   Output b =
1638       ops::Placeholder(s.WithOpName("b"), DT_FLOAT,
1639                        ops::Placeholder::Shape(PartialTensorShape({-1})));
1640   Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
1641   auto shp = ops::ShapeN(s.WithOpName("shapes"), {a, b});
1642   Output o1 = ops::Fill(s.WithOpName("o1"), shp[0], zero);
1643   Output o2 = ops::Fill(s.WithOpName("o2"), shp[1], zero);
1644 
1645   GrapplerItem item;
1646   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1647 
1648   GraphProperties properties(item);
1649   TF_CHECK_OK(properties.InferStatically(false));
1650   const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
1651   const auto shape_b = properties.GetOutputProperties("b").at(0).shape();
1652   const auto shape_o1 = properties.GetOutputProperties("o1").at(0).shape();
1653   const auto shape_o2 = properties.GetOutputProperties("o2").at(0).shape();
1654   EXPECT_EQ(shape_a.DebugString(), shape_o1.DebugString());
1655   EXPECT_EQ(shape_b.DebugString(), shape_o2.DebugString());
1656 }
1657 
TEST_F(GraphPropertiesTest,FedNodes)1658 TEST_F(GraphPropertiesTest, FedNodes) {
1659   TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
1660                                           cluster_->GetDeviceNames());
1661   GrapplerItem item;
1662   CHECK(fake_input.NextItem(&item));
1663 
1664   {
1665     // Conservative shape analysis: the shape of fed ports should be unknown
1666     GraphProperties properties(item);
1667     Status s = properties.InferStatically(false);
1668     TF_CHECK_OK(s);
1669     for (const auto& node : item.graph.node()) {
1670       if (node.op() == "Const") {
1671         continue;
1672       }
1673       const auto in_props = properties.GetInputProperties(node.name());
1674       EXPECT_EQ(1, in_props.size());
1675       const OpInfo::TensorProperties& in_prop = in_props[0];
1676       const auto out_props = properties.GetOutputProperties(node.name());
1677       EXPECT_EQ(1, out_props.size());
1678       const OpInfo::TensorProperties& out_prop = out_props[0];
1679 
1680       if (node.name() == "x") {
1681         // x is fed: its input should have a known shape, while its output
1682         // doesn't
1683         EXPECT_FALSE(in_prop.shape().unknown_rank());
1684         EXPECT_EQ(1, in_prop.shape().dim_size());
1685         EXPECT_EQ(2, in_prop.shape().dim(0).size());
1686         EXPECT_TRUE(out_prop.shape().unknown_rank());
1687       } else if (node.op() == "Square" || node.op() == "AddN") {
1688         // These nodes are in the fanout of x: their shapes should be unknown.
1689         EXPECT_TRUE(in_prop.shape().unknown_rank());
1690         EXPECT_TRUE(out_prop.shape().unknown_rank());
1691       }
1692     }
1693   }
1694   {
1695     // Optimistic shape analysis: the shape of fed ports should be derived from
1696     // the shape of the fanin.
1697     GraphProperties properties(item);
1698     Status s = properties.InferStatically(true);
1699     TF_CHECK_OK(s);
1700     for (const auto& node : item.graph.node()) {
1701       if (node.op() == "Square" || node.op() == "AddN") {
1702         const auto in_props = properties.GetInputProperties(node.name());
1703         EXPECT_EQ(1, in_props.size());
1704         const OpInfo::TensorProperties& in_prop = in_props[0];
1705         EXPECT_EQ(DT_FLOAT, in_prop.dtype());
1706         EXPECT_FALSE(in_prop.shape().unknown_rank());
1707         EXPECT_EQ(2, in_prop.shape().dim_size());
1708         const auto out_props = properties.GetOutputProperties(node.name());
1709         EXPECT_EQ(1, out_props.size());
1710         const OpInfo::TensorProperties& out_prop = out_props[0];
1711         EXPECT_EQ(in_prop.DebugString(), out_prop.DebugString());
1712       }
1713     }
1714   }
1715 }
1716 
TEST_F(GraphPropertiesTest,Performance)1717 TEST_F(GraphPropertiesTest, Performance) {
1718   // Load a large graph with many nested loops to make sure we can infer shapes
1719   // quickly.
1720   GrapplerItem item;
1721   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1722                                  "large_graph.pbtxt.html");
1723   TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
1724   TF_CHECK_OK(AddDefaultAttrsToGraphDef(
1725       &item.graph,
1726       FunctionLibraryDefinition(OpRegistry::Global(), item.graph.library()), 0,
1727       true));
1728 
1729   GraphProperties properties(item);
1730   TF_CHECK_OK(properties.InferStatically(false));
1731 }
1732 
TEST_F(GraphPropertiesTest,StridedSlicesOfShapes)1733 TEST_F(GraphPropertiesTest, StridedSlicesOfShapes) {
1734   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1735   Output a =
1736       ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
1737                        ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
1738   auto shp = ops::Shape(s.WithOpName("shape"), {a});
1739 
1740   Output index1 = ops::Const(s.WithOpName("index1"), 0, {1});
1741   Output index2 = ops::Const(s.WithOpName("index2"), 1, {1});
1742   Output index3 = ops::Const(s.WithOpName("index3"), 2, {1});
1743 
1744   Output b = ops::StridedSlice(s.WithOpName("b"), shp, index1, index2, index2);
1745   Output c = ops::StridedSlice(s.WithOpName("c"), shp, index2, index3, index2);
1746 
1747   Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
1748   Output o1 = ops::Fill(s.WithOpName("o1"), b, zero);
1749   Output o2 = ops::Fill(s.WithOpName("o2"), c, zero);
1750 
1751   GrapplerItem item;
1752   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1753 
1754   GraphProperties properties(item);
1755   TF_CHECK_OK(properties.InferStatically(false));
1756   const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
1757   const auto shape_o1 = properties.GetOutputProperties("o1").at(0).shape();
1758   const auto shape_o2 = properties.GetOutputProperties("o2").at(0).shape();
1759   EXPECT_EQ(2, shape_a.dim_size());
1760   EXPECT_EQ(1, shape_o1.dim_size());
1761   EXPECT_EQ(1, shape_o2.dim_size());
1762   EXPECT_EQ(shape_a.dim(0).size(), shape_o1.dim(0).size());
1763   EXPECT_EQ(shape_a.dim(1).size(), shape_o2.dim(0).size());
1764 }
1765 
TEST_F(GraphPropertiesTest,StridedSliceOfShapeWithShrinkAxisMask)1766 TEST_F(GraphPropertiesTest, StridedSliceOfShapeWithShrinkAxisMask) {
1767   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1768   Output placeholder =
1769       ops::Placeholder(scope.WithOpName("input_placeholder"), DT_FLOAT,
1770                        ops::Placeholder::Shape(TensorShape({5, 480, 40, 1})));
1771   auto input_shape = ops::Shape(scope.WithOpName("input_shape"), placeholder);
1772 
1773   Output begin = ops::Const(scope.WithOpName("begin"), {0}, {1});
1774   Output end = ops::Const(scope.WithOpName("end"), {3}, {1});
1775   Output stride = ops::Const(scope.WithOpName("stride"), {1}, {1});
1776 
1777   Output slice =
1778       ops::StridedSlice(scope.WithOpName("slice"), input_shape, begin, end,
1779                         stride, ops::StridedSlice::ShrinkAxisMask(1));
1780 
1781   GrapplerItem item;
1782   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1783 
1784   // Without aggressive shape inference, it cannot infer output value of
1785   // StridedSlice with ShrinkAxisMask.
1786   {
1787     GraphProperties properties(item);
1788     TF_CHECK_OK(properties.InferStatically(
1789         /*assume_valid_feeds=*/false,
1790         /*aggressive_shape_inference=*/false));
1791     EXPECT_FALSE(properties.GetOutputProperties("slice").at(0).has_value());
1792   }
1793 
1794   // InferStatically with aggressive shape inference can infer output value of
1795   // StridedSlice with ShrinkAxisMask.
1796   {
1797     GraphProperties properties(item);
1798     TF_CHECK_OK(properties.InferStatically(
1799         /*assume_valid_feeds=*/false,
1800         /*aggressive_shape_inference=*/true));
1801     EXPECT_TRUE(properties.GetOutputProperties("slice").at(0).has_value());
1802     const auto slice_value =
1803         properties.GetOutputProperties("slice").at(0).value();
1804     ExpectTensorValues({5}, slice_value);
1805   }
1806 }
1807 
TEST_F(GraphPropertiesTest,ValuePropagationThroughArithmeticOps)1808 TEST_F(GraphPropertiesTest, ValuePropagationThroughArithmeticOps) {
1809   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1810   Output a = ops::Const(s.WithOpName("a"), {5, 7}, {2});
1811   Output b = ops::Const(s.WithOpName("b"), {8, 8}, {2});
1812   Output c = ops::Const(s.WithOpName("c"), {2, 2}, {2});
1813 
1814   Output a1 = ops::OnesLike(s.WithOpName("a1"), a);
1815   Output a_plus_one = ops::Add(s.WithOpName("a_plus_one"), a, a1);
1816   Output a_plus_a = ops::Add(s.WithOpName("a_plus_a"), a, a);
1817   Output b_plus_2a = ops::Add(s.WithOpName("b_plus_2a"), b, a_plus_a);
1818   Output c_plus_b_plus_2a =
1819       ops::Add(s.WithOpName("c_plus_b_plus_2a"), c, b_plus_2a);
1820 
1821   GrapplerItem item;
1822   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1823   GraphProperties properties(item);
1824   TF_CHECK_OK(properties.InferStatically(
1825       /*assume_valid_feeds=*/false,
1826       /*aggressive_shape_inference=*/true));
1827 
1828   // Check output shapes and values.
1829   const auto& a_plus_one_prop = properties.GetOutputProperties("a_plus_one")[0];
1830   EXPECT_EQ("int32: [2]", PropToString(a_plus_one_prop));
1831   EXPECT_TRUE(a_plus_one_prop.has_value());
1832   ExpectTensorValues({6, 8}, a_plus_one_prop.value());
1833 
1834   const auto& a_plus_a_prop = properties.GetOutputProperties("a_plus_a")[0];
1835   EXPECT_EQ("int32: [2]", PropToString(a_plus_a_prop));
1836   EXPECT_TRUE(a_plus_a_prop.has_value());
1837   ExpectTensorValues({10, 14}, a_plus_a_prop.value());
1838 
1839   const auto& b_plus_2a_prop = properties.GetOutputProperties("b_plus_2a")[0];
1840   EXPECT_EQ("int32: [2]", PropToString(b_plus_2a_prop));
1841   EXPECT_TRUE(b_plus_2a_prop.has_value());
1842   ExpectTensorValues({18, 22}, b_plus_2a_prop.value());
1843 
1844   const auto& c_plus_b_plus_2a_prop =
1845       properties.GetOutputProperties("c_plus_b_plus_2a")[0];
1846   EXPECT_EQ("int32: [2]", PropToString(c_plus_b_plus_2a_prop));
1847   EXPECT_TRUE(c_plus_b_plus_2a_prop.has_value());
1848   ExpectTensorValues({20, 24}, c_plus_b_plus_2a_prop.value());
1849 }
1850 
TEST_F(GraphPropertiesTest,ShapeAnnotation)1851 TEST_F(GraphPropertiesTest, ShapeAnnotation) {
1852   GrapplerItem item;
1853   TF_CHECK_OK(NodeDefBuilder("Input", "Placeholder")
1854                   .Attr("dtype", DT_FLOAT)
1855                   .Attr("shape", PartialTensorShape({-1, -1}))
1856                   .Finalize(item.graph.add_node()));
1857   // Annotate shapes.
1858   TF_CHECK_OK(NodeDefBuilder("Identity", "Identity")
1859                   .Attr("dtype", DT_FLOAT)
1860                   .Attr("_same_output_for_iterations", true)
1861                   .Attr("_output_shape_vector", {TensorShape({5, 7})})
1862                   .Input("Input", 0, DT_FLOAT)
1863                   .Finalize(item.graph.add_node()));
1864   {
1865     GraphProperties properties(item);
1866     // Without aggressive_shape_inference, ignore annotated information.
1867     TF_CHECK_OK(properties.InferStatically(
1868         /*assume_valid_feeds=*/false,
1869         /*aggressive_shape_inference=*/false));
1870     const auto props = properties.GetOutputProperties("Identity");
1871     EXPECT_EQ(1, props.size());
1872     const OpInfo::TensorProperties& prop = props[0];
1873     EXPECT_EQ(DT_FLOAT, prop.dtype());
1874     EXPECT_EQ(2, prop.shape().dim_size());
1875     // Get unknown shapes without using annotated information.
1876     EXPECT_EQ("float: [-1,-1]", PropToString(prop));
1877   }
1878   {
1879     GraphProperties properties(item);
1880     // Use annotated information.
1881     TF_CHECK_OK(properties.InferStatically(
1882         /*assume_valid_feeds=*/false,
1883         /*aggressive_shape_inference=*/true));
1884     const auto props = properties.GetOutputProperties("Identity");
1885     EXPECT_EQ(1, props.size());
1886     const OpInfo::TensorProperties& prop = props[0];
1887     EXPECT_EQ(DT_FLOAT, prop.dtype());
1888     EXPECT_EQ(2, prop.shape().dim_size());
1889     // Update output shape using annotated shapes.
1890     EXPECT_EQ("float: [5,7]", PropToString(prop));
1891   }
1892 }
1893 
TEST_F(GraphPropertiesTest,ShapeAnnotationWithCompatibleShapes)1894 TEST_F(GraphPropertiesTest, ShapeAnnotationWithCompatibleShapes) {
1895   GrapplerItem item;
1896   TF_CHECK_OK(NodeDefBuilder("Input", "Placeholder")
1897                   .Attr("dtype", DT_FLOAT)
1898                   .Attr("shape", PartialTensorShape({-1, 100}))
1899                   .Finalize(item.graph.add_node()));
1900   // Annotate shapes.
1901   TF_CHECK_OK(NodeDefBuilder("Identity", "Identity")
1902                   .Attr("dtype", DT_FLOAT)
1903                   .Attr("_same_output_for_iterations", true)
1904                   .Attr("_output_shape_vector", {TensorShape({10, 100})})
1905                   .Input("Input", 0, DT_FLOAT)
1906                   .Finalize(item.graph.add_node()));
1907   GraphProperties properties(item);
1908   // Use annotated information.
1909   TF_CHECK_OK(properties.InferStatically(
1910       /*assume_valid_feeds=*/false,
1911       /*aggressive_shape_inference=*/true));
1912   const auto props = properties.GetOutputProperties("Identity");
1913   EXPECT_EQ(1, props.size());
1914   const OpInfo::TensorProperties& prop = props[0];
1915   EXPECT_EQ(DT_FLOAT, prop.dtype());
1916   EXPECT_EQ(2, prop.shape().dim_size());
1917   // Compatible shapes. Update output shape using annotated shapes.
1918   EXPECT_EQ("float: [10,100]", PropToString(prop));
1919 }
1920 
TEST_F(GraphPropertiesTest,ShapeAnnotationWithIncompatibleShapes)1921 TEST_F(GraphPropertiesTest, ShapeAnnotationWithIncompatibleShapes) {
1922   GrapplerItem item;
1923   TF_CHECK_OK(NodeDefBuilder("Input", "Placeholder")
1924                   .Attr("dtype", DT_FLOAT)
1925                   .Attr("shape", PartialTensorShape({-1, 100}))
1926                   .Finalize(item.graph.add_node()));
1927   // Annotate shapes.
1928   TF_CHECK_OK(NodeDefBuilder("Identity", "Identity")
1929                   .Attr("dtype", DT_FLOAT)
1930                   .Attr("_same_output_for_iterations", true)
1931                   .Attr("_output_shape_vector", {TensorShape({10, 10})})
1932                   .Input("Input", 0, DT_FLOAT)
1933                   .Finalize(item.graph.add_node()));
1934   GraphProperties properties(item);
1935   // Use annotated information.
1936   TF_CHECK_OK(properties.InferStatically(
1937       /*assume_valid_feeds=*/false,
1938       /*aggressive_shape_inference=*/true));
1939   const auto props = properties.GetOutputProperties("Identity");
1940   EXPECT_EQ(1, props.size());
1941   const OpInfo::TensorProperties& prop = props[0];
1942   EXPECT_EQ(DT_FLOAT, prop.dtype());
1943   EXPECT_EQ(2, prop.shape().dim_size());
1944   // Incompatible shapes. Do not use annotated shapes.
1945   EXPECT_EQ("float: [-1,100]", PropToString(prop));
1946 }
1947 
1948 }  // namespace
1949 }  // namespace grappler
1950 }  // namespace tensorflow
1951