1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TfLite LSTMCell wrapper.
16
17TODO(renjieliu): Find a better home for this one.
18"""
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23from tensorflow.lite.python.op_hint import OpHint
24from tensorflow.python.eager import context
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import control_flow_ops
29from tensorflow.python.ops import control_flow_util
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import rnn_cell_impl
32from tensorflow.python.ops import variable_scope as vs
33from tensorflow.python.ops.rnn import _best_effort_input_batch_size
34from tensorflow.python.ops.rnn import _dynamic_rnn_loop
35from tensorflow.python.ops.rnn import _should_cache
36from tensorflow.python.ops.rnn import _transpose_batch_time
37from tensorflow.python.util import deprecation
38from tensorflow.python.util import nest
39from tensorflow.python.util.tf_export import tf_export
40
41
42@tf_export(v1=["lite.experimental.nn.dynamic_rnn"])
43@deprecation.deprecated(
44    None, "Use `keras.layers.LSTM` instead.")
45def dynamic_rnn(cell,
46                inputs,
47                sequence_length=None,
48                initial_state=None,
49                dtype=None,
50                parallel_iterations=None,
51                swap_memory=False,
52                time_major=True,
53                scope=None):
54  """Creates a recurrent neural network specified by RNNCell `cell`.
55
56  Performs fully dynamic unrolling of `inputs`.
57
58  Example:
59
60  ```python
61  # create a BasicRNNCell
62  rnn_cell = tf.compat.v1.nn.rnn_cell.BasicRNNCell(hidden_size)
63
64  # 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]
65
66  # defining initial state
67  initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)
68
69  # 'state' is a tensor of shape [batch_size, cell_state_size]
70  outputs, state = tf.compat.v1.nn.dynamic_rnn(rnn_cell, input_data,
71                                     initial_state=initial_state,
72                                     dtype=tf.float32)
73  ```
74
75  ```python
76  # create 2 LSTMCells
77  rnn_layers = [tf.compat.v1.nn.rnn_cell.LSTMCell(size) for size in [128, 256]]
78
79  # create a RNN cell composed sequentially of a number of RNNCells
80  multi_rnn_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell(rnn_layers)
81
82  # 'outputs' is a tensor of shape [batch_size, max_time, 256]
83  # 'state' is a N-tuple where N is the number of LSTMCells containing a
84  # tf.nn.rnn_cell.LSTMStateTuple for each cell
85  outputs, state = tf.compat.v1.nn.dynamic_rnn(cell=multi_rnn_cell,
86                                     inputs=data,
87                                     dtype=tf.float32)
88  ```
89
90
91  Args:
92    cell: An instance of RNNCell.
93    inputs: The RNN inputs.
94      If `time_major == False` (default), this must be a `Tensor` of shape:
95        `[batch_size, max_time, ...]`, or a nested tuple of such elements.
96      If `time_major == True`, this must be a `Tensor` of shape: `[max_time,
97        batch_size, ...]`, or a nested tuple of such elements. This may also be
98        a (possibly nested) tuple of Tensors satisfying this property.  The
99        first two dimensions must match across all the inputs, but otherwise the
100        ranks and other shape components may differ. In this case, input to
101        `cell` at each time-step will replicate the structure of these tuples,
102        except for the time dimension (from which the time is taken). The input
103        to `cell` at each time step will be a `Tensor` or (possibly nested)
104        tuple of Tensors each with dimensions `[batch_size, ...]`.
105    sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. Used
106      to copy-through state and zero-out outputs when past a batch element's
107      sequence length.  So it's more for performance than correctness.
108    initial_state: (optional) An initial state for the RNN. If `cell.state_size`
109      is an integer, this must be a `Tensor` of appropriate type and shape
110      `[batch_size, cell.state_size]`. If `cell.state_size` is a tuple, this
111      should be a tuple of tensors having shapes `[batch_size, s] for s in
112      cell.state_size`.
113    dtype: (optional) The data type for the initial state and expected output.
114      Required if initial_state is not provided or RNN state has a heterogeneous
115      dtype.
116    parallel_iterations: (Default: 32).  The number of iterations to run in
117      parallel.  Those operations which do not have any temporal dependency and
118      can be run in parallel, will be.  This parameter trades off time for
119      space.  Values >> 1 use more memory but take less time, while smaller
120      values use less memory but computations take longer.
121    swap_memory: Transparently swap the tensors produced in forward inference
122      but needed for back prop from GPU to CPU.  This allows training RNNs which
123      would typically not fit on a single GPU, with very minimal (or no)
124      performance penalty.
125    time_major: The shape format of the `inputs` and `outputs` Tensors. If true,
126      these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false,
127      these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using
128      `time_major = True` is a bit more efficient because it avoids transposes
129      at the beginning and end of the RNN calculation.  However, most TensorFlow
130      data is batch-major, so by default this function accepts input and emits
131      output in batch-major form.
132    scope: VariableScope for the created subgraph; defaults to "rnn".
133
134  Returns:
135    A pair (outputs, state) where:
136
137    outputs: The RNN output `Tensor`.
138
139      If time_major == False (default), this will be a `Tensor` shaped:
140        `[batch_size, max_time, cell.output_size]`.
141
142      If time_major == True, this will be a `Tensor` shaped:
143        `[max_time, batch_size, cell.output_size]`.
144
145      Note, if `cell.output_size` is a (possibly nested) tuple of integers
146      or `TensorShape` objects, then `outputs` will be a tuple having the
147      same structure as `cell.output_size`, containing Tensors having shapes
148      corresponding to the shape data in `cell.output_size`.
149
150    state: The final state.  If `cell.state_size` is an int, this
151      will be shaped `[batch_size, cell.state_size]`.  If it is a
152      `TensorShape`, this will be shaped `[batch_size] + cell.state_size`.
153      If it is a (possibly nested) tuple of ints or `TensorShape`, this will
154      be a tuple having the corresponding shapes. If cells are `LSTMCells`
155      `state` will be a tuple containing a `LSTMStateTuple` for each cell.
156
157  Raises:
158    TypeError: If `cell` is not an instance of RNNCell.
159    ValueError: If inputs is None or an empty list.
160    RuntimeError: If not using control flow v2.
161  """
162
163  # Currently only support time_major == True case.
164  assert time_major
165
166  # TODO(b/123051275): We need to check if the cells are TfLiteLSTMCells or
167  # TfLiteRNNCells.
168  rnn_cell_impl.assert_like_rnncell("cell", cell)
169
170  if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
171    raise RuntimeError("OpHint dynamic rnn only supports control flow v2.")
172
173  parent_first_child_input = [{
174      "parent_ophint_input_index": 0,
175      "first_child_ophint_input_index": 0
176  }]
177  parent_last_child_output = [{
178      "parent_output_index": 0,
179      # For LstmCell, the index is 2.
180      # For RnnCell, the index is 1.
181      # So we use -1 meaning it's the last one.
182      "child_output_index": -1
183  }]
184  internal_children_input_output = [{
185      "child_input_index": 0,
186      # For LstmCell, the index is 2.
187      # For RnnCell, the index is 1.
188      # So we use -1 meaning it's the last one.
189      "child_output_index": -1
190  }]
191  inputs_outputs_mappings = {
192      "parent_first_child_input": parent_first_child_input,
193      "parent_last_child_output": parent_last_child_output,
194      "internal_children_input_output": internal_children_input_output
195  }
196  tflite_wrapper = OpHint(
197      "TfLiteDynamicRnn",
198      level=2,
199      children_inputs_mappings=inputs_outputs_mappings)
200  with vs.variable_scope(scope or "rnn") as varscope:
201    # Create a new scope in which the caching device is either
202    # determined by the parent scope, or is set to place the cached
203    # Variable using the same placement as for the rest of the RNN.
204    if _should_cache():
205      if varscope.caching_device is None:
206        varscope.set_caching_device(lambda op: op.device)
207
208    inputs = tflite_wrapper.add_input(inputs, name="input", index_override=0)
209
210    # By default, time_major==False and inputs are batch-major: shaped
211    #   [batch, time, depth]
212    # For internal calculations, we transpose to [time, batch, depth]
213    flat_input = nest.flatten(inputs)
214
215    if not time_major:
216      # (batch, time, depth) => (time, batch, depth)
217      flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input]
218      flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input)
219
220    parallel_iterations = parallel_iterations or 32
221    if sequence_length is not None:
222      sequence_length = math_ops.cast(sequence_length, dtypes.int32)
223      if sequence_length.shape.rank not in (None, 1):
224        raise ValueError(
225            "sequence_length must be a vector of length batch_size, "
226            "but saw shape: %s" % sequence_length.shape)
227      sequence_length = array_ops.identity(  # Just to find it in the graph.
228          sequence_length,
229          name="sequence_length")
230
231    batch_size = _best_effort_input_batch_size(flat_input)
232
233    if initial_state is not None:
234      state = initial_state
235    else:
236      if not dtype:
237        raise ValueError("If there is no initial_state, you must give a dtype.")
238      if getattr(cell, "get_initial_state", None) is not None:
239        state = cell.get_initial_state(
240            inputs=None, batch_size=batch_size, dtype=dtype)
241      else:
242        state = cell.zero_state(batch_size, dtype)
243
244    def _assert_has_shape(x, shape):
245      x_shape = array_ops.shape(x)
246      packed_shape = array_ops.stack(shape)
247      return control_flow_ops.Assert(
248          math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)), [
249              "Expected shape for Tensor %s is " % x.name, packed_shape,
250              " but saw shape: ", x_shape
251          ])
252
253    if not context.executing_eagerly() and sequence_length is not None:
254      # Perform some shape validation
255      with ops.control_dependencies(
256          [_assert_has_shape(sequence_length, [batch_size])]):
257        sequence_length = array_ops.identity(
258            sequence_length, name="CheckSeqLen")
259
260    inputs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input)
261
262    outputs, final_state = _dynamic_rnn_loop(
263        cell,
264        inputs,
265        state,
266        parallel_iterations=parallel_iterations,
267        swap_memory=swap_memory,
268        sequence_length=sequence_length,
269        dtype=dtype)
270
271    # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth].
272    # If we are performing batch-major calculations, transpose output back
273    # to shape [batch, time, depth]
274    if not time_major:
275      # (time, batch, depth) => (batch, time, depth)
276      outputs = nest.map_structure(_transpose_batch_time, outputs)
277    outputs = tflite_wrapper.add_output(outputs, name="outputs")
278
279    return outputs, final_state
280
281
282def bidirectional_dynamic_rnn(cell_fw,
283                              cell_bw,
284                              inputs,
285                              sequence_length=None,
286                              initial_state_fw=None,
287                              initial_state_bw=None,
288                              dtype=None,
289                              parallel_iterations=None,
290                              swap_memory=False,
291                              time_major=False,
292                              scope=None):
293  """Creates a dynamic version of bidirectional recurrent neural network.
294
295  Takes input and builds independent forward and backward RNNs. The input_size
296  of forward and backward cell must match. The initial state for both directions
297  is zero by default (but can be set optionally) and no intermediate states are
298  ever returned -- the network is fully unrolled for the given (passed in)
299  length(s) of the sequence(s) or completely unrolled if length(s) is not
300  given.
301
302  Args:
303    cell_fw: An instance of RNNCell, to be used for forward direction.
304    cell_bw: An instance of RNNCell, to be used for backward direction.
305    inputs: The RNN inputs.
306      If time_major == False (default), this must be a tensor of shape:
307        `[batch_size, max_time, ...]`, or a nested tuple of such elements.
308      If time_major == True, this must be a tensor of shape: `[max_time,
309        batch_size, ...]`, or a nested tuple of such elements.
310    sequence_length: (optional) An int32/int64 vector, size `[batch_size]`,
311      containing the actual lengths for each of the sequences in the batch. If
312      not provided, all batch entries are assumed to be full sequences; and time
313      reversal is applied from time `0` to `max_time` for each sequence.
314    initial_state_fw: (optional) An initial state for the forward RNN. This must
315      be a tensor of appropriate type and shape `[batch_size,
316      cell_fw.state_size]`. If `cell_fw.state_size` is a tuple, this should be a
317      tuple of tensors having shapes `[batch_size, s] for s in
318      cell_fw.state_size`.
319    initial_state_bw: (optional) Same as for `initial_state_fw`, but using the
320      corresponding properties of `cell_bw`.
321    dtype: (optional) The data type for the initial states and expected output.
322      Required if initial_states are not provided or RNN states have a
323      heterogeneous dtype.
324    parallel_iterations: (Default: 32).  The number of iterations to run in
325      parallel.  Those operations which do not have any temporal dependency and
326      can be run in parallel, will be.  This parameter trades off time for
327      space.  Values >> 1 use more memory but take less time, while smaller
328      values use less memory but computations take longer.
329    swap_memory: Transparently swap the tensors produced in forward inference
330      but needed for back prop from GPU to CPU.  This allows training RNNs which
331      would typically not fit on a single GPU, with very minimal (or no)
332      performance penalty.
333    time_major: The shape format of the `inputs` and `outputs` Tensors. If true,
334      these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false,
335      these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using
336      `time_major = True` is a bit more efficient because it avoids transposes
337      at the beginning and end of the RNN calculation.  However, most TensorFlow
338      data is batch-major, so by default this function accepts input and emits
339      output in batch-major form.
340    scope: VariableScope for the created subgraph; defaults to
341      "bidirectional_rnn"
342
343  Returns:
344    A tuple (outputs, output_states) where:
345      outputs: A tuple (output_fw, output_bw) containing the forward and
346        the backward rnn output `Tensor`.
347        If time_major == False (default),
348          output_fw will be a `Tensor` shaped:
349          `[batch_size, max_time, cell_fw.output_size]`
350          and output_bw will be a `Tensor` shaped:
351          `[batch_size, max_time, cell_bw.output_size]`.
352        If time_major == True,
353          output_fw will be a `Tensor` shaped:
354          `[max_time, batch_size, cell_fw.output_size]`
355          and output_bw will be a `Tensor` shaped:
356          `[max_time, batch_size, cell_bw.output_size]`.
357        It returns a tuple instead of a single concatenated `Tensor`, unlike
358        in the `bidirectional_rnn`. If the concatenated one is preferred,
359        the forward and backward outputs can be concatenated as
360        `tf.concat(outputs, 2)`.
361      output_states: A tuple (output_state_fw, output_state_bw) containing
362        the forward and the backward final states of bidirectional rnn.
363
364  Raises:
365    TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`.
366  """
367  rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw)
368  rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw)
369
370  with vs.variable_scope(scope or "bidirectional_rnn"):
371    # Forward direction
372    with vs.variable_scope("fw") as fw_scope:
373      output_fw, output_state_fw = dynamic_rnn(
374          cell=cell_fw,
375          inputs=inputs,
376          sequence_length=sequence_length,
377          initial_state=initial_state_fw,
378          dtype=dtype,
379          parallel_iterations=parallel_iterations,
380          swap_memory=swap_memory,
381          time_major=time_major,
382          scope=fw_scope)
383
384    # Backward direction
385    if not time_major:
386      time_axis = 1
387      batch_axis = 0
388    else:
389      time_axis = 0
390      batch_axis = 1
391
392    def _reverse(input_, seq_lengths, seq_axis, batch_axis):
393      if seq_lengths is not None:
394        return array_ops.reverse_sequence(
395            input=input_,
396            seq_lengths=seq_lengths,
397            seq_axis=seq_axis,
398            batch_axis=batch_axis)
399      else:
400        return array_ops.reverse(input_, axis=[seq_axis])
401
402    with vs.variable_scope("bw") as bw_scope:
403
404      def _map_reverse(inp):
405        return _reverse(
406            inp,
407            seq_lengths=sequence_length,
408            seq_axis=time_axis,
409            batch_axis=batch_axis)
410
411      inputs_reverse = nest.map_structure(_map_reverse, inputs)
412      tmp, output_state_bw = dynamic_rnn(
413          cell=cell_bw,
414          inputs=inputs_reverse,
415          sequence_length=sequence_length,
416          initial_state=initial_state_bw,
417          dtype=dtype,
418          parallel_iterations=parallel_iterations,
419          swap_memory=swap_memory,
420          time_major=time_major,
421          scope=bw_scope)
422
423  output_bw = _reverse(
424      tmp,
425      seq_lengths=sequence_length,
426      seq_axis=time_axis,
427      batch_axis=batch_axis)
428
429  outputs = (output_fw, output_bw)
430  output_states = (output_state_fw, output_state_bw)
431
432  return (outputs, output_states)
433