1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
17 #include "tensorflow/core/framework/op_kernel.h"
18 
19 namespace tensorflow {
20 namespace data {
21 namespace {
22 
23 class BigtableRangeKeyDatasetOp : public DatasetOpKernel {
24  public:
25   using DatasetOpKernel::DatasetOpKernel;
26 
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)27   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
28     string start_key;
29     OP_REQUIRES_OK(ctx,
30                    ParseScalarArgument<string>(ctx, "start_key", &start_key));
31     string end_key;
32     OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "end_key", &end_key));
33 
34     BigtableTableResource* resource;
35     OP_REQUIRES_OK(ctx,
36                    LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
37     core::ScopedUnref scoped_unref(resource);
38 
39     *output =
40         new Dataset(ctx, resource, std::move(start_key), std::move(end_key));
41   }
42 
43  private:
44   class Dataset : public DatasetBase {
45    public:
Dataset(OpKernelContext * ctx,BigtableTableResource * table,string start_key,string end_key)46     explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table,
47                      string start_key, string end_key)
48         : DatasetBase(DatasetContext(ctx)),
49           table_(table),
50           start_key_(std::move(start_key)),
51           end_key_(std::move(end_key)) {
52       table_->Ref();
53     }
54 
~Dataset()55     ~Dataset() override { table_->Unref(); }
56 
MakeIteratorInternal(const string & prefix) const57     std::unique_ptr<IteratorBase> MakeIteratorInternal(
58         const string& prefix) const override {
59       return std::unique_ptr<IteratorBase>(
60           new Iterator({this, strings::StrCat(prefix, "::BigtableRangeKey")}));
61     }
62 
output_dtypes() const63     const DataTypeVector& output_dtypes() const override {
64       static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
65       return *dtypes;
66     }
67 
output_shapes() const68     const std::vector<PartialTensorShape>& output_shapes() const override {
69       static std::vector<PartialTensorShape>* shapes =
70           new std::vector<PartialTensorShape>({{}});
71       return *shapes;
72     }
73 
DebugString() const74     string DebugString() const override {
75       return "BigtableRangeKeyDatasetOp::Dataset";
76     }
77 
table() const78     BigtableTableResource* table() const { return table_; }
79 
80    protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const81     Status AsGraphDefInternal(SerializationContext* ctx,
82                               DatasetGraphDefBuilder* b,
83                               Node** output) const override {
84       return errors::Unimplemented("%s does not support serialization",
85                                    DebugString());
86     }
87 
88    private:
89     class Iterator : public BigtableReaderDatasetIterator<Dataset> {
90      public:
Iterator(const Params & params)91       explicit Iterator(const Params& params)
92           : BigtableReaderDatasetIterator<Dataset>(params) {}
93 
MakeRowRange()94       ::google::cloud::bigtable::RowRange MakeRowRange() override {
95         return ::google::cloud::bigtable::RowRange::Range(dataset()->start_key_,
96                                                           dataset()->end_key_);
97       }
MakeFilter()98       ::google::cloud::bigtable::Filter MakeFilter() override {
99         return ::google::cloud::bigtable::Filter::Chain(
100             ::google::cloud::bigtable::Filter::CellsRowLimit(1),
101             ::google::cloud::bigtable::Filter::StripValueTransformer());
102       }
ParseRow(IteratorContext * ctx,const::google::cloud::bigtable::Row & row,std::vector<Tensor> * out_tensors)103       Status ParseRow(IteratorContext* ctx,
104                       const ::google::cloud::bigtable::Row& row,
105                       std::vector<Tensor>* out_tensors) override {
106         Tensor output_tensor(ctx->allocator({}), DT_STRING, {});
107         output_tensor.scalar<string>()() = string(row.row_key());
108         out_tensors->emplace_back(std::move(output_tensor));
109         return Status::OK();
110       }
111     };
112 
113     BigtableTableResource* const table_;
114     const string start_key_;
115     const string end_key_;
116   };
117 };
118 
119 REGISTER_KERNEL_BUILDER(Name("BigtableRangeKeyDataset").Device(DEVICE_CPU),
120                         BigtableRangeKeyDatasetOp);
121 }  // namespace
122 }  // namespace data
123 }  // namespace tensorflow
124