1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/shape_inference.h"
16 
17 #include "tensorflow/core/framework/fake_input.h"
18 #include "tensorflow/core/framework/node_def_builder.h"
19 #include "tensorflow/core/framework/op_def_builder.h"
20 #include "tensorflow/core/framework/tensor_shape.pb.h"
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/framework/types.pb.h"
23 #include "tensorflow/core/lib/core/status_test_util.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25 #include "tensorflow/core/lib/strings/strcat.h"
26 #include "tensorflow/core/platform/test.h"
27 
28 namespace tensorflow {
29 namespace shape_inference {
30 namespace {
31 
MakeOpDefWithLists()32 OpDef MakeOpDefWithLists() {
33   OpRegistrationData op_reg_data;
34   OpDefBuilder b("dummy");
35   b.Input(strings::StrCat("input: N * float"));
36   b.Output(strings::StrCat("output: N * float"));
37   CHECK(b.Attr("N:int >= 1").Finalize(&op_reg_data).ok());
38   return op_reg_data.op_def;
39 }
40 
S(std::initializer_list<int64> dims)41 PartialTensorShape S(std::initializer_list<int64> dims) {
42   return PartialTensorShape(dims);
43 }
44 
Unknown()45 PartialTensorShape Unknown() { return PartialTensorShape(); }
46 
47 }  // namespace
48 
49 class ShapeInferenceTest : public ::testing::Test {
50  protected:
51   // These give access to private functions of DimensionHandle and ShapeHandle.
SameHandle(DimensionHandle a,DimensionHandle b)52   bool SameHandle(DimensionHandle a, DimensionHandle b) {
53     return a.SameHandle(b);
54   }
SameHandle(ShapeHandle a,ShapeHandle b)55   bool SameHandle(ShapeHandle a, ShapeHandle b) { return a.SameHandle(b); }
IsSet(DimensionHandle d)56   bool IsSet(DimensionHandle d) { return d.IsSet(); }
IsSet(ShapeHandle s)57   bool IsSet(ShapeHandle s) { return s.IsSet(); }
Relax(InferenceContext * c,DimensionHandle d0,DimensionHandle d1,DimensionHandle * out)58   void Relax(InferenceContext* c, DimensionHandle d0, DimensionHandle d1,
59              DimensionHandle* out) {
60     c->Relax(d0, d1, out);
61   }
Relax(InferenceContext * c,ShapeHandle s0,ShapeHandle s1,ShapeHandle * out)62   void Relax(InferenceContext* c, ShapeHandle s0, ShapeHandle s1,
63              ShapeHandle* out) {
64     c->Relax(s0, s1, out);
65   }
66   void TestMergeHandles(bool input_not_output);
67   void TestRelaxHandles(bool input_not_output);
68 
69   static const int kVersion = 0;  // used for graph-def version.
70 };
71 
TEST_F(ShapeInferenceTest,InputOutputByName)72 TEST_F(ShapeInferenceTest, InputOutputByName) {
73   // Setup test to contain an input tensor list of size 3.
74   OpDef op_def = MakeOpDefWithLists();
75   NodeDef def;
76   auto s = NodeDefBuilder("dummy", &op_def)
77                .Attr("N", 3)
78                .Input(FakeInput(DT_FLOAT))
79                .Finalize(&def);
80   InferenceContext c(kVersion, &def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})},
81                      {}, {}, {});
82 
83   EXPECT_EQ("5", c.DebugString(c.NumElements(c.input(0))));
84   EXPECT_EQ("10", c.DebugString(c.NumElements(c.input(1))));
85   EXPECT_EQ("3", c.DebugString(c.NumElements(c.input(2))));
86   // Test getters.
87   std::vector<ShapeHandle> shapes;
88   EXPECT_FALSE(c.input("nonexistent", &shapes).ok());
89   TF_EXPECT_OK(c.input("input", &shapes));
90   EXPECT_EQ("[1,5]", c.DebugString(shapes[0]));
91   EXPECT_EQ("[2,5]", c.DebugString(shapes[1]));
92   EXPECT_EQ("[1,3]", c.DebugString(shapes[2]));
93 
94   // Test setters.
95   EXPECT_FALSE(c.set_output("nonexistent", shapes).ok());
96   TF_EXPECT_OK(c.set_output("output", shapes));
97   EXPECT_EQ("5", c.DebugString(c.NumElements(c.output(0))));
98   EXPECT_EQ("10", c.DebugString(c.NumElements(c.output(1))));
99   EXPECT_EQ("3", c.DebugString(c.NumElements(c.output(2))));
100 }
101 
MakeOpDef(int num_inputs,int num_outputs)102 static OpDef MakeOpDef(int num_inputs, int num_outputs) {
103   OpRegistrationData op_reg_data;
104   OpDefBuilder b("dummy");
105   for (int i = 0; i < num_inputs; ++i) {
106     b.Input(strings::StrCat("i", i, ": float"));
107   }
108   for (int i = 0; i < num_outputs; ++i) {
109     b.Output(strings::StrCat("o", i, ": float"));
110   }
111   CHECK(b.Attr("foo:string").Finalize(&op_reg_data).ok());
112   return op_reg_data.op_def;
113 }
114 
TEST_F(ShapeInferenceTest,DimensionOrConstant)115 TEST_F(ShapeInferenceTest, DimensionOrConstant) {
116   NodeDef def;
117   InferenceContext c(kVersion, &def, MakeOpDef(1, 1), {Unknown()}, {}, {}, {});
118   EXPECT_EQ(InferenceContext::kUnknownDim,
119             c.Value(InferenceContext::kUnknownDim));
120   EXPECT_EQ(1, c.Value(1));
121 
122 #ifndef NDEBUG
123   // Only run death test if DCHECKS are enabled.
124   EXPECT_DEATH(c.Value(-7), "Dimension must be non\\-negative or equal to");
125 #endif
126 }
127 
TEST_F(ShapeInferenceTest,Run)128 TEST_F(ShapeInferenceTest, Run) {
129   NodeDef def;
130   def.set_name("foo");
131   def.set_op("foo_op");
132   InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1})}, {}, {}, {});
133   TF_ASSERT_OK(c.construction_status());
134 
135   {
136     auto fn = [](InferenceContext* c) {
137       ShapeHandle h;
138       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 6, &h));
139       c->set_output(0, c->input(0));
140       c->set_output(1, c->input(0));
141       return Status::OK();
142     };
143     TF_ASSERT_OK(c.Run(fn));
144   }
145 
146   {
147     auto fn = [](InferenceContext* c) {
148       ShapeHandle h;
149       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
150       c->set_output(0, c->input(0));
151       c->set_output(1, c->input(0));
152       return Status::OK();
153     };
154     Status s = c.Run(fn);
155     // Extra error message is attached when Run fails.
156     EXPECT_TRUE(str_util::StrContains(
157         s.ToString(),
158         "Shape must be at most rank 0 but is rank 1 for 'foo' (op: 'foo_op')"))
159         << s;
160   }
161 }
162 
163 // Tests different context data added when Run returns error.
TEST_F(ShapeInferenceTest,AttachContext)164 TEST_F(ShapeInferenceTest, AttachContext) {
165   NodeDef def;
166   def.set_name("foo");
167   def.set_op("foo_op");
168   // Error when no constant tensors were requested.
169   {
170     InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, 3})}, {}, {},
171                        {});
172     TF_ASSERT_OK(c.construction_status());
173     auto fn = [](InferenceContext* c) {
174       ShapeHandle h;
175       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
176       c->set_output(0, c->input(0));
177       return Status::OK();
178     };
179     EXPECT_EQ(
180         "Invalid argument: Shape must be at most rank 0 but is rank 3 for "
181         "'foo' (op: 'foo_op') with input shapes: [1,2,3].",
182         c.Run(fn).ToString());
183   }
184 
185   // Error when a constant tensor value was requested.
186   {
187     Tensor input_t =
188         ::tensorflow::test::AsTensor<float>({1.1, 2.2, 3.3, 4.4, 5.5});
189     InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
190                        {S({1, 2, 3}), S({4, 5})}, {nullptr, &input_t}, {}, {});
191     TF_ASSERT_OK(c.construction_status());
192     auto fn = [](InferenceContext* c) {
193       c->input_tensor(0);  // get this one, but it's null - won't be in error.
194       c->input_tensor(1);  // get this one, will now be in error.
195       ShapeHandle h;
196       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
197       c->set_output(0, c->input(0));
198       return Status::OK();
199     };
200     EXPECT_EQ(
201         "Invalid argument: Shape must be at most rank 0 but is rank 3 for "
202         "'foo' (op: 'foo_op') with input shapes: [1,2,3], [4,5] and with "
203         "computed input tensors: input[1] = <1.1 2.2 3.3 4.4 5.5>.",
204         c.Run(fn).ToString());
205   }
206 
207   // Error when a constant tensor value as shape was requested, but no partial
208   // shapes provided.
209   {
210     Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5});
211     InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({3}), S({4})},
212                        {nullptr, &input_t}, {}, {});
213     TF_ASSERT_OK(c.construction_status());
214     auto fn = [](InferenceContext* c) {
215       ShapeHandle s;
216       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
217       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
218       ShapeHandle h;
219       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
220       c->set_output(0, c->input(0));
221       return Status::OK();
222     };
223     EXPECT_EQ(
224         "Invalid argument: Shape must be at most rank 0 but is rank 1 for "
225         "'foo' (op: 'foo_op') with input shapes: [3], [4] and with computed "
226         "input tensors: input[1] = <1 2 3 4 5>.",
227         c.Run(fn).ToString());
228   }
229 
230   // Error when a constant tensor value as shape was requested, and a partial
231   // shape was provided.
232   {
233     Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5});
234     InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({3}), S({4})},
235                        {nullptr, &input_t}, {S({10, -1, 5}), Unknown()}, {});
236     TF_ASSERT_OK(c.construction_status());
237     auto fn = [](InferenceContext* c) {
238       ShapeHandle s;
239       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
240       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
241       ShapeHandle h;
242       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
243       c->set_output(0, c->input(0));
244       return Status::OK();
245     };
246     EXPECT_EQ(
247         "Invalid argument: Shape must be at most rank 0 but is rank 1 for "
248         "'foo' (op: 'foo_op') with input shapes: [3], [4] and with computed "
249         "input tensors: input[1] = <1 2 3 4 5> and with input tensors computed "
250         "as partial shapes: input[0] = [10,?,5].",
251         c.Run(fn).ToString());
252   }
253 }
254 
TEST_F(ShapeInferenceTest,RankAndDimInspection)255 TEST_F(ShapeInferenceTest, RankAndDimInspection) {
256   NodeDef def;
257   InferenceContext c(kVersion, &def, MakeOpDef(3, 2),
258                      {Unknown(), S({1, -1, 3}), S({})}, {}, {}, {});
259   EXPECT_EQ(3, c.num_inputs());
260   EXPECT_EQ(2, c.num_outputs());
261 
262   auto in0 = c.input(0);
263   EXPECT_EQ("?", c.DebugString(in0));
264   EXPECT_FALSE(c.RankKnown(in0));
265   EXPECT_EQ(InferenceContext::kUnknownRank, c.Rank(in0));
266   EXPECT_EQ("?", c.DebugString(c.Dim(in0, 0)));
267   EXPECT_EQ("?", c.DebugString(c.Dim(in0, -1)));
268   EXPECT_EQ("?", c.DebugString(c.Dim(in0, 1000)));
269 
270   auto in1 = c.input(1);
271   EXPECT_EQ("[1,?,3]", c.DebugString(in1));
272   EXPECT_TRUE(c.RankKnown(in1));
273   EXPECT_EQ(3, c.Rank(in1));
274   auto d = c.Dim(in1, 0);
275   EXPECT_EQ(1, c.Value(d));
276   EXPECT_TRUE(SameHandle(d, c.Dim(in1, -3)));
277   EXPECT_TRUE(c.ValueKnown(d));
278   EXPECT_EQ("1", c.DebugString(d));
279   d = c.Dim(in1, 1);
280   EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(d));
281   EXPECT_FALSE(c.ValueKnown(d));
282   EXPECT_TRUE(SameHandle(d, c.Dim(in1, -2)));
283   EXPECT_EQ("?", c.DebugString(d));
284   d = c.Dim(in1, 2);
285   EXPECT_EQ(3, c.Value(d));
286   EXPECT_TRUE(SameHandle(d, c.Dim(in1, -1)));
287   EXPECT_TRUE(c.ValueKnown(d));
288   EXPECT_EQ("3", c.DebugString(d));
289 
290   auto in2 = c.input(2);
291   EXPECT_EQ("[]", c.DebugString(in2));
292   EXPECT_TRUE(c.RankKnown(in2));
293   EXPECT_EQ(0, c.Rank(in2));
294 }
295 
TEST_F(ShapeInferenceTest,NumElements)296 TEST_F(ShapeInferenceTest, NumElements) {
297   NodeDef def;
298   InferenceContext c(kVersion, &def, MakeOpDef(3, 2),
299                      {Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {}, {}, {});
300 
301   EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(0))));
302   EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(1))));
303 
304   // Different handles (not the same unknown value).
305   EXPECT_FALSE(SameHandle(c.Dim(c.input(1), 1), c.NumElements(c.input(1))));
306 
307   EXPECT_EQ("120", c.DebugString(c.NumElements(c.input(2))));
308 }
309 
TEST_F(ShapeInferenceTest,WithRank)310 TEST_F(ShapeInferenceTest, WithRank) {
311   NodeDef def;
312   InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
313                      {Unknown(), S({1, -1, 3})}, {}, {}, {});
314 
315   auto in0 = c.input(0);
316   auto in1 = c.input(1);
317   ShapeHandle s1;
318   ShapeHandle s2;
319 
320   // WithRank on a shape with unknown dimensionality always succeeds.
321   EXPECT_TRUE(c.WithRank(in0, 1, &s1).ok());
322   EXPECT_EQ("[?]", c.DebugString(s1));
323 
324   EXPECT_TRUE(c.WithRank(in0, 2, &s2).ok());
325   EXPECT_EQ("[?,?]", c.DebugString(s2));
326   EXPECT_FALSE(SameHandle(s1, s2));
327   EXPECT_FALSE(SameHandle(c.Dim(s2, 0), c.Dim(s2, 1)));
328 
329   EXPECT_TRUE(c.WithRank(in0, 1, &s2).ok());
330   EXPECT_EQ("[?]", c.DebugString(s2));
331   EXPECT_FALSE(SameHandle(s1, s2));
332 
333   EXPECT_TRUE(c.WithRank(in0, 0, &s1).ok());
334   EXPECT_EQ("[]", c.DebugString(s1));
335 
336   // WithRank on shape with known dimensionality.
337   s1 = in1;
338   EXPECT_EQ("Invalid argument: Shape must be rank 2 but is rank 3",
339             c.WithRank(in1, 2, &s1).ToString());
340   EXPECT_FALSE(IsSet(s1));
341   EXPECT_TRUE(c.WithRank(in1, 3, &s1).ok());
342   EXPECT_TRUE(SameHandle(s1, in1));
343 
344   // Inputs are unchanged.
345   EXPECT_EQ("?", c.DebugString(in0));
346   EXPECT_EQ("[1,?,3]", c.DebugString(in1));
347 }
348 
TEST_F(ShapeInferenceTest,WithRankAtMost)349 TEST_F(ShapeInferenceTest, WithRankAtMost) {
350   NodeDef def;
351   InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
352                      {Unknown(), S({1, -1, 3})}, {}, {}, {});
353 
354   auto in0 = c.input(0);
355   auto in1 = c.input(1);
356   ShapeHandle s1;
357   ShapeHandle s2;
358 
359   // WithRankAtMost on a shape with unknown dimensionality always succeeds.
360   EXPECT_TRUE(c.WithRankAtMost(in0, 1, &s1).ok());
361   EXPECT_EQ("?", c.DebugString(s1));
362   EXPECT_TRUE(SameHandle(in0, s1));
363 
364   EXPECT_TRUE(c.WithRankAtMost(in0, 2, &s2).ok());
365   EXPECT_EQ("?", c.DebugString(s2));
366   EXPECT_TRUE(SameHandle(s1, s2));
367 
368   // WithRankAtMost on shape with known dimensionality.
369   s1 = in1;
370   EXPECT_TRUE(str_util::StrContains(
371       c.WithRankAtMost(in1, 2, &s1).ToString(),
372       "Invalid argument: Shape must be at most rank 2 but is rank 3"));
373 
374   EXPECT_FALSE(IsSet(s1));
375   EXPECT_TRUE(c.WithRankAtMost(in1, 3, &s1).ok());
376   EXPECT_TRUE(SameHandle(s1, in1));
377   EXPECT_TRUE(c.WithRankAtMost(in1, 4, &s1).ok());
378   EXPECT_TRUE(SameHandle(s1, in1));
379   EXPECT_TRUE(c.WithRankAtMost(in1, 5, &s1).ok());
380   EXPECT_TRUE(SameHandle(s1, in1));
381 
382   // Inputs are unchanged.
383   EXPECT_EQ("?", c.DebugString(in0));
384   EXPECT_EQ("[1,?,3]", c.DebugString(in1));
385 }
386 
TEST_F(ShapeInferenceTest,WithRankAtLeast)387 TEST_F(ShapeInferenceTest, WithRankAtLeast) {
388   NodeDef def;
389   InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
390                      {Unknown(), S({1, -1, 3})}, {}, {}, {});
391 
392   auto in0 = c.input(0);
393   auto in1 = c.input(1);
394   ShapeHandle s1;
395   ShapeHandle s2;
396 
397   // WithRankAtLeast on a shape with unknown dimensionality always succeeds.
398   EXPECT_TRUE(c.WithRankAtLeast(in0, 1, &s1).ok());
399   EXPECT_EQ("?", c.DebugString(s1));
400   EXPECT_TRUE(SameHandle(in0, s1));
401 
402   EXPECT_TRUE(c.WithRankAtLeast(in0, 2, &s2).ok());
403   EXPECT_EQ("?", c.DebugString(s2));
404   EXPECT_TRUE(SameHandle(s1, s2));
405 
406   // WithRankAtLeast on shape with known dimensionality.
407   s1 = in1;
408   EXPECT_TRUE(str_util::StrContains(
409       c.WithRankAtLeast(in1, 4, &s1).ToString(),
410       "Invalid argument: Shape must be at least rank 4 but is rank 3"));
411 
412   EXPECT_FALSE(IsSet(s1));
413   EXPECT_TRUE(c.WithRankAtLeast(in1, 3, &s1).ok());
414   EXPECT_TRUE(SameHandle(s1, in1));
415   EXPECT_TRUE(c.WithRankAtLeast(in1, 2, &s1).ok());
416   EXPECT_TRUE(SameHandle(s1, in1));
417   EXPECT_TRUE(c.WithRankAtLeast(in1, 0, &s1).ok());
418   EXPECT_TRUE(SameHandle(s1, in1));
419 
420   // Inputs are unchanged.
421   EXPECT_EQ("?", c.DebugString(in0));
422   EXPECT_EQ("[1,?,3]", c.DebugString(in1));
423 }
424 
TEST_F(ShapeInferenceTest,WithValue)425 TEST_F(ShapeInferenceTest, WithValue) {
426   NodeDef def;
427   InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, -1})}, {}, {}, {});
428 
429   auto d0 = c.Dim(c.input(0), 0);
430   auto d1 = c.Dim(c.input(0), 1);
431   DimensionHandle out1;
432   DimensionHandle out2;
433 
434   // WithValue on a dimension with unknown value always succeeds.
435   EXPECT_TRUE(c.WithValue(d1, 1, &out1).ok());
436   EXPECT_EQ(1, c.Value(out1));
437 
438   EXPECT_TRUE(c.WithValue(d1, 2, &out2).ok());
439   EXPECT_EQ(2, c.Value(out2));
440   EXPECT_FALSE(SameHandle(out1, out2));
441   EXPECT_FALSE(SameHandle(out1, d1));
442 
443   EXPECT_TRUE(c.WithValue(d1, 1, &out2).ok());
444   EXPECT_EQ(1, c.Value(out2));
445   EXPECT_FALSE(SameHandle(out1, out2));
446 
447   // WithValue on dimension with known size.
448   out1 = d0;
449 
450   EXPECT_TRUE(
451       str_util::StrContains(c.WithValue(d0, 0, &out1).ToString(),
452                             "Invalid argument: Dimension must be 0 but is 1"));
453   EXPECT_FALSE(IsSet(out1));
454   out1 = d0;
455   EXPECT_TRUE(
456       str_util::StrContains(c.WithValue(d0, 2, &out1).ToString(),
457                             "Invalid argument: Dimension must be 2 but is 1"));
458 
459   EXPECT_FALSE(IsSet(out1));
460   EXPECT_TRUE(c.WithValue(d0, 1, &out1).ok());
461   EXPECT_TRUE(SameHandle(d0, out1));
462 
463   // Inputs are unchanged.
464   EXPECT_EQ("1", c.DebugString(d0));
465   EXPECT_EQ("?", c.DebugString(d1));
466 }
467 
TEST_F(ShapeInferenceTest,MergeDim)468 TEST_F(ShapeInferenceTest, MergeDim) {
469   NodeDef def;
470   InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})},
471                      {}, {}, {});
472 
473   auto d2 = c.Dim(c.input(0), 0);
474   auto d_unknown = c.Dim(c.input(0), 1);
475   auto d2_b = c.Dim(c.input(0), 2);
476   auto d1 = c.Dim(c.input(0), 3);
477   auto d_unknown_b = c.Dim(c.input(0), 4);
478   DimensionHandle out;
479 
480   // Merging anything with unknown returns the same pointer.
481   EXPECT_TRUE(c.Merge(d2, d_unknown, &out).ok());
482   EXPECT_TRUE(SameHandle(d2, out));
483   EXPECT_TRUE(c.Merge(d_unknown, d2, &out).ok());
484   EXPECT_TRUE(SameHandle(d2, out));
485   EXPECT_TRUE(c.Merge(d_unknown, d_unknown_b, &out).ok());
486   EXPECT_TRUE(SameHandle(d_unknown, out));
487 
488   auto merged_dims = c.MergedDims();
489   ASSERT_EQ(3, merged_dims.size());
490   EXPECT_TRUE(merged_dims[0].first.SameHandle(d2));
491   EXPECT_TRUE(merged_dims[0].second.SameHandle(d_unknown));
492   EXPECT_TRUE(merged_dims[1].first.SameHandle(d_unknown));
493   EXPECT_TRUE(merged_dims[1].second.SameHandle(d2));
494   EXPECT_TRUE(merged_dims[2].first.SameHandle(d_unknown));
495   EXPECT_TRUE(merged_dims[2].second.SameHandle(d_unknown_b));
496 
497   // Merging with self is a no-op and returns self.
498   EXPECT_TRUE(c.Merge(d2, d2, &out).ok());
499   EXPECT_TRUE(SameHandle(d2, out));
500   EXPECT_TRUE(c.Merge(d_unknown, d_unknown, &out).ok());
501   EXPECT_TRUE(SameHandle(d_unknown, out));
502 
503   merged_dims = c.MergedDims();
504   EXPECT_EQ(3, merged_dims.size());
505 
506   // Merging equal values is a no op and returns first one.
507   EXPECT_TRUE(c.Merge(d2, d2_b, &out).ok());
508   EXPECT_TRUE(SameHandle(d2, out));
509   EXPECT_TRUE(c.Merge(d2_b, d2, &out).ok());
510   EXPECT_TRUE(SameHandle(d2_b, out));
511 
512   merged_dims = c.MergedDims();
513   EXPECT_EQ(3, merged_dims.size());
514 
515   // Merging unequal values is an error.
516   EXPECT_TRUE(str_util::StrContains(
517       c.Merge(d2, d1, &out).ToString(),
518       "Invalid argument: Dimensions must be equal, but are 2 and 1"));
519 
520   EXPECT_FALSE(IsSet(out));
521   EXPECT_TRUE(str_util::StrContains(
522       c.Merge(d1, d2, &out).ToString(),
523       "Invalid argument: Dimensions must be equal, but are 1 and 2"));
524 
525   EXPECT_FALSE(IsSet(out));
526 
527   merged_dims = c.MergedDims();
528   EXPECT_EQ(3, merged_dims.size());
529 }
530 
TEST_F(ShapeInferenceTest,RelaxDim)531 TEST_F(ShapeInferenceTest, RelaxDim) {
532   NodeDef def;
533   InferenceContext c(kVersion, &def, MakeOpDef(1, 2),
534                      {S({2, InferenceContext::kUnknownDim, 2, 1,
535                          InferenceContext::kUnknownDim})},
536                      {}, {}, {});
537 
538   auto d2 = c.Dim(c.input(0), 0);
539   auto d_unknown = c.Dim(c.input(0), 1);
540   auto d2_b = c.Dim(c.input(0), 2);
541   auto d1 = c.Dim(c.input(0), 3);
542   auto d_unknown_b = c.Dim(c.input(0), 4);
543   DimensionHandle out;
544 
545   // Relaxing anything with unknown returns a new unknown or the existing
546   // unknown.
547   Relax(&c, d2, d_unknown, &out);
548   EXPECT_TRUE(SameHandle(d_unknown, out));
549   EXPECT_FALSE(SameHandle(d_unknown_b, out));
550   EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
551   Relax(&c, d_unknown, d2, &out);
552   EXPECT_FALSE(SameHandle(d_unknown, out));
553   EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
554   Relax(&c, d_unknown, d_unknown_b, &out);
555   EXPECT_FALSE(SameHandle(d_unknown, out));
556   EXPECT_TRUE(SameHandle(d_unknown_b, out));
557   EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
558 
559   // Relaxing with self returns self.
560   Relax(&c, d2, d2, &out);
561   EXPECT_TRUE(SameHandle(d2, out));
562   Relax(&c, d_unknown, d_unknown, &out);
563   EXPECT_TRUE(SameHandle(d_unknown, out));
564 
565   // Relaxing equal values returns first one.
566   Relax(&c, d2, d2_b, &out);
567   EXPECT_TRUE(SameHandle(d2, out));
568   Relax(&c, d2_b, d2, &out);
569   EXPECT_TRUE(SameHandle(d2_b, out));
570 
571   // Relaxing unequal values returns a new unknown.
572   Relax(&c, d2, d1, &out);
573   EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
574   Relax(&c, d1, d2, &out);
575   EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
576 }
577 
TEST_F(ShapeInferenceTest,RelaxShape)578 TEST_F(ShapeInferenceTest, RelaxShape) {
579   NodeDef def;
580   InferenceContext c(
581       kVersion, &def, MakeOpDef(7, 2),
582       {Unknown(), S({1, 2}), S({InferenceContext::kUnknownDim, 2}),
583        S({1, InferenceContext::kUnknownDim}), S({1, 3}), Unknown(), S({1})},
584       {}, {}, {});
585 
586   auto s_unknown = c.input(0);
587   auto s_1_2 = c.input(1);
588   auto s_u_2 = c.input(2);
589   auto s_1_u = c.input(3);
590   auto s_1_3 = c.input(4);
591   auto s_unknown_b = c.input(5);
592   auto s_1 = c.input(6);
593   ShapeHandle out;
594 
595   // Relaxing any shape with unknown returns a new unknown.
596   Relax(&c, s_unknown, s_1_2, &out);
597   EXPECT_FALSE(SameHandle(s_u_2, s_unknown));
598   EXPECT_EQ("?", c.DebugString(out));
599   Relax(&c, s_u_2, s_unknown, &out);
600   EXPECT_FALSE(SameHandle(s_u_2, out));
601   EXPECT_EQ("?", c.DebugString(out));
602   Relax(&c, s_unknown, s_unknown_b, &out);
603   EXPECT_FALSE(SameHandle(s_unknown, out));
604   EXPECT_TRUE(SameHandle(s_unknown_b, out));
605   EXPECT_EQ("?", c.DebugString(out));
606 
607   // Relaxing with self returns self.
608   Relax(&c, s_1_2, s_1_2, &out);
609   EXPECT_TRUE(SameHandle(out, s_1_2));
610 
611   // Relaxing where one of the inputs has less information.
612   out = ShapeHandle();
613   Relax(&c, s_1_2, s_u_2, &out);
614   EXPECT_FALSE(SameHandle(s_u_2, out));
615   EXPECT_EQ("[?,2]", c.DebugString(out));
616   out = ShapeHandle();
617   Relax(&c, s_u_2, s_1_2, &out);
618   EXPECT_FALSE(SameHandle(s_u_2, out));
619   EXPECT_EQ("[?,2]", c.DebugString(out));
620 
621   // Relaxing where each input has one distinct unknown dimension.
622   Relax(&c, s_u_2, s_1_u, &out);
623   EXPECT_EQ("[?,?]", c.DebugString(out));
624   EXPECT_FALSE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0)));
625   EXPECT_TRUE(SameHandle(c.Dim(s_1_u, 1), c.Dim(out, 1)));
626   auto s_u1 = c.UnknownShapeOfRank(1);
627   auto s_u2 = c.UnknownShapeOfRank(1);
628   Relax(&c, s_u1, s_u2, &out);
629   EXPECT_FALSE(SameHandle(s_u1, out));
630 
631   // Relaxing with mismatched values in a dimension returns a shape with that
632   // dimension unknown.
633   out = s_unknown;
634   Relax(&c, s_u_2, s_1_3, &out);
635   EXPECT_FALSE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0)));
636   EXPECT_EQ("[?,?]", c.DebugString(out));
637   out = s_unknown;
638   Relax(&c, s_1_3, s_u_2, &out);
639   EXPECT_TRUE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0)));
640   EXPECT_EQ("[?,?]", c.DebugString(out));
641   out = s_unknown;
642 
643   // Relaxing with mismatched ranks returns a new unknown.
644   Relax(&c, s_1, s_1_2, &out);
645   EXPECT_EQ("?", c.DebugString(out));
646 }
647 
TEST_F(ShapeInferenceTest,MergeShape)648 TEST_F(ShapeInferenceTest, MergeShape) {
649   NodeDef def;
650   InferenceContext c(kVersion, &def, MakeOpDef(7, 2),
651                      {Unknown(), S({1, 2}), S({-1, 2}), S({1, -1}), S({1, 3}),
652                       Unknown(), S({1})},
653                      {}, {}, {});
654 
655   auto s_unknown = c.input(0);
656   auto s_1_2 = c.input(1);
657   auto s_u_2 = c.input(2);
658   auto s_1_u = c.input(3);
659   auto s_1_3 = c.input(4);
660   auto s_unknown_b = c.input(5);
661   auto s_1 = c.input(6);
662   ShapeHandle out;
663 
664   // Merging any shape with unknown returns the shape.
665   EXPECT_TRUE(c.Merge(s_unknown, s_1_2, &out).ok());
666   EXPECT_TRUE(SameHandle(s_1_2, out));
667   EXPECT_TRUE(c.Merge(s_u_2, s_unknown, &out).ok());
668   EXPECT_TRUE(SameHandle(s_u_2, out));
669   EXPECT_TRUE(c.Merge(s_unknown, s_unknown_b, &out).ok());
670   EXPECT_TRUE(SameHandle(s_unknown, out));
671 
672   auto merged_shapes = c.MergedShapes();
673   ASSERT_EQ(3, merged_shapes.size());
674   EXPECT_TRUE(merged_shapes[0].first.SameHandle(s_unknown));
675   EXPECT_TRUE(merged_shapes[0].second.SameHandle(s_1_2));
676   EXPECT_TRUE(merged_shapes[1].first.SameHandle(s_u_2));
677   EXPECT_TRUE(merged_shapes[1].second.SameHandle(s_unknown));
678   EXPECT_TRUE(merged_shapes[2].first.SameHandle(s_unknown));
679   EXPECT_TRUE(merged_shapes[2].second.SameHandle(s_unknown_b));
680 
681   // Merging with self returns self.
682   EXPECT_TRUE(c.Merge(s_1_2, s_1_2, &out).ok());
683   EXPECT_TRUE(SameHandle(out, s_1_2));
684 
685   merged_shapes = c.MergedShapes();
686   EXPECT_EQ(3, merged_shapes.size());
687 
688   // Merging where one of the inputs is the right answer - return that input.
689   out = ShapeHandle();
690   EXPECT_TRUE(c.Merge(s_1_2, s_u_2, &out).ok());
691   EXPECT_TRUE(SameHandle(s_1_2, out));
692   out = ShapeHandle();
693   EXPECT_TRUE(c.Merge(s_u_2, s_1_2, &out).ok());
694   EXPECT_TRUE(SameHandle(s_1_2, out));
695 
696   merged_shapes = c.MergedShapes();
697   ASSERT_EQ(5, merged_shapes.size());
698   EXPECT_TRUE(merged_shapes[3].first.SameHandle(s_1_2));
699   EXPECT_TRUE(merged_shapes[3].second.SameHandle(s_u_2));
700   EXPECT_TRUE(merged_shapes[4].first.SameHandle(s_u_2));
701   EXPECT_TRUE(merged_shapes[4].second.SameHandle(s_1_2));
702 
703   // Merging where neither input is the right answer.
704   EXPECT_TRUE(c.Merge(s_u_2, s_1_u, &out).ok());
705   EXPECT_FALSE(SameHandle(out, s_u_2));
706   EXPECT_FALSE(SameHandle(out, s_1_u));
707   EXPECT_EQ("[1,2]", c.DebugString(out));
708   EXPECT_TRUE(SameHandle(c.Dim(s_1_u, 0), c.Dim(out, 0)));
709   EXPECT_TRUE(SameHandle(c.Dim(s_u_2, 1), c.Dim(out, 1)));
710 
711   merged_shapes = c.MergedShapes();
712   ASSERT_EQ(7, merged_shapes.size());
713   EXPECT_TRUE(merged_shapes[5].first.SameHandle(s_u_2));
714   EXPECT_TRUE(merged_shapes[5].second.SameHandle(s_1_u));
715   EXPECT_TRUE(merged_shapes[6].first.SameHandle(s_u_2));
716   EXPECT_TRUE(merged_shapes[6].second.SameHandle(out));
717 
718   auto s_u1 = c.UnknownShapeOfRank(1);
719   auto s_u2 = c.UnknownShapeOfRank(1);
720   TF_EXPECT_OK(c.Merge(s_u1, s_u2, &out));
721   EXPECT_TRUE(SameHandle(s_u1, out));
722 
723   merged_shapes = c.MergedShapes();
724   ASSERT_EQ(8, merged_shapes.size());
725   EXPECT_TRUE(merged_shapes[7].first.SameHandle(s_u1));
726   EXPECT_TRUE(merged_shapes[7].second.SameHandle(s_u2));
727 
728   // Incompatible merges give errors and set out to nullptr.
729   out = s_unknown;
730   EXPECT_TRUE(str_util::StrContains(
731       c.Merge(s_u_2, s_1_3, &out).ToString(),
732       "Invalid argument: Dimension 1 in both shapes must be equal, but "
733       "are 2 and 3"));
734 
735   EXPECT_FALSE(IsSet(out));
736   out = s_unknown;
737   EXPECT_TRUE(str_util::StrContains(
738       c.Merge(s_1_3, s_u_2, &out).ToString(),
739       "Invalid argument: Dimension 1 in both shapes must be equal, but "
740       "are 3 and 2"));
741 
742   EXPECT_FALSE(IsSet(out));
743   out = s_unknown;
744   EXPECT_TRUE(str_util::StrContains(
745       c.Merge(s_1, s_1_2, &out).ToString(),
746       "Invalid argument: Shapes must be equal rank, but are 1 and 2"));
747 
748   EXPECT_FALSE(IsSet(out));
749 
750   merged_shapes = c.MergedShapes();
751   EXPECT_EQ(8, merged_shapes.size());
752 }
753 
TEST_F(ShapeInferenceTest,MergePrefix)754 TEST_F(ShapeInferenceTest, MergePrefix) {
755   NodeDef def;
756   InferenceContext c(kVersion, &def, MakeOpDef(4, 2),
757                      {
758                          Unknown(),
759                          S({-1, 2}),
760                          S({1, -1, 3}),
761                          S({2, 4}),
762                      },
763                      {}, {}, {});
764 
765   auto s_unknown = c.input(0);
766   auto s_u_2 = c.input(1);
767   auto s_1_u_3 = c.input(2);
768   auto s_2_4 = c.input(3);
769 
770   ShapeHandle s_out;
771   ShapeHandle s_prefix_out;
772 
773   // Merging with unknown returns the inputs.
774   EXPECT_TRUE(c.MergePrefix(s_unknown, s_u_2, &s_out, &s_prefix_out).ok());
775   EXPECT_TRUE(SameHandle(s_out, s_unknown));
776   EXPECT_TRUE(SameHandle(s_prefix_out, s_u_2));
777   EXPECT_TRUE(c.MergePrefix(s_1_u_3, s_unknown, &s_out, &s_prefix_out).ok());
778   EXPECT_TRUE(SameHandle(s_out, s_1_u_3));
779   EXPECT_TRUE(SameHandle(s_prefix_out, s_unknown));
780 
781   EXPECT_TRUE(c.MergePrefix(s_1_u_3, s_u_2, &s_out, &s_prefix_out).ok());
782   EXPECT_FALSE(SameHandle(s_out, s_1_u_3));
783   EXPECT_EQ("[1,2]", c.DebugString(s_prefix_out));
784   EXPECT_EQ("[1,2,3]", c.DebugString(s_out));
785   EXPECT_TRUE(SameHandle(c.Dim(s_prefix_out, 0), c.Dim(s_out, 0)));
786   EXPECT_TRUE(SameHandle(c.Dim(s_out, 0), c.Dim(s_1_u_3, 0)));
787   EXPECT_TRUE(SameHandle(c.Dim(s_prefix_out, 1), c.Dim(s_out, 1)));
788   EXPECT_TRUE(SameHandle(c.Dim(s_prefix_out, 1), c.Dim(s_u_2, 1)));
789 
790   // Incompatible merges give errors and set outs to nullptr.
791   s_out = s_unknown;
792   s_prefix_out = s_unknown;
793   EXPECT_TRUE(str_util::StrContains(
794       c.MergePrefix(s_1_u_3, s_2_4, &s_out, &s_prefix_out).ToString(),
795       "Invalid argument: Dimensions must be equal, but are 1 and 2"));
796 
797   EXPECT_FALSE(IsSet(s_out));
798   EXPECT_FALSE(IsSet(s_prefix_out));
799 
800   s_out = s_unknown;
801   s_prefix_out = s_unknown;
802   EXPECT_TRUE(str_util::StrContains(
803       c.MergePrefix(s_2_4, s_1_u_3, &s_out, &s_prefix_out).ToString(),
804       "Invalid argument: Shape must be at least rank 3 but is rank 2"));
805   EXPECT_FALSE(IsSet(s_out));
806   EXPECT_FALSE(IsSet(s_prefix_out));
807 }
808 
TEST_F(ShapeInferenceTest,Subshape)809 TEST_F(ShapeInferenceTest, Subshape) {
810   NodeDef def;
811   InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
812                      {S({1, 2, 3, -1, 5}), Unknown()}, {}, {}, {});
813 
814   ShapeHandle unknown = c.input(1);
815   ShapeHandle out;
816   EXPECT_TRUE(c.Subshape(unknown, 0, &out).ok());
817   EXPECT_EQ("?", c.DebugString(out));
818   EXPECT_TRUE(SameHandle(out, unknown));
819   EXPECT_TRUE(c.Subshape(unknown, 1, &out).ok());
820   EXPECT_EQ("?", c.DebugString(out));
821   EXPECT_FALSE(SameHandle(out, unknown));
822   EXPECT_TRUE(c.Subshape(unknown, 200, &out).ok());
823   EXPECT_EQ("?", c.DebugString(out));
824   EXPECT_FALSE(SameHandle(out, unknown));
825 
826   const int kFullRank = 5;
827   ShapeHandle out_arr[4];
828   auto in0 = c.input(0);
829   EXPECT_TRUE(c.Subshape(in0, 0, &out).ok());
830   EXPECT_EQ("[1,2,3,?,5]", c.DebugString(out));
831   EXPECT_TRUE(SameHandle(out, in0));
832   EXPECT_EQ(kFullRank, c.Rank(out));
833   for (int start = 0; start <= kFullRank + 1; ++start) {
834     for (int end = start; end <= kFullRank + 1; ++end) {
835       // Get subshapes using different start and end values that give the same
836       // range.
837       const int neg_start =
838           start >= kFullRank ? kFullRank : (start - kFullRank);
839       const int neg_end = end >= kFullRank ? kFullRank : (end - kFullRank);
840       ASSERT_TRUE(c.Subshape(in0, start, end, &out_arr[0]).ok());
841       ASSERT_TRUE(c.Subshape(in0, neg_start, end, &out_arr[1]).ok());
842       ASSERT_TRUE(c.Subshape(in0, start, neg_end, &out_arr[2]).ok());
843       ASSERT_TRUE(c.Subshape(in0, neg_start, neg_end, &out_arr[3]).ok());
844 
845       // Verify all computed subshapes.
846       for (int arr_idx = 0; arr_idx < 4; ++arr_idx) {
847         out = out_arr[arr_idx];
848         ASSERT_EQ(std::min(kFullRank, end) - std::min(kFullRank, start),
849                   c.Rank(out))
850             << "start: " << start << " end: " << end << " arr_idx: " << arr_idx
851             << " in0: " << c.DebugString(in0) << " out: " << c.DebugString(out);
852         for (int d = 0; d < c.Rank(out); ++d) {
853           EXPECT_TRUE(SameHandle(c.Dim(in0, start + d), c.Dim(out, d)))
854               << "arr_idx: " << arr_idx;
855         }
856       }
857     }
858   }
859 
860   // Errors.
861   out = unknown;
862   EXPECT_TRUE(str_util::StrContains(
863       c.Subshape(in0, 6, -3, &out).ToString(),
864       "Invalid argument: Subshape must have computed start <= end, but is 5 "
865       "and 2 (computed from start 6 and end -3 over shape with rank 5)"));
866   EXPECT_FALSE(IsSet(out));
867   out = unknown;
868   EXPECT_TRUE(str_util::StrContains(c.Subshape(in0, -50, 100, &out).ToString(),
869                                     "Invalid argument: Subshape start out of "
870                                     "bounds: -50, for shape with rank 5"));
871 
872   EXPECT_FALSE(IsSet(out));
873   out = unknown;
874   EXPECT_TRUE(str_util::StrContains(c.Subshape(in0, 0, -50, &out).ToString(),
875                                     "Invalid argument: Subshape end out of "
876                                     "bounds: -50, for shape with rank 5"));
877 
878   EXPECT_FALSE(IsSet(out));
879 }
880 
TEST_F(ShapeInferenceTest,Concatenate)881 TEST_F(ShapeInferenceTest, Concatenate) {
882   NodeDef def;
883   InferenceContext c(kVersion, &def, MakeOpDef(3, 2),
884                      {S({1, -1, 3}), S({4, 5}), Unknown()}, {}, {}, {});
885 
886   auto in0 = c.input(0);
887   auto in1 = c.input(1);
888   ShapeHandle unknown = c.input(2);
889   ShapeHandle out;
890   EXPECT_TRUE(c.Concatenate(unknown, unknown, &out).ok());
891   EXPECT_EQ("?", c.DebugString(out));
892   EXPECT_FALSE(SameHandle(out, unknown));
893   EXPECT_TRUE(c.Concatenate(unknown, in0, &out).ok());
894   EXPECT_EQ("?", c.DebugString(out));
895   EXPECT_FALSE(SameHandle(out, unknown));
896 
897   EXPECT_TRUE(c.Concatenate(in0, in1, &out).ok());
898   EXPECT_EQ("[1,?,3,4,5]", c.DebugString(out));
899   int out_i = 0;
900   for (int i = 0; i < c.Rank(in0); ++i, ++out_i) {
901     EXPECT_TRUE(SameHandle(c.Dim(in0, i), c.Dim(out, out_i)));
902   }
903   for (int i = 0; i < c.Rank(in1); ++i, ++out_i) {
904     EXPECT_TRUE(SameHandle(c.Dim(in1, i), c.Dim(out, out_i)));
905   }
906 }
907 
TEST_F(ShapeInferenceTest,ReplaceDim)908 TEST_F(ShapeInferenceTest, ReplaceDim) {
909   NodeDef def;
910   InferenceContext c(kVersion, &def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()},
911                      {}, {}, {});
912 
913   auto in = c.input(0);
914   auto unknown = c.input(1);
915 
916   ShapeHandle replaced;
917   EXPECT_TRUE(c.ReplaceDim(in, 0, c.Dim(in, 1), &replaced).ok());
918   EXPECT_EQ("[2,2,3]", c.DebugString(replaced));
919   EXPECT_TRUE(c.ReplaceDim(in, 2, c.Dim(in, 1), &replaced).ok());
920   EXPECT_EQ("[1,2,2]", c.DebugString(replaced));
921   EXPECT_TRUE(c.ReplaceDim(in, 1, c.Dim(in, 2), &replaced).ok());
922   EXPECT_EQ("[1,3,3]", c.DebugString(replaced));
923   EXPECT_TRUE(c.ReplaceDim(unknown, 0, c.Dim(in, 1), &replaced).ok());
924   EXPECT_EQ("?", c.DebugString(replaced));
925 
926   // Negative indexing.
927   EXPECT_TRUE(c.ReplaceDim(in, -1, c.Dim(in, 1), &replaced).ok());
928   EXPECT_EQ("[1,2,2]", c.DebugString(replaced));
929   EXPECT_TRUE(c.ReplaceDim(unknown, -1, c.Dim(in, 1), &replaced).ok());
930   EXPECT_EQ("?", c.DebugString(replaced));
931 
932   // out of range indexing.
933   EXPECT_FALSE(c.ReplaceDim(in, 3, c.Dim(in, 1), &replaced).ok());
934   EXPECT_FALSE(IsSet(replaced));
935   replaced = in;
936   EXPECT_FALSE(c.ReplaceDim(in, -4, c.Dim(in, 1), &replaced).ok());
937   EXPECT_FALSE(IsSet(replaced));
938 }
939 
TEST_F(ShapeInferenceTest,MakeShape)940 TEST_F(ShapeInferenceTest, MakeShape) {
941   NodeDef def;
942   InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {},
943                      {}, {});
944 
945   std::vector<DimensionHandle> dims;
946   auto in0 = c.input(0);
947   const int rank = c.Rank(in0);
948   dims.reserve(rank);
949   for (int i = 0; i < rank; ++i) {
950     dims.push_back(c.Dim(in0, rank - i - 1));
951   }
952 
953   auto s = c.MakeShape(dims);
954   EXPECT_EQ("[5,?,3,2,1]", c.DebugString(s));
955   EXPECT_TRUE(SameHandle(c.Dim(s, 0), c.Dim(in0, rank - 1)));
956 
957   auto s2 = c.MakeShape(dims);
958   EXPECT_FALSE(SameHandle(s, s2));
959   EXPECT_TRUE(SameHandle(c.Dim(s2, 0), c.Dim(in0, rank - 1)));
960 
961   auto s3 = c.MakeShape({1, 2, dims[2]});
962   EXPECT_FALSE(SameHandle(s, s3));
963   EXPECT_EQ("[1,2,3]", c.DebugString(s3));
964 }
965 
TEST_F(ShapeInferenceTest,UnknownShape)966 TEST_F(ShapeInferenceTest, UnknownShape) {
967   NodeDef def;
968   std::vector<ShapeHandle> empty;
969   InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
970 
971   auto u0 = c.UnknownShape();
972   auto u1 = c.UnknownShape();
973   EXPECT_EQ("?", c.DebugString(u0));
974   EXPECT_EQ("?", c.DebugString(u1));
975   EXPECT_FALSE(SameHandle(u0, u1));
976 }
977 
TEST_F(ShapeInferenceTest,KnownShapeToProto)978 TEST_F(ShapeInferenceTest, KnownShapeToProto) {
979   NodeDef def;
980   std::vector<ShapeHandle> empty;
981   InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
982 
983   auto s = c.MakeShape({1, 2, 3});
984   TensorShapeProto proto;
985   c.ShapeHandleToProto(s, &proto);
986 
987   EXPECT_FALSE(proto.unknown_rank());
988   EXPECT_EQ(3, proto.dim_size());
989   EXPECT_EQ(1, proto.dim(0).size());
990 }
991 
TEST_F(ShapeInferenceTest,UnknownShapeToProto)992 TEST_F(ShapeInferenceTest, UnknownShapeToProto) {
993   NodeDef def;
994   std::vector<ShapeHandle> empty;
995   InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
996 
997   auto u0 = c.UnknownShape();
998   TensorShapeProto proto;
999   c.ShapeHandleToProto(u0, &proto);
1000 
1001   EXPECT_TRUE(proto.unknown_rank());
1002   EXPECT_EQ(0, proto.dim_size());
1003 }
1004 
TEST_F(ShapeInferenceTest,Scalar)1005 TEST_F(ShapeInferenceTest, Scalar) {
1006   NodeDef def;
1007   std::vector<ShapeHandle> empty;
1008   InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
1009 
1010   auto s0 = c.Scalar();
1011   EXPECT_EQ("[]", c.DebugString(s0));
1012   auto s1 = c.Scalar();
1013   EXPECT_EQ("[]", c.DebugString(s1));
1014 }
1015 
TEST_F(ShapeInferenceTest,Vector)1016 TEST_F(ShapeInferenceTest, Vector) {
1017   NodeDef def;
1018   std::vector<ShapeHandle> empty;
1019   InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
1020 
1021   auto s0 = c.Vector(1);
1022   EXPECT_EQ("[1]", c.DebugString(s0));
1023   auto s1 = c.Vector(InferenceContext::kUnknownDim);
1024   EXPECT_EQ("[?]", c.DebugString(s1));
1025 
1026   auto d1 = c.UnknownDim();
1027   auto s2 = c.Vector(d1);
1028   EXPECT_EQ("[?]", c.DebugString(s2));
1029   EXPECT_TRUE(SameHandle(d1, c.Dim(s2, 0)));
1030 }
1031 
TEST_F(ShapeInferenceTest,Matrix)1032 TEST_F(ShapeInferenceTest, Matrix) {
1033   NodeDef def;
1034   std::vector<ShapeHandle> empty;
1035   InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
1036 
1037   auto s0 = c.Matrix(1, 2);
1038   EXPECT_EQ("[1,2]", c.DebugString(s0));
1039   auto s1 = c.Matrix(0, InferenceContext::kUnknownDim);
1040   EXPECT_EQ("[0,?]", c.DebugString(s1));
1041 
1042   auto d1 = c.UnknownDim();
1043   auto d2 = c.UnknownDim();
1044   auto s2 = c.Matrix(d1, d2);
1045   EXPECT_EQ("[?,?]", c.DebugString(s2));
1046   EXPECT_TRUE(SameHandle(d1, c.Dim(s2, 0)));
1047   EXPECT_TRUE(SameHandle(d2, c.Dim(s2, 1)));
1048 
1049   auto s3 = c.Matrix(d1, 100);
1050   EXPECT_EQ("[?,100]", c.DebugString(s3));
1051   EXPECT_TRUE(SameHandle(d1, c.Dim(s2, 0)));
1052 }
1053 
TEST_F(ShapeInferenceTest,MakeShapeFromShapeTensor)1054 TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
1055   auto create = [&](Tensor* t) {
1056     NodeDef def;
1057     InferenceContext c(kVersion, &def, MakeOpDef(1, 0), {Unknown()}, {t}, {},
1058                        {});
1059     ShapeHandle out;
1060     Status s = c.MakeShapeFromShapeTensor(0, &out);
1061     if (s.ok()) {
1062       return c.DebugString(out);
1063     } else {
1064       EXPECT_FALSE(IsSet(out));
1065       return s.error_message();
1066     }
1067   };
1068 
1069   Tensor t;
1070   EXPECT_EQ("?", create(nullptr));
1071 
1072   t = ::tensorflow::test::AsTensor<int32>({1, 2, 3});
1073   EXPECT_EQ("[1,2,3]", create(&t));
1074 
1075   t = ::tensorflow::test::AsTensor<int64>({3, 2, 1});
1076   EXPECT_EQ("[3,2,1]", create(&t));
1077 
1078   t = ::tensorflow::test::AsTensor<int64>({3, -1, 1});
1079   EXPECT_EQ("[3,?,1]", create(&t));
1080 
1081   t = ::tensorflow::test::AsTensor<int64>({});
1082   EXPECT_EQ("[]", create(&t));
1083 
1084   // Test negative scalar
1085   t = ::tensorflow::test::AsScalar<int32>(-1);
1086   EXPECT_EQ("?", create(&t));
1087 
1088   t = ::tensorflow::test::AsTensor<float>({1, 2, 3});
1089   EXPECT_TRUE(str_util::StrContains(
1090       create(&t), "Input tensor must be int32 or int64, but was float"));
1091 
1092   t = ::tensorflow::test::AsScalar<int32>(1);
1093   auto s_scalar = create(&t);
1094   EXPECT_TRUE(str_util::StrContains(
1095       s_scalar,
1096       "Input tensor must be rank 1, or if its rank 0 it must have value -1"))
1097       << s_scalar;
1098 
1099   t = ::tensorflow::test::AsTensor<int32>({1, 2}, TensorShape{2, 1});
1100   auto s_matrix = create(&t);
1101   EXPECT_TRUE(str_util::StrContains(
1102       s_matrix, "Input tensor must be rank 1, but was rank 2"))
1103       << s_matrix;
1104 
1105   // Test negative values for the dims.
1106   t = ::tensorflow::test::AsTensor<int64>({3, -2, 1});
1107   EXPECT_TRUE(str_util::StrContains(
1108       create(&t), "Invalid value in tensor used for shape: -2"));
1109 
1110   // Test negative values for the dims.
1111   t = ::tensorflow::test::AsTensor<int32>({3, -2, 1});
1112   EXPECT_TRUE(str_util::StrContains(
1113       create(&t), "Invalid value in tensor used for shape: -2"));
1114 
1115   // Test when the input shape is wrong.
1116   {
1117     NodeDef def;
1118     InferenceContext c(kVersion, &def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr},
1119                        {}, {});
1120     ShapeHandle out;
1121     EXPECT_EQ("Shape must be rank 1 but is rank 2",
1122               c.MakeShapeFromShapeTensor(0, &out).error_message());
1123   }
1124 }
1125 
TEST_F(ShapeInferenceTest,MakeShapeFromPartialTensorShape)1126 TEST_F(ShapeInferenceTest, MakeShapeFromPartialTensorShape) {
1127   NodeDef def;
1128   std::vector<ShapeHandle> empty;
1129   InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
1130 
1131   // With an unknown rank.
1132   ShapeHandle out;
1133   TF_ASSERT_OK(c.MakeShapeFromPartialTensorShape(PartialTensorShape(), &out));
1134   EXPECT_EQ("?", c.DebugString(out));
1135 
1136   // With a known rank.
1137   TF_ASSERT_OK(
1138       c.MakeShapeFromPartialTensorShape(PartialTensorShape({0}), &out));
1139   EXPECT_EQ("[0]", c.DebugString(out));
1140   TF_ASSERT_OK(c.MakeShapeFromPartialTensorShape(
1141       PartialTensorShape({0, -1, 1000}), &out));
1142   EXPECT_EQ("[0,?,1000]", c.DebugString(out));
1143 }
1144 
TEST_F(ShapeInferenceTest,MakeShapeFromTensorShape)1145 TEST_F(ShapeInferenceTest, MakeShapeFromTensorShape) {
1146   NodeDef def;
1147   std::vector<ShapeHandle> empty;
1148   InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
1149 
1150   ShapeHandle out;
1151   TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape(), &out));
1152   EXPECT_EQ("[]", c.DebugString(out));
1153   TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape({0}), &out));
1154   EXPECT_EQ("[0]", c.DebugString(out));
1155   TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape({0, 7, 1000}), &out));
1156   EXPECT_EQ("[0,7,1000]", c.DebugString(out));
1157 }
1158 
TEST_F(ShapeInferenceTest,MakeShapeFromShapeProto)1159 TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
1160   NodeDef def;
1161   std::vector<ShapeHandle> empty;
1162   InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
1163   TensorShapeProto proto;
1164 
1165   // With a set unknown rank.
1166   ShapeHandle out;
1167   proto.set_unknown_rank(true);
1168   EXPECT_TRUE(c.MakeShapeFromShapeProto(proto, &out).ok());
1169   EXPECT_EQ("?", c.DebugString(out));
1170   proto.add_dim()->set_size(0);
1171   EXPECT_TRUE(str_util::StrContains(
1172       c.MakeShapeFromShapeProto(proto, &out).error_message(),
1173       "An unknown shape must not have any dimensions set."));
1174   EXPECT_FALSE(IsSet(out));
1175 
1176   // With known rank.
1177   proto.set_unknown_rank(false);
1178   EXPECT_TRUE(c.MakeShapeFromShapeProto(proto, &out).ok());
1179   EXPECT_EQ("[0]", c.DebugString(out));
1180   proto.add_dim()->set_size(-1);
1181   proto.add_dim()->set_size(1000);
1182   EXPECT_TRUE(c.MakeShapeFromShapeProto(proto, &out).ok());
1183   EXPECT_EQ("[0,?,1000]", c.DebugString(out));
1184 
1185   // With invalid dimension value.
1186   proto.add_dim()->set_size(-2);
1187   EXPECT_TRUE(str_util::StrContains(
1188       c.MakeShapeFromShapeProto(proto, &out).error_message(),
1189       "Shape [0,?,1000,-2] has dimensions with values below -1 "
1190       "(where -1 means unknown)"));
1191 
1192   EXPECT_FALSE(IsSet(out));
1193 }
1194 
TEST_F(ShapeInferenceTest,MakeDim)1195 TEST_F(ShapeInferenceTest, MakeDim) {
1196   NodeDef def;
1197   std::vector<ShapeHandle> empty;
1198   InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
1199 
1200   auto d0 = c.MakeDim(1);
1201   auto d1 = c.MakeDim(1);
1202   auto d2 = c.MakeDim(2);
1203   EXPECT_EQ("1", c.DebugString(d0));
1204   EXPECT_EQ("1", c.DebugString(d1));
1205   EXPECT_FALSE(SameHandle(d0, d1));
1206   EXPECT_EQ("2", c.DebugString(d2));
1207 }
1208 
TEST_F(ShapeInferenceTest,UnknownDim)1209 TEST_F(ShapeInferenceTest, UnknownDim) {
1210   NodeDef def;
1211   std::vector<ShapeHandle> empty;
1212   InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
1213 
1214   auto d0 = c.UnknownDim();
1215   auto d1 = c.UnknownDim();
1216   EXPECT_EQ("?", c.DebugString(d0));
1217   EXPECT_EQ("?", c.DebugString(d1));
1218   EXPECT_FALSE(SameHandle(d0, d1));
1219 }
1220 
TEST_F(ShapeInferenceTest,UnknownShapeOfRank)1221 TEST_F(ShapeInferenceTest, UnknownShapeOfRank) {
1222   NodeDef def;
1223   std::vector<ShapeHandle> empty;
1224   InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
1225 
1226   auto unknown_shape_of_rank_3 = c.UnknownShapeOfRank(3);
1227   EXPECT_EQ("[?,?,?]", c.DebugString(unknown_shape_of_rank_3));
1228 
1229   auto unknown_shape_of_rank_0 = c.UnknownShapeOfRank(0);
1230   EXPECT_EQ("[]", c.DebugString(unknown_shape_of_rank_0));
1231 }
1232 
TEST_F(ShapeInferenceTest,InputTensors)1233 TEST_F(ShapeInferenceTest, InputTensors) {
1234   const Tensor t1 = tensorflow::test::AsTensor<float>({10});
1235   const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30});
1236   NodeDef def;
1237   InferenceContext c(kVersion, &def, MakeOpDef(3, 2), {S({1}), S({2}), S({3})},
1238                      {&t1, &t2}, {}, {});
1239 
1240   EXPECT_TRUE(c.input_tensor(0) == &t1);
1241   EXPECT_TRUE(c.input_tensor(1) == &t2);
1242   EXPECT_TRUE(c.input_tensor(2) == nullptr);
1243 }
1244 
TEST_F(ShapeInferenceTest,MakeDimForScalarInput)1245 TEST_F(ShapeInferenceTest, MakeDimForScalarInput) {
1246   Tensor t1 = tensorflow::test::AsScalar<int32>(20);
1247   Tensor t2 = tensorflow::test::AsScalar<int32>(-1);
1248   NodeDef def;
1249   InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})},
1250                      {&t1, &t2}, {}, {});
1251 
1252   DimensionHandle d;
1253   EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok());
1254   EXPECT_EQ("20", c.DebugString(d));
1255 
1256   EXPECT_TRUE(
1257       str_util::StrContains(c.MakeDimForScalarInput(1, &d).error_message(),
1258                             "Dimension size, given by scalar input 1, must be "
1259                             "non-negative but is -1"));
1260 
1261   // Same tests, with int64 values.
1262   t1 = tensorflow::test::AsScalar<int64>(20);
1263   t2 = tensorflow::test::AsScalar<int64>(-1);
1264   EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok());
1265   EXPECT_EQ("20", c.DebugString(d));
1266 
1267   EXPECT_TRUE(
1268       str_util::StrContains(c.MakeDimForScalarInput(1, &d).error_message(),
1269                             "Dimension size, given by scalar input 1, must be "
1270                             "non-negative but is -1"));
1271 }
1272 
TEST_F(ShapeInferenceTest,GetAttr)1273 TEST_F(ShapeInferenceTest, GetAttr) {
1274   OpRegistrationData op_reg_data;
1275   op_reg_data.op_def = MakeOpDef(0, 2);
1276   NodeDef def;
1277   CHECK(NodeDefBuilder("dummy", &op_reg_data.op_def)
1278             .Attr("foo", "bar")
1279             .Finalize(&def)
1280             .ok());
1281 
1282   std::vector<ShapeHandle> empty;
1283   InferenceContext c(kVersion, &def, op_reg_data.op_def, empty, {}, {}, {});
1284   string value;
1285   EXPECT_TRUE(c.GetAttr("foo", &value).ok());
1286   EXPECT_EQ("bar", value);
1287 }
1288 
TEST_F(ShapeInferenceTest,Divide)1289 TEST_F(ShapeInferenceTest, Divide) {
1290   NodeDef def;
1291   InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {},
1292                      {}, {});
1293 
1294   auto s = c.input(0);
1295   auto d_6 = c.Dim(s, 0);
1296   auto d_unknown = c.Dim(s, 1);
1297   auto d_1 = c.Dim(s, 2);
1298   auto d_2 = c.Dim(s, 3);
1299   auto d_0 = c.Dim(s, 4);
1300   bool evenly_divisible = true;
1301 
1302   // Dividing unknown by non-1 gives new unknown.
1303   DimensionHandle out;
1304   EXPECT_TRUE(c.Divide(d_unknown, 2, evenly_divisible, &out).ok());
1305   EXPECT_EQ("?", c.DebugString(out));
1306   EXPECT_FALSE(SameHandle(out, d_unknown));
1307 
1308   // Dividing anything by 1 returns the input.
1309   EXPECT_TRUE(c.Divide(d_unknown, 1, evenly_divisible, &out).ok());
1310   EXPECT_TRUE(SameHandle(out, d_unknown));
1311   EXPECT_TRUE(c.Divide(d_6, 1, evenly_divisible, &out).ok());
1312   EXPECT_TRUE(SameHandle(out, d_6));
1313   EXPECT_TRUE(c.Divide(d_unknown, d_1, evenly_divisible, &out).ok());
1314   EXPECT_TRUE(SameHandle(out, d_unknown));
1315   EXPECT_TRUE(c.Divide(d_6, d_1, evenly_divisible, &out).ok());
1316   EXPECT_TRUE(SameHandle(out, d_6));
1317 
1318   EXPECT_TRUE(c.Divide(d_6, 2, evenly_divisible, &out).ok());
1319   EXPECT_EQ("3", c.DebugString(out));
1320   EXPECT_TRUE(c.Divide(d_6, d_2, evenly_divisible, &out).ok());
1321   EXPECT_EQ("3", c.DebugString(out));
1322 
1323   EXPECT_TRUE(str_util::StrContains(
1324       c.Divide(d_6, 5, evenly_divisible, &out).error_message(),
1325       "Dimension size must be evenly divisible by 5 but is 6"));
1326 
1327   EXPECT_TRUE(str_util::StrContains(
1328       c.Divide(d_6, 0, evenly_divisible, &out).error_message(),
1329       "Divisor must be positive but is 0"));
1330   EXPECT_TRUE(str_util::StrContains(
1331       c.Divide(d_6, d_0, evenly_divisible, &out).error_message(),
1332       "Divisor must be positive but is 0"));
1333 
1334   EXPECT_TRUE(str_util::StrContains(
1335       c.Divide(d_6, -1, evenly_divisible, &out).error_message(),
1336       "Divisor must be positive but is -1"));
1337 
1338   // Repeat error cases above with evenly_divisible=false.
1339   evenly_divisible = false;
1340   EXPECT_TRUE(c.Divide(d_6, 5, evenly_divisible, &out).ok());
1341   EXPECT_EQ("1", c.DebugString(out));
1342 
1343   EXPECT_TRUE(str_util::StrContains(
1344       c.Divide(d_6, 0, evenly_divisible, &out).error_message(),
1345       "Divisor must be positive but is 0"));
1346 
1347   EXPECT_TRUE(str_util::StrContains(
1348       c.Divide(d_6, -1, evenly_divisible, &out).error_message(),
1349       "Divisor must be positive but is -1"));
1350 }
1351 
TEST_F(ShapeInferenceTest,Add)1352 TEST_F(ShapeInferenceTest, Add) {
1353   NodeDef def;
1354   InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {},
1355                      {});
1356 
1357   auto s = c.input(0);
1358   auto d_6 = c.Dim(s, 0);
1359   auto d_unknown = c.Dim(s, 1);
1360   auto d_0 = c.Dim(s, 2);
1361 
1362   // Adding non-zero to unknown gives new unknown.
1363   DimensionHandle out;
1364   EXPECT_TRUE(c.Add(d_unknown, 1, &out).ok());
1365   EXPECT_EQ("?", c.DebugString(out));
1366   EXPECT_FALSE(SameHandle(out, d_unknown));
1367 
1368   // Adding 0 to anything gives input.
1369   EXPECT_TRUE(c.Add(d_unknown, 0, &out).ok());
1370   EXPECT_TRUE(SameHandle(out, d_unknown));
1371   EXPECT_TRUE(c.Add(d_6, 0, &out).ok());
1372   EXPECT_TRUE(SameHandle(out, d_6));
1373 
1374   // Adding dimension with value 0 to anything gives input.
1375   EXPECT_TRUE(c.Add(d_unknown, c.MakeDim(0ll), &out).ok());
1376   EXPECT_TRUE(SameHandle(out, d_unknown));
1377   EXPECT_TRUE(c.Add(d_6, c.MakeDim(0ll), &out).ok());
1378   EXPECT_TRUE(SameHandle(out, d_6));
1379 
1380   // Test addition.
1381   EXPECT_TRUE(c.Add(d_6, 2, &out).ok());
1382   EXPECT_EQ("8", c.DebugString(out));
1383   EXPECT_TRUE(c.Add(d_6, std::numeric_limits<int64>::max() - 6, &out).ok());
1384   EXPECT_EQ(std::numeric_limits<int64>::max(), c.Value(out));
1385 
1386   // Test addition using dimension as second value.
1387   EXPECT_TRUE(c.Add(d_6, c.MakeDim(2), &out).ok());
1388   EXPECT_EQ("8", c.DebugString(out));
1389   EXPECT_TRUE(
1390       c.Add(d_6, c.MakeDim(std::numeric_limits<int64>::max() - 6), &out).ok());
1391   EXPECT_EQ(std::numeric_limits<int64>::max(), c.Value(out));
1392   EXPECT_TRUE(c.Add(d_6, c.UnknownDim(), &out).ok());
1393   EXPECT_EQ("?", c.DebugString(out));
1394   EXPECT_TRUE(c.Add(d_0, d_6, &out).ok());
1395   EXPECT_TRUE(SameHandle(out, d_6));
1396 
1397   EXPECT_TRUE(str_util::StrContains(
1398       c.Add(d_6, std::numeric_limits<int64>::max() - 5, &out).error_message(),
1399       "Dimension size overflow from adding 6 and 9223372036854775802"));
1400 }
1401 
TEST_F(ShapeInferenceTest,Subtract)1402 TEST_F(ShapeInferenceTest, Subtract) {
1403   NodeDef def;
1404   InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {},
1405                      {}, {});
1406 
1407   auto s = c.input(0);
1408   auto d_6 = c.Dim(s, 0);
1409   auto d_unknown = c.Dim(s, 1);
1410   auto d_0 = c.Dim(s, 2);
1411   auto d_5 = c.Dim(s, 3);
1412 
1413   // Subtracting non-zero from unknown gives new unknown.
1414   DimensionHandle out;
1415   EXPECT_TRUE(c.Subtract(d_unknown, 1, &out).ok());
1416   EXPECT_EQ("?", c.DebugString(out));
1417   EXPECT_FALSE(SameHandle(out, d_unknown));
1418 
1419   // Subtracting 0 from anything gives input.
1420   EXPECT_TRUE(c.Subtract(d_unknown, 0ll, &out).ok());
1421   EXPECT_TRUE(SameHandle(out, d_unknown));
1422   EXPECT_TRUE(c.Subtract(d_6, 0ll, &out).ok());
1423   EXPECT_TRUE(SameHandle(out, d_6));
1424 
1425   // Subtracting dimension with value 0 from anything gives input.
1426   EXPECT_TRUE(c.Subtract(d_unknown, c.MakeDim(0ll), &out).ok());
1427   EXPECT_TRUE(SameHandle(out, d_unknown));
1428   EXPECT_TRUE(c.Subtract(d_6, c.MakeDim(0ll), &out).ok());
1429   EXPECT_TRUE(SameHandle(out, d_6));
1430 
1431   // Test subtraction.
1432   EXPECT_TRUE(c.Subtract(d_6, 2, &out).ok());
1433   EXPECT_EQ("4", c.DebugString(out));
1434   EXPECT_TRUE(c.Subtract(d_6, 6, &out).ok());
1435   EXPECT_EQ("0", c.DebugString(out));
1436 
1437   // Test subtraction using dimension as second value.
1438   EXPECT_TRUE(c.Subtract(d_6, c.MakeDim(2), &out).ok());
1439   EXPECT_EQ("4", c.DebugString(out));
1440   EXPECT_TRUE(c.Subtract(d_6, d_5, &out).ok());
1441   EXPECT_EQ("1", c.DebugString(out));
1442   EXPECT_TRUE(c.Subtract(d_6, c.UnknownDim(), &out).ok());
1443   EXPECT_EQ("?", c.DebugString(out));
1444   EXPECT_TRUE(c.Subtract(d_6, d_0, &out).ok());
1445   EXPECT_TRUE(SameHandle(out, d_6));
1446 
1447   EXPECT_TRUE(str_util::StrContains(
1448       c.Subtract(d_5, d_6, &out).error_message(),
1449       "Negative dimension size caused by subtracting 6 from 5"));
1450 }
1451 
TEST_F(ShapeInferenceTest,Multiply)1452 TEST_F(ShapeInferenceTest, Multiply) {
1453   NodeDef def;
1454   InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {},
1455                      {}, {});
1456 
1457   auto s = c.input(0);
1458   auto d_6 = c.Dim(s, 0);
1459   auto d_unknown = c.Dim(s, 1);
1460   auto d_0 = c.Dim(s, 2);
1461   auto d_1 = c.Dim(s, 3);
1462 
1463   // Multiplying non-zero to unknown gives new unknown.
1464   DimensionHandle out;
1465   EXPECT_TRUE(c.Multiply(d_unknown, 2, &out).ok());
1466   EXPECT_EQ("?", c.DebugString(out));
1467 
1468   // Multiplying 0 to anything gives 0.
1469   EXPECT_TRUE(c.Multiply(d_unknown, 0, &out).ok());
1470   EXPECT_EQ("0", c.DebugString(out));
1471   EXPECT_TRUE(c.Multiply(d_unknown, d_0, &out).ok());
1472   EXPECT_EQ("0", c.DebugString(out));
1473   EXPECT_TRUE(c.Multiply(d_0, d_unknown, &out).ok());
1474   EXPECT_EQ("0", c.DebugString(out));
1475 
1476   // Multiplying 1 to anything gives the original.
1477   // (unknown -> unknown)
1478   EXPECT_TRUE(c.Multiply(d_unknown, 1, &out).ok());
1479   EXPECT_TRUE(SameHandle(d_unknown, out));
1480   EXPECT_TRUE(c.Multiply(d_unknown, d_1, &out).ok());
1481   EXPECT_TRUE(SameHandle(d_unknown, out));
1482   EXPECT_TRUE(c.Multiply(d_1, d_unknown, &out).ok());
1483   EXPECT_TRUE(SameHandle(d_unknown, out));
1484   // (known -> known)
1485   EXPECT_TRUE(c.Multiply(d_6, 1, &out).ok());
1486   EXPECT_TRUE(SameHandle(d_6, out));
1487   EXPECT_TRUE(c.Multiply(d_6, d_1, &out).ok());
1488   EXPECT_TRUE(SameHandle(d_6, out));
1489   EXPECT_TRUE(c.Multiply(d_1, d_6, &out).ok());
1490   EXPECT_TRUE(SameHandle(d_6, out));
1491 
1492   // Test multiplication.
1493   EXPECT_TRUE(c.Multiply(d_6, 2, &out).ok());
1494   EXPECT_EQ("12", c.DebugString(out));
1495   EXPECT_TRUE(c.Multiply(d_6, 6, &out).ok());
1496   EXPECT_EQ("36", c.DebugString(out));
1497 
1498   // Test multiplication using dimension as second value.
1499   EXPECT_TRUE(c.Multiply(d_6, c.MakeDim(2), &out).ok());
1500   EXPECT_EQ("12", c.DebugString(out));
1501   EXPECT_TRUE(c.Multiply(d_6, c.UnknownDim(), &out).ok());
1502   EXPECT_EQ("?", c.DebugString(out));
1503 }
1504 
TEST_F(ShapeInferenceTest,FullyDefined)1505 TEST_F(ShapeInferenceTest, FullyDefined) {
1506   NodeDef def;
1507   std::vector<ShapeHandle> empty;
1508   InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
1509 
1510   // No rank or missing dimension information should return false.
1511   EXPECT_FALSE(c.FullyDefined(c.UnknownShape()));
1512   EXPECT_FALSE(c.FullyDefined(c.Matrix(c.MakeDim(1), c.UnknownDim())));
1513 
1514   // Return true if all information exists.
1515   EXPECT_TRUE(c.FullyDefined(c.Matrix(c.MakeDim(1), c.MakeDim(2))));
1516   EXPECT_TRUE(c.FullyDefined(c.Scalar()));
1517 }
1518 
TEST_F(ShapeInferenceTest,Min)1519 TEST_F(ShapeInferenceTest, Min) {
1520   NodeDef def;
1521   InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {},
1522                      {}, {});
1523 
1524   auto s = c.input(0);
1525   auto d_1 = c.Dim(s, 0);
1526   auto d_2 = c.Dim(s, 1);
1527   auto d_unknown = c.Dim(s, 2);
1528   auto d_0 = c.Dim(s, 3);
1529 
1530   // Minimum involving zero and unknown returns zero.
1531   DimensionHandle out;
1532   EXPECT_TRUE(c.Min(d_0, d_unknown, &out).ok());
1533   EXPECT_TRUE(SameHandle(d_0, out));
1534   EXPECT_TRUE(c.Min(d_unknown, d_0, &out).ok());
1535   EXPECT_TRUE(SameHandle(d_0, out));
1536   EXPECT_TRUE(c.Min(c.MakeDim(0ll), d_unknown, &out).ok());
1537   EXPECT_EQ("0", c.DebugString(out));
1538   EXPECT_TRUE(c.Min(d_unknown, 0ll, &out).ok());
1539   EXPECT_EQ("0", c.DebugString(out));
1540 
1541   // Minimum involving unknowns and non-zeros gives new unknown.
1542   EXPECT_TRUE(c.Min(d_unknown, d_unknown, &out).ok());
1543   EXPECT_EQ("?", c.DebugString(out));
1544   EXPECT_TRUE(c.Min(d_unknown, 1, &out).ok());
1545   EXPECT_EQ("?", c.DebugString(out));
1546   EXPECT_TRUE(c.Min(d_1, d_unknown, &out).ok());
1547   EXPECT_EQ("?", c.DebugString(out));
1548 
1549   // Minimum with constant second arg.
1550   EXPECT_TRUE(c.Min(d_1, 1, &out).ok());
1551   EXPECT_TRUE(SameHandle(d_1, out));
1552   EXPECT_TRUE(c.Min(d_1, 3, &out).ok());
1553   EXPECT_TRUE(SameHandle(d_1, out));
1554   EXPECT_TRUE(c.Min(d_2, 1, &out).ok());
1555   EXPECT_EQ("1", c.DebugString(out));
1556 
1557   // Minimum with two dimensions.
1558   EXPECT_TRUE(c.Min(d_1, d_1, &out).ok());
1559   EXPECT_TRUE(SameHandle(d_1, out));
1560   EXPECT_TRUE(c.Min(d_1, d_2, &out).ok());
1561   EXPECT_TRUE(SameHandle(d_1, out));
1562   EXPECT_TRUE(c.Min(d_2, d_1, &out).ok());
1563   EXPECT_TRUE(SameHandle(d_1, out));
1564   EXPECT_TRUE(c.Min(d_2, d_2, &out).ok());
1565   EXPECT_TRUE(SameHandle(d_2, out));
1566 }
1567 
TEST_F(ShapeInferenceTest,Max)1568 TEST_F(ShapeInferenceTest, Max) {
1569   NodeDef def;
1570   InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {},
1571                      {});
1572 
1573   auto s = c.input(0);
1574   auto d_1 = c.Dim(s, 0);
1575   auto d_2 = c.Dim(s, 1);
1576   auto d_unknown = c.Dim(s, 2);
1577 
1578   // Maximum involving unknowns gives new unknown.
1579   DimensionHandle out;
1580   EXPECT_TRUE(c.Max(d_unknown, d_unknown, &out).ok());
1581   EXPECT_EQ("?", c.DebugString(out));
1582   EXPECT_TRUE(c.Max(d_unknown, 1, &out).ok());
1583   EXPECT_EQ("?", c.DebugString(out));
1584   EXPECT_TRUE(c.Max(d_1, d_unknown, &out).ok());
1585   EXPECT_EQ("?", c.DebugString(out));
1586 
1587   // Maximum with constant second arg.
1588   EXPECT_TRUE(c.Max(d_1, 1, &out).ok());
1589   EXPECT_TRUE(SameHandle(d_1, out));
1590   EXPECT_TRUE(c.Max(d_2, 1, &out).ok());
1591   EXPECT_TRUE(SameHandle(d_2, out));
1592   EXPECT_TRUE(c.Max(d_2, 3, &out).ok());
1593   EXPECT_EQ("3", c.DebugString(out));
1594 
1595   // Maximum with two dimensions.
1596   EXPECT_TRUE(c.Max(d_1, d_1, &out).ok());
1597   EXPECT_TRUE(SameHandle(d_1, out));
1598   EXPECT_TRUE(c.Max(d_1, d_2, &out).ok());
1599   EXPECT_TRUE(SameHandle(d_2, out));
1600   EXPECT_TRUE(c.Max(d_2, d_1, &out).ok());
1601   EXPECT_TRUE(SameHandle(d_2, out));
1602   EXPECT_TRUE(c.Max(d_2, d_2, &out).ok());
1603   EXPECT_TRUE(SameHandle(d_2, out));
1604 }
1605 
TestMergeHandles(bool input_not_output)1606 void ShapeInferenceTest::TestMergeHandles(bool input_not_output) {
1607   NodeDef def;
1608   InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})}, {}, {},
1609                      {});
1610   auto make_shape = [&c](std::initializer_list<int64> dim_sizes) {
1611     ShapeHandle s;
1612     TF_CHECK_OK(c.MakeShapeFromPartialTensorShape(S(dim_sizes), &s));
1613     return s;
1614   };
1615   auto get_shapes_and_types_from_context = [&](int idx) {
1616     if (input_not_output) {
1617       return c.input_handle_shapes_and_types(idx);
1618     } else {
1619       return c.output_handle_shapes_and_types(idx);
1620     }
1621   };
1622   auto merge_shapes_and_types_to_context =
1623       [&](int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1624         if (input_not_output) {
1625           return c.MergeInputHandleShapesAndTypes(idx, shapes_and_types);
1626         } else {
1627           return c.MergeOutputHandleShapesAndTypes(idx, shapes_and_types);
1628         }
1629       };
1630 
1631   EXPECT_TRUE(get_shapes_and_types_from_context(0) == nullptr);
1632   EXPECT_TRUE(get_shapes_and_types_from_context(1) == nullptr);
1633 
1634   // First merge will take the input completely.
1635   std::vector<ShapeAndType> t{{make_shape({1, 2, 3}), DT_FLOAT},
1636                               {c.UnknownShape(), DT_INVALID},
1637                               {make_shape({4, 3, 2, 1}), DT_INT32}};
1638   ASSERT_TRUE(merge_shapes_and_types_to_context(0, t));
1639   ASSERT_TRUE(get_shapes_and_types_from_context(0) != nullptr);
1640   std::vector<ShapeAndType> v = *get_shapes_and_types_from_context(0);
1641   ASSERT_EQ(3, v.size());
1642   for (int i = 0; i < v.size(); ++i) {
1643     EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1644     EXPECT_EQ(t[i].dtype, v[i].dtype);
1645   }
1646 
1647   // Merge that fails because wrong number of values passed.
1648   // Fails, and no changes made.
1649   ASSERT_FALSE(merge_shapes_and_types_to_context(
1650       0, std::vector<ShapeAndType>{{make_shape({1, 2, 3}), DT_FLOAT}}));
1651   v = *get_shapes_and_types_from_context(0);
1652   ASSERT_EQ(3, v.size());
1653   for (int i = 0; i < v.size(); ++i) {
1654     EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1655     EXPECT_EQ(t[i].dtype, v[i].dtype);
1656   }
1657 
1658   // Only difference is in a mismatched shape. That is ignored,
1659   // and there are no other changes, so nothing is done.
1660   //
1661   // TODO(cwhipkey): in mismatch cases, change Merge*HandleShapesAndTypes to
1662   // return an error (separate error from 'refined' output)?
1663   auto t2 = t;
1664   t2[2].shape = make_shape({4, 3, 4, 1});
1665   ASSERT_FALSE(merge_shapes_and_types_to_context(0, t2));
1666   v = *get_shapes_and_types_from_context(0);
1667   ASSERT_EQ(3, v.size());
1668   for (int i = 0; i < v.size(); ++i) {
1669     EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1670     EXPECT_EQ(t[i].dtype, v[i].dtype);
1671   }
1672 
1673   // Only difference is in a mismatched dtype, but that cannot be
1674   // updated unless original dtype is DT_INVALID.
1675   t2 = t;
1676   t2[2].dtype = DT_FLOAT;
1677   ASSERT_FALSE(merge_shapes_and_types_to_context(0, t2));
1678   v = *get_shapes_and_types_from_context(0);
1679   ASSERT_EQ(3, v.size());
1680   for (int i = 0; i < v.size(); ++i) {
1681     EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1682     EXPECT_EQ(t[i].dtype, v[i].dtype);
1683   }
1684 
1685   // Difference is mergeable (new shape).
1686   t[1].shape = make_shape({1, 10});
1687   ASSERT_TRUE(merge_shapes_and_types_to_context(0, t));
1688   v = *get_shapes_and_types_from_context(0);
1689   ASSERT_EQ(3, v.size());
1690   for (int i = 0; i < v.size(); ++i) {
1691     EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1692     EXPECT_EQ(t[i].dtype, v[i].dtype);
1693   }
1694 
1695   // Difference is mergeable (new type).
1696   t[1].dtype = DT_DOUBLE;
1697   ASSERT_TRUE(merge_shapes_and_types_to_context(0, t));
1698   v = *get_shapes_and_types_from_context(0);
1699   ASSERT_EQ(3, v.size());
1700   for (int i = 0; i < v.size(); ++i) {
1701     EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1702     EXPECT_EQ(t[i].dtype, v[i].dtype);
1703   }
1704 
1705   // No difference.
1706   ASSERT_FALSE(merge_shapes_and_types_to_context(0, t));
1707 }
1708 
TEST_F(ShapeInferenceTest,MergeInputHandleShapesAndTypes)1709 TEST_F(ShapeInferenceTest, MergeInputHandleShapesAndTypes) {
1710   TestMergeHandles(true /* input_not_output */);
1711 }
1712 
TEST_F(ShapeInferenceTest,MergeOutputHandleShapesAndTypes)1713 TEST_F(ShapeInferenceTest, MergeOutputHandleShapesAndTypes) {
1714   TestMergeHandles(false /* input_not_output */);
1715 }
1716 
TestRelaxHandles(bool input_not_output)1717 void ShapeInferenceTest::TestRelaxHandles(bool input_not_output) {
1718   NodeDef def;
1719   InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})}, {}, {},
1720                      {});
1721   auto make_shape = [&c](std::initializer_list<int64> dim_sizes) {
1722     ShapeHandle s;
1723     TF_CHECK_OK(c.MakeShapeFromPartialTensorShape(S(dim_sizes), &s));
1724     return s;
1725   };
1726   auto get_shapes_and_types_from_context = [&](int idx) {
1727     if (input_not_output) {
1728       return c.input_handle_shapes_and_types(idx);
1729     } else {
1730       return c.output_handle_shapes_and_types(idx);
1731     }
1732   };
1733   auto relax_shapes_and_types_to_context =
1734       [&](int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1735         if (input_not_output) {
1736           return c.RelaxInputHandleShapesAndMergeTypes(idx, shapes_and_types);
1737         } else {
1738           return c.RelaxOutputHandleShapesAndMergeTypes(idx, shapes_and_types);
1739         }
1740       };
1741 
1742   EXPECT_TRUE(get_shapes_and_types_from_context(0) == nullptr);
1743   EXPECT_TRUE(get_shapes_and_types_from_context(1) == nullptr);
1744 
1745   // First relax will take the input completely.
1746   std::vector<ShapeAndType> t{{make_shape({1, 2, 3}), DT_FLOAT},
1747                               {c.UnknownShape(), DT_INVALID},
1748                               {make_shape({4, 3, 2, 1}), DT_INT32}};
1749   ASSERT_TRUE(relax_shapes_and_types_to_context(0, t));
1750   ASSERT_TRUE(get_shapes_and_types_from_context(0) != nullptr);
1751   std::vector<ShapeAndType> v = *get_shapes_and_types_from_context(0);
1752   ASSERT_EQ(3, v.size());
1753   for (int i = 0; i < v.size(); ++i) {
1754     EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1755     EXPECT_EQ(t[i].dtype, v[i].dtype);
1756   }
1757 
1758   // Relax that fails because wrong number of values passed.
1759   // Fails, and no changes made.
1760   ASSERT_FALSE(relax_shapes_and_types_to_context(
1761       0, std::vector<ShapeAndType>{{make_shape({1, 2, 3}), DT_FLOAT}}));
1762   v = *get_shapes_and_types_from_context(0);
1763   ASSERT_EQ(3, v.size());
1764   for (int i = 0; i < v.size(); ++i) {
1765     EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1766     EXPECT_EQ(t[i].dtype, v[i].dtype);
1767   }
1768 
1769   // Only difference is in a mismatched shape. This should replace
1770   // the mismatched dimension with an UnknownDim.
1771   auto t2 = t;
1772   t2[2].shape = make_shape({4, 3, 4, 1});
1773   ASSERT_TRUE(relax_shapes_and_types_to_context(0, t2));
1774   v = *get_shapes_and_types_from_context(0);
1775   EXPECT_EQ("[4,3,?,1]", c.DebugString(v[2].shape));
1776   for (int i = 0; i < v.size(); ++i) {
1777     EXPECT_EQ(t[i].dtype, v[i].dtype);
1778   }
1779 
1780   // Only difference is in a mismatched dtype, but that cannot be
1781   // updated unless original dtype is DT_INVALID.
1782   t2 = t;
1783   t2[2].dtype = DT_FLOAT;
1784   ASSERT_FALSE(relax_shapes_and_types_to_context(0, t2));
1785   v = *get_shapes_and_types_from_context(0);
1786   ASSERT_EQ(3, v.size());
1787   for (int i = 0; i < v.size(); ++i) {
1788     EXPECT_EQ(t[i].dtype, v[i].dtype);
1789   }
1790 
1791   // Difference is a new shape, which will result in a new UnknownShape.
1792   t[1].shape = make_shape({1, 10});
1793   ASSERT_TRUE(relax_shapes_and_types_to_context(0, t));
1794   v = *get_shapes_and_types_from_context(0);
1795   ASSERT_EQ(3, v.size());
1796   EXPECT_FALSE(SameHandle(t[1].shape, v[1].shape));
1797   EXPECT_EQ("?", c.DebugString(v[1].shape));
1798   for (int i = 0; i < v.size(); ++i) {
1799     EXPECT_EQ(t[i].dtype, v[i].dtype);
1800   }
1801 
1802   // Difference is relaxable (new type).
1803   t[1].dtype = DT_DOUBLE;
1804   ASSERT_TRUE(relax_shapes_and_types_to_context(0, t));
1805   v = *get_shapes_and_types_from_context(0);
1806   EXPECT_EQ(t[1].dtype, v[1].dtype);
1807 }
1808 
TEST_F(ShapeInferenceTest,RelaxInputHandleShapesAndTypes)1809 TEST_F(ShapeInferenceTest, RelaxInputHandleShapesAndTypes) {
1810   TestRelaxHandles(true /* input_not_output */);
1811 }
1812 
TEST_F(ShapeInferenceTest,RelaxOutputHandleShapesAndTypes)1813 TEST_F(ShapeInferenceTest, RelaxOutputHandleShapesAndTypes) {
1814   TestRelaxHandles(false /* input_not_output */);
1815 }
1816 
1817 }  // namespace shape_inference
1818 }  // namespace tensorflow
1819