1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for timeseries.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.compat import v2_compat 24from tensorflow.python.keras.preprocessing import timeseries 25from tensorflow.python.platform import test 26 27 28class TimeseriesDatasetTest(test.TestCase): 29 30 def test_basics(self): 31 # Test ordering, targets, sequence length, batch size 32 data = np.arange(100) 33 targets = data * 2 34 dataset = timeseries.timeseries_dataset_from_array( 35 data, targets, sequence_length=9, batch_size=5) 36 # Expect 19 batches 37 for i, batch in enumerate(dataset): 38 self.assertLen(batch, 2) 39 inputs, targets = batch 40 if i < 18: 41 self.assertEqual(inputs.shape, (5, 9)) 42 if i == 18: 43 # Last batch: size 2 44 self.assertEqual(inputs.shape, (2, 9)) 45 # Check target values 46 self.assertAllClose(targets, inputs[:, 0] * 2) 47 for j in range(min(5, len(inputs))): 48 # Check each sample in the batch 49 self.assertAllClose(inputs[j], np.arange(i * 5 + j, i * 5 + j + 9)) 50 51 def test_timeseries_regression(self): 52 # Test simple timeseries regression use case 53 data = np.arange(10) 54 offset = 3 55 targets = data[offset:] 56 dataset = timeseries.timeseries_dataset_from_array( 57 data, targets, sequence_length=offset, batch_size=1) 58 i = 0 59 for batch in dataset: 60 self.assertLen(batch, 2) 61 inputs, targets = batch 62 self.assertEqual(inputs.shape, (1, 3)) 63 # Check values 64 self.assertAllClose(targets[0], data[offset + i]) 65 self.assertAllClose(inputs[0], data[i : i + offset]) 66 i += 1 67 self.assertEqual(i, 7) # Expect 7 batches 68 69 def test_no_targets(self): 70 data = np.arange(50) 71 dataset = timeseries.timeseries_dataset_from_array( 72 data, None, sequence_length=10, batch_size=5) 73 # Expect 9 batches 74 i = None 75 for i, batch in enumerate(dataset): 76 if i < 8: 77 self.assertEqual(batch.shape, (5, 10)) 78 elif i == 8: 79 self.assertEqual(batch.shape, (1, 10)) 80 for j in range(min(5, len(batch))): 81 # Check each sample in the batch 82 self.assertAllClose(batch[j], np.arange(i * 5 + j, i * 5 + j + 10)) 83 self.assertEqual(i, 8) 84 85 def test_shuffle(self): 86 # Test cross-epoch random order and seed determinism 87 data = np.arange(10) 88 targets = data * 2 89 dataset = timeseries.timeseries_dataset_from_array( 90 data, targets, sequence_length=5, batch_size=1, shuffle=True, seed=123) 91 first_seq = None 92 for x, y in dataset.take(1): 93 self.assertNotAllClose(x, np.arange(0, 5)) 94 self.assertAllClose(x[:, 0] * 2, y) 95 first_seq = x 96 # Check that a new iteration with the same dataset yields different results 97 for x, _ in dataset.take(1): 98 self.assertNotAllClose(x, first_seq) 99 # Check determism with same seed 100 dataset = timeseries.timeseries_dataset_from_array( 101 data, targets, sequence_length=5, batch_size=1, shuffle=True, seed=123) 102 for x, _ in dataset.take(1): 103 self.assertAllClose(x, first_seq) 104 105 def test_sampling_rate(self): 106 data = np.arange(100) 107 targets = data * 2 108 dataset = timeseries.timeseries_dataset_from_array( 109 data, targets, sequence_length=9, batch_size=5, sampling_rate=2) 110 for i, batch in enumerate(dataset): 111 self.assertLen(batch, 2) 112 inputs, targets = batch 113 if i < 16: 114 self.assertEqual(inputs.shape, (5, 9)) 115 if i == 16: 116 # Last batch: size 3 117 self.assertEqual(inputs.shape, (3, 9)) 118 # Check target values 119 self.assertAllClose(inputs[:, 0] * 2, targets) 120 for j in range(min(5, len(inputs))): 121 # Check each sample in the batch 122 start_index = i * 5 + j 123 end_index = start_index + 9 * 2 124 self.assertAllClose(inputs[j], np.arange(start_index, end_index, 2)) 125 126 def test_sequence_stride(self): 127 data = np.arange(100) 128 targets = data * 2 129 dataset = timeseries.timeseries_dataset_from_array( 130 data, targets, sequence_length=9, batch_size=5, sequence_stride=3) 131 for i, batch in enumerate(dataset): 132 self.assertLen(batch, 2) 133 inputs, targets = batch 134 if i < 6: 135 self.assertEqual(inputs.shape, (5, 9)) 136 if i == 6: 137 # Last batch: size 1 138 self.assertEqual(inputs.shape, (1, 9)) 139 # Check target values 140 self.assertAllClose(inputs[:, 0] * 2, targets) 141 for j in range(min(5, len(inputs))): 142 # Check each sample in the batch 143 start_index = i * 5 * 3 + j * 3 144 end_index = start_index + 9 145 self.assertAllClose(inputs[j], 146 np.arange(start_index, end_index)) 147 148 def test_start_and_end_index(self): 149 data = np.arange(100) 150 dataset = timeseries.timeseries_dataset_from_array( 151 data, None, 152 sequence_length=9, batch_size=5, sequence_stride=3, sampling_rate=2, 153 start_index=10, end_index=90) 154 for batch in dataset: 155 self.assertAllLess(batch[0], 90) 156 self.assertAllGreater(batch[0], 9) 157 158 def test_errors(self): 159 # bad start index 160 with self.assertRaisesRegex(ValueError, 'start_index must be '): 161 _ = timeseries.timeseries_dataset_from_array( 162 np.arange(10), None, 3, start_index=-1) 163 with self.assertRaisesRegex(ValueError, 'start_index must be '): 164 _ = timeseries.timeseries_dataset_from_array( 165 np.arange(10), None, 3, start_index=11) 166 # bad end index 167 with self.assertRaisesRegex(ValueError, 'end_index must be '): 168 _ = timeseries.timeseries_dataset_from_array( 169 np.arange(10), None, 3, end_index=-1) 170 with self.assertRaisesRegex(ValueError, 'end_index must be '): 171 _ = timeseries.timeseries_dataset_from_array( 172 np.arange(10), None, 3, end_index=11) 173 # bad sampling_rate 174 with self.assertRaisesRegex(ValueError, 'sampling_rate must be '): 175 _ = timeseries.timeseries_dataset_from_array( 176 np.arange(10), None, 3, sampling_rate=0) 177 # bad sequence stride 178 with self.assertRaisesRegex(ValueError, 'sequence_stride must be '): 179 _ = timeseries.timeseries_dataset_from_array( 180 np.arange(10), None, 3, sequence_stride=0) 181 182 183if __name__ == '__main__': 184 v2_compat.enable_v2_behavior() 185 test.main() 186