1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
17 #include "tensorflow/core/framework/op_kernel.h"
18 
19 namespace tensorflow {
20 namespace data {
21 namespace {
22 
23 class BigtableSampleKeysDatasetOp : public DatasetOpKernel {
24  public:
25   using DatasetOpKernel::DatasetOpKernel;
26 
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)27   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
28     BigtableTableResource* resource;
29     OP_REQUIRES_OK(ctx,
30                    LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
31     core::ScopedUnref scoped_unref(resource);
32     *output = new Dataset(ctx, resource);
33   }
34 
35  private:
36   class Dataset : public DatasetBase {
37    public:
Dataset(OpKernelContext * ctx,BigtableTableResource * table)38     explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table)
39         : DatasetBase(DatasetContext(ctx)), table_(table) {
40       table_->Ref();
41     }
42 
~Dataset()43     ~Dataset() override { table_->Unref(); }
44 
MakeIteratorInternal(const string & prefix) const45     std::unique_ptr<IteratorBase> MakeIteratorInternal(
46         const string& prefix) const override {
47       return std::unique_ptr<IteratorBase>(new Iterator(
48           {this, strings::StrCat(prefix, "::BigtableSampleKeys")}));
49     }
50 
output_dtypes() const51     const DataTypeVector& output_dtypes() const override {
52       static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
53       return *dtypes;
54     }
55 
output_shapes() const56     const std::vector<PartialTensorShape>& output_shapes() const override {
57       static std::vector<PartialTensorShape>* shapes =
58           new std::vector<PartialTensorShape>({{}});
59       return *shapes;
60     }
61 
DebugString() const62     string DebugString() const override {
63       return "BigtableRangeKeyDatasetOp::Dataset";
64     }
65 
table() const66     BigtableTableResource* table() const { return table_; }
67 
68    protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const69     Status AsGraphDefInternal(SerializationContext* ctx,
70                               DatasetGraphDefBuilder* b,
71                               Node** output) const override {
72       return errors::Unimplemented("%s does not support serialization",
73                                    DebugString());
74     }
75 
76    private:
77     class Iterator : public DatasetIterator<Dataset> {
78      public:
Iterator(const Params & params)79       explicit Iterator(const Params& params)
80           : DatasetIterator<Dataset>(params) {}
81 
Initialize(IteratorContext * ctx)82       Status Initialize(IteratorContext* ctx) override {
83         ::grpc::Status status;
84         row_keys_ = dataset()->table()->table().SampleRows(status);
85         if (!status.ok()) {
86           row_keys_.clear();
87           return GrpcStatusToTfStatus(status);
88         }
89         return Status::OK();
90       }
91 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)92       Status GetNextInternal(IteratorContext* ctx,
93                              std::vector<Tensor>* out_tensors,
94                              bool* end_of_sequence) override {
95         mutex_lock l(mu_);
96         if (index_ < row_keys_.size()) {
97           out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
98                                     TensorShape({}));
99           out_tensors->back().scalar<string>()() =
100               string(row_keys_[index_].row_key);
101           *end_of_sequence = false;
102           index_++;
103         } else {
104           *end_of_sequence = true;
105         }
106         return Status::OK();
107       }
108 
109      private:
110       mutex mu_;
111       size_t index_ = 0;
112       std::vector<::google::cloud::bigtable::RowKeySample> row_keys_;
113     };
114 
115     BigtableTableResource* const table_;
116   };
117 };
118 
119 REGISTER_KERNEL_BUILDER(Name("BigtableSampleKeysDataset").Device(DEVICE_CPU),
120                         BigtableSampleKeysDatasetOp);
121 
122 }  // namespace
123 }  // namespace data
124 }  // namespace tensorflow
125