1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/node_def_builder.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference_testutil.h"
19 #include "tensorflow/core/framework/tensor.h"
20 #include "tensorflow/core/framework/tensor_shape.pb.h"
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/lib/strings/str_util.h"
24 #include "tensorflow/core/platform/test.h"
25 
26 namespace tensorflow {
27 
TEST(MathOpsTest,AddN_ShapeFn)28 TEST(MathOpsTest, AddN_ShapeFn) {
29   ShapeInferenceTestOp op("AddN");
30   auto set_n = [&op](int n) {
31     std::vector<NodeDefBuilder::NodeOut> src_list;
32     src_list.reserve(n);
33     for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT);
34     TF_ASSERT_OK(NodeDefBuilder("test", "AddN")
35                      .Input(src_list)
36                      .Attr("N", n)
37                      .Finalize(&op.node_def));
38   };
39 
40   set_n(2);
41   // Adding two unknowns returns either input.
42   INFER_OK(op, "?;?", "in0|in1");
43 
44   // known+unknown returns the known input.
45   INFER_OK(op, "[1];[?]", "in0");
46   INFER_OK(op, "[1];?", "in0");
47   INFER_OK(op, "[?];[1]", "in1");
48   INFER_OK(op, "?;[1]", "in1");
49 
50   set_n(2);
51   INFER_OK(op, "[1,2];[?,2]", "in0");
52   INFER_OK(op, "[1,2];[1,2]", "in0|in1");
53   INFER_OK(op, "[?,2];[1,2]", "in1");
54 
55   set_n(3);
56   INFER_OK(op, "[1,?];[?,2];[1,2]", "in2");
57   INFER_OK(op, "[1,2];[?,2];[1,?]", "in0");
58   INFER_OK(op, "?;?;[1,2]", "in2");
59 
60   set_n(2);
61   INFER_OK(op, "?;[1,2]", "in1");
62   INFER_OK(op, "[1,?];[?,2]", "[d0_0,d1_1]");
63   INFER_OK(op, "[?,2,?];[?,?,3]", "[d0_0|d1_0,d0_1,d1_2]");
64   INFER_OK(op, "[?,2];[1,?]", "[d1_0,d0_1]");
65 
66   set_n(3);
67   INFER_ERROR("Dimension 1 in both shapes must be equal, but are 2 and 4", op,
68               "[1,2];?;[1,4]");
69   INFER_ERROR("From merging shape 0 with other shapes.", op, "[1,2];?;[1,4]");
70   set_n(4);
71   INFER_ERROR("Shapes must be equal rank, but are 2 and 3", op,
72               "?;[1,2];?;[1,2,3]");
73   INFER_ERROR("From merging shape 1 with other shapes.", op,
74               "?;[1,2];?;[1,2,3]");
75 }
76 
TEST(MathOpsTest,UnchangedShape_ShapeFn)77 TEST(MathOpsTest, UnchangedShape_ShapeFn) {
78   ShapeInferenceTestOp op("Cast");
79   INFER_OK(op, "?", "in0");
80   INFER_OK(op, "[?]", "in0");
81   INFER_OK(op, "[1,?,3,4]", "in0");
82 }
83 
TEST(MathOpsTest,Segment_ShapeFn)84 TEST(MathOpsTest, Segment_ShapeFn) {
85   // Tests SegmentReductionShapeFn.
86   for (const auto* op_name : {"SegmentMax", "SegmentMean", "SegmentMin",
87                               "SegmentProd", "SegmentSum"}) {
88     ShapeInferenceTestOp op(op_name);
89     INFER_OK(op, "?;?", "?");
90     INFER_OK(op, "?;[100]", "?");
91 
92     // Data shape with single dimension.
93     INFER_OK(op, "[?];?", "[?]");
94     INFER_OK(op, "[?];[100]", "[?]");
95     INFER_OK(op, "[1];?", "[?]");
96     INFER_OK(op, "[1];[100]", "[?]");
97 
98     // Data shape with multiple dimensions.
99     INFER_OK(op, "[?,?];?", "[?,d0_1]");
100     INFER_OK(op, "[?,2];[100]", "[?,d0_1]");
101     INFER_OK(op, "[?,2,?,4];[100]", "[?,d0_1,d0_2,d0_3]");
102     INFER_OK(op, "[1,?];?", "[?,d0_1]");
103     INFER_OK(op, "[1,2];[100]", "[?,d0_1]");
104     INFER_OK(op, "[1,2,?,4];[100]", "[?,d0_1,d0_2,d0_3]");
105 
106     // Error cases.
107     INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[1,2]");
108     INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[];[1]");
109   }
110 }
111 
TEST(MathOpsTest,BroadcastBinaryOps_ShapeFn)112 TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) {
113   auto test_shapes = [&](ShapeInferenceTestOp& op,
114                          bool incompatible_shape_error) {
115     INFER_OK(op, "?;?", "?");
116     INFER_OK(op, "[1,2];?", "?");
117     INFER_OK(op, "?;[1,2]", "?");
118 
119     INFER_OK(op, "[?];[1]", "[d0_0]");
120     INFER_OK(op, "[1];[?]", "[d1_0]");
121     INFER_OK(op, "[?];[2]", incompatible_shape_error ? "[d1_0]" : "?");
122     INFER_OK(op, "[2];[?]", incompatible_shape_error ? "[d0_0]" : "?");
123     INFER_OK(op, "[?];[?]", "[?]");
124     INFER_OK(op, "[];[?]", "[d1_0]");
125     INFER_OK(op, "[?];[]", "[d0_0]");
126 
127     INFER_OK(op, "[1];[1]", "[d0_0|d1_0]");
128     INFER_OK(op, "[];[1]", "[d1_0]");
129     INFER_OK(op, "[1];[]", "[d0_0]");
130 
131     INFER_OK(op, "[2];[2]", "[d0_0|d1_0]");
132     INFER_OK(op, "[];[2]", "[d1_0]");
133     INFER_OK(op, "[1];[2]", "[d1_0]");
134     INFER_OK(op, "[2];[1]", "[d0_0]");
135     INFER_OK(op, "[2];[]", "[d0_0]");
136     INFER_OK(op, "[2];[?]", incompatible_shape_error ? "[d0_0]" : "?");
137 
138     INFER_OK(op, "[0];[0]", "[d0_0|d1_0]");
139     INFER_OK(op, "[];[0]", "[d1_0]");
140     INFER_OK(op, "[1];[0]", "[d1_0]");
141     INFER_OK(op, "[0];[1]", "[d0_0]");
142     INFER_OK(op, "[0];[]", "[d0_0]");
143 
144     INFER_OK(op, "[2];[?,?]", incompatible_shape_error ? "[d1_0,d0_0]" : "?");
145     INFER_OK(op, "[2,2];[?,?,?]",
146              incompatible_shape_error ? "[d1_0,d0_0,d0_1]" : "?");
147 
148     // Multiple dimension cases (same test cases, switching x and y).
149     INFER_OK(op, "[?,1,2,3,4,5];[3,1,?]",
150              incompatible_shape_error ? "[d0_0,d0_1,d0_2,d0_3|d1_0,d0_4,d0_5]"
151                                       : "?");
152     INFER_OK(op, "[3,1,?];[?,1,2,3,4,5]",
153              incompatible_shape_error ? "[d1_0,d1_1,d1_2,d1_3|d0_0,d1_4,d1_5]"
154                                       : "?");
155 
156     if (incompatible_shape_error) {
157       INFER_ERROR("Dimensions must be equal", op, "[2];[3]");
158     } else {
159       INFER_OK(op, "[2];[3]", "[]");
160     }
161   };
162 
163   for (string op_name : {"Add",        "Complex",
164                          "Div",        "Equal",
165                          "Greater",    "GreaterEqual",
166                          "Igamma",     "Igammac",
167                          "Zeta",       "Polygamma",
168                          "Less",       "LessEqual",
169                          "LogicalAnd", "LogicalOr",
170                          "Maximum",    "Minimum",
171                          "Mod",        "Mul",
172                          "NotEqual",   "Pow",
173                          "Sub",        "SquaredDifference",
174                          "DivNoNan"}) {
175     ShapeInferenceTestOp op(op_name);
176     AddNodeAttr("incompatible_shape_error", true, &op.node_def);
177     test_shapes(op, true);
178 
179     if ((op_name == "Equal") || (op_name == "NotEqual")) {
180       ShapeInferenceTestOp op(op_name);
181       AddNodeAttr("incompatible_shape_error", false, &op.node_def);
182       test_shapes(op, false);
183     }
184   }
185 }
186 
TEST(MathOpsTest,Select_ShapeFn)187 TEST(MathOpsTest, Select_ShapeFn) {
188   ShapeInferenceTestOp op("Select");
189   INFER_OK(op, "?;?;?", "in1|in2");
190 
191   // scalar case
192   INFER_OK(op, "[];[1];?", "in1");
193   INFER_OK(op, "[];?;?", "in1|in2");
194 
195   INFER_OK(op, "[1];?;?",
196            "in1|in2");  // When cond is vector, t/e may not match it.
197   INFER_OK(op, "[1,2];?;?", "in1|in2?");
198 
199   INFER_OK(op, "?;[];?", "in1");
200   INFER_OK(op, "?;?;[]", "in2");
201   INFER_OK(op, "?;[1];?", "in1");
202   INFER_OK(op, "?;?;[1]", "in2");
203   INFER_OK(op, "?;[1,2];?", "in1");
204   INFER_OK(op, "?;?;[1,2]", "in2");
205 
206   INFER_ERROR("Shapes must be equal rank, but are 0 and 1", op, "[1];[];?");
207   INFER_ERROR("Shapes must be equal rank, but are 1 and 2", op, "[];[1];[1,2]");
208   INFER_ERROR("Shapes must be equal rank, but are 1 and 2", op, "[1,2];[1];?");
209   INFER_OK(op, "[2];[?];[?]", "in1|in2");
210 
211   INFER_OK(op, "[?];[?,?,3];[1,2,?]", "[d2_0,d2_1,d1_2]");
212   INFER_OK(op, "[2];[?,?,3];[?,2,?]", "[d1_0|d2_0,d2_1,d1_2]");
213   INFER_ERROR("must be equal", op, "[1];[2,?,3];[?,2,?]");
214   INFER_ERROR("Shapes must be equal rank, but are 3 and 2", op,
215               "[2,?];[?,?,3];[?,2,?]");
216   INFER_OK(op, "[2,?,?];[?,?,3];[?,2,?]", "[d0_0,d2_1,d1_2]");
217   INFER_ERROR("Dimension 2 in both shapes must be equal, but are 3 and 5", op,
218               "[2,?,5];[?,?,3];[?,2,?]");
219 
220   // Test that handles were merged.
221   //
222   // Tests below will modify handle_data and call run_inference_for_handles to
223   // rerun shape inference, updating the context <c>.
224   const OpRegistrationData* op_reg_data;
225   TF_ASSERT_OK(OpRegistry::Global()->LookUp(op.name, &op_reg_data));
226   typedef std::vector<std::pair<PartialTensorShape, DataType>> ShapeDtypeV;
227   std::vector<std::unique_ptr<ShapeDtypeV>> handle_data;
228   std::unique_ptr<shape_inference::InferenceContext> c;
229   auto run_inference_for_handles = [&]() -> Status {
230     CHECK(op_reg_data->shape_inference_fn != nullptr);
231     c.reset(new shape_inference::InferenceContext(
232         TF_GRAPH_DEF_VERSION, op.node_def, op_reg_data->op_def,
233         {PartialTensorShape(), PartialTensorShape(), PartialTensorShape()}, {},
234         {}, handle_data));
235     TF_CHECK_OK(c->construction_status());
236     Status s = c->Run(op_reg_data->shape_inference_fn);
237     LOG(INFO) << "Inference got " << s;
238     return s;
239   };
240   auto shape_proto = [](std::initializer_list<int64> dim_sizes) {
241     TensorShapeProto p;
242     for (auto i : dim_sizes) p.add_dim()->set_size(i);
243     return p;
244   };
245 
246   auto i0 = PartialTensorShape({1, -1});
247   auto i1 = PartialTensorShape({-1, 2});
248   PartialTensorShape unknown_shape;
249   auto scalar = PartialTensorShape({});
250 
251   handle_data.emplace_back(
252       new ShapeDtypeV{{scalar, DT_FLOAT}, {unknown_shape, DT_INT32}});
253   handle_data.emplace_back(new ShapeDtypeV{{i0, DT_FLOAT}, {i1, DT_INT32}});
254   handle_data.emplace_back(
255       new ShapeDtypeV{{i1, DT_FLOAT}, {unknown_shape, DT_INT32}});
256 
257   TF_ASSERT_OK(run_inference_for_handles());
258   auto* out = c->output_handle_shapes_and_types(0);
259   ASSERT_EQ(2, out->size());
260   EXPECT_EQ("[1,2]", c->DebugString(out->at(0).shape));
261   EXPECT_EQ(DT_FLOAT, out->at(0).dtype);
262   EXPECT_EQ("[?,2]", c->DebugString(out->at(1).shape));
263   EXPECT_EQ(DT_INT32, out->at(1).dtype);
264 
265   // Expect an error when the shapes can't be merged.
266   handle_data[2]->at(0).first = shape_proto({2, 2});
267   EXPECT_TRUE(absl::StrContains(run_inference_for_handles().error_message(),
268                                 "must be equal, but are 1 and 2"));
269   handle_data[2]->at(0).first = i1;  // restore to valid
270 
271   // Expect an error when the types can't be merged.
272   handle_data[2]->at(1).second = DT_INT64;
273   EXPECT_TRUE(absl::StrContains(run_inference_for_handles().error_message(),
274                                 "pointing to different dtypes"));
275   handle_data[2]->at(1).second = DT_INT32;  // restore to valid
276 
277   // Expect an error when different numbers of tensors are merged.
278   handle_data[2]->push_back({i1, DT_FLOAT});
279   EXPECT_TRUE(absl::StrContains(run_inference_for_handles().error_message(),
280                                 "pointing to different numbers of tensors"));
281   handle_data[2]->pop_back();  // restore to valid.
282 }
283 
TEST(MathOpsTest,Range_ShapeFn)284 TEST(MathOpsTest, Range_ShapeFn) {
285   ShapeInferenceTestOp op("Range");
286 
287   TF_ASSERT_OK(NodeDefBuilder("test", "Range")
288                    .Input({"start", {}, DT_INT32})
289                    .Input({"limit", {}, DT_INT32})
290                    .Input({"delta", {}, DT_INT32})
291                    .Attr("Tidx", DT_INT32)
292                    .Finalize(&op.node_def));
293 
294   op.input_tensors.resize(3);
295   INFER_OK(op, "?;?;?", "[?]");
296   INFER_ERROR("Shape must be rank 0 but is rank 2", op, "[1,2];?;?");
297   INFER_ERROR("for 'start'", op, "[1,2];?;?");
298 
299   INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;[1,2];?");
300   INFER_ERROR("for 'limit'", op, "?;[1,2];?");
301 
302   INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;?;[1,2]");
303   INFER_ERROR("for 'delta'", op, "?;?;[1,2]");
304 
305   Tensor start_t = test::AsScalar(1);
306   op.input_tensors[0] = &start_t;
307   INFER_OK(op, "?;?;?", "[?]");
308   Tensor limit_t = test::AsScalar(1);
309   op.input_tensors[1] = &limit_t;
310   INFER_OK(op, "?;?;?", "[?]");
311 
312   Tensor delta_t = test::AsScalar(1);
313   op.input_tensors[2] = &delta_t;
314   INFER_OK(op, "?;?;?", "[0]");
315 
316   delta_t = test::AsScalar(0);
317   INFER_ERROR("Requires delta != 0", op, "?;?;?");
318   delta_t = test::AsScalar(3);
319 
320   limit_t = test::AsScalar(-1);
321   INFER_ERROR("Requires start <= limit when delta > 0: 1/-1", op, "?;?;?");
322 
323   delta_t = test::AsScalar(-1);
324   INFER_OK(op, "?;?;?", "[2]");
325 
326   limit_t = test::AsScalar(4);
327   INFER_ERROR("Requires start >= limit when delta < 0: 1/4", op, "?;?;?");
328 
329   limit_t = test::AsScalar(100);
330   start_t = test::AsScalar(2);
331   delta_t = test::AsScalar(3);
332   INFER_OK(op, "?;?;?", "[33]");
333 }
334 
TEST(MathOpsTest,LinSpace_ShapeFn)335 TEST(MathOpsTest, LinSpace_ShapeFn) {
336   ShapeInferenceTestOp op("LinSpace");
337   op.input_tensors.resize(3);
338   INFER_OK(op, "?;?;?", "[?]");
339   INFER_ERROR("Shape must be rank 0 but is rank 2", op, "[1,2];?;?");
340   INFER_ERROR("for 'start'", op, "[1,2];?;?");
341   INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;[1,2];?");
342   INFER_ERROR("for 'stop'", op, "?;[1,2];?");
343   INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;?;[1,2]");
344   INFER_ERROR("for 'num'", op, "?;?;[1,2]");
345 
346   Tensor num_t = test::AsScalar(1);
347   op.input_tensors[2] = &num_t;
348   INFER_OK(op, "?;?;?", "[1]");
349   num_t = test::AsScalar(2);
350   INFER_OK(op, "?;?;?", "[2]");
351   num_t = test::AsScalar(-1);
352   INFER_ERROR("Requires num > 0: -1", op, "?;?;?");
353 }
354 
TEST(MathOpsTest,UnsortedSegmentSum_ShapeFn)355 TEST(MathOpsTest, UnsortedSegmentSum_ShapeFn) {
356   ShapeInferenceTestOp op("UnsortedSegmentSum");
357   op.input_tensors.resize(3);
358   INFER_OK(op, "?;?;?", "?");
359   INFER_OK(op, "?;[?];?", "?");
360   INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;?;[1,2]");
361   INFER_ERROR("Dimensions must be equal, but are 2 and 3", op,
362               "[1,?,2];[1,?,3];?");
363   INFER_OK(op, "?;[3];?", "?");
364   INFER_ERROR("Shape must be at least rank 3 but is rank 2", op,
365               "[1,2];[1,2,3];?");
366 
367   Tensor num_segments_t = test::AsScalar(100);
368   op.input_tensors[2] = &num_segments_t;
369   INFER_OK(op, "[?,2,3,?,5];[1,2,?];[]", "[100,d0_3,d0_4]");
370 
371   num_segments_t = test::AsScalar(-1);
372   INFER_ERROR(("Dimension size, given by scalar input 2, must be "
373                "non-negative but is -1"),
374               op, "[3];[3];?");
375 }
376 
TEST(MathOpsTest,SparseSegment_ShapeFn)377 TEST(MathOpsTest, SparseSegment_ShapeFn) {
378   ShapeInferenceTestOp op("SparseSegmentSum");
379   op.input_tensors.resize(3);
380   INFER_OK(op, "?;?;?", "?");
381   INFER_OK(op, "[2,4,3];[3];[3]", "[?,d0_1,d0_2]");
382 
383   INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[2,4,3];[];[3]");
384   INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[2,4,3];[3];[3,4]");
385 
386   INFER_ERROR("Dimension 0 in both shapes must be equal, but are 3 and 4", op,
387               "[2,4,3];[3];[4]");
388 }
389 
TEST(MathOpsTest,SparseSegmentGrad_ShapeFn)390 TEST(MathOpsTest, SparseSegmentGrad_ShapeFn) {
391   ShapeInferenceTestOp op("SparseSegmentMeanGrad");
392   op.input_tensors.resize(4);
393   INFER_OK(op, "?;?;?;?", "?");
394   INFER_OK(op, "[2,4,3];[3];[3];[]", "[?,d0_1,d0_2]");
395 
396   Tensor num_segments_t = test::AsScalar(100);
397   op.input_tensors[3] = &num_segments_t;
398   INFER_OK(op, "[2,4,3];[3];[3];[]", "[100,d0_1,d0_2]");
399 
400   INFER_ERROR("Shape must be rank 0 but is rank 2", op,
401               "[2,4,3];[3];[3];[1,1]");
402 
403   // Negative value is not allowed
404   num_segments_t = test::AsScalar(-100);
405   op.input_tensors[3] = &num_segments_t;
406   INFER_ERROR("Cannot specify a negative value", op, "[2,4,3];[3];[3];[]");
407 }
408 
TEST(MathOpsTest,BatchMatMul_ShapeFn)409 TEST(MathOpsTest, BatchMatMul_ShapeFn) {
410   ShapeInferenceTestOp op("BatchMatMul");
411   auto set_adj = [&op](bool adj_x, bool adj_y) {
412     TF_ASSERT_OK(NodeDefBuilder("test", "BatchMatMul")
413                      .Input({"a", 0, DT_FLOAT})
414                      .Input({"b", 0, DT_FLOAT})
415                      .Attr("adj_x", adj_x)
416                      .Attr("adj_y", adj_y)
417                      .Finalize(&op.node_def));
418   };
419 
420   set_adj(false, false);
421 
422   // Rank checks.
423   INFER_ERROR("at least rank 2", op, "[1];?");
424   INFER_ERROR("at least rank 2", op, "?;[2]");
425 
426   INFER_OK(op, "?;?", "?");
427 
428   // 0 batch dims.
429   INFER_OK(op, "[?,?];[?,?]", "[d0_0,d1_1]");
430 
431   // 2 batch dims.
432   INFER_OK(op, "[?,?,?,?];?", "[d0_0,d0_1,d0_2,?]");
433 
434   // Test adj_a, testing output and that inner dims are compared.
435   set_adj(false, false);
436   INFER_OK(op, "[1,2,3,4];[1,2,?,?]", "[d0_0,d0_1,d0_2,d1_3]");
437   INFER_ERROR("are 2 and 3", op, "[?,1,2];[?,3,1]");  // inner dim mismatch
438   set_adj(true, false);
439   INFER_OK(op, "[1,2,3,4];[1,2,?,?]", "[d0_0,d0_1,d0_3,d1_3]");
440   INFER_ERROR("are 2 and 3", op, "[?,2,1];[?,3,1]");  // inner dim mismatch
441 
442   // Test adj_b=true.
443   set_adj(false, true);
444   INFER_OK(op, "[1,2,?,?];[1,2,3,4]", "[d0_0,d0_1,d0_2,d1_2]");
445   INFER_ERROR("are 2 and 3", op, "[?,1,2];[?,1,3]");  // inner dim mismatch
446   set_adj(true, true);
447   INFER_OK(op, "[1,2,?,?];[1,2,3,4]", "[d0_0,d0_1,d0_3,d1_2]");
448   INFER_ERROR("are 2 and 3", op, "[?,2,1];[?,1,3]");  // inner dim mismatch
449 }
450 
TEST(MathOpsTest,ArgOps_ShapeFn)451 TEST(MathOpsTest, ArgOps_ShapeFn) {
452   ShapeInferenceTestOp op("ArgMax");
453   op.input_tensors.resize(2);
454 
455   INFER_OK(op, "?;?", "?");
456 
457   // input rank <= 1 produces scalar
458   INFER_OK(op, "[2];?", "[]");
459   INFER_OK(op, "[];?", "[]");
460 
461   // Incorrect rank for dimension
462   INFER_ERROR("must be rank 0", op, "[2];[1]");
463 
464   // dimension not available, but input rank is.  Output is unknown
465   // shape with rank one less than input rank.
466   INFER_OK(op, "[2,3,4];?", "[?,?]");
467   INFER_OK(op, "[2,3,4,5,6];?", "[?,?,?,?]");
468 
469   // Dimension values known
470   Tensor dimension = test::AsScalar(0);
471   op.input_tensors[1] = &dimension;
472   INFER_OK(op, "[2,3,4];[]", "[d0_1,d0_2]");
473 
474   dimension = test::AsScalar(1);
475   op.input_tensors[1] = &dimension;
476   INFER_OK(op, "[2,3,4];[]", "[d0_0,d0_2]");
477 
478   dimension = test::AsScalar(2);
479   op.input_tensors[1] = &dimension;
480   INFER_OK(op, "[2,3,4];[]", "[d0_0,d0_1]");
481 
482   // Dimension value out of bounds
483   dimension = test::AsScalar(10);
484   op.input_tensors[1] = &dimension;
485   INFER_ERROR("must be in the range [-3, 3)", op, "[2,3,4];[]");
486 
487   dimension = test::AsScalar(-10);
488   op.input_tensors[1] = &dimension;
489   INFER_ERROR("must be in the range [-3, 3)", op, "[2,3,4];[]");
490 
491   dimension = test::AsScalar(-1);
492   op.input_tensors[1] = &dimension;
493   INFER_OK(op, "[2,3,4];[]", "[d0_0,d0_1]");
494 }
495 
TEST(MathOpsTest,Betainc_ShapeFn)496 TEST(MathOpsTest, Betainc_ShapeFn) {
497   ShapeInferenceTestOp op("Betainc");
498 
499   INFER_OK(op, "?;?;?", "?");
500   INFER_OK(op, "[?,?];?;?", "in0");
501   INFER_OK(op, "[?,2];?;[1,?]", "[d2_0,d0_1]");
502   INFER_OK(op, "[?,2,?];[1,?,?];[?,?,3]", "[d1_0,d0_1,d2_2]");
503 
504   INFER_OK(op, "[?,2,?];[];[?,?,3]", "[d0_0|d2_0,d0_1,d2_2]");
505   INFER_OK(op, "[];[];[?,?,3]", "in2");
506 
507   // All but one is a scalar, so use it.
508   INFER_OK(op, "[];[];?", "in2");
509   INFER_OK(op, "[];[];[1,2,3,4]", "in2");
510 
511   // All scalar input; implementation picks in0.
512   INFER_OK(op, "[];[];[]", "in0");
513 
514   // Non-scalars must match shape.
515   INFER_ERROR("must be equal", op, "[1,2];[];[1,4]");
516   INFER_ERROR("must be equal", op, "[1,2];[];[1,2,3]");
517 }
518 
TEST(MathOpsTest,Requantize_ShapeFn)519 TEST(MathOpsTest, Requantize_ShapeFn) {
520   ShapeInferenceTestOp op("Requantize");
521 
522   INFER_OK(op, "?;?;?;?;?", "in0;[];[]");
523   INFER_OK(op, "?;[];[];[];[]", "in0;[];[]");
524 
525   // Rank checks on input scalars.
526   INFER_ERROR("must be rank 0", op, "?;[1];?;?;?");
527   INFER_ERROR("must be rank 0", op, "?;?;[2];?;?");
528   INFER_ERROR("must be rank 0", op, "?;?;?;[3];?");
529   INFER_ERROR("must be rank 0", op, "?;?;?;?;[4]");
530 }
531 
TEST(MathOpstest,RequantizationRange_ShapeFn)532 TEST(MathOpstest, RequantizationRange_ShapeFn) {
533   ShapeInferenceTestOp op("RequantizationRange");
534 
535   INFER_OK(op, "?;?;?", "[];[]");
536   INFER_OK(op, "?;[];[]", "[];[]");
537 
538   // Rank checks on input scalars.
539   INFER_ERROR("must be rank 0", op, "?;[1];?");
540   INFER_ERROR("must be rank 0", op, "?;?;[2]");
541 }
542 
TEST(MathOpsTest,Cross_ShapeFn)543 TEST(MathOpsTest, Cross_ShapeFn) {
544   ShapeInferenceTestOp op("Cross");
545 
546   INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[];[]");
547   INFER_ERROR("Dimension 0 in both shapes must be equal, but", op, "[3];[5]");
548   INFER_ERROR("Dimension must be 3 but", op, "[3,5];[3,5]");
549 
550   INFER_OK(op, "?;?", "in0");
551   INFER_OK(op, "[?];[?]", "in0");
552   INFER_OK(op, "[1,?,3];[?,?,?]", "in0");
553 }
554 
TEST(MathOpsTest,HistogramFixedWidth_ShapeFn)555 TEST(MathOpsTest, HistogramFixedWidth_ShapeFn) {
556   ShapeInferenceTestOp op("HistogramFixedWidth");
557 
558   // value_range should be vector.
559   INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[];[];[]");
560   // value_range should have 2 elements.
561   INFER_ERROR("Dimension must be 2 but is 3", op, "[];[3];[]");
562   // nbins should be scalar.
563   INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[];[2];[2]");
564 
565   INFER_OK(op, "?;?;?", "[?]");
566   INFER_OK(op, "[?];[2];[]", "[?]");
567   INFER_OK(op, "[?];[2];?", "[?]");
568 }
569 
TEST(MathOpsTest,QuantizedAdd_ShapeFn)570 TEST(MathOpsTest, QuantizedAdd_ShapeFn) {
571   ShapeInferenceTestOp op("QuantizedAdd");
572 
573   INFER_OK(op, "?;?;?;?;?;?", "?;[];[]");
574   INFER_OK(op, "?;?;[];[];[];[]", "?;[];[]");
575   INFER_OK(op, "[1,2];?;[];[];[];[]", "?;[];[]");
576   INFER_OK(op, "[];[2];[];[];[];[]", "[d1_0];[];[]");
577 
578   // Rank checks on input scalars.
579   INFER_ERROR("must be rank 0", op, "?;?;[1];?;?;?");
580   INFER_ERROR("must be rank 0", op, "?;?;?;[2];?;?");
581   INFER_ERROR("must be rank 0", op, "?;?;?;?;[3];?");
582   INFER_ERROR("must be rank 0", op, "?;?;?;?;?;[4]");
583 }
584 
TEST(MathOpsTest,Bincount_ShapeFn)585 TEST(MathOpsTest, Bincount_ShapeFn) {
586   ShapeInferenceTestOp op("Bincount");
587 
588   // size should be scalar.
589   INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[1];?");
590 
591   INFER_OK(op, "?;?;?", "[?]");
592   INFER_OK(op, "?;[];?", "[?]");
593   INFER_OK(op, "[?];[];?", "[?]");
594   INFER_OK(op, "[?];[];[?]", "[?]");
595 }
596 
TEST(MathOpsTest,SobolSample)597 TEST(MathOpsTest, SobolSample) {
598   ShapeInferenceTestOp op("SobolSample");
599 
600   // All inputs should be scalar.
601   INFER_ERROR("must be rank 0", op, "[1];?;?");
602   INFER_ERROR("must be rank 0", op, "?;[1];?");
603   INFER_ERROR("must be rank 0", op, "?;?;[1]");
604 
605   INFER_OK(op, "[];[];[]", "[?,?]");
606 }
607 
TEST(MathOpsTest,EqualOp)608 TEST(MathOpsTest, EqualOp) {
609   ShapeInferenceTestOp op("Equal");
610   AddNodeAttr("incompatible_shape_error", true, &op.node_def);
611 
612   INFER_OK(op, "?;?", "?");
613   INFER_OK(op, "[1,2];?", "?");
614   INFER_OK(op, "?;[1,2]", "?");
615 
616   INFER_OK(op, "[1,2,3];[1]", "[d0_0,d0_1,d0_2]");
617   INFER_OK(op, "[?,2,1];[1,3]", "[d0_0,d0_1,d1_1]");
618   INFER_OK(op, "[1,?,3];[3,1]", "[d0_0,d1_0,d0_2]");
619   INFER_OK(op, "[1,2,3];[2,1,3]", "[d1_0,d0_1,d0_2]");
620 
621   // Note: Test case for GitHub issue 40471
622   INFER_OK(op, "[?,10,1];[?,1,4]", "[?,d0_1,d1_2]");
623   INFER_OK(op, "[10,?,1];[1,?,4]", "[d0_0,?,d1_2]");
624 }
625 }  // end namespace tensorflow
626