1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/dataset.h" 17 18 #include "rdkafkacpp.h" 19 20 namespace tensorflow { 21 22 class KafkaDatasetOp : public DatasetOpKernel { 23 public: 24 using DatasetOpKernel::DatasetOpKernel; 25 MakeDataset(OpKernelContext * ctx,DatasetBase ** output)26 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { 27 const Tensor* topics_tensor; 28 OP_REQUIRES_OK(ctx, ctx->input("topics", &topics_tensor)); 29 OP_REQUIRES( 30 ctx, topics_tensor->dims() <= 1, 31 errors::InvalidArgument("`topics` must be a scalar or a vector.")); 32 33 std::vector<string> topics; 34 topics.reserve(topics_tensor->NumElements()); 35 for (int i = 0; i < topics_tensor->NumElements(); ++i) { 36 topics.push_back(topics_tensor->flat<string>()(i)); 37 } 38 39 std::string servers = ""; 40 OP_REQUIRES_OK(ctx, 41 ParseScalarArgument<std::string>(ctx, "servers", &servers)); 42 std::string group = ""; 43 OP_REQUIRES_OK(ctx, ParseScalarArgument<std::string>(ctx, "group", &group)); 44 bool eof = false; 45 OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "eof", &eof)); 46 int64 timeout = -1; 47 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "timeout", &timeout)); 48 OP_REQUIRES(ctx, (timeout > 0), 49 errors::InvalidArgument( 50 "Timeout value should be large than 0, got ", timeout)); 51 *output = new Dataset(ctx, std::move(topics), servers, group, eof, timeout); 52 } 53 54 private: 55 class Dataset : public DatasetBase { 56 public: Dataset(OpKernelContext * ctx,std::vector<string> topics,const string & servers,const string & group,const bool eof,const int64 timeout)57 Dataset(OpKernelContext* ctx, std::vector<string> topics, 58 const string& servers, const string& group, const bool eof, 59 const int64 timeout) 60 : DatasetBase(DatasetContext(ctx)), 61 topics_(std::move(topics)), 62 servers_(servers), 63 group_(group), 64 eof_(eof), 65 timeout_(timeout) {} 66 MakeIteratorInternal(const string & prefix) const67 std::unique_ptr<IteratorBase> MakeIteratorInternal( 68 const string& prefix) const override { 69 return std::unique_ptr<IteratorBase>( 70 new Iterator({this, strings::StrCat(prefix, "::Kafka")})); 71 } 72 output_dtypes() const73 const DataTypeVector& output_dtypes() const override { 74 static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); 75 return *dtypes; 76 } 77 output_shapes() const78 const std::vector<PartialTensorShape>& output_shapes() const override { 79 static std::vector<PartialTensorShape>* shapes = 80 new std::vector<PartialTensorShape>({{}}); 81 return *shapes; 82 } 83 DebugString() const84 string DebugString() const override { return "KafkaDatasetOp::Dataset"; } 85 86 protected: AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const87 Status AsGraphDefInternal(SerializationContext* ctx, 88 DatasetGraphDefBuilder* b, 89 Node** output) const override { 90 Node* topics = nullptr; 91 TF_RETURN_IF_ERROR(b->AddVector(topics_, &topics)); 92 Node* servers = nullptr; 93 TF_RETURN_IF_ERROR(b->AddScalar(servers_, &servers)); 94 Node* group = nullptr; 95 TF_RETURN_IF_ERROR(b->AddScalar(group_, &group)); 96 Node* eof = nullptr; 97 TF_RETURN_IF_ERROR(b->AddScalar(eof_, &eof)); 98 Node* timeout = nullptr; 99 TF_RETURN_IF_ERROR(b->AddScalar(timeout_, &timeout)); 100 TF_RETURN_IF_ERROR( 101 b->AddDataset(this, {topics, servers, group, eof, timeout}, output)); 102 return Status::OK(); 103 } 104 105 private: 106 class Iterator : public DatasetIterator<Dataset> { 107 public: Iterator(const Params & params)108 explicit Iterator(const Params& params) 109 : DatasetIterator<Dataset>(params) {} 110 GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)111 Status GetNextInternal(IteratorContext* ctx, 112 std::vector<Tensor>* out_tensors, 113 bool* end_of_sequence) override { 114 mutex_lock l(mu_); 115 do { 116 // We are currently processing a topic, so try to read the next line. 117 if (consumer_.get()) { 118 while (true) { 119 if (limit_ >= 0 && 120 (topic_partition_->offset() >= limit_ || offset_ >= limit_)) { 121 // EOF current topic 122 break; 123 } 124 std::unique_ptr<RdKafka::Message> message( 125 consumer_->consume(dataset()->timeout_)); 126 if (message->err() == RdKafka::ERR_NO_ERROR) { 127 // Produce the line as output. 128 Tensor line_tensor(cpu_allocator(), DT_STRING, {}); 129 line_tensor.scalar<string>()() = 130 std::string(static_cast<const char*>(message->payload()), 131 message->len()); 132 out_tensors->emplace_back(std::move(line_tensor)); 133 *end_of_sequence = false; 134 // Sync offset 135 offset_ = message->offset(); 136 return Status::OK(); 137 } 138 139 if (message->err() == RdKafka::ERR__PARTITION_EOF && 140 dataset()->eof_) { 141 // EOF current topic 142 break; 143 } 144 if (message->err() != RdKafka::ERR__TIMED_OUT) { 145 return errors::Internal("Failed to consume:", 146 message->errstr()); 147 } 148 message.reset(nullptr); 149 consumer_->poll(0); 150 } 151 152 // We have reached the end of the current topic, so maybe 153 // move on to next topic. 154 ResetStreamsLocked(); 155 ++current_topic_index_; 156 } 157 158 // Iteration ends when there are no more topic to process. 159 if (current_topic_index_ == dataset()->topics_.size()) { 160 *end_of_sequence = true; 161 return Status::OK(); 162 } 163 164 TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); 165 } while (true); 166 } 167 168 protected: SaveInternal(IteratorStateWriter * writer)169 Status SaveInternal(IteratorStateWriter* writer) override { 170 mutex_lock l(mu_); 171 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_topic_index"), 172 current_topic_index_)); 173 174 // `consumer_` is empty if 175 // 1. GetNext has not been called even once. 176 // 2. All topics have been read and iterator has been exhausted. 177 if (consumer_.get()) { 178 TF_RETURN_IF_ERROR( 179 writer->WriteScalar(full_name("current_pos"), offset_)); 180 } 181 return Status::OK(); 182 } 183 RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)184 Status RestoreInternal(IteratorContext* ctx, 185 IteratorStateReader* reader) override { 186 mutex_lock l(mu_); 187 ResetStreamsLocked(); 188 int64 current_topic_index; 189 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_topic_index"), 190 ¤t_topic_index)); 191 current_topic_index_ = size_t(current_topic_index); 192 // The key "current_pos" is written only if the iterator was saved 193 // with an open topic. 194 if (reader->Contains(full_name("current_pos"))) { 195 int64 current_pos; 196 TF_RETURN_IF_ERROR( 197 reader->ReadScalar(full_name("current_pos"), ¤t_pos)); 198 199 TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); 200 topic_partition_->set_offset(current_pos); 201 if (topic_partition_->offset() != current_pos) { 202 return errors::Internal("Failed to restore to offset ", 203 current_pos); 204 } 205 offset_ = current_pos; 206 } 207 return Status::OK(); 208 } 209 210 private: 211 // Sets up Kafka streams to read from the topic at 212 // `current_topic_index_`. SetupStreamsLocked(Env * env)213 Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { 214 if (current_topic_index_ >= dataset()->topics_.size()) { 215 return errors::InvalidArgument( 216 "current_topic_index_:", current_topic_index_, 217 " >= topics_.size():", dataset()->topics_.size()); 218 } 219 220 // Actually move on to next topic. 221 string entry = dataset()->topics_[current_topic_index_]; 222 223 std::vector<string> parts = str_util::Split(entry, ":"); 224 if (parts.size() < 1) { 225 return errors::InvalidArgument("Invalid parameters: ", entry); 226 } 227 string topic = parts[0]; 228 int32 partition = 0; 229 if (parts.size() > 1) { 230 if (!strings::safe_strto32(parts[1], &partition)) { 231 return errors::InvalidArgument("Invalid parameters: ", entry); 232 } 233 } 234 int64 offset = 0; 235 if (parts.size() > 2) { 236 if (!strings::safe_strto64(parts[2], &offset)) { 237 return errors::InvalidArgument("Invalid parameters: ", entry); 238 } 239 } 240 241 topic_partition_.reset( 242 RdKafka::TopicPartition::create(topic, partition, offset)); 243 244 offset_ = topic_partition_->offset(); 245 limit_ = -1; 246 if (parts.size() > 3) { 247 if (!strings::safe_strto64(parts[3], &limit_)) { 248 return errors::InvalidArgument("Invalid parameters: ", entry); 249 } 250 } 251 252 std::unique_ptr<RdKafka::Conf> conf( 253 RdKafka::Conf::create(RdKafka::Conf::CONF_GLOBAL)); 254 std::unique_ptr<RdKafka::Conf> topic_conf( 255 RdKafka::Conf::create(RdKafka::Conf::CONF_TOPIC)); 256 257 std::string errstr; 258 259 RdKafka::Conf::ConfResult result = 260 conf->set("default_topic_conf", topic_conf.get(), errstr); 261 if (result != RdKafka::Conf::CONF_OK) { 262 return errors::Internal("Failed to set default_topic_conf:", errstr); 263 } 264 265 result = conf->set("bootstrap.servers", dataset()->servers_, errstr); 266 if (result != RdKafka::Conf::CONF_OK) { 267 return errors::Internal("Failed to set bootstrap.servers ", 268 dataset()->servers_, ":", errstr); 269 } 270 result = conf->set("group.id", dataset()->group_, errstr); 271 if (result != RdKafka::Conf::CONF_OK) { 272 return errors::Internal("Failed to set group.id ", dataset()->group_, 273 ":", errstr); 274 } 275 276 consumer_.reset(RdKafka::KafkaConsumer::create(conf.get(), errstr)); 277 if (!consumer_.get()) { 278 return errors::Internal("Failed to create consumer:", errstr); 279 } 280 281 std::vector<RdKafka::TopicPartition*> partitions; 282 partitions.emplace_back(topic_partition_.get()); 283 RdKafka::ErrorCode err = consumer_->assign(partitions); 284 if (err != RdKafka::ERR_NO_ERROR) { 285 return errors::Internal( 286 "Failed to assign partition [", topic_partition_->topic(), ", ", 287 topic_partition_->partition(), ", ", topic_partition_->offset(), 288 "]:", RdKafka::err2str(err)); 289 } 290 291 return Status::OK(); 292 } 293 294 // Resets all Kafka streams. ResetStreamsLocked()295 void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { 296 consumer_->unassign(); 297 consumer_->close(); 298 consumer_.reset(nullptr); 299 } 300 301 mutex mu_; 302 size_t current_topic_index_ GUARDED_BY(mu_) = 0; 303 int64 offset_ GUARDED_BY(mu_) = 0; 304 int64 limit_ GUARDED_BY(mu_) = -1; 305 std::unique_ptr<RdKafka::TopicPartition> topic_partition_ GUARDED_BY(mu_); 306 std::unique_ptr<RdKafka::KafkaConsumer> consumer_ GUARDED_BY(mu_); 307 }; 308 309 const std::vector<string> topics_; 310 const std::string servers_; 311 const std::string group_; 312 const bool eof_; 313 const int64 timeout_; 314 }; 315 }; 316 317 REGISTER_KERNEL_BUILDER(Name("KafkaDataset").Device(DEVICE_CPU), 318 KafkaDatasetOp); 319 320 } // namespace tensorflow 321