1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" 17 #include "tensorflow/core/framework/op_kernel.h" 18 19 namespace tensorflow { 20 namespace data { 21 namespace { 22 23 class BigtableLookupDatasetOp : public UnaryDatasetOpKernel { 24 public: 25 using UnaryDatasetOpKernel::UnaryDatasetOpKernel; 26 MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)27 void MakeDataset(OpKernelContext* ctx, DatasetBase* input, 28 DatasetBase** output) override { 29 BigtableTableResource* table; 30 OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &table)); 31 core::ScopedUnref scoped_unref(table); 32 33 std::vector<string> column_families; 34 std::vector<string> columns; 35 OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "column_families", 36 &column_families)); 37 OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "columns", &columns)); 38 OP_REQUIRES( 39 ctx, column_families.size() == columns.size(), 40 errors::InvalidArgument("len(columns) != len(column_families)")); 41 42 const uint64 num_outputs = columns.size() + 1; 43 std::vector<PartialTensorShape> output_shapes; 44 output_shapes.reserve(num_outputs); 45 DataTypeVector output_types; 46 output_types.reserve(num_outputs); 47 for (uint64 i = 0; i < num_outputs; ++i) { 48 output_shapes.push_back({}); 49 output_types.push_back(DT_STRING); 50 } 51 52 *output = 53 new Dataset(ctx, input, table, std::move(column_families), 54 std::move(columns), output_types, std::move(output_shapes)); 55 } 56 57 private: 58 class Dataset : public DatasetBase { 59 public: Dataset(OpKernelContext * ctx,const DatasetBase * input,BigtableTableResource * table,std::vector<string> column_families,std::vector<string> columns,const DataTypeVector & output_types,std::vector<PartialTensorShape> output_shapes)60 explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, 61 BigtableTableResource* table, 62 std::vector<string> column_families, 63 std::vector<string> columns, 64 const DataTypeVector& output_types, 65 std::vector<PartialTensorShape> output_shapes) 66 : DatasetBase(DatasetContext(ctx)), 67 input_(input), 68 table_(table), 69 column_families_(std::move(column_families)), 70 columns_(std::move(columns)), 71 output_types_(output_types), 72 output_shapes_(std::move(output_shapes)), 73 filter_(MakeFilter(column_families_, columns_)) { 74 table_->Ref(); 75 input_->Ref(); 76 } 77 ~Dataset()78 ~Dataset() override { 79 table_->Unref(); 80 input_->Unref(); 81 } 82 MakeIteratorInternal(const string & prefix) const83 std::unique_ptr<IteratorBase> MakeIteratorInternal( 84 const string& prefix) const override { 85 return std::unique_ptr<IteratorBase>( 86 new Iterator({this, strings::StrCat(prefix, "::BigtableLookup")})); 87 } 88 output_dtypes() const89 const DataTypeVector& output_dtypes() const override { 90 return output_types_; 91 } 92 output_shapes() const93 const std::vector<PartialTensorShape>& output_shapes() const override { 94 return output_shapes_; 95 } 96 DebugString() const97 string DebugString() const override { 98 return "BigtableLookupDatasetOp::Dataset"; 99 } 100 101 protected: AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const102 Status AsGraphDefInternal(SerializationContext* ctx, 103 DatasetGraphDefBuilder* b, 104 Node** output) const override { 105 return errors::Unimplemented("%s does not support serialization", 106 DebugString()); 107 } 108 109 private: MakeFilter(const std::vector<string> & column_families,const std::vector<string> & columns)110 static ::google::cloud::bigtable::Filter MakeFilter( 111 const std::vector<string>& column_families, 112 const std::vector<string>& columns) { 113 string column_family_regex = RegexFromStringSet(column_families); 114 string column_regex = RegexFromStringSet(columns); 115 116 return ::google::cloud::bigtable::Filter::Chain( 117 ::google::cloud::bigtable::Filter::Latest(1), 118 ::google::cloud::bigtable::Filter::FamilyRegex(column_family_regex), 119 ::google::cloud::bigtable::Filter::ColumnRegex(column_regex)); 120 } 121 122 class Iterator : public DatasetIterator<Dataset> { 123 public: Iterator(const Params & params)124 explicit Iterator(const Params& params) 125 : DatasetIterator<Dataset>(params) {} 126 Initialize(IteratorContext * ctx)127 Status Initialize(IteratorContext* ctx) override { 128 return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); 129 } 130 GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)131 Status GetNextInternal(IteratorContext* ctx, 132 std::vector<Tensor>* out_tensors, 133 bool* end_of_sequence) override { 134 mutex_lock l(mu_); // Sequence requests. 135 std::vector<Tensor> input_tensors; 136 TF_RETURN_IF_ERROR( 137 input_impl_->GetNext(ctx, &input_tensors, end_of_sequence)); 138 if (*end_of_sequence) { 139 return Status::OK(); 140 } 141 if (input_tensors.size() != 1) { 142 return errors::InvalidArgument( 143 "Upstream iterator (", dataset()->input_->DebugString(), 144 ") did not produce a single `tf.string` `tf.Tensor`. It " 145 "produced ", 146 input_tensors.size(), " tensors."); 147 } 148 if (input_tensors[0].NumElements() == 0) { 149 return errors::InvalidArgument("Upstream iterator (", 150 dataset()->input_->DebugString(), 151 ") return an empty set of keys."); 152 } 153 if (input_tensors[0].NumElements() == 1) { 154 // Single key lookup. 155 ::google::cloud::Status status; 156 auto pair = dataset()->table_->table().ReadRow( 157 input_tensors[0].scalar<string>()(), dataset()->filter_, status); 158 if (!status.ok()) { 159 return GcpStatusToTfStatus(status); 160 } 161 if (!pair.first) { 162 return errors::DataLoss("Row key '", 163 input_tensors[0].scalar<string>()(), 164 "' not found."); 165 } 166 TF_RETURN_IF_ERROR(ParseRow(ctx, pair.second, out_tensors)); 167 } else { 168 // Batched get. 169 return errors::Unimplemented( 170 "BigtableLookupDataset doesn't yet support batched retrieval."); 171 } 172 return Status::OK(); 173 } 174 175 private: ParseRow(IteratorContext * ctx,const::google::cloud::bigtable::Row & row,std::vector<Tensor> * out_tensors)176 Status ParseRow(IteratorContext* ctx, 177 const ::google::cloud::bigtable::Row& row, 178 std::vector<Tensor>* out_tensors) { 179 out_tensors->reserve(dataset()->columns_.size() + 1); 180 Tensor row_key_tensor(ctx->allocator({}), DT_STRING, {}); 181 row_key_tensor.scalar<string>()() = string(row.row_key()); 182 out_tensors->emplace_back(std::move(row_key_tensor)); 183 184 if (row.cells().size() > 2 * dataset()->columns_.size()) { 185 LOG(WARNING) << "An excessive number of columns (" 186 << row.cells().size() 187 << ") were retrieved when reading row: " 188 << row.row_key(); 189 } 190 191 for (uint64 i = 0; i < dataset()->columns_.size(); ++i) { 192 Tensor col_tensor(ctx->allocator({}), DT_STRING, {}); 193 bool found_column = false; 194 for (auto cell_itr = row.cells().begin(); 195 !found_column && cell_itr != row.cells().end(); ++cell_itr) { 196 if (cell_itr->family_name() == dataset()->column_families_[i] && 197 string(cell_itr->column_qualifier()) == 198 dataset()->columns_[i]) { 199 col_tensor.scalar<string>()() = string(cell_itr->value()); 200 found_column = true; 201 } 202 } 203 if (!found_column) { 204 return errors::DataLoss("Column ", dataset()->column_families_[i], 205 ":", dataset()->columns_[i], 206 " not found in row: ", row.row_key()); 207 } 208 out_tensors->emplace_back(std::move(col_tensor)); 209 } 210 return Status::OK(); 211 } 212 213 mutex mu_; 214 std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); 215 }; 216 217 const DatasetBase* const input_; 218 BigtableTableResource* table_; 219 const std::vector<string> column_families_; 220 const std::vector<string> columns_; 221 const DataTypeVector output_types_; 222 const std::vector<PartialTensorShape> output_shapes_; 223 const ::google::cloud::bigtable::Filter filter_; 224 }; 225 }; 226 227 REGISTER_KERNEL_BUILDER(Name("BigtableLookupDataset").Device(DEVICE_CPU), 228 BigtableLookupDatasetOp); 229 230 } // namespace 231 } // namespace data 232 } // namespace tensorflow 233