1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
17 #include "tensorflow/core/framework/op_kernel.h"
18 
19 namespace tensorflow {
20 namespace data {
21 namespace {
22 
23 class BigtableScanDatasetOp : public DatasetOpKernel {
24  public:
25   using DatasetOpKernel::DatasetOpKernel;
26 
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)27   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
28     string prefix;
29     OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "prefix", &prefix));
30     string start_key;
31     OP_REQUIRES_OK(ctx,
32                    ParseScalarArgument<string>(ctx, "start_key", &start_key));
33     string end_key;
34     OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "end_key", &end_key));
35 
36     OP_REQUIRES(ctx, !(prefix.empty() && start_key.empty()),
37                 errors::InvalidArgument(
38                     "Either prefix or start_key must be specified"));
39     OP_REQUIRES(ctx, prefix.empty() || start_key.empty(),
40                 errors::InvalidArgument(
41                     "Only one of prefix and start_key can be provided"));
42     if (!prefix.empty()) {
43       OP_REQUIRES(ctx, end_key.empty(),
44                   errors::InvalidArgument(
45                       "If prefix is specified, end_key must be empty."));
46     }
47 
48     std::vector<string> column_families;
49     std::vector<string> columns;
50     OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "column_families",
51                                                     &column_families));
52     OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "columns", &columns));
53     OP_REQUIRES(
54         ctx, column_families.size() == columns.size(),
55         errors::InvalidArgument("len(columns) != len(column_families)"));
56     OP_REQUIRES(ctx, !column_families.empty(),
57                 errors::InvalidArgument("`column_families` is empty"));
58 
59     float probability = 0;
60     OP_REQUIRES_OK(
61         ctx, ParseScalarArgument<float>(ctx, "probability", &probability));
62     OP_REQUIRES(
63         ctx, probability > 0 && probability <= 1,
64         errors::InvalidArgument(
65             "Probability outside the range of (0, 1]. Got: ", probability));
66 
67     BigtableTableResource* resource;
68     OP_REQUIRES_OK(ctx,
69                    LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
70     core::ScopedUnref scoped_unref(resource);
71 
72     const uint64 num_outputs = columns.size() + 1;
73     std::vector<PartialTensorShape> output_shapes;
74     output_shapes.reserve(num_outputs);
75     DataTypeVector output_types;
76     output_types.reserve(num_outputs);
77     for (uint64 i = 0; i < num_outputs; ++i) {
78       output_shapes.push_back({});
79       output_types.push_back(DT_STRING);
80     }
81 
82     *output = new Dataset(ctx, resource, std::move(prefix),
83                           std::move(start_key), std::move(end_key),
84                           std::move(column_families), std::move(columns),
85                           probability, output_types, std::move(output_shapes));
86   }
87 
88  private:
89   class Dataset : public DatasetBase {
90    public:
Dataset(OpKernelContext * ctx,BigtableTableResource * table,string prefix,string start_key,string end_key,std::vector<string> column_families,std::vector<string> columns,float probability,const DataTypeVector & output_types,std::vector<PartialTensorShape> output_shapes)91     explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table,
92                      string prefix, string start_key, string end_key,
93                      std::vector<string> column_families,
94                      std::vector<string> columns, float probability,
95                      const DataTypeVector& output_types,
96                      std::vector<PartialTensorShape> output_shapes)
97         : DatasetBase(DatasetContext(ctx)),
98           table_(table),
99           prefix_(std::move(prefix)),
100           start_key_(std::move(start_key)),
101           end_key_(std::move(end_key)),
102           column_families_(std::move(column_families)),
103           columns_(std::move(columns)),
104           column_family_regex_(RegexFromStringSet(column_families_)),
105           column_regex_(RegexFromStringSet(columns_)),
106           probability_(probability),
107           output_types_(output_types),
108           output_shapes_(std::move(output_shapes)) {
109       table_->Ref();
110     }
111 
~Dataset()112     ~Dataset() override { table_->Unref(); }
113 
MakeIteratorInternal(const string & prefix) const114     std::unique_ptr<IteratorBase> MakeIteratorInternal(
115         const string& prefix) const override {
116       return std::unique_ptr<IteratorBase>(
117           new Iterator({this, strings::StrCat(prefix, "::BigtableScan")}));
118     }
119 
output_dtypes() const120     const DataTypeVector& output_dtypes() const override {
121       return output_types_;
122     }
123 
output_shapes() const124     const std::vector<PartialTensorShape>& output_shapes() const override {
125       return output_shapes_;
126     }
127 
DebugString() const128     string DebugString() const override {
129       return "BigtableScanDatasetOp::Dataset";
130     }
131 
table() const132     BigtableTableResource* table() const { return table_; }
133 
134    protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const135     Status AsGraphDefInternal(SerializationContext* ctx,
136                               DatasetGraphDefBuilder* b,
137                               Node** output) const override {
138       return errors::Unimplemented("%s does not support serialization",
139                                    DebugString());
140     }
141 
142    private:
143     class Iterator : public BigtableReaderDatasetIterator<Dataset> {
144      public:
Iterator(const Params & params)145       explicit Iterator(const Params& params)
146           : BigtableReaderDatasetIterator<Dataset>(params) {}
147 
MakeRowRange()148       ::google::cloud::bigtable::RowRange MakeRowRange() override {
149         if (!dataset()->prefix_.empty()) {
150           DCHECK(dataset()->start_key_.empty());
151           return ::google::cloud::bigtable::RowRange::Prefix(
152               dataset()->prefix_);
153         } else {
154           DCHECK(!dataset()->start_key_.empty())
155               << "Both prefix and start_key were empty!";
156           return ::google::cloud::bigtable::RowRange::Range(
157               dataset()->start_key_, dataset()->end_key_);
158         }
159       }
MakeFilter()160       ::google::cloud::bigtable::Filter MakeFilter() override {
161         // TODO(saeta): Investigate optimal ordering here.
162         return ::google::cloud::bigtable::Filter::Chain(
163             ::google::cloud::bigtable::Filter::Latest(1),
164             ::google::cloud::bigtable::Filter::FamilyRegex(
165                 dataset()->column_family_regex_),
166             ::google::cloud::bigtable::Filter::ColumnRegex(
167                 dataset()->column_regex_),
168             dataset()->probability_ != 1.0
169                 ? ::google::cloud::bigtable::Filter::RowSample(
170                       dataset()->probability_)
171                 : ::google::cloud::bigtable::Filter::PassAllFilter());
172       }
ParseRow(IteratorContext * ctx,const::google::cloud::bigtable::Row & row,std::vector<Tensor> * out_tensors)173       Status ParseRow(IteratorContext* ctx,
174                       const ::google::cloud::bigtable::Row& row,
175                       std::vector<Tensor>* out_tensors) override {
176         out_tensors->reserve(dataset()->columns_.size() + 1);
177         Tensor row_key_tensor(ctx->allocator({}), DT_STRING, {});
178         row_key_tensor.scalar<string>()() = string(row.row_key());
179         out_tensors->emplace_back(std::move(row_key_tensor));
180 
181         if (row.cells().size() > 2 * dataset()->columns_.size()) {
182           LOG(WARNING) << "An excessive number of columns ("
183                        << row.cells().size()
184                        << ") were retrieved when reading row: "
185                        << row.row_key();
186         }
187 
188         for (uint64 i = 0; i < dataset()->columns_.size(); ++i) {
189           Tensor col_tensor(ctx->allocator({}), DT_STRING, {});
190           bool found_column = false;
191           for (auto cell_itr = row.cells().begin();
192                !found_column && cell_itr != row.cells().end(); ++cell_itr) {
193             if (cell_itr->family_name() == dataset()->column_families_[i] &&
194                 string(cell_itr->column_qualifier()) ==
195                     dataset()->columns_[i]) {
196               col_tensor.scalar<string>()() = string(cell_itr->value());
197               found_column = true;
198             }
199           }
200           if (!found_column) {
201             return errors::InvalidArgument(
202                 "Column ", dataset()->column_families_[i], ":",
203                 dataset()->columns_[i], " not found in row: ", row.row_key());
204           }
205           out_tensors->emplace_back(std::move(col_tensor));
206         }
207         return Status::OK();
208       }
209     };
210 
211     BigtableTableResource* table_;
212     const string prefix_;
213     const string start_key_;
214     const string end_key_;
215     const std::vector<string> column_families_;
216     const std::vector<string> columns_;
217     const string column_family_regex_;
218     const string column_regex_;
219     const float probability_;
220     const DataTypeVector output_types_;
221     const std::vector<PartialTensorShape> output_shapes_;
222   };
223 };
224 
225 REGISTER_KERNEL_BUILDER(Name("BigtableScanDataset").Device(DEVICE_CPU),
226                         BigtableScanDatasetOp);
227 
228 }  // namespace
229 }  // namespace data
230 }  // namespace tensorflow
231