1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14# ============================================================================== 15"""Tests for KafkaDataset.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.contrib.kafka.python.ops import kafka_dataset_ops 22from tensorflow.python.data.ops import dataset_ops 23from tensorflow.python.data.ops import iterator_ops 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import errors 26from tensorflow.python.ops import array_ops 27from tensorflow.python.platform import test 28 29 30class KafkaDatasetTest(test.TestCase): 31 32 def setUp(self): 33 # The Kafka server has to be setup before the test 34 # and tear down after the test manually. 35 # The docker engine has to be installed. 36 # 37 # To setup the Kafka server: 38 # $ bash kafka_test.sh start kafka 39 # 40 # To team down the Kafka server: 41 # $ bash kafka_test.sh stop kafka 42 pass 43 44 def testKafkaDataset(self): 45 topics = array_ops.placeholder(dtypes.string, shape=[None]) 46 num_epochs = array_ops.placeholder(dtypes.int64, shape=[]) 47 batch_size = array_ops.placeholder(dtypes.int64, shape=[]) 48 49 repeat_dataset = kafka_dataset_ops.KafkaDataset( 50 topics, group="test", eof=True).repeat(num_epochs) 51 batch_dataset = repeat_dataset.batch(batch_size) 52 53 iterator = iterator_ops.Iterator.from_structure( 54 dataset_ops.get_legacy_output_types(batch_dataset)) 55 init_op = iterator.make_initializer(repeat_dataset) 56 init_batch_op = iterator.make_initializer(batch_dataset) 57 get_next = iterator.get_next() 58 59 with self.cached_session() as sess: 60 # Basic test: read from topic 0. 61 sess.run(init_op, feed_dict={topics: ["test:0:0:4"], num_epochs: 1}) 62 for i in range(5): 63 self.assertEqual("D" + str(i), sess.run(get_next)) 64 with self.assertRaises(errors.OutOfRangeError): 65 sess.run(get_next) 66 67 # Basic test: read from topic 1. 68 sess.run(init_op, feed_dict={topics: ["test:0:5:-1"], num_epochs: 1}) 69 for i in range(5): 70 self.assertEqual("D" + str(i + 5), sess.run(get_next)) 71 with self.assertRaises(errors.OutOfRangeError): 72 sess.run(get_next) 73 74 # Basic test: read from both topics. 75 sess.run( 76 init_op, 77 feed_dict={ 78 topics: ["test:0:0:4", "test:0:5:-1"], 79 num_epochs: 1 80 }) 81 for j in range(2): 82 for i in range(5): 83 self.assertEqual("D" + str(i + j * 5), sess.run(get_next)) 84 with self.assertRaises(errors.OutOfRangeError): 85 sess.run(get_next) 86 87 # Test repeated iteration through both files. 88 sess.run( 89 init_op, 90 feed_dict={ 91 topics: ["test:0:0:4", "test:0:5:-1"], 92 num_epochs: 10 93 }) 94 for _ in range(10): 95 for j in range(2): 96 for i in range(5): 97 self.assertEqual("D" + str(i + j * 5), sess.run(get_next)) 98 with self.assertRaises(errors.OutOfRangeError): 99 sess.run(get_next) 100 101 # Test batched and repeated iteration through both files. 102 sess.run( 103 init_batch_op, 104 feed_dict={ 105 topics: ["test:0:0:4", "test:0:5:-1"], 106 num_epochs: 10, 107 batch_size: 5 108 }) 109 for _ in range(10): 110 self.assertAllEqual(["D" + str(i) for i in range(5)], 111 sess.run(get_next)) 112 self.assertAllEqual(["D" + str(i + 5) for i in range(5)], 113 sess.run(get_next)) 114 115 116if __name__ == "__main__": 117 test.main() 118