15"""Utilities for testing time series models."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
21from tensorflow.contrib.timeseries.python.timeseries import estimators
22from tensorflow.contrib.timeseries.python.timeseries import input_pipeline
23from tensorflow.contrib.timeseries.python.timeseries import state_management
24from tensorflow.contrib.timeseries.python.timeseries.feature_keys import TrainEvalFeatures
26from tensorflow.python.client import session
27from tensorflow.python.estimator import estimator_lib
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import random_seed
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import math_ops
32from tensorflow.python.ops import variables
33from tensorflow.python.platform import tf_logging as logging
34from tensorflow.python.training import adam
35from tensorflow.python.training import basic_session_run_hooks
36from tensorflow.python.training import coordinator as coordinator_lib
37from tensorflow.python.training import queue_runner_impl
38from tensorflow.python.util import nest
41class AllWindowInputFn(input_pipeline.TimeSeriesInputFn):
42  """Returns all contiguous windows of data from a full dataset.
44  In contrast to WholeDatasetInputFn, which does basic shape checking but
45  maintains the flat sequencing of data, this `TimeSeriesInputFn` creates
46  batches of windows. However, unlike `RandomWindowInputFn` these windows are
47  deterministic, starting at every possible offset (i.e. batches of size
48  series_length - window_size + 1 are produced).
49  """
51  def __init__(self, time_series_reader, window_size):
52    """Initialize the input_pipeline.
54    Args:
55      time_series_reader: A `input_pipeline.TimeSeriesReader` object.
56      window_size: The size of contiguous windows of data to produce.
57    """
58    self._window_size = window_size
59    self._reader = time_series_reader
60    super(AllWindowInputFn, self).__init__()
62  def create_batch(self):
63    features = self._reader.read_full()
64    times = features[TrainEvalFeatures.TIMES]
65    num_windows = array_ops.shape(times)[0] - self._window_size + 1
66    indices = array_ops.reshape(math_ops.range(num_windows), [num_windows, 1])
67    # indices contains the starting point for each window. We now extend these
68    # indices to include the elements inside the windows as well by doing a
69    # broadcast addition.
70    increments = array_ops.reshape(math_ops.range(self._window_size), [1, -1])
71    all_indices = array_ops.reshape(indices + increments, [-1])
72    # Select the appropriate elements in the batch and reshape the output to 3D.
73    features = {
74        key: array_ops.reshape(
75            array_ops.gather(value, all_indices),
76            array_ops.concat(
77                [[num_windows, self._window_size], array_ops.shape(value)[1:]],
78                axis=0))
79        for key, value in features.items()
80    }
81    return (features, None)
84class _SavingTensorHook(basic_session_run_hooks.LoggingTensorHook):
85  """A hook to save Tensors during training."""
87  def __init__(self, tensors, every_n_iter=None, every_n_secs=None):
88    self.tensor_values = {}
89    super(_SavingTensorHook, self).__init__(
90        tensors=tensors, every_n_iter=every_n_iter,
91        every_n_secs=every_n_secs)
93  def after_run(self, run_context, run_values):
94    del run_context
95    if self._should_trigger:
96      for tag in self._current_tensors.keys():
97        self.tensor_values[tag] = run_values.results[tag]
98      self._timer.update_last_triggered_step(self._iter_count)
99    self._iter_count += 1
102def _train_on_generated_data(
103    generate_fn, generative_model, train_iterations, seed,
104    learning_rate=0.1, ignore_params_fn=lambda _: (),
105    derived_param_test_fn=lambda _: (),
106    train_input_fn_type=input_pipeline.WholeDatasetInputFn,
107    train_state_manager=state_management.PassthroughStateManager()):
108  """The training portion of parameter recovery tests."""
109  random_seed.set_random_seed(seed)
110  generate_graph = ops.Graph()
111  with generate_graph.as_default():
112    with session.Session(graph=generate_graph):
113      generative_model.initialize_graph()
114      time_series_reader, true_parameters = generate_fn(generative_model)
115      true_parameters = {
116          tensor.name: value for tensor, value in true_parameters.items()}
117  eval_input_fn = input_pipeline.WholeDatasetInputFn(time_series_reader)
118  eval_state_manager = state_management.PassthroughStateManager()
119  true_parameter_eval_graph = ops.Graph()
120  with true_parameter_eval_graph.as_default():
121    generative_model.initialize_graph()
122    ignore_params = ignore_params_fn(generative_model)
123    feature_dict, _ = eval_input_fn()
124    eval_state_manager.initialize_graph(generative_model)
125    feature_dict[TrainEvalFeatures.VALUES] = math_ops.cast(
126        feature_dict[TrainEvalFeatures.VALUES], generative_model.dtype)
127    model_outputs = eval_state_manager.define_loss(
128        model=generative_model,
129        features=feature_dict,
130        mode=estimator_lib.ModeKeys.EVAL)
131    with session.Session(graph=true_parameter_eval_graph) as sess:
132      variables.global_variables_initializer().run()
133      coordinator = coordinator_lib.Coordinator()
134      queue_runner_impl.start_queue_runners(sess, coord=coordinator)
135      true_param_loss = model_outputs.loss.eval(feed_dict=true_parameters)
136      true_transformed_params = {
137          param: param.eval(feed_dict=true_parameters)
138          for param in derived_param_test_fn(generative_model)}
139      coordinator.request_stop()
140      coordinator.join()
142  saving_hook = _SavingTensorHook(
143      tensors=true_parameters.keys(),
144      every_n_iter=train_iterations - 1)
146  class _RunConfig(estimator_lib.RunConfig):
148    @property
149    def tf_random_seed(self):
150      return seed
152  estimator = estimators.TimeSeriesRegressor(
153      model=generative_model,
154      config=_RunConfig(),
155      state_manager=train_state_manager,
156      optimizer=adam.AdamOptimizer(learning_rate))
157  train_input_fn = train_input_fn_type(time_series_reader=time_series_reader)
158  trained_loss = (estimator.train(
159      input_fn=train_input_fn,
160      max_steps=train_iterations,
161      hooks=[saving_hook]).evaluate(
162          input_fn=eval_input_fn, steps=1))["loss"]
163  logging.info("Final trained loss: %f", trained_loss)
164  logging.info("True parameter loss: %f", true_param_loss)
165  return (ignore_params, true_parameters, true_transformed_params,
166          trained_loss, true_param_loss, saving_hook,
167          true_parameter_eval_graph)
170def test_parameter_recovery(
171    generate_fn, generative_model, train_iterations, test_case, seed,
172    learning_rate=0.1, rtol=0.2, atol=0.1, train_loss_tolerance_coeff=0.99,
173    ignore_params_fn=lambda _: (),
174    derived_param_test_fn=lambda _: (),
175    train_input_fn_type=input_pipeline.WholeDatasetInputFn,
176    train_state_manager=state_management.PassthroughStateManager()):
177  """Test that a generative model fits generated data.
179  Args:
180    generate_fn: A function taking a model and returning a `TimeSeriesReader`
181        object and dictionary mapping parameters to their
182        values. model.initialize_graph() will have been called on the model
183        before it is passed to this function.
184    generative_model: A timeseries.model.TimeSeriesModel instance to test.
185    train_iterations: Number of training steps.
186    test_case: A tf.test.TestCase to run assertions on.
187    seed: Same as for TimeSeriesModel.unconditional_generate().
188    learning_rate: Step size for optimization.
189    rtol: Relative tolerance for tests.
190    atol: Absolute tolerance for tests.
191    train_loss_tolerance_coeff: Trained loss times this value must be less
192        than the loss evaluated using the generated parameters.
193    ignore_params_fn: Function mapping from a Model to a list of parameters
194        which are not tested for accurate recovery.
195    derived_param_test_fn: Function returning a list of derived parameters
196        (Tensors) which are checked for accurate recovery (comparing the value
197        evaluated with trained parameters to the value under the true
198        parameters).
200        As an example, for VARMA, in addition to checking AR and MA parameters,
201        this function can be used to also check lagged covariance. See
202        varma_ssm.py for details.
203    train_input_fn_type: The `TimeSeriesInputFn` type to use when training
204        (likely `WholeDatasetInputFn` or `RandomWindowInputFn`). If None, use
205        `WholeDatasetInputFn`.
206    train_state_manager: The state manager to use when training (likely
207        `PassthroughStateManager` or `ChainingStateManager`). If None, use
208        `PassthroughStateManager`.
209  """
210  (ignore_params, true_parameters, true_transformed_params,
211   trained_loss, true_param_loss, saving_hook, true_parameter_eval_graph
212  ) = _train_on_generated_data(
213      generate_fn=generate_fn, generative_model=generative_model,
214      train_iterations=train_iterations, seed=seed, learning_rate=learning_rate,
215      ignore_params_fn=ignore_params_fn,
216      derived_param_test_fn=derived_param_test_fn,
217      train_input_fn_type=train_input_fn_type,
218      train_state_manager=train_state_manager)
219  trained_parameter_substitutions = {}
220  for param in true_parameters.keys():
221    evaled_value = saving_hook.tensor_values[param]
222    trained_parameter_substitutions[param] = evaled_value
223    true_value = true_parameters[param]
224    logging.info("True %s: %s, learned: %s",
225                 param, true_value, evaled_value)
226  with session.Session(graph=true_parameter_eval_graph):
227    for transformed_param, true_value in true_transformed_params.items():
228      trained_value = transformed_param.eval(
229          feed_dict=trained_parameter_substitutions)
230      logging.info("True %s [transformed parameter]: %s, learned: %s",
231                   transformed_param, true_value, trained_value)
232      test_case.assertAllClose(true_value, trained_value,
233                               rtol=rtol, atol=atol)
235  if ignore_params is None:
236    ignore_params = []
237  else:
238    ignore_params = nest.flatten(ignore_params)
239  ignore_params = [tensor.name for tensor in ignore_params]
240  if trained_loss > 0:
241    test_case.assertLess(trained_loss * train_loss_tolerance_coeff,
242                         true_param_loss)
243  else:
244    test_case.assertLess(trained_loss / train_loss_tolerance_coeff,
245                         true_param_loss)
246  for param in true_parameters.keys():
247    if param in ignore_params:
248      continue
249    evaled_value = saving_hook.tensor_values[param]
250    true_value = true_parameters[param]
251    test_case.assertAllClose(true_value, evaled_value,
252                             rtol=rtol, atol=atol)
255def parameter_recovery_dry_run(
256    generate_fn, generative_model, seed,
257    learning_rate=0.1,
258    train_input_fn_type=input_pipeline.WholeDatasetInputFn,
259    train_state_manager=state_management.PassthroughStateManager()):
260  """Test that a generative model can train on generated data.
262  Args:
263    generate_fn: A function taking a model and returning a
264        `input_pipeline.TimeSeriesReader` object and a dictionary mapping
265        parameters to their values. model.initialize_graph() will have been
266        called on the model before it is passed to this function.
267    generative_model: A timeseries.model.TimeSeriesModel instance to test.
268    seed: Same as for TimeSeriesModel.unconditional_generate().
269    learning_rate: Step size for optimization.
270    train_input_fn_type: The type of `TimeSeriesInputFn` to use when training
271        (likely `WholeDatasetInputFn` or `RandomWindowInputFn`). If None, use
272        `WholeDatasetInputFn`.
273    train_state_manager: The state manager to use when training (likely
274        `PassthroughStateManager` or `ChainingStateManager`). If None, use
275        `PassthroughStateManager`.
276  """
277  _train_on_generated_data(
278      generate_fn=generate_fn, generative_model=generative_model,
279      seed=seed, learning_rate=learning_rate,
280      train_input_fn_type=train_input_fn_type,
281      train_state_manager=train_state_manager,
282      train_iterations=2)