1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for data_utils."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from itertools import cycle
22import os
23import tarfile
24import threading
25import unittest
26import zipfile
27
28import numpy as np
29from six.moves.urllib.parse import urljoin
30from six.moves.urllib.request import pathname2url
31
32from tensorflow.python import keras
33from tensorflow.python.platform import test
34
35
36class TestGetFileAndValidateIt(test.TestCase):
37
38  def test_get_file_and_validate_it(self):
39    """Tests get_file from a url, plus extraction and validation.
40    """
41    dest_dir = self.get_temp_dir()
42    orig_dir = self.get_temp_dir()
43
44    text_file_path = os.path.join(orig_dir, 'test.txt')
45    zip_file_path = os.path.join(orig_dir, 'test.zip')
46    tar_file_path = os.path.join(orig_dir, 'test.tar.gz')
47
48    with open(text_file_path, 'w') as text_file:
49      text_file.write('Float like a butterfly, sting like a bee.')
50
51    with tarfile.open(tar_file_path, 'w:gz') as tar_file:
52      tar_file.add(text_file_path)
53
54    with zipfile.ZipFile(zip_file_path, 'w') as zip_file:
55      zip_file.write(text_file_path)
56
57    origin = urljoin('file://', pathname2url(os.path.abspath(tar_file_path)))
58
59    path = keras.utils.data_utils.get_file('test.txt', origin,
60                                           untar=True, cache_subdir=dest_dir)
61    filepath = path + '.tar.gz'
62    hashval_sha256 = keras.utils.data_utils._hash_file(filepath)
63    hashval_md5 = keras.utils.data_utils._hash_file(filepath, algorithm='md5')
64    path = keras.utils.data_utils.get_file(
65        'test.txt', origin, md5_hash=hashval_md5,
66        untar=True, cache_subdir=dest_dir)
67    path = keras.utils.data_utils.get_file(
68        filepath, origin, file_hash=hashval_sha256,
69        extract=True, cache_subdir=dest_dir)
70    self.assertTrue(os.path.exists(filepath))
71    self.assertTrue(keras.utils.data_utils.validate_file(filepath,
72                                                         hashval_sha256))
73    self.assertTrue(keras.utils.data_utils.validate_file(filepath, hashval_md5))
74    os.remove(filepath)
75
76    origin = urljoin('file://', pathname2url(os.path.abspath(zip_file_path)))
77
78    hashval_sha256 = keras.utils.data_utils._hash_file(zip_file_path)
79    hashval_md5 = keras.utils.data_utils._hash_file(zip_file_path,
80                                                    algorithm='md5')
81    path = keras.utils.data_utils.get_file(
82        'test', origin, md5_hash=hashval_md5,
83        extract=True, cache_subdir=dest_dir)
84    path = keras.utils.data_utils.get_file(
85        'test', origin, file_hash=hashval_sha256,
86        extract=True, cache_subdir=dest_dir)
87    self.assertTrue(os.path.exists(path))
88    self.assertTrue(keras.utils.data_utils.validate_file(path, hashval_sha256))
89    self.assertTrue(keras.utils.data_utils.validate_file(path, hashval_md5))
90
91
92class ThreadsafeIter(object):
93
94  def __init__(self, it):
95    self.it = it
96    self.lock = threading.Lock()
97
98  def __iter__(self):
99    return self
100
101  def __next__(self):
102    return self.next()
103
104  def next(self):
105    with self.lock:
106      return next(self.it)
107
108
109def threadsafe_generator(f):
110
111  def g(*a, **kw):
112    return ThreadsafeIter(f(*a, **kw))
113
114  return g
115
116
117class TestSequence(keras.utils.data_utils.Sequence):
118
119  def __init__(self, shape, value=1.):
120    self.shape = shape
121    self.inner = value
122
123  def __getitem__(self, item):
124    return np.ones(self.shape, dtype=np.uint32) * item * self.inner
125
126  def __len__(self):
127    return 100
128
129  def on_epoch_end(self):
130    self.inner *= 5.0
131
132
133class FaultSequence(keras.utils.data_utils.Sequence):
134
135  def __getitem__(self, item):
136    raise IndexError(item, 'item is not present')
137
138  def __len__(self):
139    return 100
140
141
142@threadsafe_generator
143def create_generator_from_sequence_threads(ds):
144  for i in cycle(range(len(ds))):
145    yield ds[i]
146
147
148def create_generator_from_sequence_pcs(ds):
149  for i in cycle(range(len(ds))):
150    yield ds[i]
151
152
153class TestEnqueuers(test.TestCase):
154
155  def test_generator_enqueuer_threads(self):
156    enqueuer = keras.utils.data_utils.GeneratorEnqueuer(
157        create_generator_from_sequence_threads(TestSequence([3, 200, 200, 3])),
158        use_multiprocessing=False)
159    enqueuer.start(3, 10)
160    gen_output = enqueuer.get()
161    acc = []
162    for _ in range(100):
163      acc.append(int(next(gen_output)[0, 0, 0, 0]))
164
165    self.assertEqual(len(set(acc) - set(range(100))), 0)
166    enqueuer.stop()
167
168  @unittest.skipIf(
169      os.name == 'nt',
170      'use_multiprocessing=True does not work on windows properly.')
171  def test_generator_enqueuer_processes(self):
172    enqueuer = keras.utils.data_utils.GeneratorEnqueuer(
173        create_generator_from_sequence_pcs(TestSequence([3, 200, 200, 3])),
174        use_multiprocessing=True)
175    enqueuer.start(3, 10)
176    gen_output = enqueuer.get()
177    acc = []
178    for _ in range(100):
179      acc.append(int(next(gen_output)[0, 0, 0, 0]))
180    self.assertNotEqual(acc, list(range(100)))
181    enqueuer.stop()
182
183  def test_generator_enqueuer_fail_threads(self):
184    enqueuer = keras.utils.data_utils.GeneratorEnqueuer(
185        create_generator_from_sequence_threads(FaultSequence()),
186        use_multiprocessing=False)
187    enqueuer.start(3, 10)
188    gen_output = enqueuer.get()
189    with self.assertRaises(IndexError):
190      next(gen_output)
191
192  @unittest.skipIf(
193      os.name == 'nt',
194      'use_multiprocessing=True does not work on windows properly.')
195  def test_generator_enqueuer_fail_processes(self):
196    enqueuer = keras.utils.data_utils.GeneratorEnqueuer(
197        create_generator_from_sequence_pcs(FaultSequence()),
198        use_multiprocessing=True)
199    enqueuer.start(3, 10)
200    gen_output = enqueuer.get()
201    with self.assertRaises(IndexError):
202      next(gen_output)
203
204  def test_ordered_enqueuer_threads(self):
205    enqueuer = keras.utils.data_utils.OrderedEnqueuer(
206        TestSequence([3, 200, 200, 3]), use_multiprocessing=False)
207    enqueuer.start(3, 10)
208    gen_output = enqueuer.get()
209    acc = []
210    for _ in range(100):
211      acc.append(next(gen_output)[0, 0, 0, 0])
212    self.assertEqual(acc, list(range(100)))
213    enqueuer.stop()
214
215  def test_ordered_enqueuer_processes(self):
216    enqueuer = keras.utils.data_utils.OrderedEnqueuer(
217        TestSequence([3, 200, 200, 3]), use_multiprocessing=True)
218    enqueuer.start(3, 10)
219    gen_output = enqueuer.get()
220    acc = []
221    for _ in range(100):
222      acc.append(next(gen_output)[0, 0, 0, 0])
223    self.assertEqual(acc, list(range(100)))
224    enqueuer.stop()
225
226  def test_ordered_enqueuer_fail_threads(self):
227    enqueuer = keras.utils.data_utils.OrderedEnqueuer(
228        FaultSequence(), use_multiprocessing=False)
229    enqueuer.start(3, 10)
230    gen_output = enqueuer.get()
231    with self.assertRaises(IndexError):
232      next(gen_output)
233
234  def test_ordered_enqueuer_fail_processes(self):
235    enqueuer = keras.utils.data_utils.OrderedEnqueuer(
236        FaultSequence(), use_multiprocessing=True)
237    enqueuer.start(3, 10)
238    gen_output = enqueuer.get()
239    with self.assertRaises(IndexError):
240      next(gen_output)
241
242  def test_on_epoch_end_processes(self):
243    enqueuer = keras.utils.data_utils.OrderedEnqueuer(
244        TestSequence([3, 200, 200, 3]), use_multiprocessing=True)
245    enqueuer.start(3, 10)
246    gen_output = enqueuer.get()
247    acc = []
248    for _ in range(200):
249      acc.append(next(gen_output)[0, 0, 0, 0])
250    # Check that order was keep in GeneratorEnqueuer with processes
251    self.assertEqual(acc[100:], list([k * 5 for k in range(100)]))
252    enqueuer.stop()
253
254  def test_context_switch(self):
255    enqueuer = keras.utils.data_utils.OrderedEnqueuer(
256        TestSequence([3, 200, 200, 3]), use_multiprocessing=True)
257    enqueuer2 = keras.utils.data_utils.OrderedEnqueuer(
258        TestSequence([3, 200, 200, 3], value=15), use_multiprocessing=True)
259    enqueuer.start(3, 10)
260    enqueuer2.start(3, 10)
261    gen_output = enqueuer.get()
262    gen_output2 = enqueuer2.get()
263    acc = []
264    for _ in range(100):
265      acc.append(next(gen_output)[0, 0, 0, 0])
266    self.assertEqual(acc[-1], 99)
267    # One epoch is completed so enqueuer will switch the Sequence
268
269    acc = []
270    for _ in range(100):
271      acc.append(next(gen_output2)[0, 0, 0, 0])
272    self.assertEqual(acc[-1], 99 * 15)
273    # One epoch has been completed so enqueuer2 will switch
274
275    # Be sure that both Sequence were updated
276    self.assertEqual(next(gen_output)[0, 0, 0, 0], 0)
277    self.assertEqual(next(gen_output)[0, 0, 0, 0], 5)
278    self.assertEqual(next(gen_output2)[0, 0, 0, 0], 0)
279    self.assertEqual(next(gen_output2)[0, 0, 0, 0], 15 * 5)
280
281    # Tear down everything
282    enqueuer.stop()
283    enqueuer2.stop()
284
285  def test_on_epoch_end_threads(self):
286    enqueuer = keras.utils.data_utils.OrderedEnqueuer(
287        TestSequence([3, 200, 200, 3]), use_multiprocessing=False)
288    enqueuer.start(3, 10)
289    gen_output = enqueuer.get()
290    acc = []
291    for _ in range(100):
292      acc.append(next(gen_output)[0, 0, 0, 0])
293    acc = []
294    for _ in range(100):
295      acc.append(next(gen_output)[0, 0, 0, 0])
296    # Check that order was keep in GeneratorEnqueuer with processes
297    self.assertEqual(acc, list([k * 5 for k in range(100)]))
298    enqueuer.stop()
299
300
301if __name__ == '__main__':
302  # Bazel sets these environment variables to very long paths.
303  # Tempfile uses them to create long paths, and in turn multiprocessing
304  # library tries to create sockets named after paths. Delete whatever bazel
305  # writes to these to avoid tests failing due to socket addresses being too
306  # long.
307  for var in ('TMPDIR', 'TMP', 'TEMP'):
308    if var in os.environ:
309      del os.environ[var]
310
311  test.main()
312