1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TfLite LSTMCell wrapper. 16 17TODO(renjieliu): Find a better home for this one. 18""" 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23from tensorflow.lite.python.op_hint import OpHint 24from tensorflow.python.eager import context 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import control_flow_ops 29from tensorflow.python.ops import control_flow_util 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops import rnn_cell_impl 32from tensorflow.python.ops import variable_scope as vs 33from tensorflow.python.ops.rnn import _best_effort_input_batch_size 34from tensorflow.python.ops.rnn import _dynamic_rnn_loop 35from tensorflow.python.ops.rnn import _should_cache 36from tensorflow.python.ops.rnn import _transpose_batch_time 37from tensorflow.python.util import deprecation 38from tensorflow.python.util import nest 39from tensorflow.python.util.tf_export import tf_export 40 41 42@tf_export(v1=["lite.experimental.nn.dynamic_rnn"]) 43@deprecation.deprecated( 44 None, "Use `keras.layers.LSTM` instead.") 45def dynamic_rnn(cell, 46 inputs, 47 sequence_length=None, 48 initial_state=None, 49 dtype=None, 50 parallel_iterations=None, 51 swap_memory=False, 52 time_major=True, 53 scope=None): 54 """Creates a recurrent neural network specified by RNNCell `cell`. 55 56 Performs fully dynamic unrolling of `inputs`. 57 58 Example: 59 60 ```python 61 # create a BasicRNNCell 62 rnn_cell = tf.compat.v1.nn.rnn_cell.BasicRNNCell(hidden_size) 63 64 # 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size] 65 66 # defining initial state 67 initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32) 68 69 # 'state' is a tensor of shape [batch_size, cell_state_size] 70 outputs, state = tf.compat.v1.nn.dynamic_rnn(rnn_cell, input_data, 71 initial_state=initial_state, 72 dtype=tf.float32) 73 ``` 74 75 ```python 76 # create 2 LSTMCells 77 rnn_layers = [tf.compat.v1.nn.rnn_cell.LSTMCell(size) for size in [128, 256]] 78 79 # create a RNN cell composed sequentially of a number of RNNCells 80 multi_rnn_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell(rnn_layers) 81 82 # 'outputs' is a tensor of shape [batch_size, max_time, 256] 83 # 'state' is a N-tuple where N is the number of LSTMCells containing a 84 # tf.nn.rnn_cell.LSTMStateTuple for each cell 85 outputs, state = tf.compat.v1.nn.dynamic_rnn(cell=multi_rnn_cell, 86 inputs=data, 87 dtype=tf.float32) 88 ``` 89 90 91 Args: 92 cell: An instance of RNNCell. 93 inputs: The RNN inputs. 94 If `time_major == False` (default), this must be a `Tensor` of shape: 95 `[batch_size, max_time, ...]`, or a nested tuple of such elements. 96 If `time_major == True`, this must be a `Tensor` of shape: `[max_time, 97 batch_size, ...]`, or a nested tuple of such elements. This may also be 98 a (possibly nested) tuple of Tensors satisfying this property. The 99 first two dimensions must match across all the inputs, but otherwise the 100 ranks and other shape components may differ. In this case, input to 101 `cell` at each time-step will replicate the structure of these tuples, 102 except for the time dimension (from which the time is taken). The input 103 to `cell` at each time step will be a `Tensor` or (possibly nested) 104 tuple of Tensors each with dimensions `[batch_size, ...]`. 105 sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. Used 106 to copy-through state and zero-out outputs when past a batch element's 107 sequence length. So it's more for performance than correctness. 108 initial_state: (optional) An initial state for the RNN. If `cell.state_size` 109 is an integer, this must be a `Tensor` of appropriate type and shape 110 `[batch_size, cell.state_size]`. If `cell.state_size` is a tuple, this 111 should be a tuple of tensors having shapes `[batch_size, s] for s in 112 cell.state_size`. 113 dtype: (optional) The data type for the initial state and expected output. 114 Required if initial_state is not provided or RNN state has a heterogeneous 115 dtype. 116 parallel_iterations: (Default: 32). The number of iterations to run in 117 parallel. Those operations which do not have any temporal dependency and 118 can be run in parallel, will be. This parameter trades off time for 119 space. Values >> 1 use more memory but take less time, while smaller 120 values use less memory but computations take longer. 121 swap_memory: Transparently swap the tensors produced in forward inference 122 but needed for back prop from GPU to CPU. This allows training RNNs which 123 would typically not fit on a single GPU, with very minimal (or no) 124 performance penalty. 125 time_major: The shape format of the `inputs` and `outputs` Tensors. If true, 126 these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false, 127 these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using 128 `time_major = True` is a bit more efficient because it avoids transposes 129 at the beginning and end of the RNN calculation. However, most TensorFlow 130 data is batch-major, so by default this function accepts input and emits 131 output in batch-major form. 132 scope: VariableScope for the created subgraph; defaults to "rnn". 133 134 Returns: 135 A pair (outputs, state) where: 136 137 outputs: The RNN output `Tensor`. 138 139 If time_major == False (default), this will be a `Tensor` shaped: 140 `[batch_size, max_time, cell.output_size]`. 141 142 If time_major == True, this will be a `Tensor` shaped: 143 `[max_time, batch_size, cell.output_size]`. 144 145 Note, if `cell.output_size` is a (possibly nested) tuple of integers 146 or `TensorShape` objects, then `outputs` will be a tuple having the 147 same structure as `cell.output_size`, containing Tensors having shapes 148 corresponding to the shape data in `cell.output_size`. 149 150 state: The final state. If `cell.state_size` is an int, this 151 will be shaped `[batch_size, cell.state_size]`. If it is a 152 `TensorShape`, this will be shaped `[batch_size] + cell.state_size`. 153 If it is a (possibly nested) tuple of ints or `TensorShape`, this will 154 be a tuple having the corresponding shapes. If cells are `LSTMCells` 155 `state` will be a tuple containing a `LSTMStateTuple` for each cell. 156 157 Raises: 158 TypeError: If `cell` is not an instance of RNNCell. 159 ValueError: If inputs is None or an empty list. 160 RuntimeError: If not using control flow v2. 161 """ 162 163 # Currently only support time_major == True case. 164 assert time_major 165 166 # TODO(b/123051275): We need to check if the cells are TfLiteLSTMCells or 167 # TfLiteRNNCells. 168 rnn_cell_impl.assert_like_rnncell("cell", cell) 169 170 if not control_flow_util.ENABLE_CONTROL_FLOW_V2: 171 raise RuntimeError("OpHint dynamic rnn only supports control flow v2.") 172 173 parent_first_child_input = [{ 174 "parent_ophint_input_index": 0, 175 "first_child_ophint_input_index": 0 176 }] 177 parent_last_child_output = [{ 178 "parent_output_index": 0, 179 # For LstmCell, the index is 2. 180 # For RnnCell, the index is 1. 181 # So we use -1 meaning it's the last one. 182 "child_output_index": -1 183 }] 184 internal_children_input_output = [{ 185 "child_input_index": 0, 186 # For LstmCell, the index is 2. 187 # For RnnCell, the index is 1. 188 # So we use -1 meaning it's the last one. 189 "child_output_index": -1 190 }] 191 inputs_outputs_mappings = { 192 "parent_first_child_input": parent_first_child_input, 193 "parent_last_child_output": parent_last_child_output, 194 "internal_children_input_output": internal_children_input_output 195 } 196 tflite_wrapper = OpHint( 197 "TfLiteDynamicRnn", 198 level=2, 199 children_inputs_mappings=inputs_outputs_mappings) 200 with vs.variable_scope(scope or "rnn") as varscope: 201 # Create a new scope in which the caching device is either 202 # determined by the parent scope, or is set to place the cached 203 # Variable using the same placement as for the rest of the RNN. 204 if _should_cache(): 205 if varscope.caching_device is None: 206 varscope.set_caching_device(lambda op: op.device) 207 208 inputs = tflite_wrapper.add_input(inputs, name="input", index_override=0) 209 210 # By default, time_major==False and inputs are batch-major: shaped 211 # [batch, time, depth] 212 # For internal calculations, we transpose to [time, batch, depth] 213 flat_input = nest.flatten(inputs) 214 215 if not time_major: 216 # (batch, time, depth) => (time, batch, depth) 217 flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input] 218 flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input) 219 220 parallel_iterations = parallel_iterations or 32 221 if sequence_length is not None: 222 sequence_length = math_ops.cast(sequence_length, dtypes.int32) 223 if sequence_length.shape.rank not in (None, 1): 224 raise ValueError( 225 "sequence_length must be a vector of length batch_size, " 226 "but saw shape: %s" % sequence_length.shape) 227 sequence_length = array_ops.identity( # Just to find it in the graph. 228 sequence_length, 229 name="sequence_length") 230 231 batch_size = _best_effort_input_batch_size(flat_input) 232 233 if initial_state is not None: 234 state = initial_state 235 else: 236 if not dtype: 237 raise ValueError("If there is no initial_state, you must give a dtype.") 238 if getattr(cell, "get_initial_state", None) is not None: 239 state = cell.get_initial_state( 240 inputs=None, batch_size=batch_size, dtype=dtype) 241 else: 242 state = cell.zero_state(batch_size, dtype) 243 244 def _assert_has_shape(x, shape): 245 x_shape = array_ops.shape(x) 246 packed_shape = array_ops.stack(shape) 247 return control_flow_ops.Assert( 248 math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)), [ 249 "Expected shape for Tensor %s is " % x.name, packed_shape, 250 " but saw shape: ", x_shape 251 ]) 252 253 if not context.executing_eagerly() and sequence_length is not None: 254 # Perform some shape validation 255 with ops.control_dependencies( 256 [_assert_has_shape(sequence_length, [batch_size])]): 257 sequence_length = array_ops.identity( 258 sequence_length, name="CheckSeqLen") 259 260 inputs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input) 261 262 outputs, final_state = _dynamic_rnn_loop( 263 cell, 264 inputs, 265 state, 266 parallel_iterations=parallel_iterations, 267 swap_memory=swap_memory, 268 sequence_length=sequence_length, 269 dtype=dtype) 270 271 # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth]. 272 # If we are performing batch-major calculations, transpose output back 273 # to shape [batch, time, depth] 274 if not time_major: 275 # (time, batch, depth) => (batch, time, depth) 276 outputs = nest.map_structure(_transpose_batch_time, outputs) 277 outputs = tflite_wrapper.add_output(outputs, name="outputs") 278 279 return outputs, final_state 280 281 282def bidirectional_dynamic_rnn(cell_fw, 283 cell_bw, 284 inputs, 285 sequence_length=None, 286 initial_state_fw=None, 287 initial_state_bw=None, 288 dtype=None, 289 parallel_iterations=None, 290 swap_memory=False, 291 time_major=False, 292 scope=None): 293 """Creates a dynamic version of bidirectional recurrent neural network. 294 295 Takes input and builds independent forward and backward RNNs. The input_size 296 of forward and backward cell must match. The initial state for both directions 297 is zero by default (but can be set optionally) and no intermediate states are 298 ever returned -- the network is fully unrolled for the given (passed in) 299 length(s) of the sequence(s) or completely unrolled if length(s) is not 300 given. 301 302 Args: 303 cell_fw: An instance of RNNCell, to be used for forward direction. 304 cell_bw: An instance of RNNCell, to be used for backward direction. 305 inputs: The RNN inputs. 306 If time_major == False (default), this must be a tensor of shape: 307 `[batch_size, max_time, ...]`, or a nested tuple of such elements. 308 If time_major == True, this must be a tensor of shape: `[max_time, 309 batch_size, ...]`, or a nested tuple of such elements. 310 sequence_length: (optional) An int32/int64 vector, size `[batch_size]`, 311 containing the actual lengths for each of the sequences in the batch. If 312 not provided, all batch entries are assumed to be full sequences; and time 313 reversal is applied from time `0` to `max_time` for each sequence. 314 initial_state_fw: (optional) An initial state for the forward RNN. This must 315 be a tensor of appropriate type and shape `[batch_size, 316 cell_fw.state_size]`. If `cell_fw.state_size` is a tuple, this should be a 317 tuple of tensors having shapes `[batch_size, s] for s in 318 cell_fw.state_size`. 319 initial_state_bw: (optional) Same as for `initial_state_fw`, but using the 320 corresponding properties of `cell_bw`. 321 dtype: (optional) The data type for the initial states and expected output. 322 Required if initial_states are not provided or RNN states have a 323 heterogeneous dtype. 324 parallel_iterations: (Default: 32). The number of iterations to run in 325 parallel. Those operations which do not have any temporal dependency and 326 can be run in parallel, will be. This parameter trades off time for 327 space. Values >> 1 use more memory but take less time, while smaller 328 values use less memory but computations take longer. 329 swap_memory: Transparently swap the tensors produced in forward inference 330 but needed for back prop from GPU to CPU. This allows training RNNs which 331 would typically not fit on a single GPU, with very minimal (or no) 332 performance penalty. 333 time_major: The shape format of the `inputs` and `outputs` Tensors. If true, 334 these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false, 335 these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using 336 `time_major = True` is a bit more efficient because it avoids transposes 337 at the beginning and end of the RNN calculation. However, most TensorFlow 338 data is batch-major, so by default this function accepts input and emits 339 output in batch-major form. 340 scope: VariableScope for the created subgraph; defaults to 341 "bidirectional_rnn" 342 343 Returns: 344 A tuple (outputs, output_states) where: 345 outputs: A tuple (output_fw, output_bw) containing the forward and 346 the backward rnn output `Tensor`. 347 If time_major == False (default), 348 output_fw will be a `Tensor` shaped: 349 `[batch_size, max_time, cell_fw.output_size]` 350 and output_bw will be a `Tensor` shaped: 351 `[batch_size, max_time, cell_bw.output_size]`. 352 If time_major == True, 353 output_fw will be a `Tensor` shaped: 354 `[max_time, batch_size, cell_fw.output_size]` 355 and output_bw will be a `Tensor` shaped: 356 `[max_time, batch_size, cell_bw.output_size]`. 357 It returns a tuple instead of a single concatenated `Tensor`, unlike 358 in the `bidirectional_rnn`. If the concatenated one is preferred, 359 the forward and backward outputs can be concatenated as 360 `tf.concat(outputs, 2)`. 361 output_states: A tuple (output_state_fw, output_state_bw) containing 362 the forward and the backward final states of bidirectional rnn. 363 364 Raises: 365 TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`. 366 """ 367 rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw) 368 rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw) 369 370 with vs.variable_scope(scope or "bidirectional_rnn"): 371 # Forward direction 372 with vs.variable_scope("fw") as fw_scope: 373 output_fw, output_state_fw = dynamic_rnn( 374 cell=cell_fw, 375 inputs=inputs, 376 sequence_length=sequence_length, 377 initial_state=initial_state_fw, 378 dtype=dtype, 379 parallel_iterations=parallel_iterations, 380 swap_memory=swap_memory, 381 time_major=time_major, 382 scope=fw_scope) 383 384 # Backward direction 385 if not time_major: 386 time_axis = 1 387 batch_axis = 0 388 else: 389 time_axis = 0 390 batch_axis = 1 391 392 def _reverse(input_, seq_lengths, seq_axis, batch_axis): 393 if seq_lengths is not None: 394 return array_ops.reverse_sequence( 395 input=input_, 396 seq_lengths=seq_lengths, 397 seq_axis=seq_axis, 398 batch_axis=batch_axis) 399 else: 400 return array_ops.reverse(input_, axis=[seq_axis]) 401 402 with vs.variable_scope("bw") as bw_scope: 403 404 def _map_reverse(inp): 405 return _reverse( 406 inp, 407 seq_lengths=sequence_length, 408 seq_axis=time_axis, 409 batch_axis=batch_axis) 410 411 inputs_reverse = nest.map_structure(_map_reverse, inputs) 412 tmp, output_state_bw = dynamic_rnn( 413 cell=cell_bw, 414 inputs=inputs_reverse, 415 sequence_length=sequence_length, 416 initial_state=initial_state_bw, 417 dtype=dtype, 418 parallel_iterations=parallel_iterations, 419 swap_memory=swap_memory, 420 time_major=time_major, 421 scope=bw_scope) 422 423 output_bw = _reverse( 424 tmp, 425 seq_lengths=sequence_length, 426 seq_axis=time_axis, 427 batch_axis=batch_axis) 428 429 outputs = (output_fw, output_bw) 430 output_states = (output_state_fw, output_state_bw) 431 432 return (outputs, output_states) 433