1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
17 
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/lib/core/threadpool.h"
20 
21 namespace tensorflow {
22 namespace {
23 
24 class BigtableClientOp : public OpKernel {
25  public:
BigtableClientOp(OpKernelConstruction * ctx)26   explicit BigtableClientOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
27     OP_REQUIRES_OK(ctx, ctx->GetAttr("project_id", &project_id_));
28     OP_REQUIRES_OK(ctx, ctx->GetAttr("instance_id", &instance_id_));
29     OP_REQUIRES(ctx, !project_id_.empty(),
30                 errors::InvalidArgument("project_id must be non-empty"));
31     OP_REQUIRES(ctx, !instance_id_.empty(),
32                 errors::InvalidArgument("instance_id must be non-empty"));
33 
34     OP_REQUIRES_OK(
35         ctx, ctx->GetAttr("connection_pool_size", &connection_pool_size_));
36     // If left unset by the client code, set it to a default of 100. Note: the
37     // cloud-cpp default of 4 concurrent connections is far too low for high
38     // performance streaming.
39     if (connection_pool_size_ == -1) {
40       connection_pool_size_ = 100;
41     }
42 
43     OP_REQUIRES_OK(ctx, ctx->GetAttr("max_receive_message_size",
44                                      &max_receive_message_size_));
45     // If left unset by the client code, set it to a default of 100. Note: the
46     // cloud-cpp default of 4 concurrent connections is far too low for high
47     // performance streaming.
48     if (max_receive_message_size_ == -1) {
49       max_receive_message_size_ = 1 << 24;  // 16 MBytes
50     }
51     OP_REQUIRES(ctx, max_receive_message_size_ > 0,
52                 errors::InvalidArgument("connection_pool_size must be > 0"));
53   }
54 
~BigtableClientOp()55   ~BigtableClientOp() override {
56     if (cinfo_.resource_is_private_to_kernel()) {
57       if (!cinfo_.resource_manager()
58                ->Delete<BigtableClientResource>(cinfo_.container(),
59                                                 cinfo_.name())
60                .ok()) {
61         // Do nothing; the resource can have been deleted by session resets.
62       }
63     }
64   }
65 
Compute(OpKernelContext * ctx)66   void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) {
67     mutex_lock l(mu_);
68     if (!initialized_) {
69       ResourceMgr* mgr = ctx->resource_manager();
70       OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def()));
71       BigtableClientResource* resource;
72       OP_REQUIRES_OK(
73           ctx,
74           mgr->LookupOrCreate<BigtableClientResource>(
75               cinfo_.container(), cinfo_.name(), &resource,
76               [this, ctx](
77                   BigtableClientResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
78                 auto client_options =
79                     google::cloud::bigtable::ClientOptions()
80                         .set_connection_pool_size(connection_pool_size_)
81                         .set_data_endpoint("batch-bigtable.googleapis.com");
82                 auto channel_args = client_options.channel_arguments();
83                 channel_args.SetMaxReceiveMessageSize(
84                     max_receive_message_size_);
85                 channel_args.SetUserAgentPrefix("tensorflow");
86                 channel_args.SetInt(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, 0);
87                 channel_args.SetInt(GRPC_ARG_KEEPALIVE_TIMEOUT_MS, 60 * 1000);
88                 client_options.set_channel_arguments(channel_args);
89                 std::shared_ptr<google::cloud::bigtable::DataClient> client =
90                     google::cloud::bigtable::CreateDefaultDataClient(
91                         project_id_, instance_id_, std::move(client_options));
92                 *ret = new BigtableClientResource(project_id_, instance_id_,
93                                                   std::move(client));
94                 return Status::OK();
95               }));
96       core::ScopedUnref resource_cleanup(resource);
97       initialized_ = true;
98     }
99     OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
100                             ctx, 0, cinfo_.container(), cinfo_.name(),
101                             MakeTypeIndex<BigtableClientResource>()));
102   }
103 
104  private:
105   string project_id_;
106   string instance_id_;
107   int64 connection_pool_size_;
108   int32 max_receive_message_size_;
109 
110   mutex mu_;
111   ContainerInfo cinfo_ GUARDED_BY(mu_);
112   bool initialized_ GUARDED_BY(mu_) = false;
113 };
114 
115 REGISTER_KERNEL_BUILDER(Name("BigtableClient").Device(DEVICE_CPU),
116                         BigtableClientOp);
117 
118 class BigtableTableOp : public OpKernel {
119  public:
BigtableTableOp(OpKernelConstruction * ctx)120   explicit BigtableTableOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
121     OP_REQUIRES_OK(ctx, ctx->GetAttr("table_name", &table_));
122     OP_REQUIRES(ctx, !table_.empty(),
123                 errors::InvalidArgument("table_name must be non-empty"));
124   }
125 
~BigtableTableOp()126   ~BigtableTableOp() override {
127     if (cinfo_.resource_is_private_to_kernel()) {
128       if (!cinfo_.resource_manager()
129                ->Delete<BigtableTableResource>(cinfo_.container(),
130                                                cinfo_.name())
131                .ok()) {
132         // Do nothing; the resource can have been deleted by session resets.
133       }
134     }
135   }
136 
Compute(OpKernelContext * ctx)137   void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) {
138     mutex_lock l(mu_);
139     if (!initialized_) {
140       ResourceMgr* mgr = ctx->resource_manager();
141       OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def()));
142 
143       BigtableClientResource* client_resource;
144       OP_REQUIRES_OK(
145           ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &client_resource));
146       core::ScopedUnref unref_client(client_resource);
147 
148       BigtableTableResource* resource;
149       OP_REQUIRES_OK(
150           ctx, mgr->LookupOrCreate<BigtableTableResource>(
151                    cinfo_.container(), cinfo_.name(), &resource,
152                    [this, client_resource](BigtableTableResource** ret) {
153                      *ret = new BigtableTableResource(client_resource, table_);
154                      return Status::OK();
155                    }));
156       initialized_ = true;
157     }
158     OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
159                             ctx, 0, cinfo_.container(), cinfo_.name(),
160                             MakeTypeIndex<BigtableTableResource>()));
161   }
162 
163  private:
164   string table_;  // Note: this is const after construction.
165 
166   mutex mu_;
167   ContainerInfo cinfo_ GUARDED_BY(mu_);
168   bool initialized_ GUARDED_BY(mu_) = false;
169 };
170 
171 REGISTER_KERNEL_BUILDER(Name("BigtableTable").Device(DEVICE_CPU),
172                         BigtableTableOp);
173 
174 }  // namespace
175 
176 namespace data {
177 namespace {
178 
179 class ToBigtableOp : public AsyncOpKernel {
180  public:
ToBigtableOp(OpKernelConstruction * ctx)181   explicit ToBigtableOp(OpKernelConstruction* ctx)
182       : AsyncOpKernel(ctx),
183         thread_pool_(new thread::ThreadPool(
184             ctx->env(), ThreadOptions(),
185             strings::StrCat("to_bigtable_op_", SanitizeThreadSuffix(name())),
186             /* num_threads = */ 1, /* low_latency_hint = */ false)) {}
187 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)188   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
189     // The call to `iterator->GetNext()` may block and depend on an
190     // inter-op thread pool thread, so we issue the call from the
191     // owned thread pool.
192     thread_pool_->Schedule([this, ctx, done]() {
193       const Tensor* column_families_tensor;
194       OP_REQUIRES_OK_ASYNC(
195           ctx, ctx->input("column_families", &column_families_tensor), done);
196       OP_REQUIRES_ASYNC(
197           ctx, column_families_tensor->dims() == 1,
198           errors::InvalidArgument("`column_families` must be a vector."), done);
199 
200       const Tensor* columns_tensor;
201       OP_REQUIRES_OK_ASYNC(ctx, ctx->input("columns", &columns_tensor), done);
202       OP_REQUIRES_ASYNC(ctx, columns_tensor->dims() == 1,
203                         errors::InvalidArgument("`columns` must be a vector."),
204                         done);
205       OP_REQUIRES_ASYNC(
206           ctx,
207           columns_tensor->NumElements() ==
208               column_families_tensor->NumElements(),
209           errors::InvalidArgument("len(column_families) != len(columns)"),
210           done);
211 
212       std::vector<string> column_families;
213       column_families.reserve(column_families_tensor->NumElements());
214       std::vector<string> columns;
215       columns.reserve(column_families_tensor->NumElements());
216       for (uint64 i = 0; i < column_families_tensor->NumElements(); ++i) {
217         column_families.push_back(column_families_tensor->flat<string>()(i));
218         columns.push_back(columns_tensor->flat<string>()(i));
219       }
220 
221       DatasetBase* dataset;
222       OP_REQUIRES_OK_ASYNC(
223           ctx, GetDatasetFromVariantTensor(ctx->input(1), &dataset), done);
224 
225       std::unique_ptr<IteratorBase> iterator;
226       OP_REQUIRES_OK_ASYNC(
227           ctx,
228           dataset->MakeIterator(IteratorContext(ctx), "ToBigtableOpIterator",
229                                 &iterator),
230           done);
231 
232       int64 timestamp_int;
233       OP_REQUIRES_OK_ASYNC(
234           ctx, ParseScalarArgument<int64>(ctx, "timestamp", &timestamp_int),
235           done);
236       OP_REQUIRES_ASYNC(ctx, timestamp_int >= -1,
237                         errors::InvalidArgument("timestamp must be >= -1"),
238                         done);
239 
240       BigtableTableResource* resource;
241       OP_REQUIRES_OK_ASYNC(
242           ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &resource), done);
243       core::ScopedUnref resource_cleanup(resource);
244 
245       std::vector<Tensor> components;
246       components.reserve(dataset->output_dtypes().size());
247       bool end_of_sequence = false;
248       do {
249         ::google::cloud::bigtable::BulkMutation mutation;
250         // TODO(saeta): Make # of mutations configurable.
251         for (uint64 i = 0; i < 100 && !end_of_sequence; ++i) {
252           OP_REQUIRES_OK_ASYNC(ctx,
253                                iterator->GetNext(IteratorContext(ctx),
254                                                  &components, &end_of_sequence),
255                                done);
256           if (!end_of_sequence) {
257             OP_REQUIRES_OK_ASYNC(
258                 ctx,
259                 CreateMutation(std::move(components), column_families, columns,
260                                timestamp_int, &mutation),
261                 done);
262           }
263           components.clear();
264         }
265         grpc::Status mutation_status;
266         std::vector<::google::cloud::bigtable::FailedMutation> failures =
267             resource->table().BulkApply(std::move(mutation), mutation_status);
268         if (!mutation_status.ok()) {
269           LOG(ERROR) << "Failure applying mutation: "
270                      << mutation_status.error_code() << " - "
271                      << mutation_status.error_message() << " ("
272                      << mutation_status.error_details() << ").";
273         }
274         if (!failures.empty()) {
275           for (const auto& failure : failures) {
276             LOG(ERROR) << "Failure applying mutation on row ("
277                        << failure.original_index()
278                        << "): " << failure.mutation().row_key()
279                        << " - error: " << failure.status().message() << ".";
280           }
281         }
282         OP_REQUIRES_ASYNC(
283             ctx, failures.empty() && mutation_status.ok(),
284             errors::Unknown("Failure while writing to Cloud Bigtable: ",
285                             mutation_status.error_code(), " - ",
286                             mutation_status.error_message(), " (",
287                             mutation_status.error_details(),
288                             "), # of mutation failures: ", failures.size(),
289                             ". See the log for the specific error details."),
290             done);
291       } while (!end_of_sequence);
292       done();
293     });
294   }
295 
296  private:
SanitizeThreadSuffix(string suffix)297   static string SanitizeThreadSuffix(string suffix) {
298     string clean;
299     for (int i = 0; i < suffix.size(); ++i) {
300       const char ch = suffix[i];
301       if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') ||
302           (ch >= '0' && ch <= '9') || ch == '_' || ch == '-') {
303         clean += ch;
304       } else {
305         clean += '_';
306       }
307     }
308     return clean;
309   }
310 
CreateMutation(std::vector<Tensor> tensors,const std::vector<string> & column_families,const std::vector<string> & columns,int64 timestamp_int,::google::cloud::bigtable::BulkMutation * bulk_mutation)311   Status CreateMutation(
312       std::vector<Tensor> tensors, const std::vector<string>& column_families,
313       const std::vector<string>& columns, int64 timestamp_int,
314       ::google::cloud::bigtable::BulkMutation* bulk_mutation) {
315     if (tensors.size() != column_families.size() + 1) {
316       return errors::InvalidArgument(
317           "Iterator produced a set of Tensors shorter than expected");
318     }
319     ::google::cloud::bigtable::SingleRowMutation mutation(
320         std::move(tensors[0].scalar<string>()()));
321     std::chrono::milliseconds timestamp(timestamp_int);
322     for (size_t i = 1; i < tensors.size(); ++i) {
323       if (!TensorShapeUtils::IsScalar(tensors[i].shape())) {
324         return errors::Internal("Output tensor ", i, " was not a scalar");
325       }
326       if (timestamp_int == -1) {
327         mutation.emplace_back(::google::cloud::bigtable::SetCell(
328             column_families[i - 1], columns[i - 1],
329             std::move(tensors[i].scalar<string>()())));
330       } else {
331         mutation.emplace_back(::google::cloud::bigtable::SetCell(
332             column_families[i - 1], columns[i - 1], timestamp,
333             std::move(tensors[i].scalar<string>()())));
334       }
335     }
336     bulk_mutation->emplace_back(std::move(mutation));
337     return Status::OK();
338   }
339 
340   template <typename T>
ParseScalarArgument(OpKernelContext * ctx,StringPiece argument_name,T * output)341   Status ParseScalarArgument(OpKernelContext* ctx, StringPiece argument_name,
342                              T* output) {
343     const Tensor* argument_t;
344     TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
345     if (!TensorShapeUtils::IsScalar(argument_t->shape())) {
346       return errors::InvalidArgument(argument_name, " must be a scalar");
347     }
348     *output = argument_t->scalar<T>()();
349     return Status::OK();
350   }
351 
352   std::unique_ptr<thread::ThreadPool> thread_pool_;
353 };
354 
355 REGISTER_KERNEL_BUILDER(Name("DatasetToBigtable").Device(DEVICE_CPU),
356                         ToBigtableOp);
357 
358 }  // namespace
359 }  // namespace data
360 }  // namespace tensorflow
361