1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""RNN helpers for TensorFlow models."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.eager import context
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import tensor_shape
25from tensorflow.python.framework import tensor_util
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import control_flow_ops
28from tensorflow.python.ops import control_flow_util
29from tensorflow.python.ops import control_flow_util_v2
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import rnn_cell_impl
32from tensorflow.python.ops import tensor_array_ops
33from tensorflow.python.ops import variable_scope as vs
34from tensorflow.python.util import deprecation
35from tensorflow.python.util import dispatch
36from tensorflow.python.util import nest
37from tensorflow.python.util.tf_export import tf_export
38
39# pylint: disable=protected-access
40_concat = rnn_cell_impl._concat
41# pylint: enable=protected-access
42
43
44def _transpose_batch_time(x):
45  """Transposes the batch and time dimensions of a Tensor.
46
47  If the input tensor has rank < 2 it returns the original tensor. Retains as
48  much of the static shape information as possible.
49
50  Args:
51    x: A Tensor.
52
53  Returns:
54    x transposed along the first two dimensions.
55  """
56  x_static_shape = x.get_shape()
57  if x_static_shape.rank is not None and x_static_shape.rank < 2:
58    return x
59
60  x_rank = array_ops.rank(x)
61  x_t = array_ops.transpose(
62      x, array_ops.concat(([1, 0], math_ops.range(2, x_rank)), axis=0))
63  x_t.set_shape(
64      tensor_shape.TensorShape(
65          [x_static_shape.dims[1].value,
66           x_static_shape.dims[0].value]).concatenate(x_static_shape[2:]))
67  return x_t
68
69
70def _best_effort_input_batch_size(flat_input):
71  """Get static input batch size if available, with fallback to the dynamic one.
72
73  Args:
74    flat_input: An iterable of time major input Tensors of shape `[max_time,
75      batch_size, ...]`. All inputs should have compatible batch sizes.
76
77  Returns:
78    The batch size in Python integer if available, or a scalar Tensor otherwise.
79
80  Raises:
81    ValueError: if there is any input with an invalid shape.
82  """
83  for input_ in flat_input:
84    shape = input_.shape
85    if shape.rank is None:
86      continue
87    if shape.rank < 2:
88      raise ValueError("Expected input tensor %s to have rank at least 2" %
89                       input_)
90    batch_size = shape.dims[1].value
91    if batch_size is not None:
92      return batch_size
93  # Fallback to the dynamic batch size of the first input.
94  return array_ops.shape(flat_input[0])[1]
95
96
97def _infer_state_dtype(explicit_dtype, state):
98  """Infer the dtype of an RNN state.
99
100  Args:
101    explicit_dtype: explicitly declared dtype or None.
102    state: RNN's hidden state. Must be a Tensor or a nested iterable containing
103      Tensors.
104
105  Returns:
106    dtype: inferred dtype of hidden state.
107
108  Raises:
109    ValueError: if `state` has heterogeneous dtypes or is empty.
110  """
111  if explicit_dtype is not None:
112    return explicit_dtype
113  elif nest.is_sequence(state):
114    inferred_dtypes = [element.dtype for element in nest.flatten(state)]
115    if not inferred_dtypes:
116      raise ValueError("Unable to infer dtype from empty state.")
117    all_same = all(x == inferred_dtypes[0] for x in inferred_dtypes)
118    if not all_same:
119      raise ValueError(
120          "State has tensors of different inferred_dtypes. Unable to infer a "
121          "single representative dtype.")
122    return inferred_dtypes[0]
123  else:
124    return state.dtype
125
126
127def _maybe_tensor_shape_from_tensor(shape):
128  if isinstance(shape, ops.Tensor):
129    return tensor_shape.as_shape(tensor_util.constant_value(shape))
130  else:
131    return shape
132
133
134def _should_cache():
135  """Returns True if a default caching device should be set, otherwise False."""
136  if context.executing_eagerly():
137    return False
138  # Don't set a caching device when running in a loop, since it is possible that
139  # train steps could be wrapped in a tf.while_loop. In that scenario caching
140  # prevents forward computations in loop iterations from re-reading the
141  # updated weights.
142  graph = ops.get_default_graph()
143  ctxt = graph._get_control_flow_context()  # pylint: disable=protected-access
144  in_v1_while_loop = (
145      control_flow_util.GetContainingWhileContext(ctxt) is not None)
146  in_v2_while_loop = control_flow_util_v2.in_while_loop_defun(graph)
147  return not in_v1_while_loop and not in_v2_while_loop
148
149
150# pylint: disable=unused-argument
151def _rnn_step(time,
152              sequence_length,
153              min_sequence_length,
154              max_sequence_length,
155              zero_output,
156              state,
157              call_cell,
158              state_size,
159              skip_conditionals=False):
160  """Calculate one step of a dynamic RNN minibatch.
161
162  Returns an (output, state) pair conditioned on `sequence_length`.
163  When skip_conditionals=False, the pseudocode is something like:
164
165  if t >= max_sequence_length:
166    return (zero_output, state)
167  if t < min_sequence_length:
168    return call_cell()
169
170  # Selectively output zeros or output, old state or new state depending
171  # on whether we've finished calculating each row.
172  new_output, new_state = call_cell()
173  final_output = np.vstack([
174    zero_output if time >= sequence_length[r] else new_output_r
175    for r, new_output_r in enumerate(new_output)
176  ])
177  final_state = np.vstack([
178    state[r] if time >= sequence_length[r] else new_state_r
179    for r, new_state_r in enumerate(new_state)
180  ])
181  return (final_output, final_state)
182
183  Args:
184    time: int32 `Tensor` scalar.
185    sequence_length: int32 `Tensor` vector of size [batch_size].
186    min_sequence_length: int32 `Tensor` scalar, min of sequence_length.
187    max_sequence_length: int32 `Tensor` scalar, max of sequence_length.
188    zero_output: `Tensor` vector of shape [output_size].
189    state: Either a single `Tensor` matrix of shape `[batch_size, state_size]`,
190      or a list/tuple of such tensors.
191    call_cell: lambda returning tuple of (new_output, new_state) where
192      new_output is a `Tensor` matrix of shape `[batch_size, output_size]`.
193      new_state is a `Tensor` matrix of shape `[batch_size, state_size]`.
194    state_size: The `cell.state_size` associated with the state.
195    skip_conditionals: Python bool, whether to skip using the conditional
196      calculations.  This is useful for `dynamic_rnn`, where the input tensor
197      matches `max_sequence_length`, and using conditionals just slows
198      everything down.
199
200  Returns:
201    A tuple of (`final_output`, `final_state`) as given by the pseudocode above:
202      final_output is a `Tensor` matrix of shape [batch_size, output_size]
203      final_state is either a single `Tensor` matrix, or a tuple of such
204        matrices (matching length and shapes of input `state`).
205
206  Raises:
207    ValueError: If the cell returns a state tuple whose length does not match
208      that returned by `state_size`.
209  """
210
211  # Convert state to a list for ease of use
212  flat_state = nest.flatten(state)
213  flat_zero_output = nest.flatten(zero_output)
214
215  # Vector describing which batch entries are finished.
216  copy_cond = time >= sequence_length
217
218  def _copy_one_through(output, new_output):
219    # TensorArray and scalar get passed through.
220    if isinstance(output, tensor_array_ops.TensorArray):
221      return new_output
222    if output.shape.rank == 0:
223      return new_output
224    # Otherwise propagate the old or the new value.
225    with ops.colocate_with(new_output):
226      return array_ops.where(copy_cond, output, new_output)
227
228  def _copy_some_through(flat_new_output, flat_new_state):
229    # Use broadcasting select to determine which values should get
230    # the previous state & zero output, and which values should get
231    # a calculated state & output.
232    flat_new_output = [
233        _copy_one_through(zero_output, new_output)
234        for zero_output, new_output in zip(flat_zero_output, flat_new_output)
235    ]
236    flat_new_state = [
237        _copy_one_through(state, new_state)
238        for state, new_state in zip(flat_state, flat_new_state)
239    ]
240    return flat_new_output + flat_new_state
241
242  def _maybe_copy_some_through():
243    """Run RNN step.  Pass through either no or some past state."""
244    new_output, new_state = call_cell()
245
246    nest.assert_same_structure(zero_output, new_output)
247    nest.assert_same_structure(state, new_state)
248
249    flat_new_state = nest.flatten(new_state)
250    flat_new_output = nest.flatten(new_output)
251    return control_flow_ops.cond(
252        # if t < min_seq_len: calculate and return everything
253        time < min_sequence_length,
254        lambda: flat_new_output + flat_new_state,
255        # else copy some of it through
256        lambda: _copy_some_through(flat_new_output, flat_new_state))
257
258  # TODO(ebrevdo): skipping these conditionals may cause a slowdown,
259  # but benefits from removing cond() and its gradient.  We should
260  # profile with and without this switch here.
261  if skip_conditionals:
262    # Instead of using conditionals, perform the selective copy at all time
263    # steps.  This is faster when max_seq_len is equal to the number of unrolls
264    # (which is typical for dynamic_rnn).
265    new_output, new_state = call_cell()
266    nest.assert_same_structure(zero_output, new_output)
267    nest.assert_same_structure(state, new_state)
268    new_state = nest.flatten(new_state)
269    new_output = nest.flatten(new_output)
270    final_output_and_state = _copy_some_through(new_output, new_state)
271  else:
272    empty_update = lambda: flat_zero_output + flat_state
273    final_output_and_state = control_flow_ops.cond(
274        # if t >= max_seq_len: copy all state through, output zeros
275        time >= max_sequence_length,
276        empty_update,
277        # otherwise calculation is required: copy some or all of it through
278        _maybe_copy_some_through)
279
280  if len(final_output_and_state) != len(flat_zero_output) + len(flat_state):
281    raise ValueError("Internal error: state and output were not concatenated "
282                     "correctly.")
283  final_output = final_output_and_state[:len(flat_zero_output)]
284  final_state = final_output_and_state[len(flat_zero_output):]
285
286  for output, flat_output in zip(final_output, flat_zero_output):
287    output.set_shape(flat_output.get_shape())
288  for substate, flat_substate in zip(final_state, flat_state):
289    if not isinstance(substate, tensor_array_ops.TensorArray):
290      substate.set_shape(flat_substate.get_shape())
291
292  final_output = nest.pack_sequence_as(
293      structure=zero_output, flat_sequence=final_output)
294  final_state = nest.pack_sequence_as(
295      structure=state, flat_sequence=final_state)
296
297  return final_output, final_state
298
299
300def _reverse_seq(input_seq, lengths):
301  """Reverse a list of Tensors up to specified lengths.
302
303  Args:
304    input_seq: Sequence of seq_len tensors of dimension (batch_size, n_features)
305      or nested tuples of tensors.
306    lengths:   A `Tensor` of dimension batch_size, containing lengths for each
307      sequence in the batch. If "None" is specified, simply reverses the list.
308
309  Returns:
310    time-reversed sequence
311  """
312  if lengths is None:
313    return list(reversed(input_seq))
314
315  flat_input_seq = tuple(nest.flatten(input_) for input_ in input_seq)
316
317  flat_results = [[] for _ in range(len(input_seq))]
318  for sequence in zip(*flat_input_seq):
319    input_shape = tensor_shape.unknown_shape(rank=sequence[0].get_shape().rank)
320    for input_ in sequence:
321      input_shape.assert_is_compatible_with(input_.get_shape())
322      input_.set_shape(input_shape)
323
324    # Join into (time, batch_size, depth)
325    s_joined = array_ops.stack(sequence)
326
327    # Reverse along dimension 0
328    s_reversed = array_ops.reverse_sequence(s_joined, lengths, 0, 1)
329    # Split again into list
330    result = array_ops.unstack(s_reversed)
331    for r, flat_result in zip(result, flat_results):
332      r.set_shape(input_shape)
333      flat_result.append(r)
334
335  results = [
336      nest.pack_sequence_as(structure=input_, flat_sequence=flat_result)
337      for input_, flat_result in zip(input_seq, flat_results)
338  ]
339  return results
340
341
342@deprecation.deprecated(None, "Please use `keras.layers.Bidirectional("
343                        "keras.layers.RNN(cell))`, which is equivalent to "
344                        "this API")
345@tf_export(v1=["nn.bidirectional_dynamic_rnn"])
346@dispatch.add_dispatch_support
347def bidirectional_dynamic_rnn(cell_fw,
348                              cell_bw,
349                              inputs,
350                              sequence_length=None,
351                              initial_state_fw=None,
352                              initial_state_bw=None,
353                              dtype=None,
354                              parallel_iterations=None,
355                              swap_memory=False,
356                              time_major=False,
357                              scope=None):
358  """Creates a dynamic version of bidirectional recurrent neural network.
359
360  Takes input and builds independent forward and backward RNNs. The input_size
361  of forward and backward cell must match. The initial state for both directions
362  is zero by default (but can be set optionally) and no intermediate states are
363  ever returned -- the network is fully unrolled for the given (passed in)
364  length(s) of the sequence(s) or completely unrolled if length(s) is not
365  given.
366
367  Args:
368    cell_fw: An instance of RNNCell, to be used for forward direction.
369    cell_bw: An instance of RNNCell, to be used for backward direction.
370    inputs: The RNN inputs.
371      If time_major == False (default), this must be a tensor of shape:
372        `[batch_size, max_time, ...]`, or a nested tuple of such elements.
373      If time_major == True, this must be a tensor of shape: `[max_time,
374        batch_size, ...]`, or a nested tuple of such elements.
375    sequence_length: (optional) An int32/int64 vector, size `[batch_size]`,
376      containing the actual lengths for each of the sequences in the batch. If
377      not provided, all batch entries are assumed to be full sequences; and time
378      reversal is applied from time `0` to `max_time` for each sequence.
379    initial_state_fw: (optional) An initial state for the forward RNN. This must
380      be a tensor of appropriate type and shape `[batch_size,
381      cell_fw.state_size]`. If `cell_fw.state_size` is a tuple, this should be a
382      tuple of tensors having shapes `[batch_size, s] for s in
383      cell_fw.state_size`.
384    initial_state_bw: (optional) Same as for `initial_state_fw`, but using the
385      corresponding properties of `cell_bw`.
386    dtype: (optional) The data type for the initial states and expected output.
387      Required if initial_states are not provided or RNN states have a
388      heterogeneous dtype.
389    parallel_iterations: (Default: 32).  The number of iterations to run in
390      parallel.  Those operations which do not have any temporal dependency and
391      can be run in parallel, will be.  This parameter trades off time for
392      space.  Values >> 1 use more memory but take less time, while smaller
393      values use less memory but computations take longer.
394    swap_memory: Transparently swap the tensors produced in forward inference
395      but needed for back prop from GPU to CPU.  This allows training RNNs which
396      would typically not fit on a single GPU, with very minimal (or no)
397      performance penalty.
398    time_major: The shape format of the `inputs` and `outputs` Tensors. If true,
399      these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false,
400      these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using
401      `time_major = True` is a bit more efficient because it avoids transposes
402      at the beginning and end of the RNN calculation.  However, most TensorFlow
403      data is batch-major, so by default this function accepts input and emits
404      output in batch-major form.
405    scope: VariableScope for the created subgraph; defaults to
406      "bidirectional_rnn"
407
408  Returns:
409    A tuple (outputs, output_states) where:
410      outputs: A tuple (output_fw, output_bw) containing the forward and
411        the backward rnn output `Tensor`.
412        If time_major == False (default),
413          output_fw will be a `Tensor` shaped:
414          `[batch_size, max_time, cell_fw.output_size]`
415          and output_bw will be a `Tensor` shaped:
416          `[batch_size, max_time, cell_bw.output_size]`.
417        If time_major == True,
418          output_fw will be a `Tensor` shaped:
419          `[max_time, batch_size, cell_fw.output_size]`
420          and output_bw will be a `Tensor` shaped:
421          `[max_time, batch_size, cell_bw.output_size]`.
422        It returns a tuple instead of a single concatenated `Tensor`, unlike
423        in the `bidirectional_rnn`. If the concatenated one is preferred,
424        the forward and backward outputs can be concatenated as
425        `tf.concat(outputs, 2)`.
426      output_states: A tuple (output_state_fw, output_state_bw) containing
427        the forward and the backward final states of bidirectional rnn.
428
429  Raises:
430    TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`.
431  """
432  rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw)
433  rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw)
434
435  with vs.variable_scope(scope or "bidirectional_rnn"):
436    # Forward direction
437    with vs.variable_scope("fw") as fw_scope:
438      output_fw, output_state_fw = dynamic_rnn(
439          cell=cell_fw,
440          inputs=inputs,
441          sequence_length=sequence_length,
442          initial_state=initial_state_fw,
443          dtype=dtype,
444          parallel_iterations=parallel_iterations,
445          swap_memory=swap_memory,
446          time_major=time_major,
447          scope=fw_scope)
448
449    # Backward direction
450    if not time_major:
451      time_axis = 1
452      batch_axis = 0
453    else:
454      time_axis = 0
455      batch_axis = 1
456
457    def _reverse(input_, seq_lengths, seq_axis, batch_axis):
458      if seq_lengths is not None:
459        return array_ops.reverse_sequence(
460            input=input_,
461            seq_lengths=seq_lengths,
462            seq_axis=seq_axis,
463            batch_axis=batch_axis)
464      else:
465        return array_ops.reverse(input_, axis=[seq_axis])
466
467    with vs.variable_scope("bw") as bw_scope:
468
469      def _map_reverse(inp):
470        return _reverse(
471            inp,
472            seq_lengths=sequence_length,
473            seq_axis=time_axis,
474            batch_axis=batch_axis)
475
476      inputs_reverse = nest.map_structure(_map_reverse, inputs)
477      tmp, output_state_bw = dynamic_rnn(
478          cell=cell_bw,
479          inputs=inputs_reverse,
480          sequence_length=sequence_length,
481          initial_state=initial_state_bw,
482          dtype=dtype,
483          parallel_iterations=parallel_iterations,
484          swap_memory=swap_memory,
485          time_major=time_major,
486          scope=bw_scope)
487
488  output_bw = _reverse(
489      tmp,
490      seq_lengths=sequence_length,
491      seq_axis=time_axis,
492      batch_axis=batch_axis)
493
494  outputs = (output_fw, output_bw)
495  output_states = (output_state_fw, output_state_bw)
496
497  return (outputs, output_states)
498
499
500@deprecation.deprecated(
501    None,
502    "Please use `keras.layers.RNN(cell)`, which is equivalent to this API")
503@tf_export(v1=["nn.dynamic_rnn"])
504@dispatch.add_dispatch_support
505def dynamic_rnn(cell,
506                inputs,
507                sequence_length=None,
508                initial_state=None,
509                dtype=None,
510                parallel_iterations=None,
511                swap_memory=False,
512                time_major=False,
513                scope=None):
514  """Creates a recurrent neural network specified by RNNCell `cell`.
515
516  Performs fully dynamic unrolling of `inputs`.
517
518  Example:
519
520  ```python
521  # create a BasicRNNCell
522  rnn_cell = tf.compat.v1.nn.rnn_cell.BasicRNNCell(hidden_size)
523
524  # 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]
525
526  # defining initial state
527  initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)
528
529  # 'state' is a tensor of shape [batch_size, cell_state_size]
530  outputs, state = tf.compat.v1.nn.dynamic_rnn(rnn_cell, input_data,
531                                     initial_state=initial_state,
532                                     dtype=tf.float32)
533  ```
534
535  ```python
536  # create 2 LSTMCells
537  rnn_layers = [tf.compat.v1.nn.rnn_cell.LSTMCell(size) for size in [128, 256]]
538
539  # create a RNN cell composed sequentially of a number of RNNCells
540  multi_rnn_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell(rnn_layers)
541
542  # 'outputs' is a tensor of shape [batch_size, max_time, 256]
543  # 'state' is a N-tuple where N is the number of LSTMCells containing a
544  # tf.nn.rnn_cell.LSTMStateTuple for each cell
545  outputs, state = tf.compat.v1.nn.dynamic_rnn(cell=multi_rnn_cell,
546                                     inputs=data,
547                                     dtype=tf.float32)
548  ```
549
550
551  Args:
552    cell: An instance of RNNCell.
553    inputs: The RNN inputs.
554      If `time_major == False` (default), this must be a `Tensor` of shape:
555        `[batch_size, max_time, ...]`, or a nested tuple of such elements.
556      If `time_major == True`, this must be a `Tensor` of shape: `[max_time,
557        batch_size, ...]`, or a nested tuple of such elements. This may also be
558        a (possibly nested) tuple of Tensors satisfying this property.  The
559        first two dimensions must match across all the inputs, but otherwise the
560        ranks and other shape components may differ. In this case, input to
561        `cell` at each time-step will replicate the structure of these tuples,
562        except for the time dimension (from which the time is taken). The input
563        to `cell` at each time step will be a `Tensor` or (possibly nested)
564        tuple of Tensors each with dimensions `[batch_size, ...]`.
565    sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. Used
566      to copy-through state and zero-out outputs when past a batch element's
567      sequence length.  This parameter enables users to extract the last valid
568      state and properly padded outputs, so it is provided for correctness.
569    initial_state: (optional) An initial state for the RNN. If `cell.state_size`
570      is an integer, this must be a `Tensor` of appropriate type and shape
571      `[batch_size, cell.state_size]`. If `cell.state_size` is a tuple, this
572      should be a tuple of tensors having shapes `[batch_size, s] for s in
573      cell.state_size`.
574    dtype: (optional) The data type for the initial state and expected output.
575      Required if initial_state is not provided or RNN state has a heterogeneous
576      dtype.
577    parallel_iterations: (Default: 32).  The number of iterations to run in
578      parallel.  Those operations which do not have any temporal dependency and
579      can be run in parallel, will be.  This parameter trades off time for
580      space.  Values >> 1 use more memory but take less time, while smaller
581      values use less memory but computations take longer.
582    swap_memory: Transparently swap the tensors produced in forward inference
583      but needed for back prop from GPU to CPU.  This allows training RNNs which
584      would typically not fit on a single GPU, with very minimal (or no)
585      performance penalty.
586    time_major: The shape format of the `inputs` and `outputs` Tensors. If true,
587      these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false,
588      these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using
589      `time_major = True` is a bit more efficient because it avoids transposes
590      at the beginning and end of the RNN calculation.  However, most TensorFlow
591      data is batch-major, so by default this function accepts input and emits
592      output in batch-major form.
593    scope: VariableScope for the created subgraph; defaults to "rnn".
594
595  Returns:
596    A pair (outputs, state) where:
597
598    outputs: The RNN output `Tensor`.
599
600      If time_major == False (default), this will be a `Tensor` shaped:
601        `[batch_size, max_time, cell.output_size]`.
602
603      If time_major == True, this will be a `Tensor` shaped:
604        `[max_time, batch_size, cell.output_size]`.
605
606      Note, if `cell.output_size` is a (possibly nested) tuple of integers
607      or `TensorShape` objects, then `outputs` will be a tuple having the
608      same structure as `cell.output_size`, containing Tensors having shapes
609      corresponding to the shape data in `cell.output_size`.
610
611    state: The final state.  If `cell.state_size` is an int, this
612      will be shaped `[batch_size, cell.state_size]`.  If it is a
613      `TensorShape`, this will be shaped `[batch_size] + cell.state_size`.
614      If it is a (possibly nested) tuple of ints or `TensorShape`, this will
615      be a tuple having the corresponding shapes. If cells are `LSTMCells`
616      `state` will be a tuple containing a `LSTMStateTuple` for each cell.
617
618  Raises:
619    TypeError: If `cell` is not an instance of RNNCell.
620    ValueError: If inputs is None or an empty list.
621  """
622  rnn_cell_impl.assert_like_rnncell("cell", cell)
623
624  with vs.variable_scope(scope or "rnn") as varscope:
625    # Create a new scope in which the caching device is either
626    # determined by the parent scope, or is set to place the cached
627    # Variable using the same placement as for the rest of the RNN.
628    if _should_cache():
629      if varscope.caching_device is None:
630        varscope.set_caching_device(lambda op: op.device)
631
632    # By default, time_major==False and inputs are batch-major: shaped
633    #   [batch, time, depth]
634    # For internal calculations, we transpose to [time, batch, depth]
635    flat_input = nest.flatten(inputs)
636
637    if not time_major:
638      # (B,T,D) => (T,B,D)
639      flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input]
640      flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input)
641
642    parallel_iterations = parallel_iterations or 32
643    if sequence_length is not None:
644      sequence_length = math_ops.cast(sequence_length, dtypes.int32)
645      if sequence_length.get_shape().rank not in (None, 1):
646        raise ValueError(
647            "sequence_length must be a vector of length batch_size, "
648            "but saw shape: %s" % sequence_length.get_shape())
649      sequence_length = array_ops.identity(  # Just to find it in the graph.
650          sequence_length,
651          name="sequence_length")
652
653    batch_size = _best_effort_input_batch_size(flat_input)
654
655    if initial_state is not None:
656      state = initial_state
657    else:
658      if not dtype:
659        raise ValueError("If there is no initial_state, you must give a dtype.")
660      if getattr(cell, "get_initial_state", None) is not None:
661        state = cell.get_initial_state(
662            inputs=None, batch_size=batch_size, dtype=dtype)
663      else:
664        state = cell.zero_state(batch_size, dtype)
665
666    def _assert_has_shape(x, shape):
667      x_shape = array_ops.shape(x)
668      packed_shape = array_ops.stack(shape)
669      return control_flow_ops.Assert(
670          math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)), [
671              "Expected shape for Tensor %s is " % x.name, packed_shape,
672              " but saw shape: ", x_shape
673          ])
674
675    if not context.executing_eagerly() and sequence_length is not None:
676      # Perform some shape validation
677      with ops.control_dependencies(
678          [_assert_has_shape(sequence_length, [batch_size])]):
679        sequence_length = array_ops.identity(
680            sequence_length, name="CheckSeqLen")
681
682    inputs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input)
683
684    (outputs, final_state) = _dynamic_rnn_loop(
685        cell,
686        inputs,
687        state,
688        parallel_iterations=parallel_iterations,
689        swap_memory=swap_memory,
690        sequence_length=sequence_length,
691        dtype=dtype)
692
693    # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth].
694    # If we are performing batch-major calculations, transpose output back
695    # to shape [batch, time, depth]
696    if not time_major:
697      # (T,B,D) => (B,T,D)
698      outputs = nest.map_structure(_transpose_batch_time, outputs)
699
700    return (outputs, final_state)
701
702
703def _dynamic_rnn_loop(cell,
704                      inputs,
705                      initial_state,
706                      parallel_iterations,
707                      swap_memory,
708                      sequence_length=None,
709                      dtype=None):
710  """Internal implementation of Dynamic RNN.
711
712  Args:
713    cell: An instance of RNNCell.
714    inputs: A `Tensor` of shape [time, batch_size, input_size], or a nested
715      tuple of such elements.
716    initial_state: A `Tensor` of shape `[batch_size, state_size]`, or if
717      `cell.state_size` is a tuple, then this should be a tuple of tensors
718      having shapes `[batch_size, s] for s in cell.state_size`.
719    parallel_iterations: Positive Python int.
720    swap_memory: A Python boolean
721    sequence_length: (optional) An `int32` `Tensor` of shape [batch_size].
722    dtype: (optional) Expected dtype of output. If not specified, inferred from
723      initial_state.
724
725  Returns:
726    Tuple `(final_outputs, final_state)`.
727    final_outputs:
728      A `Tensor` of shape `[time, batch_size, cell.output_size]`.  If
729      `cell.output_size` is a (possibly nested) tuple of ints or `TensorShape`
730      objects, then this returns a (possibly nested) tuple of Tensors matching
731      the corresponding shapes.
732    final_state:
733      A `Tensor`, or possibly nested tuple of Tensors, matching in length
734      and shapes to `initial_state`.
735
736  Raises:
737    ValueError: If the input depth cannot be inferred via shape inference
738      from the inputs.
739    ValueError: If time_step is not the same for all the elements in the
740      inputs.
741    ValueError: If batch_size is not the same for all the elements in the
742      inputs.
743  """
744  state = initial_state
745  assert isinstance(parallel_iterations, int), "parallel_iterations must be int"
746
747  state_size = cell.state_size
748
749  flat_input = nest.flatten(inputs)
750  flat_output_size = nest.flatten(cell.output_size)
751
752  # Construct an initial output
753  input_shape = array_ops.shape(flat_input[0])
754  time_steps = input_shape[0]
755  batch_size = _best_effort_input_batch_size(flat_input)
756
757  inputs_got_shape = tuple(
758      input_.get_shape().with_rank_at_least(3) for input_ in flat_input)
759
760  const_time_steps, const_batch_size = inputs_got_shape[0].as_list()[:2]
761
762  for shape in inputs_got_shape:
763    if not shape[2:].is_fully_defined():
764      raise ValueError(
765          "Input size (depth of inputs) must be accessible via shape inference,"
766          " but saw value None.")
767    got_time_steps = shape.dims[0].value
768    got_batch_size = shape.dims[1].value
769    if const_time_steps != got_time_steps:
770      raise ValueError(
771          "Time steps is not the same for all the elements in the input in a "
772          "batch.")
773    if const_batch_size != got_batch_size:
774      raise ValueError(
775          "Batch_size is not the same for all the elements in the input.")
776
777  # Prepare dynamic conditional copying of state & output
778  def _create_zero_arrays(size):
779    size = _concat(batch_size, size)
780    return array_ops.zeros(
781        array_ops.stack(size), _infer_state_dtype(dtype, state))
782
783  flat_zero_output = tuple(
784      _create_zero_arrays(output) for output in flat_output_size)
785  zero_output = nest.pack_sequence_as(
786      structure=cell.output_size, flat_sequence=flat_zero_output)
787
788  if sequence_length is not None:
789    min_sequence_length = math_ops.reduce_min(sequence_length)
790    max_sequence_length = math_ops.reduce_max(sequence_length)
791  else:
792    max_sequence_length = time_steps
793
794  time = array_ops.constant(0, dtype=dtypes.int32, name="time")
795
796  with ops.name_scope("dynamic_rnn") as scope:
797    base_name = scope
798
799  def _create_ta(name, element_shape, dtype):
800    return tensor_array_ops.TensorArray(
801        dtype=dtype,
802        size=time_steps,
803        element_shape=element_shape,
804        tensor_array_name=base_name + name)
805
806  in_graph_mode = not context.executing_eagerly()
807  if in_graph_mode:
808    output_ta = tuple(
809        _create_ta(
810            "output_%d" % i,
811            element_shape=(
812                tensor_shape.TensorShape([const_batch_size]).concatenate(
813                    _maybe_tensor_shape_from_tensor(out_size))),
814            dtype=_infer_state_dtype(dtype, state))
815        for i, out_size in enumerate(flat_output_size))
816    input_ta = tuple(
817        _create_ta(
818            "input_%d" % i,
819            element_shape=flat_input_i.shape[1:],
820            dtype=flat_input_i.dtype)
821        for i, flat_input_i in enumerate(flat_input))
822    input_ta = tuple(
823        ta.unstack(input_) for ta, input_ in zip(input_ta, flat_input))
824  else:
825    output_ta = tuple([0 for _ in range(time_steps.numpy())]
826                      for i in range(len(flat_output_size)))
827    input_ta = flat_input
828
829  def _time_step(time, output_ta_t, state):
830    """Take a time step of the dynamic RNN.
831
832    Args:
833      time: int32 scalar Tensor.
834      output_ta_t: List of `TensorArray`s that represent the output.
835      state: nested tuple of vector tensors that represent the state.
836
837    Returns:
838      The tuple (time + 1, output_ta_t with updated flow, new_state).
839    """
840
841    if in_graph_mode:
842      input_t = tuple(ta.read(time) for ta in input_ta)
843      # Restore some shape information
844      for input_, shape in zip(input_t, inputs_got_shape):
845        input_.set_shape(shape[1:])
846    else:
847      input_t = tuple(ta[time.numpy()] for ta in input_ta)
848
849    input_t = nest.pack_sequence_as(structure=inputs, flat_sequence=input_t)
850    # Keras RNN cells only accept state as list, even if it's a single tensor.
851    call_cell = lambda: cell(input_t, state)
852
853    if sequence_length is not None:
854      (output, new_state) = _rnn_step(
855          time=time,
856          sequence_length=sequence_length,
857          min_sequence_length=min_sequence_length,
858          max_sequence_length=max_sequence_length,
859          zero_output=zero_output,
860          state=state,
861          call_cell=call_cell,
862          state_size=state_size,
863          skip_conditionals=True)
864    else:
865      (output, new_state) = call_cell()
866
867    # Pack state if using state tuples
868    output = nest.flatten(output)
869
870    if in_graph_mode:
871      output_ta_t = tuple(
872          ta.write(time, out) for ta, out in zip(output_ta_t, output))
873    else:
874      for ta, out in zip(output_ta_t, output):
875        ta[time.numpy()] = out
876
877    return (time + 1, output_ta_t, new_state)
878
879  if in_graph_mode:
880    # Make sure that we run at least 1 step, if necessary, to ensure
881    # the TensorArrays pick up the dynamic shape.
882    loop_bound = math_ops.minimum(time_steps,
883                                  math_ops.maximum(1, max_sequence_length))
884  else:
885    # Using max_sequence_length isn't currently supported in the Eager branch.
886    loop_bound = time_steps
887
888  _, output_final_ta, final_state = control_flow_ops.while_loop(
889      cond=lambda time, *_: time < loop_bound,
890      body=_time_step,
891      loop_vars=(time, output_ta, state),
892      parallel_iterations=parallel_iterations,
893      maximum_iterations=time_steps,
894      swap_memory=swap_memory)
895
896  # Unpack final output if not using output tuples.
897  if in_graph_mode:
898    final_outputs = tuple(ta.stack() for ta in output_final_ta)
899    # Restore some shape information
900    for output, output_size in zip(final_outputs, flat_output_size):
901      shape = _concat([const_time_steps, const_batch_size],
902                      output_size,
903                      static=True)
904      output.set_shape(shape)
905  else:
906    final_outputs = output_final_ta
907
908  final_outputs = nest.pack_sequence_as(
909      structure=cell.output_size, flat_sequence=final_outputs)
910  if not in_graph_mode:
911    final_outputs = nest.map_structure_up_to(
912        cell.output_size, lambda x: array_ops.stack(x, axis=0), final_outputs)
913
914  return (final_outputs, final_state)
915
916
917@tf_export(v1=["nn.raw_rnn"])
918@dispatch.add_dispatch_support
919def raw_rnn(cell,
920            loop_fn,
921            parallel_iterations=None,
922            swap_memory=False,
923            scope=None):
924  """Creates an `RNN` specified by RNNCell `cell` and loop function `loop_fn`.
925
926  **NOTE: This method is still in testing, and the API may change.**
927
928  This function is a more primitive version of `dynamic_rnn` that provides
929  more direct access to the inputs each iteration.  It also provides more
930  control over when to start and finish reading the sequence, and
931  what to emit for the output.
932
933  For example, it can be used to implement the dynamic decoder of a seq2seq
934  model.
935
936  Instead of working with `Tensor` objects, most operations work with
937  `TensorArray` objects directly.
938
939  The operation of `raw_rnn`, in pseudo-code, is basically the following:
940
941  ```python
942  time = tf.constant(0, dtype=tf.int32)
943  (finished, next_input, initial_state, emit_structure, loop_state) = loop_fn(
944      time=time, cell_output=None, cell_state=None, loop_state=None)
945  emit_ta = TensorArray(dynamic_size=True, dtype=initial_state.dtype)
946  state = initial_state
947  while not all(finished):
948    (output, cell_state) = cell(next_input, state)
949    (next_finished, next_input, next_state, emit, loop_state) = loop_fn(
950        time=time + 1, cell_output=output, cell_state=cell_state,
951        loop_state=loop_state)
952    # Emit zeros and copy forward state for minibatch entries that are finished.
953    state = tf.where(finished, state, next_state)
954    emit = tf.where(finished, tf.zeros_like(emit_structure), emit)
955    emit_ta = emit_ta.write(time, emit)
956    # If any new minibatch entries are marked as finished, mark these.
957    finished = tf.logical_or(finished, next_finished)
958    time += 1
959  return (emit_ta, state, loop_state)
960  ```
961
962  with the additional properties that output and state may be (possibly nested)
963  tuples, as determined by `cell.output_size` and `cell.state_size`, and
964  as a result the final `state` and `emit_ta` may themselves be tuples.
965
966  A simple implementation of `dynamic_rnn` via `raw_rnn` looks like this:
967
968  ```python
969  inputs = tf.compat.v1.placeholder(shape=(max_time, batch_size, input_depth),
970                          dtype=tf.float32)
971  sequence_length = tf.compat.v1.placeholder(shape=(batch_size,),
972  dtype=tf.int32)
973  inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time)
974  inputs_ta = inputs_ta.unstack(inputs)
975
976  cell = tf.compat.v1.nn.rnn_cell.LSTMCell(num_units)
977
978  def loop_fn(time, cell_output, cell_state, loop_state):
979    emit_output = cell_output  # == None for time == 0
980    if cell_output is None:  # time == 0
981      next_cell_state = cell.zero_state(batch_size, tf.float32)
982    else:
983      next_cell_state = cell_state
984    elements_finished = (time >= sequence_length)
985    finished = tf.reduce_all(elements_finished)
986    next_input = tf.cond(
987        finished,
988        lambda: tf.zeros([batch_size, input_depth], dtype=tf.float32),
989        lambda: inputs_ta.read(time))
990    next_loop_state = None
991    return (elements_finished, next_input, next_cell_state,
992            emit_output, next_loop_state)
993
994  outputs_ta, final_state, _ = raw_rnn(cell, loop_fn)
995  outputs = outputs_ta.stack()
996  ```
997
998  Args:
999    cell: An instance of RNNCell.
1000    loop_fn: A callable that takes inputs `(time, cell_output, cell_state,
1001      loop_state)` and returns the tuple `(finished, next_input,
1002      next_cell_state, emit_output, next_loop_state)`. Here `time` is an int32
1003      scalar `Tensor`, `cell_output` is a `Tensor` or (possibly nested) tuple of
1004      tensors as determined by `cell.output_size`, and `cell_state` is a
1005      `Tensor` or (possibly nested) tuple of tensors, as determined by the
1006      `loop_fn` on its first call (and should match `cell.state_size`).
1007      The outputs are: `finished`, a boolean `Tensor` of
1008      shape `[batch_size]`, `next_input`: the next input to feed to `cell`,
1009      `next_cell_state`: the next state to feed to `cell`,
1010      and `emit_output`: the output to store for this iteration.  Note that
1011        `emit_output` should be a `Tensor` or (possibly nested) tuple of tensors
1012        which is aggregated in the `emit_ta` inside the `while_loop`. For the
1013        first call to `loop_fn`, the `emit_output` corresponds to the
1014        `emit_structure` which is then used to determine the size of the
1015        `zero_tensor` for the `emit_ta` (defaults to `cell.output_size`). For
1016        the subsequent calls to the `loop_fn`, the `emit_output` corresponds to
1017        the actual output tensor that is to be aggregated in the `emit_ta`. The
1018        parameter `cell_state` and output `next_cell_state` may be either a
1019        single or (possibly nested) tuple of tensors.  The parameter
1020        `loop_state` and output `next_loop_state` may be either a single or
1021        (possibly nested) tuple of `Tensor` and `TensorArray` objects.  This
1022        last parameter may be ignored by `loop_fn` and the return value may be
1023        `None`.  If it is not `None`, then the `loop_state` will be propagated
1024        through the RNN loop, for use purely by `loop_fn` to keep track of its
1025        own state. The `next_loop_state` parameter returned may be `None`.  The
1026        first call to `loop_fn` will be `time = 0`, `cell_output = None`,
1027      `cell_state = None`, and `loop_state = None`.  For this call: The
1028        `next_cell_state` value should be the value with which to initialize the
1029        cell's state.  It may be a final state from a previous RNN or it may be
1030        the output of `cell.zero_state()`.  It should be a (possibly nested)
1031        tuple structure of tensors. If `cell.state_size` is an integer, this
1032        must be a `Tensor` of appropriate type and shape `[batch_size,
1033        cell.state_size]`. If `cell.state_size` is a `TensorShape`, this must be
1034        a `Tensor` of appropriate type and shape `[batch_size] +
1035        cell.state_size`. If `cell.state_size` is a (possibly nested) tuple of
1036        ints or `TensorShape`, this will be a tuple having the corresponding
1037        shapes. The `emit_output` value may be either `None` or a (possibly
1038        nested) tuple structure of tensors, e.g., `(tf.zeros(shape_0,
1039        dtype=dtype_0), tf.zeros(shape_1, dtype=dtype_1))`. If this first
1040        `emit_output` return value is `None`, then the `emit_ta` result of
1041        `raw_rnn` will have the same structure and dtypes as `cell.output_size`.
1042        Otherwise `emit_ta` will have the same structure, shapes (prepended with
1043        a `batch_size` dimension), and dtypes as `emit_output`.  The actual
1044        values returned for `emit_output` at this initializing call are ignored.
1045        Note, this emit structure must be consistent across all time steps.
1046    parallel_iterations: (Default: 32).  The number of iterations to run in
1047      parallel.  Those operations which do not have any temporal dependency and
1048      can be run in parallel, will be.  This parameter trades off time for
1049      space.  Values >> 1 use more memory but take less time, while smaller
1050      values use less memory but computations take longer.
1051    swap_memory: Transparently swap the tensors produced in forward inference
1052      but needed for back prop from GPU to CPU.  This allows training RNNs which
1053      would typically not fit on a single GPU, with very minimal (or no)
1054      performance penalty.
1055    scope: VariableScope for the created subgraph; defaults to "rnn".
1056
1057  Returns:
1058    A tuple `(emit_ta, final_state, final_loop_state)` where:
1059
1060    `emit_ta`: The RNN output `TensorArray`.
1061       If `loop_fn` returns a (possibly nested) set of Tensors for
1062       `emit_output` during initialization, (inputs `time = 0`,
1063       `cell_output = None`, and `loop_state = None`), then `emit_ta` will
1064       have the same structure, dtypes, and shapes as `emit_output` instead.
1065       If `loop_fn` returns `emit_output = None` during this call,
1066       the structure of `cell.output_size` is used:
1067       If `cell.output_size` is a (possibly nested) tuple of integers
1068       or `TensorShape` objects, then `emit_ta` will be a tuple having the
1069       same structure as `cell.output_size`, containing TensorArrays whose
1070       elements' shapes correspond to the shape data in `cell.output_size`.
1071
1072    `final_state`: The final cell state.  If `cell.state_size` is an int, this
1073      will be shaped `[batch_size, cell.state_size]`.  If it is a
1074      `TensorShape`, this will be shaped `[batch_size] + cell.state_size`.
1075      If it is a (possibly nested) tuple of ints or `TensorShape`, this will
1076      be a tuple having the corresponding shapes.
1077
1078    `final_loop_state`: The final loop state as returned by `loop_fn`.
1079
1080  Raises:
1081    TypeError: If `cell` is not an instance of RNNCell, or `loop_fn` is not
1082      a `callable`.
1083  """
1084  rnn_cell_impl.assert_like_rnncell("cell", cell)
1085
1086  if not callable(loop_fn):
1087    raise TypeError("loop_fn must be a callable")
1088
1089  parallel_iterations = parallel_iterations or 32
1090
1091  # Create a new scope in which the caching device is either
1092  # determined by the parent scope, or is set to place the cached
1093  # Variable using the same placement as for the rest of the RNN.
1094  with vs.variable_scope(scope or "rnn") as varscope:
1095    if _should_cache():
1096      if varscope.caching_device is None:
1097        varscope.set_caching_device(lambda op: op.device)
1098
1099    time = constant_op.constant(0, dtype=dtypes.int32)
1100    (elements_finished, next_input,
1101     initial_state, emit_structure, init_loop_state) = loop_fn(
1102         time, None, None, None)  # time, cell_output, cell_state, loop_state
1103    flat_input = nest.flatten(next_input)
1104
1105    # Need a surrogate loop state for the while_loop if none is available.
1106    loop_state = (
1107        init_loop_state if init_loop_state is not None else
1108        constant_op.constant(0, dtype=dtypes.int32))
1109
1110    input_shape = [input_.get_shape() for input_ in flat_input]
1111    static_batch_size = tensor_shape.dimension_at_index(input_shape[0], 0)
1112
1113    for input_shape_i in input_shape:
1114      # Static verification that batch sizes all match
1115      static_batch_size.assert_is_compatible_with(
1116          tensor_shape.dimension_at_index(input_shape_i, 0))
1117
1118    batch_size = tensor_shape.dimension_value(static_batch_size)
1119    const_batch_size = batch_size
1120    if batch_size is None:
1121      batch_size = array_ops.shape(flat_input[0])[0]
1122
1123    nest.assert_same_structure(initial_state, cell.state_size)
1124    state = initial_state
1125    flat_state = nest.flatten(state)
1126    flat_state = [ops.convert_to_tensor(s) for s in flat_state]
1127    state = nest.pack_sequence_as(structure=state, flat_sequence=flat_state)
1128
1129    if emit_structure is not None:
1130      flat_emit_structure = nest.flatten(emit_structure)
1131      flat_emit_size = [
1132          emit.shape if emit.shape.is_fully_defined() else array_ops.shape(emit)
1133          for emit in flat_emit_structure
1134      ]
1135      flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
1136    else:
1137      emit_structure = cell.output_size
1138      flat_emit_size = nest.flatten(emit_structure)
1139      flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size)
1140
1141    flat_emit_ta = [
1142        tensor_array_ops.TensorArray(
1143            dtype=dtype_i,
1144            dynamic_size=True,
1145            element_shape=(tensor_shape.TensorShape([
1146                const_batch_size
1147            ]).concatenate(_maybe_tensor_shape_from_tensor(size_i))),
1148            size=0,
1149            name="rnn_output_%d" % i)
1150        for i, (dtype_i,
1151                size_i) in enumerate(zip(flat_emit_dtypes, flat_emit_size))
1152    ]
1153    emit_ta = nest.pack_sequence_as(
1154        structure=emit_structure, flat_sequence=flat_emit_ta)
1155    flat_zero_emit = [
1156        array_ops.zeros(_concat(batch_size, size_i), dtype_i)
1157        for size_i, dtype_i in zip(flat_emit_size, flat_emit_dtypes)
1158    ]
1159    zero_emit = nest.pack_sequence_as(
1160        structure=emit_structure, flat_sequence=flat_zero_emit)
1161
1162    def condition(unused_time, elements_finished, *_):
1163      return math_ops.logical_not(math_ops.reduce_all(elements_finished))
1164
1165    def body(time, elements_finished, current_input, emit_ta, state,
1166             loop_state):
1167      """Internal while loop body for raw_rnn.
1168
1169      Args:
1170        time: time scalar.
1171        elements_finished: batch-size vector.
1172        current_input: possibly nested tuple of input tensors.
1173        emit_ta: possibly nested tuple of output TensorArrays.
1174        state: possibly nested tuple of state tensors.
1175        loop_state: possibly nested tuple of loop state tensors.
1176
1177      Returns:
1178        Tuple having the same size as Args but with updated values.
1179      """
1180      (next_output, cell_state) = cell(current_input, state)
1181
1182      nest.assert_same_structure(state, cell_state)
1183      nest.assert_same_structure(cell.output_size, next_output)
1184
1185      next_time = time + 1
1186      (next_finished, next_input, next_state, emit_output,
1187       next_loop_state) = loop_fn(next_time, next_output, cell_state,
1188                                  loop_state)
1189
1190      nest.assert_same_structure(state, next_state)
1191      nest.assert_same_structure(current_input, next_input)
1192      nest.assert_same_structure(emit_ta, emit_output)
1193
1194      # If loop_fn returns None for next_loop_state, just reuse the
1195      # previous one.
1196      loop_state = loop_state if next_loop_state is None else next_loop_state
1197
1198      def _copy_some_through(current, candidate):
1199        """Copy some tensors through via array_ops.where."""
1200
1201        def copy_fn(cur_i, cand_i):
1202          # TensorArray and scalar get passed through.
1203          if isinstance(cur_i, tensor_array_ops.TensorArray):
1204            return cand_i
1205          if cur_i.shape.rank == 0:
1206            return cand_i
1207          # Otherwise propagate the old or the new value.
1208          with ops.colocate_with(cand_i):
1209            return array_ops.where(elements_finished, cur_i, cand_i)
1210
1211        return nest.map_structure(copy_fn, current, candidate)
1212
1213      emit_output = _copy_some_through(zero_emit, emit_output)
1214      next_state = _copy_some_through(state, next_state)
1215
1216      emit_ta = nest.map_structure(lambda ta, emit: ta.write(time, emit),
1217                                   emit_ta, emit_output)
1218
1219      elements_finished = math_ops.logical_or(elements_finished, next_finished)
1220
1221      return (next_time, elements_finished, next_input, emit_ta, next_state,
1222              loop_state)
1223
1224    returned = control_flow_ops.while_loop(
1225        condition,
1226        body,
1227        loop_vars=[
1228            time, elements_finished, next_input, emit_ta, state, loop_state
1229        ],
1230        parallel_iterations=parallel_iterations,
1231        swap_memory=swap_memory)
1232
1233    (emit_ta, final_state, final_loop_state) = returned[-3:]
1234
1235    if init_loop_state is None:
1236      final_loop_state = None
1237
1238    return (emit_ta, final_state, final_loop_state)
1239
1240
1241@deprecation.deprecated(None,
1242                        "Please use `keras.layers.RNN(cell, unroll=True)`, "
1243                        "which is equivalent to this API")
1244@tf_export(v1=["nn.static_rnn"])
1245@dispatch.add_dispatch_support
1246def static_rnn(cell,
1247               inputs,
1248               initial_state=None,
1249               dtype=None,
1250               sequence_length=None,
1251               scope=None):
1252  """Creates a recurrent neural network specified by RNNCell `cell`.
1253
1254  The simplest form of RNN network generated is:
1255
1256  ```python
1257    state = cell.zero_state(...)
1258    outputs = []
1259    for input_ in inputs:
1260      output, state = cell(input_, state)
1261      outputs.append(output)
1262    return (outputs, state)
1263  ```
1264  However, a few other options are available:
1265
1266  An initial state can be provided.
1267  If the sequence_length vector is provided, dynamic calculation is performed.
1268  This method of calculation does not compute the RNN steps past the maximum
1269  sequence length of the minibatch (thus saving computational time),
1270  and properly propagates the state at an example's sequence length
1271  to the final state output.
1272
1273  The dynamic calculation performed is, at time `t` for batch row `b`,
1274
1275  ```python
1276    (output, state)(b, t) =
1277      (t >= sequence_length(b))
1278        ? (zeros(cell.output_size), states(b, sequence_length(b) - 1))
1279        : cell(input(b, t), state(b, t - 1))
1280  ```
1281
1282  Args:
1283    cell: An instance of RNNCell.
1284    inputs: A length T list of inputs, each a `Tensor` of shape `[batch_size,
1285      input_size]`, or a nested tuple of such elements.
1286    initial_state: (optional) An initial state for the RNN. If `cell.state_size`
1287      is an integer, this must be a `Tensor` of appropriate type and shape
1288      `[batch_size, cell.state_size]`. If `cell.state_size` is a tuple, this
1289      should be a tuple of tensors having shapes `[batch_size, s] for s in
1290      cell.state_size`.
1291    dtype: (optional) The data type for the initial state and expected output.
1292      Required if initial_state is not provided or RNN state has a heterogeneous
1293      dtype.
1294    sequence_length: Specifies the length of each sequence in inputs. An int32
1295      or int64 vector (tensor) size `[batch_size]`, values in `[0, T)`.
1296    scope: VariableScope for the created subgraph; defaults to "rnn".
1297
1298  Returns:
1299    A pair (outputs, state) where:
1300
1301    - outputs is a length T list of outputs (one for each input), or a nested
1302      tuple of such elements.
1303    - state is the final state
1304
1305  Raises:
1306    TypeError: If `cell` is not an instance of RNNCell.
1307    ValueError: If `inputs` is `None` or an empty list, or if the input depth
1308      (column size) cannot be inferred from inputs via shape inference.
1309  """
1310  rnn_cell_impl.assert_like_rnncell("cell", cell)
1311  if not nest.is_sequence(inputs):
1312    raise TypeError("inputs must be a sequence")
1313  if not inputs:
1314    raise ValueError("inputs must not be empty")
1315
1316  outputs = []
1317  # Create a new scope in which the caching device is either
1318  # determined by the parent scope, or is set to place the cached
1319  # Variable using the same placement as for the rest of the RNN.
1320  with vs.variable_scope(scope or "rnn") as varscope:
1321    if _should_cache():
1322      if varscope.caching_device is None:
1323        varscope.set_caching_device(lambda op: op.device)
1324
1325    # Obtain the first sequence of the input
1326    first_input = inputs
1327    while nest.is_sequence(first_input):
1328      first_input = first_input[0]
1329
1330    # Temporarily avoid EmbeddingWrapper and seq2seq badness
1331    # TODO(lukaszkaiser): remove EmbeddingWrapper
1332    if first_input.get_shape().rank != 1:
1333
1334      input_shape = first_input.get_shape().with_rank_at_least(2)
1335      fixed_batch_size = input_shape.dims[0]
1336
1337      flat_inputs = nest.flatten(inputs)
1338      for flat_input in flat_inputs:
1339        input_shape = flat_input.get_shape().with_rank_at_least(2)
1340        batch_size, input_size = tensor_shape.dimension_at_index(
1341            input_shape, 0), input_shape[1:]
1342        fixed_batch_size.assert_is_compatible_with(batch_size)
1343        for i, size in enumerate(input_size.dims):
1344          if tensor_shape.dimension_value(size) is None:
1345            raise ValueError(
1346                "Input size (dimension %d of inputs) must be accessible via "
1347                "shape inference, but saw value None." % i)
1348    else:
1349      fixed_batch_size = first_input.get_shape().with_rank_at_least(1)[0]
1350
1351    if tensor_shape.dimension_value(fixed_batch_size):
1352      batch_size = tensor_shape.dimension_value(fixed_batch_size)
1353    else:
1354      batch_size = array_ops.shape(first_input)[0]
1355    if initial_state is not None:
1356      state = initial_state
1357    else:
1358      if not dtype:
1359        raise ValueError("If no initial_state is provided, "
1360                         "dtype must be specified")
1361      if getattr(cell, "get_initial_state", None) is not None:
1362        state = cell.get_initial_state(
1363            inputs=None, batch_size=batch_size, dtype=dtype)
1364      else:
1365        state = cell.zero_state(batch_size, dtype)
1366
1367    if sequence_length is not None:  # Prepare variables
1368      sequence_length = ops.convert_to_tensor(
1369          sequence_length, name="sequence_length")
1370      if sequence_length.get_shape().rank not in (None, 1):
1371        raise ValueError(
1372            "sequence_length must be a vector of length batch_size")
1373
1374      def _create_zero_output(output_size):
1375        # convert int to TensorShape if necessary
1376        size = _concat(batch_size, output_size)
1377        output = array_ops.zeros(
1378            array_ops.stack(size), _infer_state_dtype(dtype, state))
1379        shape = _concat(
1380            tensor_shape.dimension_value(fixed_batch_size),
1381            output_size,
1382            static=True)
1383        output.set_shape(tensor_shape.TensorShape(shape))
1384        return output
1385
1386      output_size = cell.output_size
1387      flat_output_size = nest.flatten(output_size)
1388      flat_zero_output = tuple(
1389          _create_zero_output(size) for size in flat_output_size)
1390      zero_output = nest.pack_sequence_as(
1391          structure=output_size, flat_sequence=flat_zero_output)
1392
1393      sequence_length = math_ops.cast(sequence_length, dtypes.int32)
1394      min_sequence_length = math_ops.reduce_min(sequence_length)
1395      max_sequence_length = math_ops.reduce_max(sequence_length)
1396
1397    for time, input_ in enumerate(inputs):
1398      if time > 0:
1399        varscope.reuse_variables()
1400      # pylint: disable=cell-var-from-loop
1401      call_cell = lambda: cell(input_, state)
1402      # pylint: enable=cell-var-from-loop
1403      if sequence_length is not None:
1404        (output, state) = _rnn_step(
1405            time=time,
1406            sequence_length=sequence_length,
1407            min_sequence_length=min_sequence_length,
1408            max_sequence_length=max_sequence_length,
1409            zero_output=zero_output,
1410            state=state,
1411            call_cell=call_cell,
1412            state_size=cell.state_size)
1413      else:
1414        (output, state) = call_cell()
1415      outputs.append(output)
1416
1417    return (outputs, state)
1418
1419
1420@deprecation.deprecated(None,
1421                        "Please use `keras.layers.RNN(cell, stateful=True)`, "
1422                        "which is equivalent to this API")
1423@tf_export(v1=["nn.static_state_saving_rnn"])
1424@dispatch.add_dispatch_support
1425def static_state_saving_rnn(cell,
1426                            inputs,
1427                            state_saver,
1428                            state_name,
1429                            sequence_length=None,
1430                            scope=None):
1431  """RNN that accepts a state saver for time-truncated RNN calculation.
1432
1433  Args:
1434    cell: An instance of `RNNCell`.
1435    inputs: A length T list of inputs, each a `Tensor` of shape `[batch_size,
1436      input_size]`.
1437    state_saver: A state saver object with methods `state` and `save_state`.
1438    state_name: Python string or tuple of strings.  The name to use with the
1439      state_saver. If the cell returns tuples of states (i.e., `cell.state_size`
1440      is a tuple) then `state_name` should be a tuple of strings having the same
1441      length as `cell.state_size`.  Otherwise it should be a single string.
1442    sequence_length: (optional) An int32/int64 vector size [batch_size]. See the
1443      documentation for rnn() for more details about sequence_length.
1444    scope: VariableScope for the created subgraph; defaults to "rnn".
1445
1446  Returns:
1447    A pair (outputs, state) where:
1448      outputs is a length T list of outputs (one for each input)
1449      states is the final state
1450
1451  Raises:
1452    TypeError: If `cell` is not an instance of RNNCell.
1453    ValueError: If `inputs` is `None` or an empty list, or if the arity and
1454     type of `state_name` does not match that of `cell.state_size`.
1455  """
1456  state_size = cell.state_size
1457  state_is_tuple = nest.is_sequence(state_size)
1458  state_name_tuple = nest.is_sequence(state_name)
1459
1460  if state_is_tuple != state_name_tuple:
1461    raise ValueError("state_name should be the same type as cell.state_size.  "
1462                     "state_name: %s, cell.state_size: %s" %
1463                     (str(state_name), str(state_size)))
1464
1465  if state_is_tuple:
1466    state_name_flat = nest.flatten(state_name)
1467    state_size_flat = nest.flatten(state_size)
1468
1469    if len(state_name_flat) != len(state_size_flat):
1470      raise ValueError("#elems(state_name) != #elems(state_size): %d vs. %d" %
1471                       (len(state_name_flat), len(state_size_flat)))
1472
1473    initial_state = nest.pack_sequence_as(
1474        structure=state_size,
1475        flat_sequence=[state_saver.state(s) for s in state_name_flat])
1476  else:
1477    initial_state = state_saver.state(state_name)
1478
1479  (outputs, state) = static_rnn(
1480      cell,
1481      inputs,
1482      initial_state=initial_state,
1483      sequence_length=sequence_length,
1484      scope=scope)
1485
1486  if state_is_tuple:
1487    flat_state = nest.flatten(state)
1488    state_name = nest.flatten(state_name)
1489    save_state = [
1490        state_saver.save_state(name, substate)
1491        for name, substate in zip(state_name, flat_state)
1492    ]
1493  else:
1494    save_state = [state_saver.save_state(state_name, state)]
1495
1496  with ops.control_dependencies(save_state):
1497    last_output = outputs[-1]
1498    flat_last_output = nest.flatten(last_output)
1499    flat_last_output = [
1500        array_ops.identity(output) for output in flat_last_output
1501    ]
1502    outputs[-1] = nest.pack_sequence_as(
1503        structure=last_output, flat_sequence=flat_last_output)
1504
1505    if state_is_tuple:
1506      state = nest.pack_sequence_as(
1507          structure=state,
1508          flat_sequence=[array_ops.identity(s) for s in flat_state])
1509    else:
1510      state = array_ops.identity(state)
1511
1512  return (outputs, state)
1513
1514
1515@deprecation.deprecated(None, "Please use `keras.layers.Bidirectional("
1516                        "keras.layers.RNN(cell, unroll=True))`, which is "
1517                        "equivalent to this API")
1518@tf_export(v1=["nn.static_bidirectional_rnn"])
1519@dispatch.add_dispatch_support
1520def static_bidirectional_rnn(cell_fw,
1521                             cell_bw,
1522                             inputs,
1523                             initial_state_fw=None,
1524                             initial_state_bw=None,
1525                             dtype=None,
1526                             sequence_length=None,
1527                             scope=None):
1528  """Creates a bidirectional recurrent neural network.
1529
1530  Similar to the unidirectional case above (rnn) but takes input and builds
1531  independent forward and backward RNNs with the final forward and backward
1532  outputs depth-concatenated, such that the output will have the format
1533  [time][batch][cell_fw.output_size + cell_bw.output_size]. The input_size of
1534  forward and backward cell must match. The initial state for both directions
1535  is zero by default (but can be set optionally) and no intermediate states are
1536  ever returned -- the network is fully unrolled for the given (passed in)
1537  length(s) of the sequence(s) or completely unrolled if length(s) is not given.
1538
1539  Args:
1540    cell_fw: An instance of RNNCell, to be used for forward direction.
1541    cell_bw: An instance of RNNCell, to be used for backward direction.
1542    inputs: A length T list of inputs, each a tensor of shape [batch_size,
1543      input_size], or a nested tuple of such elements.
1544    initial_state_fw: (optional) An initial state for the forward RNN. This must
1545      be a tensor of appropriate type and shape `[batch_size,
1546      cell_fw.state_size]`. If `cell_fw.state_size` is a tuple, this should be a
1547      tuple of tensors having shapes `[batch_size, s] for s in
1548      cell_fw.state_size`.
1549    initial_state_bw: (optional) Same as for `initial_state_fw`, but using the
1550      corresponding properties of `cell_bw`.
1551    dtype: (optional) The data type for the initial state.  Required if either
1552      of the initial states are not provided.
1553    sequence_length: (optional) An int32/int64 vector, size `[batch_size]`,
1554      containing the actual lengths for each of the sequences.
1555    scope: VariableScope for the created subgraph; defaults to
1556      "bidirectional_rnn"
1557
1558  Returns:
1559    A tuple (outputs, output_state_fw, output_state_bw) where:
1560      outputs is a length `T` list of outputs (one for each input), which
1561        are depth-concatenated forward and backward outputs.
1562      output_state_fw is the final state of the forward rnn.
1563      output_state_bw is the final state of the backward rnn.
1564
1565  Raises:
1566    TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`.
1567    ValueError: If inputs is None or an empty list.
1568  """
1569  rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw)
1570  rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw)
1571  if not nest.is_sequence(inputs):
1572    raise TypeError("inputs must be a sequence")
1573  if not inputs:
1574    raise ValueError("inputs must not be empty")
1575
1576  with vs.variable_scope(scope or "bidirectional_rnn"):
1577    # Forward direction
1578    with vs.variable_scope("fw") as fw_scope:
1579      output_fw, output_state_fw = static_rnn(
1580          cell_fw,
1581          inputs,
1582          initial_state_fw,
1583          dtype,
1584          sequence_length,
1585          scope=fw_scope)
1586
1587    # Backward direction
1588    with vs.variable_scope("bw") as bw_scope:
1589      reversed_inputs = _reverse_seq(inputs, sequence_length)
1590      tmp, output_state_bw = static_rnn(
1591          cell_bw,
1592          reversed_inputs,
1593          initial_state_bw,
1594          dtype,
1595          sequence_length,
1596          scope=bw_scope)
1597
1598  output_bw = _reverse_seq(tmp, sequence_length)
1599  # Concat each of the forward/backward outputs
1600  flat_output_fw = nest.flatten(output_fw)
1601  flat_output_bw = nest.flatten(output_bw)
1602
1603  flat_outputs = tuple(
1604      array_ops.concat([fw, bw], 1)
1605      for fw, bw in zip(flat_output_fw, flat_output_bw))
1606
1607  outputs = nest.pack_sequence_as(
1608      structure=output_fw, flat_sequence=flat_outputs)
1609
1610  return (outputs, output_state_fw, output_state_bw)
1611