1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <ostream>
18
19 #include "tensorflow/core/framework/common_shape_fns.h"
20 #include "tensorflow/core/framework/kernel_shape_util.h"
21 #include "tensorflow/core/framework/op.h"
22 #include "tensorflow/core/framework/shape_inference.h"
23 #include "tensorflow/core/framework/tensor.pb.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/framework/types.pb.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/util/mirror_pad_mode.h"
28 #include "tensorflow/core/util/padding.h"
29 #include "tensorflow/core/util/strided_slice_op.h"
30 #include "tensorflow/core/util/tensor_format.h"
31
32 namespace tensorflow {
33
34 using shape_inference::DimensionHandle;
35 using shape_inference::InferenceContext;
36 using shape_inference::ShapeHandle;
37 using shape_inference::UnchangedShape;
38
39 namespace {
40
GetAxisForPackAndUnpack(InferenceContext * c,int32 rank_after_pack,int32 * axis)41 Status GetAxisForPackAndUnpack(InferenceContext* c, int32 rank_after_pack,
42 int32* axis) {
43 TF_RETURN_IF_ERROR(c->GetAttr("axis", axis));
44 if (*axis < -1 * rank_after_pack || *axis >= rank_after_pack) {
45 return errors::InvalidArgument("Invalid axis: ", *axis, "; must be in [",
46 -1 * rank_after_pack, ",", rank_after_pack,
47 ")");
48 }
49 if (*axis < 0) *axis = (rank_after_pack + *axis);
50 return Status::OK();
51 }
52
53 template <typename T>
AsInt64(const Tensor * tensor,int64 num_elements)54 std::vector<int64> AsInt64(const Tensor* tensor, int64 num_elements) {
55 std::vector<int64> ret(num_elements);
56 auto data = tensor->vec<T>();
57 for (int64 i = 0; i < num_elements; ++i) {
58 ret[i] = data(i);
59 }
60 return ret;
61 }
62
63 template <typename T>
PadKnown(InferenceContext * c,ShapeHandle input,const Tensor * paddings_t,int64 num_dims)64 Status PadKnown(InferenceContext* c, ShapeHandle input,
65 const Tensor* paddings_t, int64 num_dims) {
66 // paddings_t is known.
67 std::vector<DimensionHandle> dims(num_dims);
68 auto paddings_data = paddings_t->matrix<T>();
69 for (int64 i = 0; i < num_dims; ++i) {
70 const T pad0 = paddings_data(i, 0);
71 const T pad1 = paddings_data(i, 1);
72 if (pad0 < 0 || pad1 < 0) {
73 return errors::InvalidArgument("Paddings must be non-negative");
74 }
75 TF_RETURN_IF_ERROR(c->Add(c->Dim(input, i), pad0 + pad1, &dims[i]));
76 }
77 c->set_output(0, c->MakeShape(dims));
78 return Status::OK();
79 }
80
PadShapeFn(InferenceContext * c)81 Status PadShapeFn(InferenceContext* c) {
82 // Paddings is a matrix of [input_rank, 2].
83 ShapeHandle paddings;
84 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &paddings));
85 DimensionHandle unused;
86 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(paddings, 1), 2, &unused));
87
88 // n_dim and input.rank are equivalent.
89 ShapeHandle input = c->input(0);
90 DimensionHandle n_dim = c->Dim(paddings, 0);
91 if (c->ValueKnown(n_dim)) {
92 TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(n_dim), &input));
93 } else if (c->RankKnown(input)) {
94 TF_RETURN_IF_ERROR(c->WithValue(n_dim, c->Rank(input), &n_dim));
95 }
96
97 const Tensor* paddings_t = c->input_tensor(1);
98
99 // paddings_t is unknown
100 if (paddings_t == nullptr) {
101 if (c->ValueKnown(n_dim)) {
102 // Make output with n_dim unknown dims.
103 c->set_output(0, c->UnknownShapeOfRank(c->Value(n_dim)));
104 } else {
105 c->set_output(0, c->UnknownShape());
106 }
107 return Status::OK();
108 }
109
110 const int64 num_dims = paddings_t->shape().dim_size(0);
111 TF_RETURN_IF_ERROR(c->WithRank(input, num_dims, &input));
112 TF_RETURN_IF_ERROR(c->WithValue(n_dim, num_dims, &n_dim));
113
114 if (paddings_t->dtype() == DT_INT32) {
115 return PadKnown<int32>(c, input, paddings_t, num_dims);
116 } else {
117 return PadKnown<int64>(c, input, paddings_t, num_dims);
118 }
119 }
120
TransposeShapeFn(InferenceContext * c)121 Status TransposeShapeFn(InferenceContext* c) {
122 ShapeHandle input = c->input(0);
123 ShapeHandle perm_shape = c->input(1);
124 const Tensor* perm = c->input_tensor(1);
125 DimensionHandle perm_elems = c->NumElements(perm_shape);
126 // If we don't have rank information on the input or value information on
127 // perm we can't return any shape information, otherwise we have enough
128 // information to at least find the rank of the output.
129 if (!c->RankKnown(input) && !c->ValueKnown(perm_elems) && perm == nullptr) {
130 c->set_output(0, c->UnknownShape());
131 return Status::OK();
132 }
133
134 // Find our value of the rank.
135 int64 rank;
136 if (c->RankKnown(input)) {
137 rank = c->Rank(input);
138 } else if (c->ValueKnown(perm_elems)) {
139 rank = c->Value(perm_elems);
140 } else {
141 rank = perm->NumElements();
142 }
143 if (!c->RankKnown(input) && rank < 2) {
144 // A permutation array containing a single element is ambiguous. It could
145 // indicate either a scalar or a 1-dimensional array, both of which the
146 // transpose op returns unchanged.
147 c->set_output(0, input);
148 return Status::OK();
149 }
150
151 std::vector<DimensionHandle> dims;
152 dims.resize(rank);
153 TF_RETURN_IF_ERROR(c->WithRank(input, rank, &input));
154 // Ensure that perm is a vector and has rank elements.
155 TF_RETURN_IF_ERROR(c->WithRank(perm_shape, 1, &perm_shape));
156 TF_RETURN_IF_ERROR(c->WithValue(perm_elems, rank, &perm_elems));
157
158 // If we know the rank of the input and the value of perm, we can return
159 // all shape information, otherwise we can only return rank information,
160 // but no information for the dimensions.
161 if (perm != nullptr) {
162 std::vector<int64> data;
163 if (perm->dtype() == DT_INT32) {
164 data = AsInt64<int32>(perm, rank);
165 } else {
166 data = AsInt64<int64>(perm, rank);
167 }
168
169 for (int32 i = 0; i < rank; ++i) {
170 int64 in_idx = data[i];
171 if (in_idx >= rank) {
172 return errors::InvalidArgument("perm dim ", in_idx,
173 " is out of range of input rank ", rank);
174 }
175 dims[i] = c->Dim(input, in_idx);
176 }
177 } else {
178 for (int i = 0; i < rank; ++i) {
179 dims[i] = c->UnknownDim();
180 }
181 }
182
183 c->set_output(0, c->MakeShape(dims));
184 return Status::OK();
185 }
186
SetOutputShapeForReshape(InferenceContext * c)187 Status SetOutputShapeForReshape(InferenceContext* c) {
188 ShapeHandle in = c->input(0);
189 ShapeHandle out;
190 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out));
191
192 if (!c->RankKnown(out)) {
193 // We have no information about the shape of the output.
194 c->set_output(0, out);
195 return Status::OK();
196 }
197 if (c->RankKnown(in)) {
198 // We don't know the number of output elements, but we can try to infer
199 // the missing dimension.
200 bool too_many_unknown = false;
201 int32 out_unknown_idx = -1;
202
203 DimensionHandle known_out_elems = c->NumElements(out);
204 if (!c->ValueKnown(known_out_elems)) {
205 known_out_elems = c->MakeDim(1);
206 for (int32 i = 0; i < c->Rank(out); ++i) {
207 DimensionHandle dim = c->Dim(out, i);
208 if (!c->ValueKnown(dim)) {
209 if (out_unknown_idx >= 0) {
210 too_many_unknown = true;
211 break;
212 }
213 out_unknown_idx = i;
214 } else {
215 TF_RETURN_IF_ERROR(
216 c->Multiply(known_out_elems, dim, &known_out_elems));
217 }
218 }
219 }
220 int32 in_unknown_idx = -1;
221 DimensionHandle known_in_elems = c->NumElements(in);
222 if (!c->ValueKnown(known_in_elems)) {
223 known_in_elems = c->MakeDim(1);
224 for (int32 i = 0; i < c->Rank(in); ++i) {
225 DimensionHandle dim = c->Dim(in, i);
226 if (!c->ValueKnown(dim)) {
227 if (in_unknown_idx >= 0) {
228 too_many_unknown = true;
229 break;
230 }
231 in_unknown_idx = i;
232 } else {
233 TF_RETURN_IF_ERROR(c->Multiply(known_in_elems, dim, &known_in_elems));
234 }
235 }
236 }
237
238 if (!too_many_unknown) {
239 if (in_unknown_idx < 0 && out_unknown_idx < 0) {
240 // Just check that the dimensions match.
241 if (c->Value(known_in_elems) != c->Value(known_out_elems)) {
242 return errors::InvalidArgument(
243 "Cannot reshape a tensor with ", c->DebugString(known_in_elems),
244 " elements to shape ", c->DebugString(out), " (",
245 c->DebugString(known_out_elems), " elements)");
246 }
247 } else if (in_unknown_idx < 0 && out_unknown_idx >= 0 &&
248 c->Value(known_out_elems) > 0) {
249 // Input fully known, infer the one missing output dim
250 DimensionHandle inferred_dim;
251 TF_RETURN_IF_ERROR(c->Divide(known_in_elems, c->Value(known_out_elems),
252 true /* evenly_divisible */,
253 &inferred_dim));
254 TF_RETURN_IF_ERROR(
255 c->ReplaceDim(out, out_unknown_idx, inferred_dim, &out));
256
257 } else if (in_unknown_idx >= 0 && out_unknown_idx < 0 &&
258 c->Value(known_in_elems) != 0) {
259 // Output fully known, infer the one missing input dim
260 DimensionHandle inferred_dim;
261 TF_RETURN_IF_ERROR(c->Divide(known_out_elems, c->Value(known_in_elems),
262 true /* evenly_divisible */,
263 &inferred_dim));
264 DimensionHandle unknown_in_dim = c->Dim(in, in_unknown_idx);
265 TF_RETURN_IF_ERROR(
266 c->Merge(unknown_in_dim, inferred_dim, &unknown_in_dim));
267 } else if (in_unknown_idx >= 0 && out_unknown_idx >= 0) {
268 // Exactly one unknown dimension in both input and output. These 2 are
269 // equal iff the known elements are equal.
270 if (c->Value(known_in_elems) == c->Value(known_out_elems)) {
271 DimensionHandle unknown_in_dim = c->Dim(in, in_unknown_idx);
272 TF_RETURN_IF_ERROR(
273 c->ReplaceDim(out, out_unknown_idx, unknown_in_dim, &out));
274 }
275 }
276 }
277 }
278 c->set_output(0, out);
279 return Status::OK();
280 }
281
282 } // namespace
283
284 REGISTER_OP("ParallelConcat")
285 .Input("values: N * T")
286 .Output("output: T")
287 .Attr("N: int >= 1")
288 .Attr("T: type")
289 .Attr("shape: shape")
__anondb9326b20202(InferenceContext* c) 290 .SetShapeFn([](InferenceContext* c) {
291 // Validate that the shape attr is correct.
292 PartialTensorShape shape;
293 TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
294 ShapeHandle passed_shape;
295 TF_RETURN_IF_ERROR(
296 c->MakeShapeFromPartialTensorShape(shape, &passed_shape));
297 if (!c->FullyDefined(passed_shape)) {
298 return errors::InvalidArgument("shape attr must be fully defined.");
299 }
300 ShapeHandle cur;
301 TF_RETURN_IF_ERROR(c->ReplaceDim(
302 passed_shape, 0, c->MakeDim(shape_inference::DimensionOrConstant(1)),
303 &cur));
304 for (int i = 0; i < c->num_inputs(); ++i) {
305 if (!c->FullyDefined(c->input(i))) {
306 return errors::InvalidArgument(
307 "All input shapes must be fully defined.");
308 }
309 DimensionHandle unused;
310 if (!c->WithValue(c->Dim(c->input(i), 0), 1, &unused).ok()) {
311 return errors::InvalidArgument("Size of first dimension must be 1.");
312 }
313 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
314 "From merging shape ", i,
315 " with other shapes.");
316 }
317
318 c->set_output(0, passed_shape);
319
320 return Status::OK();
321 });
322
323 REGISTER_OP("Pack")
324 .Input("values: N * T")
325 .Output("output: T")
326 .Attr("N: int >= 1")
327 .Attr("T: type")
328 .Attr("axis: int = 0")
__anondb9326b20302(InferenceContext* c) 329 .SetShapeFn([](InferenceContext* c) {
330 // Validate shapes of all inputs are compatible
331 ShapeHandle cur = c->input(c->num_inputs() - 1);
332 for (int i = c->num_inputs() - 2; i >= 0; --i) {
333 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
334 "From merging shape ", i,
335 " with other shapes.");
336 }
337 if (!c->RankKnown(cur)) {
338 c->set_output(0, c->UnknownShape());
339 return Status::OK();
340 }
341 // Determine the axis that will be added, converting from negative
342 // axes to a positive point per negative indexing rules.
343 int32 rank = c->Rank(cur);
344 int32 axis;
345 TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank + 1, &axis));
346
347 // Copy all dimensions over, inserting a dimension of value #inputs
348 // at <axis>.
349 std::vector<DimensionHandle> dims;
350 int index = 0;
351 while (index < axis) dims.push_back(c->Dim(cur, index++));
352 dims.push_back(c->MakeDim(c->num_inputs()));
353 while (index < rank) dims.push_back(c->Dim(cur, index++));
354
355 c->set_output(0, c->MakeShape(dims));
356 for (int i = 0; i < c->num_inputs(); ++i) {
357 auto* shape_and_type = c->input_handle_shapes_and_types(i);
358 if (shape_and_type) {
359 if (!c->RelaxOutputHandleShapesAndMergeTypes(0, *shape_and_type)) {
360 c->set_output_handle_shapes_and_types(
361 0, std::vector<shape_inference::ShapeAndType>({}));
362 break;
363 }
364 }
365 }
366 return Status::OK();
367 });
368
369 REGISTER_OP("DeepCopy")
370 .Input("x: T")
371 .Output("y: T")
372 .Attr("T: type")
373 .SetIsStateful()
374 .SetShapeFn(UnchangedShape);
375
376 REGISTER_OP("InplaceUpdate")
377 .Input("x: T")
378 .Input("i: int32")
379 .Input("v: T")
380 .Output("y: T")
381 .Attr("T: type")
382 .SetShapeFn(UnchangedShape);
383
384 REGISTER_OP("InplaceAdd")
385 .Input("x: T")
386 .Input("i: int32")
387 .Input("v: T")
388 .Output("y: T")
389 .Attr("T: type")
390 .SetShapeFn(UnchangedShape);
391
392 REGISTER_OP("InplaceSub")
393 .Input("x: T")
394 .Input("i: int32")
395 .Input("v: T")
396 .Output("y: T")
397 .Attr("T: type")
398 .SetShapeFn(UnchangedShape);
399
400 REGISTER_OP("Empty")
401 .Input("shape: int32")
402 .Output("output: dtype")
403 .Attr("dtype: type")
404 .Attr("init: bool = false")
405 .SetDoNotOptimize()
__anondb9326b20402(InferenceContext* c) 406 .SetShapeFn([](InferenceContext* c) {
407 ShapeHandle out;
408 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
409 c->set_output(0, out);
410 return Status::OK();
411 });
412
413 // --------------------------------------------------------------------------
414 REGISTER_OP("Unpack")
415 .Input("value: T")
416 .Output("output: num * T")
417 .Attr("num: int >= 0")
418 .Attr("T: type")
419 .Attr("axis: int = 0")
__anondb9326b20502(InferenceContext* c) 420 .SetShapeFn([](InferenceContext* c) {
421 ShapeHandle s = c->input(0);
422 ShapeHandle out;
423 if (c->RankKnown(s)) {
424 // Determine the axis that will be removed, converting from negative
425 // axes to a positive point per negative indexing rules.
426 int32 rank = c->Rank(s);
427 int32 axis;
428 TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank, &axis));
429
430 // The axis dim matches the number of outputs.
431 DimensionHandle unused;
432 TF_RETURN_IF_ERROR(
433 c->WithValue(c->Dim(s, axis), c->num_outputs(), &unused));
434
435 // Copy all dimensions, removing the <axis> dimension.
436 std::vector<DimensionHandle> dims;
437 for (int i = 0; i < rank; ++i) {
438 if (i != axis) dims.push_back(c->Dim(s, i));
439 }
440 out = c->MakeShape(dims);
441 } else {
442 // All outputs are the same shape, but it's not known.
443 out = c->UnknownShape();
444 }
445 for (int i = 0; i < c->num_outputs(); ++i) c->set_output(i, out);
446 return Status::OK();
447 });
448
449 REGISTER_OP("UnravelIndex")
450 .Input("indices: Tidx")
451 .Input("dims: Tidx")
452 .Output("output: Tidx")
453 .Attr("Tidx: {int32, int64} = DT_INT32")
__anondb9326b20602(InferenceContext* c) 454 .SetShapeFn([](InferenceContext* c) {
455 ShapeHandle indices = c->input(0);
456 ShapeHandle dims;
457 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &dims));
458 if (c->RankKnown(indices) && c->Rank(indices) == 0) {
459 c->set_output(0, c->Vector(c->Dim(dims, 0)));
460 } else if (c->RankKnown(indices)) {
461 c->set_output(0, c->Matrix(c->Dim(dims, 0), c->NumElements(indices)));
462 } else {
463 c->set_output(0, c->UnknownShape());
464 }
465 return Status::OK();
466 });
467
468 REGISTER_OP("BroadcastTo")
469 .Input("input: T")
470 .Input("shape: Tidx")
471 .Output("output: T")
472 .Attr("T: type")
473 .Attr("Tidx: {int32, int64} = DT_INT32")
__anondb9326b20702(InferenceContext* c) 474 .SetShapeFn([](InferenceContext* c) {
475 ShapeHandle shape_in = c->input(1);
476 TF_RETURN_IF_ERROR(c->WithRank(shape_in, 1, &shape_in));
477 ShapeHandle out;
478 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out));
479 if (!c->RankKnown(out)) {
480 // We have no information about the shape of the output.
481 c->set_output(0, out);
482 return Status::OK();
483 }
484
485 ShapeHandle in = c->input(0);
486 if (!c->RankKnown(in)) {
487 // We have no information about the shape of the input,
488 // nothing to do here.
489 c->set_output(0, out);
490 return Status::OK();
491 }
492 int out_rank = c->Rank(out);
493 TF_RETURN_IF_ERROR(c->WithRankAtMost(in, out_rank, &in));
494 int in_rank = c->Rank(in);
495 for (int i = 0; i < in_rank; ++i) {
496 auto in_dim = c->Dim(in, in_rank - i - 1);
497 if (c->Value(in_dim) > 1) {
498 // If the input dimension is greater than 1 then the output dimension
499 // must be equal to it, since we only broadcast "from left to right".
500 auto out_dim = c->Dim(out, out_rank - i - 1);
501 TF_RETURN_IF_ERROR(c->Merge(in_dim, out_dim, &out_dim));
502 TF_RETURN_IF_ERROR(
503 c->ReplaceDim(out, out_rank - i - 1, out_dim, &out));
504 }
505 }
506 c->set_output(0, out);
507 return Status::OK();
508 });
509
510 // --------------------------------------------------------------------------
511 // TODO(josh11b): Remove the >= 2 constraint, once we can rewrite the graph
512 // in the N == 1 case to remove the node.
513 REGISTER_OP("Concat")
514 .Input("concat_dim: int32")
515 .Input("values: N * T")
516 .Output("output: T")
517 .Attr("N: int >= 2")
518 .Attr("T: type")
__anondb9326b20802(InferenceContext* c) 519 .SetShapeFn([](InferenceContext* c) {
520 return shape_inference::ConcatShape(c, c->num_inputs() - 1);
521 });
522
523 REGISTER_OP("ConcatV2")
524 .Input("values: N * T")
525 .Input("axis: Tidx")
526 .Output("output: T")
527 .Attr("N: int >= 2")
528 .Attr("T: type")
529 .Attr("Tidx: {int32, int64} = DT_INT32")
530 .SetShapeFn(shape_inference::ConcatV2Shape);
531
532 // TODO(vivek.v.rane@intel.com): Prefix the op names with underscore if the ops
533 // are not to be made user-accessible.
534 #ifdef INTEL_MKL
535 REGISTER_OP("_MklConcatV2")
536 .Input("values: N * T")
537 .Input("axis: Tidx")
538 .Input("mkl_values: N * uint8")
539 .Input("mkl_axis: uint8")
540 .Output("output: T")
541 .Output("mkl_output: uint8")
542 .Attr("N: int >= 2")
543 .Attr("T: type")
544 .Attr("Tidx: {int32, int64} = DT_INT32")
545 .SetShapeFn(shape_inference::ConcatV2Shape)
546 .Doc(R"doc(
547 MKL version of ConcatV2 operator. Uses MKL DNN APIs to perform concatenation.
548
549 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
550 expected to invoke these operators.
551 )doc");
552 #endif
553
554 REGISTER_OP("ConcatOffset")
555 .Input("concat_dim: int32")
556 .Input("shape: N * int32")
557 .Output("offset: N * int32")
558 .Attr("N: int >= 2")
__anondb9326b20902(InferenceContext* c) 559 .SetShapeFn([](InferenceContext* c) {
560 for (int i = 1; i < c->num_inputs(); ++i) {
561 c->set_output(i - 1, c->input(i));
562 }
563 return Status::OK();
564 });
565
566 // --------------------------------------------------------------------------
567 REGISTER_OP("Split")
568 .Input("split_dim: int32")
569 .Input("value: T")
570 .Output("output: num_split * T")
571 .Attr("num_split: int >= 1")
572 .Attr("T: type")
__anondb9326b20a02(InferenceContext* c) 573 .SetShapeFn([](InferenceContext* c) {
574 DimensionHandle split_dimension;
575 ShapeHandle input = c->input(1);
576 TF_RETURN_IF_ERROR(c->MakeDimForScalarInputWithNegativeIndexing(
577 0, c->Rank(input), &split_dimension));
578 int num_split = c->num_outputs();
579 ShapeHandle out;
580 if (!c->ValueKnown(split_dimension)) {
581 if (c->RankKnown(input)) {
582 out = c->UnknownShapeOfRank(c->Rank(input));
583 } else {
584 out = c->UnknownShape();
585 }
586 } else {
587 int64 split_dim = c->Value(split_dimension);
588 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input));
589 DimensionHandle split_dim_size;
590 TF_RETURN_WITH_CONTEXT_IF_ERROR(
591 c->Divide(c->Dim(input, split_dim), num_split,
592 true /* evenly_divisible */, &split_dim_size),
593 "Number of ways to split should evenly divide the split dimension");
594 TF_RETURN_IF_ERROR(
595 c->ReplaceDim(input, split_dim, split_dim_size, &out));
596 }
597 for (int i = 0; i < num_split; ++i) c->set_output(i, out);
598 return Status::OK();
599 });
600
601 REGISTER_OP("SplitV")
602 .Input("value: T")
603 .Input("size_splits: Tlen")
604 .Input("split_dim: int32")
605 .Output("output: num_split * T")
606 .Attr("num_split: int >= 1")
607 .Attr("T: type")
608 .Attr("Tlen: {int32, int64} = DT_INT64")
__anondb9326b20b02(InferenceContext* c) 609 .SetShapeFn([](InferenceContext* c) {
610 DimensionHandle split_dimension;
611 ShapeHandle input = c->input(0);
612 TF_RETURN_IF_ERROR(c->MakeDimForScalarInputWithNegativeIndexing(
613 2, c->Rank(input), &split_dimension));
614 int32 num_outputs = c->num_outputs();
615 int32 rank = c->Rank(input);
616 ShapeHandle output_shape;
617 const Tensor* size_splits = c->input_tensor(1);
618 if (rank == InferenceContext::kUnknownRank) {
619 // If the rank of input tensor is unknown, then return unknown shapes.
620 // Note that the shape of each output can be different.
621 for (int i = 0; i < num_outputs; ++i) {
622 c->set_output(i, c->UnknownShape());
623 }
624 } else if (rank == 0) {
625 // Throw error if input is a scalar.
626 return errors::InvalidArgument("Can't split scalars");
627 } else if (size_splits == nullptr && c->ValueKnown(split_dimension)) {
628 // If split dimension is known, but the sizes are unknown, then
629 // only the split dimension is unknown
630 output_shape = input;
631 for (int i = 0; i < num_outputs; ++i) {
632 TF_RETURN_IF_ERROR(c->ReplaceDim(output_shape,
633 c->Value(split_dimension),
634 c->UnknownDim(), &output_shape));
635 c->set_output(i, output_shape);
636 }
637 } else if (size_splits == nullptr && !c->ValueKnown(split_dimension)) {
638 // If split dimension or tensor containing the split sizes is unknown,
639 // then return unknown shapes of same rank as input. Note that each
640 // output shape can be different since splitv doesn't always split
641 // tensors evenly.
642 for (int i = 0; i < num_outputs; ++i) {
643 c->set_output(i, c->UnknownShapeOfRank(rank));
644 }
645 } else {
646 // Determine the output shape if split dimension and split sizes are
647 // known.
648 int64 split_dim = c->Value(split_dimension);
649 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input));
650 std::vector<int64> data;
651 if (size_splits->dtype() == DT_INT32) {
652 data = AsInt64<int32>(size_splits, size_splits->shape().dim_size(0));
653 } else {
654 data = AsInt64<int64>(size_splits, size_splits->shape().dim_size(0));
655 }
656 if (num_outputs != data.size()) {
657 return errors::InvalidArgument(
658 "Length of size_splits should be equal to num_outputs");
659 }
660 int64_t total_size = 0;
661 bool has_neg_one = false;
662 for (const auto size : data) {
663 if (size == -1) {
664 if (has_neg_one) {
665 return errors::InvalidArgument(
666 "size_splits can only have one -1");
667 }
668 has_neg_one = true;
669 } else {
670 total_size += size;
671 }
672 }
673 auto split_dim_size = c->Value(c->Dim(input, split_dim));
674 // If the sizes of the splits are known, then
675 // make sure that the sizes add up to the expected
676 // dimension size, with the possibility of a -1.
677 // Specify the full output shapes.
678 for (int i = 0; i < num_outputs; ++i) {
679 auto size = data[i];
680 if (data[i] == -1 && c->ValueKnown(split_dim_size)) {
681 size = split_dim_size - total_size;
682 }
683 TF_RETURN_IF_ERROR(
684 c->ReplaceDim(input, split_dim, c->MakeDim(size), &output_shape));
685 c->set_output(i, output_shape);
686 }
687 if (c->ValueKnown(split_dim_size)) {
688 if (has_neg_one ? total_size > split_dim_size
689 : total_size != split_dim_size) {
690 return errors::InvalidArgument(
691 "can't split axis of size ", split_dim_size,
692 " into pieces of size [", absl::StrJoin(data, ","), "]");
693 }
694 }
695 }
696
697 return Status::OK();
698 });
699
700 // --------------------------------------------------------------------------
701 REGISTER_OP("Const")
702 .Output("output: dtype")
703 .Attr("value: tensor")
704 .Attr("dtype: type")
__anondb9326b20c02(InferenceContext* c) 705 .SetShapeFn([](InferenceContext* c) {
706 const TensorProto* proto = nullptr;
707 TF_RETURN_IF_ERROR(c->GetAttr("value", &proto));
708 TF_RETURN_IF_ERROR(TensorShape::IsValidShape(proto->tensor_shape()));
709 TensorShape shape(proto->tensor_shape());
710 std::vector<DimensionHandle> dims;
711 dims.reserve(shape.dims());
712 for (int i = 0; i < shape.dims(); ++i) {
713 dims.push_back(c->MakeDim(shape.dim_size(i)));
714 }
715 c->set_output(0, c->MakeShape(dims));
716 return Status::OK();
717 });
718
719 // Returns a constant tensor on the host. Useful for writing C++ tests
720 // and benchmarks which run on GPU but require arguments pinned to the host.
721 // Used by test::graph::HostConstant.
722 // value: Attr `value` is the tensor to return.
723 REGISTER_OP("HostConst")
724 .Output("output: dtype")
725 .Attr("value: tensor")
726 .Attr("dtype: type")
727 .SetShapeFn(shape_inference::UnknownShape);
728
729 // --------------------------------------------------------------------------
730 // TODO(mgubin): Update the doc when the freeze_graph script supports converting
731 // into memmapped format.
732 REGISTER_OP("ImmutableConst")
733 .Attr("dtype: type")
734 .Attr("shape: shape")
735 .Attr("memory_region_name: string")
736 .Output("tensor: dtype")
737 .SetShapeFn(shape_inference::ExplicitShape);
738
739 REGISTER_OP("GuaranteeConst")
740 .Input("input: T")
741 .Output("output: T")
742 .Attr("T: type")
__anondb9326b20d02(shape_inference::InferenceContext* c) 743 .SetShapeFn([](shape_inference::InferenceContext* c) {
744 return UnchangedShape(c);
745 })
746 // We don't want this to be optimized away.
747 .SetDoNotOptimize();
748
749 // --------------------------------------------------------------------------
750 REGISTER_OP("ZerosLike")
751 .Input("x: T")
752 .Output("y: T")
753 .Attr("T: type")
754 .SetShapeFn(shape_inference::UnchangedShape);
755
756 // --------------------------------------------------------------------------
757 REGISTER_OP("OnesLike")
758 .Input("x: T")
759 .Output("y: T")
760 .Attr(
761 "T: {bfloat16, half, float, double, int8, uint8, int16, uint16, int32, "
762 "int64, complex64, complex128, bool}")
763 .SetShapeFn(shape_inference::UnchangedShape);
764
765 // --------------------------------------------------------------------------
766 REGISTER_OP("Diag")
767 .Input("diagonal: T")
768 .Output("output: T")
769 .Attr(
770 "T: {bfloat16, half, float, double, int32, int64, complex64, "
771 "complex128}")
__anondb9326b20e02(InferenceContext* c) 772 .SetShapeFn([](InferenceContext* c) {
773 ShapeHandle in = c->input(0);
774 TF_RETURN_IF_ERROR(c->WithRankAtLeast(in, 1, &in));
775 // Output shape is original concatenated with itself.
776 ShapeHandle out;
777 TF_RETURN_IF_ERROR(c->Concatenate(in, in, &out));
778 c->set_output(0, out);
779 return Status::OK();
780 });
781
782 // --------------------------------------------------------------------------
783 REGISTER_OP("DiagPart")
784 .Input("input: T")
785 .Output("diagonal: T")
786 .Attr(
787 "T: {bfloat16, half, float, double, int32, int64, complex64, "
788 "complex128}")
__anondb9326b20f02(InferenceContext* c) 789 .SetShapeFn([](InferenceContext* c) {
790 ShapeHandle in = c->input(0);
791 if (!c->RankKnown(in)) {
792 c->set_output(0, c->UnknownShape());
793 return Status::OK();
794 }
795 // Rank must be even, and result will have rank <rank/2>.
796 const int32 rank = c->Rank(in);
797 if ((rank % 2) != 0 || rank <= 0) {
798 return errors::InvalidArgument(
799 "Input must have even and non-zero rank, input rank is ", rank);
800 }
801 const int32 mid = rank / 2;
802
803 // output dim[i] is the merge of in.dim[i] and in.dim[i+mid].
804 std::vector<DimensionHandle> dims(mid);
805 for (int i = 0; i < mid; ++i) {
806 TF_RETURN_IF_ERROR(
807 c->Merge(c->Dim(in, i), c->Dim(in, i + mid), &dims[i]));
808 }
809 c->set_output(0, c->MakeShape(dims));
810 return Status::OK();
811 });
812
813 // --------------------------------------------------------------------------
814 REGISTER_OP("MatrixDiag")
815 .Input("diagonal: T")
816 .Output("output: T")
817 .Attr("T: type")
__anondb9326b21002(InferenceContext* c) 818 .SetShapeFn([](InferenceContext* c) {
819 ShapeHandle in;
820 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &in));
821 if (!c->RankKnown(in)) {
822 c->set_output(0, c->UnknownShape());
823 return Status::OK();
824 }
825 const int32 rank = c->Rank(in);
826 ShapeHandle out;
827 TF_RETURN_IF_ERROR(
828 c->Concatenate(in, c->Vector(c->Dim(in, rank - 1)), &out));
829 c->set_output(0, out);
830 return Status::OK();
831 });
832
833 REGISTER_OP("MatrixDiagV2")
834 .Input("diagonal: T")
835 .Input("k: int32")
836 .Input("num_rows: int32")
837 .Input("num_cols: int32")
838 .Input("padding_value: T")
839 .Output("output: T")
840 .Attr("T: type")
841 .SetShapeFn(shape_inference::MatrixDiagV2Shape);
842
843 REGISTER_OP("MatrixDiagV3")
844 .Input("diagonal: T")
845 .Input("k: int32")
846 .Input("num_rows: int32")
847 .Input("num_cols: int32")
848 .Input("padding_value: T")
849 .Output("output: T")
850 .Attr("T: type")
851 .Attr(
852 "align: {'LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'} = "
853 "'RIGHT_LEFT'")
854 .SetShapeFn(shape_inference::MatrixDiagV2Shape);
855
856 // --------------------------------------------------------------------------
857 REGISTER_OP("MatrixSetDiag")
858 .Input("input: T")
859 .Input("diagonal: T")
860 .Output("output: T")
861 .Attr("T: type")
__anondb9326b21102(InferenceContext* c) 862 .SetShapeFn([](InferenceContext* c) {
863 ShapeHandle input;
864 ShapeHandle diag;
865 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
866 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &diag));
867 if (c->RankKnown(input)) {
868 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), c->Rank(input) - 1, &diag));
869 }
870 DimensionHandle smallest_dim;
871 TF_RETURN_IF_ERROR(
872 c->Min(c->Dim(input, -2), c->Dim(input, -1), &smallest_dim));
873 TF_RETURN_IF_ERROR(
874 c->Merge(smallest_dim, c->Dim(diag, -1), &smallest_dim));
875
876 ShapeHandle output = input;
877 if (c->RankKnown(diag) && !c->FullyDefined(input)) {
878 // Try to infer parts of shape from diag.
879 ShapeHandle diag_batch_shape;
880 TF_RETURN_IF_ERROR(c->Subshape(diag, 0, -1, &diag_batch_shape));
881 TF_RETURN_IF_ERROR(
882 c->Concatenate(diag_batch_shape, c->UnknownShapeOfRank(2), &diag));
883 TF_RETURN_IF_ERROR(c->Merge(input, diag, &output));
884 }
885 c->set_output(0, output);
886 return Status::OK();
887 });
888
889 REGISTER_OP("MatrixSetDiagV2")
890 .Input("input: T")
891 .Input("diagonal: T")
892 .Input("k: int32")
893 .Output("output: T")
894 .Attr("T: type")
895 .SetShapeFn(shape_inference::MatrixSetDiagV2Shape);
896
897 REGISTER_OP("MatrixSetDiagV3")
898 .Input("input: T")
899 .Input("diagonal: T")
900 .Input("k: int32")
901 .Output("output: T")
902 .Attr("T: type")
903 .Attr(
904 "align: {'LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'} = "
905 "'RIGHT_LEFT'")
906 .SetShapeFn(shape_inference::MatrixSetDiagV2Shape);
907
908 // --------------------------------------------------------------------------
909 REGISTER_OP("MatrixDiagPart")
910 .Input("input: T")
911 .Output("diagonal: T")
912 .Attr("T: type")
__anondb9326b21202(InferenceContext* c) 913 .SetShapeFn([](InferenceContext* c) {
914 ShapeHandle in;
915 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &in));
916 if (!c->RankKnown(in)) {
917 c->set_output(0, c->UnknownShape());
918 return Status::OK();
919 }
920 const int32 rank = c->Rank(in);
921 std::vector<DimensionHandle> dims;
922 dims.reserve(rank - 2);
923 for (int i = 0; i < rank - 2; ++i) dims.push_back(c->Dim(in, i));
924
925 DimensionHandle min_dim;
926 TF_RETURN_IF_ERROR(
927 c->Min(c->Dim(in, rank - 2), c->Dim(in, rank - 1), &min_dim));
928 dims.push_back(min_dim);
929 c->set_output(0, c->MakeShape(dims));
930 return Status::OK();
931 });
932
933 REGISTER_OP("MatrixDiagPartV2")
934 .Input("input: T")
935 .Input("k: int32")
936 .Input("padding_value: T")
937 .Output("diagonal: T")
938 .Attr("T: type")
939 .SetShapeFn(shape_inference::MatrixDiagPartV2Shape);
940
941 REGISTER_OP("MatrixDiagPartV3")
942 .Input("input: T")
943 .Input("k: int32")
944 .Input("padding_value: T")
945 .Output("diagonal: T")
946 .Attr("T: type")
947 .Attr(
948 "align: {'LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'} = "
949 "'RIGHT_LEFT'")
950 .SetShapeFn(shape_inference::MatrixDiagPartV2Shape);
951
952 // --------------------------------------------------------------------------
953 REGISTER_OP("MatrixBandPart")
954 .Input("input: T")
955 .Input("num_lower: Tindex")
956 .Input("num_upper: Tindex")
957 .Output("band: T")
958 .Attr("T: type")
959 .Attr("Tindex: {int32, int64} = DT_INT64")
960 .SetShapeFn(shape_inference::UnchangedShape);
961
962 // --------------------------------------------------------------------------
963 REGISTER_OP("Reverse")
964 .Input("tensor: T")
965 .Input("dims: bool")
966 .Output("output: T")
967 .Attr(
968 "T: {uint8, int8, uint16, int16, int32, int64, bool, bfloat16, half, "
969 "float, double, complex64, complex128, string}")
__anondb9326b21302(InferenceContext* c) 970 .SetShapeFn([](InferenceContext* c) {
971 ShapeHandle input = c->input(0);
972 ShapeHandle dims;
973 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &dims));
974 DimensionHandle dims_dim = c->Dim(dims, 0);
975 if (c->ValueKnown(dims_dim)) {
976 TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(dims_dim), &input));
977 }
978 if (c->Rank(input) > 8) {
979 return errors::InvalidArgument(
980 "reverse does not work on tensors with more than 8 dimensions");
981 }
982 c->set_output(0, input);
983 return Status::OK();
984 });
985
986 // --------------------------------------------------------------------------
987 REGISTER_OP("ReverseV2")
988 .Input("tensor: T")
989 .Input("axis: Tidx")
990 .Output("output: T")
991 .Attr("Tidx: {int32, int64} = DT_INT32")
992 .Attr(
993 "T: {uint8, int8, uint16, int16, int32, int64, bool, bfloat16, half, "
994 "float, double, complex64, complex128, string}")
__anondb9326b21402(InferenceContext* c) 995 .SetShapeFn([](InferenceContext* c) {
996 ShapeHandle input = c->input(0);
997 ShapeHandle axis;
998 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &axis));
999 if (c->Rank(input) > 8) {
1000 return errors::InvalidArgument(
1001 "reverse does not work on tensors with more than 8 dimensions");
1002 }
1003 const Tensor* axis_tensor = c->input_tensor(1);
1004 if (axis_tensor != nullptr && c->RankKnown(input)) {
1005 int32 rank = c->Rank(input);
1006 std::vector<int64> axis_value;
1007 if (axis_tensor->dtype() == DT_INT32) {
1008 axis_value = AsInt64<int32>(axis_tensor, axis_tensor->NumElements());
1009 } else {
1010 axis_value = AsInt64<int64>(axis_tensor, axis_tensor->NumElements());
1011 }
1012 std::vector<bool> axes_dense(c->Rank(input), false);
1013 for (int i = 0; i < axis_value.size(); i++) {
1014 int64 canonical_axis =
1015 axis_value[i] < 0 ? rank + axis_value[i] : axis_value[i];
1016 if (canonical_axis < 0 || canonical_axis >= rank) {
1017 return errors::InvalidArgument("'axis'[", i, "] = ", axis_value[i],
1018 " is out of valid range [", 0, ", ",
1019 rank - 1);
1020 }
1021 if (axes_dense[canonical_axis]) {
1022 return errors::InvalidArgument("axis ", canonical_axis,
1023 " specified more than once.");
1024 }
1025 axes_dense[canonical_axis] = true;
1026 }
1027 }
1028 c->set_output(0, input);
1029 return Status::OK();
1030 });
1031
1032 // --------------------------------------------------------------------------
1033 REGISTER_OP("EditDistance")
1034 .Input("hypothesis_indices: int64")
1035 .Input("hypothesis_values: T")
1036 .Input("hypothesis_shape: int64")
1037 .Input("truth_indices: int64")
1038 .Input("truth_values: T")
1039 .Input("truth_shape: int64")
1040 .Attr("normalize: bool = true")
1041 .Attr("T: type")
1042 .Output("output: float")
__anondb9326b21502(InferenceContext* c) 1043 .SetShapeFn([](InferenceContext* c) {
1044 TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
1045 c, c->input(0), c->input(1), c->input(2)));
1046 TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
1047 c, c->input(3), c->input(4), c->input(5)));
1048 const Tensor* hypothesis_shape_t = c->input_tensor(2);
1049 const Tensor* truth_shape_t = c->input_tensor(5);
1050 if (hypothesis_shape_t == nullptr || truth_shape_t == nullptr) {
1051 // We need to know the runtime shape of the two tensors,
1052 // or else the output shape is unknown.
1053 return shape_inference::UnknownShape(c);
1054 }
1055
1056 if (hypothesis_shape_t->NumElements() != truth_shape_t->NumElements()) {
1057 return errors::InvalidArgument(
1058 "Num elements of hypothesis_shape does not match truth_shape: ",
1059 hypothesis_shape_t->NumElements(), " vs. ",
1060 truth_shape_t->NumElements());
1061 }
1062
1063 auto h_values = hypothesis_shape_t->flat<int64>();
1064 auto t_values = truth_shape_t->flat<int64>();
1065 std::vector<DimensionHandle> dims(hypothesis_shape_t->NumElements() - 1);
1066 for (int i = 0; i < dims.size(); ++i) {
1067 dims[i] = c->MakeDim(std::max(h_values(i), t_values(i)));
1068 }
1069
1070 c->set_output(0, c->MakeShape(dims));
1071 return Status::OK();
1072 });
1073
1074 // --------------------------------------------------------------------------
1075 REGISTER_OP("Fill")
1076 .Input("dims: index_type")
1077 .Input("value: T")
1078 .Output("output: T")
1079 .Attr("T: type")
1080 .Attr("index_type: {int32, int64} = DT_INT32")
__anondb9326b21602(InferenceContext* c) 1081 .SetShapeFn([](InferenceContext* c) {
1082 DataType index_type = DT_INT32;
1083 Status s = c->GetAttr("index_type", &index_type);
1084 if (!s.ok() && s.code() != error::NOT_FOUND) {
1085 return s;
1086 }
1087 ShapeHandle unused;
1088 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
1089 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1090
1091 const Tensor* t = c->input_tensor(0);
1092 if (t != nullptr) {
1093 for (int i = 0; i < t->NumElements(); ++i) {
1094 if ((index_type == DT_INT32 && t->vec<int32>()(i) < 0) ||
1095 (index_type == DT_INT64 && t->vec<int64>()(i) < 0)) {
1096 return errors::InvalidArgument("Fill dimensions must be >= 0");
1097 }
1098 }
1099 }
1100
1101 ShapeHandle out;
1102 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
1103 c->set_output(0, out);
1104
1105 auto* shape_and_type = c->input_handle_shapes_and_types(1);
1106 if (shape_and_type) {
1107 c->set_output_handle_shapes_and_types(0, *shape_and_type);
1108 }
1109
1110 return Status::OK();
1111 });
1112
1113 // --------------------------------------------------------------------------
1114 REGISTER_OP("_ParallelConcatStart")
1115 .Output("output: dtype")
1116 .Attr("shape: shape")
1117 .Attr("dtype: type")
1118 .SetIsStateful()
1119 .SetShapeFn(shape_inference::ExplicitShape)
1120 .Doc(R"doc(
1121 Creates an empty Tensor with shape `shape` and type `dtype`.
1122
1123 The memory can optionally be initialized. This is usually useful in
1124 conjunction with inplace operations.
1125
1126 shape: 1-D `Tensor` indicating the shape of the output.
1127 dtype: The element type of the returned tensor.
1128 output: An empty Tensor of the specified type.
1129 )doc");
1130
1131 // --------------------------------------------------------------------------
1132 REGISTER_OP("_ParallelConcatUpdate")
1133 .Input("value: T")
1134 .Input("update: T")
1135 .Output("output: T")
1136 .Attr("T: type")
1137 .Attr("loc: int")
1138 .SetShapeFn(shape_inference::UnchangedShape)
1139 .Doc(R"doc(
1140 Updates input `value` at `loc` with `update`.
1141
1142 If you use this function you will almost certainly want to add
1143 a control dependency as done in the implementation of parallel_stack to
1144 avoid race conditions.
1145
1146 value: A `Tensor` object that will be updated in-place.
1147 loc: A scalar indicating the index of the first dimension such that
1148 value[loc, :] is updated.
1149 update: A `Tensor` of rank one less than `value` if `loc` is a scalar,
1150 otherwise of rank equal to `value` that contains the new values
1151 for `value`.
1152 output: `value` that has been updated accordingly.
1153 )doc");
1154
1155 // --------------------------------------------------------------------------
1156 REGISTER_OP("Gather")
1157 .Input("params: Tparams")
1158 .Input("indices: Tindices")
1159 .Attr("validate_indices: bool = true")
1160 .Output("output: Tparams")
1161 .Attr("Tparams: type")
1162 .Attr("Tindices: {int32,int64}")
__anondb9326b21702(InferenceContext* c) 1163 .SetShapeFn([](InferenceContext* c) {
1164 ShapeHandle unused;
1165 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &unused));
1166 ShapeHandle params_subshape;
1167 TF_RETURN_IF_ERROR(c->Subshape(c->input(0), 1, ¶ms_subshape));
1168 ShapeHandle indices_shape = c->input(1);
1169 ShapeHandle out;
1170 TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, params_subshape, &out));
1171 c->set_output(0, out);
1172 return Status::OK();
1173 });
1174
1175 // --------------------------------------------------------------------------
1176 REGISTER_OP("GatherV2")
1177 .Input("params: Tparams")
1178 .Input("indices: Tindices")
1179 .Input("axis: Taxis")
1180 .Attr("batch_dims: int = 0")
1181 .Output("output: Tparams")
1182 .Attr("Tparams: type")
1183 .Attr("Tindices: {int32,int64}")
1184 .Attr("Taxis: {int32,int64}")
__anondb9326b21802(InferenceContext* c) 1185 .SetShapeFn([](InferenceContext* c) {
1186 ShapeHandle params_shape;
1187 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, ¶ms_shape));
1188
1189 ShapeHandle indices_shape = c->input(1);
1190 ShapeHandle unused_axis_shape;
1191 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_axis_shape));
1192 const Tensor* axis_t = c->input_tensor(2);
1193
1194 // If axis is unknown, we can only infer that the result is params_rank +
1195 // indices_rank - 1.
1196 if (axis_t == nullptr) {
1197 if (c->RankKnown(params_shape) && c->RankKnown(indices_shape)) {
1198 c->set_output(0, c->UnknownShapeOfRank(c->Rank(params_shape) +
1199 c->Rank(indices_shape) - 1));
1200 } else {
1201 c->set_output(0, c->UnknownShape());
1202 }
1203 return Status::OK();
1204 }
1205
1206 // Note, axis can be negative.
1207 int64 axis = 0;
1208 if (axis_t->dtype() == DT_INT32) {
1209 axis = axis_t->scalar<int32>()();
1210 } else {
1211 axis = axis_t->scalar<int64>()();
1212 }
1213
1214 // Check that params has rank of at least axis + 1.
1215 ShapeHandle unused;
1216 TF_RETURN_IF_ERROR(c->WithRankAtLeast(
1217 params_shape, axis < 0 ? -axis : axis + 1, &unused));
1218
1219 // Note, batch_dims can be negative.
1220 int32 batch_dims;
1221 TF_RETURN_IF_ERROR(c->GetAttr("batch_dims", &batch_dims));
1222 // -rank(indices) <= batch_dims <= rank(indices)
1223 TF_RETURN_IF_ERROR(
1224 c->WithRankAtLeast(indices_shape, std::abs(batch_dims), &unused));
1225 if (batch_dims < 0) {
1226 batch_dims += c->Rank(indices_shape);
1227 }
1228 // rank(params) > batch_dims
1229 TF_RETURN_IF_ERROR(
1230 c->WithRankAtLeast(params_shape, batch_dims + 1, &unused));
1231
1232 ShapeHandle params_outer_subshape;
1233 TF_RETURN_IF_ERROR(
1234 c->Subshape(params_shape, 0, axis, ¶ms_outer_subshape));
1235
1236 ShapeHandle indices_inner_subshape;
1237 TF_RETURN_IF_ERROR(
1238 c->Subshape(indices_shape, batch_dims, &indices_inner_subshape));
1239
1240 ShapeHandle out;
1241 TF_RETURN_IF_ERROR(
1242 c->Concatenate(params_outer_subshape, indices_inner_subshape, &out));
1243
1244 // Slice from axis + 1 to the end of params_shape to collect the inner
1245 // dimensions of the result. Special case -1 here since -1 + 1 wraps, and
1246 // we slice from 0 to the end of shape. Subshape() handles all other
1247 // out-of-bounds checking.
1248 if (axis != -1) {
1249 ShapeHandle params_inner_subshape;
1250 TF_RETURN_IF_ERROR(
1251 c->Subshape(params_shape, axis + 1, ¶ms_inner_subshape));
1252 TF_RETURN_IF_ERROR(c->Concatenate(out, params_inner_subshape, &out));
1253 }
1254
1255 c->set_output(0, out);
1256 return Status::OK();
1257 });
1258
1259 // --------------------------------------------------------------------------
1260 REGISTER_OP("GatherNd")
1261 .Input("params: Tparams")
1262 .Input("indices: Tindices")
1263 .Output("output: Tparams")
1264 .Attr("Tparams: type")
1265 .Attr("Tindices: {int32,int64}")
1266 .SetShapeFn(shape_inference::GatherNdShape);
1267
1268 // --------------------------------------------------------------------------
1269 REGISTER_OP("Identity")
1270 .Input("input: T")
1271 .Output("output: T")
1272 .Attr("T: type")
1273 .SetShapeFn(shape_inference::UnchangedShape);
1274
1275 REGISTER_OP("Snapshot")
1276 .Input("input: T")
1277 .Output("output: T")
1278 .Attr("T: type")
1279 .SetShapeFn(shape_inference::UnchangedShape);
1280
1281 #ifdef INTEL_MKL
1282 REGISTER_OP("_MklIdentity")
1283 .Input("input: T")
1284 .Input("mkl_input: uint8")
1285 .Output("output: T")
1286 .Output("mkl_output: uint8")
1287 .Attr("T: type")
1288 .SetShapeFn(shape_inference::UnchangedShape)
1289 .Doc(R"Doc( Mkl implementation of IdentityOp
1290 )Doc");
1291 #endif
1292
1293 REGISTER_OP("IdentityN")
1294 .Input("input: T")
1295 .Output("output: T")
1296 .Attr("T: list(type)")
__anondb9326b21902(shape_inference::InferenceContext* c) 1297 .SetShapeFn([](shape_inference::InferenceContext* c) {
1298 std::vector<ShapeHandle> input;
1299 TF_RETURN_IF_ERROR(c->input("input", &input));
1300 TF_RETURN_IF_ERROR(c->set_output("output", input));
1301 // If any of the input shapes are not known, we should return error.
1302 for (int i = 0; i < input.size(); i++) {
1303 if (!input[i].Handle()) {
1304 return errors::InvalidArgument(absl::StrCat(
1305 "Cannot infer output shape #", i,
1306 " for IdentityN node because input shape #", i, " is unknown."));
1307 }
1308 }
1309 return Status::OK();
1310 });
1311
1312 // --------------------------------------------------------------------------
1313 REGISTER_OP("RefIdentity")
1314 .Input("input: Ref(T)")
1315 .Output("output: Ref(T)")
1316 .Attr("T: type")
1317 .SetShapeFn(shape_inference::UnchangedShape)
1318 .SetAllowsUninitializedInput();
1319
1320 // --------------------------------------------------------------------------
1321 REGISTER_OP("DebugGradientIdentity")
1322 .Input("input: T")
1323 .Output("output: T")
1324 .Attr("T: type")
1325 .SetShapeFn(shape_inference::UnchangedShape)
1326 .SetAllowsUninitializedInput();
1327
1328 REGISTER_OP("DebugGradientRefIdentity")
1329 .Input("input: Ref(T)")
1330 .Output("output: Ref(T)")
1331 .Attr("T: type")
1332 .SetShapeFn(shape_inference::UnchangedShape)
1333 .SetAllowsUninitializedInput();
1334
1335 // --------------------------------------------------------------------------
1336 REGISTER_OP("StopGradient")
1337 .Input("input: T")
1338 .Output("output: T")
1339 .Attr("T: type")
1340 .SetShapeFn(shape_inference::UnchangedShape);
1341
1342 REGISTER_OP("PreventGradient")
1343 .Input("input: T")
1344 .Output("output: T")
1345 .Attr("T: type")
1346 .Attr("message: string = ''")
1347 .SetShapeFn(shape_inference::UnchangedShape);
1348
1349 // --------------------------------------------------------------------------
1350 REGISTER_OP("CheckNumerics")
1351 .Input("tensor: T")
1352 .Output("output: T")
1353 .Attr("T: {bfloat16, half, float, double}")
1354 .Attr("message: string")
1355 .SetIsStateful()
1356 .SetShapeFn(shape_inference::UnchangedShape);
1357
1358 // --------------------------------------------------------------------------
1359 REGISTER_OP("CheckNumericsV2")
1360 .Input("tensor: T")
1361 .Output("output: T")
1362 .Attr("T: {bfloat16, half, float, double}")
1363 .Attr("message: string")
1364 .SetIsStateful()
1365 .SetShapeFn(shape_inference::UnchangedShape);
1366
1367 // --------------------------------------------------------------------------
1368 REGISTER_OP("Reshape")
1369 .Input("tensor: T")
1370 .Input("shape: Tshape")
1371 .Output("output: T")
1372 .Attr("T: type")
1373 .Attr("Tshape: {int32, int64} = DT_INT32")
__anondb9326b21a02(InferenceContext* c) 1374 .SetShapeFn([](InferenceContext* c) {
1375 return SetOutputShapeForReshape(c);
1376 });
1377
1378 #ifdef INTEL_MKL
1379 REGISTER_OP("_MklReshape")
1380 .Input("tensor: T")
1381 .Input("shape: Tshape")
1382 .Input("mkl_tensor: uint8")
1383 .Input("mkl_shape: uint8")
1384 .Output("output: T")
1385 .Output("mkl_output: uint8")
1386 .Attr("T: type")
1387 .Attr("Tshape: {int32, int64} = DT_INT32")
__anondb9326b21b02(InferenceContext* c) 1388 .SetShapeFn([](InferenceContext* c) { return SetOutputShapeForReshape(c); })
1389 .Doc(R"Doc( MKL implementation of ReshapeOp.
1390 )Doc");
1391 #endif // INTEL_MKL
1392
1393 // --------------------------------------------------------------------------
1394 REGISTER_OP("InvertPermutation")
1395 .Input("x: T")
1396 .Output("y: T")
1397 .Attr("T: {int32, int64} = DT_INT32")
__anondb9326b21c02(InferenceContext* c) 1398 .SetShapeFn([](InferenceContext* c) {
1399 ShapeHandle x;
1400 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &x));
1401 c->set_output(0, x);
1402 return Status::OK();
1403 });
1404
1405 // --------------------------------------------------------------------------
1406 REGISTER_OP("Transpose")
1407 .Input("x: T")
1408 .Input("perm: Tperm")
1409 .Output("y: T")
1410 .Attr("T: type")
1411 .Attr("Tperm: {int32, int64} = DT_INT32")
1412 .SetShapeFn(TransposeShapeFn);
1413
1414 #ifdef INTEL_MKL
1415 REGISTER_OP("_MklTranspose")
1416 .Input("x: T")
1417 .Input("perm: Tperm")
1418 .Output("y: T")
1419 .Attr("T: type")
1420 .Attr("Tperm: {int32, int64} = DT_INT32")
1421 .SetShapeFn(TransposeShapeFn);
1422 #endif // INTEL_MKL
1423
1424 // --------------------------------------------------------------------------
1425 REGISTER_OP("ConjugateTranspose")
1426 .Input("x: T")
1427 .Input("perm: Tperm")
1428 .Output("y: T")
1429 .Attr("T: type")
1430 .Attr("Tperm: {int32, int64} = DT_INT32")
1431 .SetShapeFn(TransposeShapeFn);
1432
1433 #ifdef INTEL_MKL
1434 REGISTER_OP("_MklConjugateTranspose")
1435 .Input("x: T")
1436 .Input("perm: Tperm")
1437 .Output("y: T")
1438 .Attr("T: type")
1439 .Attr("Tperm: {int32, int64} = DT_INT32")
1440 .SetShapeFn(TransposeShapeFn);
1441 #endif // INTEL_MKL
1442
1443 // --------------------------------------------------------------------------
1444 namespace {
UniqueIdxShapeFn(InferenceContext * c)1445 Status UniqueIdxShapeFn(InferenceContext* c) {
1446 ShapeHandle input = c->input(0);
1447 const Tensor* axis_t = c->input_tensor(1);
1448 if (axis_t == nullptr || !c->RankKnown(input)) {
1449 c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
1450 return Status::OK();
1451 }
1452
1453 if (c->Rank(c->input(1)) != 1) {
1454 return errors::InvalidArgument("axis expects a 1D vector.");
1455 }
1456
1457 int32 n = axis_t->NumElements();
1458 if (n == 0) {
1459 if (c->Rank(input) != 1) {
1460 return errors::InvalidArgument("x expects a 1D vector.");
1461 }
1462 c->set_output(1, input);
1463 return Status::OK();
1464 } else if (n == 1) {
1465 int64 axis;
1466 if (axis_t->dtype() == DT_INT32) {
1467 axis = static_cast<int64>(axis_t->flat<int32>()(0));
1468 } else {
1469 axis = axis_t->flat<int64>()(0);
1470 }
1471
1472 int64 input_rank = c->Rank(input);
1473 if (axis < -input_rank || axis >= input_rank) {
1474 return errors::InvalidArgument("axis expects to be in the range [",
1475 -input_rank, ", ", input_rank, ")");
1476 }
1477 if (axis < 0) {
1478 axis += input_rank;
1479 }
1480 c->set_output(1, c->Vector(c->Dim(input, axis)));
1481 return Status::OK();
1482 }
1483 return errors::InvalidArgument(
1484 "axis does not support input tensors larger than 1 elements.");
1485 }
1486 } // namespace
1487
1488 REGISTER_OP("Unique")
1489 .Input("x: T")
1490 .Output("y: T")
1491 .Output("idx: out_idx")
1492 .Attr("T: type")
1493 .Attr("out_idx: {int32, int64} = DT_INT32")
__anondb9326b21e02(InferenceContext* c) 1494 .SetShapeFn([](InferenceContext* c) {
1495 c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
1496 c->set_output(1, c->input(0));
1497 // Assert that the input rank is 1.
1498 ShapeHandle dummy;
1499 return c->WithRank(c->input(0), 1, &dummy);
1500 });
1501
1502 REGISTER_OP("UniqueV2")
1503 .Input("x: T")
1504 .Input("axis: Taxis")
1505 .Output("y: T")
1506 .Output("idx: out_idx")
1507 .Attr("T: type")
1508 .Attr("Taxis: {int32,int64} = DT_INT64")
1509 .Attr("out_idx: {int32, int64} = DT_INT32")
__anondb9326b21f02(InferenceContext* c) 1510 .SetShapeFn([](InferenceContext* c) {
1511 c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0))));
1512 TF_RETURN_IF_ERROR(UniqueIdxShapeFn(c));
1513 return Status::OK();
1514 });
1515
1516 // --------------------------------------------------------------------------
1517 REGISTER_OP("UniqueWithCounts")
1518 .Input("x: T")
1519 .Output("y: T")
1520 .Output("idx: out_idx")
1521 .Output("count: out_idx")
1522 .Attr("T: type")
1523 .Attr("out_idx: {int32, int64} = DT_INT32")
__anondb9326b22002(InferenceContext* c) 1524 .SetShapeFn([](InferenceContext* c) {
1525 auto uniq = c->Vector(InferenceContext::kUnknownDim);
1526 c->set_output(0, uniq);
1527 c->set_output(1, c->input(0));
1528 c->set_output(2, uniq);
1529 return Status::OK();
1530 });
1531
1532 REGISTER_OP("UniqueWithCountsV2")
1533 .Input("x: T")
1534 .Input("axis: Taxis")
1535 .Output("y: T")
1536 .Output("idx: out_idx")
1537 .Output("count: out_idx")
1538 .Attr("T: type")
1539 .Attr("Taxis: {int32,int64} = DT_INT64")
1540 .Attr("out_idx: {int32, int64} = DT_INT32")
__anondb9326b22102(InferenceContext* c) 1541 .SetShapeFn([](InferenceContext* c) {
1542 c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0))));
1543 TF_RETURN_IF_ERROR(UniqueIdxShapeFn(c));
1544 c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
1545 return Status::OK();
1546 });
1547
1548 namespace {
1549
ShapeShapeFn(InferenceContext * c)1550 Status ShapeShapeFn(InferenceContext* c) {
1551 for (int i = 0; i < c->num_inputs(); ++i) {
1552 DimensionHandle dim;
1553 if (c->RankKnown(c->input(i))) {
1554 dim = c->MakeDim(c->Rank(c->input(i)));
1555 } else {
1556 dim = c->UnknownDim();
1557 }
1558 c->set_output(i, c->Vector(dim));
1559 }
1560 return Status::OK();
1561 }
1562
1563 } // namespace
1564
1565 // --------------------------------------------------------------------------
1566 REGISTER_OP("Shape")
1567 .Input("input: T")
1568 .Output("output: out_type")
1569 .Attr("T: type")
1570 .Attr("out_type: {int32, int64} = DT_INT32")
1571 .SetShapeFn(ShapeShapeFn);
1572
1573 REGISTER_OP("ShapeN")
1574 .Input("input: N * T")
1575 .Output("output: N * out_type")
1576 .Attr("N: int")
1577 .Attr("T: type")
1578 .Attr("out_type: {int32, int64} = DT_INT32")
1579 .SetShapeFn(ShapeShapeFn);
1580
1581 REGISTER_OP("EnsureShape")
1582 .Input("input: T")
1583 .Output("output: T")
1584 .Attr("shape: shape")
1585 .Attr("T: type")
__anondb9326b22302(InferenceContext* c) 1586 .SetShapeFn([](InferenceContext* c) {
1587 // Merges desired shape and statically known shape of input
1588 PartialTensorShape desired_shape;
1589 TF_RETURN_IF_ERROR(c->GetAttr("shape", &desired_shape));
1590
1591 int rank = desired_shape.dims();
1592 ShapeHandle input_shape_handle;
1593 ShapeHandle desired_shape_handle;
1594 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape_handle));
1595 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
1596 desired_shape, &desired_shape_handle));
1597
1598 ShapeHandle merged_shape;
1599 TF_RETURN_IF_ERROR(
1600 c->Merge(desired_shape_handle, input_shape_handle, &merged_shape));
1601 c->set_output(0, merged_shape);
1602 return Status::OK();
1603 });
1604
1605 // --------------------------------------------------------------------------
1606 REGISTER_OP("ReverseSequence")
1607 .Input("input: T")
1608 .Input("seq_lengths: Tlen")
1609 .Output("output: T")
1610 .Attr("seq_dim: int")
1611 .Attr("batch_dim: int = 0")
1612 .Attr("T: type")
1613 .Attr("Tlen: {int32, int64} = DT_INT64")
__anondb9326b22402(InferenceContext* c) 1614 .SetShapeFn([](InferenceContext* c) {
1615 ShapeHandle input = c->input(0);
1616 ShapeHandle seq_lens_shape;
1617 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seq_lens_shape));
1618
1619 int64 seq_dim;
1620 TF_RETURN_IF_ERROR(c->GetAttr("seq_dim", &seq_dim));
1621 int64 batch_dim;
1622 TF_RETURN_IF_ERROR(c->GetAttr("batch_dim", &batch_dim));
1623
1624 if (!c->RankKnown(input)) {
1625 return shape_inference::UnknownShape(c);
1626 }
1627
1628 // Validate batch_dim and seq_dim against input.
1629 const int32 input_rank = c->Rank(input);
1630 if (batch_dim >= input_rank) {
1631 return errors::InvalidArgument(
1632 "batch_dim must be < input rank: ", batch_dim, " vs. ", input_rank);
1633 }
1634 if (seq_dim >= input_rank) {
1635 return errors::InvalidArgument(
1636 "seq_dim must be < input rank: ", seq_dim, " vs. ", input_rank);
1637 }
1638
1639 DimensionHandle batch_dim_dim = c->Dim(input, batch_dim);
1640 TF_RETURN_IF_ERROR(
1641 c->Merge(batch_dim_dim, c->Dim(seq_lens_shape, 0), &batch_dim_dim));
1642
1643 // Replace batch_dim of input with batch_size
1644 ShapeHandle output_shape;
1645 TF_RETURN_IF_ERROR(
1646 c->ReplaceDim(input, batch_dim, batch_dim_dim, &output_shape));
1647 c->set_output(0, output_shape);
1648 return Status::OK();
1649 });
1650
1651 // --------------------------------------------------------------------------
1652 REGISTER_OP("Rank")
1653 .Input("input: T")
1654 .Output("output: int32")
1655 .Attr("T: type")
1656 .SetShapeFn(shape_inference::ScalarShape);
1657
1658 // --------------------------------------------------------------------------
1659 REGISTER_OP("Size")
1660 .Input("input: T")
1661 .Output("output: out_type")
1662 .Attr("T: type")
1663 .Attr("out_type: {int32, int64} = DT_INT32")
1664 .SetShapeFn(shape_inference::ScalarShape);
1665
1666 // --------------------------------------------------------------------------
1667 REGISTER_OP("Slice")
1668 .Input("input: T")
1669 .Input("begin: Index")
1670 .Input("size: Index")
1671 .Output("output: T")
1672 .Attr("T: type")
1673 .Attr("Index: {int32,int64}")
1674 .SetShapeFn(shape_inference::SliceShape);
1675
1676 #ifdef INTEL_MKL
1677 REGISTER_OP("_MklSlice")
1678 .Input("input: T")
1679 .Input("begin: Index")
1680 .Input("size: Index")
1681 .Input("mkl_input: uint8")
1682 .Input("mkl_begin: uint8")
1683 .Input("mkl_size: uint8")
1684 .Output("output: T")
1685 .Output("mkl_output: uint8")
1686 .Attr("T: type")
1687 .Attr("Index: {int32,int64}")
1688 .SetShapeFn(shape_inference::SliceShape);
1689 #endif
1690
1691 REGISTER_OP("StridedSlice")
1692 .Input("input: T")
1693 .Input("begin: Index")
1694 .Input("end: Index")
1695 .Input("strides: Index")
1696 .Output("output: T")
1697 .Attr("T: type")
1698 .Attr("Index: {int32, int64}")
1699 .Attr("begin_mask: int = 0")
1700 .Attr("end_mask: int = 0")
1701 .Attr("ellipsis_mask: int = 0")
1702 .Attr("new_axis_mask: int = 0")
1703 .Attr("shrink_axis_mask: int = 0")
__anondb9326b22502(InferenceContext* c) 1704 .SetShapeFn([](InferenceContext* c) {
1705 ShapeHandle input = c->input(0);
1706 ShapeHandle begin_shape, end_shape, strides_shape;
1707 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape));
1708 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &end_shape));
1709 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &strides_shape));
1710 TF_RETURN_IF_ERROR(c->Merge(begin_shape, end_shape, &begin_shape));
1711 TF_RETURN_IF_ERROR(c->Merge(begin_shape, strides_shape, &begin_shape));
1712 DimensionHandle sparse_dims_dim = c->Dim(begin_shape, 0);
1713
1714 const Tensor* strides_value = c->input_tensor(3);
1715 // TODO(aselle,allenl): If we had a stride_mask it would be possible to do
1716 // more shape inference here (e.g. for x[3, ::T]).
1717 if (!c->RankKnown(input) || !c->ValueKnown(sparse_dims_dim) ||
1718 strides_value == nullptr) {
1719 c->set_output(0, c->UnknownShape());
1720 return Status::OK();
1721 }
1722
1723 PartialTensorShape input_shape({});
1724 for (int i = 0; i < c->Rank(input); ++i) {
1725 auto dim = c->Dim(input, i);
1726 input_shape.AddDim(c->ValueKnown(dim) ? c->Value(dim) : -1);
1727 }
1728
1729 int32 begin_mask, end_mask, ellipsis_mask, new_axis_mask,
1730 shrink_axis_mask;
1731 TF_RETURN_IF_ERROR(c->GetAttr("begin_mask", &begin_mask));
1732 TF_RETURN_IF_ERROR(c->GetAttr("end_mask", &end_mask));
1733 TF_RETURN_IF_ERROR(c->GetAttr("ellipsis_mask", &ellipsis_mask));
1734 TF_RETURN_IF_ERROR(c->GetAttr("new_axis_mask", &new_axis_mask));
1735 TF_RETURN_IF_ERROR(c->GetAttr("shrink_axis_mask", &shrink_axis_mask));
1736
1737 const Tensor* begin_value = c->input_tensor(1);
1738 const Tensor* end_value = c->input_tensor(2);
1739
1740 PartialTensorShape processing_shape, final_shape;
1741 bool is_identity, is_simple_slice, slice_dim0;
1742 gtl::InlinedVector<int64, 4> begin, end, strides;
1743 TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
1744 begin_value, end_value, *strides_value, input_shape, begin_mask,
1745 end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask,
1746 &processing_shape, &final_shape, &is_identity, &is_simple_slice,
1747 &slice_dim0, &begin, &end, &strides));
1748
1749 ShapeHandle out;
1750 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(final_shape, &out));
1751 c->set_output(0, out);
1752
1753 auto* shape_and_type = c->input_handle_shapes_and_types(0);
1754 if (shape_and_type) {
1755 c->set_output_handle_shapes_and_types(0, *shape_and_type);
1756 }
1757
1758 return Status::OK();
1759 });
1760
1761 REGISTER_OP("StridedSliceGrad")
1762 .Input("shape: Index")
1763 .Input("begin: Index")
1764 .Input("end: Index")
1765 .Input("strides: Index")
1766 .Input("dy: T")
1767 .Output("output: T")
1768 .Attr("T: type")
1769 .Attr("Index: {int32, int64}")
1770 .Attr("begin_mask: int = 0")
1771 .Attr("end_mask: int = 0")
1772 .Attr("ellipsis_mask: int = 0")
1773 .Attr("new_axis_mask: int = 0")
1774 .Attr("shrink_axis_mask: int = 0")
__anondb9326b22602(InferenceContext* c) 1775 .SetShapeFn([](InferenceContext* c) {
1776 ShapeHandle out;
1777 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
1778 c->set_output(0, out);
1779 return Status::OK();
1780 });
1781
1782 REGISTER_OP("StridedSliceAssign")
1783 .Input("ref: Ref(T)")
1784 .Input("begin: Index")
1785 .Input("end: Index")
1786 .Input("strides: Index")
1787 .Input("value: T")
1788 .Output("output_ref: Ref(T)")
1789 .Attr("T: type")
1790 .Attr("Index: {int32, int64}")
1791 .Attr("begin_mask: int = 0")
1792 .Attr("end_mask: int = 0")
1793 .Attr("ellipsis_mask: int = 0")
1794 .Attr("new_axis_mask: int = 0")
1795 .Attr("shrink_axis_mask: int = 0")
1796 .SetShapeFn(shape_inference::UnchangedShape);
1797 // TODO(aselle): Fix this documentation once StridedSliceAssign Supports
1798 // broadcasting.
1799 // --------------------------------------------------------------------------
1800
1801 REGISTER_OP("ResourceStridedSliceAssign")
1802 .Input("ref: resource")
1803 .Input("begin: Index")
1804 .Input("end: Index")
1805 .Input("strides: Index")
1806 .Input("value: T")
1807 .Attr("T: type")
1808 .Attr("Index: {int32, int64}")
1809 .Attr("begin_mask: int = 0")
1810 .Attr("end_mask: int = 0")
1811 .Attr("ellipsis_mask: int = 0")
1812 .Attr("new_axis_mask: int = 0")
1813 .Attr("shrink_axis_mask: int = 0")
1814 .SetShapeFn(shape_inference::NoOutputs);
1815
1816 REGISTER_OP("TensorStridedSliceUpdate")
1817 .Input("input: T")
1818 .Input("begin: Index")
1819 .Input("end: Index")
1820 .Input("strides: Index")
1821 .Input("value: T")
1822 .Output("output: T")
1823 .Attr("T: type")
1824 .Attr("Index: {int32, int64}")
1825 .Attr("begin_mask: int = 0")
1826 .Attr("end_mask: int = 0")
1827 .Attr("ellipsis_mask: int = 0")
1828 .Attr("new_axis_mask: int = 0")
1829 .Attr("shrink_axis_mask: int = 0")
1830 .SetShapeFn(shape_inference::UnchangedShape);
1831
1832 REGISTER_OP("Tile")
1833 .Input("input: T")
1834 .Input("multiples: Tmultiples")
1835 .Output("output: T")
1836 .Attr("T: type")
1837 .Attr("Tmultiples: {int32, int64} = DT_INT32")
__anondb9326b22702(InferenceContext* c) 1838 .SetShapeFn([](InferenceContext* c) {
1839 ShapeHandle input = c->input(0);
1840 // NOTE(mrry): Represent `multiples` as a `TensorShape` because (i)
1841 // it is a vector of non-negative integers, and (ii) doing so allows
1842 // us to handle partially-known multiples.
1843 ShapeHandle multiples;
1844 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &multiples));
1845 if (c->RankKnown(input)) {
1846 TF_RETURN_IF_ERROR(c->WithRank(multiples, c->Rank(input), &multiples));
1847 ShapeHandle dummy;
1848 TF_RETURN_IF_ERROR(
1849 c->Merge(c->input(1), c->Vector(c->Rank(input)), &dummy));
1850 }
1851
1852 if (!c->RankKnown(multiples)) {
1853 return shape_inference::UnknownShape(c);
1854 }
1855
1856 int32 rank = c->Rank(multiples);
1857 TF_RETURN_IF_ERROR(c->WithRank(input, rank, &input));
1858 std::vector<DimensionHandle> dims(rank);
1859 for (int i = 0; i < rank; ++i) {
1860 TF_RETURN_IF_ERROR(
1861 c->Multiply(c->Dim(input, i), c->Dim(multiples, i), &dims[i]));
1862 }
1863 c->set_output(0, c->MakeShape(dims));
1864 return Status::OK();
1865 });
1866
1867 // --------------------------------------------------------------------------
1868 REGISTER_OP("TileGrad")
1869 .Input("input: T")
1870 .Input("multiples: int32")
1871 .Output("output: T")
1872 .Attr("T: type")
1873 .Deprecated(3, "TileGrad has been replaced with reduce_sum")
1874 .SetShapeFn(tensorflow::shape_inference::UnknownShape);
1875
1876 // --------------------------------------------------------------------------
1877 REGISTER_OP("Where")
1878 .Input("input: T")
1879 .Attr("T: {numbertype, bool} = DT_BOOL")
1880 .Output("index: int64")
__anondb9326b22802(InferenceContext* c) 1881 .SetShapeFn([](InferenceContext* c) {
1882 c->set_output(0, c->Matrix(c->UnknownDim(), c->Rank(c->input(0))));
1883 return Status::OK();
1884 });
1885
1886 // --------------------------------------------------------------------------
1887 REGISTER_OP("BroadcastArgs")
1888 .Input("s0: T")
1889 .Input("s1: T")
1890 .Output("r0: T")
1891 .Attr("T: {int32, int64} = DT_INT32")
__anondb9326b22902(InferenceContext* c) 1892 .SetShapeFn([](InferenceContext* c) {
1893 ShapeHandle unused;
1894 ShapeHandle shape_x = c->input(0);
1895 ShapeHandle shape_y = c->input(1);
1896 TF_RETURN_IF_ERROR(c->WithRank(shape_x, 1, &unused));
1897 TF_RETURN_IF_ERROR(c->WithRank(shape_y, 1, &unused));
1898
1899 if (!c->ValueKnown(c->Dim(shape_x, 0)) ||
1900 !c->ValueKnown(c->Dim(shape_y, 0))) {
1901 c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
1902 return Status::OK();
1903 }
1904
1905 int64 x_dim = c->Value(c->Dim(shape_x, 0));
1906 int64 y_dim = c->Value(c->Dim(shape_y, 0));
1907
1908 // Broadcasted shape is going to be as large as the largest dimension.
1909 c->set_output(0, c->Vector(std::max(x_dim, y_dim)));
1910 return Status::OK();
1911 });
1912
1913 // --------------------------------------------------------------------------
1914 REGISTER_OP("BroadcastGradientArgs")
1915 .Input("s0: T")
1916 .Input("s1: T")
1917 .Output("r0: T")
1918 .Output("r1: T")
1919 .Attr("T: {int32, int64} = DT_INT32")
__anondb9326b22a02(InferenceContext* c) 1920 .SetShapeFn([](InferenceContext* c) {
1921 // TODO(mrry): Implement constant_value for BroadcastGradientArgs?
1922 ShapeHandle unused;
1923 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
1924 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
1925 c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
1926 c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
1927 return Status::OK();
1928 });
1929
1930 // --------------------------------------------------------------------------
1931 REGISTER_OP("Pad")
1932 .Input("input: T")
1933 .Input("paddings: Tpaddings")
1934 .Output("output: T")
1935 .Attr("T: type")
1936 .Attr("Tpaddings: {int32, int64} = DT_INT32")
1937 .SetShapeFn(PadShapeFn);
1938
1939 // --------------------------------------------------------------------------
1940 REGISTER_OP("PadV2")
1941 .Input("input: T")
1942 .Input("paddings: Tpaddings")
1943 .Input("constant_values: T")
1944 .Output("output: T")
1945 .Attr("T: type")
1946 .Attr("Tpaddings: {int32, int64} = DT_INT32")
1947 .SetShapeFn(PadShapeFn);
1948
1949 // --------------------------------------------------------------------------
1950 REGISTER_OP("MirrorPad")
1951 .Input("input: T")
1952 .Input("paddings: Tpaddings")
1953 .Output("output: T")
1954 .Attr("T: type")
1955 .Attr("Tpaddings: {int32, int64} = DT_INT32")
1956 .Attr(GetMirrorPadModeAttrString())
1957 .SetShapeFn(PadShapeFn);
1958
1959 // --------------------------------------------------------------------------
1960 namespace {
1961 template <typename T>
MirrorPadKnown(InferenceContext * c,ShapeHandle input,const Tensor * paddings_t,int64 input_rank)1962 Status MirrorPadKnown(InferenceContext* c, ShapeHandle input,
1963 const Tensor* paddings_t, int64 input_rank) {
1964 auto paddings_data = paddings_t->matrix<T>();
1965 std::vector<DimensionHandle> dims(input_rank);
1966 for (int64 i = 0; i < input_rank; ++i) {
1967 const int64 pad0 = static_cast<int64>(paddings_data(i, 0));
1968 const int64 pad1 = static_cast<int64>(paddings_data(i, 1));
1969 if (pad0 < 0 || pad1 < 0) {
1970 return errors::InvalidArgument("Paddings must be non-negative");
1971 }
1972
1973 TF_RETURN_IF_ERROR(c->Subtract(c->Dim(input, i), pad0 + pad1, &dims[i]));
1974 }
1975 c->set_output(0, c->MakeShape(dims));
1976 return Status::OK();
1977 }
1978
1979 } // namespace
1980
1981 REGISTER_OP("MirrorPadGrad")
1982 .Input("input: T")
1983 .Input("paddings: Tpaddings")
1984 .Output("output: T")
1985 .Attr("T: type")
1986 .Attr("Tpaddings: {int32, int64} = DT_INT32")
1987 .Attr(GetMirrorPadModeAttrString())
__anondb9326b22c02(InferenceContext* c) 1988 .SetShapeFn([](InferenceContext* c) {
1989 ShapeHandle paddings;
1990 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &paddings));
1991 DimensionHandle pad_0 = c->Dim(paddings, 0);
1992 if (!c->ValueKnown(pad_0)) {
1993 // We don't know the rank of the output since the first
1994 // padding dimension is unknown.
1995 c->set_output(0, c->UnknownShape());
1996 return Status::OK();
1997 }
1998
1999 int64 input_rank = c->Value(pad_0);
2000 ShapeHandle input;
2001 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), input_rank, &input));
2002 TF_RETURN_IF_ERROR(
2003 c->Merge(paddings, c->Matrix(input_rank, 2), &paddings));
2004
2005 const Tensor* paddings_t = c->input_tensor(1);
2006 if (paddings_t == nullptr) {
2007 // Values of 'paddings' is not available, but we know the
2008 // input rank, so return the rank of the output with unknown
2009 // dimensions.
2010 c->set_output(0, c->UnknownShapeOfRank(input_rank));
2011 return Status::OK();
2012 }
2013
2014 if (paddings_t->dtype() == DT_INT32) {
2015 return MirrorPadKnown<int32>(c, input, paddings_t, input_rank);
2016 } else {
2017 return MirrorPadKnown<int64>(c, input, paddings_t, input_rank);
2018 }
2019 });
2020
2021 // --------------------------------------------------------------------------
2022 REGISTER_OP("Placeholder")
2023 .Output("output: dtype")
2024 .Attr("dtype: type")
2025 .Attr("shape: shape = { unknown_rank: true }")
__anondb9326b22d02(InferenceContext* c) 2026 .SetShapeFn([](InferenceContext* c) {
2027 PartialTensorShape shape;
2028 TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
2029
2030 // Placeholder has legacy behavior where we cannot tell the difference
2031 // between a scalar shape attribute and 'unknown shape'. So if the shape
2032 // is a scalar, we return an unknown shape.
2033 if (c->graph_def_version() <= 21 && shape.dims() <= 0) {
2034 return shape_inference::UnknownShape(c);
2035 }
2036
2037 ShapeHandle out;
2038 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out));
2039 c->set_output(0, out);
2040 return Status::OK();
2041 });
2042
2043 // Placeholder was modified in a backwards compatible way to do what
2044 // PlaceholderV2 did, so we have deprecated V2 (no one was really
2045 // using it).
2046 REGISTER_OP("PlaceholderV2")
2047 .Output("output: dtype")
2048 .Attr("dtype: type")
2049 .Attr("shape: shape")
2050 .SetShapeFn(shape_inference::ExplicitShape)
2051 .Deprecated(23, "Placeholder now behaves the same as PlaceholderV2.");
2052
2053 // --------------------------------------------------------------------------
2054 REGISTER_OP("PlaceholderWithDefault")
2055 .Input("input: dtype")
2056 .Output("output: dtype")
2057 .Attr("dtype: type")
2058 .Attr("shape: shape")
__anondb9326b22e02(InferenceContext* c) 2059 .SetShapeFn([](InferenceContext* c) {
2060 ShapeHandle input = c->input(0);
2061 PartialTensorShape shape;
2062 TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
2063 ShapeHandle out;
2064 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out));
2065
2066 // We merge for compatibility checking, but return the output,
2067 // since output_shape may be less precise than input_shape.
2068 ShapeHandle unused;
2069 TF_RETURN_IF_ERROR(c->Merge(input, out, &unused));
2070 c->set_output(0, out);
2071 return Status::OK();
2072 });
2073
2074 // --------------------------------------------------------------------------
2075 REGISTER_OP("ExpandDims")
2076 .Input("input: T")
2077 .Input("dim: Tdim")
2078 .Output("output: T")
2079 .Attr("T: type")
2080 .Attr("Tdim: {int32, int64} = DT_INT32")
__anondb9326b22f02(InferenceContext* c) 2081 .SetShapeFn([](InferenceContext* c) {
2082 ShapeHandle input = c->input(0);
2083
2084 const Tensor* dim_t = c->input_tensor(1);
2085 if (dim_t != nullptr && dim_t->NumElements() != 1) {
2086 return errors::InvalidArgument(
2087 "'dim' input must be a tensor with a single value");
2088 }
2089 if (dim_t == nullptr || !c->RankKnown(input)) {
2090 c->set_output(0, c->UnknownShape());
2091 return Status::OK();
2092 }
2093
2094 int64 dim;
2095 if (dim_t->dtype() == DT_INT32) {
2096 dim = static_cast<int64>(dim_t->flat<int32>()(0));
2097 } else {
2098 dim = dim_t->flat<int64>()(0);
2099 }
2100
2101 const int32 rank = c->Rank(input);
2102 const int32 min_dim = -1 * rank - 1;
2103 if (dim < min_dim || dim > rank) {
2104 return errors::InvalidArgument("dim ", dim, " not in the interval [",
2105 min_dim, ", ", rank, "].");
2106 }
2107
2108 if (dim < 0) {
2109 dim += rank + 1;
2110 }
2111
2112 ShapeHandle end;
2113 TF_RETURN_IF_ERROR(c->Subshape(input, dim, &end));
2114
2115 // Build output as start + 1 + end.
2116 ShapeHandle output;
2117 TF_RETURN_IF_ERROR(c->Subshape(input, 0, dim, &output));
2118 TF_RETURN_IF_ERROR(c->Concatenate(output, c->Vector(1), &output));
2119 TF_RETURN_IF_ERROR(c->Concatenate(output, end, &output));
2120 c->set_output(0, output);
2121 return Status::OK();
2122 });
2123
2124 // --------------------------------------------------------------------------
2125 REGISTER_OP("Squeeze")
2126 .Input("input: T")
2127 .Output("output: T")
2128 .Attr("T: type")
2129 .Attr("squeeze_dims: list(int) >= 0 = []")
__anondb9326b23002(InferenceContext* c) 2130 .SetShapeFn([](InferenceContext* c) {
2131 ShapeHandle input = c->input(0);
2132 if (!c->RankKnown(input)) {
2133 // Input shape unknown.
2134 return shape_inference::UnknownShape(c);
2135 }
2136
2137 const int32 input_rank = c->Rank(input);
2138
2139 // Validate and wrap squeeze dimensions.
2140 std::vector<int32> squeeze_dims;
2141 TF_RETURN_IF_ERROR(c->GetAttr("squeeze_dims", &squeeze_dims));
2142 for (int i = 0; i < squeeze_dims.size(); ++i) {
2143 if (squeeze_dims[i] < -input_rank || squeeze_dims[i] >= input_rank) {
2144 return errors::InvalidArgument("squeeze_dims[", i, "] not in [",
2145 -input_rank, ",", input_rank, ").");
2146 }
2147
2148 if (squeeze_dims[i] < 0) {
2149 squeeze_dims[i] += input_rank;
2150 }
2151 }
2152
2153 std::vector<DimensionHandle> result_shape;
2154 for (int i = 0; i < input_rank; ++i) {
2155 // True if squeeze_dims contains an entry to squeeze this
2156 // dimension.
2157 bool is_explicit_match =
2158 std::find(squeeze_dims.begin(), squeeze_dims.end(), i) !=
2159 squeeze_dims.end();
2160
2161 DimensionHandle dim = c->Dim(input, i);
2162
2163 if (!c->ValueKnown(dim)) {
2164 // Assume that the squeezed dimension will be 1 at runtime.
2165 if (is_explicit_match) continue;
2166
2167 // If squeezing all 1 dimensions, and we see an unknown value,
2168 // give up and return Unknown Shape.
2169 if (squeeze_dims.empty()) {
2170 c->set_output(0, c->UnknownShape());
2171 return Status::OK();
2172 }
2173 } else if (c->Value(dim) == 1) {
2174 if (is_explicit_match || squeeze_dims.empty()) {
2175 // If explicitly squeezing, or squeezing all 1s, remove
2176 // this dimension.
2177 continue;
2178 }
2179 } else if (is_explicit_match) {
2180 return errors::InvalidArgument("Can not squeeze dim[", i,
2181 "], expected a dimension of 1, got ",
2182 c->Value(c->Dim(input, i)));
2183 }
2184
2185 result_shape.emplace_back(dim);
2186 }
2187
2188 c->set_output(0, c->MakeShape(result_shape));
2189 return Status::OK();
2190 });
2191
2192 // --------------------------------------------------------------------------
2193 REGISTER_OP("ListDiff")
2194 .Input("x: T")
2195 .Input("y: T")
2196 .Output("out: T")
2197 .Output("idx: out_idx")
2198 .Attr("T: type")
2199 .Attr("out_idx: {int32, int64} = DT_INT32")
__anondb9326b23102(InferenceContext* c) 2200 .SetShapeFn([](InferenceContext* c) {
2201 ShapeHandle unused;
2202 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
2203 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
2204 // TODO(mrry): Indicate that the length falls within an interval?
2205 ShapeHandle out = c->Vector(InferenceContext::kUnknownDim);
2206 c->set_output(0, out);
2207 c->set_output(1, out);
2208 return Status::OK();
2209 });
2210
2211 namespace {
2212
2213 // Converts Tensor to flat std::vector<int64>.
2214 template <typename InputType>
GetFlatInt64(const Tensor & t)2215 std::vector<int64> GetFlatInt64(const Tensor& t) {
2216 std::vector<int64> output(t.shape().num_elements());
2217 if (t.shape().num_elements() > 0) {
2218 auto eigen_vec = t.flat<InputType>();
2219 std::copy_n(&eigen_vec(0), output.size(), output.begin());
2220 }
2221 return output;
2222 }
2223
2224 // Converts int32 or int64 Tensor to flat std::vector<int64>.
GetFlatInt64(const Tensor & t)2225 std::vector<int64> GetFlatInt64(const Tensor& t) {
2226 if (t.dtype() == DT_INT32) {
2227 return GetFlatInt64<int32>(t);
2228 } else {
2229 return GetFlatInt64<int64>(t);
2230 }
2231 }
2232
SpaceToBatchShapeHelper(InferenceContext * c,ShapeHandle input_shape,ShapeHandle block_shape_shape,const Tensor * block_shape_t,ShapeHandle paddings_shape,const Tensor * paddings_t)2233 Status SpaceToBatchShapeHelper(InferenceContext* c, ShapeHandle input_shape,
2234 ShapeHandle block_shape_shape,
2235 const Tensor* block_shape_t,
2236 ShapeHandle paddings_shape,
2237 const Tensor* paddings_t) {
2238 if (c->Rank(block_shape_shape) != 1) {
2239 return errors::InvalidArgument("block_shape must have rank 1.");
2240 }
2241
2242 const DimensionHandle num_block_dims_handle = c->Dim(block_shape_shape, 0);
2243 if (!c->ValueKnown(num_block_dims_handle)) {
2244 return errors::InvalidArgument("block_shape must have known size.");
2245 }
2246
2247 const int64 num_block_dims = c->Value(num_block_dims_handle);
2248
2249 TF_RETURN_IF_ERROR(
2250 c->WithRankAtLeast(input_shape, num_block_dims + 1, &input_shape));
2251
2252 TF_RETURN_IF_ERROR(
2253 c->Merge(paddings_shape, c->Matrix(num_block_dims, 2), &paddings_shape));
2254
2255 DimensionHandle batch_size = c->Dim(input_shape, 0);
2256 std::vector<int64> block_shape_vec;
2257 if (block_shape_t && (block_shape_t->NumElements() > 0)) {
2258 block_shape_vec = GetFlatInt64(*block_shape_t);
2259 for (int64 dim = 0; dim < num_block_dims; ++dim) {
2260 const int64 block_shape_value = block_shape_vec[dim];
2261 if (block_shape_value < 1) {
2262 return errors::InvalidArgument("block_shape must be positive");
2263 }
2264 if (c->ValueKnown(batch_size)) {
2265 TF_RETURN_IF_ERROR(
2266 c->Multiply(batch_size, block_shape_value, &batch_size));
2267 } else {
2268 batch_size = c->UnknownDim();
2269 }
2270 }
2271 } else if (num_block_dims > 0) {
2272 batch_size = c->UnknownDim();
2273 }
2274
2275 std::vector<DimensionHandle> output_dims{batch_size};
2276 output_dims.resize(num_block_dims + 1, c->UnknownDim());
2277
2278 if (paddings_t && (paddings_t->NumElements() > 0)) {
2279 const std::vector<int64> paddings_vec = GetFlatInt64(*paddings_t);
2280 for (int64 dim = 0; dim < num_block_dims; ++dim) {
2281 const int64 pad_start = paddings_vec[dim * 2],
2282 pad_end = paddings_vec[dim * 2 + 1];
2283 if (pad_start < 0 || pad_end < 0) {
2284 return errors::InvalidArgument("paddings cannot be negative");
2285 }
2286 if (block_shape_t) {
2287 DimensionHandle padded_size;
2288 TF_RETURN_IF_ERROR(
2289 c->Add(c->Dim(input_shape, dim + 1), pad_start, &padded_size));
2290 TF_RETURN_IF_ERROR(c->Add(padded_size, pad_end, &padded_size));
2291 TF_RETURN_IF_ERROR(c->Divide(padded_size, block_shape_vec[dim],
2292 /*evenly_divisible=*/true,
2293 &output_dims[dim + 1]));
2294 }
2295 }
2296 }
2297
2298 ShapeHandle remaining_input_shape;
2299 TF_RETURN_IF_ERROR(
2300 c->Subshape(input_shape, 1 + num_block_dims, &remaining_input_shape));
2301
2302 ShapeHandle result;
2303 TF_RETURN_IF_ERROR(c->Concatenate(c->MakeShape(output_dims),
2304 remaining_input_shape, &result));
2305 c->set_output(0, result);
2306 return Status::OK();
2307 }
2308
BatchToSpaceShapeHelper(InferenceContext * c,ShapeHandle input_shape,ShapeHandle block_shape_shape,const Tensor * block_shape_t,ShapeHandle crops_shape,const Tensor * crops_t)2309 Status BatchToSpaceShapeHelper(InferenceContext* c, ShapeHandle input_shape,
2310 ShapeHandle block_shape_shape,
2311 const Tensor* block_shape_t,
2312 ShapeHandle crops_shape, const Tensor* crops_t) {
2313 if (c->Rank(block_shape_shape) != 1) {
2314 return errors::InvalidArgument("block_shape must have rank 1.");
2315 }
2316
2317 const DimensionHandle num_block_dims_handle = c->Dim(block_shape_shape, 0);
2318 if (!c->ValueKnown(num_block_dims_handle)) {
2319 return errors::InvalidArgument("block_shape must have known size.");
2320 }
2321
2322 const int64 num_block_dims = c->Value(num_block_dims_handle);
2323
2324 TF_RETURN_IF_ERROR(
2325 c->WithRankAtLeast(input_shape, num_block_dims + 1, &input_shape));
2326
2327 TF_RETURN_IF_ERROR(
2328 c->Merge(crops_shape, c->Matrix(num_block_dims, 2), &crops_shape));
2329
2330 DimensionHandle batch_size = c->Dim(input_shape, 0);
2331 std::vector<int64> block_shape_vec;
2332 if (block_shape_t) {
2333 block_shape_vec = GetFlatInt64(*block_shape_t);
2334 for (int64 dim = 0; dim < num_block_dims; ++dim) {
2335 const int64 block_shape_value = block_shape_vec[dim];
2336 if (block_shape_value < 1) {
2337 return errors::InvalidArgument("block_shape must be positive");
2338 }
2339 if (c->ValueKnown(batch_size)) {
2340 TF_RETURN_IF_ERROR(c->Divide(batch_size, block_shape_value,
2341 /*evenly_divisible=*/true, &batch_size));
2342 } else {
2343 batch_size = c->UnknownDim();
2344 }
2345 }
2346 } else if (num_block_dims > 0) {
2347 batch_size = c->UnknownDim();
2348 }
2349
2350 std::vector<DimensionHandle> output_dims{batch_size};
2351 output_dims.resize(num_block_dims + 1, c->UnknownDim());
2352
2353 if (crops_t) {
2354 const std::vector<int64> crops_vec = GetFlatInt64(*crops_t);
2355 for (int64 dim = 0; dim < num_block_dims; ++dim) {
2356 const int64 crop_start = crops_vec[dim * 2],
2357 crop_end = crops_vec[dim * 2 + 1];
2358 if (crop_start < 0 || crop_end < 0) {
2359 return errors::InvalidArgument("crops cannot be negative");
2360 }
2361 if (block_shape_t) {
2362 DimensionHandle cropped_size;
2363 TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input_shape, dim + 1),
2364 block_shape_vec[dim], &cropped_size));
2365 TF_RETURN_IF_ERROR(
2366 c->Subtract(cropped_size, crop_start, &cropped_size));
2367 TF_RETURN_IF_ERROR(
2368 c->Subtract(cropped_size, crop_end, &output_dims[dim + 1]));
2369 }
2370 }
2371 }
2372
2373 ShapeHandle remaining_input_shape;
2374 TF_RETURN_IF_ERROR(
2375 c->Subshape(input_shape, 1 + num_block_dims, &remaining_input_shape));
2376
2377 ShapeHandle result;
2378 TF_RETURN_IF_ERROR(c->Concatenate(c->MakeShape(output_dims),
2379 remaining_input_shape, &result));
2380 c->set_output(0, result);
2381 return Status::OK();
2382 }
2383
2384 } // namespace
2385
2386 // --------------------------------------------------------------------------
2387 REGISTER_OP("SpaceToBatchND")
2388 .Input("input: T")
2389 .Input("block_shape: Tblock_shape")
2390 .Input("paddings: Tpaddings")
2391 .Output("output: T")
2392 .Attr("T: type")
2393 .Attr("Tblock_shape: {int32, int64} = DT_INT32")
2394 .Attr("Tpaddings: {int32, int64} = DT_INT32")
__anondb9326b23302(InferenceContext* c) 2395 .SetShapeFn([](InferenceContext* c) {
2396 return SpaceToBatchShapeHelper(c, c->input(0), c->input(1),
2397 c->input_tensor(1), c->input(2),
2398 c->input_tensor(2));
2399 });
2400
2401 // --------------------------------------------------------------------------
2402 REGISTER_OP("SpaceToBatch")
2403 .Input("input: T")
2404 .Input("paddings: Tpaddings")
2405 .Output("output: T")
2406 .Attr("T: type")
2407 .Attr("Tpaddings: {int32, int64} = DT_INT32")
2408 .Attr("block_size: int >= 2")
__anondb9326b23402(InferenceContext* c) 2409 .SetShapeFn([](InferenceContext* c) {
2410 ShapeHandle input_shape;
2411 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
2412
2413 int32 block_size;
2414 TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
2415
2416 Tensor block_shape(tensorflow::DT_INT64, TensorShape({2}));
2417 auto block_shape_vec = block_shape.vec<int64>();
2418 block_shape_vec(0) = block_size;
2419 block_shape_vec(1) = block_size;
2420
2421 return SpaceToBatchShapeHelper(c, input_shape, c->MakeShape({2}),
2422 &block_shape, c->input(1),
2423 c->input_tensor(1));
2424 });
2425
2426 // --------------------------------------------------------------------------
2427 REGISTER_OP("BatchToSpaceND")
2428 .Input("input: T")
2429 .Input("block_shape: Tblock_shape")
2430 .Input("crops: Tcrops")
2431 .Output("output: T")
2432 .Attr("T: type")
2433 .Attr("Tblock_shape: {int32, int64} = DT_INT32")
2434 .Attr("Tcrops: {int32, int64} = DT_INT32")
__anondb9326b23502(InferenceContext* c) 2435 .SetShapeFn([](InferenceContext* c) {
2436 return BatchToSpaceShapeHelper(c, c->input(0), c->input(1),
2437 c->input_tensor(1), c->input(2),
2438 c->input_tensor(2));
2439 });
2440
2441 // --------------------------------------------------------------------------
2442 REGISTER_OP("BatchToSpace")
2443 .Input("input: T")
2444 .Input("crops: Tidx")
2445 .Output("output: T")
2446 .Attr("T: type")
2447 .Attr("block_size: int >= 2")
2448 .Attr("Tidx: {int32, int64} = DT_INT32")
__anondb9326b23602(InferenceContext* c) 2449 .SetShapeFn([](InferenceContext* c) {
2450 ShapeHandle input_shape;
2451 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
2452
2453 int32 block_size;
2454 TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
2455
2456 Tensor block_shape(tensorflow::DT_INT64, TensorShape({2}));
2457 auto block_shape_vec = block_shape.vec<int64>();
2458 block_shape_vec(0) = block_size;
2459 block_shape_vec(1) = block_size;
2460
2461 return BatchToSpaceShapeHelper(c, input_shape, c->MakeShape({2}),
2462 &block_shape, c->input(1),
2463 c->input_tensor(1));
2464 });
2465
2466 // --------------------------------------------------------------------------
2467 REGISTER_OP("SpaceToDepth")
2468 .Input("input: T")
2469 .Output("output: T")
2470 .Attr("T: type")
2471 .Attr("block_size: int >= 2")
2472 .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
2473 // TODO(pauldonnelly): Implement GPU kernels for NCHW_VECT_C.
__anondb9326b23702(InferenceContext* c) 2474 .SetShapeFn([](InferenceContext* c) {
2475 string data_format_str;
2476 TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
2477 TensorFormat data_format;
2478 FormatFromString(data_format_str, &data_format);
2479
2480 constexpr int num_spatial_dims = 2;
2481 const int dims =
2482 GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
2483 ShapeHandle input;
2484 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), dims, &input));
2485
2486 int32 block_size;
2487 TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
2488
2489 DimensionHandle batch_size =
2490 c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
2491 DimensionHandle input_height =
2492 c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
2493 DimensionHandle input_width =
2494 c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
2495 DimensionHandle input_depth =
2496 c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
2497
2498 DimensionHandle output_height;
2499 DimensionHandle output_width;
2500 DimensionHandle output_depth;
2501 // Will return an error if input height or width are not evenly divisible.
2502 TF_RETURN_IF_ERROR(c->Divide(input_height, block_size,
2503 true /* evenly_divisible */,
2504 &output_height));
2505 TF_RETURN_IF_ERROR(c->Divide(input_width, block_size,
2506 true /* evenly_divisible */, &output_width));
2507
2508 TF_RETURN_IF_ERROR(
2509 c->Multiply(input_depth, block_size * block_size, &output_depth));
2510
2511 ShapeHandle output_shape;
2512 TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size,
2513 {output_height, output_width},
2514 output_depth, &output_shape, c));
2515
2516 c->set_output(0, output_shape);
2517 return Status::OK();
2518 });
2519
2520 // --------------------------------------------------------------------------
2521 REGISTER_OP("DepthToSpace")
2522 .Input("input: T")
2523 .Output("output: T")
2524 .Attr("T: type")
2525 .Attr("block_size: int >= 2")
2526 .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
2527 // TODO(pauldonnelly): Implement GPU kernels for NCHW and NCHW_VECT_C.
__anondb9326b23802(InferenceContext* c) 2528 .SetShapeFn([](InferenceContext* c) {
2529 string data_format_str;
2530 TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
2531 TensorFormat data_format;
2532 FormatFromString(data_format_str, &data_format);
2533
2534 constexpr int num_spatial_dims = 2;
2535 const int dims =
2536 GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
2537
2538 ShapeHandle input;
2539 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), dims, &input));
2540
2541 int32 block_size;
2542 TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
2543
2544 DimensionHandle batch_size =
2545 c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
2546 DimensionHandle input_height =
2547 c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
2548 DimensionHandle input_width =
2549 c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
2550 DimensionHandle input_depth =
2551 c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
2552
2553 DimensionHandle output_height;
2554 DimensionHandle output_width;
2555 DimensionHandle output_depth;
2556 TF_RETURN_IF_ERROR(c->Multiply(input_height, block_size, &output_height));
2557 TF_RETURN_IF_ERROR(c->Multiply(input_width, block_size, &output_width));
2558
2559 // Will return an error if input_depth is not evenly divisible.
2560 TF_RETURN_IF_ERROR(c->Divide(input_depth, block_size * block_size,
2561 true /* evenly_divisible */, &output_depth));
2562
2563 ShapeHandle output_shape;
2564 TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size,
2565 {output_height, output_width},
2566 output_depth, &output_shape, c));
2567
2568 c->set_output(0, output_shape);
2569 return Status::OK();
2570 });
2571
2572 // --------------------------------------------------------------------------
2573
2574 REGISTER_OP("ExtractImagePatches")
2575 .Input("images: T")
2576 .Output("patches: T")
2577 .Attr("ksizes: list(int) >= 4")
2578 .Attr("strides: list(int) >= 4")
2579 .Attr("rates: list(int) >= 4")
2580 .Attr(
2581 "T: {bfloat16, half, float, double, int8, int16, int32, int64, "
2582 "uint8, uint16, uint32, uint64, complex64, complex128, bool}")
2583 .Attr(GetPaddingAttrString())
__anondb9326b23902(InferenceContext* c) 2584 .SetShapeFn([](InferenceContext* c) {
2585 ShapeHandle input_shape;
2586 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
2587
2588 std::vector<int32> ksizes;
2589 TF_RETURN_IF_ERROR(c->GetAttr("ksizes", &ksizes));
2590 if (ksizes.size() != 4) {
2591 return errors::InvalidArgument(
2592 "ExtractImagePatches requires the ksizes attribute to contain 4 "
2593 "values, but got: ",
2594 ksizes.size());
2595 }
2596
2597 std::vector<int32> strides;
2598 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
2599 if (strides.size() != 4) {
2600 return errors::InvalidArgument(
2601 "ExtractImagePatches requires the stride attribute to contain 4 "
2602 "values, but got: ",
2603 strides.size());
2604 }
2605
2606 std::vector<int32> rates;
2607 TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates));
2608 if (rates.size() != 4) {
2609 return errors::InvalidArgument(
2610 "ExtractImagePatches requires the rates attribute to contain 4 "
2611 "values, but got: ",
2612 rates.size());
2613 }
2614
2615 int32 ksize_rows = ksizes[1];
2616 int32 ksize_cols = ksizes[2];
2617
2618 int32 stride_rows = strides[1];
2619 int32 stride_cols = strides[2];
2620
2621 int32 rate_rows = rates[1];
2622 int32 rate_cols = rates[2];
2623
2624 int32 ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
2625 int32 ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
2626
2627 DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
2628 DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
2629 DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
2630 DimensionHandle output_depth_dim;
2631 TF_RETURN_IF_ERROR(c->Multiply(
2632 c->Dim(input_shape, 3), ksize_rows * ksize_cols, &output_depth_dim));
2633
2634 if (!c->ValueKnown(in_rows_dim) || !c->ValueKnown(in_cols_dim)) {
2635 ShapeHandle output_shape =
2636 c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim,
2637 InferenceContext::kUnknownDim, output_depth_dim});
2638 c->set_output(0, output_shape);
2639 return Status::OK();
2640 }
2641 auto in_rows = c->Value(in_rows_dim);
2642 auto in_cols = c->Value(in_cols_dim);
2643
2644 Padding padding;
2645 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
2646
2647 int64 output_rows, output_cols;
2648 int64 padding_before, padding_after;
2649 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
2650 in_rows, ksize_rows_eff, stride_rows, padding, &output_rows,
2651 &padding_before, &padding_after));
2652 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
2653 in_cols, ksize_cols_eff, stride_cols, padding, &output_cols,
2654 &padding_before, &padding_after));
2655 ShapeHandle output_shape = c->MakeShape(
2656 {batch_size_dim, output_rows, output_cols, output_depth_dim});
2657 c->set_output(0, output_shape);
2658 return Status::OK();
2659 });
2660
2661 // --------------------------------------------------------------------------
2662
2663 // To enable rates, uncomment all lines commented below and use ksize_*_eff
2664 // as the second parameter of all GetWindowedOutputSizeVerbose calls instead
2665 // of ksize_*.
2666 REGISTER_OP("ExtractVolumePatches")
2667 .Input("input: T")
2668 .Output("patches: T")
2669 .Attr("ksizes: list(int) >= 5")
2670 .Attr("strides: list(int) >= 5")
2671 /* .Attr("rates: list(int) >= 5") */
2672 .Attr("T: realnumbertype")
2673 .Attr(GetPaddingAttrString())
__anondb9326b23a02(InferenceContext* c) 2674 .SetShapeFn([](InferenceContext* c) {
2675 ShapeHandle input_shape;
2676 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
2677
2678 std::vector<int32> ksizes;
2679 TF_RETURN_IF_ERROR(c->GetAttr("ksizes", &ksizes));
2680 if (ksizes.size() != 5) {
2681 return errors::InvalidArgument(
2682 "ExtractVolumePatches requires the ksizes attribute to contain 5 "
2683 "values, but got: ",
2684 ksizes.size());
2685 }
2686
2687 std::vector<int32> strides;
2688 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
2689 if (strides.size() != 5) {
2690 return errors::InvalidArgument(
2691 "ExtractVolumePatches requires the stride attribute to contain 5 "
2692 "values, but got: ",
2693 strides.size());
2694 }
2695
2696 /*
2697 // TODO(hsgkim): Enable rates.
2698 // See extract_volume_patches_op.cc for why rates are disabled now.
2699
2700 std::vector<int32> rates;
2701 TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates));
2702 if (rates.size() != 5) {
2703 return errors::InvalidArgument(
2704 "ExtractVolumePatches requires the rates attribute to contain 5 "
2705 "values, but got: ",
2706 rates.size());
2707 }
2708 */
2709
2710 int32 ksize_planes = ksizes[1];
2711 int32 ksize_rows = ksizes[2];
2712 int32 ksize_cols = ksizes[3];
2713
2714 int32 stride_planes = strides[1];
2715 int32 stride_rows = strides[2];
2716 int32 stride_cols = strides[3];
2717
2718 /*
2719 int32 rate_planes = rates[1];
2720 int32 rate_rows = rates[2];
2721 int32 rate_cols = rates[3];
2722
2723 int32 ksize_planes_eff = ksize_planes +
2724 (ksize_planes - 1) * (rate_planes - 1);
2725 int32 ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
2726 int32 ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
2727 */
2728
2729 DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
2730 DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
2731 DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
2732 DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
2733 DimensionHandle output_depth_dim;
2734 TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input_shape, 4),
2735 ksize_planes * ksize_rows * ksize_cols,
2736 &output_depth_dim));
2737
2738 if (!c->ValueKnown(in_planes_dim) || !c->ValueKnown(in_rows_dim) ||
2739 !c->ValueKnown(in_cols_dim)) {
2740 ShapeHandle output_shape =
2741 c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim,
2742 InferenceContext::kUnknownDim, output_depth_dim});
2743 c->set_output(0, output_shape);
2744 return Status::OK();
2745 }
2746 auto in_planes = c->Value(in_planes_dim);
2747 auto in_rows = c->Value(in_rows_dim);
2748 auto in_cols = c->Value(in_cols_dim);
2749
2750 Padding padding;
2751 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
2752
2753 int64 output_planes, output_rows, output_cols;
2754 int64 padding_before, padding_after;
2755 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
2756 in_planes, ksize_planes, stride_planes, padding, &output_planes,
2757 &padding_before, &padding_after));
2758 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
2759 in_rows, ksize_rows, stride_rows, padding, &output_rows,
2760 &padding_before, &padding_after));
2761 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
2762 in_cols, ksize_cols, stride_cols, padding, &output_cols,
2763 &padding_before, &padding_after));
2764 ShapeHandle output_shape =
2765 c->MakeShape({batch_size_dim, output_planes, output_rows, output_cols,
2766 output_depth_dim});
2767 c->set_output(0, output_shape);
2768 return Status::OK();
2769 });
2770
2771 // --------------------------------------------------------------------------
2772
2773 REGISTER_OP("OneHot")
2774 .Input("indices: TI")
2775 .Input("depth: int32")
2776 .Input("on_value: T")
2777 .Input("off_value: T")
2778 .Attr("axis: int = -1")
2779 .Output("output: T")
2780 .Attr("T: type")
2781 .Attr("TI: {uint8, int32, int64} = DT_INT64")
__anondb9326b23b02(InferenceContext* c) 2782 .SetShapeFn([](InferenceContext* c) {
2783 int32 axis;
2784 TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis));
2785 if (axis < -1) return errors::InvalidArgument("axis must be >= -1");
2786
2787 DimensionHandle depth;
2788 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &depth));
2789
2790 ShapeHandle indices = c->input(0);
2791 if (!c->RankKnown(indices)) return shape_inference::UnknownShape(c);
2792
2793 int32 new_rank = c->Rank(indices) + 1;
2794 // We need to add new_rank to axis in the case the axis is -1 because
2795 // C++ returns negative values from % if the dividend is negative.
2796 int32 depth_index = (axis + new_rank) % new_rank;
2797 // Out shape is indices[0:depth_index] + [depth] + indices[depth_index:].
2798 ShapeHandle front;
2799 ShapeHandle back;
2800 ShapeHandle out;
2801 TF_RETURN_IF_ERROR(c->Subshape(indices, 0, depth_index, &front));
2802 TF_RETURN_IF_ERROR(c->Subshape(indices, depth_index, &back));
2803 TF_RETURN_IF_ERROR(c->Concatenate(front, c->Vector(depth), &front));
2804 TF_RETURN_IF_ERROR(c->Concatenate(front, back, &out));
2805 c->set_output(0, out);
2806 return Status::OK();
2807 });
2808
2809 // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET.
2810 REGISTER_OP("QuantizeAndDequantize")
2811 .Input("input: T")
2812 .Attr("signed_input: bool = true")
2813 .Attr("num_bits: int = 8")
2814 .Attr("range_given: bool = false")
2815 .Attr("input_min: float = 0")
2816 .Attr("input_max: float = 0")
2817 .Output("output: T")
2818 .Attr("T: {bfloat16, half, float, double}")
2819 .SetShapeFn(shape_inference::UnchangedShape)
2820 .Deprecated(22, "Replaced by QuantizeAndDequantizeV2");
2821
2822 // TODO(suharshs): Deprecate QuantizeAndDequantizeV2.
2823 REGISTER_OP("QuantizeAndDequantizeV2")
2824 .Input("input: T")
2825 .Input("input_min: T")
2826 .Input("input_max: T")
2827 .Attr("signed_input: bool = true")
2828 .Attr("num_bits: int = 8")
2829 .Attr("range_given: bool = false")
2830 .Output("output: T")
2831 .Attr("T: {bfloat16, half, float, double}")
2832 .Attr(
2833 "round_mode: {'HALF_TO_EVEN', 'HALF_UP'} = "
2834 "'HALF_TO_EVEN'")
2835 .Attr("narrow_range: bool = false")
2836 .Attr("axis: int = -1")
__anondb9326b23c02(InferenceContext* c) 2837 .SetShapeFn([](InferenceContext* c) {
2838 int axis;
2839 TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis));
2840 const int minmax_rank = (axis == -1) ? 0 : 1;
2841 ShapeHandle minmax;
2842 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), minmax_rank, &minmax));
2843 TF_RETURN_IF_ERROR(c->Merge(c->input(2), minmax, &minmax));
2844 if (axis != -1) {
2845 ShapeHandle input;
2846 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
2847 DimensionHandle depth;
2848 TF_RETURN_IF_ERROR(
2849 c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth));
2850 }
2851 c->set_output(0, c->input(0));
2852 return Status::OK();
2853 });
2854
2855 REGISTER_OP("QuantizeAndDequantizeV4")
2856 .Input("input: T")
2857 .Input("input_min: T")
2858 .Input("input_max: T")
2859 .Attr("signed_input: bool = true")
2860 .Attr("num_bits: int = 8")
2861 .Attr("range_given: bool = false")
2862 .Output("output: T")
2863 .Attr("T: {bfloat16, half, float, double}")
2864 .Attr(
2865 "round_mode: {'HALF_TO_EVEN', 'HALF_UP'} = "
2866 "'HALF_TO_EVEN'")
2867 .Attr("narrow_range: bool = false")
2868 .Attr("axis: int = -1")
__anondb9326b23d02(InferenceContext* c) 2869 .SetShapeFn([](InferenceContext* c) {
2870 int axis;
2871 TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis));
2872 const int minmax_rank = (axis == -1) ? 0 : 1;
2873 ShapeHandle minmax;
2874 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), minmax_rank, &minmax));
2875 TF_RETURN_IF_ERROR(c->Merge(c->input(2), minmax, &minmax));
2876 if (axis != -1) {
2877 ShapeHandle input;
2878 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
2879 DimensionHandle depth;
2880 TF_RETURN_IF_ERROR(
2881 c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth));
2882 }
2883 c->set_output(0, c->input(0));
2884 return Status::OK();
2885 });
2886
2887 REGISTER_OP("QuantizeAndDequantizeV4Grad")
2888 .Input("gradients: T")
2889 .Input("input: T")
2890 .Input("input_min: T")
2891 .Input("input_max: T")
2892 .Output("input_backprop: T")
2893 .Output("input_min_backprop: T")
2894 .Output("input_max_backprop: T")
2895 .Attr("T: {bfloat16, half, float, double}")
2896 .Attr("axis: int = -1")
__anondb9326b23e02(InferenceContext* c) 2897 .SetShapeFn([](InferenceContext* c) {
2898 int axis;
2899 TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis));
2900 const int minmax_rank = (axis == -1) ? 0 : 1;
2901 ShapeHandle minmax;
2902 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), minmax_rank, &minmax));
2903 TF_RETURN_IF_ERROR(c->Merge(c->input(3), minmax, &minmax));
2904 if (axis != -1) {
2905 ShapeHandle input;
2906 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
2907 DimensionHandle depth;
2908 TF_RETURN_IF_ERROR(
2909 c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth));
2910 }
2911 ShapeHandle inputs;
2912 TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &inputs));
2913 c->set_output(0, inputs);
2914 c->set_output(1, minmax);
2915 c->set_output(2, minmax);
2916 return Status::OK();
2917 });
2918
2919 REGISTER_OP("QuantizeAndDequantizeV3")
2920 .Input("input: T")
2921 .Input("input_min: T")
2922 .Input("input_max: T")
2923 .Input("num_bits: int32")
2924 .Attr("signed_input: bool = true")
2925 .Attr("range_given: bool = true")
2926 .Output("output: T")
2927 .Attr("T: {bfloat16, half, float, double}")
2928 .Attr("narrow_range: bool = false")
2929 .Attr("axis: int = -1")
__anondb9326b23f02(InferenceContext* c) 2930 .SetShapeFn([](InferenceContext* c) {
2931 int axis;
2932 TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis));
2933 const int minmax_rank = (axis == -1) ? 0 : 1;
2934 ShapeHandle minmax;
2935 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), minmax_rank, &minmax));
2936 TF_RETURN_IF_ERROR(c->Merge(c->input(2), minmax, &minmax));
2937 if (axis != -1) {
2938 ShapeHandle input;
2939 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
2940 DimensionHandle depth;
2941 TF_RETURN_IF_ERROR(
2942 c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth));
2943 }
2944 ShapeHandle unused;
2945 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2946 c->set_output(0, c->input(0));
2947 return Status::OK();
2948 });
2949
2950 REGISTER_OP("QuantizeV2")
2951 .Input("input: float")
2952 .Input("min_range: float")
2953 .Input("max_range: float")
2954 .Output("output: T")
2955 .Output("output_min: float")
2956 .Output("output_max: float")
2957 .Attr("T: quantizedtype")
2958 .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'MIN_COMBINED'")
2959 .Attr(
2960 "round_mode: {'HALF_AWAY_FROM_ZERO', 'HALF_TO_EVEN'} = "
2961 "'HALF_AWAY_FROM_ZERO'")
2962 .Attr("narrow_range: bool = false")
2963 .Attr("axis: int = -1")
2964 .Attr("ensure_minimum_range: float = 0.01")
__anondb9326b24002(InferenceContext* c) 2965 .SetShapeFn([](InferenceContext* c) {
2966 int axis = -1;
2967 Status s = c->GetAttr("axis", &axis);
2968 if (!s.ok() && s.code() != error::NOT_FOUND) {
2969 return s;
2970 }
2971 const int minmax_rank = (axis == -1) ? 0 : 1;
2972 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
2973 ShapeHandle minmax;
2974 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), minmax_rank, &minmax));
2975 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), minmax_rank, &minmax));
2976 if (axis != -1) {
2977 ShapeHandle input;
2978 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
2979 DimensionHandle depth;
2980 TF_RETURN_IF_ERROR(
2981 c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth));
2982 }
2983 c->set_output(1, minmax);
2984 c->set_output(2, minmax);
2985 return Status::OK();
2986 });
2987
2988 REGISTER_OP("Dequantize")
2989 .Input("input: T")
2990 .Input("min_range: float")
2991 .Input("max_range: float")
2992 .Output("output: dtype")
2993 .Attr("T: quantizedtype")
2994 .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'MIN_COMBINED'")
2995 .Attr("narrow_range: bool = false")
2996 .Attr("axis: int = -1")
2997 .Attr("dtype: {bfloat16, float} = DT_FLOAT")
__anondb9326b24102(InferenceContext* c) 2998 .SetShapeFn([](InferenceContext* c) {
2999 int axis = -1;
3000 Status s = c->GetAttr("axis", &axis);
3001 if (!s.ok() && s.code() != error::NOT_FOUND) {
3002 return s;
3003 }
3004 const int minmax_rank = (axis == -1) ? 0 : 1;
3005 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
3006 ShapeHandle minmax;
3007 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), minmax_rank, &minmax));
3008 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), minmax_rank, &minmax));
3009 if (axis != -1) {
3010 ShapeHandle input;
3011 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
3012 DimensionHandle depth;
3013 TF_RETURN_IF_ERROR(
3014 c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth));
3015 }
3016 return Status::OK();
3017 });
3018
3019 REGISTER_OP("QuantizedConcat")
3020 .Input("concat_dim: int32")
3021 .Input("values: N * T")
3022 .Input("input_mins: N * float32")
3023 .Input("input_maxes: N * float32")
3024 .Output("output: T")
3025 .Output("output_min: float")
3026 .Output("output_max: float")
3027 .Attr("N: int >= 2")
3028 .Attr("T: type")
__anondb9326b24202(InferenceContext* c) 3029 .SetShapeFn([](InferenceContext* c) {
3030 const int n = (c->num_inputs() - 1) / 3;
3031 TF_RETURN_IF_ERROR(shape_inference::ConcatShape(c, n));
3032 ShapeHandle unused;
3033 for (int i = n + 1; i < c->num_inputs(); ++i) {
3034 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused));
3035 }
3036 c->set_output(1, c->Scalar());
3037 c->set_output(2, c->Scalar());
3038 return Status::OK();
3039 });
3040
3041 REGISTER_OP("QuantizedReshape")
3042 .Input("tensor: T")
3043 .Input("shape: Tshape")
3044 .Input("input_min: float")
3045 .Input("input_max: float")
3046 .Output("output: T")
3047 .Output("output_min: float")
3048 .Output("output_max: float")
3049 .Attr("T: type")
3050 .Attr("Tshape: {int32, int64} = DT_INT32")
__anondb9326b24302(InferenceContext* c) 3051 .SetShapeFn([](InferenceContext* c) {
3052 TF_RETURN_IF_ERROR(SetOutputShapeForReshape(c));
3053 ShapeHandle unused;
3054 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
3055 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3056 c->set_output(1, c->Scalar());
3057 c->set_output(2, c->Scalar());
3058 return Status::OK();
3059 });
3060
3061 REGISTER_OP("QuantizedInstanceNorm")
3062 .Input("x: T")
3063 .Input("x_min: float")
3064 .Input("x_max: float")
3065 .Output("y: T")
3066 .Output("y_min: float")
3067 .Output("y_max: float")
3068 .Attr("T: quantizedtype")
3069 .Attr("output_range_given: bool = false")
3070 .Attr("given_y_min: float = 0")
3071 .Attr("given_y_max: float = 0")
3072 .Attr("variance_epsilon: float = 1e-5")
3073 .Attr("min_separation: float = 1e-3")
__anondb9326b24402(shape_inference::InferenceContext* c) 3074 .SetShapeFn([](shape_inference::InferenceContext* c) {
3075 shape_inference::ShapeHandle unused;
3076 // x should be a rank 4 tensor.
3077 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &unused));
3078 // Assert x_min and x_max are scalars (rank 0).
3079 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
3080 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
3081 // y has the same shape as x.
3082 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
3083 // y_min and y_max are scalars.
3084 c->set_output(1, c->Scalar());
3085 c->set_output(2, c->Scalar());
3086 return Status::OK();
3087 });
3088
3089 namespace {
3090
ScatterNdTensorShape(InferenceContext * c)3091 Status ScatterNdTensorShape(InferenceContext* c) {
3092 ShapeHandle output_shape;
3093 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &output_shape));
3094 ShapeHandle indices_shape;
3095 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape));
3096 ShapeHandle updates_shape;
3097 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape));
3098 return shape_inference::ScatterNdShapeHelper(c, indices_shape, updates_shape,
3099 output_shape);
3100 }
3101
3102 } // namespace
3103
3104 REGISTER_OP("UpperBound")
3105 .Input("sorted_inputs: T")
3106 .Input("values: T")
3107 .Output("output: out_type")
3108 .Attr("T: type")
3109 .Attr("out_type: {int32, int64} = DT_INT32")
__anondb9326b24602(InferenceContext* c) 3110 .SetShapeFn([](InferenceContext* c) {
3111 ShapeHandle unused_shape;
3112 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused_shape));
3113 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape));
3114 c->set_output(0, c->input(1));
3115 return Status::OK();
3116 });
3117
3118 REGISTER_OP("LowerBound")
3119 .Input("sorted_inputs: T")
3120 .Input("values: T")
3121 .Output("output: out_type")
3122 .Attr("T: type")
3123 .Attr("out_type: {int32, int64} = DT_INT32")
__anondb9326b24702(InferenceContext* c) 3124 .SetShapeFn([](InferenceContext* c) {
3125 ShapeHandle unused_shape;
3126 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused_shape));
3127 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape));
3128 c->set_output(0, c->input(1));
3129 return Status::OK();
3130 });
3131
3132 REGISTER_OP("ScatterNd")
3133 .Input("indices: Tindices")
3134 .Input("updates: T")
3135 .Input("shape: Tindices")
3136 .Output("output: T")
3137 .Attr("T: type")
3138 .Attr("Tindices: {int32, int64}")
__anondb9326b24802(InferenceContext* c) 3139 .SetShapeFn([](InferenceContext* c) {
3140 ShapeHandle indices_shape;
3141 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &indices_shape));
3142 ShapeHandle updates_shape;
3143 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &updates_shape));
3144 ShapeHandle output_shape;
3145 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &output_shape));
3146 return shape_inference::ScatterNdShapeHelper(c, indices_shape,
3147 updates_shape, output_shape);
3148 });
3149
3150 REGISTER_OP("TensorScatterUpdate")
3151 .Input("tensor: T")
3152 .Input("indices: Tindices")
3153 .Input("updates: T")
3154 .Output("output: T")
3155 .Attr("T: type")
3156 .Attr("Tindices: {int32, int64}")
3157 .SetShapeFn(ScatterNdTensorShape);
3158
3159 REGISTER_OP("TensorScatterAdd")
3160 .Input("tensor: T")
3161 .Input("indices: Tindices")
3162 .Input("updates: T")
3163 .Output("output: T")
3164 .Attr("T: type")
3165 .Attr("Tindices: {int32, int64}")
3166 .SetShapeFn(ScatterNdTensorShape);
3167
3168 REGISTER_OP("TensorScatterSub")
3169 .Input("tensor: T")
3170 .Input("indices: Tindices")
3171 .Input("updates: T")
3172 .Output("output: T")
3173 .Attr("T: type")
3174 .Attr("Tindices: {int32, int64}")
3175 .SetShapeFn(ScatterNdTensorShape);
3176
3177 REGISTER_OP("TensorScatterMin")
3178 .Input("tensor: T")
3179 .Input("indices: Tindices")
3180 .Input("updates: T")
3181 .Output("output: T")
3182 .Attr("T: type")
3183 .Attr("Tindices: {int32, int64}")
3184 .SetShapeFn(ScatterNdTensorShape);
3185
3186 REGISTER_OP("TensorScatterMax")
3187 .Input("tensor: T")
3188 .Input("indices: Tindices")
3189 .Input("updates: T")
3190 .Output("output: T")
3191 .Attr("T: type")
3192 .Attr("Tindices: {int32, int64}")
3193 .SetShapeFn(ScatterNdTensorShape);
3194
3195 REGISTER_OP("ScatterNdNonAliasingAdd")
3196 .Input("input: T")
3197 .Input("indices: Tindices")
3198 .Input("updates: T")
3199 .Output("output: T")
3200 .Attr("T: {numbertype, bool}")
3201 .Attr("Tindices: {int32, int64}")
3202 .SetShapeFn(ScatterNdTensorShape);
3203
3204 REGISTER_OP("FakeQuantWithMinMaxArgs")
3205 .Attr("min: float = -6.0")
3206 .Attr("max: float = 6.0")
3207 .Attr("num_bits: int = 8")
3208 .Attr("narrow_range: bool = false")
3209 .Input("inputs: float")
3210 .Output("outputs: float")
3211 .SetShapeFn(shape_inference::UnchangedShape);
3212
3213 REGISTER_OP("FakeQuantWithMinMaxArgsGradient")
3214 .Attr("min: float = -6.0")
3215 .Attr("max: float = 6.0")
3216 .Attr("num_bits: int = 8")
3217 .Attr("narrow_range: bool = false")
3218 .Input("gradients: float")
3219 .Input("inputs: float")
3220 .Output("backprops: float")
3221 .SetShapeFn(shape_inference::UnchangedShape);
3222
3223 REGISTER_OP("FakeQuantWithMinMaxVars")
3224 .Attr("num_bits: int = 8")
3225 .Attr("narrow_range: bool = false")
3226 .Input("inputs: float")
3227 .Input("min: float")
3228 .Input("max: float")
3229 .Output("outputs: float")
__anondb9326b24902(InferenceContext* c) 3230 .SetShapeFn([](InferenceContext* c) {
3231 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
3232 ShapeHandle unused;
3233 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
3234 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
3235 return Status::OK();
3236 });
3237
3238 REGISTER_OP("FakeQuantWithMinMaxVarsGradient")
3239 .Attr("num_bits: int = 8")
3240 .Attr("narrow_range: bool = false")
3241 .Input("gradients: float")
3242 .Input("inputs: float")
3243 .Input("min: float")
3244 .Input("max: float")
3245 .Output("backprops_wrt_input: float")
3246 .Output("backprop_wrt_min: float")
3247 .Output("backprop_wrt_max: float")
__anondb9326b24a02(InferenceContext* c) 3248 .SetShapeFn([](InferenceContext* c) {
3249 // gradients and inputs are same size.
3250 ShapeHandle inputs;
3251 TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &inputs));
3252
3253 // min and max are scalars
3254 ShapeHandle min_max;
3255 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_max));
3256 TF_RETURN_IF_ERROR(c->Merge(min_max, c->input(3), &min_max));
3257
3258 c->set_output(0, inputs);
3259 c->set_output(1, min_max);
3260 c->set_output(2, min_max);
3261 return Status::OK();
3262 });
3263
3264 REGISTER_OP("FakeQuantWithMinMaxVarsPerChannel")
3265 .Attr("num_bits: int = 8")
3266 .Attr("narrow_range: bool = false")
3267 .Input("inputs: float")
3268 .Input("min: float")
3269 .Input("max: float")
3270 .Output("outputs: float")
__anondb9326b24b02(InferenceContext* c) 3271 .SetShapeFn([](InferenceContext* c) {
3272 ShapeHandle input, min, max;
3273 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
3274 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &min));
3275 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &max));
3276
3277 DimensionHandle unused;
3278 TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -1), c->Dim(min, 0), &unused));
3279 TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -1), c->Dim(max, 0), &unused));
3280 TF_RETURN_IF_ERROR(c->Merge(c->Dim(min, 0), c->Dim(max, 0), &unused));
3281
3282 c->set_output(0, input);
3283 return Status::OK();
3284 });
3285
3286 REGISTER_OP("FakeQuantWithMinMaxVarsPerChannelGradient")
3287 .Attr("num_bits: int = 8")
3288 .Attr("narrow_range: bool = false")
3289 .Input("gradients: float")
3290 .Input("inputs: float")
3291 .Input("min: float")
3292 .Input("max: float")
3293 .Output("backprops_wrt_input: float")
3294 .Output("backprop_wrt_min: float")
3295 .Output("backprop_wrt_max: float")
__anondb9326b24c02(InferenceContext* c) 3296 .SetShapeFn([](InferenceContext* c) {
3297 ShapeHandle inputs;
3298 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &inputs));
3299 TF_RETURN_IF_ERROR(c->WithRankAtMost(inputs, 4, &inputs));
3300 TF_RETURN_IF_ERROR(c->Merge(inputs, c->input(1), &inputs));
3301
3302 ShapeHandle last_dim = c->Vector(c->Dim(inputs, -1));
3303
3304 ShapeHandle min_max;
3305 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &min_max));
3306 TF_RETURN_IF_ERROR(c->Merge(min_max, last_dim, &min_max));
3307 TF_RETURN_IF_ERROR(c->Merge(c->input(3), min_max, &min_max));
3308
3309 c->set_output(0, inputs);
3310 c->set_output(1, min_max);
3311 c->set_output(2, min_max);
3312 return Status::OK();
3313 });
3314
3315 REGISTER_OP("Fingerprint")
3316 .Input("data: T")
3317 .Input("method: string")
3318 .Output("fingerprint: uint8")
3319 .Attr("T: type")
__anondb9326b24d02(InferenceContext* c) 3320 .SetShapeFn([](InferenceContext* c) {
3321 ShapeHandle unused;
3322 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &unused));
3323 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
3324
3325 DimensionHandle fingerprint_size;
3326 const Tensor* method = c->input_tensor(1);
3327 if (method == nullptr) {
3328 fingerprint_size = c->UnknownDim();
3329 } else {
3330 if (method->dims() != 0) {
3331 return errors::InvalidArgument("`method` must be rank 0: ",
3332 method->shape());
3333 }
3334 const string& method_string = method->scalar<tstring>()();
3335 if (method_string != "farmhash64") {
3336 return errors::InvalidArgument("Unsupported method: ", method_string);
3337 }
3338 fingerprint_size = c->MakeDim(sizeof(uint64));
3339 }
3340
3341 DimensionHandle batch = c->Dim(c->input(0), 0);
3342 c->set_output(0, c->MakeShape({batch, fingerprint_size}));
3343 return Status::OK();
3344 });
3345
3346 #ifdef INTEL_MKL
3347 REGISTER_OP("_MklConcat")
3348 .Input("concat_dim: int32")
3349 .Input("values: N * T")
3350 .Input("mkl_concat_dim: uint8")
3351 .Input("mkl_values: N * uint8")
3352 .Output("output: T")
3353 .Output("mkl_output: uint8")
3354 .Attr("N: int >= 2")
3355 .Attr("T: type")
__anondb9326b24e02(InferenceContext* c) 3356 .SetShapeFn([](InferenceContext* c) {
3357 return shape_inference::ConcatShape(c, c->num_inputs() - 3);
3358 })
3359 .Doc(R"doc(
3360 MKL version of Concat operator. Uses MKL DNN APIs to perform concatenation.
3361
3362 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
3363 expected to invoke these operators.
3364 )doc");
3365 #endif
3366
3367 // Deprecated op registrations:
3368
3369 // The following can be deleted after 10mar2017.
3370 REGISTER_OP("BatchMatrixDiag")
3371 .Input("diagonal: T")
3372 .Output("output: T")
3373 .Attr("T: type")
3374 .Deprecated(14, "Use MatrixDiag")
3375 .SetShapeFn(shape_inference::UnknownShape);
3376 REGISTER_OP("BatchMatrixSetDiag")
3377 .Input("input: T")
3378 .Input("diagonal: T")
3379 .Output("output: T")
3380 .Attr("T: type")
3381 .Deprecated(14, "Use MatrixSetDiag")
3382 .SetShapeFn(shape_inference::UnknownShape);
3383 REGISTER_OP("BatchMatrixDiagPart")
3384 .Input("input: T")
3385 .Output("diagonal: T")
3386 .Attr("T: type")
3387 .Deprecated(14, "Use MatrixDiagPart")
3388 .SetShapeFn(shape_inference::UnknownShape);
3389 REGISTER_OP("BatchMatrixBandPart")
3390 .Input("input: T")
3391 .Input("num_lower: int64")
3392 .Input("num_upper: int64")
3393 .Output("band: T")
3394 .Attr("T: type")
3395 .Deprecated(14, "Use MatrixBandPart")
3396 .SetShapeFn(shape_inference::UnknownShape);
3397
3398 } // namespace tensorflow
3399