1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <string>
18
19 #include "absl/memory/memory.h"
20 #include "pybind11/pybind11.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/core/stringpiece.h"
23 #include "tensorflow/core/lib/io/record_reader.h"
24 #include "tensorflow/core/lib/io/record_writer.h"
25 #include "tensorflow/core/lib/io/zlib_compression_options.h"
26 #include "tensorflow/core/platform/env.h"
27 #include "tensorflow/core/platform/file_system.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow/python/lib/core/pybind11_absl.h"
30 #include "tensorflow/python/lib/core/pybind11_status.h"
31
32 namespace {
33
34 namespace py = ::pybind11;
35
36 class PyRecordReader {
37 public:
38 // NOTE(sethtroisi): At this time PyRecordReader doesn't benefit from taking
39 // RecordReaderOptions, if this changes the API can be updated at that time.
New(const std::string & filename,const std::string & compression_type,PyRecordReader ** out)40 static tensorflow::Status New(const std::string& filename,
41 const std::string& compression_type,
42 PyRecordReader** out) {
43 std::unique_ptr<tensorflow::RandomAccessFile> file;
44 TF_RETURN_IF_ERROR(
45 tensorflow::Env::Default()->NewRandomAccessFile(filename, &file));
46 auto options =
47 tensorflow::io::RecordReaderOptions::CreateRecordReaderOptions(
48 compression_type);
49 options.buffer_size = kReaderBufferSize;
50 auto reader =
51 absl::make_unique<tensorflow::io::RecordReader>(file.get(), options);
52 *out = new PyRecordReader(std::move(file), std::move(reader));
53 return tensorflow::Status::OK();
54 }
55
56 PyRecordReader() = delete;
~PyRecordReader()57 ~PyRecordReader() { Close(); }
58
ReadNextRecord(tensorflow::tstring * out)59 tensorflow::Status ReadNextRecord(tensorflow::tstring* out) {
60 if (IsClosed()) {
61 return tensorflow::errors::FailedPrecondition("Reader is closed.");
62 }
63
64 return reader_->ReadRecord(&offset_, out);
65 }
66
IsClosed() const67 bool IsClosed() const { return file_ == nullptr && reader_ == nullptr; }
68
Close()69 void Close() {
70 reader_ = nullptr;
71 file_ = nullptr;
72 }
73
74 private:
75 static constexpr tensorflow::uint64 kReaderBufferSize = 16 * 1024 * 1024;
76
PyRecordReader(std::unique_ptr<tensorflow::RandomAccessFile> file,std::unique_ptr<tensorflow::io::RecordReader> reader)77 PyRecordReader(std::unique_ptr<tensorflow::RandomAccessFile> file,
78 std::unique_ptr<tensorflow::io::RecordReader> reader)
79 : offset_(0), file_(std::move(file)), reader_(std::move(reader)) {}
80
81 tensorflow::uint64 offset_;
82 std::unique_ptr<tensorflow::RandomAccessFile> file_;
83 std::unique_ptr<tensorflow::io::RecordReader> reader_;
84
85 TF_DISALLOW_COPY_AND_ASSIGN(PyRecordReader);
86 };
87
88 class PyRecordRandomReader {
89 public:
New(const std::string & filename,PyRecordRandomReader ** out)90 static tensorflow::Status New(const std::string& filename,
91 PyRecordRandomReader** out) {
92 std::unique_ptr<tensorflow::RandomAccessFile> file;
93 TF_RETURN_IF_ERROR(
94 tensorflow::Env::Default()->NewRandomAccessFile(filename, &file));
95 auto options =
96 tensorflow::io::RecordReaderOptions::CreateRecordReaderOptions("");
97 options.buffer_size = kReaderBufferSize;
98 auto reader =
99 absl::make_unique<tensorflow::io::RecordReader>(file.get(), options);
100 *out = new PyRecordRandomReader(std::move(file), std::move(reader));
101 return tensorflow::Status::OK();
102 }
103
104 PyRecordRandomReader() = delete;
~PyRecordRandomReader()105 ~PyRecordRandomReader() { Close(); }
106
ReadRecord(tensorflow::uint64 * offset,tensorflow::tstring * out)107 tensorflow::Status ReadRecord(tensorflow::uint64* offset,
108 tensorflow::tstring* out) {
109 if (IsClosed()) {
110 return tensorflow::errors::FailedPrecondition(
111 "Random TFRecord Reader is closed.");
112 }
113 return reader_->ReadRecord(offset, out);
114 }
115
IsClosed() const116 bool IsClosed() const { return file_ == nullptr && reader_ == nullptr; }
117
Close()118 void Close() {
119 reader_ = nullptr;
120 file_ = nullptr;
121 }
122
123 private:
124 static constexpr tensorflow::uint64 kReaderBufferSize = 16 * 1024 * 1024;
125
PyRecordRandomReader(std::unique_ptr<tensorflow::RandomAccessFile> file,std::unique_ptr<tensorflow::io::RecordReader> reader)126 PyRecordRandomReader(std::unique_ptr<tensorflow::RandomAccessFile> file,
127 std::unique_ptr<tensorflow::io::RecordReader> reader)
128 : file_(std::move(file)), reader_(std::move(reader)) {}
129
130 std::unique_ptr<tensorflow::RandomAccessFile> file_;
131 std::unique_ptr<tensorflow::io::RecordReader> reader_;
132
133 TF_DISALLOW_COPY_AND_ASSIGN(PyRecordRandomReader);
134 };
135
136 class PyRecordWriter {
137 public:
New(const std::string & filename,const tensorflow::io::RecordWriterOptions & options,PyRecordWriter ** out)138 static tensorflow::Status New(
139 const std::string& filename,
140 const tensorflow::io::RecordWriterOptions& options,
141 PyRecordWriter** out) {
142 std::unique_ptr<tensorflow::WritableFile> file;
143 TF_RETURN_IF_ERROR(
144 tensorflow::Env::Default()->NewWritableFile(filename, &file));
145 auto writer =
146 absl::make_unique<tensorflow::io::RecordWriter>(file.get(), options);
147 *out = new PyRecordWriter(std::move(file), std::move(writer));
148 return tensorflow::Status::OK();
149 }
150
151 PyRecordWriter() = delete;
~PyRecordWriter()152 ~PyRecordWriter() { Close(); }
153
WriteRecord(tensorflow::StringPiece record)154 tensorflow::Status WriteRecord(tensorflow::StringPiece record) {
155 if (IsClosed()) {
156 return tensorflow::errors::FailedPrecondition("Writer is closed.");
157 }
158 return writer_->WriteRecord(record);
159 }
160
Flush()161 tensorflow::Status Flush() {
162 if (IsClosed()) {
163 return tensorflow::errors::FailedPrecondition("Writer is closed.");
164 }
165
166 auto status = writer_->Flush();
167 if (status.ok()) {
168 // Per the RecordWriter contract, flushing the RecordWriter does not
169 // flush the underlying file. Here we need to do both.
170 return file_->Flush();
171 }
172 return status;
173 }
174
IsClosed() const175 bool IsClosed() const { return file_ == nullptr && writer_ == nullptr; }
176
Close()177 tensorflow::Status Close() {
178 if (writer_ != nullptr) {
179 auto status = writer_->Close();
180 writer_ = nullptr;
181 if (!status.ok()) return status;
182 }
183 if (file_ != nullptr) {
184 auto status = file_->Close();
185 file_ = nullptr;
186 if (!status.ok()) return status;
187 }
188 return tensorflow::Status::OK();
189 }
190
191 private:
PyRecordWriter(std::unique_ptr<tensorflow::WritableFile> file,std::unique_ptr<tensorflow::io::RecordWriter> writer)192 PyRecordWriter(std::unique_ptr<tensorflow::WritableFile> file,
193 std::unique_ptr<tensorflow::io::RecordWriter> writer)
194 : file_(std::move(file)), writer_(std::move(writer)) {}
195
196 std::unique_ptr<tensorflow::WritableFile> file_;
197 std::unique_ptr<tensorflow::io::RecordWriter> writer_;
198
199 TF_DISALLOW_COPY_AND_ASSIGN(PyRecordWriter);
200 };
201
PYBIND11_MODULE(_pywrap_record_io,m)202 PYBIND11_MODULE(_pywrap_record_io, m) {
203 py::class_<PyRecordReader>(m, "RecordIterator")
204 .def(py::init(
205 [](const std::string& filename, const std::string& compression_type) {
206 tensorflow::Status status;
207 PyRecordReader* self = nullptr;
208 {
209 py::gil_scoped_release release;
210 status = PyRecordReader::New(filename, compression_type, &self);
211 }
212 MaybeRaiseRegisteredFromStatus(status);
213 return self;
214 }))
215 .def("__iter__", [](const py::object& self) { return self; })
216 .def("__next__",
217 [](PyRecordReader* self) {
218 if (self->IsClosed()) {
219 throw py::stop_iteration();
220 }
221
222 tensorflow::tstring record;
223 tensorflow::Status status;
224 {
225 py::gil_scoped_release release;
226 status = self->ReadNextRecord(&record);
227 }
228 if (tensorflow::errors::IsOutOfRange(status)) {
229 // Don't close because the file being read could be updated
230 // in-between
231 // __next__ calls.
232 throw py::stop_iteration();
233 }
234 MaybeRaiseRegisteredFromStatus(status);
235 return py::bytes(record);
236 })
237 .def("close", [](PyRecordReader* self) { self->Close(); });
238
239 py::class_<PyRecordRandomReader>(m, "RandomRecordReader")
240 .def(py::init([](const std::string& filename) {
241 tensorflow::Status status;
242 PyRecordRandomReader* self = nullptr;
243 {
244 py::gil_scoped_release release;
245 status = PyRecordRandomReader::New(filename, &self);
246 }
247 MaybeRaiseRegisteredFromStatus(status);
248 return self;
249 }))
250 .def("read",
251 [](PyRecordRandomReader* self, tensorflow::uint64 offset) {
252 tensorflow::uint64 temp_offset = offset;
253 tensorflow::tstring record;
254 tensorflow::Status status;
255 {
256 py::gil_scoped_release release;
257 status = self->ReadRecord(&temp_offset, &record);
258 }
259 if (tensorflow::errors::IsOutOfRange(status)) {
260 throw py::index_error(tensorflow::strings::StrCat(
261 "Out of range at reading offset ", offset));
262 }
263 MaybeRaiseRegisteredFromStatus(status);
264 return py::make_tuple(py::bytes(record), temp_offset);
265 })
266 .def("close", [](PyRecordRandomReader* self) { self->Close(); });
267
268 using tensorflow::io::ZlibCompressionOptions;
269 py::class_<ZlibCompressionOptions>(m, "ZlibCompressionOptions")
270 .def_readwrite("flush_mode", &ZlibCompressionOptions::flush_mode)
271 .def_readwrite("input_buffer_size",
272 &ZlibCompressionOptions::input_buffer_size)
273 .def_readwrite("output_buffer_size",
274 &ZlibCompressionOptions::output_buffer_size)
275 .def_readwrite("window_bits", &ZlibCompressionOptions::window_bits)
276 .def_readwrite("compression_level",
277 &ZlibCompressionOptions::compression_level)
278 .def_readwrite("compression_method",
279 &ZlibCompressionOptions::compression_method)
280 .def_readwrite("mem_level", &ZlibCompressionOptions::mem_level)
281 .def_readwrite("compression_strategy",
282 &ZlibCompressionOptions::compression_strategy);
283
284 using tensorflow::io::RecordWriterOptions;
285 py::class_<RecordWriterOptions>(m, "RecordWriterOptions")
286 .def(py::init(&RecordWriterOptions::CreateRecordWriterOptions))
287 .def_readonly("compression_type", &RecordWriterOptions::compression_type)
288 .def_readonly("zlib_options", &RecordWriterOptions::zlib_options);
289
290 using tensorflow::MaybeRaiseRegisteredFromStatus;
291
292 py::class_<PyRecordWriter>(m, "RecordWriter")
293 .def(py::init(
294 [](const std::string& filename, const RecordWriterOptions& options) {
295 PyRecordWriter* self = nullptr;
296 tensorflow::Status status;
297 {
298 py::gil_scoped_release release;
299 status = PyRecordWriter::New(filename, options, &self);
300 }
301 MaybeRaiseRegisteredFromStatus(status);
302 return self;
303 }))
304 .def("__enter__", [](const py::object& self) { return self; })
305 .def("__exit__",
306 [](PyRecordWriter* self, py::args) {
307 MaybeRaiseRegisteredFromStatus(self->Close());
308 })
309 .def(
310 "write",
311 [](PyRecordWriter* self, tensorflow::StringPiece record) {
312 tensorflow::Status status;
313 {
314 py::gil_scoped_release release;
315 status = self->WriteRecord(record);
316 }
317 MaybeRaiseRegisteredFromStatus(status);
318 },
319 py::arg("record"))
320 .def("flush",
321 [](PyRecordWriter* self) {
322 MaybeRaiseRegisteredFromStatus(self->Flush());
323 })
324 .def("close", [](PyRecordWriter* self) {
325 MaybeRaiseRegisteredFromStatus(self->Close());
326 });
327 }
328
329 } // namespace
330