1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Convenience functions for working with time series saved_models. 16 17@@predict_continuation 18@@cold_start_filter 19@@filter_continuation 20""" 21 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26from tensorflow.contrib.timeseries.python.timeseries import feature_keys as _feature_keys 27from tensorflow.contrib.timeseries.python.timeseries import head as _head 28from tensorflow.contrib.timeseries.python.timeseries import input_pipeline as _input_pipeline 29from tensorflow.contrib.timeseries.python.timeseries import model_utils as _model_utils 30 31from tensorflow.python.util.all_util import remove_undocumented 32 33 34def _colate_features_to_feeds_and_fetches(signature, features, graph, 35 continue_from=None): 36 """Uses a saved model signature to construct feed and fetch dictionaries.""" 37 if continue_from is None: 38 state_values = {} 39 elif _feature_keys.FilteringResults.STATE_TUPLE in continue_from: 40 # We're continuing from an evaluation, so we need to unpack/flatten state. 41 state_values = _head.state_to_dictionary( 42 continue_from[_feature_keys.FilteringResults.STATE_TUPLE]) 43 else: 44 state_values = continue_from 45 input_feed_tensors_by_name = { 46 input_key: graph.as_graph_element(input_value.name) 47 for input_key, input_value in signature.inputs.items() 48 } 49 output_tensors_by_name = { 50 output_key: graph.as_graph_element(output_value.name) 51 for output_key, output_value in signature.outputs.items() 52 } 53 feed_dict = {} 54 for state_key, state_value in state_values.items(): 55 feed_dict[input_feed_tensors_by_name[state_key]] = state_value 56 for feature_key, feature_value in features.items(): 57 feed_dict[input_feed_tensors_by_name[feature_key]] = feature_value 58 return output_tensors_by_name, feed_dict 59 60 61def predict_continuation(continue_from, 62 signatures, 63 session, 64 steps=None, 65 times=None, 66 exogenous_features=None): 67 """Perform prediction using an exported saved model. 68 69 Analogous to _input_pipeline.predict_continuation_input_fn, but operates on a 70 saved model rather than feeding into Estimator's predict method. 71 72 Args: 73 continue_from: A dictionary containing the results of either an Estimator's 74 evaluate method or filter_continuation. Used to determine the model 75 state to make predictions starting from. 76 signatures: The `MetaGraphDef` protocol buffer returned from 77 `tf.saved_model.loader.load`. Used to determine the names of Tensors to 78 feed and fetch. Must be from the same model as `continue_from`. 79 session: The session to use. The session's graph must be the one into which 80 `tf.saved_model.loader.load` loaded the model. 81 steps: The number of steps to predict (scalar), starting after the 82 evaluation or filtering. If `times` is specified, `steps` must not be; one 83 is required. 84 times: A [batch_size x window_size] array of integers (not a Tensor) 85 indicating times to make predictions for. These times must be after the 86 corresponding evaluation or filtering. If `steps` is specified, `times` 87 must not be; one is required. If the batch dimension is omitted, it is 88 assumed to be 1. 89 exogenous_features: Optional dictionary. If specified, indicates exogenous 90 features for the model to use while making the predictions. Values must 91 have shape [batch_size x window_size x ...], where `batch_size` matches 92 the batch dimension used when creating `continue_from`, and `window_size` 93 is either the `steps` argument or the `window_size` of the `times` 94 argument (depending on which was specified). 95 Returns: 96 A dictionary with model-specific predictions (typically having keys "mean" 97 and "covariance") and a feature_keys.PredictionResults.TIMES key indicating 98 the times for which the predictions were computed. 99 Raises: 100 ValueError: If `times` or `steps` are misspecified. 101 """ 102 if exogenous_features is None: 103 exogenous_features = {} 104 predict_times = _model_utils.canonicalize_times_or_steps_from_output( 105 times=times, steps=steps, previous_model_output=continue_from) 106 features = {_feature_keys.PredictionFeatures.TIMES: predict_times} 107 features.update(exogenous_features) 108 predict_signature = signatures.signature_def[ 109 _feature_keys.SavedModelLabels.PREDICT] 110 output_tensors_by_name, feed_dict = _colate_features_to_feeds_and_fetches( 111 continue_from=continue_from, 112 signature=predict_signature, 113 features=features, 114 graph=session.graph) 115 output = session.run(output_tensors_by_name, feed_dict=feed_dict) 116 output[_feature_keys.PredictionResults.TIMES] = features[ 117 _feature_keys.PredictionFeatures.TIMES] 118 return output 119 120 121def cold_start_filter(signatures, session, features): 122 """Perform filtering using an exported saved model. 123 124 Filtering refers to updating model state based on new observations. 125 Predictions based on the returned model state will be conditioned on these 126 observations. 127 128 Starts from the model's default/uninformed state. 129 130 Args: 131 signatures: The `MetaGraphDef` protocol buffer returned from 132 `tf.saved_model.loader.load`. Used to determine the names of Tensors to 133 feed and fetch. Must be from the same model as `continue_from`. 134 session: The session to use. The session's graph must be the one into which 135 `tf.saved_model.loader.load` loaded the model. 136 features: A dictionary mapping keys to Numpy arrays, with several possible 137 shapes (requires keys `FilteringFeatures.TIMES` and 138 `FilteringFeatures.VALUES`): 139 Single example; `TIMES` is a scalar and `VALUES` is either a scalar or a 140 vector of length [number of features]. 141 Sequence; `TIMES` is a vector of shape [series length], `VALUES` either 142 has shape [series length] (univariate) or [series length x number of 143 features] (multivariate). 144 Batch of sequences; `TIMES` is a vector of shape [batch size x series 145 length], `VALUES` has shape [batch size x series length] or [batch 146 size x series length x number of features]. 147 In any case, `VALUES` and any exogenous features must have their shapes 148 prefixed by the shape of the value corresponding to the `TIMES` key. 149 Returns: 150 A dictionary containing model state updated to account for the observations 151 in `features`. 152 """ 153 filter_signature = signatures.signature_def[ 154 _feature_keys.SavedModelLabels.COLD_START_FILTER] 155 features = _input_pipeline._canonicalize_numpy_data( # pylint: disable=protected-access 156 data=features, 157 require_single_batch=False) 158 output_tensors_by_name, feed_dict = _colate_features_to_feeds_and_fetches( 159 signature=filter_signature, 160 features=features, 161 graph=session.graph) 162 output = session.run(output_tensors_by_name, feed_dict=feed_dict) 163 # Make it easier to chain filter -> predict by keeping track of the current 164 # time. 165 output[_feature_keys.FilteringResults.TIMES] = features[ 166 _feature_keys.FilteringFeatures.TIMES] 167 return output 168 169 170def filter_continuation(continue_from, signatures, session, features): 171 """Perform filtering using an exported saved model. 172 173 Filtering refers to updating model state based on new observations. 174 Predictions based on the returned model state will be conditioned on these 175 observations. 176 177 Args: 178 continue_from: A dictionary containing the results of either an Estimator's 179 evaluate method or a previous filter step (cold start or 180 continuation). Used to determine the model state to start filtering from. 181 signatures: The `MetaGraphDef` protocol buffer returned from 182 `tf.saved_model.loader.load`. Used to determine the names of Tensors to 183 feed and fetch. Must be from the same model as `continue_from`. 184 session: The session to use. The session's graph must be the one into which 185 `tf.saved_model.loader.load` loaded the model. 186 features: A dictionary mapping keys to Numpy arrays, with several possible 187 shapes (requires keys `FilteringFeatures.TIMES` and 188 `FilteringFeatures.VALUES`): 189 Single example; `TIMES` is a scalar and `VALUES` is either a scalar or a 190 vector of length [number of features]. 191 Sequence; `TIMES` is a vector of shape [series length], `VALUES` either 192 has shape [series length] (univariate) or [series length x number of 193 features] (multivariate). 194 Batch of sequences; `TIMES` is a vector of shape [batch size x series 195 length], `VALUES` has shape [batch size x series length] or [batch 196 size x series length x number of features]. 197 In any case, `VALUES` and any exogenous features must have their shapes 198 prefixed by the shape of the value corresponding to the `TIMES` key. 199 Returns: 200 A dictionary containing model state updated to account for the observations 201 in `features`. 202 """ 203 filter_signature = signatures.signature_def[ 204 _feature_keys.SavedModelLabels.FILTER] 205 features = _input_pipeline._canonicalize_numpy_data( # pylint: disable=protected-access 206 data=features, 207 require_single_batch=False) 208 output_tensors_by_name, feed_dict = _colate_features_to_feeds_and_fetches( 209 continue_from=continue_from, 210 signature=filter_signature, 211 features=features, 212 graph=session.graph) 213 output = session.run(output_tensors_by_name, feed_dict=feed_dict) 214 # Make it easier to chain filter -> predict by keeping track of the current 215 # time. 216 output[_feature_keys.FilteringResults.TIMES] = features[ 217 _feature_keys.FilteringFeatures.TIMES] 218 return output 219 220remove_undocumented(module_name=__name__) 221