1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <ostream>
18 
19 #include "tensorflow/core/framework/common_shape_fns.h"
20 #include "tensorflow/core/framework/kernel_shape_util.h"
21 #include "tensorflow/core/framework/op.h"
22 #include "tensorflow/core/framework/shape_inference.h"
23 #include "tensorflow/core/framework/tensor.pb.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/framework/types.pb.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/util/mirror_pad_mode.h"
28 #include "tensorflow/core/util/padding.h"
29 #include "tensorflow/core/util/strided_slice_op.h"
30 #include "tensorflow/core/util/tensor_format.h"
31 
32 namespace tensorflow {
33 
34 using shape_inference::DimensionHandle;
35 using shape_inference::InferenceContext;
36 using shape_inference::ShapeHandle;
37 using shape_inference::UnchangedShape;
38 
39 namespace {
40 
GetAxisForPackAndUnpack(InferenceContext * c,int32 rank_after_pack,int32 * axis)41 Status GetAxisForPackAndUnpack(InferenceContext* c, int32 rank_after_pack,
42                                int32* axis) {
43   TF_RETURN_IF_ERROR(c->GetAttr("axis", axis));
44   if (*axis < -1 * rank_after_pack || *axis >= rank_after_pack) {
45     return errors::InvalidArgument("Invalid axis: ", *axis, "; must be in [",
46                                    -1 * rank_after_pack, ",", rank_after_pack,
47                                    ")");
48   }
49   if (*axis < 0) *axis = (rank_after_pack + *axis);
50   return Status::OK();
51 }
52 
53 template <typename T>
AsInt64(const Tensor * tensor,int64 num_elements)54 std::vector<int64> AsInt64(const Tensor* tensor, int64 num_elements) {
55   std::vector<int64> ret(num_elements);
56   auto data = tensor->vec<T>();
57   for (int64 i = 0; i < num_elements; ++i) {
58     ret[i] = data(i);
59   }
60   return ret;
61 }
62 
63 template <typename T>
PadKnown(InferenceContext * c,ShapeHandle input,const Tensor * paddings_t,int64 num_dims)64 Status PadKnown(InferenceContext* c, ShapeHandle input,
65                 const Tensor* paddings_t, int64 num_dims) {
66   // paddings_t is known.
67   std::vector<DimensionHandle> dims(num_dims);
68   auto paddings_data = paddings_t->matrix<T>();
69   for (int64 i = 0; i < num_dims; ++i) {
70     const T pad0 = paddings_data(i, 0);
71     const T pad1 = paddings_data(i, 1);
72     if (pad0 < 0 || pad1 < 0) {
73       return errors::InvalidArgument("Paddings must be non-negative");
74     }
75     TF_RETURN_IF_ERROR(c->Add(c->Dim(input, i), pad0 + pad1, &dims[i]));
76   }
77   c->set_output(0, c->MakeShape(dims));
78   return Status::OK();
79 }
80 
PadShapeFn(InferenceContext * c)81 Status PadShapeFn(InferenceContext* c) {
82   // Paddings is a matrix of [input_rank, 2].
83   ShapeHandle paddings;
84   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &paddings));
85   DimensionHandle unused;
86   TF_RETURN_IF_ERROR(c->WithValue(c->Dim(paddings, 1), 2, &unused));
87 
88   // n_dim and input.rank are equivalent.
89   ShapeHandle input = c->input(0);
90   DimensionHandle n_dim = c->Dim(paddings, 0);
91   if (c->ValueKnown(n_dim)) {
92     TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(n_dim), &input));
93   } else if (c->RankKnown(input)) {
94     TF_RETURN_IF_ERROR(c->WithValue(n_dim, c->Rank(input), &n_dim));
95   }
96 
97   const Tensor* paddings_t = c->input_tensor(1);
98 
99   // paddings_t is unknown
100   if (paddings_t == nullptr) {
101     if (c->ValueKnown(n_dim)) {
102       // Make output with n_dim unknown dims.
103       c->set_output(0, c->UnknownShapeOfRank(c->Value(n_dim)));
104     } else {
105       c->set_output(0, c->UnknownShape());
106     }
107     return Status::OK();
108   }
109 
110   const int64 num_dims = paddings_t->shape().dim_size(0);
111   TF_RETURN_IF_ERROR(c->WithRank(input, num_dims, &input));
112   TF_RETURN_IF_ERROR(c->WithValue(n_dim, num_dims, &n_dim));
113 
114   if (paddings_t->dtype() == DT_INT32) {
115     return PadKnown<int32>(c, input, paddings_t, num_dims);
116   } else {
117     return PadKnown<int64>(c, input, paddings_t, num_dims);
118   }
119 }
120 
TransposeShapeFn(InferenceContext * c)121 Status TransposeShapeFn(InferenceContext* c) {
122   ShapeHandle input = c->input(0);
123   ShapeHandle perm_shape = c->input(1);
124   const Tensor* perm = c->input_tensor(1);
125   DimensionHandle perm_elems = c->NumElements(perm_shape);
126   // If we don't have rank information on the input or value information on
127   // perm we can't return any shape information, otherwise we have enough
128   // information to at least find the rank of the output.
129   if (!c->RankKnown(input) && !c->ValueKnown(perm_elems) && perm == nullptr) {
130     c->set_output(0, c->UnknownShape());
131     return Status::OK();
132   }
133 
134   // Find our value of the rank.
135   int64 rank;
136   if (c->RankKnown(input)) {
137     rank = c->Rank(input);
138   } else if (c->ValueKnown(perm_elems)) {
139     rank = c->Value(perm_elems);
140   } else {
141     rank = perm->NumElements();
142   }
143   if (!c->RankKnown(input) && rank < 2) {
144     // A permutation array containing a single element is ambiguous. It could
145     // indicate either a scalar or a 1-dimensional array, both of which the
146     // transpose op returns unchanged.
147     c->set_output(0, input);
148     return Status::OK();
149   }
150 
151   std::vector<DimensionHandle> dims;
152   dims.resize(rank);
153   TF_RETURN_IF_ERROR(c->WithRank(input, rank, &input));
154   // Ensure that perm is a vector and has rank elements.
155   TF_RETURN_IF_ERROR(c->WithRank(perm_shape, 1, &perm_shape));
156   TF_RETURN_IF_ERROR(c->WithValue(perm_elems, rank, &perm_elems));
157 
158   // If we know the rank of the input and the value of perm, we can return
159   // all shape information, otherwise we can only return rank information,
160   // but no information for the dimensions.
161   if (perm != nullptr) {
162     std::vector<int64> data;
163     if (perm->dtype() == DT_INT32) {
164       data = AsInt64<int32>(perm, rank);
165     } else {
166       data = AsInt64<int64>(perm, rank);
167     }
168 
169     for (int32 i = 0; i < rank; ++i) {
170       int64 in_idx = data[i];
171       if (in_idx >= rank) {
172         return errors::InvalidArgument("perm dim ", in_idx,
173                                        " is out of range of input rank ", rank);
174       }
175       dims[i] = c->Dim(input, in_idx);
176     }
177   } else {
178     for (int i = 0; i < rank; ++i) {
179       dims[i] = c->UnknownDim();
180     }
181   }
182 
183   c->set_output(0, c->MakeShape(dims));
184   return Status::OK();
185 }
186 
SetOutputShapeForReshape(InferenceContext * c)187 Status SetOutputShapeForReshape(InferenceContext* c) {
188   ShapeHandle in = c->input(0);
189   ShapeHandle out;
190   TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out));
191 
192   if (!c->RankKnown(out)) {
193     // We have no information about the shape of the output.
194     c->set_output(0, out);
195     return Status::OK();
196   }
197   if (c->RankKnown(in)) {
198     // We don't know the number of output elements, but we can try to infer
199     // the missing dimension.
200     bool too_many_unknown = false;
201     int32 out_unknown_idx = -1;
202 
203     DimensionHandle known_out_elems = c->NumElements(out);
204     if (!c->ValueKnown(known_out_elems)) {
205       known_out_elems = c->MakeDim(1);
206       for (int32 i = 0; i < c->Rank(out); ++i) {
207         DimensionHandle dim = c->Dim(out, i);
208         if (!c->ValueKnown(dim)) {
209           if (out_unknown_idx >= 0) {
210             too_many_unknown = true;
211             break;
212           }
213           out_unknown_idx = i;
214         } else {
215           TF_RETURN_IF_ERROR(
216               c->Multiply(known_out_elems, dim, &known_out_elems));
217         }
218       }
219     }
220     int32 in_unknown_idx = -1;
221     DimensionHandle known_in_elems = c->NumElements(in);
222     if (!c->ValueKnown(known_in_elems)) {
223       known_in_elems = c->MakeDim(1);
224       for (int32 i = 0; i < c->Rank(in); ++i) {
225         DimensionHandle dim = c->Dim(in, i);
226         if (!c->ValueKnown(dim)) {
227           if (in_unknown_idx >= 0) {
228             too_many_unknown = true;
229             break;
230           }
231           in_unknown_idx = i;
232         } else {
233           TF_RETURN_IF_ERROR(c->Multiply(known_in_elems, dim, &known_in_elems));
234         }
235       }
236     }
237 
238     if (!too_many_unknown) {
239       if (in_unknown_idx < 0 && out_unknown_idx < 0) {
240         // Just check that the dimensions match.
241         if (c->Value(known_in_elems) != c->Value(known_out_elems)) {
242           return errors::InvalidArgument(
243               "Cannot reshape a tensor with ", c->DebugString(known_in_elems),
244               " elements to shape ", c->DebugString(out), " (",
245               c->DebugString(known_out_elems), " elements)");
246         }
247       } else if (in_unknown_idx < 0 && out_unknown_idx >= 0 &&
248                  c->Value(known_out_elems) > 0) {
249         // Input fully known, infer the one missing output dim
250         DimensionHandle inferred_dim;
251         TF_RETURN_IF_ERROR(c->Divide(known_in_elems, c->Value(known_out_elems),
252                                      true /* evenly_divisible */,
253                                      &inferred_dim));
254         TF_RETURN_IF_ERROR(
255             c->ReplaceDim(out, out_unknown_idx, inferred_dim, &out));
256 
257       } else if (in_unknown_idx >= 0 && out_unknown_idx < 0 &&
258                  c->Value(known_in_elems) != 0) {
259         // Output fully known, infer the one missing input dim
260         DimensionHandle inferred_dim;
261         TF_RETURN_IF_ERROR(c->Divide(known_out_elems, c->Value(known_in_elems),
262                                      true /* evenly_divisible */,
263                                      &inferred_dim));
264         DimensionHandle unknown_in_dim = c->Dim(in, in_unknown_idx);
265         TF_RETURN_IF_ERROR(
266             c->Merge(unknown_in_dim, inferred_dim, &unknown_in_dim));
267       } else if (in_unknown_idx >= 0 && out_unknown_idx >= 0) {
268         // Exactly one unknown dimension in both input and output. These 2 are
269         // equal iff the known elements are equal.
270         if (c->Value(known_in_elems) == c->Value(known_out_elems)) {
271           DimensionHandle unknown_in_dim = c->Dim(in, in_unknown_idx);
272           TF_RETURN_IF_ERROR(
273               c->ReplaceDim(out, out_unknown_idx, unknown_in_dim, &out));
274         }
275       }
276     }
277   }
278   c->set_output(0, out);
279   return Status::OK();
280 }
281 
282 }  // namespace
283 
284 REGISTER_OP("ParallelConcat")
285     .Input("values: N * T")
286     .Output("output: T")
287     .Attr("N: int >= 1")
288     .Attr("T: type")
289     .Attr("shape: shape")
__anondb9326b20202(InferenceContext* c) 290     .SetShapeFn([](InferenceContext* c) {
291       // Validate that the shape attr is correct.
292       PartialTensorShape shape;
293       TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
294       ShapeHandle passed_shape;
295       TF_RETURN_IF_ERROR(
296           c->MakeShapeFromPartialTensorShape(shape, &passed_shape));
297       if (!c->FullyDefined(passed_shape)) {
298         return errors::InvalidArgument("shape attr must be fully defined.");
299       }
300       ShapeHandle cur;
301       TF_RETURN_IF_ERROR(c->ReplaceDim(
302           passed_shape, 0, c->MakeDim(shape_inference::DimensionOrConstant(1)),
303           &cur));
304       for (int i = 0; i < c->num_inputs(); ++i) {
305         if (!c->FullyDefined(c->input(i))) {
306           return errors::InvalidArgument(
307               "All input shapes must be fully defined.");
308         }
309         DimensionHandle unused;
310         if (!c->WithValue(c->Dim(c->input(i), 0), 1, &unused).ok()) {
311           return errors::InvalidArgument("Size of first dimension must be 1.");
312         }
313         TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
314                                         "From merging shape ", i,
315                                         " with other shapes.");
316       }
317 
318       c->set_output(0, passed_shape);
319 
320       return Status::OK();
321     });
322 
323 REGISTER_OP("Pack")
324     .Input("values: N * T")
325     .Output("output: T")
326     .Attr("N: int >= 1")
327     .Attr("T: type")
328     .Attr("axis: int = 0")
__anondb9326b20302(InferenceContext* c) 329     .SetShapeFn([](InferenceContext* c) {
330       // Validate shapes of all inputs are compatible
331       ShapeHandle cur = c->input(c->num_inputs() - 1);
332       for (int i = c->num_inputs() - 2; i >= 0; --i) {
333         TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
334                                         "From merging shape ", i,
335                                         " with other shapes.");
336       }
337       if (!c->RankKnown(cur)) {
338         c->set_output(0, c->UnknownShape());
339         return Status::OK();
340       }
341       // Determine the axis that will be added, converting from negative
342       // axes to a positive point per negative indexing rules.
343       int32 rank = c->Rank(cur);
344       int32 axis;
345       TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank + 1, &axis));
346 
347       // Copy all dimensions over, inserting a dimension of value #inputs
348       // at <axis>.
349       std::vector<DimensionHandle> dims;
350       int index = 0;
351       while (index < axis) dims.push_back(c->Dim(cur, index++));
352       dims.push_back(c->MakeDim(c->num_inputs()));
353       while (index < rank) dims.push_back(c->Dim(cur, index++));
354 
355       c->set_output(0, c->MakeShape(dims));
356       for (int i = 0; i < c->num_inputs(); ++i) {
357         auto* shape_and_type = c->input_handle_shapes_and_types(i);
358         if (shape_and_type) {
359           if (!c->RelaxOutputHandleShapesAndMergeTypes(0, *shape_and_type)) {
360             c->set_output_handle_shapes_and_types(
361                 0, std::vector<shape_inference::ShapeAndType>({}));
362             break;
363           }
364         }
365       }
366       return Status::OK();
367     });
368 
369 REGISTER_OP("DeepCopy")
370     .Input("x: T")
371     .Output("y: T")
372     .Attr("T: type")
373     .SetIsStateful()
374     .SetShapeFn(UnchangedShape);
375 
376 REGISTER_OP("InplaceUpdate")
377     .Input("x: T")
378     .Input("i: int32")
379     .Input("v: T")
380     .Output("y: T")
381     .Attr("T: type")
382     .SetShapeFn(UnchangedShape);
383 
384 REGISTER_OP("InplaceAdd")
385     .Input("x: T")
386     .Input("i: int32")
387     .Input("v: T")
388     .Output("y: T")
389     .Attr("T: type")
390     .SetShapeFn(UnchangedShape);
391 
392 REGISTER_OP("InplaceSub")
393     .Input("x: T")
394     .Input("i: int32")
395     .Input("v: T")
396     .Output("y: T")
397     .Attr("T: type")
398     .SetShapeFn(UnchangedShape);
399 
400 REGISTER_OP("Empty")
401     .Input("shape: int32")
402     .Output("output: dtype")
403     .Attr("dtype: type")
404     .Attr("init: bool = false")
405     .SetDoNotOptimize()
__anondb9326b20402(InferenceContext* c) 406     .SetShapeFn([](InferenceContext* c) {
407       ShapeHandle out;
408       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
409       c->set_output(0, out);
410       return Status::OK();
411     });
412 
413 // --------------------------------------------------------------------------
414 REGISTER_OP("Unpack")
415     .Input("value: T")
416     .Output("output: num * T")
417     .Attr("num: int >= 0")
418     .Attr("T: type")
419     .Attr("axis: int = 0")
__anondb9326b20502(InferenceContext* c) 420     .SetShapeFn([](InferenceContext* c) {
421       ShapeHandle s = c->input(0);
422       ShapeHandle out;
423       if (c->RankKnown(s)) {
424         // Determine the axis that will be removed, converting from negative
425         // axes to a positive point per negative indexing rules.
426         int32 rank = c->Rank(s);
427         int32 axis;
428         TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank, &axis));
429 
430         // The axis dim matches the number of outputs.
431         DimensionHandle unused;
432         TF_RETURN_IF_ERROR(
433             c->WithValue(c->Dim(s, axis), c->num_outputs(), &unused));
434 
435         // Copy all dimensions, removing the <axis> dimension.
436         std::vector<DimensionHandle> dims;
437         for (int i = 0; i < rank; ++i) {
438           if (i != axis) dims.push_back(c->Dim(s, i));
439         }
440         out = c->MakeShape(dims);
441       } else {
442         // All outputs are the same shape, but it's not known.
443         out = c->UnknownShape();
444       }
445       for (int i = 0; i < c->num_outputs(); ++i) c->set_output(i, out);
446       return Status::OK();
447     });
448 
449 REGISTER_OP("UnravelIndex")
450     .Input("indices: Tidx")
451     .Input("dims: Tidx")
452     .Output("output: Tidx")
453     .Attr("Tidx: {int32, int64} = DT_INT32")
__anondb9326b20602(InferenceContext* c) 454     .SetShapeFn([](InferenceContext* c) {
455       ShapeHandle indices = c->input(0);
456       ShapeHandle dims;
457       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &dims));
458       if (c->RankKnown(indices) && c->Rank(indices) == 0) {
459         c->set_output(0, c->Vector(c->Dim(dims, 0)));
460       } else if (c->RankKnown(indices)) {
461         c->set_output(0, c->Matrix(c->Dim(dims, 0), c->NumElements(indices)));
462       } else {
463         c->set_output(0, c->UnknownShape());
464       }
465       return Status::OK();
466     });
467 
468 REGISTER_OP("BroadcastTo")
469     .Input("input: T")
470     .Input("shape: Tidx")
471     .Output("output: T")
472     .Attr("T: type")
473     .Attr("Tidx: {int32, int64} = DT_INT32")
__anondb9326b20702(InferenceContext* c) 474     .SetShapeFn([](InferenceContext* c) {
475       ShapeHandle shape_in = c->input(1);
476       TF_RETURN_IF_ERROR(c->WithRank(shape_in, 1, &shape_in));
477       ShapeHandle out;
478       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out));
479       if (!c->RankKnown(out)) {
480         // We have no information about the shape of the output.
481         c->set_output(0, out);
482         return Status::OK();
483       }
484 
485       ShapeHandle in = c->input(0);
486       if (!c->RankKnown(in)) {
487         // We have no information about the shape of the input,
488         // nothing to do here.
489         c->set_output(0, out);
490         return Status::OK();
491       }
492       int out_rank = c->Rank(out);
493       TF_RETURN_IF_ERROR(c->WithRankAtMost(in, out_rank, &in));
494       int in_rank = c->Rank(in);
495       for (int i = 0; i < in_rank; ++i) {
496         auto in_dim = c->Dim(in, in_rank - i - 1);
497         if (c->Value(in_dim) > 1) {
498           // If the input dimension is greater than 1 then the output dimension
499           // must be equal to it, since we only broadcast "from left to right".
500           auto out_dim = c->Dim(out, out_rank - i - 1);
501           TF_RETURN_IF_ERROR(c->Merge(in_dim, out_dim, &out_dim));
502           TF_RETURN_IF_ERROR(
503               c->ReplaceDim(out, out_rank - i - 1, out_dim, &out));
504         }
505       }
506       c->set_output(0, out);
507       return Status::OK();
508     });
509 
510 // --------------------------------------------------------------------------
511 // TODO(josh11b): Remove the >= 2 constraint, once we can rewrite the graph
512 // in the N == 1 case to remove the node.
513 REGISTER_OP("Concat")
514     .Input("concat_dim: int32")
515     .Input("values: N * T")
516     .Output("output: T")
517     .Attr("N: int >= 2")
518     .Attr("T: type")
__anondb9326b20802(InferenceContext* c) 519     .SetShapeFn([](InferenceContext* c) {
520       return shape_inference::ConcatShape(c, c->num_inputs() - 1);
521     });
522 
523 REGISTER_OP("ConcatV2")
524     .Input("values: N * T")
525     .Input("axis: Tidx")
526     .Output("output: T")
527     .Attr("N: int >= 2")
528     .Attr("T: type")
529     .Attr("Tidx: {int32, int64} = DT_INT32")
530     .SetShapeFn(shape_inference::ConcatV2Shape);
531 
532 // TODO(vivek.v.rane@intel.com): Prefix the op names with underscore if the ops
533 // are not to be made user-accessible.
534 #ifdef INTEL_MKL
535 REGISTER_OP("_MklConcatV2")
536     .Input("values: N * T")
537     .Input("axis: Tidx")
538     .Input("mkl_values: N * uint8")
539     .Input("mkl_axis: uint8")
540     .Output("output: T")
541     .Output("mkl_output: uint8")
542     .Attr("N: int >= 2")
543     .Attr("T: type")
544     .Attr("Tidx: {int32, int64} = DT_INT32")
545     .SetShapeFn(shape_inference::ConcatV2Shape)
546     .Doc(R"doc(
547 MKL version of ConcatV2 operator. Uses MKL DNN APIs to perform concatenation.
548 
549 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
550 expected to invoke these operators.
551 )doc");
552 #endif
553 
554 REGISTER_OP("ConcatOffset")
555     .Input("concat_dim: int32")
556     .Input("shape: N * int32")
557     .Output("offset: N * int32")
558     .Attr("N: int >= 2")
__anondb9326b20902(InferenceContext* c) 559     .SetShapeFn([](InferenceContext* c) {
560       for (int i = 1; i < c->num_inputs(); ++i) {
561         c->set_output(i - 1, c->input(i));
562       }
563       return Status::OK();
564     });
565 
566 // --------------------------------------------------------------------------
567 REGISTER_OP("Split")
568     .Input("split_dim: int32")
569     .Input("value: T")
570     .Output("output: num_split * T")
571     .Attr("num_split: int >= 1")
572     .Attr("T: type")
__anondb9326b20a02(InferenceContext* c) 573     .SetShapeFn([](InferenceContext* c) {
574       DimensionHandle split_dimension;
575       ShapeHandle input = c->input(1);
576       TF_RETURN_IF_ERROR(c->MakeDimForScalarInputWithNegativeIndexing(
577           0, c->Rank(input), &split_dimension));
578       int num_split = c->num_outputs();
579       ShapeHandle out;
580       if (!c->ValueKnown(split_dimension)) {
581         if (c->RankKnown(input)) {
582           out = c->UnknownShapeOfRank(c->Rank(input));
583         } else {
584           out = c->UnknownShape();
585         }
586       } else {
587         int64 split_dim = c->Value(split_dimension);
588         TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input));
589         DimensionHandle split_dim_size;
590         TF_RETURN_WITH_CONTEXT_IF_ERROR(
591             c->Divide(c->Dim(input, split_dim), num_split,
592                       true /* evenly_divisible */, &split_dim_size),
593             "Number of ways to split should evenly divide the split dimension");
594         TF_RETURN_IF_ERROR(
595             c->ReplaceDim(input, split_dim, split_dim_size, &out));
596       }
597       for (int i = 0; i < num_split; ++i) c->set_output(i, out);
598       return Status::OK();
599     });
600 
601 REGISTER_OP("SplitV")
602     .Input("value: T")
603     .Input("size_splits: Tlen")
604     .Input("split_dim: int32")
605     .Output("output: num_split * T")
606     .Attr("num_split: int >= 1")
607     .Attr("T: type")
608     .Attr("Tlen: {int32, int64} = DT_INT64")
__anondb9326b20b02(InferenceContext* c) 609     .SetShapeFn([](InferenceContext* c) {
610       DimensionHandle split_dimension;
611       ShapeHandle input = c->input(0);
612       TF_RETURN_IF_ERROR(c->MakeDimForScalarInputWithNegativeIndexing(
613           2, c->Rank(input), &split_dimension));
614       int32 num_outputs = c->num_outputs();
615       int32 rank = c->Rank(input);
616       ShapeHandle output_shape;
617       const Tensor* size_splits = c->input_tensor(1);
618       if (rank == InferenceContext::kUnknownRank) {
619         // If the rank of input tensor is unknown, then return unknown shapes.
620         // Note that the shape of each output can be different.
621         for (int i = 0; i < num_outputs; ++i) {
622           c->set_output(i, c->UnknownShape());
623         }
624       } else if (rank == 0) {
625         // Throw error if input is a scalar.
626         return errors::InvalidArgument("Can't split scalars");
627       } else if (size_splits == nullptr && c->ValueKnown(split_dimension)) {
628         // If split dimension is known, but the sizes are unknown, then
629         // only the split dimension is unknown
630         output_shape = input;
631         for (int i = 0; i < num_outputs; ++i) {
632           TF_RETURN_IF_ERROR(c->ReplaceDim(output_shape,
633                                            c->Value(split_dimension),
634                                            c->UnknownDim(), &output_shape));
635           c->set_output(i, output_shape);
636         }
637       } else if (size_splits == nullptr && !c->ValueKnown(split_dimension)) {
638         // If split dimension or tensor containing the split sizes is unknown,
639         // then return unknown shapes of same rank as input. Note that each
640         // output shape can be different since splitv doesn't always split
641         // tensors evenly.
642         for (int i = 0; i < num_outputs; ++i) {
643           c->set_output(i, c->UnknownShapeOfRank(rank));
644         }
645       } else {
646         // Determine the output shape if split dimension and split sizes are
647         // known.
648         int64 split_dim = c->Value(split_dimension);
649         TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input));
650         std::vector<int64> data;
651         if (size_splits->dtype() == DT_INT32) {
652           data = AsInt64<int32>(size_splits, size_splits->shape().dim_size(0));
653         } else {
654           data = AsInt64<int64>(size_splits, size_splits->shape().dim_size(0));
655         }
656         if (num_outputs != data.size()) {
657           return errors::InvalidArgument(
658               "Length of size_splits should be equal to num_outputs");
659         }
660         int64_t total_size = 0;
661         bool has_neg_one = false;
662         for (const auto size : data) {
663           if (size == -1) {
664             if (has_neg_one) {
665               return errors::InvalidArgument(
666                   "size_splits can only have one -1");
667             }
668             has_neg_one = true;
669           } else {
670             total_size += size;
671           }
672         }
673         auto split_dim_size = c->Value(c->Dim(input, split_dim));
674         // If the sizes of the splits are known, then
675         // make sure that the sizes add up to the expected
676         // dimension size, with the possibility of a -1.
677         // Specify the full output shapes.
678         for (int i = 0; i < num_outputs; ++i) {
679           auto size = data[i];
680           if (data[i] == -1 && c->ValueKnown(split_dim_size)) {
681             size = split_dim_size - total_size;
682           }
683           TF_RETURN_IF_ERROR(
684               c->ReplaceDim(input, split_dim, c->MakeDim(size), &output_shape));
685           c->set_output(i, output_shape);
686         }
687         if (c->ValueKnown(split_dim_size)) {
688           if (has_neg_one ? total_size > split_dim_size
689                           : total_size != split_dim_size) {
690             return errors::InvalidArgument(
691                 "can't split axis of size ", split_dim_size,
692                 " into pieces of size [", absl::StrJoin(data, ","), "]");
693           }
694         }
695       }
696 
697       return Status::OK();
698     });
699 
700 // --------------------------------------------------------------------------
701 REGISTER_OP("Const")
702     .Output("output: dtype")
703     .Attr("value: tensor")
704     .Attr("dtype: type")
__anondb9326b20c02(InferenceContext* c) 705     .SetShapeFn([](InferenceContext* c) {
706       const TensorProto* proto = nullptr;
707       TF_RETURN_IF_ERROR(c->GetAttr("value", &proto));
708       TF_RETURN_IF_ERROR(TensorShape::IsValidShape(proto->tensor_shape()));
709       TensorShape shape(proto->tensor_shape());
710       std::vector<DimensionHandle> dims;
711       dims.reserve(shape.dims());
712       for (int i = 0; i < shape.dims(); ++i) {
713         dims.push_back(c->MakeDim(shape.dim_size(i)));
714       }
715       c->set_output(0, c->MakeShape(dims));
716       return Status::OK();
717     });
718 
719 // Returns a constant tensor on the host.  Useful for writing C++ tests
720 // and benchmarks which run on GPU but require arguments pinned to the host.
721 // Used by test::graph::HostConstant.
722 // value: Attr `value` is the tensor to return.
723 REGISTER_OP("HostConst")
724     .Output("output: dtype")
725     .Attr("value: tensor")
726     .Attr("dtype: type")
727     .SetShapeFn(shape_inference::UnknownShape);
728 
729 // --------------------------------------------------------------------------
730 // TODO(mgubin): Update the doc when the freeze_graph script supports converting
731 // into memmapped format.
732 REGISTER_OP("ImmutableConst")
733     .Attr("dtype: type")
734     .Attr("shape: shape")
735     .Attr("memory_region_name: string")
736     .Output("tensor: dtype")
737     .SetShapeFn(shape_inference::ExplicitShape);
738 
739 REGISTER_OP("GuaranteeConst")
740     .Input("input: T")
741     .Output("output: T")
742     .Attr("T: type")
__anondb9326b20d02(shape_inference::InferenceContext* c) 743     .SetShapeFn([](shape_inference::InferenceContext* c) {
744       return UnchangedShape(c);
745     })
746     // We don't want this to be optimized away.
747     .SetDoNotOptimize();
748 
749 // --------------------------------------------------------------------------
750 REGISTER_OP("ZerosLike")
751     .Input("x: T")
752     .Output("y: T")
753     .Attr("T: type")
754     .SetShapeFn(shape_inference::UnchangedShape);
755 
756 // --------------------------------------------------------------------------
757 REGISTER_OP("OnesLike")
758     .Input("x: T")
759     .Output("y: T")
760     .Attr(
761         "T: {bfloat16, half, float, double, int8, uint8, int16, uint16, int32, "
762         "int64, complex64, complex128, bool}")
763     .SetShapeFn(shape_inference::UnchangedShape);
764 
765 // --------------------------------------------------------------------------
766 REGISTER_OP("Diag")
767     .Input("diagonal: T")
768     .Output("output: T")
769     .Attr(
770         "T: {bfloat16, half, float, double, int32, int64, complex64, "
771         "complex128}")
__anondb9326b20e02(InferenceContext* c) 772     .SetShapeFn([](InferenceContext* c) {
773       ShapeHandle in = c->input(0);
774       TF_RETURN_IF_ERROR(c->WithRankAtLeast(in, 1, &in));
775       // Output shape is original concatenated with itself.
776       ShapeHandle out;
777       TF_RETURN_IF_ERROR(c->Concatenate(in, in, &out));
778       c->set_output(0, out);
779       return Status::OK();
780     });
781 
782 // --------------------------------------------------------------------------
783 REGISTER_OP("DiagPart")
784     .Input("input: T")
785     .Output("diagonal: T")
786     .Attr(
787         "T: {bfloat16, half, float, double, int32, int64, complex64, "
788         "complex128}")
__anondb9326b20f02(InferenceContext* c) 789     .SetShapeFn([](InferenceContext* c) {
790       ShapeHandle in = c->input(0);
791       if (!c->RankKnown(in)) {
792         c->set_output(0, c->UnknownShape());
793         return Status::OK();
794       }
795       // Rank must be even, and result will have rank <rank/2>.
796       const int32 rank = c->Rank(in);
797       if ((rank % 2) != 0 || rank <= 0) {
798         return errors::InvalidArgument(
799             "Input must have even and non-zero rank, input rank is ", rank);
800       }
801       const int32 mid = rank / 2;
802 
803       // output dim[i] is the merge of in.dim[i] and in.dim[i+mid].
804       std::vector<DimensionHandle> dims(mid);
805       for (int i = 0; i < mid; ++i) {
806         TF_RETURN_IF_ERROR(
807             c->Merge(c->Dim(in, i), c->Dim(in, i + mid), &dims[i]));
808       }
809       c->set_output(0, c->MakeShape(dims));
810       return Status::OK();
811     });
812 
813 // --------------------------------------------------------------------------
814 REGISTER_OP("MatrixDiag")
815     .Input("diagonal: T")
816     .Output("output: T")
817     .Attr("T: type")
__anondb9326b21002(InferenceContext* c) 818     .SetShapeFn([](InferenceContext* c) {
819       ShapeHandle in;
820       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &in));
821       if (!c->RankKnown(in)) {
822         c->set_output(0, c->UnknownShape());
823         return Status::OK();
824       }
825       const int32 rank = c->Rank(in);
826       ShapeHandle out;
827       TF_RETURN_IF_ERROR(
828           c->Concatenate(in, c->Vector(c->Dim(in, rank - 1)), &out));
829       c->set_output(0, out);
830       return Status::OK();
831     });
832 
833 REGISTER_OP("MatrixDiagV2")
834     .Input("diagonal: T")
835     .Input("k: int32")
836     .Input("num_rows: int32")
837     .Input("num_cols: int32")
838     .Input("padding_value: T")
839     .Output("output: T")
840     .Attr("T: type")
841     .SetShapeFn(shape_inference::MatrixDiagV2Shape);
842 
843 REGISTER_OP("MatrixDiagV3")
844     .Input("diagonal: T")
845     .Input("k: int32")
846     .Input("num_rows: int32")
847     .Input("num_cols: int32")
848     .Input("padding_value: T")
849     .Output("output: T")
850     .Attr("T: type")
851     .Attr(
852         "align: {'LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'} = "
853         "'RIGHT_LEFT'")
854     .SetShapeFn(shape_inference::MatrixDiagV2Shape);
855 
856 // --------------------------------------------------------------------------
857 REGISTER_OP("MatrixSetDiag")
858     .Input("input: T")
859     .Input("diagonal: T")
860     .Output("output: T")
861     .Attr("T: type")
__anondb9326b21102(InferenceContext* c) 862     .SetShapeFn([](InferenceContext* c) {
863       ShapeHandle input;
864       ShapeHandle diag;
865       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
866       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &diag));
867       if (c->RankKnown(input)) {
868         TF_RETURN_IF_ERROR(c->WithRank(c->input(1), c->Rank(input) - 1, &diag));
869       }
870       DimensionHandle smallest_dim;
871       TF_RETURN_IF_ERROR(
872           c->Min(c->Dim(input, -2), c->Dim(input, -1), &smallest_dim));
873       TF_RETURN_IF_ERROR(
874           c->Merge(smallest_dim, c->Dim(diag, -1), &smallest_dim));
875 
876       ShapeHandle output = input;
877       if (c->RankKnown(diag) && !c->FullyDefined(input)) {
878         // Try to infer parts of shape from diag.
879         ShapeHandle diag_batch_shape;
880         TF_RETURN_IF_ERROR(c->Subshape(diag, 0, -1, &diag_batch_shape));
881         TF_RETURN_IF_ERROR(
882             c->Concatenate(diag_batch_shape, c->UnknownShapeOfRank(2), &diag));
883         TF_RETURN_IF_ERROR(c->Merge(input, diag, &output));
884       }
885       c->set_output(0, output);
886       return Status::OK();
887     });
888 
889 REGISTER_OP("MatrixSetDiagV2")
890     .Input("input: T")
891     .Input("diagonal: T")
892     .Input("k: int32")
893     .Output("output: T")
894     .Attr("T: type")
895     .SetShapeFn(shape_inference::MatrixSetDiagV2Shape);
896 
897 REGISTER_OP("MatrixSetDiagV3")
898     .Input("input: T")
899     .Input("diagonal: T")
900     .Input("k: int32")
901     .Output("output: T")
902     .Attr("T: type")
903     .Attr(
904         "align: {'LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'} = "
905         "'RIGHT_LEFT'")
906     .SetShapeFn(shape_inference::MatrixSetDiagV2Shape);
907 
908 // --------------------------------------------------------------------------
909 REGISTER_OP("MatrixDiagPart")
910     .Input("input: T")
911     .Output("diagonal: T")
912     .Attr("T: type")
__anondb9326b21202(InferenceContext* c) 913     .SetShapeFn([](InferenceContext* c) {
914       ShapeHandle in;
915       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &in));
916       if (!c->RankKnown(in)) {
917         c->set_output(0, c->UnknownShape());
918         return Status::OK();
919       }
920       const int32 rank = c->Rank(in);
921       std::vector<DimensionHandle> dims;
922       dims.reserve(rank - 2);
923       for (int i = 0; i < rank - 2; ++i) dims.push_back(c->Dim(in, i));
924 
925       DimensionHandle min_dim;
926       TF_RETURN_IF_ERROR(
927           c->Min(c->Dim(in, rank - 2), c->Dim(in, rank - 1), &min_dim));
928       dims.push_back(min_dim);
929       c->set_output(0, c->MakeShape(dims));
930       return Status::OK();
931     });
932 
933 REGISTER_OP("MatrixDiagPartV2")
934     .Input("input: T")
935     .Input("k: int32")
936     .Input("padding_value: T")
937     .Output("diagonal: T")
938     .Attr("T: type")
939     .SetShapeFn(shape_inference::MatrixDiagPartV2Shape);
940 
941 REGISTER_OP("MatrixDiagPartV3")
942     .Input("input: T")
943     .Input("k: int32")
944     .Input("padding_value: T")
945     .Output("diagonal: T")
946     .Attr("T: type")
947     .Attr(
948         "align: {'LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'} = "
949         "'RIGHT_LEFT'")
950     .SetShapeFn(shape_inference::MatrixDiagPartV2Shape);
951 
952 // --------------------------------------------------------------------------
953 REGISTER_OP("MatrixBandPart")
954     .Input("input: T")
955     .Input("num_lower: Tindex")
956     .Input("num_upper: Tindex")
957     .Output("band: T")
958     .Attr("T: type")
959     .Attr("Tindex: {int32, int64} = DT_INT64")
960     .SetShapeFn(shape_inference::UnchangedShape);
961 
962 // --------------------------------------------------------------------------
963 REGISTER_OP("Reverse")
964     .Input("tensor: T")
965     .Input("dims: bool")
966     .Output("output: T")
967     .Attr(
968         "T: {uint8, int8, uint16, int16, int32, int64, bool, bfloat16, half, "
969         "float, double, complex64, complex128, string}")
__anondb9326b21302(InferenceContext* c) 970     .SetShapeFn([](InferenceContext* c) {
971       ShapeHandle input = c->input(0);
972       ShapeHandle dims;
973       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &dims));
974       DimensionHandle dims_dim = c->Dim(dims, 0);
975       if (c->ValueKnown(dims_dim)) {
976         TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(dims_dim), &input));
977       }
978       if (c->Rank(input) > 8) {
979         return errors::InvalidArgument(
980             "reverse does not work on tensors with more than 8 dimensions");
981       }
982       c->set_output(0, input);
983       return Status::OK();
984     });
985 
986 // --------------------------------------------------------------------------
987 REGISTER_OP("ReverseV2")
988     .Input("tensor: T")
989     .Input("axis: Tidx")
990     .Output("output: T")
991     .Attr("Tidx: {int32, int64} = DT_INT32")
992     .Attr(
993         "T: {uint8, int8, uint16, int16, int32, int64, bool, bfloat16, half, "
994         "float, double, complex64, complex128, string}")
__anondb9326b21402(InferenceContext* c) 995     .SetShapeFn([](InferenceContext* c) {
996       ShapeHandle input = c->input(0);
997       ShapeHandle axis;
998       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &axis));
999       if (c->Rank(input) > 8) {
1000         return errors::InvalidArgument(
1001             "reverse does not work on tensors with more than 8 dimensions");
1002       }
1003       const Tensor* axis_tensor = c->input_tensor(1);
1004       if (axis_tensor != nullptr && c->RankKnown(input)) {
1005         int32 rank = c->Rank(input);
1006         std::vector<int64> axis_value;
1007         if (axis_tensor->dtype() == DT_INT32) {
1008           axis_value = AsInt64<int32>(axis_tensor, axis_tensor->NumElements());
1009         } else {
1010           axis_value = AsInt64<int64>(axis_tensor, axis_tensor->NumElements());
1011         }
1012         std::vector<bool> axes_dense(c->Rank(input), false);
1013         for (int i = 0; i < axis_value.size(); i++) {
1014           int64 canonical_axis =
1015               axis_value[i] < 0 ? rank + axis_value[i] : axis_value[i];
1016           if (canonical_axis < 0 || canonical_axis >= rank) {
1017             return errors::InvalidArgument("'axis'[", i, "] = ", axis_value[i],
1018                                            " is out of valid range [", 0, ", ",
1019                                            rank - 1);
1020           }
1021           if (axes_dense[canonical_axis]) {
1022             return errors::InvalidArgument("axis ", canonical_axis,
1023                                            " specified more than once.");
1024           }
1025           axes_dense[canonical_axis] = true;
1026         }
1027       }
1028       c->set_output(0, input);
1029       return Status::OK();
1030     });
1031 
1032 // --------------------------------------------------------------------------
1033 REGISTER_OP("EditDistance")
1034     .Input("hypothesis_indices: int64")
1035     .Input("hypothesis_values: T")
1036     .Input("hypothesis_shape: int64")
1037     .Input("truth_indices: int64")
1038     .Input("truth_values: T")
1039     .Input("truth_shape: int64")
1040     .Attr("normalize: bool = true")
1041     .Attr("T: type")
1042     .Output("output: float")
__anondb9326b21502(InferenceContext* c) 1043     .SetShapeFn([](InferenceContext* c) {
1044       TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
1045           c, c->input(0), c->input(1), c->input(2)));
1046       TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
1047           c, c->input(3), c->input(4), c->input(5)));
1048       const Tensor* hypothesis_shape_t = c->input_tensor(2);
1049       const Tensor* truth_shape_t = c->input_tensor(5);
1050       if (hypothesis_shape_t == nullptr || truth_shape_t == nullptr) {
1051         // We need to know the runtime shape of the two tensors,
1052         // or else the output shape is unknown.
1053         return shape_inference::UnknownShape(c);
1054       }
1055 
1056       if (hypothesis_shape_t->NumElements() != truth_shape_t->NumElements()) {
1057         return errors::InvalidArgument(
1058             "Num elements of hypothesis_shape does not match truth_shape: ",
1059             hypothesis_shape_t->NumElements(), " vs. ",
1060             truth_shape_t->NumElements());
1061       }
1062 
1063       auto h_values = hypothesis_shape_t->flat<int64>();
1064       auto t_values = truth_shape_t->flat<int64>();
1065       std::vector<DimensionHandle> dims(hypothesis_shape_t->NumElements() - 1);
1066       for (int i = 0; i < dims.size(); ++i) {
1067         dims[i] = c->MakeDim(std::max(h_values(i), t_values(i)));
1068       }
1069 
1070       c->set_output(0, c->MakeShape(dims));
1071       return Status::OK();
1072     });
1073 
1074 // --------------------------------------------------------------------------
1075 REGISTER_OP("Fill")
1076     .Input("dims: index_type")
1077     .Input("value: T")
1078     .Output("output: T")
1079     .Attr("T: type")
1080     .Attr("index_type: {int32, int64} = DT_INT32")
__anondb9326b21602(InferenceContext* c) 1081     .SetShapeFn([](InferenceContext* c) {
1082       DataType index_type = DT_INT32;
1083       Status s = c->GetAttr("index_type", &index_type);
1084       if (!s.ok() && s.code() != error::NOT_FOUND) {
1085         return s;
1086       }
1087       ShapeHandle unused;
1088       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
1089       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1090 
1091       const Tensor* t = c->input_tensor(0);
1092       if (t != nullptr) {
1093         for (int i = 0; i < t->NumElements(); ++i) {
1094           if ((index_type == DT_INT32 && t->vec<int32>()(i) < 0) ||
1095               (index_type == DT_INT64 && t->vec<int64>()(i) < 0)) {
1096             return errors::InvalidArgument("Fill dimensions must be >= 0");
1097           }
1098         }
1099       }
1100 
1101       ShapeHandle out;
1102       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
1103       c->set_output(0, out);
1104 
1105       auto* shape_and_type = c->input_handle_shapes_and_types(1);
1106       if (shape_and_type) {
1107         c->set_output_handle_shapes_and_types(0, *shape_and_type);
1108       }
1109 
1110       return Status::OK();
1111     });
1112 
1113 // --------------------------------------------------------------------------
1114 REGISTER_OP("_ParallelConcatStart")
1115     .Output("output: dtype")
1116     .Attr("shape: shape")
1117     .Attr("dtype: type")
1118     .SetIsStateful()
1119     .SetShapeFn(shape_inference::ExplicitShape)
1120     .Doc(R"doc(
1121 Creates an empty Tensor with shape `shape` and type `dtype`.
1122 
1123 The memory can optionally be initialized. This is usually useful in
1124 conjunction with inplace operations.
1125 
1126 shape: 1-D `Tensor` indicating the shape of the output.
1127 dtype: The element type of the returned tensor.
1128 output: An empty Tensor of the specified type.
1129 )doc");
1130 
1131 // --------------------------------------------------------------------------
1132 REGISTER_OP("_ParallelConcatUpdate")
1133     .Input("value: T")
1134     .Input("update: T")
1135     .Output("output: T")
1136     .Attr("T: type")
1137     .Attr("loc: int")
1138     .SetShapeFn(shape_inference::UnchangedShape)
1139     .Doc(R"doc(
1140 Updates input `value` at `loc` with `update`.
1141 
1142 If you use this function you will almost certainly want to add
1143 a control dependency as done in the implementation of parallel_stack to
1144 avoid race conditions.
1145 
1146 value: A `Tensor` object that will be updated in-place.
1147 loc: A scalar indicating the index of the first dimension such that
1148          value[loc, :] is updated.
1149 update: A `Tensor` of rank one less than `value` if `loc` is a scalar,
1150         otherwise of rank equal to `value` that contains the new values
1151         for `value`.
1152 output: `value` that has been updated accordingly.
1153 )doc");
1154 
1155 // --------------------------------------------------------------------------
1156 REGISTER_OP("Gather")
1157     .Input("params: Tparams")
1158     .Input("indices: Tindices")
1159     .Attr("validate_indices: bool = true")
1160     .Output("output: Tparams")
1161     .Attr("Tparams: type")
1162     .Attr("Tindices: {int32,int64}")
__anondb9326b21702(InferenceContext* c) 1163     .SetShapeFn([](InferenceContext* c) {
1164       ShapeHandle unused;
1165       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &unused));
1166       ShapeHandle params_subshape;
1167       TF_RETURN_IF_ERROR(c->Subshape(c->input(0), 1, &params_subshape));
1168       ShapeHandle indices_shape = c->input(1);
1169       ShapeHandle out;
1170       TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, params_subshape, &out));
1171       c->set_output(0, out);
1172       return Status::OK();
1173     });
1174 
1175 // --------------------------------------------------------------------------
1176 REGISTER_OP("GatherV2")
1177     .Input("params: Tparams")
1178     .Input("indices: Tindices")
1179     .Input("axis: Taxis")
1180     .Attr("batch_dims: int = 0")
1181     .Output("output: Tparams")
1182     .Attr("Tparams: type")
1183     .Attr("Tindices: {int32,int64}")
1184     .Attr("Taxis: {int32,int64}")
__anondb9326b21802(InferenceContext* c) 1185     .SetShapeFn([](InferenceContext* c) {
1186       ShapeHandle params_shape;
1187       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &params_shape));
1188 
1189       ShapeHandle indices_shape = c->input(1);
1190       ShapeHandle unused_axis_shape;
1191       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_axis_shape));
1192       const Tensor* axis_t = c->input_tensor(2);
1193 
1194       // If axis is unknown, we can only infer that the result is params_rank +
1195       // indices_rank - 1.
1196       if (axis_t == nullptr) {
1197         if (c->RankKnown(params_shape) && c->RankKnown(indices_shape)) {
1198           c->set_output(0, c->UnknownShapeOfRank(c->Rank(params_shape) +
1199                                                  c->Rank(indices_shape) - 1));
1200         } else {
1201           c->set_output(0, c->UnknownShape());
1202         }
1203         return Status::OK();
1204       }
1205 
1206       // Note, axis can be negative.
1207       int64 axis = 0;
1208       if (axis_t->dtype() == DT_INT32) {
1209         axis = axis_t->scalar<int32>()();
1210       } else {
1211         axis = axis_t->scalar<int64>()();
1212       }
1213 
1214       // Check that params has rank of at least axis + 1.
1215       ShapeHandle unused;
1216       TF_RETURN_IF_ERROR(c->WithRankAtLeast(
1217           params_shape, axis < 0 ? -axis : axis + 1, &unused));
1218 
1219       // Note, batch_dims can be negative.
1220       int32 batch_dims;
1221       TF_RETURN_IF_ERROR(c->GetAttr("batch_dims", &batch_dims));
1222       // -rank(indices) <= batch_dims <= rank(indices)
1223       TF_RETURN_IF_ERROR(
1224           c->WithRankAtLeast(indices_shape, std::abs(batch_dims), &unused));
1225       if (batch_dims < 0) {
1226         batch_dims += c->Rank(indices_shape);
1227       }
1228       // rank(params) > batch_dims
1229       TF_RETURN_IF_ERROR(
1230           c->WithRankAtLeast(params_shape, batch_dims + 1, &unused));
1231 
1232       ShapeHandle params_outer_subshape;
1233       TF_RETURN_IF_ERROR(
1234           c->Subshape(params_shape, 0, axis, &params_outer_subshape));
1235 
1236       ShapeHandle indices_inner_subshape;
1237       TF_RETURN_IF_ERROR(
1238           c->Subshape(indices_shape, batch_dims, &indices_inner_subshape));
1239 
1240       ShapeHandle out;
1241       TF_RETURN_IF_ERROR(
1242           c->Concatenate(params_outer_subshape, indices_inner_subshape, &out));
1243 
1244       // Slice from axis + 1 to the end of params_shape to collect the inner
1245       // dimensions of the result. Special case -1 here since -1 + 1 wraps, and
1246       // we slice from 0 to the end of shape. Subshape() handles all other
1247       // out-of-bounds checking.
1248       if (axis != -1) {
1249         ShapeHandle params_inner_subshape;
1250         TF_RETURN_IF_ERROR(
1251             c->Subshape(params_shape, axis + 1, &params_inner_subshape));
1252         TF_RETURN_IF_ERROR(c->Concatenate(out, params_inner_subshape, &out));
1253       }
1254 
1255       c->set_output(0, out);
1256       return Status::OK();
1257     });
1258 
1259 // --------------------------------------------------------------------------
1260 REGISTER_OP("GatherNd")
1261     .Input("params: Tparams")
1262     .Input("indices: Tindices")
1263     .Output("output: Tparams")
1264     .Attr("Tparams: type")
1265     .Attr("Tindices: {int32,int64}")
1266     .SetShapeFn(shape_inference::GatherNdShape);
1267 
1268 // --------------------------------------------------------------------------
1269 REGISTER_OP("Identity")
1270     .Input("input: T")
1271     .Output("output: T")
1272     .Attr("T: type")
1273     .SetShapeFn(shape_inference::UnchangedShape);
1274 
1275 REGISTER_OP("Snapshot")
1276     .Input("input: T")
1277     .Output("output: T")
1278     .Attr("T: type")
1279     .SetShapeFn(shape_inference::UnchangedShape);
1280 
1281 #ifdef INTEL_MKL
1282 REGISTER_OP("_MklIdentity")
1283     .Input("input: T")
1284     .Input("mkl_input: uint8")
1285     .Output("output: T")
1286     .Output("mkl_output: uint8")
1287     .Attr("T: type")
1288     .SetShapeFn(shape_inference::UnchangedShape)
1289     .Doc(R"Doc( Mkl implementation of IdentityOp
1290 )Doc");
1291 #endif
1292 
1293 REGISTER_OP("IdentityN")
1294     .Input("input: T")
1295     .Output("output: T")
1296     .Attr("T: list(type)")
__anondb9326b21902(shape_inference::InferenceContext* c) 1297     .SetShapeFn([](shape_inference::InferenceContext* c) {
1298       std::vector<ShapeHandle> input;
1299       TF_RETURN_IF_ERROR(c->input("input", &input));
1300       TF_RETURN_IF_ERROR(c->set_output("output", input));
1301       // If any of the input shapes are not known, we should return error.
1302       for (int i = 0; i < input.size(); i++) {
1303         if (!input[i].Handle()) {
1304           return errors::InvalidArgument(absl::StrCat(
1305               "Cannot infer output shape #", i,
1306               " for IdentityN node because input shape #", i, " is unknown."));
1307         }
1308       }
1309       return Status::OK();
1310     });
1311 
1312 // --------------------------------------------------------------------------
1313 REGISTER_OP("RefIdentity")
1314     .Input("input: Ref(T)")
1315     .Output("output: Ref(T)")
1316     .Attr("T: type")
1317     .SetShapeFn(shape_inference::UnchangedShape)
1318     .SetAllowsUninitializedInput();
1319 
1320 // --------------------------------------------------------------------------
1321 REGISTER_OP("DebugGradientIdentity")
1322     .Input("input: T")
1323     .Output("output: T")
1324     .Attr("T: type")
1325     .SetShapeFn(shape_inference::UnchangedShape)
1326     .SetAllowsUninitializedInput();
1327 
1328 REGISTER_OP("DebugGradientRefIdentity")
1329     .Input("input: Ref(T)")
1330     .Output("output: Ref(T)")
1331     .Attr("T: type")
1332     .SetShapeFn(shape_inference::UnchangedShape)
1333     .SetAllowsUninitializedInput();
1334 
1335 // --------------------------------------------------------------------------
1336 REGISTER_OP("StopGradient")
1337     .Input("input: T")
1338     .Output("output: T")
1339     .Attr("T: type")
1340     .SetShapeFn(shape_inference::UnchangedShape);
1341 
1342 REGISTER_OP("PreventGradient")
1343     .Input("input: T")
1344     .Output("output: T")
1345     .Attr("T: type")
1346     .Attr("message: string = ''")
1347     .SetShapeFn(shape_inference::UnchangedShape);
1348 
1349 // --------------------------------------------------------------------------
1350 REGISTER_OP("CheckNumerics")
1351     .Input("tensor: T")
1352     .Output("output: T")
1353     .Attr("T: {bfloat16, half, float, double}")
1354     .Attr("message: string")
1355     .SetIsStateful()
1356     .SetShapeFn(shape_inference::UnchangedShape);
1357 
1358 // --------------------------------------------------------------------------
1359 REGISTER_OP("CheckNumericsV2")
1360     .Input("tensor: T")
1361     .Output("output: T")
1362     .Attr("T: {bfloat16, half, float, double}")
1363     .Attr("message: string")
1364     .SetIsStateful()
1365     .SetShapeFn(shape_inference::UnchangedShape);
1366 
1367 // --------------------------------------------------------------------------
1368 REGISTER_OP("Reshape")
1369     .Input("tensor: T")
1370     .Input("shape: Tshape")
1371     .Output("output: T")
1372     .Attr("T: type")
1373     .Attr("Tshape: {int32, int64} = DT_INT32")
__anondb9326b21a02(InferenceContext* c) 1374     .SetShapeFn([](InferenceContext* c) {
1375       return SetOutputShapeForReshape(c);
1376     });
1377 
1378 #ifdef INTEL_MKL
1379 REGISTER_OP("_MklReshape")
1380     .Input("tensor: T")
1381     .Input("shape: Tshape")
1382     .Input("mkl_tensor: uint8")
1383     .Input("mkl_shape: uint8")
1384     .Output("output: T")
1385     .Output("mkl_output: uint8")
1386     .Attr("T: type")
1387     .Attr("Tshape: {int32, int64} = DT_INT32")
__anondb9326b21b02(InferenceContext* c) 1388     .SetShapeFn([](InferenceContext* c) { return SetOutputShapeForReshape(c); })
1389     .Doc(R"Doc( MKL implementation of ReshapeOp.
1390 )Doc");
1391 #endif  // INTEL_MKL
1392 
1393 // --------------------------------------------------------------------------
1394 REGISTER_OP("InvertPermutation")
1395     .Input("x: T")
1396     .Output("y: T")
1397     .Attr("T: {int32, int64} = DT_INT32")
__anondb9326b21c02(InferenceContext* c) 1398     .SetShapeFn([](InferenceContext* c) {
1399       ShapeHandle x;
1400       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &x));
1401       c->set_output(0, x);
1402       return Status::OK();
1403     });
1404 
1405 // --------------------------------------------------------------------------
1406 REGISTER_OP("Transpose")
1407     .Input("x: T")
1408     .Input("perm: Tperm")
1409     .Output("y: T")
1410     .Attr("T: type")
1411     .Attr("Tperm: {int32, int64} = DT_INT32")
1412     .SetShapeFn(TransposeShapeFn);
1413 
1414 #ifdef INTEL_MKL
1415 REGISTER_OP("_MklTranspose")
1416     .Input("x: T")
1417     .Input("perm: Tperm")
1418     .Output("y: T")
1419     .Attr("T: type")
1420     .Attr("Tperm: {int32, int64} = DT_INT32")
1421     .SetShapeFn(TransposeShapeFn);
1422 #endif  // INTEL_MKL
1423 
1424 // --------------------------------------------------------------------------
1425 REGISTER_OP("ConjugateTranspose")
1426     .Input("x: T")
1427     .Input("perm: Tperm")
1428     .Output("y: T")
1429     .Attr("T: type")
1430     .Attr("Tperm: {int32, int64} = DT_INT32")
1431     .SetShapeFn(TransposeShapeFn);
1432 
1433 #ifdef INTEL_MKL
1434 REGISTER_OP("_MklConjugateTranspose")
1435     .Input("x: T")
1436     .Input("perm: Tperm")
1437     .Output("y: T")
1438     .Attr("T: type")
1439     .Attr("Tperm: {int32, int64} = DT_INT32")
1440     .SetShapeFn(TransposeShapeFn);
1441 #endif  // INTEL_MKL
1442 
1443 // --------------------------------------------------------------------------
1444 namespace {
UniqueIdxShapeFn(InferenceContext * c)1445 Status UniqueIdxShapeFn(InferenceContext* c) {
1446   ShapeHandle input = c->input(0);
1447   const Tensor* axis_t = c->input_tensor(1);
1448   if (axis_t == nullptr || !c->RankKnown(input)) {
1449     c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
1450     return Status::OK();
1451   }
1452 
1453   if (c->Rank(c->input(1)) != 1) {
1454     return errors::InvalidArgument("axis expects a 1D vector.");
1455   }
1456 
1457   int32 n = axis_t->NumElements();
1458   if (n == 0) {
1459     if (c->Rank(input) != 1) {
1460       return errors::InvalidArgument("x expects a 1D vector.");
1461     }
1462     c->set_output(1, input);
1463     return Status::OK();
1464   } else if (n == 1) {
1465     int64 axis;
1466     if (axis_t->dtype() == DT_INT32) {
1467       axis = static_cast<int64>(axis_t->flat<int32>()(0));
1468     } else {
1469       axis = axis_t->flat<int64>()(0);
1470     }
1471 
1472     int64 input_rank = c->Rank(input);
1473     if (axis < -input_rank || axis >= input_rank) {
1474       return errors::InvalidArgument("axis expects to be in the range [",
1475                                      -input_rank, ", ", input_rank, ")");
1476     }
1477     if (axis < 0) {
1478       axis += input_rank;
1479     }
1480     c->set_output(1, c->Vector(c->Dim(input, axis)));
1481     return Status::OK();
1482   }
1483   return errors::InvalidArgument(
1484       "axis does not support input tensors larger than 1 elements.");
1485 }
1486 }  // namespace
1487 
1488 REGISTER_OP("Unique")
1489     .Input("x: T")
1490     .Output("y: T")
1491     .Output("idx: out_idx")
1492     .Attr("T: type")
1493     .Attr("out_idx: {int32, int64} = DT_INT32")
__anondb9326b21e02(InferenceContext* c) 1494     .SetShapeFn([](InferenceContext* c) {
1495       c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
1496       c->set_output(1, c->input(0));
1497       // Assert that the input rank is 1.
1498       ShapeHandle dummy;
1499       return c->WithRank(c->input(0), 1, &dummy);
1500     });
1501 
1502 REGISTER_OP("UniqueV2")
1503     .Input("x: T")
1504     .Input("axis: Taxis")
1505     .Output("y: T")
1506     .Output("idx: out_idx")
1507     .Attr("T: type")
1508     .Attr("Taxis: {int32,int64} = DT_INT64")
1509     .Attr("out_idx: {int32, int64} = DT_INT32")
__anondb9326b21f02(InferenceContext* c) 1510     .SetShapeFn([](InferenceContext* c) {
1511       c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0))));
1512       TF_RETURN_IF_ERROR(UniqueIdxShapeFn(c));
1513       return Status::OK();
1514     });
1515 
1516 // --------------------------------------------------------------------------
1517 REGISTER_OP("UniqueWithCounts")
1518     .Input("x: T")
1519     .Output("y: T")
1520     .Output("idx: out_idx")
1521     .Output("count: out_idx")
1522     .Attr("T: type")
1523     .Attr("out_idx: {int32, int64} = DT_INT32")
__anondb9326b22002(InferenceContext* c) 1524     .SetShapeFn([](InferenceContext* c) {
1525       auto uniq = c->Vector(InferenceContext::kUnknownDim);
1526       c->set_output(0, uniq);
1527       c->set_output(1, c->input(0));
1528       c->set_output(2, uniq);
1529       return Status::OK();
1530     });
1531 
1532 REGISTER_OP("UniqueWithCountsV2")
1533     .Input("x: T")
1534     .Input("axis: Taxis")
1535     .Output("y: T")
1536     .Output("idx: out_idx")
1537     .Output("count: out_idx")
1538     .Attr("T: type")
1539     .Attr("Taxis: {int32,int64} = DT_INT64")
1540     .Attr("out_idx: {int32, int64} = DT_INT32")
__anondb9326b22102(InferenceContext* c) 1541     .SetShapeFn([](InferenceContext* c) {
1542       c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0))));
1543       TF_RETURN_IF_ERROR(UniqueIdxShapeFn(c));
1544       c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
1545       return Status::OK();
1546     });
1547 
1548 namespace {
1549 
ShapeShapeFn(InferenceContext * c)1550 Status ShapeShapeFn(InferenceContext* c) {
1551   for (int i = 0; i < c->num_inputs(); ++i) {
1552     DimensionHandle dim;
1553     if (c->RankKnown(c->input(i))) {
1554       dim = c->MakeDim(c->Rank(c->input(i)));
1555     } else {
1556       dim = c->UnknownDim();
1557     }
1558     c->set_output(i, c->Vector(dim));
1559   }
1560   return Status::OK();
1561 }
1562 
1563 }  // namespace
1564 
1565 // --------------------------------------------------------------------------
1566 REGISTER_OP("Shape")
1567     .Input("input: T")
1568     .Output("output: out_type")
1569     .Attr("T: type")
1570     .Attr("out_type: {int32, int64} = DT_INT32")
1571     .SetShapeFn(ShapeShapeFn);
1572 
1573 REGISTER_OP("ShapeN")
1574     .Input("input: N * T")
1575     .Output("output: N * out_type")
1576     .Attr("N: int")
1577     .Attr("T: type")
1578     .Attr("out_type: {int32, int64} = DT_INT32")
1579     .SetShapeFn(ShapeShapeFn);
1580 
1581 REGISTER_OP("EnsureShape")
1582     .Input("input: T")
1583     .Output("output: T")
1584     .Attr("shape: shape")
1585     .Attr("T: type")
__anondb9326b22302(InferenceContext* c) 1586     .SetShapeFn([](InferenceContext* c) {
1587       // Merges desired shape and statically known shape of input
1588       PartialTensorShape desired_shape;
1589       TF_RETURN_IF_ERROR(c->GetAttr("shape", &desired_shape));
1590 
1591       int rank = desired_shape.dims();
1592       ShapeHandle input_shape_handle;
1593       ShapeHandle desired_shape_handle;
1594       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape_handle));
1595       TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
1596           desired_shape, &desired_shape_handle));
1597 
1598       ShapeHandle merged_shape;
1599       TF_RETURN_IF_ERROR(
1600           c->Merge(desired_shape_handle, input_shape_handle, &merged_shape));
1601       c->set_output(0, merged_shape);
1602       return Status::OK();
1603     });
1604 
1605 // --------------------------------------------------------------------------
1606 REGISTER_OP("ReverseSequence")
1607     .Input("input: T")
1608     .Input("seq_lengths: Tlen")
1609     .Output("output: T")
1610     .Attr("seq_dim: int")
1611     .Attr("batch_dim: int = 0")
1612     .Attr("T: type")
1613     .Attr("Tlen: {int32, int64} = DT_INT64")
__anondb9326b22402(InferenceContext* c) 1614     .SetShapeFn([](InferenceContext* c) {
1615       ShapeHandle input = c->input(0);
1616       ShapeHandle seq_lens_shape;
1617       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seq_lens_shape));
1618 
1619       int64 seq_dim;
1620       TF_RETURN_IF_ERROR(c->GetAttr("seq_dim", &seq_dim));
1621       int64 batch_dim;
1622       TF_RETURN_IF_ERROR(c->GetAttr("batch_dim", &batch_dim));
1623 
1624       if (!c->RankKnown(input)) {
1625         return shape_inference::UnknownShape(c);
1626       }
1627 
1628       // Validate batch_dim and seq_dim against input.
1629       const int32 input_rank = c->Rank(input);
1630       if (batch_dim >= input_rank) {
1631         return errors::InvalidArgument(
1632             "batch_dim must be < input rank: ", batch_dim, " vs. ", input_rank);
1633       }
1634       if (seq_dim >= input_rank) {
1635         return errors::InvalidArgument(
1636             "seq_dim must be < input rank: ", seq_dim, " vs. ", input_rank);
1637       }
1638 
1639       DimensionHandle batch_dim_dim = c->Dim(input, batch_dim);
1640       TF_RETURN_IF_ERROR(
1641           c->Merge(batch_dim_dim, c->Dim(seq_lens_shape, 0), &batch_dim_dim));
1642 
1643       // Replace batch_dim of input with batch_size
1644       ShapeHandle output_shape;
1645       TF_RETURN_IF_ERROR(
1646           c->ReplaceDim(input, batch_dim, batch_dim_dim, &output_shape));
1647       c->set_output(0, output_shape);
1648       return Status::OK();
1649     });
1650 
1651 // --------------------------------------------------------------------------
1652 REGISTER_OP("Rank")
1653     .Input("input: T")
1654     .Output("output: int32")
1655     .Attr("T: type")
1656     .SetShapeFn(shape_inference::ScalarShape);
1657 
1658 // --------------------------------------------------------------------------
1659 REGISTER_OP("Size")
1660     .Input("input: T")
1661     .Output("output: out_type")
1662     .Attr("T: type")
1663     .Attr("out_type: {int32, int64} = DT_INT32")
1664     .SetShapeFn(shape_inference::ScalarShape);
1665 
1666 // --------------------------------------------------------------------------
1667 REGISTER_OP("Slice")
1668     .Input("input: T")
1669     .Input("begin: Index")
1670     .Input("size: Index")
1671     .Output("output: T")
1672     .Attr("T: type")
1673     .Attr("Index: {int32,int64}")
1674     .SetShapeFn(shape_inference::SliceShape);
1675 
1676 #ifdef INTEL_MKL
1677 REGISTER_OP("_MklSlice")
1678     .Input("input: T")
1679     .Input("begin: Index")
1680     .Input("size: Index")
1681     .Input("mkl_input: uint8")
1682     .Input("mkl_begin: uint8")
1683     .Input("mkl_size: uint8")
1684     .Output("output: T")
1685     .Output("mkl_output: uint8")
1686     .Attr("T: type")
1687     .Attr("Index: {int32,int64}")
1688     .SetShapeFn(shape_inference::SliceShape);
1689 #endif
1690 
1691 REGISTER_OP("StridedSlice")
1692     .Input("input: T")
1693     .Input("begin: Index")
1694     .Input("end: Index")
1695     .Input("strides: Index")
1696     .Output("output: T")
1697     .Attr("T: type")
1698     .Attr("Index: {int32, int64}")
1699     .Attr("begin_mask: int = 0")
1700     .Attr("end_mask: int = 0")
1701     .Attr("ellipsis_mask: int = 0")
1702     .Attr("new_axis_mask: int = 0")
1703     .Attr("shrink_axis_mask: int = 0")
__anondb9326b22502(InferenceContext* c) 1704     .SetShapeFn([](InferenceContext* c) {
1705       ShapeHandle input = c->input(0);
1706       ShapeHandle begin_shape, end_shape, strides_shape;
1707       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape));
1708       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &end_shape));
1709       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &strides_shape));
1710       TF_RETURN_IF_ERROR(c->Merge(begin_shape, end_shape, &begin_shape));
1711       TF_RETURN_IF_ERROR(c->Merge(begin_shape, strides_shape, &begin_shape));
1712       DimensionHandle sparse_dims_dim = c->Dim(begin_shape, 0);
1713 
1714       const Tensor* strides_value = c->input_tensor(3);
1715       // TODO(aselle,allenl): If we had a stride_mask it would be possible to do
1716       // more shape inference here (e.g. for x[3, ::T]).
1717       if (!c->RankKnown(input) || !c->ValueKnown(sparse_dims_dim) ||
1718           strides_value == nullptr) {
1719         c->set_output(0, c->UnknownShape());
1720         return Status::OK();
1721       }
1722 
1723       PartialTensorShape input_shape({});
1724       for (int i = 0; i < c->Rank(input); ++i) {
1725         auto dim = c->Dim(input, i);
1726         input_shape.AddDim(c->ValueKnown(dim) ? c->Value(dim) : -1);
1727       }
1728 
1729       int32 begin_mask, end_mask, ellipsis_mask, new_axis_mask,
1730           shrink_axis_mask;
1731       TF_RETURN_IF_ERROR(c->GetAttr("begin_mask", &begin_mask));
1732       TF_RETURN_IF_ERROR(c->GetAttr("end_mask", &end_mask));
1733       TF_RETURN_IF_ERROR(c->GetAttr("ellipsis_mask", &ellipsis_mask));
1734       TF_RETURN_IF_ERROR(c->GetAttr("new_axis_mask", &new_axis_mask));
1735       TF_RETURN_IF_ERROR(c->GetAttr("shrink_axis_mask", &shrink_axis_mask));
1736 
1737       const Tensor* begin_value = c->input_tensor(1);
1738       const Tensor* end_value = c->input_tensor(2);
1739 
1740       PartialTensorShape processing_shape, final_shape;
1741       bool is_identity, is_simple_slice, slice_dim0;
1742       gtl::InlinedVector<int64, 4> begin, end, strides;
1743       TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
1744           begin_value, end_value, *strides_value, input_shape, begin_mask,
1745           end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask,
1746           &processing_shape, &final_shape, &is_identity, &is_simple_slice,
1747           &slice_dim0, &begin, &end, &strides));
1748 
1749       ShapeHandle out;
1750       TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(final_shape, &out));
1751       c->set_output(0, out);
1752 
1753       auto* shape_and_type = c->input_handle_shapes_and_types(0);
1754       if (shape_and_type) {
1755         c->set_output_handle_shapes_and_types(0, *shape_and_type);
1756       }
1757 
1758       return Status::OK();
1759     });
1760 
1761 REGISTER_OP("StridedSliceGrad")
1762     .Input("shape: Index")
1763     .Input("begin: Index")
1764     .Input("end: Index")
1765     .Input("strides: Index")
1766     .Input("dy: T")
1767     .Output("output: T")
1768     .Attr("T: type")
1769     .Attr("Index: {int32, int64}")
1770     .Attr("begin_mask: int = 0")
1771     .Attr("end_mask: int = 0")
1772     .Attr("ellipsis_mask: int = 0")
1773     .Attr("new_axis_mask: int = 0")
1774     .Attr("shrink_axis_mask: int = 0")
__anondb9326b22602(InferenceContext* c) 1775     .SetShapeFn([](InferenceContext* c) {
1776       ShapeHandle out;
1777       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
1778       c->set_output(0, out);
1779       return Status::OK();
1780     });
1781 
1782 REGISTER_OP("StridedSliceAssign")
1783     .Input("ref: Ref(T)")
1784     .Input("begin: Index")
1785     .Input("end: Index")
1786     .Input("strides: Index")
1787     .Input("value: T")
1788     .Output("output_ref: Ref(T)")
1789     .Attr("T: type")
1790     .Attr("Index: {int32, int64}")
1791     .Attr("begin_mask: int = 0")
1792     .Attr("end_mask: int = 0")
1793     .Attr("ellipsis_mask: int = 0")
1794     .Attr("new_axis_mask: int = 0")
1795     .Attr("shrink_axis_mask: int = 0")
1796     .SetShapeFn(shape_inference::UnchangedShape);
1797 // TODO(aselle): Fix this documentation once StridedSliceAssign Supports
1798 // broadcasting.
1799 // --------------------------------------------------------------------------
1800 
1801 REGISTER_OP("ResourceStridedSliceAssign")
1802     .Input("ref: resource")
1803     .Input("begin: Index")
1804     .Input("end: Index")
1805     .Input("strides: Index")
1806     .Input("value: T")
1807     .Attr("T: type")
1808     .Attr("Index: {int32, int64}")
1809     .Attr("begin_mask: int = 0")
1810     .Attr("end_mask: int = 0")
1811     .Attr("ellipsis_mask: int = 0")
1812     .Attr("new_axis_mask: int = 0")
1813     .Attr("shrink_axis_mask: int = 0")
1814     .SetShapeFn(shape_inference::NoOutputs);
1815 
1816 REGISTER_OP("TensorStridedSliceUpdate")
1817     .Input("input: T")
1818     .Input("begin: Index")
1819     .Input("end: Index")
1820     .Input("strides: Index")
1821     .Input("value: T")
1822     .Output("output: T")
1823     .Attr("T: type")
1824     .Attr("Index: {int32, int64}")
1825     .Attr("begin_mask: int = 0")
1826     .Attr("end_mask: int = 0")
1827     .Attr("ellipsis_mask: int = 0")
1828     .Attr("new_axis_mask: int = 0")
1829     .Attr("shrink_axis_mask: int = 0")
1830     .SetShapeFn(shape_inference::UnchangedShape);
1831 
1832 REGISTER_OP("Tile")
1833     .Input("input: T")
1834     .Input("multiples: Tmultiples")
1835     .Output("output: T")
1836     .Attr("T: type")
1837     .Attr("Tmultiples: {int32, int64} = DT_INT32")
__anondb9326b22702(InferenceContext* c) 1838     .SetShapeFn([](InferenceContext* c) {
1839       ShapeHandle input = c->input(0);
1840       // NOTE(mrry): Represent `multiples` as a `TensorShape` because (i)
1841       // it is a vector of non-negative integers, and (ii) doing so allows
1842       // us to handle partially-known multiples.
1843       ShapeHandle multiples;
1844       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &multiples));
1845       if (c->RankKnown(input)) {
1846         TF_RETURN_IF_ERROR(c->WithRank(multiples, c->Rank(input), &multiples));
1847         ShapeHandle dummy;
1848         TF_RETURN_IF_ERROR(
1849             c->Merge(c->input(1), c->Vector(c->Rank(input)), &dummy));
1850       }
1851 
1852       if (!c->RankKnown(multiples)) {
1853         return shape_inference::UnknownShape(c);
1854       }
1855 
1856       int32 rank = c->Rank(multiples);
1857       TF_RETURN_IF_ERROR(c->WithRank(input, rank, &input));
1858       std::vector<DimensionHandle> dims(rank);
1859       for (int i = 0; i < rank; ++i) {
1860         TF_RETURN_IF_ERROR(
1861             c->Multiply(c->Dim(input, i), c->Dim(multiples, i), &dims[i]));
1862       }
1863       c->set_output(0, c->MakeShape(dims));
1864       return Status::OK();
1865     });
1866 
1867 // --------------------------------------------------------------------------
1868 REGISTER_OP("TileGrad")
1869     .Input("input: T")
1870     .Input("multiples: int32")
1871     .Output("output: T")
1872     .Attr("T: type")
1873     .Deprecated(3, "TileGrad has been replaced with reduce_sum")
1874     .SetShapeFn(tensorflow::shape_inference::UnknownShape);
1875 
1876 // --------------------------------------------------------------------------
1877 REGISTER_OP("Where")
1878     .Input("input: T")
1879     .Attr("T: {numbertype, bool} = DT_BOOL")
1880     .Output("index: int64")
__anondb9326b22802(InferenceContext* c) 1881     .SetShapeFn([](InferenceContext* c) {
1882       c->set_output(0, c->Matrix(c->UnknownDim(), c->Rank(c->input(0))));
1883       return Status::OK();
1884     });
1885 
1886 // --------------------------------------------------------------------------
1887 REGISTER_OP("BroadcastArgs")
1888     .Input("s0: T")
1889     .Input("s1: T")
1890     .Output("r0: T")
1891     .Attr("T: {int32, int64} = DT_INT32")
__anondb9326b22902(InferenceContext* c) 1892     .SetShapeFn([](InferenceContext* c) {
1893       ShapeHandle unused;
1894       ShapeHandle shape_x = c->input(0);
1895       ShapeHandle shape_y = c->input(1);
1896       TF_RETURN_IF_ERROR(c->WithRank(shape_x, 1, &unused));
1897       TF_RETURN_IF_ERROR(c->WithRank(shape_y, 1, &unused));
1898 
1899       if (!c->ValueKnown(c->Dim(shape_x, 0)) ||
1900           !c->ValueKnown(c->Dim(shape_y, 0))) {
1901         c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
1902         return Status::OK();
1903       }
1904 
1905       int64 x_dim = c->Value(c->Dim(shape_x, 0));
1906       int64 y_dim = c->Value(c->Dim(shape_y, 0));
1907 
1908       // Broadcasted shape is going to be as large as the largest dimension.
1909       c->set_output(0, c->Vector(std::max(x_dim, y_dim)));
1910       return Status::OK();
1911     });
1912 
1913 // --------------------------------------------------------------------------
1914 REGISTER_OP("BroadcastGradientArgs")
1915     .Input("s0: T")
1916     .Input("s1: T")
1917     .Output("r0: T")
1918     .Output("r1: T")
1919     .Attr("T: {int32, int64} = DT_INT32")
__anondb9326b22a02(InferenceContext* c) 1920     .SetShapeFn([](InferenceContext* c) {
1921       // TODO(mrry): Implement constant_value for BroadcastGradientArgs?
1922       ShapeHandle unused;
1923       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
1924       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
1925       c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
1926       c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
1927       return Status::OK();
1928     });
1929 
1930 // --------------------------------------------------------------------------
1931 REGISTER_OP("Pad")
1932     .Input("input: T")
1933     .Input("paddings: Tpaddings")
1934     .Output("output: T")
1935     .Attr("T: type")
1936     .Attr("Tpaddings: {int32, int64} = DT_INT32")
1937     .SetShapeFn(PadShapeFn);
1938 
1939 // --------------------------------------------------------------------------
1940 REGISTER_OP("PadV2")
1941     .Input("input: T")
1942     .Input("paddings: Tpaddings")
1943     .Input("constant_values: T")
1944     .Output("output: T")
1945     .Attr("T: type")
1946     .Attr("Tpaddings: {int32, int64} = DT_INT32")
1947     .SetShapeFn(PadShapeFn);
1948 
1949 // --------------------------------------------------------------------------
1950 REGISTER_OP("MirrorPad")
1951     .Input("input: T")
1952     .Input("paddings: Tpaddings")
1953     .Output("output: T")
1954     .Attr("T: type")
1955     .Attr("Tpaddings: {int32, int64} = DT_INT32")
1956     .Attr(GetMirrorPadModeAttrString())
1957     .SetShapeFn(PadShapeFn);
1958 
1959 // --------------------------------------------------------------------------
1960 namespace {
1961 template <typename T>
MirrorPadKnown(InferenceContext * c,ShapeHandle input,const Tensor * paddings_t,int64 input_rank)1962 Status MirrorPadKnown(InferenceContext* c, ShapeHandle input,
1963                       const Tensor* paddings_t, int64 input_rank) {
1964   auto paddings_data = paddings_t->matrix<T>();
1965   std::vector<DimensionHandle> dims(input_rank);
1966   for (int64 i = 0; i < input_rank; ++i) {
1967     const int64 pad0 = static_cast<int64>(paddings_data(i, 0));
1968     const int64 pad1 = static_cast<int64>(paddings_data(i, 1));
1969     if (pad0 < 0 || pad1 < 0) {
1970       return errors::InvalidArgument("Paddings must be non-negative");
1971     }
1972 
1973     TF_RETURN_IF_ERROR(c->Subtract(c->Dim(input, i), pad0 + pad1, &dims[i]));
1974   }
1975   c->set_output(0, c->MakeShape(dims));
1976   return Status::OK();
1977 }
1978 
1979 }  // namespace
1980 
1981 REGISTER_OP("MirrorPadGrad")
1982     .Input("input: T")
1983     .Input("paddings: Tpaddings")
1984     .Output("output: T")
1985     .Attr("T: type")
1986     .Attr("Tpaddings: {int32, int64} = DT_INT32")
1987     .Attr(GetMirrorPadModeAttrString())
__anondb9326b22c02(InferenceContext* c) 1988     .SetShapeFn([](InferenceContext* c) {
1989       ShapeHandle paddings;
1990       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &paddings));
1991       DimensionHandle pad_0 = c->Dim(paddings, 0);
1992       if (!c->ValueKnown(pad_0)) {
1993         // We don't know the rank of the output since the first
1994         // padding dimension is unknown.
1995         c->set_output(0, c->UnknownShape());
1996         return Status::OK();
1997       }
1998 
1999       int64 input_rank = c->Value(pad_0);
2000       ShapeHandle input;
2001       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), input_rank, &input));
2002       TF_RETURN_IF_ERROR(
2003           c->Merge(paddings, c->Matrix(input_rank, 2), &paddings));
2004 
2005       const Tensor* paddings_t = c->input_tensor(1);
2006       if (paddings_t == nullptr) {
2007         // Values of 'paddings' is not available, but we know the
2008         // input rank, so return the rank of the output with unknown
2009         // dimensions.
2010         c->set_output(0, c->UnknownShapeOfRank(input_rank));
2011         return Status::OK();
2012       }
2013 
2014       if (paddings_t->dtype() == DT_INT32) {
2015         return MirrorPadKnown<int32>(c, input, paddings_t, input_rank);
2016       } else {
2017         return MirrorPadKnown<int64>(c, input, paddings_t, input_rank);
2018       }
2019     });
2020 
2021 // --------------------------------------------------------------------------
2022 REGISTER_OP("Placeholder")
2023     .Output("output: dtype")
2024     .Attr("dtype: type")
2025     .Attr("shape: shape = { unknown_rank: true }")
__anondb9326b22d02(InferenceContext* c) 2026     .SetShapeFn([](InferenceContext* c) {
2027       PartialTensorShape shape;
2028       TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
2029 
2030       // Placeholder has legacy behavior where we cannot tell the difference
2031       // between a scalar shape attribute and 'unknown shape'.  So if the shape
2032       // is a scalar, we return an unknown shape.
2033       if (c->graph_def_version() <= 21 && shape.dims() <= 0) {
2034         return shape_inference::UnknownShape(c);
2035       }
2036 
2037       ShapeHandle out;
2038       TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out));
2039       c->set_output(0, out);
2040       return Status::OK();
2041     });
2042 
2043 // Placeholder was modified in a backwards compatible way to do what
2044 // PlaceholderV2 did, so we have deprecated V2 (no one was really
2045 // using it).
2046 REGISTER_OP("PlaceholderV2")
2047     .Output("output: dtype")
2048     .Attr("dtype: type")
2049     .Attr("shape: shape")
2050     .SetShapeFn(shape_inference::ExplicitShape)
2051     .Deprecated(23, "Placeholder now behaves the same as PlaceholderV2.");
2052 
2053 // --------------------------------------------------------------------------
2054 REGISTER_OP("PlaceholderWithDefault")
2055     .Input("input: dtype")
2056     .Output("output: dtype")
2057     .Attr("dtype: type")
2058     .Attr("shape: shape")
__anondb9326b22e02(InferenceContext* c) 2059     .SetShapeFn([](InferenceContext* c) {
2060       ShapeHandle input = c->input(0);
2061       PartialTensorShape shape;
2062       TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
2063       ShapeHandle out;
2064       TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out));
2065 
2066       // We merge for compatibility checking, but return the output,
2067       // since output_shape may be less precise than input_shape.
2068       ShapeHandle unused;
2069       TF_RETURN_IF_ERROR(c->Merge(input, out, &unused));
2070       c->set_output(0, out);
2071       return Status::OK();
2072     });
2073 
2074 // --------------------------------------------------------------------------
2075 REGISTER_OP("ExpandDims")
2076     .Input("input: T")
2077     .Input("dim: Tdim")
2078     .Output("output: T")
2079     .Attr("T: type")
2080     .Attr("Tdim: {int32, int64} = DT_INT32")
__anondb9326b22f02(InferenceContext* c) 2081     .SetShapeFn([](InferenceContext* c) {
2082       ShapeHandle input = c->input(0);
2083 
2084       const Tensor* dim_t = c->input_tensor(1);
2085       if (dim_t != nullptr && dim_t->NumElements() != 1) {
2086         return errors::InvalidArgument(
2087             "'dim' input must be a tensor with a single value");
2088       }
2089       if (dim_t == nullptr || !c->RankKnown(input)) {
2090         c->set_output(0, c->UnknownShape());
2091         return Status::OK();
2092       }
2093 
2094       int64 dim;
2095       if (dim_t->dtype() == DT_INT32) {
2096         dim = static_cast<int64>(dim_t->flat<int32>()(0));
2097       } else {
2098         dim = dim_t->flat<int64>()(0);
2099       }
2100 
2101       const int32 rank = c->Rank(input);
2102       const int32 min_dim = -1 * rank - 1;
2103       if (dim < min_dim || dim > rank) {
2104         return errors::InvalidArgument("dim ", dim, " not in the interval [",
2105                                        min_dim, ", ", rank, "].");
2106       }
2107 
2108       if (dim < 0) {
2109         dim += rank + 1;
2110       }
2111 
2112       ShapeHandle end;
2113       TF_RETURN_IF_ERROR(c->Subshape(input, dim, &end));
2114 
2115       // Build output as start + 1 + end.
2116       ShapeHandle output;
2117       TF_RETURN_IF_ERROR(c->Subshape(input, 0, dim, &output));
2118       TF_RETURN_IF_ERROR(c->Concatenate(output, c->Vector(1), &output));
2119       TF_RETURN_IF_ERROR(c->Concatenate(output, end, &output));
2120       c->set_output(0, output);
2121       return Status::OK();
2122     });
2123 
2124 // --------------------------------------------------------------------------
2125 REGISTER_OP("Squeeze")
2126     .Input("input: T")
2127     .Output("output: T")
2128     .Attr("T: type")
2129     .Attr("squeeze_dims: list(int) >= 0 = []")
__anondb9326b23002(InferenceContext* c) 2130     .SetShapeFn([](InferenceContext* c) {
2131       ShapeHandle input = c->input(0);
2132       if (!c->RankKnown(input)) {
2133         // Input shape unknown.
2134         return shape_inference::UnknownShape(c);
2135       }
2136 
2137       const int32 input_rank = c->Rank(input);
2138 
2139       // Validate and wrap squeeze dimensions.
2140       std::vector<int32> squeeze_dims;
2141       TF_RETURN_IF_ERROR(c->GetAttr("squeeze_dims", &squeeze_dims));
2142       for (int i = 0; i < squeeze_dims.size(); ++i) {
2143         if (squeeze_dims[i] < -input_rank || squeeze_dims[i] >= input_rank) {
2144           return errors::InvalidArgument("squeeze_dims[", i, "] not in [",
2145                                          -input_rank, ",", input_rank, ").");
2146         }
2147 
2148         if (squeeze_dims[i] < 0) {
2149           squeeze_dims[i] += input_rank;
2150         }
2151       }
2152 
2153       std::vector<DimensionHandle> result_shape;
2154       for (int i = 0; i < input_rank; ++i) {
2155         // True if squeeze_dims contains an entry to squeeze this
2156         // dimension.
2157         bool is_explicit_match =
2158             std::find(squeeze_dims.begin(), squeeze_dims.end(), i) !=
2159             squeeze_dims.end();
2160 
2161         DimensionHandle dim = c->Dim(input, i);
2162 
2163         if (!c->ValueKnown(dim)) {
2164           // Assume that the squeezed dimension will be 1 at runtime.
2165           if (is_explicit_match) continue;
2166 
2167           // If squeezing all 1 dimensions, and we see an unknown value,
2168           // give up and return Unknown Shape.
2169           if (squeeze_dims.empty()) {
2170             c->set_output(0, c->UnknownShape());
2171             return Status::OK();
2172           }
2173         } else if (c->Value(dim) == 1) {
2174           if (is_explicit_match || squeeze_dims.empty()) {
2175             // If explicitly squeezing, or squeezing all 1s, remove
2176             // this dimension.
2177             continue;
2178           }
2179         } else if (is_explicit_match) {
2180           return errors::InvalidArgument("Can not squeeze dim[", i,
2181                                          "], expected a dimension of 1, got ",
2182                                          c->Value(c->Dim(input, i)));
2183         }
2184 
2185         result_shape.emplace_back(dim);
2186       }
2187 
2188       c->set_output(0, c->MakeShape(result_shape));
2189       return Status::OK();
2190     });
2191 
2192 // --------------------------------------------------------------------------
2193 REGISTER_OP("ListDiff")
2194     .Input("x: T")
2195     .Input("y: T")
2196     .Output("out: T")
2197     .Output("idx: out_idx")
2198     .Attr("T: type")
2199     .Attr("out_idx: {int32, int64} = DT_INT32")
__anondb9326b23102(InferenceContext* c) 2200     .SetShapeFn([](InferenceContext* c) {
2201       ShapeHandle unused;
2202       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
2203       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
2204       // TODO(mrry): Indicate that the length falls within an interval?
2205       ShapeHandle out = c->Vector(InferenceContext::kUnknownDim);
2206       c->set_output(0, out);
2207       c->set_output(1, out);
2208       return Status::OK();
2209     });
2210 
2211 namespace {
2212 
2213 // Converts Tensor to flat std::vector<int64>.
2214 template <typename InputType>
GetFlatInt64(const Tensor & t)2215 std::vector<int64> GetFlatInt64(const Tensor& t) {
2216   std::vector<int64> output(t.shape().num_elements());
2217   if (t.shape().num_elements() > 0) {
2218     auto eigen_vec = t.flat<InputType>();
2219     std::copy_n(&eigen_vec(0), output.size(), output.begin());
2220   }
2221   return output;
2222 }
2223 
2224 // Converts int32 or int64 Tensor to flat std::vector<int64>.
GetFlatInt64(const Tensor & t)2225 std::vector<int64> GetFlatInt64(const Tensor& t) {
2226   if (t.dtype() == DT_INT32) {
2227     return GetFlatInt64<int32>(t);
2228   } else {
2229     return GetFlatInt64<int64>(t);
2230   }
2231 }
2232 
SpaceToBatchShapeHelper(InferenceContext * c,ShapeHandle input_shape,ShapeHandle block_shape_shape,const Tensor * block_shape_t,ShapeHandle paddings_shape,const Tensor * paddings_t)2233 Status SpaceToBatchShapeHelper(InferenceContext* c, ShapeHandle input_shape,
2234                                ShapeHandle block_shape_shape,
2235                                const Tensor* block_shape_t,
2236                                ShapeHandle paddings_shape,
2237                                const Tensor* paddings_t) {
2238   if (c->Rank(block_shape_shape) != 1) {
2239     return errors::InvalidArgument("block_shape must have rank 1.");
2240   }
2241 
2242   const DimensionHandle num_block_dims_handle = c->Dim(block_shape_shape, 0);
2243   if (!c->ValueKnown(num_block_dims_handle)) {
2244     return errors::InvalidArgument("block_shape must have known size.");
2245   }
2246 
2247   const int64 num_block_dims = c->Value(num_block_dims_handle);
2248 
2249   TF_RETURN_IF_ERROR(
2250       c->WithRankAtLeast(input_shape, num_block_dims + 1, &input_shape));
2251 
2252   TF_RETURN_IF_ERROR(
2253       c->Merge(paddings_shape, c->Matrix(num_block_dims, 2), &paddings_shape));
2254 
2255   DimensionHandle batch_size = c->Dim(input_shape, 0);
2256   std::vector<int64> block_shape_vec;
2257   if (block_shape_t && (block_shape_t->NumElements() > 0)) {
2258     block_shape_vec = GetFlatInt64(*block_shape_t);
2259     for (int64 dim = 0; dim < num_block_dims; ++dim) {
2260       const int64 block_shape_value = block_shape_vec[dim];
2261       if (block_shape_value < 1) {
2262         return errors::InvalidArgument("block_shape must be positive");
2263       }
2264       if (c->ValueKnown(batch_size)) {
2265         TF_RETURN_IF_ERROR(
2266             c->Multiply(batch_size, block_shape_value, &batch_size));
2267       } else {
2268         batch_size = c->UnknownDim();
2269       }
2270     }
2271   } else if (num_block_dims > 0) {
2272     batch_size = c->UnknownDim();
2273   }
2274 
2275   std::vector<DimensionHandle> output_dims{batch_size};
2276   output_dims.resize(num_block_dims + 1, c->UnknownDim());
2277 
2278   if (paddings_t && (paddings_t->NumElements() > 0)) {
2279     const std::vector<int64> paddings_vec = GetFlatInt64(*paddings_t);
2280     for (int64 dim = 0; dim < num_block_dims; ++dim) {
2281       const int64 pad_start = paddings_vec[dim * 2],
2282                   pad_end = paddings_vec[dim * 2 + 1];
2283       if (pad_start < 0 || pad_end < 0) {
2284         return errors::InvalidArgument("paddings cannot be negative");
2285       }
2286       if (block_shape_t) {
2287         DimensionHandle padded_size;
2288         TF_RETURN_IF_ERROR(
2289             c->Add(c->Dim(input_shape, dim + 1), pad_start, &padded_size));
2290         TF_RETURN_IF_ERROR(c->Add(padded_size, pad_end, &padded_size));
2291         TF_RETURN_IF_ERROR(c->Divide(padded_size, block_shape_vec[dim],
2292                                      /*evenly_divisible=*/true,
2293                                      &output_dims[dim + 1]));
2294       }
2295     }
2296   }
2297 
2298   ShapeHandle remaining_input_shape;
2299   TF_RETURN_IF_ERROR(
2300       c->Subshape(input_shape, 1 + num_block_dims, &remaining_input_shape));
2301 
2302   ShapeHandle result;
2303   TF_RETURN_IF_ERROR(c->Concatenate(c->MakeShape(output_dims),
2304                                     remaining_input_shape, &result));
2305   c->set_output(0, result);
2306   return Status::OK();
2307 }
2308 
BatchToSpaceShapeHelper(InferenceContext * c,ShapeHandle input_shape,ShapeHandle block_shape_shape,const Tensor * block_shape_t,ShapeHandle crops_shape,const Tensor * crops_t)2309 Status BatchToSpaceShapeHelper(InferenceContext* c, ShapeHandle input_shape,
2310                                ShapeHandle block_shape_shape,
2311                                const Tensor* block_shape_t,
2312                                ShapeHandle crops_shape, const Tensor* crops_t) {
2313   if (c->Rank(block_shape_shape) != 1) {
2314     return errors::InvalidArgument("block_shape must have rank 1.");
2315   }
2316 
2317   const DimensionHandle num_block_dims_handle = c->Dim(block_shape_shape, 0);
2318   if (!c->ValueKnown(num_block_dims_handle)) {
2319     return errors::InvalidArgument("block_shape must have known size.");
2320   }
2321 
2322   const int64 num_block_dims = c->Value(num_block_dims_handle);
2323 
2324   TF_RETURN_IF_ERROR(
2325       c->WithRankAtLeast(input_shape, num_block_dims + 1, &input_shape));
2326 
2327   TF_RETURN_IF_ERROR(
2328       c->Merge(crops_shape, c->Matrix(num_block_dims, 2), &crops_shape));
2329 
2330   DimensionHandle batch_size = c->Dim(input_shape, 0);
2331   std::vector<int64> block_shape_vec;
2332   if (block_shape_t) {
2333     block_shape_vec = GetFlatInt64(*block_shape_t);
2334     for (int64 dim = 0; dim < num_block_dims; ++dim) {
2335       const int64 block_shape_value = block_shape_vec[dim];
2336       if (block_shape_value < 1) {
2337         return errors::InvalidArgument("block_shape must be positive");
2338       }
2339       if (c->ValueKnown(batch_size)) {
2340         TF_RETURN_IF_ERROR(c->Divide(batch_size, block_shape_value,
2341                                      /*evenly_divisible=*/true, &batch_size));
2342       } else {
2343         batch_size = c->UnknownDim();
2344       }
2345     }
2346   } else if (num_block_dims > 0) {
2347     batch_size = c->UnknownDim();
2348   }
2349 
2350   std::vector<DimensionHandle> output_dims{batch_size};
2351   output_dims.resize(num_block_dims + 1, c->UnknownDim());
2352 
2353   if (crops_t) {
2354     const std::vector<int64> crops_vec = GetFlatInt64(*crops_t);
2355     for (int64 dim = 0; dim < num_block_dims; ++dim) {
2356       const int64 crop_start = crops_vec[dim * 2],
2357                   crop_end = crops_vec[dim * 2 + 1];
2358       if (crop_start < 0 || crop_end < 0) {
2359         return errors::InvalidArgument("crops cannot be negative");
2360       }
2361       if (block_shape_t) {
2362         DimensionHandle cropped_size;
2363         TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input_shape, dim + 1),
2364                                        block_shape_vec[dim], &cropped_size));
2365         TF_RETURN_IF_ERROR(
2366             c->Subtract(cropped_size, crop_start, &cropped_size));
2367         TF_RETURN_IF_ERROR(
2368             c->Subtract(cropped_size, crop_end, &output_dims[dim + 1]));
2369       }
2370     }
2371   }
2372 
2373   ShapeHandle remaining_input_shape;
2374   TF_RETURN_IF_ERROR(
2375       c->Subshape(input_shape, 1 + num_block_dims, &remaining_input_shape));
2376 
2377   ShapeHandle result;
2378   TF_RETURN_IF_ERROR(c->Concatenate(c->MakeShape(output_dims),
2379                                     remaining_input_shape, &result));
2380   c->set_output(0, result);
2381   return Status::OK();
2382 }
2383 
2384 }  // namespace
2385 
2386 // --------------------------------------------------------------------------
2387 REGISTER_OP("SpaceToBatchND")
2388     .Input("input: T")
2389     .Input("block_shape: Tblock_shape")
2390     .Input("paddings: Tpaddings")
2391     .Output("output: T")
2392     .Attr("T: type")
2393     .Attr("Tblock_shape: {int32, int64} = DT_INT32")
2394     .Attr("Tpaddings: {int32, int64} = DT_INT32")
__anondb9326b23302(InferenceContext* c) 2395     .SetShapeFn([](InferenceContext* c) {
2396       return SpaceToBatchShapeHelper(c, c->input(0), c->input(1),
2397                                      c->input_tensor(1), c->input(2),
2398                                      c->input_tensor(2));
2399     });
2400 
2401 // --------------------------------------------------------------------------
2402 REGISTER_OP("SpaceToBatch")
2403     .Input("input: T")
2404     .Input("paddings: Tpaddings")
2405     .Output("output: T")
2406     .Attr("T: type")
2407     .Attr("Tpaddings: {int32, int64} = DT_INT32")
2408     .Attr("block_size: int >= 2")
__anondb9326b23402(InferenceContext* c) 2409     .SetShapeFn([](InferenceContext* c) {
2410       ShapeHandle input_shape;
2411       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
2412 
2413       int32 block_size;
2414       TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
2415 
2416       Tensor block_shape(tensorflow::DT_INT64, TensorShape({2}));
2417       auto block_shape_vec = block_shape.vec<int64>();
2418       block_shape_vec(0) = block_size;
2419       block_shape_vec(1) = block_size;
2420 
2421       return SpaceToBatchShapeHelper(c, input_shape, c->MakeShape({2}),
2422                                      &block_shape, c->input(1),
2423                                      c->input_tensor(1));
2424     });
2425 
2426 // --------------------------------------------------------------------------
2427 REGISTER_OP("BatchToSpaceND")
2428     .Input("input: T")
2429     .Input("block_shape: Tblock_shape")
2430     .Input("crops: Tcrops")
2431     .Output("output: T")
2432     .Attr("T: type")
2433     .Attr("Tblock_shape: {int32, int64} = DT_INT32")
2434     .Attr("Tcrops: {int32, int64} = DT_INT32")
__anondb9326b23502(InferenceContext* c) 2435     .SetShapeFn([](InferenceContext* c) {
2436       return BatchToSpaceShapeHelper(c, c->input(0), c->input(1),
2437                                      c->input_tensor(1), c->input(2),
2438                                      c->input_tensor(2));
2439     });
2440 
2441 // --------------------------------------------------------------------------
2442 REGISTER_OP("BatchToSpace")
2443     .Input("input: T")
2444     .Input("crops: Tidx")
2445     .Output("output: T")
2446     .Attr("T: type")
2447     .Attr("block_size: int >= 2")
2448     .Attr("Tidx: {int32, int64} = DT_INT32")
__anondb9326b23602(InferenceContext* c) 2449     .SetShapeFn([](InferenceContext* c) {
2450       ShapeHandle input_shape;
2451       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
2452 
2453       int32 block_size;
2454       TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
2455 
2456       Tensor block_shape(tensorflow::DT_INT64, TensorShape({2}));
2457       auto block_shape_vec = block_shape.vec<int64>();
2458       block_shape_vec(0) = block_size;
2459       block_shape_vec(1) = block_size;
2460 
2461       return BatchToSpaceShapeHelper(c, input_shape, c->MakeShape({2}),
2462                                      &block_shape, c->input(1),
2463                                      c->input_tensor(1));
2464     });
2465 
2466 // --------------------------------------------------------------------------
2467 REGISTER_OP("SpaceToDepth")
2468     .Input("input: T")
2469     .Output("output: T")
2470     .Attr("T: type")
2471     .Attr("block_size: int >= 2")
2472     .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
2473     // TODO(pauldonnelly): Implement GPU kernels for NCHW_VECT_C.
__anondb9326b23702(InferenceContext* c) 2474     .SetShapeFn([](InferenceContext* c) {
2475       string data_format_str;
2476       TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
2477       TensorFormat data_format;
2478       FormatFromString(data_format_str, &data_format);
2479 
2480       constexpr int num_spatial_dims = 2;
2481       const int dims =
2482           GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
2483       ShapeHandle input;
2484       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), dims, &input));
2485 
2486       int32 block_size;
2487       TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
2488 
2489       DimensionHandle batch_size =
2490           c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
2491       DimensionHandle input_height =
2492           c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
2493       DimensionHandle input_width =
2494           c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
2495       DimensionHandle input_depth =
2496           c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
2497 
2498       DimensionHandle output_height;
2499       DimensionHandle output_width;
2500       DimensionHandle output_depth;
2501       // Will return an error if input height or width are not evenly divisible.
2502       TF_RETURN_IF_ERROR(c->Divide(input_height, block_size,
2503                                    true /* evenly_divisible */,
2504                                    &output_height));
2505       TF_RETURN_IF_ERROR(c->Divide(input_width, block_size,
2506                                    true /* evenly_divisible */, &output_width));
2507 
2508       TF_RETURN_IF_ERROR(
2509           c->Multiply(input_depth, block_size * block_size, &output_depth));
2510 
2511       ShapeHandle output_shape;
2512       TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size,
2513                                              {output_height, output_width},
2514                                              output_depth, &output_shape, c));
2515 
2516       c->set_output(0, output_shape);
2517       return Status::OK();
2518     });
2519 
2520 // --------------------------------------------------------------------------
2521 REGISTER_OP("DepthToSpace")
2522     .Input("input: T")
2523     .Output("output: T")
2524     .Attr("T: type")
2525     .Attr("block_size: int >= 2")
2526     .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
2527     // TODO(pauldonnelly): Implement GPU kernels for NCHW and NCHW_VECT_C.
__anondb9326b23802(InferenceContext* c) 2528     .SetShapeFn([](InferenceContext* c) {
2529       string data_format_str;
2530       TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
2531       TensorFormat data_format;
2532       FormatFromString(data_format_str, &data_format);
2533 
2534       constexpr int num_spatial_dims = 2;
2535       const int dims =
2536           GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
2537 
2538       ShapeHandle input;
2539       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), dims, &input));
2540 
2541       int32 block_size;
2542       TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
2543 
2544       DimensionHandle batch_size =
2545           c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
2546       DimensionHandle input_height =
2547           c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
2548       DimensionHandle input_width =
2549           c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
2550       DimensionHandle input_depth =
2551           c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
2552 
2553       DimensionHandle output_height;
2554       DimensionHandle output_width;
2555       DimensionHandle output_depth;
2556       TF_RETURN_IF_ERROR(c->Multiply(input_height, block_size, &output_height));
2557       TF_RETURN_IF_ERROR(c->Multiply(input_width, block_size, &output_width));
2558 
2559       // Will return an error if input_depth is not evenly divisible.
2560       TF_RETURN_IF_ERROR(c->Divide(input_depth, block_size * block_size,
2561                                    true /* evenly_divisible */, &output_depth));
2562 
2563       ShapeHandle output_shape;
2564       TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size,
2565                                              {output_height, output_width},
2566                                              output_depth, &output_shape, c));
2567 
2568       c->set_output(0, output_shape);
2569       return Status::OK();
2570     });
2571 
2572 // --------------------------------------------------------------------------
2573 
2574 REGISTER_OP("ExtractImagePatches")
2575     .Input("images: T")
2576     .Output("patches: T")
2577     .Attr("ksizes: list(int) >= 4")
2578     .Attr("strides: list(int) >= 4")
2579     .Attr("rates: list(int) >= 4")
2580     .Attr(
2581         "T: {bfloat16, half, float, double, int8, int16, int32, int64, "
2582         "uint8, uint16, uint32, uint64, complex64, complex128, bool}")
2583     .Attr(GetPaddingAttrString())
__anondb9326b23902(InferenceContext* c) 2584     .SetShapeFn([](InferenceContext* c) {
2585       ShapeHandle input_shape;
2586       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
2587 
2588       std::vector<int32> ksizes;
2589       TF_RETURN_IF_ERROR(c->GetAttr("ksizes", &ksizes));
2590       if (ksizes.size() != 4) {
2591         return errors::InvalidArgument(
2592             "ExtractImagePatches requires the ksizes attribute to contain 4 "
2593             "values, but got: ",
2594             ksizes.size());
2595       }
2596 
2597       std::vector<int32> strides;
2598       TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
2599       if (strides.size() != 4) {
2600         return errors::InvalidArgument(
2601             "ExtractImagePatches requires the stride attribute to contain 4 "
2602             "values, but got: ",
2603             strides.size());
2604       }
2605 
2606       std::vector<int32> rates;
2607       TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates));
2608       if (rates.size() != 4) {
2609         return errors::InvalidArgument(
2610             "ExtractImagePatches requires the rates attribute to contain 4 "
2611             "values, but got: ",
2612             rates.size());
2613       }
2614 
2615       int32 ksize_rows = ksizes[1];
2616       int32 ksize_cols = ksizes[2];
2617 
2618       int32 stride_rows = strides[1];
2619       int32 stride_cols = strides[2];
2620 
2621       int32 rate_rows = rates[1];
2622       int32 rate_cols = rates[2];
2623 
2624       int32 ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
2625       int32 ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
2626 
2627       DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
2628       DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
2629       DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
2630       DimensionHandle output_depth_dim;
2631       TF_RETURN_IF_ERROR(c->Multiply(
2632           c->Dim(input_shape, 3), ksize_rows * ksize_cols, &output_depth_dim));
2633 
2634       if (!c->ValueKnown(in_rows_dim) || !c->ValueKnown(in_cols_dim)) {
2635         ShapeHandle output_shape =
2636             c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim,
2637                           InferenceContext::kUnknownDim, output_depth_dim});
2638         c->set_output(0, output_shape);
2639         return Status::OK();
2640       }
2641       auto in_rows = c->Value(in_rows_dim);
2642       auto in_cols = c->Value(in_cols_dim);
2643 
2644       Padding padding;
2645       TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
2646 
2647       int64 output_rows, output_cols;
2648       int64 padding_before, padding_after;
2649       TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
2650           in_rows, ksize_rows_eff, stride_rows, padding, &output_rows,
2651           &padding_before, &padding_after));
2652       TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
2653           in_cols, ksize_cols_eff, stride_cols, padding, &output_cols,
2654           &padding_before, &padding_after));
2655       ShapeHandle output_shape = c->MakeShape(
2656           {batch_size_dim, output_rows, output_cols, output_depth_dim});
2657       c->set_output(0, output_shape);
2658       return Status::OK();
2659     });
2660 
2661 // --------------------------------------------------------------------------
2662 
2663 // To enable rates, uncomment all lines commented below and use ksize_*_eff
2664 // as the second parameter of all GetWindowedOutputSizeVerbose calls instead
2665 // of ksize_*.
2666 REGISTER_OP("ExtractVolumePatches")
2667     .Input("input: T")
2668     .Output("patches: T")
2669     .Attr("ksizes: list(int) >= 5")
2670     .Attr("strides: list(int) >= 5")
2671     /* .Attr("rates: list(int) >= 5") */
2672     .Attr("T: realnumbertype")
2673     .Attr(GetPaddingAttrString())
__anondb9326b23a02(InferenceContext* c) 2674     .SetShapeFn([](InferenceContext* c) {
2675       ShapeHandle input_shape;
2676       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
2677 
2678       std::vector<int32> ksizes;
2679       TF_RETURN_IF_ERROR(c->GetAttr("ksizes", &ksizes));
2680       if (ksizes.size() != 5) {
2681         return errors::InvalidArgument(
2682             "ExtractVolumePatches requires the ksizes attribute to contain 5 "
2683             "values, but got: ",
2684             ksizes.size());
2685       }
2686 
2687       std::vector<int32> strides;
2688       TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
2689       if (strides.size() != 5) {
2690         return errors::InvalidArgument(
2691             "ExtractVolumePatches requires the stride attribute to contain 5 "
2692             "values, but got: ",
2693             strides.size());
2694       }
2695 
2696       /*
2697       // TODO(hsgkim): Enable rates.
2698       // See extract_volume_patches_op.cc for why rates are disabled now.
2699 
2700       std::vector<int32> rates;
2701       TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates));
2702       if (rates.size() != 5) {
2703         return errors::InvalidArgument(
2704             "ExtractVolumePatches requires the rates attribute to contain 5 "
2705             "values, but got: ",
2706             rates.size());
2707       }
2708       */
2709 
2710       int32 ksize_planes = ksizes[1];
2711       int32 ksize_rows = ksizes[2];
2712       int32 ksize_cols = ksizes[3];
2713 
2714       int32 stride_planes = strides[1];
2715       int32 stride_rows = strides[2];
2716       int32 stride_cols = strides[3];
2717 
2718       /*
2719       int32 rate_planes = rates[1];
2720       int32 rate_rows = rates[2];
2721       int32 rate_cols = rates[3];
2722 
2723       int32 ksize_planes_eff = ksize_planes +
2724                                (ksize_planes - 1) * (rate_planes - 1);
2725       int32 ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
2726       int32 ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
2727       */
2728 
2729       DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
2730       DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
2731       DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
2732       DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
2733       DimensionHandle output_depth_dim;
2734       TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input_shape, 4),
2735                                      ksize_planes * ksize_rows * ksize_cols,
2736                                      &output_depth_dim));
2737 
2738       if (!c->ValueKnown(in_planes_dim) || !c->ValueKnown(in_rows_dim) ||
2739           !c->ValueKnown(in_cols_dim)) {
2740         ShapeHandle output_shape =
2741             c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim,
2742                           InferenceContext::kUnknownDim, output_depth_dim});
2743         c->set_output(0, output_shape);
2744         return Status::OK();
2745       }
2746       auto in_planes = c->Value(in_planes_dim);
2747       auto in_rows = c->Value(in_rows_dim);
2748       auto in_cols = c->Value(in_cols_dim);
2749 
2750       Padding padding;
2751       TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
2752 
2753       int64 output_planes, output_rows, output_cols;
2754       int64 padding_before, padding_after;
2755       TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
2756           in_planes, ksize_planes, stride_planes, padding, &output_planes,
2757           &padding_before, &padding_after));
2758       TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
2759           in_rows, ksize_rows, stride_rows, padding, &output_rows,
2760           &padding_before, &padding_after));
2761       TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
2762           in_cols, ksize_cols, stride_cols, padding, &output_cols,
2763           &padding_before, &padding_after));
2764       ShapeHandle output_shape =
2765           c->MakeShape({batch_size_dim, output_planes, output_rows, output_cols,
2766                         output_depth_dim});
2767       c->set_output(0, output_shape);
2768       return Status::OK();
2769     });
2770 
2771 // --------------------------------------------------------------------------
2772 
2773 REGISTER_OP("OneHot")
2774     .Input("indices: TI")
2775     .Input("depth: int32")
2776     .Input("on_value: T")
2777     .Input("off_value: T")
2778     .Attr("axis: int = -1")
2779     .Output("output: T")
2780     .Attr("T: type")
2781     .Attr("TI: {uint8, int32, int64} = DT_INT64")
__anondb9326b23b02(InferenceContext* c) 2782     .SetShapeFn([](InferenceContext* c) {
2783       int32 axis;
2784       TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis));
2785       if (axis < -1) return errors::InvalidArgument("axis must be >= -1");
2786 
2787       DimensionHandle depth;
2788       TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &depth));
2789 
2790       ShapeHandle indices = c->input(0);
2791       if (!c->RankKnown(indices)) return shape_inference::UnknownShape(c);
2792 
2793       int32 new_rank = c->Rank(indices) + 1;
2794       // We need to add new_rank to axis in the case the axis is -1 because
2795       // C++ returns negative values from % if the dividend is negative.
2796       int32 depth_index = (axis + new_rank) % new_rank;
2797       // Out shape is indices[0:depth_index] + [depth] + indices[depth_index:].
2798       ShapeHandle front;
2799       ShapeHandle back;
2800       ShapeHandle out;
2801       TF_RETURN_IF_ERROR(c->Subshape(indices, 0, depth_index, &front));
2802       TF_RETURN_IF_ERROR(c->Subshape(indices, depth_index, &back));
2803       TF_RETURN_IF_ERROR(c->Concatenate(front, c->Vector(depth), &front));
2804       TF_RETURN_IF_ERROR(c->Concatenate(front, back, &out));
2805       c->set_output(0, out);
2806       return Status::OK();
2807     });
2808 
2809 // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET.
2810 REGISTER_OP("QuantizeAndDequantize")
2811     .Input("input: T")
2812     .Attr("signed_input: bool = true")
2813     .Attr("num_bits: int = 8")
2814     .Attr("range_given: bool = false")
2815     .Attr("input_min: float = 0")
2816     .Attr("input_max: float = 0")
2817     .Output("output: T")
2818     .Attr("T: {bfloat16, half, float, double}")
2819     .SetShapeFn(shape_inference::UnchangedShape)
2820     .Deprecated(22, "Replaced by QuantizeAndDequantizeV2");
2821 
2822 // TODO(suharshs): Deprecate QuantizeAndDequantizeV2.
2823 REGISTER_OP("QuantizeAndDequantizeV2")
2824     .Input("input: T")
2825     .Input("input_min: T")
2826     .Input("input_max: T")
2827     .Attr("signed_input: bool = true")
2828     .Attr("num_bits: int = 8")
2829     .Attr("range_given: bool = false")
2830     .Output("output: T")
2831     .Attr("T: {bfloat16, half, float, double}")
2832     .Attr(
2833         "round_mode: {'HALF_TO_EVEN', 'HALF_UP'} = "
2834         "'HALF_TO_EVEN'")
2835     .Attr("narrow_range: bool = false")
2836     .Attr("axis: int = -1")
__anondb9326b23c02(InferenceContext* c) 2837     .SetShapeFn([](InferenceContext* c) {
2838       int axis;
2839       TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis));
2840       const int minmax_rank = (axis == -1) ? 0 : 1;
2841       ShapeHandle minmax;
2842       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), minmax_rank, &minmax));
2843       TF_RETURN_IF_ERROR(c->Merge(c->input(2), minmax, &minmax));
2844       if (axis != -1) {
2845         ShapeHandle input;
2846         TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
2847         DimensionHandle depth;
2848         TF_RETURN_IF_ERROR(
2849             c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth));
2850       }
2851       c->set_output(0, c->input(0));
2852       return Status::OK();
2853     });
2854 
2855 REGISTER_OP("QuantizeAndDequantizeV4")
2856     .Input("input: T")
2857     .Input("input_min: T")
2858     .Input("input_max: T")
2859     .Attr("signed_input: bool = true")
2860     .Attr("num_bits: int = 8")
2861     .Attr("range_given: bool = false")
2862     .Output("output: T")
2863     .Attr("T: {bfloat16, half, float, double}")
2864     .Attr(
2865         "round_mode: {'HALF_TO_EVEN', 'HALF_UP'} = "
2866         "'HALF_TO_EVEN'")
2867     .Attr("narrow_range: bool = false")
2868     .Attr("axis: int = -1")
__anondb9326b23d02(InferenceContext* c) 2869     .SetShapeFn([](InferenceContext* c) {
2870       int axis;
2871       TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis));
2872       const int minmax_rank = (axis == -1) ? 0 : 1;
2873       ShapeHandle minmax;
2874       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), minmax_rank, &minmax));
2875       TF_RETURN_IF_ERROR(c->Merge(c->input(2), minmax, &minmax));
2876       if (axis != -1) {
2877         ShapeHandle input;
2878         TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
2879         DimensionHandle depth;
2880         TF_RETURN_IF_ERROR(
2881             c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth));
2882       }
2883       c->set_output(0, c->input(0));
2884       return Status::OK();
2885     });
2886 
2887 REGISTER_OP("QuantizeAndDequantizeV4Grad")
2888     .Input("gradients: T")
2889     .Input("input: T")
2890     .Input("input_min: T")
2891     .Input("input_max: T")
2892     .Output("input_backprop: T")
2893     .Output("input_min_backprop: T")
2894     .Output("input_max_backprop: T")
2895     .Attr("T: {bfloat16, half, float, double}")
2896     .Attr("axis: int = -1")
__anondb9326b23e02(InferenceContext* c) 2897     .SetShapeFn([](InferenceContext* c) {
2898       int axis;
2899       TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis));
2900       const int minmax_rank = (axis == -1) ? 0 : 1;
2901       ShapeHandle minmax;
2902       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), minmax_rank, &minmax));
2903       TF_RETURN_IF_ERROR(c->Merge(c->input(3), minmax, &minmax));
2904       if (axis != -1) {
2905         ShapeHandle input;
2906         TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
2907         DimensionHandle depth;
2908         TF_RETURN_IF_ERROR(
2909             c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth));
2910       }
2911       ShapeHandle inputs;
2912       TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &inputs));
2913       c->set_output(0, inputs);
2914       c->set_output(1, minmax);
2915       c->set_output(2, minmax);
2916       return Status::OK();
2917     });
2918 
2919 REGISTER_OP("QuantizeAndDequantizeV3")
2920     .Input("input: T")
2921     .Input("input_min: T")
2922     .Input("input_max: T")
2923     .Input("num_bits: int32")
2924     .Attr("signed_input: bool = true")
2925     .Attr("range_given: bool = true")
2926     .Output("output: T")
2927     .Attr("T: {bfloat16, half, float, double}")
2928     .Attr("narrow_range: bool = false")
2929     .Attr("axis: int = -1")
__anondb9326b23f02(InferenceContext* c) 2930     .SetShapeFn([](InferenceContext* c) {
2931       int axis;
2932       TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis));
2933       const int minmax_rank = (axis == -1) ? 0 : 1;
2934       ShapeHandle minmax;
2935       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), minmax_rank, &minmax));
2936       TF_RETURN_IF_ERROR(c->Merge(c->input(2), minmax, &minmax));
2937       if (axis != -1) {
2938         ShapeHandle input;
2939         TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
2940         DimensionHandle depth;
2941         TF_RETURN_IF_ERROR(
2942             c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth));
2943       }
2944       ShapeHandle unused;
2945       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2946       c->set_output(0, c->input(0));
2947       return Status::OK();
2948     });
2949 
2950 REGISTER_OP("QuantizeV2")
2951     .Input("input: float")
2952     .Input("min_range: float")
2953     .Input("max_range: float")
2954     .Output("output: T")
2955     .Output("output_min: float")
2956     .Output("output_max: float")
2957     .Attr("T: quantizedtype")
2958     .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'MIN_COMBINED'")
2959     .Attr(
2960         "round_mode: {'HALF_AWAY_FROM_ZERO', 'HALF_TO_EVEN'} = "
2961         "'HALF_AWAY_FROM_ZERO'")
2962     .Attr("narrow_range: bool = false")
2963     .Attr("axis: int = -1")
2964     .Attr("ensure_minimum_range: float = 0.01")
__anondb9326b24002(InferenceContext* c) 2965     .SetShapeFn([](InferenceContext* c) {
2966       int axis = -1;
2967       Status s = c->GetAttr("axis", &axis);
2968       if (!s.ok() && s.code() != error::NOT_FOUND) {
2969         return s;
2970       }
2971       const int minmax_rank = (axis == -1) ? 0 : 1;
2972       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
2973       ShapeHandle minmax;
2974       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), minmax_rank, &minmax));
2975       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), minmax_rank, &minmax));
2976       if (axis != -1) {
2977         ShapeHandle input;
2978         TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
2979         DimensionHandle depth;
2980         TF_RETURN_IF_ERROR(
2981             c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth));
2982       }
2983       c->set_output(1, minmax);
2984       c->set_output(2, minmax);
2985       return Status::OK();
2986     });
2987 
2988 REGISTER_OP("Dequantize")
2989     .Input("input: T")
2990     .Input("min_range: float")
2991     .Input("max_range: float")
2992     .Output("output: dtype")
2993     .Attr("T: quantizedtype")
2994     .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'MIN_COMBINED'")
2995     .Attr("narrow_range: bool = false")
2996     .Attr("axis: int = -1")
2997     .Attr("dtype: {bfloat16, float} = DT_FLOAT")
__anondb9326b24102(InferenceContext* c) 2998     .SetShapeFn([](InferenceContext* c) {
2999       int axis = -1;
3000       Status s = c->GetAttr("axis", &axis);
3001       if (!s.ok() && s.code() != error::NOT_FOUND) {
3002         return s;
3003       }
3004       const int minmax_rank = (axis == -1) ? 0 : 1;
3005       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
3006       ShapeHandle minmax;
3007       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), minmax_rank, &minmax));
3008       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), minmax_rank, &minmax));
3009       if (axis != -1) {
3010         ShapeHandle input;
3011         TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
3012         DimensionHandle depth;
3013         TF_RETURN_IF_ERROR(
3014             c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth));
3015       }
3016       return Status::OK();
3017     });
3018 
3019 REGISTER_OP("QuantizedConcat")
3020     .Input("concat_dim: int32")
3021     .Input("values: N * T")
3022     .Input("input_mins: N * float32")
3023     .Input("input_maxes: N * float32")
3024     .Output("output: T")
3025     .Output("output_min: float")
3026     .Output("output_max: float")
3027     .Attr("N: int >= 2")
3028     .Attr("T: type")
__anondb9326b24202(InferenceContext* c) 3029     .SetShapeFn([](InferenceContext* c) {
3030       const int n = (c->num_inputs() - 1) / 3;
3031       TF_RETURN_IF_ERROR(shape_inference::ConcatShape(c, n));
3032       ShapeHandle unused;
3033       for (int i = n + 1; i < c->num_inputs(); ++i) {
3034         TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused));
3035       }
3036       c->set_output(1, c->Scalar());
3037       c->set_output(2, c->Scalar());
3038       return Status::OK();
3039     });
3040 
3041 REGISTER_OP("QuantizedReshape")
3042     .Input("tensor: T")
3043     .Input("shape: Tshape")
3044     .Input("input_min: float")
3045     .Input("input_max: float")
3046     .Output("output: T")
3047     .Output("output_min: float")
3048     .Output("output_max: float")
3049     .Attr("T: type")
3050     .Attr("Tshape: {int32, int64} = DT_INT32")
__anondb9326b24302(InferenceContext* c) 3051     .SetShapeFn([](InferenceContext* c) {
3052       TF_RETURN_IF_ERROR(SetOutputShapeForReshape(c));
3053       ShapeHandle unused;
3054       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
3055       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3056       c->set_output(1, c->Scalar());
3057       c->set_output(2, c->Scalar());
3058       return Status::OK();
3059     });
3060 
3061 REGISTER_OP("QuantizedInstanceNorm")
3062     .Input("x: T")
3063     .Input("x_min: float")
3064     .Input("x_max: float")
3065     .Output("y: T")
3066     .Output("y_min: float")
3067     .Output("y_max: float")
3068     .Attr("T: quantizedtype")
3069     .Attr("output_range_given: bool = false")
3070     .Attr("given_y_min: float = 0")
3071     .Attr("given_y_max: float = 0")
3072     .Attr("variance_epsilon: float = 1e-5")
3073     .Attr("min_separation: float = 1e-3")
__anondb9326b24402(shape_inference::InferenceContext* c) 3074     .SetShapeFn([](shape_inference::InferenceContext* c) {
3075       shape_inference::ShapeHandle unused;
3076       // x should be a rank 4 tensor.
3077       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &unused));
3078       // Assert x_min and x_max are scalars (rank 0).
3079       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
3080       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
3081       // y has the same shape as x.
3082       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
3083       // y_min and y_max are scalars.
3084       c->set_output(1, c->Scalar());
3085       c->set_output(2, c->Scalar());
3086       return Status::OK();
3087     });
3088 
3089 namespace {
3090 
ScatterNdTensorShape(InferenceContext * c)3091 Status ScatterNdTensorShape(InferenceContext* c) {
3092   ShapeHandle output_shape;
3093   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &output_shape));
3094   ShapeHandle indices_shape;
3095   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape));
3096   ShapeHandle updates_shape;
3097   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape));
3098   return shape_inference::ScatterNdShapeHelper(c, indices_shape, updates_shape,
3099                                                output_shape);
3100 }
3101 
3102 }  // namespace
3103 
3104 REGISTER_OP("UpperBound")
3105     .Input("sorted_inputs: T")
3106     .Input("values: T")
3107     .Output("output: out_type")
3108     .Attr("T: type")
3109     .Attr("out_type: {int32, int64} = DT_INT32")
__anondb9326b24602(InferenceContext* c) 3110     .SetShapeFn([](InferenceContext* c) {
3111       ShapeHandle unused_shape;
3112       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused_shape));
3113       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape));
3114       c->set_output(0, c->input(1));
3115       return Status::OK();
3116     });
3117 
3118 REGISTER_OP("LowerBound")
3119     .Input("sorted_inputs: T")
3120     .Input("values: T")
3121     .Output("output: out_type")
3122     .Attr("T: type")
3123     .Attr("out_type: {int32, int64} = DT_INT32")
__anondb9326b24702(InferenceContext* c) 3124     .SetShapeFn([](InferenceContext* c) {
3125       ShapeHandle unused_shape;
3126       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused_shape));
3127       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape));
3128       c->set_output(0, c->input(1));
3129       return Status::OK();
3130     });
3131 
3132 REGISTER_OP("ScatterNd")
3133     .Input("indices: Tindices")
3134     .Input("updates: T")
3135     .Input("shape: Tindices")
3136     .Output("output: T")
3137     .Attr("T: type")
3138     .Attr("Tindices: {int32, int64}")
__anondb9326b24802(InferenceContext* c) 3139     .SetShapeFn([](InferenceContext* c) {
3140       ShapeHandle indices_shape;
3141       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &indices_shape));
3142       ShapeHandle updates_shape;
3143       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &updates_shape));
3144       ShapeHandle output_shape;
3145       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &output_shape));
3146       return shape_inference::ScatterNdShapeHelper(c, indices_shape,
3147                                                    updates_shape, output_shape);
3148     });
3149 
3150 REGISTER_OP("TensorScatterUpdate")
3151     .Input("tensor: T")
3152     .Input("indices: Tindices")
3153     .Input("updates: T")
3154     .Output("output: T")
3155     .Attr("T: type")
3156     .Attr("Tindices: {int32, int64}")
3157     .SetShapeFn(ScatterNdTensorShape);
3158 
3159 REGISTER_OP("TensorScatterAdd")
3160     .Input("tensor: T")
3161     .Input("indices: Tindices")
3162     .Input("updates: T")
3163     .Output("output: T")
3164     .Attr("T: type")
3165     .Attr("Tindices: {int32, int64}")
3166     .SetShapeFn(ScatterNdTensorShape);
3167 
3168 REGISTER_OP("TensorScatterSub")
3169     .Input("tensor: T")
3170     .Input("indices: Tindices")
3171     .Input("updates: T")
3172     .Output("output: T")
3173     .Attr("T: type")
3174     .Attr("Tindices: {int32, int64}")
3175     .SetShapeFn(ScatterNdTensorShape);
3176 
3177 REGISTER_OP("TensorScatterMin")
3178     .Input("tensor: T")
3179     .Input("indices: Tindices")
3180     .Input("updates: T")
3181     .Output("output: T")
3182     .Attr("T: type")
3183     .Attr("Tindices: {int32, int64}")
3184     .SetShapeFn(ScatterNdTensorShape);
3185 
3186 REGISTER_OP("TensorScatterMax")
3187     .Input("tensor: T")
3188     .Input("indices: Tindices")
3189     .Input("updates: T")
3190     .Output("output: T")
3191     .Attr("T: type")
3192     .Attr("Tindices: {int32, int64}")
3193     .SetShapeFn(ScatterNdTensorShape);
3194 
3195 REGISTER_OP("ScatterNdNonAliasingAdd")
3196     .Input("input: T")
3197     .Input("indices: Tindices")
3198     .Input("updates: T")
3199     .Output("output: T")
3200     .Attr("T: {numbertype, bool}")
3201     .Attr("Tindices: {int32, int64}")
3202     .SetShapeFn(ScatterNdTensorShape);
3203 
3204 REGISTER_OP("FakeQuantWithMinMaxArgs")
3205     .Attr("min: float = -6.0")
3206     .Attr("max: float = 6.0")
3207     .Attr("num_bits: int = 8")
3208     .Attr("narrow_range: bool = false")
3209     .Input("inputs: float")
3210     .Output("outputs: float")
3211     .SetShapeFn(shape_inference::UnchangedShape);
3212 
3213 REGISTER_OP("FakeQuantWithMinMaxArgsGradient")
3214     .Attr("min: float = -6.0")
3215     .Attr("max: float = 6.0")
3216     .Attr("num_bits: int = 8")
3217     .Attr("narrow_range: bool = false")
3218     .Input("gradients: float")
3219     .Input("inputs: float")
3220     .Output("backprops: float")
3221     .SetShapeFn(shape_inference::UnchangedShape);
3222 
3223 REGISTER_OP("FakeQuantWithMinMaxVars")
3224     .Attr("num_bits: int = 8")
3225     .Attr("narrow_range: bool = false")
3226     .Input("inputs: float")
3227     .Input("min: float")
3228     .Input("max: float")
3229     .Output("outputs: float")
__anondb9326b24902(InferenceContext* c) 3230     .SetShapeFn([](InferenceContext* c) {
3231       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
3232       ShapeHandle unused;
3233       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
3234       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
3235       return Status::OK();
3236     });
3237 
3238 REGISTER_OP("FakeQuantWithMinMaxVarsGradient")
3239     .Attr("num_bits: int = 8")
3240     .Attr("narrow_range: bool = false")
3241     .Input("gradients: float")
3242     .Input("inputs: float")
3243     .Input("min: float")
3244     .Input("max: float")
3245     .Output("backprops_wrt_input: float")
3246     .Output("backprop_wrt_min: float")
3247     .Output("backprop_wrt_max: float")
__anondb9326b24a02(InferenceContext* c) 3248     .SetShapeFn([](InferenceContext* c) {
3249       // gradients and inputs are same size.
3250       ShapeHandle inputs;
3251       TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &inputs));
3252 
3253       // min and max are scalars
3254       ShapeHandle min_max;
3255       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_max));
3256       TF_RETURN_IF_ERROR(c->Merge(min_max, c->input(3), &min_max));
3257 
3258       c->set_output(0, inputs);
3259       c->set_output(1, min_max);
3260       c->set_output(2, min_max);
3261       return Status::OK();
3262     });
3263 
3264 REGISTER_OP("FakeQuantWithMinMaxVarsPerChannel")
3265     .Attr("num_bits: int = 8")
3266     .Attr("narrow_range: bool = false")
3267     .Input("inputs: float")
3268     .Input("min: float")
3269     .Input("max: float")
3270     .Output("outputs: float")
__anondb9326b24b02(InferenceContext* c) 3271     .SetShapeFn([](InferenceContext* c) {
3272       ShapeHandle input, min, max;
3273       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
3274       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &min));
3275       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &max));
3276 
3277       DimensionHandle unused;
3278       TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -1), c->Dim(min, 0), &unused));
3279       TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -1), c->Dim(max, 0), &unused));
3280       TF_RETURN_IF_ERROR(c->Merge(c->Dim(min, 0), c->Dim(max, 0), &unused));
3281 
3282       c->set_output(0, input);
3283       return Status::OK();
3284     });
3285 
3286 REGISTER_OP("FakeQuantWithMinMaxVarsPerChannelGradient")
3287     .Attr("num_bits: int = 8")
3288     .Attr("narrow_range: bool = false")
3289     .Input("gradients: float")
3290     .Input("inputs: float")
3291     .Input("min: float")
3292     .Input("max: float")
3293     .Output("backprops_wrt_input: float")
3294     .Output("backprop_wrt_min: float")
3295     .Output("backprop_wrt_max: float")
__anondb9326b24c02(InferenceContext* c) 3296     .SetShapeFn([](InferenceContext* c) {
3297       ShapeHandle inputs;
3298       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &inputs));
3299       TF_RETURN_IF_ERROR(c->WithRankAtMost(inputs, 4, &inputs));
3300       TF_RETURN_IF_ERROR(c->Merge(inputs, c->input(1), &inputs));
3301 
3302       ShapeHandle last_dim = c->Vector(c->Dim(inputs, -1));
3303 
3304       ShapeHandle min_max;
3305       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &min_max));
3306       TF_RETURN_IF_ERROR(c->Merge(min_max, last_dim, &min_max));
3307       TF_RETURN_IF_ERROR(c->Merge(c->input(3), min_max, &min_max));
3308 
3309       c->set_output(0, inputs);
3310       c->set_output(1, min_max);
3311       c->set_output(2, min_max);
3312       return Status::OK();
3313     });
3314 
3315 REGISTER_OP("Fingerprint")
3316     .Input("data: T")
3317     .Input("method: string")
3318     .Output("fingerprint: uint8")
3319     .Attr("T: type")
__anondb9326b24d02(InferenceContext* c) 3320     .SetShapeFn([](InferenceContext* c) {
3321       ShapeHandle unused;
3322       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &unused));
3323       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
3324 
3325       DimensionHandle fingerprint_size;
3326       const Tensor* method = c->input_tensor(1);
3327       if (method == nullptr) {
3328         fingerprint_size = c->UnknownDim();
3329       } else {
3330         if (method->dims() != 0) {
3331           return errors::InvalidArgument("`method` must be rank 0: ",
3332                                          method->shape());
3333         }
3334         const string& method_string = method->scalar<tstring>()();
3335         if (method_string != "farmhash64") {
3336           return errors::InvalidArgument("Unsupported method: ", method_string);
3337         }
3338         fingerprint_size = c->MakeDim(sizeof(uint64));
3339       }
3340 
3341       DimensionHandle batch = c->Dim(c->input(0), 0);
3342       c->set_output(0, c->MakeShape({batch, fingerprint_size}));
3343       return Status::OK();
3344     });
3345 
3346 #ifdef INTEL_MKL
3347 REGISTER_OP("_MklConcat")
3348     .Input("concat_dim: int32")
3349     .Input("values: N * T")
3350     .Input("mkl_concat_dim: uint8")
3351     .Input("mkl_values: N * uint8")
3352     .Output("output: T")
3353     .Output("mkl_output: uint8")
3354     .Attr("N: int >= 2")
3355     .Attr("T: type")
__anondb9326b24e02(InferenceContext* c) 3356     .SetShapeFn([](InferenceContext* c) {
3357       return shape_inference::ConcatShape(c, c->num_inputs() - 3);
3358     })
3359     .Doc(R"doc(
3360 MKL version of Concat operator. Uses MKL DNN APIs to perform concatenation.
3361 
3362 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
3363 expected to invoke these operators.
3364 )doc");
3365 #endif
3366 
3367 // Deprecated op registrations:
3368 
3369 // The following can be deleted after 10mar2017.
3370 REGISTER_OP("BatchMatrixDiag")
3371     .Input("diagonal: T")
3372     .Output("output: T")
3373     .Attr("T: type")
3374     .Deprecated(14, "Use MatrixDiag")
3375     .SetShapeFn(shape_inference::UnknownShape);
3376 REGISTER_OP("BatchMatrixSetDiag")
3377     .Input("input: T")
3378     .Input("diagonal: T")
3379     .Output("output: T")
3380     .Attr("T: type")
3381     .Deprecated(14, "Use MatrixSetDiag")
3382     .SetShapeFn(shape_inference::UnknownShape);
3383 REGISTER_OP("BatchMatrixDiagPart")
3384     .Input("input: T")
3385     .Output("diagonal: T")
3386     .Attr("T: type")
3387     .Deprecated(14, "Use MatrixDiagPart")
3388     .SetShapeFn(shape_inference::UnknownShape);
3389 REGISTER_OP("BatchMatrixBandPart")
3390     .Input("input: T")
3391     .Input("num_lower: int64")
3392     .Input("num_upper: int64")
3393     .Output("band: T")
3394     .Attr("T: type")
3395     .Deprecated(14, "Use MatrixBandPart")
3396     .SetShapeFn(shape_inference::UnknownShape);
3397 
3398 }  // namespace tensorflow
3399