1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
17 #include "tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h"
18 #include "tensorflow/core/framework/op_kernel.h"
19 
20 namespace tensorflow {
21 namespace data {
22 namespace {
23 
24 class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel {
25  public:
26   using DatasetOpKernel::DatasetOpKernel;
27 
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)28   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
29     string prefix;
30     OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "prefix", &prefix));
31 
32     string start_key;
33     OP_REQUIRES_OK(ctx,
34                    ParseScalarArgument<string>(ctx, "start_key", &start_key));
35     string end_key;
36     OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "end_key", &end_key));
37 
38     BigtableTableResource* resource;
39     OP_REQUIRES_OK(ctx,
40                    LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
41     core::ScopedUnref scoped_unref(resource);
42 
43     OP_REQUIRES(ctx, prefix.empty() || start_key.empty(),
44                 errors::InvalidArgument(
45                     "Only one of prefix and start_key can be provided"));
46     if (!prefix.empty()) {
47       OP_REQUIRES(ctx, end_key.empty(),
48                   errors::InvalidArgument(
49                       "If prefix is specified, end_key must be empty."));
50     }
51 
52     *output = new Dataset(ctx, resource, std::move(prefix),
53                           std::move(start_key), std::move(end_key));
54   }
55 
56  private:
57   class Dataset : public DatasetBase {
58    public:
Dataset(OpKernelContext * ctx,BigtableTableResource * table,string prefix,string start_key,string end_key)59     explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table,
60                      string prefix, string start_key, string end_key)
61         : DatasetBase(DatasetContext(ctx)),
62           table_(table),
63           key_range_(MakeMultiModeKeyRange(
64               std::move(prefix), std::move(start_key), std::move(end_key))) {
65       table_->Ref();
66     }
67 
~Dataset()68     ~Dataset() override { table_->Unref(); }
69 
MakeIteratorInternal(const string & prefix) const70     std::unique_ptr<IteratorBase> MakeIteratorInternal(
71         const string& prefix) const override {
72       return std::unique_ptr<IteratorBase>(new Iterator(
73           {this, strings::StrCat(prefix, "::BigtableSampleKeyPairs")}));
74     }
75 
output_dtypes() const76     const DataTypeVector& output_dtypes() const override {
77       static DataTypeVector* dtypes =
78           new DataTypeVector({DT_STRING, DT_STRING});
79       return *dtypes;
80     }
81 
output_shapes() const82     const std::vector<PartialTensorShape>& output_shapes() const override {
83       static std::vector<PartialTensorShape>* shapes =
84           new std::vector<PartialTensorShape>({{}, {}});
85       return *shapes;
86     }
87 
DebugString() const88     string DebugString() const override {
89       return "BigtableSampleKeyPairsDatasetOp::Dataset";
90     }
91 
92    protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const93     Status AsGraphDefInternal(SerializationContext* ctx,
94                               DatasetGraphDefBuilder* b,
95                               Node** output) const override {
96       return errors::Unimplemented("%s does not support serialization",
97                                    DebugString());
98     }
99 
100    private:
MakeMultiModeKeyRange(string prefix,string start_key,string end_key)101     static MultiModeKeyRange MakeMultiModeKeyRange(string prefix,
102                                                    string start_key,
103                                                    string end_key) {
104       if (!start_key.empty()) {
105         return MultiModeKeyRange::FromRange(std::move(start_key),
106                                             std::move(end_key));
107       }
108       return MultiModeKeyRange::FromPrefix(std::move(prefix));
109     }
110 
table() const111     BigtableTableResource& table() const { return *table_; }
112 
113     class Iterator : public DatasetIterator<Dataset> {
114      public:
Iterator(const Params & params)115       explicit Iterator(const Params& params)
116           : DatasetIterator<Dataset>(params) {}
117 
118       // Computes split points (`keys_`) to use when scanning the table.
119       //
120       // Initialize first retrieves the sample keys from the table (`row_keys`),
121       // as these often form good split points within the table. We then iterate
122       // over them, and copy them to `keys_` if they fall within the requested
123       // range to scan (`dataset()->key_range_`). Because the requested range
124       // might start between elements of the sampled keys list, care is taken to
125       // ensure we don't accidentally miss any subsets of the requested range by
126       // including `begin_key()` and `end_key()` as appropriate.
Initialize(IteratorContext * ctx)127       Status Initialize(IteratorContext* ctx) override {
128         grpc::Status status;
129         std::vector<google::cloud::bigtable::RowKeySample> row_keys =
130             dataset()->table().table().SampleRows(status);
131         if (!status.ok()) {
132           return GrpcStatusToTfStatus(status);
133         }
134 
135         for (size_t i = 0; i < row_keys.size(); ++i) {
136           string row_key(row_keys[i].row_key);
137           if (dataset()->key_range_.contains_key(row_key)) {
138             // First key: check to see if we need to add the begin_key.
139             if (keys_.empty() && dataset()->key_range_.begin_key() != row_key) {
140               keys_.push_back(dataset()->key_range_.begin_key());
141             }
142             keys_.push_back(std::move(row_key));
143           } else if (!keys_.empty()) {
144             // If !keys_.empty(), then we have found at least one element of
145             // `row_keys` that is within our requested range
146             // (`dataset()->key_range_`). Because `row_keys` is sorted, if we
147             // have found an element that's not within our key range, then we
148             // are after our requested range (ranges are contiguous) and can end
149             // iteration early.
150             break;
151           }
152         }
153 
154         // Handle the case where we skip over the selected range entirely.
155         if (keys_.empty()) {
156           keys_.push_back(dataset()->key_range_.begin_key());
157         }
158 
159         // Last key: check to see if we need to add the end_key.
160         if (keys_.back() != dataset()->key_range_.end_key()) {
161           keys_.push_back(dataset()->key_range_.end_key());
162         }
163         return Status::OK();
164       }
165 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)166       Status GetNextInternal(IteratorContext* ctx,
167                              std::vector<Tensor>* out_tensors,
168                              bool* end_of_sequence) override {
169         mutex_lock l(mu_);
170         if (index_ + 2 > keys_.size()) {
171           *end_of_sequence = true;
172           return Status::OK();
173         }
174 
175         *end_of_sequence = false;
176         out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
177                                   TensorShape({}));
178         out_tensors->back().scalar<string>()() = keys_[index_];
179 
180         out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
181                                   TensorShape({}));
182         out_tensors->back().scalar<string>()() = keys_[index_ + 1];
183         ++index_;
184 
185         return Status::OK();
186       }
187 
188      private:
189       mutex mu_;
190       size_t index_ GUARDED_BY(mu_) = 0;
191       // Note: we store the keys_ on the iterator instead of the dataset
192       // because we want to re-sample the row keys in case there have been
193       // tablet rebalancing operations since the dataset was created.
194       //
195       // Note: keys_ is readonly after Initialize, and thus does not need a
196       // guarding lock.
197       std::vector<string> keys_;
198     };
199 
200     BigtableTableResource* const table_;
201     const MultiModeKeyRange key_range_;
202   };
203 };
204 
205 REGISTER_KERNEL_BUILDER(
206     Name("BigtableSampleKeyPairsDataset").Device(DEVICE_CPU),
207     BigtableSampleKeyPairsDatasetOp);
208 
209 }  // namespace
210 }  // namespace data
211 }  // namespace tensorflow
212