1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for io_utils.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22import shutil 23 24import numpy as np 25import six 26 27from tensorflow.python import keras 28from tensorflow.python.platform import test 29 30try: 31 import h5py # pylint:disable=g-import-not-at-top 32except ImportError: 33 h5py = None 34 35 36def create_dataset(h5_path='test.h5'): 37 x = np.random.randn(200, 10).astype('float32') 38 y = np.random.randint(0, 2, size=(200, 1)) 39 f = h5py.File(h5_path, 'w') 40 # Creating dataset to store features 41 x_dset = f.create_dataset('my_data', (200, 10), dtype='f') 42 x_dset[:] = x 43 # Creating dataset to store labels 44 y_dset = f.create_dataset('my_labels', (200, 1), dtype='i') 45 y_dset[:] = y 46 f.close() 47 48 49class TestIOUtils(test.TestCase): 50 51 def test_HDF5Matrix(self): 52 if h5py is None: 53 return 54 55 temp_dir = self.get_temp_dir() 56 self.addCleanup(shutil.rmtree, temp_dir) 57 58 h5_path = os.path.join(temp_dir, 'test.h5') 59 create_dataset(h5_path) 60 61 # Instantiating HDF5Matrix for the training set, 62 # which is a slice of the first 150 elements 63 x_train = keras.utils.io_utils.HDF5Matrix( 64 h5_path, 'my_data', start=0, end=150) 65 y_train = keras.utils.io_utils.HDF5Matrix( 66 h5_path, 'my_labels', start=0, end=150) 67 68 # Likewise for the test set 69 x_test = keras.utils.io_utils.HDF5Matrix( 70 h5_path, 'my_data', start=150, end=200) 71 y_test = keras.utils.io_utils.HDF5Matrix( 72 h5_path, 'my_labels', start=150, end=200) 73 74 # HDF5Matrix behave more or less like Numpy matrices 75 # with regard to indexing 76 self.assertEqual(y_train.shape, (150, 1)) 77 # But they do not support negative indices, so don't try print(x_train[-1]) 78 79 self.assertEqual(y_train.dtype, np.dtype('i')) 80 self.assertEqual(y_train.ndim, 2) 81 self.assertEqual(y_train.size, 150) 82 83 model = keras.models.Sequential() 84 model.add(keras.layers.Dense(64, input_shape=(10,), activation='relu')) 85 model.add(keras.layers.Dense(1, activation='sigmoid')) 86 model.compile(loss='binary_crossentropy', optimizer='sgd') 87 88 # Note: you have to use shuffle='batch' or False with HDF5Matrix 89 model.fit(x_train, y_train, batch_size=32, shuffle='batch', verbose=False) 90 # test that evalutation and prediction 91 # don't crash and return reasonable results 92 out_pred = model.predict(x_test, batch_size=32, verbose=False) 93 out_eval = model.evaluate(x_test, y_test, batch_size=32, verbose=False) 94 95 self.assertEqual(out_pred.shape, (50, 1)) 96 self.assertEqual(out_eval.shape, ()) 97 self.assertGreater(out_eval, 0) 98 99 # test slicing for shortened array 100 self.assertEqual(len(x_train[0:]), len(x_train)) 101 102 # test __getitem__ invalid use cases 103 with self.assertRaises(IndexError): 104 _ = x_train[1000] 105 with self.assertRaises(IndexError): 106 _ = x_train[1000: 1001] 107 with self.assertRaises(IndexError): 108 _ = x_train[[1000, 1001]] 109 with self.assertRaises(IndexError): 110 _ = x_train[six.moves.range(1000, 1001)] 111 with self.assertRaises(IndexError): 112 _ = x_train[np.array([1000])] 113 with self.assertRaises(TypeError): 114 _ = x_train[None] 115 116 # test normalizer 117 normalizer = lambda x: x + 1 118 normalized_x_train = keras.utils.io_utils.HDF5Matrix( 119 h5_path, 'my_data', start=0, end=150, normalizer=normalizer) 120 self.assertAllClose(normalized_x_train[0][0], x_train[0][0] + 1) 121 122 123if __name__ == '__main__': 124 test.main() 125