1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""RNN helpers for TensorFlow models."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.eager import context
22from tensorflow.python.framework import constant_op
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import tensor_shape
26from tensorflow.python.framework import tensor_util
27from tensorflow.python.keras.engine import base_layer
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import control_flow_ops
30from tensorflow.python.ops import control_flow_util
31from tensorflow.python.ops import math_ops
32from tensorflow.python.ops import rnn_cell_impl
33from tensorflow.python.ops import tensor_array_ops
34from tensorflow.python.ops import variable_scope as vs
35from tensorflow.python.util import deprecation
36from tensorflow.python.util import nest
37from tensorflow.python.util.tf_export import tf_export
38
39
40# pylint: disable=protected-access
41_concat = rnn_cell_impl._concat
42# pylint: enable=protected-access
43
44
45def _transpose_batch_time(x):
46  """Transposes the batch and time dimensions of a Tensor.
47
48  If the input tensor has rank < 2 it returns the original tensor. Retains as
49  much of the static shape information as possible.
50
51  Args:
52    x: A Tensor.
53
54  Returns:
55    x transposed along the first two dimensions.
56  """
57  x_static_shape = x.get_shape()
58  if x_static_shape.rank is not None and x_static_shape.rank < 2:
59    return x
60
61  x_rank = array_ops.rank(x)
62  x_t = array_ops.transpose(
63      x, array_ops.concat(
64          ([1, 0], math_ops.range(2, x_rank)), axis=0))
65  x_t.set_shape(
66      tensor_shape.TensorShape([
67          x_static_shape.dims[1].value, x_static_shape.dims[0].value
68      ]).concatenate(x_static_shape[2:]))
69  return x_t
70
71
72def _best_effort_input_batch_size(flat_input):
73  """Get static input batch size if available, with fallback to the dynamic one.
74
75  Args:
76    flat_input: An iterable of time major input Tensors of shape
77      `[max_time, batch_size, ...]`.
78    All inputs should have compatible batch sizes.
79
80  Returns:
81    The batch size in Python integer if available, or a scalar Tensor otherwise.
82
83  Raises:
84    ValueError: if there is any input with an invalid shape.
85  """
86  for input_ in flat_input:
87    shape = input_.shape
88    if shape.rank is None:
89      continue
90    if shape.rank < 2:
91      raise ValueError(
92          "Expected input tensor %s to have rank at least 2" % input_)
93    batch_size = shape.dims[1].value
94    if batch_size is not None:
95      return batch_size
96  # Fallback to the dynamic batch size of the first input.
97  return array_ops.shape(flat_input[0])[1]
98
99
100def _infer_state_dtype(explicit_dtype, state):
101  """Infer the dtype of an RNN state.
102
103  Args:
104    explicit_dtype: explicitly declared dtype or None.
105    state: RNN's hidden state. Must be a Tensor or a nested iterable containing
106      Tensors.
107
108  Returns:
109    dtype: inferred dtype of hidden state.
110
111  Raises:
112    ValueError: if `state` has heterogeneous dtypes or is empty.
113  """
114  if explicit_dtype is not None:
115    return explicit_dtype
116  elif nest.is_sequence(state):
117    inferred_dtypes = [element.dtype for element in nest.flatten(state)]
118    if not inferred_dtypes:
119      raise ValueError("Unable to infer dtype from empty state.")
120    all_same = all(x == inferred_dtypes[0] for x in inferred_dtypes)
121    if not all_same:
122      raise ValueError(
123          "State has tensors of different inferred_dtypes. Unable to infer a "
124          "single representative dtype.")
125    return inferred_dtypes[0]
126  else:
127    return state.dtype
128
129
130def _maybe_tensor_shape_from_tensor(shape):
131  if isinstance(shape, ops.Tensor):
132    return tensor_shape.as_shape(tensor_util.constant_value(shape))
133  else:
134    return shape
135
136
137def _should_cache():
138  """Returns True if a default caching device should be set, otherwise False."""
139  if context.executing_eagerly():
140    return False
141  # Don't set a caching device when running in a loop, since it is possible that
142  # train steps could be wrapped in a tf.while_loop. In that scenario caching
143  # prevents forward computations in loop iterations from re-reading the
144  # updated weights.
145  ctxt = ops.get_default_graph()._get_control_flow_context()  # pylint: disable=protected-access
146  return control_flow_util.GetContainingWhileContext(ctxt) is None
147
148
149def _is_keras_rnn_cell(rnn_cell):
150  """Check whether the cell is a Keras RNN cell.
151
152  The Keras RNN cell accept the state as a list even the state is a single
153  tensor, whereas the TF RNN cell does not wrap single state tensor in list.
154  This behavior difference should be unified in future version.
155
156  Args:
157    rnn_cell: An RNN cell instance that either follow the Keras interface or TF
158      RNN interface.
159  Returns:
160    Boolean, whether the cell is an Keras RNN cell.
161  """
162  # Cell type check is not strict enough since there are cells created by other
163  # library like Deepmind that didn't inherit tf.nn.rnn_cell.RNNCell.
164  # Keras cells never had zero_state method, which was from the original
165  # interface from TF RNN cell.
166  return (not isinstance(rnn_cell, rnn_cell_impl.RNNCell)
167          and isinstance(rnn_cell, base_layer.Layer)
168          and getattr(rnn_cell, "zero_state", None) is None)
169
170
171# pylint: disable=unused-argument
172def _rnn_step(
173    time, sequence_length, min_sequence_length, max_sequence_length,
174    zero_output, state, call_cell, state_size, skip_conditionals=False):
175  """Calculate one step of a dynamic RNN minibatch.
176
177  Returns an (output, state) pair conditioned on `sequence_length`.
178  When skip_conditionals=False, the pseudocode is something like:
179
180  if t >= max_sequence_length:
181    return (zero_output, state)
182  if t < min_sequence_length:
183    return call_cell()
184
185  # Selectively output zeros or output, old state or new state depending
186  # on whether we've finished calculating each row.
187  new_output, new_state = call_cell()
188  final_output = np.vstack([
189    zero_output if time >= sequence_length[r] else new_output_r
190    for r, new_output_r in enumerate(new_output)
191  ])
192  final_state = np.vstack([
193    state[r] if time >= sequence_length[r] else new_state_r
194    for r, new_state_r in enumerate(new_state)
195  ])
196  return (final_output, final_state)
197
198  Args:
199    time: int32 `Tensor` scalar.
200    sequence_length: int32 `Tensor` vector of size [batch_size].
201    min_sequence_length: int32 `Tensor` scalar, min of sequence_length.
202    max_sequence_length: int32 `Tensor` scalar, max of sequence_length.
203    zero_output: `Tensor` vector of shape [output_size].
204    state: Either a single `Tensor` matrix of shape `[batch_size, state_size]`,
205      or a list/tuple of such tensors.
206    call_cell: lambda returning tuple of (new_output, new_state) where
207      new_output is a `Tensor` matrix of shape `[batch_size, output_size]`.
208      new_state is a `Tensor` matrix of shape `[batch_size, state_size]`.
209    state_size: The `cell.state_size` associated with the state.
210    skip_conditionals: Python bool, whether to skip using the conditional
211      calculations.  This is useful for `dynamic_rnn`, where the input tensor
212      matches `max_sequence_length`, and using conditionals just slows
213      everything down.
214
215  Returns:
216    A tuple of (`final_output`, `final_state`) as given by the pseudocode above:
217      final_output is a `Tensor` matrix of shape [batch_size, output_size]
218      final_state is either a single `Tensor` matrix, or a tuple of such
219        matrices (matching length and shapes of input `state`).
220
221  Raises:
222    ValueError: If the cell returns a state tuple whose length does not match
223      that returned by `state_size`.
224  """
225
226  # Convert state to a list for ease of use
227  flat_state = nest.flatten(state)
228  flat_zero_output = nest.flatten(zero_output)
229
230  # Vector describing which batch entries are finished.
231  copy_cond = time >= sequence_length
232
233  def _copy_one_through(output, new_output):
234    # TensorArray and scalar get passed through.
235    if isinstance(output, tensor_array_ops.TensorArray):
236      return new_output
237    if output.shape.rank == 0:
238      return new_output
239    # Otherwise propagate the old or the new value.
240    with ops.colocate_with(new_output):
241      return array_ops.where(copy_cond, output, new_output)
242
243  def _copy_some_through(flat_new_output, flat_new_state):
244    # Use broadcasting select to determine which values should get
245    # the previous state & zero output, and which values should get
246    # a calculated state & output.
247    flat_new_output = [
248        _copy_one_through(zero_output, new_output)
249        for zero_output, new_output in zip(flat_zero_output, flat_new_output)]
250    flat_new_state = [
251        _copy_one_through(state, new_state)
252        for state, new_state in zip(flat_state, flat_new_state)]
253    return flat_new_output + flat_new_state
254
255  def _maybe_copy_some_through():
256    """Run RNN step.  Pass through either no or some past state."""
257    new_output, new_state = call_cell()
258
259    nest.assert_same_structure(state, new_state)
260
261    flat_new_state = nest.flatten(new_state)
262    flat_new_output = nest.flatten(new_output)
263    return control_flow_ops.cond(
264        # if t < min_seq_len: calculate and return everything
265        time < min_sequence_length, lambda: flat_new_output + flat_new_state,
266        # else copy some of it through
267        lambda: _copy_some_through(flat_new_output, flat_new_state))
268
269  # TODO(ebrevdo): skipping these conditionals may cause a slowdown,
270  # but benefits from removing cond() and its gradient.  We should
271  # profile with and without this switch here.
272  if skip_conditionals:
273    # Instead of using conditionals, perform the selective copy at all time
274    # steps.  This is faster when max_seq_len is equal to the number of unrolls
275    # (which is typical for dynamic_rnn).
276    new_output, new_state = call_cell()
277    nest.assert_same_structure(state, new_state)
278    new_state = nest.flatten(new_state)
279    new_output = nest.flatten(new_output)
280    final_output_and_state = _copy_some_through(new_output, new_state)
281  else:
282    empty_update = lambda: flat_zero_output + flat_state
283    final_output_and_state = control_flow_ops.cond(
284        # if t >= max_seq_len: copy all state through, output zeros
285        time >= max_sequence_length, empty_update,
286        # otherwise calculation is required: copy some or all of it through
287        _maybe_copy_some_through)
288
289  if len(final_output_and_state) != len(flat_zero_output) + len(flat_state):
290    raise ValueError("Internal error: state and output were not concatenated "
291                     "correctly.")
292  final_output = final_output_and_state[:len(flat_zero_output)]
293  final_state = final_output_and_state[len(flat_zero_output):]
294
295  for output, flat_output in zip(final_output, flat_zero_output):
296    output.set_shape(flat_output.get_shape())
297  for substate, flat_substate in zip(final_state, flat_state):
298    if not isinstance(substate, tensor_array_ops.TensorArray):
299      substate.set_shape(flat_substate.get_shape())
300
301  final_output = nest.pack_sequence_as(
302      structure=zero_output, flat_sequence=final_output)
303  final_state = nest.pack_sequence_as(
304      structure=state, flat_sequence=final_state)
305
306  return final_output, final_state
307
308
309def _reverse_seq(input_seq, lengths):
310  """Reverse a list of Tensors up to specified lengths.
311
312  Args:
313    input_seq: Sequence of seq_len tensors of dimension (batch_size, n_features)
314               or nested tuples of tensors.
315    lengths:   A `Tensor` of dimension batch_size, containing lengths for each
316               sequence in the batch. If "None" is specified, simply reverses
317               the list.
318
319  Returns:
320    time-reversed sequence
321  """
322  if lengths is None:
323    return list(reversed(input_seq))
324
325  flat_input_seq = tuple(nest.flatten(input_) for input_ in input_seq)
326
327  flat_results = [[] for _ in range(len(input_seq))]
328  for sequence in zip(*flat_input_seq):
329    input_shape = tensor_shape.unknown_shape(
330        rank=sequence[0].get_shape().rank)
331    for input_ in sequence:
332      input_shape.merge_with(input_.get_shape())
333      input_.set_shape(input_shape)
334
335    # Join into (time, batch_size, depth)
336    s_joined = array_ops.stack(sequence)
337
338    # Reverse along dimension 0
339    s_reversed = array_ops.reverse_sequence(s_joined, lengths, 0, 1)
340    # Split again into list
341    result = array_ops.unstack(s_reversed)
342    for r, flat_result in zip(result, flat_results):
343      r.set_shape(input_shape)
344      flat_result.append(r)
345
346  results = [nest.pack_sequence_as(structure=input_, flat_sequence=flat_result)
347             for input_, flat_result in zip(input_seq, flat_results)]
348  return results
349
350
351@deprecation.deprecated(None, "Please use `keras.layers.Bidirectional("
352                        "keras.layers.RNN(cell))`, which is equivalent to "
353                        "this API")
354@tf_export(v1=["nn.bidirectional_dynamic_rnn"])
355def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None,
356                              initial_state_fw=None, initial_state_bw=None,
357                              dtype=None, parallel_iterations=None,
358                              swap_memory=False, time_major=False, scope=None):
359  """Creates a dynamic version of bidirectional recurrent neural network.
360
361  Takes input and builds independent forward and backward RNNs. The input_size
362  of forward and backward cell must match. The initial state for both directions
363  is zero by default (but can be set optionally) and no intermediate states are
364  ever returned -- the network is fully unrolled for the given (passed in)
365  length(s) of the sequence(s) or completely unrolled if length(s) is not
366  given.
367
368  Args:
369    cell_fw: An instance of RNNCell, to be used for forward direction.
370    cell_bw: An instance of RNNCell, to be used for backward direction.
371    inputs: The RNN inputs.
372      If time_major == False (default), this must be a tensor of shape:
373        `[batch_size, max_time, ...]`, or a nested tuple of such elements.
374      If time_major == True, this must be a tensor of shape:
375        `[max_time, batch_size, ...]`, or a nested tuple of such elements.
376    sequence_length: (optional) An int32/int64 vector, size `[batch_size]`,
377      containing the actual lengths for each of the sequences in the batch.
378      If not provided, all batch entries are assumed to be full sequences; and
379      time reversal is applied from time `0` to `max_time` for each sequence.
380    initial_state_fw: (optional) An initial state for the forward RNN.
381      This must be a tensor of appropriate type and shape
382      `[batch_size, cell_fw.state_size]`.
383      If `cell_fw.state_size` is a tuple, this should be a tuple of
384      tensors having shapes `[batch_size, s] for s in cell_fw.state_size`.
385    initial_state_bw: (optional) Same as for `initial_state_fw`, but using
386      the corresponding properties of `cell_bw`.
387    dtype: (optional) The data type for the initial states and expected output.
388      Required if initial_states are not provided or RNN states have a
389      heterogeneous dtype.
390    parallel_iterations: (Default: 32).  The number of iterations to run in
391      parallel.  Those operations which do not have any temporal dependency
392      and can be run in parallel, will be.  This parameter trades off
393      time for space.  Values >> 1 use more memory but take less time,
394      while smaller values use less memory but computations take longer.
395    swap_memory: Transparently swap the tensors produced in forward inference
396      but needed for back prop from GPU to CPU.  This allows training RNNs
397      which would typically not fit on a single GPU, with very minimal (or no)
398      performance penalty.
399    time_major: The shape format of the `inputs` and `outputs` Tensors.
400      If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`.
401      If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`.
402      Using `time_major = True` is a bit more efficient because it avoids
403      transposes at the beginning and end of the RNN calculation.  However,
404      most TensorFlow data is batch-major, so by default this function
405      accepts input and emits output in batch-major form.
406    scope: VariableScope for the created subgraph; defaults to
407      "bidirectional_rnn"
408
409  Returns:
410    A tuple (outputs, output_states) where:
411      outputs: A tuple (output_fw, output_bw) containing the forward and
412        the backward rnn output `Tensor`.
413        If time_major == False (default),
414          output_fw will be a `Tensor` shaped:
415          `[batch_size, max_time, cell_fw.output_size]`
416          and output_bw will be a `Tensor` shaped:
417          `[batch_size, max_time, cell_bw.output_size]`.
418        If time_major == True,
419          output_fw will be a `Tensor` shaped:
420          `[max_time, batch_size, cell_fw.output_size]`
421          and output_bw will be a `Tensor` shaped:
422          `[max_time, batch_size, cell_bw.output_size]`.
423        It returns a tuple instead of a single concatenated `Tensor`, unlike
424        in the `bidirectional_rnn`. If the concatenated one is preferred,
425        the forward and backward outputs can be concatenated as
426        `tf.concat(outputs, 2)`.
427      output_states: A tuple (output_state_fw, output_state_bw) containing
428        the forward and the backward final states of bidirectional rnn.
429
430  Raises:
431    TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`.
432  """
433  rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw)
434  rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw)
435
436  with vs.variable_scope(scope or "bidirectional_rnn"):
437    # Forward direction
438    with vs.variable_scope("fw") as fw_scope:
439      output_fw, output_state_fw = dynamic_rnn(
440          cell=cell_fw, inputs=inputs, sequence_length=sequence_length,
441          initial_state=initial_state_fw, dtype=dtype,
442          parallel_iterations=parallel_iterations, swap_memory=swap_memory,
443          time_major=time_major, scope=fw_scope)
444
445    # Backward direction
446    if not time_major:
447      time_axis = 1
448      batch_axis = 0
449    else:
450      time_axis = 0
451      batch_axis = 1
452
453    def _reverse(input_, seq_lengths, seq_axis, batch_axis):
454      if seq_lengths is not None:
455        return array_ops.reverse_sequence(
456            input=input_, seq_lengths=seq_lengths,
457            seq_axis=seq_axis, batch_axis=batch_axis)
458      else:
459        return array_ops.reverse(input_, axis=[seq_axis])
460
461    with vs.variable_scope("bw") as bw_scope:
462
463      def _map_reverse(inp):
464        return _reverse(
465            inp,
466            seq_lengths=sequence_length,
467            seq_axis=time_axis,
468            batch_axis=batch_axis)
469
470      inputs_reverse = nest.map_structure(_map_reverse, inputs)
471      tmp, output_state_bw = dynamic_rnn(
472          cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length,
473          initial_state=initial_state_bw, dtype=dtype,
474          parallel_iterations=parallel_iterations, swap_memory=swap_memory,
475          time_major=time_major, scope=bw_scope)
476
477  output_bw = _reverse(
478      tmp, seq_lengths=sequence_length,
479      seq_axis=time_axis, batch_axis=batch_axis)
480
481  outputs = (output_fw, output_bw)
482  output_states = (output_state_fw, output_state_bw)
483
484  return (outputs, output_states)
485
486
487@deprecation.deprecated(
488    None,
489    "Please use `keras.layers.RNN(cell)`, which is equivalent to this API")
490@tf_export(v1=["nn.dynamic_rnn"])
491def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
492                dtype=None, parallel_iterations=None, swap_memory=False,
493                time_major=False, scope=None):
494  """Creates a recurrent neural network specified by RNNCell `cell`.
495
496  Performs fully dynamic unrolling of `inputs`.
497
498  Example:
499
500  ```python
501  # create a BasicRNNCell
502  rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
503
504  # 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]
505
506  # defining initial state
507  initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)
508
509  # 'state' is a tensor of shape [batch_size, cell_state_size]
510  outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data,
511                                     initial_state=initial_state,
512                                     dtype=tf.float32)
513  ```
514
515  ```python
516  # create 2 LSTMCells
517  rnn_layers = [tf.nn.rnn_cell.LSTMCell(size) for size in [128, 256]]
518
519  # create a RNN cell composed sequentially of a number of RNNCells
520  multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers)
521
522  # 'outputs' is a tensor of shape [batch_size, max_time, 256]
523  # 'state' is a N-tuple where N is the number of LSTMCells containing a
524  # tf.contrib.rnn.LSTMStateTuple for each cell
525  outputs, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell,
526                                     inputs=data,
527                                     dtype=tf.float32)
528  ```
529
530
531  Args:
532    cell: An instance of RNNCell.
533    inputs: The RNN inputs.
534      If `time_major == False` (default), this must be a `Tensor` of shape:
535        `[batch_size, max_time, ...]`, or a nested tuple of such
536        elements.
537      If `time_major == True`, this must be a `Tensor` of shape:
538        `[max_time, batch_size, ...]`, or a nested tuple of such
539        elements.
540      This may also be a (possibly nested) tuple of Tensors satisfying
541      this property.  The first two dimensions must match across all the inputs,
542      but otherwise the ranks and other shape components may differ.
543      In this case, input to `cell` at each time-step will replicate the
544      structure of these tuples, except for the time dimension (from which the
545      time is taken).
546      The input to `cell` at each time step will be a `Tensor` or (possibly
547      nested) tuple of Tensors each with dimensions `[batch_size, ...]`.
548    sequence_length: (optional) An int32/int64 vector sized `[batch_size]`.
549      Used to copy-through state and zero-out outputs when past a batch
550      element's sequence length.  So it's more for performance than correctness.
551    initial_state: (optional) An initial state for the RNN.
552      If `cell.state_size` is an integer, this must be
553      a `Tensor` of appropriate type and shape `[batch_size, cell.state_size]`.
554      If `cell.state_size` is a tuple, this should be a tuple of
555      tensors having shapes `[batch_size, s] for s in cell.state_size`.
556    dtype: (optional) The data type for the initial state and expected output.
557      Required if initial_state is not provided or RNN state has a heterogeneous
558      dtype.
559    parallel_iterations: (Default: 32).  The number of iterations to run in
560      parallel.  Those operations which do not have any temporal dependency
561      and can be run in parallel, will be.  This parameter trades off
562      time for space.  Values >> 1 use more memory but take less time,
563      while smaller values use less memory but computations take longer.
564    swap_memory: Transparently swap the tensors produced in forward inference
565      but needed for back prop from GPU to CPU.  This allows training RNNs
566      which would typically not fit on a single GPU, with very minimal (or no)
567      performance penalty.
568    time_major: The shape format of the `inputs` and `outputs` Tensors.
569      If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`.
570      If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`.
571      Using `time_major = True` is a bit more efficient because it avoids
572      transposes at the beginning and end of the RNN calculation.  However,
573      most TensorFlow data is batch-major, so by default this function
574      accepts input and emits output in batch-major form.
575    scope: VariableScope for the created subgraph; defaults to "rnn".
576
577  Returns:
578    A pair (outputs, state) where:
579
580    outputs: The RNN output `Tensor`.
581
582      If time_major == False (default), this will be a `Tensor` shaped:
583        `[batch_size, max_time, cell.output_size]`.
584
585      If time_major == True, this will be a `Tensor` shaped:
586        `[max_time, batch_size, cell.output_size]`.
587
588      Note, if `cell.output_size` is a (possibly nested) tuple of integers
589      or `TensorShape` objects, then `outputs` will be a tuple having the
590      same structure as `cell.output_size`, containing Tensors having shapes
591      corresponding to the shape data in `cell.output_size`.
592
593    state: The final state.  If `cell.state_size` is an int, this
594      will be shaped `[batch_size, cell.state_size]`.  If it is a
595      `TensorShape`, this will be shaped `[batch_size] + cell.state_size`.
596      If it is a (possibly nested) tuple of ints or `TensorShape`, this will
597      be a tuple having the corresponding shapes. If cells are `LSTMCells`
598      `state` will be a tuple containing a `LSTMStateTuple` for each cell.
599
600  Raises:
601    TypeError: If `cell` is not an instance of RNNCell.
602    ValueError: If inputs is None or an empty list.
603  """
604  rnn_cell_impl.assert_like_rnncell("cell", cell)
605
606  with vs.variable_scope(scope or "rnn") as varscope:
607    # Create a new scope in which the caching device is either
608    # determined by the parent scope, or is set to place the cached
609    # Variable using the same placement as for the rest of the RNN.
610    if _should_cache():
611      if varscope.caching_device is None:
612        varscope.set_caching_device(lambda op: op.device)
613
614    # By default, time_major==False and inputs are batch-major: shaped
615    #   [batch, time, depth]
616    # For internal calculations, we transpose to [time, batch, depth]
617    flat_input = nest.flatten(inputs)
618
619    if not time_major:
620      # (B,T,D) => (T,B,D)
621      flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input]
622      flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input)
623
624    parallel_iterations = parallel_iterations or 32
625    if sequence_length is not None:
626      sequence_length = math_ops.cast(sequence_length, dtypes.int32)
627      if sequence_length.get_shape().rank not in (None, 1):
628        raise ValueError(
629            "sequence_length must be a vector of length batch_size, "
630            "but saw shape: %s" % sequence_length.get_shape())
631      sequence_length = array_ops.identity(  # Just to find it in the graph.
632          sequence_length, name="sequence_length")
633
634    batch_size = _best_effort_input_batch_size(flat_input)
635
636    if initial_state is not None:
637      state = initial_state
638    else:
639      if not dtype:
640        raise ValueError("If there is no initial_state, you must give a dtype.")
641      if getattr(cell, "get_initial_state", None) is not None:
642        state = cell.get_initial_state(
643            inputs=None, batch_size=batch_size, dtype=dtype)
644      else:
645        state = cell.zero_state(batch_size, dtype)
646
647    def _assert_has_shape(x, shape):
648      x_shape = array_ops.shape(x)
649      packed_shape = array_ops.stack(shape)
650      return control_flow_ops.Assert(
651          math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)),
652          ["Expected shape for Tensor %s is " % x.name,
653           packed_shape, " but saw shape: ", x_shape])
654
655    if not context.executing_eagerly() and sequence_length is not None:
656      # Perform some shape validation
657      with ops.control_dependencies(
658          [_assert_has_shape(sequence_length, [batch_size])]):
659        sequence_length = array_ops.identity(
660            sequence_length, name="CheckSeqLen")
661
662    inputs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input)
663
664    (outputs, final_state) = _dynamic_rnn_loop(
665        cell,
666        inputs,
667        state,
668        parallel_iterations=parallel_iterations,
669        swap_memory=swap_memory,
670        sequence_length=sequence_length,
671        dtype=dtype)
672
673    # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth].
674    # If we are performing batch-major calculations, transpose output back
675    # to shape [batch, time, depth]
676    if not time_major:
677      # (T,B,D) => (B,T,D)
678      outputs = nest.map_structure(_transpose_batch_time, outputs)
679
680    return (outputs, final_state)
681
682
683def _dynamic_rnn_loop(cell,
684                      inputs,
685                      initial_state,
686                      parallel_iterations,
687                      swap_memory,
688                      sequence_length=None,
689                      dtype=None):
690  """Internal implementation of Dynamic RNN.
691
692  Args:
693    cell: An instance of RNNCell.
694    inputs: A `Tensor` of shape [time, batch_size, input_size], or a nested
695      tuple of such elements.
696    initial_state: A `Tensor` of shape `[batch_size, state_size]`, or if
697      `cell.state_size` is a tuple, then this should be a tuple of
698      tensors having shapes `[batch_size, s] for s in cell.state_size`.
699    parallel_iterations: Positive Python int.
700    swap_memory: A Python boolean
701    sequence_length: (optional) An `int32` `Tensor` of shape [batch_size].
702    dtype: (optional) Expected dtype of output. If not specified, inferred from
703      initial_state.
704
705  Returns:
706    Tuple `(final_outputs, final_state)`.
707    final_outputs:
708      A `Tensor` of shape `[time, batch_size, cell.output_size]`.  If
709      `cell.output_size` is a (possibly nested) tuple of ints or `TensorShape`
710      objects, then this returns a (possibly nested) tuple of Tensors matching
711      the corresponding shapes.
712    final_state:
713      A `Tensor`, or possibly nested tuple of Tensors, matching in length
714      and shapes to `initial_state`.
715
716  Raises:
717    ValueError: If the input depth cannot be inferred via shape inference
718      from the inputs.
719    ValueError: If time_step is not the same for all the elements in the
720      inputs.
721    ValueError: If batch_size is not the same for all the elements in the
722      inputs.
723  """
724  state = initial_state
725  assert isinstance(parallel_iterations, int), "parallel_iterations must be int"
726
727  state_size = cell.state_size
728
729  flat_input = nest.flatten(inputs)
730  flat_output_size = nest.flatten(cell.output_size)
731
732  # Construct an initial output
733  input_shape = array_ops.shape(flat_input[0])
734  time_steps = input_shape[0]
735  batch_size = _best_effort_input_batch_size(flat_input)
736
737  inputs_got_shape = tuple(input_.get_shape().with_rank_at_least(3)
738                           for input_ in flat_input)
739
740  const_time_steps, const_batch_size = inputs_got_shape[0].as_list()[:2]
741
742  for shape in inputs_got_shape:
743    if not shape[2:].is_fully_defined():
744      raise ValueError(
745          "Input size (depth of inputs) must be accessible via shape inference,"
746          " but saw value None.")
747    got_time_steps = shape.dims[0].value
748    got_batch_size = shape.dims[1].value
749    if const_time_steps != got_time_steps:
750      raise ValueError(
751          "Time steps is not the same for all the elements in the input in a "
752          "batch.")
753    if const_batch_size != got_batch_size:
754      raise ValueError(
755          "Batch_size is not the same for all the elements in the input.")
756
757  # Prepare dynamic conditional copying of state & output
758  def _create_zero_arrays(size):
759    size = _concat(batch_size, size)
760    return array_ops.zeros(
761        array_ops.stack(size), _infer_state_dtype(dtype, state))
762
763  flat_zero_output = tuple(_create_zero_arrays(output)
764                           for output in flat_output_size)
765  zero_output = nest.pack_sequence_as(structure=cell.output_size,
766                                      flat_sequence=flat_zero_output)
767
768  if sequence_length is not None:
769    min_sequence_length = math_ops.reduce_min(sequence_length)
770    max_sequence_length = math_ops.reduce_max(sequence_length)
771  else:
772    max_sequence_length = time_steps
773
774  time = array_ops.constant(0, dtype=dtypes.int32, name="time")
775
776  with ops.name_scope("dynamic_rnn") as scope:
777    base_name = scope
778
779  def _create_ta(name, element_shape, dtype):
780    return tensor_array_ops.TensorArray(dtype=dtype,
781                                        size=time_steps,
782                                        element_shape=element_shape,
783                                        tensor_array_name=base_name + name)
784
785  in_graph_mode = not context.executing_eagerly()
786  if in_graph_mode:
787    output_ta = tuple(
788        _create_ta(
789            "output_%d" % i,
790            element_shape=(tensor_shape.TensorShape([const_batch_size])
791                           .concatenate(
792                               _maybe_tensor_shape_from_tensor(out_size))),
793            dtype=_infer_state_dtype(dtype, state))
794        for i, out_size in enumerate(flat_output_size))
795    input_ta = tuple(
796        _create_ta(
797            "input_%d" % i,
798            element_shape=flat_input_i.shape[1:],
799            dtype=flat_input_i.dtype)
800        for i, flat_input_i in enumerate(flat_input))
801    input_ta = tuple(ta.unstack(input_)
802                     for ta, input_ in zip(input_ta, flat_input))
803  else:
804    output_ta = tuple([0 for _ in range(time_steps.numpy())]
805                      for i in range(len(flat_output_size)))
806    input_ta = flat_input
807
808  def _time_step(time, output_ta_t, state):
809    """Take a time step of the dynamic RNN.
810
811    Args:
812      time: int32 scalar Tensor.
813      output_ta_t: List of `TensorArray`s that represent the output.
814      state: nested tuple of vector tensors that represent the state.
815
816    Returns:
817      The tuple (time + 1, output_ta_t with updated flow, new_state).
818    """
819
820    if in_graph_mode:
821      input_t = tuple(ta.read(time) for ta in input_ta)
822      # Restore some shape information
823      for input_, shape in zip(input_t, inputs_got_shape):
824        input_.set_shape(shape[1:])
825    else:
826      input_t = tuple(ta[time.numpy()] for ta in input_ta)
827
828    input_t = nest.pack_sequence_as(structure=inputs, flat_sequence=input_t)
829    # Keras RNN cells only accept state as list, even if it's a single tensor.
830    is_keras_rnn_cell = _is_keras_rnn_cell(cell)
831    if is_keras_rnn_cell and not nest.is_sequence(state):
832      state = [state]
833    call_cell = lambda: cell(input_t, state)
834
835    if sequence_length is not None:
836      (output, new_state) = _rnn_step(
837          time=time,
838          sequence_length=sequence_length,
839          min_sequence_length=min_sequence_length,
840          max_sequence_length=max_sequence_length,
841          zero_output=zero_output,
842          state=state,
843          call_cell=call_cell,
844          state_size=state_size,
845          skip_conditionals=True)
846    else:
847      (output, new_state) = call_cell()
848
849    # Keras cells always wrap state as list, even if it's a single tensor.
850    if is_keras_rnn_cell and len(new_state) == 1:
851      new_state = new_state[0]
852    # Pack state if using state tuples
853    output = nest.flatten(output)
854
855    if in_graph_mode:
856      output_ta_t = tuple(
857          ta.write(time, out) for ta, out in zip(output_ta_t, output))
858    else:
859      for ta, out in zip(output_ta_t, output):
860        ta[time.numpy()] = out
861
862    return (time + 1, output_ta_t, new_state)
863
864  if in_graph_mode:
865    # Make sure that we run at least 1 step, if necessary, to ensure
866    # the TensorArrays pick up the dynamic shape.
867    loop_bound = math_ops.minimum(
868        time_steps, math_ops.maximum(1, max_sequence_length))
869  else:
870    # Using max_sequence_length isn't currently supported in the Eager branch.
871    loop_bound = time_steps
872
873  _, output_final_ta, final_state = control_flow_ops.while_loop(
874      cond=lambda time, *_: time < loop_bound,
875      body=_time_step,
876      loop_vars=(time, output_ta, state),
877      parallel_iterations=parallel_iterations,
878      maximum_iterations=time_steps,
879      swap_memory=swap_memory)
880
881  # Unpack final output if not using output tuples.
882  if in_graph_mode:
883    final_outputs = tuple(ta.stack() for ta in output_final_ta)
884    # Restore some shape information
885    for output, output_size in zip(final_outputs, flat_output_size):
886      shape = _concat(
887          [const_time_steps, const_batch_size], output_size, static=True)
888      output.set_shape(shape)
889  else:
890    final_outputs = output_final_ta
891
892  final_outputs = nest.pack_sequence_as(
893      structure=cell.output_size, flat_sequence=final_outputs)
894  if not in_graph_mode:
895    final_outputs = nest.map_structure_up_to(
896        cell.output_size, lambda x: array_ops.stack(x, axis=0), final_outputs)
897
898  return (final_outputs, final_state)
899
900
901@tf_export(v1=["nn.raw_rnn"])
902def raw_rnn(cell, loop_fn,
903            parallel_iterations=None, swap_memory=False, scope=None):
904  """Creates an `RNN` specified by RNNCell `cell` and loop function `loop_fn`.
905
906  **NOTE: This method is still in testing, and the API may change.**
907
908  This function is a more primitive version of `dynamic_rnn` that provides
909  more direct access to the inputs each iteration.  It also provides more
910  control over when to start and finish reading the sequence, and
911  what to emit for the output.
912
913  For example, it can be used to implement the dynamic decoder of a seq2seq
914  model.
915
916  Instead of working with `Tensor` objects, most operations work with
917  `TensorArray` objects directly.
918
919  The operation of `raw_rnn`, in pseudo-code, is basically the following:
920
921  ```python
922  time = tf.constant(0, dtype=tf.int32)
923  (finished, next_input, initial_state, emit_structure, loop_state) = loop_fn(
924      time=time, cell_output=None, cell_state=None, loop_state=None)
925  emit_ta = TensorArray(dynamic_size=True, dtype=initial_state.dtype)
926  state = initial_state
927  while not all(finished):
928    (output, cell_state) = cell(next_input, state)
929    (next_finished, next_input, next_state, emit, loop_state) = loop_fn(
930        time=time + 1, cell_output=output, cell_state=cell_state,
931        loop_state=loop_state)
932    # Emit zeros and copy forward state for minibatch entries that are finished.
933    state = tf.where(finished, state, next_state)
934    emit = tf.where(finished, tf.zeros_like(emit_structure), emit)
935    emit_ta = emit_ta.write(time, emit)
936    # If any new minibatch entries are marked as finished, mark these.
937    finished = tf.logical_or(finished, next_finished)
938    time += 1
939  return (emit_ta, state, loop_state)
940  ```
941
942  with the additional properties that output and state may be (possibly nested)
943  tuples, as determined by `cell.output_size` and `cell.state_size`, and
944  as a result the final `state` and `emit_ta` may themselves be tuples.
945
946  A simple implementation of `dynamic_rnn` via `raw_rnn` looks like this:
947
948  ```python
949  inputs = tf.placeholder(shape=(max_time, batch_size, input_depth),
950                          dtype=tf.float32)
951  sequence_length = tf.placeholder(shape=(batch_size,), dtype=tf.int32)
952  inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time)
953  inputs_ta = inputs_ta.unstack(inputs)
954
955  cell = tf.contrib.rnn.LSTMCell(num_units)
956
957  def loop_fn(time, cell_output, cell_state, loop_state):
958    emit_output = cell_output  # == None for time == 0
959    if cell_output is None:  # time == 0
960      next_cell_state = cell.zero_state(batch_size, tf.float32)
961    else:
962      next_cell_state = cell_state
963    elements_finished = (time >= sequence_length)
964    finished = tf.reduce_all(elements_finished)
965    next_input = tf.cond(
966        finished,
967        lambda: tf.zeros([batch_size, input_depth], dtype=tf.float32),
968        lambda: inputs_ta.read(time))
969    next_loop_state = None
970    return (elements_finished, next_input, next_cell_state,
971            emit_output, next_loop_state)
972
973  outputs_ta, final_state, _ = raw_rnn(cell, loop_fn)
974  outputs = outputs_ta.stack()
975  ```
976
977  Args:
978    cell: An instance of RNNCell.
979    loop_fn: A callable that takes inputs
980      `(time, cell_output, cell_state, loop_state)`
981      and returns the tuple
982      `(finished, next_input, next_cell_state, emit_output, next_loop_state)`.
983      Here `time` is an int32 scalar `Tensor`, `cell_output` is a
984      `Tensor` or (possibly nested) tuple of tensors as determined by
985      `cell.output_size`, and `cell_state` is a `Tensor`
986      or (possibly nested) tuple of tensors, as determined by the `loop_fn`
987      on its first call (and should match `cell.state_size`).
988      The outputs are: `finished`, a boolean `Tensor` of
989      shape `[batch_size]`, `next_input`: the next input to feed to `cell`,
990      `next_cell_state`: the next state to feed to `cell`,
991      and `emit_output`: the output to store for this iteration.
992
993      Note that `emit_output` should be a `Tensor` or (possibly nested)
994      tuple of tensors which is aggregated in the `emit_ta` inside the
995      `while_loop`. For the first call to `loop_fn`, the `emit_output`
996      corresponds to the `emit_structure` which is then used to determine the
997      size of the `zero_tensor` for the `emit_ta` (defaults to
998      `cell.output_size`). For the subsequent calls to the `loop_fn`, the
999      `emit_output` corresponds to the actual output tensor
1000      that is to be aggregated in the `emit_ta`. The parameter `cell_state`
1001      and output `next_cell_state` may be either a single or (possibly nested)
1002      tuple of tensors.  The parameter `loop_state` and
1003      output `next_loop_state` may be either a single or (possibly nested) tuple
1004      of `Tensor` and `TensorArray` objects.  This last parameter
1005      may be ignored by `loop_fn` and the return value may be `None`.  If it
1006      is not `None`, then the `loop_state` will be propagated through the RNN
1007      loop, for use purely by `loop_fn` to keep track of its own state.
1008      The `next_loop_state` parameter returned may be `None`.
1009
1010      The first call to `loop_fn` will be `time = 0`, `cell_output = None`,
1011      `cell_state = None`, and `loop_state = None`.  For this call:
1012      The `next_cell_state` value should be the value with which to initialize
1013      the cell's state.  It may be a final state from a previous RNN or it
1014      may be the output of `cell.zero_state()`.  It should be a
1015      (possibly nested) tuple structure of tensors.
1016      If `cell.state_size` is an integer, this must be
1017      a `Tensor` of appropriate type and shape `[batch_size, cell.state_size]`.
1018      If `cell.state_size` is a `TensorShape`, this must be a `Tensor` of
1019      appropriate type and shape `[batch_size] + cell.state_size`.
1020      If `cell.state_size` is a (possibly nested) tuple of ints or
1021      `TensorShape`, this will be a tuple having the corresponding shapes.
1022      The `emit_output` value may be either `None` or a (possibly nested)
1023      tuple structure of tensors, e.g.,
1024      `(tf.zeros(shape_0, dtype=dtype_0), tf.zeros(shape_1, dtype=dtype_1))`.
1025      If this first `emit_output` return value is `None`,
1026      then the `emit_ta` result of `raw_rnn` will have the same structure and
1027      dtypes as `cell.output_size`.  Otherwise `emit_ta` will have the same
1028      structure, shapes (prepended with a `batch_size` dimension), and dtypes
1029      as `emit_output`.  The actual values returned for `emit_output` at this
1030      initializing call are ignored.  Note, this emit structure must be
1031      consistent across all time steps.
1032
1033    parallel_iterations: (Default: 32).  The number of iterations to run in
1034      parallel.  Those operations which do not have any temporal dependency
1035      and can be run in parallel, will be.  This parameter trades off
1036      time for space.  Values >> 1 use more memory but take less time,
1037      while smaller values use less memory but computations take longer.
1038    swap_memory: Transparently swap the tensors produced in forward inference
1039      but needed for back prop from GPU to CPU.  This allows training RNNs
1040      which would typically not fit on a single GPU, with very minimal (or no)
1041      performance penalty.
1042    scope: VariableScope for the created subgraph; defaults to "rnn".
1043
1044  Returns:
1045    A tuple `(emit_ta, final_state, final_loop_state)` where:
1046
1047    `emit_ta`: The RNN output `TensorArray`.
1048       If `loop_fn` returns a (possibly nested) set of Tensors for
1049       `emit_output` during initialization, (inputs `time = 0`,
1050       `cell_output = None`, and `loop_state = None`), then `emit_ta` will
1051       have the same structure, dtypes, and shapes as `emit_output` instead.
1052       If `loop_fn` returns `emit_output = None` during this call,
1053       the structure of `cell.output_size` is used:
1054       If `cell.output_size` is a (possibly nested) tuple of integers
1055       or `TensorShape` objects, then `emit_ta` will be a tuple having the
1056       same structure as `cell.output_size`, containing TensorArrays whose
1057       elements' shapes correspond to the shape data in `cell.output_size`.
1058
1059    `final_state`: The final cell state.  If `cell.state_size` is an int, this
1060      will be shaped `[batch_size, cell.state_size]`.  If it is a
1061      `TensorShape`, this will be shaped `[batch_size] + cell.state_size`.
1062      If it is a (possibly nested) tuple of ints or `TensorShape`, this will
1063      be a tuple having the corresponding shapes.
1064
1065    `final_loop_state`: The final loop state as returned by `loop_fn`.
1066
1067  Raises:
1068    TypeError: If `cell` is not an instance of RNNCell, or `loop_fn` is not
1069      a `callable`.
1070  """
1071  rnn_cell_impl.assert_like_rnncell("cell", cell)
1072
1073  if not callable(loop_fn):
1074    raise TypeError("loop_fn must be a callable")
1075
1076  parallel_iterations = parallel_iterations or 32
1077
1078  # Create a new scope in which the caching device is either
1079  # determined by the parent scope, or is set to place the cached
1080  # Variable using the same placement as for the rest of the RNN.
1081  with vs.variable_scope(scope or "rnn") as varscope:
1082    if _should_cache():
1083      if varscope.caching_device is None:
1084        varscope.set_caching_device(lambda op: op.device)
1085
1086    time = constant_op.constant(0, dtype=dtypes.int32)
1087    (elements_finished, next_input, initial_state, emit_structure,
1088     init_loop_state) = loop_fn(
1089         time, None, None, None)  # time, cell_output, cell_state, loop_state
1090    flat_input = nest.flatten(next_input)
1091
1092    # Need a surrogate loop state for the while_loop if none is available.
1093    loop_state = (init_loop_state if init_loop_state is not None
1094                  else constant_op.constant(0, dtype=dtypes.int32))
1095
1096    input_shape = [input_.get_shape() for input_ in flat_input]
1097    static_batch_size = tensor_shape.dimension_at_index(input_shape[0], 0)
1098
1099    for input_shape_i in input_shape:
1100      # Static verification that batch sizes all match
1101      static_batch_size.merge_with(
1102          tensor_shape.dimension_at_index(input_shape_i, 0))
1103
1104    batch_size = tensor_shape.dimension_value(static_batch_size)
1105    const_batch_size = batch_size
1106    if batch_size is None:
1107      batch_size = array_ops.shape(flat_input[0])[0]
1108
1109    nest.assert_same_structure(initial_state, cell.state_size)
1110    state = initial_state
1111    flat_state = nest.flatten(state)
1112    flat_state = [ops.convert_to_tensor(s) for s in flat_state]
1113    state = nest.pack_sequence_as(structure=state,
1114                                  flat_sequence=flat_state)
1115
1116    if emit_structure is not None:
1117      flat_emit_structure = nest.flatten(emit_structure)
1118      flat_emit_size = [emit.shape if emit.shape.is_fully_defined() else
1119                        array_ops.shape(emit) for emit in flat_emit_structure]
1120      flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
1121    else:
1122      emit_structure = cell.output_size
1123      flat_emit_size = nest.flatten(emit_structure)
1124      flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size)
1125
1126    flat_emit_ta = [
1127        tensor_array_ops.TensorArray(
1128            dtype=dtype_i,
1129            dynamic_size=True,
1130            element_shape=(tensor_shape.TensorShape([const_batch_size])
1131                           .concatenate(
1132                               _maybe_tensor_shape_from_tensor(size_i))),
1133            size=0,
1134            name="rnn_output_%d" % i)
1135        for i, (dtype_i, size_i)
1136        in enumerate(zip(flat_emit_dtypes, flat_emit_size))]
1137    emit_ta = nest.pack_sequence_as(structure=emit_structure,
1138                                    flat_sequence=flat_emit_ta)
1139    flat_zero_emit = [
1140        array_ops.zeros(_concat(batch_size, size_i), dtype_i)
1141        for size_i, dtype_i in zip(flat_emit_size, flat_emit_dtypes)]
1142    zero_emit = nest.pack_sequence_as(structure=emit_structure,
1143                                      flat_sequence=flat_zero_emit)
1144
1145    def condition(unused_time, elements_finished, *_):
1146      return math_ops.logical_not(math_ops.reduce_all(elements_finished))
1147
1148    def body(time, elements_finished, current_input,
1149             emit_ta, state, loop_state):
1150      """Internal while loop body for raw_rnn.
1151
1152      Args:
1153        time: time scalar.
1154        elements_finished: batch-size vector.
1155        current_input: possibly nested tuple of input tensors.
1156        emit_ta: possibly nested tuple of output TensorArrays.
1157        state: possibly nested tuple of state tensors.
1158        loop_state: possibly nested tuple of loop state tensors.
1159
1160      Returns:
1161        Tuple having the same size as Args but with updated values.
1162      """
1163      (next_output, cell_state) = cell(current_input, state)
1164
1165      nest.assert_same_structure(state, cell_state)
1166      nest.assert_same_structure(cell.output_size, next_output)
1167
1168      next_time = time + 1
1169      (next_finished, next_input, next_state, emit_output,
1170       next_loop_state) = loop_fn(
1171           next_time, next_output, cell_state, loop_state)
1172
1173      nest.assert_same_structure(state, next_state)
1174      nest.assert_same_structure(current_input, next_input)
1175      nest.assert_same_structure(emit_ta, emit_output)
1176
1177      # If loop_fn returns None for next_loop_state, just reuse the
1178      # previous one.
1179      loop_state = loop_state if next_loop_state is None else next_loop_state
1180
1181      def _copy_some_through(current, candidate):
1182        """Copy some tensors through via array_ops.where."""
1183        def copy_fn(cur_i, cand_i):
1184          # TensorArray and scalar get passed through.
1185          if isinstance(cur_i, tensor_array_ops.TensorArray):
1186            return cand_i
1187          if cur_i.shape.rank == 0:
1188            return cand_i
1189          # Otherwise propagate the old or the new value.
1190          with ops.colocate_with(cand_i):
1191            return array_ops.where(elements_finished, cur_i, cand_i)
1192        return nest.map_structure(copy_fn, current, candidate)
1193
1194      emit_output = _copy_some_through(zero_emit, emit_output)
1195      next_state = _copy_some_through(state, next_state)
1196
1197      emit_ta = nest.map_structure(
1198          lambda ta, emit: ta.write(time, emit), emit_ta, emit_output)
1199
1200      elements_finished = math_ops.logical_or(elements_finished, next_finished)
1201
1202      return (next_time, elements_finished, next_input,
1203              emit_ta, next_state, loop_state)
1204
1205    returned = control_flow_ops.while_loop(
1206        condition, body, loop_vars=[
1207            time, elements_finished, next_input,
1208            emit_ta, state, loop_state],
1209        parallel_iterations=parallel_iterations,
1210        swap_memory=swap_memory)
1211
1212    (emit_ta, final_state, final_loop_state) = returned[-3:]
1213
1214    if init_loop_state is None:
1215      final_loop_state = None
1216
1217    return (emit_ta, final_state, final_loop_state)
1218
1219
1220@deprecation.deprecated(
1221    None, "Please use `keras.layers.RNN(cell, unroll=True)`, "
1222    "which is equivalent to this API")
1223@tf_export(v1=["nn.static_rnn"])
1224def static_rnn(cell,
1225               inputs,
1226               initial_state=None,
1227               dtype=None,
1228               sequence_length=None,
1229               scope=None):
1230  """Creates a recurrent neural network specified by RNNCell `cell`.
1231
1232  The simplest form of RNN network generated is:
1233
1234  ```python
1235    state = cell.zero_state(...)
1236    outputs = []
1237    for input_ in inputs:
1238      output, state = cell(input_, state)
1239      outputs.append(output)
1240    return (outputs, state)
1241  ```
1242  However, a few other options are available:
1243
1244  An initial state can be provided.
1245  If the sequence_length vector is provided, dynamic calculation is performed.
1246  This method of calculation does not compute the RNN steps past the maximum
1247  sequence length of the minibatch (thus saving computational time),
1248  and properly propagates the state at an example's sequence length
1249  to the final state output.
1250
1251  The dynamic calculation performed is, at time `t` for batch row `b`,
1252
1253  ```python
1254    (output, state)(b, t) =
1255      (t >= sequence_length(b))
1256        ? (zeros(cell.output_size), states(b, sequence_length(b) - 1))
1257        : cell(input(b, t), state(b, t - 1))
1258  ```
1259
1260  Args:
1261    cell: An instance of RNNCell.
1262    inputs: A length T list of inputs, each a `Tensor` of shape
1263      `[batch_size, input_size]`, or a nested tuple of such elements.
1264    initial_state: (optional) An initial state for the RNN.
1265      If `cell.state_size` is an integer, this must be
1266      a `Tensor` of appropriate type and shape `[batch_size, cell.state_size]`.
1267      If `cell.state_size` is a tuple, this should be a tuple of
1268      tensors having shapes `[batch_size, s] for s in cell.state_size`.
1269    dtype: (optional) The data type for the initial state and expected output.
1270      Required if initial_state is not provided or RNN state has a heterogeneous
1271      dtype.
1272    sequence_length: Specifies the length of each sequence in inputs.
1273      An int32 or int64 vector (tensor) size `[batch_size]`, values in `[0, T)`.
1274    scope: VariableScope for the created subgraph; defaults to "rnn".
1275
1276  Returns:
1277    A pair (outputs, state) where:
1278
1279    - outputs is a length T list of outputs (one for each input), or a nested
1280      tuple of such elements.
1281    - state is the final state
1282
1283  Raises:
1284    TypeError: If `cell` is not an instance of RNNCell.
1285    ValueError: If `inputs` is `None` or an empty list, or if the input depth
1286      (column size) cannot be inferred from inputs via shape inference.
1287  """
1288  rnn_cell_impl.assert_like_rnncell("cell", cell)
1289  if not nest.is_sequence(inputs):
1290    raise TypeError("inputs must be a sequence")
1291  if not inputs:
1292    raise ValueError("inputs must not be empty")
1293
1294  outputs = []
1295  # Create a new scope in which the caching device is either
1296  # determined by the parent scope, or is set to place the cached
1297  # Variable using the same placement as for the rest of the RNN.
1298  with vs.variable_scope(scope or "rnn") as varscope:
1299    if _should_cache():
1300      if varscope.caching_device is None:
1301        varscope.set_caching_device(lambda op: op.device)
1302
1303    # Obtain the first sequence of the input
1304    first_input = inputs
1305    while nest.is_sequence(first_input):
1306      first_input = first_input[0]
1307
1308    # Temporarily avoid EmbeddingWrapper and seq2seq badness
1309    # TODO(lukaszkaiser): remove EmbeddingWrapper
1310    if first_input.get_shape().rank != 1:
1311
1312      input_shape = first_input.get_shape().with_rank_at_least(2)
1313      fixed_batch_size = input_shape.dims[0]
1314
1315      flat_inputs = nest.flatten(inputs)
1316      for flat_input in flat_inputs:
1317        input_shape = flat_input.get_shape().with_rank_at_least(2)
1318        batch_size, input_size = tensor_shape.dimension_at_index(
1319            input_shape, 0), input_shape[1:]
1320        fixed_batch_size.merge_with(batch_size)
1321        for i, size in enumerate(input_size.dims):
1322          if tensor_shape.dimension_value(size) is None:
1323            raise ValueError(
1324                "Input size (dimension %d of inputs) must be accessible via "
1325                "shape inference, but saw value None." % i)
1326    else:
1327      fixed_batch_size = first_input.get_shape().with_rank_at_least(1)[0]
1328
1329    if tensor_shape.dimension_value(fixed_batch_size):
1330      batch_size = tensor_shape.dimension_value(fixed_batch_size)
1331    else:
1332      batch_size = array_ops.shape(first_input)[0]
1333    if initial_state is not None:
1334      state = initial_state
1335    else:
1336      if not dtype:
1337        raise ValueError("If no initial_state is provided, "
1338                         "dtype must be specified")
1339      if getattr(cell, "get_initial_state", None) is not None:
1340        state = cell.get_initial_state(
1341            inputs=None, batch_size=batch_size, dtype=dtype)
1342      else:
1343        state = cell.zero_state(batch_size, dtype)
1344
1345    if sequence_length is not None:  # Prepare variables
1346      sequence_length = ops.convert_to_tensor(
1347          sequence_length, name="sequence_length")
1348      if sequence_length.get_shape().rank not in (None, 1):
1349        raise ValueError(
1350            "sequence_length must be a vector of length batch_size")
1351
1352      def _create_zero_output(output_size):
1353        # convert int to TensorShape if necessary
1354        size = _concat(batch_size, output_size)
1355        output = array_ops.zeros(
1356            array_ops.stack(size), _infer_state_dtype(dtype, state))
1357        shape = _concat(tensor_shape.dimension_value(fixed_batch_size),
1358                        output_size,
1359                        static=True)
1360        output.set_shape(tensor_shape.TensorShape(shape))
1361        return output
1362
1363      output_size = cell.output_size
1364      flat_output_size = nest.flatten(output_size)
1365      flat_zero_output = tuple(
1366          _create_zero_output(size) for size in flat_output_size)
1367      zero_output = nest.pack_sequence_as(
1368          structure=output_size, flat_sequence=flat_zero_output)
1369
1370      sequence_length = math_ops.cast(sequence_length, dtypes.int32)
1371      min_sequence_length = math_ops.reduce_min(sequence_length)
1372      max_sequence_length = math_ops.reduce_max(sequence_length)
1373
1374    # Keras RNN cells only accept state as list, even if it's a single tensor.
1375    is_keras_rnn_cell = _is_keras_rnn_cell(cell)
1376    if is_keras_rnn_cell and not nest.is_sequence(state):
1377      state = [state]
1378    for time, input_ in enumerate(inputs):
1379      if time > 0:
1380        varscope.reuse_variables()
1381      # pylint: disable=cell-var-from-loop
1382      call_cell = lambda: cell(input_, state)
1383      # pylint: enable=cell-var-from-loop
1384      if sequence_length is not None:
1385        (output, state) = _rnn_step(
1386            time=time,
1387            sequence_length=sequence_length,
1388            min_sequence_length=min_sequence_length,
1389            max_sequence_length=max_sequence_length,
1390            zero_output=zero_output,
1391            state=state,
1392            call_cell=call_cell,
1393            state_size=cell.state_size)
1394      else:
1395        (output, state) = call_cell()
1396      outputs.append(output)
1397    # Keras RNN cells only return state as list, even if it's a single tensor.
1398    if is_keras_rnn_cell and len(state) == 1:
1399      state = state[0]
1400
1401    return (outputs, state)
1402
1403
1404@tf_export("nn.static_state_saving_rnn")
1405def static_state_saving_rnn(cell,
1406                            inputs,
1407                            state_saver,
1408                            state_name,
1409                            sequence_length=None,
1410                            scope=None):
1411  """RNN that accepts a state saver for time-truncated RNN calculation.
1412
1413  Args:
1414    cell: An instance of `RNNCell`.
1415    inputs: A length T list of inputs, each a `Tensor` of shape
1416      `[batch_size, input_size]`.
1417    state_saver: A state saver object with methods `state` and `save_state`.
1418    state_name: Python string or tuple of strings.  The name to use with the
1419      state_saver. If the cell returns tuples of states (i.e.,
1420      `cell.state_size` is a tuple) then `state_name` should be a tuple of
1421      strings having the same length as `cell.state_size`.  Otherwise it should
1422      be a single string.
1423    sequence_length: (optional) An int32/int64 vector size [batch_size].
1424      See the documentation for rnn() for more details about sequence_length.
1425    scope: VariableScope for the created subgraph; defaults to "rnn".
1426
1427  Returns:
1428    A pair (outputs, state) where:
1429      outputs is a length T list of outputs (one for each input)
1430      states is the final state
1431
1432  Raises:
1433    TypeError: If `cell` is not an instance of RNNCell.
1434    ValueError: If `inputs` is `None` or an empty list, or if the arity and
1435     type of `state_name` does not match that of `cell.state_size`.
1436  """
1437  state_size = cell.state_size
1438  state_is_tuple = nest.is_sequence(state_size)
1439  state_name_tuple = nest.is_sequence(state_name)
1440
1441  if state_is_tuple != state_name_tuple:
1442    raise ValueError("state_name should be the same type as cell.state_size.  "
1443                     "state_name: %s, cell.state_size: %s" % (str(state_name),
1444                                                              str(state_size)))
1445
1446  if state_is_tuple:
1447    state_name_flat = nest.flatten(state_name)
1448    state_size_flat = nest.flatten(state_size)
1449
1450    if len(state_name_flat) != len(state_size_flat):
1451      raise ValueError("#elems(state_name) != #elems(state_size): %d vs. %d" %
1452                       (len(state_name_flat), len(state_size_flat)))
1453
1454    initial_state = nest.pack_sequence_as(
1455        structure=state_size,
1456        flat_sequence=[state_saver.state(s) for s in state_name_flat])
1457  else:
1458    initial_state = state_saver.state(state_name)
1459
1460  (outputs, state) = static_rnn(
1461      cell,
1462      inputs,
1463      initial_state=initial_state,
1464      sequence_length=sequence_length,
1465      scope=scope)
1466
1467  if state_is_tuple:
1468    flat_state = nest.flatten(state)
1469    state_name = nest.flatten(state_name)
1470    save_state = [
1471        state_saver.save_state(name, substate)
1472        for name, substate in zip(state_name, flat_state)
1473    ]
1474  else:
1475    save_state = [state_saver.save_state(state_name, state)]
1476
1477  with ops.control_dependencies(save_state):
1478    last_output = outputs[-1]
1479    flat_last_output = nest.flatten(last_output)
1480    flat_last_output = [
1481        array_ops.identity(output) for output in flat_last_output
1482    ]
1483    outputs[-1] = nest.pack_sequence_as(
1484        structure=last_output, flat_sequence=flat_last_output)
1485
1486    if state_is_tuple:
1487      state = nest.pack_sequence_as(
1488          structure=state,
1489          flat_sequence=[array_ops.identity(s) for s in flat_state])
1490    else:
1491      state = array_ops.identity(state)
1492
1493  return (outputs, state)
1494
1495
1496@deprecation.deprecated(None, "Please use `keras.layers.Bidirectional("
1497                        "keras.layers.RNN(cell, unroll=True))`, which is "
1498                        "equivalent to this API")
1499@tf_export(v1=["nn.static_bidirectional_rnn"])
1500def static_bidirectional_rnn(cell_fw,
1501                             cell_bw,
1502                             inputs,
1503                             initial_state_fw=None,
1504                             initial_state_bw=None,
1505                             dtype=None,
1506                             sequence_length=None,
1507                             scope=None):
1508  """Creates a bidirectional recurrent neural network.
1509
1510  Similar to the unidirectional case above (rnn) but takes input and builds
1511  independent forward and backward RNNs with the final forward and backward
1512  outputs depth-concatenated, such that the output will have the format
1513  [time][batch][cell_fw.output_size + cell_bw.output_size]. The input_size of
1514  forward and backward cell must match. The initial state for both directions
1515  is zero by default (but can be set optionally) and no intermediate states are
1516  ever returned -- the network is fully unrolled for the given (passed in)
1517  length(s) of the sequence(s) or completely unrolled if length(s) is not given.
1518
1519  Args:
1520    cell_fw: An instance of RNNCell, to be used for forward direction.
1521    cell_bw: An instance of RNNCell, to be used for backward direction.
1522    inputs: A length T list of inputs, each a tensor of shape
1523      [batch_size, input_size], or a nested tuple of such elements.
1524    initial_state_fw: (optional) An initial state for the forward RNN.
1525      This must be a tensor of appropriate type and shape
1526      `[batch_size, cell_fw.state_size]`.
1527      If `cell_fw.state_size` is a tuple, this should be a tuple of
1528      tensors having shapes `[batch_size, s] for s in cell_fw.state_size`.
1529    initial_state_bw: (optional) Same as for `initial_state_fw`, but using
1530      the corresponding properties of `cell_bw`.
1531    dtype: (optional) The data type for the initial state.  Required if
1532      either of the initial states are not provided.
1533    sequence_length: (optional) An int32/int64 vector, size `[batch_size]`,
1534      containing the actual lengths for each of the sequences.
1535    scope: VariableScope for the created subgraph; defaults to
1536      "bidirectional_rnn"
1537
1538  Returns:
1539    A tuple (outputs, output_state_fw, output_state_bw) where:
1540      outputs is a length `T` list of outputs (one for each input), which
1541        are depth-concatenated forward and backward outputs.
1542      output_state_fw is the final state of the forward rnn.
1543      output_state_bw is the final state of the backward rnn.
1544
1545  Raises:
1546    TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`.
1547    ValueError: If inputs is None or an empty list.
1548  """
1549  rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw)
1550  rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw)
1551  if not nest.is_sequence(inputs):
1552    raise TypeError("inputs must be a sequence")
1553  if not inputs:
1554    raise ValueError("inputs must not be empty")
1555
1556  with vs.variable_scope(scope or "bidirectional_rnn"):
1557    # Forward direction
1558    with vs.variable_scope("fw") as fw_scope:
1559      output_fw, output_state_fw = static_rnn(
1560          cell_fw,
1561          inputs,
1562          initial_state_fw,
1563          dtype,
1564          sequence_length,
1565          scope=fw_scope)
1566
1567    # Backward direction
1568    with vs.variable_scope("bw") as bw_scope:
1569      reversed_inputs = _reverse_seq(inputs, sequence_length)
1570      tmp, output_state_bw = static_rnn(
1571          cell_bw,
1572          reversed_inputs,
1573          initial_state_bw,
1574          dtype,
1575          sequence_length,
1576          scope=bw_scope)
1577
1578  output_bw = _reverse_seq(tmp, sequence_length)
1579  # Concat each of the forward/backward outputs
1580  flat_output_fw = nest.flatten(output_fw)
1581  flat_output_bw = nest.flatten(output_bw)
1582
1583  flat_outputs = tuple(
1584      array_ops.concat([fw, bw], 1)
1585      for fw, bw in zip(flat_output_fw, flat_output_bw))
1586
1587  outputs = nest.pack_sequence_as(
1588      structure=output_fw, flat_sequence=flat_outputs)
1589
1590  return (outputs, output_state_fw, output_state_bw)
1591