1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Convenience functions for working with time series saved_models.
16
17@@predict_continuation
18@@cold_start_filter
19@@filter_continuation
20"""
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26from tensorflow.contrib.timeseries.python.timeseries import feature_keys as _feature_keys
27from tensorflow.contrib.timeseries.python.timeseries import head as _head
28from tensorflow.contrib.timeseries.python.timeseries import input_pipeline as _input_pipeline
29from tensorflow.contrib.timeseries.python.timeseries import model_utils as _model_utils
30
31from tensorflow.python.util.all_util import remove_undocumented
32
33
34def _colate_features_to_feeds_and_fetches(signature, features, graph,
35                                          continue_from=None):
36  """Uses a saved model signature to construct feed and fetch dictionaries."""
37  if continue_from is None:
38    state_values = {}
39  elif _feature_keys.FilteringResults.STATE_TUPLE in continue_from:
40    # We're continuing from an evaluation, so we need to unpack/flatten state.
41    state_values = _head.state_to_dictionary(
42        continue_from[_feature_keys.FilteringResults.STATE_TUPLE])
43  else:
44    state_values = continue_from
45  input_feed_tensors_by_name = {
46      input_key: graph.as_graph_element(input_value.name)
47      for input_key, input_value in signature.inputs.items()
48  }
49  output_tensors_by_name = {
50      output_key: graph.as_graph_element(output_value.name)
51      for output_key, output_value in signature.outputs.items()
52  }
53  feed_dict = {}
54  for state_key, state_value in state_values.items():
55    feed_dict[input_feed_tensors_by_name[state_key]] = state_value
56  for feature_key, feature_value in features.items():
57    feed_dict[input_feed_tensors_by_name[feature_key]] = feature_value
58  return output_tensors_by_name, feed_dict
59
60
61def predict_continuation(continue_from,
62                         signatures,
63                         session,
64                         steps=None,
65                         times=None,
66                         exogenous_features=None):
67  """Perform prediction using an exported saved model.
68
69  Analogous to _input_pipeline.predict_continuation_input_fn, but operates on a
70  saved model rather than feeding into Estimator's predict method.
71
72  Args:
73    continue_from: A dictionary containing the results of either an Estimator's
74      evaluate method or filter_continuation. Used to determine the model
75      state to make predictions starting from.
76    signatures: The `MetaGraphDef` protocol buffer returned from
77      `tf.saved_model.loader.load`. Used to determine the names of Tensors to
78      feed and fetch. Must be from the same model as `continue_from`.
79    session: The session to use. The session's graph must be the one into which
80      `tf.saved_model.loader.load` loaded the model.
81    steps: The number of steps to predict (scalar), starting after the
82      evaluation or filtering. If `times` is specified, `steps` must not be; one
83      is required.
84    times: A [batch_size x window_size] array of integers (not a Tensor)
85      indicating times to make predictions for. These times must be after the
86      corresponding evaluation or filtering. If `steps` is specified, `times`
87      must not be; one is required. If the batch dimension is omitted, it is
88      assumed to be 1.
89    exogenous_features: Optional dictionary. If specified, indicates exogenous
90      features for the model to use while making the predictions. Values must
91      have shape [batch_size x window_size x ...], where `batch_size` matches
92      the batch dimension used when creating `continue_from`, and `window_size`
93      is either the `steps` argument or the `window_size` of the `times`
94      argument (depending on which was specified).
95  Returns:
96    A dictionary with model-specific predictions (typically having keys "mean"
97    and "covariance") and a feature_keys.PredictionResults.TIMES key indicating
98    the times for which the predictions were computed.
99  Raises:
100    ValueError: If `times` or `steps` are misspecified.
101  """
102  if exogenous_features is None:
103    exogenous_features = {}
104  predict_times = _model_utils.canonicalize_times_or_steps_from_output(
105      times=times, steps=steps, previous_model_output=continue_from)
106  features = {_feature_keys.PredictionFeatures.TIMES: predict_times}
107  features.update(exogenous_features)
108  predict_signature = signatures.signature_def[
109      _feature_keys.SavedModelLabels.PREDICT]
110  output_tensors_by_name, feed_dict = _colate_features_to_feeds_and_fetches(
111      continue_from=continue_from,
112      signature=predict_signature,
113      features=features,
114      graph=session.graph)
115  output = session.run(output_tensors_by_name, feed_dict=feed_dict)
116  output[_feature_keys.PredictionResults.TIMES] = features[
117      _feature_keys.PredictionFeatures.TIMES]
118  return output
119
120
121def cold_start_filter(signatures, session, features):
122  """Perform filtering using an exported saved model.
123
124  Filtering refers to updating model state based on new observations.
125  Predictions based on the returned model state will be conditioned on these
126  observations.
127
128  Starts from the model's default/uninformed state.
129
130  Args:
131    signatures: The `MetaGraphDef` protocol buffer returned from
132      `tf.saved_model.loader.load`. Used to determine the names of Tensors to
133      feed and fetch. Must be from the same model as `continue_from`.
134    session: The session to use. The session's graph must be the one into which
135      `tf.saved_model.loader.load` loaded the model.
136    features: A dictionary mapping keys to Numpy arrays, with several possible
137      shapes (requires keys `FilteringFeatures.TIMES` and
138      `FilteringFeatures.VALUES`):
139        Single example; `TIMES` is a scalar and `VALUES` is either a scalar or a
140          vector of length [number of features].
141        Sequence; `TIMES` is a vector of shape [series length], `VALUES` either
142          has shape [series length] (univariate) or [series length x number of
143          features] (multivariate).
144        Batch of sequences; `TIMES` is a vector of shape [batch size x series
145          length], `VALUES` has shape [batch size x series length] or [batch
146          size x series length x number of features].
147      In any case, `VALUES` and any exogenous features must have their shapes
148      prefixed by the shape of the value corresponding to the `TIMES` key.
149  Returns:
150    A dictionary containing model state updated to account for the observations
151    in `features`.
152  """
153  filter_signature = signatures.signature_def[
154      _feature_keys.SavedModelLabels.COLD_START_FILTER]
155  features = _input_pipeline._canonicalize_numpy_data(  # pylint: disable=protected-access
156      data=features,
157      require_single_batch=False)
158  output_tensors_by_name, feed_dict = _colate_features_to_feeds_and_fetches(
159      signature=filter_signature,
160      features=features,
161      graph=session.graph)
162  output = session.run(output_tensors_by_name, feed_dict=feed_dict)
163  # Make it easier to chain filter -> predict by keeping track of the current
164  # time.
165  output[_feature_keys.FilteringResults.TIMES] = features[
166      _feature_keys.FilteringFeatures.TIMES]
167  return output
168
169
170def filter_continuation(continue_from, signatures, session, features):
171  """Perform filtering using an exported saved model.
172
173  Filtering refers to updating model state based on new observations.
174  Predictions based on the returned model state will be conditioned on these
175  observations.
176
177  Args:
178    continue_from: A dictionary containing the results of either an Estimator's
179      evaluate method or a previous filter step (cold start or
180      continuation). Used to determine the model state to start filtering from.
181    signatures: The `MetaGraphDef` protocol buffer returned from
182      `tf.saved_model.loader.load`. Used to determine the names of Tensors to
183      feed and fetch. Must be from the same model as `continue_from`.
184    session: The session to use. The session's graph must be the one into which
185      `tf.saved_model.loader.load` loaded the model.
186    features: A dictionary mapping keys to Numpy arrays, with several possible
187      shapes (requires keys `FilteringFeatures.TIMES` and
188      `FilteringFeatures.VALUES`):
189        Single example; `TIMES` is a scalar and `VALUES` is either a scalar or a
190          vector of length [number of features].
191        Sequence; `TIMES` is a vector of shape [series length], `VALUES` either
192          has shape [series length] (univariate) or [series length x number of
193          features] (multivariate).
194        Batch of sequences; `TIMES` is a vector of shape [batch size x series
195          length], `VALUES` has shape [batch size x series length] or [batch
196          size x series length x number of features].
197      In any case, `VALUES` and any exogenous features must have their shapes
198      prefixed by the shape of the value corresponding to the `TIMES` key.
199  Returns:
200    A dictionary containing model state updated to account for the observations
201    in `features`.
202  """
203  filter_signature = signatures.signature_def[
204      _feature_keys.SavedModelLabels.FILTER]
205  features = _input_pipeline._canonicalize_numpy_data(  # pylint: disable=protected-access
206      data=features,
207      require_single_batch=False)
208  output_tensors_by_name, feed_dict = _colate_features_to_feeds_and_fetches(
209      continue_from=continue_from,
210      signature=filter_signature,
211      features=features,
212      graph=session.graph)
213  output = session.run(output_tensors_by_name, feed_dict=feed_dict)
214  # Make it easier to chain filter -> predict by keeping track of the current
215  # time.
216  output[_feature_keys.FilteringResults.TIMES] = features[
217      _feature_keys.FilteringFeatures.TIMES]
218  return output
219
220remove_undocumented(module_name=__name__)
221