1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19 #include "tensorflow/core/util/saved_tensor_slice_util.h"
20
21 namespace tensorflow {
22
23 using shape_inference::DimensionHandle;
24 using shape_inference::InferenceContext;
25 using shape_inference::ShapeHandle;
26
27 namespace {
28
ScalarInputsAndOutputs(InferenceContext * c)29 Status ScalarInputsAndOutputs(InferenceContext* c) {
30 ShapeHandle unused;
31 for (int i = 0; i < c->num_inputs(); ++i) {
32 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused));
33 }
34 for (int i = 0; i < c->num_outputs(); ++i) {
35 c->set_output(i, c->Scalar());
36 }
37 return Status::OK();
38 }
39
TwoElementVectorAndScalarOutputs(InferenceContext * c)40 Status TwoElementVectorAndScalarOutputs(InferenceContext* c) {
41 ShapeHandle handle;
42 DimensionHandle unused_handle;
43 for (int i = 0; i < c->num_inputs(); ++i) {
44 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle));
45 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle));
46 }
47 for (int i = 0; i < c->num_outputs(); ++i) {
48 c->set_output(i, c->Scalar());
49 }
50 return Status::OK();
51 }
52
TwoElementOutput(InferenceContext * c)53 Status TwoElementOutput(InferenceContext* c) {
54 c->set_output(0, c->Vector(2));
55 return Status::OK();
56 }
57
58 } // namespace
59
60 REGISTER_OP("SaveV2")
61 .Input("prefix: string")
62 .Input("tensor_names: string")
63 .Input("shape_and_slices: string")
64 .Input("tensors: dtypes")
65 .Attr("dtypes: list(type)")
66 .SetIsStateful()
__anonec23b74b0202(InferenceContext* c) 67 .SetShapeFn([](InferenceContext* c) {
68 ShapeHandle unused;
69 ShapeHandle s;
70 DimensionHandle unused_dim;
71
72 // Validate prefix.
73 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
74
75 // Validate tensor_names and shapes_and_slices.
76 for (int i = 1; i <= 2; ++i) {
77 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &s));
78 TF_RETURN_IF_ERROR(
79 c->WithValue(c->Dim(s, 0), c->num_inputs() - 3, &unused_dim));
80 }
81 // TODO(mrry): Attempt to parse the shapes_and_slices values and use
82 // them to constrain the shape of the remaining inputs.
83 return Status::OK();
84 });
85
86 REGISTER_OP("RestoreV2")
87 .Input("prefix: string")
88 .Input("tensor_names: string")
89 .Input("shape_and_slices: string")
90 .Output("tensors: dtypes")
91 .Attr("dtypes: list(type)")
92 .SetIsStateful()
__anonec23b74b0302(InferenceContext* c) 93 .SetShapeFn([](InferenceContext* c) {
94 ShapeHandle shape0, shape1, shape2;
95 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &shape0));
96 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &shape1));
97 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &shape2));
98 TF_RETURN_IF_ERROR(c->Merge(shape1, shape2, &shape0));
99
100 // Attempt to infer output shapes from its shape_and_slice input.
101 const Tensor* shape_and_slices_tensor = c->input_tensor(2);
102 if (shape_and_slices_tensor) {
103 const auto& shape_and_slices_flat =
104 shape_and_slices_tensor->flat<string>();
105 if (shape_and_slices_flat.size() != c->num_outputs()) {
106 return errors::InvalidArgument(
107 "The number of shape_and_slice doesn't match tensor outputs.");
108 }
109 for (int i = 0; i < shape_and_slices_flat.size(); ++i) {
110 const string& shape_and_slice = shape_and_slices_flat(i);
111 if (shape_and_slice.empty()) {
112 c->set_output(i, c->UnknownShape());
113 continue;
114 }
115 TensorShape parsed_full_shape;
116 TensorSlice parsed_slice;
117 TensorShape parsed_slice_shape;
118 TF_RETURN_IF_ERROR(checkpoint::ParseShapeAndSlice(
119 shape_and_slice, &parsed_full_shape, &parsed_slice,
120 &parsed_slice_shape));
121 ShapeHandle shape_handle;
122 TF_RETURN_IF_ERROR(
123 c->MakeShapeFromTensorShape(parsed_slice_shape, &shape_handle));
124 c->set_output(i, shape_handle);
125 }
126 return Status::OK();
127 } else {
128 return UnknownShape(c);
129 }
130 });
131
132 REGISTER_OP("MergeV2Checkpoints")
133 .Input("checkpoint_prefixes: string")
134 .Input("destination_prefix: string")
135 .Attr("delete_old_dirs: bool = true")
136 .SetIsStateful()
__anonec23b74b0402(InferenceContext* c) 137 .SetShapeFn([](InferenceContext* c) {
138 ShapeHandle unused;
139 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
140 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
141 return Status::OK();
142 });
143
144 REGISTER_OP("Save")
145 .Input("filename: string")
146 .Input("tensor_names: string")
147 .Input("data: T")
148 .Attr("T: list(type)")
149 .SetIsStateful()
__anonec23b74b0502(InferenceContext* c) 150 .SetShapeFn([](InferenceContext* c) {
151 ShapeHandle unused;
152 ShapeHandle s;
153 DimensionHandle unused_dim;
154
155 // Validate filename.
156 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
157
158 // Validate tensor_names.
159 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &s));
160 TF_RETURN_IF_ERROR(
161 c->WithValue(c->Dim(s, 0), c->num_inputs() - 2, &unused_dim));
162
163 return Status::OK();
164 });
165
166 REGISTER_OP("SaveSlices")
167 .Input("filename: string")
168 .Input("tensor_names: string")
169 .Input("shapes_and_slices: string")
170 .Input("data: T")
171 .Attr("T: list(type)")
172 .SetIsStateful()
__anonec23b74b0602(InferenceContext* c) 173 .SetShapeFn([](InferenceContext* c) {
174 ShapeHandle unused;
175 ShapeHandle s;
176 DimensionHandle unused_dim;
177
178 // Validate filename.
179 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
180
181 // Validate tensor_names and unused_shapes_and_slices.
182 for (int i = 1; i <= 2; ++i) {
183 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &s));
184 TF_RETURN_IF_ERROR(
185 c->WithValue(c->Dim(s, 0), c->num_inputs() - 3, &unused_dim));
186 }
187 // TODO(mrry): Attempt to parse the shapes_and_slices values and use
188 // them to constrain the shape of the remaining inputs.
189 return Status::OK();
190 });
191
192 REGISTER_OP("Restore")
193 .Input("file_pattern: string")
194 .Input("tensor_name: string")
195 .Output("tensor: dt")
196 .Attr("dt: type")
197 .Attr("preferred_shard: int = -1")
198 .SetIsStateful()
__anonec23b74b0702(InferenceContext* c) 199 .SetShapeFn([](InferenceContext* c) {
200 ShapeHandle unused;
201 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
202 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
203 c->set_output(0, c->UnknownShape());
204 return Status::OK();
205 });
206
207 REGISTER_OP("RestoreSlice")
208 .Input("file_pattern: string")
209 .Input("tensor_name: string")
210 .Input("shape_and_slice: string")
211 .Output("tensor: dt")
212 .Attr("dt: type")
213 .Attr("preferred_shard: int = -1")
214 .SetIsStateful()
__anonec23b74b0802(InferenceContext* c) 215 .SetShapeFn([](InferenceContext* c) {
216 ShapeHandle unused;
217 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
218 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
219 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
220
221 // Attempt to infer output shapes from its shape_and_slice input.
222 const Tensor* shape_and_slices_tensor = c->input_tensor(2);
223 if (shape_and_slices_tensor) {
224 const auto& shape_and_slice =
225 shape_and_slices_tensor->flat<string>()(0);
226 if (shape_and_slice.empty()) {
227 c->set_output(0, c->UnknownShape());
228 } else {
229 TensorShape parsed_full_shape;
230 TensorSlice parsed_slice;
231 TensorShape parsed_slice_shape;
232 TF_RETURN_IF_ERROR(checkpoint::ParseShapeAndSlice(
233 shape_and_slice, &parsed_full_shape, &parsed_slice,
234 &parsed_slice_shape));
235 ShapeHandle shape_handle;
236 TF_RETURN_IF_ERROR(
237 c->MakeShapeFromTensorShape(parsed_slice_shape, &shape_handle));
238 c->set_output(0, shape_handle);
239 }
240 } else {
241 c->set_output(0, c->UnknownShape());
242 }
243 return Status::OK();
244 });
245
246 REGISTER_OP("ShardedFilename")
247 .Input("basename: string")
248 .Input("shard: int32")
249 .Input("num_shards: int32")
250 .Output("filename: string")
251 .SetShapeFn(ScalarInputsAndOutputs);
252
253 REGISTER_OP("ShardedFilespec")
254 .Input("basename: string")
255 .Input("num_shards: int32")
256 .Output("filename: string")
257 .SetShapeFn(ScalarInputsAndOutputs);
258
259 // Reader source ops ----------------------------------------------------------
260
261 REGISTER_OP("WholeFileReader")
262 .Output("reader_handle: Ref(string)")
263 .Attr("container: string = ''")
264 .Attr("shared_name: string = ''")
265 .SetIsStateful()
266 .SetShapeFn(TwoElementOutput);
267
268 REGISTER_OP("WholeFileReaderV2")
269 .Output("reader_handle: resource")
270 .Attr("container: string = ''")
271 .Attr("shared_name: string = ''")
272 .SetIsStateful()
273 .SetShapeFn(shape_inference::ScalarShape);
274
275 REGISTER_OP("TextLineReader")
276 .Output("reader_handle: Ref(string)")
277 .Attr("skip_header_lines: int = 0")
278 .Attr("container: string = ''")
279 .Attr("shared_name: string = ''")
280 .SetIsStateful()
281 .SetShapeFn(TwoElementOutput)
282 .Deprecated(26, "Use TextLineReaderV2");
283
284 REGISTER_OP("TextLineReaderV2")
285 .Output("reader_handle: resource")
286 .Attr("skip_header_lines: int = 0")
287 .Attr("container: string = ''")
288 .Attr("shared_name: string = ''")
289 .SetIsStateful()
290 .SetShapeFn(shape_inference::ScalarShape);
291
292 REGISTER_OP("FixedLengthRecordReader")
293 .Output("reader_handle: Ref(string)")
294 .Attr("header_bytes: int = 0")
295 .Attr("record_bytes: int")
296 .Attr("footer_bytes: int = 0")
297 .Attr("hop_bytes: int = 0")
298 .Attr("container: string = ''")
299 .Attr("shared_name: string = ''")
300 .SetIsStateful()
301 .SetShapeFn(TwoElementOutput)
302 .Deprecated(26, "Use FixedLengthRecordReaderV2");
303
304 REGISTER_OP("FixedLengthRecordReaderV2")
305 .Output("reader_handle: resource")
306 .Attr("header_bytes: int = 0")
307 .Attr("record_bytes: int")
308 .Attr("footer_bytes: int = 0")
309 .Attr("hop_bytes: int = 0")
310 .Attr("container: string = ''")
311 .Attr("shared_name: string = ''")
312 .Attr("encoding: string = ''")
313 .SetIsStateful()
314 .SetShapeFn(shape_inference::ScalarShape);
315
316 REGISTER_OP("TFRecordReader")
317 .Output("reader_handle: Ref(string)")
318 .Attr("container: string = ''")
319 .Attr("shared_name: string = ''")
320 .Attr("compression_type: string = ''")
321 .SetIsStateful()
322 .SetShapeFn(TwoElementOutput)
323 .Deprecated(26, "Use TFRecordReaderV2");
324
325 REGISTER_OP("TFRecordReaderV2")
326 .Output("reader_handle: resource")
327 .Attr("container: string = ''")
328 .Attr("shared_name: string = ''")
329 .Attr("compression_type: string = ''")
330 .SetIsStateful()
331 .SetShapeFn(shape_inference::ScalarShape);
332
333 REGISTER_OP("LMDBReader")
334 .Output("reader_handle: Ref(string)")
335 .Attr("container: string = ''")
336 .Attr("shared_name: string = ''")
337 .SetIsStateful()
338 .SetShapeFn(TwoElementOutput);
339
340 REGISTER_OP("IdentityReader")
341 .Output("reader_handle: Ref(string)")
342 .Attr("container: string = ''")
343 .Attr("shared_name: string = ''")
344 .SetIsStateful()
345 .SetShapeFn(TwoElementOutput)
346 .Deprecated(26, "Use IdentityReaderV2");
347
348 REGISTER_OP("IdentityReaderV2")
349 .Output("reader_handle: resource")
350 .Attr("container: string = ''")
351 .Attr("shared_name: string = ''")
352 .SetIsStateful()
353 .SetShapeFn(shape_inference::ScalarShape);
354
355 // Ops that operate on Readers ------------------------------------------------
356
357 REGISTER_OP("ReaderRead")
358 .Input("reader_handle: Ref(string)")
359 .Input("queue_handle: Ref(string)")
360 .Output("key: string")
361 .Output("value: string")
362 .SetShapeFn(TwoElementVectorAndScalarOutputs);
363
364 REGISTER_OP("ReaderReadV2")
365 .Input("reader_handle: resource")
366 .Input("queue_handle: resource")
367 .Output("key: string")
368 .Output("value: string")
369 .SetShapeFn(ScalarInputsAndOutputs);
370
371 REGISTER_OP("ReaderReadUpTo")
372 .Input("reader_handle: Ref(string)")
373 .Input("queue_handle: Ref(string)")
374 .Input("num_records: int64")
375 .Output("keys: string")
376 .Output("values: string")
__anonec23b74b0902(InferenceContext* c) 377 .SetShapeFn([](InferenceContext* c) {
378 ShapeHandle unused;
379 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
380 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
381 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
382 ShapeHandle out = c->Vector(InferenceContext::kUnknownDim);
383 c->set_output(0, out);
384 c->set_output(1, out);
385 return Status::OK();
386 });
387
388 REGISTER_OP("ReaderReadUpToV2")
389 .Input("reader_handle: resource")
390 .Input("queue_handle: resource")
391 .Input("num_records: int64")
392 .Output("keys: string")
393 .Output("values: string")
__anonec23b74b0a02(InferenceContext* c) 394 .SetShapeFn([](InferenceContext* c) {
395 ShapeHandle unused;
396 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
397 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
398 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
399 ShapeHandle out = c->Vector(InferenceContext::kUnknownDim);
400 c->set_output(0, out);
401 c->set_output(1, out);
402 return Status::OK();
403 });
404
405 REGISTER_OP("ReaderNumRecordsProduced")
406 .Input("reader_handle: Ref(string)")
407 .Output("records_produced: int64")
408 .SetShapeFn(TwoElementVectorAndScalarOutputs);
409
410 REGISTER_OP("ReaderNumRecordsProducedV2")
411 .Input("reader_handle: resource")
412 .Output("records_produced: int64")
413 .SetShapeFn(ScalarInputsAndOutputs);
414
415 REGISTER_OP("ReaderNumWorkUnitsCompleted")
416 .Input("reader_handle: Ref(string)")
417 .Output("units_completed: int64")
418 .SetShapeFn(TwoElementVectorAndScalarOutputs);
419
420 REGISTER_OP("ReaderNumWorkUnitsCompletedV2")
421 .Input("reader_handle: resource")
422 .Output("units_completed: int64")
423 .SetShapeFn(ScalarInputsAndOutputs);
424
425 REGISTER_OP("ReaderSerializeState")
426 .Input("reader_handle: Ref(string)")
427 .Output("state: string")
428 .SetShapeFn(TwoElementVectorAndScalarOutputs);
429
430 REGISTER_OP("ReaderSerializeStateV2")
431 .Input("reader_handle: resource")
432 .Output("state: string")
433 .SetShapeFn(ScalarInputsAndOutputs);
434
435 REGISTER_OP("ReaderRestoreState")
436 .Input("reader_handle: Ref(string)")
437 .Input("state: string")
__anonec23b74b0b02(InferenceContext* c) 438 .SetShapeFn([](InferenceContext* c) {
439 ShapeHandle unused;
440 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
441 DimensionHandle unused_handle;
442 TF_RETURN_IF_ERROR(
443 c->WithValue(c->Dim(c->input(0), 0), 2, &unused_handle));
444
445 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
446 return Status::OK();
447 });
448
449 REGISTER_OP("ReaderRestoreStateV2")
450 .Input("reader_handle: resource")
451 .Input("state: string")
__anonec23b74b0c02(InferenceContext* c) 452 .SetShapeFn([](InferenceContext* c) {
453 ShapeHandle unused;
454 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
455 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
456 return Status::OK();
457 });
458
459 REGISTER_OP("ReaderReset")
460 .Input("reader_handle: Ref(string)")
461 .SetShapeFn(TwoElementVectorAndScalarOutputs);
462
463 REGISTER_OP("ReaderResetV2")
464 .Input("reader_handle: resource")
465 .SetShapeFn(ScalarInputsAndOutputs);
466
467 // Other input Ops ----------------------------------------------------------
468
469 REGISTER_OP("ReadFile")
470 .Input("filename: string")
471 .Output("contents: string")
472 .SetShapeFn(ScalarInputsAndOutputs);
473
474 REGISTER_OP("WriteFile")
475 .Input("filename: string")
476 .Input("contents: string")
__anonec23b74b0d02(InferenceContext* c) 477 .SetShapeFn([](InferenceContext* c) {
478 ShapeHandle unused;
479 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
480 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
481 return Status::OK();
482 });
483
484 REGISTER_OP("MatchingFiles")
485 .Input("pattern: string")
486 .Output("filenames: string")
__anonec23b74b0e02(InferenceContext* c) 487 .SetShapeFn([](InferenceContext* c) {
488 ShapeHandle unused;
489 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
490 c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
491 return Status::OK();
492 });
493
494 } // namespace tensorflow
495