1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" 17 #include "tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h" 18 #include "tensorflow/core/framework/op_kernel.h" 19 20 namespace tensorflow { 21 namespace data { 22 namespace { 23 24 class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel { 25 public: 26 using DatasetOpKernel::DatasetOpKernel; 27 MakeDataset(OpKernelContext * ctx,DatasetBase ** output)28 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { 29 string prefix; 30 OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "prefix", &prefix)); 31 32 string start_key; 33 OP_REQUIRES_OK(ctx, 34 ParseScalarArgument<string>(ctx, "start_key", &start_key)); 35 string end_key; 36 OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "end_key", &end_key)); 37 38 BigtableTableResource* resource; 39 OP_REQUIRES_OK(ctx, 40 LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); 41 core::ScopedUnref scoped_unref(resource); 42 43 OP_REQUIRES(ctx, prefix.empty() || start_key.empty(), 44 errors::InvalidArgument( 45 "Only one of prefix and start_key can be provided")); 46 if (!prefix.empty()) { 47 OP_REQUIRES(ctx, end_key.empty(), 48 errors::InvalidArgument( 49 "If prefix is specified, end_key must be empty.")); 50 } 51 52 *output = new Dataset(ctx, resource, std::move(prefix), 53 std::move(start_key), std::move(end_key)); 54 } 55 56 private: 57 class Dataset : public DatasetBase { 58 public: Dataset(OpKernelContext * ctx,BigtableTableResource * table,string prefix,string start_key,string end_key)59 explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table, 60 string prefix, string start_key, string end_key) 61 : DatasetBase(DatasetContext(ctx)), 62 table_(table), 63 key_range_(MakeMultiModeKeyRange( 64 std::move(prefix), std::move(start_key), std::move(end_key))) { 65 table_->Ref(); 66 } 67 ~Dataset()68 ~Dataset() override { table_->Unref(); } 69 MakeIteratorInternal(const string & prefix) const70 std::unique_ptr<IteratorBase> MakeIteratorInternal( 71 const string& prefix) const override { 72 return std::unique_ptr<IteratorBase>(new Iterator( 73 {this, strings::StrCat(prefix, "::BigtableSampleKeyPairs")})); 74 } 75 output_dtypes() const76 const DataTypeVector& output_dtypes() const override { 77 static DataTypeVector* dtypes = 78 new DataTypeVector({DT_STRING, DT_STRING}); 79 return *dtypes; 80 } 81 output_shapes() const82 const std::vector<PartialTensorShape>& output_shapes() const override { 83 static std::vector<PartialTensorShape>* shapes = 84 new std::vector<PartialTensorShape>({{}, {}}); 85 return *shapes; 86 } 87 DebugString() const88 string DebugString() const override { 89 return "BigtableSampleKeyPairsDatasetOp::Dataset"; 90 } 91 92 protected: AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const93 Status AsGraphDefInternal(SerializationContext* ctx, 94 DatasetGraphDefBuilder* b, 95 Node** output) const override { 96 return errors::Unimplemented("%s does not support serialization", 97 DebugString()); 98 } 99 100 private: MakeMultiModeKeyRange(string prefix,string start_key,string end_key)101 static MultiModeKeyRange MakeMultiModeKeyRange(string prefix, 102 string start_key, 103 string end_key) { 104 if (!start_key.empty()) { 105 return MultiModeKeyRange::FromRange(std::move(start_key), 106 std::move(end_key)); 107 } 108 return MultiModeKeyRange::FromPrefix(std::move(prefix)); 109 } 110 table() const111 BigtableTableResource& table() const { return *table_; } 112 113 class Iterator : public DatasetIterator<Dataset> { 114 public: Iterator(const Params & params)115 explicit Iterator(const Params& params) 116 : DatasetIterator<Dataset>(params) {} 117 118 // Computes split points (`keys_`) to use when scanning the table. 119 // 120 // Initialize first retrieves the sample keys from the table (`row_keys`), 121 // as these often form good split points within the table. We then iterate 122 // over them, and copy them to `keys_` if they fall within the requested 123 // range to scan (`dataset()->key_range_`). Because the requested range 124 // might start between elements of the sampled keys list, care is taken to 125 // ensure we don't accidentally miss any subsets of the requested range by 126 // including `begin_key()` and `end_key()` as appropriate. Initialize(IteratorContext * ctx)127 Status Initialize(IteratorContext* ctx) override { 128 grpc::Status status; 129 std::vector<google::cloud::bigtable::RowKeySample> row_keys = 130 dataset()->table().table().SampleRows(status); 131 if (!status.ok()) { 132 return GrpcStatusToTfStatus(status); 133 } 134 135 for (size_t i = 0; i < row_keys.size(); ++i) { 136 string row_key(row_keys[i].row_key); 137 if (dataset()->key_range_.contains_key(row_key)) { 138 // First key: check to see if we need to add the begin_key. 139 if (keys_.empty() && dataset()->key_range_.begin_key() != row_key) { 140 keys_.push_back(dataset()->key_range_.begin_key()); 141 } 142 keys_.push_back(std::move(row_key)); 143 } else if (!keys_.empty()) { 144 // If !keys_.empty(), then we have found at least one element of 145 // `row_keys` that is within our requested range 146 // (`dataset()->key_range_`). Because `row_keys` is sorted, if we 147 // have found an element that's not within our key range, then we 148 // are after our requested range (ranges are contiguous) and can end 149 // iteration early. 150 break; 151 } 152 } 153 154 // Handle the case where we skip over the selected range entirely. 155 if (keys_.empty()) { 156 keys_.push_back(dataset()->key_range_.begin_key()); 157 } 158 159 // Last key: check to see if we need to add the end_key. 160 if (keys_.back() != dataset()->key_range_.end_key()) { 161 keys_.push_back(dataset()->key_range_.end_key()); 162 } 163 return Status::OK(); 164 } 165 GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)166 Status GetNextInternal(IteratorContext* ctx, 167 std::vector<Tensor>* out_tensors, 168 bool* end_of_sequence) override { 169 mutex_lock l(mu_); 170 if (index_ + 2 > keys_.size()) { 171 *end_of_sequence = true; 172 return Status::OK(); 173 } 174 175 *end_of_sequence = false; 176 out_tensors->emplace_back(ctx->allocator({}), DT_STRING, 177 TensorShape({})); 178 out_tensors->back().scalar<string>()() = keys_[index_]; 179 180 out_tensors->emplace_back(ctx->allocator({}), DT_STRING, 181 TensorShape({})); 182 out_tensors->back().scalar<string>()() = keys_[index_ + 1]; 183 ++index_; 184 185 return Status::OK(); 186 } 187 188 private: 189 mutex mu_; 190 size_t index_ GUARDED_BY(mu_) = 0; 191 // Note: we store the keys_ on the iterator instead of the dataset 192 // because we want to re-sample the row keys in case there have been 193 // tablet rebalancing operations since the dataset was created. 194 // 195 // Note: keys_ is readonly after Initialize, and thus does not need a 196 // guarding lock. 197 std::vector<string> keys_; 198 }; 199 200 BigtableTableResource* const table_; 201 const MultiModeKeyRange key_range_; 202 }; 203 }; 204 205 REGISTER_KERNEL_BUILDER( 206 Name("BigtableSampleKeyPairsDataset").Device(DEVICE_CPU), 207 BigtableSampleKeyPairsDatasetOp); 208 209 } // namespace 210 } // namespace data 211 } // namespace tensorflow 212