1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19 #include "tensorflow/core/framework/tensor.pb.h"
20 #include "tensorflow/core/util/mirror_pad_mode.h"
21 #include "tensorflow/core/util/padding.h"
22 #include "tensorflow/core/util/strided_slice_op.h"
23 #include "tensorflow/core/util/tensor_format.h"
24
25 namespace tensorflow {
26
27 using shape_inference::DimensionHandle;
28 using shape_inference::InferenceContext;
29 using shape_inference::ShapeHandle;
30 using shape_inference::UnchangedShape;
31
32 namespace {
33
GetAxisForPackAndUnpack(InferenceContext * c,int32 rank_after_pack,int32 * axis)34 Status GetAxisForPackAndUnpack(InferenceContext* c, int32 rank_after_pack,
35 int32* axis) {
36 TF_RETURN_IF_ERROR(c->GetAttr("axis", axis));
37 if (*axis < -1 * rank_after_pack || *axis >= rank_after_pack) {
38 return errors::InvalidArgument("Invalid axis: ", *axis, "; must be in [",
39 -1 * rank_after_pack, ",", rank_after_pack,
40 ")");
41 }
42 if (*axis < 0) *axis = (rank_after_pack + *axis);
43 return Status::OK();
44 }
45
46 template <typename T>
AsInt64(const Tensor * tensor,int64 num_elements)47 std::vector<int64> AsInt64(const Tensor* tensor, int64 num_elements) {
48 std::vector<int64> ret(num_elements);
49 auto data = tensor->vec<T>();
50 for (int64 i = 0; i < num_elements; ++i) {
51 ret[i] = data(i);
52 }
53 return ret;
54 }
55
56 template <typename T>
PadKnown(InferenceContext * c,ShapeHandle input,const Tensor * paddings_t,int64 num_dims)57 Status PadKnown(InferenceContext* c, ShapeHandle input,
58 const Tensor* paddings_t, int64 num_dims) {
59 // paddings_t is known.
60 std::vector<DimensionHandle> dims(num_dims);
61 auto paddings_data = paddings_t->matrix<T>();
62 for (int64 i = 0; i < num_dims; ++i) {
63 const T pad0 = paddings_data(i, 0);
64 const T pad1 = paddings_data(i, 1);
65 if (pad0 < 0 || pad1 < 0) {
66 return errors::InvalidArgument("Paddings must be non-negative");
67 }
68 TF_RETURN_IF_ERROR(c->Add(c->Dim(input, i), pad0 + pad1, &dims[i]));
69 }
70 c->set_output(0, c->MakeShape(dims));
71 return Status::OK();
72 }
73
PadShapeFn(InferenceContext * c)74 Status PadShapeFn(InferenceContext* c) {
75 // Paddings is a matrix of [input_rank, 2].
76 ShapeHandle paddings;
77 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &paddings));
78 DimensionHandle unused;
79 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(paddings, 1), 2, &unused));
80
81 // n_dim and input.rank are equivalent.
82 ShapeHandle input = c->input(0);
83 DimensionHandle n_dim = c->Dim(paddings, 0);
84 if (c->ValueKnown(n_dim)) {
85 TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(n_dim), &input));
86 } else if (c->RankKnown(input)) {
87 TF_RETURN_IF_ERROR(c->WithValue(n_dim, c->Rank(input), &n_dim));
88 }
89
90 const Tensor* paddings_t = c->input_tensor(1);
91
92 // paddings_t is unknown
93 if (paddings_t == nullptr) {
94 if (c->ValueKnown(n_dim)) {
95 // Make output with n_dim unknown dims.
96 c->set_output(0, c->UnknownShapeOfRank(c->Value(n_dim)));
97 } else {
98 c->set_output(0, c->UnknownShape());
99 }
100 return Status::OK();
101 }
102
103 const int64 num_dims = paddings_t->shape().dim_size(0);
104 TF_RETURN_IF_ERROR(c->WithRank(input, num_dims, &input));
105 TF_RETURN_IF_ERROR(c->WithValue(n_dim, num_dims, &n_dim));
106
107 if (paddings_t->dtype() == DT_INT32) {
108 return PadKnown<int32>(c, input, paddings_t, num_dims);
109 } else {
110 return PadKnown<int64>(c, input, paddings_t, num_dims);
111 }
112 }
113
TransposeShapeFn(InferenceContext * c)114 Status TransposeShapeFn(InferenceContext* c) {
115 ShapeHandle input = c->input(0);
116 ShapeHandle perm_shape = c->input(1);
117 const Tensor* perm = c->input_tensor(1);
118 DimensionHandle perm_elems = c->NumElements(perm_shape);
119 // If we don't have rank information on the input or value information on
120 // perm we can't return any shape information, otherwise we have enough
121 // information to at least find the rank of the output.
122 if (!c->RankKnown(input) && !c->ValueKnown(perm_elems) && perm == nullptr) {
123 c->set_output(0, c->UnknownShape());
124 return Status::OK();
125 }
126
127 // Find our value of the rank.
128 int64 rank;
129 if (c->RankKnown(input)) {
130 rank = c->Rank(input);
131 } else if (c->ValueKnown(perm_elems)) {
132 rank = c->Value(perm_elems);
133 } else {
134 rank = perm->NumElements();
135 }
136 if (!c->RankKnown(input) && rank < 2) {
137 // A permutation array containing a single element is ambiguous. It could
138 // indicate either a scalar or a 1-dimensional array, both of which the
139 // transpose op returns unchanged.
140 c->set_output(0, input);
141 return Status::OK();
142 }
143
144 std::vector<DimensionHandle> dims;
145 dims.resize(rank);
146 TF_RETURN_IF_ERROR(c->WithRank(input, rank, &input));
147 // Ensure that perm is a vector and has rank elements.
148 TF_RETURN_IF_ERROR(c->WithRank(perm_shape, 1, &perm_shape));
149 TF_RETURN_IF_ERROR(c->WithValue(perm_elems, rank, &perm_elems));
150
151 // If we know the rank of the input and the value of perm, we can return
152 // all shape informantion, otherwise we can only return rank information,
153 // but no information for the dimensions.
154 if (perm != nullptr) {
155 std::vector<int64> data;
156 if (perm->dtype() == DT_INT32) {
157 data = AsInt64<int32>(perm, rank);
158 } else {
159 data = AsInt64<int64>(perm, rank);
160 }
161
162 for (int32 i = 0; i < rank; ++i) {
163 int64 in_idx = data[i];
164 if (in_idx >= rank) {
165 return errors::InvalidArgument("perm dim ", in_idx,
166 " is out of range of input rank ", rank);
167 }
168 dims[i] = c->Dim(input, in_idx);
169 }
170 } else {
171 for (int i = 0; i < rank; ++i) {
172 dims[i] = c->UnknownDim();
173 }
174 }
175
176 c->set_output(0, c->MakeShape(dims));
177 return Status::OK();
178 }
179
SetOutputShapeForReshape(InferenceContext * c)180 Status SetOutputShapeForReshape(InferenceContext* c) {
181 ShapeHandle in = c->input(0);
182 ShapeHandle out;
183 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out));
184
185 if (!c->RankKnown(out)) {
186 // We have no information about the shape of the output.
187 c->set_output(0, out);
188 return Status::OK();
189 }
190
191 if (c->RankKnown(out) && c->RankKnown(in)) {
192 // We don't know the number of output elements, but we can try to infer
193 // the missing dimension.
194 bool too_many_unknown = false;
195 int32 out_unknown_idx = -1;
196
197 DimensionHandle known_out_elems = c->NumElements(out);
198 if (!c->ValueKnown(known_out_elems)) {
199 known_out_elems = c->MakeDim(1);
200 for (int32 i = 0; i < c->Rank(out); ++i) {
201 DimensionHandle dim = c->Dim(out, i);
202 if (!c->ValueKnown(dim)) {
203 if (out_unknown_idx >= 0) {
204 too_many_unknown = true;
205 break;
206 }
207 out_unknown_idx = i;
208 } else {
209 TF_RETURN_IF_ERROR(
210 c->Multiply(known_out_elems, dim, &known_out_elems));
211 }
212 }
213 }
214 int32 in_unknown_idx = -1;
215 DimensionHandle known_in_elems = c->NumElements(in);
216 if (!c->ValueKnown(known_in_elems)) {
217 known_in_elems = c->MakeDim(1);
218 for (int32 i = 0; i < c->Rank(in); ++i) {
219 DimensionHandle dim = c->Dim(in, i);
220 if (!c->ValueKnown(dim)) {
221 if (in_unknown_idx >= 0) {
222 too_many_unknown = true;
223 break;
224 }
225 in_unknown_idx = i;
226 } else {
227 TF_RETURN_IF_ERROR(c->Multiply(known_in_elems, dim, &known_in_elems));
228 }
229 }
230 }
231
232 if (!too_many_unknown) {
233 if (in_unknown_idx < 0 && out_unknown_idx < 0) {
234 // Just check that the dimensions match.
235 if (c->Value(known_in_elems) != c->Value(known_out_elems)) {
236 return errors::InvalidArgument(
237 "Cannot reshape a tensor with ", c->DebugString(known_in_elems),
238 " elements to shape ", c->DebugString(out), " (",
239 c->DebugString(known_out_elems), " elements)");
240 }
241 } else if (in_unknown_idx < 0 && out_unknown_idx >= 0 &&
242 c->Value(known_out_elems) > 0) {
243 // Input fully known, infer the one missing output dim
244 DimensionHandle inferred_dim;
245 TF_RETURN_IF_ERROR(c->Divide(known_in_elems, c->Value(known_out_elems),
246 true /* evenly_divisible */,
247 &inferred_dim));
248 TF_RETURN_IF_ERROR(
249 c->ReplaceDim(out, out_unknown_idx, inferred_dim, &out));
250
251 } else if (in_unknown_idx >= 0 && out_unknown_idx < 0 &&
252 c->Value(known_in_elems) != 0) {
253 // Output fully known, infer the one missing input dim
254 DimensionHandle inferred_dim;
255 TF_RETURN_IF_ERROR(c->Divide(known_out_elems, c->Value(known_in_elems),
256 true /* evenly_divisible */,
257 &inferred_dim));
258 DimensionHandle unknown_in_dim = c->Dim(in, in_unknown_idx);
259 TF_RETURN_IF_ERROR(
260 c->Merge(unknown_in_dim, inferred_dim, &unknown_in_dim));
261 } else if (in_unknown_idx >= 0 && out_unknown_idx >= 0) {
262 // Exactly one unknown dimension in both input and output. These 2 are
263 // equal iff the known elements are equal.
264 if (c->Value(known_in_elems) == c->Value(known_out_elems)) {
265 DimensionHandle unknown_in_dim = c->Dim(in, in_unknown_idx);
266 TF_RETURN_IF_ERROR(
267 c->ReplaceDim(out, out_unknown_idx, unknown_in_dim, &out));
268 }
269 }
270 }
271 }
272 c->set_output(0, out);
273 return Status::OK();
274 }
275
276 } // namespace
277
278 REGISTER_OP("ParallelConcat")
279 .Input("values: N * T")
280 .Output("output: T")
281 .Attr("N: int >= 1")
282 .Attr("T: type")
283 .Attr("shape: shape")
__anondb9326b20202(InferenceContext* c) 284 .SetShapeFn([](InferenceContext* c) {
285 // Validate that the shape attr is correct.
286 PartialTensorShape shape;
287 TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
288 ShapeHandle passed_shape;
289 TF_RETURN_IF_ERROR(
290 c->MakeShapeFromPartialTensorShape(shape, &passed_shape));
291 if (!c->FullyDefined(passed_shape)) {
292 return errors::InvalidArgument("shape attr must be fully defined.");
293 }
294 ShapeHandle cur;
295 TF_RETURN_IF_ERROR(c->ReplaceDim(
296 passed_shape, 0, c->MakeDim(shape_inference::DimensionOrConstant(1)),
297 &cur));
298 for (int i = 0; i < c->num_inputs(); ++i) {
299 if (!c->FullyDefined(c->input(i))) {
300 return errors::InvalidArgument(
301 "All input shapes must be fully defined.");
302 }
303 DimensionHandle unused;
304 if (!c->WithValue(c->Dim(c->input(i), 0), 1, &unused).ok()) {
305 return errors::InvalidArgument("Size of first dimension must be 1.");
306 }
307 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
308 "From merging shape ", i,
309 " with other shapes.");
310 }
311
312 c->set_output(0, passed_shape);
313
314 return Status::OK();
315 });
316
317 REGISTER_OP("Pack")
318 .Input("values: N * T")
319 .Output("output: T")
320 .Attr("N: int >= 1")
321 .Attr("T: type")
322 .Attr("axis: int = 0")
__anondb9326b20302(InferenceContext* c) 323 .SetShapeFn([](InferenceContext* c) {
324 // Validate shapes of all inputs are compatible
325 ShapeHandle cur = c->input(c->num_inputs() - 1);
326 for (int i = c->num_inputs() - 2; i >= 0; --i) {
327 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
328 "From merging shape ", i,
329 " with other shapes.");
330 }
331 if (!c->RankKnown(cur)) {
332 c->set_output(0, c->UnknownShape());
333 return Status::OK();
334 }
335 // Determine the axis that will be added, converting from negative
336 // axes to a positive point per negative indexing rules.
337 int32 rank = c->Rank(cur);
338 int32 axis;
339 TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank + 1, &axis));
340
341 // Copy all dimensions over, inserting a dimension of value #inputs
342 // at <axis>.
343 std::vector<DimensionHandle> dims;
344 int index = 0;
345 while (index < axis) dims.push_back(c->Dim(cur, index++));
346 dims.push_back(c->MakeDim(c->num_inputs()));
347 while (index < rank) dims.push_back(c->Dim(cur, index++));
348
349 c->set_output(0, c->MakeShape(dims));
350 for (int i = 0; i < c->num_inputs(); ++i) {
351 auto* shape_and_type = c->input_handle_shapes_and_types(i);
352 if (shape_and_type) {
353 if (!c->RelaxOutputHandleShapesAndMergeTypes(0, *shape_and_type)) {
354 c->set_output_handle_shapes_and_types(
355 0, std::vector<shape_inference::ShapeAndType>({}));
356 break;
357 }
358 }
359 }
360 return Status::OK();
361 });
362
363 REGISTER_OP("DeepCopy")
364 .Input("x: T")
365 .Output("y: T")
366 .Attr("T: type")
367 .SetIsStateful()
368 .SetShapeFn(UnchangedShape);
369
370 REGISTER_OP("InplaceUpdate")
371 .Input("x: T")
372 .Input("i: int32")
373 .Input("v: T")
374 .Output("y: T")
375 .Attr("T: type")
376 .SetShapeFn(UnchangedShape);
377
378 REGISTER_OP("InplaceAdd")
379 .Input("x: T")
380 .Input("i: int32")
381 .Input("v: T")
382 .Output("y: T")
383 .Attr("T: type")
384 .SetShapeFn(UnchangedShape);
385
386 REGISTER_OP("InplaceSub")
387 .Input("x: T")
388 .Input("i: int32")
389 .Input("v: T")
390 .Output("y: T")
391 .Attr("T: type")
392 .SetShapeFn(UnchangedShape);
393
394 REGISTER_OP("Empty")
395 .Input("shape: int32")
396 .Output("output: dtype")
397 .Attr("dtype: type")
398 .Attr("init: bool = false")
399 .SetIsStateful()
__anondb9326b20402(InferenceContext* c) 400 .SetShapeFn([](InferenceContext* c) {
401 ShapeHandle out;
402 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
403 c->set_output(0, out);
404 return Status::OK();
405 });
406
407 // --------------------------------------------------------------------------
408 REGISTER_OP("Unpack")
409 .Input("value: T")
410 .Output("output: num * T")
411 .Attr("num: int >= 0")
412 .Attr("T: type")
413 .Attr("axis: int = 0")
__anondb9326b20502(InferenceContext* c) 414 .SetShapeFn([](InferenceContext* c) {
415 ShapeHandle s = c->input(0);
416 ShapeHandle out;
417 if (c->RankKnown(s)) {
418 // Determine the axis that will be removed, converting from negative
419 // axes to a positive point per negative indexing rules.
420 int32 rank = c->Rank(s);
421 int32 axis;
422 TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank, &axis));
423
424 // The axis dim matches the number of outputs.
425 DimensionHandle unused;
426 TF_RETURN_IF_ERROR(
427 c->WithValue(c->Dim(s, axis), c->num_outputs(), &unused));
428
429 // Copy all dimensions, removing the <axis> dimension.
430 std::vector<DimensionHandle> dims;
431 for (int i = 0; i < rank; ++i) {
432 if (i != axis) dims.push_back(c->Dim(s, i));
433 }
434 out = c->MakeShape(dims);
435 } else {
436 // All outputs are the same shape, but it's not known.
437 out = c->UnknownShape();
438 }
439 for (int i = 0; i < c->num_outputs(); ++i) c->set_output(i, out);
440 return Status::OK();
441 });
442
443 REGISTER_OP("UnravelIndex")
444 .Input("indices: Tidx")
445 .Input("dims: Tidx")
446 .Output("output: Tidx")
447 .Attr("Tidx: {int32, int64} = DT_INT32")
__anondb9326b20602(InferenceContext* c) 448 .SetShapeFn([](InferenceContext* c) {
449 ShapeHandle indices = c->input(0);
450 ShapeHandle dims;
451 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &dims));
452 if (c->RankKnown(indices) && c->Rank(indices) == 0) {
453 c->set_output(0, c->Vector(c->Dim(dims, 0)));
454 } else if (c->RankKnown(indices)) {
455 c->set_output(0, c->Matrix(c->Dim(dims, 0), c->NumElements(indices)));
456 } else {
457 c->set_output(0, c->UnknownShape());
458 }
459 return Status::OK();
460 });
461
462 REGISTER_OP("BroadcastTo")
463 .Input("input: T")
464 .Input("shape: Tidx")
465 .Output("output: T")
466 .Attr("T: type")
467 .Attr("Tidx: {int32, int64} = DT_INT32")
__anondb9326b20702(InferenceContext* c) 468 .SetShapeFn([](InferenceContext* c) {
469 ShapeHandle shape_in = c->input(1);
470 TF_RETURN_IF_ERROR(c->WithRank(shape_in, 1, &shape_in));
471 ShapeHandle out;
472 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out));
473 if (!c->RankKnown(out)) {
474 // We have no information about the shape of the output.
475 c->set_output(0, out);
476 return Status::OK();
477 }
478
479 ShapeHandle in = c->input(0);
480 if (!c->RankKnown(in)) {
481 // We have no information about the shape of the input,
482 // nothing to do here.
483 c->set_output(0, out);
484 return Status::OK();
485 }
486 int out_rank = c->Rank(out);
487 TF_RETURN_IF_ERROR(c->WithRankAtMost(in, out_rank, &in));
488 int in_rank = c->Rank(in);
489 for (int i = 0; i < in_rank; ++i) {
490 auto in_dim = c->Dim(in, in_rank - i - 1);
491 if (c->Value(in_dim) > 1) {
492 // If the input dimension is greater than 1 then the output dimension
493 // must be equal to it, since we only broadcast "from left to right".
494 auto out_dim = c->Dim(out, out_rank - i - 1);
495 TF_RETURN_IF_ERROR(c->Merge(in_dim, out_dim, &out_dim));
496 TF_RETURN_IF_ERROR(
497 c->ReplaceDim(out, out_rank - i - 1, out_dim, &out));
498 }
499 }
500 c->set_output(0, out);
501 return Status::OK();
502 });
503
504 // --------------------------------------------------------------------------
505 // TODO(josh11b): Remove the >= 2 constraint, once we can rewrite the graph
506 // in the N == 1 case to remove the node.
507 REGISTER_OP("Concat")
508 .Input("concat_dim: int32")
509 .Input("values: N * T")
510 .Output("output: T")
511 .Attr("N: int >= 2")
512 .Attr("T: type")
__anondb9326b20802(InferenceContext* c) 513 .SetShapeFn([](InferenceContext* c) {
514 return shape_inference::ConcatShape(c, c->num_inputs() - 1);
515 });
516
517 REGISTER_OP("ConcatV2")
518 .Input("values: N * T")
519 .Input("axis: Tidx")
520 .Output("output: T")
521 .Attr("N: int >= 2")
522 .Attr("T: type")
523 .Attr("Tidx: {int32, int64} = DT_INT32")
524 .SetShapeFn(shape_inference::ConcatV2Shape);
525
526 // TODO(vivek.v.rane@intel.com): Prefix the op names with underscore if the ops
527 // are not to be made user-accessible.
528 #ifdef INTEL_MKL
529 REGISTER_OP("_MklConcatV2")
530 .Input("values: N * T")
531 .Input("axis: Tidx")
532 .Input("mkl_values: N * uint8")
533 .Input("mkl_axis: uint8")
534 .Output("output: T")
535 .Output("mkl_output: uint8")
536 .Attr("N: int >= 2")
537 .Attr("T: type")
538 .Attr("Tidx: {int32, int64} = DT_INT32")
539 .SetShapeFn(shape_inference::ConcatV2Shape)
540 .Doc(R"doc(
541 MKL version of ConcatV2 operator. Uses MKL DNN APIs to perform concatenation.
542
543 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
544 expected to invoke these operators.
545 )doc");
546 #endif
547
548 REGISTER_OP("ConcatOffset")
549 .Input("concat_dim: int32")
550 .Input("shape: N * int32")
551 .Output("offset: N * int32")
552 .Attr("N: int >= 2")
__anondb9326b20902(InferenceContext* c) 553 .SetShapeFn([](InferenceContext* c) {
554 for (int i = 1; i < c->num_inputs(); ++i) {
555 c->set_output(i - 1, c->input(i));
556 }
557 return Status::OK();
558 });
559
560 // --------------------------------------------------------------------------
561 REGISTER_OP("Split")
562 .Input("split_dim: int32")
563 .Input("value: T")
564 .Output("output: num_split * T")
565 .Attr("num_split: int >= 1")
566 .Attr("T: type")
__anondb9326b20a02(InferenceContext* c) 567 .SetShapeFn([](InferenceContext* c) {
568 DimensionHandle split_dimension;
569 ShapeHandle input = c->input(1);
570 TF_RETURN_IF_ERROR(c->MakeDimForScalarInputWithNegativeIndexing(
571 0, c->Rank(input), &split_dimension));
572 int num_split = c->num_outputs();
573 ShapeHandle out;
574 if (!c->ValueKnown(split_dimension)) {
575 if (c->RankKnown(input)) {
576 out = c->UnknownShapeOfRank(c->Rank(input));
577 } else {
578 out = c->UnknownShape();
579 }
580 } else {
581 int64 split_dim = c->Value(split_dimension);
582 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input));
583 DimensionHandle split_dim_size;
584 TF_RETURN_WITH_CONTEXT_IF_ERROR(
585 c->Divide(c->Dim(input, split_dim), num_split,
586 true /* evenly_divisible */, &split_dim_size),
587 "Number of ways to split should evenly divide the split dimension");
588 TF_RETURN_IF_ERROR(
589 c->ReplaceDim(input, split_dim, split_dim_size, &out));
590 }
591 for (int i = 0; i < num_split; ++i) c->set_output(i, out);
592 return Status::OK();
593 });
594
595 REGISTER_OP("SplitV")
596 .Input("value: T")
597 .Input("size_splits: Tlen")
598 .Input("split_dim: int32")
599 .Output("output: num_split * T")
600 .Attr("num_split: int >= 1")
601 .Attr("T: type")
602 .Attr("Tlen: {int32, int64} = DT_INT64")
__anondb9326b20b02(InferenceContext* c) 603 .SetShapeFn([](InferenceContext* c) {
604 DimensionHandle split_dimension;
605 ShapeHandle input = c->input(0);
606 TF_RETURN_IF_ERROR(c->MakeDimForScalarInputWithNegativeIndexing(
607 2, c->Rank(input), &split_dimension));
608 int32 num_outputs = c->num_outputs();
609 int32 rank = c->Rank(input);
610 ShapeHandle output_shape;
611 const Tensor* size_splits = c->input_tensor(1);
612 if (rank == InferenceContext::kUnknownRank) {
613 // If the rank of input tensor is unknown, then return unknown shapes.
614 // Note that the shape of each output can be different.
615 for (int i = 0; i < num_outputs; ++i) {
616 c->set_output(i, c->UnknownShape());
617 }
618 } else if (rank == 0) {
619 // Throw error if input is a scalar.
620 return errors::InvalidArgument("Can't split scalars");
621 } else if (size_splits == nullptr && c->ValueKnown(split_dimension)) {
622 // If split dimension is known, but the sizes are unknown, then
623 // only the split dimension is unknown
624 output_shape = input;
625 for (int i = 0; i < num_outputs; ++i) {
626 TF_RETURN_IF_ERROR(c->ReplaceDim(output_shape,
627 c->Value(split_dimension),
628 c->UnknownDim(), &output_shape));
629 c->set_output(i, output_shape);
630 }
631 } else if (size_splits == nullptr && !c->ValueKnown(split_dimension)) {
632 // If split dimension or tensor containing the split sizes is unknown,
633 // then return unknown shapes of same rank as input. Note that each
634 // output shape can be different since splitv doesn't always split
635 // tensors evenly.
636 for (int i = 0; i < num_outputs; ++i) {
637 c->set_output(i, c->UnknownShapeOfRank(rank));
638 }
639 } else {
640 // Determine the output shape if split dimension and split sizes are
641 // known.
642 int64 split_dim = c->Value(split_dimension);
643 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input));
644 std::vector<int64> data;
645 if (size_splits->dtype() == DT_INT32) {
646 data = AsInt64<int32>(size_splits, size_splits->shape().dim_size(0));
647 } else {
648 data = AsInt64<int64>(size_splits, size_splits->shape().dim_size(0));
649 }
650 if (num_outputs != data.size()) {
651 return errors::InvalidArgument(
652 "Length of size_splits should be equal to num_outputs");
653 }
654 int64_t total_size = 0;
655 bool has_neg_one = false;
656 for (const auto size : data) {
657 if (size == -1) {
658 if (has_neg_one) {
659 return errors::InvalidArgument(
660 "size_splits can only have one -1");
661 }
662 has_neg_one = true;
663 } else {
664 total_size += size;
665 }
666 }
667 auto split_dim_size = c->Value(c->Dim(input, split_dim));
668 // If the sizes of the splits are known, then
669 // make sure that the sizes add up to the expected
670 // dimension size, with the possibility of a -1.
671 // Specify the full output shapes.
672 for (int i = 0; i < num_outputs; ++i) {
673 auto size = data[i];
674 if (data[i] == -1 && c->ValueKnown(split_dim_size)) {
675 size = split_dim_size - total_size;
676 }
677 TF_RETURN_IF_ERROR(
678 c->ReplaceDim(input, split_dim, c->MakeDim(size), &output_shape));
679 c->set_output(i, output_shape);
680 }
681 if (c->ValueKnown(split_dim_size)) {
682 if (has_neg_one ? total_size > split_dim_size
683 : total_size != split_dim_size) {
684 return errors::InvalidArgument(
685 "can't split axis of size ", split_dim_size,
686 " into pieces of size [", str_util::Join(data, ","), "]");
687 }
688 }
689 }
690
691 return Status::OK();
692 });
693
694 // --------------------------------------------------------------------------
695 REGISTER_OP("Const")
696 .Output("output: dtype")
697 .Attr("value: tensor")
698 .Attr("dtype: type")
__anondb9326b20c02(InferenceContext* c) 699 .SetShapeFn([](InferenceContext* c) {
700 const TensorProto* proto = nullptr;
701 TF_RETURN_IF_ERROR(c->GetAttr("value", &proto));
702 TF_RETURN_IF_ERROR(TensorShape::IsValidShape(proto->tensor_shape()));
703 TensorShape shape(proto->tensor_shape());
704 std::vector<DimensionHandle> dims;
705 dims.reserve(shape.dims());
706 for (int i = 0; i < shape.dims(); ++i) {
707 dims.push_back(c->MakeDim(shape.dim_size(i)));
708 }
709 c->set_output(0, c->MakeShape(dims));
710 return Status::OK();
711 });
712
713 // Returns a constant tensor on the host. Useful for writing C++ tests
714 // and benchmarks which run on GPU but require arguments pinned to the host.
715 // Used by test::graph::HostConstant.
716 // value: Attr `value` is the tensor to return.
717 REGISTER_OP("HostConst")
718 .Output("output: dtype")
719 .Attr("value: tensor")
720 .Attr("dtype: type")
721 .SetShapeFn(shape_inference::UnknownShape);
722
723 // --------------------------------------------------------------------------
724 // TODO(mgubin): Update the doc when the freeze_graph script supports converting
725 // into memmapped format.
726 REGISTER_OP("ImmutableConst")
727 .Attr("dtype: type")
728 .Attr("shape: shape")
729 .Attr("memory_region_name: string")
730 .Output("tensor: dtype")
731 .SetShapeFn(shape_inference::ExplicitShape);
732
733 REGISTER_OP("GuaranteeConst")
734 .Input("input: T")
735 .Output("output: T")
736 .Attr("T: type")
__anondb9326b20d02(shape_inference::InferenceContext* c) 737 .SetShapeFn([](shape_inference::InferenceContext* c) {
738 return UnchangedShape(c);
739 })
740 // We don't want this to be optimized away.
741 .SetIsStateful();
742
743 // --------------------------------------------------------------------------
744 REGISTER_OP("ZerosLike")
745 .Input("x: T")
746 .Output("y: T")
747 .Attr("T: type")
748 .SetShapeFn(shape_inference::UnchangedShape);
749
750 // --------------------------------------------------------------------------
751 REGISTER_OP("OnesLike")
752 .Input("x: T")
753 .Output("y: T")
754 .Attr(
755 "T: {bfloat16, half, float, double, int8, uint8, int16, uint16, int32, "
756 "int64, complex64, complex128, bool}")
757 .SetShapeFn(shape_inference::UnchangedShape);
758
759 // --------------------------------------------------------------------------
760 REGISTER_OP("Diag")
761 .Input("diagonal: T")
762 .Output("output: T")
763 .Attr(
764 "T: {bfloat16, half, float, double, int32, int64, complex64, "
765 "complex128}")
__anondb9326b20e02(InferenceContext* c) 766 .SetShapeFn([](InferenceContext* c) {
767 ShapeHandle in = c->input(0);
768 TF_RETURN_IF_ERROR(c->WithRankAtLeast(in, 1, &in));
769 // Output shape is original concatenated with itself.
770 ShapeHandle out;
771 TF_RETURN_IF_ERROR(c->Concatenate(in, in, &out));
772 c->set_output(0, out);
773 return Status::OK();
774 });
775
776 // --------------------------------------------------------------------------
777 REGISTER_OP("DiagPart")
778 .Input("input: T")
779 .Output("diagonal: T")
780 .Attr(
781 "T: {bfloat16, half, float, double, int32, int64, complex64, "
782 "complex128}")
__anondb9326b20f02(InferenceContext* c) 783 .SetShapeFn([](InferenceContext* c) {
784 ShapeHandle in = c->input(0);
785 if (!c->RankKnown(in)) {
786 c->set_output(0, c->UnknownShape());
787 return Status::OK();
788 }
789 // Rank must be even, and result will have rank <rank/2>.
790 const int32 rank = c->Rank(in);
791 if ((rank % 2) != 0 || rank <= 0) {
792 return errors::InvalidArgument(
793 "Input must have even and non-zero rank, input rank is ", rank);
794 }
795 const int32 mid = rank / 2;
796
797 // output dim[i] is the merge of in.dim[i] and in.dim[i+mid].
798 std::vector<DimensionHandle> dims(mid);
799 for (int i = 0; i < mid; ++i) {
800 TF_RETURN_IF_ERROR(
801 c->Merge(c->Dim(in, i), c->Dim(in, i + mid), &dims[i]));
802 }
803 c->set_output(0, c->MakeShape(dims));
804 return Status::OK();
805 });
806
807 // --------------------------------------------------------------------------
808 REGISTER_OP("MatrixDiag")
809 .Input("diagonal: T")
810 .Output("output: T")
811 .Attr("T: type")
__anondb9326b21002(InferenceContext* c) 812 .SetShapeFn([](InferenceContext* c) {
813 ShapeHandle in;
814 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &in));
815 if (!c->RankKnown(in)) {
816 c->set_output(0, c->UnknownShape());
817 return Status::OK();
818 }
819 const int32 rank = c->Rank(in);
820 ShapeHandle out;
821 TF_RETURN_IF_ERROR(
822 c->Concatenate(in, c->Vector(c->Dim(in, rank - 1)), &out));
823 c->set_output(0, out);
824 return Status::OK();
825 });
826
827 // --------------------------------------------------------------------------
828 REGISTER_OP("MatrixSetDiag")
829 .Input("input: T")
830 .Input("diagonal: T")
831 .Output("output: T")
832 .Attr("T: type")
__anondb9326b21102(InferenceContext* c) 833 .SetShapeFn([](InferenceContext* c) {
834 ShapeHandle input;
835 ShapeHandle diag;
836 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
837 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &diag));
838 if (c->RankKnown(input)) {
839 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), c->Rank(input) - 1, &diag));
840 }
841 DimensionHandle smallest_dim;
842 TF_RETURN_IF_ERROR(
843 c->Min(c->Dim(input, -2), c->Dim(input, -1), &smallest_dim));
844 TF_RETURN_IF_ERROR(
845 c->Merge(smallest_dim, c->Dim(diag, -1), &smallest_dim));
846
847 ShapeHandle output = input;
848 if (c->RankKnown(diag) && !c->FullyDefined(input)) {
849 // Try to infer parts of shape from diag.
850 ShapeHandle diag_prefix;
851 TF_RETURN_IF_ERROR(c->Subshape(diag, 0, -1, &diag_prefix));
852 TF_RETURN_IF_ERROR(
853 c->Concatenate(diag_prefix, c->UnknownShapeOfRank(2), &diag));
854 TF_RETURN_IF_ERROR(c->Merge(input, diag, &output));
855 }
856 c->set_output(0, output);
857 return Status::OK();
858 });
859
860 // --------------------------------------------------------------------------
861 REGISTER_OP("MatrixDiagPart")
862 .Input("input: T")
863 .Output("diagonal: T")
864 .Attr("T: type")
__anondb9326b21202(InferenceContext* c) 865 .SetShapeFn([](InferenceContext* c) {
866 ShapeHandle in;
867 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &in));
868 if (!c->RankKnown(in)) {
869 c->set_output(0, c->UnknownShape());
870 return Status::OK();
871 }
872 const int32 rank = c->Rank(in);
873 std::vector<DimensionHandle> dims;
874 dims.reserve(rank - 2);
875 for (int i = 0; i < rank - 2; ++i) dims.push_back(c->Dim(in, i));
876
877 DimensionHandle min_dim;
878 TF_RETURN_IF_ERROR(
879 c->Min(c->Dim(in, rank - 2), c->Dim(in, rank - 1), &min_dim));
880 dims.push_back(min_dim);
881 c->set_output(0, c->MakeShape(dims));
882 return Status::OK();
883 });
884
885 // --------------------------------------------------------------------------
886 REGISTER_OP("MatrixBandPart")
887 .Input("input: T")
888 .Input("num_lower: Tindex")
889 .Input("num_upper: Tindex")
890 .Output("band: T")
891 .Attr("T: type")
892 .Attr("Tindex: {int32, int64} = DT_INT64")
893 .SetShapeFn(shape_inference::UnchangedShape);
894
895 // --------------------------------------------------------------------------
896 REGISTER_OP("Reverse")
897 .Input("tensor: T")
898 .Input("dims: bool")
899 .Output("output: T")
900 .Attr(
901 "T: {uint8, int8, uint16, int16, int32, int64, bool, half, "
902 "float, double, complex64, complex128, string}")
__anondb9326b21302(InferenceContext* c) 903 .SetShapeFn([](InferenceContext* c) {
904 ShapeHandle input = c->input(0);
905 ShapeHandle dims;
906 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &dims));
907 DimensionHandle dims_dim = c->Dim(dims, 0);
908 if (c->ValueKnown(dims_dim)) {
909 TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(dims_dim), &input));
910 }
911 if (c->Rank(input) > 8) {
912 return errors::InvalidArgument(
913 "reverse does not work on tensors with more than 8 dimensions");
914 }
915 c->set_output(0, input);
916 return Status::OK();
917 });
918
919 // --------------------------------------------------------------------------
920 REGISTER_OP("ReverseV2")
921 .Input("tensor: T")
922 .Input("axis: Tidx")
923 .Output("output: T")
924 .Attr("Tidx: {int32, int64} = DT_INT32")
925 .Attr(
926 "T: {uint8, int8, uint16, int16, int32, int64, bool, bfloat16, half, "
927 "float, double, complex64, complex128, string}")
__anondb9326b21402(InferenceContext* c) 928 .SetShapeFn([](InferenceContext* c) {
929 ShapeHandle input = c->input(0);
930 ShapeHandle axis;
931 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &axis));
932 if (c->Rank(input) > 8) {
933 return errors::InvalidArgument(
934 "reverse does not work on tensors with more than 8 dimensions");
935 }
936 const Tensor* axis_tensor = c->input_tensor(1);
937 if (axis_tensor != nullptr && c->RankKnown(input)) {
938 int32 rank = c->Rank(input);
939 std::vector<int64> axis_value;
940 if (axis_tensor->dtype() == DT_INT32) {
941 axis_value = AsInt64<int32>(axis_tensor, axis_tensor->NumElements());
942 } else {
943 axis_value = AsInt64<int64>(axis_tensor, axis_tensor->NumElements());
944 }
945 std::vector<bool> axes_dense(c->Rank(input), false);
946 for (int i = 0; i < axis_value.size(); i++) {
947 int64 canonical_axis =
948 axis_value[i] < 0 ? rank + axis_value[i] : axis_value[i];
949 if (canonical_axis < 0 || canonical_axis >= rank) {
950 return errors::InvalidArgument("'axis'[", i, "] = ", axis_value[i],
951 " is out of valid range [", 0, ", ",
952 rank - 1);
953 }
954 if (axes_dense[canonical_axis]) {
955 return errors::InvalidArgument("axis ", canonical_axis,
956 " specified more than once.");
957 }
958 axes_dense[canonical_axis] = true;
959 }
960 }
961 c->set_output(0, input);
962 return Status::OK();
963 });
964
965 // --------------------------------------------------------------------------
966 REGISTER_OP("EditDistance")
967 .Input("hypothesis_indices: int64")
968 .Input("hypothesis_values: T")
969 .Input("hypothesis_shape: int64")
970 .Input("truth_indices: int64")
971 .Input("truth_values: T")
972 .Input("truth_shape: int64")
973 .Attr("normalize: bool = true")
974 .Attr("T: type")
975 .Output("output: float")
__anondb9326b21502(InferenceContext* c) 976 .SetShapeFn([](InferenceContext* c) {
977 TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
978 c, c->input(0), c->input(1), c->input(2)));
979 TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
980 c, c->input(3), c->input(4), c->input(5)));
981 const Tensor* hypothesis_shape_t = c->input_tensor(2);
982 const Tensor* truth_shape_t = c->input_tensor(5);
983 if (hypothesis_shape_t == nullptr || truth_shape_t == nullptr) {
984 // We need to know the runtime shape of the two tensors,
985 // or else the output shape is unknown.
986 return shape_inference::UnknownShape(c);
987 }
988
989 if (hypothesis_shape_t->NumElements() != truth_shape_t->NumElements()) {
990 return errors::InvalidArgument(
991 "Num elements of hypothesis_shape does not match truth_shape: ",
992 hypothesis_shape_t->NumElements(), " vs. ",
993 truth_shape_t->NumElements());
994 }
995
996 auto h_values = hypothesis_shape_t->flat<int64>();
997 auto t_values = truth_shape_t->flat<int64>();
998 std::vector<DimensionHandle> dims(hypothesis_shape_t->NumElements() - 1);
999 for (int i = 0; i < dims.size(); ++i) {
1000 dims[i] = c->MakeDim(std::max(h_values(i), t_values(i)));
1001 }
1002
1003 c->set_output(0, c->MakeShape(dims));
1004 return Status::OK();
1005 });
1006
1007 // --------------------------------------------------------------------------
1008 REGISTER_OP("Fill")
1009 .Input("dims: index_type")
1010 .Input("value: T")
1011 .Output("output: T")
1012 .Attr("T: type")
1013 .Attr("index_type: {int32, int64} = DT_INT32")
__anondb9326b21602(InferenceContext* c) 1014 .SetShapeFn([](InferenceContext* c) {
1015 DataType index_type = DT_INT32;
1016 Status s = c->GetAttr("index_type", &index_type);
1017 if (!s.ok() && s.code() != error::NOT_FOUND) {
1018 return s;
1019 }
1020 ShapeHandle unused;
1021 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
1022 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1023
1024 const Tensor* t = c->input_tensor(0);
1025 if (t != nullptr) {
1026 for (int i = 0; i < t->NumElements(); ++i) {
1027 if ((index_type == DT_INT32 && t->vec<int32>()(i) < 0) ||
1028 (index_type == DT_INT64 && t->vec<int64>()(i) < 0)) {
1029 return errors::InvalidArgument("Fill dimensions must be >= 0");
1030 }
1031 }
1032 }
1033
1034 ShapeHandle out;
1035 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
1036 c->set_output(0, out);
1037
1038 auto* shape_and_type = c->input_handle_shapes_and_types(1);
1039 if (shape_and_type) {
1040 c->set_output_handle_shapes_and_types(0, *shape_and_type);
1041 }
1042
1043 return Status::OK();
1044 });
1045
1046 // --------------------------------------------------------------------------
1047 REGISTER_OP("_ParallelConcatStart")
1048 .Output("output: dtype")
1049 .Attr("shape: shape")
1050 .Attr("dtype: type")
1051 .SetIsStateful()
1052 .SetShapeFn(shape_inference::ExplicitShape)
1053 .Doc(R"doc(
1054 Creates an empty Tensor with shape `shape` and type `dtype`.
1055
1056 The memory can optionally be initialized. This is usually useful in
1057 conjunction with inplace operations.
1058
1059 shape: 1-D `Tensor` indicating the shape of the output.
1060 dtype: The element type of the returned tensor.
1061 output: An empty Tensor of the specified type.
1062 )doc");
1063
1064 // --------------------------------------------------------------------------
1065 REGISTER_OP("_ParallelConcatUpdate")
1066 .Input("value: T")
1067 .Input("update: T")
1068 .Output("output: T")
1069 .Attr("T: type")
1070 .Attr("loc: int")
1071 .SetShapeFn(shape_inference::UnchangedShape)
1072 .Doc(R"doc(
1073 Updates input `value` at `loc` with `update`.
1074
1075 If you use this function you will almost certainly want to add
1076 a control dependency as done in the implementation of parallel_stack to
1077 avoid race conditions.
1078
1079 value: A `Tensor` object that will be updated in-place.
1080 loc: A scalar indicating the index of the first dimension such that
1081 value[loc, :] is updated.
1082 update: A `Tensor` of rank one less than `value` if `loc` is a scalar,
1083 otherwise of rank equal to `value` that contains the new values
1084 for `value`.
1085 output: `value` that has been updated accordingly.
1086 )doc");
1087
1088 // --------------------------------------------------------------------------
1089 REGISTER_OP("Gather")
1090 .Input("params: Tparams")
1091 .Input("indices: Tindices")
1092 .Attr("validate_indices: bool = true")
1093 .Output("output: Tparams")
1094 .Attr("Tparams: type")
1095 .Attr("Tindices: {int32,int64}")
__anondb9326b21702(InferenceContext* c) 1096 .SetShapeFn([](InferenceContext* c) {
1097 ShapeHandle unused;
1098 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &unused));
1099 ShapeHandle params_subshape;
1100 TF_RETURN_IF_ERROR(c->Subshape(c->input(0), 1, ¶ms_subshape));
1101 ShapeHandle indices_shape = c->input(1);
1102 ShapeHandle out;
1103 TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, params_subshape, &out));
1104 c->set_output(0, out);
1105 return Status::OK();
1106 });
1107
1108 // --------------------------------------------------------------------------
1109 REGISTER_OP("GatherV2")
1110 .Input("params: Tparams")
1111 .Input("indices: Tindices")
1112 .Input("axis: Taxis")
1113 .Output("output: Tparams")
1114 .Attr("Tparams: type")
1115 .Attr("Tindices: {int32,int64}")
1116 .Attr("Taxis: {int32,int64}")
__anondb9326b21802(InferenceContext* c) 1117 .SetShapeFn([](InferenceContext* c) {
1118 ShapeHandle params_shape;
1119 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, ¶ms_shape));
1120
1121 ShapeHandle indices_shape = c->input(1);
1122 ShapeHandle unused_axis_shape;
1123 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_axis_shape));
1124 const Tensor* axis_t = c->input_tensor(2);
1125
1126 // If axis is unknown, we can only infer that the result is params_rank +
1127 // indices_rank - 1.
1128 if (axis_t == nullptr) {
1129 if (c->RankKnown(params_shape) && c->RankKnown(indices_shape)) {
1130 c->set_output(0, c->UnknownShapeOfRank(c->Rank(params_shape) +
1131 c->Rank(indices_shape) - 1));
1132 } else {
1133 c->set_output(0, c->UnknownShape());
1134 }
1135 return Status::OK();
1136 }
1137
1138 // Note, axis can be negative.
1139 int64 axis = 0;
1140 if (axis_t->dtype() == DT_INT32) {
1141 axis = axis_t->scalar<int32>()();
1142 } else {
1143 axis = axis_t->scalar<int64>()();
1144 }
1145
1146 // Check that params has rank of at least axis + 1.
1147 ShapeHandle unused;
1148 TF_RETURN_IF_ERROR(c->WithRankAtLeast(
1149 params_shape, axis < 0 ? -axis : axis + 1, &unused));
1150
1151 ShapeHandle params_outer_subshape;
1152 TF_RETURN_IF_ERROR(
1153 c->Subshape(params_shape, 0, axis, ¶ms_outer_subshape));
1154
1155 ShapeHandle out;
1156 TF_RETURN_IF_ERROR(
1157 c->Concatenate(params_outer_subshape, indices_shape, &out));
1158
1159 // Slice from axis + 1 to the end of params_shape to collect the inner
1160 // dimensions of the result. Special case -1 here since -1 + 1 wraps, and
1161 // we slice from 0 to the end of shape. Subshape() handles all other
1162 // out-of-bounds checking.
1163 if (axis != -1) {
1164 ShapeHandle params_inner_subshape;
1165 TF_RETURN_IF_ERROR(
1166 c->Subshape(params_shape, axis + 1, ¶ms_inner_subshape));
1167 TF_RETURN_IF_ERROR(c->Concatenate(out, params_inner_subshape, &out));
1168 }
1169
1170 c->set_output(0, out);
1171 return Status::OK();
1172 });
1173
1174 // --------------------------------------------------------------------------
1175 REGISTER_OP("GatherNd")
1176 .Input("params: Tparams")
1177 .Input("indices: Tindices")
1178 .Output("output: Tparams")
1179 .Attr("Tparams: type")
1180 .Attr("Tindices: {int32,int64}")
__anondb9326b21902(InferenceContext* c) 1181 .SetShapeFn([](InferenceContext* c) {
1182 ShapeHandle params = c->input(0);
1183 ShapeHandle indices;
1184 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices));
1185 DimensionHandle r_dim = c->Dim(indices, -1);
1186
1187 if (!c->RankKnown(params) || !c->ValueKnown(r_dim)) {
1188 c->set_output(0, c->UnknownShape());
1189 return Status::OK();
1190 }
1191
1192 if (c->Value(r_dim) > c->Rank(params)) {
1193 return errors::InvalidArgument(
1194 "indices.shape[-1] must be <= params.rank, but saw indices shape: ",
1195 c->DebugString(indices),
1196 " and params shape: ", c->DebugString(params));
1197 }
1198
1199 // Remove r_dim from indices to get output.
1200 ShapeHandle indices_slice;
1201 ShapeHandle params_slice;
1202 TF_RETURN_IF_ERROR(c->Subshape(indices, 0, -1, &indices_slice));
1203 TF_RETURN_IF_ERROR(c->Subshape(params, c->Value(r_dim), ¶ms_slice));
1204 ShapeHandle out;
1205 TF_RETURN_IF_ERROR(c->Concatenate(indices_slice, params_slice, &out));
1206 c->set_output(0, out);
1207 return Status::OK();
1208 });
1209
1210 // --------------------------------------------------------------------------
1211 REGISTER_OP("Identity")
1212 .Input("input: T")
1213 .Output("output: T")
1214 .Attr("T: type")
1215 .SetShapeFn(shape_inference::UnchangedShape);
1216
1217 REGISTER_OP("Snapshot")
1218 .Input("input: T")
1219 .Output("output: T")
1220 .Attr("T: type")
1221 .SetShapeFn(shape_inference::UnchangedShape);
1222
1223 #ifdef INTEL_MKL
1224 REGISTER_OP("_MklIdentity")
1225 .Input("input: T")
1226 .Input("mkl_input: uint8")
1227 .Output("output: T")
1228 .Output("mkl_output: uint8")
1229 .Attr("T: type")
1230 .SetShapeFn(shape_inference::UnchangedShape)
1231 .Doc(R"Doc( Mkl implementation of IdentityOp
1232 )Doc");
1233 #endif
1234
1235 REGISTER_OP("IdentityN")
1236 .Input("input: T")
1237 .Output("output: T")
1238 .Attr("T: list(type)")
__anondb9326b21a02(shape_inference::InferenceContext* c) 1239 .SetShapeFn([](shape_inference::InferenceContext* c) {
1240 std::vector<ShapeHandle> input;
1241 TF_RETURN_IF_ERROR(c->input("input", &input));
1242 TF_RETURN_IF_ERROR(c->set_output("output", input));
1243 return Status::OK();
1244 });
1245
1246 // --------------------------------------------------------------------------
1247 REGISTER_OP("RefIdentity")
1248 .Input("input: Ref(T)")
1249 .Output("output: Ref(T)")
1250 .Attr("T: type")
1251 .SetShapeFn(shape_inference::UnchangedShape)
1252 .SetAllowsUninitializedInput();
1253
1254 // --------------------------------------------------------------------------
1255 REGISTER_OP("DebugGradientIdentity")
1256 .Input("input: T")
1257 .Output("output: T")
1258 .Attr("T: type")
1259 .SetShapeFn(shape_inference::UnchangedShape)
1260 .SetAllowsUninitializedInput();
1261
1262 REGISTER_OP("DebugGradientRefIdentity")
1263 .Input("input: Ref(T)")
1264 .Output("output: Ref(T)")
1265 .Attr("T: type")
1266 .SetShapeFn(shape_inference::UnchangedShape)
1267 .SetAllowsUninitializedInput();
1268
1269 // --------------------------------------------------------------------------
1270 REGISTER_OP("StopGradient")
1271 .Input("input: T")
1272 .Output("output: T")
1273 .Attr("T: type")
1274 .SetShapeFn(shape_inference::UnchangedShape);
1275
1276 REGISTER_OP("PreventGradient")
1277 .Input("input: T")
1278 .Output("output: T")
1279 .Attr("T: type")
1280 .Attr("message: string = ''")
1281 .SetShapeFn(shape_inference::UnchangedShape);
1282
1283 // --------------------------------------------------------------------------
1284 REGISTER_OP("CheckNumerics")
1285 .Input("tensor: T")
1286 .Output("output: T")
1287 .Attr("T: {bfloat16, half, float, double}")
1288 .Attr("message: string")
1289 .SetShapeFn(shape_inference::UnchangedShape);
1290
1291 // --------------------------------------------------------------------------
1292 REGISTER_OP("Reshape")
1293 .Input("tensor: T")
1294 .Input("shape: Tshape")
1295 .Output("output: T")
1296 .Attr("T: type")
1297 .Attr("Tshape: {int32, int64} = DT_INT32")
__anondb9326b21b02(InferenceContext* c) 1298 .SetShapeFn([](InferenceContext* c) {
1299 return SetOutputShapeForReshape(c);
1300 });
1301
1302 #ifdef INTEL_MKL
1303 REGISTER_OP("_MklReshape")
1304 .Input("tensor: T")
1305 .Input("shape: Tshape")
1306 .Input("mkl_tensor: uint8")
1307 .Input("mkl_shape: uint8")
1308 .Output("output: T")
1309 .Output("mkl_output: uint8")
1310 .Attr("T: type")
1311 .Attr("Tshape: {int32, int64} = DT_INT32")
__anondb9326b21c02(InferenceContext* c) 1312 .SetShapeFn([](InferenceContext* c) { return SetOutputShapeForReshape(c); })
1313 .Doc(R"Doc( MKL implementation of ReshapeOp.
1314 )Doc");
1315 #endif // INTEL_MKL
1316
1317 // --------------------------------------------------------------------------
1318 REGISTER_OP("InvertPermutation")
1319 .Input("x: T")
1320 .Output("y: T")
1321 .Attr("T: {int32, int64} = DT_INT32")
__anondb9326b21d02(InferenceContext* c) 1322 .SetShapeFn([](InferenceContext* c) {
1323 ShapeHandle x;
1324 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &x));
1325 c->set_output(0, x);
1326 return Status::OK();
1327 });
1328
1329 // --------------------------------------------------------------------------
1330 REGISTER_OP("Transpose")
1331 .Input("x: T")
1332 .Input("perm: Tperm")
1333 .Output("y: T")
1334 .Attr("T: type")
1335 .Attr("Tperm: {int32, int64} = DT_INT32")
1336 .SetShapeFn(TransposeShapeFn);
1337
1338 // --------------------------------------------------------------------------
1339 REGISTER_OP("ConjugateTranspose")
1340 .Input("x: T")
1341 .Input("perm: Tperm")
1342 .Output("y: T")
1343 .Attr("T: type")
1344 .Attr("Tperm: {int32, int64} = DT_INT32")
1345 .SetShapeFn(TransposeShapeFn);
1346
1347 // --------------------------------------------------------------------------
1348 REGISTER_OP("Unique")
1349 .Input("x: T")
1350 .Output("y: T")
1351 .Output("idx: out_idx")
1352 .Attr("T: type")
1353 .Attr("out_idx: {int32, int64} = DT_INT32")
__anondb9326b21e02(InferenceContext* c) 1354 .SetShapeFn([](InferenceContext* c) {
1355 c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
1356 c->set_output(1, c->input(0));
1357 // Assert that the input rank is 1.
1358 ShapeHandle dummy;
1359 return c->WithRank(c->input(0), 1, &dummy);
1360 });
1361
1362 REGISTER_OP("UniqueV2")
1363 .Input("x: T")
1364 .Input("axis: Taxis")
1365 .Output("y: T")
1366 .Output("idx: out_idx")
1367 .Attr("T: type")
1368 .Attr("Taxis: {int32,int64} = DT_INT64")
1369 .Attr("out_idx: {int32, int64} = DT_INT32")
__anondb9326b21f02(InferenceContext* c) 1370 .SetShapeFn([](InferenceContext* c) {
1371 c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
1372 c->set_output(1, c->input(0));
1373 return Status::OK();
1374 });
1375
1376 // --------------------------------------------------------------------------
1377 REGISTER_OP("UniqueWithCounts")
1378 .Input("x: T")
1379 .Output("y: T")
1380 .Output("idx: out_idx")
1381 .Output("count: out_idx")
1382 .Attr("T: type")
1383 .Attr("out_idx: {int32, int64} = DT_INT32")
__anondb9326b22002(InferenceContext* c) 1384 .SetShapeFn([](InferenceContext* c) {
1385 auto uniq = c->Vector(InferenceContext::kUnknownDim);
1386 c->set_output(0, uniq);
1387 c->set_output(1, c->input(0));
1388 c->set_output(2, uniq);
1389 return Status::OK();
1390 });
1391
1392 REGISTER_OP("UniqueWithCountsV2")
1393 .Input("x: T")
1394 .Input("axis: Taxis")
1395 .Output("y: T")
1396 .Output("idx: out_idx")
1397 .Output("count: out_idx")
1398 .Attr("T: type")
1399 .Attr("Taxis: {int32,int64} = DT_INT64")
1400 .Attr("out_idx: {int32, int64} = DT_INT32")
__anondb9326b22102(InferenceContext* c) 1401 .SetShapeFn([](InferenceContext* c) {
1402 auto uniq = c->Vector(InferenceContext::kUnknownDim);
1403 c->set_output(0, uniq);
1404 c->set_output(1, c->input(0));
1405 c->set_output(2, uniq);
1406 return Status::OK();
1407 });
1408
1409 namespace {
1410
ShapeShapeFn(InferenceContext * c)1411 Status ShapeShapeFn(InferenceContext* c) {
1412 for (int i = 0; i < c->num_inputs(); ++i) {
1413 DimensionHandle dim;
1414 if (c->RankKnown(c->input(i))) {
1415 dim = c->MakeDim(c->Rank(c->input(i)));
1416 } else {
1417 dim = c->UnknownDim();
1418 }
1419 c->set_output(i, c->Vector(dim));
1420 }
1421 return Status::OK();
1422 }
1423
1424 } // namespace
1425
1426 // --------------------------------------------------------------------------
1427 REGISTER_OP("Shape")
1428 .Input("input: T")
1429 .Output("output: out_type")
1430 .Attr("T: type")
1431 .Attr("out_type: {int32, int64} = DT_INT32")
1432 .SetShapeFn(ShapeShapeFn);
1433
1434 REGISTER_OP("ShapeN")
1435 .Input("input: N * T")
1436 .Output("output: N * out_type")
1437 .Attr("N: int")
1438 .Attr("T: type")
1439 .Attr("out_type: {int32, int64} = DT_INT32")
1440 .SetShapeFn(ShapeShapeFn);
1441
1442 REGISTER_OP("EnsureShape")
1443 .Input("input: T")
1444 .Output("output: T")
1445 .Attr("shape: shape")
1446 .Attr("T: type")
__anondb9326b22302(InferenceContext* c) 1447 .SetShapeFn([](InferenceContext* c) {
1448 // Merges desired shape and statically known shape of input
1449 PartialTensorShape desired_shape;
1450 TF_RETURN_IF_ERROR(c->GetAttr("shape", &desired_shape));
1451
1452 int rank = desired_shape.dims();
1453 ShapeHandle input_shape_handle;
1454 ShapeHandle desired_shape_handle;
1455 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape_handle));
1456 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
1457 desired_shape, &desired_shape_handle));
1458
1459 ShapeHandle merged_shape;
1460 TF_RETURN_IF_ERROR(
1461 c->Merge(desired_shape_handle, input_shape_handle, &merged_shape));
1462 c->set_output(0, merged_shape);
1463 return Status::OK();
1464 });
1465
1466 // --------------------------------------------------------------------------
1467 REGISTER_OP("ReverseSequence")
1468 .Input("input: T")
1469 .Input("seq_lengths: Tlen")
1470 .Output("output: T")
1471 .Attr("seq_dim: int")
1472 .Attr("batch_dim: int = 0")
1473 .Attr("T: type")
1474 .Attr("Tlen: {int32, int64} = DT_INT64")
__anondb9326b22402(InferenceContext* c) 1475 .SetShapeFn([](InferenceContext* c) {
1476 ShapeHandle input = c->input(0);
1477 ShapeHandle seq_lens_shape;
1478 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seq_lens_shape));
1479
1480 int64 seq_dim;
1481 TF_RETURN_IF_ERROR(c->GetAttr("seq_dim", &seq_dim));
1482 int64 batch_dim;
1483 TF_RETURN_IF_ERROR(c->GetAttr("batch_dim", &batch_dim));
1484
1485 if (!c->RankKnown(input)) {
1486 return shape_inference::UnknownShape(c);
1487 }
1488
1489 // Validate batch_dim and seq_dim against input.
1490 const int32 input_rank = c->Rank(input);
1491 if (batch_dim >= input_rank) {
1492 return errors::InvalidArgument(
1493 "batch_dim must be < input rank: ", batch_dim, " vs. ", input_rank);
1494 }
1495 if (seq_dim >= input_rank) {
1496 return errors::InvalidArgument(
1497 "seq_dim must be < input rank: ", seq_dim, " vs. ", input_rank);
1498 }
1499
1500 DimensionHandle batch_dim_dim = c->Dim(input, batch_dim);
1501 TF_RETURN_IF_ERROR(
1502 c->Merge(batch_dim_dim, c->Dim(seq_lens_shape, 0), &batch_dim_dim));
1503
1504 // Replace batch_dim of input with batch_size
1505 ShapeHandle output_shape;
1506 TF_RETURN_IF_ERROR(
1507 c->ReplaceDim(input, batch_dim, batch_dim_dim, &output_shape));
1508 c->set_output(0, output_shape);
1509 return Status::OK();
1510 });
1511
1512 // --------------------------------------------------------------------------
1513 REGISTER_OP("Rank")
1514 .Input("input: T")
1515 .Output("output: int32")
1516 .Attr("T: type")
1517 .SetShapeFn(shape_inference::ScalarShape);
1518
1519 // --------------------------------------------------------------------------
1520 REGISTER_OP("Size")
1521 .Input("input: T")
1522 .Output("output: out_type")
1523 .Attr("T: type")
1524 .Attr("out_type: {int32, int64} = DT_INT32")
1525 .SetShapeFn(shape_inference::ScalarShape);
1526
1527 // --------------------------------------------------------------------------
1528 REGISTER_OP("Slice")
1529 .Input("input: T")
1530 .Input("begin: Index")
1531 .Input("size: Index")
1532 .Output("output: T")
1533 .Attr("T: type")
1534 .Attr("Index: {int32,int64}")
1535 .SetShapeFn(shape_inference::SliceShape);
1536
1537 #ifdef INTEL_MKL
1538 REGISTER_OP("_MklSlice")
1539 .Input("input: T")
1540 .Input("begin: Index")
1541 .Input("size: Index")
1542 .Input("mkl_input: uint8")
1543 .Input("mkl_begin: uint8")
1544 .Input("mkl_size: uint8")
1545 .Output("output: T")
1546 .Output("mkl_output: uint8")
1547 .Attr("T: type")
1548 .Attr("Index: {int32,int64}")
1549 .SetShapeFn(shape_inference::SliceShape);
1550 #endif
1551
1552 REGISTER_OP("StridedSlice")
1553 .Input("input: T")
1554 .Input("begin: Index")
1555 .Input("end: Index")
1556 .Input("strides: Index")
1557 .Output("output: T")
1558 .Attr("T: type")
1559 .Attr("Index: {int32, int64}")
1560 .Attr("begin_mask: int = 0")
1561 .Attr("end_mask: int = 0")
1562 .Attr("ellipsis_mask: int = 0")
1563 .Attr("new_axis_mask: int = 0")
1564 .Attr("shrink_axis_mask: int = 0")
__anondb9326b22502(InferenceContext* c) 1565 .SetShapeFn([](InferenceContext* c) {
1566 ShapeHandle input = c->input(0);
1567 ShapeHandle begin_shape, end_shape, strides_shape;
1568 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape));
1569 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &end_shape));
1570 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &strides_shape));
1571 TF_RETURN_IF_ERROR(c->Merge(begin_shape, end_shape, &begin_shape));
1572 TF_RETURN_IF_ERROR(c->Merge(begin_shape, strides_shape, &begin_shape));
1573 DimensionHandle sparse_dims_dim = c->Dim(begin_shape, 0);
1574
1575 const Tensor* strides_value = c->input_tensor(3);
1576 // TODO(aselle,allenl): If we had a stride_mask it would be possible to do
1577 // more shape inference here (e.g. for x[3, ::T]).
1578 if (!c->RankKnown(input) || !c->ValueKnown(sparse_dims_dim) ||
1579 strides_value == nullptr) {
1580 c->set_output(0, c->UnknownShape());
1581 return Status::OK();
1582 }
1583
1584 PartialTensorShape input_shape({});
1585 for (int i = 0; i < c->Rank(input); ++i) {
1586 auto dim = c->Dim(input, i);
1587 input_shape.AddDim(c->ValueKnown(dim) ? c->Value(dim) : -1);
1588 }
1589
1590 int32 begin_mask, end_mask, ellipsis_mask, new_axis_mask,
1591 shrink_axis_mask;
1592 TF_RETURN_IF_ERROR(c->GetAttr("begin_mask", &begin_mask));
1593 TF_RETURN_IF_ERROR(c->GetAttr("end_mask", &end_mask));
1594 TF_RETURN_IF_ERROR(c->GetAttr("ellipsis_mask", &ellipsis_mask));
1595 TF_RETURN_IF_ERROR(c->GetAttr("new_axis_mask", &new_axis_mask));
1596 TF_RETURN_IF_ERROR(c->GetAttr("shrink_axis_mask", &shrink_axis_mask));
1597
1598 const Tensor* begin_value = c->input_tensor(1);
1599 const Tensor* end_value = c->input_tensor(2);
1600
1601 PartialTensorShape processing_shape, final_shape;
1602 bool is_identity, is_simple_slice, slice_dim0;
1603 gtl::InlinedVector<int64, 4> begin, end, strides;
1604 TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
1605 begin_value, end_value, *strides_value, input_shape, begin_mask,
1606 end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask,
1607 &processing_shape, &final_shape, &is_identity, &is_simple_slice,
1608 &slice_dim0, &begin, &end, &strides));
1609
1610 ShapeHandle out;
1611 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(final_shape, &out));
1612 c->set_output(0, out);
1613
1614 auto* shape_and_type = c->input_handle_shapes_and_types(0);
1615 if (shape_and_type) {
1616 c->set_output_handle_shapes_and_types(0, *shape_and_type);
1617 }
1618
1619 return Status::OK();
1620 });
1621
1622 REGISTER_OP("StridedSliceGrad")
1623 .Input("shape: Index")
1624 .Input("begin: Index")
1625 .Input("end: Index")
1626 .Input("strides: Index")
1627 .Input("dy: T")
1628 .Output("output: T")
1629 .Attr("T: type")
1630 .Attr("Index: {int32, int64}")
1631 .Attr("begin_mask: int = 0")
1632 .Attr("end_mask: int = 0")
1633 .Attr("ellipsis_mask: int = 0")
1634 .Attr("new_axis_mask: int = 0")
1635 .Attr("shrink_axis_mask: int = 0")
__anondb9326b22602(InferenceContext* c) 1636 .SetShapeFn([](InferenceContext* c) {
1637 ShapeHandle out;
1638 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
1639 c->set_output(0, out);
1640 return Status::OK();
1641 });
1642
1643 REGISTER_OP("StridedSliceAssign")
1644 .Input("ref: Ref(T)")
1645 .Input("begin: Index")
1646 .Input("end: Index")
1647 .Input("strides: Index")
1648 .Input("value: T")
1649 .Output("output_ref: Ref(T)")
1650 .Attr("T: type")
1651 .Attr("Index: {int32, int64}")
1652 .Attr("begin_mask: int = 0")
1653 .Attr("end_mask: int = 0")
1654 .Attr("ellipsis_mask: int = 0")
1655 .Attr("new_axis_mask: int = 0")
1656 .Attr("shrink_axis_mask: int = 0")
1657 .SetShapeFn(shape_inference::UnchangedShape);
1658 // TODO(aselle): Fix this documentation once StridedSliceAssign Supports
1659 // broadcasting.
1660 // --------------------------------------------------------------------------
1661
1662 REGISTER_OP("ResourceStridedSliceAssign")
1663 .Input("ref: resource")
1664 .Input("begin: Index")
1665 .Input("end: Index")
1666 .Input("strides: Index")
1667 .Input("value: T")
1668 .Attr("T: type")
1669 .Attr("Index: {int32, int64}")
1670 .Attr("begin_mask: int = 0")
1671 .Attr("end_mask: int = 0")
1672 .Attr("ellipsis_mask: int = 0")
1673 .Attr("new_axis_mask: int = 0")
1674 .Attr("shrink_axis_mask: int = 0")
1675 .SetShapeFn(shape_inference::NoOutputs);
1676
1677 REGISTER_OP("Tile")
1678 .Input("input: T")
1679 .Input("multiples: Tmultiples")
1680 .Output("output: T")
1681 .Attr("T: type")
1682 .Attr("Tmultiples: {int32, int64} = DT_INT32")
__anondb9326b22702(InferenceContext* c) 1683 .SetShapeFn([](InferenceContext* c) {
1684 ShapeHandle input = c->input(0);
1685 // NOTE(mrry): Represent `multiples` as a `TensorShape` because (i)
1686 // it is a vector of non-negative integers, and (ii) doing so allows
1687 // us to handle partially-known multiples.
1688 ShapeHandle multiples;
1689 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &multiples));
1690 if (c->RankKnown(input)) {
1691 TF_RETURN_IF_ERROR(c->WithRank(multiples, c->Rank(input), &multiples));
1692 ShapeHandle dummy;
1693 TF_RETURN_IF_ERROR(
1694 c->Merge(c->input(1), c->Vector(c->Rank(input)), &dummy));
1695 }
1696
1697 if (!c->RankKnown(multiples)) {
1698 return shape_inference::UnknownShape(c);
1699 }
1700
1701 int32 rank = c->Rank(multiples);
1702 TF_RETURN_IF_ERROR(c->WithRank(input, rank, &input));
1703 std::vector<DimensionHandle> dims(rank);
1704 for (int i = 0; i < rank; ++i) {
1705 TF_RETURN_IF_ERROR(
1706 c->Multiply(c->Dim(input, i), c->Dim(multiples, i), &dims[i]));
1707 }
1708 c->set_output(0, c->MakeShape(dims));
1709 return Status::OK();
1710 });
1711
1712 // --------------------------------------------------------------------------
1713 REGISTER_OP("TileGrad")
1714 .Input("input: T")
1715 .Input("multiples: int32")
1716 .Output("output: T")
1717 .Attr("T: type")
1718 .Deprecated(3, "TileGrad has been replaced with reduce_sum")
1719 .SetShapeFn(tensorflow::shape_inference::UnknownShape);
1720
1721 // --------------------------------------------------------------------------
1722 REGISTER_OP("Where")
1723 .Input("input: T")
1724 .Attr("T: {numbertype, bool} = DT_BOOL")
1725 .Output("index: int64")
__anondb9326b22802(InferenceContext* c) 1726 .SetShapeFn([](InferenceContext* c) {
1727 c->set_output(0, c->Matrix(c->UnknownDim(), c->Rank(c->input(0))));
1728 return Status::OK();
1729 });
1730
1731 // --------------------------------------------------------------------------
1732 REGISTER_OP("BroadcastArgs")
1733 .Input("s0: T")
1734 .Input("s1: T")
1735 .Output("r0: T")
1736 .Attr("T: {int32, int64} = DT_INT32")
__anondb9326b22902(InferenceContext* c) 1737 .SetShapeFn([](InferenceContext* c) {
1738 ShapeHandle unused;
1739 ShapeHandle shape_x = c->input(0);
1740 ShapeHandle shape_y = c->input(1);
1741 TF_RETURN_IF_ERROR(c->WithRank(shape_x, 1, &unused));
1742 TF_RETURN_IF_ERROR(c->WithRank(shape_y, 1, &unused));
1743
1744 if (!c->ValueKnown(c->Dim(shape_x, 0)) ||
1745 !c->ValueKnown(c->Dim(shape_y, 0))) {
1746 c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
1747 return Status::OK();
1748 }
1749
1750 int64 x_dim = c->Value(c->Dim(shape_x, 0));
1751 int64 y_dim = c->Value(c->Dim(shape_y, 0));
1752
1753 // Broadcasted shape is going to be as large as the largest dimension.
1754 c->set_output(0, c->Vector(std::max(x_dim, y_dim)));
1755 return Status::OK();
1756 });
1757
1758 // --------------------------------------------------------------------------
1759 REGISTER_OP("BroadcastGradientArgs")
1760 .Input("s0: T")
1761 .Input("s1: T")
1762 .Output("r0: T")
1763 .Output("r1: T")
1764 .Attr("T: {int32, int64} = DT_INT32")
__anondb9326b22a02(InferenceContext* c) 1765 .SetShapeFn([](InferenceContext* c) {
1766 // TODO(mrry): Implement constant_value for BroadcastGradientArgs?
1767 ShapeHandle unused;
1768 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
1769 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
1770 c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
1771 c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
1772 return Status::OK();
1773 });
1774
1775 // --------------------------------------------------------------------------
1776 REGISTER_OP("Pad")
1777 .Input("input: T")
1778 .Input("paddings: Tpaddings")
1779 .Output("output: T")
1780 .Attr("T: type")
1781 .Attr("Tpaddings: {int32, int64} = DT_INT32")
1782 .SetShapeFn(PadShapeFn);
1783
1784 // --------------------------------------------------------------------------
1785 REGISTER_OP("PadV2")
1786 .Input("input: T")
1787 .Input("paddings: Tpaddings")
1788 .Input("constant_values: T")
1789 .Output("output: T")
1790 .Attr("T: type")
1791 .Attr("Tpaddings: {int32, int64} = DT_INT32")
1792 .SetShapeFn(PadShapeFn);
1793
1794 // --------------------------------------------------------------------------
1795 REGISTER_OP("MirrorPad")
1796 .Input("input: T")
1797 .Input("paddings: Tpaddings")
1798 .Output("output: T")
1799 .Attr("T: type")
1800 .Attr("Tpaddings: {int32, int64} = DT_INT32")
1801 .Attr(GetMirrorPadModeAttrString())
1802 .SetShapeFn(PadShapeFn);
1803
1804 // --------------------------------------------------------------------------
1805 namespace {
1806 template <typename T>
MirrorPadKnown(InferenceContext * c,ShapeHandle input,const Tensor * paddings_t,int64 input_rank)1807 Status MirrorPadKnown(InferenceContext* c, ShapeHandle input,
1808 const Tensor* paddings_t, int64 input_rank) {
1809 auto paddings_data = paddings_t->matrix<T>();
1810 std::vector<DimensionHandle> dims(input_rank);
1811 for (int64 i = 0; i < input_rank; ++i) {
1812 const int64 pad0 = static_cast<int64>(paddings_data(i, 0));
1813 const int64 pad1 = static_cast<int64>(paddings_data(i, 1));
1814 if (pad0 < 0 || pad1 < 0) {
1815 return errors::InvalidArgument("Paddings must be non-negative");
1816 }
1817
1818 TF_RETURN_IF_ERROR(c->Subtract(c->Dim(input, i), pad0 + pad1, &dims[i]));
1819 }
1820 c->set_output(0, c->MakeShape(dims));
1821 return Status::OK();
1822 }
1823
1824 } // namespace
1825
1826 REGISTER_OP("MirrorPadGrad")
1827 .Input("input: T")
1828 .Input("paddings: Tpaddings")
1829 .Output("output: T")
1830 .Attr("T: type")
1831 .Attr("Tpaddings: {int32, int64} = DT_INT32")
1832 .Attr(GetMirrorPadModeAttrString())
__anondb9326b22c02(InferenceContext* c) 1833 .SetShapeFn([](InferenceContext* c) {
1834 ShapeHandle paddings;
1835 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &paddings));
1836 DimensionHandle pad_0 = c->Dim(paddings, 0);
1837 if (!c->ValueKnown(pad_0)) {
1838 // We don't know the rank of the output since the first
1839 // padding dimension is unknown.
1840 c->set_output(0, c->UnknownShape());
1841 return Status::OK();
1842 }
1843
1844 int64 input_rank = c->Value(pad_0);
1845 ShapeHandle input;
1846 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), input_rank, &input));
1847 TF_RETURN_IF_ERROR(
1848 c->Merge(paddings, c->Matrix(input_rank, 2), &paddings));
1849
1850 const Tensor* paddings_t = c->input_tensor(1);
1851 if (paddings_t == nullptr) {
1852 // Values of 'paddings' is not available, but we know the
1853 // input rank, so return the rank of the output with unknown
1854 // dimensions.
1855 c->set_output(0, c->UnknownShapeOfRank(input_rank));
1856 return Status::OK();
1857 }
1858
1859 if (paddings_t->dtype() == DT_INT32) {
1860 return MirrorPadKnown<int32>(c, input, paddings_t, input_rank);
1861 } else {
1862 return MirrorPadKnown<int64>(c, input, paddings_t, input_rank);
1863 }
1864 });
1865
1866 // --------------------------------------------------------------------------
1867 REGISTER_OP("Placeholder")
1868 .Output("output: dtype")
1869 .Attr("dtype: type")
1870 .Attr("shape: shape = { unknown_rank: true }")
__anondb9326b22d02(InferenceContext* c) 1871 .SetShapeFn([](InferenceContext* c) {
1872 PartialTensorShape shape;
1873 TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
1874
1875 // Placeholder has legacy behavior where we cannot tell the difference
1876 // between a scalar shape attribute and 'unknown shape'. So if the shape
1877 // is a scalar, we return an unknown shape.
1878 if (c->graph_def_version() <= 21 && shape.dims() <= 0) {
1879 return shape_inference::UnknownShape(c);
1880 }
1881
1882 ShapeHandle out;
1883 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out));
1884 c->set_output(0, out);
1885 return Status::OK();
1886 });
1887
1888 // Placeholder was modified in a backwards compatible way to do what
1889 // PlaceholderV2 did, so we have deprecated V2 (no one was really
1890 // using it).
1891 REGISTER_OP("PlaceholderV2")
1892 .Output("output: dtype")
1893 .Attr("dtype: type")
1894 .Attr("shape: shape")
1895 .SetShapeFn(shape_inference::ExplicitShape)
1896 .Deprecated(23, "Placeholder now behaves the same as PlaceholderV2.");
1897
1898 // --------------------------------------------------------------------------
1899 REGISTER_OP("PlaceholderWithDefault")
1900 .Input("input: dtype")
1901 .Output("output: dtype")
1902 .Attr("dtype: type")
1903 .Attr("shape: shape")
__anondb9326b22e02(InferenceContext* c) 1904 .SetShapeFn([](InferenceContext* c) {
1905 ShapeHandle input = c->input(0);
1906 PartialTensorShape shape;
1907 TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
1908 ShapeHandle out;
1909 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out));
1910
1911 // We merge for compatibility checking, but return the output,
1912 // since output_shape may be less precise than input_shape.
1913 ShapeHandle unused;
1914 TF_RETURN_IF_ERROR(c->Merge(input, out, &unused));
1915 c->set_output(0, out);
1916 return Status::OK();
1917 });
1918
1919 // --------------------------------------------------------------------------
1920 REGISTER_OP("ExpandDims")
1921 .Input("input: T")
1922 .Input("dim: Tdim")
1923 .Output("output: T")
1924 .Attr("T: type")
1925 .Attr("Tdim: {int32, int64} = DT_INT32")
__anondb9326b22f02(InferenceContext* c) 1926 .SetShapeFn([](InferenceContext* c) {
1927 ShapeHandle input = c->input(0);
1928
1929 const Tensor* dim_t = c->input_tensor(1);
1930 if (dim_t != nullptr && dim_t->NumElements() != 1) {
1931 return errors::InvalidArgument(
1932 "'dim' input must be a tensor with a single value");
1933 }
1934 if (dim_t == nullptr || !c->RankKnown(input)) {
1935 c->set_output(0, c->UnknownShape());
1936 return Status::OK();
1937 }
1938
1939 int64 dim;
1940 if (dim_t->dtype() == DT_INT32) {
1941 dim = static_cast<int64>(dim_t->flat<int32>()(0));
1942 } else {
1943 dim = dim_t->flat<int64>()(0);
1944 }
1945
1946 const int32 rank = c->Rank(input);
1947 const int32 min_dim = -1 * rank - 1;
1948 if (dim < min_dim || dim > rank) {
1949 return errors::InvalidArgument("dim ", dim, " not in the interval [",
1950 min_dim, ", ", rank, "].");
1951 }
1952
1953 if (dim < 0) {
1954 dim += rank + 1;
1955 }
1956
1957 ShapeHandle end;
1958 TF_RETURN_IF_ERROR(c->Subshape(input, dim, &end));
1959
1960 // Build output as start + 1 + end.
1961 ShapeHandle output;
1962 TF_RETURN_IF_ERROR(c->Subshape(input, 0, dim, &output));
1963 TF_RETURN_IF_ERROR(c->Concatenate(output, c->Vector(1), &output));
1964 TF_RETURN_IF_ERROR(c->Concatenate(output, end, &output));
1965 c->set_output(0, output);
1966 return Status::OK();
1967 });
1968
1969 // --------------------------------------------------------------------------
1970 REGISTER_OP("Squeeze")
1971 .Input("input: T")
1972 .Output("output: T")
1973 .Attr("T: type")
1974 .Attr("squeeze_dims: list(int) >= 0 = []")
__anondb9326b23002(InferenceContext* c) 1975 .SetShapeFn([](InferenceContext* c) {
1976 ShapeHandle input = c->input(0);
1977 if (!c->RankKnown(input)) {
1978 // Input shape unknown.
1979 return shape_inference::UnknownShape(c);
1980 }
1981
1982 const int32 input_rank = c->Rank(input);
1983
1984 // Validate and wrap squeeze dimensions.
1985 std::vector<int32> squeeze_dims;
1986 TF_RETURN_IF_ERROR(c->GetAttr("squeeze_dims", &squeeze_dims));
1987 for (int i = 0; i < squeeze_dims.size(); ++i) {
1988 if (squeeze_dims[i] < -input_rank || squeeze_dims[i] >= input_rank) {
1989 return errors::InvalidArgument("squeeze_dims[", i, "] not in [",
1990 -input_rank, ",", input_rank, ").");
1991 }
1992
1993 if (squeeze_dims[i] < 0) {
1994 squeeze_dims[i] += input_rank;
1995 }
1996 }
1997
1998 std::vector<DimensionHandle> result_shape;
1999 for (int i = 0; i < input_rank; ++i) {
2000 // True if squeeze_dims contains an entry to squeeze this
2001 // dimension.
2002 bool is_explicit_match =
2003 std::find(squeeze_dims.begin(), squeeze_dims.end(), i) !=
2004 squeeze_dims.end();
2005
2006 DimensionHandle dim = c->Dim(input, i);
2007
2008 if (!c->ValueKnown(dim)) {
2009 // Assume that the squeezed dimension will be 1 at runtime.
2010 if (is_explicit_match) continue;
2011
2012 // If squeezing all 1 dimensions, and we see an unknown value,
2013 // give up and return Unknown Shape.
2014 if (squeeze_dims.empty()) {
2015 c->set_output(0, c->UnknownShape());
2016 return Status::OK();
2017 }
2018 } else if (c->Value(dim) == 1) {
2019 if (is_explicit_match || squeeze_dims.empty()) {
2020 // If explicitly squeezing, or squeezing all 1s, remove
2021 // this dimension.
2022 continue;
2023 }
2024 } else if (is_explicit_match) {
2025 return errors::InvalidArgument("Can not squeeze dim[", i,
2026 "], expected a dimension of 1, got ",
2027 c->Value(c->Dim(input, i)));
2028 }
2029
2030 result_shape.emplace_back(dim);
2031 }
2032
2033 c->set_output(0, c->MakeShape(result_shape));
2034 return Status::OK();
2035 });
2036
2037 // --------------------------------------------------------------------------
2038 REGISTER_OP("ListDiff")
2039 .Input("x: T")
2040 .Input("y: T")
2041 .Output("out: T")
2042 .Output("idx: out_idx")
2043 .Attr("T: type")
2044 .Attr("out_idx: {int32, int64} = DT_INT32")
__anondb9326b23102(InferenceContext* c) 2045 .SetShapeFn([](InferenceContext* c) {
2046 ShapeHandle unused;
2047 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
2048 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
2049 // TODO(mrry): Indicate that the length falls within an interval?
2050 ShapeHandle out = c->Vector(InferenceContext::kUnknownDim);
2051 c->set_output(0, out);
2052 c->set_output(1, out);
2053 return Status::OK();
2054 });
2055
2056 namespace {
2057
2058 // Converts Tensor to flat std::vector<int64>.
2059 template <typename InputType>
GetFlatInt64(const Tensor & t)2060 std::vector<int64> GetFlatInt64(const Tensor& t) {
2061 std::vector<int64> output(t.shape().num_elements());
2062 auto eigen_vec = t.flat<InputType>();
2063 std::copy_n(&eigen_vec(0), output.size(), output.begin());
2064 return output;
2065 }
2066
2067 // Converts int32 or int64 Tensor to flat std::vector<int64>.
GetFlatInt64(const Tensor & t)2068 std::vector<int64> GetFlatInt64(const Tensor& t) {
2069 if (t.dtype() == DT_INT32) {
2070 return GetFlatInt64<int32>(t);
2071 } else {
2072 return GetFlatInt64<int64>(t);
2073 }
2074 }
2075
SpaceToBatchShapeHelper(InferenceContext * c,ShapeHandle input_shape,ShapeHandle block_shape_shape,const Tensor * block_shape_t,ShapeHandle paddings_shape,const Tensor * paddings_t)2076 Status SpaceToBatchShapeHelper(InferenceContext* c, ShapeHandle input_shape,
2077 ShapeHandle block_shape_shape,
2078 const Tensor* block_shape_t,
2079 ShapeHandle paddings_shape,
2080 const Tensor* paddings_t) {
2081 if (c->Rank(block_shape_shape) != 1) {
2082 return errors::InvalidArgument("block_shape must have rank 1.");
2083 }
2084
2085 const DimensionHandle num_block_dims_handle = c->Dim(block_shape_shape, 0);
2086 if (!c->ValueKnown(num_block_dims_handle)) {
2087 return errors::InvalidArgument("block_shape must have known size.");
2088 }
2089
2090 const int64 num_block_dims = c->Value(num_block_dims_handle);
2091
2092 TF_RETURN_IF_ERROR(
2093 c->WithRankAtLeast(input_shape, num_block_dims + 1, &input_shape));
2094
2095 TF_RETURN_IF_ERROR(
2096 c->Merge(paddings_shape, c->Matrix(num_block_dims, 2), &paddings_shape));
2097
2098 DimensionHandle batch_size = c->Dim(input_shape, 0);
2099 std::vector<int64> block_shape_vec;
2100 if (block_shape_t) {
2101 block_shape_vec = GetFlatInt64(*block_shape_t);
2102 for (int64 dim = 0; dim < num_block_dims; ++dim) {
2103 const int64 block_shape_value = block_shape_vec[dim];
2104 if (block_shape_value < 1) {
2105 return errors::InvalidArgument("block_shape must be positive");
2106 }
2107 if (c->ValueKnown(batch_size)) {
2108 TF_RETURN_IF_ERROR(
2109 c->Multiply(batch_size, block_shape_value, &batch_size));
2110 } else {
2111 batch_size = c->UnknownDim();
2112 }
2113 }
2114 } else if (num_block_dims > 0) {
2115 batch_size = c->UnknownDim();
2116 }
2117
2118 std::vector<DimensionHandle> output_dims{batch_size};
2119 output_dims.resize(num_block_dims + 1, c->UnknownDim());
2120
2121 if (paddings_t) {
2122 const std::vector<int64> paddings_vec = GetFlatInt64(*paddings_t);
2123 for (int64 dim = 0; dim < num_block_dims; ++dim) {
2124 const int64 pad_start = paddings_vec[dim * 2],
2125 pad_end = paddings_vec[dim * 2 + 1];
2126 if (pad_start < 0 || pad_end < 0) {
2127 return errors::InvalidArgument("paddings cannot be negative");
2128 }
2129 if (block_shape_t) {
2130 DimensionHandle padded_size;
2131 TF_RETURN_IF_ERROR(
2132 c->Add(c->Dim(input_shape, dim + 1), pad_start, &padded_size));
2133 TF_RETURN_IF_ERROR(c->Add(padded_size, pad_end, &padded_size));
2134 TF_RETURN_IF_ERROR(c->Divide(padded_size, block_shape_vec[dim],
2135 /*evenly_divisible=*/true,
2136 &output_dims[dim + 1]));
2137 }
2138 }
2139 }
2140
2141 ShapeHandle remaining_input_shape;
2142 TF_RETURN_IF_ERROR(
2143 c->Subshape(input_shape, 1 + num_block_dims, &remaining_input_shape));
2144
2145 ShapeHandle result;
2146 TF_RETURN_IF_ERROR(c->Concatenate(c->MakeShape(output_dims),
2147 remaining_input_shape, &result));
2148 c->set_output(0, result);
2149 return Status::OK();
2150 }
2151
BatchToSpaceShapeHelper(InferenceContext * c,ShapeHandle input_shape,ShapeHandle block_shape_shape,const Tensor * block_shape_t,ShapeHandle crops_shape,const Tensor * crops_t)2152 Status BatchToSpaceShapeHelper(InferenceContext* c, ShapeHandle input_shape,
2153 ShapeHandle block_shape_shape,
2154 const Tensor* block_shape_t,
2155 ShapeHandle crops_shape, const Tensor* crops_t) {
2156 if (c->Rank(block_shape_shape) != 1) {
2157 return errors::InvalidArgument("block_shape must have rank 1.");
2158 }
2159
2160 const DimensionHandle num_block_dims_handle = c->Dim(block_shape_shape, 0);
2161 if (!c->ValueKnown(num_block_dims_handle)) {
2162 return errors::InvalidArgument("block_shape must have known size.");
2163 }
2164
2165 const int64 num_block_dims = c->Value(num_block_dims_handle);
2166
2167 TF_RETURN_IF_ERROR(
2168 c->WithRankAtLeast(input_shape, num_block_dims + 1, &input_shape));
2169
2170 TF_RETURN_IF_ERROR(
2171 c->Merge(crops_shape, c->Matrix(num_block_dims, 2), &crops_shape));
2172
2173 DimensionHandle batch_size = c->Dim(input_shape, 0);
2174 std::vector<int64> block_shape_vec;
2175 if (block_shape_t) {
2176 block_shape_vec = GetFlatInt64(*block_shape_t);
2177 for (int64 dim = 0; dim < num_block_dims; ++dim) {
2178 const int64 block_shape_value = block_shape_vec[dim];
2179 if (block_shape_value < 1) {
2180 return errors::InvalidArgument("block_shape must be positive");
2181 }
2182 if (c->ValueKnown(batch_size)) {
2183 TF_RETURN_IF_ERROR(c->Divide(batch_size, block_shape_value,
2184 /*evenly_divisible=*/true, &batch_size));
2185 } else {
2186 batch_size = c->UnknownDim();
2187 }
2188 }
2189 } else if (num_block_dims > 0) {
2190 batch_size = c->UnknownDim();
2191 }
2192
2193 std::vector<DimensionHandle> output_dims{batch_size};
2194 output_dims.resize(num_block_dims + 1, c->UnknownDim());
2195
2196 if (crops_t) {
2197 const std::vector<int64> crops_vec = GetFlatInt64(*crops_t);
2198 for (int64 dim = 0; dim < num_block_dims; ++dim) {
2199 const int64 crop_start = crops_vec[dim * 2],
2200 crop_end = crops_vec[dim * 2 + 1];
2201 if (crop_start < 0 || crop_end < 0) {
2202 return errors::InvalidArgument("crops cannot be negative");
2203 }
2204 if (block_shape_t) {
2205 DimensionHandle cropped_size;
2206 TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input_shape, dim + 1),
2207 block_shape_vec[dim], &cropped_size));
2208 TF_RETURN_IF_ERROR(
2209 c->Subtract(cropped_size, crop_start, &cropped_size));
2210 TF_RETURN_IF_ERROR(
2211 c->Subtract(cropped_size, crop_end, &output_dims[dim + 1]));
2212 }
2213 }
2214 }
2215
2216 ShapeHandle remaining_input_shape;
2217 TF_RETURN_IF_ERROR(
2218 c->Subshape(input_shape, 1 + num_block_dims, &remaining_input_shape));
2219
2220 ShapeHandle result;
2221 TF_RETURN_IF_ERROR(c->Concatenate(c->MakeShape(output_dims),
2222 remaining_input_shape, &result));
2223 c->set_output(0, result);
2224 return Status::OK();
2225 }
2226
2227 } // namespace
2228
2229 // --------------------------------------------------------------------------
2230 REGISTER_OP("SpaceToBatchND")
2231 .Input("input: T")
2232 .Input("block_shape: Tblock_shape")
2233 .Input("paddings: Tpaddings")
2234 .Output("output: T")
2235 .Attr("T: type")
2236 .Attr("Tblock_shape: {int32, int64} = DT_INT32")
2237 .Attr("Tpaddings: {int32, int64} = DT_INT32")
__anondb9326b23302(InferenceContext* c) 2238 .SetShapeFn([](InferenceContext* c) {
2239 return SpaceToBatchShapeHelper(c, c->input(0), c->input(1),
2240 c->input_tensor(1), c->input(2),
2241 c->input_tensor(2));
2242 });
2243
2244 // --------------------------------------------------------------------------
2245 REGISTER_OP("SpaceToBatch")
2246 .Input("input: T")
2247 .Input("paddings: Tpaddings")
2248 .Output("output: T")
2249 .Attr("T: type")
2250 .Attr("Tpaddings: {int32, int64} = DT_INT32")
2251 .Attr("block_size: int >= 2")
__anondb9326b23402(InferenceContext* c) 2252 .SetShapeFn([](InferenceContext* c) {
2253 ShapeHandle input_shape;
2254 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
2255
2256 int32 block_size;
2257 TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
2258
2259 Tensor block_shape(tensorflow::DT_INT64, TensorShape({2}));
2260 auto block_shape_vec = block_shape.vec<int64>();
2261 block_shape_vec(0) = block_size;
2262 block_shape_vec(1) = block_size;
2263
2264 return SpaceToBatchShapeHelper(c, input_shape, c->MakeShape({2}),
2265 &block_shape, c->input(1),
2266 c->input_tensor(1));
2267 });
2268
2269 // --------------------------------------------------------------------------
2270 REGISTER_OP("BatchToSpaceND")
2271 .Input("input: T")
2272 .Input("block_shape: Tblock_shape")
2273 .Input("crops: Tcrops")
2274 .Output("output: T")
2275 .Attr("T: type")
2276 .Attr("Tblock_shape: {int32, int64} = DT_INT32")
2277 .Attr("Tcrops: {int32, int64} = DT_INT32")
__anondb9326b23502(InferenceContext* c) 2278 .SetShapeFn([](InferenceContext* c) {
2279 return BatchToSpaceShapeHelper(c, c->input(0), c->input(1),
2280 c->input_tensor(1), c->input(2),
2281 c->input_tensor(2));
2282 });
2283
2284 // --------------------------------------------------------------------------
2285 REGISTER_OP("BatchToSpace")
2286 .Input("input: T")
2287 .Input("crops: Tidx")
2288 .Output("output: T")
2289 .Attr("T: type")
2290 .Attr("block_size: int >= 2")
2291 .Attr("Tidx: {int32, int64} = DT_INT32")
__anondb9326b23602(InferenceContext* c) 2292 .SetShapeFn([](InferenceContext* c) {
2293 ShapeHandle input_shape;
2294 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
2295
2296 int32 block_size;
2297 TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
2298
2299 Tensor block_shape(tensorflow::DT_INT64, TensorShape({2}));
2300 auto block_shape_vec = block_shape.vec<int64>();
2301 block_shape_vec(0) = block_size;
2302 block_shape_vec(1) = block_size;
2303
2304 return BatchToSpaceShapeHelper(c, input_shape, c->MakeShape({2}),
2305 &block_shape, c->input(1),
2306 c->input_tensor(1));
2307 });
2308
2309 // --------------------------------------------------------------------------
2310 REGISTER_OP("SpaceToDepth")
2311 .Input("input: T")
2312 .Output("output: T")
2313 .Attr("T: type")
2314 .Attr("block_size: int >= 2")
2315 .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
2316 // TODO(pauldonnelly): Implement GPU kernels for NCHW_VECT_C.
__anondb9326b23702(InferenceContext* c) 2317 .SetShapeFn([](InferenceContext* c) {
2318 string data_format_str;
2319 TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
2320 TensorFormat data_format;
2321 FormatFromString(data_format_str, &data_format);
2322
2323 constexpr int num_spatial_dims = 2;
2324 const int dims =
2325 GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
2326 ShapeHandle input;
2327 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), dims, &input));
2328
2329 int32 block_size;
2330 TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
2331
2332 DimensionHandle batch_size =
2333 c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
2334 DimensionHandle input_height =
2335 c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
2336 DimensionHandle input_width =
2337 c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
2338 DimensionHandle input_depth =
2339 c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
2340
2341 DimensionHandle output_height;
2342 DimensionHandle output_width;
2343 DimensionHandle output_depth;
2344 // Will return an error if input height or width are not evenly divisible.
2345 TF_RETURN_IF_ERROR(c->Divide(input_height, block_size,
2346 true /* evenly_divisible */,
2347 &output_height));
2348 TF_RETURN_IF_ERROR(c->Divide(input_width, block_size,
2349 true /* evenly_divisible */, &output_width));
2350
2351 TF_RETURN_IF_ERROR(
2352 c->Multiply(input_depth, block_size * block_size, &output_depth));
2353
2354 ShapeHandle output_shape;
2355 TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size,
2356 {output_height, output_width},
2357 output_depth, &output_shape, c));
2358
2359 c->set_output(0, output_shape);
2360 return Status::OK();
2361 });
2362
2363 // --------------------------------------------------------------------------
2364 REGISTER_OP("DepthToSpace")
2365 .Input("input: T")
2366 .Output("output: T")
2367 .Attr("T: type")
2368 .Attr("block_size: int >= 2")
2369 .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
2370 // TODO(pauldonnelly): Implement GPU kernels for NCHW and NCHW_VECT_C.
__anondb9326b23802(InferenceContext* c) 2371 .SetShapeFn([](InferenceContext* c) {
2372 string data_format_str;
2373 TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
2374 TensorFormat data_format;
2375 FormatFromString(data_format_str, &data_format);
2376
2377 constexpr int num_spatial_dims = 2;
2378 const int dims =
2379 GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
2380
2381 ShapeHandle input;
2382 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), dims, &input));
2383
2384 int32 block_size;
2385 TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
2386
2387 DimensionHandle batch_size =
2388 c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
2389 DimensionHandle input_height =
2390 c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
2391 DimensionHandle input_width =
2392 c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
2393 DimensionHandle input_depth =
2394 c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
2395
2396 DimensionHandle output_height;
2397 DimensionHandle output_width;
2398 DimensionHandle output_depth;
2399 TF_RETURN_IF_ERROR(c->Multiply(input_height, block_size, &output_height));
2400 TF_RETURN_IF_ERROR(c->Multiply(input_width, block_size, &output_width));
2401
2402 // Will return an error if input_depth is not evenly divisible.
2403 TF_RETURN_IF_ERROR(c->Divide(input_depth, block_size * block_size,
2404 true /* evenly_divisible */, &output_depth));
2405
2406 ShapeHandle output_shape;
2407 TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size,
2408 {output_height, output_width},
2409 output_depth, &output_shape, c));
2410
2411 c->set_output(0, output_shape);
2412 return Status::OK();
2413 });
2414
2415 // --------------------------------------------------------------------------
2416
2417 REGISTER_OP("ExtractImagePatches")
2418 .Input("images: T")
2419 .Output("patches: T")
2420 .Attr("ksizes: list(int) >= 4")
2421 .Attr("strides: list(int) >= 4")
2422 .Attr("rates: list(int) >= 4")
2423 .Attr("T: realnumbertype")
2424 .Attr(GetPaddingAttrString())
__anondb9326b23902(InferenceContext* c) 2425 .SetShapeFn([](InferenceContext* c) {
2426 ShapeHandle input_shape;
2427 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
2428
2429 std::vector<int32> ksizes;
2430 TF_RETURN_IF_ERROR(c->GetAttr("ksizes", &ksizes));
2431 if (ksizes.size() != 4) {
2432 return errors::InvalidArgument(
2433 "ExtractImagePatches requires the ksizes attribute to contain 4 "
2434 "values, but got: ",
2435 ksizes.size());
2436 }
2437
2438 std::vector<int32> strides;
2439 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
2440 if (strides.size() != 4) {
2441 return errors::InvalidArgument(
2442 "ExtractImagePatches requires the stride attribute to contain 4 "
2443 "values, but got: ",
2444 strides.size());
2445 }
2446
2447 std::vector<int32> rates;
2448 TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates));
2449 if (rates.size() != 4) {
2450 return errors::InvalidArgument(
2451 "ExtractImagePatches requires the rates attribute to contain 4 "
2452 "values, but got: ",
2453 rates.size());
2454 }
2455
2456 int32 ksize_rows = ksizes[1];
2457 int32 ksize_cols = ksizes[2];
2458
2459 int32 stride_rows = strides[1];
2460 int32 stride_cols = strides[2];
2461
2462 int32 rate_rows = rates[1];
2463 int32 rate_cols = rates[2];
2464
2465 int32 ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
2466 int32 ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
2467
2468 DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
2469 DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
2470 DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
2471 DimensionHandle output_depth_dim;
2472 TF_RETURN_IF_ERROR(c->Multiply(
2473 c->Dim(input_shape, 3), ksize_rows * ksize_cols, &output_depth_dim));
2474
2475 if (!c->ValueKnown(in_rows_dim) || !c->ValueKnown(in_cols_dim)) {
2476 ShapeHandle output_shape =
2477 c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim,
2478 InferenceContext::kUnknownDim, output_depth_dim});
2479 c->set_output(0, output_shape);
2480 return Status::OK();
2481 }
2482 auto in_rows = c->Value(in_rows_dim);
2483 auto in_cols = c->Value(in_cols_dim);
2484
2485 Padding padding;
2486 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
2487
2488 int64 output_rows, output_cols;
2489 int64 padding_before, padding_after;
2490 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
2491 in_rows, ksize_rows_eff, stride_rows, padding, &output_rows,
2492 &padding_before, &padding_after));
2493 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
2494 in_cols, ksize_cols_eff, stride_cols, padding, &output_cols,
2495 &padding_before, &padding_after));
2496 ShapeHandle output_shape = c->MakeShape(
2497 {batch_size_dim, output_rows, output_cols, output_depth_dim});
2498 c->set_output(0, output_shape);
2499 return Status::OK();
2500 });
2501
2502 // --------------------------------------------------------------------------
2503
2504 // To enable rates, uncomment all lines commented below and use ksize_*_eff
2505 // as the second parameter of all GetWindowedOutputSizeVerbose calls instead
2506 // of ksize_*.
2507 REGISTER_OP("ExtractVolumePatches")
2508 .Input("input: T")
2509 .Output("patches: T")
2510 .Attr("ksizes: list(int) >= 5")
2511 .Attr("strides: list(int) >= 5")
2512 /* .Attr("rates: list(int) >= 5") */
2513 .Attr("T: realnumbertype")
2514 .Attr(GetPaddingAttrString())
__anondb9326b23a02(InferenceContext* c) 2515 .SetShapeFn([](InferenceContext* c) {
2516 ShapeHandle input_shape;
2517 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
2518
2519 std::vector<int32> ksizes;
2520 TF_RETURN_IF_ERROR(c->GetAttr("ksizes", &ksizes));
2521 if (ksizes.size() != 5) {
2522 return errors::InvalidArgument(
2523 "ExtractVolumePatches requires the ksizes attribute to contain 5 "
2524 "values, but got: ",
2525 ksizes.size());
2526 }
2527
2528 std::vector<int32> strides;
2529 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
2530 if (strides.size() != 5) {
2531 return errors::InvalidArgument(
2532 "ExtractVolumePatches requires the stride attribute to contain 5 "
2533 "values, but got: ",
2534 strides.size());
2535 }
2536
2537 /*
2538 // TODO(hsgkim): Enable rates.
2539 // See extract_volume_patches_op.cc for why rates are disabled now.
2540
2541 std::vector<int32> rates;
2542 TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates));
2543 if (rates.size() != 5) {
2544 return errors::InvalidArgument(
2545 "ExtractVolumePatches requires the rates attribute to contain 5 "
2546 "values, but got: ",
2547 rates.size());
2548 }
2549 */
2550
2551 int32 ksize_planes = ksizes[1];
2552 int32 ksize_rows = ksizes[2];
2553 int32 ksize_cols = ksizes[3];
2554
2555 int32 stride_planes = strides[1];
2556 int32 stride_rows = strides[2];
2557 int32 stride_cols = strides[3];
2558
2559 /*
2560 int32 rate_planes = rates[1];
2561 int32 rate_rows = rates[2];
2562 int32 rate_cols = rates[3];
2563
2564 int32 ksize_planes_eff = ksize_planes +
2565 (ksize_planes - 1) * (rate_planes - 1);
2566 int32 ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
2567 int32 ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
2568 */
2569
2570 DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
2571 DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
2572 DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
2573 DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
2574 DimensionHandle output_depth_dim;
2575 TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input_shape, 4),
2576 ksize_planes * ksize_rows * ksize_cols,
2577 &output_depth_dim));
2578
2579 if (!c->ValueKnown(in_planes_dim) || !c->ValueKnown(in_rows_dim) ||
2580 !c->ValueKnown(in_cols_dim)) {
2581 ShapeHandle output_shape =
2582 c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim,
2583 InferenceContext::kUnknownDim, output_depth_dim});
2584 c->set_output(0, output_shape);
2585 return Status::OK();
2586 }
2587 auto in_planes = c->Value(in_planes_dim);
2588 auto in_rows = c->Value(in_rows_dim);
2589 auto in_cols = c->Value(in_cols_dim);
2590
2591 Padding padding;
2592 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
2593
2594 int64 output_planes, output_rows, output_cols;
2595 int64 padding_before, padding_after;
2596 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
2597 in_planes, ksize_planes, stride_planes, padding, &output_planes,
2598 &padding_before, &padding_after));
2599 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
2600 in_rows, ksize_rows, stride_rows, padding, &output_rows,
2601 &padding_before, &padding_after));
2602 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
2603 in_cols, ksize_cols, stride_cols, padding, &output_cols,
2604 &padding_before, &padding_after));
2605 ShapeHandle output_shape =
2606 c->MakeShape({batch_size_dim, output_planes, output_rows, output_cols,
2607 output_depth_dim});
2608 c->set_output(0, output_shape);
2609 return Status::OK();
2610 });
2611
2612 // --------------------------------------------------------------------------
2613
2614 REGISTER_OP("Bitcast")
2615 .Input("input: T")
2616 .Output("output: type")
2617 // All supported dtypes are listed here to include qint16, quint16, uint32,
2618 // and uint64.
2619 .Attr(
2620 "T: {bfloat16, half, float, double, int64, int32, uint8, uint16, "
2621 "uint32, uint64, int8, int16, complex64, complex128, qint8, quint8, "
2622 "qint16, quint16, qint32}")
2623 .Attr(
2624 "type: {bfloat16, half, float, double, int64, int32, uint8, uint16, "
2625 "uint32, uint64, int8, int16, complex64, complex128, qint8, quint8, "
2626 "qint16, quint16, qint32}")
__anondb9326b23b02(InferenceContext* c) 2627 .SetShapeFn([](InferenceContext* c) {
2628 ShapeHandle input = c->input(0);
2629 if (!c->RankKnown(input)) {
2630 // Input shape unknown.
2631 return shape_inference::UnknownShape(c);
2632 }
2633
2634 // Find the size of the input and output data types.
2635 DataType input_type;
2636 DataType output_type;
2637 TF_RETURN_IF_ERROR(c->GetAttr("T", &input_type));
2638 TF_RETURN_IF_ERROR(c->GetAttr("type", &output_type));
2639 const int input_type_size = DataTypeSize(input_type);
2640 const int output_type_size = DataTypeSize(output_type);
2641
2642 if (input_type_size == 0 || output_type_size == 0) {
2643 return errors::InvalidArgument("Cannot bitcast types ",
2644 DataTypeString(input_type), " to ",
2645 DataTypeString(output_type),
2646 " because "
2647 "one of the type sizes is zero.");
2648 }
2649
2650 ShapeHandle new_shape;
2651 if (input_type_size == output_type_size) {
2652 // No change in size.
2653 new_shape = input;
2654 } else if (input_type_size < output_type_size) {
2655 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, 1, &new_shape));
2656
2657 int64 divisor_val = output_type_size / input_type_size;
2658 DimensionHandle last_dim = c->Dim(new_shape, -1);
2659 if (!c->ValueKnown(last_dim) || c->Value(last_dim) == divisor_val) {
2660 TF_RETURN_IF_ERROR(c->Subshape(new_shape, 0, -1, &new_shape));
2661 } else {
2662 return errors::InvalidArgument("Cannot bitcast due to shape. ",
2663 c->Value(last_dim), " does not match ",
2664 divisor_val);
2665 }
2666 } else {
2667 // Input type size is larger than output type size.
2668 int64 divisor_val = input_type_size / output_type_size;
2669 ShapeHandle extension = c->Vector(divisor_val);
2670 TF_RETURN_IF_ERROR(c->Concatenate(input, extension, &new_shape));
2671 }
2672
2673 c->set_output(0, new_shape);
2674 return Status::OK();
2675 });
2676
2677 REGISTER_OP("OneHot")
2678 .Input("indices: TI")
2679 .Input("depth: int32")
2680 .Input("on_value: T")
2681 .Input("off_value: T")
2682 .Attr("axis: int = -1")
2683 .Output("output: T")
2684 .Attr("T: type")
2685 .Attr("TI: {uint8, int32, int64} = DT_INT64")
__anondb9326b23c02(InferenceContext* c) 2686 .SetShapeFn([](InferenceContext* c) {
2687 int32 axis;
2688 TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis));
2689 if (axis < -1) return errors::InvalidArgument("axis must be >= -1");
2690
2691 DimensionHandle depth;
2692 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &depth));
2693
2694 ShapeHandle indices = c->input(0);
2695 if (!c->RankKnown(indices)) return shape_inference::UnknownShape(c);
2696
2697 int32 new_rank = c->Rank(indices) + 1;
2698 // We need to add new_rank to axis in the case the axis is -1 because
2699 // C++ returns negative values from % if the dividend is negative.
2700 int32 depth_index = (axis + new_rank) % new_rank;
2701 // Out shape is indices[0:depth_index] + [depth] + indices[depth_index:].
2702 ShapeHandle front;
2703 ShapeHandle back;
2704 ShapeHandle out;
2705 TF_RETURN_IF_ERROR(c->Subshape(indices, 0, depth_index, &front));
2706 TF_RETURN_IF_ERROR(c->Subshape(indices, depth_index, &back));
2707 TF_RETURN_IF_ERROR(c->Concatenate(front, c->Vector(depth), &front));
2708 TF_RETURN_IF_ERROR(c->Concatenate(front, back, &out));
2709 c->set_output(0, out);
2710 return Status::OK();
2711 });
2712
2713 // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET.
2714 REGISTER_OP("QuantizeAndDequantize")
2715 .Input("input: T")
2716 .Attr("signed_input: bool = true")
2717 .Attr("num_bits: int = 8")
2718 .Attr("range_given: bool = false")
2719 .Attr("input_min: float = 0")
2720 .Attr("input_max: float = 0")
2721 .Output("output: T")
2722 .Attr("T: {bfloat16, half, float, double}")
2723 .SetShapeFn(shape_inference::UnchangedShape)
2724 .Deprecated(22, "Replaced by QuantizeAndDequantizeV2");
2725
2726 // TODO(suharshs): Deprecate QuantizeAndDequantizeV2.
2727 REGISTER_OP("QuantizeAndDequantizeV2")
2728 .Input("input: T")
2729 .Input("input_min: T")
2730 .Input("input_max: T")
2731 .Attr("signed_input: bool = true")
2732 .Attr("num_bits: int = 8")
2733 .Attr("range_given: bool = false")
2734 .Output("output: T")
2735 .Attr("T: {bfloat16, half, float, double}")
2736 .Attr(
2737 "round_mode: {'HALF_TO_EVEN', 'HALF_UP'} = "
2738 "'HALF_TO_EVEN'")
__anondb9326b23d02(InferenceContext* c) 2739 .SetShapeFn([](InferenceContext* c) {
2740 ShapeHandle unused;
2741 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
2742 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2743 c->set_output(0, c->input(0));
2744 return Status::OK();
2745 });
2746
2747 REGISTER_OP("QuantizeAndDequantizeV3")
2748 .Input("input: T")
2749 .Input("input_min: T")
2750 .Input("input_max: T")
2751 .Input("num_bits: int32")
2752 .Attr("signed_input: bool = true")
2753 .Attr("range_given: bool = true")
2754 .Output("output: T")
2755 .Attr("T: {bfloat16, half, float, double}")
__anondb9326b23e02(InferenceContext* c) 2756 .SetShapeFn([](InferenceContext* c) {
2757 ShapeHandle unused;
2758 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
2759 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2760 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2761 c->set_output(0, c->input(0));
2762 return Status::OK();
2763 });
2764
2765 REGISTER_OP("QuantizeV2")
2766 .Input("input: float")
2767 .Input("min_range: float")
2768 .Input("max_range: float")
2769 .Output("output: T")
2770 .Output("output_min: float")
2771 .Output("output_max: float")
2772 .Attr("T: quantizedtype")
2773 .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'MIN_COMBINED'")
2774 .Attr(
2775 "round_mode: {'HALF_AWAY_FROM_ZERO', 'HALF_TO_EVEN'} = "
2776 "'HALF_AWAY_FROM_ZERO'")
__anondb9326b23f02(InferenceContext* c) 2777 .SetShapeFn([](InferenceContext* c) {
2778 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
2779 ShapeHandle unused;
2780 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
2781 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2782 c->set_output(1, c->Scalar());
2783 c->set_output(2, c->Scalar());
2784 return Status::OK();
2785 });
2786
2787 REGISTER_OP("Dequantize")
2788 .Input("input: T")
2789 .Input("min_range: float")
2790 .Input("max_range: float")
2791 .Output("output: float")
2792 .Attr("T: quantizedtype")
2793 .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'MIN_COMBINED'")
__anondb9326b24002(InferenceContext* c) 2794 .SetShapeFn([](InferenceContext* c) {
2795 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
2796 ShapeHandle unused;
2797 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
2798 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2799 return Status::OK();
2800 });
2801
2802 REGISTER_OP("QuantizedConcat")
2803 .Input("concat_dim: int32")
2804 .Input("values: N * T")
2805 .Input("input_mins: N * float32")
2806 .Input("input_maxes: N * float32")
2807 .Output("output: T")
2808 .Output("output_min: float")
2809 .Output("output_max: float")
2810 .Attr("N: int >= 2")
2811 .Attr("T: type")
__anondb9326b24102(InferenceContext* c) 2812 .SetShapeFn([](InferenceContext* c) {
2813 const int n = (c->num_inputs() - 1) / 3;
2814 TF_RETURN_IF_ERROR(shape_inference::ConcatShape(c, n));
2815 ShapeHandle unused;
2816 for (int i = n + 1; i < c->num_inputs(); ++i) {
2817 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused));
2818 }
2819 c->set_output(1, c->Scalar());
2820 c->set_output(2, c->Scalar());
2821 return Status::OK();
2822 });
2823
2824 REGISTER_OP("QuantizedReshape")
2825 .Input("tensor: T")
2826 .Input("shape: Tshape")
2827 .Input("input_min: float")
2828 .Input("input_max: float")
2829 .Output("output: T")
2830 .Output("output_min: float")
2831 .Output("output_max: float")
2832 .Attr("T: type")
2833 .Attr("Tshape: {int32, int64} = DT_INT32")
__anondb9326b24202(InferenceContext* c) 2834 .SetShapeFn([](InferenceContext* c) {
2835 TF_RETURN_IF_ERROR(SetOutputShapeForReshape(c));
2836 ShapeHandle unused;
2837 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2838 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2839 c->set_output(1, c->Scalar());
2840 c->set_output(2, c->Scalar());
2841 return Status::OK();
2842 });
2843
2844 REGISTER_OP("QuantizedInstanceNorm")
2845 .Input("x: T")
2846 .Input("x_min: float")
2847 .Input("x_max: float")
2848 .Output("y: T")
2849 .Output("y_min: float")
2850 .Output("y_max: float")
2851 .Attr("T: quantizedtype")
2852 .Attr("output_range_given: bool = false")
2853 .Attr("given_y_min: float = 0")
2854 .Attr("given_y_max: float = 0")
2855 .Attr("variance_epsilon: float = 1e-5")
2856 .Attr("min_separation: float = 1e-3")
__anondb9326b24302(shape_inference::InferenceContext* c) 2857 .SetShapeFn([](shape_inference::InferenceContext* c) {
2858 shape_inference::ShapeHandle unused;
2859 // x should be a rank 4 tensor.
2860 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &unused));
2861 // Assert x_min and x_max are scalars (rank 0).
2862 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
2863 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2864 // y has the same shape as x.
2865 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
2866 // y_min and y_max are scalars.
2867 c->set_output(1, c->Scalar());
2868 c->set_output(2, c->Scalar());
2869 return Status::OK();
2870 });
2871
2872 namespace {
2873
ScatterNdShapeHelper(InferenceContext * c,ShapeHandle indices_shape,ShapeHandle updates_shape,ShapeHandle output_shape)2874 Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape,
2875 ShapeHandle updates_shape,
2876 ShapeHandle output_shape) {
2877 if (c->Value(c->NumElements(output_shape)) == 0 &&
2878 (c->Value(c->NumElements(indices_shape)) > 0 ||
2879 c->Value(c->NumElements(updates_shape)) > 0)) {
2880 return errors::InvalidArgument(
2881 "Indices and updates specified for empty output shape");
2882 }
2883
2884 if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) {
2885 const int64 outer_dims = c->Rank(indices_shape) - 1;
2886 const DimensionHandle ixdim = c->Dim(indices_shape, -1);
2887
2888 // We can only do more validation if the last dimension of indices
2889 // is a known value.
2890 if (c->ValueKnown(ixdim)) {
2891 int64 ix = c->Value(ixdim);
2892 ShapeHandle unused;
2893 ShapeHandle prefix_indices;
2894 TF_RETURN_IF_ERROR(
2895 c->Subshape(indices_shape, 0, outer_dims, &prefix_indices));
2896 ShapeHandle prefix_updates;
2897 TF_RETURN_IF_ERROR(
2898 c->Subshape(updates_shape, 0, outer_dims, &prefix_updates));
2899
2900 Status s = c->Merge(prefix_indices, prefix_updates, &unused);
2901 if (!s.ok()) {
2902 return errors::InvalidArgument(
2903 "The outer ", outer_dims,
2904 " dimensions of indices.shape=", c->DebugString(indices_shape),
2905 " must match the outer ", outer_dims,
2906 " dimensions of updates.shape=", c->DebugString(updates_shape),
2907 ": ", s.error_message());
2908 }
2909
2910 ShapeHandle suffix_output;
2911 TF_RETURN_IF_ERROR(c->Subshape(output_shape, ix, &suffix_output));
2912 ShapeHandle suffix_updates;
2913 TF_RETURN_IF_ERROR(
2914 c->Subshape(updates_shape, outer_dims, &suffix_updates));
2915 s = c->Merge(suffix_output, suffix_updates, &unused);
2916 if (!s.ok()) {
2917 return errors::InvalidArgument(
2918 "The inner ", c->Rank(output_shape) - ix,
2919 " dimensions of output.shape=", c->DebugString(output_shape),
2920 " must match the inner ", c->Rank(updates_shape) - outer_dims,
2921 " dimensions of updates.shape=", c->DebugString(updates_shape),
2922 ": ", s.error_message());
2923 }
2924 }
2925 }
2926
2927 c->set_output(0, output_shape);
2928 return Status::OK();
2929 }
2930
ScatterNdShape(InferenceContext * c)2931 Status ScatterNdShape(InferenceContext* c) {
2932 ShapeHandle indices_shape;
2933 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &indices_shape));
2934 ShapeHandle updates_shape;
2935 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &updates_shape));
2936 ShapeHandle output_shape;
2937 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &output_shape));
2938 return ScatterNdShapeHelper(c, indices_shape, updates_shape, output_shape);
2939 }
2940
ScatterNdTensorShape(InferenceContext * c)2941 Status ScatterNdTensorShape(InferenceContext* c) {
2942 ShapeHandle output_shape;
2943 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &output_shape));
2944 ShapeHandle indices_shape;
2945 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape));
2946 ShapeHandle updates_shape;
2947 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape));
2948 return ScatterNdShapeHelper(c, indices_shape, updates_shape, output_shape);
2949 }
2950
2951 } // namespace
2952
2953 REGISTER_OP("UpperBound")
2954 .Input("sorted_inputs: T")
2955 .Input("values: T")
2956 .Output("output: out_type")
2957 .Attr("T: type")
2958 .Attr("out_type: {int32, int64} = DT_INT32")
__anondb9326b24502(InferenceContext* c) 2959 .SetShapeFn([](InferenceContext* c) {
2960 ShapeHandle unused_shape;
2961 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused_shape));
2962 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape));
2963 c->set_output(0, c->input(1));
2964 return Status::OK();
2965 });
2966
2967 REGISTER_OP("LowerBound")
2968 .Input("sorted_inputs: T")
2969 .Input("values: T")
2970 .Output("output: out_type")
2971 .Attr("T: type")
2972 .Attr("out_type: {int32, int64} = DT_INT32")
__anondb9326b24602(InferenceContext* c) 2973 .SetShapeFn([](InferenceContext* c) {
2974 ShapeHandle unused_shape;
2975 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused_shape));
2976 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape));
2977 c->set_output(0, c->input(1));
2978 return Status::OK();
2979 });
2980
2981 REGISTER_OP("ScatterNd")
2982 .Input("indices: Tindices")
2983 .Input("updates: T")
2984 .Input("shape: Tindices")
2985 .Output("output: T")
2986 .Attr("T: type")
2987 .Attr("Tindices: {int32, int64}")
2988 .SetShapeFn(ScatterNdShape);
2989
2990 REGISTER_OP("TensorScatterUpdate")
2991 .Input("tensor: T")
2992 .Input("indices: Tindices")
2993 .Input("updates: T")
2994 .Output("output: T")
2995 .Attr("T: type")
2996 .Attr("Tindices: {int32, int64}")
2997 .SetShapeFn(ScatterNdTensorShape);
2998
2999 REGISTER_OP("TensorScatterAdd")
3000 .Input("tensor: T")
3001 .Input("indices: Tindices")
3002 .Input("updates: T")
3003 .Output("output: T")
3004 .Attr("T: type")
3005 .Attr("Tindices: {int32, int64}")
3006 .SetShapeFn(ScatterNdTensorShape);
3007
3008 REGISTER_OP("TensorScatterSub")
3009 .Input("tensor: T")
3010 .Input("indices: Tindices")
3011 .Input("updates: T")
3012 .Output("output: T")
3013 .Attr("T: type")
3014 .Attr("Tindices: {int32, int64}")
3015 .SetShapeFn(ScatterNdTensorShape);
3016
3017 REGISTER_OP("ScatterNdNonAliasingAdd")
3018 .Input("input: T")
3019 .Input("indices: Tindices")
3020 .Input("updates: T")
3021 .Output("output: T")
3022 .Attr("T: {numbertype, bool}")
3023 .Attr("Tindices: {int32, int64}")
3024 .SetShapeFn(shape_inference::ScatterNdUpdateShape);
3025
3026 REGISTER_OP("FakeQuantWithMinMaxArgs")
3027 .Attr("min: float = -6.0")
3028 .Attr("max: float = 6.0")
3029 .Attr("num_bits: int = 8")
3030 .Attr("narrow_range: bool = false")
3031 .Input("inputs: float")
3032 .Output("outputs: float")
3033 .SetShapeFn(shape_inference::UnchangedShape);
3034
3035 REGISTER_OP("FakeQuantWithMinMaxArgsGradient")
3036 .Attr("min: float = -6.0")
3037 .Attr("max: float = 6.0")
3038 .Attr("num_bits: int = 8")
3039 .Attr("narrow_range: bool = false")
3040 .Input("gradients: float")
3041 .Input("inputs: float")
3042 .Output("backprops: float")
3043 .SetShapeFn(shape_inference::UnchangedShape);
3044
3045 REGISTER_OP("FakeQuantWithMinMaxVars")
3046 .Attr("num_bits: int = 8")
3047 .Attr("narrow_range: bool = false")
3048 .Input("inputs: float")
3049 .Input("min: float")
3050 .Input("max: float")
3051 .Output("outputs: float")
__anondb9326b24702(InferenceContext* c) 3052 .SetShapeFn([](InferenceContext* c) {
3053 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
3054 ShapeHandle unused;
3055 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
3056 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
3057 return Status::OK();
3058 });
3059
3060 REGISTER_OP("FakeQuantWithMinMaxVarsGradient")
3061 .Attr("num_bits: int = 8")
3062 .Attr("narrow_range: bool = false")
3063 .Input("gradients: float")
3064 .Input("inputs: float")
3065 .Input("min: float")
3066 .Input("max: float")
3067 .Output("backprops_wrt_input: float")
3068 .Output("backprop_wrt_min: float")
3069 .Output("backprop_wrt_max: float")
__anondb9326b24802(InferenceContext* c) 3070 .SetShapeFn([](InferenceContext* c) {
3071 // gradients and inputs are same size.
3072 ShapeHandle inputs;
3073 TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &inputs));
3074
3075 // min and max are scalars
3076 ShapeHandle min_max;
3077 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_max));
3078 TF_RETURN_IF_ERROR(c->Merge(min_max, c->input(3), &min_max));
3079
3080 c->set_output(0, inputs);
3081 c->set_output(1, min_max);
3082 c->set_output(2, min_max);
3083 return Status::OK();
3084 });
3085
3086 REGISTER_OP("FakeQuantWithMinMaxVarsPerChannel")
3087 .Attr("num_bits: int = 8")
3088 .Attr("narrow_range: bool = false")
3089 .Input("inputs: float")
3090 .Input("min: float")
3091 .Input("max: float")
3092 .Output("outputs: float")
__anondb9326b24902(InferenceContext* c) 3093 .SetShapeFn([](InferenceContext* c) {
3094 ShapeHandle input, min, max;
3095 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
3096 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &min));
3097 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &max));
3098
3099 DimensionHandle unused;
3100 TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -1), c->Dim(min, 0), &unused));
3101 TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -1), c->Dim(max, 0), &unused));
3102 TF_RETURN_IF_ERROR(c->Merge(c->Dim(min, 0), c->Dim(max, 0), &unused));
3103
3104 c->set_output(0, input);
3105 return Status::OK();
3106 });
3107
3108 REGISTER_OP("FakeQuantWithMinMaxVarsPerChannelGradient")
3109 .Attr("num_bits: int = 8")
3110 .Attr("narrow_range: bool = false")
3111 .Input("gradients: float")
3112 .Input("inputs: float")
3113 .Input("min: float")
3114 .Input("max: float")
3115 .Output("backprops_wrt_input: float")
3116 .Output("backprop_wrt_min: float")
3117 .Output("backprop_wrt_max: float")
__anondb9326b24a02(InferenceContext* c) 3118 .SetShapeFn([](InferenceContext* c) {
3119 ShapeHandle inputs;
3120 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &inputs));
3121 TF_RETURN_IF_ERROR(c->WithRankAtMost(inputs, 4, &inputs));
3122 TF_RETURN_IF_ERROR(c->Merge(inputs, c->input(1), &inputs));
3123
3124 ShapeHandle last_dim = c->Vector(c->Dim(inputs, -1));
3125
3126 ShapeHandle min_max;
3127 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &min_max));
3128 TF_RETURN_IF_ERROR(c->Merge(min_max, last_dim, &min_max));
3129 TF_RETURN_IF_ERROR(c->Merge(c->input(3), min_max, &min_max));
3130
3131 c->set_output(0, inputs);
3132 c->set_output(1, min_max);
3133 c->set_output(2, min_max);
3134 return Status::OK();
3135 });
3136
3137 #ifdef INTEL_MKL
3138 REGISTER_OP("_MklConcat")
3139 .Input("concat_dim: int32")
3140 .Input("values: N * T")
3141 .Input("mkl_concat_dim: uint8")
3142 .Input("mkl_values: N * uint8")
3143 .Output("output: T")
3144 .Output("mkl_output: uint8")
3145 .Attr("N: int >= 2")
3146 .Attr("T: type")
__anondb9326b24b02(InferenceContext* c) 3147 .SetShapeFn([](InferenceContext* c) {
3148 return shape_inference::ConcatShape(c, c->num_inputs() - 3);
3149 })
3150 .Doc(R"doc(
3151 MKL version of Concat operator. Uses MKL DNN APIs to perform concatenation.
3152
3153 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
3154 expected to invoke these operators.
3155 )doc");
3156 #endif
3157
3158 // Deprecated op registrations:
3159
3160 // The following can be deleted after 10mar2017.
3161 REGISTER_OP("BatchMatrixDiag")
3162 .Input("diagonal: T")
3163 .Output("output: T")
3164 .Attr("T: type")
3165 .Deprecated(14, "Use MatrixDiag")
3166 .SetShapeFn(shape_inference::UnknownShape);
3167 REGISTER_OP("BatchMatrixSetDiag")
3168 .Input("input: T")
3169 .Input("diagonal: T")
3170 .Output("output: T")
3171 .Attr("T: type")
3172 .Deprecated(14, "Use MatrixSetDiag")
3173 .SetShapeFn(shape_inference::UnknownShape);
3174 REGISTER_OP("BatchMatrixDiagPart")
3175 .Input("input: T")
3176 .Output("diagonal: T")
3177 .Attr("T: type")
3178 .Deprecated(14, "Use MatrixDiagPart")
3179 .SetShapeFn(shape_inference::UnknownShape);
3180 REGISTER_OP("BatchMatrixBandPart")
3181 .Input("input: T")
3182 .Input("num_lower: int64")
3183 .Input("num_upper: int64")
3184 .Output("band: T")
3185 .Attr("T: type")
3186 .Deprecated(14, "Use MatrixBandPart")
3187 .SetShapeFn(shape_inference::UnknownShape);
3188
3189 } // namespace tensorflow
3190