1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19 #include "tensorflow/core/framework/tensor.pb.h"
20 #include "tensorflow/core/util/mirror_pad_mode.h"
21 #include "tensorflow/core/util/padding.h"
22 #include "tensorflow/core/util/strided_slice_op.h"
23 #include "tensorflow/core/util/tensor_format.h"
24 
25 namespace tensorflow {
26 
27 using shape_inference::DimensionHandle;
28 using shape_inference::InferenceContext;
29 using shape_inference::ShapeHandle;
30 using shape_inference::UnchangedShape;
31 
32 namespace {
33 
GetAxisForPackAndUnpack(InferenceContext * c,int32 rank_after_pack,int32 * axis)34 Status GetAxisForPackAndUnpack(InferenceContext* c, int32 rank_after_pack,
35                                int32* axis) {
36   TF_RETURN_IF_ERROR(c->GetAttr("axis", axis));
37   if (*axis < -1 * rank_after_pack || *axis >= rank_after_pack) {
38     return errors::InvalidArgument("Invalid axis: ", *axis, "; must be in [",
39                                    -1 * rank_after_pack, ",", rank_after_pack,
40                                    ")");
41   }
42   if (*axis < 0) *axis = (rank_after_pack + *axis);
43   return Status::OK();
44 }
45 
46 template <typename T>
AsInt64(const Tensor * tensor,int64 num_elements)47 std::vector<int64> AsInt64(const Tensor* tensor, int64 num_elements) {
48   std::vector<int64> ret(num_elements);
49   auto data = tensor->vec<T>();
50   for (int64 i = 0; i < num_elements; ++i) {
51     ret[i] = data(i);
52   }
53   return ret;
54 }
55 
56 template <typename T>
PadKnown(InferenceContext * c,ShapeHandle input,const Tensor * paddings_t,int64 num_dims)57 Status PadKnown(InferenceContext* c, ShapeHandle input,
58                 const Tensor* paddings_t, int64 num_dims) {
59   // paddings_t is known.
60   std::vector<DimensionHandle> dims(num_dims);
61   auto paddings_data = paddings_t->matrix<T>();
62   for (int64 i = 0; i < num_dims; ++i) {
63     const T pad0 = paddings_data(i, 0);
64     const T pad1 = paddings_data(i, 1);
65     if (pad0 < 0 || pad1 < 0) {
66       return errors::InvalidArgument("Paddings must be non-negative");
67     }
68     TF_RETURN_IF_ERROR(c->Add(c->Dim(input, i), pad0 + pad1, &dims[i]));
69   }
70   c->set_output(0, c->MakeShape(dims));
71   return Status::OK();
72 }
73 
PadShapeFn(InferenceContext * c)74 Status PadShapeFn(InferenceContext* c) {
75   // Paddings is a matrix of [input_rank, 2].
76   ShapeHandle paddings;
77   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &paddings));
78   DimensionHandle unused;
79   TF_RETURN_IF_ERROR(c->WithValue(c->Dim(paddings, 1), 2, &unused));
80 
81   // n_dim and input.rank are equivalent.
82   ShapeHandle input = c->input(0);
83   DimensionHandle n_dim = c->Dim(paddings, 0);
84   if (c->ValueKnown(n_dim)) {
85     TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(n_dim), &input));
86   } else if (c->RankKnown(input)) {
87     TF_RETURN_IF_ERROR(c->WithValue(n_dim, c->Rank(input), &n_dim));
88   }
89 
90   const Tensor* paddings_t = c->input_tensor(1);
91 
92   // paddings_t is unknown
93   if (paddings_t == nullptr) {
94     if (c->ValueKnown(n_dim)) {
95       // Make output with n_dim unknown dims.
96       c->set_output(0, c->UnknownShapeOfRank(c->Value(n_dim)));
97     } else {
98       c->set_output(0, c->UnknownShape());
99     }
100     return Status::OK();
101   }
102 
103   const int64 num_dims = paddings_t->shape().dim_size(0);
104   TF_RETURN_IF_ERROR(c->WithRank(input, num_dims, &input));
105   TF_RETURN_IF_ERROR(c->WithValue(n_dim, num_dims, &n_dim));
106 
107   if (paddings_t->dtype() == DT_INT32) {
108     return PadKnown<int32>(c, input, paddings_t, num_dims);
109   } else {
110     return PadKnown<int64>(c, input, paddings_t, num_dims);
111   }
112 }
113 
TransposeShapeFn(InferenceContext * c)114 Status TransposeShapeFn(InferenceContext* c) {
115   ShapeHandle input = c->input(0);
116   ShapeHandle perm_shape = c->input(1);
117   const Tensor* perm = c->input_tensor(1);
118   DimensionHandle perm_elems = c->NumElements(perm_shape);
119   // If we don't have rank information on the input or value information on
120   // perm we can't return any shape information, otherwise we have enough
121   // information to at least find the rank of the output.
122   if (!c->RankKnown(input) && !c->ValueKnown(perm_elems) && perm == nullptr) {
123     c->set_output(0, c->UnknownShape());
124     return Status::OK();
125   }
126 
127   // Find our value of the rank.
128   int64 rank;
129   if (c->RankKnown(input)) {
130     rank = c->Rank(input);
131   } else if (c->ValueKnown(perm_elems)) {
132     rank = c->Value(perm_elems);
133   } else {
134     rank = perm->NumElements();
135   }
136   if (!c->RankKnown(input) && rank < 2) {
137     // A permutation array containing a single element is ambiguous. It could
138     // indicate either a scalar or a 1-dimensional array, both of which the
139     // transpose op returns unchanged.
140     c->set_output(0, input);
141     return Status::OK();
142   }
143 
144   std::vector<DimensionHandle> dims;
145   dims.resize(rank);
146   TF_RETURN_IF_ERROR(c->WithRank(input, rank, &input));
147   // Ensure that perm is a vector and has rank elements.
148   TF_RETURN_IF_ERROR(c->WithRank(perm_shape, 1, &perm_shape));
149   TF_RETURN_IF_ERROR(c->WithValue(perm_elems, rank, &perm_elems));
150 
151   // If we know the rank of the input and the value of perm, we can return
152   // all shape informantion, otherwise we can only return rank information,
153   // but no information for the dimensions.
154   if (perm != nullptr) {
155     std::vector<int64> data;
156     if (perm->dtype() == DT_INT32) {
157       data = AsInt64<int32>(perm, rank);
158     } else {
159       data = AsInt64<int64>(perm, rank);
160     }
161 
162     for (int32 i = 0; i < rank; ++i) {
163       int64 in_idx = data[i];
164       if (in_idx >= rank) {
165         return errors::InvalidArgument("perm dim ", in_idx,
166                                        " is out of range of input rank ", rank);
167       }
168       dims[i] = c->Dim(input, in_idx);
169     }
170   } else {
171     for (int i = 0; i < rank; ++i) {
172       dims[i] = c->UnknownDim();
173     }
174   }
175 
176   c->set_output(0, c->MakeShape(dims));
177   return Status::OK();
178 }
179 
SetOutputShapeForReshape(InferenceContext * c)180 Status SetOutputShapeForReshape(InferenceContext* c) {
181   ShapeHandle in = c->input(0);
182   ShapeHandle out;
183   TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out));
184 
185   if (!c->RankKnown(out)) {
186     // We have no information about the shape of the output.
187     c->set_output(0, out);
188     return Status::OK();
189   }
190 
191   if (c->RankKnown(out) && c->RankKnown(in)) {
192     // We don't know the number of output elements, but we can try to infer
193     // the missing dimension.
194     bool too_many_unknown = false;
195     int32 out_unknown_idx = -1;
196 
197     DimensionHandle known_out_elems = c->NumElements(out);
198     if (!c->ValueKnown(known_out_elems)) {
199       known_out_elems = c->MakeDim(1);
200       for (int32 i = 0; i < c->Rank(out); ++i) {
201         DimensionHandle dim = c->Dim(out, i);
202         if (!c->ValueKnown(dim)) {
203           if (out_unknown_idx >= 0) {
204             too_many_unknown = true;
205             break;
206           }
207           out_unknown_idx = i;
208         } else {
209           TF_RETURN_IF_ERROR(
210               c->Multiply(known_out_elems, dim, &known_out_elems));
211         }
212       }
213     }
214     int32 in_unknown_idx = -1;
215     DimensionHandle known_in_elems = c->NumElements(in);
216     if (!c->ValueKnown(known_in_elems)) {
217       known_in_elems = c->MakeDim(1);
218       for (int32 i = 0; i < c->Rank(in); ++i) {
219         DimensionHandle dim = c->Dim(in, i);
220         if (!c->ValueKnown(dim)) {
221           if (in_unknown_idx >= 0) {
222             too_many_unknown = true;
223             break;
224           }
225           in_unknown_idx = i;
226         } else {
227           TF_RETURN_IF_ERROR(c->Multiply(known_in_elems, dim, &known_in_elems));
228         }
229       }
230     }
231 
232     if (!too_many_unknown) {
233       if (in_unknown_idx < 0 && out_unknown_idx < 0) {
234         // Just check that the dimensions match.
235         if (c->Value(known_in_elems) != c->Value(known_out_elems)) {
236           return errors::InvalidArgument(
237               "Cannot reshape a tensor with ", c->DebugString(known_in_elems),
238               " elements to shape ", c->DebugString(out), " (",
239               c->DebugString(known_out_elems), " elements)");
240         }
241       } else if (in_unknown_idx < 0 && out_unknown_idx >= 0 &&
242                  c->Value(known_out_elems) > 0) {
243         // Input fully known, infer the one missing output dim
244         DimensionHandle inferred_dim;
245         TF_RETURN_IF_ERROR(c->Divide(known_in_elems, c->Value(known_out_elems),
246                                      true /* evenly_divisible */,
247                                      &inferred_dim));
248         TF_RETURN_IF_ERROR(
249             c->ReplaceDim(out, out_unknown_idx, inferred_dim, &out));
250 
251       } else if (in_unknown_idx >= 0 && out_unknown_idx < 0 &&
252                  c->Value(known_in_elems) != 0) {
253         // Output fully known, infer the one missing input dim
254         DimensionHandle inferred_dim;
255         TF_RETURN_IF_ERROR(c->Divide(known_out_elems, c->Value(known_in_elems),
256                                      true /* evenly_divisible */,
257                                      &inferred_dim));
258         DimensionHandle unknown_in_dim = c->Dim(in, in_unknown_idx);
259         TF_RETURN_IF_ERROR(
260             c->Merge(unknown_in_dim, inferred_dim, &unknown_in_dim));
261       } else if (in_unknown_idx >= 0 && out_unknown_idx >= 0) {
262         // Exactly one unknown dimension in both input and output. These 2 are
263         // equal iff the known elements are equal.
264         if (c->Value(known_in_elems) == c->Value(known_out_elems)) {
265           DimensionHandle unknown_in_dim = c->Dim(in, in_unknown_idx);
266           TF_RETURN_IF_ERROR(
267               c->ReplaceDim(out, out_unknown_idx, unknown_in_dim, &out));
268         }
269       }
270     }
271   }
272   c->set_output(0, out);
273   return Status::OK();
274 }
275 
276 }  // namespace
277 
278 REGISTER_OP("ParallelConcat")
279     .Input("values: N * T")
280     .Output("output: T")
281     .Attr("N: int >= 1")
282     .Attr("T: type")
283     .Attr("shape: shape")
__anondb9326b20202(InferenceContext* c) 284     .SetShapeFn([](InferenceContext* c) {
285       // Validate that the shape attr is correct.
286       PartialTensorShape shape;
287       TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
288       ShapeHandle passed_shape;
289       TF_RETURN_IF_ERROR(
290           c->MakeShapeFromPartialTensorShape(shape, &passed_shape));
291       if (!c->FullyDefined(passed_shape)) {
292         return errors::InvalidArgument("shape attr must be fully defined.");
293       }
294       ShapeHandle cur;
295       TF_RETURN_IF_ERROR(c->ReplaceDim(
296           passed_shape, 0, c->MakeDim(shape_inference::DimensionOrConstant(1)),
297           &cur));
298       for (int i = 0; i < c->num_inputs(); ++i) {
299         if (!c->FullyDefined(c->input(i))) {
300           return errors::InvalidArgument(
301               "All input shapes must be fully defined.");
302         }
303         DimensionHandle unused;
304         if (!c->WithValue(c->Dim(c->input(i), 0), 1, &unused).ok()) {
305           return errors::InvalidArgument("Size of first dimension must be 1.");
306         }
307         TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
308                                         "From merging shape ", i,
309                                         " with other shapes.");
310       }
311 
312       c->set_output(0, passed_shape);
313 
314       return Status::OK();
315     });
316 
317 REGISTER_OP("Pack")
318     .Input("values: N * T")
319     .Output("output: T")
320     .Attr("N: int >= 1")
321     .Attr("T: type")
322     .Attr("axis: int = 0")
__anondb9326b20302(InferenceContext* c) 323     .SetShapeFn([](InferenceContext* c) {
324       // Validate shapes of all inputs are compatible
325       ShapeHandle cur = c->input(c->num_inputs() - 1);
326       for (int i = c->num_inputs() - 2; i >= 0; --i) {
327         TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
328                                         "From merging shape ", i,
329                                         " with other shapes.");
330       }
331       if (!c->RankKnown(cur)) {
332         c->set_output(0, c->UnknownShape());
333         return Status::OK();
334       }
335       // Determine the axis that will be added, converting from negative
336       // axes to a positive point per negative indexing rules.
337       int32 rank = c->Rank(cur);
338       int32 axis;
339       TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank + 1, &axis));
340 
341       // Copy all dimensions over, inserting a dimension of value #inputs
342       // at <axis>.
343       std::vector<DimensionHandle> dims;
344       int index = 0;
345       while (index < axis) dims.push_back(c->Dim(cur, index++));
346       dims.push_back(c->MakeDim(c->num_inputs()));
347       while (index < rank) dims.push_back(c->Dim(cur, index++));
348 
349       c->set_output(0, c->MakeShape(dims));
350       for (int i = 0; i < c->num_inputs(); ++i) {
351         auto* shape_and_type = c->input_handle_shapes_and_types(i);
352         if (shape_and_type) {
353           if (!c->RelaxOutputHandleShapesAndMergeTypes(0, *shape_and_type)) {
354             c->set_output_handle_shapes_and_types(
355                 0, std::vector<shape_inference::ShapeAndType>({}));
356             break;
357           }
358         }
359       }
360       return Status::OK();
361     });
362 
363 REGISTER_OP("DeepCopy")
364     .Input("x: T")
365     .Output("y: T")
366     .Attr("T: type")
367     .SetIsStateful()
368     .SetShapeFn(UnchangedShape);
369 
370 REGISTER_OP("InplaceUpdate")
371     .Input("x: T")
372     .Input("i: int32")
373     .Input("v: T")
374     .Output("y: T")
375     .Attr("T: type")
376     .SetShapeFn(UnchangedShape);
377 
378 REGISTER_OP("InplaceAdd")
379     .Input("x: T")
380     .Input("i: int32")
381     .Input("v: T")
382     .Output("y: T")
383     .Attr("T: type")
384     .SetShapeFn(UnchangedShape);
385 
386 REGISTER_OP("InplaceSub")
387     .Input("x: T")
388     .Input("i: int32")
389     .Input("v: T")
390     .Output("y: T")
391     .Attr("T: type")
392     .SetShapeFn(UnchangedShape);
393 
394 REGISTER_OP("Empty")
395     .Input("shape: int32")
396     .Output("output: dtype")
397     .Attr("dtype: type")
398     .Attr("init: bool = false")
399     .SetIsStateful()
__anondb9326b20402(InferenceContext* c) 400     .SetShapeFn([](InferenceContext* c) {
401       ShapeHandle out;
402       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
403       c->set_output(0, out);
404       return Status::OK();
405     });
406 
407 // --------------------------------------------------------------------------
408 REGISTER_OP("Unpack")
409     .Input("value: T")
410     .Output("output: num * T")
411     .Attr("num: int >= 0")
412     .Attr("T: type")
413     .Attr("axis: int = 0")
__anondb9326b20502(InferenceContext* c) 414     .SetShapeFn([](InferenceContext* c) {
415       ShapeHandle s = c->input(0);
416       ShapeHandle out;
417       if (c->RankKnown(s)) {
418         // Determine the axis that will be removed, converting from negative
419         // axes to a positive point per negative indexing rules.
420         int32 rank = c->Rank(s);
421         int32 axis;
422         TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank, &axis));
423 
424         // The axis dim matches the number of outputs.
425         DimensionHandle unused;
426         TF_RETURN_IF_ERROR(
427             c->WithValue(c->Dim(s, axis), c->num_outputs(), &unused));
428 
429         // Copy all dimensions, removing the <axis> dimension.
430         std::vector<DimensionHandle> dims;
431         for (int i = 0; i < rank; ++i) {
432           if (i != axis) dims.push_back(c->Dim(s, i));
433         }
434         out = c->MakeShape(dims);
435       } else {
436         // All outputs are the same shape, but it's not known.
437         out = c->UnknownShape();
438       }
439       for (int i = 0; i < c->num_outputs(); ++i) c->set_output(i, out);
440       return Status::OK();
441     });
442 
443 REGISTER_OP("UnravelIndex")
444     .Input("indices: Tidx")
445     .Input("dims: Tidx")
446     .Output("output: Tidx")
447     .Attr("Tidx: {int32, int64} = DT_INT32")
__anondb9326b20602(InferenceContext* c) 448     .SetShapeFn([](InferenceContext* c) {
449       ShapeHandle indices = c->input(0);
450       ShapeHandle dims;
451       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &dims));
452       if (c->RankKnown(indices) && c->Rank(indices) == 0) {
453         c->set_output(0, c->Vector(c->Dim(dims, 0)));
454       } else if (c->RankKnown(indices)) {
455         c->set_output(0, c->Matrix(c->Dim(dims, 0), c->NumElements(indices)));
456       } else {
457         c->set_output(0, c->UnknownShape());
458       }
459       return Status::OK();
460     });
461 
462 REGISTER_OP("BroadcastTo")
463     .Input("input: T")
464     .Input("shape: Tidx")
465     .Output("output: T")
466     .Attr("T: type")
467     .Attr("Tidx: {int32, int64} = DT_INT32")
__anondb9326b20702(InferenceContext* c) 468     .SetShapeFn([](InferenceContext* c) {
469       ShapeHandle shape_in = c->input(1);
470       TF_RETURN_IF_ERROR(c->WithRank(shape_in, 1, &shape_in));
471       ShapeHandle out;
472       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out));
473       if (!c->RankKnown(out)) {
474         // We have no information about the shape of the output.
475         c->set_output(0, out);
476         return Status::OK();
477       }
478 
479       ShapeHandle in = c->input(0);
480       if (!c->RankKnown(in)) {
481         // We have no information about the shape of the input,
482         // nothing to do here.
483         c->set_output(0, out);
484         return Status::OK();
485       }
486       int out_rank = c->Rank(out);
487       TF_RETURN_IF_ERROR(c->WithRankAtMost(in, out_rank, &in));
488       int in_rank = c->Rank(in);
489       for (int i = 0; i < in_rank; ++i) {
490         auto in_dim = c->Dim(in, in_rank - i - 1);
491         if (c->Value(in_dim) > 1) {
492           // If the input dimension is greater than 1 then the output dimension
493           // must be equal to it, since we only broadcast "from left to right".
494           auto out_dim = c->Dim(out, out_rank - i - 1);
495           TF_RETURN_IF_ERROR(c->Merge(in_dim, out_dim, &out_dim));
496           TF_RETURN_IF_ERROR(
497               c->ReplaceDim(out, out_rank - i - 1, out_dim, &out));
498         }
499       }
500       c->set_output(0, out);
501       return Status::OK();
502     });
503 
504 // --------------------------------------------------------------------------
505 // TODO(josh11b): Remove the >= 2 constraint, once we can rewrite the graph
506 // in the N == 1 case to remove the node.
507 REGISTER_OP("Concat")
508     .Input("concat_dim: int32")
509     .Input("values: N * T")
510     .Output("output: T")
511     .Attr("N: int >= 2")
512     .Attr("T: type")
__anondb9326b20802(InferenceContext* c) 513     .SetShapeFn([](InferenceContext* c) {
514       return shape_inference::ConcatShape(c, c->num_inputs() - 1);
515     });
516 
517 REGISTER_OP("ConcatV2")
518     .Input("values: N * T")
519     .Input("axis: Tidx")
520     .Output("output: T")
521     .Attr("N: int >= 2")
522     .Attr("T: type")
523     .Attr("Tidx: {int32, int64} = DT_INT32")
524     .SetShapeFn(shape_inference::ConcatV2Shape);
525 
526 // TODO(vivek.v.rane@intel.com): Prefix the op names with underscore if the ops
527 // are not to be made user-accessible.
528 #ifdef INTEL_MKL
529 REGISTER_OP("_MklConcatV2")
530     .Input("values: N * T")
531     .Input("axis: Tidx")
532     .Input("mkl_values: N * uint8")
533     .Input("mkl_axis: uint8")
534     .Output("output: T")
535     .Output("mkl_output: uint8")
536     .Attr("N: int >= 2")
537     .Attr("T: type")
538     .Attr("Tidx: {int32, int64} = DT_INT32")
539     .SetShapeFn(shape_inference::ConcatV2Shape)
540     .Doc(R"doc(
541 MKL version of ConcatV2 operator. Uses MKL DNN APIs to perform concatenation.
542 
543 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
544 expected to invoke these operators.
545 )doc");
546 #endif
547 
548 REGISTER_OP("ConcatOffset")
549     .Input("concat_dim: int32")
550     .Input("shape: N * int32")
551     .Output("offset: N * int32")
552     .Attr("N: int >= 2")
__anondb9326b20902(InferenceContext* c) 553     .SetShapeFn([](InferenceContext* c) {
554       for (int i = 1; i < c->num_inputs(); ++i) {
555         c->set_output(i - 1, c->input(i));
556       }
557       return Status::OK();
558     });
559 
560 // --------------------------------------------------------------------------
561 REGISTER_OP("Split")
562     .Input("split_dim: int32")
563     .Input("value: T")
564     .Output("output: num_split * T")
565     .Attr("num_split: int >= 1")
566     .Attr("T: type")
__anondb9326b20a02(InferenceContext* c) 567     .SetShapeFn([](InferenceContext* c) {
568       DimensionHandle split_dimension;
569       ShapeHandle input = c->input(1);
570       TF_RETURN_IF_ERROR(c->MakeDimForScalarInputWithNegativeIndexing(
571           0, c->Rank(input), &split_dimension));
572       int num_split = c->num_outputs();
573       ShapeHandle out;
574       if (!c->ValueKnown(split_dimension)) {
575         if (c->RankKnown(input)) {
576           out = c->UnknownShapeOfRank(c->Rank(input));
577         } else {
578           out = c->UnknownShape();
579         }
580       } else {
581         int64 split_dim = c->Value(split_dimension);
582         TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input));
583         DimensionHandle split_dim_size;
584         TF_RETURN_WITH_CONTEXT_IF_ERROR(
585             c->Divide(c->Dim(input, split_dim), num_split,
586                       true /* evenly_divisible */, &split_dim_size),
587             "Number of ways to split should evenly divide the split dimension");
588         TF_RETURN_IF_ERROR(
589             c->ReplaceDim(input, split_dim, split_dim_size, &out));
590       }
591       for (int i = 0; i < num_split; ++i) c->set_output(i, out);
592       return Status::OK();
593     });
594 
595 REGISTER_OP("SplitV")
596     .Input("value: T")
597     .Input("size_splits: Tlen")
598     .Input("split_dim: int32")
599     .Output("output: num_split * T")
600     .Attr("num_split: int >= 1")
601     .Attr("T: type")
602     .Attr("Tlen: {int32, int64} = DT_INT64")
__anondb9326b20b02(InferenceContext* c) 603     .SetShapeFn([](InferenceContext* c) {
604       DimensionHandle split_dimension;
605       ShapeHandle input = c->input(0);
606       TF_RETURN_IF_ERROR(c->MakeDimForScalarInputWithNegativeIndexing(
607           2, c->Rank(input), &split_dimension));
608       int32 num_outputs = c->num_outputs();
609       int32 rank = c->Rank(input);
610       ShapeHandle output_shape;
611       const Tensor* size_splits = c->input_tensor(1);
612       if (rank == InferenceContext::kUnknownRank) {
613         // If the rank of input tensor is unknown, then return unknown shapes.
614         // Note that the shape of each output can be different.
615         for (int i = 0; i < num_outputs; ++i) {
616           c->set_output(i, c->UnknownShape());
617         }
618       } else if (rank == 0) {
619         // Throw error if input is a scalar.
620         return errors::InvalidArgument("Can't split scalars");
621       } else if (size_splits == nullptr && c->ValueKnown(split_dimension)) {
622         // If split dimension is known, but the sizes are unknown, then
623         // only the split dimension is unknown
624         output_shape = input;
625         for (int i = 0; i < num_outputs; ++i) {
626           TF_RETURN_IF_ERROR(c->ReplaceDim(output_shape,
627                                            c->Value(split_dimension),
628                                            c->UnknownDim(), &output_shape));
629           c->set_output(i, output_shape);
630         }
631       } else if (size_splits == nullptr && !c->ValueKnown(split_dimension)) {
632         // If split dimension or tensor containing the split sizes is unknown,
633         // then return unknown shapes of same rank as input. Note that each
634         // output shape can be different since splitv doesn't always split
635         // tensors evenly.
636         for (int i = 0; i < num_outputs; ++i) {
637           c->set_output(i, c->UnknownShapeOfRank(rank));
638         }
639       } else {
640         // Determine the output shape if split dimension and split sizes are
641         // known.
642         int64 split_dim = c->Value(split_dimension);
643         TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input));
644         std::vector<int64> data;
645         if (size_splits->dtype() == DT_INT32) {
646           data = AsInt64<int32>(size_splits, size_splits->shape().dim_size(0));
647         } else {
648           data = AsInt64<int64>(size_splits, size_splits->shape().dim_size(0));
649         }
650         if (num_outputs != data.size()) {
651           return errors::InvalidArgument(
652               "Length of size_splits should be equal to num_outputs");
653         }
654         int64_t total_size = 0;
655         bool has_neg_one = false;
656         for (const auto size : data) {
657           if (size == -1) {
658             if (has_neg_one) {
659               return errors::InvalidArgument(
660                   "size_splits can only have one -1");
661             }
662             has_neg_one = true;
663           } else {
664             total_size += size;
665           }
666         }
667         auto split_dim_size = c->Value(c->Dim(input, split_dim));
668         // If the sizes of the splits are known, then
669         // make sure that the sizes add up to the expected
670         // dimension size, with the possibility of a -1.
671         // Specify the full output shapes.
672         for (int i = 0; i < num_outputs; ++i) {
673           auto size = data[i];
674           if (data[i] == -1 && c->ValueKnown(split_dim_size)) {
675             size = split_dim_size - total_size;
676           }
677           TF_RETURN_IF_ERROR(
678               c->ReplaceDim(input, split_dim, c->MakeDim(size), &output_shape));
679           c->set_output(i, output_shape);
680         }
681         if (c->ValueKnown(split_dim_size)) {
682           if (has_neg_one ? total_size > split_dim_size
683                           : total_size != split_dim_size) {
684             return errors::InvalidArgument(
685                 "can't split axis of size ", split_dim_size,
686                 " into pieces of size [", str_util::Join(data, ","), "]");
687           }
688         }
689       }
690 
691       return Status::OK();
692     });
693 
694 // --------------------------------------------------------------------------
695 REGISTER_OP("Const")
696     .Output("output: dtype")
697     .Attr("value: tensor")
698     .Attr("dtype: type")
__anondb9326b20c02(InferenceContext* c) 699     .SetShapeFn([](InferenceContext* c) {
700       const TensorProto* proto = nullptr;
701       TF_RETURN_IF_ERROR(c->GetAttr("value", &proto));
702       TF_RETURN_IF_ERROR(TensorShape::IsValidShape(proto->tensor_shape()));
703       TensorShape shape(proto->tensor_shape());
704       std::vector<DimensionHandle> dims;
705       dims.reserve(shape.dims());
706       for (int i = 0; i < shape.dims(); ++i) {
707         dims.push_back(c->MakeDim(shape.dim_size(i)));
708       }
709       c->set_output(0, c->MakeShape(dims));
710       return Status::OK();
711     });
712 
713 // Returns a constant tensor on the host.  Useful for writing C++ tests
714 // and benchmarks which run on GPU but require arguments pinned to the host.
715 // Used by test::graph::HostConstant.
716 // value: Attr `value` is the tensor to return.
717 REGISTER_OP("HostConst")
718     .Output("output: dtype")
719     .Attr("value: tensor")
720     .Attr("dtype: type")
721     .SetShapeFn(shape_inference::UnknownShape);
722 
723 // --------------------------------------------------------------------------
724 // TODO(mgubin): Update the doc when the freeze_graph script supports converting
725 // into memmapped format.
726 REGISTER_OP("ImmutableConst")
727     .Attr("dtype: type")
728     .Attr("shape: shape")
729     .Attr("memory_region_name: string")
730     .Output("tensor: dtype")
731     .SetShapeFn(shape_inference::ExplicitShape);
732 
733 REGISTER_OP("GuaranteeConst")
734     .Input("input: T")
735     .Output("output: T")
736     .Attr("T: type")
__anondb9326b20d02(shape_inference::InferenceContext* c) 737     .SetShapeFn([](shape_inference::InferenceContext* c) {
738       return UnchangedShape(c);
739     })
740     // We don't want this to be optimized away.
741     .SetIsStateful();
742 
743 // --------------------------------------------------------------------------
744 REGISTER_OP("ZerosLike")
745     .Input("x: T")
746     .Output("y: T")
747     .Attr("T: type")
748     .SetShapeFn(shape_inference::UnchangedShape);
749 
750 // --------------------------------------------------------------------------
751 REGISTER_OP("OnesLike")
752     .Input("x: T")
753     .Output("y: T")
754     .Attr(
755         "T: {bfloat16, half, float, double, int8, uint8, int16, uint16, int32, "
756         "int64, complex64, complex128, bool}")
757     .SetShapeFn(shape_inference::UnchangedShape);
758 
759 // --------------------------------------------------------------------------
760 REGISTER_OP("Diag")
761     .Input("diagonal: T")
762     .Output("output: T")
763     .Attr(
764         "T: {bfloat16, half, float, double, int32, int64, complex64, "
765         "complex128}")
__anondb9326b20e02(InferenceContext* c) 766     .SetShapeFn([](InferenceContext* c) {
767       ShapeHandle in = c->input(0);
768       TF_RETURN_IF_ERROR(c->WithRankAtLeast(in, 1, &in));
769       // Output shape is original concatenated with itself.
770       ShapeHandle out;
771       TF_RETURN_IF_ERROR(c->Concatenate(in, in, &out));
772       c->set_output(0, out);
773       return Status::OK();
774     });
775 
776 // --------------------------------------------------------------------------
777 REGISTER_OP("DiagPart")
778     .Input("input: T")
779     .Output("diagonal: T")
780     .Attr(
781         "T: {bfloat16, half, float, double, int32, int64, complex64, "
782         "complex128}")
__anondb9326b20f02(InferenceContext* c) 783     .SetShapeFn([](InferenceContext* c) {
784       ShapeHandle in = c->input(0);
785       if (!c->RankKnown(in)) {
786         c->set_output(0, c->UnknownShape());
787         return Status::OK();
788       }
789       // Rank must be even, and result will have rank <rank/2>.
790       const int32 rank = c->Rank(in);
791       if ((rank % 2) != 0 || rank <= 0) {
792         return errors::InvalidArgument(
793             "Input must have even and non-zero rank, input rank is ", rank);
794       }
795       const int32 mid = rank / 2;
796 
797       // output dim[i] is the merge of in.dim[i] and in.dim[i+mid].
798       std::vector<DimensionHandle> dims(mid);
799       for (int i = 0; i < mid; ++i) {
800         TF_RETURN_IF_ERROR(
801             c->Merge(c->Dim(in, i), c->Dim(in, i + mid), &dims[i]));
802       }
803       c->set_output(0, c->MakeShape(dims));
804       return Status::OK();
805     });
806 
807 // --------------------------------------------------------------------------
808 REGISTER_OP("MatrixDiag")
809     .Input("diagonal: T")
810     .Output("output: T")
811     .Attr("T: type")
__anondb9326b21002(InferenceContext* c) 812     .SetShapeFn([](InferenceContext* c) {
813       ShapeHandle in;
814       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &in));
815       if (!c->RankKnown(in)) {
816         c->set_output(0, c->UnknownShape());
817         return Status::OK();
818       }
819       const int32 rank = c->Rank(in);
820       ShapeHandle out;
821       TF_RETURN_IF_ERROR(
822           c->Concatenate(in, c->Vector(c->Dim(in, rank - 1)), &out));
823       c->set_output(0, out);
824       return Status::OK();
825     });
826 
827 // --------------------------------------------------------------------------
828 REGISTER_OP("MatrixSetDiag")
829     .Input("input: T")
830     .Input("diagonal: T")
831     .Output("output: T")
832     .Attr("T: type")
__anondb9326b21102(InferenceContext* c) 833     .SetShapeFn([](InferenceContext* c) {
834       ShapeHandle input;
835       ShapeHandle diag;
836       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
837       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &diag));
838       if (c->RankKnown(input)) {
839         TF_RETURN_IF_ERROR(c->WithRank(c->input(1), c->Rank(input) - 1, &diag));
840       }
841       DimensionHandle smallest_dim;
842       TF_RETURN_IF_ERROR(
843           c->Min(c->Dim(input, -2), c->Dim(input, -1), &smallest_dim));
844       TF_RETURN_IF_ERROR(
845           c->Merge(smallest_dim, c->Dim(diag, -1), &smallest_dim));
846 
847       ShapeHandle output = input;
848       if (c->RankKnown(diag) && !c->FullyDefined(input)) {
849         // Try to infer parts of shape from diag.
850         ShapeHandle diag_prefix;
851         TF_RETURN_IF_ERROR(c->Subshape(diag, 0, -1, &diag_prefix));
852         TF_RETURN_IF_ERROR(
853             c->Concatenate(diag_prefix, c->UnknownShapeOfRank(2), &diag));
854         TF_RETURN_IF_ERROR(c->Merge(input, diag, &output));
855       }
856       c->set_output(0, output);
857       return Status::OK();
858     });
859 
860 // --------------------------------------------------------------------------
861 REGISTER_OP("MatrixDiagPart")
862     .Input("input: T")
863     .Output("diagonal: T")
864     .Attr("T: type")
__anondb9326b21202(InferenceContext* c) 865     .SetShapeFn([](InferenceContext* c) {
866       ShapeHandle in;
867       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &in));
868       if (!c->RankKnown(in)) {
869         c->set_output(0, c->UnknownShape());
870         return Status::OK();
871       }
872       const int32 rank = c->Rank(in);
873       std::vector<DimensionHandle> dims;
874       dims.reserve(rank - 2);
875       for (int i = 0; i < rank - 2; ++i) dims.push_back(c->Dim(in, i));
876 
877       DimensionHandle min_dim;
878       TF_RETURN_IF_ERROR(
879           c->Min(c->Dim(in, rank - 2), c->Dim(in, rank - 1), &min_dim));
880       dims.push_back(min_dim);
881       c->set_output(0, c->MakeShape(dims));
882       return Status::OK();
883     });
884 
885 // --------------------------------------------------------------------------
886 REGISTER_OP("MatrixBandPart")
887     .Input("input: T")
888     .Input("num_lower: Tindex")
889     .Input("num_upper: Tindex")
890     .Output("band: T")
891     .Attr("T: type")
892     .Attr("Tindex: {int32, int64} = DT_INT64")
893     .SetShapeFn(shape_inference::UnchangedShape);
894 
895 // --------------------------------------------------------------------------
896 REGISTER_OP("Reverse")
897     .Input("tensor: T")
898     .Input("dims: bool")
899     .Output("output: T")
900     .Attr(
901         "T: {uint8, int8, uint16, int16, int32, int64, bool, half, "
902         "float, double, complex64, complex128, string}")
__anondb9326b21302(InferenceContext* c) 903     .SetShapeFn([](InferenceContext* c) {
904       ShapeHandle input = c->input(0);
905       ShapeHandle dims;
906       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &dims));
907       DimensionHandle dims_dim = c->Dim(dims, 0);
908       if (c->ValueKnown(dims_dim)) {
909         TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(dims_dim), &input));
910       }
911       if (c->Rank(input) > 8) {
912         return errors::InvalidArgument(
913             "reverse does not work on tensors with more than 8 dimensions");
914       }
915       c->set_output(0, input);
916       return Status::OK();
917     });
918 
919 // --------------------------------------------------------------------------
920 REGISTER_OP("ReverseV2")
921     .Input("tensor: T")
922     .Input("axis: Tidx")
923     .Output("output: T")
924     .Attr("Tidx: {int32, int64} = DT_INT32")
925     .Attr(
926         "T: {uint8, int8, uint16, int16, int32, int64, bool, bfloat16, half, "
927         "float, double, complex64, complex128, string}")
__anondb9326b21402(InferenceContext* c) 928     .SetShapeFn([](InferenceContext* c) {
929       ShapeHandle input = c->input(0);
930       ShapeHandle axis;
931       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &axis));
932       if (c->Rank(input) > 8) {
933         return errors::InvalidArgument(
934             "reverse does not work on tensors with more than 8 dimensions");
935       }
936       const Tensor* axis_tensor = c->input_tensor(1);
937       if (axis_tensor != nullptr && c->RankKnown(input)) {
938         int32 rank = c->Rank(input);
939         std::vector<int64> axis_value;
940         if (axis_tensor->dtype() == DT_INT32) {
941           axis_value = AsInt64<int32>(axis_tensor, axis_tensor->NumElements());
942         } else {
943           axis_value = AsInt64<int64>(axis_tensor, axis_tensor->NumElements());
944         }
945         std::vector<bool> axes_dense(c->Rank(input), false);
946         for (int i = 0; i < axis_value.size(); i++) {
947           int64 canonical_axis =
948               axis_value[i] < 0 ? rank + axis_value[i] : axis_value[i];
949           if (canonical_axis < 0 || canonical_axis >= rank) {
950             return errors::InvalidArgument("'axis'[", i, "] = ", axis_value[i],
951                                            " is out of valid range [", 0, ", ",
952                                            rank - 1);
953           }
954           if (axes_dense[canonical_axis]) {
955             return errors::InvalidArgument("axis ", canonical_axis,
956                                            " specified more than once.");
957           }
958           axes_dense[canonical_axis] = true;
959         }
960       }
961       c->set_output(0, input);
962       return Status::OK();
963     });
964 
965 // --------------------------------------------------------------------------
966 REGISTER_OP("EditDistance")
967     .Input("hypothesis_indices: int64")
968     .Input("hypothesis_values: T")
969     .Input("hypothesis_shape: int64")
970     .Input("truth_indices: int64")
971     .Input("truth_values: T")
972     .Input("truth_shape: int64")
973     .Attr("normalize: bool = true")
974     .Attr("T: type")
975     .Output("output: float")
__anondb9326b21502(InferenceContext* c) 976     .SetShapeFn([](InferenceContext* c) {
977       TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
978           c, c->input(0), c->input(1), c->input(2)));
979       TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
980           c, c->input(3), c->input(4), c->input(5)));
981       const Tensor* hypothesis_shape_t = c->input_tensor(2);
982       const Tensor* truth_shape_t = c->input_tensor(5);
983       if (hypothesis_shape_t == nullptr || truth_shape_t == nullptr) {
984         // We need to know the runtime shape of the two tensors,
985         // or else the output shape is unknown.
986         return shape_inference::UnknownShape(c);
987       }
988 
989       if (hypothesis_shape_t->NumElements() != truth_shape_t->NumElements()) {
990         return errors::InvalidArgument(
991             "Num elements of hypothesis_shape does not match truth_shape: ",
992             hypothesis_shape_t->NumElements(), " vs. ",
993             truth_shape_t->NumElements());
994       }
995 
996       auto h_values = hypothesis_shape_t->flat<int64>();
997       auto t_values = truth_shape_t->flat<int64>();
998       std::vector<DimensionHandle> dims(hypothesis_shape_t->NumElements() - 1);
999       for (int i = 0; i < dims.size(); ++i) {
1000         dims[i] = c->MakeDim(std::max(h_values(i), t_values(i)));
1001       }
1002 
1003       c->set_output(0, c->MakeShape(dims));
1004       return Status::OK();
1005     });
1006 
1007 // --------------------------------------------------------------------------
1008 REGISTER_OP("Fill")
1009     .Input("dims: index_type")
1010     .Input("value: T")
1011     .Output("output: T")
1012     .Attr("T: type")
1013     .Attr("index_type: {int32, int64} = DT_INT32")
__anondb9326b21602(InferenceContext* c) 1014     .SetShapeFn([](InferenceContext* c) {
1015       DataType index_type = DT_INT32;
1016       Status s = c->GetAttr("index_type", &index_type);
1017       if (!s.ok() && s.code() != error::NOT_FOUND) {
1018         return s;
1019       }
1020       ShapeHandle unused;
1021       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
1022       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1023 
1024       const Tensor* t = c->input_tensor(0);
1025       if (t != nullptr) {
1026         for (int i = 0; i < t->NumElements(); ++i) {
1027           if ((index_type == DT_INT32 && t->vec<int32>()(i) < 0) ||
1028               (index_type == DT_INT64 && t->vec<int64>()(i) < 0)) {
1029             return errors::InvalidArgument("Fill dimensions must be >= 0");
1030           }
1031         }
1032       }
1033 
1034       ShapeHandle out;
1035       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
1036       c->set_output(0, out);
1037 
1038       auto* shape_and_type = c->input_handle_shapes_and_types(1);
1039       if (shape_and_type) {
1040         c->set_output_handle_shapes_and_types(0, *shape_and_type);
1041       }
1042 
1043       return Status::OK();
1044     });
1045 
1046 // --------------------------------------------------------------------------
1047 REGISTER_OP("_ParallelConcatStart")
1048     .Output("output: dtype")
1049     .Attr("shape: shape")
1050     .Attr("dtype: type")
1051     .SetIsStateful()
1052     .SetShapeFn(shape_inference::ExplicitShape)
1053     .Doc(R"doc(
1054 Creates an empty Tensor with shape `shape` and type `dtype`.
1055 
1056 The memory can optionally be initialized. This is usually useful in
1057 conjunction with inplace operations.
1058 
1059 shape: 1-D `Tensor` indicating the shape of the output.
1060 dtype: The element type of the returned tensor.
1061 output: An empty Tensor of the specified type.
1062 )doc");
1063 
1064 // --------------------------------------------------------------------------
1065 REGISTER_OP("_ParallelConcatUpdate")
1066     .Input("value: T")
1067     .Input("update: T")
1068     .Output("output: T")
1069     .Attr("T: type")
1070     .Attr("loc: int")
1071     .SetShapeFn(shape_inference::UnchangedShape)
1072     .Doc(R"doc(
1073 Updates input `value` at `loc` with `update`.
1074 
1075 If you use this function you will almost certainly want to add
1076 a control dependency as done in the implementation of parallel_stack to
1077 avoid race conditions.
1078 
1079 value: A `Tensor` object that will be updated in-place.
1080 loc: A scalar indicating the index of the first dimension such that
1081          value[loc, :] is updated.
1082 update: A `Tensor` of rank one less than `value` if `loc` is a scalar,
1083         otherwise of rank equal to `value` that contains the new values
1084         for `value`.
1085 output: `value` that has been updated accordingly.
1086 )doc");
1087 
1088 // --------------------------------------------------------------------------
1089 REGISTER_OP("Gather")
1090     .Input("params: Tparams")
1091     .Input("indices: Tindices")
1092     .Attr("validate_indices: bool = true")
1093     .Output("output: Tparams")
1094     .Attr("Tparams: type")
1095     .Attr("Tindices: {int32,int64}")
__anondb9326b21702(InferenceContext* c) 1096     .SetShapeFn([](InferenceContext* c) {
1097       ShapeHandle unused;
1098       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &unused));
1099       ShapeHandle params_subshape;
1100       TF_RETURN_IF_ERROR(c->Subshape(c->input(0), 1, &params_subshape));
1101       ShapeHandle indices_shape = c->input(1);
1102       ShapeHandle out;
1103       TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, params_subshape, &out));
1104       c->set_output(0, out);
1105       return Status::OK();
1106     });
1107 
1108 // --------------------------------------------------------------------------
1109 REGISTER_OP("GatherV2")
1110     .Input("params: Tparams")
1111     .Input("indices: Tindices")
1112     .Input("axis: Taxis")
1113     .Output("output: Tparams")
1114     .Attr("Tparams: type")
1115     .Attr("Tindices: {int32,int64}")
1116     .Attr("Taxis: {int32,int64}")
__anondb9326b21802(InferenceContext* c) 1117     .SetShapeFn([](InferenceContext* c) {
1118       ShapeHandle params_shape;
1119       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &params_shape));
1120 
1121       ShapeHandle indices_shape = c->input(1);
1122       ShapeHandle unused_axis_shape;
1123       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_axis_shape));
1124       const Tensor* axis_t = c->input_tensor(2);
1125 
1126       // If axis is unknown, we can only infer that the result is params_rank +
1127       // indices_rank - 1.
1128       if (axis_t == nullptr) {
1129         if (c->RankKnown(params_shape) && c->RankKnown(indices_shape)) {
1130           c->set_output(0, c->UnknownShapeOfRank(c->Rank(params_shape) +
1131                                                  c->Rank(indices_shape) - 1));
1132         } else {
1133           c->set_output(0, c->UnknownShape());
1134         }
1135         return Status::OK();
1136       }
1137 
1138       // Note, axis can be negative.
1139       int64 axis = 0;
1140       if (axis_t->dtype() == DT_INT32) {
1141         axis = axis_t->scalar<int32>()();
1142       } else {
1143         axis = axis_t->scalar<int64>()();
1144       }
1145 
1146       // Check that params has rank of at least axis + 1.
1147       ShapeHandle unused;
1148       TF_RETURN_IF_ERROR(c->WithRankAtLeast(
1149           params_shape, axis < 0 ? -axis : axis + 1, &unused));
1150 
1151       ShapeHandle params_outer_subshape;
1152       TF_RETURN_IF_ERROR(
1153           c->Subshape(params_shape, 0, axis, &params_outer_subshape));
1154 
1155       ShapeHandle out;
1156       TF_RETURN_IF_ERROR(
1157           c->Concatenate(params_outer_subshape, indices_shape, &out));
1158 
1159       // Slice from axis + 1 to the end of params_shape to collect the inner
1160       // dimensions of the result. Special case -1 here since -1 + 1 wraps, and
1161       // we slice from 0 to the end of shape. Subshape() handles all other
1162       // out-of-bounds checking.
1163       if (axis != -1) {
1164         ShapeHandle params_inner_subshape;
1165         TF_RETURN_IF_ERROR(
1166             c->Subshape(params_shape, axis + 1, &params_inner_subshape));
1167         TF_RETURN_IF_ERROR(c->Concatenate(out, params_inner_subshape, &out));
1168       }
1169 
1170       c->set_output(0, out);
1171       return Status::OK();
1172     });
1173 
1174 // --------------------------------------------------------------------------
1175 REGISTER_OP("GatherNd")
1176     .Input("params: Tparams")
1177     .Input("indices: Tindices")
1178     .Output("output: Tparams")
1179     .Attr("Tparams: type")
1180     .Attr("Tindices: {int32,int64}")
__anondb9326b21902(InferenceContext* c) 1181     .SetShapeFn([](InferenceContext* c) {
1182       ShapeHandle params = c->input(0);
1183       ShapeHandle indices;
1184       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices));
1185       DimensionHandle r_dim = c->Dim(indices, -1);
1186 
1187       if (!c->RankKnown(params) || !c->ValueKnown(r_dim)) {
1188         c->set_output(0, c->UnknownShape());
1189         return Status::OK();
1190       }
1191 
1192       if (c->Value(r_dim) > c->Rank(params)) {
1193         return errors::InvalidArgument(
1194             "indices.shape[-1] must be <= params.rank, but saw indices shape: ",
1195             c->DebugString(indices),
1196             " and params shape: ", c->DebugString(params));
1197       }
1198 
1199       // Remove r_dim from indices to get output.
1200       ShapeHandle indices_slice;
1201       ShapeHandle params_slice;
1202       TF_RETURN_IF_ERROR(c->Subshape(indices, 0, -1, &indices_slice));
1203       TF_RETURN_IF_ERROR(c->Subshape(params, c->Value(r_dim), &params_slice));
1204       ShapeHandle out;
1205       TF_RETURN_IF_ERROR(c->Concatenate(indices_slice, params_slice, &out));
1206       c->set_output(0, out);
1207       return Status::OK();
1208     });
1209 
1210 // --------------------------------------------------------------------------
1211 REGISTER_OP("Identity")
1212     .Input("input: T")
1213     .Output("output: T")
1214     .Attr("T: type")
1215     .SetShapeFn(shape_inference::UnchangedShape);
1216 
1217 REGISTER_OP("Snapshot")
1218     .Input("input: T")
1219     .Output("output: T")
1220     .Attr("T: type")
1221     .SetShapeFn(shape_inference::UnchangedShape);
1222 
1223 #ifdef INTEL_MKL
1224 REGISTER_OP("_MklIdentity")
1225     .Input("input: T")
1226     .Input("mkl_input: uint8")
1227     .Output("output: T")
1228     .Output("mkl_output: uint8")
1229     .Attr("T: type")
1230     .SetShapeFn(shape_inference::UnchangedShape)
1231     .Doc(R"Doc( Mkl implementation of IdentityOp
1232 )Doc");
1233 #endif
1234 
1235 REGISTER_OP("IdentityN")
1236     .Input("input: T")
1237     .Output("output: T")
1238     .Attr("T: list(type)")
__anondb9326b21a02(shape_inference::InferenceContext* c) 1239     .SetShapeFn([](shape_inference::InferenceContext* c) {
1240       std::vector<ShapeHandle> input;
1241       TF_RETURN_IF_ERROR(c->input("input", &input));
1242       TF_RETURN_IF_ERROR(c->set_output("output", input));
1243       return Status::OK();
1244     });
1245 
1246 // --------------------------------------------------------------------------
1247 REGISTER_OP("RefIdentity")
1248     .Input("input: Ref(T)")
1249     .Output("output: Ref(T)")
1250     .Attr("T: type")
1251     .SetShapeFn(shape_inference::UnchangedShape)
1252     .SetAllowsUninitializedInput();
1253 
1254 // --------------------------------------------------------------------------
1255 REGISTER_OP("DebugGradientIdentity")
1256     .Input("input: T")
1257     .Output("output: T")
1258     .Attr("T: type")
1259     .SetShapeFn(shape_inference::UnchangedShape)
1260     .SetAllowsUninitializedInput();
1261 
1262 REGISTER_OP("DebugGradientRefIdentity")
1263     .Input("input: Ref(T)")
1264     .Output("output: Ref(T)")
1265     .Attr("T: type")
1266     .SetShapeFn(shape_inference::UnchangedShape)
1267     .SetAllowsUninitializedInput();
1268 
1269 // --------------------------------------------------------------------------
1270 REGISTER_OP("StopGradient")
1271     .Input("input: T")
1272     .Output("output: T")
1273     .Attr("T: type")
1274     .SetShapeFn(shape_inference::UnchangedShape);
1275 
1276 REGISTER_OP("PreventGradient")
1277     .Input("input: T")
1278     .Output("output: T")
1279     .Attr("T: type")
1280     .Attr("message: string = ''")
1281     .SetShapeFn(shape_inference::UnchangedShape);
1282 
1283 // --------------------------------------------------------------------------
1284 REGISTER_OP("CheckNumerics")
1285     .Input("tensor: T")
1286     .Output("output: T")
1287     .Attr("T: {bfloat16, half, float, double}")
1288     .Attr("message: string")
1289     .SetShapeFn(shape_inference::UnchangedShape);
1290 
1291 // --------------------------------------------------------------------------
1292 REGISTER_OP("Reshape")
1293     .Input("tensor: T")
1294     .Input("shape: Tshape")
1295     .Output("output: T")
1296     .Attr("T: type")
1297     .Attr("Tshape: {int32, int64} = DT_INT32")
__anondb9326b21b02(InferenceContext* c) 1298     .SetShapeFn([](InferenceContext* c) {
1299       return SetOutputShapeForReshape(c);
1300     });
1301 
1302 #ifdef INTEL_MKL
1303 REGISTER_OP("_MklReshape")
1304     .Input("tensor: T")
1305     .Input("shape: Tshape")
1306     .Input("mkl_tensor: uint8")
1307     .Input("mkl_shape: uint8")
1308     .Output("output: T")
1309     .Output("mkl_output: uint8")
1310     .Attr("T: type")
1311     .Attr("Tshape: {int32, int64} = DT_INT32")
__anondb9326b21c02(InferenceContext* c) 1312     .SetShapeFn([](InferenceContext* c) { return SetOutputShapeForReshape(c); })
1313     .Doc(R"Doc( MKL implementation of ReshapeOp.
1314 )Doc");
1315 #endif  // INTEL_MKL
1316 
1317 // --------------------------------------------------------------------------
1318 REGISTER_OP("InvertPermutation")
1319     .Input("x: T")
1320     .Output("y: T")
1321     .Attr("T: {int32, int64} = DT_INT32")
__anondb9326b21d02(InferenceContext* c) 1322     .SetShapeFn([](InferenceContext* c) {
1323       ShapeHandle x;
1324       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &x));
1325       c->set_output(0, x);
1326       return Status::OK();
1327     });
1328 
1329 // --------------------------------------------------------------------------
1330 REGISTER_OP("Transpose")
1331     .Input("x: T")
1332     .Input("perm: Tperm")
1333     .Output("y: T")
1334     .Attr("T: type")
1335     .Attr("Tperm: {int32, int64} = DT_INT32")
1336     .SetShapeFn(TransposeShapeFn);
1337 
1338 // --------------------------------------------------------------------------
1339 REGISTER_OP("ConjugateTranspose")
1340     .Input("x: T")
1341     .Input("perm: Tperm")
1342     .Output("y: T")
1343     .Attr("T: type")
1344     .Attr("Tperm: {int32, int64} = DT_INT32")
1345     .SetShapeFn(TransposeShapeFn);
1346 
1347 // --------------------------------------------------------------------------
1348 REGISTER_OP("Unique")
1349     .Input("x: T")
1350     .Output("y: T")
1351     .Output("idx: out_idx")
1352     .Attr("T: type")
1353     .Attr("out_idx: {int32, int64} = DT_INT32")
__anondb9326b21e02(InferenceContext* c) 1354     .SetShapeFn([](InferenceContext* c) {
1355       c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
1356       c->set_output(1, c->input(0));
1357       // Assert that the input rank is 1.
1358       ShapeHandle dummy;
1359       return c->WithRank(c->input(0), 1, &dummy);
1360     });
1361 
1362 REGISTER_OP("UniqueV2")
1363     .Input("x: T")
1364     .Input("axis: Taxis")
1365     .Output("y: T")
1366     .Output("idx: out_idx")
1367     .Attr("T: type")
1368     .Attr("Taxis: {int32,int64} = DT_INT64")
1369     .Attr("out_idx: {int32, int64} = DT_INT32")
__anondb9326b21f02(InferenceContext* c) 1370     .SetShapeFn([](InferenceContext* c) {
1371       c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
1372       c->set_output(1, c->input(0));
1373       return Status::OK();
1374     });
1375 
1376 // --------------------------------------------------------------------------
1377 REGISTER_OP("UniqueWithCounts")
1378     .Input("x: T")
1379     .Output("y: T")
1380     .Output("idx: out_idx")
1381     .Output("count: out_idx")
1382     .Attr("T: type")
1383     .Attr("out_idx: {int32, int64} = DT_INT32")
__anondb9326b22002(InferenceContext* c) 1384     .SetShapeFn([](InferenceContext* c) {
1385       auto uniq = c->Vector(InferenceContext::kUnknownDim);
1386       c->set_output(0, uniq);
1387       c->set_output(1, c->input(0));
1388       c->set_output(2, uniq);
1389       return Status::OK();
1390     });
1391 
1392 REGISTER_OP("UniqueWithCountsV2")
1393     .Input("x: T")
1394     .Input("axis: Taxis")
1395     .Output("y: T")
1396     .Output("idx: out_idx")
1397     .Output("count: out_idx")
1398     .Attr("T: type")
1399     .Attr("Taxis: {int32,int64} = DT_INT64")
1400     .Attr("out_idx: {int32, int64} = DT_INT32")
__anondb9326b22102(InferenceContext* c) 1401     .SetShapeFn([](InferenceContext* c) {
1402       auto uniq = c->Vector(InferenceContext::kUnknownDim);
1403       c->set_output(0, uniq);
1404       c->set_output(1, c->input(0));
1405       c->set_output(2, uniq);
1406       return Status::OK();
1407     });
1408 
1409 namespace {
1410 
ShapeShapeFn(InferenceContext * c)1411 Status ShapeShapeFn(InferenceContext* c) {
1412   for (int i = 0; i < c->num_inputs(); ++i) {
1413     DimensionHandle dim;
1414     if (c->RankKnown(c->input(i))) {
1415       dim = c->MakeDim(c->Rank(c->input(i)));
1416     } else {
1417       dim = c->UnknownDim();
1418     }
1419     c->set_output(i, c->Vector(dim));
1420   }
1421   return Status::OK();
1422 }
1423 
1424 }  // namespace
1425 
1426 // --------------------------------------------------------------------------
1427 REGISTER_OP("Shape")
1428     .Input("input: T")
1429     .Output("output: out_type")
1430     .Attr("T: type")
1431     .Attr("out_type: {int32, int64} = DT_INT32")
1432     .SetShapeFn(ShapeShapeFn);
1433 
1434 REGISTER_OP("ShapeN")
1435     .Input("input: N * T")
1436     .Output("output: N * out_type")
1437     .Attr("N: int")
1438     .Attr("T: type")
1439     .Attr("out_type: {int32, int64} = DT_INT32")
1440     .SetShapeFn(ShapeShapeFn);
1441 
1442 REGISTER_OP("EnsureShape")
1443     .Input("input: T")
1444     .Output("output: T")
1445     .Attr("shape: shape")
1446     .Attr("T: type")
__anondb9326b22302(InferenceContext* c) 1447     .SetShapeFn([](InferenceContext* c) {
1448       // Merges desired shape and statically known shape of input
1449       PartialTensorShape desired_shape;
1450       TF_RETURN_IF_ERROR(c->GetAttr("shape", &desired_shape));
1451 
1452       int rank = desired_shape.dims();
1453       ShapeHandle input_shape_handle;
1454       ShapeHandle desired_shape_handle;
1455       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape_handle));
1456       TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
1457           desired_shape, &desired_shape_handle));
1458 
1459       ShapeHandle merged_shape;
1460       TF_RETURN_IF_ERROR(
1461           c->Merge(desired_shape_handle, input_shape_handle, &merged_shape));
1462       c->set_output(0, merged_shape);
1463       return Status::OK();
1464     });
1465 
1466 // --------------------------------------------------------------------------
1467 REGISTER_OP("ReverseSequence")
1468     .Input("input: T")
1469     .Input("seq_lengths: Tlen")
1470     .Output("output: T")
1471     .Attr("seq_dim: int")
1472     .Attr("batch_dim: int = 0")
1473     .Attr("T: type")
1474     .Attr("Tlen: {int32, int64} = DT_INT64")
__anondb9326b22402(InferenceContext* c) 1475     .SetShapeFn([](InferenceContext* c) {
1476       ShapeHandle input = c->input(0);
1477       ShapeHandle seq_lens_shape;
1478       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seq_lens_shape));
1479 
1480       int64 seq_dim;
1481       TF_RETURN_IF_ERROR(c->GetAttr("seq_dim", &seq_dim));
1482       int64 batch_dim;
1483       TF_RETURN_IF_ERROR(c->GetAttr("batch_dim", &batch_dim));
1484 
1485       if (!c->RankKnown(input)) {
1486         return shape_inference::UnknownShape(c);
1487       }
1488 
1489       // Validate batch_dim and seq_dim against input.
1490       const int32 input_rank = c->Rank(input);
1491       if (batch_dim >= input_rank) {
1492         return errors::InvalidArgument(
1493             "batch_dim must be < input rank: ", batch_dim, " vs. ", input_rank);
1494       }
1495       if (seq_dim >= input_rank) {
1496         return errors::InvalidArgument(
1497             "seq_dim must be < input rank: ", seq_dim, " vs. ", input_rank);
1498       }
1499 
1500       DimensionHandle batch_dim_dim = c->Dim(input, batch_dim);
1501       TF_RETURN_IF_ERROR(
1502           c->Merge(batch_dim_dim, c->Dim(seq_lens_shape, 0), &batch_dim_dim));
1503 
1504       // Replace batch_dim of input with batch_size
1505       ShapeHandle output_shape;
1506       TF_RETURN_IF_ERROR(
1507           c->ReplaceDim(input, batch_dim, batch_dim_dim, &output_shape));
1508       c->set_output(0, output_shape);
1509       return Status::OK();
1510     });
1511 
1512 // --------------------------------------------------------------------------
1513 REGISTER_OP("Rank")
1514     .Input("input: T")
1515     .Output("output: int32")
1516     .Attr("T: type")
1517     .SetShapeFn(shape_inference::ScalarShape);
1518 
1519 // --------------------------------------------------------------------------
1520 REGISTER_OP("Size")
1521     .Input("input: T")
1522     .Output("output: out_type")
1523     .Attr("T: type")
1524     .Attr("out_type: {int32, int64} = DT_INT32")
1525     .SetShapeFn(shape_inference::ScalarShape);
1526 
1527 // --------------------------------------------------------------------------
1528 REGISTER_OP("Slice")
1529     .Input("input: T")
1530     .Input("begin: Index")
1531     .Input("size: Index")
1532     .Output("output: T")
1533     .Attr("T: type")
1534     .Attr("Index: {int32,int64}")
1535     .SetShapeFn(shape_inference::SliceShape);
1536 
1537 #ifdef INTEL_MKL
1538 REGISTER_OP("_MklSlice")
1539     .Input("input: T")
1540     .Input("begin: Index")
1541     .Input("size: Index")
1542     .Input("mkl_input: uint8")
1543     .Input("mkl_begin: uint8")
1544     .Input("mkl_size: uint8")
1545     .Output("output: T")
1546     .Output("mkl_output: uint8")
1547     .Attr("T: type")
1548     .Attr("Index: {int32,int64}")
1549     .SetShapeFn(shape_inference::SliceShape);
1550 #endif
1551 
1552 REGISTER_OP("StridedSlice")
1553     .Input("input: T")
1554     .Input("begin: Index")
1555     .Input("end: Index")
1556     .Input("strides: Index")
1557     .Output("output: T")
1558     .Attr("T: type")
1559     .Attr("Index: {int32, int64}")
1560     .Attr("begin_mask: int = 0")
1561     .Attr("end_mask: int = 0")
1562     .Attr("ellipsis_mask: int = 0")
1563     .Attr("new_axis_mask: int = 0")
1564     .Attr("shrink_axis_mask: int = 0")
__anondb9326b22502(InferenceContext* c) 1565     .SetShapeFn([](InferenceContext* c) {
1566       ShapeHandle input = c->input(0);
1567       ShapeHandle begin_shape, end_shape, strides_shape;
1568       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape));
1569       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &end_shape));
1570       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &strides_shape));
1571       TF_RETURN_IF_ERROR(c->Merge(begin_shape, end_shape, &begin_shape));
1572       TF_RETURN_IF_ERROR(c->Merge(begin_shape, strides_shape, &begin_shape));
1573       DimensionHandle sparse_dims_dim = c->Dim(begin_shape, 0);
1574 
1575       const Tensor* strides_value = c->input_tensor(3);
1576       // TODO(aselle,allenl): If we had a stride_mask it would be possible to do
1577       // more shape inference here (e.g. for x[3, ::T]).
1578       if (!c->RankKnown(input) || !c->ValueKnown(sparse_dims_dim) ||
1579           strides_value == nullptr) {
1580         c->set_output(0, c->UnknownShape());
1581         return Status::OK();
1582       }
1583 
1584       PartialTensorShape input_shape({});
1585       for (int i = 0; i < c->Rank(input); ++i) {
1586         auto dim = c->Dim(input, i);
1587         input_shape.AddDim(c->ValueKnown(dim) ? c->Value(dim) : -1);
1588       }
1589 
1590       int32 begin_mask, end_mask, ellipsis_mask, new_axis_mask,
1591           shrink_axis_mask;
1592       TF_RETURN_IF_ERROR(c->GetAttr("begin_mask", &begin_mask));
1593       TF_RETURN_IF_ERROR(c->GetAttr("end_mask", &end_mask));
1594       TF_RETURN_IF_ERROR(c->GetAttr("ellipsis_mask", &ellipsis_mask));
1595       TF_RETURN_IF_ERROR(c->GetAttr("new_axis_mask", &new_axis_mask));
1596       TF_RETURN_IF_ERROR(c->GetAttr("shrink_axis_mask", &shrink_axis_mask));
1597 
1598       const Tensor* begin_value = c->input_tensor(1);
1599       const Tensor* end_value = c->input_tensor(2);
1600 
1601       PartialTensorShape processing_shape, final_shape;
1602       bool is_identity, is_simple_slice, slice_dim0;
1603       gtl::InlinedVector<int64, 4> begin, end, strides;
1604       TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
1605           begin_value, end_value, *strides_value, input_shape, begin_mask,
1606           end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask,
1607           &processing_shape, &final_shape, &is_identity, &is_simple_slice,
1608           &slice_dim0, &begin, &end, &strides));
1609 
1610       ShapeHandle out;
1611       TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(final_shape, &out));
1612       c->set_output(0, out);
1613 
1614       auto* shape_and_type = c->input_handle_shapes_and_types(0);
1615       if (shape_and_type) {
1616         c->set_output_handle_shapes_and_types(0, *shape_and_type);
1617       }
1618 
1619       return Status::OK();
1620     });
1621 
1622 REGISTER_OP("StridedSliceGrad")
1623     .Input("shape: Index")
1624     .Input("begin: Index")
1625     .Input("end: Index")
1626     .Input("strides: Index")
1627     .Input("dy: T")
1628     .Output("output: T")
1629     .Attr("T: type")
1630     .Attr("Index: {int32, int64}")
1631     .Attr("begin_mask: int = 0")
1632     .Attr("end_mask: int = 0")
1633     .Attr("ellipsis_mask: int = 0")
1634     .Attr("new_axis_mask: int = 0")
1635     .Attr("shrink_axis_mask: int = 0")
__anondb9326b22602(InferenceContext* c) 1636     .SetShapeFn([](InferenceContext* c) {
1637       ShapeHandle out;
1638       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
1639       c->set_output(0, out);
1640       return Status::OK();
1641     });
1642 
1643 REGISTER_OP("StridedSliceAssign")
1644     .Input("ref: Ref(T)")
1645     .Input("begin: Index")
1646     .Input("end: Index")
1647     .Input("strides: Index")
1648     .Input("value: T")
1649     .Output("output_ref: Ref(T)")
1650     .Attr("T: type")
1651     .Attr("Index: {int32, int64}")
1652     .Attr("begin_mask: int = 0")
1653     .Attr("end_mask: int = 0")
1654     .Attr("ellipsis_mask: int = 0")
1655     .Attr("new_axis_mask: int = 0")
1656     .Attr("shrink_axis_mask: int = 0")
1657     .SetShapeFn(shape_inference::UnchangedShape);
1658 // TODO(aselle): Fix this documentation once StridedSliceAssign Supports
1659 // broadcasting.
1660 // --------------------------------------------------------------------------
1661 
1662 REGISTER_OP("ResourceStridedSliceAssign")
1663     .Input("ref: resource")
1664     .Input("begin: Index")
1665     .Input("end: Index")
1666     .Input("strides: Index")
1667     .Input("value: T")
1668     .Attr("T: type")
1669     .Attr("Index: {int32, int64}")
1670     .Attr("begin_mask: int = 0")
1671     .Attr("end_mask: int = 0")
1672     .Attr("ellipsis_mask: int = 0")
1673     .Attr("new_axis_mask: int = 0")
1674     .Attr("shrink_axis_mask: int = 0")
1675     .SetShapeFn(shape_inference::NoOutputs);
1676 
1677 REGISTER_OP("Tile")
1678     .Input("input: T")
1679     .Input("multiples: Tmultiples")
1680     .Output("output: T")
1681     .Attr("T: type")
1682     .Attr("Tmultiples: {int32, int64} = DT_INT32")
__anondb9326b22702(InferenceContext* c) 1683     .SetShapeFn([](InferenceContext* c) {
1684       ShapeHandle input = c->input(0);
1685       // NOTE(mrry): Represent `multiples` as a `TensorShape` because (i)
1686       // it is a vector of non-negative integers, and (ii) doing so allows
1687       // us to handle partially-known multiples.
1688       ShapeHandle multiples;
1689       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &multiples));
1690       if (c->RankKnown(input)) {
1691         TF_RETURN_IF_ERROR(c->WithRank(multiples, c->Rank(input), &multiples));
1692         ShapeHandle dummy;
1693         TF_RETURN_IF_ERROR(
1694             c->Merge(c->input(1), c->Vector(c->Rank(input)), &dummy));
1695       }
1696 
1697       if (!c->RankKnown(multiples)) {
1698         return shape_inference::UnknownShape(c);
1699       }
1700 
1701       int32 rank = c->Rank(multiples);
1702       TF_RETURN_IF_ERROR(c->WithRank(input, rank, &input));
1703       std::vector<DimensionHandle> dims(rank);
1704       for (int i = 0; i < rank; ++i) {
1705         TF_RETURN_IF_ERROR(
1706             c->Multiply(c->Dim(input, i), c->Dim(multiples, i), &dims[i]));
1707       }
1708       c->set_output(0, c->MakeShape(dims));
1709       return Status::OK();
1710     });
1711 
1712 // --------------------------------------------------------------------------
1713 REGISTER_OP("TileGrad")
1714     .Input("input: T")
1715     .Input("multiples: int32")
1716     .Output("output: T")
1717     .Attr("T: type")
1718     .Deprecated(3, "TileGrad has been replaced with reduce_sum")
1719     .SetShapeFn(tensorflow::shape_inference::UnknownShape);
1720 
1721 // --------------------------------------------------------------------------
1722 REGISTER_OP("Where")
1723     .Input("input: T")
1724     .Attr("T: {numbertype, bool} = DT_BOOL")
1725     .Output("index: int64")
__anondb9326b22802(InferenceContext* c) 1726     .SetShapeFn([](InferenceContext* c) {
1727       c->set_output(0, c->Matrix(c->UnknownDim(), c->Rank(c->input(0))));
1728       return Status::OK();
1729     });
1730 
1731 // --------------------------------------------------------------------------
1732 REGISTER_OP("BroadcastArgs")
1733     .Input("s0: T")
1734     .Input("s1: T")
1735     .Output("r0: T")
1736     .Attr("T: {int32, int64} = DT_INT32")
__anondb9326b22902(InferenceContext* c) 1737     .SetShapeFn([](InferenceContext* c) {
1738       ShapeHandle unused;
1739       ShapeHandle shape_x = c->input(0);
1740       ShapeHandle shape_y = c->input(1);
1741       TF_RETURN_IF_ERROR(c->WithRank(shape_x, 1, &unused));
1742       TF_RETURN_IF_ERROR(c->WithRank(shape_y, 1, &unused));
1743 
1744       if (!c->ValueKnown(c->Dim(shape_x, 0)) ||
1745           !c->ValueKnown(c->Dim(shape_y, 0))) {
1746         c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
1747         return Status::OK();
1748       }
1749 
1750       int64 x_dim = c->Value(c->Dim(shape_x, 0));
1751       int64 y_dim = c->Value(c->Dim(shape_y, 0));
1752 
1753       // Broadcasted shape is going to be as large as the largest dimension.
1754       c->set_output(0, c->Vector(std::max(x_dim, y_dim)));
1755       return Status::OK();
1756     });
1757 
1758 // --------------------------------------------------------------------------
1759 REGISTER_OP("BroadcastGradientArgs")
1760     .Input("s0: T")
1761     .Input("s1: T")
1762     .Output("r0: T")
1763     .Output("r1: T")
1764     .Attr("T: {int32, int64} = DT_INT32")
__anondb9326b22a02(InferenceContext* c) 1765     .SetShapeFn([](InferenceContext* c) {
1766       // TODO(mrry): Implement constant_value for BroadcastGradientArgs?
1767       ShapeHandle unused;
1768       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
1769       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
1770       c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
1771       c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
1772       return Status::OK();
1773     });
1774 
1775 // --------------------------------------------------------------------------
1776 REGISTER_OP("Pad")
1777     .Input("input: T")
1778     .Input("paddings: Tpaddings")
1779     .Output("output: T")
1780     .Attr("T: type")
1781     .Attr("Tpaddings: {int32, int64} = DT_INT32")
1782     .SetShapeFn(PadShapeFn);
1783 
1784 // --------------------------------------------------------------------------
1785 REGISTER_OP("PadV2")
1786     .Input("input: T")
1787     .Input("paddings: Tpaddings")
1788     .Input("constant_values: T")
1789     .Output("output: T")
1790     .Attr("T: type")
1791     .Attr("Tpaddings: {int32, int64} = DT_INT32")
1792     .SetShapeFn(PadShapeFn);
1793 
1794 // --------------------------------------------------------------------------
1795 REGISTER_OP("MirrorPad")
1796     .Input("input: T")
1797     .Input("paddings: Tpaddings")
1798     .Output("output: T")
1799     .Attr("T: type")
1800     .Attr("Tpaddings: {int32, int64} = DT_INT32")
1801     .Attr(GetMirrorPadModeAttrString())
1802     .SetShapeFn(PadShapeFn);
1803 
1804 // --------------------------------------------------------------------------
1805 namespace {
1806 template <typename T>
MirrorPadKnown(InferenceContext * c,ShapeHandle input,const Tensor * paddings_t,int64 input_rank)1807 Status MirrorPadKnown(InferenceContext* c, ShapeHandle input,
1808                       const Tensor* paddings_t, int64 input_rank) {
1809   auto paddings_data = paddings_t->matrix<T>();
1810   std::vector<DimensionHandle> dims(input_rank);
1811   for (int64 i = 0; i < input_rank; ++i) {
1812     const int64 pad0 = static_cast<int64>(paddings_data(i, 0));
1813     const int64 pad1 = static_cast<int64>(paddings_data(i, 1));
1814     if (pad0 < 0 || pad1 < 0) {
1815       return errors::InvalidArgument("Paddings must be non-negative");
1816     }
1817 
1818     TF_RETURN_IF_ERROR(c->Subtract(c->Dim(input, i), pad0 + pad1, &dims[i]));
1819   }
1820   c->set_output(0, c->MakeShape(dims));
1821   return Status::OK();
1822 }
1823 
1824 }  // namespace
1825 
1826 REGISTER_OP("MirrorPadGrad")
1827     .Input("input: T")
1828     .Input("paddings: Tpaddings")
1829     .Output("output: T")
1830     .Attr("T: type")
1831     .Attr("Tpaddings: {int32, int64} = DT_INT32")
1832     .Attr(GetMirrorPadModeAttrString())
__anondb9326b22c02(InferenceContext* c) 1833     .SetShapeFn([](InferenceContext* c) {
1834       ShapeHandle paddings;
1835       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &paddings));
1836       DimensionHandle pad_0 = c->Dim(paddings, 0);
1837       if (!c->ValueKnown(pad_0)) {
1838         // We don't know the rank of the output since the first
1839         // padding dimension is unknown.
1840         c->set_output(0, c->UnknownShape());
1841         return Status::OK();
1842       }
1843 
1844       int64 input_rank = c->Value(pad_0);
1845       ShapeHandle input;
1846       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), input_rank, &input));
1847       TF_RETURN_IF_ERROR(
1848           c->Merge(paddings, c->Matrix(input_rank, 2), &paddings));
1849 
1850       const Tensor* paddings_t = c->input_tensor(1);
1851       if (paddings_t == nullptr) {
1852         // Values of 'paddings' is not available, but we know the
1853         // input rank, so return the rank of the output with unknown
1854         // dimensions.
1855         c->set_output(0, c->UnknownShapeOfRank(input_rank));
1856         return Status::OK();
1857       }
1858 
1859       if (paddings_t->dtype() == DT_INT32) {
1860         return MirrorPadKnown<int32>(c, input, paddings_t, input_rank);
1861       } else {
1862         return MirrorPadKnown<int64>(c, input, paddings_t, input_rank);
1863       }
1864     });
1865 
1866 // --------------------------------------------------------------------------
1867 REGISTER_OP("Placeholder")
1868     .Output("output: dtype")
1869     .Attr("dtype: type")
1870     .Attr("shape: shape = { unknown_rank: true }")
__anondb9326b22d02(InferenceContext* c) 1871     .SetShapeFn([](InferenceContext* c) {
1872       PartialTensorShape shape;
1873       TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
1874 
1875       // Placeholder has legacy behavior where we cannot tell the difference
1876       // between a scalar shape attribute and 'unknown shape'.  So if the shape
1877       // is a scalar, we return an unknown shape.
1878       if (c->graph_def_version() <= 21 && shape.dims() <= 0) {
1879         return shape_inference::UnknownShape(c);
1880       }
1881 
1882       ShapeHandle out;
1883       TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out));
1884       c->set_output(0, out);
1885       return Status::OK();
1886     });
1887 
1888 // Placeholder was modified in a backwards compatible way to do what
1889 // PlaceholderV2 did, so we have deprecated V2 (no one was really
1890 // using it).
1891 REGISTER_OP("PlaceholderV2")
1892     .Output("output: dtype")
1893     .Attr("dtype: type")
1894     .Attr("shape: shape")
1895     .SetShapeFn(shape_inference::ExplicitShape)
1896     .Deprecated(23, "Placeholder now behaves the same as PlaceholderV2.");
1897 
1898 // --------------------------------------------------------------------------
1899 REGISTER_OP("PlaceholderWithDefault")
1900     .Input("input: dtype")
1901     .Output("output: dtype")
1902     .Attr("dtype: type")
1903     .Attr("shape: shape")
__anondb9326b22e02(InferenceContext* c) 1904     .SetShapeFn([](InferenceContext* c) {
1905       ShapeHandle input = c->input(0);
1906       PartialTensorShape shape;
1907       TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
1908       ShapeHandle out;
1909       TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out));
1910 
1911       // We merge for compatibility checking, but return the output,
1912       // since output_shape may be less precise than input_shape.
1913       ShapeHandle unused;
1914       TF_RETURN_IF_ERROR(c->Merge(input, out, &unused));
1915       c->set_output(0, out);
1916       return Status::OK();
1917     });
1918 
1919 // --------------------------------------------------------------------------
1920 REGISTER_OP("ExpandDims")
1921     .Input("input: T")
1922     .Input("dim: Tdim")
1923     .Output("output: T")
1924     .Attr("T: type")
1925     .Attr("Tdim: {int32, int64} = DT_INT32")
__anondb9326b22f02(InferenceContext* c) 1926     .SetShapeFn([](InferenceContext* c) {
1927       ShapeHandle input = c->input(0);
1928 
1929       const Tensor* dim_t = c->input_tensor(1);
1930       if (dim_t != nullptr && dim_t->NumElements() != 1) {
1931         return errors::InvalidArgument(
1932             "'dim' input must be a tensor with a single value");
1933       }
1934       if (dim_t == nullptr || !c->RankKnown(input)) {
1935         c->set_output(0, c->UnknownShape());
1936         return Status::OK();
1937       }
1938 
1939       int64 dim;
1940       if (dim_t->dtype() == DT_INT32) {
1941         dim = static_cast<int64>(dim_t->flat<int32>()(0));
1942       } else {
1943         dim = dim_t->flat<int64>()(0);
1944       }
1945 
1946       const int32 rank = c->Rank(input);
1947       const int32 min_dim = -1 * rank - 1;
1948       if (dim < min_dim || dim > rank) {
1949         return errors::InvalidArgument("dim ", dim, " not in the interval [",
1950                                        min_dim, ", ", rank, "].");
1951       }
1952 
1953       if (dim < 0) {
1954         dim += rank + 1;
1955       }
1956 
1957       ShapeHandle end;
1958       TF_RETURN_IF_ERROR(c->Subshape(input, dim, &end));
1959 
1960       // Build output as start + 1 + end.
1961       ShapeHandle output;
1962       TF_RETURN_IF_ERROR(c->Subshape(input, 0, dim, &output));
1963       TF_RETURN_IF_ERROR(c->Concatenate(output, c->Vector(1), &output));
1964       TF_RETURN_IF_ERROR(c->Concatenate(output, end, &output));
1965       c->set_output(0, output);
1966       return Status::OK();
1967     });
1968 
1969 // --------------------------------------------------------------------------
1970 REGISTER_OP("Squeeze")
1971     .Input("input: T")
1972     .Output("output: T")
1973     .Attr("T: type")
1974     .Attr("squeeze_dims: list(int) >= 0 = []")
__anondb9326b23002(InferenceContext* c) 1975     .SetShapeFn([](InferenceContext* c) {
1976       ShapeHandle input = c->input(0);
1977       if (!c->RankKnown(input)) {
1978         // Input shape unknown.
1979         return shape_inference::UnknownShape(c);
1980       }
1981 
1982       const int32 input_rank = c->Rank(input);
1983 
1984       // Validate and wrap squeeze dimensions.
1985       std::vector<int32> squeeze_dims;
1986       TF_RETURN_IF_ERROR(c->GetAttr("squeeze_dims", &squeeze_dims));
1987       for (int i = 0; i < squeeze_dims.size(); ++i) {
1988         if (squeeze_dims[i] < -input_rank || squeeze_dims[i] >= input_rank) {
1989           return errors::InvalidArgument("squeeze_dims[", i, "] not in [",
1990                                          -input_rank, ",", input_rank, ").");
1991         }
1992 
1993         if (squeeze_dims[i] < 0) {
1994           squeeze_dims[i] += input_rank;
1995         }
1996       }
1997 
1998       std::vector<DimensionHandle> result_shape;
1999       for (int i = 0; i < input_rank; ++i) {
2000         // True if squeeze_dims contains an entry to squeeze this
2001         // dimension.
2002         bool is_explicit_match =
2003             std::find(squeeze_dims.begin(), squeeze_dims.end(), i) !=
2004             squeeze_dims.end();
2005 
2006         DimensionHandle dim = c->Dim(input, i);
2007 
2008         if (!c->ValueKnown(dim)) {
2009           // Assume that the squeezed dimension will be 1 at runtime.
2010           if (is_explicit_match) continue;
2011 
2012           // If squeezing all 1 dimensions, and we see an unknown value,
2013           // give up and return Unknown Shape.
2014           if (squeeze_dims.empty()) {
2015             c->set_output(0, c->UnknownShape());
2016             return Status::OK();
2017           }
2018         } else if (c->Value(dim) == 1) {
2019           if (is_explicit_match || squeeze_dims.empty()) {
2020             // If explicitly squeezing, or squeezing all 1s, remove
2021             // this dimension.
2022             continue;
2023           }
2024         } else if (is_explicit_match) {
2025           return errors::InvalidArgument("Can not squeeze dim[", i,
2026                                          "], expected a dimension of 1, got ",
2027                                          c->Value(c->Dim(input, i)));
2028         }
2029 
2030         result_shape.emplace_back(dim);
2031       }
2032 
2033       c->set_output(0, c->MakeShape(result_shape));
2034       return Status::OK();
2035     });
2036 
2037 // --------------------------------------------------------------------------
2038 REGISTER_OP("ListDiff")
2039     .Input("x: T")
2040     .Input("y: T")
2041     .Output("out: T")
2042     .Output("idx: out_idx")
2043     .Attr("T: type")
2044     .Attr("out_idx: {int32, int64} = DT_INT32")
__anondb9326b23102(InferenceContext* c) 2045     .SetShapeFn([](InferenceContext* c) {
2046       ShapeHandle unused;
2047       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
2048       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
2049       // TODO(mrry): Indicate that the length falls within an interval?
2050       ShapeHandle out = c->Vector(InferenceContext::kUnknownDim);
2051       c->set_output(0, out);
2052       c->set_output(1, out);
2053       return Status::OK();
2054     });
2055 
2056 namespace {
2057 
2058 // Converts Tensor to flat std::vector<int64>.
2059 template <typename InputType>
GetFlatInt64(const Tensor & t)2060 std::vector<int64> GetFlatInt64(const Tensor& t) {
2061   std::vector<int64> output(t.shape().num_elements());
2062   auto eigen_vec = t.flat<InputType>();
2063   std::copy_n(&eigen_vec(0), output.size(), output.begin());
2064   return output;
2065 }
2066 
2067 // Converts int32 or int64 Tensor to flat std::vector<int64>.
GetFlatInt64(const Tensor & t)2068 std::vector<int64> GetFlatInt64(const Tensor& t) {
2069   if (t.dtype() == DT_INT32) {
2070     return GetFlatInt64<int32>(t);
2071   } else {
2072     return GetFlatInt64<int64>(t);
2073   }
2074 }
2075 
SpaceToBatchShapeHelper(InferenceContext * c,ShapeHandle input_shape,ShapeHandle block_shape_shape,const Tensor * block_shape_t,ShapeHandle paddings_shape,const Tensor * paddings_t)2076 Status SpaceToBatchShapeHelper(InferenceContext* c, ShapeHandle input_shape,
2077                                ShapeHandle block_shape_shape,
2078                                const Tensor* block_shape_t,
2079                                ShapeHandle paddings_shape,
2080                                const Tensor* paddings_t) {
2081   if (c->Rank(block_shape_shape) != 1) {
2082     return errors::InvalidArgument("block_shape must have rank 1.");
2083   }
2084 
2085   const DimensionHandle num_block_dims_handle = c->Dim(block_shape_shape, 0);
2086   if (!c->ValueKnown(num_block_dims_handle)) {
2087     return errors::InvalidArgument("block_shape must have known size.");
2088   }
2089 
2090   const int64 num_block_dims = c->Value(num_block_dims_handle);
2091 
2092   TF_RETURN_IF_ERROR(
2093       c->WithRankAtLeast(input_shape, num_block_dims + 1, &input_shape));
2094 
2095   TF_RETURN_IF_ERROR(
2096       c->Merge(paddings_shape, c->Matrix(num_block_dims, 2), &paddings_shape));
2097 
2098   DimensionHandle batch_size = c->Dim(input_shape, 0);
2099   std::vector<int64> block_shape_vec;
2100   if (block_shape_t) {
2101     block_shape_vec = GetFlatInt64(*block_shape_t);
2102     for (int64 dim = 0; dim < num_block_dims; ++dim) {
2103       const int64 block_shape_value = block_shape_vec[dim];
2104       if (block_shape_value < 1) {
2105         return errors::InvalidArgument("block_shape must be positive");
2106       }
2107       if (c->ValueKnown(batch_size)) {
2108         TF_RETURN_IF_ERROR(
2109             c->Multiply(batch_size, block_shape_value, &batch_size));
2110       } else {
2111         batch_size = c->UnknownDim();
2112       }
2113     }
2114   } else if (num_block_dims > 0) {
2115     batch_size = c->UnknownDim();
2116   }
2117 
2118   std::vector<DimensionHandle> output_dims{batch_size};
2119   output_dims.resize(num_block_dims + 1, c->UnknownDim());
2120 
2121   if (paddings_t) {
2122     const std::vector<int64> paddings_vec = GetFlatInt64(*paddings_t);
2123     for (int64 dim = 0; dim < num_block_dims; ++dim) {
2124       const int64 pad_start = paddings_vec[dim * 2],
2125                   pad_end = paddings_vec[dim * 2 + 1];
2126       if (pad_start < 0 || pad_end < 0) {
2127         return errors::InvalidArgument("paddings cannot be negative");
2128       }
2129       if (block_shape_t) {
2130         DimensionHandle padded_size;
2131         TF_RETURN_IF_ERROR(
2132             c->Add(c->Dim(input_shape, dim + 1), pad_start, &padded_size));
2133         TF_RETURN_IF_ERROR(c->Add(padded_size, pad_end, &padded_size));
2134         TF_RETURN_IF_ERROR(c->Divide(padded_size, block_shape_vec[dim],
2135                                      /*evenly_divisible=*/true,
2136                                      &output_dims[dim + 1]));
2137       }
2138     }
2139   }
2140 
2141   ShapeHandle remaining_input_shape;
2142   TF_RETURN_IF_ERROR(
2143       c->Subshape(input_shape, 1 + num_block_dims, &remaining_input_shape));
2144 
2145   ShapeHandle result;
2146   TF_RETURN_IF_ERROR(c->Concatenate(c->MakeShape(output_dims),
2147                                     remaining_input_shape, &result));
2148   c->set_output(0, result);
2149   return Status::OK();
2150 }
2151 
BatchToSpaceShapeHelper(InferenceContext * c,ShapeHandle input_shape,ShapeHandle block_shape_shape,const Tensor * block_shape_t,ShapeHandle crops_shape,const Tensor * crops_t)2152 Status BatchToSpaceShapeHelper(InferenceContext* c, ShapeHandle input_shape,
2153                                ShapeHandle block_shape_shape,
2154                                const Tensor* block_shape_t,
2155                                ShapeHandle crops_shape, const Tensor* crops_t) {
2156   if (c->Rank(block_shape_shape) != 1) {
2157     return errors::InvalidArgument("block_shape must have rank 1.");
2158   }
2159 
2160   const DimensionHandle num_block_dims_handle = c->Dim(block_shape_shape, 0);
2161   if (!c->ValueKnown(num_block_dims_handle)) {
2162     return errors::InvalidArgument("block_shape must have known size.");
2163   }
2164 
2165   const int64 num_block_dims = c->Value(num_block_dims_handle);
2166 
2167   TF_RETURN_IF_ERROR(
2168       c->WithRankAtLeast(input_shape, num_block_dims + 1, &input_shape));
2169 
2170   TF_RETURN_IF_ERROR(
2171       c->Merge(crops_shape, c->Matrix(num_block_dims, 2), &crops_shape));
2172 
2173   DimensionHandle batch_size = c->Dim(input_shape, 0);
2174   std::vector<int64> block_shape_vec;
2175   if (block_shape_t) {
2176     block_shape_vec = GetFlatInt64(*block_shape_t);
2177     for (int64 dim = 0; dim < num_block_dims; ++dim) {
2178       const int64 block_shape_value = block_shape_vec[dim];
2179       if (block_shape_value < 1) {
2180         return errors::InvalidArgument("block_shape must be positive");
2181       }
2182       if (c->ValueKnown(batch_size)) {
2183         TF_RETURN_IF_ERROR(c->Divide(batch_size, block_shape_value,
2184                                      /*evenly_divisible=*/true, &batch_size));
2185       } else {
2186         batch_size = c->UnknownDim();
2187       }
2188     }
2189   } else if (num_block_dims > 0) {
2190     batch_size = c->UnknownDim();
2191   }
2192 
2193   std::vector<DimensionHandle> output_dims{batch_size};
2194   output_dims.resize(num_block_dims + 1, c->UnknownDim());
2195 
2196   if (crops_t) {
2197     const std::vector<int64> crops_vec = GetFlatInt64(*crops_t);
2198     for (int64 dim = 0; dim < num_block_dims; ++dim) {
2199       const int64 crop_start = crops_vec[dim * 2],
2200                   crop_end = crops_vec[dim * 2 + 1];
2201       if (crop_start < 0 || crop_end < 0) {
2202         return errors::InvalidArgument("crops cannot be negative");
2203       }
2204       if (block_shape_t) {
2205         DimensionHandle cropped_size;
2206         TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input_shape, dim + 1),
2207                                        block_shape_vec[dim], &cropped_size));
2208         TF_RETURN_IF_ERROR(
2209             c->Subtract(cropped_size, crop_start, &cropped_size));
2210         TF_RETURN_IF_ERROR(
2211             c->Subtract(cropped_size, crop_end, &output_dims[dim + 1]));
2212       }
2213     }
2214   }
2215 
2216   ShapeHandle remaining_input_shape;
2217   TF_RETURN_IF_ERROR(
2218       c->Subshape(input_shape, 1 + num_block_dims, &remaining_input_shape));
2219 
2220   ShapeHandle result;
2221   TF_RETURN_IF_ERROR(c->Concatenate(c->MakeShape(output_dims),
2222                                     remaining_input_shape, &result));
2223   c->set_output(0, result);
2224   return Status::OK();
2225 }
2226 
2227 }  // namespace
2228 
2229 // --------------------------------------------------------------------------
2230 REGISTER_OP("SpaceToBatchND")
2231     .Input("input: T")
2232     .Input("block_shape: Tblock_shape")
2233     .Input("paddings: Tpaddings")
2234     .Output("output: T")
2235     .Attr("T: type")
2236     .Attr("Tblock_shape: {int32, int64} = DT_INT32")
2237     .Attr("Tpaddings: {int32, int64} = DT_INT32")
__anondb9326b23302(InferenceContext* c) 2238     .SetShapeFn([](InferenceContext* c) {
2239       return SpaceToBatchShapeHelper(c, c->input(0), c->input(1),
2240                                      c->input_tensor(1), c->input(2),
2241                                      c->input_tensor(2));
2242     });
2243 
2244 // --------------------------------------------------------------------------
2245 REGISTER_OP("SpaceToBatch")
2246     .Input("input: T")
2247     .Input("paddings: Tpaddings")
2248     .Output("output: T")
2249     .Attr("T: type")
2250     .Attr("Tpaddings: {int32, int64} = DT_INT32")
2251     .Attr("block_size: int >= 2")
__anondb9326b23402(InferenceContext* c) 2252     .SetShapeFn([](InferenceContext* c) {
2253       ShapeHandle input_shape;
2254       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
2255 
2256       int32 block_size;
2257       TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
2258 
2259       Tensor block_shape(tensorflow::DT_INT64, TensorShape({2}));
2260       auto block_shape_vec = block_shape.vec<int64>();
2261       block_shape_vec(0) = block_size;
2262       block_shape_vec(1) = block_size;
2263 
2264       return SpaceToBatchShapeHelper(c, input_shape, c->MakeShape({2}),
2265                                      &block_shape, c->input(1),
2266                                      c->input_tensor(1));
2267     });
2268 
2269 // --------------------------------------------------------------------------
2270 REGISTER_OP("BatchToSpaceND")
2271     .Input("input: T")
2272     .Input("block_shape: Tblock_shape")
2273     .Input("crops: Tcrops")
2274     .Output("output: T")
2275     .Attr("T: type")
2276     .Attr("Tblock_shape: {int32, int64} = DT_INT32")
2277     .Attr("Tcrops: {int32, int64} = DT_INT32")
__anondb9326b23502(InferenceContext* c) 2278     .SetShapeFn([](InferenceContext* c) {
2279       return BatchToSpaceShapeHelper(c, c->input(0), c->input(1),
2280                                      c->input_tensor(1), c->input(2),
2281                                      c->input_tensor(2));
2282     });
2283 
2284 // --------------------------------------------------------------------------
2285 REGISTER_OP("BatchToSpace")
2286     .Input("input: T")
2287     .Input("crops: Tidx")
2288     .Output("output: T")
2289     .Attr("T: type")
2290     .Attr("block_size: int >= 2")
2291     .Attr("Tidx: {int32, int64} = DT_INT32")
__anondb9326b23602(InferenceContext* c) 2292     .SetShapeFn([](InferenceContext* c) {
2293       ShapeHandle input_shape;
2294       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
2295 
2296       int32 block_size;
2297       TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
2298 
2299       Tensor block_shape(tensorflow::DT_INT64, TensorShape({2}));
2300       auto block_shape_vec = block_shape.vec<int64>();
2301       block_shape_vec(0) = block_size;
2302       block_shape_vec(1) = block_size;
2303 
2304       return BatchToSpaceShapeHelper(c, input_shape, c->MakeShape({2}),
2305                                      &block_shape, c->input(1),
2306                                      c->input_tensor(1));
2307     });
2308 
2309 // --------------------------------------------------------------------------
2310 REGISTER_OP("SpaceToDepth")
2311     .Input("input: T")
2312     .Output("output: T")
2313     .Attr("T: type")
2314     .Attr("block_size: int >= 2")
2315     .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
2316     // TODO(pauldonnelly): Implement GPU kernels for NCHW_VECT_C.
__anondb9326b23702(InferenceContext* c) 2317     .SetShapeFn([](InferenceContext* c) {
2318       string data_format_str;
2319       TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
2320       TensorFormat data_format;
2321       FormatFromString(data_format_str, &data_format);
2322 
2323       constexpr int num_spatial_dims = 2;
2324       const int dims =
2325           GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
2326       ShapeHandle input;
2327       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), dims, &input));
2328 
2329       int32 block_size;
2330       TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
2331 
2332       DimensionHandle batch_size =
2333           c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
2334       DimensionHandle input_height =
2335           c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
2336       DimensionHandle input_width =
2337           c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
2338       DimensionHandle input_depth =
2339           c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
2340 
2341       DimensionHandle output_height;
2342       DimensionHandle output_width;
2343       DimensionHandle output_depth;
2344       // Will return an error if input height or width are not evenly divisible.
2345       TF_RETURN_IF_ERROR(c->Divide(input_height, block_size,
2346                                    true /* evenly_divisible */,
2347                                    &output_height));
2348       TF_RETURN_IF_ERROR(c->Divide(input_width, block_size,
2349                                    true /* evenly_divisible */, &output_width));
2350 
2351       TF_RETURN_IF_ERROR(
2352           c->Multiply(input_depth, block_size * block_size, &output_depth));
2353 
2354       ShapeHandle output_shape;
2355       TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size,
2356                                              {output_height, output_width},
2357                                              output_depth, &output_shape, c));
2358 
2359       c->set_output(0, output_shape);
2360       return Status::OK();
2361     });
2362 
2363 // --------------------------------------------------------------------------
2364 REGISTER_OP("DepthToSpace")
2365     .Input("input: T")
2366     .Output("output: T")
2367     .Attr("T: type")
2368     .Attr("block_size: int >= 2")
2369     .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
2370     // TODO(pauldonnelly): Implement GPU kernels for NCHW and NCHW_VECT_C.
__anondb9326b23802(InferenceContext* c) 2371     .SetShapeFn([](InferenceContext* c) {
2372       string data_format_str;
2373       TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
2374       TensorFormat data_format;
2375       FormatFromString(data_format_str, &data_format);
2376 
2377       constexpr int num_spatial_dims = 2;
2378       const int dims =
2379           GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
2380 
2381       ShapeHandle input;
2382       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), dims, &input));
2383 
2384       int32 block_size;
2385       TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
2386 
2387       DimensionHandle batch_size =
2388           c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
2389       DimensionHandle input_height =
2390           c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
2391       DimensionHandle input_width =
2392           c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
2393       DimensionHandle input_depth =
2394           c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
2395 
2396       DimensionHandle output_height;
2397       DimensionHandle output_width;
2398       DimensionHandle output_depth;
2399       TF_RETURN_IF_ERROR(c->Multiply(input_height, block_size, &output_height));
2400       TF_RETURN_IF_ERROR(c->Multiply(input_width, block_size, &output_width));
2401 
2402       // Will return an error if input_depth is not evenly divisible.
2403       TF_RETURN_IF_ERROR(c->Divide(input_depth, block_size * block_size,
2404                                    true /* evenly_divisible */, &output_depth));
2405 
2406       ShapeHandle output_shape;
2407       TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size,
2408                                              {output_height, output_width},
2409                                              output_depth, &output_shape, c));
2410 
2411       c->set_output(0, output_shape);
2412       return Status::OK();
2413     });
2414 
2415 // --------------------------------------------------------------------------
2416 
2417 REGISTER_OP("ExtractImagePatches")
2418     .Input("images: T")
2419     .Output("patches: T")
2420     .Attr("ksizes: list(int) >= 4")
2421     .Attr("strides: list(int) >= 4")
2422     .Attr("rates: list(int) >= 4")
2423     .Attr("T: realnumbertype")
2424     .Attr(GetPaddingAttrString())
__anondb9326b23902(InferenceContext* c) 2425     .SetShapeFn([](InferenceContext* c) {
2426       ShapeHandle input_shape;
2427       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
2428 
2429       std::vector<int32> ksizes;
2430       TF_RETURN_IF_ERROR(c->GetAttr("ksizes", &ksizes));
2431       if (ksizes.size() != 4) {
2432         return errors::InvalidArgument(
2433             "ExtractImagePatches requires the ksizes attribute to contain 4 "
2434             "values, but got: ",
2435             ksizes.size());
2436       }
2437 
2438       std::vector<int32> strides;
2439       TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
2440       if (strides.size() != 4) {
2441         return errors::InvalidArgument(
2442             "ExtractImagePatches requires the stride attribute to contain 4 "
2443             "values, but got: ",
2444             strides.size());
2445       }
2446 
2447       std::vector<int32> rates;
2448       TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates));
2449       if (rates.size() != 4) {
2450         return errors::InvalidArgument(
2451             "ExtractImagePatches requires the rates attribute to contain 4 "
2452             "values, but got: ",
2453             rates.size());
2454       }
2455 
2456       int32 ksize_rows = ksizes[1];
2457       int32 ksize_cols = ksizes[2];
2458 
2459       int32 stride_rows = strides[1];
2460       int32 stride_cols = strides[2];
2461 
2462       int32 rate_rows = rates[1];
2463       int32 rate_cols = rates[2];
2464 
2465       int32 ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
2466       int32 ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
2467 
2468       DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
2469       DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
2470       DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
2471       DimensionHandle output_depth_dim;
2472       TF_RETURN_IF_ERROR(c->Multiply(
2473           c->Dim(input_shape, 3), ksize_rows * ksize_cols, &output_depth_dim));
2474 
2475       if (!c->ValueKnown(in_rows_dim) || !c->ValueKnown(in_cols_dim)) {
2476         ShapeHandle output_shape =
2477             c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim,
2478                           InferenceContext::kUnknownDim, output_depth_dim});
2479         c->set_output(0, output_shape);
2480         return Status::OK();
2481       }
2482       auto in_rows = c->Value(in_rows_dim);
2483       auto in_cols = c->Value(in_cols_dim);
2484 
2485       Padding padding;
2486       TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
2487 
2488       int64 output_rows, output_cols;
2489       int64 padding_before, padding_after;
2490       TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
2491           in_rows, ksize_rows_eff, stride_rows, padding, &output_rows,
2492           &padding_before, &padding_after));
2493       TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
2494           in_cols, ksize_cols_eff, stride_cols, padding, &output_cols,
2495           &padding_before, &padding_after));
2496       ShapeHandle output_shape = c->MakeShape(
2497           {batch_size_dim, output_rows, output_cols, output_depth_dim});
2498       c->set_output(0, output_shape);
2499       return Status::OK();
2500     });
2501 
2502 // --------------------------------------------------------------------------
2503 
2504 // To enable rates, uncomment all lines commented below and use ksize_*_eff
2505 // as the second parameter of all GetWindowedOutputSizeVerbose calls instead
2506 // of ksize_*.
2507 REGISTER_OP("ExtractVolumePatches")
2508     .Input("input: T")
2509     .Output("patches: T")
2510     .Attr("ksizes: list(int) >= 5")
2511     .Attr("strides: list(int) >= 5")
2512     /* .Attr("rates: list(int) >= 5") */
2513     .Attr("T: realnumbertype")
2514     .Attr(GetPaddingAttrString())
__anondb9326b23a02(InferenceContext* c) 2515     .SetShapeFn([](InferenceContext* c) {
2516       ShapeHandle input_shape;
2517       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
2518 
2519       std::vector<int32> ksizes;
2520       TF_RETURN_IF_ERROR(c->GetAttr("ksizes", &ksizes));
2521       if (ksizes.size() != 5) {
2522         return errors::InvalidArgument(
2523             "ExtractVolumePatches requires the ksizes attribute to contain 5 "
2524             "values, but got: ",
2525             ksizes.size());
2526       }
2527 
2528       std::vector<int32> strides;
2529       TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
2530       if (strides.size() != 5) {
2531         return errors::InvalidArgument(
2532             "ExtractVolumePatches requires the stride attribute to contain 5 "
2533             "values, but got: ",
2534             strides.size());
2535       }
2536 
2537       /*
2538       // TODO(hsgkim): Enable rates.
2539       // See extract_volume_patches_op.cc for why rates are disabled now.
2540 
2541       std::vector<int32> rates;
2542       TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates));
2543       if (rates.size() != 5) {
2544         return errors::InvalidArgument(
2545             "ExtractVolumePatches requires the rates attribute to contain 5 "
2546             "values, but got: ",
2547             rates.size());
2548       }
2549       */
2550 
2551       int32 ksize_planes = ksizes[1];
2552       int32 ksize_rows = ksizes[2];
2553       int32 ksize_cols = ksizes[3];
2554 
2555       int32 stride_planes = strides[1];
2556       int32 stride_rows = strides[2];
2557       int32 stride_cols = strides[3];
2558 
2559       /*
2560       int32 rate_planes = rates[1];
2561       int32 rate_rows = rates[2];
2562       int32 rate_cols = rates[3];
2563 
2564       int32 ksize_planes_eff = ksize_planes +
2565                                (ksize_planes - 1) * (rate_planes - 1);
2566       int32 ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
2567       int32 ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
2568       */
2569 
2570       DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
2571       DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
2572       DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
2573       DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
2574       DimensionHandle output_depth_dim;
2575       TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input_shape, 4),
2576                                      ksize_planes * ksize_rows * ksize_cols,
2577                                      &output_depth_dim));
2578 
2579       if (!c->ValueKnown(in_planes_dim) || !c->ValueKnown(in_rows_dim) ||
2580           !c->ValueKnown(in_cols_dim)) {
2581         ShapeHandle output_shape =
2582             c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim,
2583                           InferenceContext::kUnknownDim, output_depth_dim});
2584         c->set_output(0, output_shape);
2585         return Status::OK();
2586       }
2587       auto in_planes = c->Value(in_planes_dim);
2588       auto in_rows = c->Value(in_rows_dim);
2589       auto in_cols = c->Value(in_cols_dim);
2590 
2591       Padding padding;
2592       TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
2593 
2594       int64 output_planes, output_rows, output_cols;
2595       int64 padding_before, padding_after;
2596       TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
2597           in_planes, ksize_planes, stride_planes, padding, &output_planes,
2598           &padding_before, &padding_after));
2599       TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
2600           in_rows, ksize_rows, stride_rows, padding, &output_rows,
2601           &padding_before, &padding_after));
2602       TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
2603           in_cols, ksize_cols, stride_cols, padding, &output_cols,
2604           &padding_before, &padding_after));
2605       ShapeHandle output_shape =
2606           c->MakeShape({batch_size_dim, output_planes, output_rows, output_cols,
2607                         output_depth_dim});
2608       c->set_output(0, output_shape);
2609       return Status::OK();
2610     });
2611 
2612 // --------------------------------------------------------------------------
2613 
2614 REGISTER_OP("Bitcast")
2615     .Input("input: T")
2616     .Output("output: type")
2617     // All supported dtypes are listed here to include qint16, quint16, uint32,
2618     // and uint64.
2619     .Attr(
2620         "T: {bfloat16, half, float, double, int64, int32, uint8, uint16, "
2621         "uint32, uint64, int8, int16, complex64, complex128, qint8, quint8, "
2622         "qint16, quint16, qint32}")
2623     .Attr(
2624         "type: {bfloat16, half, float, double, int64, int32, uint8, uint16, "
2625         "uint32, uint64, int8, int16, complex64, complex128, qint8, quint8, "
2626         "qint16, quint16, qint32}")
__anondb9326b23b02(InferenceContext* c) 2627     .SetShapeFn([](InferenceContext* c) {
2628       ShapeHandle input = c->input(0);
2629       if (!c->RankKnown(input)) {
2630         // Input shape unknown.
2631         return shape_inference::UnknownShape(c);
2632       }
2633 
2634       // Find the size of the input and output data types.
2635       DataType input_type;
2636       DataType output_type;
2637       TF_RETURN_IF_ERROR(c->GetAttr("T", &input_type));
2638       TF_RETURN_IF_ERROR(c->GetAttr("type", &output_type));
2639       const int input_type_size = DataTypeSize(input_type);
2640       const int output_type_size = DataTypeSize(output_type);
2641 
2642       if (input_type_size == 0 || output_type_size == 0) {
2643         return errors::InvalidArgument("Cannot bitcast types ",
2644                                        DataTypeString(input_type), " to ",
2645                                        DataTypeString(output_type),
2646                                        " because "
2647                                        "one of the type sizes is zero.");
2648       }
2649 
2650       ShapeHandle new_shape;
2651       if (input_type_size == output_type_size) {
2652         // No change in size.
2653         new_shape = input;
2654       } else if (input_type_size < output_type_size) {
2655         TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, 1, &new_shape));
2656 
2657         int64 divisor_val = output_type_size / input_type_size;
2658         DimensionHandle last_dim = c->Dim(new_shape, -1);
2659         if (!c->ValueKnown(last_dim) || c->Value(last_dim) == divisor_val) {
2660           TF_RETURN_IF_ERROR(c->Subshape(new_shape, 0, -1, &new_shape));
2661         } else {
2662           return errors::InvalidArgument("Cannot bitcast due to shape. ",
2663                                          c->Value(last_dim), " does not match ",
2664                                          divisor_val);
2665         }
2666       } else {
2667         // Input type size is larger than output type size.
2668         int64 divisor_val = input_type_size / output_type_size;
2669         ShapeHandle extension = c->Vector(divisor_val);
2670         TF_RETURN_IF_ERROR(c->Concatenate(input, extension, &new_shape));
2671       }
2672 
2673       c->set_output(0, new_shape);
2674       return Status::OK();
2675     });
2676 
2677 REGISTER_OP("OneHot")
2678     .Input("indices: TI")
2679     .Input("depth: int32")
2680     .Input("on_value: T")
2681     .Input("off_value: T")
2682     .Attr("axis: int = -1")
2683     .Output("output: T")
2684     .Attr("T: type")
2685     .Attr("TI: {uint8, int32, int64} = DT_INT64")
__anondb9326b23c02(InferenceContext* c) 2686     .SetShapeFn([](InferenceContext* c) {
2687       int32 axis;
2688       TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis));
2689       if (axis < -1) return errors::InvalidArgument("axis must be >= -1");
2690 
2691       DimensionHandle depth;
2692       TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &depth));
2693 
2694       ShapeHandle indices = c->input(0);
2695       if (!c->RankKnown(indices)) return shape_inference::UnknownShape(c);
2696 
2697       int32 new_rank = c->Rank(indices) + 1;
2698       // We need to add new_rank to axis in the case the axis is -1 because
2699       // C++ returns negative values from % if the dividend is negative.
2700       int32 depth_index = (axis + new_rank) % new_rank;
2701       // Out shape is indices[0:depth_index] + [depth] + indices[depth_index:].
2702       ShapeHandle front;
2703       ShapeHandle back;
2704       ShapeHandle out;
2705       TF_RETURN_IF_ERROR(c->Subshape(indices, 0, depth_index, &front));
2706       TF_RETURN_IF_ERROR(c->Subshape(indices, depth_index, &back));
2707       TF_RETURN_IF_ERROR(c->Concatenate(front, c->Vector(depth), &front));
2708       TF_RETURN_IF_ERROR(c->Concatenate(front, back, &out));
2709       c->set_output(0, out);
2710       return Status::OK();
2711     });
2712 
2713 // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET.
2714 REGISTER_OP("QuantizeAndDequantize")
2715     .Input("input: T")
2716     .Attr("signed_input: bool = true")
2717     .Attr("num_bits: int = 8")
2718     .Attr("range_given: bool = false")
2719     .Attr("input_min: float = 0")
2720     .Attr("input_max: float = 0")
2721     .Output("output: T")
2722     .Attr("T: {bfloat16, half, float, double}")
2723     .SetShapeFn(shape_inference::UnchangedShape)
2724     .Deprecated(22, "Replaced by QuantizeAndDequantizeV2");
2725 
2726 // TODO(suharshs): Deprecate QuantizeAndDequantizeV2.
2727 REGISTER_OP("QuantizeAndDequantizeV2")
2728     .Input("input: T")
2729     .Input("input_min: T")
2730     .Input("input_max: T")
2731     .Attr("signed_input: bool = true")
2732     .Attr("num_bits: int = 8")
2733     .Attr("range_given: bool = false")
2734     .Output("output: T")
2735     .Attr("T: {bfloat16, half, float, double}")
2736     .Attr(
2737         "round_mode: {'HALF_TO_EVEN', 'HALF_UP'} = "
2738         "'HALF_TO_EVEN'")
__anondb9326b23d02(InferenceContext* c) 2739     .SetShapeFn([](InferenceContext* c) {
2740       ShapeHandle unused;
2741       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
2742       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2743       c->set_output(0, c->input(0));
2744       return Status::OK();
2745     });
2746 
2747 REGISTER_OP("QuantizeAndDequantizeV3")
2748     .Input("input: T")
2749     .Input("input_min: T")
2750     .Input("input_max: T")
2751     .Input("num_bits: int32")
2752     .Attr("signed_input: bool = true")
2753     .Attr("range_given: bool = true")
2754     .Output("output: T")
2755     .Attr("T: {bfloat16, half, float, double}")
__anondb9326b23e02(InferenceContext* c) 2756     .SetShapeFn([](InferenceContext* c) {
2757       ShapeHandle unused;
2758       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
2759       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2760       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2761       c->set_output(0, c->input(0));
2762       return Status::OK();
2763     });
2764 
2765 REGISTER_OP("QuantizeV2")
2766     .Input("input: float")
2767     .Input("min_range: float")
2768     .Input("max_range: float")
2769     .Output("output: T")
2770     .Output("output_min: float")
2771     .Output("output_max: float")
2772     .Attr("T: quantizedtype")
2773     .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'MIN_COMBINED'")
2774     .Attr(
2775         "round_mode: {'HALF_AWAY_FROM_ZERO', 'HALF_TO_EVEN'} = "
2776         "'HALF_AWAY_FROM_ZERO'")
__anondb9326b23f02(InferenceContext* c) 2777     .SetShapeFn([](InferenceContext* c) {
2778       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
2779       ShapeHandle unused;
2780       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
2781       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2782       c->set_output(1, c->Scalar());
2783       c->set_output(2, c->Scalar());
2784       return Status::OK();
2785     });
2786 
2787 REGISTER_OP("Dequantize")
2788     .Input("input: T")
2789     .Input("min_range: float")
2790     .Input("max_range: float")
2791     .Output("output: float")
2792     .Attr("T: quantizedtype")
2793     .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'MIN_COMBINED'")
__anondb9326b24002(InferenceContext* c) 2794     .SetShapeFn([](InferenceContext* c) {
2795       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
2796       ShapeHandle unused;
2797       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
2798       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2799       return Status::OK();
2800     });
2801 
2802 REGISTER_OP("QuantizedConcat")
2803     .Input("concat_dim: int32")
2804     .Input("values: N * T")
2805     .Input("input_mins: N * float32")
2806     .Input("input_maxes: N * float32")
2807     .Output("output: T")
2808     .Output("output_min: float")
2809     .Output("output_max: float")
2810     .Attr("N: int >= 2")
2811     .Attr("T: type")
__anondb9326b24102(InferenceContext* c) 2812     .SetShapeFn([](InferenceContext* c) {
2813       const int n = (c->num_inputs() - 1) / 3;
2814       TF_RETURN_IF_ERROR(shape_inference::ConcatShape(c, n));
2815       ShapeHandle unused;
2816       for (int i = n + 1; i < c->num_inputs(); ++i) {
2817         TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused));
2818       }
2819       c->set_output(1, c->Scalar());
2820       c->set_output(2, c->Scalar());
2821       return Status::OK();
2822     });
2823 
2824 REGISTER_OP("QuantizedReshape")
2825     .Input("tensor: T")
2826     .Input("shape: Tshape")
2827     .Input("input_min: float")
2828     .Input("input_max: float")
2829     .Output("output: T")
2830     .Output("output_min: float")
2831     .Output("output_max: float")
2832     .Attr("T: type")
2833     .Attr("Tshape: {int32, int64} = DT_INT32")
__anondb9326b24202(InferenceContext* c) 2834     .SetShapeFn([](InferenceContext* c) {
2835       TF_RETURN_IF_ERROR(SetOutputShapeForReshape(c));
2836       ShapeHandle unused;
2837       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2838       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2839       c->set_output(1, c->Scalar());
2840       c->set_output(2, c->Scalar());
2841       return Status::OK();
2842     });
2843 
2844 REGISTER_OP("QuantizedInstanceNorm")
2845     .Input("x: T")
2846     .Input("x_min: float")
2847     .Input("x_max: float")
2848     .Output("y: T")
2849     .Output("y_min: float")
2850     .Output("y_max: float")
2851     .Attr("T: quantizedtype")
2852     .Attr("output_range_given: bool = false")
2853     .Attr("given_y_min: float = 0")
2854     .Attr("given_y_max: float = 0")
2855     .Attr("variance_epsilon: float = 1e-5")
2856     .Attr("min_separation: float = 1e-3")
__anondb9326b24302(shape_inference::InferenceContext* c) 2857     .SetShapeFn([](shape_inference::InferenceContext* c) {
2858       shape_inference::ShapeHandle unused;
2859       // x should be a rank 4 tensor.
2860       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &unused));
2861       // Assert x_min and x_max are scalars (rank 0).
2862       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
2863       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2864       // y has the same shape as x.
2865       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
2866       // y_min and y_max are scalars.
2867       c->set_output(1, c->Scalar());
2868       c->set_output(2, c->Scalar());
2869       return Status::OK();
2870     });
2871 
2872 namespace {
2873 
ScatterNdShapeHelper(InferenceContext * c,ShapeHandle indices_shape,ShapeHandle updates_shape,ShapeHandle output_shape)2874 Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape,
2875                             ShapeHandle updates_shape,
2876                             ShapeHandle output_shape) {
2877   if (c->Value(c->NumElements(output_shape)) == 0 &&
2878       (c->Value(c->NumElements(indices_shape)) > 0 ||
2879        c->Value(c->NumElements(updates_shape)) > 0)) {
2880     return errors::InvalidArgument(
2881         "Indices and updates specified for empty output shape");
2882   }
2883 
2884   if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) {
2885     const int64 outer_dims = c->Rank(indices_shape) - 1;
2886     const DimensionHandle ixdim = c->Dim(indices_shape, -1);
2887 
2888     // We can only do more validation if the last dimension of indices
2889     // is a known value.
2890     if (c->ValueKnown(ixdim)) {
2891       int64 ix = c->Value(ixdim);
2892       ShapeHandle unused;
2893       ShapeHandle prefix_indices;
2894       TF_RETURN_IF_ERROR(
2895           c->Subshape(indices_shape, 0, outer_dims, &prefix_indices));
2896       ShapeHandle prefix_updates;
2897       TF_RETURN_IF_ERROR(
2898           c->Subshape(updates_shape, 0, outer_dims, &prefix_updates));
2899 
2900       Status s = c->Merge(prefix_indices, prefix_updates, &unused);
2901       if (!s.ok()) {
2902         return errors::InvalidArgument(
2903             "The outer ", outer_dims,
2904             " dimensions of indices.shape=", c->DebugString(indices_shape),
2905             " must match the outer ", outer_dims,
2906             " dimensions of updates.shape=", c->DebugString(updates_shape),
2907             ": ", s.error_message());
2908       }
2909 
2910       ShapeHandle suffix_output;
2911       TF_RETURN_IF_ERROR(c->Subshape(output_shape, ix, &suffix_output));
2912       ShapeHandle suffix_updates;
2913       TF_RETURN_IF_ERROR(
2914           c->Subshape(updates_shape, outer_dims, &suffix_updates));
2915       s = c->Merge(suffix_output, suffix_updates, &unused);
2916       if (!s.ok()) {
2917         return errors::InvalidArgument(
2918             "The inner ", c->Rank(output_shape) - ix,
2919             " dimensions of output.shape=", c->DebugString(output_shape),
2920             " must match the inner ", c->Rank(updates_shape) - outer_dims,
2921             " dimensions of updates.shape=", c->DebugString(updates_shape),
2922             ": ", s.error_message());
2923       }
2924     }
2925   }
2926 
2927   c->set_output(0, output_shape);
2928   return Status::OK();
2929 }
2930 
ScatterNdShape(InferenceContext * c)2931 Status ScatterNdShape(InferenceContext* c) {
2932   ShapeHandle indices_shape;
2933   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &indices_shape));
2934   ShapeHandle updates_shape;
2935   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &updates_shape));
2936   ShapeHandle output_shape;
2937   TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &output_shape));
2938   return ScatterNdShapeHelper(c, indices_shape, updates_shape, output_shape);
2939 }
2940 
ScatterNdTensorShape(InferenceContext * c)2941 Status ScatterNdTensorShape(InferenceContext* c) {
2942   ShapeHandle output_shape;
2943   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &output_shape));
2944   ShapeHandle indices_shape;
2945   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape));
2946   ShapeHandle updates_shape;
2947   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape));
2948   return ScatterNdShapeHelper(c, indices_shape, updates_shape, output_shape);
2949 }
2950 
2951 }  // namespace
2952 
2953 REGISTER_OP("UpperBound")
2954     .Input("sorted_inputs: T")
2955     .Input("values: T")
2956     .Output("output: out_type")
2957     .Attr("T: type")
2958     .Attr("out_type: {int32, int64} = DT_INT32")
__anondb9326b24502(InferenceContext* c) 2959     .SetShapeFn([](InferenceContext* c) {
2960       ShapeHandle unused_shape;
2961       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused_shape));
2962       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape));
2963       c->set_output(0, c->input(1));
2964       return Status::OK();
2965     });
2966 
2967 REGISTER_OP("LowerBound")
2968     .Input("sorted_inputs: T")
2969     .Input("values: T")
2970     .Output("output: out_type")
2971     .Attr("T: type")
2972     .Attr("out_type: {int32, int64} = DT_INT32")
__anondb9326b24602(InferenceContext* c) 2973     .SetShapeFn([](InferenceContext* c) {
2974       ShapeHandle unused_shape;
2975       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused_shape));
2976       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape));
2977       c->set_output(0, c->input(1));
2978       return Status::OK();
2979     });
2980 
2981 REGISTER_OP("ScatterNd")
2982     .Input("indices: Tindices")
2983     .Input("updates: T")
2984     .Input("shape: Tindices")
2985     .Output("output: T")
2986     .Attr("T: type")
2987     .Attr("Tindices: {int32, int64}")
2988     .SetShapeFn(ScatterNdShape);
2989 
2990 REGISTER_OP("TensorScatterUpdate")
2991     .Input("tensor: T")
2992     .Input("indices: Tindices")
2993     .Input("updates: T")
2994     .Output("output: T")
2995     .Attr("T: type")
2996     .Attr("Tindices: {int32, int64}")
2997     .SetShapeFn(ScatterNdTensorShape);
2998 
2999 REGISTER_OP("TensorScatterAdd")
3000     .Input("tensor: T")
3001     .Input("indices: Tindices")
3002     .Input("updates: T")
3003     .Output("output: T")
3004     .Attr("T: type")
3005     .Attr("Tindices: {int32, int64}")
3006     .SetShapeFn(ScatterNdTensorShape);
3007 
3008 REGISTER_OP("TensorScatterSub")
3009     .Input("tensor: T")
3010     .Input("indices: Tindices")
3011     .Input("updates: T")
3012     .Output("output: T")
3013     .Attr("T: type")
3014     .Attr("Tindices: {int32, int64}")
3015     .SetShapeFn(ScatterNdTensorShape);
3016 
3017 REGISTER_OP("ScatterNdNonAliasingAdd")
3018     .Input("input: T")
3019     .Input("indices: Tindices")
3020     .Input("updates: T")
3021     .Output("output: T")
3022     .Attr("T: {numbertype, bool}")
3023     .Attr("Tindices: {int32, int64}")
3024     .SetShapeFn(shape_inference::ScatterNdUpdateShape);
3025 
3026 REGISTER_OP("FakeQuantWithMinMaxArgs")
3027     .Attr("min: float = -6.0")
3028     .Attr("max: float = 6.0")
3029     .Attr("num_bits: int = 8")
3030     .Attr("narrow_range: bool = false")
3031     .Input("inputs: float")
3032     .Output("outputs: float")
3033     .SetShapeFn(shape_inference::UnchangedShape);
3034 
3035 REGISTER_OP("FakeQuantWithMinMaxArgsGradient")
3036     .Attr("min: float = -6.0")
3037     .Attr("max: float = 6.0")
3038     .Attr("num_bits: int = 8")
3039     .Attr("narrow_range: bool = false")
3040     .Input("gradients: float")
3041     .Input("inputs: float")
3042     .Output("backprops: float")
3043     .SetShapeFn(shape_inference::UnchangedShape);
3044 
3045 REGISTER_OP("FakeQuantWithMinMaxVars")
3046     .Attr("num_bits: int = 8")
3047     .Attr("narrow_range: bool = false")
3048     .Input("inputs: float")
3049     .Input("min: float")
3050     .Input("max: float")
3051     .Output("outputs: float")
__anondb9326b24702(InferenceContext* c) 3052     .SetShapeFn([](InferenceContext* c) {
3053       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
3054       ShapeHandle unused;
3055       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
3056       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
3057       return Status::OK();
3058     });
3059 
3060 REGISTER_OP("FakeQuantWithMinMaxVarsGradient")
3061     .Attr("num_bits: int = 8")
3062     .Attr("narrow_range: bool = false")
3063     .Input("gradients: float")
3064     .Input("inputs: float")
3065     .Input("min: float")
3066     .Input("max: float")
3067     .Output("backprops_wrt_input: float")
3068     .Output("backprop_wrt_min: float")
3069     .Output("backprop_wrt_max: float")
__anondb9326b24802(InferenceContext* c) 3070     .SetShapeFn([](InferenceContext* c) {
3071       // gradients and inputs are same size.
3072       ShapeHandle inputs;
3073       TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &inputs));
3074 
3075       // min and max are scalars
3076       ShapeHandle min_max;
3077       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_max));
3078       TF_RETURN_IF_ERROR(c->Merge(min_max, c->input(3), &min_max));
3079 
3080       c->set_output(0, inputs);
3081       c->set_output(1, min_max);
3082       c->set_output(2, min_max);
3083       return Status::OK();
3084     });
3085 
3086 REGISTER_OP("FakeQuantWithMinMaxVarsPerChannel")
3087     .Attr("num_bits: int = 8")
3088     .Attr("narrow_range: bool = false")
3089     .Input("inputs: float")
3090     .Input("min: float")
3091     .Input("max: float")
3092     .Output("outputs: float")
__anondb9326b24902(InferenceContext* c) 3093     .SetShapeFn([](InferenceContext* c) {
3094       ShapeHandle input, min, max;
3095       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
3096       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &min));
3097       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &max));
3098 
3099       DimensionHandle unused;
3100       TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -1), c->Dim(min, 0), &unused));
3101       TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -1), c->Dim(max, 0), &unused));
3102       TF_RETURN_IF_ERROR(c->Merge(c->Dim(min, 0), c->Dim(max, 0), &unused));
3103 
3104       c->set_output(0, input);
3105       return Status::OK();
3106     });
3107 
3108 REGISTER_OP("FakeQuantWithMinMaxVarsPerChannelGradient")
3109     .Attr("num_bits: int = 8")
3110     .Attr("narrow_range: bool = false")
3111     .Input("gradients: float")
3112     .Input("inputs: float")
3113     .Input("min: float")
3114     .Input("max: float")
3115     .Output("backprops_wrt_input: float")
3116     .Output("backprop_wrt_min: float")
3117     .Output("backprop_wrt_max: float")
__anondb9326b24a02(InferenceContext* c) 3118     .SetShapeFn([](InferenceContext* c) {
3119       ShapeHandle inputs;
3120       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &inputs));
3121       TF_RETURN_IF_ERROR(c->WithRankAtMost(inputs, 4, &inputs));
3122       TF_RETURN_IF_ERROR(c->Merge(inputs, c->input(1), &inputs));
3123 
3124       ShapeHandle last_dim = c->Vector(c->Dim(inputs, -1));
3125 
3126       ShapeHandle min_max;
3127       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &min_max));
3128       TF_RETURN_IF_ERROR(c->Merge(min_max, last_dim, &min_max));
3129       TF_RETURN_IF_ERROR(c->Merge(c->input(3), min_max, &min_max));
3130 
3131       c->set_output(0, inputs);
3132       c->set_output(1, min_max);
3133       c->set_output(2, min_max);
3134       return Status::OK();
3135     });
3136 
3137 #ifdef INTEL_MKL
3138 REGISTER_OP("_MklConcat")
3139     .Input("concat_dim: int32")
3140     .Input("values: N * T")
3141     .Input("mkl_concat_dim: uint8")
3142     .Input("mkl_values: N * uint8")
3143     .Output("output: T")
3144     .Output("mkl_output: uint8")
3145     .Attr("N: int >= 2")
3146     .Attr("T: type")
__anondb9326b24b02(InferenceContext* c) 3147     .SetShapeFn([](InferenceContext* c) {
3148       return shape_inference::ConcatShape(c, c->num_inputs() - 3);
3149     })
3150     .Doc(R"doc(
3151 MKL version of Concat operator. Uses MKL DNN APIs to perform concatenation.
3152 
3153 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
3154 expected to invoke these operators.
3155 )doc");
3156 #endif
3157 
3158 // Deprecated op registrations:
3159 
3160 // The following can be deleted after 10mar2017.
3161 REGISTER_OP("BatchMatrixDiag")
3162     .Input("diagonal: T")
3163     .Output("output: T")
3164     .Attr("T: type")
3165     .Deprecated(14, "Use MatrixDiag")
3166     .SetShapeFn(shape_inference::UnknownShape);
3167 REGISTER_OP("BatchMatrixSetDiag")
3168     .Input("input: T")
3169     .Input("diagonal: T")
3170     .Output("output: T")
3171     .Attr("T: type")
3172     .Deprecated(14, "Use MatrixSetDiag")
3173     .SetShapeFn(shape_inference::UnknownShape);
3174 REGISTER_OP("BatchMatrixDiagPart")
3175     .Input("input: T")
3176     .Output("diagonal: T")
3177     .Attr("T: type")
3178     .Deprecated(14, "Use MatrixDiagPart")
3179     .SetShapeFn(shape_inference::UnknownShape);
3180 REGISTER_OP("BatchMatrixBandPart")
3181     .Input("input: T")
3182     .Input("num_lower: int64")
3183     .Input("num_upper: int64")
3184     .Output("band: T")
3185     .Attr("T: type")
3186     .Deprecated(14, "Use MatrixBandPart")
3187     .SetShapeFn(shape_inference::UnknownShape);
3188 
3189 }  // namespace tensorflow
3190