1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" 17 #include "tensorflow/core/framework/op_kernel.h" 18 19 namespace tensorflow { 20 namespace data { 21 namespace { 22 23 class BigtableRangeKeyDatasetOp : public DatasetOpKernel { 24 public: 25 using DatasetOpKernel::DatasetOpKernel; 26 MakeDataset(OpKernelContext * ctx,DatasetBase ** output)27 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { 28 string start_key; 29 OP_REQUIRES_OK(ctx, 30 ParseScalarArgument<string>(ctx, "start_key", &start_key)); 31 string end_key; 32 OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "end_key", &end_key)); 33 34 BigtableTableResource* resource; 35 OP_REQUIRES_OK(ctx, 36 LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); 37 core::ScopedUnref scoped_unref(resource); 38 39 *output = 40 new Dataset(ctx, resource, std::move(start_key), std::move(end_key)); 41 } 42 43 private: 44 class Dataset : public DatasetBase { 45 public: Dataset(OpKernelContext * ctx,BigtableTableResource * table,string start_key,string end_key)46 explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table, 47 string start_key, string end_key) 48 : DatasetBase(DatasetContext(ctx)), 49 table_(table), 50 start_key_(std::move(start_key)), 51 end_key_(std::move(end_key)) { 52 table_->Ref(); 53 } 54 ~Dataset()55 ~Dataset() override { table_->Unref(); } 56 MakeIteratorInternal(const string & prefix) const57 std::unique_ptr<IteratorBase> MakeIteratorInternal( 58 const string& prefix) const override { 59 return std::unique_ptr<IteratorBase>( 60 new Iterator({this, strings::StrCat(prefix, "::BigtableRangeKey")})); 61 } 62 output_dtypes() const63 const DataTypeVector& output_dtypes() const override { 64 static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); 65 return *dtypes; 66 } 67 output_shapes() const68 const std::vector<PartialTensorShape>& output_shapes() const override { 69 static std::vector<PartialTensorShape>* shapes = 70 new std::vector<PartialTensorShape>({{}}); 71 return *shapes; 72 } 73 DebugString() const74 string DebugString() const override { 75 return "BigtableRangeKeyDatasetOp::Dataset"; 76 } 77 table() const78 BigtableTableResource* table() const { return table_; } 79 80 protected: AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const81 Status AsGraphDefInternal(SerializationContext* ctx, 82 DatasetGraphDefBuilder* b, 83 Node** output) const override { 84 return errors::Unimplemented("%s does not support serialization", 85 DebugString()); 86 } 87 88 private: 89 class Iterator : public BigtableReaderDatasetIterator<Dataset> { 90 public: Iterator(const Params & params)91 explicit Iterator(const Params& params) 92 : BigtableReaderDatasetIterator<Dataset>(params) {} 93 MakeRowRange()94 ::google::cloud::bigtable::RowRange MakeRowRange() override { 95 return ::google::cloud::bigtable::RowRange::Range(dataset()->start_key_, 96 dataset()->end_key_); 97 } MakeFilter()98 ::google::cloud::bigtable::Filter MakeFilter() override { 99 return ::google::cloud::bigtable::Filter::Chain( 100 ::google::cloud::bigtable::Filter::CellsRowLimit(1), 101 ::google::cloud::bigtable::Filter::StripValueTransformer()); 102 } ParseRow(IteratorContext * ctx,const::google::cloud::bigtable::Row & row,std::vector<Tensor> * out_tensors)103 Status ParseRow(IteratorContext* ctx, 104 const ::google::cloud::bigtable::Row& row, 105 std::vector<Tensor>* out_tensors) override { 106 Tensor output_tensor(ctx->allocator({}), DT_STRING, {}); 107 output_tensor.scalar<string>()() = string(row.row_key()); 108 out_tensors->emplace_back(std::move(output_tensor)); 109 return Status::OK(); 110 } 111 }; 112 113 BigtableTableResource* const table_; 114 const string start_key_; 115 const string end_key_; 116 }; 117 }; 118 119 REGISTER_KERNEL_BUILDER(Name("BigtableRangeKeyDataset").Device(DEVICE_CPU), 120 BigtableRangeKeyDatasetOp); 121 } // namespace 122 } // namespace data 123 } // namespace tensorflow 124