1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 
19 #include "absl/memory/memory.h"
20 #include "pybind11/pybind11.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/core/stringpiece.h"
23 #include "tensorflow/core/lib/io/record_reader.h"
24 #include "tensorflow/core/lib/io/record_writer.h"
25 #include "tensorflow/core/lib/io/zlib_compression_options.h"
26 #include "tensorflow/core/platform/env.h"
27 #include "tensorflow/core/platform/file_system.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow/python/lib/core/pybind11_absl.h"
30 #include "tensorflow/python/lib/core/pybind11_status.h"
31 
32 namespace {
33 
34 namespace py = ::pybind11;
35 
36 class PyRecordReader {
37  public:
38   // NOTE(sethtroisi): At this time PyRecordReader doesn't benefit from taking
39   // RecordReaderOptions, if this changes the API can be updated at that time.
New(const std::string & filename,const std::string & compression_type,PyRecordReader ** out)40   static tensorflow::Status New(const std::string& filename,
41                                 const std::string& compression_type,
42                                 PyRecordReader** out) {
43     std::unique_ptr<tensorflow::RandomAccessFile> file;
44     TF_RETURN_IF_ERROR(
45         tensorflow::Env::Default()->NewRandomAccessFile(filename, &file));
46     auto options =
47         tensorflow::io::RecordReaderOptions::CreateRecordReaderOptions(
48             compression_type);
49     options.buffer_size = kReaderBufferSize;
50     auto reader =
51         absl::make_unique<tensorflow::io::RecordReader>(file.get(), options);
52     *out = new PyRecordReader(std::move(file), std::move(reader));
53     return tensorflow::Status::OK();
54   }
55 
56   PyRecordReader() = delete;
~PyRecordReader()57   ~PyRecordReader() { Close(); }
58 
ReadNextRecord(tensorflow::tstring * out)59   tensorflow::Status ReadNextRecord(tensorflow::tstring* out) {
60     if (IsClosed()) {
61       return tensorflow::errors::FailedPrecondition("Reader is closed.");
62     }
63 
64     return reader_->ReadRecord(&offset_, out);
65   }
66 
IsClosed() const67   bool IsClosed() const { return file_ == nullptr && reader_ == nullptr; }
68 
Close()69   void Close() {
70     reader_ = nullptr;
71     file_ = nullptr;
72   }
73 
74  private:
75   static constexpr tensorflow::uint64 kReaderBufferSize = 16 * 1024 * 1024;
76 
PyRecordReader(std::unique_ptr<tensorflow::RandomAccessFile> file,std::unique_ptr<tensorflow::io::RecordReader> reader)77   PyRecordReader(std::unique_ptr<tensorflow::RandomAccessFile> file,
78                  std::unique_ptr<tensorflow::io::RecordReader> reader)
79       : offset_(0), file_(std::move(file)), reader_(std::move(reader)) {}
80 
81   tensorflow::uint64 offset_;
82   std::unique_ptr<tensorflow::RandomAccessFile> file_;
83   std::unique_ptr<tensorflow::io::RecordReader> reader_;
84 
85   TF_DISALLOW_COPY_AND_ASSIGN(PyRecordReader);
86 };
87 
88 class PyRecordRandomReader {
89  public:
New(const std::string & filename,PyRecordRandomReader ** out)90   static tensorflow::Status New(const std::string& filename,
91                                 PyRecordRandomReader** out) {
92     std::unique_ptr<tensorflow::RandomAccessFile> file;
93     TF_RETURN_IF_ERROR(
94         tensorflow::Env::Default()->NewRandomAccessFile(filename, &file));
95     auto options =
96         tensorflow::io::RecordReaderOptions::CreateRecordReaderOptions("");
97     options.buffer_size = kReaderBufferSize;
98     auto reader =
99         absl::make_unique<tensorflow::io::RecordReader>(file.get(), options);
100     *out = new PyRecordRandomReader(std::move(file), std::move(reader));
101     return tensorflow::Status::OK();
102   }
103 
104   PyRecordRandomReader() = delete;
~PyRecordRandomReader()105   ~PyRecordRandomReader() { Close(); }
106 
ReadRecord(tensorflow::uint64 * offset,tensorflow::tstring * out)107   tensorflow::Status ReadRecord(tensorflow::uint64* offset,
108                                 tensorflow::tstring* out) {
109     if (IsClosed()) {
110       return tensorflow::errors::FailedPrecondition(
111           "Random TFRecord Reader is closed.");
112     }
113     return reader_->ReadRecord(offset, out);
114   }
115 
IsClosed() const116   bool IsClosed() const { return file_ == nullptr && reader_ == nullptr; }
117 
Close()118   void Close() {
119     reader_ = nullptr;
120     file_ = nullptr;
121   }
122 
123  private:
124   static constexpr tensorflow::uint64 kReaderBufferSize = 16 * 1024 * 1024;
125 
PyRecordRandomReader(std::unique_ptr<tensorflow::RandomAccessFile> file,std::unique_ptr<tensorflow::io::RecordReader> reader)126   PyRecordRandomReader(std::unique_ptr<tensorflow::RandomAccessFile> file,
127                        std::unique_ptr<tensorflow::io::RecordReader> reader)
128       : file_(std::move(file)), reader_(std::move(reader)) {}
129 
130   std::unique_ptr<tensorflow::RandomAccessFile> file_;
131   std::unique_ptr<tensorflow::io::RecordReader> reader_;
132 
133   TF_DISALLOW_COPY_AND_ASSIGN(PyRecordRandomReader);
134 };
135 
136 class PyRecordWriter {
137  public:
New(const std::string & filename,const tensorflow::io::RecordWriterOptions & options,PyRecordWriter ** out)138   static tensorflow::Status New(
139       const std::string& filename,
140       const tensorflow::io::RecordWriterOptions& options,
141       PyRecordWriter** out) {
142     std::unique_ptr<tensorflow::WritableFile> file;
143     TF_RETURN_IF_ERROR(
144         tensorflow::Env::Default()->NewWritableFile(filename, &file));
145     auto writer =
146         absl::make_unique<tensorflow::io::RecordWriter>(file.get(), options);
147     *out = new PyRecordWriter(std::move(file), std::move(writer));
148     return tensorflow::Status::OK();
149   }
150 
151   PyRecordWriter() = delete;
~PyRecordWriter()152   ~PyRecordWriter() { Close(); }
153 
WriteRecord(tensorflow::StringPiece record)154   tensorflow::Status WriteRecord(tensorflow::StringPiece record) {
155     if (IsClosed()) {
156       return tensorflow::errors::FailedPrecondition("Writer is closed.");
157     }
158     return writer_->WriteRecord(record);
159   }
160 
Flush()161   tensorflow::Status Flush() {
162     if (IsClosed()) {
163       return tensorflow::errors::FailedPrecondition("Writer is closed.");
164     }
165 
166     auto status = writer_->Flush();
167     if (status.ok()) {
168       // Per the RecordWriter contract, flushing the RecordWriter does not
169       // flush the underlying file.  Here we need to do both.
170       return file_->Flush();
171     }
172     return status;
173   }
174 
IsClosed() const175   bool IsClosed() const { return file_ == nullptr && writer_ == nullptr; }
176 
Close()177   tensorflow::Status Close() {
178     if (writer_ != nullptr) {
179       auto status = writer_->Close();
180       writer_ = nullptr;
181       if (!status.ok()) return status;
182     }
183     if (file_ != nullptr) {
184       auto status = file_->Close();
185       file_ = nullptr;
186       if (!status.ok()) return status;
187     }
188     return tensorflow::Status::OK();
189   }
190 
191  private:
PyRecordWriter(std::unique_ptr<tensorflow::WritableFile> file,std::unique_ptr<tensorflow::io::RecordWriter> writer)192   PyRecordWriter(std::unique_ptr<tensorflow::WritableFile> file,
193                  std::unique_ptr<tensorflow::io::RecordWriter> writer)
194       : file_(std::move(file)), writer_(std::move(writer)) {}
195 
196   std::unique_ptr<tensorflow::WritableFile> file_;
197   std::unique_ptr<tensorflow::io::RecordWriter> writer_;
198 
199   TF_DISALLOW_COPY_AND_ASSIGN(PyRecordWriter);
200 };
201 
PYBIND11_MODULE(_pywrap_record_io,m)202 PYBIND11_MODULE(_pywrap_record_io, m) {
203   py::class_<PyRecordReader>(m, "RecordIterator")
204       .def(py::init(
205           [](const std::string& filename, const std::string& compression_type) {
206             tensorflow::Status status;
207             PyRecordReader* self = nullptr;
208             {
209               py::gil_scoped_release release;
210               status = PyRecordReader::New(filename, compression_type, &self);
211             }
212             MaybeRaiseRegisteredFromStatus(status);
213             return self;
214           }))
215       .def("__iter__", [](const py::object& self) { return self; })
216       .def("__next__",
217            [](PyRecordReader* self) {
218              if (self->IsClosed()) {
219                throw py::stop_iteration();
220              }
221 
222              tensorflow::tstring record;
223              tensorflow::Status status;
224              {
225                py::gil_scoped_release release;
226                status = self->ReadNextRecord(&record);
227              }
228              if (tensorflow::errors::IsOutOfRange(status)) {
229                // Don't close because the file being read could be updated
230                // in-between
231                // __next__ calls.
232                throw py::stop_iteration();
233              }
234              MaybeRaiseRegisteredFromStatus(status);
235              return py::bytes(record);
236            })
237       .def("close", [](PyRecordReader* self) { self->Close(); });
238 
239   py::class_<PyRecordRandomReader>(m, "RandomRecordReader")
240       .def(py::init([](const std::string& filename) {
241         tensorflow::Status status;
242         PyRecordRandomReader* self = nullptr;
243         {
244           py::gil_scoped_release release;
245           status = PyRecordRandomReader::New(filename, &self);
246         }
247         MaybeRaiseRegisteredFromStatus(status);
248         return self;
249       }))
250       .def("read",
251            [](PyRecordRandomReader* self, tensorflow::uint64 offset) {
252              tensorflow::uint64 temp_offset = offset;
253              tensorflow::tstring record;
254              tensorflow::Status status;
255              {
256                py::gil_scoped_release release;
257                status = self->ReadRecord(&temp_offset, &record);
258              }
259              if (tensorflow::errors::IsOutOfRange(status)) {
260                throw py::index_error(tensorflow::strings::StrCat(
261                    "Out of range at reading offset ", offset));
262              }
263              MaybeRaiseRegisteredFromStatus(status);
264              return py::make_tuple(py::bytes(record), temp_offset);
265            })
266       .def("close", [](PyRecordRandomReader* self) { self->Close(); });
267 
268   using tensorflow::io::ZlibCompressionOptions;
269   py::class_<ZlibCompressionOptions>(m, "ZlibCompressionOptions")
270       .def_readwrite("flush_mode", &ZlibCompressionOptions::flush_mode)
271       .def_readwrite("input_buffer_size",
272                      &ZlibCompressionOptions::input_buffer_size)
273       .def_readwrite("output_buffer_size",
274                      &ZlibCompressionOptions::output_buffer_size)
275       .def_readwrite("window_bits", &ZlibCompressionOptions::window_bits)
276       .def_readwrite("compression_level",
277                      &ZlibCompressionOptions::compression_level)
278       .def_readwrite("compression_method",
279                      &ZlibCompressionOptions::compression_method)
280       .def_readwrite("mem_level", &ZlibCompressionOptions::mem_level)
281       .def_readwrite("compression_strategy",
282                      &ZlibCompressionOptions::compression_strategy);
283 
284   using tensorflow::io::RecordWriterOptions;
285   py::class_<RecordWriterOptions>(m, "RecordWriterOptions")
286       .def(py::init(&RecordWriterOptions::CreateRecordWriterOptions))
287       .def_readonly("compression_type", &RecordWriterOptions::compression_type)
288       .def_readonly("zlib_options", &RecordWriterOptions::zlib_options);
289 
290   using tensorflow::MaybeRaiseRegisteredFromStatus;
291 
292   py::class_<PyRecordWriter>(m, "RecordWriter")
293       .def(py::init(
294           [](const std::string& filename, const RecordWriterOptions& options) {
295             PyRecordWriter* self = nullptr;
296             tensorflow::Status status;
297             {
298               py::gil_scoped_release release;
299               status = PyRecordWriter::New(filename, options, &self);
300             }
301             MaybeRaiseRegisteredFromStatus(status);
302             return self;
303           }))
304       .def("__enter__", [](const py::object& self) { return self; })
305       .def("__exit__",
306            [](PyRecordWriter* self, py::args) {
307              MaybeRaiseRegisteredFromStatus(self->Close());
308            })
309       .def(
310           "write",
311           [](PyRecordWriter* self, tensorflow::StringPiece record) {
312             tensorflow::Status status;
313             {
314               py::gil_scoped_release release;
315               status = self->WriteRecord(record);
316             }
317             MaybeRaiseRegisteredFromStatus(status);
318           },
319           py::arg("record"))
320       .def("flush",
321            [](PyRecordWriter* self) {
322              MaybeRaiseRegisteredFromStatus(self->Flush());
323            })
324       .def("close", [](PyRecordWriter* self) {
325         MaybeRaiseRegisteredFromStatus(self->Close());
326       });
327 }
328 
329 }  // namespace
330