1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" 17 #include "tensorflow/core/framework/op_kernel.h" 18 19 namespace tensorflow { 20 namespace data { 21 namespace { 22 23 class BigtableSampleKeysDatasetOp : public DatasetOpKernel { 24 public: 25 using DatasetOpKernel::DatasetOpKernel; 26 MakeDataset(OpKernelContext * ctx,DatasetBase ** output)27 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { 28 BigtableTableResource* resource; 29 OP_REQUIRES_OK(ctx, 30 LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); 31 core::ScopedUnref scoped_unref(resource); 32 *output = new Dataset(ctx, resource); 33 } 34 35 private: 36 class Dataset : public DatasetBase { 37 public: Dataset(OpKernelContext * ctx,BigtableTableResource * table)38 explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table) 39 : DatasetBase(DatasetContext(ctx)), table_(table) { 40 table_->Ref(); 41 } 42 ~Dataset()43 ~Dataset() override { table_->Unref(); } 44 MakeIteratorInternal(const string & prefix) const45 std::unique_ptr<IteratorBase> MakeIteratorInternal( 46 const string& prefix) const override { 47 return std::unique_ptr<IteratorBase>(new Iterator( 48 {this, strings::StrCat(prefix, "::BigtableSampleKeys")})); 49 } 50 output_dtypes() const51 const DataTypeVector& output_dtypes() const override { 52 static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); 53 return *dtypes; 54 } 55 output_shapes() const56 const std::vector<PartialTensorShape>& output_shapes() const override { 57 static std::vector<PartialTensorShape>* shapes = 58 new std::vector<PartialTensorShape>({{}}); 59 return *shapes; 60 } 61 DebugString() const62 string DebugString() const override { 63 return "BigtableRangeKeyDatasetOp::Dataset"; 64 } 65 table() const66 BigtableTableResource* table() const { return table_; } 67 68 protected: AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const69 Status AsGraphDefInternal(SerializationContext* ctx, 70 DatasetGraphDefBuilder* b, 71 Node** output) const override { 72 return errors::Unimplemented("%s does not support serialization", 73 DebugString()); 74 } 75 76 private: 77 class Iterator : public DatasetIterator<Dataset> { 78 public: Iterator(const Params & params)79 explicit Iterator(const Params& params) 80 : DatasetIterator<Dataset>(params) {} 81 Initialize(IteratorContext * ctx)82 Status Initialize(IteratorContext* ctx) override { 83 ::grpc::Status status; 84 row_keys_ = dataset()->table()->table().SampleRows(status); 85 if (!status.ok()) { 86 row_keys_.clear(); 87 return GrpcStatusToTfStatus(status); 88 } 89 return Status::OK(); 90 } 91 GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)92 Status GetNextInternal(IteratorContext* ctx, 93 std::vector<Tensor>* out_tensors, 94 bool* end_of_sequence) override { 95 mutex_lock l(mu_); 96 if (index_ < row_keys_.size()) { 97 out_tensors->emplace_back(ctx->allocator({}), DT_STRING, 98 TensorShape({})); 99 out_tensors->back().scalar<string>()() = 100 string(row_keys_[index_].row_key); 101 *end_of_sequence = false; 102 index_++; 103 } else { 104 *end_of_sequence = true; 105 } 106 return Status::OK(); 107 } 108 109 private: 110 mutex mu_; 111 size_t index_ = 0; 112 std::vector<::google::cloud::bigtable::RowKeySample> row_keys_; 113 }; 114 115 BigtableTableResource* const table_; 116 }; 117 }; 118 119 REGISTER_KERNEL_BUILDER(Name("BigtableSampleKeysDataset").Device(DEVICE_CPU), 120 BigtableSampleKeysDatasetOp); 121 122 } // namespace 123 } // namespace data 124 } // namespace tensorflow 125