1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for state space model infrastructure."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22
23import numpy
24
25from tensorflow.contrib import layers
26
27from tensorflow.contrib.timeseries.python.timeseries import estimators
28from tensorflow.contrib.timeseries.python.timeseries import feature_keys
29from tensorflow.contrib.timeseries.python.timeseries import input_pipeline
30from tensorflow.contrib.timeseries.python.timeseries import math_utils
31from tensorflow.contrib.timeseries.python.timeseries import saved_model_utils
32from tensorflow.contrib.timeseries.python.timeseries import state_management
33from tensorflow.contrib.timeseries.python.timeseries import test_utils
34from tensorflow.contrib.timeseries.python.timeseries.state_space_models import state_space_model
35
36from tensorflow.python.estimator import estimator_lib
37from tensorflow.python.framework import constant_op
38from tensorflow.python.framework import dtypes
39from tensorflow.python.framework import ops
40from tensorflow.python.framework import random_seed
41from tensorflow.python.framework import tensor_shape
42from tensorflow.python.ops import array_ops
43from tensorflow.python.ops import linalg_ops
44from tensorflow.python.ops import math_ops
45from tensorflow.python.ops import variable_scope
46from tensorflow.python.ops import variables
47from tensorflow.python.platform import test
48from tensorflow.python.saved_model import loader
49from tensorflow.python.saved_model import tag_constants
50from tensorflow.python.training import coordinator as coordinator_lib
51from tensorflow.python.training import gradient_descent
52from tensorflow.python.training import queue_runner_impl
53
54
55class RandomStateSpaceModel(state_space_model.StateSpaceModel):
56
57  def __init__(self,
58               state_dimension,
59               state_noise_dimension,
60               configuration=state_space_model.StateSpaceModelConfiguration()):
61    self.transition = numpy.random.normal(
62        size=[state_dimension, state_dimension]).astype(
63            configuration.dtype.as_numpy_dtype)
64    self.noise_transform = numpy.random.normal(
65        size=(state_dimension, state_noise_dimension)).astype(
66            configuration.dtype.as_numpy_dtype)
67    # Test batch broadcasting
68    self.observation_model = numpy.random.normal(
69        size=(configuration.num_features, state_dimension)).astype(
70            configuration.dtype.as_numpy_dtype)
71    super(RandomStateSpaceModel, self).__init__(
72        configuration=configuration._replace(
73            covariance_prior_fn=lambda _: 0.))
74
75  def get_state_transition(self):
76    return self.transition
77
78  def get_noise_transform(self):
79    return self.noise_transform
80
81  def get_observation_model(self, times):
82    return self.observation_model
83
84
85class ConstructionTests(test.TestCase):
86
87  def test_initialize_graph_error(self):
88    with self.assertRaisesRegexp(ValueError, "initialize_graph"):
89      model = RandomStateSpaceModel(2, 2)
90      outputs = model.define_loss(
91          features={
92              feature_keys.TrainEvalFeatures.TIMES:
93                  constant_op.constant([[1, 2]]),
94              feature_keys.TrainEvalFeatures.VALUES:
95                  constant_op.constant([[[1.], [2.]]])
96          },
97          mode=estimator_lib.ModeKeys.TRAIN)
98      initializer = variables.global_variables_initializer()
99      with self.cached_session() as sess:
100        sess.run([initializer])
101        outputs.loss.eval()
102
103  def test_initialize_graph_state_manager_error(self):
104    with self.assertRaisesRegexp(ValueError, "initialize_graph"):
105      model = RandomStateSpaceModel(2, 2)
106      state_manager = state_management.ChainingStateManager()
107      outputs = state_manager.define_loss(
108          model=model,
109          features={
110              feature_keys.TrainEvalFeatures.TIMES:
111                  constant_op.constant([[1, 2]]),
112              feature_keys.TrainEvalFeatures.VALUES:
113                  constant_op.constant([[[1.], [2.]]])
114          },
115          mode=estimator_lib.ModeKeys.TRAIN)
116      initializer = variables.global_variables_initializer()
117      with self.cached_session() as sess:
118        sess.run([initializer])
119        outputs.loss.eval()
120
121
122class GapTests(test.TestCase):
123
124  def _gap_test_template(self, times, values):
125    random_model = RandomStateSpaceModel(
126        state_dimension=1, state_noise_dimension=1,
127        configuration=state_space_model.StateSpaceModelConfiguration(
128            num_features=1))
129    random_model.initialize_graph()
130    input_fn = input_pipeline.WholeDatasetInputFn(
131        input_pipeline.NumpyReader({
132            feature_keys.TrainEvalFeatures.TIMES: times,
133            feature_keys.TrainEvalFeatures.VALUES: values
134        }))
135    features, _ = input_fn()
136    times = features[feature_keys.TrainEvalFeatures.TIMES]
137    values = features[feature_keys.TrainEvalFeatures.VALUES]
138    model_outputs = random_model.get_batch_loss(
139        features={
140            feature_keys.TrainEvalFeatures.TIMES: times,
141            feature_keys.TrainEvalFeatures.VALUES: values
142        },
143        mode=None,
144        state=math_utils.replicate_state(
145            start_state=random_model.get_start_state(),
146            batch_size=array_ops.shape(times)[0]))
147    with self.cached_session() as session:
148      variables.global_variables_initializer().run()
149      coordinator = coordinator_lib.Coordinator()
150      queue_runner_impl.start_queue_runners(session, coord=coordinator)
151      model_outputs.loss.eval()
152      coordinator.request_stop()
153      coordinator.join()
154
155  def test_start_gap(self):
156    self._gap_test_template(times=[20, 21, 22], values=numpy.arange(3))
157
158  def test_mid_gap(self):
159    self._gap_test_template(times=[2, 60, 61], values=numpy.arange(3))
160
161  def test_end_gap(self):
162    self._gap_test_template(times=[2, 3, 73], values=numpy.arange(3))
163
164  def test_all_gaps(self):
165    self._gap_test_template(times=[2, 4, 8, 16, 32, 64, 128],
166                            values=numpy.arange(7))
167
168
169class StateSpaceEquivalenceTests(test.TestCase):
170
171  def test_savedmodel_state_override(self):
172    random_model = RandomStateSpaceModel(
173        state_dimension=5,
174        state_noise_dimension=4,
175        configuration=state_space_model.StateSpaceModelConfiguration(
176            exogenous_feature_columns=[layers.real_valued_column("exogenous")],
177            dtype=dtypes.float64, num_features=1))
178    estimator = estimators.StateSpaceRegressor(
179        model=random_model,
180        optimizer=gradient_descent.GradientDescentOptimizer(0.1))
181    combined_input_fn = input_pipeline.WholeDatasetInputFn(
182        input_pipeline.NumpyReader({
183            feature_keys.FilteringFeatures.TIMES: [1, 2, 3, 4],
184            feature_keys.FilteringFeatures.VALUES: [1., 2., 3., 4.],
185            "exogenous": [-1., -2., -3., -4.]
186        }))
187    estimator.train(combined_input_fn, steps=1)
188    export_location = estimator.export_saved_model(
189        self.get_temp_dir(), estimator.build_raw_serving_input_receiver_fn())
190    with ops.Graph().as_default() as graph:
191      random_model.initialize_graph()
192      with self.session(graph=graph) as session:
193        variables.global_variables_initializer().run()
194        evaled_start_state = session.run(random_model.get_start_state())
195    evaled_start_state = [
196        state_element[None, ...] for state_element in evaled_start_state]
197    with ops.Graph().as_default() as graph:
198      with self.session(graph=graph) as session:
199        signatures = loader.load(
200            session, [tag_constants.SERVING], export_location)
201        first_split_filtering = saved_model_utils.filter_continuation(
202            continue_from={
203                feature_keys.FilteringResults.STATE_TUPLE: evaled_start_state},
204            signatures=signatures,
205            session=session,
206            features={
207                feature_keys.FilteringFeatures.TIMES: [1, 2],
208                feature_keys.FilteringFeatures.VALUES: [1., 2.],
209                "exogenous": [[-1.], [-2.]]})
210        second_split_filtering = saved_model_utils.filter_continuation(
211            continue_from=first_split_filtering,
212            signatures=signatures,
213            session=session,
214            features={
215                feature_keys.FilteringFeatures.TIMES: [3, 4],
216                feature_keys.FilteringFeatures.VALUES: [3., 4.],
217                "exogenous": [[-3.], [-4.]]
218            })
219        combined_filtering = saved_model_utils.filter_continuation(
220            continue_from={
221                feature_keys.FilteringResults.STATE_TUPLE: evaled_start_state},
222            signatures=signatures,
223            session=session,
224            features={
225                feature_keys.FilteringFeatures.TIMES: [1, 2, 3, 4],
226                feature_keys.FilteringFeatures.VALUES: [1., 2., 3., 4.],
227                "exogenous": [[-1.], [-2.], [-3.], [-4.]]
228            })
229        split_predict = saved_model_utils.predict_continuation(
230            continue_from=second_split_filtering,
231            signatures=signatures,
232            session=session,
233            steps=1,
234            exogenous_features={
235                "exogenous": [[[-5.]]]})
236        combined_predict = saved_model_utils.predict_continuation(
237            continue_from=combined_filtering,
238            signatures=signatures,
239            session=session,
240            steps=1,
241            exogenous_features={
242                "exogenous": [[[-5.]]]})
243    for state_key, combined_state_value in combined_filtering.items():
244      if state_key == feature_keys.FilteringResults.TIMES:
245        continue
246      self.assertAllClose(
247          combined_state_value, second_split_filtering[state_key])
248    for prediction_key, combined_value in combined_predict.items():
249      self.assertAllClose(combined_value, split_predict[prediction_key])
250
251  def _equivalent_to_single_model_test_template(self, model_generator):
252    with self.cached_session() as session:
253      random_model = RandomStateSpaceModel(
254          state_dimension=5,
255          state_noise_dimension=4,
256          configuration=state_space_model.StateSpaceModelConfiguration(
257              dtype=dtypes.float64, num_features=1))
258      random_model.initialize_graph()
259      series_length = 10
260      model_data = random_model.generate(
261          number_of_series=1, series_length=series_length,
262          model_parameters=random_model.random_model_parameters())
263      input_fn = input_pipeline.WholeDatasetInputFn(
264          input_pipeline.NumpyReader(model_data))
265      features, _ = input_fn()
266      model_outputs = random_model.get_batch_loss(
267          features=features,
268          mode=None,
269          state=math_utils.replicate_state(
270              start_state=random_model.get_start_state(),
271              batch_size=array_ops.shape(
272                  features[feature_keys.TrainEvalFeatures.TIMES])[0]))
273      variables.global_variables_initializer().run()
274      compare_outputs_evaled_fn = model_generator(
275          random_model, model_data)
276      coordinator = coordinator_lib.Coordinator()
277      queue_runner_impl.start_queue_runners(session, coord=coordinator)
278      compare_outputs_evaled = compare_outputs_evaled_fn(session)
279      model_outputs_evaled = session.run(
280          (model_outputs.end_state, model_outputs.predictions))
281      coordinator.request_stop()
282      coordinator.join()
283      model_posteriors, model_predictions = model_outputs_evaled
284      (_, compare_posteriors,
285       compare_predictions) = compare_outputs_evaled
286      (model_posterior_mean, model_posterior_var,
287       model_from_time) = model_posteriors
288      (compare_posterior_mean, compare_posterior_var,
289       compare_from_time) = compare_posteriors
290      self.assertAllClose(model_posterior_mean, compare_posterior_mean[0])
291      self.assertAllClose(model_posterior_var, compare_posterior_var[0])
292      self.assertAllClose(model_from_time, compare_from_time)
293      self.assertEqual(sorted(model_predictions.keys()),
294                       sorted(compare_predictions.keys()))
295      for prediction_name in model_predictions:
296        if prediction_name == "loss":
297          # Chunking means that losses will be different; skip testing them.
298          continue
299        # Compare the last chunk to their corresponding un-chunked model
300        # predictions
301        last_prediction_chunk = compare_predictions[prediction_name][-1]
302        comparison_values = last_prediction_chunk.shape[0]
303        model_prediction = (
304            model_predictions[prediction_name][0, -comparison_values:])
305        self.assertAllClose(model_prediction,
306                            last_prediction_chunk)
307
308  def _model_equivalent_to_chained_model_test_template(self, chunk_size):
309    def chained_model_outputs(original_model, data):
310      input_fn = test_utils.AllWindowInputFn(
311          input_pipeline.NumpyReader(data), window_size=chunk_size)
312      state_manager = state_management.ChainingStateManager(
313          state_saving_interval=1)
314      features, _ = input_fn()
315      state_manager.initialize_graph(original_model)
316      model_outputs = state_manager.define_loss(
317          model=original_model,
318          features=features,
319          mode=estimator_lib.ModeKeys.TRAIN)
320      def _eval_outputs(session):
321        for _ in range(50):
322          # Warm up saved state
323          model_outputs.loss.eval()
324        (posterior_mean, posterior_var,
325         priors_from_time) = model_outputs.end_state
326        posteriors = ((posterior_mean,), (posterior_var,), priors_from_time)
327        outputs = (model_outputs.loss, posteriors,
328                   model_outputs.predictions)
329        chunked_outputs_evaled = session.run(outputs)
330        return chunked_outputs_evaled
331      return _eval_outputs
332    self._equivalent_to_single_model_test_template(chained_model_outputs)
333
334  def test_model_equivalent_to_chained_model_chunk_size_one(self):
335    numpy.random.seed(2)
336    random_seed.set_random_seed(3)
337    self._model_equivalent_to_chained_model_test_template(1)
338
339  def test_model_equivalent_to_chained_model_chunk_size_five(self):
340    numpy.random.seed(4)
341    random_seed.set_random_seed(5)
342    self._model_equivalent_to_chained_model_test_template(5)
343
344
345class PredictionTests(test.TestCase):
346
347  def _check_predictions(
348      self, predicted_mean, predicted_covariance, window_size):
349    self.assertAllEqual(predicted_covariance.shape,
350                        [1,   # batch
351                         window_size,
352                         1,   # num features
353                         1])  # num features
354    self.assertAllEqual(predicted_mean.shape,
355                        [1,   # batch
356                         window_size,
357                         1])  # num features
358    for position in range(window_size - 2):
359      self.assertGreater(predicted_covariance[0, position + 2, 0, 0],
360                         predicted_covariance[0, position, 0, 0])
361
362  def test_predictions_direct(self):
363    dtype = dtypes.float64
364    with variable_scope.variable_scope(dtype.name):
365      random_model = RandomStateSpaceModel(
366          state_dimension=5, state_noise_dimension=4,
367          configuration=state_space_model.StateSpaceModelConfiguration(
368              dtype=dtype, num_features=1))
369      random_model.initialize_graph()
370      prediction_dict = random_model.predict(features={
371          feature_keys.PredictionFeatures.TIMES: [[1, 3, 5, 6]],
372          feature_keys.PredictionFeatures.STATE_TUPLE:
373              math_utils.replicate_state(
374                  start_state=random_model.get_start_state(), batch_size=1)
375      })
376      with self.cached_session():
377        variables.global_variables_initializer().run()
378        predicted_mean = prediction_dict["mean"].eval()
379        predicted_covariance = prediction_dict["covariance"].eval()
380      self._check_predictions(predicted_mean, predicted_covariance,
381                              window_size=4)
382
383  def test_predictions_after_loss(self):
384    dtype = dtypes.float32
385    with variable_scope.variable_scope(dtype.name):
386      random_model = RandomStateSpaceModel(
387          state_dimension=5, state_noise_dimension=4,
388          configuration=state_space_model.StateSpaceModelConfiguration(
389              dtype=dtype, num_features=1))
390      features = {
391          feature_keys.TrainEvalFeatures.TIMES: [[1, 2, 3, 4]],
392          feature_keys.TrainEvalFeatures.VALUES:
393              array_ops.ones([1, 4, 1], dtype=dtype)
394      }
395      passthrough = state_management.PassthroughStateManager()
396      random_model.initialize_graph()
397      passthrough.initialize_graph(random_model)
398      model_outputs = passthrough.define_loss(
399          model=random_model,
400          features=features,
401          mode=estimator_lib.ModeKeys.EVAL)
402      predictions = random_model.predict({
403          feature_keys.PredictionFeatures.TIMES: [[5, 7, 8]],
404          feature_keys.PredictionFeatures.STATE_TUPLE: model_outputs.end_state
405      })
406      with self.cached_session():
407        variables.global_variables_initializer().run()
408        predicted_mean = predictions["mean"].eval()
409        predicted_covariance = predictions["covariance"].eval()
410      self._check_predictions(predicted_mean, predicted_covariance,
411                              window_size=3)
412
413
414class ExogenousTests(test.TestCase):
415
416  def test_noise_increasing(self):
417    for dtype in [dtypes.float32, dtypes.float64]:
418      with variable_scope.variable_scope(dtype.name):
419        random_model = RandomStateSpaceModel(
420            state_dimension=5, state_noise_dimension=4,
421            configuration=state_space_model.StateSpaceModelConfiguration(
422                dtype=dtype, num_features=1))
423        original_covariance = array_ops.diag(array_ops.ones(shape=[5]))
424        _, new_covariance, _ = random_model._exogenous_noise_increasing(
425            current_times=[[1]],
426            exogenous_values=[[5.]],
427            state=[
428                array_ops.ones(shape=[1, 5]), original_covariance[None], [0]
429            ])
430        with self.cached_session() as session:
431          variables.global_variables_initializer().run()
432          evaled_new_covariance, evaled_original_covariance = session.run(
433              [new_covariance[0], original_covariance])
434          new_variances = numpy.diag(evaled_new_covariance)
435          original_variances = numpy.diag(evaled_original_covariance)
436          for i in range(5):
437            self.assertGreater(new_variances[i], original_variances[i])
438
439  def test_noise_decreasing(self):
440    for dtype in [dtypes.float32, dtypes.float64]:
441      with variable_scope.variable_scope(dtype.name):
442        random_model = RandomStateSpaceModel(
443            state_dimension=5, state_noise_dimension=4,
444            configuration=state_space_model.StateSpaceModelConfiguration(
445                dtype=dtype, num_features=1))
446        random_model.initialize_graph()
447        original_covariance = array_ops.diag(
448            array_ops.ones(shape=[5], dtype=dtype))
449        _, new_covariance, _ = random_model._exogenous_noise_decreasing(
450            current_times=[[1]],
451            exogenous_values=constant_op.constant([[-2.]], dtype=dtype),
452            state=[
453                -array_ops.ones(shape=[1, 5], dtype=dtype),
454                original_covariance[None], [0]
455            ])
456        with self.cached_session() as session:
457          variables.global_variables_initializer().run()
458          evaled_new_covariance, evaled_original_covariance = session.run(
459              [new_covariance[0], original_covariance])
460          new_variances = numpy.diag(evaled_new_covariance)
461          original_variances = numpy.diag(evaled_original_covariance)
462          for i in range(5):
463            self.assertLess(new_variances[i], original_variances[i])
464
465
466class StubStateSpaceModel(state_space_model.StateSpaceModel):
467
468  def __init__(self,
469               transition,
470               state_noise_dimension,
471               configuration=state_space_model.StateSpaceModelConfiguration()):
472    self.transition = transition
473    self.noise_transform = numpy.random.normal(
474        size=(transition.shape[0], state_noise_dimension)).astype(numpy.float32)
475    # Test feature + batch broadcasting
476    self.observation_model = numpy.random.normal(
477        size=(transition.shape[0])).astype(numpy.float32)
478    super(StubStateSpaceModel, self).__init__(
479        configuration=configuration)
480
481  def get_state_transition(self):
482    return self.transition
483
484  def get_noise_transform(self):
485    return self.noise_transform
486
487  def get_observation_model(self, times):
488    return self.observation_model
489
490
491GeneratedModel = collections.namedtuple(
492    "GeneratedModel", ["model", "data", "true_parameters"])
493
494
495class PosteriorTests(test.TestCase):
496
497  def _get_cycle_transition(self, period):
498    cycle_transition = numpy.zeros([period - 1, period - 1],
499                                   dtype=numpy.float32)
500    cycle_transition[0, :] = -1
501    cycle_transition[1:, :-1] = numpy.identity(period - 2)
502    return cycle_transition
503
504  _adder_transition = numpy.array([[1, 1],
505                                   [0, 1]], dtype=numpy.float32)
506
507  def _get_single_model(self):
508    numpy.random.seed(8)
509    stub_model = StubStateSpaceModel(
510        transition=self._get_cycle_transition(5), state_noise_dimension=0)
511    series_length = 1000
512    stub_model.initialize_graph()
513    true_params = stub_model.random_model_parameters()
514    data = stub_model.generate(
515        number_of_series=1, series_length=series_length,
516        model_parameters=true_params)
517    return GeneratedModel(
518        model=stub_model, data=data, true_parameters=true_params)
519
520  def test_exact_posterior_recovery_no_transition_noise(self):
521    with self.cached_session() as session:
522      stub_model, data, true_params = self._get_single_model()
523      input_fn = input_pipeline.WholeDatasetInputFn(
524          input_pipeline.NumpyReader(data))
525      features, _ = input_fn()
526      model_outputs = stub_model.get_batch_loss(
527          features=features,
528          mode=None,
529          state=math_utils.replicate_state(
530              start_state=stub_model.get_start_state(),
531              batch_size=array_ops.shape(
532                  features[feature_keys.TrainEvalFeatures.TIMES])[0]))
533      variables.global_variables_initializer().run()
534      coordinator = coordinator_lib.Coordinator()
535      queue_runner_impl.start_queue_runners(session, coord=coordinator)
536      posterior_mean, posterior_var, posterior_times = session.run(
537          # Feed the true model parameters so that this test doesn't depend on
538          # the generated parameters being close to the variable initializations
539          # (an alternative would be training steps to fit the noise values,
540          # which would be slow).
541          model_outputs.end_state, feed_dict=true_params)
542      coordinator.request_stop()
543      coordinator.join()
544
545      self.assertAllClose(numpy.zeros([1, 4, 4]), posterior_var,
546                          atol=1e-2)
547      self.assertAllClose(
548          numpy.dot(
549              numpy.linalg.matrix_power(
550                  stub_model.transition,
551                  data[feature_keys.TrainEvalFeatures.TIMES].shape[1]),
552              true_params[stub_model.prior_state_mean]),
553          posterior_mean[0],
554          rtol=1e-1)
555      self.assertAllClose(
556          math_utils.batch_end_time(
557              features[feature_keys.TrainEvalFeatures.TIMES]).eval(),
558          posterior_times)
559
560  def test_chained_exact_posterior_recovery_no_transition_noise(self):
561    with self.cached_session() as session:
562      stub_model, data, true_params = self._get_single_model()
563      chunk_size = 10
564      input_fn = test_utils.AllWindowInputFn(
565          input_pipeline.NumpyReader(data), window_size=chunk_size)
566      features, _ = input_fn()
567      state_manager = state_management.ChainingStateManager(
568          state_saving_interval=1)
569      state_manager.initialize_graph(stub_model)
570      model_outputs = state_manager.define_loss(
571          model=stub_model,
572          features=features,
573          mode=estimator_lib.ModeKeys.TRAIN)
574      variables.global_variables_initializer().run()
575      coordinator = coordinator_lib.Coordinator()
576      queue_runner_impl.start_queue_runners(session, coord=coordinator)
577      for _ in range(
578          data[feature_keys.TrainEvalFeatures.TIMES].shape[1] // chunk_size):
579        model_outputs.loss.eval()
580      posterior_mean, posterior_var, posterior_times = session.run(
581          model_outputs.end_state, feed_dict=true_params)
582      coordinator.request_stop()
583      coordinator.join()
584      self.assertAllClose(numpy.zeros([1, 4, 4]), posterior_var,
585                          atol=1e-2)
586      self.assertAllClose(
587          numpy.dot(
588              numpy.linalg.matrix_power(
589                  stub_model.transition,
590                  data[feature_keys.TrainEvalFeatures.TIMES].shape[1]),
591              true_params[stub_model.prior_state_mean]),
592          posterior_mean[0],
593          rtol=1e-1)
594      self.assertAllClose(data[feature_keys.TrainEvalFeatures.TIMES][:, -1],
595                          posterior_times)
596
597
598class TimeDependentStateSpaceModel(state_space_model.StateSpaceModel):
599  """A mostly trivial model which predicts values = times + 1."""
600
601  def __init__(self, static_unrolling_window_size_threshold=None):
602    super(TimeDependentStateSpaceModel, self).__init__(
603        configuration=state_space_model.StateSpaceModelConfiguration(
604            use_observation_noise=False,
605            transition_covariance_initial_log_scale_bias=5.,
606            static_unrolling_window_size_threshold=
607            static_unrolling_window_size_threshold))
608
609  def get_state_transition(self):
610    return array_ops.ones(shape=[1, 1])
611
612  def get_noise_transform(self):
613    return array_ops.ones(shape=[1, 1])
614
615  def get_observation_model(self, times):
616    return array_ops.reshape(
617        tensor=math_ops.cast(times + 1, dtypes.float32), shape=[-1, 1, 1])
618
619  def make_priors(self):
620    return (ops.convert_to_tensor([1.]), ops.convert_to_tensor([[0.]]))
621
622
623class UnknownShapeModel(TimeDependentStateSpaceModel):
624
625  def get_observation_model(self, times):
626    parent_model = super(UnknownShapeModel, self).get_observation_model(times)
627    return array_ops.placeholder_with_default(
628        input=parent_model, shape=tensor_shape.unknown_shape())
629
630
631class TimeDependentTests(test.TestCase):
632
633  def _time_dependency_test_template(self, model_type):
634    """Test that a time-dependent observation model influences predictions."""
635    model = model_type()
636    estimator = estimators.StateSpaceRegressor(
637        model=model, optimizer=gradient_descent.GradientDescentOptimizer(0.1))
638    values = numpy.reshape([1., 2., 3., 4.],
639                           newshape=[1, 4, 1])
640    input_fn = input_pipeline.WholeDatasetInputFn(
641        input_pipeline.NumpyReader({
642            feature_keys.TrainEvalFeatures.TIMES: [[0, 1, 2, 3]],
643            feature_keys.TrainEvalFeatures.VALUES: values
644        }))
645    estimator.train(input_fn=input_fn, max_steps=1)
646    predicted_values = estimator.evaluate(input_fn=input_fn, steps=1)["mean"]
647    # Throw out the first value so we don't test the prior
648    self.assertAllEqual(values[1:], predicted_values[1:])
649
650  def test_undefined_shape_time_dependency(self):
651    self._time_dependency_test_template(UnknownShapeModel)
652
653  def test_loop_unrolling(self):
654    """Tests running/restoring from a checkpoint with static unrolling."""
655    model = TimeDependentStateSpaceModel(
656        # Unroll during training, but not evaluation
657        static_unrolling_window_size_threshold=2)
658    estimator = estimators.StateSpaceRegressor(model=model)
659    times = numpy.arange(100)
660    values = numpy.arange(100)
661    dataset = {
662        feature_keys.TrainEvalFeatures.TIMES: times,
663        feature_keys.TrainEvalFeatures.VALUES: values
664    }
665    train_input_fn = input_pipeline.RandomWindowInputFn(
666        input_pipeline.NumpyReader(dataset), batch_size=16, window_size=2)
667    eval_input_fn = input_pipeline.WholeDatasetInputFn(
668        input_pipeline.NumpyReader(dataset))
669    estimator.train(input_fn=train_input_fn, max_steps=1)
670    estimator.evaluate(input_fn=eval_input_fn, steps=1)
671
672
673class LevelOnlyModel(state_space_model.StateSpaceModel):
674
675  def get_state_transition(self):
676    return linalg_ops.eye(1, dtype=self.dtype)
677
678  def get_noise_transform(self):
679    return linalg_ops.eye(1, dtype=self.dtype)
680
681  def get_observation_model(self, times):
682    return [1]
683
684
685class MultivariateLevelModel(
686    state_space_model.StateSpaceCorrelatedFeaturesEnsemble):
687
688  def __init__(self, configuration):
689    univariate_component_configuration = configuration._replace(
690        num_features=1)
691    components = []
692    for feature in range(configuration.num_features):
693      with variable_scope.variable_scope("feature{}".format(feature)):
694        components.append(
695            LevelOnlyModel(configuration=univariate_component_configuration))
696    super(MultivariateLevelModel, self).__init__(
697        ensemble_members=components, configuration=configuration)
698
699
700class MultivariateTests(test.TestCase):
701
702  def test_multivariate(self):
703    dtype = dtypes.float32
704    num_features = 3
705    covariance = numpy.eye(num_features)
706    # A single off-diagonal has a non-zero value in the true transition
707    # noise covariance.
708    covariance[-1, 0] = 1.
709    covariance[0, -1] = 1.
710    dataset_size = 100
711    values = numpy.cumsum(
712        numpy.random.multivariate_normal(
713            mean=numpy.zeros(num_features),
714            cov=covariance,
715            size=dataset_size),
716        axis=0)
717    times = numpy.arange(dataset_size)
718    model = MultivariateLevelModel(
719        configuration=state_space_model.StateSpaceModelConfiguration(
720            num_features=num_features,
721            dtype=dtype,
722            use_observation_noise=False,
723            transition_covariance_initial_log_scale_bias=5.))
724    estimator = estimators.StateSpaceRegressor(
725        model=model, optimizer=gradient_descent.GradientDescentOptimizer(0.1))
726    data = {
727        feature_keys.TrainEvalFeatures.TIMES: times,
728        feature_keys.TrainEvalFeatures.VALUES: values
729    }
730    train_input_fn = input_pipeline.RandomWindowInputFn(
731        input_pipeline.NumpyReader(data), batch_size=16, window_size=16)
732    estimator.train(input_fn=train_input_fn, steps=1)
733    for component in model._ensemble_members:
734      # Check that input statistics propagated to component models
735      self.assertTrue(component._input_statistics)
736
737  def test_ensemble_observation_noise(self):
738    model = MultivariateLevelModel(
739        configuration=state_space_model.StateSpaceModelConfiguration())
740    model.initialize_graph()
741    outputs = model.define_loss(
742        features={
743            feature_keys.TrainEvalFeatures.TIMES:
744                constant_op.constant([[1, 2]]),
745            feature_keys.TrainEvalFeatures.VALUES:
746                constant_op.constant([[[1.], [2.]]])
747        },
748        mode=estimator_lib.ModeKeys.TRAIN)
749    initializer = variables.global_variables_initializer()
750    with self.cached_session() as sess:
751      sess.run([initializer])
752      outputs.loss.eval()
753
754if __name__ == "__main__":
755  test.main()
756