1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A more advanced example, of building an RNN-based time series model."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22from os import path
23import tempfile
24
25import numpy
26import tensorflow as tf
27
28from tensorflow.contrib.timeseries.python.timeseries import estimators as ts_estimators
29from tensorflow.contrib.timeseries.python.timeseries import model as ts_model
30from tensorflow.contrib.timeseries.python.timeseries import state_management
31
32try:
33  import matplotlib  # pylint: disable=g-import-not-at-top
34  matplotlib.use("TkAgg")  # Need Tk for interactive plots.
35  from matplotlib import pyplot  # pylint: disable=g-import-not-at-top
36  HAS_MATPLOTLIB = True
37except ImportError:
38  # Plotting requires matplotlib, but the unit test running this code may
39  # execute in an environment without it (i.e. matplotlib is not a build
40  # dependency). We'd still like to test the TensorFlow-dependent parts of this
41  # example.
42  HAS_MATPLOTLIB = False
43
44_MODULE_PATH = path.dirname(__file__)
45_DATA_FILE = path.join(_MODULE_PATH, "data/multivariate_periods.csv")
46
47
48class _LSTMModel(ts_model.SequentialTimeSeriesModel):
49  """A time series model-building example using an RNNCell."""
50
51  def __init__(self, num_units, num_features, exogenous_feature_columns=None,
52               dtype=tf.float32):
53    """Initialize/configure the model object.
54
55    Note that we do not start graph building here. Rather, this object is a
56    configurable factory for TensorFlow graphs which are run by an Estimator.
57
58    Args:
59      num_units: The number of units in the model's LSTMCell.
60      num_features: The dimensionality of the time series (features per
61        timestep).
62      exogenous_feature_columns: A list of `tf.feature_column`s representing
63          features which are inputs to the model but are not predicted by
64          it. These must then be present for training, evaluation, and
65          prediction.
66      dtype: The floating point data type to use.
67    """
68    super(_LSTMModel, self).__init__(
69        # Pre-register the metrics we'll be outputting (just a mean here).
70        train_output_names=["mean"],
71        predict_output_names=["mean"],
72        num_features=num_features,
73        exogenous_feature_columns=exogenous_feature_columns,
74        dtype=dtype)
75    self._num_units = num_units
76    # Filled in by initialize_graph()
77    self._lstm_cell = None
78    self._lstm_cell_run = None
79    self._predict_from_lstm_output = None
80
81  def initialize_graph(self, input_statistics=None):
82    """Save templates for components, which can then be used repeatedly.
83
84    This method is called every time a new graph is created. It's safe to start
85    adding ops to the current default graph here, but the graph should be
86    constructed from scratch.
87
88    Args:
89      input_statistics: A math_utils.InputStatistics object.
90    """
91    super(_LSTMModel, self).initialize_graph(input_statistics=input_statistics)
92    with tf.variable_scope("", use_resource=True):
93      # Use ResourceVariables to avoid race conditions.
94      self._lstm_cell = tf.nn.rnn_cell.LSTMCell(num_units=self._num_units)
95      # Create templates so we don't have to worry about variable reuse.
96      self._lstm_cell_run = tf.make_template(
97          name_="lstm_cell",
98          func_=self._lstm_cell,
99          create_scope_now_=True)
100      # Transforms LSTM output into mean predictions.
101      self._predict_from_lstm_output = tf.make_template(
102          name_="predict_from_lstm_output",
103          func_=functools.partial(tf.layers.dense, units=self.num_features),
104          create_scope_now_=True)
105
106  def get_start_state(self):
107    """Return initial state for the time series model."""
108    return (
109        # Keeps track of the time associated with this state for error checking.
110        tf.zeros([], dtype=tf.int64),
111        # The previous observation or prediction.
112        tf.zeros([self.num_features], dtype=self.dtype),
113        # The most recently seen exogenous features.
114        tf.zeros(self._get_exogenous_embedding_shape(), dtype=self.dtype),
115        # The state of the RNNCell (batch dimension removed since this parent
116        # class will broadcast).
117        [tf.squeeze(state_element, axis=0)
118         for state_element
119         in self._lstm_cell.zero_state(batch_size=1, dtype=self.dtype)])
120
121  def _filtering_step(self, current_times, current_values, state, predictions):
122    """Update model state based on observations.
123
124    Note that we don't do much here aside from computing a loss. In this case
125    it's easier to update the RNN state in _prediction_step, since that covers
126    running the RNN both on observations (from this method) and our own
127    predictions. This distinction can be important for probabilistic models,
128    where repeatedly predicting without filtering should lead to low-confidence
129    predictions.
130
131    Args:
132      current_times: A [batch size] integer Tensor.
133      current_values: A [batch size, self.num_features] floating point Tensor
134        with new observations.
135      state: The model's state tuple.
136      predictions: The output of the previous `_prediction_step`.
137    Returns:
138      A tuple of new state and a predictions dictionary updated to include a
139      loss (note that we could also return other measures of goodness of fit,
140      although only "loss" will be optimized).
141    """
142    state_from_time, prediction, exogenous, lstm_state = state
143    with tf.control_dependencies(
144        [tf.assert_equal(current_times, state_from_time)]):
145      # Subtract the mean and divide by the variance of the series.  Slightly
146      # more efficient if done for a whole window (using the normalize_features
147      # argument to SequentialTimeSeriesModel).
148      transformed_values = self._scale_data(current_values)
149      # Use mean squared error across features for the loss.
150      predictions["loss"] = tf.reduce_mean(
151          (prediction - transformed_values) ** 2, axis=-1)
152      # Keep track of the new observation in model state. It won't be run
153      # through the LSTM until the next _imputation_step.
154      new_state_tuple = (current_times, transformed_values,
155                         exogenous, lstm_state)
156    return (new_state_tuple, predictions)
157
158  def _prediction_step(self, current_times, state):
159    """Advance the RNN state using a previous observation or prediction."""
160    _, previous_observation_or_prediction, exogenous, lstm_state = state
161    # Update LSTM state based on the most recent exogenous and endogenous
162    # features.
163    inputs = tf.concat([previous_observation_or_prediction, exogenous],
164                       axis=-1)
165    lstm_output, new_lstm_state = self._lstm_cell_run(
166        inputs=inputs, state=lstm_state)
167    next_prediction = self._predict_from_lstm_output(lstm_output)
168    new_state_tuple = (current_times, next_prediction,
169                       exogenous, new_lstm_state)
170    return new_state_tuple, {"mean": self._scale_back_data(next_prediction)}
171
172  def _imputation_step(self, current_times, state):
173    """Advance model state across a gap."""
174    # Does not do anything special if we're jumping across a gap. More advanced
175    # models, especially probabilistic ones, would want a special case that
176    # depends on the gap size.
177    return state
178
179  def _exogenous_input_step(
180      self, current_times, current_exogenous_regressors, state):
181    """Save exogenous regressors in model state for use in _prediction_step."""
182    state_from_time, prediction, _, lstm_state = state
183    return (state_from_time, prediction,
184            current_exogenous_regressors, lstm_state)
185
186
187def train_and_predict(
188    csv_file_name=_DATA_FILE, training_steps=200, estimator_config=None,
189    export_directory=None):
190  """Train and predict using a custom time series model."""
191  # Construct an Estimator from our LSTM model.
192  categorical_column = tf.feature_column.categorical_column_with_hash_bucket(
193      key="categorical_exogenous_feature", hash_bucket_size=16)
194  exogenous_feature_columns = [
195      # Exogenous features are not part of the loss, but can inform
196      # predictions. In this example the features have no extra information, but
197      # are included as an API example.
198      tf.feature_column.numeric_column(
199          "2d_exogenous_feature", shape=(2,)),
200      tf.feature_column.embedding_column(
201          categorical_column=categorical_column, dimension=10)]
202  estimator = ts_estimators.TimeSeriesRegressor(
203      model=_LSTMModel(num_features=5, num_units=128,
204                       exogenous_feature_columns=exogenous_feature_columns),
205      optimizer=tf.train.AdamOptimizer(0.001), config=estimator_config,
206      # Set state to be saved across windows.
207      state_manager=state_management.ChainingStateManager())
208  reader = tf.contrib.timeseries.CSVReader(
209      csv_file_name,
210      column_names=((tf.contrib.timeseries.TrainEvalFeatures.TIMES,)
211                    + (tf.contrib.timeseries.TrainEvalFeatures.VALUES,) * 5
212                    + ("2d_exogenous_feature",) * 2
213                    + ("categorical_exogenous_feature",)),
214      # Data types other than for `times` need to be specified if they aren't
215      # float32. In this case one of our exogenous features has string dtype.
216      column_dtypes=((tf.int64,) + (tf.float32,) * 7 + (tf.string,)))
217  train_input_fn = tf.contrib.timeseries.RandomWindowInputFn(
218      reader, batch_size=4, window_size=32)
219  estimator.train(input_fn=train_input_fn, steps=training_steps)
220  evaluation_input_fn = tf.contrib.timeseries.WholeDatasetInputFn(reader)
221  evaluation = estimator.evaluate(input_fn=evaluation_input_fn, steps=1)
222  # Predict starting after the evaluation
223  predict_exogenous_features = {
224      "2d_exogenous_feature": numpy.concatenate(
225          [numpy.ones([1, 100, 1]), numpy.zeros([1, 100, 1])],
226          axis=-1),
227      "categorical_exogenous_feature": numpy.array(
228          ["strkey"] * 100)[None, :, None]}
229  (predictions,) = tuple(estimator.predict(
230      input_fn=tf.contrib.timeseries.predict_continuation_input_fn(
231          evaluation, steps=100,
232          exogenous_features=predict_exogenous_features)))
233  times = evaluation["times"][0]
234  observed = evaluation["observed"][0, :, :]
235  predicted_mean = numpy.squeeze(numpy.concatenate(
236      [evaluation["mean"][0], predictions["mean"]], axis=0))
237  all_times = numpy.concatenate([times, predictions["times"]], axis=0)
238
239  # Export the model in SavedModel format. We include a bit of extra boilerplate
240  # for "cold starting" as if we didn't have any state from the Estimator, which
241  # is the case when serving from a SavedModel. If Estimator output is
242  # available, the result of "Estimator.evaluate" can be passed directly to
243  # `tf.contrib.timeseries.saved_model_utils.predict_continuation` as the
244  # `continue_from` argument.
245  with tf.Graph().as_default():
246    filter_feature_tensors, _ = evaluation_input_fn()
247    with tf.train.MonitoredSession() as session:
248      # Fetch the series to "warm up" our state, which will allow us to make
249      # predictions for its future values. This is just a dictionary of times,
250      # values, and exogenous features mapping to numpy arrays. The use of an
251      # input_fn is just a convenience for the example; they can also be
252      # specified manually.
253      filter_features = session.run(filter_feature_tensors)
254  if export_directory is None:
255    export_directory = tempfile.mkdtemp()
256  input_receiver_fn = estimator.build_raw_serving_input_receiver_fn()
257  export_location = estimator.export_saved_model(export_directory,
258                                                 input_receiver_fn)
259  # Warm up and predict using the SavedModel
260  with tf.Graph().as_default():
261    with tf.Session() as session:
262      signatures = tf.saved_model.loader.load(
263          session, [tf.saved_model.tag_constants.SERVING], export_location)
264      state = tf.contrib.timeseries.saved_model_utils.cold_start_filter(
265          signatures=signatures, session=session, features=filter_features)
266      saved_model_output = (
267          tf.contrib.timeseries.saved_model_utils.predict_continuation(
268              continue_from=state, signatures=signatures,
269              session=session, steps=100,
270              exogenous_features=predict_exogenous_features))
271      # The exported model gives the same results as the Estimator.predict()
272      # call above.
273      numpy.testing.assert_allclose(
274          predictions["mean"],
275          numpy.squeeze(saved_model_output["mean"], axis=0))
276  return times, observed, all_times, predicted_mean
277
278
279def main(unused_argv):
280  if not HAS_MATPLOTLIB:
281    raise ImportError(
282        "Please install matplotlib to generate a plot from this example.")
283  (observed_times, observations,
284   all_times, predictions) = train_and_predict()
285  pyplot.axvline(99, linestyle="dotted")
286  observed_lines = pyplot.plot(
287      observed_times, observations, label="Observed", color="k")
288  predicted_lines = pyplot.plot(
289      all_times, predictions, label="Predicted", color="b")
290  pyplot.legend(handles=[observed_lines[0], predicted_lines[0]],
291                loc="upper left")
292  pyplot.show()
293
294
295if __name__ == "__main__":
296  tf.app.run(main=main)
297