1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""RNN helpers for TensorFlow models.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.eager import context 21from tensorflow.python.framework import constant_op 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import tensor_shape 25from tensorflow.python.framework import tensor_util 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import control_flow_ops 28from tensorflow.python.ops import control_flow_util 29from tensorflow.python.ops import control_flow_util_v2 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops import rnn_cell_impl 32from tensorflow.python.ops import tensor_array_ops 33from tensorflow.python.ops import variable_scope as vs 34from tensorflow.python.util import deprecation 35from tensorflow.python.util import dispatch 36from tensorflow.python.util import nest 37from tensorflow.python.util.tf_export import tf_export 38 39# pylint: disable=protected-access 40_concat = rnn_cell_impl._concat 41# pylint: enable=protected-access 42 43 44def _transpose_batch_time(x): 45 """Transposes the batch and time dimensions of a Tensor. 46 47 If the input tensor has rank < 2 it returns the original tensor. Retains as 48 much of the static shape information as possible. 49 50 Args: 51 x: A Tensor. 52 53 Returns: 54 x transposed along the first two dimensions. 55 """ 56 x_static_shape = x.get_shape() 57 if x_static_shape.rank is not None and x_static_shape.rank < 2: 58 return x 59 60 x_rank = array_ops.rank(x) 61 x_t = array_ops.transpose( 62 x, array_ops.concat(([1, 0], math_ops.range(2, x_rank)), axis=0)) 63 x_t.set_shape( 64 tensor_shape.TensorShape( 65 [x_static_shape.dims[1].value, 66 x_static_shape.dims[0].value]).concatenate(x_static_shape[2:])) 67 return x_t 68 69 70def _best_effort_input_batch_size(flat_input): 71 """Get static input batch size if available, with fallback to the dynamic one. 72 73 Args: 74 flat_input: An iterable of time major input Tensors of shape `[max_time, 75 batch_size, ...]`. All inputs should have compatible batch sizes. 76 77 Returns: 78 The batch size in Python integer if available, or a scalar Tensor otherwise. 79 80 Raises: 81 ValueError: if there is any input with an invalid shape. 82 """ 83 for input_ in flat_input: 84 shape = input_.shape 85 if shape.rank is None: 86 continue 87 if shape.rank < 2: 88 raise ValueError("Expected input tensor %s to have rank at least 2" % 89 input_) 90 batch_size = shape.dims[1].value 91 if batch_size is not None: 92 return batch_size 93 # Fallback to the dynamic batch size of the first input. 94 return array_ops.shape(flat_input[0])[1] 95 96 97def _infer_state_dtype(explicit_dtype, state): 98 """Infer the dtype of an RNN state. 99 100 Args: 101 explicit_dtype: explicitly declared dtype or None. 102 state: RNN's hidden state. Must be a Tensor or a nested iterable containing 103 Tensors. 104 105 Returns: 106 dtype: inferred dtype of hidden state. 107 108 Raises: 109 ValueError: if `state` has heterogeneous dtypes or is empty. 110 """ 111 if explicit_dtype is not None: 112 return explicit_dtype 113 elif nest.is_sequence(state): 114 inferred_dtypes = [element.dtype for element in nest.flatten(state)] 115 if not inferred_dtypes: 116 raise ValueError("Unable to infer dtype from empty state.") 117 all_same = all(x == inferred_dtypes[0] for x in inferred_dtypes) 118 if not all_same: 119 raise ValueError( 120 "State has tensors of different inferred_dtypes. Unable to infer a " 121 "single representative dtype.") 122 return inferred_dtypes[0] 123 else: 124 return state.dtype 125 126 127def _maybe_tensor_shape_from_tensor(shape): 128 if isinstance(shape, ops.Tensor): 129 return tensor_shape.as_shape(tensor_util.constant_value(shape)) 130 else: 131 return shape 132 133 134def _should_cache(): 135 """Returns True if a default caching device should be set, otherwise False.""" 136 if context.executing_eagerly(): 137 return False 138 # Don't set a caching device when running in a loop, since it is possible that 139 # train steps could be wrapped in a tf.while_loop. In that scenario caching 140 # prevents forward computations in loop iterations from re-reading the 141 # updated weights. 142 graph = ops.get_default_graph() 143 ctxt = graph._get_control_flow_context() # pylint: disable=protected-access 144 in_v1_while_loop = ( 145 control_flow_util.GetContainingWhileContext(ctxt) is not None) 146 in_v2_while_loop = control_flow_util_v2.in_while_loop_defun(graph) 147 return not in_v1_while_loop and not in_v2_while_loop 148 149 150# pylint: disable=unused-argument 151def _rnn_step(time, 152 sequence_length, 153 min_sequence_length, 154 max_sequence_length, 155 zero_output, 156 state, 157 call_cell, 158 state_size, 159 skip_conditionals=False): 160 """Calculate one step of a dynamic RNN minibatch. 161 162 Returns an (output, state) pair conditioned on `sequence_length`. 163 When skip_conditionals=False, the pseudocode is something like: 164 165 if t >= max_sequence_length: 166 return (zero_output, state) 167 if t < min_sequence_length: 168 return call_cell() 169 170 # Selectively output zeros or output, old state or new state depending 171 # on whether we've finished calculating each row. 172 new_output, new_state = call_cell() 173 final_output = np.vstack([ 174 zero_output if time >= sequence_length[r] else new_output_r 175 for r, new_output_r in enumerate(new_output) 176 ]) 177 final_state = np.vstack([ 178 state[r] if time >= sequence_length[r] else new_state_r 179 for r, new_state_r in enumerate(new_state) 180 ]) 181 return (final_output, final_state) 182 183 Args: 184 time: int32 `Tensor` scalar. 185 sequence_length: int32 `Tensor` vector of size [batch_size]. 186 min_sequence_length: int32 `Tensor` scalar, min of sequence_length. 187 max_sequence_length: int32 `Tensor` scalar, max of sequence_length. 188 zero_output: `Tensor` vector of shape [output_size]. 189 state: Either a single `Tensor` matrix of shape `[batch_size, state_size]`, 190 or a list/tuple of such tensors. 191 call_cell: lambda returning tuple of (new_output, new_state) where 192 new_output is a `Tensor` matrix of shape `[batch_size, output_size]`. 193 new_state is a `Tensor` matrix of shape `[batch_size, state_size]`. 194 state_size: The `cell.state_size` associated with the state. 195 skip_conditionals: Python bool, whether to skip using the conditional 196 calculations. This is useful for `dynamic_rnn`, where the input tensor 197 matches `max_sequence_length`, and using conditionals just slows 198 everything down. 199 200 Returns: 201 A tuple of (`final_output`, `final_state`) as given by the pseudocode above: 202 final_output is a `Tensor` matrix of shape [batch_size, output_size] 203 final_state is either a single `Tensor` matrix, or a tuple of such 204 matrices (matching length and shapes of input `state`). 205 206 Raises: 207 ValueError: If the cell returns a state tuple whose length does not match 208 that returned by `state_size`. 209 """ 210 211 # Convert state to a list for ease of use 212 flat_state = nest.flatten(state) 213 flat_zero_output = nest.flatten(zero_output) 214 215 # Vector describing which batch entries are finished. 216 copy_cond = time >= sequence_length 217 218 def _copy_one_through(output, new_output): 219 # TensorArray and scalar get passed through. 220 if isinstance(output, tensor_array_ops.TensorArray): 221 return new_output 222 if output.shape.rank == 0: 223 return new_output 224 # Otherwise propagate the old or the new value. 225 with ops.colocate_with(new_output): 226 return array_ops.where(copy_cond, output, new_output) 227 228 def _copy_some_through(flat_new_output, flat_new_state): 229 # Use broadcasting select to determine which values should get 230 # the previous state & zero output, and which values should get 231 # a calculated state & output. 232 flat_new_output = [ 233 _copy_one_through(zero_output, new_output) 234 for zero_output, new_output in zip(flat_zero_output, flat_new_output) 235 ] 236 flat_new_state = [ 237 _copy_one_through(state, new_state) 238 for state, new_state in zip(flat_state, flat_new_state) 239 ] 240 return flat_new_output + flat_new_state 241 242 def _maybe_copy_some_through(): 243 """Run RNN step. Pass through either no or some past state.""" 244 new_output, new_state = call_cell() 245 246 nest.assert_same_structure(zero_output, new_output) 247 nest.assert_same_structure(state, new_state) 248 249 flat_new_state = nest.flatten(new_state) 250 flat_new_output = nest.flatten(new_output) 251 return control_flow_ops.cond( 252 # if t < min_seq_len: calculate and return everything 253 time < min_sequence_length, 254 lambda: flat_new_output + flat_new_state, 255 # else copy some of it through 256 lambda: _copy_some_through(flat_new_output, flat_new_state)) 257 258 # TODO(ebrevdo): skipping these conditionals may cause a slowdown, 259 # but benefits from removing cond() and its gradient. We should 260 # profile with and without this switch here. 261 if skip_conditionals: 262 # Instead of using conditionals, perform the selective copy at all time 263 # steps. This is faster when max_seq_len is equal to the number of unrolls 264 # (which is typical for dynamic_rnn). 265 new_output, new_state = call_cell() 266 nest.assert_same_structure(zero_output, new_output) 267 nest.assert_same_structure(state, new_state) 268 new_state = nest.flatten(new_state) 269 new_output = nest.flatten(new_output) 270 final_output_and_state = _copy_some_through(new_output, new_state) 271 else: 272 empty_update = lambda: flat_zero_output + flat_state 273 final_output_and_state = control_flow_ops.cond( 274 # if t >= max_seq_len: copy all state through, output zeros 275 time >= max_sequence_length, 276 empty_update, 277 # otherwise calculation is required: copy some or all of it through 278 _maybe_copy_some_through) 279 280 if len(final_output_and_state) != len(flat_zero_output) + len(flat_state): 281 raise ValueError("Internal error: state and output were not concatenated " 282 "correctly.") 283 final_output = final_output_and_state[:len(flat_zero_output)] 284 final_state = final_output_and_state[len(flat_zero_output):] 285 286 for output, flat_output in zip(final_output, flat_zero_output): 287 output.set_shape(flat_output.get_shape()) 288 for substate, flat_substate in zip(final_state, flat_state): 289 if not isinstance(substate, tensor_array_ops.TensorArray): 290 substate.set_shape(flat_substate.get_shape()) 291 292 final_output = nest.pack_sequence_as( 293 structure=zero_output, flat_sequence=final_output) 294 final_state = nest.pack_sequence_as( 295 structure=state, flat_sequence=final_state) 296 297 return final_output, final_state 298 299 300def _reverse_seq(input_seq, lengths): 301 """Reverse a list of Tensors up to specified lengths. 302 303 Args: 304 input_seq: Sequence of seq_len tensors of dimension (batch_size, n_features) 305 or nested tuples of tensors. 306 lengths: A `Tensor` of dimension batch_size, containing lengths for each 307 sequence in the batch. If "None" is specified, simply reverses the list. 308 309 Returns: 310 time-reversed sequence 311 """ 312 if lengths is None: 313 return list(reversed(input_seq)) 314 315 flat_input_seq = tuple(nest.flatten(input_) for input_ in input_seq) 316 317 flat_results = [[] for _ in range(len(input_seq))] 318 for sequence in zip(*flat_input_seq): 319 input_shape = tensor_shape.unknown_shape(rank=sequence[0].get_shape().rank) 320 for input_ in sequence: 321 input_shape.assert_is_compatible_with(input_.get_shape()) 322 input_.set_shape(input_shape) 323 324 # Join into (time, batch_size, depth) 325 s_joined = array_ops.stack(sequence) 326 327 # Reverse along dimension 0 328 s_reversed = array_ops.reverse_sequence(s_joined, lengths, 0, 1) 329 # Split again into list 330 result = array_ops.unstack(s_reversed) 331 for r, flat_result in zip(result, flat_results): 332 r.set_shape(input_shape) 333 flat_result.append(r) 334 335 results = [ 336 nest.pack_sequence_as(structure=input_, flat_sequence=flat_result) 337 for input_, flat_result in zip(input_seq, flat_results) 338 ] 339 return results 340 341 342@deprecation.deprecated(None, "Please use `keras.layers.Bidirectional(" 343 "keras.layers.RNN(cell))`, which is equivalent to " 344 "this API") 345@tf_export(v1=["nn.bidirectional_dynamic_rnn"]) 346@dispatch.add_dispatch_support 347def bidirectional_dynamic_rnn(cell_fw, 348 cell_bw, 349 inputs, 350 sequence_length=None, 351 initial_state_fw=None, 352 initial_state_bw=None, 353 dtype=None, 354 parallel_iterations=None, 355 swap_memory=False, 356 time_major=False, 357 scope=None): 358 """Creates a dynamic version of bidirectional recurrent neural network. 359 360 Takes input and builds independent forward and backward RNNs. The input_size 361 of forward and backward cell must match. The initial state for both directions 362 is zero by default (but can be set optionally) and no intermediate states are 363 ever returned -- the network is fully unrolled for the given (passed in) 364 length(s) of the sequence(s) or completely unrolled if length(s) is not 365 given. 366 367 Args: 368 cell_fw: An instance of RNNCell, to be used for forward direction. 369 cell_bw: An instance of RNNCell, to be used for backward direction. 370 inputs: The RNN inputs. 371 If time_major == False (default), this must be a tensor of shape: 372 `[batch_size, max_time, ...]`, or a nested tuple of such elements. 373 If time_major == True, this must be a tensor of shape: `[max_time, 374 batch_size, ...]`, or a nested tuple of such elements. 375 sequence_length: (optional) An int32/int64 vector, size `[batch_size]`, 376 containing the actual lengths for each of the sequences in the batch. If 377 not provided, all batch entries are assumed to be full sequences; and time 378 reversal is applied from time `0` to `max_time` for each sequence. 379 initial_state_fw: (optional) An initial state for the forward RNN. This must 380 be a tensor of appropriate type and shape `[batch_size, 381 cell_fw.state_size]`. If `cell_fw.state_size` is a tuple, this should be a 382 tuple of tensors having shapes `[batch_size, s] for s in 383 cell_fw.state_size`. 384 initial_state_bw: (optional) Same as for `initial_state_fw`, but using the 385 corresponding properties of `cell_bw`. 386 dtype: (optional) The data type for the initial states and expected output. 387 Required if initial_states are not provided or RNN states have a 388 heterogeneous dtype. 389 parallel_iterations: (Default: 32). The number of iterations to run in 390 parallel. Those operations which do not have any temporal dependency and 391 can be run in parallel, will be. This parameter trades off time for 392 space. Values >> 1 use more memory but take less time, while smaller 393 values use less memory but computations take longer. 394 swap_memory: Transparently swap the tensors produced in forward inference 395 but needed for back prop from GPU to CPU. This allows training RNNs which 396 would typically not fit on a single GPU, with very minimal (or no) 397 performance penalty. 398 time_major: The shape format of the `inputs` and `outputs` Tensors. If true, 399 these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false, 400 these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using 401 `time_major = True` is a bit more efficient because it avoids transposes 402 at the beginning and end of the RNN calculation. However, most TensorFlow 403 data is batch-major, so by default this function accepts input and emits 404 output in batch-major form. 405 scope: VariableScope for the created subgraph; defaults to 406 "bidirectional_rnn" 407 408 Returns: 409 A tuple (outputs, output_states) where: 410 outputs: A tuple (output_fw, output_bw) containing the forward and 411 the backward rnn output `Tensor`. 412 If time_major == False (default), 413 output_fw will be a `Tensor` shaped: 414 `[batch_size, max_time, cell_fw.output_size]` 415 and output_bw will be a `Tensor` shaped: 416 `[batch_size, max_time, cell_bw.output_size]`. 417 If time_major == True, 418 output_fw will be a `Tensor` shaped: 419 `[max_time, batch_size, cell_fw.output_size]` 420 and output_bw will be a `Tensor` shaped: 421 `[max_time, batch_size, cell_bw.output_size]`. 422 It returns a tuple instead of a single concatenated `Tensor`, unlike 423 in the `bidirectional_rnn`. If the concatenated one is preferred, 424 the forward and backward outputs can be concatenated as 425 `tf.concat(outputs, 2)`. 426 output_states: A tuple (output_state_fw, output_state_bw) containing 427 the forward and the backward final states of bidirectional rnn. 428 429 Raises: 430 TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`. 431 """ 432 rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw) 433 rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw) 434 435 with vs.variable_scope(scope or "bidirectional_rnn"): 436 # Forward direction 437 with vs.variable_scope("fw") as fw_scope: 438 output_fw, output_state_fw = dynamic_rnn( 439 cell=cell_fw, 440 inputs=inputs, 441 sequence_length=sequence_length, 442 initial_state=initial_state_fw, 443 dtype=dtype, 444 parallel_iterations=parallel_iterations, 445 swap_memory=swap_memory, 446 time_major=time_major, 447 scope=fw_scope) 448 449 # Backward direction 450 if not time_major: 451 time_axis = 1 452 batch_axis = 0 453 else: 454 time_axis = 0 455 batch_axis = 1 456 457 def _reverse(input_, seq_lengths, seq_axis, batch_axis): 458 if seq_lengths is not None: 459 return array_ops.reverse_sequence( 460 input=input_, 461 seq_lengths=seq_lengths, 462 seq_axis=seq_axis, 463 batch_axis=batch_axis) 464 else: 465 return array_ops.reverse(input_, axis=[seq_axis]) 466 467 with vs.variable_scope("bw") as bw_scope: 468 469 def _map_reverse(inp): 470 return _reverse( 471 inp, 472 seq_lengths=sequence_length, 473 seq_axis=time_axis, 474 batch_axis=batch_axis) 475 476 inputs_reverse = nest.map_structure(_map_reverse, inputs) 477 tmp, output_state_bw = dynamic_rnn( 478 cell=cell_bw, 479 inputs=inputs_reverse, 480 sequence_length=sequence_length, 481 initial_state=initial_state_bw, 482 dtype=dtype, 483 parallel_iterations=parallel_iterations, 484 swap_memory=swap_memory, 485 time_major=time_major, 486 scope=bw_scope) 487 488 output_bw = _reverse( 489 tmp, 490 seq_lengths=sequence_length, 491 seq_axis=time_axis, 492 batch_axis=batch_axis) 493 494 outputs = (output_fw, output_bw) 495 output_states = (output_state_fw, output_state_bw) 496 497 return (outputs, output_states) 498 499 500@deprecation.deprecated( 501 None, 502 "Please use `keras.layers.RNN(cell)`, which is equivalent to this API") 503@tf_export(v1=["nn.dynamic_rnn"]) 504@dispatch.add_dispatch_support 505def dynamic_rnn(cell, 506 inputs, 507 sequence_length=None, 508 initial_state=None, 509 dtype=None, 510 parallel_iterations=None, 511 swap_memory=False, 512 time_major=False, 513 scope=None): 514 """Creates a recurrent neural network specified by RNNCell `cell`. 515 516 Performs fully dynamic unrolling of `inputs`. 517 518 Example: 519 520 ```python 521 # create a BasicRNNCell 522 rnn_cell = tf.compat.v1.nn.rnn_cell.BasicRNNCell(hidden_size) 523 524 # 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size] 525 526 # defining initial state 527 initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32) 528 529 # 'state' is a tensor of shape [batch_size, cell_state_size] 530 outputs, state = tf.compat.v1.nn.dynamic_rnn(rnn_cell, input_data, 531 initial_state=initial_state, 532 dtype=tf.float32) 533 ``` 534 535 ```python 536 # create 2 LSTMCells 537 rnn_layers = [tf.compat.v1.nn.rnn_cell.LSTMCell(size) for size in [128, 256]] 538 539 # create a RNN cell composed sequentially of a number of RNNCells 540 multi_rnn_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell(rnn_layers) 541 542 # 'outputs' is a tensor of shape [batch_size, max_time, 256] 543 # 'state' is a N-tuple where N is the number of LSTMCells containing a 544 # tf.nn.rnn_cell.LSTMStateTuple for each cell 545 outputs, state = tf.compat.v1.nn.dynamic_rnn(cell=multi_rnn_cell, 546 inputs=data, 547 dtype=tf.float32) 548 ``` 549 550 551 Args: 552 cell: An instance of RNNCell. 553 inputs: The RNN inputs. 554 If `time_major == False` (default), this must be a `Tensor` of shape: 555 `[batch_size, max_time, ...]`, or a nested tuple of such elements. 556 If `time_major == True`, this must be a `Tensor` of shape: `[max_time, 557 batch_size, ...]`, or a nested tuple of such elements. This may also be 558 a (possibly nested) tuple of Tensors satisfying this property. The 559 first two dimensions must match across all the inputs, but otherwise the 560 ranks and other shape components may differ. In this case, input to 561 `cell` at each time-step will replicate the structure of these tuples, 562 except for the time dimension (from which the time is taken). The input 563 to `cell` at each time step will be a `Tensor` or (possibly nested) 564 tuple of Tensors each with dimensions `[batch_size, ...]`. 565 sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. Used 566 to copy-through state and zero-out outputs when past a batch element's 567 sequence length. This parameter enables users to extract the last valid 568 state and properly padded outputs, so it is provided for correctness. 569 initial_state: (optional) An initial state for the RNN. If `cell.state_size` 570 is an integer, this must be a `Tensor` of appropriate type and shape 571 `[batch_size, cell.state_size]`. If `cell.state_size` is a tuple, this 572 should be a tuple of tensors having shapes `[batch_size, s] for s in 573 cell.state_size`. 574 dtype: (optional) The data type for the initial state and expected output. 575 Required if initial_state is not provided or RNN state has a heterogeneous 576 dtype. 577 parallel_iterations: (Default: 32). The number of iterations to run in 578 parallel. Those operations which do not have any temporal dependency and 579 can be run in parallel, will be. This parameter trades off time for 580 space. Values >> 1 use more memory but take less time, while smaller 581 values use less memory but computations take longer. 582 swap_memory: Transparently swap the tensors produced in forward inference 583 but needed for back prop from GPU to CPU. This allows training RNNs which 584 would typically not fit on a single GPU, with very minimal (or no) 585 performance penalty. 586 time_major: The shape format of the `inputs` and `outputs` Tensors. If true, 587 these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false, 588 these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using 589 `time_major = True` is a bit more efficient because it avoids transposes 590 at the beginning and end of the RNN calculation. However, most TensorFlow 591 data is batch-major, so by default this function accepts input and emits 592 output in batch-major form. 593 scope: VariableScope for the created subgraph; defaults to "rnn". 594 595 Returns: 596 A pair (outputs, state) where: 597 598 outputs: The RNN output `Tensor`. 599 600 If time_major == False (default), this will be a `Tensor` shaped: 601 `[batch_size, max_time, cell.output_size]`. 602 603 If time_major == True, this will be a `Tensor` shaped: 604 `[max_time, batch_size, cell.output_size]`. 605 606 Note, if `cell.output_size` is a (possibly nested) tuple of integers 607 or `TensorShape` objects, then `outputs` will be a tuple having the 608 same structure as `cell.output_size`, containing Tensors having shapes 609 corresponding to the shape data in `cell.output_size`. 610 611 state: The final state. If `cell.state_size` is an int, this 612 will be shaped `[batch_size, cell.state_size]`. If it is a 613 `TensorShape`, this will be shaped `[batch_size] + cell.state_size`. 614 If it is a (possibly nested) tuple of ints or `TensorShape`, this will 615 be a tuple having the corresponding shapes. If cells are `LSTMCells` 616 `state` will be a tuple containing a `LSTMStateTuple` for each cell. 617 618 Raises: 619 TypeError: If `cell` is not an instance of RNNCell. 620 ValueError: If inputs is None or an empty list. 621 """ 622 rnn_cell_impl.assert_like_rnncell("cell", cell) 623 624 with vs.variable_scope(scope or "rnn") as varscope: 625 # Create a new scope in which the caching device is either 626 # determined by the parent scope, or is set to place the cached 627 # Variable using the same placement as for the rest of the RNN. 628 if _should_cache(): 629 if varscope.caching_device is None: 630 varscope.set_caching_device(lambda op: op.device) 631 632 # By default, time_major==False and inputs are batch-major: shaped 633 # [batch, time, depth] 634 # For internal calculations, we transpose to [time, batch, depth] 635 flat_input = nest.flatten(inputs) 636 637 if not time_major: 638 # (B,T,D) => (T,B,D) 639 flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input] 640 flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input) 641 642 parallel_iterations = parallel_iterations or 32 643 if sequence_length is not None: 644 sequence_length = math_ops.cast(sequence_length, dtypes.int32) 645 if sequence_length.get_shape().rank not in (None, 1): 646 raise ValueError( 647 "sequence_length must be a vector of length batch_size, " 648 "but saw shape: %s" % sequence_length.get_shape()) 649 sequence_length = array_ops.identity( # Just to find it in the graph. 650 sequence_length, 651 name="sequence_length") 652 653 batch_size = _best_effort_input_batch_size(flat_input) 654 655 if initial_state is not None: 656 state = initial_state 657 else: 658 if not dtype: 659 raise ValueError("If there is no initial_state, you must give a dtype.") 660 if getattr(cell, "get_initial_state", None) is not None: 661 state = cell.get_initial_state( 662 inputs=None, batch_size=batch_size, dtype=dtype) 663 else: 664 state = cell.zero_state(batch_size, dtype) 665 666 def _assert_has_shape(x, shape): 667 x_shape = array_ops.shape(x) 668 packed_shape = array_ops.stack(shape) 669 return control_flow_ops.Assert( 670 math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)), [ 671 "Expected shape for Tensor %s is " % x.name, packed_shape, 672 " but saw shape: ", x_shape 673 ]) 674 675 if not context.executing_eagerly() and sequence_length is not None: 676 # Perform some shape validation 677 with ops.control_dependencies( 678 [_assert_has_shape(sequence_length, [batch_size])]): 679 sequence_length = array_ops.identity( 680 sequence_length, name="CheckSeqLen") 681 682 inputs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input) 683 684 (outputs, final_state) = _dynamic_rnn_loop( 685 cell, 686 inputs, 687 state, 688 parallel_iterations=parallel_iterations, 689 swap_memory=swap_memory, 690 sequence_length=sequence_length, 691 dtype=dtype) 692 693 # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth]. 694 # If we are performing batch-major calculations, transpose output back 695 # to shape [batch, time, depth] 696 if not time_major: 697 # (T,B,D) => (B,T,D) 698 outputs = nest.map_structure(_transpose_batch_time, outputs) 699 700 return (outputs, final_state) 701 702 703def _dynamic_rnn_loop(cell, 704 inputs, 705 initial_state, 706 parallel_iterations, 707 swap_memory, 708 sequence_length=None, 709 dtype=None): 710 """Internal implementation of Dynamic RNN. 711 712 Args: 713 cell: An instance of RNNCell. 714 inputs: A `Tensor` of shape [time, batch_size, input_size], or a nested 715 tuple of such elements. 716 initial_state: A `Tensor` of shape `[batch_size, state_size]`, or if 717 `cell.state_size` is a tuple, then this should be a tuple of tensors 718 having shapes `[batch_size, s] for s in cell.state_size`. 719 parallel_iterations: Positive Python int. 720 swap_memory: A Python boolean 721 sequence_length: (optional) An `int32` `Tensor` of shape [batch_size]. 722 dtype: (optional) Expected dtype of output. If not specified, inferred from 723 initial_state. 724 725 Returns: 726 Tuple `(final_outputs, final_state)`. 727 final_outputs: 728 A `Tensor` of shape `[time, batch_size, cell.output_size]`. If 729 `cell.output_size` is a (possibly nested) tuple of ints or `TensorShape` 730 objects, then this returns a (possibly nested) tuple of Tensors matching 731 the corresponding shapes. 732 final_state: 733 A `Tensor`, or possibly nested tuple of Tensors, matching in length 734 and shapes to `initial_state`. 735 736 Raises: 737 ValueError: If the input depth cannot be inferred via shape inference 738 from the inputs. 739 ValueError: If time_step is not the same for all the elements in the 740 inputs. 741 ValueError: If batch_size is not the same for all the elements in the 742 inputs. 743 """ 744 state = initial_state 745 assert isinstance(parallel_iterations, int), "parallel_iterations must be int" 746 747 state_size = cell.state_size 748 749 flat_input = nest.flatten(inputs) 750 flat_output_size = nest.flatten(cell.output_size) 751 752 # Construct an initial output 753 input_shape = array_ops.shape(flat_input[0]) 754 time_steps = input_shape[0] 755 batch_size = _best_effort_input_batch_size(flat_input) 756 757 inputs_got_shape = tuple( 758 input_.get_shape().with_rank_at_least(3) for input_ in flat_input) 759 760 const_time_steps, const_batch_size = inputs_got_shape[0].as_list()[:2] 761 762 for shape in inputs_got_shape: 763 if not shape[2:].is_fully_defined(): 764 raise ValueError( 765 "Input size (depth of inputs) must be accessible via shape inference," 766 " but saw value None.") 767 got_time_steps = shape.dims[0].value 768 got_batch_size = shape.dims[1].value 769 if const_time_steps != got_time_steps: 770 raise ValueError( 771 "Time steps is not the same for all the elements in the input in a " 772 "batch.") 773 if const_batch_size != got_batch_size: 774 raise ValueError( 775 "Batch_size is not the same for all the elements in the input.") 776 777 # Prepare dynamic conditional copying of state & output 778 def _create_zero_arrays(size): 779 size = _concat(batch_size, size) 780 return array_ops.zeros( 781 array_ops.stack(size), _infer_state_dtype(dtype, state)) 782 783 flat_zero_output = tuple( 784 _create_zero_arrays(output) for output in flat_output_size) 785 zero_output = nest.pack_sequence_as( 786 structure=cell.output_size, flat_sequence=flat_zero_output) 787 788 if sequence_length is not None: 789 min_sequence_length = math_ops.reduce_min(sequence_length) 790 max_sequence_length = math_ops.reduce_max(sequence_length) 791 else: 792 max_sequence_length = time_steps 793 794 time = array_ops.constant(0, dtype=dtypes.int32, name="time") 795 796 with ops.name_scope("dynamic_rnn") as scope: 797 base_name = scope 798 799 def _create_ta(name, element_shape, dtype): 800 return tensor_array_ops.TensorArray( 801 dtype=dtype, 802 size=time_steps, 803 element_shape=element_shape, 804 tensor_array_name=base_name + name) 805 806 in_graph_mode = not context.executing_eagerly() 807 if in_graph_mode: 808 output_ta = tuple( 809 _create_ta( 810 "output_%d" % i, 811 element_shape=( 812 tensor_shape.TensorShape([const_batch_size]).concatenate( 813 _maybe_tensor_shape_from_tensor(out_size))), 814 dtype=_infer_state_dtype(dtype, state)) 815 for i, out_size in enumerate(flat_output_size)) 816 input_ta = tuple( 817 _create_ta( 818 "input_%d" % i, 819 element_shape=flat_input_i.shape[1:], 820 dtype=flat_input_i.dtype) 821 for i, flat_input_i in enumerate(flat_input)) 822 input_ta = tuple( 823 ta.unstack(input_) for ta, input_ in zip(input_ta, flat_input)) 824 else: 825 output_ta = tuple([0 for _ in range(time_steps.numpy())] 826 for i in range(len(flat_output_size))) 827 input_ta = flat_input 828 829 def _time_step(time, output_ta_t, state): 830 """Take a time step of the dynamic RNN. 831 832 Args: 833 time: int32 scalar Tensor. 834 output_ta_t: List of `TensorArray`s that represent the output. 835 state: nested tuple of vector tensors that represent the state. 836 837 Returns: 838 The tuple (time + 1, output_ta_t with updated flow, new_state). 839 """ 840 841 if in_graph_mode: 842 input_t = tuple(ta.read(time) for ta in input_ta) 843 # Restore some shape information 844 for input_, shape in zip(input_t, inputs_got_shape): 845 input_.set_shape(shape[1:]) 846 else: 847 input_t = tuple(ta[time.numpy()] for ta in input_ta) 848 849 input_t = nest.pack_sequence_as(structure=inputs, flat_sequence=input_t) 850 # Keras RNN cells only accept state as list, even if it's a single tensor. 851 call_cell = lambda: cell(input_t, state) 852 853 if sequence_length is not None: 854 (output, new_state) = _rnn_step( 855 time=time, 856 sequence_length=sequence_length, 857 min_sequence_length=min_sequence_length, 858 max_sequence_length=max_sequence_length, 859 zero_output=zero_output, 860 state=state, 861 call_cell=call_cell, 862 state_size=state_size, 863 skip_conditionals=True) 864 else: 865 (output, new_state) = call_cell() 866 867 # Pack state if using state tuples 868 output = nest.flatten(output) 869 870 if in_graph_mode: 871 output_ta_t = tuple( 872 ta.write(time, out) for ta, out in zip(output_ta_t, output)) 873 else: 874 for ta, out in zip(output_ta_t, output): 875 ta[time.numpy()] = out 876 877 return (time + 1, output_ta_t, new_state) 878 879 if in_graph_mode: 880 # Make sure that we run at least 1 step, if necessary, to ensure 881 # the TensorArrays pick up the dynamic shape. 882 loop_bound = math_ops.minimum(time_steps, 883 math_ops.maximum(1, max_sequence_length)) 884 else: 885 # Using max_sequence_length isn't currently supported in the Eager branch. 886 loop_bound = time_steps 887 888 _, output_final_ta, final_state = control_flow_ops.while_loop( 889 cond=lambda time, *_: time < loop_bound, 890 body=_time_step, 891 loop_vars=(time, output_ta, state), 892 parallel_iterations=parallel_iterations, 893 maximum_iterations=time_steps, 894 swap_memory=swap_memory) 895 896 # Unpack final output if not using output tuples. 897 if in_graph_mode: 898 final_outputs = tuple(ta.stack() for ta in output_final_ta) 899 # Restore some shape information 900 for output, output_size in zip(final_outputs, flat_output_size): 901 shape = _concat([const_time_steps, const_batch_size], 902 output_size, 903 static=True) 904 output.set_shape(shape) 905 else: 906 final_outputs = output_final_ta 907 908 final_outputs = nest.pack_sequence_as( 909 structure=cell.output_size, flat_sequence=final_outputs) 910 if not in_graph_mode: 911 final_outputs = nest.map_structure_up_to( 912 cell.output_size, lambda x: array_ops.stack(x, axis=0), final_outputs) 913 914 return (final_outputs, final_state) 915 916 917@tf_export(v1=["nn.raw_rnn"]) 918@dispatch.add_dispatch_support 919def raw_rnn(cell, 920 loop_fn, 921 parallel_iterations=None, 922 swap_memory=False, 923 scope=None): 924 """Creates an `RNN` specified by RNNCell `cell` and loop function `loop_fn`. 925 926 **NOTE: This method is still in testing, and the API may change.** 927 928 This function is a more primitive version of `dynamic_rnn` that provides 929 more direct access to the inputs each iteration. It also provides more 930 control over when to start and finish reading the sequence, and 931 what to emit for the output. 932 933 For example, it can be used to implement the dynamic decoder of a seq2seq 934 model. 935 936 Instead of working with `Tensor` objects, most operations work with 937 `TensorArray` objects directly. 938 939 The operation of `raw_rnn`, in pseudo-code, is basically the following: 940 941 ```python 942 time = tf.constant(0, dtype=tf.int32) 943 (finished, next_input, initial_state, emit_structure, loop_state) = loop_fn( 944 time=time, cell_output=None, cell_state=None, loop_state=None) 945 emit_ta = TensorArray(dynamic_size=True, dtype=initial_state.dtype) 946 state = initial_state 947 while not all(finished): 948 (output, cell_state) = cell(next_input, state) 949 (next_finished, next_input, next_state, emit, loop_state) = loop_fn( 950 time=time + 1, cell_output=output, cell_state=cell_state, 951 loop_state=loop_state) 952 # Emit zeros and copy forward state for minibatch entries that are finished. 953 state = tf.where(finished, state, next_state) 954 emit = tf.where(finished, tf.zeros_like(emit_structure), emit) 955 emit_ta = emit_ta.write(time, emit) 956 # If any new minibatch entries are marked as finished, mark these. 957 finished = tf.logical_or(finished, next_finished) 958 time += 1 959 return (emit_ta, state, loop_state) 960 ``` 961 962 with the additional properties that output and state may be (possibly nested) 963 tuples, as determined by `cell.output_size` and `cell.state_size`, and 964 as a result the final `state` and `emit_ta` may themselves be tuples. 965 966 A simple implementation of `dynamic_rnn` via `raw_rnn` looks like this: 967 968 ```python 969 inputs = tf.compat.v1.placeholder(shape=(max_time, batch_size, input_depth), 970 dtype=tf.float32) 971 sequence_length = tf.compat.v1.placeholder(shape=(batch_size,), 972 dtype=tf.int32) 973 inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time) 974 inputs_ta = inputs_ta.unstack(inputs) 975 976 cell = tf.compat.v1.nn.rnn_cell.LSTMCell(num_units) 977 978 def loop_fn(time, cell_output, cell_state, loop_state): 979 emit_output = cell_output # == None for time == 0 980 if cell_output is None: # time == 0 981 next_cell_state = cell.zero_state(batch_size, tf.float32) 982 else: 983 next_cell_state = cell_state 984 elements_finished = (time >= sequence_length) 985 finished = tf.reduce_all(elements_finished) 986 next_input = tf.cond( 987 finished, 988 lambda: tf.zeros([batch_size, input_depth], dtype=tf.float32), 989 lambda: inputs_ta.read(time)) 990 next_loop_state = None 991 return (elements_finished, next_input, next_cell_state, 992 emit_output, next_loop_state) 993 994 outputs_ta, final_state, _ = raw_rnn(cell, loop_fn) 995 outputs = outputs_ta.stack() 996 ``` 997 998 Args: 999 cell: An instance of RNNCell. 1000 loop_fn: A callable that takes inputs `(time, cell_output, cell_state, 1001 loop_state)` and returns the tuple `(finished, next_input, 1002 next_cell_state, emit_output, next_loop_state)`. Here `time` is an int32 1003 scalar `Tensor`, `cell_output` is a `Tensor` or (possibly nested) tuple of 1004 tensors as determined by `cell.output_size`, and `cell_state` is a 1005 `Tensor` or (possibly nested) tuple of tensors, as determined by the 1006 `loop_fn` on its first call (and should match `cell.state_size`). 1007 The outputs are: `finished`, a boolean `Tensor` of 1008 shape `[batch_size]`, `next_input`: the next input to feed to `cell`, 1009 `next_cell_state`: the next state to feed to `cell`, 1010 and `emit_output`: the output to store for this iteration. Note that 1011 `emit_output` should be a `Tensor` or (possibly nested) tuple of tensors 1012 which is aggregated in the `emit_ta` inside the `while_loop`. For the 1013 first call to `loop_fn`, the `emit_output` corresponds to the 1014 `emit_structure` which is then used to determine the size of the 1015 `zero_tensor` for the `emit_ta` (defaults to `cell.output_size`). For 1016 the subsequent calls to the `loop_fn`, the `emit_output` corresponds to 1017 the actual output tensor that is to be aggregated in the `emit_ta`. The 1018 parameter `cell_state` and output `next_cell_state` may be either a 1019 single or (possibly nested) tuple of tensors. The parameter 1020 `loop_state` and output `next_loop_state` may be either a single or 1021 (possibly nested) tuple of `Tensor` and `TensorArray` objects. This 1022 last parameter may be ignored by `loop_fn` and the return value may be 1023 `None`. If it is not `None`, then the `loop_state` will be propagated 1024 through the RNN loop, for use purely by `loop_fn` to keep track of its 1025 own state. The `next_loop_state` parameter returned may be `None`. The 1026 first call to `loop_fn` will be `time = 0`, `cell_output = None`, 1027 `cell_state = None`, and `loop_state = None`. For this call: The 1028 `next_cell_state` value should be the value with which to initialize the 1029 cell's state. It may be a final state from a previous RNN or it may be 1030 the output of `cell.zero_state()`. It should be a (possibly nested) 1031 tuple structure of tensors. If `cell.state_size` is an integer, this 1032 must be a `Tensor` of appropriate type and shape `[batch_size, 1033 cell.state_size]`. If `cell.state_size` is a `TensorShape`, this must be 1034 a `Tensor` of appropriate type and shape `[batch_size] + 1035 cell.state_size`. If `cell.state_size` is a (possibly nested) tuple of 1036 ints or `TensorShape`, this will be a tuple having the corresponding 1037 shapes. The `emit_output` value may be either `None` or a (possibly 1038 nested) tuple structure of tensors, e.g., `(tf.zeros(shape_0, 1039 dtype=dtype_0), tf.zeros(shape_1, dtype=dtype_1))`. If this first 1040 `emit_output` return value is `None`, then the `emit_ta` result of 1041 `raw_rnn` will have the same structure and dtypes as `cell.output_size`. 1042 Otherwise `emit_ta` will have the same structure, shapes (prepended with 1043 a `batch_size` dimension), and dtypes as `emit_output`. The actual 1044 values returned for `emit_output` at this initializing call are ignored. 1045 Note, this emit structure must be consistent across all time steps. 1046 parallel_iterations: (Default: 32). The number of iterations to run in 1047 parallel. Those operations which do not have any temporal dependency and 1048 can be run in parallel, will be. This parameter trades off time for 1049 space. Values >> 1 use more memory but take less time, while smaller 1050 values use less memory but computations take longer. 1051 swap_memory: Transparently swap the tensors produced in forward inference 1052 but needed for back prop from GPU to CPU. This allows training RNNs which 1053 would typically not fit on a single GPU, with very minimal (or no) 1054 performance penalty. 1055 scope: VariableScope for the created subgraph; defaults to "rnn". 1056 1057 Returns: 1058 A tuple `(emit_ta, final_state, final_loop_state)` where: 1059 1060 `emit_ta`: The RNN output `TensorArray`. 1061 If `loop_fn` returns a (possibly nested) set of Tensors for 1062 `emit_output` during initialization, (inputs `time = 0`, 1063 `cell_output = None`, and `loop_state = None`), then `emit_ta` will 1064 have the same structure, dtypes, and shapes as `emit_output` instead. 1065 If `loop_fn` returns `emit_output = None` during this call, 1066 the structure of `cell.output_size` is used: 1067 If `cell.output_size` is a (possibly nested) tuple of integers 1068 or `TensorShape` objects, then `emit_ta` will be a tuple having the 1069 same structure as `cell.output_size`, containing TensorArrays whose 1070 elements' shapes correspond to the shape data in `cell.output_size`. 1071 1072 `final_state`: The final cell state. If `cell.state_size` is an int, this 1073 will be shaped `[batch_size, cell.state_size]`. If it is a 1074 `TensorShape`, this will be shaped `[batch_size] + cell.state_size`. 1075 If it is a (possibly nested) tuple of ints or `TensorShape`, this will 1076 be a tuple having the corresponding shapes. 1077 1078 `final_loop_state`: The final loop state as returned by `loop_fn`. 1079 1080 Raises: 1081 TypeError: If `cell` is not an instance of RNNCell, or `loop_fn` is not 1082 a `callable`. 1083 """ 1084 rnn_cell_impl.assert_like_rnncell("cell", cell) 1085 1086 if not callable(loop_fn): 1087 raise TypeError("loop_fn must be a callable") 1088 1089 parallel_iterations = parallel_iterations or 32 1090 1091 # Create a new scope in which the caching device is either 1092 # determined by the parent scope, or is set to place the cached 1093 # Variable using the same placement as for the rest of the RNN. 1094 with vs.variable_scope(scope or "rnn") as varscope: 1095 if _should_cache(): 1096 if varscope.caching_device is None: 1097 varscope.set_caching_device(lambda op: op.device) 1098 1099 time = constant_op.constant(0, dtype=dtypes.int32) 1100 (elements_finished, next_input, 1101 initial_state, emit_structure, init_loop_state) = loop_fn( 1102 time, None, None, None) # time, cell_output, cell_state, loop_state 1103 flat_input = nest.flatten(next_input) 1104 1105 # Need a surrogate loop state for the while_loop if none is available. 1106 loop_state = ( 1107 init_loop_state if init_loop_state is not None else 1108 constant_op.constant(0, dtype=dtypes.int32)) 1109 1110 input_shape = [input_.get_shape() for input_ in flat_input] 1111 static_batch_size = tensor_shape.dimension_at_index(input_shape[0], 0) 1112 1113 for input_shape_i in input_shape: 1114 # Static verification that batch sizes all match 1115 static_batch_size.assert_is_compatible_with( 1116 tensor_shape.dimension_at_index(input_shape_i, 0)) 1117 1118 batch_size = tensor_shape.dimension_value(static_batch_size) 1119 const_batch_size = batch_size 1120 if batch_size is None: 1121 batch_size = array_ops.shape(flat_input[0])[0] 1122 1123 nest.assert_same_structure(initial_state, cell.state_size) 1124 state = initial_state 1125 flat_state = nest.flatten(state) 1126 flat_state = [ops.convert_to_tensor(s) for s in flat_state] 1127 state = nest.pack_sequence_as(structure=state, flat_sequence=flat_state) 1128 1129 if emit_structure is not None: 1130 flat_emit_structure = nest.flatten(emit_structure) 1131 flat_emit_size = [ 1132 emit.shape if emit.shape.is_fully_defined() else array_ops.shape(emit) 1133 for emit in flat_emit_structure 1134 ] 1135 flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure] 1136 else: 1137 emit_structure = cell.output_size 1138 flat_emit_size = nest.flatten(emit_structure) 1139 flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size) 1140 1141 flat_emit_ta = [ 1142 tensor_array_ops.TensorArray( 1143 dtype=dtype_i, 1144 dynamic_size=True, 1145 element_shape=(tensor_shape.TensorShape([ 1146 const_batch_size 1147 ]).concatenate(_maybe_tensor_shape_from_tensor(size_i))), 1148 size=0, 1149 name="rnn_output_%d" % i) 1150 for i, (dtype_i, 1151 size_i) in enumerate(zip(flat_emit_dtypes, flat_emit_size)) 1152 ] 1153 emit_ta = nest.pack_sequence_as( 1154 structure=emit_structure, flat_sequence=flat_emit_ta) 1155 flat_zero_emit = [ 1156 array_ops.zeros(_concat(batch_size, size_i), dtype_i) 1157 for size_i, dtype_i in zip(flat_emit_size, flat_emit_dtypes) 1158 ] 1159 zero_emit = nest.pack_sequence_as( 1160 structure=emit_structure, flat_sequence=flat_zero_emit) 1161 1162 def condition(unused_time, elements_finished, *_): 1163 return math_ops.logical_not(math_ops.reduce_all(elements_finished)) 1164 1165 def body(time, elements_finished, current_input, emit_ta, state, 1166 loop_state): 1167 """Internal while loop body for raw_rnn. 1168 1169 Args: 1170 time: time scalar. 1171 elements_finished: batch-size vector. 1172 current_input: possibly nested tuple of input tensors. 1173 emit_ta: possibly nested tuple of output TensorArrays. 1174 state: possibly nested tuple of state tensors. 1175 loop_state: possibly nested tuple of loop state tensors. 1176 1177 Returns: 1178 Tuple having the same size as Args but with updated values. 1179 """ 1180 (next_output, cell_state) = cell(current_input, state) 1181 1182 nest.assert_same_structure(state, cell_state) 1183 nest.assert_same_structure(cell.output_size, next_output) 1184 1185 next_time = time + 1 1186 (next_finished, next_input, next_state, emit_output, 1187 next_loop_state) = loop_fn(next_time, next_output, cell_state, 1188 loop_state) 1189 1190 nest.assert_same_structure(state, next_state) 1191 nest.assert_same_structure(current_input, next_input) 1192 nest.assert_same_structure(emit_ta, emit_output) 1193 1194 # If loop_fn returns None for next_loop_state, just reuse the 1195 # previous one. 1196 loop_state = loop_state if next_loop_state is None else next_loop_state 1197 1198 def _copy_some_through(current, candidate): 1199 """Copy some tensors through via array_ops.where.""" 1200 1201 def copy_fn(cur_i, cand_i): 1202 # TensorArray and scalar get passed through. 1203 if isinstance(cur_i, tensor_array_ops.TensorArray): 1204 return cand_i 1205 if cur_i.shape.rank == 0: 1206 return cand_i 1207 # Otherwise propagate the old or the new value. 1208 with ops.colocate_with(cand_i): 1209 return array_ops.where(elements_finished, cur_i, cand_i) 1210 1211 return nest.map_structure(copy_fn, current, candidate) 1212 1213 emit_output = _copy_some_through(zero_emit, emit_output) 1214 next_state = _copy_some_through(state, next_state) 1215 1216 emit_ta = nest.map_structure(lambda ta, emit: ta.write(time, emit), 1217 emit_ta, emit_output) 1218 1219 elements_finished = math_ops.logical_or(elements_finished, next_finished) 1220 1221 return (next_time, elements_finished, next_input, emit_ta, next_state, 1222 loop_state) 1223 1224 returned = control_flow_ops.while_loop( 1225 condition, 1226 body, 1227 loop_vars=[ 1228 time, elements_finished, next_input, emit_ta, state, loop_state 1229 ], 1230 parallel_iterations=parallel_iterations, 1231 swap_memory=swap_memory) 1232 1233 (emit_ta, final_state, final_loop_state) = returned[-3:] 1234 1235 if init_loop_state is None: 1236 final_loop_state = None 1237 1238 return (emit_ta, final_state, final_loop_state) 1239 1240 1241@deprecation.deprecated(None, 1242 "Please use `keras.layers.RNN(cell, unroll=True)`, " 1243 "which is equivalent to this API") 1244@tf_export(v1=["nn.static_rnn"]) 1245@dispatch.add_dispatch_support 1246def static_rnn(cell, 1247 inputs, 1248 initial_state=None, 1249 dtype=None, 1250 sequence_length=None, 1251 scope=None): 1252 """Creates a recurrent neural network specified by RNNCell `cell`. 1253 1254 The simplest form of RNN network generated is: 1255 1256 ```python 1257 state = cell.zero_state(...) 1258 outputs = [] 1259 for input_ in inputs: 1260 output, state = cell(input_, state) 1261 outputs.append(output) 1262 return (outputs, state) 1263 ``` 1264 However, a few other options are available: 1265 1266 An initial state can be provided. 1267 If the sequence_length vector is provided, dynamic calculation is performed. 1268 This method of calculation does not compute the RNN steps past the maximum 1269 sequence length of the minibatch (thus saving computational time), 1270 and properly propagates the state at an example's sequence length 1271 to the final state output. 1272 1273 The dynamic calculation performed is, at time `t` for batch row `b`, 1274 1275 ```python 1276 (output, state)(b, t) = 1277 (t >= sequence_length(b)) 1278 ? (zeros(cell.output_size), states(b, sequence_length(b) - 1)) 1279 : cell(input(b, t), state(b, t - 1)) 1280 ``` 1281 1282 Args: 1283 cell: An instance of RNNCell. 1284 inputs: A length T list of inputs, each a `Tensor` of shape `[batch_size, 1285 input_size]`, or a nested tuple of such elements. 1286 initial_state: (optional) An initial state for the RNN. If `cell.state_size` 1287 is an integer, this must be a `Tensor` of appropriate type and shape 1288 `[batch_size, cell.state_size]`. If `cell.state_size` is a tuple, this 1289 should be a tuple of tensors having shapes `[batch_size, s] for s in 1290 cell.state_size`. 1291 dtype: (optional) The data type for the initial state and expected output. 1292 Required if initial_state is not provided or RNN state has a heterogeneous 1293 dtype. 1294 sequence_length: Specifies the length of each sequence in inputs. An int32 1295 or int64 vector (tensor) size `[batch_size]`, values in `[0, T)`. 1296 scope: VariableScope for the created subgraph; defaults to "rnn". 1297 1298 Returns: 1299 A pair (outputs, state) where: 1300 1301 - outputs is a length T list of outputs (one for each input), or a nested 1302 tuple of such elements. 1303 - state is the final state 1304 1305 Raises: 1306 TypeError: If `cell` is not an instance of RNNCell. 1307 ValueError: If `inputs` is `None` or an empty list, or if the input depth 1308 (column size) cannot be inferred from inputs via shape inference. 1309 """ 1310 rnn_cell_impl.assert_like_rnncell("cell", cell) 1311 if not nest.is_sequence(inputs): 1312 raise TypeError("inputs must be a sequence") 1313 if not inputs: 1314 raise ValueError("inputs must not be empty") 1315 1316 outputs = [] 1317 # Create a new scope in which the caching device is either 1318 # determined by the parent scope, or is set to place the cached 1319 # Variable using the same placement as for the rest of the RNN. 1320 with vs.variable_scope(scope or "rnn") as varscope: 1321 if _should_cache(): 1322 if varscope.caching_device is None: 1323 varscope.set_caching_device(lambda op: op.device) 1324 1325 # Obtain the first sequence of the input 1326 first_input = inputs 1327 while nest.is_sequence(first_input): 1328 first_input = first_input[0] 1329 1330 # Temporarily avoid EmbeddingWrapper and seq2seq badness 1331 # TODO(lukaszkaiser): remove EmbeddingWrapper 1332 if first_input.get_shape().rank != 1: 1333 1334 input_shape = first_input.get_shape().with_rank_at_least(2) 1335 fixed_batch_size = input_shape.dims[0] 1336 1337 flat_inputs = nest.flatten(inputs) 1338 for flat_input in flat_inputs: 1339 input_shape = flat_input.get_shape().with_rank_at_least(2) 1340 batch_size, input_size = tensor_shape.dimension_at_index( 1341 input_shape, 0), input_shape[1:] 1342 fixed_batch_size.assert_is_compatible_with(batch_size) 1343 for i, size in enumerate(input_size.dims): 1344 if tensor_shape.dimension_value(size) is None: 1345 raise ValueError( 1346 "Input size (dimension %d of inputs) must be accessible via " 1347 "shape inference, but saw value None." % i) 1348 else: 1349 fixed_batch_size = first_input.get_shape().with_rank_at_least(1)[0] 1350 1351 if tensor_shape.dimension_value(fixed_batch_size): 1352 batch_size = tensor_shape.dimension_value(fixed_batch_size) 1353 else: 1354 batch_size = array_ops.shape(first_input)[0] 1355 if initial_state is not None: 1356 state = initial_state 1357 else: 1358 if not dtype: 1359 raise ValueError("If no initial_state is provided, " 1360 "dtype must be specified") 1361 if getattr(cell, "get_initial_state", None) is not None: 1362 state = cell.get_initial_state( 1363 inputs=None, batch_size=batch_size, dtype=dtype) 1364 else: 1365 state = cell.zero_state(batch_size, dtype) 1366 1367 if sequence_length is not None: # Prepare variables 1368 sequence_length = ops.convert_to_tensor( 1369 sequence_length, name="sequence_length") 1370 if sequence_length.get_shape().rank not in (None, 1): 1371 raise ValueError( 1372 "sequence_length must be a vector of length batch_size") 1373 1374 def _create_zero_output(output_size): 1375 # convert int to TensorShape if necessary 1376 size = _concat(batch_size, output_size) 1377 output = array_ops.zeros( 1378 array_ops.stack(size), _infer_state_dtype(dtype, state)) 1379 shape = _concat( 1380 tensor_shape.dimension_value(fixed_batch_size), 1381 output_size, 1382 static=True) 1383 output.set_shape(tensor_shape.TensorShape(shape)) 1384 return output 1385 1386 output_size = cell.output_size 1387 flat_output_size = nest.flatten(output_size) 1388 flat_zero_output = tuple( 1389 _create_zero_output(size) for size in flat_output_size) 1390 zero_output = nest.pack_sequence_as( 1391 structure=output_size, flat_sequence=flat_zero_output) 1392 1393 sequence_length = math_ops.cast(sequence_length, dtypes.int32) 1394 min_sequence_length = math_ops.reduce_min(sequence_length) 1395 max_sequence_length = math_ops.reduce_max(sequence_length) 1396 1397 for time, input_ in enumerate(inputs): 1398 if time > 0: 1399 varscope.reuse_variables() 1400 # pylint: disable=cell-var-from-loop 1401 call_cell = lambda: cell(input_, state) 1402 # pylint: enable=cell-var-from-loop 1403 if sequence_length is not None: 1404 (output, state) = _rnn_step( 1405 time=time, 1406 sequence_length=sequence_length, 1407 min_sequence_length=min_sequence_length, 1408 max_sequence_length=max_sequence_length, 1409 zero_output=zero_output, 1410 state=state, 1411 call_cell=call_cell, 1412 state_size=cell.state_size) 1413 else: 1414 (output, state) = call_cell() 1415 outputs.append(output) 1416 1417 return (outputs, state) 1418 1419 1420@deprecation.deprecated(None, 1421 "Please use `keras.layers.RNN(cell, stateful=True)`, " 1422 "which is equivalent to this API") 1423@tf_export(v1=["nn.static_state_saving_rnn"]) 1424@dispatch.add_dispatch_support 1425def static_state_saving_rnn(cell, 1426 inputs, 1427 state_saver, 1428 state_name, 1429 sequence_length=None, 1430 scope=None): 1431 """RNN that accepts a state saver for time-truncated RNN calculation. 1432 1433 Args: 1434 cell: An instance of `RNNCell`. 1435 inputs: A length T list of inputs, each a `Tensor` of shape `[batch_size, 1436 input_size]`. 1437 state_saver: A state saver object with methods `state` and `save_state`. 1438 state_name: Python string or tuple of strings. The name to use with the 1439 state_saver. If the cell returns tuples of states (i.e., `cell.state_size` 1440 is a tuple) then `state_name` should be a tuple of strings having the same 1441 length as `cell.state_size`. Otherwise it should be a single string. 1442 sequence_length: (optional) An int32/int64 vector size [batch_size]. See the 1443 documentation for rnn() for more details about sequence_length. 1444 scope: VariableScope for the created subgraph; defaults to "rnn". 1445 1446 Returns: 1447 A pair (outputs, state) where: 1448 outputs is a length T list of outputs (one for each input) 1449 states is the final state 1450 1451 Raises: 1452 TypeError: If `cell` is not an instance of RNNCell. 1453 ValueError: If `inputs` is `None` or an empty list, or if the arity and 1454 type of `state_name` does not match that of `cell.state_size`. 1455 """ 1456 state_size = cell.state_size 1457 state_is_tuple = nest.is_sequence(state_size) 1458 state_name_tuple = nest.is_sequence(state_name) 1459 1460 if state_is_tuple != state_name_tuple: 1461 raise ValueError("state_name should be the same type as cell.state_size. " 1462 "state_name: %s, cell.state_size: %s" % 1463 (str(state_name), str(state_size))) 1464 1465 if state_is_tuple: 1466 state_name_flat = nest.flatten(state_name) 1467 state_size_flat = nest.flatten(state_size) 1468 1469 if len(state_name_flat) != len(state_size_flat): 1470 raise ValueError("#elems(state_name) != #elems(state_size): %d vs. %d" % 1471 (len(state_name_flat), len(state_size_flat))) 1472 1473 initial_state = nest.pack_sequence_as( 1474 structure=state_size, 1475 flat_sequence=[state_saver.state(s) for s in state_name_flat]) 1476 else: 1477 initial_state = state_saver.state(state_name) 1478 1479 (outputs, state) = static_rnn( 1480 cell, 1481 inputs, 1482 initial_state=initial_state, 1483 sequence_length=sequence_length, 1484 scope=scope) 1485 1486 if state_is_tuple: 1487 flat_state = nest.flatten(state) 1488 state_name = nest.flatten(state_name) 1489 save_state = [ 1490 state_saver.save_state(name, substate) 1491 for name, substate in zip(state_name, flat_state) 1492 ] 1493 else: 1494 save_state = [state_saver.save_state(state_name, state)] 1495 1496 with ops.control_dependencies(save_state): 1497 last_output = outputs[-1] 1498 flat_last_output = nest.flatten(last_output) 1499 flat_last_output = [ 1500 array_ops.identity(output) for output in flat_last_output 1501 ] 1502 outputs[-1] = nest.pack_sequence_as( 1503 structure=last_output, flat_sequence=flat_last_output) 1504 1505 if state_is_tuple: 1506 state = nest.pack_sequence_as( 1507 structure=state, 1508 flat_sequence=[array_ops.identity(s) for s in flat_state]) 1509 else: 1510 state = array_ops.identity(state) 1511 1512 return (outputs, state) 1513 1514 1515@deprecation.deprecated(None, "Please use `keras.layers.Bidirectional(" 1516 "keras.layers.RNN(cell, unroll=True))`, which is " 1517 "equivalent to this API") 1518@tf_export(v1=["nn.static_bidirectional_rnn"]) 1519@dispatch.add_dispatch_support 1520def static_bidirectional_rnn(cell_fw, 1521 cell_bw, 1522 inputs, 1523 initial_state_fw=None, 1524 initial_state_bw=None, 1525 dtype=None, 1526 sequence_length=None, 1527 scope=None): 1528 """Creates a bidirectional recurrent neural network. 1529 1530 Similar to the unidirectional case above (rnn) but takes input and builds 1531 independent forward and backward RNNs with the final forward and backward 1532 outputs depth-concatenated, such that the output will have the format 1533 [time][batch][cell_fw.output_size + cell_bw.output_size]. The input_size of 1534 forward and backward cell must match. The initial state for both directions 1535 is zero by default (but can be set optionally) and no intermediate states are 1536 ever returned -- the network is fully unrolled for the given (passed in) 1537 length(s) of the sequence(s) or completely unrolled if length(s) is not given. 1538 1539 Args: 1540 cell_fw: An instance of RNNCell, to be used for forward direction. 1541 cell_bw: An instance of RNNCell, to be used for backward direction. 1542 inputs: A length T list of inputs, each a tensor of shape [batch_size, 1543 input_size], or a nested tuple of such elements. 1544 initial_state_fw: (optional) An initial state for the forward RNN. This must 1545 be a tensor of appropriate type and shape `[batch_size, 1546 cell_fw.state_size]`. If `cell_fw.state_size` is a tuple, this should be a 1547 tuple of tensors having shapes `[batch_size, s] for s in 1548 cell_fw.state_size`. 1549 initial_state_bw: (optional) Same as for `initial_state_fw`, but using the 1550 corresponding properties of `cell_bw`. 1551 dtype: (optional) The data type for the initial state. Required if either 1552 of the initial states are not provided. 1553 sequence_length: (optional) An int32/int64 vector, size `[batch_size]`, 1554 containing the actual lengths for each of the sequences. 1555 scope: VariableScope for the created subgraph; defaults to 1556 "bidirectional_rnn" 1557 1558 Returns: 1559 A tuple (outputs, output_state_fw, output_state_bw) where: 1560 outputs is a length `T` list of outputs (one for each input), which 1561 are depth-concatenated forward and backward outputs. 1562 output_state_fw is the final state of the forward rnn. 1563 output_state_bw is the final state of the backward rnn. 1564 1565 Raises: 1566 TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`. 1567 ValueError: If inputs is None or an empty list. 1568 """ 1569 rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw) 1570 rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw) 1571 if not nest.is_sequence(inputs): 1572 raise TypeError("inputs must be a sequence") 1573 if not inputs: 1574 raise ValueError("inputs must not be empty") 1575 1576 with vs.variable_scope(scope or "bidirectional_rnn"): 1577 # Forward direction 1578 with vs.variable_scope("fw") as fw_scope: 1579 output_fw, output_state_fw = static_rnn( 1580 cell_fw, 1581 inputs, 1582 initial_state_fw, 1583 dtype, 1584 sequence_length, 1585 scope=fw_scope) 1586 1587 # Backward direction 1588 with vs.variable_scope("bw") as bw_scope: 1589 reversed_inputs = _reverse_seq(inputs, sequence_length) 1590 tmp, output_state_bw = static_rnn( 1591 cell_bw, 1592 reversed_inputs, 1593 initial_state_bw, 1594 dtype, 1595 sequence_length, 1596 scope=bw_scope) 1597 1598 output_bw = _reverse_seq(tmp, sequence_length) 1599 # Concat each of the forward/backward outputs 1600 flat_output_fw = nest.flatten(output_fw) 1601 flat_output_bw = nest.flatten(output_bw) 1602 1603 flat_outputs = tuple( 1604 array_ops.concat([fw, bw], 1) 1605 for fw, bw in zip(flat_output_fw, flat_output_bw)) 1606 1607 outputs = nest.pack_sequence_as( 1608 structure=output_fw, flat_sequence=flat_outputs) 1609 1610 return (outputs, output_state_fw, output_state_bw) 1611