1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Miscellaneous utilities used by time series models."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import math
23
24import numpy as np
25
26from tensorflow.contrib import lookup
27from tensorflow.contrib.layers.python.layers import layers
28
29from tensorflow.contrib.timeseries.python.timeseries.feature_keys import TrainEvalFeatures
30
31from tensorflow.python.framework import constant_op
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import tensor_shape
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import control_flow_ops
37from tensorflow.python.ops import functional_ops
38from tensorflow.python.ops import gen_math_ops
39from tensorflow.python.ops import init_ops
40from tensorflow.python.ops import linalg_ops
41from tensorflow.python.ops import math_ops
42from tensorflow.python.ops import nn
43from tensorflow.python.ops import state_ops
44from tensorflow.python.ops import variable_scope
45from tensorflow.python.util import nest
46
47
48def normal_log_prob(loc, scale, x):
49  """Computes the Normal log pdf."""
50  z = (x - loc) / scale
51  return -0.5 * (math_ops.square(z)
52                 + np.log(2. * np.pi) + math_ops.log(scale))
53
54
55def cauchy_log_prob(loc, scale, x):
56  """Computes the Cauchy log pdf."""
57  z = (x - loc) / scale
58  return (-np.log(np.pi) - math_ops.log(scale) -
59          math_ops.log1p(math_ops.square(z)))
60
61
62def mvn_tril_log_prob(loc, scale_tril, x):
63  """Computes the MVN log pdf under tril scale. Doesn't handle batches."""
64  x0 = x - loc
65  z = linalg_ops.matrix_triangular_solve(
66      scale_tril, x0[..., array_ops.newaxis])[..., 0]
67  log_det_cov = 2. * math_ops.reduce_sum(math_ops.log(
68      array_ops.matrix_diag_part(scale_tril)), axis=-1)
69  d = math_ops.cast(array_ops.shape(scale_tril)[-1], log_det_cov.dtype)
70  return -0.5 * (math_ops.reduce_sum(math_ops.square(z), axis=-1)
71                 + d * np.log(2. * np.pi) + log_det_cov)
72
73
74def clip_covariance(
75    covariance_matrix, maximum_variance_ratio, minimum_variance):
76  """Enforce constraints on a covariance matrix to improve numerical stability.
77
78  Args:
79    covariance_matrix: A [..., N, N] batch of covariance matrices.
80    maximum_variance_ratio: The maximum allowed ratio of two diagonal
81      entries. Any entries lower than the maximum entry divided by this ratio
82      will be set to that value.
83    minimum_variance: A floor for diagonal entries in the returned matrix.
84  Returns:
85    A new covariance matrix with the requested constraints enforced. If the
86    input was positive definite, the output will be too.
87  """
88  # TODO(allenl): Smarter scaling here so that correlations are preserved when
89  # fiddling with diagonal elements.
90  diagonal = array_ops.matrix_diag_part(covariance_matrix)
91  maximum = math_ops.reduce_max(diagonal, axis=-1, keepdims=True)
92  new_diagonal = gen_math_ops.maximum(
93      diagonal, maximum / maximum_variance_ratio)
94  return array_ops.matrix_set_diag(
95      covariance_matrix, math_ops.maximum(new_diagonal, minimum_variance))
96
97
98def block_diagonal(matrices, dtype=dtypes.float32, name="block_diagonal"):
99  r"""Constructs block-diagonal matrices from a list of batched 2D tensors.
100
101  Args:
102    matrices: A list of Tensors with shape [..., N_i, M_i] (i.e. a list of
103      matrices with the same batch dimension).
104    dtype: Data type to use. The Tensors in `matrices` must match this dtype.
105    name: A name for the returned op.
106  Returns:
107    A matrix with the input matrices stacked along its main diagonal, having
108    shape [..., \sum_i N_i, \sum_i M_i].
109  """
110  matrices = [ops.convert_to_tensor(matrix, dtype=dtype) for matrix in matrices]
111  blocked_rows = tensor_shape.Dimension(0)
112  blocked_cols = tensor_shape.Dimension(0)
113  batch_shape = tensor_shape.TensorShape(None)
114  for matrix in matrices:
115    full_matrix_shape = matrix.get_shape().with_rank_at_least(2)
116    batch_shape = batch_shape.merge_with(full_matrix_shape[:-2])
117    blocked_rows += full_matrix_shape[-2]
118    blocked_cols += full_matrix_shape[-1]
119  ret_columns_list = []
120  for matrix in matrices:
121    matrix_shape = array_ops.shape(matrix)
122    ret_columns_list.append(matrix_shape[-1])
123  ret_columns = math_ops.add_n(ret_columns_list)
124  row_blocks = []
125  current_column = 0
126  for matrix in matrices:
127    matrix_shape = array_ops.shape(matrix)
128    row_before_length = current_column
129    current_column += matrix_shape[-1]
130    row_after_length = ret_columns - current_column
131    row_blocks.append(
132        array_ops.pad(
133            tensor=matrix,
134            paddings=array_ops.concat(
135                [
136                    array_ops.zeros(
137                        [array_ops.rank(matrix) - 1, 2], dtype=dtypes.int32), [(
138                            row_before_length, row_after_length)]
139                ],
140                axis=0)))
141  blocked = array_ops.concat(row_blocks, -2, name=name)
142  blocked.set_shape(batch_shape.concatenate((blocked_rows, blocked_cols)))
143  return blocked
144
145
146def power_sums_tensor(array_size, power_matrix, multiplier):
147  r"""Computes \sum_{i=0}^{N-1} A^i B (A^i)^T for N=0..(array_size + 1).
148
149  Args:
150    array_size: The number of non-trivial sums to pre-compute.
151    power_matrix: The "A" matrix above.
152    multiplier: The "B" matrix above
153  Returns:
154    A Tensor with S[N] = \sum_{i=0}^{N-1} A^i B (A^i)^T
155      S[0] is the zero matrix
156      S[1] is B
157      S[2] is A B A^T + B
158      ...and so on
159  """
160  array_size = math_ops.cast(array_size, dtypes.int32)
161  power_matrix = ops.convert_to_tensor(power_matrix)
162  identity_like_power_matrix = linalg_ops.eye(
163      array_ops.shape(power_matrix)[0], dtype=power_matrix.dtype)
164  identity_like_power_matrix.set_shape(
165      ops.convert_to_tensor(power_matrix).get_shape())
166  transition_powers = functional_ops.scan(
167      lambda previous_power, _: math_ops.matmul(previous_power, power_matrix),
168      math_ops.range(array_size - 1),
169      initializer=identity_like_power_matrix)
170  summed = math_ops.cumsum(
171      array_ops.concat([
172          array_ops.expand_dims(multiplier, 0), math_ops.matmul(
173              batch_times_matrix(transition_powers, multiplier),
174              transition_powers,
175              adjoint_b=True)
176      ], 0))
177  return array_ops.concat(
178      [array_ops.expand_dims(array_ops.zeros_like(multiplier), 0), summed], 0)
179
180
181def matrix_to_powers(matrix, powers):
182  """Raise a single matrix to multiple powers."""
183  matrix_tiled = array_ops.tile(
184      array_ops.expand_dims(matrix, 0), [array_ops.size(powers), 1, 1])
185  return batch_matrix_pow(matrix_tiled, powers)
186
187
188def batch_matrix_pow(matrices, powers):
189  """Compute powers of matrices, e.g. A^3 = matmul(matmul(A, A), A).
190
191  Uses exponentiation by squaring, with O(log(p)) matrix multiplications to
192  compute A^p.
193
194  Args:
195    matrices: [batch size x N x N]
196    powers: Which integer power to raise each matrix to [batch size]
197  Returns:
198    The matrices raised to their respective powers, same dimensions as the
199    "matrices" argument.
200  """
201
202  def terminate_when_all_zero(current_argument, residual_powers, accumulator):
203    del current_argument, accumulator  # not used for condition
204    do_exit = math_ops.reduce_any(
205        math_ops.greater(residual_powers, array_ops.ones_like(residual_powers)))
206    return do_exit
207
208  def do_iteration(current_argument, residual_powers, accumulator):
209    """Compute one step of iterative exponentiation by squaring.
210
211    The recursive form is:
212      power(A, p) = { power(matmul(A, A), p / 2) for even p
213                    { matmul(A, power(matmul(A, A), (p - 1) / 2)) for odd p
214      power(A, 0) = I
215
216    The power(A, 0) = I case is handled by starting with accumulator set to the
217    identity matrix; matrices with zero residual powers are passed through
218    unchanged.
219
220    Args:
221      current_argument: On this step, what is the first argument (A^2..^2) to
222          the (unrolled) recursive function? [batch size x N x N]
223      residual_powers: On this step, what is the second argument (residual p)?
224          [batch_size]
225      accumulator: Accumulates the exterior multiplications from the odd
226          powers (initially the identity matrix). [batch_size x N x N]
227    Returns:
228      Updated versions of each argument for one step of the unrolled
229      computation. Does not change parts of the batch which have a residual
230      power of zero.
231    """
232    is_even = math_ops.equal(residual_powers % 2,
233                             array_ops.zeros(
234                                 array_ops.shape(residual_powers),
235                                 dtype=dtypes.int32))
236    new_accumulator = array_ops.where(is_even, accumulator,
237                                      math_ops.matmul(accumulator,
238                                                      current_argument))
239    new_argument = math_ops.matmul(current_argument, current_argument)
240    do_update = math_ops.greater(residual_powers, 1)
241    new_residual_powers = residual_powers - residual_powers % 2
242    new_residual_powers //= 2
243    # Stop updating if we've reached our base case; some batch elements may
244    # finish sooner than others
245    accumulator = array_ops.where(do_update, new_accumulator, accumulator)
246    current_argument = array_ops.where(do_update, new_argument,
247                                       current_argument)
248    residual_powers = array_ops.where(do_update, new_residual_powers,
249                                      residual_powers)
250    return (current_argument, residual_powers, accumulator)
251
252  matrices = ops.convert_to_tensor(matrices)
253  powers = math_ops.cast(powers, dtype=dtypes.int32)
254  ident = array_ops.expand_dims(
255      array_ops.diag(
256          array_ops.ones([array_ops.shape(matrices)[1]], dtype=matrices.dtype)),
257      0)
258  ident_tiled = array_ops.tile(ident, [array_ops.shape(matrices)[0], 1, 1])
259  (final_argument,
260   final_residual_power, final_accumulator) = control_flow_ops.while_loop(
261       terminate_when_all_zero, do_iteration, [matrices, powers, ident_tiled])
262  return array_ops.where(
263      math_ops.equal(final_residual_power,
264                     array_ops.zeros_like(
265                         final_residual_power, dtype=dtypes.int32)),
266      ident_tiled, math_ops.matmul(final_argument, final_accumulator))
267
268
269# TODO(allenl): would be useful if this was built into batch_matmul
270def batch_times_matrix(batch, matrix, adj_x=False, adj_y=False):
271  """Multiply a batch of matrices by a single matrix.
272
273  Functionally equivalent to:
274  tf.matmul(batch, array_ops.tile(gen_math_ops.expand_dims(matrix, 0),
275                                 [array_ops.shape(batch)[0], 1, 1]),
276                  adjoint_a=adj_x, adjoint_b=adj_y)
277
278  Args:
279    batch: [batch_size x N x M] after optional transpose
280    matrix: [M x P] after optional transpose
281    adj_x: If true, transpose the second two dimensions of "batch" before
282        multiplying.
283    adj_y: If true, transpose "matrix" before multiplying.
284  Returns:
285    [batch_size x N x P]
286  """
287  batch = ops.convert_to_tensor(batch)
288  matrix = ops.convert_to_tensor(matrix)
289  assert batch.get_shape().ndims == 3
290  assert matrix.get_shape().ndims == 2
291  if adj_x:
292    batch = array_ops.transpose(batch, [0, 2, 1])
293  batch_dimension = batch.get_shape().dims[0].value
294  first_dimension = batch.get_shape().dims[1].value
295  tensor_batch_shape = array_ops.shape(batch)
296  if batch_dimension is None:
297    batch_dimension = tensor_batch_shape[0]
298  if first_dimension is None:
299    first_dimension = tensor_batch_shape[1]
300  matrix_first_dimension, matrix_second_dimension = matrix.get_shape().as_list()
301  batch_reshaped = array_ops.reshape(batch, [-1, tensor_batch_shape[2]])
302  if adj_y:
303    if matrix_first_dimension is None:
304      matrix_first_dimension = array_ops.shape(matrix)[0]
305    result_shape = [batch_dimension, first_dimension, matrix_first_dimension]
306  else:
307    if matrix_second_dimension is None:
308      matrix_second_dimension = array_ops.shape(matrix)[1]
309    result_shape = [batch_dimension, first_dimension, matrix_second_dimension]
310  return array_ops.reshape(
311      math_ops.matmul(batch_reshaped, matrix, adjoint_b=adj_y), result_shape)
312
313
314def matrix_times_batch(matrix, batch, adj_x=False, adj_y=False):
315  """Like batch_times_matrix, but with the multiplication order swapped."""
316  return array_ops.transpose(
317      batch_times_matrix(
318          batch=batch, matrix=matrix, adj_x=not adj_y, adj_y=not adj_x),
319      [0, 2, 1])
320
321
322def make_toeplitz_matrix(inputs, name=None):
323  """Make a symmetric Toeplitz matrix from input array of values.
324
325  Args:
326    inputs: a 3-D tensor of shape [num_blocks, block_size, block_size].
327    name: the name of the operation.
328
329  Returns:
330    a symmetric Toeplitz matrix of shape
331      [num_blocks*block_size, num_blocks*block_size].
332  """
333  num_blocks = array_ops.shape(inputs)[0]
334  block_size = array_ops.shape(inputs)[1]
335  output_size = block_size * num_blocks
336  lags = array_ops.reshape(math_ops.range(num_blocks), shape=[1, -1])
337  indices = math_ops.abs(lags - array_ops.transpose(lags))
338  output = array_ops.gather(inputs, indices)
339  output = array_ops.reshape(
340      array_ops.transpose(output, [0, 2, 1, 3]), [output_size, output_size])
341  return array_ops.identity(output, name=name)
342
343
344# TODO(allenl): Investigate alternative parameterizations.
345def sign_magnitude_positive_definite(
346    raw, off_diagonal_scale=0., overall_scale=0.):
347  """Constructs a positive definite matrix from an unconstrained input matrix.
348
349  We want to keep the whole matrix on a log scale, but also allow off-diagonal
350  elements to be negative, so the sign of off-diagonal elements is modeled
351  separately from their magnitude (using the lower and upper triangles
352  respectively). Specifically:
353
354  for i < j, we have:
355    output_cholesky[i, j] = raw[j, i] / (abs(raw[j, i]) + 1) *
356        exp((off_diagonal_scale + overall_scale + raw[i, j]) / 2)
357
358  output_cholesky[i, i] = exp((raw[i, i] + overall_scale) / 2)
359
360  output = output_cholesky^T * output_cholesky
361
362  where raw, off_diagonal_scale, and overall_scale are
363  un-constrained real-valued variables. The resulting values are stable
364  around zero due to the exponential (and the softsign keeps the function
365  smooth).
366
367  Args:
368    raw: A [..., M, M] Tensor.
369    off_diagonal_scale: A scalar or [...] shaped Tensor controlling the relative
370        scale of off-diagonal values in the output matrix.
371    overall_scale: A scalar or [...] shaped Tensor controlling the overall scale
372        of the output matrix.
373  Returns:
374    The `output` matrix described above, a [..., M, M] positive definite matrix.
375
376  """
377  raw = ops.convert_to_tensor(raw)
378  diagonal = array_ops.matrix_diag_part(raw)
379  def _right_pad_with_ones(tensor, target_rank):
380    # Allow broadcasting even if overall_scale and off_diagonal_scale have batch
381    # dimensions
382    tensor = ops.convert_to_tensor(tensor, dtype=raw.dtype.base_dtype)
383    return array_ops.reshape(tensor,
384                             array_ops.concat(
385                                 [
386                                     array_ops.shape(tensor), array_ops.ones(
387                                         [target_rank - array_ops.rank(tensor)],
388                                         dtype=target_rank.dtype)
389                                 ],
390                                 axis=0))
391  # We divide the log values by 2 to compensate for the squaring that happens
392  # when transforming Cholesky factors into positive definite matrices.
393  sign_magnitude = (gen_math_ops.exp(
394      (raw + _right_pad_with_ones(off_diagonal_scale, array_ops.rank(raw)) +
395       _right_pad_with_ones(overall_scale, array_ops.rank(raw))) / 2.) *
396                    nn.softsign(array_ops.matrix_transpose(raw)))
397  sign_magnitude.set_shape(raw.get_shape())
398  cholesky_factor = array_ops.matrix_set_diag(
399      input=array_ops.matrix_band_part(sign_magnitude, 0, -1),
400      diagonal=gen_math_ops.exp((diagonal + _right_pad_with_ones(
401          overall_scale, array_ops.rank(diagonal))) / 2.))
402  return math_ops.matmul(cholesky_factor, cholesky_factor, transpose_a=True)
403
404
405def transform_to_covariance_matrices(input_vectors, matrix_size):
406  """Construct covariance matrices via transformations from input_vectors.
407
408  Args:
409    input_vectors: A [batch size x input size] batch of vectors to transform.
410    matrix_size: An integer indicating one dimension of the (square) output
411        matrix.
412  Returns:
413    A [batch size x matrix_size x matrix_size] batch of covariance matrices.
414  """
415  combined_values = layers.fully_connected(
416      input_vectors, matrix_size**2 + 2, activation_fn=None)
417  return sign_magnitude_positive_definite(
418      raw=array_ops.reshape(combined_values[..., :-2],
419                            array_ops.concat([
420                                array_ops.shape(combined_values)[:-1],
421                                [matrix_size, matrix_size]
422                            ], 0)),
423      off_diagonal_scale=combined_values[..., -2],
424      overall_scale=combined_values[..., -1])
425
426
427def variable_covariance_matrix(
428    size, name, dtype, initial_diagonal_values=None,
429    initial_overall_scale_log=0.):
430  """Construct a Variable-parameterized positive definite matrix.
431
432  Useful for parameterizing covariance matrices.
433
434  Args:
435    size: The size of the main diagonal, the returned matrix having shape [size
436        x size].
437    name: The name to use when defining variables and ops.
438    dtype: The floating point data type to use.
439    initial_diagonal_values: A Tensor with shape [size] with initial values for
440        the diagonal values of the returned matrix. Must be positive.
441    initial_overall_scale_log: Initial value of the bias term for every element
442        of the matrix in log space.
443  Returns:
444    A Variable-parameterized covariance matrix with shape [size x size].
445  """
446  raw_values = variable_scope.get_variable(
447      name + "_pre_transform",
448      dtype=dtype,
449      shape=[size, size],
450      initializer=init_ops.zeros_initializer())
451  if initial_diagonal_values is not None:
452    raw_values += array_ops.matrix_diag(math_ops.log(initial_diagonal_values))
453  return array_ops.identity(
454      sign_magnitude_positive_definite(
455          raw=raw_values,
456          off_diagonal_scale=variable_scope.get_variable(
457              name + "_off_diagonal_scale",
458              dtype=dtype,
459              initializer=constant_op.constant(-5., dtype=dtype)),
460          overall_scale=ops.convert_to_tensor(
461              initial_overall_scale_log, dtype=dtype) +
462          variable_scope.get_variable(
463              name + "_overall_scale",
464              dtype=dtype,
465              shape=[],
466              initializer=init_ops.zeros_initializer())),
467      name=name)
468
469
470def batch_start_time(times):
471  return times[:, 0]
472
473
474def batch_end_time(times):
475  return times[:, -1]
476
477
478def log_noninformative_covariance_prior(covariance):
479  """Compute a relatively uninformative prior for noise parameters.
480
481  Helpful for avoiding noise over-estimation, where noise otherwise decreases
482  very slowly during optimization.
483
484  See:
485    Villegas, C. On the A Priori Distribution of the Covariance Matrix.
486    Ann. Math. Statist. 40 (1969), no. 3, 1098--1099.
487
488  Args:
489    covariance: A covariance matrix.
490  Returns:
491    For a [p x p] matrix:
492      log(det(covariance)^(-(p + 1) / 2))
493  """
494  # Avoid zero/negative determinants due to numerical errors
495  covariance += array_ops.diag(1e-8 * array_ops.ones(
496      shape=[array_ops.shape(covariance)[0]], dtype=covariance.dtype))
497  power = -(math_ops.cast(array_ops.shape(covariance)[0] + 1,
498                          covariance.dtype) / 2.)
499  return power * math_ops.log(linalg_ops.matrix_determinant(covariance))
500
501
502def entropy_matched_cauchy_scale(covariance):
503  """Approximates a similar Cauchy distribution given a covariance matrix.
504
505  Since Cauchy distributions do not have moments, entropy matching provides one
506  way to set a Cauchy's scale parameter in a way that provides a similar
507  distribution. The effect is dividing the standard deviation of an independent
508  Gaussian by a constant very near 3.
509
510  To set the scale of the Cauchy distribution, we first select the diagonals of
511  `covariance`. Since this ignores cross terms, it overestimates the entropy of
512  the Gaussian. For each of these variances, we solve for the Cauchy scale
513  parameter which gives the same entropy as the Gaussian with that
514  variance. This means setting the (univariate) Gaussian entropy
515      0.5 * ln(2 * variance * pi * e)
516  equal to the Cauchy entropy
517      ln(4 * pi * scale)
518  Solving, we get scale = sqrt(variance * (e / (8 pi))).
519
520  Args:
521    covariance: A [batch size x N x N] batch of covariance matrices to produce
522        Cauchy scales for.
523  Returns:
524    A [batch size x N] set of Cauchy scale parameters for each part of the batch
525    and each dimension of the input Gaussians.
526  """
527  return math_ops.sqrt(math.e / (8. * math.pi) *
528                       array_ops.matrix_diag_part(covariance))
529
530
531class TensorValuedMutableDenseHashTable(lookup.MutableDenseHashTable):
532  """A version of MutableDenseHashTable which stores arbitrary Tensor shapes.
533
534  Since MutableDenseHashTable only allows vectors right now, simply adds reshape
535  ops on both ends.
536  """
537
538  def __init__(self, key_dtype, value_dtype, default_value, *args, **kwargs):
539    self._non_vector_value_shape = array_ops.shape(default_value)
540    super(TensorValuedMutableDenseHashTable, self).__init__(
541        key_dtype=key_dtype,
542        value_dtype=value_dtype,
543        default_value=array_ops.reshape(default_value, [-1]),
544        *args,
545        **kwargs)
546
547  def insert(self, keys, values, name=None):
548    keys = ops.convert_to_tensor(keys, dtype=self._key_dtype)
549    keys_flat = array_ops.reshape(keys, [-1])
550    return super(TensorValuedMutableDenseHashTable, self).insert(
551        keys=keys_flat,
552        # Each key has one corresponding value, so the shape of the tensor of
553        # values for every key is key_shape + value_shape
554        values=array_ops.reshape(values, [array_ops.shape(keys_flat)[0], -1]),
555        name=name)
556
557  def lookup(self, keys, name=None):
558    keys_flat = array_ops.reshape(
559        ops.convert_to_tensor(keys, dtype=self._key_dtype), [-1])
560    return array_ops.reshape(
561        super(TensorValuedMutableDenseHashTable, self).lookup(
562            keys=keys_flat, name=name),
563        array_ops.concat([array_ops.shape(keys), self._non_vector_value_shape],
564                         0))
565
566
567class TupleOfTensorsLookup(lookup.LookupInterface):
568  """A LookupInterface with nested tuples of Tensors as values.
569
570  Creates one MutableDenseHashTable per value Tensor, which has some unnecessary
571  overhead.
572  """
573
574  def __init__(self,
575               key_dtype,
576               default_values,
577               empty_key,
578               deleted_key,
579               name,
580               checkpoint=True):
581    default_values_flat = nest.flatten(default_values)
582    self._hash_tables = nest.pack_sequence_as(default_values, [
583        TensorValuedMutableDenseHashTable(
584            key_dtype=key_dtype,
585            value_dtype=default_value.dtype.base_dtype,
586            default_value=default_value,
587            empty_key=empty_key,
588            deleted_key=deleted_key,
589            name=name + "_{}".format(table_number),
590            checkpoint=checkpoint)
591        for table_number, default_value in enumerate(default_values_flat)
592    ])
593    self._name = name
594
595  def lookup(self, keys):
596    return nest.pack_sequence_as(
597        self._hash_tables,
598        [hash_table.lookup(keys)
599         for hash_table in nest.flatten(self._hash_tables)])
600
601  def insert(self, keys, values):
602    nest.assert_same_structure(self._hash_tables, values)
603    # Avoid race conditions by requiring that all inputs are computed before any
604    # inserts happen (an issue if one key's update relies on another's value).
605    values_flat = [array_ops.identity(value) for value in nest.flatten(values)]
606    with ops.control_dependencies(values_flat):
607      insert_ops = [hash_table.insert(keys, value)
608                    for hash_table, value
609                    in zip(nest.flatten(self._hash_tables),
610                           values_flat)]
611    return control_flow_ops.group(*insert_ops)
612
613  def check_table_dtypes(self, key_dtype, value_dtype):
614    # dtype checking is done in the objects in self._hash_tables
615    pass
616
617
618def replicate_state(start_state, batch_size):
619  """Create batch versions of state.
620
621  Takes a list of Tensors, adds a batch dimension, and replicates
622  batch_size times across that batch dimension. Used to replicate the
623  non-batch state returned by get_start_state in define_loss.
624
625  Args:
626    start_state: Model-defined state to replicate.
627    batch_size: Batch dimension for data.
628  Returns:
629    Replicated versions of the state.
630  """
631  flattened_state = nest.flatten(start_state)
632  replicated_state = [
633      array_ops.tile(
634          array_ops.expand_dims(state_nonbatch, 0),
635          array_ops.concat([[batch_size], array_ops.ones(
636              [array_ops.rank(state_nonbatch)], dtype=dtypes.int32)], 0))
637      for state_nonbatch in flattened_state
638  ]
639  return nest.pack_sequence_as(start_state, replicated_state)
640
641
642Moments = collections.namedtuple("Moments", ["mean", "variance"])
643
644
645# Currently all of these statistics are computed incrementally (i.e. are updated
646# every time a new mini-batch of training data is presented) when this object is
647# created in InputStatisticsFromMiniBatch.
648InputStatistics = collections.namedtuple(
649    "InputStatistics",
650    ["series_start_moments",  # The mean and variance of each feature in a chunk
651                              # (with a size configured in the statistics
652                              # object) at the start of the series. A tuple of
653                              # (mean, variance), each with shape [number of
654                              # features], floating point. One use is in state
655                              # space models, to keep priors calibrated even as
656                              # earlier parts of the series are presented. If
657                              # this object was created by
658                              # InputStatisticsFromMiniBatch, these moments are
659                              # computed based on the earliest chunk of data
660                              # presented so far. However, there is a race
661                              # condition in the update, so these may reflect
662                              # statistics later in the series, but should
663                              # eventually reflect statistics in a chunk at the
664                              # series start.
665     "overall_feature_moments",  # The mean and variance of each feature over
666                                 # the entire series. A tuple of (mean,
667                                 # variance), each with shape [number of
668                                 # features]. If this object was created by
669                                 # InputStatisticsFromMiniBatch, these moments
670                                 # are estimates based on the data seen so far.
671     "start_time",  # The first (lowest) time in the series, a scalar
672                    # integer. If this object was created by
673                    # InputStatisticsFromMiniBatch, this is the lowest time seen
674                    # so far rather than the lowest time that will ever be seen
675                    # (guaranteed to be at least as low as the lowest time
676                    # presented in the current minibatch).
677     "total_observation_count",  # Count of data points, a scalar integer. If
678                                 # this object was created by
679                                 # InputStatisticsFromMiniBatch, this is an
680                                 # estimate of the total number of observations
681                                 # in the whole dataset computed based on the
682                                 # density of the series and the minimum and
683                                 # maximum times seen.
684    ])
685
686
687# TODO(allenl): It would be nice to do something with full series statistics
688# when the user provides that.
689class InputStatisticsFromMiniBatch(object):
690  """Generate statistics from mini-batch input."""
691
692  def __init__(self, num_features, dtype, starting_variance_window_size=16):
693    """Configure the input statistics object.
694
695    Args:
696      num_features: Number of features for the time series
697      dtype: The floating point data type to use.
698      starting_variance_window_size: The number of datapoints to use when
699          computing the mean and variance at the start of the series.
700    """
701    self._starting_variance_window_size = starting_variance_window_size
702    self._num_features = num_features
703    self._dtype = dtype
704
705  def initialize_graph(self, features, update_statistics=True):
706    """Create any ops needed to provide input statistics.
707
708    Should be called before statistics are requested.
709
710    Args:
711      features: A dictionary, the output of a `TimeSeriesInputFn` (with keys
712          TrainEvalFeatures.TIMES and TrainEvalFeatures.VALUES).
713      update_statistics: Whether `features` should be used to update adaptive
714          statistics. Typically True for training and false for evaluation.
715    Returns:
716      An InputStatistics object composed of Variables, which will be updated
717      based on mini-batches of data if requested.
718    """
719    if (TrainEvalFeatures.TIMES in features
720        and TrainEvalFeatures.VALUES in features):
721      times = features[TrainEvalFeatures.TIMES]
722      values = features[TrainEvalFeatures.VALUES]
723    else:
724      # times and values may not be available, for example during prediction. We
725      # still need to retrieve our variables so that they can be read from, even
726      # if we're not going to update them.
727      times = None
728      values = None
729    # Create/retrieve variables representing input statistics, initialized
730    # without data to avoid deadlocking if variables are initialized before
731    # queue runners are started.
732    with variable_scope.variable_scope("input_statistics", use_resource=True):
733      statistics = self._create_variable_statistics_object()
734    with variable_scope.variable_scope(
735        "input_statistics_auxiliary", use_resource=True):
736      # Secondary statistics, necessary for the incremental computation of the
737      # primary statistics (e.g. counts and sums for computing a mean
738      # incrementally).
739      auxiliary_variables = self._AdaptiveInputAuxiliaryStatistics(
740          num_features=self._num_features, dtype=self._dtype)
741    if update_statistics and times is not None and values is not None:
742      # If we have times and values from mini-batch input, create update ops to
743      # take the new data into account.
744      assign_op = self._update_statistics_from_mini_batch(
745          statistics, auxiliary_variables, times, values)
746      with ops.control_dependencies([assign_op]):
747        stat_variables = nest.pack_sequence_as(statistics, [
748            array_ops.identity(tensor) for tensor in nest.flatten(statistics)
749        ])
750        # Since start time updates have a race condition, ensure that the
751        # reported start time is at least as low as the lowest time in this
752        # mini-batch. The start time should converge on the correct value
753        # eventually even with the race condition, but for example state space
754        # models have an assertion which could fail without this
755        # post-processing.
756        return stat_variables._replace(start_time=gen_math_ops.minimum(
757            stat_variables.start_time, math_ops.reduce_min(times)))
758    else:
759      return statistics
760
761  class _AdaptiveInputAuxiliaryStatistics(collections.namedtuple(
762      "_AdaptiveInputAuxiliaryStatistics",
763      ["max_time_seen",  # The maximum time seen (best effort if updated from
764                         # multiple workers; see notes about race condition
765                         # below).
766       "chunk_count",  # The number of chunks seen.
767       "inter_observation_duration_sum",  # The sum across chunks of their "time
768                                          # density" (number of times per
769                                          # example).
770       "example_count",  # The number of examples seen (each example has a
771                         # single time associated with it and one or more
772                         # real-valued features).
773       "overall_feature_sum",  # The sum of values for each feature. Shape
774                               # [number of features].
775       "overall_feature_sum_of_squares",  # The sum of squared values for each
776                                          # feature. Shape [number of features]
777      ])):
778    """Extra statistics used to incrementally update InputStatistics."""
779
780    def __new__(cls, num_features, dtype):
781      return super(
782          InputStatisticsFromMiniBatch  # pylint: disable=protected-access
783          ._AdaptiveInputAuxiliaryStatistics,
784          cls).__new__(
785              cls,
786              max_time_seen=variable_scope.get_variable(
787                  name="max_time_seen",
788                  initializer=dtypes.int64.min,
789                  dtype=dtypes.int64,
790                  trainable=False),
791              chunk_count=variable_scope.get_variable(
792                  name="chunk_count",
793                  initializer=init_ops.zeros_initializer(),
794                  shape=[],
795                  dtype=dtypes.int64,
796                  trainable=False),
797              inter_observation_duration_sum=variable_scope.get_variable(
798                  name="inter_observation_duration_sum",
799                  initializer=init_ops.zeros_initializer(),
800                  shape=[],
801                  dtype=dtype,
802                  trainable=False),
803              example_count=variable_scope.get_variable(
804                  name="example_count",
805                  shape=[],
806                  dtype=dtypes.int64,
807                  trainable=False),
808              overall_feature_sum=variable_scope.get_variable(
809                  name="overall_feature_sum",
810                  shape=[num_features],
811                  dtype=dtype,
812                  initializer=init_ops.zeros_initializer(),
813                  trainable=False),
814              overall_feature_sum_of_squares=variable_scope.get_variable(
815                  name="overall_feature_sum_of_squares",
816                  shape=[num_features],
817                  dtype=dtype,
818                  initializer=init_ops.zeros_initializer(),
819                  trainable=False))
820
821  def _update_statistics_from_mini_batch(
822      self, statistics, auxiliary_variables, times, values):
823    """Given mini-batch input, update `statistics` and `auxiliary_variables`."""
824    values = math_ops.cast(values, self._dtype)
825    # The density (measured in times per observation) that we see in each part
826    # of the mini-batch.
827    batch_inter_observation_duration = (math_ops.cast(
828        math_ops.reduce_max(times, axis=1) - math_ops.reduce_min(times, axis=1),
829        self._dtype) / math_ops.cast(
830            array_ops.shape(times)[1] - 1, self._dtype))
831    # Co-locate updates with their variables to minimize race conditions when
832    # updating statistics.
833    with ops.device(auxiliary_variables.max_time_seen.device):
834      # There is a race condition if this value is being updated from multiple
835      # workers. However, it should eventually reach the correct value if the
836      # last chunk is presented enough times.
837      max_time_seen_assign = state_ops.assign(
838          auxiliary_variables.max_time_seen,
839          gen_math_ops.maximum(auxiliary_variables.max_time_seen,
840                               math_ops.reduce_max(times)))
841    with ops.device(auxiliary_variables.chunk_count.device):
842      chunk_count_assign = state_ops.assign_add(auxiliary_variables.chunk_count,
843                                                array_ops.shape(
844                                                    times,
845                                                    out_type=dtypes.int64)[0])
846    with ops.device(auxiliary_variables.inter_observation_duration_sum.device):
847      inter_observation_duration_assign = state_ops.assign_add(
848          auxiliary_variables.inter_observation_duration_sum,
849          math_ops.reduce_sum(batch_inter_observation_duration))
850    with ops.device(auxiliary_variables.example_count.device):
851      example_count_assign = state_ops.assign_add(
852          auxiliary_variables.example_count,
853          array_ops.size(times, out_type=dtypes.int64))
854    # Note: These mean/variance updates assume that all points are equally
855    # likely, which is not true if _chunks_ are sampled uniformly from the space
856    # of all possible contiguous chunks, since points at the start and end of
857    # the series are then members of fewer chunks. For series which are much
858    # longer than the chunk size (the usual/expected case), this effect becomes
859    # irrelevant.
860    with ops.device(auxiliary_variables.overall_feature_sum.device):
861      overall_feature_sum_assign = state_ops.assign_add(
862          auxiliary_variables.overall_feature_sum,
863          math_ops.reduce_sum(values, axis=[0, 1]))
864    with ops.device(auxiliary_variables.overall_feature_sum_of_squares.device):
865      overall_feature_sum_of_squares_assign = state_ops.assign_add(
866          auxiliary_variables.overall_feature_sum_of_squares,
867          math_ops.reduce_sum(values**2, axis=[0, 1]))
868    per_chunk_aux_updates = control_flow_ops.group(
869        max_time_seen_assign, chunk_count_assign,
870        inter_observation_duration_assign, example_count_assign,
871        overall_feature_sum_assign, overall_feature_sum_of_squares_assign)
872    with ops.control_dependencies([per_chunk_aux_updates]):
873      example_count_float = math_ops.cast(auxiliary_variables.example_count,
874                                          self._dtype)
875      new_feature_mean = (auxiliary_variables.overall_feature_sum /
876                          example_count_float)
877      overall_feature_mean_update = state_ops.assign(
878          statistics.overall_feature_moments.mean, new_feature_mean)
879      overall_feature_var_update = state_ops.assign(
880          statistics.overall_feature_moments.variance,
881          # De-biased n / (n - 1) variance correction
882          example_count_float / (example_count_float - 1.) *
883          (auxiliary_variables.overall_feature_sum_of_squares /
884           example_count_float - new_feature_mean**2))
885      # TODO(b/35675805): Remove this cast
886      min_time_batch = math_ops.cast(math_ops.argmin(times[:, 0]), dtypes.int32)
887      def series_start_updates():
888        # If this is the lowest-time chunk that we have seen so far, update
889        # series start moments to reflect that. Note that these statistics are
890        # "best effort", as there are race conditions in the update (however,
891        # they should eventually converge if the start of the series is
892        # presented enough times).
893        mean, variance = nn.moments(
894            values[min_time_batch, :self._starting_variance_window_size],
895            axes=[0])
896        return control_flow_ops.group(
897            state_ops.assign(statistics.series_start_moments.mean, mean),
898            state_ops.assign(statistics.series_start_moments.variance,
899                             variance))
900      with ops.device(statistics.start_time.device):
901        series_start_update = control_flow_ops.cond(
902            # Update moments whenever we even match the lowest time seen so far,
903            # to ensure that series start statistics are eventually updated to
904            # their correct values, despite race conditions (i.e. eventually
905            # statistics.start_time will reflect the global lowest time, and
906            # given that we will eventually update the series start moments to
907            # their correct values).
908            math_ops.less_equal(times[min_time_batch, 0],
909                                statistics.start_time),
910            series_start_updates,
911            control_flow_ops.no_op)
912        with ops.control_dependencies([series_start_update]):
913          # There is a race condition if this update is performed in parallel on
914          # multiple workers. Since models may be sensitive to being presented
915          # with times before the putative start time, the value of this
916          # variable is post-processed above to guarantee that each worker is
917          # presented with a start time which is at least as low as the lowest
918          # time in its current mini-batch.
919          start_time_update = state_ops.assign(statistics.start_time,
920                                               gen_math_ops.minimum(
921                                                   statistics.start_time,
922                                                   math_ops.reduce_min(times)))
923      inter_observation_duration_estimate = (
924          auxiliary_variables.inter_observation_duration_sum / math_ops.cast(
925              auxiliary_variables.chunk_count, self._dtype))
926      # Estimate the total number of observations as:
927      #   (end time - start time + 1) * average intra-chunk time density
928      total_observation_count_update = state_ops.assign(
929          statistics.total_observation_count,
930          math_ops.cast(
931              gen_math_ops.round(
932                  math_ops.cast(max_time_seen_assign -
933                                start_time_update + 1, self._dtype) /
934                  inter_observation_duration_estimate), dtypes.int64))
935      per_chunk_stat_updates = control_flow_ops.group(
936          overall_feature_mean_update, overall_feature_var_update,
937          series_start_update, start_time_update,
938          total_observation_count_update)
939    return per_chunk_stat_updates
940
941  def _create_variable_statistics_object(self):
942    """Creates non-trainable variables representing input statistics."""
943    series_start_moments = Moments(
944        mean=variable_scope.get_variable(
945            name="series_start_mean",
946            shape=[self._num_features],
947            dtype=self._dtype,
948            initializer=init_ops.zeros_initializer(),
949            trainable=False),
950        variance=variable_scope.get_variable(
951            name="series_start_variance",
952            shape=[self._num_features],
953            dtype=self._dtype,
954            initializer=init_ops.ones_initializer(),
955            trainable=False))
956    overall_feature_moments = Moments(
957        mean=variable_scope.get_variable(
958            name="overall_feature_mean",
959            shape=[self._num_features],
960            dtype=self._dtype,
961            initializer=init_ops.zeros_initializer(),
962            trainable=False),
963        variance=variable_scope.get_variable(
964            name="overall_feature_var",
965            shape=[self._num_features],
966            dtype=self._dtype,
967            initializer=init_ops.ones_initializer(),
968            trainable=False))
969    start_time = variable_scope.get_variable(
970        name="start_time",
971        dtype=dtypes.int64,
972        initializer=dtypes.int64.max,
973        trainable=False)
974    total_observation_count = variable_scope.get_variable(
975        name="total_observation_count",
976        shape=[],
977        dtype=dtypes.int64,
978        initializer=init_ops.ones_initializer(),
979        trainable=False)
980    return InputStatistics(
981        series_start_moments=series_start_moments,
982        overall_feature_moments=overall_feature_moments,
983        start_time=start_time,
984        total_observation_count=total_observation_count)
985