1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base class for time series models."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22import collections
23
24import six
25
26from tensorflow.contrib.timeseries.python.timeseries import math_utils
27from tensorflow.contrib.timeseries.python.timeseries.feature_keys import PredictionFeatures
28from tensorflow.contrib.timeseries.python.timeseries.feature_keys import TrainEvalFeatures
29
30from tensorflow.python.feature_column import feature_column_lib as feature_column
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import tensor_shape
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops import control_flow_ops
36from tensorflow.python.ops import math_ops
37from tensorflow.python.ops import parsing_ops
38from tensorflow.python.ops import tensor_array_ops
39from tensorflow.python.ops import variable_scope
40
41from tensorflow.python.util import nest
42
43
44ModelOutputs = collections.namedtuple(  # pylint: disable=invalid-name
45    typename="ModelOutputs",
46    field_names=[
47        "loss",  # The scalar value to be minimized during training.
48        "end_state",  # A nested tuple specifying the model's state after
49                      # running on the specified data
50        "predictions",  # A dictionary of predictions, each with shape prefixed
51                        # by the shape of `prediction_times`.
52        "prediction_times"  # A [batch size x window size] integer Tensor
53                            # indicating times for which values in `predictions`
54                            # were computed.
55    ])
56
57
58@six.add_metaclass(abc.ABCMeta)
59class TimeSeriesModel(object):
60  """Base class for creating generative time series models."""
61
62  def __init__(self,
63               num_features,
64               exogenous_feature_columns=None,
65               dtype=dtypes.float32):
66    """Constructor for generative models.
67
68    Args:
69      num_features: Number of features for the time series
70      exogenous_feature_columns: A list of `tf.feature_column`s (for example
71           `tf.feature_column.embedding_column`) corresponding to exogenous
72           features which provide extra information to the model but are not
73           part of the series to be predicted. Passed to
74           `tf.feature_column.input_layer`.
75      dtype: The floating point datatype to use.
76    """
77    if exogenous_feature_columns:
78      self._exogenous_feature_columns = exogenous_feature_columns
79    else:
80      self._exogenous_feature_columns = []
81    self.num_features = num_features
82    self.dtype = dtype
83    self._input_statistics = None
84    self._graph_initialized = False
85    self._stats_means = None
86    self._stats_sigmas = None
87
88  @property
89  def exogenous_feature_columns(self):
90    """`tf.feature_colum`s for features which are not predicted."""
91    return self._exogenous_feature_columns
92
93  # TODO(allenl): Move more of the generic machinery for generating and
94  # predicting into TimeSeriesModel, and possibly share it between generate()
95  # and predict()
96  def generate(self, number_of_series, series_length,
97               model_parameters=None, seed=None):
98    """Sample synthetic data from model parameters, with optional substitutions.
99
100    Returns `number_of_series` possible sequences of future values, sampled from
101    the generative model with each conditioned on the previous. Samples are
102    based on trained parameters, except for those parameters explicitly
103    overridden in `model_parameters`.
104
105    For distributions over future observations, see predict().
106
107    Args:
108      number_of_series: Number of time series to create.
109      series_length: Length of each time series.
110      model_parameters: A dictionary mapping model parameters to values, which
111          replace trained parameters when generating data.
112      seed: If specified, return deterministic time series according to this
113          value.
114    Returns:
115      A dictionary with keys TrainEvalFeatures.TIMES (mapping to an array with
116      shape [number_of_series, series_length]) and TrainEvalFeatures.VALUES
117      (mapping to an array with shape [number_of_series, series_length,
118      num_features]).
119    """
120    raise NotImplementedError("This model does not support generation.")
121
122  def initialize_graph(self, input_statistics=None):
123    """Define ops for the model, not depending on any previously defined ops.
124
125    Args:
126      input_statistics: A math_utils.InputStatistics object containing input
127          statistics. If None, data-independent defaults are used, which may
128          result in longer or unstable training.
129    """
130    self._graph_initialized = True
131    self._input_statistics = input_statistics
132    if self._input_statistics:
133      self._stats_means, variances = (
134          self._input_statistics.overall_feature_moments)
135      self._stats_sigmas = math_ops.sqrt(variances)
136
137  def _scale_data(self, data):
138    """Scale data according to stats (input scale -> model scale)."""
139    if self._input_statistics is not None:
140      return (data - self._stats_means) / self._stats_sigmas
141    else:
142      return data
143
144  def _scale_variance(self, variance):
145    """Scale variances according to stats (input scale -> model scale)."""
146    if self._input_statistics is not None:
147      return variance / self._input_statistics.overall_feature_moments.variance
148    else:
149      return variance
150
151  def _scale_back_data(self, data):
152    """Scale back data according to stats (model scale -> input scale)."""
153    if self._input_statistics is not None:
154      return (data * self._stats_sigmas) + self._stats_means
155    else:
156      return data
157
158  def _scale_back_variance(self, variance):
159    """Scale back variances according to stats (model scale -> input scale)."""
160    if self._input_statistics is not None:
161      return variance * self._input_statistics.overall_feature_moments.variance
162    else:
163      return variance
164
165  def _check_graph_initialized(self):
166    if not self._graph_initialized:
167      raise ValueError(
168          "TimeSeriesModels require initialize_graph() to be called before "
169          "use. This defines variables and ops in the default graph, and "
170          "allows Tensor-valued input statistics to be specified.")
171
172  def define_loss(self, features, mode):
173    """Default loss definition with state replicated across a batch.
174
175    Time series passed to this model have a batch dimension, and each series in
176    a batch can be operated on in parallel. This loss definition assumes that
177    each element of the batch represents an independent sample conditioned on
178    the same initial state (i.e. it is simply replicated across the batch). A
179    batch size of one provides sequential operations on a single time series.
180
181    More complex processing may operate instead on get_start_state() and
182    get_batch_loss() directly.
183
184    Args:
185      features: A dictionary (such as is produced by a chunker) with at minimum
186        the following key/value pairs (others corresponding to the
187        `exogenous_feature_columns` argument to `__init__` may be included
188        representing exogenous regressors):
189        TrainEvalFeatures.TIMES: A [batch size x window size] integer Tensor
190            with times for each observation. If there is no artificial chunking,
191            the window size is simply the length of the time series.
192        TrainEvalFeatures.VALUES: A [batch size x window size x num features]
193            Tensor with values for each observation.
194      mode: The tf.estimator.ModeKeys mode to use (TRAIN, EVAL). For INFER,
195        see predict().
196    Returns:
197      A ModelOutputs object.
198    """
199    self._check_graph_initialized()
200    start_state = math_utils.replicate_state(
201        start_state=self.get_start_state(),
202        batch_size=array_ops.shape(features[TrainEvalFeatures.TIMES])[0])
203    return self.get_batch_loss(features=features, mode=mode, state=start_state)
204
205  # TODO(vitalyk,allenl): Better documentation surrounding options for chunking,
206  # references to papers, etc.
207  @abc.abstractmethod
208  def get_start_state(self):
209    """Returns a tuple of state for the start of the time series.
210
211    For example, a mean and covariance. State should not have a batch
212    dimension, and will often be TensorFlow Variables to be learned along with
213    the rest of the model parameters.
214    """
215    pass
216
217  @abc.abstractmethod
218  def get_batch_loss(self, features, mode, state):
219    """Return predictions, losses, and end state for a time series.
220
221    Args:
222      features: A dictionary with times, values, and (optionally) exogenous
223          regressors. See `define_loss`.
224      mode: The tf.estimator.ModeKeys mode to use (TRAIN, EVAL, INFER).
225      state: Model-dependent state, each with size [batch size x ...]. The
226          number and type will typically be fixed by the model (for example a
227          mean and variance).
228    Returns:
229      A ModelOutputs object.
230    """
231    pass
232
233  @abc.abstractmethod
234  def predict(self, features):
235    """Returns predictions of future observations given an initial state.
236
237    Computes distributions for future observations. For sampled draws from the
238    model where each is conditioned on the previous, see generate().
239
240    Args:
241      features: A dictionary with at minimum the following key/value pairs
242        (others corresponding to the `exogenous_feature_columns` argument to
243        `__init__` may be included representing exogenous regressors):
244        PredictionFeatures.TIMES: A [batch size x window size] Tensor with
245          times to make predictions for. Times must be increasing within each
246          part of the batch, and must be greater than the last time `state` was
247          updated.
248        PredictionFeatures.STATE_TUPLE: Model-dependent state, each with size
249          [batch size x ...]. The number and type will typically be fixed by the
250          model (for example a mean and variance). Typically these will be the
251          end state returned by get_batch_loss, predicting beyond that data.
252    Returns:
253      A dictionary with model-dependent predictions corresponding to the
254      requested times. Keys indicate the type of prediction, and values have
255      shape [batch size x window size x ...]. For example state space models
256      return a "predicted_mean" and "predicted_covariance".
257    """
258    pass
259
260  def _get_exogenous_embedding_shape(self):
261    """Computes the shape of the vector returned by _process_exogenous_features.
262
263    Returns:
264      The shape as a list. Does not include a batch dimension.
265    """
266    if not self._exogenous_feature_columns:
267      return (0,)
268    with ops.Graph().as_default():
269      parsed_features = (
270          feature_column.make_parse_example_spec(
271              self._exogenous_feature_columns))
272      placeholder_features = parsing_ops.parse_example(
273          serialized=array_ops.placeholder(shape=[None], dtype=dtypes.string),
274          features=parsed_features)
275      embedded = feature_column.input_layer(
276          features=placeholder_features,
277          feature_columns=self._exogenous_feature_columns)
278      return embedded.get_shape().as_list()[1:]
279
280  def _process_exogenous_features(self, times, features):
281    """Create a single vector from exogenous features.
282
283    Args:
284      times: A [batch size, window size] vector of times for this batch,
285          primarily used to check the shape information of exogenous features.
286      features: A dictionary of exogenous features corresponding to the columns
287          in self._exogenous_feature_columns. Each value should have a shape
288          prefixed by [batch size, window size].
289    Returns:
290      A Tensor with shape [batch size, window size, exogenous dimension], where
291      the size of the exogenous dimension depends on the exogenous feature
292      columns passed to the model's constructor.
293    Raises:
294      ValueError: If an exogenous feature has an unknown rank.
295    """
296    if self._exogenous_feature_columns:
297      exogenous_features_single_batch_dimension = {}
298      for name, tensor in features.items():
299        if tensor.get_shape().ndims is None:
300          # input_from_feature_columns does not support completely unknown
301          # feature shapes, so we save on a bit of logic and provide a better
302          # error message by checking that here.
303          raise ValueError(
304              ("Features with unknown rank are not supported. Got shape {} for "
305               "feature {}.").format(tensor.get_shape(), name))
306        tensor_shape_dynamic = array_ops.shape(tensor)
307        tensor = array_ops.reshape(
308            tensor,
309            array_ops.concat([[tensor_shape_dynamic[0]
310                               * tensor_shape_dynamic[1]],
311                              tensor_shape_dynamic[2:]], axis=0))
312        # Avoid shape warnings when embedding "scalar" exogenous features (those
313        # with only batch and window dimensions); input_from_feature_columns
314        # expects input ranks to match the embedded rank.
315        if tensor.get_shape().ndims == 1 and tensor.dtype != dtypes.string:
316          exogenous_features_single_batch_dimension[name] = tensor[:, None]
317        else:
318          exogenous_features_single_batch_dimension[name] = tensor
319      embedded_exogenous_features_single_batch_dimension = (
320          feature_column.input_layer(
321              features=exogenous_features_single_batch_dimension,
322              feature_columns=self._exogenous_feature_columns,
323              trainable=True))
324      exogenous_regressors = array_ops.reshape(
325          embedded_exogenous_features_single_batch_dimension,
326          array_ops.concat(
327              [
328                  array_ops.shape(times), array_ops.shape(
329                      embedded_exogenous_features_single_batch_dimension)[1:]
330              ],
331              axis=0))
332      exogenous_regressors.set_shape(times.get_shape().concatenate(
333          embedded_exogenous_features_single_batch_dimension.get_shape()[1:]))
334      exogenous_regressors = math_ops.cast(
335          exogenous_regressors, dtype=self.dtype)
336    else:
337      # Not having any exogenous features is a special case so that models can
338      # avoid superfluous updates, which may not be free of side effects due to
339      # bias terms in transformations.
340      exogenous_regressors = None
341    return exogenous_regressors
342
343
344# TODO(allenl): Add a superclass of SequentialTimeSeriesModel which fuses
345# filtering/prediction/exogenous into one step, and move looping constructs to
346# that class.
347class SequentialTimeSeriesModel(TimeSeriesModel):
348  """Base class for recurrent generative models.
349
350  Models implementing this interface have three main functions, corresponding to
351  abstract methods:
352    _filtering_step: Updates state based on observations and computes a loss.
353    _prediction_step: Predicts a batch of observations and new model state.
354    _imputation_step: Updates model state across a gap.
355    _exogenous_input_step: Updates state to account for exogenous regressors.
356
357  Models may also specify a _window_initializer to prepare for a window of data.
358
359  See StateSpaceModel for a concrete example of a model implementing this
360  interface.
361
362  """
363
364  def __init__(self,
365               train_output_names,
366               predict_output_names,
367               num_features,
368               normalize_features=False,
369               dtype=dtypes.float32,
370               exogenous_feature_columns=None,
371               exogenous_update_condition=None,
372               static_unrolling_window_size_threshold=None):
373    """Initialize a SequentialTimeSeriesModel.
374
375    Args:
376      train_output_names: A list of products/predictions returned from
377          _filtering_step.
378      predict_output_names: A list of products/predictions returned from
379          _prediction_step.
380      num_features: Number of features for the time series
381      normalize_features: Boolean. If True, `values` are passed normalized to
382          the model (via self._scale_data). Scaling is done for the whole window
383          as a batch, which is slightly more efficient than scaling inside the
384          window loop. The model must then define _scale_back_predictions, which
385          may use _scale_back_data or _scale_back_variance to return predictions
386          to the input scale.
387      dtype: The floating point datatype to use.
388      exogenous_feature_columns: A list of `tf.feature_column`s objects. See
389          `TimeSeriesModel`.
390      exogenous_update_condition: A function taking two Tensor arguments `times`
391          (shape [batch size]) and `features` (a dictionary mapping exogenous
392          feature keys to Tensors with shapes [batch size, ...]) and returning a
393          boolean Tensor with shape [batch size] indicating whether state should
394          be updated using exogenous features for each part of the batch. Where
395          it is False, no exogenous update is performed. If None (default),
396          exogenous updates are always performed. Useful for avoiding "leaky"
397          frequent exogenous updates when sparse updates are desired. Called
398          only during graph construction.
399      static_unrolling_window_size_threshold: Controls whether a `tf.while_loop`
400          is used when looping over a window of data. If
401          `static_unrolling_window_size_threshold` is None, a `tf.while_loop` is
402          always used. Otherwise it must be an integer, and the graph is
403          replicated for each step taken whenever the window size is less than
404          or equal to this value (if the window size is available in the static
405          shape information of the TrainEvalFeatures.TIMES feature). Static
406          unrolling generally decreases the per-step time for small window/batch
407          sizes, but increases graph construction time.
408    """
409    super(SequentialTimeSeriesModel, self).__init__(
410        num_features=num_features, dtype=dtype,
411        exogenous_feature_columns=exogenous_feature_columns)
412    self._exogenous_update_condition = exogenous_update_condition
413    self._train_output_names = train_output_names
414    self._predict_output_names = predict_output_names
415    self._normalize_features = normalize_features
416    self._static_unrolling_window_size_threshold = (
417        static_unrolling_window_size_threshold)
418
419  def _scale_back_predictions(self, predictions):
420    """Return a window of predictions to input scale.
421
422    Args:
423      predictions: A dictionary mapping from prediction names to Tensors.
424    Returns:
425      A dictionary with values corrected for input normalization (e.g. with
426      self._scale_back_mean and possibly self._scale_back_variance). May be a
427      mutated version of the argument.
428    """
429    raise NotImplementedError(
430        "SequentialTimeSeriesModel normalized input data"
431        " (normalize_features=True), but no method was provided to transform "
432        "the predictions back to the input scale.")
433
434  @abc.abstractmethod
435  def _filtering_step(self, current_times, current_values, state, predictions):
436    """Compute a single-step loss for a batch of data.
437
438    Args:
439      current_times: A [batch size] Tensor of times for each observation.
440      current_values: A [batch size] Tensor of values for each observation.
441      state: Model state, updated to current_times.
442      predictions: The outputs of _prediction_step
443    Returns:
444      A tuple of (updated state, outputs):
445        updated state: Model state taking current_values into account.
446        outputs: A dictionary of Tensors with keys corresponding to
447            self._train_output_names, plus a special "loss" key. The value
448            corresponding to "loss" is minimized during training. Other outputs
449            may include one-step-ahead predictions, for example a predicted
450            location and scale.
451    """
452    pass
453
454  @abc.abstractmethod
455  def _prediction_step(self, current_times, state):
456    """Compute a batch of single-step predictions.
457
458    Args:
459      current_times: A [batch size] Tensor of times for each observation.
460      state: Model state, imputed to one step before current_times.
461    Returns:
462      A tuple of (updated state, outputs):
463        updated state: Model state updated to current_times.
464        outputs: A dictionary of Tensors with keys corresponding to
465            self._predict_output_names.
466    """
467    pass
468
469  @abc.abstractmethod
470  def _imputation_step(self, current_times, state):
471    """Update model state across missing values.
472
473    Called to prepare model state for _filtering_step and _prediction_step.
474
475    Args:
476      current_times: A [batch size] Tensor; state will be imputed up to, but not
477          including, these timesteps.
478      state: The pre-imputation model state, Tensors with shape [batch size x
479          ...].
480    Returns:
481      Updated/imputed model state, corresponding to `state`.
482    """
483    pass
484
485  @abc.abstractmethod
486  def _exogenous_input_step(
487      self, current_times, current_exogenous_regressors, state):
488    """Update state to account for exogenous regressors.
489
490    Args:
491      current_times: A [batch size] Tensor of times for the exogenous values
492          being input.
493      current_exogenous_regressors: A [batch size x exogenous input dimension]
494          Tensor of exogenous values for each part of the batch.
495      state: Model state, a possibly nested list of Tensors, each with shape
496          [batch size x ...].
497    Returns:
498      Updated model state, structure and shapes matching the `state` argument.
499    """
500    pass
501
502  # TODO(allenl): Move regularization to a separate object (optional and
503  # configurable)
504  def _loss_additions(self, times, values, mode):
505    """Additions to per-observation normalized loss, e.g. regularization.
506
507    Args:
508      times: A [batch size x window size] Tensor with times for each
509          observation.
510      values: A [batch size x window size x num features] Tensor with values for
511          each observation.
512      mode: The tf.estimator.ModeKeys mode to use (TRAIN, EVAL, INFER).
513    Returns:
514      A scalar value to add to the per-observation normalized loss.
515    """
516    del times, values, mode
517    return 0.
518
519  def _window_initializer(self, times, state):
520    """Prepare for training or prediction on a window of data.
521
522    Args:
523      times: A [batch size x window size] Tensor with times for each
524          observation.
525      state: Model-dependent state, each with size [batch size x ...]. The
526          number and type will typically be fixed by the model (for example a
527          mean and variance).
528    Returns:
529      Nothing
530    """
531    pass
532
533  def get_batch_loss(self, features, mode, state):
534    """Calls self._filtering_step. See TimeSeriesModel.get_batch_loss."""
535    per_observation_loss, state, outputs = self.per_step_batch_loss(
536        features, mode, state)
537    # per_step_batch_loss returns [batch size, window size, ...] state, whereas
538    # get_batch_loss is expected to return [batch size, ...] state for the last
539    # element of a window
540    state = nest.pack_sequence_as(
541        state,
542        [state_element[:, -1] for state_element in nest.flatten(state)])
543    outputs["observed"] = features[TrainEvalFeatures.VALUES]
544    return ModelOutputs(
545        loss=per_observation_loss,
546        end_state=state,
547        predictions=outputs,
548        prediction_times=features[TrainEvalFeatures.TIMES])
549
550  def _apply_exogenous_update(
551      self, current_times, step_number, state, raw_features,
552      embedded_exogenous_regressors):
553    """Performs a conditional state update based on exogenous features."""
554    if embedded_exogenous_regressors is None:
555      return state
556    else:
557      current_exogenous_regressors = embedded_exogenous_regressors[
558          :, step_number, :]
559      exogenous_updated_state = self._exogenous_input_step(
560          current_times=current_times,
561          current_exogenous_regressors=current_exogenous_regressors,
562          state=state)
563      if self._exogenous_update_condition is not None:
564        current_raw_exogenous_features = {
565            key: value[:, step_number] for key, value in raw_features.items()
566            if key not in [PredictionFeatures.STATE_TUPLE,
567                           TrainEvalFeatures.TIMES,
568                           TrainEvalFeatures.VALUES]}
569        conditionally_updated_state_flat = []
570        for updated_state_element, original_state_element in zip(
571            nest.flatten(exogenous_updated_state),
572            nest.flatten(state)):
573          conditionally_updated_state_flat.append(
574              array_ops.where(
575                  self._exogenous_update_condition(
576                      times=current_times,
577                      features=current_raw_exogenous_features),
578                  updated_state_element,
579                  original_state_element))
580        return nest.pack_sequence_as(state, conditionally_updated_state_flat)
581      else:
582        return exogenous_updated_state
583
584  def per_step_batch_loss(self, features, mode, state):
585    """Computes predictions, losses, and intermediate model states.
586
587    Args:
588      features: A dictionary with times, values, and (optionally) exogenous
589          regressors. See `define_loss`.
590      mode: The tf.estimator.ModeKeys mode to use (TRAIN, EVAL, INFER).
591      state: Model-dependent state, each with size [batch size x ...]. The
592          number and type will typically be fixed by the model (for example a
593          mean and variance).
594    Returns:
595      A tuple of (loss, filtered_states, predictions)
596        loss: Average loss values across the batch.
597        filtered_states: For each Tensor in `state` with shape [batch size x
598            ...], `filtered_states` has a Tensor with shape [batch size x window
599            size x ...] with filtered state for each part of the batch and
600            window.
601        predictions: A dictionary with model-dependent one-step-ahead (or
602            at-least-one-step-ahead with missing values) predictions, with keys
603            indicating the type of prediction and values having shape [batch
604            size x window size x ...]. For example state space models provide
605            "mean", "covariance", and "log_likelihood".
606
607    """
608    self._check_graph_initialized()
609    times = math_ops.cast(features[TrainEvalFeatures.TIMES], dtype=dtypes.int64)
610    values = math_ops.cast(features[TrainEvalFeatures.VALUES], dtype=self.dtype)
611    if self._normalize_features:
612      values = self._scale_data(values)
613    exogenous_regressors = self._process_exogenous_features(
614        times=times,
615        features={key: value for key, value in features.items()
616                  if key not in [TrainEvalFeatures.TIMES,
617                                 TrainEvalFeatures.VALUES]})
618    def _batch_loss_filtering_step(step_number, current_times, state):
619      """Make a prediction and update it based on data."""
620      current_values = values[:, step_number, :]
621      state = self._apply_exogenous_update(
622          step_number=step_number, current_times=current_times, state=state,
623          raw_features=features,
624          embedded_exogenous_regressors=exogenous_regressors)
625      predicted_state, predictions = self._prediction_step(
626          current_times=current_times,
627          state=state)
628      filtered_state, outputs = self._filtering_step(
629          current_times=current_times,
630          current_values=current_values,
631          state=predicted_state,
632          predictions=predictions)
633      return filtered_state, outputs
634    state, outputs = self._state_update_loop(
635        times=times, state=state, state_update_fn=_batch_loss_filtering_step,
636        outputs=["loss"] + self._train_output_names)
637    outputs["loss"].set_shape(times.get_shape())
638    loss_sum = math_ops.reduce_sum(outputs["loss"])
639    per_observation_loss = (loss_sum / math_ops.cast(
640        math_ops.reduce_prod(array_ops.shape(times)), dtype=self.dtype))
641    per_observation_loss += self._loss_additions(times, values, mode)
642    # Since we have window-level additions to the loss, its per-step value is
643    # misleading, so we avoid returning it.
644    del outputs["loss"]
645    if self._normalize_features:
646      outputs = self._scale_back_predictions(outputs)
647    return per_observation_loss, state, outputs
648
649  def predict(self, features):
650    """Calls self._prediction_step in a loop. See TimeSeriesModel.predict."""
651    predict_times = ops.convert_to_tensor(features[PredictionFeatures.TIMES],
652                                          dtypes.int64)
653    start_state = features[PredictionFeatures.STATE_TUPLE]
654    exogenous_regressors = self._process_exogenous_features(
655        times=predict_times,
656        features={
657            key: value
658            for key, value in features.items()
659            if key not in
660            [PredictionFeatures.TIMES, PredictionFeatures.STATE_TUPLE]
661        })
662    def _call_prediction_step(step_number, current_times, state):
663      state = self._apply_exogenous_update(
664          step_number=step_number, current_times=current_times, state=state,
665          raw_features=features,
666          embedded_exogenous_regressors=exogenous_regressors)
667      state, outputs = self._prediction_step(
668          current_times=current_times, state=state)
669      return state, outputs
670    _, predictions = self._state_update_loop(
671        times=predict_times, state=start_state,
672        state_update_fn=_call_prediction_step,
673        outputs=self._predict_output_names)
674    if self._normalize_features:
675      predictions = self._scale_back_predictions(predictions)
676    return predictions
677
678  class _FakeTensorArray(object):
679    """An interface for Python lists that is similar to TensorArray.
680
681    Used for easy switching between static and dynamic looping.
682    """
683
684    def __init__(self):
685      self.values = []
686
687    def write(self, unused_position, value):
688      del unused_position
689      self.values.append(value)
690      return self
691
692  def _state_update_loop(self, times, state, state_update_fn, outputs):
693    """Iterates over `times`, calling `state_update_fn` to collect outputs.
694
695    Args:
696      times: A [batch size x window size] Tensor of integers to iterate over.
697      state: A list of model-specific state Tensors, each with shape [batch size
698          x ...].
699      state_update_fn: A callback taking the following arguments
700            step_number; A scalar integer Tensor indicating the current position
701              in the window.
702            current_times; A [batch size] vector of Integers indicating times
703              for each part of the batch.
704            state; Current model state.
705          It returns a tuple of (updated state, output_values), output_values
706          being a dictionary of Tensors with keys corresponding to `outputs`.
707      outputs: A list of strings indicating values which will be saved while
708          iterating. Must match the keys of the dictionary returned by
709          state_update_fn.
710    Returns:
711      A tuple of (state, output_dict)
712      state: The final model state.
713      output_dict: A dictionary of outputs corresponding to those specified in
714        `outputs` and computed in state_update_fn.
715    """
716    times = ops.convert_to_tensor(times, dtype=dtypes.int64)
717    window_static_shape = tensor_shape.dimension_value(times.shape[1])
718    if self._static_unrolling_window_size_threshold is None:
719      static_unroll = False
720    else:
721      # The user has specified a threshold for static loop unrolling.
722      if window_static_shape is None:
723        # We don't have static shape information for the window size, so dynamic
724        # looping is our only option.
725        static_unroll = False
726      elif window_static_shape <= self._static_unrolling_window_size_threshold:
727        # The threshold is satisfied; unroll statically
728        static_unroll = True
729      else:
730        # A threshold was set but not satisfied
731        static_unroll = False
732
733    self._window_initializer(times, state)
734
735    def _run_condition(step_number, *unused):
736      del unused  # not part of while loop run condition
737      return math_ops.less(step_number, window_size)
738
739    def _state_update_step(
740        step_number, state, state_accumulators, output_accumulators,
741        reuse=False):
742      """Impute, then take one state_update_fn step, accumulating outputs."""
743      with variable_scope.variable_scope("state_update_step", reuse=reuse):
744        current_times = times[:, step_number]
745        state = self._imputation_step(current_times=current_times, state=state)
746        output_accumulators_dict = {
747            accumulator_key: accumulator
748            for accumulator_key, accumulator
749            in zip(outputs, output_accumulators)}
750        step_state, output_values = state_update_fn(
751            step_number=step_number,
752            current_times=current_times,
753            state=state)
754        assert set(output_values.keys()) == set(outputs)
755        new_output_accumulators = []
756        for output_key in outputs:
757          accumulator = output_accumulators_dict[output_key]
758          output_value = output_values[output_key]
759          new_output_accumulators.append(
760              accumulator.write(step_number, output_value))
761        flat_step_state = nest.flatten(step_state)
762        assert len(state_accumulators) == len(flat_step_state)
763        new_state_accumulators = []
764        new_state_flat = []
765        for step_state_value, state_accumulator, original_state in zip(
766            flat_step_state, state_accumulators, nest.flatten(state)):
767          # Make sure the static shape information is complete so while_loop
768          # does not complain about shape information changing.
769          step_state_value.set_shape(original_state.get_shape())
770          new_state_flat.append(step_state_value)
771          new_state_accumulators.append(state_accumulator.write(
772              step_number, step_state_value))
773        step_state = nest.pack_sequence_as(state, new_state_flat)
774        return (step_number + 1, step_state,
775                new_state_accumulators, new_output_accumulators)
776
777    window_size = array_ops.shape(times)[1]
778
779    def _window_size_tensor_array(dtype):
780      if static_unroll:
781        return self._FakeTensorArray()
782      else:
783        return tensor_array_ops.TensorArray(
784            dtype=dtype, size=window_size, dynamic_size=False)
785
786    initial_loop_arguments = [
787        array_ops.zeros([], dtypes.int32),
788        state,
789        [_window_size_tensor_array(element.dtype)
790         for element in nest.flatten(state)],
791        [_window_size_tensor_array(self.dtype) for _ in outputs]]
792    if static_unroll:
793      arguments = initial_loop_arguments
794      for step_number in range(tensor_shape.dimension_value(times.shape[1])):
795        arguments = _state_update_step(
796            array_ops.constant(step_number, dtypes.int32), *arguments[1:],
797            reuse=(step_number > 0))  # Variable sharing between steps
798    else:
799      arguments = control_flow_ops.while_loop(
800          cond=_run_condition,
801          body=_state_update_step,
802          loop_vars=initial_loop_arguments)
803    (_, _, state_loop_result, outputs_loop_result) = arguments
804
805    def _stack_and_transpose(tensor_array):
806      """Stack and re-order the dimensions of a TensorArray."""
807      if static_unroll:
808        return array_ops.stack(tensor_array.values, axis=1)
809      else:
810        # TensorArrays from while_loop stack with window size as the first
811        # dimension, so this function swaps it and the batch dimension to
812        # maintain the [batch x window size x ...] convention used elsewhere.
813        stacked = tensor_array.stack()
814        return array_ops.transpose(
815            stacked,
816            perm=array_ops.concat([[1, 0], math_ops.range(
817                2, array_ops.rank(stacked))], 0))
818
819    outputs_dict = {output_key: _stack_and_transpose(output)
820                    for output_key, output
821                    in zip(outputs, outputs_loop_result)}
822    full_state = nest.pack_sequence_as(
823        state,
824        [_stack_and_transpose(state_element)
825         for state_element in state_loop_result])
826    return full_state, outputs_dict
827