1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for timeseries."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.compat import v2_compat
24from tensorflow.python.keras.preprocessing import timeseries
25from tensorflow.python.platform import test
26
27
28class TimeseriesDatasetTest(test.TestCase):
29
30  def test_basics(self):
31    # Test ordering, targets, sequence length, batch size
32    data = np.arange(100)
33    targets = data * 2
34    dataset = timeseries.timeseries_dataset_from_array(
35        data, targets, sequence_length=9, batch_size=5)
36    # Expect 19 batches
37    for i, batch in enumerate(dataset):
38      self.assertLen(batch, 2)
39      inputs, targets = batch
40      if i < 18:
41        self.assertEqual(inputs.shape, (5, 9))
42      if i == 18:
43        # Last batch: size 2
44        self.assertEqual(inputs.shape, (2, 9))
45      # Check target values
46      self.assertAllClose(targets, inputs[:, 0] * 2)
47      for j in range(min(5, len(inputs))):
48        # Check each sample in the batch
49        self.assertAllClose(inputs[j], np.arange(i * 5 + j, i * 5 + j + 9))
50
51  def test_timeseries_regression(self):
52    # Test simple timeseries regression use case
53    data = np.arange(10)
54    offset = 3
55    targets = data[offset:]
56    dataset = timeseries.timeseries_dataset_from_array(
57        data, targets, sequence_length=offset, batch_size=1)
58    i = 0
59    for batch in dataset:
60      self.assertLen(batch, 2)
61      inputs, targets = batch
62      self.assertEqual(inputs.shape, (1, 3))
63      # Check values
64      self.assertAllClose(targets[0], data[offset + i])
65      self.assertAllClose(inputs[0], data[i : i + offset])
66      i += 1
67    self.assertEqual(i, 7)  # Expect 7 batches
68
69  def test_no_targets(self):
70    data = np.arange(50)
71    dataset = timeseries.timeseries_dataset_from_array(
72        data, None, sequence_length=10, batch_size=5)
73    # Expect 9 batches
74    i = None
75    for i, batch in enumerate(dataset):
76      if i < 8:
77        self.assertEqual(batch.shape, (5, 10))
78      elif i == 8:
79        self.assertEqual(batch.shape, (1, 10))
80      for j in range(min(5, len(batch))):
81        # Check each sample in the batch
82        self.assertAllClose(batch[j], np.arange(i * 5 + j, i * 5 + j + 10))
83    self.assertEqual(i, 8)
84
85  def test_shuffle(self):
86    # Test cross-epoch random order and seed determinism
87    data = np.arange(10)
88    targets = data * 2
89    dataset = timeseries.timeseries_dataset_from_array(
90        data, targets, sequence_length=5, batch_size=1, shuffle=True, seed=123)
91    first_seq = None
92    for x, y in dataset.take(1):
93      self.assertNotAllClose(x, np.arange(0, 5))
94      self.assertAllClose(x[:, 0] * 2, y)
95      first_seq = x
96    # Check that a new iteration with the same dataset yields different results
97    for x, _ in dataset.take(1):
98      self.assertNotAllClose(x, first_seq)
99    # Check determism with same seed
100    dataset = timeseries.timeseries_dataset_from_array(
101        data, targets, sequence_length=5, batch_size=1, shuffle=True, seed=123)
102    for x, _ in dataset.take(1):
103      self.assertAllClose(x, first_seq)
104
105  def test_sampling_rate(self):
106    data = np.arange(100)
107    targets = data * 2
108    dataset = timeseries.timeseries_dataset_from_array(
109        data, targets, sequence_length=9, batch_size=5, sampling_rate=2)
110    for i, batch in enumerate(dataset):
111      self.assertLen(batch, 2)
112      inputs, targets = batch
113      if i < 16:
114        self.assertEqual(inputs.shape, (5, 9))
115      if i == 16:
116        # Last batch: size 3
117        self.assertEqual(inputs.shape, (3, 9))
118      # Check target values
119      self.assertAllClose(inputs[:, 0] * 2, targets)
120      for j in range(min(5, len(inputs))):
121        # Check each sample in the batch
122        start_index = i * 5 + j
123        end_index = start_index + 9 * 2
124        self.assertAllClose(inputs[j], np.arange(start_index, end_index, 2))
125
126  def test_sequence_stride(self):
127    data = np.arange(100)
128    targets = data * 2
129    dataset = timeseries.timeseries_dataset_from_array(
130        data, targets, sequence_length=9, batch_size=5, sequence_stride=3)
131    for i, batch in enumerate(dataset):
132      self.assertLen(batch, 2)
133      inputs, targets = batch
134      if i < 6:
135        self.assertEqual(inputs.shape, (5, 9))
136      if i == 6:
137        # Last batch: size 1
138        self.assertEqual(inputs.shape, (1, 9))
139      # Check target values
140      self.assertAllClose(inputs[:, 0] * 2, targets)
141      for j in range(min(5, len(inputs))):
142        # Check each sample in the batch
143        start_index = i * 5 * 3 + j * 3
144        end_index = start_index + 9
145        self.assertAllClose(inputs[j],
146                            np.arange(start_index, end_index))
147
148  def test_start_and_end_index(self):
149    data = np.arange(100)
150    dataset = timeseries.timeseries_dataset_from_array(
151        data, None,
152        sequence_length=9, batch_size=5, sequence_stride=3, sampling_rate=2,
153        start_index=10, end_index=90)
154    for batch in dataset:
155      self.assertAllLess(batch[0], 90)
156      self.assertAllGreater(batch[0], 9)
157
158  def test_errors(self):
159    # bad start index
160    with self.assertRaisesRegex(ValueError, 'start_index must be '):
161      _ = timeseries.timeseries_dataset_from_array(
162          np.arange(10), None, 3, start_index=-1)
163    with self.assertRaisesRegex(ValueError, 'start_index must be '):
164      _ = timeseries.timeseries_dataset_from_array(
165          np.arange(10), None, 3, start_index=11)
166    # bad end index
167    with self.assertRaisesRegex(ValueError, 'end_index must be '):
168      _ = timeseries.timeseries_dataset_from_array(
169          np.arange(10), None, 3, end_index=-1)
170    with self.assertRaisesRegex(ValueError, 'end_index must be '):
171      _ = timeseries.timeseries_dataset_from_array(
172          np.arange(10), None, 3, end_index=11)
173    # bad sampling_rate
174    with self.assertRaisesRegex(ValueError, 'sampling_rate must be '):
175      _ = timeseries.timeseries_dataset_from_array(
176          np.arange(10), None, 3, sampling_rate=0)
177    # bad sequence stride
178    with self.assertRaisesRegex(ValueError, 'sequence_stride must be '):
179      _ = timeseries.timeseries_dataset_from_array(
180          np.arange(10), None, 3, sequence_stride=0)
181
182
183if __name__ == '__main__':
184  v2_compat.enable_v2_behavior()
185  test.main()
186