1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19 #include "tensorflow/core/util/saved_tensor_slice_util.h"
20 
21 namespace tensorflow {
22 
23 using shape_inference::DimensionHandle;
24 using shape_inference::InferenceContext;
25 using shape_inference::ShapeHandle;
26 
27 namespace {
28 
ScalarInputsAndOutputs(InferenceContext * c)29 Status ScalarInputsAndOutputs(InferenceContext* c) {
30   ShapeHandle unused;
31   for (int i = 0; i < c->num_inputs(); ++i) {
32     TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused));
33   }
34   for (int i = 0; i < c->num_outputs(); ++i) {
35     c->set_output(i, c->Scalar());
36   }
37   return Status::OK();
38 }
39 
TwoElementVectorAndScalarOutputs(InferenceContext * c)40 Status TwoElementVectorAndScalarOutputs(InferenceContext* c) {
41   ShapeHandle handle;
42   DimensionHandle unused_handle;
43   for (int i = 0; i < c->num_inputs(); ++i) {
44     TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle));
45     TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle));
46   }
47   for (int i = 0; i < c->num_outputs(); ++i) {
48     c->set_output(i, c->Scalar());
49   }
50   return Status::OK();
51 }
52 
TwoElementOutput(InferenceContext * c)53 Status TwoElementOutput(InferenceContext* c) {
54   c->set_output(0, c->Vector(2));
55   return Status::OK();
56 }
57 
58 }  // namespace
59 
60 REGISTER_OP("SaveV2")
61     .Input("prefix: string")
62     .Input("tensor_names: string")
63     .Input("shape_and_slices: string")
64     .Input("tensors: dtypes")
65     .Attr("dtypes: list(type)")
66     .SetIsStateful()
__anonec23b74b0202(InferenceContext* c) 67     .SetShapeFn([](InferenceContext* c) {
68       ShapeHandle unused;
69       ShapeHandle s;
70       DimensionHandle unused_dim;
71 
72       // Validate prefix.
73       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
74 
75       // Validate tensor_names and shapes_and_slices.
76       for (int i = 1; i <= 2; ++i) {
77         TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &s));
78         TF_RETURN_IF_ERROR(
79             c->WithValue(c->Dim(s, 0), c->num_inputs() - 3, &unused_dim));
80       }
81       // TODO(mrry): Attempt to parse the shapes_and_slices values and use
82       // them to constrain the shape of the remaining inputs.
83       return Status::OK();
84     });
85 
86 REGISTER_OP("RestoreV2")
87     .Input("prefix: string")
88     .Input("tensor_names: string")
89     .Input("shape_and_slices: string")
90     .Output("tensors: dtypes")
91     .Attr("dtypes: list(type)")
92     .SetIsStateful()
__anonec23b74b0302(InferenceContext* c) 93     .SetShapeFn([](InferenceContext* c) {
94       ShapeHandle shape0, shape1, shape2;
95       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &shape0));
96       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &shape1));
97       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &shape2));
98       TF_RETURN_IF_ERROR(c->Merge(shape1, shape2, &shape0));
99 
100       // Attempt to infer output shapes from its shape_and_slice input.
101       const Tensor* shape_and_slices_tensor = c->input_tensor(2);
102       if (shape_and_slices_tensor) {
103         const auto& shape_and_slices_flat =
104             shape_and_slices_tensor->flat<string>();
105         if (shape_and_slices_flat.size() != c->num_outputs()) {
106           return errors::InvalidArgument(
107               "The number of shape_and_slice doesn't match tensor outputs.");
108         }
109         for (int i = 0; i < shape_and_slices_flat.size(); ++i) {
110           const string& shape_and_slice = shape_and_slices_flat(i);
111           if (shape_and_slice.empty()) {
112             c->set_output(i, c->UnknownShape());
113             continue;
114           }
115           TensorShape parsed_full_shape;
116           TensorSlice parsed_slice;
117           TensorShape parsed_slice_shape;
118           TF_RETURN_IF_ERROR(checkpoint::ParseShapeAndSlice(
119               shape_and_slice, &parsed_full_shape, &parsed_slice,
120               &parsed_slice_shape));
121           ShapeHandle shape_handle;
122           TF_RETURN_IF_ERROR(
123               c->MakeShapeFromTensorShape(parsed_slice_shape, &shape_handle));
124           c->set_output(i, shape_handle);
125         }
126         return Status::OK();
127       } else {
128         return UnknownShape(c);
129       }
130     });
131 
132 REGISTER_OP("MergeV2Checkpoints")
133     .Input("checkpoint_prefixes: string")
134     .Input("destination_prefix: string")
135     .Attr("delete_old_dirs: bool = true")
136     .SetIsStateful()
__anonec23b74b0402(InferenceContext* c) 137     .SetShapeFn([](InferenceContext* c) {
138       ShapeHandle unused;
139       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
140       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
141       return Status::OK();
142     });
143 
144 REGISTER_OP("Save")
145     .Input("filename: string")
146     .Input("tensor_names: string")
147     .Input("data: T")
148     .Attr("T: list(type)")
149     .SetIsStateful()
__anonec23b74b0502(InferenceContext* c) 150     .SetShapeFn([](InferenceContext* c) {
151       ShapeHandle unused;
152       ShapeHandle s;
153       DimensionHandle unused_dim;
154 
155       // Validate filename.
156       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
157 
158       // Validate tensor_names.
159       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &s));
160       TF_RETURN_IF_ERROR(
161           c->WithValue(c->Dim(s, 0), c->num_inputs() - 2, &unused_dim));
162 
163       return Status::OK();
164     });
165 
166 REGISTER_OP("SaveSlices")
167     .Input("filename: string")
168     .Input("tensor_names: string")
169     .Input("shapes_and_slices: string")
170     .Input("data: T")
171     .Attr("T: list(type)")
172     .SetIsStateful()
__anonec23b74b0602(InferenceContext* c) 173     .SetShapeFn([](InferenceContext* c) {
174       ShapeHandle unused;
175       ShapeHandle s;
176       DimensionHandle unused_dim;
177 
178       // Validate filename.
179       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
180 
181       // Validate tensor_names and unused_shapes_and_slices.
182       for (int i = 1; i <= 2; ++i) {
183         TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &s));
184         TF_RETURN_IF_ERROR(
185             c->WithValue(c->Dim(s, 0), c->num_inputs() - 3, &unused_dim));
186       }
187       // TODO(mrry): Attempt to parse the shapes_and_slices values and use
188       // them to constrain the shape of the remaining inputs.
189       return Status::OK();
190     });
191 
192 REGISTER_OP("Restore")
193     .Input("file_pattern: string")
194     .Input("tensor_name: string")
195     .Output("tensor: dt")
196     .Attr("dt: type")
197     .Attr("preferred_shard: int = -1")
198     .SetIsStateful()
__anonec23b74b0702(InferenceContext* c) 199     .SetShapeFn([](InferenceContext* c) {
200       ShapeHandle unused;
201       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
202       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
203       c->set_output(0, c->UnknownShape());
204       return Status::OK();
205     });
206 
207 REGISTER_OP("RestoreSlice")
208     .Input("file_pattern: string")
209     .Input("tensor_name: string")
210     .Input("shape_and_slice: string")
211     .Output("tensor: dt")
212     .Attr("dt: type")
213     .Attr("preferred_shard: int = -1")
214     .SetIsStateful()
__anonec23b74b0802(InferenceContext* c) 215     .SetShapeFn([](InferenceContext* c) {
216       ShapeHandle unused;
217       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
218       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
219       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
220 
221       // Attempt to infer output shapes from its shape_and_slice input.
222       const Tensor* shape_and_slices_tensor = c->input_tensor(2);
223       if (shape_and_slices_tensor) {
224         const auto& shape_and_slice =
225             shape_and_slices_tensor->flat<string>()(0);
226         if (shape_and_slice.empty()) {
227           c->set_output(0, c->UnknownShape());
228         } else {
229           TensorShape parsed_full_shape;
230           TensorSlice parsed_slice;
231           TensorShape parsed_slice_shape;
232           TF_RETURN_IF_ERROR(checkpoint::ParseShapeAndSlice(
233               shape_and_slice, &parsed_full_shape, &parsed_slice,
234               &parsed_slice_shape));
235           ShapeHandle shape_handle;
236           TF_RETURN_IF_ERROR(
237               c->MakeShapeFromTensorShape(parsed_slice_shape, &shape_handle));
238           c->set_output(0, shape_handle);
239         }
240       } else {
241         c->set_output(0, c->UnknownShape());
242       }
243       return Status::OK();
244     });
245 
246 REGISTER_OP("ShardedFilename")
247     .Input("basename: string")
248     .Input("shard: int32")
249     .Input("num_shards: int32")
250     .Output("filename: string")
251     .SetShapeFn(ScalarInputsAndOutputs);
252 
253 REGISTER_OP("ShardedFilespec")
254     .Input("basename: string")
255     .Input("num_shards: int32")
256     .Output("filename: string")
257     .SetShapeFn(ScalarInputsAndOutputs);
258 
259 // Reader source ops ----------------------------------------------------------
260 
261 REGISTER_OP("WholeFileReader")
262     .Output("reader_handle: Ref(string)")
263     .Attr("container: string = ''")
264     .Attr("shared_name: string = ''")
265     .SetIsStateful()
266     .SetShapeFn(TwoElementOutput);
267 
268 REGISTER_OP("WholeFileReaderV2")
269     .Output("reader_handle: resource")
270     .Attr("container: string = ''")
271     .Attr("shared_name: string = ''")
272     .SetIsStateful()
273     .SetShapeFn(shape_inference::ScalarShape);
274 
275 REGISTER_OP("TextLineReader")
276     .Output("reader_handle: Ref(string)")
277     .Attr("skip_header_lines: int = 0")
278     .Attr("container: string = ''")
279     .Attr("shared_name: string = ''")
280     .SetIsStateful()
281     .SetShapeFn(TwoElementOutput)
282     .Deprecated(26, "Use TextLineReaderV2");
283 
284 REGISTER_OP("TextLineReaderV2")
285     .Output("reader_handle: resource")
286     .Attr("skip_header_lines: int = 0")
287     .Attr("container: string = ''")
288     .Attr("shared_name: string = ''")
289     .SetIsStateful()
290     .SetShapeFn(shape_inference::ScalarShape);
291 
292 REGISTER_OP("FixedLengthRecordReader")
293     .Output("reader_handle: Ref(string)")
294     .Attr("header_bytes: int = 0")
295     .Attr("record_bytes: int")
296     .Attr("footer_bytes: int = 0")
297     .Attr("hop_bytes: int = 0")
298     .Attr("container: string = ''")
299     .Attr("shared_name: string = ''")
300     .SetIsStateful()
301     .SetShapeFn(TwoElementOutput)
302     .Deprecated(26, "Use FixedLengthRecordReaderV2");
303 
304 REGISTER_OP("FixedLengthRecordReaderV2")
305     .Output("reader_handle: resource")
306     .Attr("header_bytes: int = 0")
307     .Attr("record_bytes: int")
308     .Attr("footer_bytes: int = 0")
309     .Attr("hop_bytes: int = 0")
310     .Attr("container: string = ''")
311     .Attr("shared_name: string = ''")
312     .Attr("encoding: string = ''")
313     .SetIsStateful()
314     .SetShapeFn(shape_inference::ScalarShape);
315 
316 REGISTER_OP("TFRecordReader")
317     .Output("reader_handle: Ref(string)")
318     .Attr("container: string = ''")
319     .Attr("shared_name: string = ''")
320     .Attr("compression_type: string = ''")
321     .SetIsStateful()
322     .SetShapeFn(TwoElementOutput)
323     .Deprecated(26, "Use TFRecordReaderV2");
324 
325 REGISTER_OP("TFRecordReaderV2")
326     .Output("reader_handle: resource")
327     .Attr("container: string = ''")
328     .Attr("shared_name: string = ''")
329     .Attr("compression_type: string = ''")
330     .SetIsStateful()
331     .SetShapeFn(shape_inference::ScalarShape);
332 
333 REGISTER_OP("LMDBReader")
334     .Output("reader_handle: Ref(string)")
335     .Attr("container: string = ''")
336     .Attr("shared_name: string = ''")
337     .SetIsStateful()
338     .SetShapeFn(TwoElementOutput);
339 
340 REGISTER_OP("IdentityReader")
341     .Output("reader_handle: Ref(string)")
342     .Attr("container: string = ''")
343     .Attr("shared_name: string = ''")
344     .SetIsStateful()
345     .SetShapeFn(TwoElementOutput)
346     .Deprecated(26, "Use IdentityReaderV2");
347 
348 REGISTER_OP("IdentityReaderV2")
349     .Output("reader_handle: resource")
350     .Attr("container: string = ''")
351     .Attr("shared_name: string = ''")
352     .SetIsStateful()
353     .SetShapeFn(shape_inference::ScalarShape);
354 
355 // Ops that operate on Readers ------------------------------------------------
356 
357 REGISTER_OP("ReaderRead")
358     .Input("reader_handle: Ref(string)")
359     .Input("queue_handle: Ref(string)")
360     .Output("key: string")
361     .Output("value: string")
362     .SetShapeFn(TwoElementVectorAndScalarOutputs);
363 
364 REGISTER_OP("ReaderReadV2")
365     .Input("reader_handle: resource")
366     .Input("queue_handle: resource")
367     .Output("key: string")
368     .Output("value: string")
369     .SetShapeFn(ScalarInputsAndOutputs);
370 
371 REGISTER_OP("ReaderReadUpTo")
372     .Input("reader_handle: Ref(string)")
373     .Input("queue_handle: Ref(string)")
374     .Input("num_records: int64")
375     .Output("keys: string")
376     .Output("values: string")
__anonec23b74b0902(InferenceContext* c) 377     .SetShapeFn([](InferenceContext* c) {
378       ShapeHandle unused;
379       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
380       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
381       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
382       ShapeHandle out = c->Vector(InferenceContext::kUnknownDim);
383       c->set_output(0, out);
384       c->set_output(1, out);
385       return Status::OK();
386     });
387 
388 REGISTER_OP("ReaderReadUpToV2")
389     .Input("reader_handle: resource")
390     .Input("queue_handle: resource")
391     .Input("num_records: int64")
392     .Output("keys: string")
393     .Output("values: string")
__anonec23b74b0a02(InferenceContext* c) 394     .SetShapeFn([](InferenceContext* c) {
395       ShapeHandle unused;
396       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
397       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
398       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
399       ShapeHandle out = c->Vector(InferenceContext::kUnknownDim);
400       c->set_output(0, out);
401       c->set_output(1, out);
402       return Status::OK();
403     });
404 
405 REGISTER_OP("ReaderNumRecordsProduced")
406     .Input("reader_handle: Ref(string)")
407     .Output("records_produced: int64")
408     .SetShapeFn(TwoElementVectorAndScalarOutputs);
409 
410 REGISTER_OP("ReaderNumRecordsProducedV2")
411     .Input("reader_handle: resource")
412     .Output("records_produced: int64")
413     .SetShapeFn(ScalarInputsAndOutputs);
414 
415 REGISTER_OP("ReaderNumWorkUnitsCompleted")
416     .Input("reader_handle: Ref(string)")
417     .Output("units_completed: int64")
418     .SetShapeFn(TwoElementVectorAndScalarOutputs);
419 
420 REGISTER_OP("ReaderNumWorkUnitsCompletedV2")
421     .Input("reader_handle: resource")
422     .Output("units_completed: int64")
423     .SetShapeFn(ScalarInputsAndOutputs);
424 
425 REGISTER_OP("ReaderSerializeState")
426     .Input("reader_handle: Ref(string)")
427     .Output("state: string")
428     .SetShapeFn(TwoElementVectorAndScalarOutputs);
429 
430 REGISTER_OP("ReaderSerializeStateV2")
431     .Input("reader_handle: resource")
432     .Output("state: string")
433     .SetShapeFn(ScalarInputsAndOutputs);
434 
435 REGISTER_OP("ReaderRestoreState")
436     .Input("reader_handle: Ref(string)")
437     .Input("state: string")
__anonec23b74b0b02(InferenceContext* c) 438     .SetShapeFn([](InferenceContext* c) {
439       ShapeHandle unused;
440       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
441       DimensionHandle unused_handle;
442       TF_RETURN_IF_ERROR(
443           c->WithValue(c->Dim(c->input(0), 0), 2, &unused_handle));
444 
445       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
446       return Status::OK();
447     });
448 
449 REGISTER_OP("ReaderRestoreStateV2")
450     .Input("reader_handle: resource")
451     .Input("state: string")
__anonec23b74b0c02(InferenceContext* c) 452     .SetShapeFn([](InferenceContext* c) {
453       ShapeHandle unused;
454       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
455       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
456       return Status::OK();
457     });
458 
459 REGISTER_OP("ReaderReset")
460     .Input("reader_handle: Ref(string)")
461     .SetShapeFn(TwoElementVectorAndScalarOutputs);
462 
463 REGISTER_OP("ReaderResetV2")
464     .Input("reader_handle: resource")
465     .SetShapeFn(ScalarInputsAndOutputs);
466 
467 // Other input Ops ----------------------------------------------------------
468 
469 REGISTER_OP("ReadFile")
470     .Input("filename: string")
471     .Output("contents: string")
472     .SetShapeFn(ScalarInputsAndOutputs);
473 
474 REGISTER_OP("WriteFile")
475     .Input("filename: string")
476     .Input("contents: string")
__anonec23b74b0d02(InferenceContext* c) 477     .SetShapeFn([](InferenceContext* c) {
478       ShapeHandle unused;
479       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
480       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
481       return Status::OK();
482     });
483 
484 REGISTER_OP("MatchingFiles")
485     .Input("pattern: string")
486     .Output("filenames: string")
__anonec23b74b0e02(InferenceContext* c) 487     .SetShapeFn([](InferenceContext* c) {
488       ShapeHandle unused;
489       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
490       c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
491       return Status::OK();
492     });
493 
494 }  // namespace tensorflow
495