1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implements Kalman filtering for linear state space models."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.contrib.timeseries.python.timeseries import math_utils
22
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import control_flow_ops
27from tensorflow.python.ops import linalg_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import numerics
30
31
32# TODO(allenl): support for always-factored covariance matrices
33class KalmanFilter(object):
34  """Inference on linear state models.
35
36  The model for observations in a given state is:
37    observation(t) = observation_model * state(t)
38        + Gaussian(0, observation_noise_covariance)
39
40  State updates take the following form:
41    state(t) = state_transition * state(t-1)
42        + state_noise_transform * Gaussian(0, state_transition_noise_covariance)
43
44  This is a real-valued analog to hidden Markov models, with linear transitions
45  and a Gaussian noise model. Given initial conditions, noise, and state
46  transition, Kalman filtering recursively estimates states and observations,
47  along with their associated uncertainty. When fed observations, future state
48  and uncertainty estimates are conditioned on those observations (in a Bayesian
49  sense).
50
51  Typically some "given"s mentioned above (noises) will be unknown, and so
52  optimizing the Kalman filter's probabilistic predictions with respect to these
53  parameters is a good approach. The state transition and observation models are
54  usually known a priori as a modeling decision.
55
56  """
57
58  def __init__(self, dtype=dtypes.float32,
59               simplified_posterior_covariance_computation=False):
60    """Initialize the Kalman filter.
61
62    Args:
63      dtype: The data type to use for floating point tensors.
64      simplified_posterior_covariance_computation: If True, uses an algebraic
65        simplification of the Kalman filtering posterior covariance update,
66        which is slightly faster at the cost of numerical stability. The
67        simplified update is often stable when using double precision on small
68        models or with fixed transition matrices.
69    """
70    self._simplified_posterior_covariance_computation = (
71        simplified_posterior_covariance_computation)
72    self.dtype = dtype
73
74  def do_filter(
75      self, estimated_state, estimated_state_covariance,
76      predicted_observation, predicted_observation_covariance,
77      observation, observation_model, observation_noise):
78    """Convenience function for scoring predictions.
79
80    Scores a prediction against an observation, and computes the updated
81    posterior over states.
82
83    Shapes given below for arguments are for single-model Kalman filtering
84    (e.g. KalmanFilter). For ensembles, prior_state and prior_state_var are
85    same-length tuples of values corresponding to each model.
86
87    Args:
88      estimated_state: A prior mean over states [batch size x state dimension]
89      estimated_state_covariance: Covariance of state prior [batch size x D x
90          D], with D depending on the Kalman filter implementation (typically
91          the state dimension).
92      predicted_observation: A prediction for the observed value, such as that
93          returned by observed_from_state. A [batch size x num features] Tensor.
94      predicted_observation_covariance: A covariance matrix corresponding to
95          `predicted_observation`, a [batch size x num features x num features]
96          Tensor.
97      observation: The observed value corresponding to the predictions
98          given [batch size x observation dimension]
99      observation_model: The [batch size x observation dimension x model state
100          dimension] Tensor indicating how a particular state is mapped to
101          (pre-noise) observations for each part of the batch.
102      observation_noise: A [batch size x observation dimension x observation
103          dimension] Tensor or [observation dimension x observation dimension]
104          Tensor with covariance matrices to use for each part of the batch (a
105          two-dimensional input will be broadcast).
106    Returns:
107      posterior_state, posterior_state_var: Posterior mean and
108          covariance, updated versions of prior_state and
109          prior_state_var.
110      log_prediction_prob: Log probability of the observations under
111          the priors, suitable for optimization (should be maximized).
112
113    """
114    symmetrized_observation_covariance = 0.5 * (
115        predicted_observation_covariance + array_ops.matrix_transpose(
116            predicted_observation_covariance))
117    instability_message = (
118        "This may occur due to numerically unstable filtering when there is "
119        "a large difference in posterior variances, or when inferences are "
120        "near-deterministic. Considering tuning the "
121        "'filtering_maximum_posterior_variance_ratio' or "
122        "'filtering_minimum_posterior_variance' parameters in your "
123        "StateSpaceModelConfiguration, or tuning the transition matrix.")
124    symmetrized_observation_covariance = numerics.verify_tensor_all_finite(
125        symmetrized_observation_covariance,
126        "Predicted observation covariance was not finite. {}".format(
127            instability_message))
128    diag = array_ops.matrix_diag_part(symmetrized_observation_covariance)
129    min_diag = math_ops.reduce_min(diag)
130    non_negative_assert = control_flow_ops.Assert(
131        min_diag >= 0.,
132        [("The predicted observation covariance "
133          "has a negative diagonal entry. {}").format(instability_message),
134         min_diag])
135    with ops.control_dependencies([non_negative_assert]):
136      observation_covariance_cholesky = linalg_ops.cholesky(
137          symmetrized_observation_covariance)
138    log_prediction_prob = math_utils.mvn_tril_log_prob(
139        loc=predicted_observation,
140        scale_tril=observation_covariance_cholesky,
141        x=observation)
142    (posterior_state,
143     posterior_state_var) = self.posterior_from_prior_state(
144         prior_state=estimated_state,
145         prior_state_var=estimated_state_covariance,
146         observation=observation,
147         observation_model=observation_model,
148         predicted_observations=(predicted_observation,
149                                 predicted_observation_covariance),
150         observation_noise=observation_noise)
151    return (posterior_state, posterior_state_var, log_prediction_prob)
152
153  def predict_state_mean(self, prior_state, transition_matrices):
154    """Compute state transitions.
155
156    Args:
157      prior_state: Current estimated state mean [batch_size x state_dimension]
158      transition_matrices: A [batch size, state dimension, state dimension]
159        batch of matrices (dtype matching the `dtype` argument to the
160        constructor) with the transition matrix raised to the power of the
161        number of steps to be taken (not element-wise; use
162        math_utils.matrix_to_powers if there is no efficient special case) if
163        more than one step is desired.
164    Returns:
165      State mean advanced based on `transition_matrices` (dimensions matching
166      first argument).
167    """
168    advanced_state = array_ops.squeeze(
169        math_ops.matmul(
170            transition_matrices,
171            prior_state[..., None]),
172        axis=[-1])
173    return advanced_state
174
175  def predict_state_var(
176      self, prior_state_var, transition_matrices, transition_noise_sums):
177    r"""Compute variance for state transitions.
178
179    Computes a noise estimate corresponding to the value returned by
180    predict_state_mean.
181
182    Args:
183      prior_state_var: Covariance matrix specifying uncertainty of current state
184          estimate [batch size x state dimension x state dimension]
185      transition_matrices: A [batch size, state dimension, state dimension]
186        batch of matrices (dtype matching the `dtype` argument to the
187        constructor) with the transition matrix raised to the power of the
188        number of steps to be taken (not element-wise; use
189        math_utils.matrix_to_powers if there is no efficient special case).
190      transition_noise_sums: A [batch size, state dimension, state dimension]
191        Tensor (dtype matching the `dtype` argument to the constructor) with:
192
193          \sum_{i=0}^{num_steps - 1} (
194             state_transition_to_powers_fn(i)
195             * state_transition_noise_covariance
196             * state_transition_to_powers_fn(i)^T
197          )
198
199        for the number of steps to be taken in each part of the batch (this
200        should match `transition_matrices`). Use math_utils.power_sums_tensor
201        with `tf.gather` if there is no efficient special case.
202    Returns:
203      State variance advanced based on `transition_matrices` and
204      `transition_noise_sums` (dimensions matching first argument).
205    """
206    prior_variance_transitioned = math_ops.matmul(
207        math_ops.matmul(transition_matrices, prior_state_var),
208        transition_matrices,
209        adjoint_b=True)
210    return prior_variance_transitioned + transition_noise_sums
211
212  def posterior_from_prior_state(self, prior_state, prior_state_var,
213                                 observation, observation_model,
214                                 predicted_observations,
215                                 observation_noise):
216    """Compute a posterior over states given an observation.
217
218    Args:
219      prior_state: Prior state mean [batch size x state dimension]
220      prior_state_var: Prior state covariance [batch size x state dimension x
221          state dimension]
222      observation: The observed value corresponding to the predictions given
223          [batch size x observation dimension]
224      observation_model: The [batch size x observation dimension x model state
225          dimension] Tensor indicating how a particular state is mapped to
226          (pre-noise) observations for each part of the batch.
227      predicted_observations: An (observation mean, observation variance) tuple
228          computed based on the current state, usually the output of
229          observed_from_state.
230      observation_noise: A [batch size x observation dimension x observation
231          dimension] or [observation dimension x observation dimension] Tensor
232          with covariance matrices to use for each part of the batch (a
233          two-dimensional input will be broadcast).
234    Returns:
235      Posterior mean and covariance (dimensions matching the first two
236      arguments).
237
238    """
239    observed_mean, observed_var = predicted_observations
240    residual = observation - observed_mean
241    # TODO(allenl): Can more of this be done using matrix_solve_ls?
242    kalman_solve_rhs = math_ops.matmul(
243        observation_model, prior_state_var, adjoint_b=True)
244    # This matrix_solve adjoint doesn't make a difference symbolically (since
245    # observed_var is a covariance matrix, and should be symmetric), but
246    # filtering on multivariate series is unstable without it. See
247    # test_multivariate_symmetric_covariance_float64 in kalman_filter_test.py
248    # for an example of the instability (fails with adjoint=False).
249    kalman_gain_transposed = linalg_ops.matrix_solve(
250        matrix=observed_var, rhs=kalman_solve_rhs, adjoint=True)
251    posterior_state = prior_state + array_ops.squeeze(
252        math_ops.matmul(
253            kalman_gain_transposed,
254            array_ops.expand_dims(residual, -1),
255            adjoint_a=True),
256        axis=[-1])
257    gain_obs = math_ops.matmul(
258        kalman_gain_transposed, observation_model, adjoint_a=True)
259    identity_extradim = linalg_ops.eye(
260        array_ops.shape(gain_obs)[1], dtype=gain_obs.dtype)[None]
261    identity_minus_factor = identity_extradim - gain_obs
262    if self._simplified_posterior_covariance_computation:
263      # posterior covariance =
264      #   (I - kalman_gain * observation_model) * prior_state_var
265      posterior_state_var = math_ops.matmul(identity_minus_factor,
266                                            prior_state_var)
267    else:
268      observation_noise = ops.convert_to_tensor(observation_noise)
269      # A Joseph form update, which provides better numeric stability than the
270      # simplified optimal Kalman gain update, at the cost of a few extra
271      # operations. Joseph form updates are valid for any gain (not just the
272      # optimal Kalman gain), and so are more forgiving of numerical errors in
273      # computing the optimal Kalman gain.
274      #
275      # posterior covariance =
276      #   (I - kalman_gain * observation_model) * prior_state_var
277      #     * (I - kalman_gain * observation_model)^T
278      #   + kalman_gain * observation_noise * kalman_gain^T
279      left_multiplied_state_var = math_ops.matmul(identity_minus_factor,
280                                                  prior_state_var)
281      multiplied_state_var = math_ops.matmul(
282          identity_minus_factor, left_multiplied_state_var, adjoint_b=True)
283      def _batch_observation_noise_update():
284        return (multiplied_state_var + math_ops.matmul(
285            math_ops.matmul(
286                kalman_gain_transposed, observation_noise, adjoint_a=True),
287            kalman_gain_transposed))
288      def _matrix_observation_noise_update():
289        return (multiplied_state_var + math_ops.matmul(
290            math_utils.batch_times_matrix(
291                kalman_gain_transposed, observation_noise, adj_x=True),
292            kalman_gain_transposed))
293      if observation_noise.get_shape().ndims is None:
294        posterior_state_var = control_flow_ops.cond(
295            math_ops.equal(array_ops.rank(observation_noise), 2),
296            _matrix_observation_noise_update, _batch_observation_noise_update)
297      else:
298        # If static shape information exists, it gets checked in each cond()
299        # branch, so we need a special case to avoid graph-build-time
300        # exceptions.
301        if observation_noise.get_shape().ndims == 2:
302          posterior_state_var = _matrix_observation_noise_update()
303        else:
304          posterior_state_var = _batch_observation_noise_update()
305    return posterior_state, posterior_state_var
306
307  def observed_from_state(self, state_mean, state_var, observation_model,
308                          observation_noise):
309    """Compute an observation distribution given a state distribution.
310
311    Args:
312      state_mean: State mean vector [batch size x state dimension]
313      state_var: State covariance [batch size x state dimension x state
314          dimension]
315      observation_model: The [batch size x observation dimension x model state
316          dimension] Tensor indicating how a particular state is mapped to
317          (pre-noise) observations for each part of the batch.
318      observation_noise: A [batch size x observation dimension x observation
319          dimension] Tensor with covariance matrices to use for each part of the
320          batch. To remove observation noise, pass a Tensor of zeros (or simply
321          0, which will broadcast).
322    Returns:
323      observed_mean: Observation mean vector [batch size x observation
324          dimension]
325      observed_var: Observation covariance [batch size x observation dimension x
326          observation dimension]
327
328    """
329    observed_mean = array_ops.squeeze(
330        math_ops.matmul(
331            array_ops.expand_dims(state_mean, 1),
332            observation_model,
333            adjoint_b=True),
334        axis=[1])
335    observed_var = math_ops.matmul(
336        math_ops.matmul(observation_model, state_var),
337        observation_model,
338        adjoint_b=True)
339    observed_var += observation_noise
340    return observed_mean, observed_var
341