1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/dataset.h"
17 
18 #include "rdkafkacpp.h"
19 
20 namespace tensorflow {
21 
22 class KafkaDatasetOp : public DatasetOpKernel {
23  public:
24   using DatasetOpKernel::DatasetOpKernel;
25 
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)26   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
27     const Tensor* topics_tensor;
28     OP_REQUIRES_OK(ctx, ctx->input("topics", &topics_tensor));
29     OP_REQUIRES(
30         ctx, topics_tensor->dims() <= 1,
31         errors::InvalidArgument("`topics` must be a scalar or a vector."));
32 
33     std::vector<string> topics;
34     topics.reserve(topics_tensor->NumElements());
35     for (int i = 0; i < topics_tensor->NumElements(); ++i) {
36       topics.push_back(topics_tensor->flat<string>()(i));
37     }
38 
39     std::string servers = "";
40     OP_REQUIRES_OK(ctx,
41                    ParseScalarArgument<std::string>(ctx, "servers", &servers));
42     std::string group = "";
43     OP_REQUIRES_OK(ctx, ParseScalarArgument<std::string>(ctx, "group", &group));
44     bool eof = false;
45     OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "eof", &eof));
46     int64 timeout = -1;
47     OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "timeout", &timeout));
48     OP_REQUIRES(ctx, (timeout > 0),
49                 errors::InvalidArgument(
50                     "Timeout value should be large than 0, got ", timeout));
51     *output = new Dataset(ctx, std::move(topics), servers, group, eof, timeout);
52   }
53 
54  private:
55   class Dataset : public DatasetBase {
56    public:
Dataset(OpKernelContext * ctx,std::vector<string> topics,const string & servers,const string & group,const bool eof,const int64 timeout)57     Dataset(OpKernelContext* ctx, std::vector<string> topics,
58             const string& servers, const string& group, const bool eof,
59             const int64 timeout)
60         : DatasetBase(DatasetContext(ctx)),
61           topics_(std::move(topics)),
62           servers_(servers),
63           group_(group),
64           eof_(eof),
65           timeout_(timeout) {}
66 
MakeIteratorInternal(const string & prefix) const67     std::unique_ptr<IteratorBase> MakeIteratorInternal(
68         const string& prefix) const override {
69       return std::unique_ptr<IteratorBase>(
70           new Iterator({this, strings::StrCat(prefix, "::Kafka")}));
71     }
72 
output_dtypes() const73     const DataTypeVector& output_dtypes() const override {
74       static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
75       return *dtypes;
76     }
77 
output_shapes() const78     const std::vector<PartialTensorShape>& output_shapes() const override {
79       static std::vector<PartialTensorShape>* shapes =
80           new std::vector<PartialTensorShape>({{}});
81       return *shapes;
82     }
83 
DebugString() const84     string DebugString() const override { return "KafkaDatasetOp::Dataset"; }
85 
86    protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const87     Status AsGraphDefInternal(SerializationContext* ctx,
88                               DatasetGraphDefBuilder* b,
89                               Node** output) const override {
90       Node* topics = nullptr;
91       TF_RETURN_IF_ERROR(b->AddVector(topics_, &topics));
92       Node* servers = nullptr;
93       TF_RETURN_IF_ERROR(b->AddScalar(servers_, &servers));
94       Node* group = nullptr;
95       TF_RETURN_IF_ERROR(b->AddScalar(group_, &group));
96       Node* eof = nullptr;
97       TF_RETURN_IF_ERROR(b->AddScalar(eof_, &eof));
98       Node* timeout = nullptr;
99       TF_RETURN_IF_ERROR(b->AddScalar(timeout_, &timeout));
100       TF_RETURN_IF_ERROR(
101           b->AddDataset(this, {topics, servers, group, eof, timeout}, output));
102       return Status::OK();
103     }
104 
105    private:
106     class Iterator : public DatasetIterator<Dataset> {
107      public:
Iterator(const Params & params)108       explicit Iterator(const Params& params)
109           : DatasetIterator<Dataset>(params) {}
110 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)111       Status GetNextInternal(IteratorContext* ctx,
112                              std::vector<Tensor>* out_tensors,
113                              bool* end_of_sequence) override {
114         mutex_lock l(mu_);
115         do {
116           // We are currently processing a topic, so try to read the next line.
117           if (consumer_.get()) {
118             while (true) {
119               if (limit_ >= 0 &&
120                   (topic_partition_->offset() >= limit_ || offset_ >= limit_)) {
121                 // EOF current topic
122                 break;
123               }
124               std::unique_ptr<RdKafka::Message> message(
125                   consumer_->consume(dataset()->timeout_));
126               if (message->err() == RdKafka::ERR_NO_ERROR) {
127                 // Produce the line as output.
128                 Tensor line_tensor(cpu_allocator(), DT_STRING, {});
129                 line_tensor.scalar<string>()() =
130                     std::string(static_cast<const char*>(message->payload()),
131                                 message->len());
132                 out_tensors->emplace_back(std::move(line_tensor));
133                 *end_of_sequence = false;
134                 // Sync offset
135                 offset_ = message->offset();
136                 return Status::OK();
137               }
138 
139               if (message->err() == RdKafka::ERR__PARTITION_EOF &&
140                   dataset()->eof_) {
141                 // EOF current topic
142                 break;
143               }
144               if (message->err() != RdKafka::ERR__TIMED_OUT) {
145                 return errors::Internal("Failed to consume:",
146                                         message->errstr());
147               }
148               message.reset(nullptr);
149               consumer_->poll(0);
150             }
151 
152             // We have reached the end of the current topic, so maybe
153             // move on to next topic.
154             ResetStreamsLocked();
155             ++current_topic_index_;
156           }
157 
158           // Iteration ends when there are no more topic to process.
159           if (current_topic_index_ == dataset()->topics_.size()) {
160             *end_of_sequence = true;
161             return Status::OK();
162           }
163 
164           TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
165         } while (true);
166       }
167 
168      protected:
SaveInternal(IteratorStateWriter * writer)169       Status SaveInternal(IteratorStateWriter* writer) override {
170         mutex_lock l(mu_);
171         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_topic_index"),
172                                                current_topic_index_));
173 
174         // `consumer_` is empty if
175         // 1. GetNext has not been called even once.
176         // 2. All topics have been read and iterator has been exhausted.
177         if (consumer_.get()) {
178           TF_RETURN_IF_ERROR(
179               writer->WriteScalar(full_name("current_pos"), offset_));
180         }
181         return Status::OK();
182       }
183 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)184       Status RestoreInternal(IteratorContext* ctx,
185                              IteratorStateReader* reader) override {
186         mutex_lock l(mu_);
187         ResetStreamsLocked();
188         int64 current_topic_index;
189         TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_topic_index"),
190                                               &current_topic_index));
191         current_topic_index_ = size_t(current_topic_index);
192         // The key "current_pos" is written only if the iterator was saved
193         // with an open topic.
194         if (reader->Contains(full_name("current_pos"))) {
195           int64 current_pos;
196           TF_RETURN_IF_ERROR(
197               reader->ReadScalar(full_name("current_pos"), &current_pos));
198 
199           TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
200           topic_partition_->set_offset(current_pos);
201           if (topic_partition_->offset() != current_pos) {
202             return errors::Internal("Failed to restore to offset ",
203                                     current_pos);
204           }
205           offset_ = current_pos;
206         }
207         return Status::OK();
208       }
209 
210      private:
211       // Sets up Kafka streams to read from the topic at
212       // `current_topic_index_`.
SetupStreamsLocked(Env * env)213       Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
214         if (current_topic_index_ >= dataset()->topics_.size()) {
215           return errors::InvalidArgument(
216               "current_topic_index_:", current_topic_index_,
217               " >= topics_.size():", dataset()->topics_.size());
218         }
219 
220         // Actually move on to next topic.
221         string entry = dataset()->topics_[current_topic_index_];
222 
223         std::vector<string> parts = str_util::Split(entry, ":");
224         if (parts.size() < 1) {
225           return errors::InvalidArgument("Invalid parameters: ", entry);
226         }
227         string topic = parts[0];
228         int32 partition = 0;
229         if (parts.size() > 1) {
230           if (!strings::safe_strto32(parts[1], &partition)) {
231             return errors::InvalidArgument("Invalid parameters: ", entry);
232           }
233         }
234         int64 offset = 0;
235         if (parts.size() > 2) {
236           if (!strings::safe_strto64(parts[2], &offset)) {
237             return errors::InvalidArgument("Invalid parameters: ", entry);
238           }
239         }
240 
241         topic_partition_.reset(
242             RdKafka::TopicPartition::create(topic, partition, offset));
243 
244         offset_ = topic_partition_->offset();
245         limit_ = -1;
246         if (parts.size() > 3) {
247           if (!strings::safe_strto64(parts[3], &limit_)) {
248             return errors::InvalidArgument("Invalid parameters: ", entry);
249           }
250         }
251 
252         std::unique_ptr<RdKafka::Conf> conf(
253             RdKafka::Conf::create(RdKafka::Conf::CONF_GLOBAL));
254         std::unique_ptr<RdKafka::Conf> topic_conf(
255             RdKafka::Conf::create(RdKafka::Conf::CONF_TOPIC));
256 
257         std::string errstr;
258 
259         RdKafka::Conf::ConfResult result =
260             conf->set("default_topic_conf", topic_conf.get(), errstr);
261         if (result != RdKafka::Conf::CONF_OK) {
262           return errors::Internal("Failed to set default_topic_conf:", errstr);
263         }
264 
265         result = conf->set("bootstrap.servers", dataset()->servers_, errstr);
266         if (result != RdKafka::Conf::CONF_OK) {
267           return errors::Internal("Failed to set bootstrap.servers ",
268                                   dataset()->servers_, ":", errstr);
269         }
270         result = conf->set("group.id", dataset()->group_, errstr);
271         if (result != RdKafka::Conf::CONF_OK) {
272           return errors::Internal("Failed to set group.id ", dataset()->group_,
273                                   ":", errstr);
274         }
275 
276         consumer_.reset(RdKafka::KafkaConsumer::create(conf.get(), errstr));
277         if (!consumer_.get()) {
278           return errors::Internal("Failed to create consumer:", errstr);
279         }
280 
281         std::vector<RdKafka::TopicPartition*> partitions;
282         partitions.emplace_back(topic_partition_.get());
283         RdKafka::ErrorCode err = consumer_->assign(partitions);
284         if (err != RdKafka::ERR_NO_ERROR) {
285           return errors::Internal(
286               "Failed to assign partition [", topic_partition_->topic(), ", ",
287               topic_partition_->partition(), ", ", topic_partition_->offset(),
288               "]:", RdKafka::err2str(err));
289         }
290 
291         return Status::OK();
292       }
293 
294       // Resets all Kafka streams.
ResetStreamsLocked()295       void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
296         consumer_->unassign();
297         consumer_->close();
298         consumer_.reset(nullptr);
299       }
300 
301       mutex mu_;
302       size_t current_topic_index_ GUARDED_BY(mu_) = 0;
303       int64 offset_ GUARDED_BY(mu_) = 0;
304       int64 limit_ GUARDED_BY(mu_) = -1;
305       std::unique_ptr<RdKafka::TopicPartition> topic_partition_ GUARDED_BY(mu_);
306       std::unique_ptr<RdKafka::KafkaConsumer> consumer_ GUARDED_BY(mu_);
307     };
308 
309     const std::vector<string> topics_;
310     const std::string servers_;
311     const std::string group_;
312     const bool eof_;
313     const int64 timeout_;
314   };
315 };
316 
317 REGISTER_KERNEL_BUILDER(Name("KafkaDataset").Device(DEVICE_CPU),
318                         KafkaDatasetOp);
319 
320 }  // namespace tensorflow
321