1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for text_dataset.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22import random 23import shutil 24import string 25 26from tensorflow.python.compat import v2_compat 27from tensorflow.python.keras import keras_parameterized 28from tensorflow.python.keras.preprocessing import text_dataset 29from tensorflow.python.platform import test 30 31 32class TextDatasetFromDirectoryTest(keras_parameterized.TestCase): 33 34 def _prepare_directory(self, 35 num_classes=2, 36 nested_dirs=False, 37 count=16, 38 length=20): 39 # Get a unique temp directory 40 temp_dir = os.path.join(self.get_temp_dir(), str(random.randint(0, 1e6))) 41 os.mkdir(temp_dir) 42 self.addCleanup(shutil.rmtree, temp_dir) 43 44 # Generate paths to class subdirectories 45 paths = [] 46 for class_index in range(num_classes): 47 class_directory = 'class_%s' % (class_index,) 48 if nested_dirs: 49 class_paths = [ 50 class_directory, os.path.join(class_directory, 'subfolder_1'), 51 os.path.join(class_directory, 'subfolder_2'), os.path.join( 52 class_directory, 'subfolder_1', 'sub-subfolder') 53 ] 54 else: 55 class_paths = [class_directory] 56 for path in class_paths: 57 os.mkdir(os.path.join(temp_dir, path)) 58 paths += class_paths 59 60 for i in range(count): 61 path = paths[i % len(paths)] 62 filename = os.path.join(path, 'text_%s.txt' % (i,)) 63 f = open(os.path.join(temp_dir, filename), 'w') 64 text = ''.join([random.choice(string.printable) for _ in range(length)]) 65 f.write(text) 66 f.close() 67 return temp_dir 68 69 def test_text_dataset_from_directory_standalone(self): 70 # Test retrieving txt files without labels from a directory and its subdirs. 71 # Save a few extra files in the parent directory. 72 directory = self._prepare_directory(count=7, num_classes=2) 73 for i in range(3): 74 filename = 'text_%s.txt' % (i,) 75 f = open(os.path.join(directory, filename), 'w') 76 text = ''.join([random.choice(string.printable) for _ in range(20)]) 77 f.write(text) 78 f.close() 79 80 dataset = text_dataset.text_dataset_from_directory( 81 directory, batch_size=5, label_mode=None, max_length=10) 82 batch = next(iter(dataset)) 83 # We just return the texts, no labels 84 self.assertEqual(batch.shape, (5,)) 85 self.assertEqual(batch.dtype.name, 'string') 86 # Count samples 87 batch_count = 0 88 sample_count = 0 89 for batch in dataset: 90 batch_count += 1 91 sample_count += batch.shape[0] 92 self.assertEqual(batch_count, 2) 93 self.assertEqual(sample_count, 10) 94 95 def test_text_dataset_from_directory_binary(self): 96 directory = self._prepare_directory(num_classes=2) 97 dataset = text_dataset.text_dataset_from_directory( 98 directory, batch_size=8, label_mode='int', max_length=10) 99 batch = next(iter(dataset)) 100 self.assertLen(batch, 2) 101 self.assertEqual(batch[0].shape, (8,)) 102 self.assertEqual(batch[0].dtype.name, 'string') 103 self.assertEqual(len(batch[0].numpy()[0]), 10) # Test max_length 104 self.assertEqual(batch[1].shape, (8,)) 105 self.assertEqual(batch[1].dtype.name, 'int32') 106 107 dataset = text_dataset.text_dataset_from_directory( 108 directory, batch_size=8, label_mode='binary') 109 batch = next(iter(dataset)) 110 self.assertLen(batch, 2) 111 self.assertEqual(batch[0].shape, (8,)) 112 self.assertEqual(batch[0].dtype.name, 'string') 113 self.assertEqual(batch[1].shape, (8, 1)) 114 self.assertEqual(batch[1].dtype.name, 'float32') 115 116 dataset = text_dataset.text_dataset_from_directory( 117 directory, batch_size=8, label_mode='categorical') 118 batch = next(iter(dataset)) 119 self.assertLen(batch, 2) 120 self.assertEqual(batch[0].shape, (8,)) 121 self.assertEqual(batch[0].dtype.name, 'string') 122 self.assertEqual(batch[1].shape, (8, 2)) 123 self.assertEqual(batch[1].dtype.name, 'float32') 124 125 def test_sample_count(self): 126 directory = self._prepare_directory(num_classes=4, count=15) 127 dataset = text_dataset.text_dataset_from_directory( 128 directory, batch_size=8, label_mode=None) 129 sample_count = 0 130 for batch in dataset: 131 sample_count += batch.shape[0] 132 self.assertEqual(sample_count, 15) 133 134 def test_text_dataset_from_directory_multiclass(self): 135 directory = self._prepare_directory(num_classes=4, count=15) 136 137 dataset = text_dataset.text_dataset_from_directory( 138 directory, batch_size=8, label_mode=None) 139 batch = next(iter(dataset)) 140 self.assertEqual(batch.shape, (8,)) 141 142 dataset = text_dataset.text_dataset_from_directory( 143 directory, batch_size=8, label_mode=None) 144 sample_count = 0 145 iterator = iter(dataset) 146 for batch in dataset: 147 sample_count += next(iterator).shape[0] 148 self.assertEqual(sample_count, 15) 149 150 dataset = text_dataset.text_dataset_from_directory( 151 directory, batch_size=8, label_mode='int') 152 batch = next(iter(dataset)) 153 self.assertLen(batch, 2) 154 self.assertEqual(batch[0].shape, (8,)) 155 self.assertEqual(batch[0].dtype.name, 'string') 156 self.assertEqual(batch[1].shape, (8,)) 157 self.assertEqual(batch[1].dtype.name, 'int32') 158 159 dataset = text_dataset.text_dataset_from_directory( 160 directory, batch_size=8, label_mode='categorical') 161 batch = next(iter(dataset)) 162 self.assertLen(batch, 2) 163 self.assertEqual(batch[0].shape, (8,)) 164 self.assertEqual(batch[0].dtype.name, 'string') 165 self.assertEqual(batch[1].shape, (8, 4)) 166 self.assertEqual(batch[1].dtype.name, 'float32') 167 168 def test_text_dataset_from_directory_validation_split(self): 169 directory = self._prepare_directory(num_classes=2, count=10) 170 dataset = text_dataset.text_dataset_from_directory( 171 directory, batch_size=10, validation_split=0.2, subset='training', 172 seed=1337) 173 batch = next(iter(dataset)) 174 self.assertLen(batch, 2) 175 self.assertEqual(batch[0].shape, (8,)) 176 dataset = text_dataset.text_dataset_from_directory( 177 directory, batch_size=10, validation_split=0.2, subset='validation', 178 seed=1337) 179 batch = next(iter(dataset)) 180 self.assertLen(batch, 2) 181 self.assertEqual(batch[0].shape, (2,)) 182 183 def test_text_dataset_from_directory_manual_labels(self): 184 directory = self._prepare_directory(num_classes=2, count=2) 185 dataset = text_dataset.text_dataset_from_directory( 186 directory, batch_size=8, labels=[0, 1], shuffle=False) 187 batch = next(iter(dataset)) 188 self.assertLen(batch, 2) 189 self.assertAllClose(batch[1], [0, 1]) 190 191 def test_text_dataset_from_directory_follow_links(self): 192 directory = self._prepare_directory(num_classes=2, count=25, 193 nested_dirs=True) 194 dataset = text_dataset.text_dataset_from_directory( 195 directory, batch_size=8, label_mode=None, follow_links=True) 196 sample_count = 0 197 for batch in dataset: 198 sample_count += batch.shape[0] 199 self.assertEqual(sample_count, 25) 200 201 def test_text_dataset_from_directory_no_files(self): 202 directory = self._prepare_directory(num_classes=2, count=0) 203 with self.assertRaisesRegex(ValueError, 'No text files found.'): 204 _ = text_dataset.text_dataset_from_directory(directory) 205 206 def test_text_dataset_from_directory_errors(self): 207 directory = self._prepare_directory(num_classes=3, count=5) 208 209 with self.assertRaisesRegex(ValueError, '`labels` argument should be'): 210 _ = text_dataset.text_dataset_from_directory( 211 directory, labels='other') 212 213 with self.assertRaisesRegex(ValueError, '`label_mode` argument must be'): 214 _ = text_dataset.text_dataset_from_directory( 215 directory, label_mode='other') 216 217 with self.assertRaisesRegex( 218 ValueError, 'only pass `class_names` if the labels are inferred'): 219 _ = text_dataset.text_dataset_from_directory( 220 directory, labels=[0, 0, 1, 1, 1], 221 class_names=['class_0', 'class_1', 'class_2']) 222 223 with self.assertRaisesRegex( 224 ValueError, 225 'Expected the lengths of `labels` to match the number of files'): 226 _ = text_dataset.text_dataset_from_directory( 227 directory, labels=[0, 0, 1, 1]) 228 229 with self.assertRaisesRegex( 230 ValueError, '`class_names` passed did not match'): 231 _ = text_dataset.text_dataset_from_directory( 232 directory, class_names=['class_0', 'class_2']) 233 234 with self.assertRaisesRegex(ValueError, 'there must exactly 2 classes'): 235 _ = text_dataset.text_dataset_from_directory( 236 directory, label_mode='binary') 237 238 with self.assertRaisesRegex(ValueError, 239 '`validation_split` must be between 0 and 1'): 240 _ = text_dataset.text_dataset_from_directory( 241 directory, validation_split=2) 242 243 with self.assertRaisesRegex(ValueError, 244 '`subset` must be either "training" or'): 245 _ = text_dataset.text_dataset_from_directory( 246 directory, validation_split=0.2, subset='other') 247 248 with self.assertRaisesRegex(ValueError, '`validation_split` must be set'): 249 _ = text_dataset.text_dataset_from_directory( 250 directory, validation_split=0, subset='training') 251 252 with self.assertRaisesRegex(ValueError, 'must provide a `seed`'): 253 _ = text_dataset.text_dataset_from_directory( 254 directory, validation_split=0.2, subset='training') 255 256 257if __name__ == '__main__': 258 v2_compat.enable_v2_behavior() 259 test.main() 260