1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TPU datasets tests.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22 23from tensorflow.core.protobuf import cluster_pb2 24from tensorflow.core.protobuf import config_pb2 25from tensorflow.python.client import session 26from tensorflow.python.data.ops import dataset_ops 27from tensorflow.python.data.ops import readers 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import tensor_shape 31from tensorflow.python.lib.io import python_io 32from tensorflow.python.platform import test 33from tensorflow.python.tpu import datasets 34from tensorflow.python.training import server_lib 35from tensorflow.python.util import compat 36 37_NUM_FILES = 10 38_NUM_ENTRIES = 20 39 40 41class DatasetsTest(test.TestCase): 42 43 def setUp(self): 44 super(DatasetsTest, self).setUp() 45 self._coord = server_lib.Server.create_local_server() 46 self._worker = server_lib.Server.create_local_server() 47 48 self._cluster_def = cluster_pb2.ClusterDef() 49 worker_job = self._cluster_def.job.add() 50 worker_job.name = 'worker' 51 worker_job.tasks[0] = self._worker.target[len('grpc://'):] 52 coord_job = self._cluster_def.job.add() 53 coord_job.name = 'coordinator' 54 coord_job.tasks[0] = self._coord.target[len('grpc://'):] 55 56 session_config = config_pb2.ConfigProto(cluster_def=self._cluster_def) 57 58 self._sess = session.Session(self._worker.target, config=session_config) 59 self._worker_device = '/job:' + worker_job.name 60 61 def testTextLineDataset(self): 62 all_contents = [] 63 for i in range(_NUM_FILES): 64 filename = os.path.join(self.get_temp_dir(), 'text_line.%d.txt' % i) 65 contents = [] 66 for j in range(_NUM_ENTRIES): 67 contents.append(compat.as_bytes('%d: %d' % (i, j))) 68 with open(filename, 'wb') as f: 69 f.write(b'\n'.join(contents)) 70 all_contents.extend(contents) 71 72 dataset = datasets.StreamingFilesDataset( 73 os.path.join(self.get_temp_dir(), 'text_line.*.txt'), filetype='text') 74 75 with ops.device(self._worker_device): 76 iterator = dataset_ops.make_initializable_iterator(dataset) 77 self._sess.run(iterator.initializer) 78 get_next = iterator.get_next() 79 80 retrieved_values = [] 81 for _ in range(4 * len(all_contents)): 82 retrieved_values.append(compat.as_bytes(self._sess.run(get_next))) 83 84 self.assertEqual(set(all_contents), set(retrieved_values)) 85 86 def testTFRecordDataset(self): 87 all_contents = [] 88 for i in range(_NUM_FILES): 89 filename = os.path.join(self.get_temp_dir(), 'tf_record.%d' % i) 90 writer = python_io.TFRecordWriter(filename) 91 for j in range(_NUM_ENTRIES): 92 record = compat.as_bytes('Record %d of file %d' % (j, i)) 93 writer.write(record) 94 all_contents.append(record) 95 writer.close() 96 97 dataset = datasets.StreamingFilesDataset( 98 os.path.join(self.get_temp_dir(), 'tf_record*'), filetype='tfrecord') 99 100 with ops.device(self._worker_device): 101 iterator = dataset_ops.make_initializable_iterator(dataset) 102 self._sess.run(iterator.initializer) 103 get_next = iterator.get_next() 104 105 retrieved_values = [] 106 for _ in range(4 * len(all_contents)): 107 retrieved_values.append(compat.as_bytes(self._sess.run(get_next))) 108 109 self.assertEqual(set(all_contents), set(retrieved_values)) 110 111 def testTFRecordDatasetFromDataset(self): 112 filenames = [] 113 all_contents = [] 114 for i in range(_NUM_FILES): 115 filename = os.path.join(self.get_temp_dir(), 'tf_record.%d' % i) 116 filenames.append(filename) 117 writer = python_io.TFRecordWriter(filename) 118 for j in range(_NUM_ENTRIES): 119 record = compat.as_bytes('Record %d of file %d' % (j, i)) 120 writer.write(record) 121 all_contents.append(record) 122 writer.close() 123 124 filenames = dataset_ops.Dataset.from_tensor_slices(filenames) 125 126 dataset = datasets.StreamingFilesDataset(filenames, filetype='tfrecord') 127 128 with ops.device(self._worker_device): 129 iterator = dataset_ops.make_initializable_iterator(dataset) 130 self._sess.run(iterator.initializer) 131 get_next = iterator.get_next() 132 133 retrieved_values = [] 134 for _ in range(4 * len(all_contents)): 135 retrieved_values.append(compat.as_bytes(self._sess.run(get_next))) 136 137 self.assertEqual(set(all_contents), set(retrieved_values)) 138 139 def testArbitraryReaderFunc(self): 140 141 def MakeRecord(i, j): 142 return compat.as_bytes('%04d-%04d' % (i, j)) 143 144 record_bytes = len(MakeRecord(10, 200)) 145 146 all_contents = [] 147 for i in range(_NUM_FILES): 148 filename = os.path.join(self.get_temp_dir(), 'fixed_length.%d' % i) 149 with open(filename, 'wb') as f: 150 for j in range(_NUM_ENTRIES): 151 record = MakeRecord(i, j) 152 f.write(record) 153 all_contents.append(record) 154 155 def FixedLengthFile(filename): 156 return readers.FixedLengthRecordDataset(filename, record_bytes) 157 158 dataset = datasets.StreamingFilesDataset( 159 os.path.join(self.get_temp_dir(), 'fixed_length*'), 160 filetype=FixedLengthFile) 161 162 with ops.device(self._worker_device): 163 iterator = dataset_ops.make_initializable_iterator(dataset) 164 self._sess.run(iterator.initializer) 165 get_next = iterator.get_next() 166 167 retrieved_values = [] 168 for _ in range(4 * len(all_contents)): 169 retrieved_values.append(compat.as_bytes(self._sess.run(get_next))) 170 171 self.assertEqual(set(all_contents), set(retrieved_values)) 172 173 def testArbitraryReaderFuncFromDatasetGenerator(self): 174 175 def my_generator(): 176 yield (1, [1] * 10) 177 178 def gen_dataset(dummy): 179 return dataset_ops.Dataset.from_generator( 180 my_generator, (dtypes.int64, dtypes.int64), 181 (tensor_shape.TensorShape([]), tensor_shape.TensorShape([10]))) 182 183 dataset = datasets.StreamingFilesDataset( 184 dataset_ops.Dataset.range(10), filetype=gen_dataset) 185 186 with ops.device(self._worker_device): 187 iterator = dataset_ops.make_initializable_iterator(dataset) 188 self._sess.run(iterator.initializer) 189 get_next = iterator.get_next() 190 191 retrieved_values = self._sess.run(get_next) 192 193 self.assertIsInstance(retrieved_values, (list, tuple)) 194 self.assertEqual(len(retrieved_values), 2) 195 self.assertEqual(retrieved_values[0], 1) 196 self.assertItemsEqual(retrieved_values[1], [1] * 10) 197 198 def testUnexpectedFiletypeString(self): 199 with self.assertRaises(ValueError): 200 datasets.StreamingFilesDataset( 201 os.path.join(self.get_temp_dir(), '*'), filetype='foo') 202 203 def testUnexpectedFiletypeType(self): 204 with self.assertRaises(ValueError): 205 datasets.StreamingFilesDataset( 206 os.path.join(self.get_temp_dir(), '*'), filetype=3) 207 208 def testUnexpectedFilesType(self): 209 with self.assertRaises(ValueError): 210 datasets.StreamingFilesDataset(123, filetype='tfrecord') 211 212 213if __name__ == '__main__': 214 test.main() 215