1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Kalman filtering."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy
22
23from tensorflow.contrib.timeseries.python.timeseries import math_utils
24from tensorflow.contrib.timeseries.python.timeseries.state_space_models import kalman_filter
25
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import linalg_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.platform import test
32
33
34# Two-dimensional state model with "slope" and "level" components.
35STATE_TRANSITION = [
36    [1., 1.],  # Add slope to level
37    [0., 1.]   # Maintain slope
38]
39# Independent noise for each component
40STATE_TRANSITION_NOISE = [[0.1, 0.0], [0.0, 0.2]]
41OBSERVATION_MODEL = [[[0.5, 0.0], [0.0, 1.0]]]
42OBSERVATION_NOISE = [[0.0001, 0.], [0., 0.0002]]
43STATE_NOISE_TRANSFORM = [[1.0, 0.0], [0.0, 1.0]]
44
45
46def _powers_and_sums_from_transition_matrix(
47    state_transition, state_transition_noise_covariance,
48    state_noise_transform, max_gap=1):
49  def _transition_matrix_powers(powers):
50    return math_utils.matrix_to_powers(state_transition, powers)
51  def _power_sums(num_steps):
52    power_sums_tensor = math_utils.power_sums_tensor(
53        max_gap + 1, state_transition,
54        math_ops.matmul(state_noise_transform,
55                        math_ops.matmul(
56                            state_transition_noise_covariance,
57                            state_noise_transform,
58                            adjoint_b=True)))
59    return array_ops.gather(power_sums_tensor, indices=num_steps)
60  return (_transition_matrix_powers, _power_sums)
61
62
63class MultivariateTests(test.TestCase):
64
65  def _multivariate_symmetric_covariance_test_template(
66      self, dtype, simplified_posterior_variance_computation):
67    """Check that errors aren't building up asymmetries in covariances."""
68    kf = kalman_filter.KalmanFilter(dtype=dtype)
69    observation_noise_covariance = constant_op.constant(
70        [[1., 0.5], [0.5, 1.]], dtype=dtype)
71    observation_model = constant_op.constant(
72        [[[1., 0., 0., 0.], [0., 0., 1., 0.]]], dtype=dtype)
73    state = array_ops.placeholder(shape=[1, 4], dtype=dtype)
74    state_var = array_ops.placeholder(shape=[1, 4, 4], dtype=dtype)
75    observation = array_ops.placeholder(shape=[1, 2], dtype=dtype)
76    transition_fn, power_sum_fn = _powers_and_sums_from_transition_matrix(
77        state_transition=constant_op.constant(
78            [[1., 1., 0., 0.], [0., 1., 0., 0.], [0., 0., 1., 1.],
79             [0., 0., 0., 1.]],
80            dtype=dtype),
81        state_noise_transform=linalg_ops.eye(4, dtype=dtype),
82        state_transition_noise_covariance=constant_op.constant(
83            [[1., 0., 0.5, 0.], [0., 1., 0., 0.5], [0.5, 0., 1., 0.],
84             [0., 0.5, 0., 1.]],
85            dtype=dtype))
86    pred_state = kf.predict_state_mean(
87        prior_state=state, transition_matrices=transition_fn([1]))
88    pred_state_var = kf.predict_state_var(
89        prior_state_var=state_var, transition_matrices=transition_fn([1]),
90        transition_noise_sums=power_sum_fn([1]))
91    observed_mean, observed_var = kf.observed_from_state(
92        state_mean=pred_state, state_var=pred_state_var,
93        observation_model=observation_model,
94        observation_noise=observation_noise_covariance)
95    post_state, post_state_var = kf.posterior_from_prior_state(
96        prior_state=pred_state, prior_state_var=pred_state_var,
97        observation=observation,
98        observation_model=observation_model,
99        predicted_observations=(observed_mean, observed_var),
100        observation_noise=observation_noise_covariance)
101    with self.cached_session() as session:
102      evaled_state = numpy.array([[1., 1., 1., 1.]])
103      evaled_state_var = numpy.eye(4)[None]
104      for i in range(500):
105        evaled_state, evaled_state_var, evaled_observed_var = session.run(
106            [post_state, post_state_var, observed_var],
107            feed_dict={state: evaled_state,
108                       state_var: evaled_state_var,
109                       observation: [[float(i), float(i)]]})
110        self.assertAllClose(evaled_observed_var[0],
111                            evaled_observed_var[0].T)
112        self.assertAllClose(evaled_state_var[0],
113                            evaled_state_var[0].T)
114
115  def test_multivariate_symmetric_covariance_float32(self):
116    self._multivariate_symmetric_covariance_test_template(
117        dtypes.float32, simplified_posterior_variance_computation=False)
118
119  def test_multivariate_symmetric_covariance_float64(self):
120    self._multivariate_symmetric_covariance_test_template(
121        dtypes.float64, simplified_posterior_variance_computation=True)
122
123
124class KalmanFilterNonBatchTest(test.TestCase):
125  """Single-batch KalmanFilter tests."""
126
127  def setUp(self):
128    """The basic model defined above, with unit batches."""
129    self.kalman_filter = kalman_filter.KalmanFilter()
130    self.transition_fn, self.power_sum_fn = (
131        _powers_and_sums_from_transition_matrix(
132            state_transition=STATE_TRANSITION,
133            state_transition_noise_covariance=STATE_TRANSITION_NOISE,
134            state_noise_transform=STATE_NOISE_TRANSFORM,
135            max_gap=5))
136
137  def test_observed_from_state(self):
138    """Compare observation mean and noise to hand-computed values."""
139    with self.cached_session():
140      state = constant_op.constant([[2., 1.]])
141      state_var = constant_op.constant([[[4., 0.], [0., 3.]]])
142      observed_mean, observed_var = self.kalman_filter.observed_from_state(
143          state, state_var,
144          observation_model=OBSERVATION_MODEL,
145          observation_noise=OBSERVATION_NOISE)
146      observed_mean_override, observed_var_override = (
147          self.kalman_filter.observed_from_state(
148              state, state_var,
149              observation_model=OBSERVATION_MODEL,
150              observation_noise=100 * constant_op.constant(
151                  OBSERVATION_NOISE)[None]))
152      self.assertAllClose(numpy.array([[1., 1.]]),
153                          observed_mean.eval())
154      self.assertAllClose(numpy.array([[1., 1.]]),
155                          observed_mean_override.eval())
156      self.assertAllClose(numpy.array([[[1.0001, 0.], [0., 3.0002]]]),
157                          observed_var.eval())
158      self.assertAllClose(numpy.array([[[1.01, 0.], [0., 3.02]]]),
159                          observed_var_override.eval())
160
161  def _posterior_from_prior_state_test_template(
162      self, state, state_var, observation, observation_model, observation_noise,
163      expected_state, expected_state_var):
164    """Test that repeated observations converge to the expected value."""
165    predicted_observations = self.kalman_filter.observed_from_state(
166        state, state_var, observation_model,
167        observation_noise=observation_noise)
168    state_update, state_var_update = (
169        self.kalman_filter.posterior_from_prior_state(
170            state, state_var, observation,
171            observation_model=observation_model,
172            predicted_observations=predicted_observations,
173            observation_noise=observation_noise))
174    with self.cached_session() as session:
175      evaled_state, evaled_state_var = session.run([state, state_var])
176      for _ in range(300):
177        evaled_state, evaled_state_var = session.run(
178            [state_update, state_var_update],
179            feed_dict={state: evaled_state, state_var: evaled_state_var})
180    self.assertAllClose(expected_state,
181                        evaled_state,
182                        atol=1e-5)
183    self.assertAllClose(
184        expected_state_var,
185        evaled_state_var,
186        atol=1e-5)
187
188  def test_posterior_from_prior_state_univariate(self):
189    self._posterior_from_prior_state_test_template(
190        state=constant_op.constant([[0.3]]),
191        state_var=constant_op.constant([[[1.]]]),
192        observation=constant_op.constant([[1.]]),
193        observation_model=[[[2.]]],
194        observation_noise=[[[0.01]]],
195        expected_state=numpy.array([[0.5]]),
196        expected_state_var=[[[0.]]])
197
198  def test_posterior_from_prior_state_univariate_unit_noise(self):
199    self._posterior_from_prior_state_test_template(
200        state=constant_op.constant([[0.3]]),
201        state_var=constant_op.constant([[[1e10]]]),
202        observation=constant_op.constant([[1.]]),
203        observation_model=[[[2.]]],
204        observation_noise=[[[1.0]]],
205        expected_state=numpy.array([[0.5]]),
206        expected_state_var=[[[1. / (300. * 2. ** 2)]]])
207
208  def test_posterior_from_prior_state_multivariate_2d(self):
209    self._posterior_from_prior_state_test_template(
210        state=constant_op.constant([[1.9, 1.]]),
211        state_var=constant_op.constant([[[1., 0.], [0., 2.]]]),
212        observation=constant_op.constant([[1., 1.]]),
213        observation_model=OBSERVATION_MODEL,
214        observation_noise=OBSERVATION_NOISE,
215        expected_state=numpy.array([[2., 1.]]),
216        expected_state_var=[[[0., 0.], [0., 0.]]])
217
218  def test_posterior_from_prior_state_multivariate_3d(self):
219    self._posterior_from_prior_state_test_template(
220        state=constant_op.constant([[1.9, 1., 5.]]),
221        state_var=constant_op.constant(
222            [[[200., 0., 1.], [0., 2000., 0.], [1., 0., 40000.]]]),
223        observation=constant_op.constant([[1., 1., 3.]]),
224        observation_model=constant_op.constant(
225            [[[0.5, 0., 0.],
226              [0., 10., 0.],
227              [0., 0., 100.]]]),
228        observation_noise=linalg_ops.eye(3) / 10000.,
229        expected_state=numpy.array([[2., .1, .03]]),
230        expected_state_var=numpy.zeros([1, 3, 3]))
231
232  def test_predict_state_mean(self):
233    """Compare state mean transitions with simple hand-computed values."""
234    with self.cached_session():
235      state = constant_op.constant([[4., 2.]])
236      state = self.kalman_filter.predict_state_mean(
237          state, self.transition_fn([1]))
238      for _ in range(2):
239        state = self.kalman_filter.predict_state_mean(
240            state, self.transition_fn([1]))
241      self.assertAllClose(
242          numpy.array([[2. * 3. + 4.,  # Slope * time + base
243                        2.]]),
244          state.eval())
245
246  def test_predict_state_var(self):
247    """Compare a variance transition with simple hand-computed values."""
248    with self.cached_session():
249      state_var = constant_op.constant([[[1., 0.], [0., 2.]]])
250      state_var = self.kalman_filter.predict_state_var(
251          state_var, self.transition_fn([1]), self.power_sum_fn([1]))
252      self.assertAllClose(
253          numpy.array([[[3.1, 2.0], [2.0, 2.2]]]),
254          state_var.eval())
255
256  def test_do_filter(self):
257    """Tests do_filter.
258
259    Tests that correct values have high probability and incorrect values
260    have low probability when there is low uncertainty.
261    """
262    with self.cached_session():
263      state = constant_op.constant([[4., 2.]])
264      state_var = constant_op.constant([[[0.0001, 0.], [0., 0.0001]]])
265      observation = constant_op.constant([[
266          .5 * (
267              4.  # Base
268              + 2.),  # State transition
269          2.
270      ]])
271      estimated_state = self.kalman_filter.predict_state_mean(
272          state, self.transition_fn([1]))
273      estimated_state_covariance = self.kalman_filter.predict_state_var(
274          state_var, self.transition_fn([1]), self.power_sum_fn([1]))
275      (predicted_observation,
276       predicted_observation_covariance) = (
277           self.kalman_filter.observed_from_state(
278               estimated_state, estimated_state_covariance,
279               observation_model=OBSERVATION_MODEL,
280               observation_noise=OBSERVATION_NOISE))
281      (_, _, first_log_prob) = self.kalman_filter.do_filter(
282          estimated_state=estimated_state,
283          estimated_state_covariance=estimated_state_covariance,
284          predicted_observation=predicted_observation,
285          predicted_observation_covariance=predicted_observation_covariance,
286          observation=observation,
287          observation_model=OBSERVATION_MODEL,
288          observation_noise=OBSERVATION_NOISE)
289      self.assertGreater(first_log_prob.eval()[0], numpy.log(0.99))
290
291  def test_predict_n_ahead_mean(self):
292    with self.cached_session():
293      original_state = constant_op.constant([[4., 2.]])
294      n = 5
295      iterative_state = original_state
296      for i in range(n):
297        self.assertAllClose(
298            iterative_state.eval(),
299            self.kalman_filter.predict_state_mean(
300                original_state,
301                self.transition_fn([i])).eval())
302        iterative_state = self.kalman_filter.predict_state_mean(
303            iterative_state,
304            self.transition_fn([1]))
305
306  def test_predict_n_ahead_var(self):
307    with self.cached_session():
308      original_var = constant_op.constant([[[2., 3.], [4., 5.]]])
309      n = 5
310      iterative_var = original_var
311      for i in range(n):
312        self.assertAllClose(
313            iterative_var.eval(),
314            self.kalman_filter.predict_state_var(
315                original_var,
316                self.transition_fn([i]),
317                self.power_sum_fn([i])).eval())
318        iterative_var = self.kalman_filter.predict_state_var(
319            iterative_var,
320            self.transition_fn([1]),
321            self.power_sum_fn([1]))
322
323
324class KalmanFilterBatchTest(test.TestCase):
325  """KalmanFilter tests with more than one element batches."""
326
327  def test_do_filter_batch(self):
328    """Tests do_filter, in batch mode.
329
330    Tests that correct values have high probability and incorrect values
331    have low probability when there is low uncertainty.
332    """
333    with self.cached_session():
334      state = constant_op.constant([[4., 2.], [5., 3.], [6., 4.]])
335      state_var = constant_op.constant(3 * [[[0.0001, 0.], [0., 0.0001]]])
336      observation = constant_op.constant([
337          [
338              .5 * (
339                  4.  # Base
340                  + 2.),  # State transition
341              2.
342          ],
343          [
344              .5 * (
345                  5.  # Base
346                  + 3.),  # State transition
347              3.
348          ],
349          [3.14, 2.71]
350      ])  # Low probability observation
351      kf = kalman_filter.KalmanFilter()
352      transition_fn, power_sum_fn = _powers_and_sums_from_transition_matrix(
353          state_transition=STATE_TRANSITION,
354          state_transition_noise_covariance=STATE_TRANSITION_NOISE,
355          state_noise_transform=STATE_NOISE_TRANSFORM,
356          max_gap=2)
357      estimated_state = kf.predict_state_mean(state, transition_fn(3*[1]))
358      estimated_state_covariance = kf.predict_state_var(
359          state_var, transition_fn(3*[1]), power_sum_fn(3*[1]))
360      observation_model = array_ops.tile(OBSERVATION_MODEL, [3, 1, 1])
361      (predicted_observation,
362       predicted_observation_covariance) = (
363           kf.observed_from_state(
364               estimated_state, estimated_state_covariance,
365               observation_model=observation_model,
366               observation_noise=OBSERVATION_NOISE))
367      (state, state_var, log_prob) = kf.do_filter(
368          estimated_state=estimated_state,
369          estimated_state_covariance=estimated_state_covariance,
370          predicted_observation=predicted_observation,
371          predicted_observation_covariance=predicted_observation_covariance,
372          observation=observation,
373          observation_model=observation_model,
374          observation_noise=OBSERVATION_NOISE)
375      first_log_prob, second_log_prob, third_log_prob = log_prob.eval()
376      self.assertGreater(first_log_prob.sum(), numpy.log(0.99))
377      self.assertGreater(second_log_prob.sum(), numpy.log(0.99))
378      self.assertLess(third_log_prob.sum(), numpy.log(0.01))
379
380  def test_predict_n_ahead_mean(self):
381    with self.cached_session():
382      kf = kalman_filter.KalmanFilter()
383      transition_fn, _ = _powers_and_sums_from_transition_matrix(
384          state_transition=STATE_TRANSITION,
385          state_transition_noise_covariance=STATE_TRANSITION_NOISE,
386          state_noise_transform=STATE_NOISE_TRANSFORM,
387          max_gap=2)
388      original_state = constant_op.constant([[4., 2.], [3., 1.], [6., 2.]])
389      state0 = original_state
390      state1 = kf.predict_state_mean(state0, transition_fn(3 * [1]))
391      state2 = kf.predict_state_mean(state1, transition_fn(3 * [1]))
392      batch_eval = kf.predict_state_mean(
393          original_state, transition_fn([1, 0, 2])).eval()
394      self.assertAllClose(state0.eval()[1], batch_eval[1])
395      self.assertAllClose(state1.eval()[0], batch_eval[0])
396      self.assertAllClose(state2.eval()[2], batch_eval[2])
397
398  def test_predict_n_ahead_var(self):
399    with self.cached_session():
400      kf = kalman_filter.KalmanFilter()
401      transition_fn, power_sum_fn = _powers_and_sums_from_transition_matrix(
402          state_transition=STATE_TRANSITION,
403          state_transition_noise_covariance=STATE_TRANSITION_NOISE,
404          state_noise_transform=STATE_NOISE_TRANSFORM,
405          max_gap=2)
406      base_var = 2.0 * numpy.identity(2) + numpy.ones([2, 2])
407      original_var = constant_op.constant(
408          numpy.array(
409              [base_var, 2.0 * base_var, 3.0 * base_var], dtype=numpy.float32))
410      var0 = original_var
411      var1 = kf.predict_state_var(
412          var0, transition_fn(3 * [1]), power_sum_fn(3 * [1]))
413      var2 = kf.predict_state_var(
414          var1, transition_fn(3 * [1]), power_sum_fn(3 * [1]))
415      batch_eval = kf.predict_state_var(
416          original_var,
417          transition_fn([1, 0, 2]),
418          power_sum_fn([1, 0, 2])).eval()
419      self.assertAllClose(var0.eval()[1], batch_eval[1])
420      self.assertAllClose(var1.eval()[0], batch_eval[0])
421      self.assertAllClose(var2.eval()[2], batch_eval[2])
422
423
424if __name__ == "__main__":
425  test.main()
426