1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" 17 18 #include "tensorflow/core/framework/op_kernel.h" 19 #include "tensorflow/core/lib/core/threadpool.h" 20 21 namespace tensorflow { 22 namespace { 23 24 class BigtableClientOp : public OpKernel { 25 public: BigtableClientOp(OpKernelConstruction * ctx)26 explicit BigtableClientOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 27 OP_REQUIRES_OK(ctx, ctx->GetAttr("project_id", &project_id_)); 28 OP_REQUIRES_OK(ctx, ctx->GetAttr("instance_id", &instance_id_)); 29 OP_REQUIRES(ctx, !project_id_.empty(), 30 errors::InvalidArgument("project_id must be non-empty")); 31 OP_REQUIRES(ctx, !instance_id_.empty(), 32 errors::InvalidArgument("instance_id must be non-empty")); 33 34 OP_REQUIRES_OK( 35 ctx, ctx->GetAttr("connection_pool_size", &connection_pool_size_)); 36 // If left unset by the client code, set it to a default of 100. Note: the 37 // cloud-cpp default of 4 concurrent connections is far too low for high 38 // performance streaming. 39 if (connection_pool_size_ == -1) { 40 connection_pool_size_ = 100; 41 } 42 43 OP_REQUIRES_OK(ctx, ctx->GetAttr("max_receive_message_size", 44 &max_receive_message_size_)); 45 // If left unset by the client code, set it to a default of 100. Note: the 46 // cloud-cpp default of 4 concurrent connections is far too low for high 47 // performance streaming. 48 if (max_receive_message_size_ == -1) { 49 max_receive_message_size_ = 1 << 24; // 16 MBytes 50 } 51 OP_REQUIRES(ctx, max_receive_message_size_ > 0, 52 errors::InvalidArgument("connection_pool_size must be > 0")); 53 } 54 ~BigtableClientOp()55 ~BigtableClientOp() override { 56 if (cinfo_.resource_is_private_to_kernel()) { 57 if (!cinfo_.resource_manager() 58 ->Delete<BigtableClientResource>(cinfo_.container(), 59 cinfo_.name()) 60 .ok()) { 61 // Do nothing; the resource can have been deleted by session resets. 62 } 63 } 64 } 65 Compute(OpKernelContext * ctx)66 void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) { 67 mutex_lock l(mu_); 68 if (!initialized_) { 69 ResourceMgr* mgr = ctx->resource_manager(); 70 OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def())); 71 BigtableClientResource* resource; 72 OP_REQUIRES_OK( 73 ctx, 74 mgr->LookupOrCreate<BigtableClientResource>( 75 cinfo_.container(), cinfo_.name(), &resource, 76 [this, ctx]( 77 BigtableClientResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { 78 auto client_options = 79 google::cloud::bigtable::ClientOptions() 80 .set_connection_pool_size(connection_pool_size_) 81 .set_data_endpoint("batch-bigtable.googleapis.com"); 82 auto channel_args = client_options.channel_arguments(); 83 channel_args.SetMaxReceiveMessageSize( 84 max_receive_message_size_); 85 channel_args.SetUserAgentPrefix("tensorflow"); 86 channel_args.SetInt(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, 0); 87 channel_args.SetInt(GRPC_ARG_KEEPALIVE_TIMEOUT_MS, 60 * 1000); 88 client_options.set_channel_arguments(channel_args); 89 std::shared_ptr<google::cloud::bigtable::DataClient> client = 90 google::cloud::bigtable::CreateDefaultDataClient( 91 project_id_, instance_id_, std::move(client_options)); 92 *ret = new BigtableClientResource(project_id_, instance_id_, 93 std::move(client)); 94 return Status::OK(); 95 })); 96 core::ScopedUnref resource_cleanup(resource); 97 initialized_ = true; 98 } 99 OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( 100 ctx, 0, cinfo_.container(), cinfo_.name(), 101 MakeTypeIndex<BigtableClientResource>())); 102 } 103 104 private: 105 string project_id_; 106 string instance_id_; 107 int64 connection_pool_size_; 108 int32 max_receive_message_size_; 109 110 mutex mu_; 111 ContainerInfo cinfo_ GUARDED_BY(mu_); 112 bool initialized_ GUARDED_BY(mu_) = false; 113 }; 114 115 REGISTER_KERNEL_BUILDER(Name("BigtableClient").Device(DEVICE_CPU), 116 BigtableClientOp); 117 118 class BigtableTableOp : public OpKernel { 119 public: BigtableTableOp(OpKernelConstruction * ctx)120 explicit BigtableTableOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 121 OP_REQUIRES_OK(ctx, ctx->GetAttr("table_name", &table_)); 122 OP_REQUIRES(ctx, !table_.empty(), 123 errors::InvalidArgument("table_name must be non-empty")); 124 } 125 ~BigtableTableOp()126 ~BigtableTableOp() override { 127 if (cinfo_.resource_is_private_to_kernel()) { 128 if (!cinfo_.resource_manager() 129 ->Delete<BigtableTableResource>(cinfo_.container(), 130 cinfo_.name()) 131 .ok()) { 132 // Do nothing; the resource can have been deleted by session resets. 133 } 134 } 135 } 136 Compute(OpKernelContext * ctx)137 void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) { 138 mutex_lock l(mu_); 139 if (!initialized_) { 140 ResourceMgr* mgr = ctx->resource_manager(); 141 OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def())); 142 143 BigtableClientResource* client_resource; 144 OP_REQUIRES_OK( 145 ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &client_resource)); 146 core::ScopedUnref unref_client(client_resource); 147 148 BigtableTableResource* resource; 149 OP_REQUIRES_OK( 150 ctx, mgr->LookupOrCreate<BigtableTableResource>( 151 cinfo_.container(), cinfo_.name(), &resource, 152 [this, client_resource](BigtableTableResource** ret) { 153 *ret = new BigtableTableResource(client_resource, table_); 154 return Status::OK(); 155 })); 156 initialized_ = true; 157 } 158 OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( 159 ctx, 0, cinfo_.container(), cinfo_.name(), 160 MakeTypeIndex<BigtableTableResource>())); 161 } 162 163 private: 164 string table_; // Note: this is const after construction. 165 166 mutex mu_; 167 ContainerInfo cinfo_ GUARDED_BY(mu_); 168 bool initialized_ GUARDED_BY(mu_) = false; 169 }; 170 171 REGISTER_KERNEL_BUILDER(Name("BigtableTable").Device(DEVICE_CPU), 172 BigtableTableOp); 173 174 } // namespace 175 176 namespace data { 177 namespace { 178 179 class ToBigtableOp : public AsyncOpKernel { 180 public: ToBigtableOp(OpKernelConstruction * ctx)181 explicit ToBigtableOp(OpKernelConstruction* ctx) 182 : AsyncOpKernel(ctx), 183 thread_pool_(new thread::ThreadPool( 184 ctx->env(), ThreadOptions(), 185 strings::StrCat("to_bigtable_op_", SanitizeThreadSuffix(name())), 186 /* num_threads = */ 1, /* low_latency_hint = */ false)) {} 187 ComputeAsync(OpKernelContext * ctx,DoneCallback done)188 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { 189 // The call to `iterator->GetNext()` may block and depend on an 190 // inter-op thread pool thread, so we issue the call from the 191 // owned thread pool. 192 thread_pool_->Schedule([this, ctx, done]() { 193 const Tensor* column_families_tensor; 194 OP_REQUIRES_OK_ASYNC( 195 ctx, ctx->input("column_families", &column_families_tensor), done); 196 OP_REQUIRES_ASYNC( 197 ctx, column_families_tensor->dims() == 1, 198 errors::InvalidArgument("`column_families` must be a vector."), done); 199 200 const Tensor* columns_tensor; 201 OP_REQUIRES_OK_ASYNC(ctx, ctx->input("columns", &columns_tensor), done); 202 OP_REQUIRES_ASYNC(ctx, columns_tensor->dims() == 1, 203 errors::InvalidArgument("`columns` must be a vector."), 204 done); 205 OP_REQUIRES_ASYNC( 206 ctx, 207 columns_tensor->NumElements() == 208 column_families_tensor->NumElements(), 209 errors::InvalidArgument("len(column_families) != len(columns)"), 210 done); 211 212 std::vector<string> column_families; 213 column_families.reserve(column_families_tensor->NumElements()); 214 std::vector<string> columns; 215 columns.reserve(column_families_tensor->NumElements()); 216 for (uint64 i = 0; i < column_families_tensor->NumElements(); ++i) { 217 column_families.push_back(column_families_tensor->flat<string>()(i)); 218 columns.push_back(columns_tensor->flat<string>()(i)); 219 } 220 221 DatasetBase* dataset; 222 OP_REQUIRES_OK_ASYNC( 223 ctx, GetDatasetFromVariantTensor(ctx->input(1), &dataset), done); 224 225 std::unique_ptr<IteratorBase> iterator; 226 OP_REQUIRES_OK_ASYNC( 227 ctx, 228 dataset->MakeIterator(IteratorContext(ctx), "ToBigtableOpIterator", 229 &iterator), 230 done); 231 232 int64 timestamp_int; 233 OP_REQUIRES_OK_ASYNC( 234 ctx, ParseScalarArgument<int64>(ctx, "timestamp", ×tamp_int), 235 done); 236 OP_REQUIRES_ASYNC(ctx, timestamp_int >= -1, 237 errors::InvalidArgument("timestamp must be >= -1"), 238 done); 239 240 BigtableTableResource* resource; 241 OP_REQUIRES_OK_ASYNC( 242 ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &resource), done); 243 core::ScopedUnref resource_cleanup(resource); 244 245 std::vector<Tensor> components; 246 components.reserve(dataset->output_dtypes().size()); 247 bool end_of_sequence = false; 248 do { 249 ::google::cloud::bigtable::BulkMutation mutation; 250 // TODO(saeta): Make # of mutations configurable. 251 for (uint64 i = 0; i < 100 && !end_of_sequence; ++i) { 252 OP_REQUIRES_OK_ASYNC(ctx, 253 iterator->GetNext(IteratorContext(ctx), 254 &components, &end_of_sequence), 255 done); 256 if (!end_of_sequence) { 257 OP_REQUIRES_OK_ASYNC( 258 ctx, 259 CreateMutation(std::move(components), column_families, columns, 260 timestamp_int, &mutation), 261 done); 262 } 263 components.clear(); 264 } 265 grpc::Status mutation_status; 266 std::vector<::google::cloud::bigtable::FailedMutation> failures = 267 resource->table().BulkApply(std::move(mutation), mutation_status); 268 if (!mutation_status.ok()) { 269 LOG(ERROR) << "Failure applying mutation: " 270 << mutation_status.error_code() << " - " 271 << mutation_status.error_message() << " (" 272 << mutation_status.error_details() << ")."; 273 } 274 if (!failures.empty()) { 275 for (const auto& failure : failures) { 276 LOG(ERROR) << "Failure applying mutation on row (" 277 << failure.original_index() 278 << "): " << failure.mutation().row_key() 279 << " - error: " << failure.status().message() << "."; 280 } 281 } 282 OP_REQUIRES_ASYNC( 283 ctx, failures.empty() && mutation_status.ok(), 284 errors::Unknown("Failure while writing to Cloud Bigtable: ", 285 mutation_status.error_code(), " - ", 286 mutation_status.error_message(), " (", 287 mutation_status.error_details(), 288 "), # of mutation failures: ", failures.size(), 289 ". See the log for the specific error details."), 290 done); 291 } while (!end_of_sequence); 292 done(); 293 }); 294 } 295 296 private: SanitizeThreadSuffix(string suffix)297 static string SanitizeThreadSuffix(string suffix) { 298 string clean; 299 for (int i = 0; i < suffix.size(); ++i) { 300 const char ch = suffix[i]; 301 if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || 302 (ch >= '0' && ch <= '9') || ch == '_' || ch == '-') { 303 clean += ch; 304 } else { 305 clean += '_'; 306 } 307 } 308 return clean; 309 } 310 CreateMutation(std::vector<Tensor> tensors,const std::vector<string> & column_families,const std::vector<string> & columns,int64 timestamp_int,::google::cloud::bigtable::BulkMutation * bulk_mutation)311 Status CreateMutation( 312 std::vector<Tensor> tensors, const std::vector<string>& column_families, 313 const std::vector<string>& columns, int64 timestamp_int, 314 ::google::cloud::bigtable::BulkMutation* bulk_mutation) { 315 if (tensors.size() != column_families.size() + 1) { 316 return errors::InvalidArgument( 317 "Iterator produced a set of Tensors shorter than expected"); 318 } 319 ::google::cloud::bigtable::SingleRowMutation mutation( 320 std::move(tensors[0].scalar<string>()())); 321 std::chrono::milliseconds timestamp(timestamp_int); 322 for (size_t i = 1; i < tensors.size(); ++i) { 323 if (!TensorShapeUtils::IsScalar(tensors[i].shape())) { 324 return errors::Internal("Output tensor ", i, " was not a scalar"); 325 } 326 if (timestamp_int == -1) { 327 mutation.emplace_back(::google::cloud::bigtable::SetCell( 328 column_families[i - 1], columns[i - 1], 329 std::move(tensors[i].scalar<string>()()))); 330 } else { 331 mutation.emplace_back(::google::cloud::bigtable::SetCell( 332 column_families[i - 1], columns[i - 1], timestamp, 333 std::move(tensors[i].scalar<string>()()))); 334 } 335 } 336 bulk_mutation->emplace_back(std::move(mutation)); 337 return Status::OK(); 338 } 339 340 template <typename T> ParseScalarArgument(OpKernelContext * ctx,StringPiece argument_name,T * output)341 Status ParseScalarArgument(OpKernelContext* ctx, StringPiece argument_name, 342 T* output) { 343 const Tensor* argument_t; 344 TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); 345 if (!TensorShapeUtils::IsScalar(argument_t->shape())) { 346 return errors::InvalidArgument(argument_name, " must be a scalar"); 347 } 348 *output = argument_t->scalar<T>()(); 349 return Status::OK(); 350 } 351 352 std::unique_ptr<thread::ThreadPool> thread_pool_; 353 }; 354 355 REGISTER_KERNEL_BUILDER(Name("DatasetToBigtable").Device(DEVICE_CPU), 356 ToBigtableOp); 357 358 } // namespace 359 } // namespace data 360 } // namespace tensorflow 361