1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/op_def_builder.h"
19 #include "tensorflow/core/framework/shape_inference.h"
20
21 namespace tensorflow {
22
23 using shape_inference::DimensionHandle;
24 using shape_inference::InferenceContext;
25 using shape_inference::ShapeHandle;
26
27 namespace {
28
DequeueManyV2Shape(InferenceContext * c,ShapeHandle n_shape)29 Status DequeueManyV2Shape(InferenceContext* c, ShapeHandle n_shape) {
30 auto* t = c->input_handle_shapes_and_types(0);
31 if (t != nullptr && t->size() == c->num_outputs()) {
32 for (int i = 0; i < c->num_outputs(); ++i) {
33 ShapeHandle combined_shape;
34 TF_RETURN_IF_ERROR(
35 c->Concatenate(n_shape, (*t)[i].shape, &combined_shape));
36 c->set_output(i, combined_shape);
37 }
38 return Status::OK();
39 } else {
40 return shape_inference::UnknownShape(c);
41 }
42 }
43
44 } // namespace
45
46 // --------------------------------------------------------------------------
47
48 REGISTER_OP("DynamicPartition")
49 .Input("data: T")
50 .Input("partitions: int32")
51 .Output("outputs: num_partitions * T")
52 .Attr("num_partitions: int")
53 .Attr("T: type")
__anona2f81a440202(InferenceContext* c) 54 .SetShapeFn([](InferenceContext* c) {
55 int64 num_partitions;
56 TF_RETURN_IF_ERROR(c->GetAttr("num_partitions", &num_partitions));
57
58 ShapeHandle data_shape = c->input(0);
59 ShapeHandle partitions_shape = c->input(1);
60
61 if (!c->RankKnown(partitions_shape)) {
62 return shape_inference::UnknownShape(c);
63 }
64
65 const int64 rank = c->Rank(partitions_shape);
66
67 // data shape must start with partitions_shape
68 ShapeHandle unused;
69 TF_RETURN_IF_ERROR(
70 c->MergePrefix(data_shape, partitions_shape, &unused, &unused));
71
72 // The partition shape is dynamic in the 0th dimension, and matches
73 // data_shape in the remaining dimensions.
74 ShapeHandle unknown_dim0 = c->MakeShape({c->UnknownDim()});
75
76 ShapeHandle data_suffix_shape;
77 TF_RETURN_IF_ERROR(c->Subshape(data_shape, rank, &data_suffix_shape));
78 ShapeHandle result_shape;
79 TF_RETURN_IF_ERROR(
80 c->Concatenate(unknown_dim0, data_suffix_shape, &result_shape));
81
82 for (int i = 0; i < c->num_outputs(); ++i) {
83 c->set_output(i, result_shape);
84 }
85
86 return Status::OK();
87 });
88
89 namespace {
90
DynamicStitchShapeFunction(InferenceContext * c)91 Status DynamicStitchShapeFunction(InferenceContext* c) {
92 int32 num_partitions;
93 TF_RETURN_IF_ERROR(c->GetAttr("N", &num_partitions));
94
95 bool all_indices_constant = true;
96 int32 max_index = -1;
97 ShapeHandle extra_shape = c->UnknownShape();
98 for (int i = 0; i < num_partitions; ++i) {
99 const Tensor* indices_t = c->input_tensor(i);
100 if (indices_t == nullptr) {
101 all_indices_constant = false;
102 }
103
104 ShapeHandle indices_shape = c->input(i);
105 ShapeHandle data_shape = c->input(i + num_partitions);
106 if (!c->RankKnown(indices_shape)) {
107 continue;
108 }
109 const int64 indices_rank = c->Rank(indices_shape);
110
111 // Assert that data_shape starts with indices_shape.
112 ShapeHandle unused;
113 TF_RETURN_IF_ERROR(
114 c->MergePrefix(data_shape, indices_shape, &unused, &unused));
115
116 // The rest belongs to output.
117 ShapeHandle rest;
118 TF_RETURN_IF_ERROR(c->Subshape(data_shape, indices_rank, &rest));
119 TF_RETURN_IF_ERROR(c->Merge(extra_shape, rest, &extra_shape));
120
121 if (indices_t != nullptr) {
122 // The length is based on the highest index from flattened indices.
123 const int32* indices = indices_t->flat<int32>().data();
124 int64 count = indices_t->NumElements();
125 for (int64 i = 0; i < count; ++i) {
126 if (indices[i] > max_index) {
127 max_index = indices[i];
128 }
129 }
130 }
131 }
132
133 ShapeHandle output_shape = c->Vector(
134 all_indices_constant ? c->MakeDim(max_index + 1) : c->UnknownDim());
135 TF_RETURN_IF_ERROR(c->Concatenate(output_shape, extra_shape, &output_shape));
136 c->set_output(0, output_shape);
137 return Status::OK();
138 }
139
140 } // namespace
141
142 REGISTER_OP("DynamicStitch")
143 .Input("indices: N * int32")
144 .Input("data: N * T")
145 .Output("merged: T")
146 .Attr("N : int >= 1")
147 .Attr("T : type")
148 .SetShapeFn(DynamicStitchShapeFunction);
149
150 REGISTER_OP("ParallelDynamicStitch")
151 .Input("indices: N * int32")
152 .Input("data: N * T")
153 .Output("merged: T")
154 .Attr("N : int >= 1")
155 .Attr("T : type")
156 .SetShapeFn(DynamicStitchShapeFunction);
157
158 // --------------------------------------------------------------------------
159
160 namespace {
TwoElementVectorInputsAndScalarOutputs(InferenceContext * c)161 Status TwoElementVectorInputsAndScalarOutputs(InferenceContext* c) {
162 ShapeHandle handle;
163 DimensionHandle unused_handle;
164 for (int i = 0; i < c->num_inputs(); ++i) {
165 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle));
166 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle));
167 }
168 for (int i = 0; i < c->num_outputs(); ++i) {
169 c->set_output(i, c->Scalar());
170 }
171 return Status::OK();
172 }
173
TwoElementOutput(InferenceContext * c)174 Status TwoElementOutput(InferenceContext* c) {
175 c->set_output(0, c->Vector(2));
176 return Status::OK();
177 }
178 } // namespace
179
180 REGISTER_OP("RandomShuffleQueue")
181 .Output("handle: Ref(string)")
182 .Attr("component_types: list(type) >= 1")
183 .Attr("shapes: list(shape) >= 0 = []")
184 .Attr("capacity: int = -1")
185 .Attr("min_after_dequeue: int = 0")
186 .Attr("seed: int = 0")
187 .Attr("seed2: int = 0")
188 .Attr("container: string = ''")
189 .Attr("shared_name: string = ''")
190 .SetIsStateful()
191 .SetShapeFn(TwoElementOutput);
192
193 REGISTER_OP("RandomShuffleQueueV2")
194 .Output("handle: resource")
195 .Attr("component_types: list(type) >= 1")
196 .Attr("shapes: list(shape) >= 0 = []")
197 .Attr("capacity: int = -1")
198 .Attr("min_after_dequeue: int = 0")
199 .Attr("seed: int = 0")
200 .Attr("seed2: int = 0")
201 .Attr("container: string = ''")
202 .Attr("shared_name: string = ''")
203 .SetIsStateful()
204 .SetShapeFn(shape_inference::ScalarShape);
205
206 REGISTER_OP("FIFOQueue")
207 .Output("handle: Ref(string)")
208 .Attr("component_types: list(type) >= 1")
209 .Attr("shapes: list(shape) >= 0 = []")
210 .Attr("capacity: int = -1")
211 .Attr("container: string = ''")
212 .Attr("shared_name: string = ''")
213 .SetIsStateful()
214 .SetShapeFn(TwoElementOutput);
215
216 REGISTER_OP("FIFOQueueV2")
217 .Output("handle: resource")
218 .Attr("component_types: list(type) >= 1")
219 .Attr("shapes: list(shape) >= 0 = []")
220 .Attr("capacity: int = -1")
221 .Attr("container: string = ''")
222 .Attr("shared_name: string = ''")
223 .SetIsStateful()
224 .SetShapeFn(shape_inference::ScalarShape);
225
226 REGISTER_OP("PaddingFIFOQueue")
227 .Output("handle: Ref(string)")
228 .Attr("component_types: list(type) >= 1")
229 .Attr("shapes: list(shape) >= 0 = []")
230 .Attr("capacity: int = -1")
231 .Attr("container: string = ''")
232 .Attr("shared_name: string = ''")
233 .SetIsStateful()
234 .SetShapeFn(TwoElementOutput);
235
236 REGISTER_OP("PaddingFIFOQueueV2")
237 .Output("handle: resource")
238 .Attr("component_types: list(type) >= 1")
239 .Attr("shapes: list(shape) >= 0 = []")
240 .Attr("capacity: int = -1")
241 .Attr("container: string = ''")
242 .Attr("shared_name: string = ''")
243 .SetIsStateful()
244 .SetShapeFn(shape_inference::ScalarShape);
245
246 REGISTER_OP("PriorityQueue")
247 .Output("handle: Ref(string)")
248 .Attr("component_types: list(type) >= 0 = []")
249 .Attr("shapes: list(shape) >= 0")
250 .Attr("capacity: int = -1")
251 .Attr("container: string = ''")
252 .Attr("shared_name: string = ''")
253 .SetIsStateful()
254 .SetShapeFn(TwoElementOutput);
255
256 REGISTER_OP("PriorityQueueV2")
257 .Output("handle: resource")
258 .Attr("component_types: list(type) >= 0 = []")
259 .Attr("shapes: list(shape) >= 0")
260 .Attr("capacity: int = -1")
261 .Attr("container: string = ''")
262 .Attr("shared_name: string = ''")
263 .SetIsStateful()
264 .SetShapeFn(shape_inference::ScalarShape);
265
266 REGISTER_OP("FakeQueue")
267 .Input("resource: resource")
268 .Output("handle: Ref(string)")
269 .SetIsStateful()
270 .SetShapeFn(TwoElementOutput);
271
272 REGISTER_OP("QueueEnqueue")
273 .Input("handle: Ref(string)")
274 .Input("components: Tcomponents")
275 .Attr("Tcomponents: list(type) >= 1")
276 .Attr("timeout_ms: int = -1")
277 .SetShapeFn(shape_inference::UnknownShape);
278
279 REGISTER_OP("QueueEnqueueV2")
280 .Input("handle: resource")
281 .Input("components: Tcomponents")
282 .Attr("Tcomponents: list(type) >= 1")
283 .Attr("timeout_ms: int = -1")
284 .SetShapeFn(shape_inference::UnknownShape);
285
286 REGISTER_OP("QueueEnqueueMany")
287 .Input("handle: Ref(string)")
288 .Input("components: Tcomponents")
289 .Attr("Tcomponents: list(type) >= 1")
290 .Attr("timeout_ms: int = -1")
291 .SetShapeFn(shape_inference::UnknownShape);
292
293 REGISTER_OP("QueueEnqueueManyV2")
294 .Input("handle: resource")
295 .Input("components: Tcomponents")
296 .Attr("Tcomponents: list(type) >= 1")
297 .Attr("timeout_ms: int = -1")
298 .SetShapeFn(shape_inference::UnknownShape);
299
300 REGISTER_OP("QueueDequeue")
301 .Input("handle: Ref(string)")
302 .Output("components: component_types")
303 .Attr("component_types: list(type) >= 1")
304 .Attr("timeout_ms: int = -1")
305 .SetShapeFn(shape_inference::UnknownShape);
306
307 REGISTER_OP("QueueDequeueV2")
308 .Input("handle: resource")
309 .Output("components: component_types")
310 .Attr("component_types: list(type) >= 1")
311 .Attr("timeout_ms: int = -1")
__anona2f81a440502(InferenceContext* c) 312 .SetShapeFn([](InferenceContext* c) {
313 auto* t = c->input_handle_shapes_and_types(0);
314 if (t != nullptr && t->size() == c->num_outputs()) {
315 for (int i = 0; i < c->num_outputs(); ++i) {
316 c->set_output(i, (*t)[i].shape);
317 }
318 return Status::OK();
319 } else {
320 return shape_inference::UnknownShape(c);
321 }
322 });
323
324 REGISTER_OP("QueueDequeueMany")
325 .Input("handle: Ref(string)")
326 .Input("n: int32")
327 .Output("components: component_types")
328 .Attr("component_types: list(type) >= 1")
329 .Attr("timeout_ms: int = -1")
330 .SetShapeFn(shape_inference::UnknownShape);
331
332 REGISTER_OP("QueueDequeueManyV2")
333 .Input("handle: resource")
334 .Input("n: int32")
335 .Output("components: component_types")
336 .Attr("component_types: list(type) >= 1")
337 .Attr("timeout_ms: int = -1")
__anona2f81a440602(InferenceContext* c) 338 .SetShapeFn([](InferenceContext* c) {
339 ShapeHandle n_shape;
340 if (c->input_tensor(1) == nullptr) {
341 n_shape = c->Vector(InferenceContext::kUnknownDim);
342 } else {
343 const int32 n = c->input_tensor(1)->scalar<int32>()();
344 if (n < 0) {
345 return errors::InvalidArgument("Input 'n' must be >= 0, but is ", n);
346 }
347 n_shape = c->Vector(n);
348 }
349 return DequeueManyV2Shape(c, n_shape);
350 });
351
352 REGISTER_OP("QueueDequeueUpTo")
353 .Input("handle: Ref(string)")
354 .Input("n: int32")
355 .Output("components: component_types")
356 .Attr("component_types: list(type) >= 1")
357 .Attr("timeout_ms: int = -1")
358 .SetShapeFn(shape_inference::UnknownShape);
359
360 REGISTER_OP("QueueDequeueUpToV2")
361 .Input("handle: resource")
362 .Input("n: int32")
363 .Output("components: component_types")
364 .Attr("component_types: list(type) >= 1")
365 .Attr("timeout_ms: int = -1")
__anona2f81a440702(InferenceContext* c) 366 .SetShapeFn([](InferenceContext* c) {
367 return DequeueManyV2Shape(c, c->Vector(InferenceContext::kUnknownDim));
368 });
369
370 REGISTER_OP("QueueClose")
371 .Input("handle: Ref(string)")
372 .SetShapeFn(TwoElementVectorInputsAndScalarOutputs)
373 .Attr("cancel_pending_enqueues: bool = false");
374
375 REGISTER_OP("QueueCloseV2")
376 .Input("handle: resource")
377 .SetShapeFn(shape_inference::NoOutputs)
378 .Attr("cancel_pending_enqueues: bool = false");
379
380 REGISTER_OP("QueueIsClosed")
381 .Input("handle: Ref(string)")
382 .Output("is_closed: bool")
383 .SetShapeFn(shape_inference::ScalarShape);
384
385 REGISTER_OP("QueueIsClosedV2")
386 .Input("handle: resource")
387 .Output("is_closed: bool")
388 .SetShapeFn(shape_inference::ScalarShape);
389
390 REGISTER_OP("QueueSize")
391 .Input("handle: Ref(string)")
392 .Output("size: int32")
393 .SetShapeFn(TwoElementVectorInputsAndScalarOutputs);
394
395 REGISTER_OP("QueueSizeV2")
396 .Input("handle: resource")
397 .Output("size: int32")
398 .SetShapeFn(shape_inference::UnchangedShape);
399
400 // --------------------------------------------------------------------------
401
402 REGISTER_OP("AccumulatorNumAccumulated")
403 .Input("handle: Ref(string)")
404 .Output("num_accumulated: int32")
405 .SetShapeFn(shape_inference::ScalarShape);
406
407 REGISTER_OP("AccumulatorSetGlobalStep")
408 .Input("handle: Ref(string)")
409 .Input("new_global_step: int64")
__anona2f81a440802(InferenceContext* c) 410 .SetShapeFn([](InferenceContext* c) {
411 ShapeHandle unused;
412 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
413 return Status::OK();
414 });
415
416 REGISTER_OP("ConditionalAccumulator")
417 .Output("handle: Ref(string)")
418 .Attr("dtype: numbertype")
419 .Attr("shape: shape")
420 .Attr("container: string = ''")
421 .Attr("shared_name: string = ''")
422 .Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ")
423 .SetIsStateful()
__anona2f81a440902(InferenceContext* c) 424 .SetShapeFn([](InferenceContext* c) {
425 c->set_output(0, c->Vector(2));
426 return Status::OK();
427 });
428
429 REGISTER_OP("AccumulatorApplyGradient")
430 .Input("handle: Ref(string)")
431 .Input("local_step: int64")
432 .Input("gradient: dtype")
433 .Attr("dtype: numbertype")
__anona2f81a440a02(InferenceContext* c) 434 .SetShapeFn([](InferenceContext* c) {
435 ShapeHandle unused;
436 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
437 return Status::OK();
438 });
439
440 REGISTER_OP("AccumulatorTakeGradient")
441 .Input("handle: Ref(string)")
442 .Input("num_required: int32")
443 .Output("average: dtype")
__anona2f81a440b02(InferenceContext* c) 444 .SetShapeFn([](InferenceContext* c) {
445 ShapeHandle unused;
446 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
447 // Shape of output is the shape of the accumulator referenced
448 // by 'handle', but which is not available here, so we lose
449 // shape information.
450 return shape_inference::UnknownShape(c);
451 })
452 .Attr("dtype: numbertype");
453
454 // -----------------V2 accumulators that use resource -------------------------
455
456 REGISTER_OP("ResourceAccumulatorNumAccumulated")
457 .Input("handle: resource")
458 .Output("num_accumulated: int32")
459 .SetShapeFn(shape_inference::ScalarShape);
460
461 REGISTER_OP("ResourceAccumulatorSetGlobalStep")
462 .Input("handle: resource")
463 .Input("new_global_step: int64")
__anona2f81a440c02(InferenceContext* c) 464 .SetShapeFn([](InferenceContext* c) {
465 ShapeHandle unused;
466 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
467 return Status::OK();
468 });
469
470 REGISTER_OP("ResourceConditionalAccumulator")
471 .Output("handle: resource")
472 .Attr("dtype: numbertype")
473 .Attr("shape: shape")
474 .Attr("container: string = ''")
475 .Attr("shared_name: string = ''")
476 .Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ")
477 .SetIsStateful()
__anona2f81a440d02(InferenceContext* c) 478 .SetShapeFn([](InferenceContext* c) {
479 c->set_output(0, c->Vector(2));
480 return Status::OK();
481 });
482
483 REGISTER_OP("ResourceAccumulatorApplyGradient")
484 .Input("handle: resource")
485 .Input("local_step: int64")
486 .Input("gradient: dtype")
487 .Attr("dtype: numbertype")
__anona2f81a440e02(InferenceContext* c) 488 .SetShapeFn([](InferenceContext* c) {
489 ShapeHandle unused;
490 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
491 return Status::OK();
492 });
493
494 REGISTER_OP("ResourceAccumulatorTakeGradient")
495 .Input("handle: resource")
496 .Input("num_required: int32")
497 .Output("average: dtype")
__anona2f81a440f02(InferenceContext* c) 498 .SetShapeFn([](InferenceContext* c) {
499 ShapeHandle unused;
500 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
501 // Shape of output is the shape of the accumulator referenced
502 // by 'handle', but which is not available here, so we lose
503 // shape information.
504 return shape_inference::UnknownShape(c);
505 })
506 .Attr("dtype: numbertype");
507
508 // TODO(nponomareva): change these all to use resources.
509 REGISTER_OP("SparseConditionalAccumulator")
510 .Output("handle: Ref(string)")
511 .Attr("dtype: numbertype")
512 .Attr("shape: shape")
513 .Attr("container: string = ''")
514 .Attr("shared_name: string = ''")
515 .Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ")
516 .SetIsStateful()
__anona2f81a441002(InferenceContext* c) 517 .SetShapeFn([](InferenceContext* c) {
518 c->set_output(0, c->Vector(2));
519 return Status::OK();
520 });
521
522 REGISTER_OP("SparseAccumulatorApplyGradient")
523 .Input("handle: Ref(string)")
524 .Input("local_step: int64")
525 .Input("gradient_indices: int64")
526 .Input("gradient_values: dtype")
527 .Input("gradient_shape: int64")
528 .Attr("dtype: numbertype")
529 .Attr("has_known_shape: bool")
__anona2f81a441102(InferenceContext* c) 530 .SetShapeFn([](InferenceContext* c) {
531 ShapeHandle unused;
532 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
533 return Status::OK();
534 });
535
536 REGISTER_OP("SparseAccumulatorTakeGradient")
537 .Input("handle: Ref(string)")
538 .Input("num_required: int32")
539 .Output("indices: int64")
540 .Output("values: dtype")
541 .Output("shape: int64")
542 .Attr("dtype: numbertype")
__anona2f81a441202(InferenceContext* c) 543 .SetShapeFn([](InferenceContext* c) {
544 ShapeHandle unused;
545 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
546 // Shape of output is the shape of the accumulator referenced
547 // by 'handle', but which is not available here, so we lose
548 // shape information.
549 return shape_inference::UnknownShape(c);
550 });
551
552 // --------------------------------------------------------------------------
553
554 REGISTER_OP("StackV2")
555 .Input("max_size: int32")
556 .Output("handle: resource")
557 .Attr("elem_type: type")
558 .Attr("stack_name: string = ''")
559 .SetIsStateful()
560 .SetShapeFn(TwoElementOutput);
561
562 REGISTER_OP("StackPushV2")
563 .Input("handle: resource")
564 .Input("elem: T")
565 .Output("output: T")
566 .Attr("T: type")
567 .Attr("swap_memory: bool = false")
__anona2f81a441302(shape_inference::InferenceContext* c) 568 .SetShapeFn([](shape_inference::InferenceContext* c) {
569 c->set_output(0, c->input(1));
570 return Status::OK();
571 });
572
573 REGISTER_OP("StackPopV2")
574 .Input("handle: resource")
575 .Output("elem: elem_type")
576 .Attr("elem_type: type")
577 .SetShapeFn(shape_inference::UnknownShape);
578
579 REGISTER_OP("StackCloseV2")
580 .Input("handle: resource")
581 .SetShapeFn(TwoElementVectorInputsAndScalarOutputs);
582
583 // Deprecated ref-typed variants of stack.
584
585 REGISTER_OP("Stack")
586 .Output("handle: Ref(string)")
587 .Attr("elem_type: type")
588 .Attr("stack_name: string = ''")
589 .SetIsStateful()
590 .SetShapeFn(TwoElementOutput);
591
592 REGISTER_OP("StackPush")
593 .Input("handle: Ref(string)")
594 .Input("elem: T")
595 .Output("output: T")
596 .Attr("T: type")
597 .Attr("swap_memory: bool = false")
__anona2f81a441402(shape_inference::InferenceContext* c) 598 .SetShapeFn([](shape_inference::InferenceContext* c) {
599 c->set_output(0, c->input(1));
600 return Status::OK();
601 });
602
603 REGISTER_OP("StackPop")
604 .Input("handle: Ref(string)")
605 .Output("elem: elem_type")
606 .Attr("elem_type: type")
607 .SetShapeFn(shape_inference::UnknownShape);
608
609 REGISTER_OP("StackClose")
610 .Input("handle: Ref(string)")
611 .SetShapeFn(TwoElementVectorInputsAndScalarOutputs);
612
613 // --------------------------------------------------------------------------
614
615 REGISTER_OP("TensorArrayV3")
616 .Input("size: int32")
617 .Attr("dtype: type")
618 .Attr("element_shape: shape = { unknown_rank: true }")
619 .Attr("dynamic_size: bool = false")
620 .Attr("clear_after_read: bool = true")
621 .Attr("identical_element_shapes: bool = false")
622 .Attr("tensor_array_name: string = ''")
623 .Output("handle: resource")
624 .Output("flow: float")
625 .SetIsStateful()
__anona2f81a441502(InferenceContext* c) 626 .SetShapeFn([](InferenceContext* c) {
627 ShapeHandle unused;
628 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
629 c->set_output(0, c->Vector(2));
630 c->set_output(1, c->Scalar());
631 bool identical_shapes;
632 TF_RETURN_IF_ERROR(
633 c->GetAttr("identical_element_shapes", &identical_shapes));
634 DataType t;
635 TF_RETURN_IF_ERROR(c->GetAttr("dtype", &t));
636 PartialTensorShape p;
637 TF_RETURN_IF_ERROR(c->GetAttr("element_shape", &p));
638 ShapeHandle s;
639 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s));
640 if (c->FullyDefined(s) || identical_shapes) {
641 c->set_output_handle_shapes_and_types(
642 0, std::vector<shape_inference::ShapeAndType>{{s, t}});
643 }
644 return Status::OK();
645 });
646
647 REGISTER_OP("TensorArrayGradV3")
648 .Input("handle: resource")
649 .Input("flow_in: float")
650 .Output("grad_handle: resource")
651 .Output("flow_out: float")
652 .Attr("source: string")
653 .SetIsStateful()
__anona2f81a441602(InferenceContext* c) 654 .SetShapeFn([](InferenceContext* c) {
655 ShapeHandle handle;
656 DimensionHandle unused_dim;
657 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
658 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
659 c->set_output(0, c->Vector(2));
660 c->set_output(1, c->Scalar());
661 if (c->input_handle_shapes_and_types(0)) {
662 c->set_output_handle_shapes_and_types(
663 0, *c->input_handle_shapes_and_types(0));
664 }
665 return Status::OK();
666 });
667
668 REGISTER_OP("TensorArrayGradWithShape")
669 .Input("handle: resource")
670 .Input("flow_in: float")
671 .Input("shape_to_prepend: int32")
672 .Output("grad_handle: resource")
673 .Output("flow_out: float")
674 .Attr("source: string")
675 .SetIsStateful()
__anona2f81a441702(InferenceContext* c) 676 .SetShapeFn([](InferenceContext* c) {
677 ShapeHandle handle;
678 DimensionHandle unused_dim;
679 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
680 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
681 c->set_output(0, c->Vector(2));
682 c->set_output(1, c->Scalar());
683 auto* shape_and_type = c->input_handle_shapes_and_types(0);
684 if (shape_and_type) {
685 auto input_shape = (*shape_and_type)[0].shape;
686 auto dtype = (*shape_and_type)[0].dtype;
687 // Note that shape_to_preped is a rank 1 Tensor representing a shape.
688 // The size of dimension 0 is the number of dimensions we need to add to
689 // output shape.
690 int64 prepend_rank = c->Value(c->Dim(c->input(2), 0));
691 if (c->RankKnown(input_shape) &&
692 prepend_rank != InferenceContext::kUnknownDim) {
693 int32 input_rank = c->Rank(input_shape);
694 std::vector<DimensionHandle> dims;
695 dims.reserve(prepend_rank + input_rank);
696 for (int i = 0; i < prepend_rank; ++i) {
697 dims.push_back(c->UnknownDim());
698 }
699 for (int i = 0; i < input_rank; ++i) {
700 dims.push_back(c->Dim(input_shape, i));
701 }
702 c->set_output_handle_shapes_and_types(0,
703 {{c->MakeShape(dims), dtype}});
704 } else {
705 c->set_output_handle_shapes_and_types(0,
706 {{c->UnknownShape(), dtype}});
707 }
708 }
709 return Status::OK();
710 });
711
712 REGISTER_OP("TensorArrayWriteV3")
713 .Input("handle: resource")
714 .Input("index: int32")
715 .Input("value: T")
716 .Input("flow_in: float")
717 .Output("flow_out: float")
718 .Attr("T: type")
__anona2f81a441802(InferenceContext* c) 719 .SetShapeFn([](InferenceContext* c) {
720 ShapeHandle handle;
721 DimensionHandle unused_dim;
722 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
723 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
724
725 ShapeHandle unused;
726 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
727 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
728
729 auto* handle_data = c->input_handle_shapes_and_types(0);
730 if (handle_data != nullptr && !handle_data->empty()) {
731 shape_inference::ShapeAndType shape_and_type = (*handle_data)[0];
732 ShapeHandle value_shape = c->input(2);
733 TF_RETURN_IF_ERROR(
734 c->Merge(shape_and_type.shape, value_shape, &unused));
735 }
736
737 return shape_inference::ScalarShape(c);
738 });
739
740 REGISTER_OP("TensorArrayReadV3")
741 .Input("handle: resource")
742 .Input("index: int32")
743 .Input("flow_in: float")
744 .Output("value: dtype")
745 .Attr("dtype: type")
__anona2f81a441902(InferenceContext* c) 746 .SetShapeFn([](InferenceContext* c) {
747 ShapeHandle handle;
748 DimensionHandle unused_dim;
749 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
750 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
751 ShapeHandle unused;
752 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
753 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
754 auto shapes = c->input_handle_shapes_and_types(0);
755 if (shapes != nullptr && !shapes->empty()) {
756 ShapeHandle tensor_shape = shapes->at(0).shape;
757 c->set_output(0, tensor_shape);
758 return Status::OK();
759 } else {
760 return shape_inference::UnknownShape(c);
761 }
762 });
763
764 REGISTER_OP("TensorArrayGatherV3")
765 .Input("handle: resource")
766 .Input("indices: int32")
767 .Input("flow_in: float")
768 .Output("value: dtype")
769 .Attr("dtype: type")
770 .Attr("element_shape: shape = { unknown_rank: true }")
__anona2f81a441a02(InferenceContext* c) 771 .SetShapeFn([](InferenceContext* c) {
772 ShapeHandle indices;
773 ShapeHandle unused;
774 DimensionHandle unused_dim;
775 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
776 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices));
777 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim));
778 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
779 auto shapes = c->input_handle_shapes_and_types(0);
780 if (shapes != nullptr && !shapes->empty()) {
781 ShapeHandle tensor_shape = shapes->at(0).shape;
782 ShapeHandle output_shape;
783 TF_RETURN_IF_ERROR(
784 c->Concatenate(indices, tensor_shape, &output_shape));
785 c->set_output(0, output_shape);
786 return Status::OK();
787 } else {
788 PartialTensorShape p;
789 TF_RETURN_IF_ERROR(c->GetAttr("element_shape", &p));
790 ShapeHandle s;
791 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s));
792 ShapeHandle output_shape;
793 TF_RETURN_IF_ERROR(c->Concatenate(indices, s, &output_shape));
794 c->set_output(0, output_shape);
795 return Status::OK();
796 }
797 });
798
799 REGISTER_OP("TensorArrayScatterV3")
800 .Input("handle: resource")
801 .Input("indices: int32")
802 .Input("value: T")
803 .Input("flow_in: float")
804 .Output("flow_out: float")
805 .Attr("T: type")
__anona2f81a441b02(InferenceContext* c) 806 .SetShapeFn([](InferenceContext* c) {
807 ShapeHandle indices;
808 ShapeHandle unused;
809 DimensionHandle unused_dim;
810 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
811 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices));
812 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim));
813 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
814 ShapeHandle value_shape;
815 // Assert that the length of the indices tensor is equal to the first
816 // dimension of the value tensor.
817 TF_RETURN_IF_ERROR(
818 c->MergePrefix(c->input(2), indices, &value_shape, &indices));
819 auto shapes = c->input_handle_shapes_and_types(0);
820 if (shapes != nullptr && !shapes->empty()) {
821 ShapeHandle tensor_shape = shapes->at(0).shape;
822 ShapeHandle fed_shape;
823 TF_RETURN_IF_ERROR(c->Subshape(value_shape, 1, &fed_shape));
824 TF_RETURN_IF_ERROR(c->Merge(tensor_shape, fed_shape, &fed_shape));
825 }
826 return shape_inference::ScalarShape(c);
827 });
828
829 REGISTER_OP("TensorArrayConcatV3")
830 .Input("handle: resource")
831 .Input("flow_in: float")
832 .Output("value: dtype")
833 .Output("lengths: int64")
834 .Attr("dtype: type")
835 .Attr("element_shape_except0: shape = { unknown_rank: true }")
__anona2f81a441c02(InferenceContext* c) 836 .SetShapeFn([](InferenceContext* c) {
837 ShapeHandle handle;
838 DimensionHandle unused_dim;
839 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
840 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
841 ShapeHandle unused;
842 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
843 c->set_output(0, c->UnknownShape());
844 c->set_output(1, c->Vector(c->UnknownDim()));
845 return Status::OK();
846 });
847
848 REGISTER_OP("TensorArraySplitV3")
849 .Input("handle: resource")
850 .Input("value: T")
851 .Input("lengths: int64")
852 .Input("flow_in: float")
853 .Output("flow_out: float")
854 .Attr("T: type")
__anona2f81a441d02(InferenceContext* c) 855 .SetShapeFn([](InferenceContext* c) {
856 ShapeHandle handle;
857 DimensionHandle unused_dim;
858 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
859 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
860 ShapeHandle unused;
861 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
862 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
863 return shape_inference::ScalarShape(c);
864 });
865
866 REGISTER_OP("TensorArraySizeV3")
867 .Input("handle: resource")
868 .Input("flow_in: float")
869 .Output("size: int32")
__anona2f81a441e02(InferenceContext* c) 870 .SetShapeFn([](InferenceContext* c) {
871 ShapeHandle handle;
872 DimensionHandle unused_dim;
873 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
874 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
875 return shape_inference::ScalarShape(c);
876 });
877
878 REGISTER_OP("TensorArrayCloseV3")
879 .Input("handle: resource")
__anona2f81a441f02(InferenceContext* c) 880 .SetShapeFn([](InferenceContext* c) {
881 ShapeHandle handle;
882 DimensionHandle unused_dim;
883 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
884 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
885 return Status::OK();
886 });
887
888 // --------------------------------------------------------------------------
889
890 // Deprecated TensorArray methods
891
892 REGISTER_OP("TensorArray")
893 .Input("size: int32")
894 .Attr("dtype: type")
895 .Attr("dynamic_size: bool = false")
896 .Attr("clear_after_read: bool = true")
897 .Attr("tensor_array_name: string = ''")
898 .Attr("element_shape: shape = { unknown_rank: true }")
899 .Output("handle: Ref(string)")
900 .SetIsStateful()
901 .SetShapeFn(shape_inference::UnknownShape)
902 .Deprecated(16, "Use TensorArrayV3");
903 REGISTER_OP("TensorArrayV2")
904 .Input("size: int32")
905 .Attr("dtype: type")
906 .Attr("element_shape: shape = { unknown_rank: true }")
907 .Attr("dynamic_size: bool = false")
908 .Attr("clear_after_read: bool = true")
909 .Attr("tensor_array_name: string = ''")
910 .Output("handle: string")
911 .SetIsStateful()
__anona2f81a442002(InferenceContext* c) 912 .SetShapeFn([](InferenceContext* c) {
913 ShapeHandle unused;
914 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
915 c->set_output(0, c->Vector(2));
916 return Status::OK();
917 })
918 .Deprecated(26, "Use TensorArrayV3");
919 REGISTER_OP("TensorArrayGrad")
920 .Input("handle: string")
921 .Input("flow_in: float")
922 .Output("grad_handle: Ref(string)")
923 .Attr("source: string")
924 .SetIsStateful()
925 .SetShapeFn(shape_inference::UnknownShape)
926 .Deprecated(16, "Use TensorArrayGradV3");
927 REGISTER_OP("TensorArrayGradV2")
928 .Input("handle: string")
929 .Input("flow_in: float")
930 .Output("grad_handle: string")
931 .Attr("source: string")
932 .SetIsStateful()
__anona2f81a442102(InferenceContext* c) 933 .SetShapeFn([](InferenceContext* c) {
934 ShapeHandle handle;
935 DimensionHandle unused_dim;
936 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
937 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
938 c->set_output(0, c->Vector(2));
939 return Status::OK();
940 })
941 .Deprecated(26, "Use TensorArrayGradV3");
942 REGISTER_OP("TensorArrayWrite")
943 .Input("handle: Ref(string)")
944 .Input("index: int32")
945 .Input("value: T")
946 .Input("flow_in: float")
947 .Output("flow_out: float")
948 .Attr("T: type")
949 .SetShapeFn(shape_inference::UnknownShape)
950 .Deprecated(16, "Use TensorArrayWriteV3");
951 REGISTER_OP("TensorArrayWriteV2")
952 .Input("handle: string")
953 .Input("index: int32")
954 .Input("value: T")
955 .Input("flow_in: float")
956 .Output("flow_out: float")
957 .Attr("T: type")
__anona2f81a442202(InferenceContext* c) 958 .SetShapeFn([](InferenceContext* c) {
959 ShapeHandle handle;
960 DimensionHandle unused_dim;
961 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
962 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
963
964 ShapeHandle unused;
965 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
966 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
967 return shape_inference::ScalarShape(c);
968 })
969 .Deprecated(26, "Use TensorArrayWriteV3");
970 REGISTER_OP("TensorArrayRead")
971 .Input("handle: Ref(string)")
972 .Input("index: int32")
973 .Input("flow_in: float")
974 .Output("value: dtype")
975 .Attr("dtype: type")
976 .SetShapeFn(shape_inference::UnknownShape)
977 .Deprecated(16, "Use TensorArrayReadV3");
978 REGISTER_OP("TensorArrayReadV2")
979 .Input("handle: string")
980 .Input("index: int32")
981 .Input("flow_in: float")
982 .Output("value: dtype")
983 .Attr("dtype: type")
__anona2f81a442302(InferenceContext* c) 984 .SetShapeFn([](InferenceContext* c) {
985 ShapeHandle handle;
986 DimensionHandle unused_dim;
987 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
988 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
989 ShapeHandle unused;
990 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
991 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
992 return shape_inference::UnknownShape(c);
993 })
994 .Deprecated(26, "Use TensorArrayReadV3");
995 REGISTER_OP("TensorArrayPack")
996 .Input("handle: Ref(string)")
997 .Input("flow_in: float")
998 .Output("value: dtype")
999 .Attr("dtype: type")
1000 .Attr("element_shape: shape = { unknown_rank: true }")
1001 .SetShapeFn(shape_inference::UnknownShape)
1002 .Deprecated(16, "Use TensorArrayGatherV3 with RangeOp");
1003 REGISTER_OP("TensorArrayUnpack")
1004 .Input("handle: Ref(string)")
1005 .Input("value: T")
1006 .Input("flow_in: float")
1007 .Output("flow_out: float")
1008 .Attr("T: type")
1009 .SetShapeFn(shape_inference::UnknownShape)
1010 .Deprecated(20, "Use TensorArrayScatterV3 with RangeOp");
1011 REGISTER_OP("TensorArrayGather")
1012 .Input("handle: Ref(string)")
1013 .Input("indices: int32")
1014 .Input("flow_in: float")
1015 .Output("value: dtype")
1016 .Attr("dtype: type")
1017 .Attr("element_shape: shape = { unknown_rank: true }")
1018 .SetShapeFn(shape_inference::UnknownShape)
1019 .Deprecated(16, "Use TensorArrayGatherV3");
1020 REGISTER_OP("TensorArrayGatherV2")
1021 .Input("handle: string")
1022 .Input("indices: int32")
1023 .Input("flow_in: float")
1024 .Output("value: dtype")
1025 .Attr("dtype: type")
1026 .Attr("element_shape: shape = { unknown_rank: true }")
__anona2f81a442402(InferenceContext* c) 1027 .SetShapeFn([](InferenceContext* c) {
1028 ShapeHandle unused;
1029 DimensionHandle unused_dim;
1030 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
1031 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
1032 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim));
1033 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1034 return shape_inference::UnknownShape(c);
1035 })
1036 .Deprecated(26, "Use TensorArrayGatherV3");
1037 REGISTER_OP("TensorArrayScatter")
1038 .Input("handle: Ref(string)")
1039 .Input("indices: int32")
1040 .Input("value: T")
1041 .Input("flow_in: float")
1042 .Output("flow_out: float")
1043 .Attr("T: type")
1044 .SetShapeFn(shape_inference::UnknownShape)
1045 .Deprecated(19, "Use TensorArrayGradV3");
1046 REGISTER_OP("TensorArrayScatterV2")
1047 .Input("handle: string")
1048 .Input("indices: int32")
1049 .Input("value: T")
1050 .Input("flow_in: float")
1051 .Output("flow_out: float")
1052 .Attr("T: type")
__anona2f81a442502(InferenceContext* c) 1053 .SetShapeFn([](InferenceContext* c) {
1054 ShapeHandle unused;
1055 DimensionHandle unused_dim;
1056 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
1057 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
1058 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim));
1059 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1060 return shape_inference::ScalarShape(c);
1061 })
1062 .Deprecated(26, "Use TensorArrayScatterV3");
1063 REGISTER_OP("TensorArrayConcat")
1064 .Input("handle: Ref(string)")
1065 .Input("flow_in: float")
1066 .Output("value: dtype")
1067 .Output("lengths: int64")
1068 .Attr("dtype: type")
1069 .Attr("element_shape_except0: shape = { unknown_rank: true }")
1070 .SetShapeFn(shape_inference::UnknownShape)
1071 .Deprecated(16, "Use TensorArrayGradV3");
1072 REGISTER_OP("TensorArrayConcatV2")
1073 .Input("handle: string")
1074 .Input("flow_in: float")
1075 .Output("value: dtype")
1076 .Output("lengths: int64")
1077 .Attr("dtype: type")
1078 .Attr("element_shape_except0: shape = { unknown_rank: true }")
__anona2f81a442602(InferenceContext* c) 1079 .SetShapeFn([](InferenceContext* c) {
1080 ShapeHandle handle;
1081 DimensionHandle unused_dim;
1082 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
1083 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
1084 ShapeHandle unused;
1085 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1086 c->set_output(0, c->UnknownShape());
1087 c->set_output(1, c->Vector(c->UnknownDim()));
1088 return Status::OK();
1089 });
1090 REGISTER_OP("TensorArraySplit")
1091 .Input("handle: Ref(string)")
1092 .Input("value: T")
1093 .Input("lengths: int64")
1094 .Input("flow_in: float")
1095 .Output("flow_out: float")
1096 .Attr("T: type")
1097 .SetShapeFn(shape_inference::UnknownShape)
1098 .Deprecated(16, "Use TensorArraySplitV3");
1099 REGISTER_OP("TensorArraySplitV2")
1100 .Input("handle: string")
1101 .Input("value: T")
1102 .Input("lengths: int64")
1103 .Input("flow_in: float")
1104 .Output("flow_out: float")
1105 .Attr("T: type")
__anona2f81a442702(InferenceContext* c) 1106 .SetShapeFn([](InferenceContext* c) {
1107 ShapeHandle handle;
1108 DimensionHandle unused_dim;
1109 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
1110 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
1111 ShapeHandle unused;
1112 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
1113 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1114 return shape_inference::ScalarShape(c);
1115 })
1116 .Deprecated(26, "Use TensorArraySplitV3");
1117 REGISTER_OP("TensorArraySize")
1118 .Input("handle: Ref(string)")
1119 .Input("flow_in: float")
1120 .Output("size: int32")
1121 .SetShapeFn(shape_inference::UnknownShape)
1122 .Deprecated(16, "Use TensorArraySizeV3");
1123 REGISTER_OP("TensorArraySizeV2")
1124 .Input("handle: string")
1125 .Input("flow_in: float")
1126 .Output("size: int32")
__anona2f81a442802(InferenceContext* c) 1127 .SetShapeFn([](InferenceContext* c) {
1128 ShapeHandle handle;
1129 DimensionHandle unused_dim;
1130 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
1131 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
1132 return shape_inference::ScalarShape(c);
1133 })
1134 .Deprecated(26, "Use TensorArraySizeV3");
1135 REGISTER_OP("TensorArrayClose")
1136 .Input("handle: Ref(string)")
__anona2f81a442902(InferenceContext* c) 1137 .SetShapeFn([](InferenceContext* c) { return Status::OK(); })
1138 .Deprecated(16, "Use TensorArrayCloseV3");
1139 REGISTER_OP("TensorArrayCloseV2")
1140 .Input("handle: string")
__anona2f81a442a02(InferenceContext* c) 1141 .SetShapeFn([](InferenceContext* c) {
1142 ShapeHandle handle;
1143 DimensionHandle unused_dim;
1144 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
1145 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
1146 return Status::OK();
1147 })
1148 .Deprecated(26, "Use TensorArrayCloseV3");
1149
1150 // --------------------------------------------------------------------------
1151
1152 REGISTER_OP("Barrier")
1153 .SetIsStateful()
1154 .Output("handle: Ref(string)")
1155 .Attr("component_types: list(type) >= 1")
1156 .Attr("shapes: list(shape) >= 0 = []")
1157 .Attr("capacity: int = -1")
1158 .Attr("container: string = ''")
1159 .Attr("shared_name: string = ''")
1160 .SetShapeFn(TwoElementOutput);
1161
1162 REGISTER_OP("BarrierInsertMany")
1163 .Input("handle: Ref(string)")
1164 .Input("keys: string")
1165 .Input("values: T")
1166 .Attr("T: type")
1167 .Attr("component_index: int")
__anona2f81a442b02(InferenceContext* c) 1168 .SetShapeFn([](InferenceContext* c) {
1169 ShapeHandle keys = c->input(1);
1170 ShapeHandle values = c->input(2);
1171 ShapeHandle handle;
1172 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
1173 DimensionHandle unused_dim;
1174 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
1175 TF_RETURN_IF_ERROR(c->WithRank(keys, 1, &keys));
1176 TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values));
1177 TF_RETURN_IF_ERROR(c->Merge(keys, c->Vector(c->Dim(values, 0)), &handle));
1178 return Status::OK();
1179 });
1180
1181 REGISTER_OP("BarrierTakeMany")
1182 .Input("handle: Ref(string)")
1183 .Input("num_elements: int32")
1184 .Output("indices: int64")
1185 .Output("keys: string")
1186 .Output("values: component_types")
1187 .Attr("component_types: list(type) >= 1")
1188 .Attr("allow_small_batch: bool = false")
1189 .Attr("wait_for_incomplete: bool = false")
1190 .Attr("timeout_ms: int = -1")
1191 .SetShapeFn(shape_inference::UnknownShape);
1192
1193 REGISTER_OP("BarrierClose")
1194 .Input("handle: Ref(string)")
1195 .SetShapeFn(TwoElementVectorInputsAndScalarOutputs)
1196 .Attr("cancel_pending_enqueues: bool = false");
1197
1198 REGISTER_OP("BarrierReadySize")
1199 .Input("handle: Ref(string)")
1200 .Output("size: int32")
1201 .SetShapeFn(TwoElementVectorInputsAndScalarOutputs);
1202
1203 REGISTER_OP("BarrierIncompleteSize")
1204 .Input("handle: Ref(string)")
1205 .Output("size: int32")
1206 .SetShapeFn(TwoElementVectorInputsAndScalarOutputs);
1207
1208 // --------------------------------------------------------------------------
1209
1210 REGISTER_OP("GetSessionHandle")
1211 .Input("value: T")
1212 .Output("handle: string")
1213 .Attr("T: type")
1214 .SetIsStateful()
1215 .SetShapeFn(shape_inference::ScalarShape);
1216
1217 REGISTER_OP("GetSessionHandleV2")
1218 .Input("value: T")
1219 .Output("handle: resource")
1220 .Attr("T: type")
1221 .SetIsStateful()
1222 .SetShapeFn(shape_inference::ScalarShape);
1223
1224 REGISTER_OP("GetSessionTensor")
1225 .Input("handle: string")
1226 .Output("value: dtype")
1227 .Attr("dtype: type")
1228 .SetIsStateful()
__anona2f81a442c02(InferenceContext* c) 1229 .SetShapeFn([](InferenceContext* c) {
1230 ShapeHandle unused;
1231 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
1232 return shape_inference::UnknownShape(c);
1233 });
1234
1235 REGISTER_OP("DeleteSessionTensor")
1236 .Input("handle: string")
1237 .SetIsStateful()
__anona2f81a442d02(InferenceContext* c) 1238 .SetShapeFn([](InferenceContext* c) {
1239 ShapeHandle unused;
1240 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
1241 return Status::OK();
1242 });
1243
1244 REGISTER_OP("Stage")
1245 .Input("values: dtypes")
1246 .Attr("capacity: int >= 0 = 0")
1247 .Attr("memory_limit: int >= 0 = 0")
1248 .Attr("dtypes: list(type)")
1249 .Attr("container: string = ''")
1250 .Attr("shared_name: string = ''")
1251 .SetShapeFn(shape_inference::UnknownShape)
1252 .SetIsStateful();
1253
1254 REGISTER_OP("Unstage")
1255 .Output("values: dtypes")
1256 .Attr("capacity: int >= 0 = 0")
1257 .Attr("memory_limit: int >= 0 = 0")
1258 .Attr("dtypes: list(type)")
1259 .Attr("container: string = ''")
1260 .Attr("shared_name: string = ''")
1261 .SetShapeFn(shape_inference::UnknownShape)
1262 .SetIsStateful();
1263
1264 REGISTER_OP("StagePeek")
1265 .Input("index: int32")
1266 .Output("values: dtypes")
1267 .Attr("capacity: int >= 0 = 0")
1268 .Attr("memory_limit: int >= 0 = 0")
1269 .Attr("dtypes: list(type)")
1270 .Attr("container: string = ''")
1271 .Attr("shared_name: string = ''")
1272 .SetShapeFn(shape_inference::UnknownShape)
1273 .SetIsStateful();
1274
1275 REGISTER_OP("StageSize")
1276 .Output("size: int32")
1277 .Attr("capacity: int >= 0 = 0")
1278 .Attr("memory_limit: int >= 0 = 0")
1279 .Attr("dtypes: list(type)")
1280 .Attr("container: string = ''")
1281 .Attr("shared_name: string = ''")
1282 .SetShapeFn(shape_inference::ScalarShape)
1283 .SetIsStateful();
1284
1285 REGISTER_OP("StageClear")
1286 .Attr("capacity: int >= 0 = 0")
1287 .Attr("memory_limit: int >= 0 = 0")
1288 .Attr("dtypes: list(type)")
1289 .Attr("container: string = ''")
1290 .Attr("shared_name: string = ''")
1291 .SetShapeFn(shape_inference::UnknownShape)
1292 .SetIsStateful();
1293
1294 // UnorderedMap
1295 REGISTER_OP("MapStage")
1296 .Input("key: int64")
1297 .Input("indices: int32")
1298 .Input("values: fake_dtypes")
1299 .Attr("capacity: int >= 0 = 0")
1300 .Attr("memory_limit: int >= 0 = 0")
1301 .Attr("dtypes: list(type)")
1302 .Attr("fake_dtypes: list(type)")
1303 .Attr("container: string = ''")
1304 .Attr("shared_name: string = ''")
1305 .SetShapeFn(tensorflow::shape_inference::NoOutputs)
1306 .SetIsStateful();
1307
1308 REGISTER_OP("MapPeek")
1309 .Input("key: int64")
1310 .Input("indices: int32")
1311 .Output("values: dtypes")
1312 .Attr("capacity: int >= 0 = 0")
1313 .Attr("memory_limit: int >= 0 = 0")
1314 .Attr("dtypes: list(type)")
1315 .Attr("container: string = ''")
1316 .Attr("shared_name: string = ''")
1317 .SetShapeFn(tensorflow::shape_inference::UnknownShape)
1318 .SetIsStateful();
1319
1320 REGISTER_OP("MapUnstage")
1321 .Input("key: int64")
1322 .Input("indices: int32")
1323 .Output("values: dtypes")
1324 .Attr("capacity: int >= 0 = 0")
1325 .Attr("memory_limit: int >= 0 = 0")
1326 .Attr("dtypes: list(type)")
1327 .Attr("container: string = ''")
1328 .Attr("shared_name: string = ''")
1329 .SetShapeFn(tensorflow::shape_inference::UnknownShape)
1330 .SetIsStateful();
1331
1332 REGISTER_OP("MapUnstageNoKey")
1333 .Input("indices: int32")
1334 .Output("key: int64")
1335 .Output("values: dtypes")
1336 .Attr("capacity: int >= 0 = 0")
1337 .Attr("memory_limit: int >= 0 = 0")
1338 .Attr("dtypes: list(type)")
1339 .Attr("container: string = ''")
1340 .Attr("shared_name: string = ''")
1341 .SetShapeFn(tensorflow::shape_inference::UnknownShape)
1342 .SetIsStateful();
1343
1344 REGISTER_OP("MapSize")
1345 .Output("size: int32")
1346 .Attr("capacity: int >= 0 = 0")
1347 .Attr("memory_limit: int >= 0 = 0")
1348 .Attr("dtypes: list(type)")
1349 .Attr("container: string = ''")
1350 .Attr("shared_name: string = ''")
1351 .SetShapeFn(tensorflow::shape_inference::ScalarShape)
1352 .SetIsStateful();
1353
1354 REGISTER_OP("MapIncompleteSize")
1355 .Output("size: int32")
1356 .Attr("capacity: int >= 0 = 0")
1357 .Attr("memory_limit: int >= 0 = 0")
1358 .Attr("dtypes: list(type)")
1359 .Attr("container: string = ''")
1360 .Attr("shared_name: string = ''")
1361 .SetShapeFn(tensorflow::shape_inference::ScalarShape)
1362 .SetIsStateful();
1363
1364 REGISTER_OP("MapClear")
1365 .Attr("capacity: int >= 0 = 0")
1366 .Attr("memory_limit: int >= 0 = 0")
1367 .Attr("dtypes: list(type)")
1368 .Attr("container: string = ''")
1369 .Attr("shared_name: string = ''")
1370 .SetShapeFn(tensorflow::shape_inference::NoOutputs)
1371 .SetIsStateful();
1372
1373 // OrderedMap
1374 REGISTER_OP("OrderedMapStage")
1375 .Input("key: int64")
1376 .Input("indices: int32")
1377 .Input("values: fake_dtypes")
1378 .Attr("capacity: int >= 0 = 0")
1379 .Attr("memory_limit: int >= 0 = 0")
1380 .Attr("dtypes: list(type)")
1381 .Attr("fake_dtypes: list(type)")
1382 .Attr("container: string = ''")
1383 .Attr("shared_name: string = ''")
1384 .SetShapeFn(tensorflow::shape_inference::NoOutputs)
1385 .SetIsStateful();
1386
1387 REGISTER_OP("OrderedMapPeek")
1388 .Input("key: int64")
1389 .Input("indices: int32")
1390 .Output("values: dtypes")
1391 .Attr("capacity: int >= 0 = 0")
1392 .Attr("memory_limit: int >= 0 = 0")
1393 .Attr("dtypes: list(type)")
1394 .Attr("container: string = ''")
1395 .Attr("shared_name: string = ''")
1396 .SetShapeFn(tensorflow::shape_inference::UnknownShape)
1397 .SetIsStateful();
1398
1399 REGISTER_OP("OrderedMapUnstage")
1400 .Input("key: int64")
1401 .Input("indices: int32")
1402 .Output("values: dtypes")
1403 .Attr("capacity: int >= 0 = 0")
1404 .Attr("memory_limit: int >= 0 = 0")
1405 .Attr("dtypes: list(type)")
1406 .Attr("container: string = ''")
1407 .Attr("shared_name: string = ''")
1408 .SetShapeFn(tensorflow::shape_inference::UnknownShape)
1409 .SetIsStateful();
1410
1411 REGISTER_OP("OrderedMapUnstageNoKey")
1412 .Input("indices: int32")
1413 .Output("key: int64")
1414 .Output("values: dtypes")
1415 .Attr("capacity: int >= 0 = 0")
1416 .Attr("memory_limit: int >= 0 = 0")
1417 .Attr("dtypes: list(type)")
1418 .Attr("container: string = ''")
1419 .Attr("shared_name: string = ''")
1420 .SetShapeFn(tensorflow::shape_inference::UnknownShape)
1421 .SetIsStateful();
1422
1423 REGISTER_OP("OrderedMapSize")
1424 .Output("size: int32")
1425 .Attr("capacity: int >= 0 = 0")
1426 .Attr("memory_limit: int >= 0 = 0")
1427 .Attr("dtypes: list(type)")
1428 .Attr("container: string = ''")
1429 .Attr("shared_name: string = ''")
1430 .SetShapeFn(tensorflow::shape_inference::ScalarShape)
1431 .SetIsStateful();
1432
1433 REGISTER_OP("OrderedMapIncompleteSize")
1434 .Output("size: int32")
1435 .Attr("capacity: int >= 0 = 0")
1436 .Attr("memory_limit: int >= 0 = 0")
1437 .Attr("dtypes: list(type)")
1438 .Attr("container: string = ''")
1439 .Attr("shared_name: string = ''")
1440 .SetShapeFn(tensorflow::shape_inference::ScalarShape)
1441 .SetIsStateful();
1442
1443 REGISTER_OP("OrderedMapClear")
1444 .Attr("capacity: int >= 0 = 0")
1445 .Attr("memory_limit: int >= 0 = 0")
1446 .Attr("dtypes: list(type)")
1447 .Attr("container: string = ''")
1448 .Attr("shared_name: string = ''")
1449 .SetShapeFn(tensorflow::shape_inference::NoOutputs)
1450 .SetIsStateful();
1451
1452 REGISTER_OP("RecordInput")
1453 .Output("records: string")
1454 .Attr("file_pattern: string")
1455 .Attr("file_random_seed: int = 301")
1456 .Attr("file_shuffle_shift_ratio: float = 0")
1457 .Attr("file_buffer_size: int = 10000")
1458 .Attr("file_parallelism: int = 16")
1459 .Attr("batch_size: int = 32")
1460 .Attr("compression_type: string = ''")
1461 .SetIsStateful()
1462 .SetShapeFn(shape_inference::UnknownShape);
1463
1464 } // namespace tensorflow
1465