1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
7#     http://www.apache.org/licenses/LICENSE-2.0
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helper functions for training and constructing time series Models."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
21import numpy
23from tensorflow.contrib.timeseries.python.timeseries import feature_keys
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import init_ops
28from tensorflow.python.ops import nn_ops
29from tensorflow.python.ops import variable_scope
32# TODO(agarwal): Remove and replace with functionality from tf.slim
33def fully_connected(inp,
34                    inp_size,
35                    layer_size,
36                    name,
37                    activation=nn_ops.relu,
38                    dtype=dtypes.float32):
39  """Helper method to create a fully connected hidden layer."""
40  wt = variable_scope.get_variable(
41      name="{}_weight".format(name), shape=[inp_size, layer_size], dtype=dtype)
42  bias = variable_scope.get_variable(
43      name="{}_bias".format(name),
44      shape=[layer_size],
45      initializer=init_ops.zeros_initializer())
46  output = nn_ops.xw_plus_b(inp, wt, bias)
47  if activation is not None:
48    assert callable(activation)
49    output = activation(output)
50  return output
53def parameter_switch(parameter_overrides):
54  """Create a function which chooses between overridden and model parameters.
56  Args:
57    parameter_overrides: A dictionary with explicit overrides of model
58        parameters, mapping from Tensors to their overridden values.
59  Returns:
60    A function which takes a Tensor and returns the override if it is specified,
61        or otherwise the evaluated value (given current Variable values).
62  """
63  def get_passed_or_trained_value(parameter):
64    return ops.convert_to_tensor(
65        parameter_overrides.get(parameter, parameter)).eval()
66  return get_passed_or_trained_value
69def canonicalize_times_or_steps_from_output(times, steps,
70                                            previous_model_output):
71  """Canonicalizes either relative or absolute times, with error checking."""
72  if steps is not None and times is not None:
73    raise ValueError("Only one of `steps` and `times` may be specified.")
74  if steps is None and times is None:
75    raise ValueError("One of `steps` and `times` must be specified.")
76  if times is not None:
77    times = numpy.array(times)
78    if len(times.shape) != 2:
79      times = times[None, ...]
80    if (previous_model_output[feature_keys.FilteringResults.TIMES].shape[0] !=
81        times.shape[0]):
82      raise ValueError(
83          ("`times` must have a batch dimension matching"
84           " the previous model output (got a batch dimension of {} for `times`"
85           " and {} for the previous model output).").format(
86               times.shape[0], previous_model_output[
87                   feature_keys.FilteringResults.TIMES].shape[0]))
88    if not (previous_model_output[feature_keys.FilteringResults.TIMES][:, -1] <
89            times[:, 0]).all():
90      raise ValueError("Prediction times must be after the corresponding "
91                       "previous model output.")
92  if steps is not None:
93    predict_times = (
94        previous_model_output[feature_keys.FilteringResults.TIMES][:, -1:] + 1 +
95        numpy.arange(steps)[None, ...])
96  else:
97    predict_times = times
98  return predict_times