1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/common_shape_fns.h"
16 
17 #include "absl/container/flat_hash_map.h"
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/strings/match.h"
20 #include "absl/strings/str_split.h"
21 #include "absl/strings/string_view.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/shape_inference.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/gtl/inlined_vector.h"
26 #include "tensorflow/core/util/einsum_op_util.h"
27 
28 namespace tensorflow {
29 
30 namespace shape_inference {
31 
32 // The V2 version computes windowed output size with arbitrary dilation_rate and
33 // explicit padding, while the original version only handles the cases where
34 // dilation_rates equal to 1 and the padding is SAME or VALID.
GetWindowedOutputSizeFromDimsV2(shape_inference::InferenceContext * c,shape_inference::DimensionHandle input_size,shape_inference::DimensionOrConstant filter_size,int64 dilation_rate,int64 stride,Padding padding_type,int64 padding_before,int64 padding_after,shape_inference::DimensionHandle * output_size)35 Status GetWindowedOutputSizeFromDimsV2(
36     shape_inference::InferenceContext* c,
37     shape_inference::DimensionHandle input_size,
38     shape_inference::DimensionOrConstant filter_size, int64 dilation_rate,
39     int64 stride, Padding padding_type, int64 padding_before,
40     int64 padding_after, shape_inference::DimensionHandle* output_size) {
41   if (stride <= 0) {
42     return errors::InvalidArgument("Stride must be > 0, but got ", stride);
43   }
44 
45   if (dilation_rate < 1) {
46     return errors::InvalidArgument("Dilation rate must be >= 1, but got ",
47                                    dilation_rate);
48   }
49 
50   // See also the parallel implementation in GetWindowedOutputSizeVerbose.
51   switch (padding_type) {
52     case Padding::VALID:
53       padding_before = padding_after = 0;
54       TF_FALLTHROUGH_INTENDED;
55     case Padding::EXPLICIT:
56       TF_RETURN_IF_ERROR(
57           c->Add(input_size, padding_before + padding_after, &input_size));
58       if (dilation_rate > 1) {
59         DimensionHandle window_size;
60         TF_RETURN_IF_ERROR(
61             c->Subtract(c->MakeDim(filter_size), 1, &window_size));
62         TF_RETURN_IF_ERROR(
63             c->Multiply(window_size, dilation_rate, &window_size));
64         TF_RETURN_IF_ERROR(c->Add(window_size, 1, &window_size));
65         TF_RETURN_IF_ERROR(c->Subtract(input_size, window_size, output_size));
66       } else {
67         TF_RETURN_IF_ERROR(c->Subtract(input_size, filter_size, output_size));
68       }
69       TF_RETURN_IF_ERROR(c->Add(*output_size, stride, output_size));
70       TF_RETURN_IF_ERROR(c->Divide(*output_size, stride,
71                                    /*evenly_divisible=*/false, output_size));
72       break;
73     case Padding::SAME:
74       TF_RETURN_IF_ERROR(c->Add(input_size, stride - 1, output_size));
75       TF_RETURN_IF_ERROR(c->Divide(*output_size, stride,
76                                    /*evenly_divisible=*/false, output_size));
77       break;
78   }
79   return Status::OK();
80 }
81 
GetWindowedOutputSizeFromDims(shape_inference::InferenceContext * c,shape_inference::DimensionHandle input_size,shape_inference::DimensionOrConstant filter_size,int64 stride,Padding padding_type,shape_inference::DimensionHandle * output_size)82 Status GetWindowedOutputSizeFromDims(
83     shape_inference::InferenceContext* c,
84     shape_inference::DimensionHandle input_size,
85     shape_inference::DimensionOrConstant filter_size, int64 stride,
86     Padding padding_type, shape_inference::DimensionHandle* output_size) {
87   if (padding_type == Padding::EXPLICIT) {
88     return errors::Internal(
89         "GetWindowedOutputSizeFromDims does not handle EXPLICIT padding; call "
90         "GetWindowedOutputSizeFromDimsV2 instead");
91   }
92   return GetWindowedOutputSizeFromDimsV2(c, input_size, filter_size,
93                                          /*dilation_rate=*/1, stride,
94                                          padding_type,
95                                          // Give dummy values of -1 to
96                                          // padding_before and padding_after,
97                                          // since explicit padding is not used.
98                                          -1, -1, output_size);
99 }
100 
UnchangedShape(shape_inference::InferenceContext * c)101 Status UnchangedShape(shape_inference::InferenceContext* c) {
102   c->set_output(0, c->input(0));
103   auto* handle_data = c->input_handle_shapes_and_types(0);
104   if (handle_data != nullptr) {
105     c->set_output_handle_shapes_and_types(0, *handle_data);
106   }
107   return Status::OK();
108 }
109 
MatMulShape(shape_inference::InferenceContext * c)110 Status MatMulShape(shape_inference::InferenceContext* c) {
111   ShapeHandle a;
112   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &a));
113 
114   ShapeHandle b;
115   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &b));
116 
117   bool transpose_a, transpose_b;
118   TF_RETURN_IF_ERROR(c->GetAttr("transpose_a", &transpose_a));
119   TF_RETURN_IF_ERROR(c->GetAttr("transpose_b", &transpose_b));
120   DimensionHandle output_rows = transpose_a ? c->Dim(a, 1) : c->Dim(a, 0);
121   DimensionHandle output_cols = transpose_b ? c->Dim(b, 0) : c->Dim(b, 1);
122 
123   // Validate that the inner shapes are compatible.
124   DimensionHandle inner_a = transpose_a ? c->Dim(a, 0) : c->Dim(a, 1);
125   DimensionHandle inner_b = transpose_b ? c->Dim(b, 1) : c->Dim(b, 0);
126   DimensionHandle merged;
127   TF_RETURN_IF_ERROR(c->Merge(inner_a, inner_b, &merged));
128 
129   c->set_output(0, c->Matrix(output_rows, output_cols));
130   return Status::OK();
131 }
132 
133 namespace {
134 
135 // Validate that an Einsum subscript contains exactly one or zero ellipsis; and
136 // that periods (.) occur only within an ellipses (...).
ValidateEinsumEllipsis(absl::string_view subscript,bool * found_ellipsis)137 Status ValidateEinsumEllipsis(absl::string_view subscript,
138                               bool* found_ellipsis) {
139   const int num_periods = absl::c_count(subscript, '.');
140   if (num_periods != 0 && num_periods != 3) {
141     return errors::InvalidArgument(
142         "Expected at most one ellipsis (...), but found ", num_periods,
143         " periods (.) in the input subscript: ", subscript);
144   }
145   if (num_periods == 3 && !absl::StrContains(subscript, "...")) {
146     return errors::InvalidArgument(
147         "Periods found outside of ellipsis in subscript: ", subscript);
148   }
149   *found_ellipsis = num_periods > 0;
150   return Status::OK();
151 }
152 
153 }  // namespace
154 
EinsumShape(shape_inference::InferenceContext * c)155 Status EinsumShape(shape_inference::InferenceContext* c) {
156   // We assume that the equation has a valid format. Either (x),(y)->(z)
157   // or (x)->(z), where each of (x), (y) and (z) are concatenation of zero or
158   // more latin alphabets and contains at most one ellipsis ('...').
159   string equation;
160   TF_RETURN_IF_ERROR(c->GetAttr("equation", &equation));
161   gtl::InlinedVector<string, 2> input_labels;
162   string output_labels;
163   TF_RETURN_IF_ERROR(
164       ParseEinsumEquation(equation, &input_labels, &output_labels));
165 
166   if (c->num_inputs() == 0 || c->num_inputs() > 2) {
167     return errors::InvalidArgument("Expected either 1 or 2 inputs but got: ",
168                                    c->num_inputs());
169   }
170   const int input_labels_size = input_labels.size();
171   if (c->num_inputs() != input_labels_size) {
172     return errors::InvalidArgument("Expected ", input_labels.size(),
173                                    " inputs for equation ", equation,
174                                    " but got: ", c->num_inputs());
175   }
176 
177   // Validate input subscripts, build the label to dimension mapping and obtain
178   // the broadcast shapes that map to ellipsis.
179   absl::flat_hash_map<char, DimensionHandle> label_to_dimension;
180   gtl::InlinedVector<ShapeHandle, 2> input_bcast_shapes(c->num_inputs());
181   for (int i = 0, end = c->num_inputs(); i < end; ++i) {
182     bool has_ellipsis = false;
183     TF_RETURN_IF_ERROR(ValidateEinsumEllipsis(input_labels[i], &has_ellipsis));
184     ShapeHandle input_shape = c->input(i);
185     // Validate that the input rank is sufficient for the given number of named
186     // labels.
187     if (c->RankKnown(input_shape)) {
188       if (has_ellipsis) {
189         const int num_named_labels =
190             static_cast<int>(input_labels[i].size()) - 3;
191         TF_RETURN_WITH_CONTEXT_IF_ERROR(
192             c->WithRankAtLeast(input_shape, num_named_labels, &input_shape),
193             " for ", i, "th input and equation: ", equation);
194       } else {
195         const int num_named_labels = static_cast<int>(input_labels[i].size());
196         TF_RETURN_WITH_CONTEXT_IF_ERROR(
197             c->WithRank(input_shape, num_named_labels, &input_shape), " for ",
198             i, "th input and equation: ", equation);
199       }
200     }
201 
202     bool seen_ellipsis = false;
203     input_bcast_shapes[i] = c->Scalar();
204     // Run through the input labels; populate label_to_dimension mapping and
205     // compute the broadcast shapes corresponding to the ellipsis (if present).
206     for (int label_idx = 0, end = input_labels[i].size(); label_idx < end;
207          ++label_idx) {
208       const char label = input_labels[i][label_idx];
209       // Calculate the input axis that the current label is referring to. After
210       // the ellipsis, the axis may be found by using negative indices; i.e the
211       // (rank - k)th dimension corresponds to the (num_labels - k)th label.
212       const int64 axis_before_ellipsis = label_idx;
213       const int64 axis_after_ellipsis =
214           c->RankKnown(input_shape)
215               ? label_idx + c->Rank(input_shape) - input_labels[i].size()
216               : -1;
217 
218       // Populate the input broadcast shape when we encounter an ellipsis (...).
219       if (label == '.') {
220         if (!c->RankKnown(input_shape)) {
221           input_bcast_shapes[i] = c->UnknownShape();
222         } else {
223           // The broadcast shape runs till the named label right after the
224           // ellipsis, the label with index (label_idx + 3).
225           TF_RETURN_IF_ERROR(c->Subshape(input_shape, axis_before_ellipsis,
226                                          axis_after_ellipsis + 3,
227                                          &input_bcast_shapes[i]));
228         }
229         label_idx += 2;  // Skip the rest of the ellipsis.
230         seen_ellipsis = true;
231         continue;
232       }
233       // Obtain the dimension that the current label corresponds to.
234       int64 axis = seen_ellipsis ? axis_after_ellipsis : axis_before_ellipsis;
235       DimensionHandle new_dim = c->RankKnown(input_shape)
236                                     ? c->Dim(input_shape, axis)
237                                     : c->UnknownDim();
238       // If we've seen this label before, make sure previous and current
239       // dimensions are compatible.
240       if (label_to_dimension.contains(label)) {
241         DimensionHandle merged;
242         TF_RETURN_IF_ERROR(
243             c->Merge(label_to_dimension[label], new_dim, &merged));
244         label_to_dimension[label] = merged;
245       } else {
246         label_to_dimension[label] = new_dim;
247       }
248     }
249   }
250 
251   // For two inputs, broadcast the two input broadcast shapes to create the
252   // output broadcast shape. For one input, just copy the single broadcast
253   // shape.
254   ShapeHandle output_bcast_shape;
255   if (input_bcast_shapes.size() == 1) {
256     output_bcast_shape = input_bcast_shapes[0];
257   } else if (input_bcast_shapes.size() == 2) {
258     TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
259         c, input_bcast_shapes[0], input_bcast_shapes[1], true,
260         &output_bcast_shape));
261   }
262 
263   bool output_has_ellipsis = false;
264   TF_RETURN_IF_ERROR(
265       ValidateEinsumEllipsis(output_labels, &output_has_ellipsis));
266   if (output_has_ellipsis) {
267     // If the output subscript has ellipsis and the output broadcast rank is
268     // unknown, then the output shape should have unknown rank.
269     if (!c->RankKnown(output_bcast_shape)) {
270       c->set_output(0, c->UnknownShape());
271       return Status::OK();
272     }
273   } else {
274     // If the output subscripts don't have ellipsis then make sure the output
275     // broadcasting shape is empty.
276     TF_RETURN_WITH_CONTEXT_IF_ERROR(
277         c->WithRankAtMost(output_bcast_shape, 0, &output_bcast_shape),
278         " for einsum equation '", equation,
279         "' without ellipsis (...) in the output subscripts where input(s) have "
280         "non-empty broadcasting shape");
281     output_bcast_shape = c->Scalar();
282   }
283 
284   // Create the output shape from output labels and label_to_dimension mapping.
285   std::vector<DimensionHandle> output_dims;
286   for (int label_idx = 0, end = output_labels.size(); label_idx < end;
287        ++label_idx) {
288     const char label = output_labels[label_idx];
289     // Append the output_bcast_shape when the ellipsis is encountered.
290     if (label == '.') {
291       for (int k = 0; k < c->Rank(output_bcast_shape); ++k) {
292         output_dims.push_back(c->Dim(output_bcast_shape, k));
293       }
294       label_idx += 2;  // Skip the rest of the ellipsis.
295       continue;
296     }
297     auto dimension_it = label_to_dimension.find(label);
298     if (dimension_it == label_to_dimension.end()) {
299       return errors::InvalidArgument(
300           "Einsum output subscripts for equation '", equation, "' has label '",
301           label, "' which is not present in the input subscripts");
302     }
303     output_dims.push_back(dimension_it->second);
304   }
305   c->set_output(0, c->MakeShape(output_dims));
306   return Status::OK();
307 }
308 
BatchMatMulV2Shape(shape_inference::InferenceContext * c)309 Status BatchMatMulV2Shape(shape_inference::InferenceContext* c) {
310   ShapeHandle a_shape;
311   ShapeHandle b_shape;
312   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &a_shape));
313   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &b_shape));
314 
315   // Determine output rows and columns.
316   bool adj_x;
317   bool adj_y;
318   TF_RETURN_IF_ERROR(c->GetAttr("adj_x", &adj_x));
319   TF_RETURN_IF_ERROR(c->GetAttr("adj_y", &adj_y));
320   DimensionHandle output_rows = c->Dim(a_shape, adj_x ? -1 : -2);
321   DimensionHandle output_cols = c->Dim(b_shape, adj_y ? -2 : -1);
322 
323   // Inner dimensions should be compatible.
324   DimensionHandle inner_merged;
325   TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, adj_x ? -2 : -1),
326                               c->Dim(b_shape, adj_y ? -1 : -2), &inner_merged));
327 
328   // Batch dimensions should broadcast with each other.
329   ShapeHandle a_batch_shape;
330   ShapeHandle b_batch_shape;
331   ShapeHandle output_batch_shape;
332   TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_shape));
333   TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_shape));
334 
335   TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
336       c, a_batch_shape, b_batch_shape, true, &output_batch_shape));
337 
338   ShapeHandle output_shape;
339   TF_RETURN_IF_ERROR(c->Concatenate(
340       output_batch_shape, c->Matrix(output_rows, output_cols), &output_shape));
341 
342   c->set_output(0, output_shape);
343   return Status::OK();
344 }
345 
BatchMatMulShape(shape_inference::InferenceContext * c)346 Status BatchMatMulShape(shape_inference::InferenceContext* c) {
347   ShapeHandle a_shape;
348   ShapeHandle b_shape;
349   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &a_shape));
350   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &b_shape));
351 
352   // Determine output rows and cols.
353   bool adj_x;
354   bool adj_y;
355   TF_RETURN_IF_ERROR(c->GetAttr("adj_x", &adj_x));
356   TF_RETURN_IF_ERROR(c->GetAttr("adj_y", &adj_y));
357   DimensionHandle output_rows = c->Dim(a_shape, adj_x ? -1 : -2);
358   DimensionHandle output_cols = c->Dim(b_shape, adj_y ? -2 : -1);
359 
360   // Batch dims match between inputs.
361   ShapeHandle a_batch_dims;
362   ShapeHandle b_batch_dims;
363   ShapeHandle batch_dims;
364   TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_dims));
365   TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_dims));
366   TF_RETURN_IF_ERROR(c->Merge(a_batch_dims, b_batch_dims, &batch_dims));
367 
368   // Assert inner dims match.
369   DimensionHandle unused;
370   TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, adj_x ? -2 : -1),
371                               c->Dim(b_shape, adj_y ? -1 : -2), &unused));
372 
373   ShapeHandle out;
374   TF_RETURN_IF_ERROR(
375       c->Concatenate(batch_dims, c->Matrix(output_rows, output_cols), &out));
376   c->set_output(0, out);
377   return Status::OK();
378 }
379 
380 // --------------------------------------------------------------------------
381 
BiasAddShape(shape_inference::InferenceContext * c)382 Status BiasAddShape(shape_inference::InferenceContext* c) {
383   ShapeHandle input_shape;
384 
385   // Fetch the data_format attribute, which may not exist.
386   string data_format;
387   Status s = c->GetAttr("data_format", &data_format);
388 
389   if (s.ok() && data_format == "NCHW") {
390     TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape));
391   } else {
392     TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
393   }
394 
395   ShapeHandle bias_shape;
396   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &bias_shape));
397   DimensionHandle bias_dim = c->Dim(bias_shape, 0);
398 
399   // If rank unknown, return unknown shape.
400   if (!c->RankKnown(input_shape)) {
401     c->set_output(0, c->UnknownShape());
402     return Status::OK();
403   }
404 
405   // Output has the same shape as the input, and matches the length of
406   // the bias in its bias dimension.
407   ShapeHandle output_shape;
408   if (s.ok() && data_format == "NCHW") {
409     // Merge the length of bias_shape into the third to last dimension
410     ShapeHandle first;
411     TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, 1, &first));
412 
413     ShapeHandle last;
414     TF_RETURN_IF_ERROR(c->Subshape(input_shape, 2, &last));
415 
416     DimensionHandle input_bias_dim = c->Dim(input_shape, 1);
417     DimensionHandle merged_bias_dim;
418     TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim));
419     ShapeHandle merged_bias = c->Vector(merged_bias_dim);
420 
421     ShapeHandle temp;
422     TF_RETURN_IF_ERROR(c->Concatenate(first, merged_bias, &temp));
423     TF_RETURN_IF_ERROR(c->Concatenate(temp, last, &output_shape));
424   } else {
425     ShapeHandle all_but_bias;
426     TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, -1, &all_but_bias));
427 
428     DimensionHandle input_bias_dim = c->Dim(input_shape, -1);
429     DimensionHandle merged_bias_dim;
430     TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim));
431 
432     ShapeHandle merged_bias = c->Vector(merged_bias_dim);
433     TF_RETURN_IF_ERROR(
434         c->Concatenate(all_but_bias, merged_bias, &output_shape));
435   }
436 
437   c->set_output(0, output_shape);
438   return Status::OK();
439 }
440 
BiasAddGradShape(shape_inference::InferenceContext * c)441 Status BiasAddGradShape(shape_inference::InferenceContext* c) {
442   ShapeHandle input_shape;
443   // Fetch the data_format attribute, which may not exist.
444   string data_format;
445   Status s = c->GetAttr("data_format", &data_format);
446 
447   if (s.ok() && data_format == "NCHW") {
448     TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape));
449     c->set_output(0, c->Vector(c->Dim(input_shape, 1)));
450   } else {
451     TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
452     c->set_output(0, c->Vector(c->Dim(input_shape, -1)));
453   }
454 
455   return Status::OK();
456 }
457 
CheckFormatConstraintsOnShape(const TensorFormat tensor_format,const ShapeHandle shape_handle,const string & tensor_name,shape_inference::InferenceContext * c)458 Status CheckFormatConstraintsOnShape(const TensorFormat tensor_format,
459                                      const ShapeHandle shape_handle,
460                                      const string& tensor_name,
461                                      shape_inference::InferenceContext* c) {
462   if (tensor_format == FORMAT_NCHW_VECT_C) {
463     // Check that the vect dim has size 4.
464     const int num_dims = c->Rank(shape_handle);
465     DimensionHandle vect_dim = c->Dim(
466         shape_handle, GetTensorInnerFeatureDimIndex(num_dims, tensor_format));
467     DimensionHandle unused_vect_dim;
468     TF_RETURN_IF_ERROR(c->WithValue(vect_dim, 4, &unused_vect_dim));
469   }
470 
471   return Status::OK();
472 }
473 
DatasetIteratorShape(shape_inference::InferenceContext * c)474 Status DatasetIteratorShape(shape_inference::InferenceContext* c) {
475   shape_inference::ShapeHandle unused;
476   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
477   std::vector<PartialTensorShape> output_shapes;
478   TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
479   const int output_shapes_size = output_shapes.size();
480   if (output_shapes_size != c->num_outputs()) {
481     return errors::InvalidArgument(
482         "`output_shapes` must be the same length as `output_types` (",
483         output_shapes.size(), " vs. ", c->num_outputs());
484   }
485   for (size_t i = 0; i < output_shapes.size(); ++i) {
486     shape_inference::ShapeHandle output_shape_handle;
487     TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
488         output_shapes[i], &output_shape_handle));
489     c->set_output(static_cast<int>(i), output_shape_handle);
490   }
491   return Status::OK();
492 }
493 
MakeShapeFromFormat(TensorFormat format,DimensionOrConstant N,const std::vector<DimensionOrConstant> & spatial,DimensionOrConstant C,ShapeHandle * out,shape_inference::InferenceContext * context)494 Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N,
495                            const std::vector<DimensionOrConstant>& spatial,
496                            DimensionOrConstant C, ShapeHandle* out,
497                            shape_inference::InferenceContext* context) {
498   const int num_dims = GetTensorDimsFromSpatialDims(spatial.size(), format);
499   std::vector<DimensionHandle> dims_actual(num_dims);
500   dims_actual[GetTensorBatchDimIndex(num_dims, format)] = context->MakeDim(N);
501   int outer_c_index = GetTensorFeatureDimIndex(num_dims, format);
502   dims_actual[outer_c_index] = context->MakeDim(C);
503   if (format == FORMAT_NCHW_VECT_C) {
504     dims_actual[GetTensorInnerFeatureDimIndex(num_dims, format)] =
505         context->MakeDim(4);
506   } else if (format == FORMAT_NHWC_VECT_W) {
507     dims_actual[GetTensorInnerWidthDimIndex(num_dims, format)] =
508         context->MakeDim(4);
509   }
510   for (int spatial_dim = 0, end = spatial.size(); spatial_dim < end;
511        spatial_dim++) {
512     dims_actual[GetTensorSpatialDimIndex(num_dims, format, spatial_dim)] =
513         context->MakeDim(spatial[spatial_dim]);
514   }
515   *out = context->MakeShape(dims_actual);
516   return Status::OK();
517 }
518 
DimensionsFromShape(ShapeHandle shape,TensorFormat format,DimensionHandle * batch_dim,gtl::MutableArraySlice<DimensionHandle> spatial_dims,DimensionHandle * filter_dim,InferenceContext * context)519 Status DimensionsFromShape(ShapeHandle shape, TensorFormat format,
520                            DimensionHandle* batch_dim,
521                            gtl::MutableArraySlice<DimensionHandle> spatial_dims,
522                            DimensionHandle* filter_dim,
523                            InferenceContext* context) {
524   const int32 rank = GetTensorDimsFromSpatialDims(spatial_dims.size(), format);
525   // Batch.
526   *batch_dim = context->Dim(shape, GetTensorBatchDimIndex(rank, format));
527   // Spatial.
528   for (int spatial_dim_index = 0, end = spatial_dims.size();
529        spatial_dim_index < end; ++spatial_dim_index) {
530     spatial_dims[spatial_dim_index] = context->Dim(
531         shape, GetTensorSpatialDimIndex(rank, format, spatial_dim_index));
532   }
533   // Channel.
534   *filter_dim = context->Dim(shape, GetTensorFeatureDimIndex(rank, format));
535   if (format == FORMAT_NCHW_VECT_C) {
536     TF_RETURN_IF_ERROR(context->Multiply(
537         *filter_dim,
538         context->Dim(shape, GetTensorInnerFeatureDimIndex(rank, format)),
539         filter_dim));
540   }
541   return Status::OK();
542 }
543 
ShapeFromDimensions(DimensionHandle batch_dim,gtl::ArraySlice<DimensionHandle> spatial_dims,DimensionHandle filter_dim,TensorFormat format,InferenceContext * context,ShapeHandle * shape)544 Status ShapeFromDimensions(DimensionHandle batch_dim,
545                            gtl::ArraySlice<DimensionHandle> spatial_dims,
546                            DimensionHandle filter_dim, TensorFormat format,
547                            InferenceContext* context, ShapeHandle* shape) {
548   const int32 rank = GetTensorDimsFromSpatialDims(spatial_dims.size(), format);
549   std::vector<DimensionHandle> out_dims(rank);
550 
551   // Batch.
552   out_dims[tensorflow::GetTensorBatchDimIndex(rank, format)] = batch_dim;
553   // Spatial.
554   for (int spatial_dim_index = 0, end = spatial_dims.size();
555        spatial_dim_index < end; ++spatial_dim_index) {
556     out_dims[tensorflow::GetTensorSpatialDimIndex(
557         rank, format, spatial_dim_index)] = spatial_dims[spatial_dim_index];
558   }
559   // Channel.
560   if (format == tensorflow::FORMAT_NCHW_VECT_C) {
561     // When format is NCHW_VECT_C, factor the feature map count
562     // into the outer feature count and the inner feature count (=4).
563     TF_RETURN_IF_ERROR(context->Divide(
564         filter_dim, 4, /*evenly_divisible=*/true,
565         &out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)]));
566     out_dims[GetTensorInnerFeatureDimIndex(rank, format)] = context->MakeDim(4);
567   } else {
568     out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)] = filter_dim;
569   }
570 
571   *shape = context->MakeShape(out_dims);
572   return tensorflow::Status::OK();
573 }
574 
575 namespace {
576 
Conv2DShapeImpl(shape_inference::InferenceContext * c,bool supports_explicit_padding)577 Status Conv2DShapeImpl(shape_inference::InferenceContext* c,
578                        bool supports_explicit_padding) {
579   string data_format_str, filter_format_str;
580   if (!c->GetAttr("data_format", &data_format_str).ok()) {
581     data_format_str = "NHWC";
582   }
583   if (!c->GetAttr("filter_format", &filter_format_str).ok()) {
584     filter_format_str = "HWIO";
585   }
586 
587   TensorFormat data_format;
588   if (!FormatFromString(data_format_str, &data_format)) {
589     return errors::InvalidArgument("Invalid data format string: ",
590                                    data_format_str);
591   }
592   FilterTensorFormat filter_format;
593   if (!FilterFormatFromString(filter_format_str, &filter_format)) {
594     return errors::InvalidArgument("Invalid filter format string: ",
595                                    filter_format_str);
596   }
597 
598   constexpr int num_spatial_dims = 2;
599   const int rank = GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
600   ShapeHandle conv_input_shape;
601   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &conv_input_shape));
602   TF_RETURN_IF_ERROR(CheckFormatConstraintsOnShape(
603       data_format, conv_input_shape, "conv_input", c));
604 
605   // The filter rank should match the input (4 for NCHW, 5 for NCHW_VECT_C).
606   ShapeHandle filter_shape;
607   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape));
608   TF_RETURN_IF_ERROR(
609       CheckFormatConstraintsOnShape(data_format, filter_shape, "filter", c));
610 
611   std::vector<int32> dilations;
612   TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations));
613 
614   if (dilations.size() != 4) {
615     return errors::InvalidArgument(
616         "Conv2D requires the dilation attribute to contain 4 values, but got: ",
617         dilations.size());
618   }
619 
620   std::vector<int32> strides;
621   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
622 
623   // strides.size() should be 4 (NCHW) even if the input is 5 (NCHW_VECT_C).
624   if (strides.size() != 4) {
625     return errors::InvalidArgument("Conv2D on data format ", data_format_str,
626                                    " requires the stride attribute to contain"
627                                    " 4 values, but got: ",
628                                    strides.size());
629   }
630 
631   const int32 stride_rows = GetTensorDim(strides, data_format, 'H');
632   const int32 stride_cols = GetTensorDim(strides, data_format, 'W');
633   const int32 dilation_rows = GetTensorDim(dilations, data_format, 'H');
634   const int32 dilation_cols = GetTensorDim(dilations, data_format, 'W');
635 
636   DimensionHandle batch_size_dim;
637   DimensionHandle input_depth_dim;
638   gtl::InlinedVector<DimensionHandle, 2> input_spatial_dims(2);
639   TF_RETURN_IF_ERROR(DimensionsFromShape(
640       conv_input_shape, data_format, &batch_size_dim,
641       absl::MakeSpan(input_spatial_dims), &input_depth_dim, c));
642 
643   DimensionHandle output_depth_dim = c->Dim(
644       filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'O'));
645   DimensionHandle filter_rows_dim = c->Dim(
646       filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'H'));
647   DimensionHandle filter_cols_dim = c->Dim(
648       filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'W'));
649   DimensionHandle filter_input_depth_dim;
650   if (filter_format == FORMAT_OIHW_VECT_I) {
651     TF_RETURN_IF_ERROR(c->Multiply(
652         c->Dim(filter_shape,
653                GetFilterDimIndex<num_spatial_dims>(filter_format, 'I')),
654         c->Dim(filter_shape,
655                GetFilterTensorInnerInputChannelsDimIndex(rank, filter_format)),
656         &filter_input_depth_dim));
657   } else {
658     filter_input_depth_dim = c->Dim(
659         filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'I'));
660   }
661 
662   // Check that the input tensor and the filter tensor agree on the channel
663   // count.
664   if (c->ValueKnown(input_depth_dim) && c->ValueKnown(filter_input_depth_dim)) {
665     int64 input_depth_value = c->Value(input_depth_dim),
666           filter_input_depth_value = c->Value(filter_input_depth_dim);
667     if (input_depth_value % filter_input_depth_value != 0)
668       return errors::InvalidArgument(
669           "Depth of input (", input_depth_value,
670           ") is not a multiple of input depth of filter (",
671           filter_input_depth_value, ")");
672     if (input_depth_value != filter_input_depth_value) {
673       int64 num_groups = input_depth_value / filter_input_depth_value;
674       if (c->ValueKnown(output_depth_dim)) {
675         int64 output_depth_value = c->Value(output_depth_dim);
676         if (output_depth_value % num_groups != 0)
677           return errors::InvalidArgument(
678               "Depth of output (", output_depth_value,
679               ") is not a multiple of the number of groups (", num_groups, ")");
680       }
681     }
682   }
683 
684   Padding padding;
685   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
686 
687   std::vector<int64> explicit_paddings;
688   if (supports_explicit_padding) {
689     Status s = c->GetAttr("explicit_paddings", &explicit_paddings);
690     // Use the default value, which is an empty list, if the attribute is not
691     // found. Otherwise return the error to the caller.
692     if (!s.ok() && !errors::IsNotFound(s)) {
693       return s;
694     }
695     TF_RETURN_IF_ERROR(CheckValidPadding(padding, explicit_paddings,
696                                          /*num_dims=*/4, data_format));
697   } else {
698     CHECK(padding != Padding::EXPLICIT);  // Crash ok.
699   }
700 
701   DimensionHandle output_rows, output_cols;
702   int64 pad_rows_before = -1, pad_rows_after = -1;
703   int64 pad_cols_before = -1, pad_cols_after = -1;
704   if (padding == Padding::EXPLICIT) {
705     GetExplicitPaddingForDim(explicit_paddings, data_format, 'H',
706                              &pad_rows_before, &pad_rows_after);
707     GetExplicitPaddingForDim(explicit_paddings, data_format, 'W',
708                              &pad_cols_before, &pad_cols_after);
709   }
710   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
711       c, input_spatial_dims[0], filter_rows_dim, dilation_rows, stride_rows,
712       padding, pad_rows_before, pad_rows_after, &output_rows));
713   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
714       c, input_spatial_dims[1], filter_cols_dim, dilation_cols, stride_cols,
715       padding, pad_cols_before, pad_cols_after, &output_cols));
716 
717   ShapeHandle output_shape;
718   TF_RETURN_IF_ERROR(
719       ShapeFromDimensions(batch_size_dim, {output_rows, output_cols},
720                           output_depth_dim, data_format, c, &output_shape));
721   c->set_output(0, output_shape);
722   return Status::OK();
723 }
724 
725 }  // namespace
726 
727 // Shape function for Conv2D-like operations that support explicit padding.
Conv2DShapeWithExplicitPadding(shape_inference::InferenceContext * c)728 Status Conv2DShapeWithExplicitPadding(shape_inference::InferenceContext* c) {
729   return Conv2DShapeImpl(c, true);
730 }
731 
732 // Shape function for Conv2D-like operations that do not support explicit
733 // padding.
Conv2DShape(shape_inference::InferenceContext * c)734 Status Conv2DShape(shape_inference::InferenceContext* c) {
735   return Conv2DShapeImpl(c, false);
736 }
737 
738 // TODO(mjanusz): Unify all conv/pooling shape functions.
Conv3DShape(shape_inference::InferenceContext * c)739 Status Conv3DShape(shape_inference::InferenceContext* c) {
740   ShapeHandle input_shape;
741   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
742   ShapeHandle filter_shape;
743   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &filter_shape));
744 
745   string data_format;
746   Status s = c->GetAttr("data_format", &data_format);
747 
748   std::vector<int32> dilations;
749   TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations));
750 
751   if (dilations.size() != 5) {
752     return errors::InvalidArgument(
753         "Conv3D requires the dilation attribute to contain 5 values, but got: ",
754         dilations.size());
755   }
756 
757   std::vector<int32> strides;
758   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
759   if (strides.size() != 5) {
760     return errors::InvalidArgument(
761         "Conv3D requires the stride attribute to contain 5 values, but got: ",
762         strides.size());
763   }
764 
765   int32 stride_planes, stride_rows, stride_cols;
766   int32 dilation_planes, dilation_rows, dilation_cols;
767   if (s.ok() && data_format == "NCDHW") {
768     // Convert input_shape to NDHWC.
769     auto dim = [&](char dimension) {
770       return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension));
771     };
772     input_shape =
773         c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}});
774     stride_planes = strides[2];
775     stride_rows = strides[3];
776     stride_cols = strides[4];
777     dilation_planes = dilations[2];
778     dilation_cols = dilations[3];
779     dilation_rows = dilations[4];
780   } else {
781     stride_planes = strides[1];
782     stride_rows = strides[2];
783     stride_cols = strides[3];
784     dilation_planes = dilations[1];
785     dilation_cols = dilations[2];
786     dilation_rows = dilations[3];
787   }
788 
789   DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
790   DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
791   DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
792   DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
793   DimensionHandle input_depth_dim = c->Dim(input_shape, 4);
794 
795   DimensionHandle filter_planes_dim = c->Dim(filter_shape, 0);
796   DimensionHandle filter_rows_dim = c->Dim(filter_shape, 1);
797   DimensionHandle filter_cols_dim = c->Dim(filter_shape, 2);
798   DimensionHandle filter_input_depth_dim = c->Dim(filter_shape, 3);
799   DimensionHandle output_depth_dim = c->Dim(filter_shape, 4);
800 
801   // Check that the input tensor and the filter tensor agree on the channel
802   // count.
803   if (c->ValueKnown(input_depth_dim) && c->ValueKnown(filter_input_depth_dim)) {
804     int64 input_depth_value = c->Value(input_depth_dim),
805           filter_input_depth_value = c->Value(filter_input_depth_dim);
806     if (input_depth_value % filter_input_depth_value != 0)
807       return errors::InvalidArgument(
808           "Depth of input (", input_depth_value,
809           ") is not a multiple of input depth of filter (",
810           filter_input_depth_value, ")");
811     if (input_depth_value != filter_input_depth_value) {
812       int64 num_groups = input_depth_value / filter_input_depth_value;
813       if (c->ValueKnown(output_depth_dim)) {
814         int64 output_depth_value = c->Value(output_depth_dim);
815         if (output_depth_value % num_groups != 0)
816           return errors::InvalidArgument(
817               "Depth of output (", output_depth_value,
818               ") is not a multiple of the number of groups (", num_groups, ")");
819       }
820     }
821   }
822 
823   Padding padding;
824   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
825   DimensionHandle output_planes, output_rows, output_cols;
826 
827   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
828       c, in_planes_dim, filter_planes_dim, dilation_planes, stride_planes,
829       padding, -1, -1, &output_planes));
830   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
831       c, in_rows_dim, filter_rows_dim, dilation_rows, stride_rows, padding, -1,
832       -1, &output_rows));
833   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
834       c, in_cols_dim, filter_cols_dim, dilation_cols, stride_cols, padding, -1,
835       -1, &output_cols));
836 
837   ShapeHandle output_shape;
838   if (data_format == "NCDHW") {
839     output_shape = c->MakeShape({batch_size_dim, output_depth_dim,
840                                  output_planes, output_rows, output_cols});
841   } else {
842     output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows,
843                                  output_cols, output_depth_dim});
844   }
845   c->set_output(0, output_shape);
846   return Status::OK();
847 }
848 
Conv2DBackpropInputShape(shape_inference::InferenceContext * c)849 Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c) {
850   string data_format_str;
851   if (!c->GetAttr("data_format", &data_format_str).ok()) {
852     data_format_str = "NHWC";
853   }
854   TensorFormat data_format;
855   if (!FormatFromString(data_format_str, &data_format)) {
856     return errors::InvalidArgument("Invalid data format string: ",
857                                    data_format_str);
858   }
859 
860   // For the rest of this function, output_grad_* describes out_backprop and
861   // input_grad_* describes in_backprop.
862   ShapeHandle output_grad_shape = c->input(2);
863   TF_RETURN_IF_ERROR(c->WithRank(output_grad_shape, 4, &output_grad_shape));
864   ShapeHandle filter_shape = c->input(1);
865   TF_RETURN_IF_ERROR(c->WithRank(filter_shape, 4, &filter_shape));
866 
867   DimensionHandle batch_size_dim;
868   DimensionHandle output_grad_depth_dim;
869   gtl::InlinedVector<DimensionHandle, 2> output_grad_spatial_dims(2);
870   TF_RETURN_IF_ERROR(DimensionsFromShape(
871       output_grad_shape, data_format, &batch_size_dim,
872       absl::MakeSpan(output_grad_spatial_dims), &output_grad_depth_dim, c));
873   DimensionHandle unused;
874   TF_RETURN_IF_ERROR(
875       c->Merge(output_grad_depth_dim, c->Dim(filter_shape, 3), &unused));
876 
877   ShapeHandle specified_input_grad_shape;
878   TF_RETURN_IF_ERROR(
879       c->MakeShapeFromShapeTensor(0, &specified_input_grad_shape));
880   if (c->Rank(specified_input_grad_shape) == InferenceContext::kUnknownRank) {
881     TF_RETURN_IF_ERROR(c->WithRank(specified_input_grad_shape, 4,
882                                    &specified_input_grad_shape));
883   }
884 
885   // input_grad_depth_dim doesn't equal c->Dim(filter_shape,2) when the number
886   // of groups is larger than 1. If input_sizes is a 4D shape, we collect
887   // input_grad_depth_dim from input_sizes; otherwise we compute it as
888   // c->Dim(filter_shape,2).
889   DimensionHandle input_grad_depth_dim;
890   gtl::InlinedVector<DimensionHandle, 2> specified_input_grad_spatial_dims(2);
891   int specified_input_grad_rank = c->Rank(specified_input_grad_shape);
892   if (specified_input_grad_rank == 4) {
893     DimensionHandle specified_batch_size_dim;
894     TF_RETURN_IF_ERROR(DimensionsFromShape(
895         specified_input_grad_shape, data_format, &specified_batch_size_dim,
896         absl::MakeSpan(specified_input_grad_spatial_dims),
897         &input_grad_depth_dim, c));
898     TF_RETURN_IF_ERROR(
899         c->Merge(specified_batch_size_dim, batch_size_dim, &unused));
900   } else if (specified_input_grad_rank == 2) {
901     specified_input_grad_spatial_dims[0] =
902         c->Dim(specified_input_grad_shape, 0);
903     specified_input_grad_spatial_dims[1] =
904         c->Dim(specified_input_grad_shape, 1);
905     input_grad_depth_dim = c->Dim(filter_shape, 2);
906   } else {
907     return errors::InvalidArgument(
908         "Conv2DBackpropInput requires input_sizes to contain 4 values or 2 "
909         "values, but got: ",
910         specified_input_grad_rank);
911   }
912 
913   ShapeHandle input_grad_shape;
914   TF_RETURN_IF_ERROR(ShapeFromDimensions(
915       batch_size_dim, specified_input_grad_spatial_dims, input_grad_depth_dim,
916       data_format, c, &input_grad_shape));
917   c->set_output(0, input_grad_shape);
918   return Status::OK();
919 }
920 
921 namespace {
922 
DepthwiseConv2DNativeShapeImpl(shape_inference::InferenceContext * c,bool supports_explicit_padding)923 Status DepthwiseConv2DNativeShapeImpl(shape_inference::InferenceContext* c,
924                                       bool supports_explicit_padding) {
925   ShapeHandle input_shape;
926   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
927   ShapeHandle filter_shape;
928   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape));
929 
930   std::vector<int32> strides;
931   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
932 
933   if (strides.size() != 4) {
934     return errors::InvalidArgument(
935         "DepthwiseConv2D requires the stride attribute to contain 4 values, "
936         "but got: ",
937         strides.size());
938   }
939 
940   std::vector<int32> dilations;
941   if (!c->GetAttr("dilations", &dilations).ok()) {
942     dilations.resize(4, 1);
943   }
944 
945   if (dilations.size() != 4) {
946     return errors::InvalidArgument(
947         "DepthwiseConv2D requires the dilations attribute to contain 4 values, "
948         "but got: ",
949         dilations.size());
950   }
951 
952   string data_format_str;
953   Status s = c->GetAttr("data_format", &data_format_str);
954   TensorFormat data_format;
955   if (!s.ok() || !FormatFromString(data_format_str, &data_format)) {
956     data_format = FORMAT_NHWC;
957   }
958   int32 stride_rows;
959   int32 stride_cols;
960   int32 dilation_rows;
961   int32 dilation_cols;
962   if (data_format == FORMAT_NCHW) {
963     // Canonicalize input shape to NHWC so the shape inference code below can
964     // process it.
965     input_shape =
966         c->MakeShape({{c->Dim(input_shape, 0), c->Dim(input_shape, 2),
967                        c->Dim(input_shape, 3), c->Dim(input_shape, 1)}});
968     stride_rows = strides[2];
969     stride_cols = strides[3];
970     dilation_rows = dilations[2];
971     dilation_cols = dilations[3];
972   } else {
973     stride_rows = strides[1];
974     stride_cols = strides[2];
975     dilation_rows = dilations[1];
976     dilation_cols = dilations[2];
977   }
978 
979   DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
980   DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
981   DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
982 
983   DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0);
984   DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1);
985   DimensionHandle input_depth = c->Dim(filter_shape, 2);
986   DimensionHandle depth_multiplier = c->Dim(filter_shape, 3);
987 
988   // Check that the input depths are compatible.
989   TF_RETURN_IF_ERROR(
990       c->Merge(c->Dim(input_shape, 3), input_depth, &input_depth));
991 
992   DimensionHandle output_depth;
993   TF_RETURN_IF_ERROR(c->Multiply(input_depth, depth_multiplier, &output_depth));
994 
995   Padding padding;
996   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
997 
998   std::vector<int64> explicit_paddings;
999   if (supports_explicit_padding) {
1000     Status status = c->GetAttr("explicit_paddings", &explicit_paddings);
1001     // Use the default value, which is an empty list, if the attribute is not
1002     // found. Otherwise return the error to the caller.
1003     if (!status.ok() && !errors::IsNotFound(status)) {
1004       return status;
1005     }
1006     TF_RETURN_IF_ERROR(CheckValidPadding(padding, explicit_paddings,
1007                                          /*num_dims=*/4, data_format));
1008   } else {
1009     DCHECK(padding != Padding::EXPLICIT);
1010   }
1011 
1012   // TODO(mrry,shlens): Raise an error if the stride would cause
1013   // information in the input to be ignored. This will require a change
1014   // in the kernel implementation.
1015   DimensionHandle output_rows, output_cols;
1016   int64 pad_rows_before = -1, pad_rows_after = -1;
1017   int64 pad_cols_before = -1, pad_cols_after = -1;
1018   if (padding == Padding::EXPLICIT) {
1019     GetExplicitPaddingForDim(explicit_paddings, data_format, 'H',
1020                              &pad_rows_before, &pad_rows_after);
1021     GetExplicitPaddingForDim(explicit_paddings, data_format, 'W',
1022                              &pad_cols_before, &pad_cols_after);
1023   }
1024   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
1025       c, in_rows_dim, filter_rows_dim, dilation_rows, stride_rows, padding,
1026       pad_rows_before, pad_rows_after, &output_rows));
1027   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
1028       c, in_cols_dim, filter_cols_dim, dilation_cols, stride_cols, padding,
1029       pad_cols_before, pad_cols_after, &output_cols));
1030 
1031   ShapeHandle output_shape;
1032   if (data_format == FORMAT_NCHW) {
1033     output_shape =
1034         c->MakeShape({batch_size_dim, output_depth, output_rows, output_cols});
1035   } else {
1036     output_shape =
1037         c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth});
1038   }
1039   c->set_output(0, output_shape);
1040   return Status::OK();
1041 }
1042 
1043 };  // namespace
1044 
DepthwiseConv2DNativeShape(shape_inference::InferenceContext * c)1045 Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) {
1046   return DepthwiseConv2DNativeShapeImpl(c, false);
1047 }
1048 
DepthwiseConv2DNativeShapeWithExplicitPadding(shape_inference::InferenceContext * c)1049 Status DepthwiseConv2DNativeShapeWithExplicitPadding(
1050     shape_inference::InferenceContext* c) {
1051   return DepthwiseConv2DNativeShapeImpl(c, true);
1052 }
1053 
AvgPoolShape(shape_inference::InferenceContext * c)1054 Status AvgPoolShape(shape_inference::InferenceContext* c) {
1055   string data_format_str;
1056   TensorFormat data_format;
1057   Status s = c->GetAttr("data_format", &data_format_str);
1058   if (s.ok()) {
1059     FormatFromString(data_format_str, &data_format);
1060   } else {
1061     data_format = FORMAT_NHWC;
1062   }
1063 
1064   const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
1065   ShapeHandle input_shape;
1066   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
1067 
1068   TF_RETURN_IF_ERROR(
1069       CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
1070 
1071   std::vector<int32> strides;
1072   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
1073   if (strides.size() != 4) {
1074     return errors::InvalidArgument(
1075         "AvgPool requires the stride attribute to contain 4 values, but got: ",
1076         strides.size());
1077   }
1078 
1079   std::vector<int32> kernel_sizes;
1080   TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
1081   if (kernel_sizes.size() != 4) {
1082     return errors::InvalidArgument(
1083         "AvgPool requires the ksize attribute to contain 4 values, but got: ",
1084         kernel_sizes.size());
1085   }
1086 
1087   int32 stride_rows = GetTensorDim(strides, data_format, 'H');
1088   int32 stride_cols = GetTensorDim(strides, data_format, 'W');
1089   int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
1090   int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
1091 
1092   constexpr int num_spatial_dims = 2;
1093   DimensionHandle batch_size_dim = c->Dim(
1094       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
1095   DimensionHandle in_rows_dim = c->Dim(
1096       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
1097   DimensionHandle in_cols_dim = c->Dim(
1098       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
1099   DimensionHandle depth_dim = c->Dim(
1100       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
1101 
1102   Padding padding;
1103   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1104 
1105   // TODO(mrry,shlens): Raise an error if the stride would cause
1106   // information in the input to be ignored. This will require a change
1107   // in the kernel implementation.
1108 
1109   DimensionHandle output_rows, output_cols;
1110   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1111       c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
1112   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1113       c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
1114 
1115   ShapeHandle output_shape;
1116   TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
1117                                          {output_rows, output_cols}, depth_dim,
1118                                          &output_shape, c));
1119   c->set_output(0, output_shape);
1120   return Status::OK();
1121 }
1122 
AvgPoolGradShape(shape_inference::InferenceContext * c)1123 Status AvgPoolGradShape(shape_inference::InferenceContext* c) {
1124   ShapeHandle s;
1125   TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
1126   TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
1127   c->set_output(0, s);
1128   return Status::OK();
1129 }
1130 
FusedBatchNormShape(shape_inference::InferenceContext * c)1131 Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
1132   string data_format_str;
1133   TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
1134   TensorFormat data_format;
1135   if (!FormatFromString(data_format_str, &data_format)) {
1136     return errors::InvalidArgument("Invalid data format string: ",
1137                                    data_format_str);
1138   }
1139   const int rank =
1140       (data_format_str == "NDHWC" || data_format_str == "NCDHW") ? 5 : 4;
1141   ShapeHandle x;
1142   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &x));
1143 
1144   bool is_training;
1145   TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
1146   float exponential_avg_factor;
1147   if (!c->GetAttr("exponential_avg_factor", &exponential_avg_factor).ok()) {
1148     exponential_avg_factor = 1.0f;  // default value
1149   }
1150   int number_inputs = (is_training && exponential_avg_factor == 1.0f) ? 3 : 5;
1151 
1152   int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format);
1153   DimensionHandle channel_dim = c->Dim(x, channel_dim_index);
1154 
1155   // covers scale, offset, and if is_training is false, mean, variance
1156   for (int i = 1; i < number_inputs; ++i) {
1157     ShapeHandle vec;
1158     TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
1159     TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
1160   }
1161 
1162   ShapeHandle y;
1163   TF_RETURN_IF_ERROR(c->ReplaceDim(x, channel_dim_index, channel_dim, &y));
1164   c->set_output(0, y);
1165   ShapeHandle vector_shape = c->Vector(channel_dim);
1166   c->set_output(1, vector_shape);
1167   c->set_output(2, vector_shape);
1168   c->set_output(3, vector_shape);
1169   c->set_output(4, vector_shape);
1170   return Status::OK();
1171 }
1172 
FusedBatchNormV3Shape(shape_inference::InferenceContext * c)1173 Status FusedBatchNormV3Shape(shape_inference::InferenceContext* c) {
1174   TF_RETURN_IF_ERROR(FusedBatchNormShape(c));
1175   c->set_output(5, c->UnknownShape());
1176   return Status::OK();
1177 }
1178 
FusedBatchNormExShape(shape_inference::InferenceContext * c)1179 Status FusedBatchNormExShape(shape_inference::InferenceContext* c) {
1180   TF_RETURN_IF_ERROR(FusedBatchNormV3Shape(c));
1181 
1182   string data_format_str;
1183   TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
1184   TensorFormat data_format;
1185   if (!FormatFromString(data_format_str, &data_format)) {
1186     return errors::InvalidArgument("Invalid data format string: ",
1187                                    data_format_str);
1188   }
1189   ShapeHandle x;
1190   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x));
1191 
1192   int channel_dim_index = GetTensorFeatureDimIndex(4, data_format);
1193   DimensionHandle channel_dim = c->Dim(x, channel_dim_index);
1194 
1195   // This is a cuDNN implementation constraint.
1196   if (c->ValueKnown(channel_dim) && c->Value(channel_dim) % 4 != 0) {
1197     return errors::InvalidArgument(
1198         "_FusedBatchNormEx channel dimension must be divisible by 4.");
1199   }
1200 
1201   return Status::OK();
1202 }
1203 
FusedBatchNormGradShape(shape_inference::InferenceContext * c)1204 Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
1205   string data_format_str;
1206   TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
1207   TensorFormat data_format;
1208   if (!FormatFromString(data_format_str, &data_format)) {
1209     return errors::InvalidArgument("Invalid data format string: ",
1210                                    data_format_str);
1211   }
1212   const int rank =
1213       (data_format_str == "NDHWC" || data_format_str == "NCDHW") ? 5 : 4;
1214   ShapeHandle y_backprop;
1215   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &y_backprop));
1216   ShapeHandle x;
1217   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &x));
1218 
1219   bool is_training;
1220   TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
1221 
1222   int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format);
1223   DimensionHandle channel_dim = c->Dim(y_backprop, channel_dim_index);
1224   TF_RETURN_IF_ERROR(
1225       c->Merge(channel_dim, c->Dim(x, channel_dim_index), &channel_dim));
1226 
1227   // covers scale, mean (reserve_space_1), variance (reserve_space_2)
1228   for (int i = 2; i < 5; ++i) {
1229     ShapeHandle vec;
1230     TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
1231     TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
1232   }
1233 
1234   ShapeHandle x_backprop;
1235   TF_RETURN_IF_ERROR(
1236       c->ReplaceDim(y_backprop, channel_dim_index, channel_dim, &x_backprop));
1237   c->set_output(0, x_backprop);
1238   c->set_output(1, c->Vector(channel_dim));
1239   c->set_output(2, c->Vector(channel_dim));
1240   c->set_output(3, c->Vector(0));
1241   c->set_output(4, c->Vector(0));
1242   return Status::OK();
1243 }
1244 
ReadDiagIndex(InferenceContext * c,const Tensor * diag_index_tensor,int32 * lower_diag_index,int32 * upper_diag_index)1245 Status ReadDiagIndex(InferenceContext* c, const Tensor* diag_index_tensor,
1246                      int32* lower_diag_index, int32* upper_diag_index) {
1247   // This function assumes that the shape of diag_index_tensor is fully defined.
1248   if (diag_index_tensor->dims() == 0) {
1249     *lower_diag_index = diag_index_tensor->scalar<int32>()();
1250     *upper_diag_index = *lower_diag_index;
1251   } else {
1252     int32 num_elements = diag_index_tensor->dim_size(0);
1253     if (num_elements == 1) {
1254       *lower_diag_index = diag_index_tensor->vec<int32>()(0);
1255       *upper_diag_index = *lower_diag_index;
1256     } else if (num_elements == 2) {
1257       *lower_diag_index = diag_index_tensor->vec<int32>()(0);
1258       *upper_diag_index = diag_index_tensor->vec<int32>()(1);
1259     } else {
1260       return errors::InvalidArgument(
1261           "diag_index must be a vector with one or two elements. It has ",
1262           num_elements, " elements.");
1263     }
1264   }
1265   return Status::OK();
1266 }
1267 
MatrixDiagPartV2Shape(shape_inference::InferenceContext * c)1268 Status MatrixDiagPartV2Shape(shape_inference::InferenceContext* c) {
1269   ShapeHandle input_shape, diag_index_shape, unused_shape;
1270   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
1271   TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &diag_index_shape));
1272   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_shape));
1273 
1274   const Tensor* diag_index_tensor = c->input_tensor(1);
1275   if (!c->RankKnown(input_shape) || !c->FullyDefined(diag_index_shape) ||
1276       diag_index_tensor == nullptr) {
1277     c->set_output(0, c->UnknownShape());
1278     return Status::OK();
1279   }
1280   int32 lower_diag_index = 0;
1281   int32 upper_diag_index = 0;
1282   TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index,
1283                                    &upper_diag_index));
1284   if (lower_diag_index > upper_diag_index) {
1285     return errors::InvalidArgument(
1286         "lower_diag_index is greater than upper_diag_index");
1287   }
1288 
1289   // Validates lower_diag_index and upper_diag_index.
1290   const int32 input_rank = c->Rank(input_shape);
1291   const int32 num_rows = c->Value(c->Dim(input_shape, input_rank - 2));
1292   const int32 num_cols = c->Value(c->Dim(input_shape, input_rank - 1));
1293   int32 max_diag_len = InferenceContext::kUnknownDim;
1294   if (num_rows != InferenceContext::kUnknownDim &&
1295       num_cols != InferenceContext::kUnknownDim) {
1296     if (lower_diag_index != 0 &&  // For when num_rows or num_cols == 0.
1297         (-num_rows >= lower_diag_index || lower_diag_index >= num_cols)) {
1298       return errors::InvalidArgument("lower_diag_index is out of bound.");
1299     }
1300     if (upper_diag_index != 0 &&  // For when num_rows or num_cols == 0.
1301         (-num_rows >= upper_diag_index || upper_diag_index >= num_cols)) {
1302       return errors::InvalidArgument("upper_diag_index is out of bound.");
1303     }
1304     max_diag_len = std::min(num_rows + std::min(upper_diag_index, 0),
1305                             num_cols - std::max(lower_diag_index, 0));
1306   }
1307 
1308   std::vector<DimensionHandle> dims;
1309   dims.reserve(input_rank - 2);
1310   for (int i = 0; i < input_rank - 2; ++i) {
1311     dims.push_back(c->Dim(input_shape, i));
1312   }
1313   if (lower_diag_index < upper_diag_index) {
1314     dims.push_back(c->MakeDim(upper_diag_index - lower_diag_index + 1));
1315   }
1316   dims.push_back(c->MakeDim(max_diag_len));
1317   c->set_output(0, c->MakeShape(dims));
1318   return Status::OK();
1319 }
1320 
MatrixDiagV2Shape(shape_inference::InferenceContext * c)1321 Status MatrixDiagV2Shape(shape_inference::InferenceContext* c) {
1322   // Checks input ranks.
1323   ShapeHandle input_shape, diag_index_shape, unused_shape;
1324   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input_shape));
1325   TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &diag_index_shape));
1326   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_shape));
1327   TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape));
1328   TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape));
1329 
1330   // Reads the diagonal indices.
1331   const Tensor* diag_index_tensor = c->input_tensor(1);
1332   if (!c->RankKnown(input_shape) || !c->FullyDefined(diag_index_shape) ||
1333       diag_index_tensor == nullptr) {
1334     c->set_output(0, c->UnknownShape());
1335     return Status::OK();
1336   }
1337   int32 lower_diag_index = 0;
1338   int32 upper_diag_index = 0;
1339   TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index,
1340                                    &upper_diag_index));
1341   if (lower_diag_index > upper_diag_index) {
1342     return errors::InvalidArgument(
1343         "lower_diag_index is greater than upper_diag_index");
1344   }
1345 
1346   // Checks if the number of diagonals provided matches what we imply from
1347   // lower_diag_index and upper_diag_index.
1348   const int32 input_rank = c->Rank(input_shape);
1349   if (lower_diag_index < upper_diag_index) {
1350     const int32 num_diags = c->Value(c->Dim(input_shape, input_rank - 2));
1351     const int32 other_dim = c->Value(c->Dim(input_shape, input_rank - 1));
1352 
1353     if (num_diags != (upper_diag_index - lower_diag_index + 1)) {
1354       return errors::InvalidArgument(
1355           "The number of rows of `diagonal` doesn't match the number of "
1356           "diagonals implied from `d_lower` and `d_upper`.\n",
1357           "num_diags = ", num_diags, ", d_lower = ", lower_diag_index,
1358           ", d_upper = ", upper_diag_index, " ", input_rank, " ", other_dim);
1359     }
1360   }
1361 
1362   // Reads num_rows and num_cols.
1363   const Tensor* num_rows_tensor = c->input_tensor(2);
1364   const Tensor* num_cols_tensor = c->input_tensor(3);
1365   int64 num_rows = -1;
1366   int64 num_cols = -1;
1367   if (num_rows_tensor != nullptr) {
1368     TF_RETURN_IF_ERROR(c->GetScalarFromTensor(num_rows_tensor, &num_rows));
1369   }
1370   if (num_cols_tensor != nullptr) {
1371     TF_RETURN_IF_ERROR(c->GetScalarFromTensor(num_cols_tensor, &num_cols));
1372   }
1373 
1374   // Infers the missing num_rows or num_cols: If both are missing, assume
1375   // output is square. Otherwise, use the smallest possible value. Also
1376   // validates the provided values.
1377   const int32 max_diag_len = c->Value(c->Dim(input_shape, input_rank - 1));
1378   const int32 min_num_rows = max_diag_len - std::min(upper_diag_index, 0);
1379   const int32 min_num_cols = max_diag_len + std::max(lower_diag_index, 0);
1380   if (num_rows == -1 && num_cols == -1) {  // Special case.
1381     num_rows = std::max(min_num_rows, min_num_cols);
1382     num_cols = num_rows;
1383   }
1384   if (num_rows == -1) {
1385     num_rows = min_num_rows;
1386   } else if (num_rows < min_num_rows) {
1387     return errors::InvalidArgument("num_rows is too small");
1388   }
1389   if (num_cols == -1) {
1390     num_cols = min_num_cols;
1391   } else if (num_cols < min_num_cols) {
1392     return errors::InvalidArgument("num_cols is too small.");
1393   }
1394   // At least one of them must match the minimum length.
1395   if (num_rows != min_num_rows && num_cols != min_num_cols) {
1396     return errors::InvalidArgument(
1397         "num_rows and num_cols are not consistent with lower_diag_index, "
1398         "upper_diag_index, and the length of the given diagonals.\n",
1399         "num_rows = ", num_rows, " != min_num_rows = ", min_num_rows,
1400         ", num_cols = ", num_cols, " != min_num_cols = ", min_num_cols);
1401   }
1402 
1403   // Sets output shape.
1404   ShapeHandle output_shape;
1405   const DimensionHandle output_row_dim = c->MakeDim(num_rows);
1406   const DimensionHandle output_col_dim = c->MakeDim(num_cols);
1407   if (lower_diag_index == upper_diag_index) {
1408     TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape, input_rank - 1,
1409                                      output_row_dim, &output_shape));
1410     TF_RETURN_IF_ERROR(
1411         c->Concatenate(output_shape, c->Vector(output_col_dim), &output_shape));
1412   } else {
1413     TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape, input_rank - 2,
1414                                      output_row_dim, &output_shape));
1415     TF_RETURN_IF_ERROR(c->ReplaceDim(output_shape, input_rank - 1,
1416                                      output_col_dim, &output_shape));
1417   }
1418   c->set_output(0, output_shape);
1419   return Status::OK();
1420 }
1421 
MatrixSetDiagV2Shape(shape_inference::InferenceContext * c)1422 Status MatrixSetDiagV2Shape(shape_inference::InferenceContext* c) {
1423   ShapeHandle input_shape, diag_shape, diag_index_shape;
1424   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
1425   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &diag_shape));
1426   TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &diag_index_shape));
1427 
1428   int32 lower_diag_index = 0;
1429   int32 upper_diag_index = 0;
1430   bool diag_index_known = false;
1431   const Tensor* diag_index_tensor = c->input_tensor(2);
1432   if (diag_index_tensor != nullptr && c->FullyDefined(diag_index_shape)) {
1433     diag_index_known = true;
1434     TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index,
1435                                      &upper_diag_index));
1436     if (lower_diag_index > upper_diag_index) {
1437       return errors::InvalidArgument(
1438           "lower_diag_index is greater than upper_diag_index");
1439     }
1440   }
1441 
1442   // Do more checks when input rank is known.
1443   if (c->RankKnown(input_shape)) {
1444     int32 input_rank = c->Rank(input_shape);
1445 
1446     // If diag_index is set, we know the exact rank of diagonal.
1447     if (diag_index_known) {
1448       TF_RETURN_IF_ERROR(c->WithRank(
1449           c->input(1),
1450           (lower_diag_index == upper_diag_index) ? input_rank - 1 : input_rank,
1451           &diag_shape));
1452     } else {
1453       TF_RETURN_IF_ERROR(
1454           c->WithRankAtLeast(c->input(1), input_rank - 1, &diag_shape));
1455       TF_RETURN_IF_ERROR(
1456           c->WithRankAtMost(c->input(1), input_rank, &diag_shape));
1457     }
1458 
1459     // Validates lower_diag_index and upper_diag_index.
1460     const int32 num_rows = c->Value(c->Dim(input_shape, input_rank - 2));
1461     const int32 num_cols = c->Value(c->Dim(input_shape, input_rank - 1));
1462     if (num_rows != InferenceContext::kUnknownDim &&
1463         num_cols != InferenceContext::kUnknownDim) {
1464       if (lower_diag_index != 0 &&  // For when num_rows or num_cols == 0.
1465           (-num_rows >= lower_diag_index || lower_diag_index >= num_cols)) {
1466         return errors::InvalidArgument("lower_diag_index is out of bound.");
1467       }
1468       if (upper_diag_index != 0 &&  // For when num_rows or num_cols == 0.
1469           (-num_rows >= upper_diag_index || upper_diag_index >= num_cols)) {
1470         return errors::InvalidArgument("upper_diag_index is out of bound.");
1471       }
1472     }
1473   }
1474 
1475   ShapeHandle output_shape = input_shape;
1476   if (c->RankKnown(diag_shape) && !c->FullyDefined(input_shape)) {
1477     // Try to infer parts of shape from diag.
1478     ShapeHandle diag_prefix;
1479     TF_RETURN_IF_ERROR(c->Subshape(
1480         diag_shape, 0, (lower_diag_index == upper_diag_index) ? -1 : -2,
1481         &diag_prefix));
1482 
1483     // The inner matrices can be rectangular, so we can't pinpoint their
1484     // exact height and width by just lower_diag_index, upper_diag_index,
1485     // and the longest length of given diagonals.
1486     TF_RETURN_IF_ERROR(
1487         c->Concatenate(diag_prefix, c->UnknownShapeOfRank(2), &diag_shape));
1488     TF_RETURN_IF_ERROR(c->Merge(input_shape, diag_shape, &output_shape));
1489   }
1490   c->set_output(0, output_shape);
1491   return Status::OK();
1492 }
1493 
MaxPoolShapeImpl(shape_inference::InferenceContext * c,bool supports_explicit_padding)1494 Status MaxPoolShapeImpl(shape_inference::InferenceContext* c,
1495                         bool supports_explicit_padding) {
1496   string data_format_str;
1497   TensorFormat data_format;
1498   Status s = c->GetAttr("data_format", &data_format_str);
1499   if (s.ok()) {
1500     FormatFromString(data_format_str, &data_format);
1501   } else {
1502     data_format = FORMAT_NHWC;
1503   }
1504 
1505   const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
1506   ShapeHandle input_shape;
1507   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
1508 
1509   TF_RETURN_IF_ERROR(
1510       CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
1511 
1512   std::vector<int32> strides;
1513   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
1514   if (strides.size() != 4) {
1515     return errors::InvalidArgument(
1516         "MaxPool requires the stride attribute to contain 4 values, but got: ",
1517         strides.size());
1518   }
1519 
1520   std::vector<int32> kernel_sizes;
1521   TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
1522   if (kernel_sizes.size() != 4) {
1523     return errors::InvalidArgument(
1524         "MaxPool requires the ksize attribute to contain 4 values, but got: ",
1525         kernel_sizes.size());
1526   }
1527 
1528   int32 stride_depth = GetTensorDim(strides, data_format, 'C');
1529   int32 stride_rows = GetTensorDim(strides, data_format, 'H');
1530   int32 stride_cols = GetTensorDim(strides, data_format, 'W');
1531   int32 kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C');
1532   int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
1533   int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
1534 
1535   constexpr int num_spatial_dims = 2;
1536   DimensionHandle batch_size_dim = c->Dim(
1537       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
1538   DimensionHandle in_rows_dim = c->Dim(
1539       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
1540   DimensionHandle in_cols_dim = c->Dim(
1541       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
1542   DimensionHandle in_depth_dim = c->Dim(
1543       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
1544 
1545   Padding padding;
1546   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1547 
1548   std::vector<int64> explicit_paddings;
1549   if (supports_explicit_padding) {
1550     Status status = c->GetAttr("explicit_paddings", &explicit_paddings);
1551     // Use the default value, which is an empty list, if the attribute is not
1552     // found. Otherwise return the error to the caller.
1553     if (!status.ok() && !errors::IsNotFound(status)) {
1554       return status;
1555     }
1556     TF_RETURN_IF_ERROR(CheckValidPadding(padding, explicit_paddings,
1557                                          /*num_dims=*/4, data_format));
1558   } else {
1559     DCHECK(padding != Padding::EXPLICIT);
1560   }
1561 
1562   ShapeHandle output_shape;
1563   DimensionHandle output_rows, output_cols, output_depth;
1564   int64 pad_rows_before = -1, pad_rows_after = -1;
1565   int64 pad_cols_before = -1, pad_cols_after = -1;
1566   if (padding == Padding::EXPLICIT) {
1567     GetExplicitPaddingForDim(explicit_paddings, data_format, 'H',
1568                              &pad_rows_before, &pad_rows_after);
1569     GetExplicitPaddingForDim(explicit_paddings, data_format, 'W',
1570                              &pad_cols_before, &pad_cols_after);
1571   }
1572   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
1573       c, in_rows_dim, kernel_rows, /*dilation_rate=*/1, stride_rows, padding,
1574       pad_rows_before, pad_rows_after, &output_rows));
1575   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
1576       c, in_cols_dim, kernel_cols, /*dilation_rate=*/1, stride_cols, padding,
1577       pad_cols_before, pad_cols_after, &output_cols));
1578   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
1579       c, in_depth_dim, kernel_depth, /*dilation_rate=*/1, stride_depth, padding,
1580       /*pad_before*/ 0, /*pad_after*/ 0, &output_depth));
1581 
1582   TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
1583                                          {output_rows, output_cols},
1584                                          output_depth, &output_shape, c));
1585 
1586   c->set_output(0, output_shape);
1587   return Status::OK();
1588 }
1589 
MaxPoolShape(shape_inference::InferenceContext * c)1590 Status MaxPoolShape(shape_inference::InferenceContext* c) {
1591   return MaxPoolShapeImpl(c, /*supports_explicit_padding=*/false);
1592 }
1593 
MaxPoolGradShape(shape_inference::InferenceContext * c)1594 Status MaxPoolGradShape(shape_inference::InferenceContext* c) {
1595   return UnchangedShapeWithRank(c, 4);
1596 }
1597 
MaxPoolShapeWithExplicitPadding(shape_inference::InferenceContext * c)1598 Status MaxPoolShapeWithExplicitPadding(shape_inference::InferenceContext* c) {
1599   return MaxPoolShapeImpl(c, /*supports_explicit_padding=*/true);
1600 }
1601 
MaxPoolV2Shape(shape_inference::InferenceContext * c,int num_inputs)1602 Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) {
1603   string data_format_str;
1604   TensorFormat data_format;
1605   Status s = c->GetAttr("data_format", &data_format_str);
1606   if (s.ok()) {
1607     FormatFromString(data_format_str, &data_format);
1608   } else {
1609     data_format = FORMAT_NHWC;
1610   }
1611 
1612   const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
1613   ShapeHandle input_shape;
1614   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
1615 
1616   TF_RETURN_IF_ERROR(
1617       CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
1618 
1619   std::vector<int32> kernel_sizes;
1620   std::vector<int32> strides;
1621 
1622   if (c->num_inputs() + 2 == num_inputs) {
1623     TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
1624 
1625     TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
1626   } else {
1627     // Verify shape of ksize and strides input.
1628     ShapeHandle size;
1629     DimensionHandle unused;
1630     TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 2), 1, &size));
1631     TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused));
1632     TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &size));
1633     TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused));
1634 
1635     const Tensor* kernel_sizes_tensor = c->input_tensor(c->num_inputs() - 2);
1636     if (kernel_sizes_tensor == nullptr) {
1637       c->set_output(0, c->UnknownShape());
1638       return Status::OK();
1639     }
1640     kernel_sizes.resize(kernel_sizes_tensor->shape().num_elements());
1641     auto kernel_sizes_vec = kernel_sizes_tensor->flat<int32>();
1642     std::copy_n(&kernel_sizes_vec(0), kernel_sizes.size(),
1643                 kernel_sizes.begin());
1644 
1645     const Tensor* strides_tensor = c->input_tensor(c->num_inputs() - 1);
1646     if (strides_tensor == nullptr) {
1647       c->set_output(0, c->UnknownShape());
1648       return Status::OK();
1649     }
1650     strides.resize(strides_tensor->shape().num_elements());
1651     auto strides_vec = strides_tensor->flat<int32>();
1652     std::copy_n(&strides_vec(0), strides.size(), strides.begin());
1653   }
1654 
1655   if (strides.size() != 4) {
1656     return errors::InvalidArgument(
1657         "MaxPool requires the stride attribute to contain 4 values, but "
1658         "got: ",
1659         strides.size());
1660   }
1661   if (kernel_sizes.size() != 4) {
1662     return errors::InvalidArgument(
1663         "MaxPool requires the ksize attribute to contain 4 values, but got: ",
1664         kernel_sizes.size());
1665   }
1666 
1667   int32 stride_depth = GetTensorDim(strides, data_format, 'C');
1668   int32 stride_rows = GetTensorDim(strides, data_format, 'H');
1669   int32 stride_cols = GetTensorDim(strides, data_format, 'W');
1670   int32 kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C');
1671   int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
1672   int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
1673 
1674   constexpr int num_spatial_dims = 2;
1675   DimensionHandle batch_size_dim = c->Dim(
1676       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
1677   DimensionHandle in_rows_dim = c->Dim(
1678       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
1679   DimensionHandle in_cols_dim = c->Dim(
1680       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
1681   DimensionHandle in_depth_dim = c->Dim(
1682       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
1683 
1684   Padding padding;
1685   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1686 
1687   ShapeHandle output_shape;
1688   DimensionHandle output_rows, output_cols, output_depth;
1689   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1690       c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
1691   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1692       c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
1693   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1694       c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth));
1695 
1696   TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
1697                                          {output_rows, output_cols},
1698                                          output_depth, &output_shape, c));
1699 
1700   c->set_output(0, output_shape);
1701   return Status::OK();
1702 }
1703 
Pool3DShape(shape_inference::InferenceContext * c)1704 Status Pool3DShape(shape_inference::InferenceContext* c) {
1705   ShapeHandle input_shape;
1706   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
1707 
1708   string data_format;
1709   Status s = c->GetAttr("data_format", &data_format);
1710 
1711   std::vector<int32> strides;
1712   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
1713   if (strides.size() != 5) {
1714     return errors::InvalidArgument(
1715         "Pool3D ops require the stride attribute to contain 5 values, but "
1716         "got: ",
1717         strides.size());
1718   }
1719 
1720   std::vector<int32> kernel_sizes;
1721   TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
1722   if (kernel_sizes.size() != 5) {
1723     return errors::InvalidArgument(
1724         "Pool3D requires the ksize attribute to contain 5 values, but got: ",
1725         kernel_sizes.size());
1726   }
1727 
1728   int32 stride_planes, stride_rows, stride_cols;
1729   int32 kernel_planes, kernel_rows, kernel_cols;
1730 
1731   if (s.ok() && data_format == "NCDHW") {
1732     // Convert input_shape to NDHWC.
1733     auto dim = [&](char dimension) {
1734       return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension));
1735     };
1736     input_shape =
1737         c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}});
1738     stride_planes = strides[2];
1739     stride_rows = strides[3];
1740     stride_cols = strides[4];
1741     kernel_planes = kernel_sizes[2];
1742     kernel_rows = kernel_sizes[3];
1743     kernel_cols = kernel_sizes[4];
1744   } else {
1745     stride_planes = strides[1];
1746     stride_rows = strides[2];
1747     stride_cols = strides[3];
1748     kernel_planes = kernel_sizes[1];
1749     kernel_rows = kernel_sizes[2];
1750     kernel_cols = kernel_sizes[3];
1751   }
1752 
1753   DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
1754   DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
1755   DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
1756   DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
1757   DimensionHandle output_depth_dim = c->Dim(input_shape, 4);
1758 
1759   Padding padding;
1760   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1761 
1762   // TODO(mrry,shlens): Raise an error if the stride would cause
1763   // information in the input to be ignored. This will require a change
1764   // in the kernel implementation.
1765   DimensionHandle output_planes, output_rows, output_cols;
1766   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1767       c, in_planes_dim, kernel_planes, stride_planes, padding, &output_planes));
1768   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1769       c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
1770   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1771       c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
1772 
1773   ShapeHandle output_shape;
1774   if (data_format == "NCDHW") {
1775     output_shape = c->MakeShape({batch_size_dim, output_depth_dim,
1776                                  output_planes, output_rows, output_cols});
1777   } else {
1778     output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows,
1779                                  output_cols, output_depth_dim});
1780   }
1781 
1782   c->set_output(0, output_shape);
1783   return Status::OK();
1784 }
1785 
MaxPool3DGradShape(shape_inference::InferenceContext * c)1786 Status MaxPool3DGradShape(shape_inference::InferenceContext* c) {
1787   return UnchangedShapeWithRank(c, 5);
1788 }
1789 
AvgPool3DGradShape(shape_inference::InferenceContext * c)1790 Status AvgPool3DGradShape(shape_inference::InferenceContext* c) {
1791   ShapeHandle s;
1792   TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
1793   TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
1794   c->set_output(0, s);
1795   return Status::OK();
1796 }
1797 
UnknownShape(shape_inference::InferenceContext * c)1798 Status UnknownShape(shape_inference::InferenceContext* c) {
1799   for (int i = 0; i < c->num_outputs(); ++i) {
1800     c->set_output(i, c->UnknownShape());
1801   }
1802   return Status::OK();
1803 }
1804 
1805 template <typename T>
ReductionShapeHelper(const Tensor * reduction_indices_t,const int32 input_rank,std::set<int64> * true_indices)1806 Status ReductionShapeHelper(const Tensor* reduction_indices_t,
1807                             const int32 input_rank,
1808                             std::set<int64>* true_indices) {
1809   auto reduction_indices = reduction_indices_t->flat<T>();
1810   for (int i = 0; i < reduction_indices_t->NumElements(); ++i) {
1811     const T reduction_index = reduction_indices(i);
1812     if (reduction_index < -input_rank || reduction_index >= input_rank) {
1813       return errors::InvalidArgument("Invalid reduction dimension ",
1814                                      reduction_index, " for input with ",
1815                                      input_rank, " dimensions.");
1816     }
1817 
1818     auto wrapped_index = reduction_index;
1819     if (wrapped_index < 0) {
1820       wrapped_index += input_rank;
1821     }
1822 
1823     true_indices->insert(wrapped_index);
1824   }
1825   return Status::OK();
1826 }
1827 
ReductionShape(InferenceContext * c)1828 Status ReductionShape(InferenceContext* c) {
1829   ShapeHandle input = c->input(0);
1830 
1831   ShapeHandle indices;
1832   // Older versions of TensorFlow accidentally allowed higher rank tensors like
1833   // [[1,2]] or [[1],[2]] to represent axis=[1,2].
1834   if (c->graph_def_version() < 21) {
1835     indices = c->input(1);
1836   } else {
1837     TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &indices));
1838   }
1839 
1840   bool keep_dims;
1841   TF_RETURN_IF_ERROR(c->GetAttr("keep_dims", &keep_dims));
1842 
1843   const Tensor* reduction_indices_t = c->input_tensor(1);
1844   if (reduction_indices_t == nullptr || !c->RankKnown(input)) {
1845     // If we do not have the reduction values at runtime, or the
1846     // rank of the input, we don't know the output shape.
1847 
1848     if (keep_dims && c->RankKnown(input)) {
1849       // output rank matches input input if <keep_dims>.
1850       c->set_output(0, c->UnknownShapeOfRank(c->Rank(input)));
1851       return Status::OK();
1852     } else {
1853       return shape_inference::UnknownShape(c);
1854     }
1855   }
1856 
1857   const int32 input_rank = c->Rank(input);
1858   std::set<int64> true_indices;
1859   if (reduction_indices_t->dtype() == DataType::DT_INT32) {
1860     TF_RETURN_IF_ERROR(ReductionShapeHelper<int32>(reduction_indices_t,
1861                                                    input_rank, &true_indices));
1862   } else if (reduction_indices_t->dtype() == DataType::DT_INT64) {
1863     TF_RETURN_IF_ERROR(ReductionShapeHelper<int64>(reduction_indices_t,
1864                                                    input_rank, &true_indices));
1865   } else {
1866     return errors::InvalidArgument(
1867         "reduction_indices can only be int32 or int64");
1868   }
1869 
1870   std::vector<DimensionHandle> dims;
1871   for (int i = 0; i < input_rank; ++i) {
1872     if (true_indices.count(i) > 0) {
1873       if (keep_dims) {
1874         dims.emplace_back(c->MakeDim(1));
1875       }
1876     } else {
1877       dims.emplace_back(c->Dim(input, i));
1878     }
1879   }
1880 
1881   c->set_output(0, c->MakeShape(dims));
1882   return Status::OK();
1883 }
1884 
ConcatShapeHelper(InferenceContext * c,int start_value_index,int end_value_index,int dim_index)1885 Status ConcatShapeHelper(InferenceContext* c, int start_value_index,
1886                          int end_value_index, int dim_index) {
1887   ShapeHandle unused;
1888   TF_RETURN_IF_ERROR(c->WithRank(c->input(dim_index), 0, &unused));
1889   const Tensor* concat_dim_t = c->input_tensor(dim_index);
1890   if (concat_dim_t == nullptr) {
1891     // Return an unknown shape with same rank as inputs, or an unknown rank
1892     // if no input's rank is known.
1893 
1894     // Find rank.
1895     int32 rank = InferenceContext::kUnknownRank;
1896     for (int i = start_value_index; i < end_value_index; ++i) {
1897       if (rank == InferenceContext::kUnknownRank) rank = c->Rank(c->input(i));
1898       if (rank != InferenceContext::kUnknownRank) {
1899         break;
1900       }
1901     }
1902     if (rank == InferenceContext::kUnknownRank) {
1903       c->set_output(0, c->UnknownShape());
1904       return Status::OK();
1905     } else if (rank == 0) {
1906       return errors::InvalidArgument(
1907           "Can't concatenate scalars (use tf.stack instead)");
1908     } else {
1909       for (int i = start_value_index; i < end_value_index; ++i) {
1910         // Check that all the inputs are of the correct rank.
1911         TF_RETURN_IF_ERROR(c->WithRank(c->input(i), rank, &unused));
1912       }
1913     }
1914     // Build result of <rank> different unknown dims.
1915     std::vector<DimensionHandle> dims;
1916     dims.reserve(rank);
1917     for (int i = 0; i < rank; ++i) dims.push_back(c->UnknownDim());
1918     c->set_output(0, c->MakeShape(dims));
1919     return Status::OK();
1920   }
1921 
1922   // Merge all the non-concat dims, and sum the concat dim to make an output
1923   // shape.
1924   int64 concat_dim;
1925   if (concat_dim_t->dtype() == DT_INT32) {
1926     concat_dim = static_cast<int64>(concat_dim_t->flat<int32>()(0));
1927   } else {
1928     concat_dim = concat_dim_t->flat<int64>()(0);
1929   }
1930 
1931   // Minimum required number of dimensions.
1932   const int min_rank = concat_dim < 0 ? -concat_dim : concat_dim + 1;
1933 
1934   ShapeHandle output_before;
1935   ShapeHandle output_after;
1936 
1937   ShapeHandle input = c->input(end_value_index - 1);
1938   TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input));
1939   TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &output_before));
1940   DimensionHandle output_middle = c->Dim(input, concat_dim);
1941   if (concat_dim == -1) {
1942     output_after = c->Scalar();  // no dimensions.
1943   } else {
1944     TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &output_after));
1945   }
1946 
1947   for (int i = end_value_index - 2; i >= start_value_index; --i) {
1948     ShapeHandle before;
1949     ShapeHandle after;
1950     input = c->input(i);
1951     TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input));
1952     TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &before));
1953     DimensionHandle middle = c->Dim(input, concat_dim);
1954     if (concat_dim == -1) {
1955       after = c->Scalar();
1956     } else {
1957       TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &after));
1958     }
1959 
1960     TF_RETURN_IF_ERROR(c->Merge(before, output_before, &output_before));
1961     TF_RETURN_IF_ERROR(c->Add(output_middle, middle, &output_middle));
1962     TF_RETURN_IF_ERROR(c->Merge(after, output_after, &output_after));
1963   }
1964 
1965   ShapeHandle s;
1966   TF_RETURN_IF_ERROR(
1967       c->Concatenate(output_before, c->Vector(output_middle), &s));
1968   TF_RETURN_IF_ERROR(c->Concatenate(s, output_after, &s));
1969   c->set_output(0, s);
1970   return Status::OK();
1971 }
1972 
ConcatShape(InferenceContext * c,int num_inputs_to_concat)1973 Status ConcatShape(InferenceContext* c, int num_inputs_to_concat) {
1974   return ConcatShapeHelper(c, 1 /* start_value_index */,
1975                            1 + num_inputs_to_concat /* end_value_index */,
1976                            0 /* dim_index */);
1977 }
1978 
ConcatV2Shape(InferenceContext * c)1979 Status ConcatV2Shape(InferenceContext* c) {
1980   return ConcatShapeHelper(c, 0 /* start_value_index */,
1981                            c->num_inputs() - 1 /* end_value_index */,
1982                            c->num_inputs() - 1 /* dim_index */);
1983 }
1984 
QuantizedConcatV2Shape(InferenceContext * c,int num_inputs_to_concat)1985 Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat) {
1986   return ConcatShapeHelper(c, 0 /* start_value_index */,
1987                            num_inputs_to_concat /* end_value_index */,
1988                            num_inputs_to_concat /* dim_index */);
1989 }
1990 
BroadcastBinaryOpOutputShapeFnHelper(InferenceContext * c,ShapeHandle shape_x,ShapeHandle shape_y,bool incompatible_shape_error,ShapeHandle * out)1991 Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
1992                                             ShapeHandle shape_x,
1993                                             ShapeHandle shape_y,
1994                                             bool incompatible_shape_error,
1995                                             ShapeHandle* out) {
1996   CHECK_NOTNULL(out);
1997   if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) {
1998     *out = c->UnknownShape();
1999     return Status::OK();
2000   }
2001   const int32 rank_x = c->Rank(shape_x);
2002   const int32 rank_y = c->Rank(shape_y);
2003   const int32 rank_out = std::max(rank_x, rank_y);
2004 
2005   // To compute the broadcast dimensions, we zip together shape_x and shape_y
2006   // and
2007   // pad with 1 to make them the same length.
2008   std::vector<DimensionHandle> dims;
2009   DimensionHandle dim_one;
2010   if (rank_x != rank_y) dim_one = c->MakeDim(1);
2011   for (int i = 0; i < rank_out; ++i) {
2012     const auto dim_x = i < (rank_out - rank_x)
2013                            ? dim_one
2014                            : c->Dim(shape_x, i - (rank_out - rank_x));
2015     const bool dim_y_is_one = (i < (rank_out - rank_y));
2016     const auto dim_y =
2017         dim_y_is_one ? dim_one : c->Dim(shape_y, i - (rank_out - rank_y));
2018     if (!c->ValueKnown(dim_x) || !c->ValueKnown(dim_y)) {
2019       // One or both dimensions is unknown.
2020       //
2021       // - If either dimension is greater than 1, we assume that the program is
2022       // correct, and the other dimension will be broadcast to match it.
2023       // TODO(cwhipkey): For shape inference, if we eliminate the shape checks
2024       // in C++ op code, we must still assert that the unknown dim is either 1
2025       // or the same as the known dim.
2026       // - If either dimension is 1, the other dimension is the output.
2027       // - If both are unknown then dimension is unknown
2028       if (c->Value(dim_x) > 1) {
2029         if (!incompatible_shape_error) {
2030           *out = c->UnknownShape();
2031           return Status::OK();
2032         }
2033         dims.push_back(dim_x);
2034       } else if (c->Value(dim_y) > 1) {
2035         if (!incompatible_shape_error) {
2036           *out = c->UnknownShape();
2037           return Status::OK();
2038         }
2039         dims.push_back(dim_y);
2040       } else if (c->Value(dim_x) == 1) {
2041         dims.push_back(dim_y);
2042       } else if (c->Value(dim_y) == 1) {
2043         dims.push_back(dim_x);
2044       } else if (dim_y.SameHandle(dim_x)) {
2045         dims.push_back(dim_x);
2046       } else if (!c->ValueKnown(dim_x) && !c->ValueKnown(dim_y)) {
2047         dims.push_back(c->UnknownDim());
2048       } else {
2049         if (!incompatible_shape_error) {
2050           *out = c->UnknownShape();
2051           return Status::OK();
2052         }
2053         dims.push_back(c->UnknownDim());
2054       }
2055     } else if (c->Value(dim_x) == 1 || c->Value(dim_y) == 1) {
2056       if (c->Value(dim_x) == 1 && !dim_y_is_one) {
2057         // We will broadcast dim_x to dim_y.
2058         dims.push_back(dim_y);
2059       } else {
2060         DCHECK_EQ(c->Value(dim_y), 1);
2061         // We will broadcast dim_y to dim_x.
2062         dims.push_back(dim_x);
2063       }
2064     } else {
2065       DimensionHandle dim;
2066       Status s = c->Merge(dim_x, dim_y, &dim);
2067       if (!s.ok()) {
2068         if (!incompatible_shape_error) {
2069           *out = c->MakeShape({});
2070           return Status::OK();
2071         }
2072         return s;
2073       }
2074       dims.push_back(dim);
2075     }
2076   }
2077 
2078   *out = c->MakeShape(dims);
2079   return Status::OK();
2080 }
2081 
RandomShape(shape_inference::InferenceContext * c)2082 Status RandomShape(shape_inference::InferenceContext* c) {
2083   shape_inference::ShapeHandle out;
2084   TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
2085   c->set_output(0, out);
2086   return Status::OK();
2087 }
2088 
UnsortedSegmentReductionShapeFn(InferenceContext * c)2089 Status UnsortedSegmentReductionShapeFn(InferenceContext* c) {
2090   ShapeHandle s_data = c->input(0);
2091   ShapeHandle s_segment_ids = c->input(1);
2092   ShapeHandle s_num_segments = c->input(2);
2093   TF_RETURN_IF_ERROR(c->WithRank(s_num_segments, 0, &s_num_segments));
2094 
2095   ShapeHandle out;
2096 
2097   // Leading dimensions of data must be compatible with dimensions of
2098   // <s_segment_ids>.
2099   if (c->RankKnown(s_segment_ids)) {
2100     TF_RETURN_IF_ERROR(
2101         c->MergePrefix(s_data, s_segment_ids, &s_data, &s_segment_ids));
2102 
2103     // Get the value of the num_segments input tensor.
2104     DimensionHandle num_segments_dim;
2105     TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &num_segments_dim));
2106 
2107     // Output is {segment_id_rank} + s_data[segment_id_rank:].
2108     ShapeHandle s_data_suffix;
2109     TF_RETURN_IF_ERROR(
2110         c->Subshape(s_data, c->Rank(s_segment_ids), &s_data_suffix));
2111     TF_RETURN_IF_ERROR(
2112         c->Concatenate(c->Vector(num_segments_dim), s_data_suffix, &out));
2113   } else {
2114     out = c->UnknownShape();
2115   }
2116   c->set_output(0, out);
2117   return Status::OK();
2118 }
2119 
2120 namespace {
2121 
2122 // This SliceHelper processes the output shape of the `slice`
2123 // when the tensor of `sizes` is available.
2124 template <typename T>
SliceHelper(InferenceContext * c,ShapeHandle begin_value,const Tensor * sizes_value,std::vector<DimensionHandle> * dims)2125 Status SliceHelper(InferenceContext* c, ShapeHandle begin_value,
2126                    const Tensor* sizes_value,
2127                    std::vector<DimensionHandle>* dims) {
2128   auto sizes_vec = sizes_value->vec<T>();
2129   for (int i = 0; i < sizes_value->NumElements(); ++i) {
2130     DimensionHandle dim = c->Dim(c->input(0), i);
2131     if (sizes_vec(i) != -1) {
2132       auto dim_val = c->Value(dim);
2133       if (sizes_vec(i) < 0) {
2134         return errors::InvalidArgument(
2135             "Out of bounds slicing on dimension ", i, " of length ", dim_val,
2136             ": sizes vector cannot be < -1, but was ", sizes_vec(i));
2137       }
2138 
2139       dims->emplace_back(c->MakeDim(sizes_vec(i)));
2140     } else {
2141       DimensionHandle result;
2142       TF_RETURN_IF_ERROR(c->Subtract(dim, c->Dim(begin_value, i), &result));
2143       dims->emplace_back(result);
2144     }
2145   }
2146 
2147   return Status::OK();
2148 }
2149 }  // namespace
2150 
SliceShape(InferenceContext * c)2151 Status SliceShape(InferenceContext* c) {
2152   ShapeHandle input = c->input(0);
2153   ShapeHandle begin_shape;
2154   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape));
2155   ShapeHandle sizes_shape;
2156   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sizes_shape));
2157 
2158   // Merge to check compatibility of begin and sizes tensors.
2159   TF_RETURN_IF_ERROR(c->Merge(begin_shape, sizes_shape, &begin_shape));
2160 
2161   DimensionHandle ndims = c->Dim(begin_shape, 0);
2162   if (c->ValueKnown(ndims)) {
2163     TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(ndims), &input));
2164   }
2165 
2166   // NOTE(mrry): Use MakeShapeFromShapeTensor to handle partially-known
2167   // values, even though the `begin` value does not represent a shape.
2168   ShapeHandle begin_value;
2169   TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &begin_value));
2170 
2171   // We check the tensor value here and will only use
2172   // `MakeShapeFromShapeTensor` when `sizes_value` is null.
2173   // The reason is that `sizes` might contain -1, which can't
2174   // be represented (-1 in the ShapeHandle would mean "unknown").
2175   const Tensor* sizes_value = c->input_tensor(2);
2176 
2177   if (sizes_value != nullptr) {
2178     TF_RETURN_IF_ERROR(
2179         c->WithRank(begin_value, sizes_value->NumElements(), &begin_value));
2180     std::vector<DimensionHandle> dims;
2181     // If the begin and sizes tensors are available, then
2182     // we can be precise about the shape of the output.
2183     if (sizes_value->dtype() == DT_INT64) {
2184       TF_RETURN_IF_ERROR(
2185           SliceHelper<int64>(c, begin_value, sizes_value, &dims));
2186     } else {
2187       TF_RETURN_IF_ERROR(
2188           SliceHelper<int32>(c, begin_value, sizes_value, &dims));
2189     }
2190     c->set_output(0, c->MakeShape(dims));
2191     return Status::OK();
2192   } else {
2193     // In case `sizes` is not available (`sizes_value` is null),
2194     // we could try to use `MakeShapeFromShapeTensor` here.
2195     // If sizes contain -1, we will simply consider it as `Unknown`.
2196     // This is less than ideal but still an improvement of shape inference.
2197     // The following is an example that returns [None, 1, None] with this
2198     // code path:
2199     //   z = tf.zeros((1, 2, 3))
2200     //   m = tf.slice(z, [0, 0, 0], [tf.constant(1) + 0, 1, -1])
2201     //   m.get_shape().as_list()
2202     ShapeHandle sizes_value;
2203     TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &sizes_value));
2204     if (c->RankKnown(sizes_value)) {
2205       TF_RETURN_IF_ERROR(
2206           c->WithRank(begin_value, c->Rank(sizes_value), &begin_value));
2207       std::vector<DimensionHandle> dims;
2208       dims.reserve(c->Rank(sizes_value));
2209       for (int i = 0; i < c->Rank(sizes_value); ++i) {
2210         dims.emplace_back(c->Dim(sizes_value, i));
2211       }
2212       c->set_output(0, c->MakeShape(dims));
2213       return Status::OK();
2214     }
2215     // We might know the rank of the input.
2216     if (c->RankKnown(input)) {
2217       c->set_output(0, c->UnknownShapeOfRank(c->Rank(input)));
2218       return Status::OK();
2219     } else {
2220       return shape_inference::UnknownShape(c);
2221     }
2222   }
2223 
2224   return Status::OK();
2225 }
2226 
ValidateSparseTensor(InferenceContext * c,ShapeHandle indices_shape,ShapeHandle values_shape,ShapeHandle shape_shape)2227 Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
2228                             ShapeHandle values_shape, ShapeHandle shape_shape) {
2229   // Validate ranks.
2230   ShapeHandle unused_shape;
2231   TF_RETURN_IF_ERROR(c->WithRank(indices_shape, 2, &unused_shape));
2232   TF_RETURN_IF_ERROR(c->WithRank(values_shape, 1, &unused_shape));
2233   TF_RETURN_IF_ERROR(c->WithRank(shape_shape, 1, &unused_shape));
2234 
2235   // Number of elements in indices and values must match.
2236   DimensionHandle num_index_elements_dim = c->Dim(indices_shape, 0);
2237   if (c->ValueKnown(num_index_elements_dim)) {
2238     DimensionHandle num_values_elements_dim = c->Dim(values_shape, 0);
2239     if (c->ValueKnown(num_values_elements_dim)) {
2240       int64 num_index_elements = c->Value(num_index_elements_dim);
2241       int64 num_values_elements = c->Value(num_values_elements_dim);
2242       if (num_index_elements != num_values_elements) {
2243         return errors::InvalidArgument("Number of elements in index (",
2244                                        num_index_elements, ") and values (",
2245                                        num_values_elements, ") do not match.");
2246       }
2247     }
2248   }
2249 
2250   // Rank embedded in indices must match shape.
2251   DimensionHandle index_rank_dim = c->Dim(indices_shape, 1);
2252   if (c->ValueKnown(index_rank_dim)) {
2253     DimensionHandle shape_rank_dim = c->Dim(shape_shape, 0);
2254     if (c->ValueKnown(shape_rank_dim)) {
2255       int64 index_rank = c->Value(index_rank_dim);
2256       int32 shape_rank = c->Value(shape_rank_dim);
2257       if (index_rank != shape_rank) {
2258         return errors::InvalidArgument("Index rank (", index_rank,
2259                                        ") and shape rank (", shape_rank,
2260                                        ") do not match.");
2261       }
2262     }
2263   }
2264 
2265   return Status::OK();
2266 }
2267 
ValidateVariableResourceHandle(InferenceContext * c,std::vector<ShapeAndType> * shape_and_type)2268 Status ValidateVariableResourceHandle(
2269     InferenceContext* c, std::vector<ShapeAndType>* shape_and_type) {
2270   auto* handle_data = c->input_handle_shapes_and_types(0);
2271   if (handle_data == nullptr || handle_data->empty()) {
2272     shape_and_type->emplace_back(c->UnknownShape(), DT_INVALID);
2273   } else {
2274     *shape_and_type = *handle_data;
2275     DataType value_dtype;
2276     TF_RETURN_IF_ERROR(c->GetAttr("dtype", &value_dtype));
2277     if (shape_and_type->at(0).dtype != value_dtype) {
2278       return errors::InvalidArgument(
2279           "Trying to read variable with wrong dtype. "
2280           "Expected ",
2281           DataTypeString(shape_and_type->at(0).dtype), " got ",
2282           DataTypeString(value_dtype));
2283     }
2284   }
2285   return Status::OK();
2286 }
2287 
GatherNdShape(InferenceContext * c)2288 Status GatherNdShape(InferenceContext* c) {
2289   ShapeHandle params;
2290   std::vector<ShapeAndType> handle_shape_and_type;
2291   if (c->input_handle_shapes_and_types(0) != nullptr) {
2292     TF_RETURN_IF_ERROR(
2293         ValidateVariableResourceHandle(c, &handle_shape_and_type));
2294     params = handle_shape_and_type[0].shape;
2295   } else {
2296     params = c->input(0);
2297   }
2298   ShapeHandle indices;
2299   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices));
2300   DimensionHandle r_dim = c->Dim(indices, -1);
2301 
2302   if (!c->RankKnown(params) || !c->ValueKnown(r_dim)) {
2303     c->set_output(0, c->UnknownShape());
2304     return Status::OK();
2305   }
2306 
2307   if (c->Value(r_dim) > c->Rank(params)) {
2308     return errors::InvalidArgument(
2309         "indices.shape[-1] must be <= params.rank, but saw indices shape: ",
2310         c->DebugString(indices), " and params shape: ", c->DebugString(params));
2311   }
2312 
2313   // Remove r_dim from indices to get output.
2314   ShapeHandle indices_slice;
2315   ShapeHandle params_slice;
2316   TF_RETURN_IF_ERROR(c->Subshape(indices, 0, -1, &indices_slice));
2317   TF_RETURN_IF_ERROR(c->Subshape(params, c->Value(r_dim), &params_slice));
2318   ShapeHandle out;
2319   TF_RETURN_IF_ERROR(c->Concatenate(indices_slice, params_slice, &out));
2320   c->set_output(0, out);
2321   return Status::OK();
2322 }
2323 
ScatterNdShapeHelper(InferenceContext * c,ShapeHandle indices_shape,ShapeHandle updates_shape,ShapeHandle input_shape)2324 Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape,
2325                             ShapeHandle updates_shape,
2326                             ShapeHandle input_shape) {
2327   if (c->Value(c->NumElements(input_shape)) == 0 &&
2328       (c->Value(c->NumElements(indices_shape)) > 0 ||
2329        c->Value(c->NumElements(updates_shape)) > 0)) {
2330     return errors::InvalidArgument(
2331         "Indices and updates specified for empty input");
2332   }
2333 
2334   if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) {
2335     const int64 outer_dims = c->Rank(indices_shape) - 1;
2336     const DimensionHandle ixdim = c->Dim(indices_shape, -1);
2337 
2338     // We can only do more validation if the last dimension of indices
2339     // is a known value.
2340     if (c->ValueKnown(ixdim)) {
2341       int64 ix = c->Value(ixdim);
2342       ShapeHandle unused;
2343       ShapeHandle prefix_indices;
2344       TF_RETURN_IF_ERROR(
2345           c->Subshape(indices_shape, 0, outer_dims, &prefix_indices));
2346       ShapeHandle prefix_updates;
2347       TF_RETURN_IF_ERROR(
2348           c->Subshape(updates_shape, 0, outer_dims, &prefix_updates));
2349 
2350       Status s = c->Merge(prefix_indices, prefix_updates, &unused);
2351       if (!s.ok()) {
2352         return errors::InvalidArgument(
2353             "Dimensions [0,", outer_dims,
2354             ") of indices[shape=", c->DebugString(indices_shape),
2355             "] = ", c->DebugString(prefix_indices),
2356             " must match dimensions [0,", outer_dims,
2357             ") of updates[shape=", c->DebugString(updates_shape),
2358             "] = ", c->DebugString(prefix_updates), ": ", s.error_message());
2359       }
2360 
2361       ShapeHandle suffix_output;
2362       TF_RETURN_IF_ERROR(c->Subshape(input_shape, ix, &suffix_output));
2363       ShapeHandle suffix_updates;
2364       TF_RETURN_IF_ERROR(
2365           c->Subshape(updates_shape, outer_dims, &suffix_updates));
2366       s = c->Merge(suffix_output, suffix_updates, &unused);
2367       if (!s.ok()) {
2368         return errors::InvalidArgument(
2369             "Dimensions [", ix, ",", c->Rank(input_shape),
2370             ") of input[shape=", c->DebugString(input_shape),
2371             "] = ", c->DebugString(suffix_output), " must match dimensions [",
2372             outer_dims, ",", c->Rank(updates_shape),
2373             ") of updates[shape=", c->DebugString(updates_shape),
2374             "] = ", c->DebugString(suffix_updates), ": ", s.error_message());
2375       }
2376     }
2377   }
2378 
2379   if (c->input_handle_shapes_and_types(0) == nullptr && c->num_outputs() > 0) {
2380     // This is called for tf.scatter_nd; output is a tensor with this shape.
2381     c->set_output(0, input_shape);
2382   }
2383   return Status::OK();
2384 }
2385 
ExplicitShape(InferenceContext * c)2386 Status ExplicitShape(InferenceContext* c) {
2387   PartialTensorShape shape;
2388   TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
2389   ShapeHandle output_shape;
2390   TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &output_shape));
2391   c->set_output(0, output_shape);
2392   return Status::OK();
2393 }
2394 
ExplicitShapes(InferenceContext * c)2395 Status ExplicitShapes(InferenceContext* c) {
2396   std::vector<PartialTensorShape> shapes;
2397   TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes));
2398   if (shapes.empty()) {
2399     return errors::Internal("shapes attribute is empty");
2400   }
2401   for (int i = 0, end = shapes.size(); i < end; ++i) {
2402     ShapeHandle output_shape;
2403     TF_RETURN_IF_ERROR(
2404         c->MakeShapeFromPartialTensorShape(shapes[i], &output_shape));
2405     c->set_output(i, output_shape);
2406   }
2407   return Status::OK();
2408 }
2409 
SparseReduceShapeFn(InferenceContext * c)2410 Status SparseReduceShapeFn(InferenceContext* c) {
2411   // Input 0: input_indices
2412   // Input 1: input_values
2413   // Input 2: input_shape
2414   // Input 3: reduction_axes
2415   // Attr: keep_dims
2416   bool keep_dims = false;
2417   TF_RETURN_IF_ERROR(c->GetAttr("keep_dims", &keep_dims));
2418 
2419   const Tensor* shape_tensor = c->input_tensor(2);
2420   const Tensor* axes_tensor = c->input_tensor(3);
2421   if (shape_tensor != nullptr && axes_tensor != nullptr) {
2422     auto shape_vec = shape_tensor->flat<int64>();
2423     auto axes_vec = axes_tensor->flat<int32>();
2424 
2425     int64 ndims = shape_vec.size();
2426     absl::flat_hash_set<int64> axes;
2427     for (int i = 0; i < axes_vec.size(); i++) {
2428       axes.insert((axes_vec(i) + ndims) % ndims);
2429     }
2430 
2431     std::vector<DimensionHandle> dims;
2432     if (keep_dims) {
2433       dims.reserve(ndims);
2434       for (int d = 0; d < ndims; ++d) {
2435         if (axes.find(d) == axes.end()) {
2436           dims.push_back(c->MakeDim(shape_vec(d)));
2437         } else {
2438           dims.push_back(c->MakeDim(1));
2439         }
2440       }
2441     } else {
2442       for (int d = 0; d < ndims; ++d) {
2443         if (axes.find(d) == axes.end()) {
2444           dims.push_back(c->MakeDim(shape_vec(d)));
2445         }
2446       }
2447     }
2448 
2449     c->set_output(0, c->MakeShape(dims));
2450     return Status::OK();
2451   }
2452   return UnknownShape(c);
2453 }
2454 
2455 }  // namespace shape_inference
2456 
2457 }  // namespace tensorflow
2458