1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TfLite LSTMCell wrapper. 16 17TODO(renjieliu): Find a better home for this one. 18""" 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23import tensorflow.lite.python.op_hint as op_hint 24from tensorflow.python.eager import context 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import control_flow_ops 29from tensorflow.python.ops import control_flow_util 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops import rnn_cell_impl 32from tensorflow.python.ops import variable_scope as vs 33from tensorflow.python.ops.rnn import _best_effort_input_batch_size 34from tensorflow.python.ops.rnn import _dynamic_rnn_loop 35from tensorflow.python.ops.rnn import _should_cache 36from tensorflow.python.ops.rnn import _transpose_batch_time 37from tensorflow.python.util import nest 38from tensorflow.python.util.tf_export import tf_export 39 40 41@tf_export("lite.experimental.nn.dynamic_rnn") 42def dynamic_rnn(cell, 43 inputs, 44 sequence_length=None, 45 initial_state=None, 46 dtype=None, 47 parallel_iterations=None, 48 swap_memory=False, 49 time_major=True, 50 scope=None): 51 """Creates a recurrent neural network specified by RNNCell `cell`. 52 53 Performs fully dynamic unrolling of `inputs`. 54 55 Example: 56 57 ```python 58 # create a BasicRNNCell 59 rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) 60 61 # 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size] 62 63 # defining initial state 64 initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32) 65 66 # 'state' is a tensor of shape [batch_size, cell_state_size] 67 outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data, 68 initial_state=initial_state, 69 dtype=tf.float32) 70 ``` 71 72 ```python 73 # create 2 LSTMCells 74 rnn_layers = [tf.nn.rnn_cell.LSTMCell(size) for size in [128, 256]] 75 76 # create a RNN cell composed sequentially of a number of RNNCells 77 multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers) 78 79 # 'outputs' is a tensor of shape [batch_size, max_time, 256] 80 # 'state' is a N-tuple where N is the number of LSTMCells containing a 81 # tf.contrib.rnn.LSTMStateTuple for each cell 82 outputs, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell, 83 inputs=data, 84 dtype=tf.float32) 85 ``` 86 87 88 Args: 89 cell: An instance of RNNCell. 90 inputs: The RNN inputs. 91 If `time_major == False` (default), this must be a `Tensor` of shape: 92 `[batch_size, max_time, ...]`, or a nested tuple of such elements. 93 If `time_major == True`, this must be a `Tensor` of shape: `[max_time, 94 batch_size, ...]`, or a nested tuple of such elements. This may also be 95 a (possibly nested) tuple of Tensors satisfying this property. The 96 first two dimensions must match across all the inputs, but otherwise the 97 ranks and other shape components may differ. In this case, input to 98 `cell` at each time-step will replicate the structure of these tuples, 99 except for the time dimension (from which the time is taken). The input 100 to `cell` at each time step will be a `Tensor` or (possibly nested) 101 tuple of Tensors each with dimensions `[batch_size, ...]`. 102 sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. Used 103 to copy-through state and zero-out outputs when past a batch element's 104 sequence length. So it's more for performance than correctness. 105 initial_state: (optional) An initial state for the RNN. If `cell.state_size` 106 is an integer, this must be a `Tensor` of appropriate type and shape 107 `[batch_size, cell.state_size]`. If `cell.state_size` is a tuple, this 108 should be a tuple of tensors having shapes `[batch_size, s] for s in 109 cell.state_size`. 110 dtype: (optional) The data type for the initial state and expected output. 111 Required if initial_state is not provided or RNN state has a heterogeneous 112 dtype. 113 parallel_iterations: (Default: 32). The number of iterations to run in 114 parallel. Those operations which do not have any temporal dependency and 115 can be run in parallel, will be. This parameter trades off time for 116 space. Values >> 1 use more memory but take less time, while smaller 117 values use less memory but computations take longer. 118 swap_memory: Transparently swap the tensors produced in forward inference 119 but needed for back prop from GPU to CPU. This allows training RNNs which 120 would typically not fit on a single GPU, with very minimal (or no) 121 performance penalty. 122 time_major: The shape format of the `inputs` and `outputs` Tensors. If true, 123 these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false, 124 these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using 125 `time_major = True` is a bit more efficient because it avoids transposes 126 at the beginning and end of the RNN calculation. However, most TensorFlow 127 data is batch-major, so by default this function accepts input and emits 128 output in batch-major form. 129 scope: VariableScope for the created subgraph; defaults to "rnn". 130 131 Returns: 132 A pair (outputs, state) where: 133 134 outputs: The RNN output `Tensor`. 135 136 If time_major == False (default), this will be a `Tensor` shaped: 137 `[batch_size, max_time, cell.output_size]`. 138 139 If time_major == True, this will be a `Tensor` shaped: 140 `[max_time, batch_size, cell.output_size]`. 141 142 Note, if `cell.output_size` is a (possibly nested) tuple of integers 143 or `TensorShape` objects, then `outputs` will be a tuple having the 144 same structure as `cell.output_size`, containing Tensors having shapes 145 corresponding to the shape data in `cell.output_size`. 146 147 state: The final state. If `cell.state_size` is an int, this 148 will be shaped `[batch_size, cell.state_size]`. If it is a 149 `TensorShape`, this will be shaped `[batch_size] + cell.state_size`. 150 If it is a (possibly nested) tuple of ints or `TensorShape`, this will 151 be a tuple having the corresponding shapes. If cells are `LSTMCells` 152 `state` will be a tuple containing a `LSTMStateTuple` for each cell. 153 154 Raises: 155 TypeError: If `cell` is not an instance of RNNCell. 156 ValueError: If inputs is None or an empty list. 157 RuntimeError: If not using control flow v2. 158 """ 159 160 # Currently only support time_major == True case. 161 assert time_major 162 163 # TODO(b/123051275): We need to check if the cells are TfLiteLSTMCells or 164 # TfLiteRNNCells. 165 rnn_cell_impl.assert_like_rnncell("cell", cell) 166 167 if not control_flow_util.ENABLE_CONTROL_FLOW_V2: 168 raise RuntimeError("OpHint dynamic rnn only supports control flow v2.") 169 170 parent_first_child_input = [{ 171 "parent_ophint_input_index": 0, 172 "first_child_ophint_input_index": 0 173 }] 174 parent_last_child_output = [{ 175 "parent_output_index": 0, 176 # For LstmCell, the index is 2. 177 # For RnnCell, the index is 1. 178 # So we use -1 meaning it's the last one. 179 "child_output_index": -1 180 }] 181 internal_children_input_output = [{ 182 "child_input_index": 0, 183 # For LstmCell, the index is 2. 184 # For RnnCell, the index is 1. 185 # So we use -1 meaning it's the last one. 186 "child_output_index": -1 187 }] 188 inputs_outputs_mappings = { 189 "parent_first_child_input": parent_first_child_input, 190 "parent_last_child_output": parent_last_child_output, 191 "internal_children_input_output": internal_children_input_output 192 } 193 tflite_wrapper = op_hint.OpHint( 194 "TfLiteDynamicRnn", 195 level=2, 196 children_inputs_mappings=inputs_outputs_mappings) 197 with vs.variable_scope(scope or "rnn") as varscope: 198 # Create a new scope in which the caching device is either 199 # determined by the parent scope, or is set to place the cached 200 # Variable using the same placement as for the rest of the RNN. 201 if _should_cache(): 202 if varscope.caching_device is None: 203 varscope.set_caching_device(lambda op: op.device) 204 205 inputs = tflite_wrapper.add_input(inputs, name="input", index_override=0) 206 207 # By default, time_major==False and inputs are batch-major: shaped 208 # [batch, time, depth] 209 # For internal calculations, we transpose to [time, batch, depth] 210 flat_input = nest.flatten(inputs) 211 212 if not time_major: 213 # (batch, time, depth) => (time, batch, depth) 214 flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input] 215 flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input) 216 217 parallel_iterations = parallel_iterations or 32 218 if sequence_length is not None: 219 sequence_length = math_ops.cast(sequence_length, dtypes.int32) 220 if sequence_length.shape.rank not in (None, 1): 221 raise ValueError( 222 "sequence_length must be a vector of length batch_size, " 223 "but saw shape: %s" % sequence_length.shape) 224 sequence_length = array_ops.identity( # Just to find it in the graph. 225 sequence_length, 226 name="sequence_length") 227 228 batch_size = _best_effort_input_batch_size(flat_input) 229 230 if initial_state is not None: 231 state = initial_state 232 else: 233 if not dtype: 234 raise ValueError("If there is no initial_state, you must give a dtype.") 235 if getattr(cell, "get_initial_state", None) is not None: 236 state = cell.get_initial_state( 237 inputs=None, batch_size=batch_size, dtype=dtype) 238 else: 239 state = cell.zero_state(batch_size, dtype) 240 241 def _assert_has_shape(x, shape): 242 x_shape = array_ops.shape(x) 243 packed_shape = array_ops.stack(shape) 244 return control_flow_ops.Assert( 245 math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)), [ 246 "Expected shape for Tensor %s is " % x.name, packed_shape, 247 " but saw shape: ", x_shape 248 ]) 249 250 if not context.executing_eagerly() and sequence_length is not None: 251 # Perform some shape validation 252 with ops.control_dependencies( 253 [_assert_has_shape(sequence_length, [batch_size])]): 254 sequence_length = array_ops.identity( 255 sequence_length, name="CheckSeqLen") 256 257 inputs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input) 258 259 outputs, final_state = _dynamic_rnn_loop( 260 cell, 261 inputs, 262 state, 263 parallel_iterations=parallel_iterations, 264 swap_memory=swap_memory, 265 sequence_length=sequence_length, 266 dtype=dtype) 267 268 # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth]. 269 # If we are performing batch-major calculations, transpose output back 270 # to shape [batch, time, depth] 271 if not time_major: 272 # (time, batch, depth) => (batch, time, depth) 273 outputs = nest.map_structure(_transpose_batch_time, outputs) 274 outputs = tflite_wrapper.add_output(outputs, name="outputs") 275 276 return outputs, final_state 277 278 279def bidirectional_dynamic_rnn(cell_fw, 280 cell_bw, 281 inputs, 282 sequence_length=None, 283 initial_state_fw=None, 284 initial_state_bw=None, 285 dtype=None, 286 parallel_iterations=None, 287 swap_memory=False, 288 time_major=False, 289 scope=None): 290 """Creates a dynamic version of bidirectional recurrent neural network. 291 292 Takes input and builds independent forward and backward RNNs. The input_size 293 of forward and backward cell must match. The initial state for both directions 294 is zero by default (but can be set optionally) and no intermediate states are 295 ever returned -- the network is fully unrolled for the given (passed in) 296 length(s) of the sequence(s) or completely unrolled if length(s) is not 297 given. 298 299 Args: 300 cell_fw: An instance of RNNCell, to be used for forward direction. 301 cell_bw: An instance of RNNCell, to be used for backward direction. 302 inputs: The RNN inputs. 303 If time_major == False (default), this must be a tensor of shape: 304 `[batch_size, max_time, ...]`, or a nested tuple of such elements. 305 If time_major == True, this must be a tensor of shape: `[max_time, 306 batch_size, ...]`, or a nested tuple of such elements. 307 sequence_length: (optional) An int32/int64 vector, size `[batch_size]`, 308 containing the actual lengths for each of the sequences in the batch. If 309 not provided, all batch entries are assumed to be full sequences; and time 310 reversal is applied from time `0` to `max_time` for each sequence. 311 initial_state_fw: (optional) An initial state for the forward RNN. This must 312 be a tensor of appropriate type and shape `[batch_size, 313 cell_fw.state_size]`. If `cell_fw.state_size` is a tuple, this should be a 314 tuple of tensors having shapes `[batch_size, s] for s in 315 cell_fw.state_size`. 316 initial_state_bw: (optional) Same as for `initial_state_fw`, but using the 317 corresponding properties of `cell_bw`. 318 dtype: (optional) The data type for the initial states and expected output. 319 Required if initial_states are not provided or RNN states have a 320 heterogeneous dtype. 321 parallel_iterations: (Default: 32). The number of iterations to run in 322 parallel. Those operations which do not have any temporal dependency and 323 can be run in parallel, will be. This parameter trades off time for 324 space. Values >> 1 use more memory but take less time, while smaller 325 values use less memory but computations take longer. 326 swap_memory: Transparently swap the tensors produced in forward inference 327 but needed for back prop from GPU to CPU. This allows training RNNs which 328 would typically not fit on a single GPU, with very minimal (or no) 329 performance penalty. 330 time_major: The shape format of the `inputs` and `outputs` Tensors. If true, 331 these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false, 332 these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using 333 `time_major = True` is a bit more efficient because it avoids transposes 334 at the beginning and end of the RNN calculation. However, most TensorFlow 335 data is batch-major, so by default this function accepts input and emits 336 output in batch-major form. 337 scope: VariableScope for the created subgraph; defaults to 338 "bidirectional_rnn" 339 340 Returns: 341 A tuple (outputs, output_states) where: 342 outputs: A tuple (output_fw, output_bw) containing the forward and 343 the backward rnn output `Tensor`. 344 If time_major == False (default), 345 output_fw will be a `Tensor` shaped: 346 `[batch_size, max_time, cell_fw.output_size]` 347 and output_bw will be a `Tensor` shaped: 348 `[batch_size, max_time, cell_bw.output_size]`. 349 If time_major == True, 350 output_fw will be a `Tensor` shaped: 351 `[max_time, batch_size, cell_fw.output_size]` 352 and output_bw will be a `Tensor` shaped: 353 `[max_time, batch_size, cell_bw.output_size]`. 354 It returns a tuple instead of a single concatenated `Tensor`, unlike 355 in the `bidirectional_rnn`. If the concatenated one is preferred, 356 the forward and backward outputs can be concatenated as 357 `tf.concat(outputs, 2)`. 358 output_states: A tuple (output_state_fw, output_state_bw) containing 359 the forward and the backward final states of bidirectional rnn. 360 361 Raises: 362 TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`. 363 """ 364 rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw) 365 rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw) 366 367 with vs.variable_scope(scope or "bidirectional_rnn"): 368 # Forward direction 369 with vs.variable_scope("fw") as fw_scope: 370 output_fw, output_state_fw = dynamic_rnn( 371 cell=cell_fw, 372 inputs=inputs, 373 sequence_length=sequence_length, 374 initial_state=initial_state_fw, 375 dtype=dtype, 376 parallel_iterations=parallel_iterations, 377 swap_memory=swap_memory, 378 time_major=time_major, 379 scope=fw_scope) 380 381 # Backward direction 382 if not time_major: 383 time_axis = 1 384 batch_axis = 0 385 else: 386 time_axis = 0 387 batch_axis = 1 388 389 def _reverse(input_, seq_lengths, seq_axis, batch_axis): 390 if seq_lengths is not None: 391 return array_ops.reverse_sequence( 392 input=input_, 393 seq_lengths=seq_lengths, 394 seq_axis=seq_axis, 395 batch_axis=batch_axis) 396 else: 397 return array_ops.reverse(input_, axis=[seq_axis]) 398 399 with vs.variable_scope("bw") as bw_scope: 400 401 def _map_reverse(inp): 402 return _reverse( 403 inp, 404 seq_lengths=sequence_length, 405 seq_axis=time_axis, 406 batch_axis=batch_axis) 407 408 inputs_reverse = nest.map_structure(_map_reverse, inputs) 409 tmp, output_state_bw = dynamic_rnn( 410 cell=cell_bw, 411 inputs=inputs_reverse, 412 sequence_length=sequence_length, 413 initial_state=initial_state_bw, 414 dtype=dtype, 415 parallel_iterations=parallel_iterations, 416 swap_memory=swap_memory, 417 time_major=time_major, 418 scope=bw_scope) 419 420 output_bw = _reverse( 421 tmp, 422 seq_lengths=sequence_length, 423 seq_axis=time_axis, 424 batch_axis=batch_axis) 425 426 outputs = (output_fw, output_bw) 427 output_states = (output_state_fw, output_state_bw) 428 429 return (outputs, output_states) 430