1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/node_def_builder.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference_testutil.h"
19 #include "tensorflow/core/framework/tensor.h"
20 #include "tensorflow/core/framework/tensor_shape.pb.h"
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/lib/strings/str_util.h"
24 #include "tensorflow/core/platform/test.h"
25
26 namespace tensorflow {
27
TEST(MathOpsTest,AddN_ShapeFn)28 TEST(MathOpsTest, AddN_ShapeFn) {
29 ShapeInferenceTestOp op("AddN");
30 auto set_n = [&op](int n) {
31 std::vector<NodeDefBuilder::NodeOut> src_list;
32 src_list.reserve(n);
33 for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT);
34 TF_ASSERT_OK(NodeDefBuilder("test", "AddN")
35 .Input(src_list)
36 .Attr("N", n)
37 .Finalize(&op.node_def));
38 };
39
40 set_n(2);
41 // Adding two unknowns returns either input.
42 INFER_OK(op, "?;?", "in0|in1");
43
44 // known+unknown returns the known input.
45 INFER_OK(op, "[1];[?]", "in0");
46 INFER_OK(op, "[1];?", "in0");
47 INFER_OK(op, "[?];[1]", "in1");
48 INFER_OK(op, "?;[1]", "in1");
49
50 set_n(2);
51 INFER_OK(op, "[1,2];[?,2]", "in0");
52 INFER_OK(op, "[1,2];[1,2]", "in0|in1");
53 INFER_OK(op, "[?,2];[1,2]", "in1");
54
55 set_n(3);
56 INFER_OK(op, "[1,?];[?,2];[1,2]", "in2");
57 INFER_OK(op, "[1,2];[?,2];[1,?]", "in0");
58 INFER_OK(op, "?;?;[1,2]", "in2");
59
60 set_n(2);
61 INFER_OK(op, "?;[1,2]", "in1");
62 INFER_OK(op, "[1,?];[?,2]", "[d0_0,d1_1]");
63 INFER_OK(op, "[?,2,?];[?,?,3]", "[d0_0|d1_0,d0_1,d1_2]");
64 INFER_OK(op, "[?,2];[1,?]", "[d1_0,d0_1]");
65
66 set_n(3);
67 INFER_ERROR("Dimension 1 in both shapes must be equal, but are 2 and 4", op,
68 "[1,2];?;[1,4]");
69 INFER_ERROR("From merging shape 0 with other shapes.", op, "[1,2];?;[1,4]");
70 set_n(4);
71 INFER_ERROR("Shapes must be equal rank, but are 2 and 3", op,
72 "?;[1,2];?;[1,2,3]");
73 INFER_ERROR("From merging shape 1 with other shapes.", op,
74 "?;[1,2];?;[1,2,3]");
75 }
76
TEST(MathOpsTest,UnchangedShape_ShapeFn)77 TEST(MathOpsTest, UnchangedShape_ShapeFn) {
78 ShapeInferenceTestOp op("Cast");
79 INFER_OK(op, "?", "in0");
80 INFER_OK(op, "[?]", "in0");
81 INFER_OK(op, "[1,?,3,4]", "in0");
82 }
83
TEST(MathOpsTest,Segment_ShapeFn)84 TEST(MathOpsTest, Segment_ShapeFn) {
85 // Tests SegmentReductionShapeFn.
86 for (const auto* op_name : {"SegmentMax", "SegmentMean", "SegmentMin",
87 "SegmentProd", "SegmentSum"}) {
88 ShapeInferenceTestOp op(op_name);
89 INFER_OK(op, "?;?", "?");
90 INFER_OK(op, "?;[100]", "?");
91
92 // Data shape with single dimension.
93 INFER_OK(op, "[?];?", "[?]");
94 INFER_OK(op, "[?];[100]", "[?]");
95 INFER_OK(op, "[1];?", "[?]");
96 INFER_OK(op, "[1];[100]", "[?]");
97
98 // Data shape with multiple dimensions.
99 INFER_OK(op, "[?,?];?", "[?,d0_1]");
100 INFER_OK(op, "[?,2];[100]", "[?,d0_1]");
101 INFER_OK(op, "[?,2,?,4];[100]", "[?,d0_1,d0_2,d0_3]");
102 INFER_OK(op, "[1,?];?", "[?,d0_1]");
103 INFER_OK(op, "[1,2];[100]", "[?,d0_1]");
104 INFER_OK(op, "[1,2,?,4];[100]", "[?,d0_1,d0_2,d0_3]");
105
106 // Error cases.
107 INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[1,2]");
108 INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[];[1]");
109 }
110 }
111
TEST(MathOpsTest,BroadcastBinaryOps_ShapeFn)112 TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) {
113 auto test_shapes = [&](ShapeInferenceTestOp& op,
114 bool incompatible_shape_error) {
115 INFER_OK(op, "?;?", "?");
116 INFER_OK(op, "[1,2];?", "?");
117 INFER_OK(op, "?;[1,2]", "?");
118
119 INFER_OK(op, "[?];[1]", "[d0_0]");
120 INFER_OK(op, "[1];[?]", "[d1_0]");
121 INFER_OK(op, "[?];[2]", incompatible_shape_error ? "[d1_0]" : "?");
122 INFER_OK(op, "[2];[?]", incompatible_shape_error ? "[d0_0]" : "?");
123 INFER_OK(op, "[?];[?]", "[?]");
124 INFER_OK(op, "[];[?]", "[d1_0]");
125 INFER_OK(op, "[?];[]", "[d0_0]");
126
127 INFER_OK(op, "[1];[1]", "[d0_0|d1_0]");
128 INFER_OK(op, "[];[1]", "[d1_0]");
129 INFER_OK(op, "[1];[]", "[d0_0]");
130
131 INFER_OK(op, "[2];[2]", "[d0_0|d1_0]");
132 INFER_OK(op, "[];[2]", "[d1_0]");
133 INFER_OK(op, "[1];[2]", "[d1_0]");
134 INFER_OK(op, "[2];[1]", "[d0_0]");
135 INFER_OK(op, "[2];[]", "[d0_0]");
136 INFER_OK(op, "[2];[?]", incompatible_shape_error ? "[d0_0]" : "?");
137
138 INFER_OK(op, "[0];[0]", "[d0_0|d1_0]");
139 INFER_OK(op, "[];[0]", "[d1_0]");
140 INFER_OK(op, "[1];[0]", "[d1_0]");
141 INFER_OK(op, "[0];[1]", "[d0_0]");
142 INFER_OK(op, "[0];[]", "[d0_0]");
143
144 INFER_OK(op, "[2];[?,?]", incompatible_shape_error ? "[d1_0,d0_0]" : "?");
145 INFER_OK(op, "[2,2];[?,?,?]",
146 incompatible_shape_error ? "[d1_0,d0_0,d0_1]" : "?");
147
148 // Multiple dimension cases (same test cases, switching x and y).
149 INFER_OK(op, "[?,1,2,3,4,5];[3,1,?]",
150 incompatible_shape_error ? "[d0_0,d0_1,d0_2,d0_3|d1_0,d0_4,d0_5]"
151 : "?");
152 INFER_OK(op, "[3,1,?];[?,1,2,3,4,5]",
153 incompatible_shape_error ? "[d1_0,d1_1,d1_2,d1_3|d0_0,d1_4,d1_5]"
154 : "?");
155
156 if (incompatible_shape_error) {
157 INFER_ERROR("Dimensions must be equal", op, "[2];[3]");
158 } else {
159 INFER_OK(op, "[2];[3]", "[]");
160 }
161 };
162
163 for (string op_name : {"Add", "Complex",
164 "Div", "Equal",
165 "Greater", "GreaterEqual",
166 "Igamma", "Igammac",
167 "Zeta", "Polygamma",
168 "Less", "LessEqual",
169 "LogicalAnd", "LogicalOr",
170 "Maximum", "Minimum",
171 "Mod", "Mul",
172 "NotEqual", "Pow",
173 "Sub", "SquaredDifference",
174 "DivNoNan"}) {
175 ShapeInferenceTestOp op(op_name);
176 AddNodeAttr("incompatible_shape_error", true, &op.node_def);
177 test_shapes(op, true);
178
179 if ((op_name == "Equal") || (op_name == "NotEqual")) {
180 ShapeInferenceTestOp op(op_name);
181 AddNodeAttr("incompatible_shape_error", false, &op.node_def);
182 test_shapes(op, false);
183 }
184 }
185 }
186
TEST(MathOpsTest,Select_ShapeFn)187 TEST(MathOpsTest, Select_ShapeFn) {
188 ShapeInferenceTestOp op("Select");
189 INFER_OK(op, "?;?;?", "in1|in2");
190
191 // scalar case
192 INFER_OK(op, "[];[1];?", "in1");
193 INFER_OK(op, "[];?;?", "in1|in2");
194
195 INFER_OK(op, "[1];?;?",
196 "in1|in2"); // When cond is vector, t/e may not match it.
197 INFER_OK(op, "[1,2];?;?", "in1|in2?");
198
199 INFER_OK(op, "?;[];?", "in1");
200 INFER_OK(op, "?;?;[]", "in2");
201 INFER_OK(op, "?;[1];?", "in1");
202 INFER_OK(op, "?;?;[1]", "in2");
203 INFER_OK(op, "?;[1,2];?", "in1");
204 INFER_OK(op, "?;?;[1,2]", "in2");
205
206 INFER_ERROR("Shapes must be equal rank, but are 0 and 1", op, "[1];[];?");
207 INFER_ERROR("Shapes must be equal rank, but are 1 and 2", op, "[];[1];[1,2]");
208 INFER_ERROR("Shapes must be equal rank, but are 1 and 2", op, "[1,2];[1];?");
209 INFER_OK(op, "[2];[?];[?]", "in1|in2");
210
211 INFER_OK(op, "[?];[?,?,3];[1,2,?]", "[d2_0,d2_1,d1_2]");
212 INFER_OK(op, "[2];[?,?,3];[?,2,?]", "[d1_0|d2_0,d2_1,d1_2]");
213 INFER_ERROR("must be equal", op, "[1];[2,?,3];[?,2,?]");
214 INFER_ERROR("Shapes must be equal rank, but are 3 and 2", op,
215 "[2,?];[?,?,3];[?,2,?]");
216 INFER_OK(op, "[2,?,?];[?,?,3];[?,2,?]", "[d0_0,d2_1,d1_2]");
217 INFER_ERROR("Dimension 2 in both shapes must be equal, but are 3 and 5", op,
218 "[2,?,5];[?,?,3];[?,2,?]");
219
220 // Test that handles were merged.
221 //
222 // Tests below will modify handle_data and call run_inference_for_handles to
223 // rerun shape inference, updating the context <c>.
224 const OpRegistrationData* op_reg_data;
225 TF_ASSERT_OK(OpRegistry::Global()->LookUp(op.name, &op_reg_data));
226 typedef std::vector<std::pair<PartialTensorShape, DataType>> ShapeDtypeV;
227 std::vector<std::unique_ptr<ShapeDtypeV>> handle_data;
228 std::unique_ptr<shape_inference::InferenceContext> c;
229 auto run_inference_for_handles = [&]() -> Status {
230 CHECK(op_reg_data->shape_inference_fn != nullptr);
231 c.reset(new shape_inference::InferenceContext(
232 TF_GRAPH_DEF_VERSION, op.node_def, op_reg_data->op_def,
233 {PartialTensorShape(), PartialTensorShape(), PartialTensorShape()}, {},
234 {}, handle_data));
235 TF_CHECK_OK(c->construction_status());
236 Status s = c->Run(op_reg_data->shape_inference_fn);
237 LOG(INFO) << "Inference got " << s;
238 return s;
239 };
240 auto shape_proto = [](std::initializer_list<int64> dim_sizes) {
241 TensorShapeProto p;
242 for (auto i : dim_sizes) p.add_dim()->set_size(i);
243 return p;
244 };
245
246 auto i0 = PartialTensorShape({1, -1});
247 auto i1 = PartialTensorShape({-1, 2});
248 PartialTensorShape unknown_shape;
249 auto scalar = PartialTensorShape({});
250
251 handle_data.emplace_back(
252 new ShapeDtypeV{{scalar, DT_FLOAT}, {unknown_shape, DT_INT32}});
253 handle_data.emplace_back(new ShapeDtypeV{{i0, DT_FLOAT}, {i1, DT_INT32}});
254 handle_data.emplace_back(
255 new ShapeDtypeV{{i1, DT_FLOAT}, {unknown_shape, DT_INT32}});
256
257 TF_ASSERT_OK(run_inference_for_handles());
258 auto* out = c->output_handle_shapes_and_types(0);
259 ASSERT_EQ(2, out->size());
260 EXPECT_EQ("[1,2]", c->DebugString(out->at(0).shape));
261 EXPECT_EQ(DT_FLOAT, out->at(0).dtype);
262 EXPECT_EQ("[?,2]", c->DebugString(out->at(1).shape));
263 EXPECT_EQ(DT_INT32, out->at(1).dtype);
264
265 // Expect an error when the shapes can't be merged.
266 handle_data[2]->at(0).first = shape_proto({2, 2});
267 EXPECT_TRUE(absl::StrContains(run_inference_for_handles().error_message(),
268 "must be equal, but are 1 and 2"));
269 handle_data[2]->at(0).first = i1; // restore to valid
270
271 // Expect an error when the types can't be merged.
272 handle_data[2]->at(1).second = DT_INT64;
273 EXPECT_TRUE(absl::StrContains(run_inference_for_handles().error_message(),
274 "pointing to different dtypes"));
275 handle_data[2]->at(1).second = DT_INT32; // restore to valid
276
277 // Expect an error when different numbers of tensors are merged.
278 handle_data[2]->push_back({i1, DT_FLOAT});
279 EXPECT_TRUE(absl::StrContains(run_inference_for_handles().error_message(),
280 "pointing to different numbers of tensors"));
281 handle_data[2]->pop_back(); // restore to valid.
282 }
283
TEST(MathOpsTest,Range_ShapeFn)284 TEST(MathOpsTest, Range_ShapeFn) {
285 ShapeInferenceTestOp op("Range");
286
287 TF_ASSERT_OK(NodeDefBuilder("test", "Range")
288 .Input({"start", {}, DT_INT32})
289 .Input({"limit", {}, DT_INT32})
290 .Input({"delta", {}, DT_INT32})
291 .Attr("Tidx", DT_INT32)
292 .Finalize(&op.node_def));
293
294 op.input_tensors.resize(3);
295 INFER_OK(op, "?;?;?", "[?]");
296 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "[1,2];?;?");
297 INFER_ERROR("for 'start'", op, "[1,2];?;?");
298
299 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;[1,2];?");
300 INFER_ERROR("for 'limit'", op, "?;[1,2];?");
301
302 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;?;[1,2]");
303 INFER_ERROR("for 'delta'", op, "?;?;[1,2]");
304
305 Tensor start_t = test::AsScalar(1);
306 op.input_tensors[0] = &start_t;
307 INFER_OK(op, "?;?;?", "[?]");
308 Tensor limit_t = test::AsScalar(1);
309 op.input_tensors[1] = &limit_t;
310 INFER_OK(op, "?;?;?", "[?]");
311
312 Tensor delta_t = test::AsScalar(1);
313 op.input_tensors[2] = &delta_t;
314 INFER_OK(op, "?;?;?", "[0]");
315
316 delta_t = test::AsScalar(0);
317 INFER_ERROR("Requires delta != 0", op, "?;?;?");
318 delta_t = test::AsScalar(3);
319
320 limit_t = test::AsScalar(-1);
321 INFER_ERROR("Requires start <= limit when delta > 0: 1/-1", op, "?;?;?");
322
323 delta_t = test::AsScalar(-1);
324 INFER_OK(op, "?;?;?", "[2]");
325
326 limit_t = test::AsScalar(4);
327 INFER_ERROR("Requires start >= limit when delta < 0: 1/4", op, "?;?;?");
328
329 limit_t = test::AsScalar(100);
330 start_t = test::AsScalar(2);
331 delta_t = test::AsScalar(3);
332 INFER_OK(op, "?;?;?", "[33]");
333 }
334
TEST(MathOpsTest,LinSpace_ShapeFn)335 TEST(MathOpsTest, LinSpace_ShapeFn) {
336 ShapeInferenceTestOp op("LinSpace");
337 op.input_tensors.resize(3);
338 INFER_OK(op, "?;?;?", "[?]");
339 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "[1,2];?;?");
340 INFER_ERROR("for 'start'", op, "[1,2];?;?");
341 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;[1,2];?");
342 INFER_ERROR("for 'stop'", op, "?;[1,2];?");
343 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;?;[1,2]");
344 INFER_ERROR("for 'num'", op, "?;?;[1,2]");
345
346 Tensor num_t = test::AsScalar(1);
347 op.input_tensors[2] = &num_t;
348 INFER_OK(op, "?;?;?", "[1]");
349 num_t = test::AsScalar(2);
350 INFER_OK(op, "?;?;?", "[2]");
351 num_t = test::AsScalar(-1);
352 INFER_ERROR("Requires num > 0: -1", op, "?;?;?");
353 }
354
TEST(MathOpsTest,UnsortedSegmentSum_ShapeFn)355 TEST(MathOpsTest, UnsortedSegmentSum_ShapeFn) {
356 ShapeInferenceTestOp op("UnsortedSegmentSum");
357 op.input_tensors.resize(3);
358 INFER_OK(op, "?;?;?", "?");
359 INFER_OK(op, "?;[?];?", "?");
360 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;?;[1,2]");
361 INFER_ERROR("Dimensions must be equal, but are 2 and 3", op,
362 "[1,?,2];[1,?,3];?");
363 INFER_OK(op, "?;[3];?", "?");
364 INFER_ERROR("Shape must be at least rank 3 but is rank 2", op,
365 "[1,2];[1,2,3];?");
366
367 Tensor num_segments_t = test::AsScalar(100);
368 op.input_tensors[2] = &num_segments_t;
369 INFER_OK(op, "[?,2,3,?,5];[1,2,?];[]", "[100,d0_3,d0_4]");
370
371 num_segments_t = test::AsScalar(-1);
372 INFER_ERROR(("Dimension size, given by scalar input 2, must be "
373 "non-negative but is -1"),
374 op, "[3];[3];?");
375 }
376
TEST(MathOpsTest,SparseSegment_ShapeFn)377 TEST(MathOpsTest, SparseSegment_ShapeFn) {
378 ShapeInferenceTestOp op("SparseSegmentSum");
379 op.input_tensors.resize(3);
380 INFER_OK(op, "?;?;?", "?");
381 INFER_OK(op, "[2,4,3];[3];[3]", "[?,d0_1,d0_2]");
382
383 INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[2,4,3];[];[3]");
384 INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[2,4,3];[3];[3,4]");
385
386 INFER_ERROR("Dimension 0 in both shapes must be equal, but are 3 and 4", op,
387 "[2,4,3];[3];[4]");
388 }
389
TEST(MathOpsTest,SparseSegmentGrad_ShapeFn)390 TEST(MathOpsTest, SparseSegmentGrad_ShapeFn) {
391 ShapeInferenceTestOp op("SparseSegmentMeanGrad");
392 op.input_tensors.resize(4);
393 INFER_OK(op, "?;?;?;?", "?");
394 INFER_OK(op, "[2,4,3];[3];[3];[]", "[?,d0_1,d0_2]");
395
396 Tensor num_segments_t = test::AsScalar(100);
397 op.input_tensors[3] = &num_segments_t;
398 INFER_OK(op, "[2,4,3];[3];[3];[]", "[100,d0_1,d0_2]");
399
400 INFER_ERROR("Shape must be rank 0 but is rank 2", op,
401 "[2,4,3];[3];[3];[1,1]");
402
403 // Negative value is not allowed
404 num_segments_t = test::AsScalar(-100);
405 op.input_tensors[3] = &num_segments_t;
406 INFER_ERROR("Cannot specify a negative value", op, "[2,4,3];[3];[3];[]");
407 }
408
TEST(MathOpsTest,BatchMatMul_ShapeFn)409 TEST(MathOpsTest, BatchMatMul_ShapeFn) {
410 ShapeInferenceTestOp op("BatchMatMul");
411 auto set_adj = [&op](bool adj_x, bool adj_y) {
412 TF_ASSERT_OK(NodeDefBuilder("test", "BatchMatMul")
413 .Input({"a", 0, DT_FLOAT})
414 .Input({"b", 0, DT_FLOAT})
415 .Attr("adj_x", adj_x)
416 .Attr("adj_y", adj_y)
417 .Finalize(&op.node_def));
418 };
419
420 set_adj(false, false);
421
422 // Rank checks.
423 INFER_ERROR("at least rank 2", op, "[1];?");
424 INFER_ERROR("at least rank 2", op, "?;[2]");
425
426 INFER_OK(op, "?;?", "?");
427
428 // 0 batch dims.
429 INFER_OK(op, "[?,?];[?,?]", "[d0_0,d1_1]");
430
431 // 2 batch dims.
432 INFER_OK(op, "[?,?,?,?];?", "[d0_0,d0_1,d0_2,?]");
433
434 // Test adj_a, testing output and that inner dims are compared.
435 set_adj(false, false);
436 INFER_OK(op, "[1,2,3,4];[1,2,?,?]", "[d0_0,d0_1,d0_2,d1_3]");
437 INFER_ERROR("are 2 and 3", op, "[?,1,2];[?,3,1]"); // inner dim mismatch
438 set_adj(true, false);
439 INFER_OK(op, "[1,2,3,4];[1,2,?,?]", "[d0_0,d0_1,d0_3,d1_3]");
440 INFER_ERROR("are 2 and 3", op, "[?,2,1];[?,3,1]"); // inner dim mismatch
441
442 // Test adj_b=true.
443 set_adj(false, true);
444 INFER_OK(op, "[1,2,?,?];[1,2,3,4]", "[d0_0,d0_1,d0_2,d1_2]");
445 INFER_ERROR("are 2 and 3", op, "[?,1,2];[?,1,3]"); // inner dim mismatch
446 set_adj(true, true);
447 INFER_OK(op, "[1,2,?,?];[1,2,3,4]", "[d0_0,d0_1,d0_3,d1_2]");
448 INFER_ERROR("are 2 and 3", op, "[?,2,1];[?,1,3]"); // inner dim mismatch
449 }
450
TEST(MathOpsTest,ArgOps_ShapeFn)451 TEST(MathOpsTest, ArgOps_ShapeFn) {
452 ShapeInferenceTestOp op("ArgMax");
453 op.input_tensors.resize(2);
454
455 INFER_OK(op, "?;?", "?");
456
457 // input rank <= 1 produces scalar
458 INFER_OK(op, "[2];?", "[]");
459 INFER_OK(op, "[];?", "[]");
460
461 // Incorrect rank for dimension
462 INFER_ERROR("must be rank 0", op, "[2];[1]");
463
464 // dimension not available, but input rank is. Output is unknown
465 // shape with rank one less than input rank.
466 INFER_OK(op, "[2,3,4];?", "[?,?]");
467 INFER_OK(op, "[2,3,4,5,6];?", "[?,?,?,?]");
468
469 // Dimension values known
470 Tensor dimension = test::AsScalar(0);
471 op.input_tensors[1] = &dimension;
472 INFER_OK(op, "[2,3,4];[]", "[d0_1,d0_2]");
473
474 dimension = test::AsScalar(1);
475 op.input_tensors[1] = &dimension;
476 INFER_OK(op, "[2,3,4];[]", "[d0_0,d0_2]");
477
478 dimension = test::AsScalar(2);
479 op.input_tensors[1] = &dimension;
480 INFER_OK(op, "[2,3,4];[]", "[d0_0,d0_1]");
481
482 // Dimension value out of bounds
483 dimension = test::AsScalar(10);
484 op.input_tensors[1] = &dimension;
485 INFER_ERROR("must be in the range [-3, 3)", op, "[2,3,4];[]");
486
487 dimension = test::AsScalar(-10);
488 op.input_tensors[1] = &dimension;
489 INFER_ERROR("must be in the range [-3, 3)", op, "[2,3,4];[]");
490
491 dimension = test::AsScalar(-1);
492 op.input_tensors[1] = &dimension;
493 INFER_OK(op, "[2,3,4];[]", "[d0_0,d0_1]");
494 }
495
TEST(MathOpsTest,Betainc_ShapeFn)496 TEST(MathOpsTest, Betainc_ShapeFn) {
497 ShapeInferenceTestOp op("Betainc");
498
499 INFER_OK(op, "?;?;?", "?");
500 INFER_OK(op, "[?,?];?;?", "in0");
501 INFER_OK(op, "[?,2];?;[1,?]", "[d2_0,d0_1]");
502 INFER_OK(op, "[?,2,?];[1,?,?];[?,?,3]", "[d1_0,d0_1,d2_2]");
503
504 INFER_OK(op, "[?,2,?];[];[?,?,3]", "[d0_0|d2_0,d0_1,d2_2]");
505 INFER_OK(op, "[];[];[?,?,3]", "in2");
506
507 // All but one is a scalar, so use it.
508 INFER_OK(op, "[];[];?", "in2");
509 INFER_OK(op, "[];[];[1,2,3,4]", "in2");
510
511 // All scalar input; implementation picks in0.
512 INFER_OK(op, "[];[];[]", "in0");
513
514 // Non-scalars must match shape.
515 INFER_ERROR("must be equal", op, "[1,2];[];[1,4]");
516 INFER_ERROR("must be equal", op, "[1,2];[];[1,2,3]");
517 }
518
TEST(MathOpsTest,Requantize_ShapeFn)519 TEST(MathOpsTest, Requantize_ShapeFn) {
520 ShapeInferenceTestOp op("Requantize");
521
522 INFER_OK(op, "?;?;?;?;?", "in0;[];[]");
523 INFER_OK(op, "?;[];[];[];[]", "in0;[];[]");
524
525 // Rank checks on input scalars.
526 INFER_ERROR("must be rank 0", op, "?;[1];?;?;?");
527 INFER_ERROR("must be rank 0", op, "?;?;[2];?;?");
528 INFER_ERROR("must be rank 0", op, "?;?;?;[3];?");
529 INFER_ERROR("must be rank 0", op, "?;?;?;?;[4]");
530 }
531
TEST(MathOpstest,RequantizationRange_ShapeFn)532 TEST(MathOpstest, RequantizationRange_ShapeFn) {
533 ShapeInferenceTestOp op("RequantizationRange");
534
535 INFER_OK(op, "?;?;?", "[];[]");
536 INFER_OK(op, "?;[];[]", "[];[]");
537
538 // Rank checks on input scalars.
539 INFER_ERROR("must be rank 0", op, "?;[1];?");
540 INFER_ERROR("must be rank 0", op, "?;?;[2]");
541 }
542
TEST(MathOpsTest,Cross_ShapeFn)543 TEST(MathOpsTest, Cross_ShapeFn) {
544 ShapeInferenceTestOp op("Cross");
545
546 INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[];[]");
547 INFER_ERROR("Dimension 0 in both shapes must be equal, but", op, "[3];[5]");
548 INFER_ERROR("Dimension must be 3 but", op, "[3,5];[3,5]");
549
550 INFER_OK(op, "?;?", "in0");
551 INFER_OK(op, "[?];[?]", "in0");
552 INFER_OK(op, "[1,?,3];[?,?,?]", "in0");
553 }
554
TEST(MathOpsTest,HistogramFixedWidth_ShapeFn)555 TEST(MathOpsTest, HistogramFixedWidth_ShapeFn) {
556 ShapeInferenceTestOp op("HistogramFixedWidth");
557
558 // value_range should be vector.
559 INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[];[];[]");
560 // value_range should have 2 elements.
561 INFER_ERROR("Dimension must be 2 but is 3", op, "[];[3];[]");
562 // nbins should be scalar.
563 INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[];[2];[2]");
564
565 INFER_OK(op, "?;?;?", "[?]");
566 INFER_OK(op, "[?];[2];[]", "[?]");
567 INFER_OK(op, "[?];[2];?", "[?]");
568 }
569
TEST(MathOpsTest,QuantizedAdd_ShapeFn)570 TEST(MathOpsTest, QuantizedAdd_ShapeFn) {
571 ShapeInferenceTestOp op("QuantizedAdd");
572
573 INFER_OK(op, "?;?;?;?;?;?", "?;[];[]");
574 INFER_OK(op, "?;?;[];[];[];[]", "?;[];[]");
575 INFER_OK(op, "[1,2];?;[];[];[];[]", "?;[];[]");
576 INFER_OK(op, "[];[2];[];[];[];[]", "[d1_0];[];[]");
577
578 // Rank checks on input scalars.
579 INFER_ERROR("must be rank 0", op, "?;?;[1];?;?;?");
580 INFER_ERROR("must be rank 0", op, "?;?;?;[2];?;?");
581 INFER_ERROR("must be rank 0", op, "?;?;?;?;[3];?");
582 INFER_ERROR("must be rank 0", op, "?;?;?;?;?;[4]");
583 }
584
TEST(MathOpsTest,Bincount_ShapeFn)585 TEST(MathOpsTest, Bincount_ShapeFn) {
586 ShapeInferenceTestOp op("Bincount");
587
588 // size should be scalar.
589 INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[1];?");
590
591 INFER_OK(op, "?;?;?", "[?]");
592 INFER_OK(op, "?;[];?", "[?]");
593 INFER_OK(op, "[?];[];?", "[?]");
594 INFER_OK(op, "[?];[];[?]", "[?]");
595 }
596
TEST(MathOpsTest,SobolSample)597 TEST(MathOpsTest, SobolSample) {
598 ShapeInferenceTestOp op("SobolSample");
599
600 // All inputs should be scalar.
601 INFER_ERROR("must be rank 0", op, "[1];?;?");
602 INFER_ERROR("must be rank 0", op, "?;[1];?");
603 INFER_ERROR("must be rank 0", op, "?;?;[1]");
604
605 INFER_OK(op, "[];[];[]", "[?,?]");
606 }
607
TEST(MathOpsTest,EqualOp)608 TEST(MathOpsTest, EqualOp) {
609 ShapeInferenceTestOp op("Equal");
610 AddNodeAttr("incompatible_shape_error", true, &op.node_def);
611
612 INFER_OK(op, "?;?", "?");
613 INFER_OK(op, "[1,2];?", "?");
614 INFER_OK(op, "?;[1,2]", "?");
615
616 INFER_OK(op, "[1,2,3];[1]", "[d0_0,d0_1,d0_2]");
617 INFER_OK(op, "[?,2,1];[1,3]", "[d0_0,d0_1,d1_1]");
618 INFER_OK(op, "[1,?,3];[3,1]", "[d0_0,d1_0,d0_2]");
619 INFER_OK(op, "[1,2,3];[2,1,3]", "[d1_0,d0_1,d0_2]");
620
621 // Note: Test case for GitHub issue 40471
622 INFER_OK(op, "[?,10,1];[?,1,4]", "[?,d0_1,d1_2]");
623 INFER_OK(op, "[10,?,1];[1,?,4]", "[d0_0,?,d1_2]");
624 }
625 } // end namespace tensorflow
626