1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Abstract base for state space models."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22import collections
23
24import numpy
25
26from tensorflow.contrib.layers.python.layers import layers
27
28from tensorflow.contrib.timeseries.python.timeseries import math_utils
29from tensorflow.contrib.timeseries.python.timeseries import model
30from tensorflow.contrib.timeseries.python.timeseries import model_utils
31from tensorflow.contrib.timeseries.python.timeseries.feature_keys import TrainEvalFeatures
32from tensorflow.contrib.timeseries.python.timeseries.state_space_models import kalman_filter
33
34from tensorflow.python.estimator import estimator_lib
35from tensorflow.python.framework import constant_op
36from tensorflow.python.framework import dtypes
37from tensorflow.python.framework import ops
38from tensorflow.python.framework import tensor_shape
39from tensorflow.python.ops import array_ops
40from tensorflow.python.ops import control_flow_ops
41from tensorflow.python.ops import gen_math_ops
42from tensorflow.python.ops import linalg_ops
43from tensorflow.python.ops import math_ops
44from tensorflow.python.ops import variable_scope
45
46
47class StateSpaceModelConfiguration(
48    collections.namedtuple(
49        typename="StateSpaceModelConfiguration",
50        field_names=[
51            "num_features", "use_observation_noise", "dtype",
52            "covariance_prior_fn", "bayesian_prior_weighting",
53            "filtering_postprocessor", "trainable_start_state",
54            "exogenous_noise_increases", "exogenous_noise_decreases",
55            "exogenous_feature_columns", "exogenous_update_condition",
56            "filtering_maximum_posterior_variance_ratio",
57            "filtering_minimum_posterior_variance",
58            "transition_covariance_initial_log_scale_bias",
59            "static_unrolling_window_size_threshold"])):
60  """Configuration options for StateSpaceModels."""
61
62  def __new__(
63      cls,
64      num_features=1,
65      use_observation_noise=True,
66      dtype=dtypes.float32,
67      covariance_prior_fn=math_utils.log_noninformative_covariance_prior,
68      bayesian_prior_weighting=True,
69      filtering_postprocessor=None,
70      trainable_start_state=False,
71      exogenous_noise_increases=True,
72      exogenous_noise_decreases=False,
73      exogenous_feature_columns=None,
74      exogenous_update_condition=None,
75      filtering_maximum_posterior_variance_ratio=1e6,
76      filtering_minimum_posterior_variance=0.,
77      transition_covariance_initial_log_scale_bias=-5.,
78      static_unrolling_window_size_threshold=None):
79    """Configuration options for StateSpaceModels.
80
81    Args:
82      num_features: Output dimension for model
83      use_observation_noise: If true, observations are modeled as noisy
84        functions of the current state. If false, observations are a
85        deterministic function of the current state. Only applicable to the
86        top-level model in an ensemble. Consider also changing the
87        transition_covariance_initial_log_scale_bias when disabling observation
88        noise, as its default setting assumes that observation noise is part of
89        the model.
90      dtype: The float dtype to use when defining the model.
91      covariance_prior_fn: A function mapping from a covariance matrix to a
92          scalar value (e.g. log likelihood) which can be summed across
93          matrices. Defaults to an independent Jeffreys prior on the diagonal
94          elements (regularizing as log(1. / variance)). To use a flat prior
95          (i.e. no regularization), set to `lambda _: 0.`.  Defaults to
96          relatively uninformative priors on state transition and observation
97          noise, which have the effect of encouraging low-noise solutions which
98          provide confident predictions when possible. Without regularization,
99          transition noise tends to remain high, and multi-step predictions are
100          under-confident.
101      bayesian_prior_weighting: If True, weights the prior (covariance_prior_fn)
102          based on an estimate of the full dataset size. If False, weights it
103          based on the mini-batch window size, which (while statistically
104          improper) can lead to more desirable low-noise solutions in cases
105          where the full dataset is large enough to overwhelm the prior.
106      filtering_postprocessor: A FilteringStepPostprocessor object to use,
107          useful for ignoring anomalies in training data.
108      trainable_start_state: If True, start state may depend on trainable
109          Variables. If False, it will not.
110      exogenous_noise_increases: If True, exogenous regressors can add to model
111          state, increasing uncertainty. If both this parameter and
112          exogenous_noise_decreases are False, exogenous regressors are ignored.
113      exogenous_noise_decreases: If True, exogenous regressors can "set" model
114          state, decreasing uncertainty. If both this parameter and
115          exogenous_noise_increases are False, exogenous regressors are ignored.
116      exogenous_feature_columns: A list of `tf.feature_column`s (for example
117          `tf.feature_column.embedding_column`) corresponding to exogenous
118          features which provide extra information to the model but are not part
119          of the series to be predicted. Passed to
120          `tf.feature_column.input_layer`.
121      exogenous_update_condition: A function taking two Tensor arguments `times`
122          (shape [batch size]) and `features` (a dictionary mapping exogenous
123          feature keys to Tensors with shapes [batch size, ...]) and returning a
124          boolean Tensor with shape [batch size] indicating whether state should
125          be updated using exogenous features for each part of the batch. Where
126          it is False, no exogenous update is performed. If None (default),
127          exogenous updates are always performed. Useful for avoiding "leaky"
128          frequent exogenous updates when sparse updates are desired. Called
129          only during graph construction.
130      filtering_maximum_posterior_variance_ratio: The maximum allowed ratio of
131          two diagonal entries in a state covariance matrix just prior to
132          filtering. Lower values mean that filtering will be more numerically
133          stable, at the cost of artificially increasing estimated uncertainty
134          in some cases. This parameter can be important when learning a
135          transition matrix.
136      filtering_minimum_posterior_variance: The minimum diagonal value in a
137          state covariance matrix just prior to filtering, preventing numerical
138          instability due to deterministic beliefs (sometimes an issue when
139          learning transition matrices). This value should be set several orders
140          of magnitude below any expected minimum state uncertainty.
141      transition_covariance_initial_log_scale_bias: Controls the initial
142          tradeoff between the transition noise covariance matrix and the
143          observation noise covariance matrix, on a log scale (the elements of
144          the transition noise covariance matrix are proportional to `e^{X +
145          transition_covariance_initial_log_scale_bias}` where `X` is learned
146          and may depend on input statistics, observation noise covariance is
147          proportional to `e^{Y -
148          transition_covariance_initial_log_scale_bias}`). For models *with*
149          observation noise, -5 is a reasonable value. Models which do not use
150          observation noise, and are not part of an ensemble which does use
151          observation noise, should have this set to 0 or more to avoid
152          numerical issues due to filtering with too little noise.
153      static_unrolling_window_size_threshold: Only relevant for the top-level
154          StateSpaceModel in an ensemble; enables switching between static and
155          dynamic looping (if not None, default, meaning that no static
156          unrolling is performed) based on the window size (windows with this
157          size and smaller will have their graphs unrolled statically). See the
158          SequentialTimeSeriesModel constructor for details.
159    Returns:
160      A StateSpaceModelConfiguration object.
161    """
162    if exogenous_feature_columns is None:
163      exogenous_feature_columns = []
164    return super(StateSpaceModelConfiguration, cls).__new__(
165        cls, num_features, use_observation_noise, dtype,
166        covariance_prior_fn, bayesian_prior_weighting,
167        filtering_postprocessor, trainable_start_state,
168        exogenous_noise_increases, exogenous_noise_decreases,
169        exogenous_feature_columns, exogenous_update_condition,
170        filtering_maximum_posterior_variance_ratio,
171        filtering_minimum_posterior_variance,
172        transition_covariance_initial_log_scale_bias,
173        static_unrolling_window_size_threshold)
174
175
176class StateSpaceModel(model.SequentialTimeSeriesModel):
177  """Base class for linear state space models.
178
179  Sub-classes can specify the model to be learned by overriding
180  get_state_transition, get_noise_transform, and get_observation_model.
181
182  See kalman_filter.py for a detailed description of the class of models covered
183  by StateSpaceModel.
184
185  Briefly, state space models are defined by a state transition equation:
186
187  state[t] = StateTransition * state[t-1] + NoiseTransform * StateNoise[t]
188             + ExogenousNoiseIncreasing[t]
189  StateNoise[t] ~ Gaussian(0, StateNoiseCovariance)
190  ExogenousNoiseIncreasing[t] ~ Gaussian(ExogenousNoiseIncreasingMean[t],
191                                         ExogenousNoiseIncreasingCovariance[t])
192
193  And an observation model:
194
195  observation[t] = ObservationModel * state[t] + ObservationNoise[t]
196  ObservationNoise[t] ~ Gaussian(0, ObservationNoiseCovariance)
197
198  Additionally, exogenous regressors can act as observations, decreasing
199  uncertainty:
200
201  ExogenousNoiseDecreasingObservation[t] ~ Gaussian(
202      ExogenousNoiseDecreasingMean[t], ExogenousNoiseDecreasingCovariance[t])
203
204  Attributes:
205    kalman_filter: If initialize_graph has been called, the initialized
206        KalmanFilter to use for inference. None otherwise.
207    prior_state_mean: If initialize_graph has been called, a
208        Variable-parameterized Tensor with shape [state dimension];
209        the initial prior mean for one or more time series. None otherwise.
210    prior_state_var: If initialize_graph has been called, a
211        Variable-parameterized Tensor with shape [state dimension x state
212        dimension]; the initial prior covariance. None otherwise.
213    state_transition_noise_covariance: If initialize_graph has been called, a
214        Variable-parameterized Tensor with shape [state noise dimension x state
215        noise dimension] indicating the amount of noise added at each
216        transition.
217  """
218
219  def __init__(self, configuration):
220    """Initialize a state space model.
221
222    Args:
223      configuration: A StateSpaceModelConfiguration object.
224    """
225    self._configuration = configuration
226    if configuration.filtering_postprocessor is not None:
227      filtering_postprocessor_names = (
228          configuration.filtering_postprocessor.output_names)
229    else:
230      filtering_postprocessor_names = []
231    super(StateSpaceModel, self).__init__(
232        train_output_names=(["mean", "covariance", "log_likelihood"]
233                            + filtering_postprocessor_names),
234        predict_output_names=["mean", "covariance"],
235        num_features=configuration.num_features,
236        normalize_features=True,
237        dtype=configuration.dtype,
238        exogenous_feature_columns=configuration.exogenous_feature_columns,
239        exogenous_update_condition=configuration.exogenous_update_condition,
240        static_unrolling_window_size_threshold=
241        configuration.static_unrolling_window_size_threshold)
242    self._kalman_filter = None
243    self.prior_state_mean = None
244    self.prior_state_var = None
245    self.state_transition_noise_covariance = None
246    self._total_observation_count = None
247    self._observation_noise_covariance = None
248    # Capture the current variable scope and use it to define all model
249    # variables. Especially useful for ensembles, where variables may be defined
250    # for every component model in one function call, which would otherwise
251    # prevent the user from separating variables from different models into
252    # different scopes.
253    self._variable_scope = variable_scope.get_variable_scope()
254
255  def transition_power_noise_accumulator(self, num_steps):
256    r"""Sum a transitioned covariance matrix over a number of steps.
257
258    Computes
259
260      \sum_{i=0}^{num_steps - 1} (
261        state_transition^i
262        * state_transition_noise_covariance
263        * (state_transition^i)^T)
264
265    If special cases are available, overriding this function can lead to more
266    efficient inferences.
267
268    Args:
269      num_steps: A [...] shape integer Tensor with numbers of steps to compute
270        power sums for.
271    Returns:
272      The computed power sum, with shape [..., state dimension, state
273      dimension].
274    """
275    # TODO(allenl): This general case should use cumsum if transition_to_powers
276    # can be computed in constant time (important for correlated ensembles,
277    # where transition_power_noise_accumulator special cases cannot be
278    # aggregated from member models).
279    noise_transform = ops.convert_to_tensor(self.get_noise_transform(),
280                                            self.dtype)
281    noise_transformed = math_ops.matmul(
282        math_ops.matmul(noise_transform,
283                        self.state_transition_noise_covariance),
284        noise_transform,
285        transpose_b=True)
286    noise_additions = math_utils.power_sums_tensor(
287        math_ops.reduce_max(num_steps) + 1,
288        ops.convert_to_tensor(self.get_state_transition(), dtype=self.dtype),
289        noise_transformed)
290    return array_ops.gather(noise_additions, indices=num_steps)
291
292  def transition_to_powers(self, powers):
293    """Raise the transition matrix to a batch of powers.
294
295    Computes state_transition^powers. If special cases are available, overriding
296    this function can lead to more efficient inferences.
297
298    Args:
299      powers: A [...] shape integer Tensor with powers to raise the transition
300        matrix to.
301    Returns:
302      The computed matrix powers, with shape [..., state dimension, state
303      dimension].
304    """
305    return math_utils.matrix_to_powers(
306        ops.convert_to_tensor(self.get_state_transition(), dtype=self.dtype),
307        powers)
308
309  def _window_initializer(self, times, state):
310    """Prepare to impute across the gaps in a window."""
311    _, _, priors_from_time = state
312    times = ops.convert_to_tensor(times)
313    priors_from_time = ops.convert_to_tensor(priors_from_time)
314    intra_batch_gaps = array_ops.reshape(times[:, 1:] - times[:, :-1], [-1])
315    # Ignore negative starting gaps, since there will be transient start times
316    # as inputs statistics are computed.
317    starting_gaps = math_ops.maximum(times[:, 0] - priors_from_time, 0)
318    # Pre-define transition matrices raised to powers (and their sums) for every
319    # gap in this window. This avoids duplicate computation (for example many
320    # steps will use the transition matrix raised to the first power) and
321    # batches the computation rather than doing it inside the per-step loop.
322    unique_gaps, _ = array_ops.unique(
323        array_ops.concat([intra_batch_gaps, starting_gaps], axis=0))
324    self._window_power_sums = self.transition_power_noise_accumulator(
325        unique_gaps)
326    self._window_transition_powers = self.transition_to_powers(unique_gaps)
327    self._window_gap_sizes = unique_gaps
328
329  def _lookup_window_caches(self, caches, indices):
330    _, window_power_ids = array_ops.unique(
331        array_ops.concat(
332            [
333                self._window_gap_sizes, math_ops.cast(
334                    indices, self._window_gap_sizes.dtype)
335            ],
336            axis=0))
337    all_gathered_indices = []
338    for cache in caches:
339      gathered_indices = array_ops.gather(
340          cache, window_power_ids[-array_ops.shape(indices)[0]:])
341      gathered_indices.set_shape(indices.get_shape().concatenate(
342          gathered_indices.get_shape()[-2:]))
343      all_gathered_indices.append(gathered_indices)
344    return all_gathered_indices
345
346  def _cached_transition_powers_and_sums(self, num_steps):
347    return self._lookup_window_caches(
348        caches=[self._window_transition_powers, self._window_power_sums],
349        indices=num_steps)
350
351  def _imputation_step(self, current_times, state):
352    """Add state transition noise to catch `state` up to `current_times`.
353
354    State space models are inherently sequential, so we need to "predict
355    through" any missing time steps to catch up each element of the batch to its
356    next observation/prediction time.
357
358    Args:
359      current_times: A [batch size] Tensor of times to impute up to, not
360          inclusive.
361      state: A tuple of (mean, covariance, previous_times) having shapes
362          mean; [batch size x state dimension]
363          covariance; [batch size x state dimension x state dimension]
364          previous_times; [batch size]
365    Returns:
366      Imputed model state corresponding to the `state` argument.
367    """
368    estimated_state, estimated_state_var, previous_times = state
369    # Ignore negative imputation intervals due to transient start time
370    # estimates.
371    catchup_times = math_ops.maximum(current_times - previous_times, 0)
372    transition_matrices, transition_noise_sums = (  # pylint: disable=unbalanced-tuple-unpacking
373        self._cached_transition_powers_and_sums(catchup_times))
374    estimated_state = self._kalman_filter.predict_state_mean(
375        estimated_state, transition_matrices)
376    estimated_state_var = self._kalman_filter.predict_state_var(
377        estimated_state_var, transition_matrices, transition_noise_sums)
378    return (estimated_state, estimated_state_var,
379            previous_times + catchup_times)
380
381  def _filtering_step(self, current_times, current_values, state, predictions):
382    """Compute posteriors and accumulate one-step-ahead predictions.
383
384    Args:
385      current_times: A [batch size] Tensor for times for each observation.
386      current_values: A [batch size] Tensor of values for each observation.
387      state: A tuple of (mean, covariance, previous_times) having shapes
388          mean; [batch size x state dimension]
389          covariance; [batch size x state dimension x state dimension]
390          previous_times; [batch size]
391      predictions: A dictionary containing mean and covariance Tensors, the
392          output of _prediction_step.
393    Returns:
394      A tuple of (posteriors, outputs):
395        posteriors: Model state updated to take `current_values` into account.
396        outputs: The `predictions` dictionary updated to include "loss" and
397            "log_likelihood" entries (loss simply being negative log
398            likelihood).
399    """
400    estimated_state, estimated_state_covariance, previous_times = state
401    observation_model = self.get_broadcasted_observation_model(current_times)
402    imputed_to_current_step_assert = control_flow_ops.Assert(
403        math_ops.reduce_all(math_ops.equal(current_times, previous_times)),
404        ["Attempted to perform filtering without imputation/prediction"])
405    with ops.control_dependencies([imputed_to_current_step_assert]):
406      estimated_state_covariance = math_utils.clip_covariance(
407          estimated_state_covariance,
408          self._configuration.filtering_maximum_posterior_variance_ratio,
409          self._configuration.filtering_minimum_posterior_variance)
410      (filtered_state, filtered_state_covariance,
411       log_prob) = self._kalman_filter.do_filter(
412           estimated_state=estimated_state,
413           estimated_state_covariance=estimated_state_covariance,
414           predicted_observation=predictions["mean"],
415           predicted_observation_covariance=predictions["covariance"],
416           observation=current_values,
417           observation_model=observation_model,
418           observation_noise=self._observation_noise_covariance)
419    filtered_state = (filtered_state, filtered_state_covariance, current_times)
420    log_prob.set_shape(current_times.get_shape())
421    predictions["loss"] = -log_prob
422    predictions["log_likelihood"] = log_prob
423    if self._configuration.filtering_postprocessor is not None:
424      return self._configuration.filtering_postprocessor.process_filtering_step(
425          current_times=current_times,
426          current_values=current_values,
427          predicted_state=state,
428          filtered_state=filtered_state,
429          outputs=predictions)
430    return (filtered_state, predictions)
431
432  def _scale_back_predictions(self, predictions):
433    """Return a window of predictions to input scale."""
434    predictions["mean"] = self._scale_back_data(predictions["mean"])
435    predictions["covariance"] = self._scale_back_variance(
436        predictions["covariance"])
437    return predictions
438
439  def _prediction_step(self, current_times, state):
440    """Make a prediction based on `state`.
441
442    Computes predictions based on the current `state`, checking that it has
443    already been updated (in `_imputation_step`) to `current_times`.
444
445    Args:
446      current_times: A [batch size] Tensor for times to make predictions for.
447      state: A tuple of (mean, covariance, previous_times) having shapes
448          mean; [batch size x state dimension]
449          covariance; [batch size x state dimension x state dimension]
450          previous_times; [batch size]
451    Returns:
452      A tuple of (updated state, predictions):
453        updated state: Model state with added transition noise.
454        predictions: A dictionary with "mean" and "covariance", having shapes
455            "mean": [batch size x num features]
456            "covariance: [batch size x num features x num features]
457    """
458    estimated_state, estimated_state_var, previous_times = state
459    advanced_to_current_assert = control_flow_ops.Assert(
460        math_ops.reduce_all(math_ops.less_equal(current_times, previous_times)),
461        ["Attempted to predict without imputation"])
462    with ops.control_dependencies([advanced_to_current_assert]):
463      observation_model = self.get_broadcasted_observation_model(current_times)
464      predicted_obs, predicted_obs_var = (
465          self._kalman_filter.observed_from_state(
466              state_mean=estimated_state,
467              state_var=estimated_state_var,
468              observation_model=observation_model,
469              observation_noise=self._observation_noise_covariance))
470      predicted_obs_var.set_shape(
471          ops.convert_to_tensor(current_times).get_shape()
472          .concatenate([self.num_features, self.num_features]))
473    predicted_obs.set_shape(current_times.get_shape().concatenate(
474        (self.num_features,)))
475    predicted_obs_var.set_shape(current_times.get_shape().concatenate(
476        (self.num_features, self.num_features)))
477    # Not scaled back to input-scale, since this also feeds into the
478    # loss. Instead, predictions are scaled back before being returned to the
479    # user in _scale_back_predictions.
480    predictions = {
481        "mean": predicted_obs,
482        "covariance": predicted_obs_var}
483    state = (estimated_state, estimated_state_var, current_times)
484    return (state, predictions)
485
486  def _exogenous_noise_decreasing(self, current_times, exogenous_values, state):
487    """Update state with exogenous regressors, decreasing uncertainty.
488
489    Constructs a mean and covariance based on transformations of
490    `exogenous_values`, then performs Bayesian inference on the constructed
491    observation. This has the effect of lowering uncertainty.
492
493    This update refines or overrides previous inferences, useful for modeling
494    exogenous inputs which "set" state, e.g. we dumped boiling water on the
495    thermometer so we're pretty sure it's 100 degrees C.
496
497    Args:
498      current_times: A [batch size] Tensor of times for the exogenous values
499          being input.
500      exogenous_values: A [batch size x exogenous input dimension] Tensor of
501          exogenous values for each part of the batch.
502      state: A tuple of (mean, covariance, previous_times) having shapes
503          mean; [batch size x state dimension]
504          covariance; [batch size x state dimension x state dimension]
505          previous_times; [batch size]
506    Returns:
507      Updated state taking the exogenous regressors into account (with lower
508      uncertainty than the input state).
509
510    """
511    estimated_state, estimated_state_covariance, previous_times = state
512    state_transition = ops.convert_to_tensor(
513        self.get_state_transition(), dtype=self.dtype)
514    state_dimension = tensor_shape.dimension_value(state_transition.shape[0])
515    # Learning the observation model would be redundant since we transform
516    # `exogenous_values` to the state space via a linear transformation anyway.
517    observation_model = linalg_ops.eye(
518        state_dimension,
519        batch_shape=array_ops.shape(exogenous_values)[:-1],
520        dtype=self.dtype)
521    with variable_scope.variable_scope("exogenous_noise_decreasing_covariance"):
522      observation_noise = math_utils.transform_to_covariance_matrices(
523          exogenous_values, state_dimension)
524    with variable_scope.variable_scope(
525        "exogenous_noise_decreasing_observation"):
526      observation = layers.fully_connected(
527          exogenous_values, state_dimension, activation_fn=None)
528    # Pretend that we are making an observation with an observation model equal
529    # to the identity matrix (i.e. a direct observation of the latent state),
530    # with learned observation noise.
531    posterior_state, posterior_state_var = (
532        self._kalman_filter.posterior_from_prior_state(
533            prior_state=estimated_state,
534            prior_state_var=estimated_state_covariance,
535            observation=observation,
536            observation_model=observation_model,
537            predicted_observations=(
538                estimated_state,
539                # The predicted noise covariance is noise due to current state
540                # uncertainty plus noise learned based on the exogenous
541                # observation (a somewhat trivial call to
542                # self._kalman_filter.observed_from_state has been omitted).
543                observation_noise + estimated_state_covariance),
544            observation_noise=observation_noise))
545    return (posterior_state, posterior_state_var, previous_times)
546
547  def _exogenous_noise_increasing(self, current_times, exogenous_values, state):
548    """Update state with exogenous regressors, increasing uncertainty.
549
550    Adds to the state mean a linear transformation of `exogenous_values`, and
551    increases uncertainty by constructing a covariance matrix based on
552    `exogenous_values` and adding it to the state covariance.
553
554    This update is useful for modeling changes relative to current state,
555    e.g. the furnace turned on so the temperature will be increasing at an
556    additional 1 degree per minute with some uncertainty, this uncertainty being
557    added to our current uncertainty in the per-minute change in temperature.
558
559    Args:
560      current_times: A [batch size] Tensor of times for the exogenous values
561          being input.
562      exogenous_values: A [batch size x exogenous input dimension] Tensor of
563          exogenous values for each part of the batch.
564      state: A tuple of (mean, covariance, previous_times) having shapes
565          mean; [batch size x state dimension]
566          covariance; [batch size x state dimension x state dimension]
567          previous_times; [batch size]
568    Returns:
569      Updated state taking the exogenous regressors into account (with higher
570      uncertainty than the input state).
571
572    """
573    start_mean, start_covariance, previous_times = state
574    with variable_scope.variable_scope("exogenous_noise_increasing_mean"):
575      mean_addition = layers.fully_connected(
576          exogenous_values,
577          tensor_shape.dimension_value(start_mean.shape[1]), activation_fn=None)
578    state_dimension = tensor_shape.dimension_value(start_covariance.shape[1])
579    with variable_scope.variable_scope("exogenous_noise_increasing_covariance"):
580      covariance_addition = (
581          math_utils.transform_to_covariance_matrices(
582              exogenous_values, state_dimension))
583    return (start_mean + mean_addition,
584            start_covariance + covariance_addition,
585            previous_times)
586
587  def _exogenous_input_step(
588      self, current_times, current_exogenous_regressors, state):
589    """Update state with exogenous regressors.
590
591    Allows both increases and decreases in uncertainty.
592
593    Args:
594      current_times: A [batch size] Tensor of times for the exogenous values
595          being input.
596      current_exogenous_regressors: A [batch size x exogenous input dimension]
597          Tensor of exogenous values for each part of the batch.
598      state: A tuple of (mean, covariance, previous_times) having shapes
599          mean; [batch size x state dimension]
600          covariance; [batch size x state dimension x state dimension]
601          previous_times; [batch size]
602    Returns:
603      Updated state taking the exogenous regressors into account.
604    """
605    if self._configuration.exogenous_noise_decreases:
606      state = self._exogenous_noise_decreasing(
607          current_times, current_exogenous_regressors, state)
608    if self._configuration.exogenous_noise_increases:
609      state = self._exogenous_noise_increasing(
610          current_times, current_exogenous_regressors, state)
611    return state
612
613  def _loss_additions(self, times, values, mode):
614    """Add regularization during training."""
615    if mode == estimator_lib.ModeKeys.TRAIN:
616      if (self._input_statistics is not None
617          and self._configuration.bayesian_prior_weighting):
618        normalization = 1. / math_ops.cast(
619            self._input_statistics.total_observation_count, self.dtype)
620      else:
621        # If there is no total observation count recorded, or if we are not
622        # doing a Bayesian prior weighting, assumes/pretends that the full
623        # dataset size is the window size.
624        normalization = 1. / math_ops.cast(
625            array_ops.shape(times)[1], self.dtype)
626      transition_contribution = ops.convert_to_tensor(
627          self._configuration.covariance_prior_fn(
628              self.state_transition_noise_covariance),
629          dtype=self.dtype)
630      if (self._configuration.use_observation_noise
631          and self._observation_noise_covariance is not None):
632        observation_contribution = ops.convert_to_tensor(
633            self._configuration.covariance_prior_fn(
634                self._observation_noise_covariance),
635            dtype=self.dtype)
636        regularization_sum = transition_contribution + observation_contribution
637      else:
638        regularization_sum = transition_contribution
639      return -normalization * regularization_sum
640    else:
641      return array_ops.zeros([], dtype=self.dtype)
642
643  def _variable_observation_transition_tradeoff_log(self):
644    """Define a variable to trade off observation and transition noise."""
645    return variable_scope.get_variable(
646        name="observation_transition_tradeoff_log_scale",
647        initializer=constant_op.constant(
648            -self._configuration.transition_covariance_initial_log_scale_bias,
649            dtype=self.dtype),
650        dtype=self.dtype)
651
652  def _define_parameters(self, observation_transition_tradeoff_log=None):
653    """Define extra model-specific parameters.
654
655    Models should wrap any variables defined here in the model's variable scope.
656
657    Args:
658      observation_transition_tradeoff_log: An ensemble-global parameter
659        controlling the tradeoff between observation noise and transition
660        noise. If its value is not None, component transition noise should scale
661        with e^-observation_transition_tradeoff_log.
662    """
663    with variable_scope.variable_scope(self._variable_scope):
664      # A scalar which allows the optimizer to quickly shift from observation
665      # noise to transition noise (this value is subtracted from log transition
666      # noise and added to log observation noise).
667      if observation_transition_tradeoff_log is None:
668        self._observation_transition_tradeoff_log_scale = (
669            self._variable_observation_transition_tradeoff_log())
670      else:
671        self._observation_transition_tradeoff_log_scale = (
672            observation_transition_tradeoff_log)
673      self.state_transition_noise_covariance = (
674          self.get_state_transition_noise_covariance())
675
676  def _set_input_statistics(self, input_statistics=None):
677    super(StateSpaceModel, self).initialize_graph(
678        input_statistics=input_statistics)
679
680  def initialize_graph(self, input_statistics=None):
681    """Define variables and ops relevant to the top-level model in an ensemble.
682
683    For generic model parameters, _define_parameters() is called recursively on
684    all members of an ensemble.
685
686    Args:
687      input_statistics: A math_utils.InputStatistics object containing input
688          statistics. If None, data-independent defaults are used, which may
689          result in longer or unstable training.
690    """
691    self._set_input_statistics(input_statistics=input_statistics)
692    self._define_parameters()
693    with variable_scope.variable_scope(self._variable_scope):
694      self._observation_noise_covariance = ops.convert_to_tensor(
695          self.get_observation_noise_covariance(), dtype=self.dtype)
696    self._kalman_filter = kalman_filter.KalmanFilter(dtype=self.dtype)
697    (self.prior_state_mean,
698     self.prior_state_var) = self._make_priors()
699
700  def _make_priors(self):
701    """Creates and returns model priors."""
702    prior_state_covariance = self.get_prior_covariance()
703    prior_state_mean = self.get_prior_mean()
704    return (prior_state_mean, prior_state_covariance)
705
706  def get_prior_covariance(self):
707    """Constructs a variable prior covariance with data-based initialization.
708
709    Models should wrap any variables defined here in the model's variable scope.
710
711    Returns:
712      A two-dimensional [state dimension, state dimension] floating point Tensor
713      with a (positive definite) prior state covariance matrix.
714    """
715    with variable_scope.variable_scope(self._variable_scope):
716      state_dimension = ops.convert_to_tensor(
717          self.get_state_transition()).get_shape().dims[0].value
718      if self._configuration.trainable_start_state:
719        base_covariance = math_utils.variable_covariance_matrix(
720            state_dimension, "prior_state_var",
721            dtype=self.dtype)
722      else:
723        return linalg_ops.eye(state_dimension, dtype=self.dtype)
724      if self._input_statistics is not None:
725        # Make sure initial latent value uncertainty is at least on the same
726        # scale as noise in the data.
727        covariance_multiplier = math_ops.reduce_max(
728            self._scale_variance(
729                self._input_statistics.series_start_moments.variance))
730        return base_covariance * gen_math_ops.maximum(
731            covariance_multiplier, 1.0)
732      else:
733        return base_covariance
734
735  def get_prior_mean(self):
736    """Constructs a Variable-parameterized prior mean.
737
738    Models should wrap any variables defined here in the model's variable scope.
739
740    Returns:
741      A one-dimensional floating point Tensor with shape [state dimension]
742      indicating the prior mean.
743    """
744    with variable_scope.variable_scope(self._variable_scope):
745      state_transition = ops.convert_to_tensor(
746          self.get_state_transition(), dtype=self.dtype)
747      state_dimension = state_transition.get_shape().dims[0].value
748      return variable_scope.get_variable(
749          name="prior_state_mean",
750          shape=[state_dimension],
751          dtype=self.dtype,
752          trainable=self._configuration.trainable_start_state)
753
754  # TODO(allenl): It would be nice if the generation were done with TensorFlow
755  # ops, and if the model parameters were somehow set instead of being passed
756  # around in a dictionary. Maybe unconditional generation should be through a
757  # special set of initializers?
758  def random_model_parameters(self, seed=None):
759    if self.num_features != 1:
760      raise NotImplementedError("Generation for multivariate state space models"
761                                " is not currently implemented.")
762    if seed:
763      numpy.random.seed(seed)
764    state_dimension, noise_dimension = ops.convert_to_tensor(
765        self.get_noise_transform()).get_shape().as_list()
766    transition_var = 1.0 / numpy.random.gamma(shape=10., scale=10.,
767                                              size=[noise_dimension])
768    initial_state = numpy.random.normal(size=[state_dimension])
769    params_dict = {}
770    if self.prior_state_mean is not None:
771      params_dict[self.prior_state_mean] = initial_state
772    if self.state_transition_noise_covariance is not None:
773      params_dict[self.state_transition_noise_covariance] = numpy.diag(
774          transition_var)
775    if self.prior_state_var is not None:
776      params_dict[self.prior_state_var] = numpy.zeros(
777          [state_dimension, state_dimension])
778    if self._configuration.use_observation_noise:
779      observation_var = 1.0 / numpy.random.gamma(shape=4, scale=4)
780      params_dict[self._observation_noise_covariance] = [[observation_var]]
781    return params_dict
782
783  def generate(self, number_of_series, series_length,
784               model_parameters=None, seed=None, add_observation_noise=None):
785    if seed is not None:
786      numpy.random.seed(seed)
787    if self.num_features != 1:
788      raise NotImplementedError("Generation for multivariate state space models"
789                                " is not currently implemented.")
790    if add_observation_noise is None:
791      add_observation_noise = self._configuration.use_observation_noise
792    if model_parameters is None:
793      model_parameters = {}
794    transitions = ops.convert_to_tensor(
795        self.get_state_transition(), dtype=self.dtype).eval(
796            feed_dict=model_parameters)
797    noise_transform = ops.convert_to_tensor(self.get_noise_transform()).eval(
798        feed_dict=model_parameters)
799
800    noise_dimension = noise_transform.shape[1]
801    get_passed_or_trained_value = model_utils.parameter_switch(model_parameters)
802    transition_var = numpy.diag(get_passed_or_trained_value(
803        self.state_transition_noise_covariance))
804    transition_std = numpy.sqrt(transition_var)
805    if add_observation_noise:
806      observation_var = get_passed_or_trained_value(
807          self._observation_noise_covariance)[0][0]
808      observation_std = numpy.sqrt(observation_var)
809    initial_state = get_passed_or_trained_value(self.prior_state_mean)
810    current_state = numpy.tile(numpy.expand_dims(initial_state, 0),
811                               [number_of_series, 1])
812    observations = numpy.zeros([number_of_series, series_length])
813    observation_models = self.get_broadcasted_observation_model(
814        times=math_ops.range(series_length)).eval(feed_dict=model_parameters)
815    for timestep, observation_model in enumerate(observation_models):
816      current_state = numpy.dot(current_state, transitions.T)
817      current_state += numpy.dot(
818          numpy.random.normal(
819              loc=numpy.zeros([number_of_series, noise_dimension]),
820              scale=numpy.tile(numpy.expand_dims(transition_std, 0),
821                               [number_of_series, 1])),
822          noise_transform.T)
823      observation_mean = numpy.dot(current_state, observation_model[0].T)
824      if add_observation_noise:
825        observations[:, timestep] = numpy.random.normal(loc=observation_mean,
826                                                        scale=observation_std)
827      else:
828        observations[:, timestep] = observation_mean
829    observations = numpy.expand_dims(observations, -1)
830    times = numpy.tile(
831        numpy.expand_dims(numpy.arange(observations.shape[1]), 0),
832        [observations.shape[0], 1])
833    return {TrainEvalFeatures.TIMES: times,
834            TrainEvalFeatures.VALUES: observations}
835
836  @abc.abstractmethod
837  def get_state_transition(self):
838    """Specifies the state transition model to use.
839
840    Returns:
841      A [state dimension x state dimension] Tensor specifying how states
842      transition from one timestep to the next.
843    """
844    pass
845
846  @abc.abstractmethod
847  def get_noise_transform(self):
848    """Specifies the noise transition model to use.
849
850    Returns:
851      A [state dimension x state noise dimension] Tensor specifying how noise
852      (generated with shape [state noise dimension]) affects the model's state.
853    """
854    pass
855
856  @abc.abstractmethod
857  def get_observation_model(self, times):
858    """Specifies the observation model to use.
859
860    Args:
861      times: A [batch dimension] int32 Tensor with times for each part of the
862          batch, on which the observation model can depend.
863    Returns:
864      This function, when overridden, has three possible return values:
865        - A [state dimension] Tensor with a static, univariate observation
866          model.
867        - A [self.num_features x state dimension] static, multivariate model.
868        - A [batch dimension x self.num_features x state dimension] observation
869          model, which may depend on `times`.
870      See get_broadcasted_observation_model for details of the broadcasting.
871    """
872    pass
873
874  def get_broadcasted_observation_model(self, times):
875    """Broadcast this model's observation model if necessary.
876
877    The model can define a univariate observation model which will be broadcast
878    over both self.num_features and the batch dimension of `times`.
879
880    The model can define a multi-variate observation model which does not depend
881    on `times`, and it will be broadcast over the batch dimension of `times`.
882
883    Finally, the model can define a multi-variate observation model with a batch
884    dimension, which will not be broadcast.
885
886    Args:
887      times: A [batch dimension] int32 Tensor with times for each part of the
888          batch, on which the observation model can depend.
889    Returns:
890      A [batch dimension x self.num_features x state dimension] Tensor
891      specifying the observation model to use for each time in `times` and each
892      feature.
893    """
894    unbroadcasted_model = ops.convert_to_tensor(
895        self.get_observation_model(times), dtype=self.dtype)
896    unbroadcasted_shape = (unbroadcasted_model.get_shape()
897                           .with_rank_at_least(1).with_rank_at_most(3))
898    if unbroadcasted_shape.ndims is None:
899      # Pass through fully undefined shapes, but make sure they're rank 3 at
900      # graph eval time
901      assert_op = control_flow_ops.Assert(
902          math_ops.equal(array_ops.rank(unbroadcasted_model), 3),
903          [array_ops.shape(unbroadcasted_model)])
904      with ops.control_dependencies([assert_op]):
905        return array_ops.identity(unbroadcasted_model)
906    if unbroadcasted_shape.ndims == 1:
907      # Unbroadcasted shape [state dimension]
908      broadcasted_model = array_ops.tile(
909          array_ops.reshape(tensor=unbroadcasted_model, shape=[1, 1, -1]),
910          [array_ops.shape(times)[0], self.num_features, 1])
911    elif unbroadcasted_shape.ndims == 2:
912      # Unbroadcasted shape [num features x state dimension]
913      broadcasted_model = array_ops.tile(
914          array_ops.expand_dims(unbroadcasted_model, axis=0),
915          [array_ops.shape(times)[0], 1, 1])
916    elif unbroadcasted_shape.ndims == 3:
917      broadcasted_model = unbroadcasted_model
918    broadcasted_model.get_shape().assert_has_rank(3)
919    return broadcasted_model
920
921  def get_state_transition_noise_covariance(
922      self, minimum_initial_variance=1e-5):
923    state_noise_transform = ops.convert_to_tensor(
924        self.get_noise_transform(), dtype=self.dtype)
925    state_noise_dimension = state_noise_transform.get_shape().dims[1].value
926    if self._input_statistics is not None:
927      feature_variance = self._scale_variance(
928          self._input_statistics.series_start_moments.variance)
929      initial_transition_noise_scale = math_ops.log(
930          gen_math_ops.maximum(
931              math_ops.reduce_mean(feature_variance) / math_ops.cast(
932                  self._input_statistics.total_observation_count, self.dtype),
933              minimum_initial_variance))
934    else:
935      initial_transition_noise_scale = 0.
936    # Generally high transition noise is undesirable; we want to set it quite
937    # low to start so that we don't need too much training to get to good
938    # solutions (i.e. with confident predictions into the future if possible),
939    # but not so low that training can't yield a high transition noise if the
940    # data demands it.
941    initial_transition_noise_scale -= (
942        self._observation_transition_tradeoff_log_scale)
943    return math_utils.variable_covariance_matrix(
944        state_noise_dimension, "state_transition_noise",
945        dtype=self.dtype,
946        initial_overall_scale_log=initial_transition_noise_scale)
947
948  def get_observation_noise_covariance(self, minimum_initial_variance=1e-5):
949    if self._configuration.use_observation_noise:
950      if self._input_statistics is not None:
951        # Get variance across the first few values in each batch for each
952        # feature, for an initial observation noise (over-)estimate.
953        feature_variance = self._scale_variance(
954            self._input_statistics.series_start_moments.variance)
955      else:
956        feature_variance = None
957      if feature_variance is not None:
958        feature_variance = gen_math_ops.maximum(feature_variance,
959                                                minimum_initial_variance)
960      return math_utils.variable_covariance_matrix(
961          size=self.num_features,
962          dtype=self.dtype,
963          name="observation_noise_covariance",
964          initial_diagonal_values=feature_variance,
965          initial_overall_scale_log=(
966              self._observation_transition_tradeoff_log_scale))
967    else:
968      return array_ops.zeros(
969          shape=[self.num_features, self.num_features],
970          name="observation_noise_covariance",
971          dtype=self.dtype)
972
973  def get_start_state(self):
974    """Defines and returns a non-batched prior state and covariance."""
975    # TODO(allenl,vitalyk): Add an option for non-Gaussian priors once extended
976    # Kalman filtering is implemented (ideally any Distribution object).
977    if self._input_statistics is not None:
978      start_time = self._input_statistics.start_time
979    else:
980      start_time = array_ops.zeros([], dtype=dtypes.int64)
981    return (self.prior_state_mean,
982            self.prior_state_var,
983            start_time - 1)
984
985  def get_features_for_timesteps(self, timesteps):
986    """Get features for a batch of timesteps. Default to no features."""
987    return array_ops.zeros([array_ops.shape(timesteps)[0], 0], dtype=self.dtype)
988
989
990class StateSpaceEnsemble(StateSpaceModel):
991  """Base class for combinations of state space models."""
992
993  def __init__(self, ensemble_members, configuration):
994    """Initialize the ensemble by specifying its members.
995
996    Args:
997      ensemble_members: A list of StateSpaceModel objects which will be included
998          in this ensemble.
999      configuration: A StateSpaceModelConfiguration object.
1000    """
1001    self._ensemble_members = ensemble_members
1002    super(StateSpaceEnsemble, self).__init__(configuration=configuration)
1003
1004  def _set_input_statistics(self, input_statistics):
1005    super(StateSpaceEnsemble, self)._set_input_statistics(input_statistics)
1006    for member in self._ensemble_members:
1007      member._set_input_statistics(input_statistics)  # pylint: disable=protected-access
1008
1009  def _loss_additions(self, times, values, mode):
1010    # Allow sub-models to regularize
1011    return (super(StateSpaceEnsemble, self)._loss_additions(
1012        times, values, mode) + math_ops.add_n([
1013            member._loss_additions(times, values, mode)  # pylint: disable=protected-access
1014            for member in self._ensemble_members
1015        ]))
1016
1017  def _compute_blocked(self, member_fn, name):
1018    with variable_scope.variable_scope(self._variable_scope):
1019      return math_utils.block_diagonal(
1020          [member_fn(member)
1021           for member in self._ensemble_members],
1022          dtype=self.dtype,
1023          name=name)
1024
1025  def transition_to_powers(self, powers):
1026    return self._compute_blocked(
1027        member_fn=lambda member: member.transition_to_powers(powers),
1028        name="ensemble_transition_to_powers")
1029
1030  def _define_parameters(self, observation_transition_tradeoff_log=None):
1031    with variable_scope.variable_scope(self._variable_scope):
1032      if observation_transition_tradeoff_log is None:
1033        # Define the tradeoff parameter between observation and transition noise
1034        # once for the whole ensemble, and pass it down to members.
1035        observation_transition_tradeoff_log = (
1036            self._variable_observation_transition_tradeoff_log())
1037      for member in self._ensemble_members:
1038        member._define_parameters(observation_transition_tradeoff_log=(  # pylint: disable=protected-access
1039            observation_transition_tradeoff_log))
1040      super(StateSpaceEnsemble, self)._define_parameters(
1041          observation_transition_tradeoff_log
1042          =observation_transition_tradeoff_log)
1043
1044  def random_model_parameters(self, seed=None):
1045    param_union = {}
1046    for i, member in enumerate(self._ensemble_members):
1047      member_params = member.random_model_parameters(
1048          seed=seed + i if seed else None)
1049      param_union.update(member_params)
1050    param_union.update(
1051        super(StateSpaceEnsemble, self).random_model_parameters(seed=seed))
1052    return param_union
1053
1054  def get_prior_mean(self):
1055    return array_ops.concat(
1056        values=[member.get_prior_mean() for member in self._ensemble_members],
1057        axis=0,
1058        name="ensemble_prior_state_mean")
1059
1060  def get_state_transition(self):
1061    return self._compute_blocked(
1062        member_fn=
1063        lambda member: member.get_state_transition(),
1064        name="ensemble_state_transition")
1065
1066  def get_noise_transform(self):
1067    return self._compute_blocked(
1068        member_fn=
1069        lambda member: member.get_noise_transform(),
1070        name="ensemble_noise_transform")
1071
1072  def get_observation_model(self, times):
1073    raise NotImplementedError("No un-broadcasted observation model defined for"
1074                              " ensembles.")
1075
1076  def get_broadcasted_observation_model(self, times):
1077    """Computes a combined observation model based on member models.
1078
1079    The effect is that predicted observations from each model are summed.
1080
1081    Args:
1082      times: A [batch dimension] int32 Tensor with times for each part of the
1083          batch, on which member observation models can depend.
1084    Returns:
1085      A [batch dimension x num features x combined state dimension] Tensor with
1086      the combined observation model.
1087    """
1088    member_observation_models = [
1089        ops.convert_to_tensor(
1090            member.get_broadcasted_observation_model(times), dtype=self.dtype)
1091        for member in self._ensemble_members
1092    ]
1093    return array_ops.concat(values=member_observation_models, axis=2)
1094
1095
1096class StateSpaceIndependentEnsemble(StateSpaceEnsemble):
1097  """Implements ensembles of independent state space models.
1098
1099  Useful for fitting multiple independent state space models together while
1100  keeping their specifications decoupled. The "ensemble" is simply a state space
1101  model with the observation models of its members concatenated, and the
1102  transition matrices and noise transforms stacked in block-diagonal
1103  matrices. This means that the dimensionality of the ensemble's state is the
1104  sum of those of its components, which can lead to slow and memory-intensive
1105  training and inference as the posterior (shape [state dimension x state
1106  dimension]) gets large.
1107
1108  Each individual model j's state at time t is defined by:
1109
1110  state[t, j] = StateTransition[j] * state[t-1, j]
1111      + NoiseTransform[j] * StateNoise[t, j]
1112  StateNoise[t, j] ~ Gaussian(0, StateNoiseCovariance[j])
1113
1114  and the ensemble observation model is:
1115
1116  observation[t] = Sum { ObservationModel[j] * state[t, j] }
1117      + ObservationNoise[t]
1118  ObservationNoise[t] ~ Gaussian(0, ObservationNoiseCovariance)
1119  """
1120
1121  def transition_power_noise_accumulator(self, num_steps):
1122    return self._compute_blocked(
1123        member_fn=lambda m: m.transition_power_noise_accumulator(num_steps),
1124        name="ensemble_power_noise_accumulator")
1125
1126  def get_prior_covariance(self):
1127    """Construct the ensemble prior covariance based on component models."""
1128    return self._compute_blocked(
1129        member_fn=
1130        lambda member: member.get_prior_covariance(),
1131        name="ensemble_prior_state_covariance")
1132
1133  def get_state_transition_noise_covariance(self):
1134    """Construct the ensemble transition noise covariance from components."""
1135    return self._compute_blocked(
1136        member_fn=
1137        lambda member: member.state_transition_noise_covariance,
1138        name="ensemble_state_transition_noise")
1139
1140
1141# TODO(allenl): It would be nice to have replicated feature models which are
1142# identical batched together to reduce the graph size.
1143# TODO(allenl): Support for sharing M independent models across N features, with
1144# N > M.
1145# TODO(allenl): Stack component prior covariances while allowing cross-model
1146# correlations to be learned (currently a full covariance prior is learned, but
1147# custom component model covariances are not used).
1148class StateSpaceCorrelatedFeaturesEnsemble(StateSpaceEnsemble):
1149  """An correlated ensemble where each model represents a feature.
1150
1151  Unlike `StateSpaceIndependentEnsemble`, a full state transition noise
1152  covariance matrix is learned for this ensemble; the models are not assumed to
1153  be independent. Rather than concatenating observation models (i.e. summing the
1154  contributions of each model to each feature),
1155  StateSpaceCorrelatedFeaturesEnsemble stacks observation models diagonally,
1156  meaning that each model corresponds to one feature of the series.
1157
1158  Behaves like (and is) a single state space model where:
1159
1160  StateTransition = Diag(StateTransition[j] for models j)
1161  ObservationModel = Diag(ObservationModel[j] for models j)
1162
1163  Note that each ObservationModel[j] is a [1 x S_j] matrix (S_j being the state
1164  dimension of model j), i.e. a univariate model. The combined model is
1165  multivariate, the number of features of the series being equal to the number
1166  of component models in the ensemble.
1167  """
1168
1169  def __init__(self, ensemble_members, configuration):
1170    """Specify the ensemble's configuration and component models.
1171
1172    Args:
1173      ensemble_members: A list of `StateSpaceModel` objects, with length equal
1174        to `configuration.num_features`. Each of these models, which must be
1175        univariate, corresponds to a single feature of the time series.
1176      configuration: A StateSpaceModelConfiguration object.
1177    Raises:
1178      ValueError: If the length of `ensemble_members` does not equal the number
1179        of features in the series, or any component is not univariate.
1180    """
1181    if len(ensemble_members) != configuration.num_features:
1182      raise ValueError(
1183          "The number of members in a StateSpaceCorrelatedFeaturesEnsemble "
1184          "must equal the number of features in the time series.")
1185    for member in ensemble_members:
1186      if member.num_features != 1:
1187        raise ValueError(
1188            "StateSpaceCorrelatedFeaturesEnsemble components must be "
1189            "univariate.")
1190    super(StateSpaceCorrelatedFeaturesEnsemble, self).__init__(
1191        ensemble_members=ensemble_members, configuration=configuration)
1192
1193  def transition_power_noise_accumulator(self, num_steps):
1194    """Use a noise accumulator special case when possible."""
1195    if len(self._ensemble_members) == 1:
1196      # If this is a univariate series, we should use the special casing built
1197      # into the single component model.
1198      return self._ensemble_members[0].transition_power_noise_accumulator(
1199          num_steps=num_steps)
1200    # If we have multiple features, and therefore multiple models, we have
1201    # introduced correlations which make noise accumulation more
1202    # complicated. Here we fall back to the general case, since we can't just
1203    # aggregate member special cases.
1204    return super(StateSpaceCorrelatedFeaturesEnsemble,
1205                 self).transition_power_noise_accumulator(num_steps=num_steps)
1206
1207  def get_broadcasted_observation_model(self, times):
1208    """Stack observation models diagonally."""
1209    def _member_observation_model(member):
1210      return ops.convert_to_tensor(
1211          member.get_broadcasted_observation_model(times), dtype=self.dtype)
1212    return self._compute_blocked(member_fn=_member_observation_model,
1213                                 name="feature_ensemble_observation_model")
1214