15"""Timeseries head."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
20import re
22from tensorflow.contrib.timeseries.python.timeseries import feature_keys
23from tensorflow.python.estimator import estimator_lib
24from tensorflow.python.estimator.canned import head as head_lib
25from tensorflow.python.estimator.canned import metric_keys
26from tensorflow.python.estimator.export import export_lib
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import sparse_tensor
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import control_flow_ops
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops import metrics_impl
34from tensorflow.python.ops import state_ops
35from tensorflow.python.ops import variable_scope
36from tensorflow.python.summary import summary
37from tensorflow.python.training import training_util
38from tensorflow.python.util import nest
41class _NoStatePredictOutput(export_lib.PredictOutput):
43  def as_signature_def(self, receiver_tensors):
44    no_state_receiver_tensors = {
45        key: value for key, value in receiver_tensors.items()
46        if not key.startswith(feature_keys.State.STATE_PREFIX)}
47    return super(_NoStatePredictOutput, self).as_signature_def(
48        receiver_tensors=no_state_receiver_tensors)
51class TimeSeriesRegressionHead(head_lib._Head):  # pylint:disable=protected-access
52  """Determines input and output signatures for a time series model."""
54  def __init__(self,
55               model,
56               state_manager,
57               optimizer,
58               input_statistics_generator=None,
59               name=None):
60    """Creates a `_Head` for time series regression.
62    Args:
63      model: A model for time series regression.
64      state_manager: A state manager.
65      optimizer: An optimizer.
66      input_statistics_generator: A input statistics generator.
67      name: An optional name for the model.
68    """
69    self.model = model
70    self.state_manager = state_manager
71    self.optimizer = optimizer
72    self.input_statistics_generator = input_statistics_generator
73    self._name = name
75  @property
76  def name(self):
77    return self._name
79  # TODO(terrytangyuan): consolidate `model_outputs` and `_Head.LossSpec`
80  # once `_Head.create_loss` becomes extendable
81  def create_loss(self, features, mode, logits=None, labels=None):
82    """See `_Head`."""
83    model_outputs = self.state_manager.define_loss(
84        self.model, features, mode)
85    summary.scalar(
86        head_lib._summary_key(self._name, metric_keys.MetricKeys.LOSS),
87        model_outputs.loss)
88    return model_outputs
90  @property
91  def logits_dimension(self):
92    """See `_Head`."""
93    return 1
95  def _train_ops(self, features):
96    """Add training ops to the graph."""
97    mode = estimator_lib.ModeKeys.TRAIN
98    with variable_scope.variable_scope(
99        "model",
100        # Use ResourceVariables to avoid race conditions.
101        use_resource=True):
102      model_outputs = self.create_loss(features, mode)
104    train_op = self.optimizer.minimize(
105        model_outputs.loss,
106        global_step=training_util.get_global_step())
107    return estimator_lib.EstimatorSpec(
108        loss=model_outputs.loss,
109        mode=mode,
110        train_op=train_op)
112  def _evaluate_ops(self, features):
113    """Add ops for evaluation (aka filtering) to the graph."""
114    mode = estimator_lib.ModeKeys.EVAL
115    with variable_scope.variable_scope("model", use_resource=True):
116      model_outputs = self.create_loss(features, mode)
117    metrics = {}
118    # Just output in-sample predictions for the last chunk seen
119    for prediction_key, prediction_value in model_outputs.predictions.items():
120      metrics[prediction_key] = _identity_metric_single(prediction_key,
121                                                        prediction_value)
122    metrics[feature_keys.FilteringResults.TIMES] = _identity_metric_single(
123        feature_keys.FilteringResults.TIMES, model_outputs.prediction_times)
124    metrics[feature_keys.FilteringResults.STATE_TUPLE] = (
125        _identity_metric_nested(feature_keys.FilteringResults.STATE_TUPLE,
126                                model_outputs.end_state))
127    metrics[metric_keys.MetricKeys.LOSS_MEAN] = metrics_impl.mean(
128        model_outputs.loss, name="average_loss")
129    return estimator_lib.EstimatorSpec(
130        loss=model_outputs.loss,
131        mode=mode,
132        eval_metric_ops=metrics,
133        # needed for custom metrics.
134        predictions=model_outputs.predictions)
136  def _predict_ops(self, features):
137    """Add ops for prediction to the graph."""
138    with variable_scope.variable_scope("model", use_resource=True):
139      prediction = self.model.predict(features=features)
140    prediction[feature_keys.PredictionResults.TIMES] = features[
141        feature_keys.PredictionFeatures.TIMES]
142    return estimator_lib.EstimatorSpec(
143        predictions=prediction, mode=estimator_lib.ModeKeys.PREDICT)
145  def _serving_ops(self, features):
146    """Add ops for serving to the graph."""
147    with variable_scope.variable_scope("model", use_resource=True):
148      prediction_outputs = self.model.predict(features=features)
149    with variable_scope.variable_scope("model", reuse=True):
150      filtering_outputs = self.create_loss(
151          features, estimator_lib.ModeKeys.EVAL)
152    with variable_scope.variable_scope("model", reuse=True):
153      no_state_features = {
154          k: v for k, v in features.items()
155          if not k.startswith(feature_keys.State.STATE_PREFIX)}
156      # Ignore any state management when cold-starting. The model's default
157      # start state is replicated across the batch.
158      cold_filtering_outputs = self.model.define_loss(
159          features=no_state_features, mode=estimator_lib.ModeKeys.EVAL)
160    return estimator_lib.EstimatorSpec(
161        mode=estimator_lib.ModeKeys.PREDICT,
162        export_outputs={
163            feature_keys.SavedModelLabels.PREDICT:
164                export_lib.PredictOutput(prediction_outputs),
165            feature_keys.SavedModelLabels.FILTER:
166                export_lib.PredictOutput(
167                    state_to_dictionary(filtering_outputs.end_state)),
168            feature_keys.SavedModelLabels.COLD_START_FILTER:
169                _NoStatePredictOutput(
170                    state_to_dictionary(cold_filtering_outputs.end_state))
171        },
172        # Likely unused, but it is necessary to return `predictions` to satisfy
173        # the Estimator's error checking.
174        predictions={})
176  def _convert_feature_to_tensor(self, name, value):
177    """Casts features to the correct dtype based on their name."""
178    if name in [
179        feature_keys.TrainEvalFeatures.TIMES,
180        feature_keys.PredictionFeatures.TIMES
181    ]:
182      return math_ops.cast(value, dtypes.int64)
183    if name == feature_keys.TrainEvalFeatures.VALUES:
184      return math_ops.cast(value, self.model.dtype)
185    if name == feature_keys.PredictionFeatures.STATE_TUPLE:
186      return value  # Correct dtypes are model-dependent
187    return sparse_tensor.convert_to_tensor_or_sparse_tensor(value)
189  def _gather_state(self, features):
190    """Returns `features` with state packed, indicates if packing was done."""
191    prefixed_state_re = re.compile(r"^" + feature_keys.State.STATE_PREFIX +
192                                   r"_(\d+)$")
193    numbered_state = []
194    for key, tensor in features.items():
195      search_result = prefixed_state_re.search(key)
196      if search_result:
197        numbered_state.append((int(search_result.group(1)), key, tensor))
198    if not numbered_state:
199      return features, False
200    features = features.copy()
201    for _, key, _ in numbered_state:
202      del features[key]
203    numbered_state.sort(key=lambda number, *_: number)
204    features[feature_keys.State.STATE_TUPLE] = nest.pack_sequence_as(
205        structure=self.model.get_start_state(),
206        flat_sequence=[tensor for _, _, tensor in numbered_state])
207    return features, True
209  def _check_predict_features(self, features):
210    """Raises errors if features are not suitable for prediction."""
211    if feature_keys.PredictionFeatures.TIMES not in features:
212      raise ValueError("Expected a '{}' feature for prediction.".format(
213          feature_keys.PredictionFeatures.TIMES))
214    if feature_keys.PredictionFeatures.STATE_TUPLE not in features:
215      raise ValueError("Expected a '{}' feature for prediction.".format(
216          feature_keys.PredictionFeatures.STATE_TUPLE))
217    times_feature = features[feature_keys.PredictionFeatures.TIMES]
218    if not times_feature.get_shape().is_compatible_with([None, None]):
219      raise ValueError(
220          ("Expected shape (batch dimension, window size) for feature '{}' "
221           "(got shape {})").format(feature_keys.PredictionFeatures.TIMES,
222                                    times_feature.get_shape()))
223    _check_feature_shapes_compatible_with(
224        features=features,
225        compatible_with_name=feature_keys.PredictionFeatures.TIMES,
226        compatible_with_value=times_feature,
227        ignore=set([
228            # Model-dependent shapes
229            feature_keys.PredictionFeatures.STATE_TUPLE
230        ]))
232  def create_estimator_spec(self, features, mode, labels=None):
233    """Performs basic error checking and returns an EstimatorSpec."""
234    with ops.name_scope(self._name, "head"):
235      if labels is not None and labels != {}:  # for better error messages.
236        raise ValueError(
237            "The model received a `labels`, which is not supported. "
238            "Pass '{}' and '{}' as features.".format(
239                feature_keys.TrainEvalFeatures.TIMES,
240                feature_keys.TrainEvalFeatures.VALUES))
241      del labels
242      features = {
243          name: self._convert_feature_to_tensor(name=name, value=value)
244          for name, value in features.items()
245      }
246      if self.input_statistics_generator is not None:
247        input_statistics = self.input_statistics_generator.initialize_graph(
248            features, update_statistics=(mode == estimator_lib.ModeKeys.TRAIN))
249      else:
250        input_statistics = None
251      self.model.initialize_graph(input_statistics=input_statistics)
253      # _gather_state requires the model to have its graph initialized (so it
254      # has access to the structure of the model's state)
255      features, passed_flat_state = self._gather_state(features)
256      if (mode == estimator_lib.ModeKeys.TRAIN or
257          mode == estimator_lib.ModeKeys.EVAL):
258        _check_train_eval_features(features, self.model)
259      elif mode == estimator_lib.ModeKeys.PREDICT:
260        self._check_predict_features(features)
261      else:
262        raise ValueError("Unknown mode '{}' passed to model_fn.".format(mode))
264      self.state_manager.initialize_graph(
265          model=self.model, input_statistics=input_statistics)
267      if mode == estimator_lib.ModeKeys.TRAIN:
268        return self._train_ops(features)
269      elif mode == estimator_lib.ModeKeys.EVAL:
270        return self._evaluate_ops(features)
271      elif mode == estimator_lib.ModeKeys.PREDICT and not passed_flat_state:
272        return self._predict_ops(features)
273      elif mode == estimator_lib.ModeKeys.PREDICT and passed_flat_state:
274        # The mode is PREDICT, but we're actually in export_savedmodel for
275        # serving. We want to return two graphs: one for filtering (state + data
276        # -> state) and one for predicting (state -> prediction).
277        return self._serving_ops(features)
280class OneShotPredictionHead(TimeSeriesRegressionHead):
281  """A time series head which exports a single stateless serving signature.
283  The serving default signature exported by this head expects `times`, `values`,
284  and any exogenous features, but no state. `values` has shape `[batch_size,
285  filter_length, num_features]` and `times` has shape `[batch_size,
286  total_length]`, where `total_length > filter_length`. Any exogenous features
287  must have their shapes prefixed by the shape of the `times` feature.
289  When serving, first performs filtering on the series up to `filter_length`
290  starting from the default start state for the model, then computes predictions
291  on the remainder of the series, returning them.
293  Model state is neither accepted nor returned, so filtering must be performed
294  each time predictions are requested when using this head.
295  """
297  def _check_predict_features(self, features):
298    """Raises errors if features are not suitable for one-shot prediction."""
299    if feature_keys.PredictionFeatures.TIMES not in features:
300      raise ValueError("Expected a '{}' feature for prediction.".format(
301          feature_keys.PredictionFeatures.TIMES))
302    if feature_keys.TrainEvalFeatures.VALUES not in features:
303      raise ValueError("Expected a '{}' feature for prediction.".format(
304          feature_keys.TrainEvalFeatures.VALUES))
305    if feature_keys.PredictionFeatures.STATE_TUPLE not in features:
306      raise ValueError("Expected a '{}' feature for prediction.".format(
307          feature_keys.PredictionFeatures.STATE_TUPLE))
308    times_feature = features[feature_keys.PredictionFeatures.TIMES]
309    if not times_feature.get_shape().is_compatible_with([None, None]):
310      raise ValueError(
311          ("Expected shape (batch dimension, window size) for feature '{}' "
312           "(got shape {})").format(feature_keys.PredictionFeatures.TIMES,
313                                    times_feature.get_shape()))
314    _check_feature_shapes_compatible_with(
315        features=features,
316        compatible_with_name=feature_keys.PredictionFeatures.TIMES,
317        compatible_with_value=times_feature,
318        ignore=set([
319            # Model-dependent shapes
320            feature_keys.PredictionFeatures.STATE_TUPLE,
321            # One shot prediction head relies on values being shorter than
322            # times. Even though we're predicting eventually, we need values for
323            # the filtering phase.
324            feature_keys.TrainEvalFeatures.VALUES,
325        ]))
327  def _evaluate_ops(self, features):
328    """Add ops for evaluation (aka filtering) to the graph."""
329    spec = super(OneShotPredictionHead, self)._evaluate_ops(features)
330    # No state is fed to OneShotPredictionHead, so we don't return it; it being
331    # a tuple can cause issues for downstream infrastructure.
332    del spec.eval_metric_ops[feature_keys.State.STATE_TUPLE]
333    return spec
335  def _serving_ops(self, features):
336    """Add ops for serving to the graph."""
337    with variable_scope.variable_scope("model", use_resource=True):
338      filtering_features = {}
339      prediction_features = {}
340      values_length = array_ops.shape(
341          features[feature_keys.FilteringFeatures.VALUES])[1]
342      for key, value in features.items():
343        if key == feature_keys.State.STATE_TUPLE:
344          # Ignore state input. The model's default start state is replicated
345          # across the batch.
346          continue
347        if key == feature_keys.FilteringFeatures.VALUES:
348          filtering_features[key] = value
349        else:
350          filtering_features[key] = value[:, :values_length]
351          prediction_features[key] = value[:, values_length:]
352      cold_filtering_outputs = self.model.define_loss(
353          features=filtering_features, mode=estimator_lib.ModeKeys.EVAL)
354      prediction_features[feature_keys.State.STATE_TUPLE] = (
355          cold_filtering_outputs.end_state)
356    with variable_scope.variable_scope("model", reuse=True):
357      prediction_outputs = self.model.predict(
358          features=prediction_features)
359    return estimator_lib.EstimatorSpec(
360        mode=estimator_lib.ModeKeys.PREDICT,
361        export_outputs={
362            feature_keys.SavedModelLabels.PREDICT:
363                _NoStatePredictOutput(prediction_outputs),
364        },
365        # Likely unused, but it is necessary to return `predictions` to satisfy
366        # the Estimator's error checking.
367        predictions={})
370def _check_feature_shapes_compatible_with(features,
371                                          compatible_with_name,
372                                          compatible_with_value,
373                                          ignore=None):
374  """Checks all features are compatible with the given time-like feature."""
375  if ignore is None:
376    ignore = set()
377  for name, value in features.items():
378    if name in ignore:
379      continue
380    feature_shape = value.get_shape()
381    if feature_shape.ndims is None:
382      continue
383    if feature_shape.ndims < 2:
384      raise ValueError(
385          ("Features must have shape (batch dimension, window size, ...) "
386           "(got rank {} for feature '{}')").format(feature_shape.ndims, name))
387    if not feature_shape[:2].is_compatible_with(
388        compatible_with_value.get_shape()):
389      raise ValueError(
390          ("Features must have shape (batch dimension, window size, ...) "
391           "where batch dimension and window size match the "
392           "'{times_feature}' feature (got shape {feature_shape} for "
393           "feature '{feature_name}' but shape {times_shape} for feature "
394           "'{times_feature}')").format(
395               times_feature=compatible_with_name,
396               feature_shape=feature_shape,
397               feature_name=name,
398               times_shape=compatible_with_value.get_shape()))
401def _check_train_eval_features(features, model):
402  """Raise errors if features are not suitable for training/evaluation."""
403  if feature_keys.TrainEvalFeatures.TIMES not in features:
404    raise ValueError("Expected a '{}' feature for training/evaluation.".format(
405        feature_keys.TrainEvalFeatures.TIMES))
406  if feature_keys.TrainEvalFeatures.VALUES not in features:
407    raise ValueError("Expected a '{}' feature for training/evaluation.".format(
408        feature_keys.TrainEvalFeatures.VALUES))
409  times_feature = features[feature_keys.TrainEvalFeatures.TIMES]
410  if not times_feature.get_shape().is_compatible_with([None, None]):
411    raise ValueError(
412        ("Expected shape (batch dimension, window size) for feature '{}' "
413         "(got shape {})").format(feature_keys.TrainEvalFeatures.TIMES,
414                                  times_feature.get_shape()))
415  values_feature = features[feature_keys.TrainEvalFeatures.VALUES]
416  if not values_feature.get_shape().is_compatible_with(
417      [None, None, model.num_features]):
418    raise ValueError(
419        ("Expected shape (batch dimension, window size, {num_features}) "
420         "for feature '{feature_name}', since the model was configured "
421         "with num_features={num_features} (got shape {got_shape})").format(
422             num_features=model.num_features,
423             feature_name=feature_keys.TrainEvalFeatures.VALUES,
424             got_shape=times_feature.get_shape()))
425  _check_feature_shapes_compatible_with(
426      features=features,
427      compatible_with_name=feature_keys.TrainEvalFeatures.TIMES,
428      compatible_with_value=times_feature,
429      ignore=set([
430          feature_keys.State.STATE_TUPLE  # Model-dependent shapes
431      ]))
434def _identity_metric_single(name, input_tensor):
435  """A metric which takes on its last updated value.
437  This keeps evaluation metrics in sync with one another, since update ops are
438  run separately from their result Tensors. Simply returning (input_tensor,
439  no_op) as a metric with a value but no update means that a metric will come
440  from a different batch of data than metrics which cache values in a Variable
441  (e.g. the default loss metric).
443  Args:
444    name: A name for the metric.
445    input_tensor: Any Tensor.
446  Returns:
447    A tuple of (value, update_op).
448  """
449  metric_variable = variable_scope.variable(
450      name="{}_identity_metric".format(name),
451      initial_value=array_ops.zeros([], dtype=input_tensor.dtype),
452      collections=[ops.GraphKeys.LOCAL_VARIABLES],
453      validate_shape=False)
454  update_op = state_ops.assign(
455      metric_variable, input_tensor, validate_shape=False)
456  # This shape will be correct once the first update runs (but may be
457  # incomplete, so is not helpful for initializing the variable).
458  metric_variable.set_shape(input_tensor.get_shape())
459  return (metric_variable.value(), update_op)
462def _identity_metric_nested(name, input_tensors):
463  """Create identity metrics for a nested tuple of Tensors."""
464  update_ops = []
465  value_tensors = []
466  for tensor_number, tensor in enumerate(nest.flatten(input_tensors)):
467    value_tensor, update_op = _identity_metric_single(
468        name="{}_{}".format(name, tensor_number), input_tensor=tensor)
469    update_ops.append(update_op)
470    value_tensors.append(value_tensor)
471  return (nest.pack_sequence_as(input_tensors, value_tensors),
472          control_flow_ops.group(*update_ops))
475def state_to_dictionary(state_tuple):
476  """Flatten model state into a dictionary with string keys."""
477  flattened = {}
478  for state_number, state_value in enumerate(nest.flatten(state_tuple)):
479    prefixed_state_name = "{}_{:02d}".format(feature_keys.State.STATE_PREFIX,
480                                             state_number)
481    flattened[prefixed_state_name] = state_value
482  return flattened