• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TPU datasets tests."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22
23from tensorflow.core.protobuf import cluster_pb2
24from tensorflow.core.protobuf import config_pb2
25from tensorflow.python.client import session
26from tensorflow.python.data.ops import dataset_ops
27from tensorflow.python.data.ops import readers
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import tensor_shape
31from tensorflow.python.lib.io import python_io
32from tensorflow.python.platform import test
33from tensorflow.python.tpu import datasets
34from tensorflow.python.training import server_lib
35from tensorflow.python.util import compat
36
37_NUM_FILES = 10
38_NUM_ENTRIES = 20
39
40
41class DatasetsTest(test.TestCase):
42
43  def setUp(self):
44    super(DatasetsTest, self).setUp()
45    self._coord = server_lib.Server.create_local_server()
46    self._worker = server_lib.Server.create_local_server()
47
48    self._cluster_def = cluster_pb2.ClusterDef()
49    worker_job = self._cluster_def.job.add()
50    worker_job.name = 'worker'
51    worker_job.tasks[0] = self._worker.target[len('grpc://'):]
52    coord_job = self._cluster_def.job.add()
53    coord_job.name = 'coordinator'
54    coord_job.tasks[0] = self._coord.target[len('grpc://'):]
55
56    session_config = config_pb2.ConfigProto(cluster_def=self._cluster_def)
57
58    self._sess = session.Session(self._worker.target, config=session_config)
59    self._worker_device = '/job:' + worker_job.name
60
61  def testTextLineDataset(self):
62    all_contents = []
63    for i in range(_NUM_FILES):
64      filename = os.path.join(self.get_temp_dir(), 'text_line.%d.txt' % i)
65      contents = []
66      for j in range(_NUM_ENTRIES):
67        contents.append(compat.as_bytes('%d: %d' % (i, j)))
68      with open(filename, 'wb') as f:
69        f.write(b'\n'.join(contents))
70      all_contents.extend(contents)
71
72    dataset = datasets.StreamingFilesDataset(
73        os.path.join(self.get_temp_dir(), 'text_line.*.txt'), filetype='text')
74
75    with ops.device(self._worker_device):
76      iterator = dataset_ops.make_initializable_iterator(dataset)
77    self._sess.run(iterator.initializer)
78    get_next = iterator.get_next()
79
80    retrieved_values = []
81    for _ in range(4 * len(all_contents)):
82      retrieved_values.append(compat.as_bytes(self._sess.run(get_next)))
83
84    self.assertEqual(set(all_contents), set(retrieved_values))
85
86  def testTFRecordDataset(self):
87    all_contents = []
88    for i in range(_NUM_FILES):
89      filename = os.path.join(self.get_temp_dir(), 'tf_record.%d' % i)
90      writer = python_io.TFRecordWriter(filename)
91      for j in range(_NUM_ENTRIES):
92        record = compat.as_bytes('Record %d of file %d' % (j, i))
93        writer.write(record)
94        all_contents.append(record)
95      writer.close()
96
97    dataset = datasets.StreamingFilesDataset(
98        os.path.join(self.get_temp_dir(), 'tf_record*'), filetype='tfrecord')
99
100    with ops.device(self._worker_device):
101      iterator = dataset_ops.make_initializable_iterator(dataset)
102    self._sess.run(iterator.initializer)
103    get_next = iterator.get_next()
104
105    retrieved_values = []
106    for _ in range(4 * len(all_contents)):
107      retrieved_values.append(compat.as_bytes(self._sess.run(get_next)))
108
109    self.assertEqual(set(all_contents), set(retrieved_values))
110
111  def testTFRecordDatasetFromDataset(self):
112    filenames = []
113    all_contents = []
114    for i in range(_NUM_FILES):
115      filename = os.path.join(self.get_temp_dir(), 'tf_record.%d' % i)
116      filenames.append(filename)
117      writer = python_io.TFRecordWriter(filename)
118      for j in range(_NUM_ENTRIES):
119        record = compat.as_bytes('Record %d of file %d' % (j, i))
120        writer.write(record)
121        all_contents.append(record)
122      writer.close()
123
124    filenames = dataset_ops.Dataset.from_tensor_slices(filenames)
125
126    dataset = datasets.StreamingFilesDataset(filenames, filetype='tfrecord')
127
128    with ops.device(self._worker_device):
129      iterator = dataset_ops.make_initializable_iterator(dataset)
130    self._sess.run(iterator.initializer)
131    get_next = iterator.get_next()
132
133    retrieved_values = []
134    for _ in range(4 * len(all_contents)):
135      retrieved_values.append(compat.as_bytes(self._sess.run(get_next)))
136
137    self.assertEqual(set(all_contents), set(retrieved_values))
138
139  def testArbitraryReaderFunc(self):
140
141    def MakeRecord(i, j):
142      return compat.as_bytes('%04d-%04d' % (i, j))
143
144    record_bytes = len(MakeRecord(10, 200))
145
146    all_contents = []
147    for i in range(_NUM_FILES):
148      filename = os.path.join(self.get_temp_dir(), 'fixed_length.%d' % i)
149      with open(filename, 'wb') as f:
150        for j in range(_NUM_ENTRIES):
151          record = MakeRecord(i, j)
152          f.write(record)
153          all_contents.append(record)
154
155    def FixedLengthFile(filename):
156      return readers.FixedLengthRecordDataset(filename, record_bytes)
157
158    dataset = datasets.StreamingFilesDataset(
159        os.path.join(self.get_temp_dir(), 'fixed_length*'),
160        filetype=FixedLengthFile)
161
162    with ops.device(self._worker_device):
163      iterator = dataset_ops.make_initializable_iterator(dataset)
164    self._sess.run(iterator.initializer)
165    get_next = iterator.get_next()
166
167    retrieved_values = []
168    for _ in range(4 * len(all_contents)):
169      retrieved_values.append(compat.as_bytes(self._sess.run(get_next)))
170
171    self.assertEqual(set(all_contents), set(retrieved_values))
172
173  def testArbitraryReaderFuncFromDatasetGenerator(self):
174
175    def my_generator():
176      yield (1, [1] * 10)
177
178    def gen_dataset(dummy):
179      return dataset_ops.Dataset.from_generator(
180          my_generator, (dtypes.int64, dtypes.int64),
181          (tensor_shape.TensorShape([]), tensor_shape.TensorShape([10])))
182
183    dataset = datasets.StreamingFilesDataset(
184        dataset_ops.Dataset.range(10), filetype=gen_dataset)
185
186    with ops.device(self._worker_device):
187      iterator = dataset_ops.make_initializable_iterator(dataset)
188    self._sess.run(iterator.initializer)
189    get_next = iterator.get_next()
190
191    retrieved_values = self._sess.run(get_next)
192
193    self.assertIsInstance(retrieved_values, (list, tuple))
194    self.assertEqual(len(retrieved_values), 2)
195    self.assertEqual(retrieved_values[0], 1)
196    self.assertItemsEqual(retrieved_values[1], [1] * 10)
197
198  def testUnexpectedFiletypeString(self):
199    with self.assertRaises(ValueError):
200      datasets.StreamingFilesDataset(
201          os.path.join(self.get_temp_dir(), '*'), filetype='foo')
202
203  def testUnexpectedFiletypeType(self):
204    with self.assertRaises(ValueError):
205      datasets.StreamingFilesDataset(
206          os.path.join(self.get_temp_dir(), '*'), filetype=3)
207
208  def testUnexpectedFilesType(self):
209    with self.assertRaises(ValueError):
210      datasets.StreamingFilesDataset(123, filetype='tfrecord')
211
212
213if __name__ == '__main__':
214  test.main()
215