1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/op_def_builder.h"
19 #include "tensorflow/core/framework/shape_inference.h"
20
21 namespace tensorflow {
22
23 using shape_inference::DimensionHandle;
24 using shape_inference::InferenceContext;
25 using shape_inference::ShapeHandle;
26
27 namespace {
28
DequeueManyV2Shape(InferenceContext * c,ShapeHandle n_shape)29 Status DequeueManyV2Shape(InferenceContext* c, ShapeHandle n_shape) {
30 auto* t = c->input_handle_shapes_and_types(0);
31 if (t != nullptr && t->size() == c->num_outputs()) {
32 for (int i = 0; i < c->num_outputs(); ++i) {
33 ShapeHandle combined_shape;
34 TF_RETURN_IF_ERROR(
35 c->Concatenate(n_shape, (*t)[i].shape, &combined_shape));
36 c->set_output(i, combined_shape);
37 }
38 return Status::OK();
39 } else {
40 return shape_inference::UnknownShape(c);
41 }
42 }
43
44 } // namespace
45
46 // --------------------------------------------------------------------------
47
48 REGISTER_OP("DynamicPartition")
49 .Input("data: T")
50 .Input("partitions: int32")
51 .Output("outputs: num_partitions * T")
52 .Attr("num_partitions: int")
53 .Attr("T: type")
__anona2f81a440202(InferenceContext* c) 54 .SetShapeFn([](InferenceContext* c) {
55 int64 num_partitions;
56 TF_RETURN_IF_ERROR(c->GetAttr("num_partitions", &num_partitions));
57
58 ShapeHandle data_shape = c->input(0);
59 ShapeHandle partitions_shape = c->input(1);
60
61 if (!c->RankKnown(partitions_shape)) {
62 return shape_inference::UnknownShape(c);
63 }
64
65 const int64 rank = c->Rank(partitions_shape);
66
67 // data shape must start with partitions_shape
68 ShapeHandle unused;
69 TF_RETURN_IF_ERROR(
70 c->MergePrefix(data_shape, partitions_shape, &unused, &unused));
71
72 // The partition shape is dynamic in the 0th dimension, and matches
73 // data_shape in the remaining dimensions.
74 ShapeHandle unknown_dim0 = c->MakeShape({c->UnknownDim()});
75
76 ShapeHandle data_suffix_shape;
77 TF_RETURN_IF_ERROR(c->Subshape(data_shape, rank, &data_suffix_shape));
78 ShapeHandle result_shape;
79 TF_RETURN_IF_ERROR(
80 c->Concatenate(unknown_dim0, data_suffix_shape, &result_shape));
81
82 for (int i = 0; i < c->num_outputs(); ++i) {
83 c->set_output(i, result_shape);
84 }
85
86 return Status::OK();
87 });
88
89 namespace {
90
DynamicStitchShapeFunction(InferenceContext * c)91 Status DynamicStitchShapeFunction(InferenceContext* c) {
92 int32 num_partitions;
93 TF_RETURN_IF_ERROR(c->GetAttr("N", &num_partitions));
94
95 bool all_indices_constant = true;
96 int32 max_index = 0;
97 ShapeHandle extra_shape = c->UnknownShape();
98 for (int i = 0; i < num_partitions; ++i) {
99 const Tensor* indices_t = c->input_tensor(i);
100 if (indices_t == nullptr) {
101 all_indices_constant = false;
102 }
103
104 ShapeHandle indices_shape = c->input(i);
105 ShapeHandle data_shape = c->input(i + num_partitions);
106 if (!c->RankKnown(indices_shape)) {
107 continue;
108 }
109 const int64 indices_rank = c->Rank(indices_shape);
110
111 // Assert that data_shape starts with indices_shape.
112 ShapeHandle unused;
113 TF_RETURN_IF_ERROR(
114 c->MergePrefix(data_shape, indices_shape, &unused, &unused));
115
116 // The rest belongs to output.
117 ShapeHandle rest;
118 TF_RETURN_IF_ERROR(c->Subshape(data_shape, indices_rank, &rest));
119 TF_RETURN_IF_ERROR(c->Merge(extra_shape, rest, &extra_shape));
120
121 if (indices_t != nullptr) {
122 // The length is based on the highest index from flattened indices.
123 const int32* indices = indices_t->flat<int32>().data();
124 int64 count = indices_t->NumElements();
125 for (int64 i = 0; i < count; ++i) {
126 if (indices[i] > max_index) {
127 max_index = indices[i];
128 }
129 }
130 }
131 }
132
133 ShapeHandle output_shape = c->Vector(
134 all_indices_constant ? c->MakeDim(max_index + 1) : c->UnknownDim());
135 TF_RETURN_IF_ERROR(c->Concatenate(output_shape, extra_shape, &output_shape));
136 c->set_output(0, output_shape);
137 return Status::OK();
138 }
139
140 } // namespace
141
142 REGISTER_OP("DynamicStitch")
143 .Input("indices: N * int32")
144 .Input("data: N * T")
145 .Output("merged: T")
146 .Attr("N : int >= 1")
147 .Attr("T : type")
148 .SetShapeFn(DynamicStitchShapeFunction);
149
150 REGISTER_OP("ParallelDynamicStitch")
151 .Input("indices: N * int32")
152 .Input("data: N * T")
153 .Output("merged: T")
154 .Attr("N : int >= 1")
155 .Attr("T : type")
156 .SetShapeFn(DynamicStitchShapeFunction);
157
158 // --------------------------------------------------------------------------
159
160 namespace {
TwoElementVectorInputsAndScalarOutputs(InferenceContext * c)161 Status TwoElementVectorInputsAndScalarOutputs(InferenceContext* c) {
162 ShapeHandle handle;
163 DimensionHandle unused_handle;
164 for (int i = 0; i < c->num_inputs(); ++i) {
165 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle));
166 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle));
167 }
168 for (int i = 0; i < c->num_outputs(); ++i) {
169 c->set_output(i, c->Scalar());
170 }
171 return Status::OK();
172 }
173
TwoElementOutput(InferenceContext * c)174 Status TwoElementOutput(InferenceContext* c) {
175 c->set_output(0, c->Vector(2));
176 return Status::OK();
177 }
178 } // namespace
179
180 REGISTER_OP("RandomShuffleQueue")
181 .Output("handle: Ref(string)")
182 .Attr("component_types: list(type) >= 1")
183 .Attr("shapes: list(shape) >= 0 = []")
184 .Attr("capacity: int = -1")
185 .Attr("min_after_dequeue: int = 0")
186 .Attr("seed: int = 0")
187 .Attr("seed2: int = 0")
188 .Attr("container: string = ''")
189 .Attr("shared_name: string = ''")
190 .SetIsStateful()
191 .SetShapeFn(TwoElementOutput);
192
193 REGISTER_OP("RandomShuffleQueueV2")
194 .Output("handle: resource")
195 .Attr("component_types: list(type) >= 1")
196 .Attr("shapes: list(shape) >= 0 = []")
197 .Attr("capacity: int = -1")
198 .Attr("min_after_dequeue: int = 0")
199 .Attr("seed: int = 0")
200 .Attr("seed2: int = 0")
201 .Attr("container: string = ''")
202 .Attr("shared_name: string = ''")
203 .SetIsStateful()
204 .SetShapeFn(shape_inference::ScalarShape);
205
206 REGISTER_OP("FIFOQueue")
207 .Output("handle: Ref(string)")
208 .Attr("component_types: list(type) >= 1")
209 .Attr("shapes: list(shape) >= 0 = []")
210 .Attr("capacity: int = -1")
211 .Attr("container: string = ''")
212 .Attr("shared_name: string = ''")
213 .SetIsStateful()
214 .SetShapeFn(TwoElementOutput);
215
216 REGISTER_OP("FIFOQueueV2")
217 .Output("handle: resource")
218 .Attr("component_types: list(type) >= 1")
219 .Attr("shapes: list(shape) >= 0 = []")
220 .Attr("capacity: int = -1")
221 .Attr("container: string = ''")
222 .Attr("shared_name: string = ''")
223 .SetIsStateful()
224 .SetShapeFn(shape_inference::ScalarShape);
225
226 REGISTER_OP("PaddingFIFOQueue")
227 .Output("handle: Ref(string)")
228 .Attr("component_types: list(type) >= 1")
229 .Attr("shapes: list(shape) >= 0 = []")
230 .Attr("capacity: int = -1")
231 .Attr("container: string = ''")
232 .Attr("shared_name: string = ''")
233 .SetIsStateful()
234 .SetShapeFn(TwoElementOutput);
235
236 REGISTER_OP("PaddingFIFOQueueV2")
237 .Output("handle: resource")
238 .Attr("component_types: list(type) >= 1")
239 .Attr("shapes: list(shape) >= 0 = []")
240 .Attr("capacity: int = -1")
241 .Attr("container: string = ''")
242 .Attr("shared_name: string = ''")
243 .SetIsStateful()
244 .SetShapeFn(shape_inference::ScalarShape);
245
246 REGISTER_OP("PriorityQueue")
247 .Output("handle: Ref(string)")
248 .Attr("component_types: list(type) >= 0 = []")
249 .Attr("shapes: list(shape) >= 0")
250 .Attr("capacity: int = -1")
251 .Attr("container: string = ''")
252 .Attr("shared_name: string = ''")
253 .SetIsStateful()
254 .SetShapeFn(TwoElementOutput);
255
256 REGISTER_OP("PriorityQueueV2")
257 .Output("handle: resource")
258 .Attr("component_types: list(type) >= 0 = []")
259 .Attr("shapes: list(shape) >= 0")
260 .Attr("capacity: int = -1")
261 .Attr("container: string = ''")
262 .Attr("shared_name: string = ''")
263 .SetIsStateful()
264 .SetShapeFn(shape_inference::ScalarShape);
265
266 REGISTER_OP("FakeQueue")
267 .Input("resource: resource")
268 .Output("handle: Ref(string)")
269 .SetIsStateful()
270 .SetShapeFn(TwoElementOutput);
271
272 REGISTER_OP("QueueEnqueue")
273 .Input("handle: Ref(string)")
274 .Input("components: Tcomponents")
275 .Attr("Tcomponents: list(type) >= 1")
276 .Attr("timeout_ms: int = -1")
277 .SetShapeFn(shape_inference::UnknownShape);
278
279 REGISTER_OP("QueueEnqueueV2")
280 .Input("handle: resource")
281 .Input("components: Tcomponents")
282 .Attr("Tcomponents: list(type) >= 1")
283 .Attr("timeout_ms: int = -1")
284 .SetShapeFn(shape_inference::UnknownShape);
285
286 REGISTER_OP("QueueEnqueueMany")
287 .Input("handle: Ref(string)")
288 .Input("components: Tcomponents")
289 .Attr("Tcomponents: list(type) >= 1")
290 .Attr("timeout_ms: int = -1")
291 .SetShapeFn(shape_inference::UnknownShape);
292
293 REGISTER_OP("QueueEnqueueManyV2")
294 .Input("handle: resource")
295 .Input("components: Tcomponents")
296 .Attr("Tcomponents: list(type) >= 1")
297 .Attr("timeout_ms: int = -1")
298 .SetShapeFn(shape_inference::UnknownShape);
299
300 REGISTER_OP("QueueDequeue")
301 .Input("handle: Ref(string)")
302 .Output("components: component_types")
303 .Attr("component_types: list(type) >= 1")
304 .Attr("timeout_ms: int = -1")
305 .SetShapeFn(shape_inference::UnknownShape);
306
307 REGISTER_OP("QueueDequeueV2")
308 .Input("handle: resource")
309 .Output("components: component_types")
310 .Attr("component_types: list(type) >= 1")
311 .Attr("timeout_ms: int = -1")
__anona2f81a440502(InferenceContext* c) 312 .SetShapeFn([](InferenceContext* c) {
313 auto* t = c->input_handle_shapes_and_types(0);
314 if (t != nullptr && t->size() == c->num_outputs()) {
315 for (int i = 0; i < c->num_outputs(); ++i) {
316 c->set_output(i, (*t)[i].shape);
317 }
318 return Status::OK();
319 } else {
320 return shape_inference::UnknownShape(c);
321 }
322 });
323
324 REGISTER_OP("QueueDequeueMany")
325 .Input("handle: Ref(string)")
326 .Input("n: int32")
327 .Output("components: component_types")
328 .Attr("component_types: list(type) >= 1")
329 .Attr("timeout_ms: int = -1")
330 .SetShapeFn(shape_inference::UnknownShape);
331
332 REGISTER_OP("QueueDequeueManyV2")
333 .Input("handle: resource")
334 .Input("n: int32")
335 .Output("components: component_types")
336 .Attr("component_types: list(type) >= 1")
337 .Attr("timeout_ms: int = -1")
__anona2f81a440602(InferenceContext* c) 338 .SetShapeFn([](InferenceContext* c) {
339 ShapeHandle n_shape;
340 if (c->input_tensor(1) == nullptr) {
341 n_shape = c->Vector(InferenceContext::kUnknownDim);
342 } else {
343 const int32 n = c->input_tensor(1)->scalar<int32>()();
344 if (n < 0) {
345 return errors::InvalidArgument("Input 'n' must be >= 0, but is ", n);
346 }
347 n_shape = c->Vector(n);
348 }
349 return DequeueManyV2Shape(c, n_shape);
350 });
351
352 REGISTER_OP("QueueDequeueUpTo")
353 .Input("handle: Ref(string)")
354 .Input("n: int32")
355 .Output("components: component_types")
356 .Attr("component_types: list(type) >= 1")
357 .Attr("timeout_ms: int = -1")
358 .SetShapeFn(shape_inference::UnknownShape);
359
360 REGISTER_OP("QueueDequeueUpToV2")
361 .Input("handle: resource")
362 .Input("n: int32")
363 .Output("components: component_types")
364 .Attr("component_types: list(type) >= 1")
365 .Attr("timeout_ms: int = -1")
__anona2f81a440702(InferenceContext* c) 366 .SetShapeFn([](InferenceContext* c) {
367 return DequeueManyV2Shape(c, c->Vector(InferenceContext::kUnknownDim));
368 });
369
370 REGISTER_OP("QueueClose")
371 .Input("handle: Ref(string)")
372 .SetShapeFn(TwoElementVectorInputsAndScalarOutputs)
373 .Attr("cancel_pending_enqueues: bool = false");
374
375 REGISTER_OP("QueueCloseV2")
376 .Input("handle: resource")
377 .SetShapeFn(shape_inference::NoOutputs)
378 .Attr("cancel_pending_enqueues: bool = false");
379
380 REGISTER_OP("QueueIsClosed")
381 .Input("handle: Ref(string)")
382 .Output("is_closed: bool")
383 .SetShapeFn(shape_inference::ScalarShape);
384
385 REGISTER_OP("QueueIsClosedV2")
386 .Input("handle: resource")
387 .Output("is_closed: bool")
388 .SetShapeFn(shape_inference::ScalarShape);
389
390 REGISTER_OP("QueueSize")
391 .Input("handle: Ref(string)")
392 .Output("size: int32")
393 .SetShapeFn(TwoElementVectorInputsAndScalarOutputs);
394
395 REGISTER_OP("QueueSizeV2")
396 .Input("handle: resource")
397 .Output("size: int32")
398 .SetShapeFn(shape_inference::UnchangedShape);
399
400 // --------------------------------------------------------------------------
401
402 REGISTER_OP("AccumulatorNumAccumulated")
403 .Input("handle: Ref(string)")
404 .Output("num_accumulated: int32")
405 .SetShapeFn(shape_inference::ScalarShape);
406
407 REGISTER_OP("AccumulatorSetGlobalStep")
408 .Input("handle: Ref(string)")
409 .Input("new_global_step: int64")
__anona2f81a440802(InferenceContext* c) 410 .SetShapeFn([](InferenceContext* c) {
411 ShapeHandle unused;
412 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
413 return Status::OK();
414 });
415
416 REGISTER_OP("ConditionalAccumulator")
417 .Output("handle: Ref(string)")
418 .Attr("dtype: numbertype")
419 .Attr("shape: shape")
420 .Attr("container: string = ''")
421 .Attr("shared_name: string = ''")
422 .Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ")
423 .SetIsStateful()
__anona2f81a440902(InferenceContext* c) 424 .SetShapeFn([](InferenceContext* c) {
425 c->set_output(0, c->Vector(2));
426 return Status::OK();
427 });
428
429 REGISTER_OP("AccumulatorApplyGradient")
430 .Input("handle: Ref(string)")
431 .Input("local_step: int64")
432 .Input("gradient: dtype")
433 .Attr("dtype: numbertype")
__anona2f81a440a02(InferenceContext* c) 434 .SetShapeFn([](InferenceContext* c) {
435 ShapeHandle unused;
436 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
437 return Status::OK();
438 });
439
440 REGISTER_OP("AccumulatorTakeGradient")
441 .Input("handle: Ref(string)")
442 .Input("num_required: int32")
443 .Output("average: dtype")
__anona2f81a440b02(InferenceContext* c) 444 .SetShapeFn([](InferenceContext* c) {
445 ShapeHandle unused;
446 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
447 // Shape of output is the shape of the accumulator referenced
448 // by 'handle', but which is not available here, so we lose
449 // shape information.
450 return shape_inference::UnknownShape(c);
451 })
452 .Attr("dtype: numbertype");
453
454 REGISTER_OP("SparseConditionalAccumulator")
455 .Output("handle: Ref(string)")
456 .Attr("dtype: numbertype")
457 .Attr("shape: shape")
458 .Attr("container: string = ''")
459 .Attr("shared_name: string = ''")
460 .Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ")
461 .SetIsStateful()
__anona2f81a440c02(InferenceContext* c) 462 .SetShapeFn([](InferenceContext* c) {
463 c->set_output(0, c->Vector(2));
464 return Status::OK();
465 });
466
467 REGISTER_OP("SparseAccumulatorApplyGradient")
468 .Input("handle: Ref(string)")
469 .Input("local_step: int64")
470 .Input("gradient_indices: int64")
471 .Input("gradient_values: dtype")
472 .Input("gradient_shape: int64")
473 .Attr("dtype: numbertype")
474 .Attr("has_known_shape: bool")
__anona2f81a440d02(InferenceContext* c) 475 .SetShapeFn([](InferenceContext* c) {
476 ShapeHandle unused;
477 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
478 return Status::OK();
479 });
480
481 REGISTER_OP("SparseAccumulatorTakeGradient")
482 .Input("handle: Ref(string)")
483 .Input("num_required: int32")
484 .Output("indices: int64")
485 .Output("values: dtype")
486 .Output("shape: int64")
487 .Attr("dtype: numbertype")
__anona2f81a440e02(InferenceContext* c) 488 .SetShapeFn([](InferenceContext* c) {
489 ShapeHandle unused;
490 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
491 // Shape of output is the shape of the accumulator referenced
492 // by 'handle', but which is not available here, so we lose
493 // shape information.
494 return shape_inference::UnknownShape(c);
495 });
496
497 // --------------------------------------------------------------------------
498
499 REGISTER_OP("StackV2")
500 .Input("max_size: int32")
501 .Output("handle: resource")
502 .Attr("elem_type: type")
503 .Attr("stack_name: string = ''")
504 .SetIsStateful()
505 .SetShapeFn(TwoElementOutput);
506
507 REGISTER_OP("StackPushV2")
508 .Input("handle: resource")
509 .Input("elem: T")
510 .Output("output: T")
511 .Attr("T: type")
512 .Attr("swap_memory: bool = false")
__anona2f81a440f02(shape_inference::InferenceContext* c) 513 .SetShapeFn([](shape_inference::InferenceContext* c) {
514 c->set_output(0, c->input(1));
515 return Status::OK();
516 });
517
518 REGISTER_OP("StackPopV2")
519 .Input("handle: resource")
520 .Output("elem: elem_type")
521 .Attr("elem_type: type")
522 .SetShapeFn(shape_inference::UnknownShape);
523
524 REGISTER_OP("StackCloseV2")
525 .Input("handle: resource")
526 .SetShapeFn(TwoElementVectorInputsAndScalarOutputs);
527
528 // Deprecated ref-typed variants of stack.
529
530 REGISTER_OP("Stack")
531 .Output("handle: Ref(string)")
532 .Attr("elem_type: type")
533 .Attr("stack_name: string = ''")
534 .SetIsStateful()
535 .SetShapeFn(TwoElementOutput);
536
537 REGISTER_OP("StackPush")
538 .Input("handle: Ref(string)")
539 .Input("elem: T")
540 .Output("output: T")
541 .Attr("T: type")
542 .Attr("swap_memory: bool = false")
__anona2f81a441002(shape_inference::InferenceContext* c) 543 .SetShapeFn([](shape_inference::InferenceContext* c) {
544 c->set_output(0, c->input(1));
545 return Status::OK();
546 });
547
548 REGISTER_OP("StackPop")
549 .Input("handle: Ref(string)")
550 .Output("elem: elem_type")
551 .Attr("elem_type: type")
552 .SetShapeFn(shape_inference::UnknownShape);
553
554 REGISTER_OP("StackClose")
555 .Input("handle: Ref(string)")
556 .SetShapeFn(TwoElementVectorInputsAndScalarOutputs);
557
558 // --------------------------------------------------------------------------
559
560 REGISTER_OP("TensorArrayV3")
561 .Input("size: int32")
562 .Attr("dtype: type")
563 .Attr("element_shape: shape = { unknown_rank: true }")
564 .Attr("dynamic_size: bool = false")
565 .Attr("clear_after_read: bool = true")
566 .Attr("identical_element_shapes: bool = false")
567 .Attr("tensor_array_name: string = ''")
568 .Output("handle: resource")
569 .Output("flow: float")
570 .SetIsStateful()
__anona2f81a441102(InferenceContext* c) 571 .SetShapeFn([](InferenceContext* c) {
572 ShapeHandle unused;
573 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
574 c->set_output(0, c->Vector(2));
575 c->set_output(1, c->Scalar());
576 bool identical_shapes;
577 TF_RETURN_IF_ERROR(
578 c->GetAttr("identical_element_shapes", &identical_shapes));
579 DataType t;
580 TF_RETURN_IF_ERROR(c->GetAttr("dtype", &t));
581 PartialTensorShape p;
582 TF_RETURN_IF_ERROR(c->GetAttr("element_shape", &p));
583 ShapeHandle s;
584 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s));
585 if (c->FullyDefined(s) || identical_shapes) {
586 c->set_output_handle_shapes_and_types(
587 0, std::vector<shape_inference::ShapeAndType>{{s, t}});
588 }
589 return Status::OK();
590 });
591
592 REGISTER_OP("TensorArrayGradV3")
593 .Input("handle: resource")
594 .Input("flow_in: float")
595 .Output("grad_handle: resource")
596 .Output("flow_out: float")
597 .Attr("source: string")
598 .SetIsStateful()
__anona2f81a441202(InferenceContext* c) 599 .SetShapeFn([](InferenceContext* c) {
600 ShapeHandle handle;
601 DimensionHandle unused_dim;
602 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
603 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
604 c->set_output(0, c->Vector(2));
605 c->set_output(1, c->Scalar());
606 if (c->input_handle_shapes_and_types(0)) {
607 c->set_output_handle_shapes_and_types(
608 0, *c->input_handle_shapes_and_types(0));
609 }
610 return Status::OK();
611 });
612
613 REGISTER_OP("TensorArrayGradWithShape")
614 .Input("handle: resource")
615 .Input("flow_in: float")
616 .Input("shape_to_prepend: int32")
617 .Output("grad_handle: resource")
618 .Output("flow_out: float")
619 .Attr("source: string")
620 .SetIsStateful()
__anona2f81a441302(InferenceContext* c) 621 .SetShapeFn([](InferenceContext* c) {
622 ShapeHandle handle;
623 DimensionHandle unused_dim;
624 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
625 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
626 c->set_output(0, c->Vector(2));
627 c->set_output(1, c->Scalar());
628 auto* shape_and_type = c->input_handle_shapes_and_types(0);
629 if (shape_and_type) {
630 auto input_shape = (*shape_and_type)[0].shape;
631 auto dtype = (*shape_and_type)[0].dtype;
632 // Note that shape_to_preped is a rank 1 Tensor representing a shape.
633 // The size of dimension 0 is the number of dimensions we need to add to
634 // output shape.
635 int64 prepend_rank = c->Value(c->Dim(c->input(2), 0));
636 if (c->RankKnown(input_shape) &&
637 prepend_rank != InferenceContext::kUnknownDim) {
638 int32 input_rank = c->Rank(input_shape);
639 std::vector<DimensionHandle> dims;
640 dims.reserve(prepend_rank + input_rank);
641 for (int i = 0; i < prepend_rank; ++i) {
642 dims.push_back(c->UnknownDim());
643 }
644 for (int i = 0; i < input_rank; ++i) {
645 dims.push_back(c->Dim(input_shape, i));
646 }
647 c->set_output_handle_shapes_and_types(0,
648 {{c->MakeShape(dims), dtype}});
649 } else {
650 c->set_output_handle_shapes_and_types(0,
651 {{c->UnknownShape(), dtype}});
652 }
653 }
654 return Status::OK();
655 });
656
657 REGISTER_OP("TensorArrayWriteV3")
658 .Input("handle: resource")
659 .Input("index: int32")
660 .Input("value: T")
661 .Input("flow_in: float")
662 .Output("flow_out: float")
663 .Attr("T: type")
__anona2f81a441402(InferenceContext* c) 664 .SetShapeFn([](InferenceContext* c) {
665 ShapeHandle handle;
666 DimensionHandle unused_dim;
667 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
668 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
669
670 ShapeHandle unused;
671 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
672 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
673
674 auto* handle_data = c->input_handle_shapes_and_types(0);
675 if (handle_data != nullptr && !handle_data->empty()) {
676 shape_inference::ShapeAndType shape_and_type = (*handle_data)[0];
677 ShapeHandle value_shape = c->input(2);
678 TF_RETURN_IF_ERROR(
679 c->Merge(shape_and_type.shape, value_shape, &unused));
680 }
681
682 return shape_inference::ScalarShape(c);
683 });
684
685 REGISTER_OP("TensorArrayReadV3")
686 .Input("handle: resource")
687 .Input("index: int32")
688 .Input("flow_in: float")
689 .Output("value: dtype")
690 .Attr("dtype: type")
__anona2f81a441502(InferenceContext* c) 691 .SetShapeFn([](InferenceContext* c) {
692 ShapeHandle handle;
693 DimensionHandle unused_dim;
694 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
695 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
696 ShapeHandle unused;
697 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
698 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
699 auto shapes = c->input_handle_shapes_and_types(0);
700 if (shapes != nullptr && !shapes->empty()) {
701 ShapeHandle tensor_shape = shapes->at(0).shape;
702 c->set_output(0, tensor_shape);
703 return Status::OK();
704 } else {
705 return shape_inference::UnknownShape(c);
706 }
707 });
708
709 REGISTER_OP("TensorArrayGatherV3")
710 .Input("handle: resource")
711 .Input("indices: int32")
712 .Input("flow_in: float")
713 .Output("value: dtype")
714 .Attr("dtype: type")
715 .Attr("element_shape: shape = { unknown_rank: true }")
__anona2f81a441602(InferenceContext* c) 716 .SetShapeFn([](InferenceContext* c) {
717 ShapeHandle indices;
718 ShapeHandle unused;
719 DimensionHandle unused_dim;
720 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
721 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices));
722 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim));
723 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
724 auto shapes = c->input_handle_shapes_and_types(0);
725 if (shapes != nullptr && !shapes->empty()) {
726 ShapeHandle tensor_shape = shapes->at(0).shape;
727 ShapeHandle output_shape;
728 TF_RETURN_IF_ERROR(
729 c->Concatenate(indices, tensor_shape, &output_shape));
730 c->set_output(0, output_shape);
731 return Status::OK();
732 } else {
733 PartialTensorShape p;
734 TF_RETURN_IF_ERROR(c->GetAttr("element_shape", &p));
735 ShapeHandle s;
736 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s));
737 ShapeHandle output_shape;
738 TF_RETURN_IF_ERROR(c->Concatenate(indices, s, &output_shape));
739 c->set_output(0, output_shape);
740 return Status::OK();
741 }
742 });
743
744 REGISTER_OP("TensorArrayScatterV3")
745 .Input("handle: resource")
746 .Input("indices: int32")
747 .Input("value: T")
748 .Input("flow_in: float")
749 .Output("flow_out: float")
750 .Attr("T: type")
__anona2f81a441702(InferenceContext* c) 751 .SetShapeFn([](InferenceContext* c) {
752 ShapeHandle indices;
753 ShapeHandle unused;
754 DimensionHandle unused_dim;
755 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
756 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices));
757 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim));
758 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
759 ShapeHandle value_shape;
760 // Assert that the length of the indices tensor is equal to the first
761 // dimension of the value tensor.
762 TF_RETURN_IF_ERROR(
763 c->MergePrefix(c->input(2), indices, &value_shape, &indices));
764 auto shapes = c->input_handle_shapes_and_types(0);
765 if (shapes != nullptr && !shapes->empty()) {
766 ShapeHandle tensor_shape = shapes->at(0).shape;
767 ShapeHandle fed_shape;
768 TF_RETURN_IF_ERROR(c->Subshape(value_shape, 1, &fed_shape));
769 TF_RETURN_IF_ERROR(c->Merge(tensor_shape, fed_shape, &fed_shape));
770 }
771 return shape_inference::ScalarShape(c);
772 });
773
774 REGISTER_OP("TensorArrayConcatV3")
775 .Input("handle: resource")
776 .Input("flow_in: float")
777 .Output("value: dtype")
778 .Output("lengths: int64")
779 .Attr("dtype: type")
780 .Attr("element_shape_except0: shape = { unknown_rank: true }")
__anona2f81a441802(InferenceContext* c) 781 .SetShapeFn([](InferenceContext* c) {
782 ShapeHandle handle;
783 DimensionHandle unused_dim;
784 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
785 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
786 ShapeHandle unused;
787 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
788 c->set_output(0, c->UnknownShape());
789 c->set_output(1, c->Vector(c->UnknownDim()));
790 return Status::OK();
791 });
792
793 REGISTER_OP("TensorArraySplitV3")
794 .Input("handle: resource")
795 .Input("value: T")
796 .Input("lengths: int64")
797 .Input("flow_in: float")
798 .Output("flow_out: float")
799 .Attr("T: type")
__anona2f81a441902(InferenceContext* c) 800 .SetShapeFn([](InferenceContext* c) {
801 ShapeHandle handle;
802 DimensionHandle unused_dim;
803 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
804 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
805 ShapeHandle unused;
806 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
807 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
808 return shape_inference::ScalarShape(c);
809 });
810
811 REGISTER_OP("TensorArraySizeV3")
812 .Input("handle: resource")
813 .Input("flow_in: float")
814 .Output("size: int32")
__anona2f81a441a02(InferenceContext* c) 815 .SetShapeFn([](InferenceContext* c) {
816 ShapeHandle handle;
817 DimensionHandle unused_dim;
818 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
819 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
820 return shape_inference::ScalarShape(c);
821 });
822
823 REGISTER_OP("TensorArrayCloseV3")
824 .Input("handle: resource")
__anona2f81a441b02(InferenceContext* c) 825 .SetShapeFn([](InferenceContext* c) {
826 ShapeHandle handle;
827 DimensionHandle unused_dim;
828 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
829 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
830 return Status::OK();
831 });
832
833 // --------------------------------------------------------------------------
834
835 // Deprecated TensorArray methods
836
837 REGISTER_OP("TensorArray")
838 .Input("size: int32")
839 .Attr("dtype: type")
840 .Attr("dynamic_size: bool = false")
841 .Attr("clear_after_read: bool = true")
842 .Attr("tensor_array_name: string = ''")
843 .Attr("element_shape: shape = { unknown_rank: true }")
844 .Output("handle: Ref(string)")
845 .SetIsStateful()
846 .SetShapeFn(shape_inference::UnknownShape)
847 .Deprecated(16, "Use TensorArrayV3");
848 REGISTER_OP("TensorArrayV2")
849 .Input("size: int32")
850 .Attr("dtype: type")
851 .Attr("element_shape: shape = { unknown_rank: true }")
852 .Attr("dynamic_size: bool = false")
853 .Attr("clear_after_read: bool = true")
854 .Attr("tensor_array_name: string = ''")
855 .Output("handle: string")
856 .SetIsStateful()
__anona2f81a441c02(InferenceContext* c) 857 .SetShapeFn([](InferenceContext* c) {
858 ShapeHandle unused;
859 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
860 c->set_output(0, c->Vector(2));
861 return Status::OK();
862 })
863 .Deprecated(26, "Use TensorArrayV3");
864 REGISTER_OP("TensorArrayGrad")
865 .Input("handle: string")
866 .Input("flow_in: float")
867 .Output("grad_handle: Ref(string)")
868 .Attr("source: string")
869 .SetIsStateful()
870 .SetShapeFn(shape_inference::UnknownShape)
871 .Deprecated(16, "Use TensorArrayGradV3");
872 REGISTER_OP("TensorArrayGradV2")
873 .Input("handle: string")
874 .Input("flow_in: float")
875 .Output("grad_handle: string")
876 .Attr("source: string")
877 .SetIsStateful()
__anona2f81a441d02(InferenceContext* c) 878 .SetShapeFn([](InferenceContext* c) {
879 ShapeHandle handle;
880 DimensionHandle unused_dim;
881 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
882 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
883 c->set_output(0, c->Vector(2));
884 return Status::OK();
885 })
886 .Deprecated(26, "Use TensorArrayGradV3");
887 REGISTER_OP("TensorArrayWrite")
888 .Input("handle: Ref(string)")
889 .Input("index: int32")
890 .Input("value: T")
891 .Input("flow_in: float")
892 .Output("flow_out: float")
893 .Attr("T: type")
894 .SetShapeFn(shape_inference::UnknownShape)
895 .Deprecated(16, "Use TensorArrayWriteV3");
896 REGISTER_OP("TensorArrayWriteV2")
897 .Input("handle: string")
898 .Input("index: int32")
899 .Input("value: T")
900 .Input("flow_in: float")
901 .Output("flow_out: float")
902 .Attr("T: type")
__anona2f81a441e02(InferenceContext* c) 903 .SetShapeFn([](InferenceContext* c) {
904 ShapeHandle handle;
905 DimensionHandle unused_dim;
906 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
907 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
908
909 ShapeHandle unused;
910 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
911 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
912 return shape_inference::ScalarShape(c);
913 })
914 .Deprecated(26, "Use TensorArrayWriteV3");
915 REGISTER_OP("TensorArrayRead")
916 .Input("handle: Ref(string)")
917 .Input("index: int32")
918 .Input("flow_in: float")
919 .Output("value: dtype")
920 .Attr("dtype: type")
921 .SetShapeFn(shape_inference::UnknownShape)
922 .Deprecated(16, "Use TensorArrayReadV3");
923 REGISTER_OP("TensorArrayReadV2")
924 .Input("handle: string")
925 .Input("index: int32")
926 .Input("flow_in: float")
927 .Output("value: dtype")
928 .Attr("dtype: type")
__anona2f81a441f02(InferenceContext* c) 929 .SetShapeFn([](InferenceContext* c) {
930 ShapeHandle handle;
931 DimensionHandle unused_dim;
932 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
933 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
934 ShapeHandle unused;
935 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
936 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
937 return shape_inference::UnknownShape(c);
938 })
939 .Deprecated(26, "Use TensorArrayReadV3");
940 REGISTER_OP("TensorArrayPack")
941 .Input("handle: Ref(string)")
942 .Input("flow_in: float")
943 .Output("value: dtype")
944 .Attr("dtype: type")
945 .Attr("element_shape: shape = { unknown_rank: true }")
946 .SetShapeFn(shape_inference::UnknownShape)
947 .Deprecated(16, "Use TensorArrayGatherV3 with RangeOp");
948 REGISTER_OP("TensorArrayUnpack")
949 .Input("handle: Ref(string)")
950 .Input("value: T")
951 .Input("flow_in: float")
952 .Output("flow_out: float")
953 .Attr("T: type")
954 .SetShapeFn(shape_inference::UnknownShape)
955 .Deprecated(20, "Use TensorArrayScatterV3 with RangeOp");
956 REGISTER_OP("TensorArrayGather")
957 .Input("handle: Ref(string)")
958 .Input("indices: int32")
959 .Input("flow_in: float")
960 .Output("value: dtype")
961 .Attr("dtype: type")
962 .Attr("element_shape: shape = { unknown_rank: true }")
963 .SetShapeFn(shape_inference::UnknownShape)
964 .Deprecated(16, "Use TensorArrayGatherV3");
965 REGISTER_OP("TensorArrayGatherV2")
966 .Input("handle: string")
967 .Input("indices: int32")
968 .Input("flow_in: float")
969 .Output("value: dtype")
970 .Attr("dtype: type")
971 .Attr("element_shape: shape = { unknown_rank: true }")
__anona2f81a442002(InferenceContext* c) 972 .SetShapeFn([](InferenceContext* c) {
973 ShapeHandle unused;
974 DimensionHandle unused_dim;
975 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
976 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
977 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim));
978 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
979 return shape_inference::UnknownShape(c);
980 })
981 .Deprecated(26, "Use TensorArrayGatherV3");
982 REGISTER_OP("TensorArrayScatter")
983 .Input("handle: Ref(string)")
984 .Input("indices: int32")
985 .Input("value: T")
986 .Input("flow_in: float")
987 .Output("flow_out: float")
988 .Attr("T: type")
989 .SetShapeFn(shape_inference::UnknownShape)
990 .Deprecated(19, "Use TensorArrayGradV3");
991 REGISTER_OP("TensorArrayScatterV2")
992 .Input("handle: string")
993 .Input("indices: int32")
994 .Input("value: T")
995 .Input("flow_in: float")
996 .Output("flow_out: float")
997 .Attr("T: type")
__anona2f81a442102(InferenceContext* c) 998 .SetShapeFn([](InferenceContext* c) {
999 ShapeHandle unused;
1000 DimensionHandle unused_dim;
1001 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
1002 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
1003 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim));
1004 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1005 return shape_inference::ScalarShape(c);
1006 })
1007 .Deprecated(26, "Use TensorArrayScatterV3");
1008 REGISTER_OP("TensorArrayConcat")
1009 .Input("handle: Ref(string)")
1010 .Input("flow_in: float")
1011 .Output("value: dtype")
1012 .Output("lengths: int64")
1013 .Attr("dtype: type")
1014 .Attr("element_shape_except0: shape = { unknown_rank: true }")
1015 .SetShapeFn(shape_inference::UnknownShape)
1016 .Deprecated(16, "Use TensorArrayGradV3");
1017 REGISTER_OP("TensorArrayConcatV2")
1018 .Input("handle: string")
1019 .Input("flow_in: float")
1020 .Output("value: dtype")
1021 .Output("lengths: int64")
1022 .Attr("dtype: type")
1023 .Attr("element_shape_except0: shape = { unknown_rank: true }")
__anona2f81a442202(InferenceContext* c) 1024 .SetShapeFn([](InferenceContext* c) {
1025 ShapeHandle handle;
1026 DimensionHandle unused_dim;
1027 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
1028 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
1029 ShapeHandle unused;
1030 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1031 c->set_output(0, c->UnknownShape());
1032 c->set_output(1, c->Vector(c->UnknownDim()));
1033 return Status::OK();
1034 });
1035 REGISTER_OP("TensorArraySplit")
1036 .Input("handle: Ref(string)")
1037 .Input("value: T")
1038 .Input("lengths: int64")
1039 .Input("flow_in: float")
1040 .Output("flow_out: float")
1041 .Attr("T: type")
1042 .SetShapeFn(shape_inference::UnknownShape)
1043 .Deprecated(16, "Use TensorArraySplitV3");
1044 REGISTER_OP("TensorArraySplitV2")
1045 .Input("handle: string")
1046 .Input("value: T")
1047 .Input("lengths: int64")
1048 .Input("flow_in: float")
1049 .Output("flow_out: float")
1050 .Attr("T: type")
__anona2f81a442302(InferenceContext* c) 1051 .SetShapeFn([](InferenceContext* c) {
1052 ShapeHandle handle;
1053 DimensionHandle unused_dim;
1054 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
1055 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
1056 ShapeHandle unused;
1057 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
1058 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1059 return shape_inference::ScalarShape(c);
1060 })
1061 .Deprecated(26, "Use TensorArraySplitV3");
1062 REGISTER_OP("TensorArraySize")
1063 .Input("handle: Ref(string)")
1064 .Input("flow_in: float")
1065 .Output("size: int32")
1066 .SetShapeFn(shape_inference::UnknownShape)
1067 .Deprecated(16, "Use TensorArraySizeV3");
1068 REGISTER_OP("TensorArraySizeV2")
1069 .Input("handle: string")
1070 .Input("flow_in: float")
1071 .Output("size: int32")
__anona2f81a442402(InferenceContext* c) 1072 .SetShapeFn([](InferenceContext* c) {
1073 ShapeHandle handle;
1074 DimensionHandle unused_dim;
1075 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
1076 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
1077 return shape_inference::ScalarShape(c);
1078 })
1079 .Deprecated(26, "Use TensorArraySizeV3");
1080 REGISTER_OP("TensorArrayClose")
1081 .Input("handle: Ref(string)")
__anona2f81a442502(InferenceContext* c) 1082 .SetShapeFn([](InferenceContext* c) { return Status::OK(); })
1083 .Deprecated(16, "Use TensorArrayCloseV3");
1084 REGISTER_OP("TensorArrayCloseV2")
1085 .Input("handle: string")
__anona2f81a442602(InferenceContext* c) 1086 .SetShapeFn([](InferenceContext* c) {
1087 ShapeHandle handle;
1088 DimensionHandle unused_dim;
1089 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
1090 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
1091 return Status::OK();
1092 })
1093 .Deprecated(26, "Use TensorArrayCloseV3");
1094
1095 // --------------------------------------------------------------------------
1096
1097 REGISTER_OP("Barrier")
1098 .SetIsStateful()
1099 .Output("handle: Ref(string)")
1100 .Attr("component_types: list(type) >= 1")
1101 .Attr("shapes: list(shape) >= 0 = []")
1102 .Attr("capacity: int = -1")
1103 .Attr("container: string = ''")
1104 .Attr("shared_name: string = ''")
1105 .SetShapeFn(TwoElementOutput);
1106
1107 REGISTER_OP("BarrierInsertMany")
1108 .Input("handle: Ref(string)")
1109 .Input("keys: string")
1110 .Input("values: T")
1111 .Attr("T: type")
1112 .Attr("component_index: int")
__anona2f81a442702(InferenceContext* c) 1113 .SetShapeFn([](InferenceContext* c) {
1114 ShapeHandle keys = c->input(1);
1115 ShapeHandle values = c->input(2);
1116 ShapeHandle handle;
1117 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
1118 DimensionHandle unused_dim;
1119 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
1120 TF_RETURN_IF_ERROR(c->WithRank(keys, 1, &keys));
1121 TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values));
1122 TF_RETURN_IF_ERROR(c->Merge(keys, c->Vector(c->Dim(values, 0)), &handle));
1123 return Status::OK();
1124 });
1125
1126 REGISTER_OP("BarrierTakeMany")
1127 .Input("handle: Ref(string)")
1128 .Input("num_elements: int32")
1129 .Output("indices: int64")
1130 .Output("keys: string")
1131 .Output("values: component_types")
1132 .Attr("component_types: list(type) >= 1")
1133 .Attr("allow_small_batch: bool = false")
1134 .Attr("wait_for_incomplete: bool = false")
1135 .Attr("timeout_ms: int = -1")
1136 .SetShapeFn(shape_inference::UnknownShape);
1137
1138 REGISTER_OP("BarrierClose")
1139 .Input("handle: Ref(string)")
1140 .SetShapeFn(TwoElementVectorInputsAndScalarOutputs)
1141 .Attr("cancel_pending_enqueues: bool = false");
1142
1143 REGISTER_OP("BarrierReadySize")
1144 .Input("handle: Ref(string)")
1145 .Output("size: int32")
1146 .SetShapeFn(TwoElementVectorInputsAndScalarOutputs);
1147
1148 REGISTER_OP("BarrierIncompleteSize")
1149 .Input("handle: Ref(string)")
1150 .Output("size: int32")
1151 .SetShapeFn(TwoElementVectorInputsAndScalarOutputs);
1152
1153 // --------------------------------------------------------------------------
1154
1155 REGISTER_OP("GetSessionHandle")
1156 .Input("value: T")
1157 .Output("handle: string")
1158 .Attr("T: type")
1159 .SetIsStateful()
1160 .SetShapeFn(shape_inference::ScalarShape);
1161
1162 REGISTER_OP("GetSessionHandleV2")
1163 .Input("value: T")
1164 .Output("handle: resource")
1165 .Attr("T: type")
1166 .SetIsStateful()
1167 .SetShapeFn(shape_inference::ScalarShape);
1168
1169 REGISTER_OP("GetSessionTensor")
1170 .Input("handle: string")
1171 .Output("value: dtype")
1172 .Attr("dtype: type")
1173 .SetIsStateful()
__anona2f81a442802(InferenceContext* c) 1174 .SetShapeFn([](InferenceContext* c) {
1175 ShapeHandle unused;
1176 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
1177 return shape_inference::UnknownShape(c);
1178 });
1179
1180 REGISTER_OP("DeleteSessionTensor")
1181 .Input("handle: string")
1182 .SetIsStateful()
__anona2f81a442902(InferenceContext* c) 1183 .SetShapeFn([](InferenceContext* c) {
1184 ShapeHandle unused;
1185 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
1186 return Status::OK();
1187 });
1188
1189 REGISTER_OP("Stage")
1190 .Input("values: dtypes")
1191 .Attr("capacity: int >= 0 = 0")
1192 .Attr("memory_limit: int >= 0 = 0")
1193 .Attr("dtypes: list(type)")
1194 .Attr("container: string = ''")
1195 .Attr("shared_name: string = ''")
1196 .SetShapeFn(shape_inference::UnknownShape)
1197 .SetIsStateful();
1198
1199 REGISTER_OP("Unstage")
1200 .Output("values: dtypes")
1201 .Attr("capacity: int >= 0 = 0")
1202 .Attr("memory_limit: int >= 0 = 0")
1203 .Attr("dtypes: list(type)")
1204 .Attr("container: string = ''")
1205 .Attr("shared_name: string = ''")
1206 .SetShapeFn(shape_inference::UnknownShape)
1207 .SetIsStateful();
1208
1209 REGISTER_OP("StagePeek")
1210 .Input("index: int32")
1211 .Output("values: dtypes")
1212 .Attr("capacity: int >= 0 = 0")
1213 .Attr("memory_limit: int >= 0 = 0")
1214 .Attr("dtypes: list(type)")
1215 .Attr("container: string = ''")
1216 .Attr("shared_name: string = ''")
1217 .SetShapeFn(shape_inference::UnknownShape)
1218 .SetIsStateful();
1219
1220 REGISTER_OP("StageSize")
1221 .Output("size: int32")
1222 .Attr("capacity: int >= 0 = 0")
1223 .Attr("memory_limit: int >= 0 = 0")
1224 .Attr("dtypes: list(type)")
1225 .Attr("container: string = ''")
1226 .Attr("shared_name: string = ''")
1227 .SetShapeFn(shape_inference::ScalarShape)
1228 .SetIsStateful();
1229
1230 REGISTER_OP("StageClear")
1231 .Attr("capacity: int >= 0 = 0")
1232 .Attr("memory_limit: int >= 0 = 0")
1233 .Attr("dtypes: list(type)")
1234 .Attr("container: string = ''")
1235 .Attr("shared_name: string = ''")
1236 .SetShapeFn(shape_inference::UnknownShape)
1237 .SetIsStateful();
1238
1239 // UnorderedMap
1240 REGISTER_OP("MapStage")
1241 .Input("key: int64")
1242 .Input("indices: int32")
1243 .Input("values: fake_dtypes")
1244 .Attr("capacity: int >= 0 = 0")
1245 .Attr("memory_limit: int >= 0 = 0")
1246 .Attr("dtypes: list(type)")
1247 .Attr("fake_dtypes: list(type)")
1248 .Attr("container: string = ''")
1249 .Attr("shared_name: string = ''")
1250 .SetShapeFn(tensorflow::shape_inference::NoOutputs)
1251 .SetIsStateful();
1252
1253 REGISTER_OP("MapPeek")
1254 .Input("key: int64")
1255 .Input("indices: int32")
1256 .Output("values: dtypes")
1257 .Attr("capacity: int >= 0 = 0")
1258 .Attr("memory_limit: int >= 0 = 0")
1259 .Attr("dtypes: list(type)")
1260 .Attr("container: string = ''")
1261 .Attr("shared_name: string = ''")
1262 .SetShapeFn(tensorflow::shape_inference::UnknownShape)
1263 .SetIsStateful();
1264
1265 REGISTER_OP("MapUnstage")
1266 .Input("key: int64")
1267 .Input("indices: int32")
1268 .Output("values: dtypes")
1269 .Attr("capacity: int >= 0 = 0")
1270 .Attr("memory_limit: int >= 0 = 0")
1271 .Attr("dtypes: list(type)")
1272 .Attr("container: string = ''")
1273 .Attr("shared_name: string = ''")
1274 .SetShapeFn(tensorflow::shape_inference::UnknownShape)
1275 .SetIsStateful();
1276
1277 REGISTER_OP("MapUnstageNoKey")
1278 .Input("indices: int32")
1279 .Output("key: int64")
1280 .Output("values: dtypes")
1281 .Attr("capacity: int >= 0 = 0")
1282 .Attr("memory_limit: int >= 0 = 0")
1283 .Attr("dtypes: list(type)")
1284 .Attr("container: string = ''")
1285 .Attr("shared_name: string = ''")
1286 .SetShapeFn(tensorflow::shape_inference::UnknownShape)
1287 .SetIsStateful();
1288
1289 REGISTER_OP("MapSize")
1290 .Output("size: int32")
1291 .Attr("capacity: int >= 0 = 0")
1292 .Attr("memory_limit: int >= 0 = 0")
1293 .Attr("dtypes: list(type)")
1294 .Attr("container: string = ''")
1295 .Attr("shared_name: string = ''")
1296 .SetShapeFn(tensorflow::shape_inference::ScalarShape)
1297 .SetIsStateful();
1298
1299 REGISTER_OP("MapIncompleteSize")
1300 .Output("size: int32")
1301 .Attr("capacity: int >= 0 = 0")
1302 .Attr("memory_limit: int >= 0 = 0")
1303 .Attr("dtypes: list(type)")
1304 .Attr("container: string = ''")
1305 .Attr("shared_name: string = ''")
1306 .SetShapeFn(tensorflow::shape_inference::ScalarShape)
1307 .SetIsStateful();
1308
1309 REGISTER_OP("MapClear")
1310 .Attr("capacity: int >= 0 = 0")
1311 .Attr("memory_limit: int >= 0 = 0")
1312 .Attr("dtypes: list(type)")
1313 .Attr("container: string = ''")
1314 .Attr("shared_name: string = ''")
1315 .SetShapeFn(tensorflow::shape_inference::NoOutputs)
1316 .SetIsStateful();
1317
1318 // OrderedMap
1319 REGISTER_OP("OrderedMapStage")
1320 .Input("key: int64")
1321 .Input("indices: int32")
1322 .Input("values: fake_dtypes")
1323 .Attr("capacity: int >= 0 = 0")
1324 .Attr("memory_limit: int >= 0 = 0")
1325 .Attr("dtypes: list(type)")
1326 .Attr("fake_dtypes: list(type)")
1327 .Attr("container: string = ''")
1328 .Attr("shared_name: string = ''")
1329 .SetShapeFn(tensorflow::shape_inference::NoOutputs)
1330 .SetIsStateful();
1331
1332 REGISTER_OP("OrderedMapPeek")
1333 .Input("key: int64")
1334 .Input("indices: int32")
1335 .Output("values: dtypes")
1336 .Attr("capacity: int >= 0 = 0")
1337 .Attr("memory_limit: int >= 0 = 0")
1338 .Attr("dtypes: list(type)")
1339 .Attr("container: string = ''")
1340 .Attr("shared_name: string = ''")
1341 .SetShapeFn(tensorflow::shape_inference::UnknownShape)
1342 .SetIsStateful();
1343
1344 REGISTER_OP("OrderedMapUnstage")
1345 .Input("key: int64")
1346 .Input("indices: int32")
1347 .Output("values: dtypes")
1348 .Attr("capacity: int >= 0 = 0")
1349 .Attr("memory_limit: int >= 0 = 0")
1350 .Attr("dtypes: list(type)")
1351 .Attr("container: string = ''")
1352 .Attr("shared_name: string = ''")
1353 .SetShapeFn(tensorflow::shape_inference::UnknownShape)
1354 .SetIsStateful();
1355
1356 REGISTER_OP("OrderedMapUnstageNoKey")
1357 .Input("indices: int32")
1358 .Output("key: int64")
1359 .Output("values: dtypes")
1360 .Attr("capacity: int >= 0 = 0")
1361 .Attr("memory_limit: int >= 0 = 0")
1362 .Attr("dtypes: list(type)")
1363 .Attr("container: string = ''")
1364 .Attr("shared_name: string = ''")
1365 .SetShapeFn(tensorflow::shape_inference::UnknownShape)
1366 .SetIsStateful();
1367
1368 REGISTER_OP("OrderedMapSize")
1369 .Output("size: int32")
1370 .Attr("capacity: int >= 0 = 0")
1371 .Attr("memory_limit: int >= 0 = 0")
1372 .Attr("dtypes: list(type)")
1373 .Attr("container: string = ''")
1374 .Attr("shared_name: string = ''")
1375 .SetShapeFn(tensorflow::shape_inference::ScalarShape)
1376 .SetIsStateful();
1377
1378 REGISTER_OP("OrderedMapIncompleteSize")
1379 .Output("size: int32")
1380 .Attr("capacity: int >= 0 = 0")
1381 .Attr("memory_limit: int >= 0 = 0")
1382 .Attr("dtypes: list(type)")
1383 .Attr("container: string = ''")
1384 .Attr("shared_name: string = ''")
1385 .SetShapeFn(tensorflow::shape_inference::ScalarShape)
1386 .SetIsStateful();
1387
1388 REGISTER_OP("OrderedMapClear")
1389 .Attr("capacity: int >= 0 = 0")
1390 .Attr("memory_limit: int >= 0 = 0")
1391 .Attr("dtypes: list(type)")
1392 .Attr("container: string = ''")
1393 .Attr("shared_name: string = ''")
1394 .SetShapeFn(tensorflow::shape_inference::NoOutputs)
1395 .SetIsStateful();
1396
1397 REGISTER_OP("RecordInput")
1398 .Output("records: string")
1399 .Attr("file_pattern: string")
1400 .Attr("file_random_seed: int = 301")
1401 .Attr("file_shuffle_shift_ratio: float = 0")
1402 .Attr("file_buffer_size: int = 10000")
1403 .Attr("file_parallelism: int = 16")
1404 .Attr("batch_size: int = 32")
1405 .Attr("compression_type: string = ''")
1406 .SetIsStateful()
1407 .SetShapeFn(shape_inference::UnknownShape);
1408
1409 } // namespace tensorflow
1410