1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" 17 #include "tensorflow/core/framework/op_kernel.h" 18 19 namespace tensorflow { 20 namespace data { 21 namespace { 22 23 class BigtableScanDatasetOp : public DatasetOpKernel { 24 public: 25 using DatasetOpKernel::DatasetOpKernel; 26 MakeDataset(OpKernelContext * ctx,DatasetBase ** output)27 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { 28 string prefix; 29 OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "prefix", &prefix)); 30 string start_key; 31 OP_REQUIRES_OK(ctx, 32 ParseScalarArgument<string>(ctx, "start_key", &start_key)); 33 string end_key; 34 OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "end_key", &end_key)); 35 36 OP_REQUIRES(ctx, !(prefix.empty() && start_key.empty()), 37 errors::InvalidArgument( 38 "Either prefix or start_key must be specified")); 39 OP_REQUIRES(ctx, prefix.empty() || start_key.empty(), 40 errors::InvalidArgument( 41 "Only one of prefix and start_key can be provided")); 42 if (!prefix.empty()) { 43 OP_REQUIRES(ctx, end_key.empty(), 44 errors::InvalidArgument( 45 "If prefix is specified, end_key must be empty.")); 46 } 47 48 std::vector<string> column_families; 49 std::vector<string> columns; 50 OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "column_families", 51 &column_families)); 52 OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "columns", &columns)); 53 OP_REQUIRES( 54 ctx, column_families.size() == columns.size(), 55 errors::InvalidArgument("len(columns) != len(column_families)")); 56 OP_REQUIRES(ctx, !column_families.empty(), 57 errors::InvalidArgument("`column_families` is empty")); 58 59 float probability = 0; 60 OP_REQUIRES_OK( 61 ctx, ParseScalarArgument<float>(ctx, "probability", &probability)); 62 OP_REQUIRES( 63 ctx, probability > 0 && probability <= 1, 64 errors::InvalidArgument( 65 "Probability outside the range of (0, 1]. Got: ", probability)); 66 67 BigtableTableResource* resource; 68 OP_REQUIRES_OK(ctx, 69 LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); 70 core::ScopedUnref scoped_unref(resource); 71 72 const uint64 num_outputs = columns.size() + 1; 73 std::vector<PartialTensorShape> output_shapes; 74 output_shapes.reserve(num_outputs); 75 DataTypeVector output_types; 76 output_types.reserve(num_outputs); 77 for (uint64 i = 0; i < num_outputs; ++i) { 78 output_shapes.push_back({}); 79 output_types.push_back(DT_STRING); 80 } 81 82 *output = new Dataset(ctx, resource, std::move(prefix), 83 std::move(start_key), std::move(end_key), 84 std::move(column_families), std::move(columns), 85 probability, output_types, std::move(output_shapes)); 86 } 87 88 private: 89 class Dataset : public DatasetBase { 90 public: Dataset(OpKernelContext * ctx,BigtableTableResource * table,string prefix,string start_key,string end_key,std::vector<string> column_families,std::vector<string> columns,float probability,const DataTypeVector & output_types,std::vector<PartialTensorShape> output_shapes)91 explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table, 92 string prefix, string start_key, string end_key, 93 std::vector<string> column_families, 94 std::vector<string> columns, float probability, 95 const DataTypeVector& output_types, 96 std::vector<PartialTensorShape> output_shapes) 97 : DatasetBase(DatasetContext(ctx)), 98 table_(table), 99 prefix_(std::move(prefix)), 100 start_key_(std::move(start_key)), 101 end_key_(std::move(end_key)), 102 column_families_(std::move(column_families)), 103 columns_(std::move(columns)), 104 column_family_regex_(RegexFromStringSet(column_families_)), 105 column_regex_(RegexFromStringSet(columns_)), 106 probability_(probability), 107 output_types_(output_types), 108 output_shapes_(std::move(output_shapes)) { 109 table_->Ref(); 110 } 111 ~Dataset()112 ~Dataset() override { table_->Unref(); } 113 MakeIteratorInternal(const string & prefix) const114 std::unique_ptr<IteratorBase> MakeIteratorInternal( 115 const string& prefix) const override { 116 return std::unique_ptr<IteratorBase>( 117 new Iterator({this, strings::StrCat(prefix, "::BigtableScan")})); 118 } 119 output_dtypes() const120 const DataTypeVector& output_dtypes() const override { 121 return output_types_; 122 } 123 output_shapes() const124 const std::vector<PartialTensorShape>& output_shapes() const override { 125 return output_shapes_; 126 } 127 DebugString() const128 string DebugString() const override { 129 return "BigtableScanDatasetOp::Dataset"; 130 } 131 table() const132 BigtableTableResource* table() const { return table_; } 133 134 protected: AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const135 Status AsGraphDefInternal(SerializationContext* ctx, 136 DatasetGraphDefBuilder* b, 137 Node** output) const override { 138 return errors::Unimplemented("%s does not support serialization", 139 DebugString()); 140 } 141 142 private: 143 class Iterator : public BigtableReaderDatasetIterator<Dataset> { 144 public: Iterator(const Params & params)145 explicit Iterator(const Params& params) 146 : BigtableReaderDatasetIterator<Dataset>(params) {} 147 MakeRowRange()148 ::google::cloud::bigtable::RowRange MakeRowRange() override { 149 if (!dataset()->prefix_.empty()) { 150 DCHECK(dataset()->start_key_.empty()); 151 return ::google::cloud::bigtable::RowRange::Prefix( 152 dataset()->prefix_); 153 } else { 154 DCHECK(!dataset()->start_key_.empty()) 155 << "Both prefix and start_key were empty!"; 156 return ::google::cloud::bigtable::RowRange::Range( 157 dataset()->start_key_, dataset()->end_key_); 158 } 159 } MakeFilter()160 ::google::cloud::bigtable::Filter MakeFilter() override { 161 // TODO(saeta): Investigate optimal ordering here. 162 return ::google::cloud::bigtable::Filter::Chain( 163 ::google::cloud::bigtable::Filter::Latest(1), 164 ::google::cloud::bigtable::Filter::FamilyRegex( 165 dataset()->column_family_regex_), 166 ::google::cloud::bigtable::Filter::ColumnRegex( 167 dataset()->column_regex_), 168 dataset()->probability_ != 1.0 169 ? ::google::cloud::bigtable::Filter::RowSample( 170 dataset()->probability_) 171 : ::google::cloud::bigtable::Filter::PassAllFilter()); 172 } ParseRow(IteratorContext * ctx,const::google::cloud::bigtable::Row & row,std::vector<Tensor> * out_tensors)173 Status ParseRow(IteratorContext* ctx, 174 const ::google::cloud::bigtable::Row& row, 175 std::vector<Tensor>* out_tensors) override { 176 out_tensors->reserve(dataset()->columns_.size() + 1); 177 Tensor row_key_tensor(ctx->allocator({}), DT_STRING, {}); 178 row_key_tensor.scalar<string>()() = string(row.row_key()); 179 out_tensors->emplace_back(std::move(row_key_tensor)); 180 181 if (row.cells().size() > 2 * dataset()->columns_.size()) { 182 LOG(WARNING) << "An excessive number of columns (" 183 << row.cells().size() 184 << ") were retrieved when reading row: " 185 << row.row_key(); 186 } 187 188 for (uint64 i = 0; i < dataset()->columns_.size(); ++i) { 189 Tensor col_tensor(ctx->allocator({}), DT_STRING, {}); 190 bool found_column = false; 191 for (auto cell_itr = row.cells().begin(); 192 !found_column && cell_itr != row.cells().end(); ++cell_itr) { 193 if (cell_itr->family_name() == dataset()->column_families_[i] && 194 string(cell_itr->column_qualifier()) == 195 dataset()->columns_[i]) { 196 col_tensor.scalar<string>()() = string(cell_itr->value()); 197 found_column = true; 198 } 199 } 200 if (!found_column) { 201 return errors::InvalidArgument( 202 "Column ", dataset()->column_families_[i], ":", 203 dataset()->columns_[i], " not found in row: ", row.row_key()); 204 } 205 out_tensors->emplace_back(std::move(col_tensor)); 206 } 207 return Status::OK(); 208 } 209 }; 210 211 BigtableTableResource* table_; 212 const string prefix_; 213 const string start_key_; 214 const string end_key_; 215 const std::vector<string> column_families_; 216 const std::vector<string> columns_; 217 const string column_family_regex_; 218 const string column_regex_; 219 const float probability_; 220 const DataTypeVector output_types_; 221 const std::vector<PartialTensorShape> output_shapes_; 222 }; 223 }; 224 225 REGISTER_KERNEL_BUILDER(Name("BigtableScanDataset").Device(DEVICE_CPU), 226 BigtableScanDatasetOp); 227 228 } // namespace 229 } // namespace data 230 } // namespace tensorflow 231