1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/dataset.h" 17 #include "tensorflow/core/framework/partial_tensor_shape.h" 18 #include "tensorflow/core/framework/tensor.h" 19 #include "tensorflow/core/kernels/data/window_dataset.h" 20 21 namespace tensorflow { 22 namespace data { 23 namespace { 24 25 // See documentation in ../../ops/dataset_ops.cc for a high-level 26 // description of the following op. 27 28 class WindowDatasetOp : public UnaryDatasetOpKernel { 29 public: WindowDatasetOp(OpKernelConstruction * ctx)30 explicit WindowDatasetOp(OpKernelConstruction* ctx) 31 : UnaryDatasetOpKernel(ctx) {} 32 MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)33 void MakeDataset(OpKernelContext* ctx, DatasetBase* input, 34 DatasetBase** output) override { 35 int64 window_size = 0; 36 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "size", &window_size)); 37 OP_REQUIRES( 38 ctx, window_size > 0, 39 errors::InvalidArgument("Window size must be greater than zero.")); 40 41 int64 window_shift = 0; 42 OP_REQUIRES_OK(ctx, 43 ParseScalarArgument<int64>(ctx, "shift", &window_shift)); 44 OP_REQUIRES( 45 ctx, window_shift > 0, 46 errors::InvalidArgument("Window shift must be greater than zero.")); 47 48 int64 window_stride = 0; 49 OP_REQUIRES_OK(ctx, 50 ParseScalarArgument<int64>(ctx, "stride", &window_stride)); 51 OP_REQUIRES( 52 ctx, window_stride > 0, 53 errors::InvalidArgument("Window stride must be greater than zero.")); 54 55 bool drop_remainder; 56 OP_REQUIRES_OK( 57 ctx, ParseScalarArgument<bool>(ctx, "drop_remainder", &drop_remainder)); 58 59 *output = new Dataset(ctx, input, window_size, window_shift, window_stride, 60 drop_remainder); 61 } 62 63 private: 64 class Dataset : public DatasetBase { 65 public: Dataset(OpKernelContext * ctx,const DatasetBase * input,int64 window_size,int64 window_shift,int64 window_stride,bool drop_remainder)66 Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 window_size, 67 int64 window_shift, int64 window_stride, bool drop_remainder) 68 : DatasetBase(DatasetContext(ctx)), 69 input_(input), 70 window_size_(window_size), 71 window_shift_(window_shift), 72 window_stride_(window_stride), 73 drop_remainder_(drop_remainder) { 74 input_->Ref(); 75 } 76 ~Dataset()77 ~Dataset() override { input_->Unref(); } 78 MakeIteratorInternal(const string & prefix) const79 std::unique_ptr<IteratorBase> MakeIteratorInternal( 80 const string& prefix) const override { 81 return absl::make_unique<Iterator>( 82 Iterator::Params{this, strings::StrCat(prefix, "::Window")}); 83 } 84 output_dtypes() const85 const DataTypeVector& output_dtypes() const override { 86 static DataTypeVector* output_dtypes = new DataTypeVector({DT_VARIANT}); 87 return *output_dtypes; 88 } 89 output_shapes() const90 const std::vector<PartialTensorShape>& output_shapes() const override { 91 static std::vector<PartialTensorShape>* output_shapes = 92 new std::vector<PartialTensorShape>({TensorShape({})}); 93 return *output_shapes; 94 } 95 DebugString() const96 string DebugString() const override { 97 return strings::StrCat("WindowDatasetOp(", window_size_, window_shift_, 98 window_stride_, drop_remainder_, ")::Dataset"); 99 } 100 Cardinality() const101 int64 Cardinality() const override { 102 int64 n = input_->Cardinality(); 103 if (n == kInfiniteCardinality || n == kUnknownCardinality) { 104 return n; 105 } 106 return n / window_shift_ + 107 (n % window_shift_ == 0 || drop_remainder_ ? 0 : 1); 108 } 109 110 protected: AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const111 Status AsGraphDefInternal(SerializationContext* ctx, 112 DatasetGraphDefBuilder* b, 113 Node** output) const override { 114 Node* input_graph_node = nullptr; 115 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); 116 Node* window_size_node = nullptr; 117 TF_RETURN_IF_ERROR(b->AddScalar(window_size_, &window_size_node)); 118 Node* window_shift_node = nullptr; 119 TF_RETURN_IF_ERROR(b->AddScalar(window_shift_, &window_shift_node)); 120 Node* window_stride_node = nullptr; 121 TF_RETURN_IF_ERROR(b->AddScalar(window_stride_, &window_stride_node)); 122 Node* drop_remainder_node = nullptr; 123 TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder_node)); 124 TF_RETURN_IF_ERROR( 125 b->AddDataset(this, 126 {input_graph_node, window_size_node, window_shift_node, 127 window_stride_node, drop_remainder_node}, 128 output)); 129 return Status::OK(); 130 } 131 132 private: 133 class Iterator : public DatasetIterator<Dataset> { 134 public: Iterator(const Params & params)135 explicit Iterator(const Params& params) 136 : DatasetIterator<Dataset>(params) {} 137 Initialize(IteratorContext * ctx)138 Status Initialize(IteratorContext* ctx) override { 139 return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); 140 } 141 GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)142 Status GetNextInternal(IteratorContext* ctx, 143 std::vector<Tensor>* out_tensors, 144 bool* end_of_sequence) override { 145 const int64 window_size = dataset()->window_size_; 146 const int64 window_shift = dataset()->window_shift_; 147 const int64 window_stride = dataset()->window_stride_; 148 std::vector<std::vector<Tensor>> window_elements; 149 Status status = Status::OK(); 150 { 151 mutex_lock l(mu_); 152 if (!input_impl_ && buffer_.empty()) { 153 *end_of_sequence = true; 154 return Status::OK(); 155 } 156 157 // Add elements to the buffer. 158 size_t target_size = TargetBufferSize(window_size, window_stride); 159 if (input_impl_) { 160 *end_of_sequence = false; 161 for (size_t i = buffer_.size(); 162 i < target_size && !*end_of_sequence; ++i) { 163 std::vector<Tensor> element; 164 Status status = 165 input_impl_->GetNext(ctx, &element, end_of_sequence); 166 if (!*end_of_sequence) { 167 RecordBufferEnqueue(ctx, element); 168 buffer_.emplace_back(std::move(element), status); 169 } else { 170 input_impl_.reset(); 171 } 172 } 173 } 174 175 // If there are not enough elements and `drop_remainder` is set, we do 176 // not wish to return a smaller window. 177 if (buffer_.empty() || 178 (dataset()->drop_remainder_ && buffer_.size() < target_size)) { 179 DCHECK(*end_of_sequence); 180 return Status::OK(); 181 } 182 183 int num_elements = 1 + (buffer_.size() - 1) / window_stride; 184 window_elements.reserve(num_elements); 185 for (size_t i = 0; i < num_elements; ++i) { 186 status.Update(buffer_[window_stride * i].status); 187 if (!status.ok()) { 188 break; 189 } 190 window_elements.emplace_back(buffer_[window_stride * i].result); 191 } 192 193 // Shift the window, discarding elements if necessary. 194 int buffer_size = buffer_.size(); 195 if (window_shift >= buffer_size) { 196 for (size_t i = buffer_size; input_impl_ && i < window_shift; ++i) { 197 bool end_of_input; 198 std::vector<Tensor> element; 199 // Ignore non-error status of discarded elements. 200 input_impl_->GetNext(ctx, &element, &end_of_input).IgnoreError(); 201 if (end_of_input) { 202 input_impl_.reset(); 203 } 204 } 205 for (size_t i = 0; i < buffer_.size(); ++i) { 206 RecordBufferDequeue(ctx, buffer_.at(i).result); 207 } 208 buffer_.clear(); 209 } else { 210 for (size_t i = 0; i < window_shift; ++i) { 211 RecordBufferDequeue(ctx, buffer_.at(i).result); 212 } 213 buffer_.erase(buffer_.begin(), buffer_.begin() + window_shift); 214 } 215 } 216 217 if (!status.ok()) { 218 return status; 219 } 220 221 // Construct output tensors. 222 const size_t num_tuple_components = window_elements[0].size(); 223 const int64 num_window_elements = window_elements.size(); 224 *end_of_sequence = false; 225 for (size_t idx = 0; idx < num_tuple_components; ++idx) { 226 DatasetBase* window_dataset; 227 std::vector<std::vector<Tensor>> window_component_elements; 228 window_component_elements.reserve(num_window_elements); 229 // Build the output tuple component by copying one slice 230 // from each input element in the window. 231 for (size_t i = 0; i < num_window_elements; ++i) { 232 std::vector<Tensor> component_element; 233 component_element.push_back(std::move(window_elements[i][idx])); 234 window_component_elements.push_back(component_element); 235 } 236 DataTypeVector output_types( 237 {dataset()->input_->output_dtypes()[idx]}); 238 std::vector<PartialTensorShape> output_shapes( 239 {dataset()->input_->output_shapes()[idx]}); 240 TF_RETURN_IF_ERROR(NewWindowDataset(window_component_elements, 241 output_types, output_shapes, 242 &window_dataset)); 243 out_tensors->emplace_back(DT_VARIANT, TensorShape({})); 244 TF_RETURN_IF_ERROR(StoreDatasetInVariantTensor(window_dataset, 245 &out_tensors->back())); 246 } 247 return Status::OK(); 248 } 249 250 protected: CreateNode(IteratorContext * ctx,model::Node::Args args) const251 std::shared_ptr<model::Node> CreateNode( 252 IteratorContext* ctx, model::Node::Args args) const override { 253 return model::MakeKnownRatioNode(std::move(args), 254 dataset()->window_shift_); 255 } 256 SaveInternal(IteratorStateWriter * writer)257 Status SaveInternal(IteratorStateWriter* writer) override { 258 mutex_lock l(mu_); 259 if (!input_impl_) { 260 TF_RETURN_IF_ERROR( 261 writer->WriteScalar(full_name("input_impl_empty"), "")); 262 } else { 263 TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); 264 } 265 // Save buffer. 266 TF_RETURN_IF_ERROR(writer->WriteScalar(strings::StrCat("buffer_size"), 267 buffer_.size())); 268 for (int64 i = 0; i < buffer_.size(); i++) { 269 TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, buffer_[i].status)); 270 TF_RETURN_IF_ERROR( 271 writer->WriteScalar(strings::StrCat("buffer[", i, "].size"), 272 buffer_[i].result.size())); 273 for (int64 j = 0; j < buffer_[i].result.size(); j++) { 274 TF_RETURN_IF_ERROR( 275 writer->WriteTensor(strings::StrCat("buffer[", i, "][", j, "]"), 276 buffer_[i].result[j])); 277 } 278 } 279 return Status::OK(); 280 } 281 RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)282 Status RestoreInternal(IteratorContext* ctx, 283 IteratorStateReader* reader) override { 284 mutex_lock l(mu_); 285 if (!reader->Contains(full_name("input_impl_empty"))) { 286 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); 287 } else { 288 input_impl_.reset(); 289 } 290 // Restore buffer. 291 int64 buffer_size; 292 TF_RETURN_IF_ERROR( 293 reader->ReadScalar(strings::StrCat("buffer_size"), &buffer_size)); 294 buffer_.resize(buffer_size); 295 for (int64 i = 0; i < buffer_size; i++) { 296 int64 vector_size; 297 TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &buffer_[i].status)); 298 TF_RETURN_IF_ERROR(reader->ReadScalar( 299 strings::StrCat("buffer[", i, "].size"), &vector_size)); 300 buffer_[i].result.resize(vector_size); 301 for (int64 j = 0; j < vector_size; j++) { 302 TF_RETURN_IF_ERROR( 303 reader->ReadTensor(strings::StrCat("buffer[", i, "][", j, "]"), 304 &buffer_[i].result[j])); 305 } 306 } 307 return Status::OK(); 308 } 309 310 private: 311 struct InvocationResult { 312 InvocationResult() = default; InvocationResulttensorflow::data::__anon4ffa73280111::WindowDatasetOp::Dataset::Iterator::InvocationResult313 InvocationResult(std::vector<Tensor>&& result, const Status& status) 314 : result(result), status(status) {} 315 316 std::vector<Tensor> result; 317 Status status; 318 }; 319 WriteStatusLocked(IteratorStateWriter * writer,size_t index,const Status & status)320 Status WriteStatusLocked(IteratorStateWriter* writer, size_t index, 321 const Status& status) 322 EXCLUSIVE_LOCKS_REQUIRED(mu_) { 323 TF_RETURN_IF_ERROR(writer->WriteScalar( 324 CodeKey(index), static_cast<int64>(status.code()))); 325 if (!status.ok()) { 326 TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index), 327 status.error_message())); 328 } 329 return Status::OK(); 330 } 331 ReadStatusLocked(IteratorStateReader * reader,size_t index,Status * status)332 Status ReadStatusLocked(IteratorStateReader* reader, size_t index, 333 Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { 334 int64 code_int; 335 TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int)); 336 error::Code code = static_cast<error::Code>(code_int); 337 338 if (code != error::Code::OK) { 339 string error_message; 340 TF_RETURN_IF_ERROR( 341 reader->ReadScalar(ErrorMessageKey(index), &error_message)); 342 *status = Status(code, error_message); 343 } else { 344 *status = Status::OK(); 345 } 346 return Status::OK(); 347 } 348 CodeKey(size_t index)349 string CodeKey(size_t index) { 350 return full_name(strings::StrCat("buffer[", index, "].code")); 351 } 352 ErrorMessageKey(size_t index)353 string ErrorMessageKey(size_t index) { 354 return full_name(strings::StrCat("buffer[", index, "].error_message")); 355 } 356 TargetBufferSize(int64 window_size,int64 window_stride)357 size_t TargetBufferSize(int64 window_size, int64 window_stride) { 358 return (window_size - 1) * window_stride + 1; 359 } 360 361 mutex mu_; 362 std::deque<InvocationResult> buffer_ GUARDED_BY(mu_); 363 std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); 364 }; 365 366 const DatasetBase* const input_; 367 const int64 window_size_; 368 const int64 window_shift_; 369 const int64 window_stride_; 370 const bool drop_remainder_; 371 }; 372 }; 373 374 REGISTER_KERNEL_BUILDER(Name("WindowDataset").Device(DEVICE_CPU), 375 WindowDatasetOp); 376 } // namespace 377 } // namespace data 378 } // namespace tensorflow 379