1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/core/framework/dataset.h" 16 #include "tensorflow/core/framework/partial_tensor_shape.h" 17 #include "tensorflow/core/framework/tensor.h" 18 #include "tensorflow/core/framework/variant.h" 19 20 namespace tensorflow { 21 namespace data { 22 namespace { 23 24 // See documentation in ../../ops/dataset_ops.cc for a high-level 25 // description of the following op. 26 27 class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel { 28 public: DenseToSparseBatchDatasetOp(OpKernelConstruction * ctx)29 explicit DenseToSparseBatchDatasetOp(OpKernelConstruction* ctx) 30 : UnaryDatasetOpKernel(ctx) {} 31 MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)32 void MakeDataset(OpKernelContext* ctx, DatasetBase* input, 33 DatasetBase** output) override { 34 // Create a new DenseToSparseBatchDatasetOp::Dataset, insert it in the 35 // step-local container, and return it as the output. 36 OP_REQUIRES( 37 ctx, input->output_dtypes().size() == 1, 38 errors::InvalidArgument("DenseToSparseBatchDataset only supports " 39 "inputs with a single component.")); 40 41 int64 batch_size; 42 OP_REQUIRES_OK(ctx, 43 ParseScalarArgument<int64>(ctx, "batch_size", &batch_size)); 44 OP_REQUIRES( 45 ctx, batch_size > 0, 46 errors::InvalidArgument("Batch size must be greater than zero.")); 47 48 const Tensor* row_shape_t; 49 OP_REQUIRES_OK(ctx, ctx->input("row_shape", &row_shape_t)); 50 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(row_shape_t->shape()), 51 errors::InvalidArgument("row_shape must be a vector")); 52 PartialTensorShape row_shape; 53 OP_REQUIRES_OK(ctx, PartialTensorShape::MakePartialShape( 54 row_shape_t->vec<int64>().data(), 55 row_shape_t->NumElements(), &row_shape)); 56 57 *output = nullptr; 58 59 #define HANDLE_TYPE(T) \ 60 case DataTypeToEnum<T>::value: { \ 61 *output = new Dataset<T>(ctx, batch_size, row_shape, input); \ 62 break; \ 63 } 64 65 switch (input->output_dtypes()[0]) { 66 TF_CALL_DATASET_TYPES(HANDLE_TYPE); 67 #undef HANDLE_TYPE 68 default: 69 OP_REQUIRES(ctx, false, 70 errors::Unimplemented( 71 "DenseToSparseBatchDataset unhandled data type: ", 72 input->output_dtypes()[0])); 73 } 74 } 75 76 private: 77 // TODO(mrry): Push the templated code down to the raw copying routine. 78 template <class T> 79 class Dataset : public DatasetBase { 80 public: Dataset(OpKernelContext * ctx,int64 batch_size,const PartialTensorShape & row_shape,const DatasetBase * input)81 Dataset(OpKernelContext* ctx, int64 batch_size, 82 const PartialTensorShape& row_shape, const DatasetBase* input) 83 : DatasetBase(DatasetContext(ctx)), 84 batch_size_(batch_size), 85 row_shape_(row_shape), 86 input_(input) { 87 input_->Ref(); 88 89 output_shapes_.reserve(1); 90 PartialTensorShape output_shape({-1}); 91 output_shape.AppendShape(row_shape_); 92 output_shapes_.push_back(output_shape); 93 } 94 ~Dataset()95 ~Dataset() override { input_->Unref(); } 96 MakeIteratorInternal(const string & prefix) const97 std::unique_ptr<IteratorBase> MakeIteratorInternal( 98 const string& prefix) const override { 99 return absl::make_unique<Iterator>(typename Iterator::Params{ 100 this, strings::StrCat(prefix, "::DenseToSparseBatch")}); 101 } 102 output_dtypes() const103 const DataTypeVector& output_dtypes() const override { 104 static DataTypeVector* output_dtypes = new DataTypeVector({DT_VARIANT}); 105 return *output_dtypes; 106 } 107 output_shapes() const108 const std::vector<PartialTensorShape>& output_shapes() const override { 109 return output_shapes_; 110 } 111 DebugString() const112 string DebugString() const override { 113 return strings::StrCat("DenseToSparseBatchDatasetOp(", batch_size_, 114 ")::Dataset"); 115 } 116 Cardinality() const117 int64 Cardinality() const override { 118 int64 n = input_->Cardinality(); 119 if (n == kInfiniteCardinality || n == kUnknownCardinality) { 120 return n; 121 } 122 return n / batch_size_ + (n % batch_size_ == 0 ? 0 : 1); 123 } 124 125 protected: AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const126 Status AsGraphDefInternal(SerializationContext* ctx, 127 DatasetGraphDefBuilder* b, 128 Node** output) const override { 129 Node* input_node; 130 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); 131 Node* batch_size_node; 132 TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size_node)); 133 Node* row_shape_node; 134 std::vector<int64> row_shape; 135 row_shape.reserve( 136 row_shape_.dims()); // not an unknown rank PartialTensorShape 137 for (int i = 0; i < row_shape_.dims(); i++) 138 row_shape.emplace_back(row_shape_.dim_size(i)); 139 TF_RETURN_IF_ERROR(b->AddVector(row_shape, &row_shape_node)); 140 TF_RETURN_IF_ERROR(b->AddDataset( 141 this, {input_node, batch_size_node, row_shape_node}, output)); 142 return Status::OK(); 143 } 144 145 private: 146 class Iterator : public DatasetIterator<Dataset<T>> { 147 public: Iterator(const typename Iterator::Params & params)148 explicit Iterator(const typename Iterator::Params& params) 149 : DatasetIterator<Dataset<T>>(params) {} 150 Initialize(IteratorContext * ctx)151 Status Initialize(IteratorContext* ctx) override { 152 return DatasetIterator<Dataset<T>>::dataset()->input_->MakeIterator( 153 ctx, DatasetIterator<Dataset<T>>::prefix(), &input_impl_); 154 } 155 GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)156 Status GetNextInternal(IteratorContext* ctx, 157 std::vector<Tensor>* out_tensors, 158 bool* end_of_sequence) override { 159 // Each row of the output SparseTensor is an individual tensor 160 // from the input iterator. 161 std::vector<Tensor> batch_elements; 162 int64 total_elements = 0; 163 batch_elements.reserve( 164 DatasetIterator<Dataset<T>>::dataset()->batch_size_); 165 const PartialTensorShape& row_shape = 166 DatasetIterator<Dataset<T>>::dataset()->row_shape_; 167 const int row_ndims = row_shape.dims(); 168 169 // Determine the size of the output tensors: 170 // * dense_shape will be [`row_shape + 1`]. 171 Tensor dense_shape(ctx->allocator({}), DT_INT64, {row_ndims + 1}); 172 auto dense_shape_vec = dense_shape.vec<int64>(); 173 for (size_t i = 0; i < row_ndims; ++i) { 174 if (row_shape.dim_size(i) == -1) { 175 dense_shape_vec(i + 1) = 0; 176 } else { 177 dense_shape_vec(i + 1) = row_shape.dim_size(i); 178 } 179 } 180 181 { 182 mutex_lock l(mu_); 183 *end_of_sequence = false; 184 for (int i = 0; 185 i < DatasetIterator<Dataset<T>>::dataset()->batch_size_ && 186 !*end_of_sequence; 187 ++i) { 188 std::vector<Tensor> batch_element_tuple; 189 TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &batch_element_tuple, 190 end_of_sequence)); 191 if (!*end_of_sequence) { 192 DCHECK_EQ(1, batch_element_tuple.size()); 193 batch_elements.push_back(std::move(batch_element_tuple[0])); 194 total_elements += batch_element_tuple[0].NumElements(); 195 196 // TODO(mrry): Investigate how to hoist this check when we 197 // have static information that renders it unnecessary. 198 if (batch_element_tuple[0].shape().dims() != row_ndims) { 199 return errors::InvalidArgument( 200 "Input element had shape (", 201 batch_element_tuple[0].shape().DebugString(), 202 ") that is incompatible with the row shape (", 203 row_shape.DebugString(), ")."); 204 } 205 for (int j = 0; j < row_ndims; ++j) { 206 // Take the maximum in the dimension if -1 is given. 207 if (row_shape.dim_size(j) == -1) { 208 dense_shape_vec(j + 1) = 209 std::max(batch_element_tuple[0].dim_size(j), 210 dense_shape_vec(j + 1)); 211 } else if (batch_element_tuple[0].dim_size(j) > 212 row_shape.dim_size(j)) { 213 return errors::DataLoss( 214 "Input element had shape (", 215 batch_element_tuple[0].shape().DebugString(), 216 ") that is larger than the row shape (", 217 row_shape.DebugString(), ")."); 218 } 219 } 220 } 221 } 222 } 223 224 if (batch_elements.empty()) { 225 DCHECK(*end_of_sequence); 226 return Status::OK(); 227 } 228 229 // * indices will be [`total_elements`, `row_shape + 1`]. 230 // * values will be [`total_elements`]. 231 Tensor indices(ctx->allocator({}), DT_INT64, 232 {total_elements, row_ndims + 1}); 233 Tensor values( 234 ctx->allocator({}), 235 DatasetIterator<Dataset<T>>::dataset()->input_->output_dtypes()[0], 236 {total_elements}); 237 auto indices_matrix = indices.matrix<int64>(); 238 auto values_flat = values.flat<T>(); 239 240 int64 current_position_in_values = 0; 241 for (int64 i = 0; i < batch_elements.size(); ++i) { 242 const Tensor& t = batch_elements[i]; 243 const auto& t_flat = t.flat<T>(); 244 // TODO(mrry): Replace with a memcpy or something more 245 // efficient. (Maybe an Eigen assign op?) 246 gtl::InlinedVector<int64, 4> strides(row_ndims); 247 if (!strides.empty()) { 248 strides[row_ndims - 1] = 1; 249 for (int64_t row_dim = strides.size() - 2; row_dim >= 0; 250 --row_dim) { 251 strides[row_dim] = 252 strides[row_dim + 1] * t.shape().dim_size(row_dim + 1); 253 } 254 } 255 256 for (int64 j = 0; j < t.NumElements(); ++j) { 257 values_flat(current_position_in_values) = t_flat(j); 258 indices_matrix(current_position_in_values, 0) = i; 259 int64 index = j; 260 for (size_t k = 0; k < strides.size(); ++k) { 261 indices_matrix(current_position_in_values, k + 1) = 262 index / strides[k]; 263 index %= strides[k]; 264 } 265 ++current_position_in_values; 266 } 267 } 268 269 dense_shape_vec(0) = batch_elements.size(); 270 271 Tensor serialized_sparse(DT_VARIANT, TensorShape({3})); 272 auto serialized_sparse_t = serialized_sparse.vec<Variant>(); 273 serialized_sparse_t(0) = std::move(indices); 274 serialized_sparse_t(1) = std::move(values); 275 serialized_sparse_t(2) = std::move(dense_shape); 276 out_tensors->push_back(std::move(serialized_sparse)); 277 278 *end_of_sequence = false; 279 return Status::OK(); 280 } 281 282 protected: CreateNode(IteratorContext * ctx,model::Node::Args args) const283 std::shared_ptr<model::Node> CreateNode( 284 IteratorContext* ctx, model::Node::Args args) const override { 285 return model::MakeKnownRatioNode( 286 std::move(args), 287 DatasetIterator<Dataset<T>>::dataset()->batch_size_); 288 } 289 SaveInternal(IteratorStateWriter * writer)290 Status SaveInternal(IteratorStateWriter* writer) override { 291 mutex_lock l(mu_); 292 TF_RETURN_IF_ERROR(Iterator::SaveInput(writer, input_impl_)); 293 return Status::OK(); 294 } 295 RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)296 Status RestoreInternal(IteratorContext* ctx, 297 IteratorStateReader* reader) override { 298 mutex_lock l(mu_); 299 TF_RETURN_IF_ERROR(Iterator::RestoreInput(ctx, reader, input_impl_)); 300 return Status::OK(); 301 } 302 303 private: 304 mutex mu_; 305 std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); 306 }; 307 308 const int64 batch_size_; 309 const PartialTensorShape row_shape_; 310 const DatasetBase* const input_; 311 std::vector<PartialTensorShape> output_shapes_; 312 }; 313 }; 314 315 REGISTER_KERNEL_BUILDER( 316 Name("ExperimentalDenseToSparseBatchDataset").Device(DEVICE_CPU), 317 DenseToSparseBatchDatasetOp); 318 319 } // namespace 320 } // namespace data 321 } // namespace tensorflow 322