1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for data_utils.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from itertools import cycle 22import os 23import tarfile 24import threading 25import unittest 26import zipfile 27 28import numpy as np 29from six.moves.urllib.parse import urljoin 30from six.moves.urllib.request import pathname2url 31 32from tensorflow.python import keras 33from tensorflow.python.platform import test 34 35 36class TestGetFileAndValidateIt(test.TestCase): 37 38 def test_get_file_and_validate_it(self): 39 """Tests get_file from a url, plus extraction and validation. 40 """ 41 dest_dir = self.get_temp_dir() 42 orig_dir = self.get_temp_dir() 43 44 text_file_path = os.path.join(orig_dir, 'test.txt') 45 zip_file_path = os.path.join(orig_dir, 'test.zip') 46 tar_file_path = os.path.join(orig_dir, 'test.tar.gz') 47 48 with open(text_file_path, 'w') as text_file: 49 text_file.write('Float like a butterfly, sting like a bee.') 50 51 with tarfile.open(tar_file_path, 'w:gz') as tar_file: 52 tar_file.add(text_file_path) 53 54 with zipfile.ZipFile(zip_file_path, 'w') as zip_file: 55 zip_file.write(text_file_path) 56 57 origin = urljoin('file://', pathname2url(os.path.abspath(tar_file_path))) 58 59 path = keras.utils.data_utils.get_file('test.txt', origin, 60 untar=True, cache_subdir=dest_dir) 61 filepath = path + '.tar.gz' 62 hashval_sha256 = keras.utils.data_utils._hash_file(filepath) 63 hashval_md5 = keras.utils.data_utils._hash_file(filepath, algorithm='md5') 64 path = keras.utils.data_utils.get_file( 65 'test.txt', origin, md5_hash=hashval_md5, 66 untar=True, cache_subdir=dest_dir) 67 path = keras.utils.data_utils.get_file( 68 filepath, origin, file_hash=hashval_sha256, 69 extract=True, cache_subdir=dest_dir) 70 self.assertTrue(os.path.exists(filepath)) 71 self.assertTrue(keras.utils.data_utils.validate_file(filepath, 72 hashval_sha256)) 73 self.assertTrue(keras.utils.data_utils.validate_file(filepath, hashval_md5)) 74 os.remove(filepath) 75 76 origin = urljoin('file://', pathname2url(os.path.abspath(zip_file_path))) 77 78 hashval_sha256 = keras.utils.data_utils._hash_file(zip_file_path) 79 hashval_md5 = keras.utils.data_utils._hash_file(zip_file_path, 80 algorithm='md5') 81 path = keras.utils.data_utils.get_file( 82 'test', origin, md5_hash=hashval_md5, 83 extract=True, cache_subdir=dest_dir) 84 path = keras.utils.data_utils.get_file( 85 'test', origin, file_hash=hashval_sha256, 86 extract=True, cache_subdir=dest_dir) 87 self.assertTrue(os.path.exists(path)) 88 self.assertTrue(keras.utils.data_utils.validate_file(path, hashval_sha256)) 89 self.assertTrue(keras.utils.data_utils.validate_file(path, hashval_md5)) 90 91 92class ThreadsafeIter(object): 93 94 def __init__(self, it): 95 self.it = it 96 self.lock = threading.Lock() 97 98 def __iter__(self): 99 return self 100 101 def __next__(self): 102 return self.next() 103 104 def next(self): 105 with self.lock: 106 return next(self.it) 107 108 109def threadsafe_generator(f): 110 111 def g(*a, **kw): 112 return ThreadsafeIter(f(*a, **kw)) 113 114 return g 115 116 117class TestSequence(keras.utils.data_utils.Sequence): 118 119 def __init__(self, shape, value=1.): 120 self.shape = shape 121 self.inner = value 122 123 def __getitem__(self, item): 124 return np.ones(self.shape, dtype=np.uint32) * item * self.inner 125 126 def __len__(self): 127 return 100 128 129 def on_epoch_end(self): 130 self.inner *= 5.0 131 132 133class FaultSequence(keras.utils.data_utils.Sequence): 134 135 def __getitem__(self, item): 136 raise IndexError(item, 'item is not present') 137 138 def __len__(self): 139 return 100 140 141 142@threadsafe_generator 143def create_generator_from_sequence_threads(ds): 144 for i in cycle(range(len(ds))): 145 yield ds[i] 146 147 148def create_generator_from_sequence_pcs(ds): 149 for i in cycle(range(len(ds))): 150 yield ds[i] 151 152 153class TestEnqueuers(test.TestCase): 154 155 def test_generator_enqueuer_threads(self): 156 enqueuer = keras.utils.data_utils.GeneratorEnqueuer( 157 create_generator_from_sequence_threads(TestSequence([3, 200, 200, 3])), 158 use_multiprocessing=False) 159 enqueuer.start(3, 10) 160 gen_output = enqueuer.get() 161 acc = [] 162 for _ in range(100): 163 acc.append(int(next(gen_output)[0, 0, 0, 0])) 164 165 self.assertEqual(len(set(acc) - set(range(100))), 0) 166 enqueuer.stop() 167 168 @unittest.skipIf( 169 os.name == 'nt', 170 'use_multiprocessing=True does not work on windows properly.') 171 def test_generator_enqueuer_processes(self): 172 enqueuer = keras.utils.data_utils.GeneratorEnqueuer( 173 create_generator_from_sequence_pcs(TestSequence([3, 200, 200, 3])), 174 use_multiprocessing=True) 175 enqueuer.start(3, 10) 176 gen_output = enqueuer.get() 177 acc = [] 178 for _ in range(100): 179 acc.append(int(next(gen_output)[0, 0, 0, 0])) 180 self.assertNotEqual(acc, list(range(100))) 181 enqueuer.stop() 182 183 def test_generator_enqueuer_fail_threads(self): 184 enqueuer = keras.utils.data_utils.GeneratorEnqueuer( 185 create_generator_from_sequence_threads(FaultSequence()), 186 use_multiprocessing=False) 187 enqueuer.start(3, 10) 188 gen_output = enqueuer.get() 189 with self.assertRaises(IndexError): 190 next(gen_output) 191 192 @unittest.skipIf( 193 os.name == 'nt', 194 'use_multiprocessing=True does not work on windows properly.') 195 def test_generator_enqueuer_fail_processes(self): 196 enqueuer = keras.utils.data_utils.GeneratorEnqueuer( 197 create_generator_from_sequence_pcs(FaultSequence()), 198 use_multiprocessing=True) 199 enqueuer.start(3, 10) 200 gen_output = enqueuer.get() 201 with self.assertRaises(IndexError): 202 next(gen_output) 203 204 def test_ordered_enqueuer_threads(self): 205 enqueuer = keras.utils.data_utils.OrderedEnqueuer( 206 TestSequence([3, 200, 200, 3]), use_multiprocessing=False) 207 enqueuer.start(3, 10) 208 gen_output = enqueuer.get() 209 acc = [] 210 for _ in range(100): 211 acc.append(next(gen_output)[0, 0, 0, 0]) 212 self.assertEqual(acc, list(range(100))) 213 enqueuer.stop() 214 215 def test_ordered_enqueuer_processes(self): 216 enqueuer = keras.utils.data_utils.OrderedEnqueuer( 217 TestSequence([3, 200, 200, 3]), use_multiprocessing=True) 218 enqueuer.start(3, 10) 219 gen_output = enqueuer.get() 220 acc = [] 221 for _ in range(100): 222 acc.append(next(gen_output)[0, 0, 0, 0]) 223 self.assertEqual(acc, list(range(100))) 224 enqueuer.stop() 225 226 def test_ordered_enqueuer_fail_threads(self): 227 enqueuer = keras.utils.data_utils.OrderedEnqueuer( 228 FaultSequence(), use_multiprocessing=False) 229 enqueuer.start(3, 10) 230 gen_output = enqueuer.get() 231 with self.assertRaises(IndexError): 232 next(gen_output) 233 234 def test_ordered_enqueuer_fail_processes(self): 235 enqueuer = keras.utils.data_utils.OrderedEnqueuer( 236 FaultSequence(), use_multiprocessing=True) 237 enqueuer.start(3, 10) 238 gen_output = enqueuer.get() 239 with self.assertRaises(IndexError): 240 next(gen_output) 241 242 def test_on_epoch_end_processes(self): 243 enqueuer = keras.utils.data_utils.OrderedEnqueuer( 244 TestSequence([3, 200, 200, 3]), use_multiprocessing=True) 245 enqueuer.start(3, 10) 246 gen_output = enqueuer.get() 247 acc = [] 248 for _ in range(200): 249 acc.append(next(gen_output)[0, 0, 0, 0]) 250 # Check that order was keep in GeneratorEnqueuer with processes 251 self.assertEqual(acc[100:], list([k * 5 for k in range(100)])) 252 enqueuer.stop() 253 254 def test_context_switch(self): 255 enqueuer = keras.utils.data_utils.OrderedEnqueuer( 256 TestSequence([3, 200, 200, 3]), use_multiprocessing=True) 257 enqueuer2 = keras.utils.data_utils.OrderedEnqueuer( 258 TestSequence([3, 200, 200, 3], value=15), use_multiprocessing=True) 259 enqueuer.start(3, 10) 260 enqueuer2.start(3, 10) 261 gen_output = enqueuer.get() 262 gen_output2 = enqueuer2.get() 263 acc = [] 264 for _ in range(100): 265 acc.append(next(gen_output)[0, 0, 0, 0]) 266 self.assertEqual(acc[-1], 99) 267 # One epoch is completed so enqueuer will switch the Sequence 268 269 acc = [] 270 for _ in range(100): 271 acc.append(next(gen_output2)[0, 0, 0, 0]) 272 self.assertEqual(acc[-1], 99 * 15) 273 # One epoch has been completed so enqueuer2 will switch 274 275 # Be sure that both Sequence were updated 276 self.assertEqual(next(gen_output)[0, 0, 0, 0], 0) 277 self.assertEqual(next(gen_output)[0, 0, 0, 0], 5) 278 self.assertEqual(next(gen_output2)[0, 0, 0, 0], 0) 279 self.assertEqual(next(gen_output2)[0, 0, 0, 0], 15 * 5) 280 281 # Tear down everything 282 enqueuer.stop() 283 enqueuer2.stop() 284 285 def test_on_epoch_end_threads(self): 286 enqueuer = keras.utils.data_utils.OrderedEnqueuer( 287 TestSequence([3, 200, 200, 3]), use_multiprocessing=False) 288 enqueuer.start(3, 10) 289 gen_output = enqueuer.get() 290 acc = [] 291 for _ in range(100): 292 acc.append(next(gen_output)[0, 0, 0, 0]) 293 acc = [] 294 for _ in range(100): 295 acc.append(next(gen_output)[0, 0, 0, 0]) 296 # Check that order was keep in GeneratorEnqueuer with processes 297 self.assertEqual(acc, list([k * 5 for k in range(100)])) 298 enqueuer.stop() 299 300 301if __name__ == '__main__': 302 # Bazel sets these environment variables to very long paths. 303 # Tempfile uses them to create long paths, and in turn multiprocessing 304 # library tries to create sockets named after paths. Delete whatever bazel 305 # writes to these to avoid tests failing due to socket addresses being too 306 # long. 307 for var in ('TMPDIR', 'TMP', 'TEMP'): 308 if var in os.environ: 309 del os.environ[var] 310 311 test.main() 312