1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""CTC (Connectionist Temporal Classification) Operations.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from tensorflow.python.framework import constant_op 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import function 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import sparse_tensor 27from tensorflow.python.framework import tensor_shape 28 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import functional_ops 31from tensorflow.python.ops import gen_ctc_ops 32from tensorflow.python.ops import inplace_ops 33from tensorflow.python.ops import linalg_ops 34from tensorflow.python.ops import map_fn 35from tensorflow.python.ops import math_ops 36from tensorflow.python.ops import nn_ops 37from tensorflow.python.ops import sparse_ops 38from tensorflow.python.ops.nn_grad import _BroadcastMul 39from tensorflow.python.util import deprecation 40from tensorflow.python.util import nest 41from tensorflow.python.util.tf_export import tf_export 42 43 44# pylint: disable=protected-access, invalid-name 45@tf_export(v1=["nn.ctc_loss"]) 46def ctc_loss(labels, inputs=None, sequence_length=None, 47 preprocess_collapse_repeated=False, 48 ctc_merge_repeated=True, 49 ignore_longer_outputs_than_inputs=False, time_major=True, 50 logits=None): 51 """Computes the CTC (Connectionist Temporal Classification) Loss. 52 53 This op implements the CTC loss as presented in the article: 54 55 [A. Graves, S. Fernandez, F. Gomez, J. Schmidhuber. 56 Connectionist Temporal Classification: Labeling Unsegmented Sequence Data 57 with Recurrent Neural Networks. ICML 2006, Pittsburgh, USA, 58 pp. 369-376.](http://www.cs.toronto.edu/~graves/icml_2006.pdf) 59 60 Input requirements: 61 62 ``` 63 sequence_length(b) <= time for all b 64 65 max(labels.indices(labels.indices[:, 1] == b, 2)) 66 <= sequence_length(b) for all b. 67 ``` 68 69 Notes: 70 71 This class performs the softmax operation for you, so inputs should 72 be e.g. linear projections of outputs by an LSTM. 73 74 The `inputs` Tensor's innermost dimension size, `num_classes`, represents 75 `num_labels + 1` classes, where num_labels is the number of true labels, and 76 the largest value `(num_classes - 1)` is reserved for the blank label. 77 78 For example, for a vocabulary containing 3 labels `[a, b, c]`, 79 `num_classes = 4` and the labels indexing is `{a: 0, b: 1, c: 2, blank: 3}`. 80 81 Regarding the arguments `preprocess_collapse_repeated` and 82 `ctc_merge_repeated`: 83 84 If `preprocess_collapse_repeated` is True, then a preprocessing step runs 85 before loss calculation, wherein repeated labels passed to the loss 86 are merged into single labels. This is useful if the training labels come 87 from, e.g., forced alignments and therefore have unnecessary repetitions. 88 89 If `ctc_merge_repeated` is set False, then deep within the CTC calculation, 90 repeated non-blank labels will not be merged and are interpreted 91 as individual labels. This is a simplified (non-standard) version of CTC. 92 93 Here is a table of the (roughly) expected first order behavior: 94 95 * `preprocess_collapse_repeated=False`, `ctc_merge_repeated=True` 96 97 Classical CTC behavior: Outputs true repeated classes with blanks in 98 between, and can also output repeated classes with no blanks in 99 between that need to be collapsed by the decoder. 100 101 * `preprocess_collapse_repeated=True`, `ctc_merge_repeated=False` 102 103 Never learns to output repeated classes, as they are collapsed 104 in the input labels before training. 105 106 * `preprocess_collapse_repeated=False`, `ctc_merge_repeated=False` 107 108 Outputs repeated classes with blanks in between, but generally does not 109 require the decoder to collapse/merge repeated classes. 110 111 * `preprocess_collapse_repeated=True`, `ctc_merge_repeated=True` 112 113 Untested. Very likely will not learn to output repeated classes. 114 115 The `ignore_longer_outputs_than_inputs` option allows to specify the behavior 116 of the CTCLoss when dealing with sequences that have longer outputs than 117 inputs. If true, the CTCLoss will simply return zero gradient for those 118 items, otherwise an InvalidArgument error is returned, stopping training. 119 120 Args: 121 labels: An `int32` `SparseTensor`. 122 `labels.indices[i, :] == [b, t]` means `labels.values[i]` stores 123 the id for (batch b, time t). 124 `labels.values[i]` must take on values in `[0, num_labels)`. 125 See `core/ops/ctc_ops.cc` for more details. 126 inputs: 3-D `float` `Tensor`. 127 If time_major == False, this will be a `Tensor` shaped: 128 `[batch_size, max_time, num_classes]`. 129 If time_major == True (default), this will be a `Tensor` shaped: 130 `[max_time, batch_size, num_classes]`. 131 The logits. 132 sequence_length: 1-D `int32` vector, size `[batch_size]`. 133 The sequence lengths. 134 preprocess_collapse_repeated: Boolean. Default: False. 135 If True, repeated labels are collapsed prior to the CTC calculation. 136 ctc_merge_repeated: Boolean. Default: True. 137 ignore_longer_outputs_than_inputs: Boolean. Default: False. 138 If True, sequences with longer outputs than inputs will be ignored. 139 time_major: The shape format of the `inputs` Tensors. 140 If True, these `Tensors` must be shaped `[max_time, batch_size, 141 num_classes]`. 142 If False, these `Tensors` must be shaped `[batch_size, max_time, 143 num_classes]`. 144 Using `time_major = True` (default) is a bit more efficient because it 145 avoids transposes at the beginning of the ctc_loss calculation. However, 146 most TensorFlow data is batch-major, so by this function also accepts 147 inputs in batch-major form. 148 logits: Alias for inputs. 149 150 Returns: 151 A 1-D `float` `Tensor`, size `[batch]`, containing the negative log 152 probabilities. 153 154 Raises: 155 TypeError: if labels is not a `SparseTensor`. 156 """ 157 # The second, third, etc output tensors contain the gradients. We use it in 158 # _CTCLossGrad() below. 159 if not isinstance(labels, sparse_tensor.SparseTensor): 160 raise TypeError("Expected labels (first argument) to be a SparseTensor") 161 162 # For internal calculations, we transpose to [time, batch, num_classes] 163 inputs = deprecation.deprecated_argument_lookup( 164 "logits", logits, "inputs", inputs) 165 if not time_major: 166 inputs = array_ops.transpose(inputs, [1, 0, 2]) # (B,T,N) => (T,B,N) 167 168 loss, _ = gen_ctc_ops.ctc_loss( 169 inputs, 170 labels.indices, 171 labels.values, 172 sequence_length, 173 preprocess_collapse_repeated=preprocess_collapse_repeated, 174 ctc_merge_repeated=ctc_merge_repeated, 175 ignore_longer_outputs_than_inputs=ignore_longer_outputs_than_inputs) 176 177 return loss 178 179 180# pylint: disable=unused-argument 181@ops.RegisterGradient("CTCLoss") 182def _CTCLossGrad(op, grad_loss, _): 183 """The derivative provided by CTC Loss. 184 185 Args: 186 op: the CTCLoss op. 187 grad_loss: The backprop for cost. 188 189 Returns: 190 The CTC Loss gradient. 191 """ 192 # Outputs are: loss, grad 193 # 194 # Currently there is no way to take the second derivative of this op 195 # due to the fused implementation's interaction with tf.gradients(), 196 # so we make sure we prevent silently incorrect results by raising 197 # an error if the second derivative is requested via prevent_gradient. 198 grad_without_gradient = array_ops.prevent_gradient( 199 op.outputs[1], message="Currently there is no way to take the second " 200 " derivative of ctc_loss due to the fused implementation's interaction " 201 " with tf.gradients()") 202 # Return gradient for inputs and None for 203 # labels_indices, labels_values and sequence_length 204 return [_BroadcastMul(grad_loss, grad_without_gradient), None, None, None] 205 206 207@tf_export("nn.ctc_greedy_decoder") 208def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True): 209 """Performs greedy decoding on the logits given in input (best path). 210 211 Note: Regardless of the value of merge_repeated, if the maximum index of a 212 given time and batch corresponds to the blank index `(num_classes - 1)`, no 213 new element is emitted. 214 215 If `merge_repeated` is `True`, merge repeated classes in output. 216 This means that if consecutive logits' maximum indices are the same, 217 only the first of these is emitted. The sequence `A B B * B * B` (where '*' 218 is the blank label) becomes 219 220 * `A B B B` if `merge_repeated=True`. 221 * `A B B B B` if `merge_repeated=False`. 222 223 Args: 224 inputs: 3-D `float` `Tensor` sized 225 `[max_time, batch_size, num_classes]`. The logits. 226 sequence_length: 1-D `int32` vector containing sequence lengths, 227 having size `[batch_size]`. 228 merge_repeated: Boolean. Default: True. 229 230 Returns: 231 A tuple `(decoded, neg_sum_logits)` where 232 233 decoded: A single-element list. `decoded[0]` 234 is an `SparseTensor` containing the decoded outputs s.t.: 235 236 `decoded.indices`: Indices matrix `(total_decoded_outputs, 2)`. 237 The rows store: `[batch, time]`. 238 239 `decoded.values`: Values vector, size `(total_decoded_outputs)`. 240 The vector stores the decoded classes. 241 242 `decoded.dense_shape`: Shape vector, size `(2)`. 243 The shape values are: `[batch_size, max_decoded_length]` 244 245 neg_sum_logits: A `float` matrix `(batch_size x 1)` containing, for the 246 sequence found, the negative of the sum of the greatest logit at each 247 timeframe. 248 """ 249 outputs = gen_ctc_ops.ctc_greedy_decoder( 250 inputs, sequence_length, merge_repeated=merge_repeated) 251 (decoded_ix, decoded_val, decoded_shape, log_probabilities) = outputs 252 return ([sparse_tensor.SparseTensor(decoded_ix, decoded_val, decoded_shape)], 253 log_probabilities) 254 255 256@tf_export(v1=["nn.ctc_beam_search_decoder"]) 257def ctc_beam_search_decoder(inputs, sequence_length, beam_width=100, 258 top_paths=1, merge_repeated=True): 259 """Performs beam search decoding on the logits given in input. 260 261 **Note** The `ctc_greedy_decoder` is a special case of the 262 `ctc_beam_search_decoder` with `top_paths=1` and `beam_width=1` (but 263 that decoder is faster for this special case). 264 265 If `merge_repeated` is `True`, merge repeated classes in the output beams. 266 This means that if consecutive entries in a beam are the same, 267 only the first of these is emitted. That is, when the sequence is 268 `A B B * B * B` (where '*' is the blank label), the return value is: 269 270 * `A B` if `merge_repeated = True`. 271 * `A B B B` if `merge_repeated = False`. 272 273 Args: 274 inputs: 3-D `float` `Tensor`, size 275 `[max_time x batch_size x num_classes]`. The logits. 276 sequence_length: 1-D `int32` vector containing sequence lengths, 277 having size `[batch_size]`. 278 beam_width: An int scalar >= 0 (beam search beam width). 279 top_paths: An int scalar >= 0, <= beam_width (controls output size). 280 merge_repeated: Boolean. Default: True. 281 282 Returns: 283 A tuple `(decoded, log_probabilities)` where 284 285 decoded: A list of length top_paths, where `decoded[j]` 286 is a `SparseTensor` containing the decoded outputs: 287 288 `decoded[j].indices`: Indices matrix `(total_decoded_outputs[j] x 2)` 289 The rows store: [batch, time]. 290 291 `decoded[j].values`: Values vector, size `(total_decoded_outputs[j])`. 292 The vector stores the decoded classes for beam j. 293 294 `decoded[j].dense_shape`: Shape vector, size `(2)`. 295 The shape values are: `[batch_size, max_decoded_length[j]]`. 296 297 log_probability: A `float` matrix `(batch_size x top_paths)` containing 298 sequence log-probabilities. 299 """ 300 301 decoded_ixs, decoded_vals, decoded_shapes, log_probabilities = ( 302 gen_ctc_ops.ctc_beam_search_decoder( 303 inputs, sequence_length, beam_width=beam_width, top_paths=top_paths, 304 merge_repeated=merge_repeated)) 305 306 return ( 307 [sparse_tensor.SparseTensor(ix, val, shape) for (ix, val, shape) 308 in zip(decoded_ixs, decoded_vals, decoded_shapes)], 309 log_probabilities) 310 311 312@tf_export("nn.ctc_beam_search_decoder", v1=["nn.ctc_beam_search_decoder_v2"]) 313def ctc_beam_search_decoder_v2(inputs, sequence_length, beam_width=100, 314 top_paths=1): 315 """Performs beam search decoding on the logits given in input. 316 317 **Note** The `ctc_greedy_decoder` is a special case of the 318 `ctc_beam_search_decoder` with `top_paths=1` and `beam_width=1` (but 319 that decoder is faster for this special case). 320 321 Args: 322 inputs: 3-D `float` `Tensor`, size 323 `[max_time, batch_size, num_classes]`. The logits. 324 sequence_length: 1-D `int32` vector containing sequence lengths, 325 having size `[batch_size]`. 326 beam_width: An int scalar >= 0 (beam search beam width). 327 top_paths: An int scalar >= 0, <= beam_width (controls output size). 328 329 Returns: 330 A tuple `(decoded, log_probabilities)` where 331 332 decoded: A list of length top_paths, where `decoded[j]` 333 is a `SparseTensor` containing the decoded outputs: 334 335 `decoded[j].indices`: Indices matrix `[total_decoded_outputs[j], 2]`; 336 The rows store: `[batch, time]`. 337 338 `decoded[j].values`: Values vector, size `[total_decoded_outputs[j]]`. 339 The vector stores the decoded classes for beam `j`. 340 341 `decoded[j].dense_shape`: Shape vector, size `(2)`. 342 The shape values are: `[batch_size, max_decoded_length[j]]`. 343 344 log_probability: A `float` matrix `[batch_size, top_paths]` containing 345 sequence log-probabilities. 346 """ 347 348 # Note, merge_repeated is an invalid optimization that is removed from the 349 # public API: it returns low probability paths. 350 return ctc_beam_search_decoder(inputs, sequence_length=sequence_length, 351 beam_width=beam_width, top_paths=top_paths, 352 merge_repeated=False) 353 354 355ops.NotDifferentiable("CTCGreedyDecoder") 356ops.NotDifferentiable("CTCBeamSearchDecoder") 357 358 359def _ctc_state_trans(label_seq): 360 """Compute CTC alignment model transition matrix. 361 362 Args: 363 label_seq: tensor of shape [batch_size, max_seq_length] 364 365 Returns: 366 tensor of shape [batch_size, states, states] with a state transition matrix 367 computed for each sequence of the batch. 368 """ 369 370 with ops.name_scope("ctc_state_trans"): 371 label_seq = ops.convert_to_tensor(label_seq, name="label_seq") 372 batch_size = _get_dim(label_seq, 0) 373 num_labels = _get_dim(label_seq, 1) 374 375 num_label_states = num_labels + 1 376 num_states = 2 * num_label_states 377 378 label_states = math_ops.range(num_label_states) 379 blank_states = label_states + num_label_states 380 381 # Start state to first label. 382 start_to_label = [[1, 0]] 383 384 # Blank to label transitions. 385 blank_to_label = array_ops.stack([label_states[1:], blank_states[:-1]], 1) 386 387 # Label to blank transitions. 388 label_to_blank = array_ops.stack([blank_states, label_states], 1) 389 390 # Scatter transitions that don't depend on sequence. 391 indices = array_ops.concat( 392 [start_to_label, blank_to_label, label_to_blank], 0) 393 values = array_ops.ones([_get_dim(indices, 0)]) 394 trans = array_ops.scatter_nd( 395 indices, values, shape=[num_states, num_states]) 396 trans += linalg_ops.eye(num_states) # Self-loops. 397 398 # Label to label transitions. Disallow transitions between repeated labels 399 # with no blank state in between. 400 batch_idx = array_ops.zeros_like(label_states[2:]) 401 indices = array_ops.stack( 402 [batch_idx, label_states[2:], label_states[1:-1]], 1) 403 indices = array_ops.tile( 404 array_ops.expand_dims(indices, 0), [batch_size, 1, 1]) 405 batch_idx = array_ops.expand_dims(math_ops.range(batch_size), 1) * [1, 0, 0] 406 indices += array_ops.expand_dims(batch_idx, 1) 407 repeats = math_ops.equal(label_seq[:, :-1], label_seq[:, 1:]) 408 values = 1.0 - math_ops.cast(repeats, dtypes.float32) 409 batched_shape = [batch_size, num_states, num_states] 410 label_to_label = array_ops.scatter_nd(indices, values, batched_shape) 411 412 return array_ops.expand_dims(trans, 0) + label_to_label 413 414 415def ctc_state_log_probs(seq_lengths, max_seq_length): 416 """Computes CTC alignment initial and final state log probabilities. 417 418 Create the initial/final state values directly as log values to avoid 419 having to take a float64 log on tpu (which does not exist). 420 421 Args: 422 seq_lengths: int tensor of shape [batch_size], seq lengths in the batch. 423 max_seq_length: int, max sequence length possible. 424 425 Returns: 426 initial_state_log_probs, final_state_log_probs 427 """ 428 429 batch_size = _get_dim(seq_lengths, 0) 430 num_label_states = max_seq_length + 1 431 num_duration_states = 2 432 num_states = num_duration_states * num_label_states 433 log_0 = math_ops.cast( 434 math_ops.log(math_ops.cast(0, dtypes.float64) + 1e-307), 435 dtypes.float32) 436 437 initial_state_log_probs = array_ops.one_hot( 438 indices=array_ops.zeros([batch_size], dtype=dtypes.int32), 439 depth=num_states, 440 on_value=0.0, 441 off_value=log_0, axis=1) 442 443 label_final_state_mask = array_ops.one_hot( 444 seq_lengths, depth=num_label_states, axis=0) 445 duration_final_state_mask = array_ops.ones( 446 [num_duration_states, 1, batch_size]) 447 final_state_mask = duration_final_state_mask * label_final_state_mask 448 final_state_log_probs = (1.0 - final_state_mask) * log_0 449 final_state_log_probs = array_ops.reshape( 450 final_state_log_probs, [num_states, batch_size]) 451 452 return initial_state_log_probs, array_ops.transpose(final_state_log_probs) 453 454 455def _ilabel_to_state(labels, num_labels, ilabel_log_probs): 456 """Project ilabel log probs to state log probs.""" 457 458 num_label_states = _get_dim(labels, 1) 459 blank = ilabel_log_probs[:, :, :1] 460 blank = array_ops.tile(blank, [1, 1, num_label_states + 1]) 461 one_hot = array_ops.one_hot(labels, depth=num_labels) 462 one_hot = array_ops.expand_dims(one_hot, axis=0) 463 ilabel_log_probs = array_ops.expand_dims(ilabel_log_probs, axis=2) 464 state_log_probs = math_ops.reduce_sum(ilabel_log_probs * one_hot, axis=3) 465 state_log_probs = array_ops.concat([state_log_probs, blank], axis=2) 466 return array_ops.pad( 467 state_log_probs, [[0, 0], [0, 0], [1, 0]], 468 constant_values=math_ops.log(0.0)) 469 470 471def _state_to_olabel(labels, num_labels, states): 472 """Sum state log probs to ilabel log probs.""" 473 474 num_label_states = _get_dim(labels, 1) + 1 475 label_states = states[:, :, 1:num_label_states] 476 blank_states = states[:, :, num_label_states:] 477 one_hot = array_ops.one_hot( 478 labels - 1, depth=(num_labels - 1), 479 on_value=0.0, off_value=math_ops.log(0.0)) 480 one_hot = array_ops.expand_dims(one_hot, axis=0) 481 label_states = array_ops.expand_dims(label_states, axis=3) 482 label_olabels = math_ops.reduce_logsumexp(label_states + one_hot, axis=2) 483 blank_olabels = math_ops.reduce_logsumexp( 484 blank_states, axis=2, keepdims=True) 485 return array_ops.concat([blank_olabels, label_olabels], axis=-1) 486 487 488# pylint: disable=redefined-outer-name 489def _state_to_olabel_unique(labels, num_labels, states, unique): 490 """Sum state log probs to ilabel log probs using unique label indices.""" 491 492 num_label_states = _get_dim(labels, 1) + 1 493 label_states = states[:, :, 1:num_label_states] 494 blank_states = states[:, :, num_label_states:] 495 496 unique_y, unique_idx = unique 497 mul_reduce = _sum_states(unique_idx, label_states) 498 499 num_frames = states.shape[0] 500 batch_size = states.shape[1] 501 num_states = num_label_states - 1 502 batch_state_major = array_ops.transpose(mul_reduce, perm=[1, 2, 0]) 503 batch_state_major = array_ops.reshape( 504 batch_state_major, [batch_size * num_states, num_frames]) 505 batch_offset = math_ops.range(batch_size, dtype=unique_y.dtype) * num_labels 506 indices = unique_y + array_ops.expand_dims(batch_offset, axis=-1) 507 indices = array_ops.reshape(indices, [-1, 1]) 508 scatter = array_ops.scatter_nd( 509 indices=indices, 510 updates=batch_state_major, 511 shape=[batch_size * num_labels, num_frames]) 512 scatter = array_ops.reshape(scatter, [batch_size, num_labels, num_frames]) 513 scatter = array_ops.where( 514 math_ops.equal(scatter, 0.0), 515 array_ops.fill(array_ops.shape(scatter), math_ops.log(0.0)), 516 scatter) 517 label_olabels = array_ops.transpose(scatter, [2, 0, 1]) 518 label_olabels = label_olabels[:, :, 1:] 519 520 blank_olabels = math_ops.reduce_logsumexp( 521 blank_states, axis=2, keepdims=True) 522 523 return array_ops.concat([blank_olabels, label_olabels], axis=-1) 524 525 526def ctc_loss_and_grad(logits, labels, label_length, logit_length, unique=None): 527 """Computes the CTC loss and gradients. 528 529 Most users will want fwd_bwd.ctc_loss 530 531 This function returns the computed gradient, it does not have a gradient 532 of its own defined. 533 534 Args: 535 logits: tensor of shape [frames, batch_size, num_labels] 536 labels: tensor of shape [batch_size, max_label_seq_length] 537 label_length: tensor of shape [batch_size] 538 Length of reference label sequence in labels. 539 logit_length: tensor of shape [batch_size] 540 Length of input sequence in logits. 541 unique: (optional) unique label indices as computed by unique(labels) 542 If supplied, enables an implementation that is faster and more memory 543 efficient on TPU. 544 545 Returns: 546 loss: tensor of shape [batch_size] 547 gradient: tensor of shape [frames, batch_size, num_labels] 548 """ 549 550 num_labels = _get_dim(logits, 2) 551 max_label_seq_length = _get_dim(labels, 1) 552 553 ilabel_log_probs = nn_ops.log_softmax(logits) 554 state_log_probs = _ilabel_to_state(labels, num_labels, ilabel_log_probs) 555 state_trans_probs = _ctc_state_trans(labels) 556 initial_state_log_probs, final_state_log_probs = ctc_state_log_probs( 557 label_length, max_label_seq_length) 558 fwd_bwd_log_probs, log_likelihood = _forward_backward_log( 559 state_trans_log_probs=math_ops.log(state_trans_probs), 560 initial_state_log_probs=initial_state_log_probs, 561 final_state_log_probs=final_state_log_probs, 562 observed_log_probs=state_log_probs, 563 sequence_length=logit_length) 564 565 if unique: 566 olabel_log_probs = _state_to_olabel_unique( 567 labels, num_labels, fwd_bwd_log_probs, unique) 568 else: 569 olabel_log_probs = _state_to_olabel(labels, num_labels, fwd_bwd_log_probs) 570 571 grad = math_ops.exp(ilabel_log_probs) - math_ops.exp(olabel_log_probs) 572 loss = -log_likelihood 573 return loss, grad 574 575 576def _ctc_loss_grad(op, grad_loss, _): 577 grad = op.outputs[1] 578 grad = [array_ops.reshape(grad_loss, [1, -1, 1]) * grad] 579 grad += [None] * (len(op.inputs) - len(grad)) 580 return grad 581 582 583def _ctc_loss_shape(op): 584 return [op.inputs[2].get_shape(), op.inputs[0].get_shape()] 585 586 587@tf_export("nn.ctc_loss", v1=["nn.ctc_loss_v2"]) 588def ctc_loss_v2(labels, logits, label_length, logit_length, 589 logits_time_major=True, unique=None, 590 blank_index=None, name=None): 591 """Computes CTC (Connectionist Temporal Classification) loss. 592 593 This op implements the CTC loss as presented in the article: 594 595 [A. Graves, S. Fernandez, F. Gomez, J. Schmidhuber. 596 Connectionist Temporal Classification: Labeling Unsegmented Sequence Data 597 with Recurrent Neural Networks. ICML 2006, Pittsburgh, USA, 598 pp. 369-376.](http://www.cs.toronto.edu/~graves/icml_2006.pdf) 599 600 Notes: 601 - Same as the "Classic CTC" in TensorFlow 1.x's tf.nn.ctc_loss setting of 602 preprocess_collapse_repeated=False, ctc_merge_repeated=True 603 - Labels may be supplied as either a dense, zero-padded tensor with a 604 vector of label sequence lengths OR as a SparseTensor. 605 - On TPU and GPU: 606 - Only dense padded labels are supported. 607 - On CPU: 608 - Caller may use SparseTensor or dense padded labels but calling with 609 a SparseTensor will be significantly faster. 610 - Default blank label is 0 rather num_classes - 1, unless overridden by 611 blank_index. 612 613 Args: 614 labels: tensor of shape [batch_size, max_label_seq_length] or SparseTensor 615 logits: tensor of shape [frames, batch_size, num_labels], 616 if logits_time_major == False, shape is [batch_size, frames, num_labels]. 617 label_length: tensor of shape [batch_size], None if labels is SparseTensor 618 Length of reference label sequence in labels. 619 logit_length: tensor of shape [batch_size] 620 Length of input sequence in logits. 621 logits_time_major: (optional) If True (default), logits is shaped 622 [time, batch, logits]. If False, shape is [batch, time, logits] 623 unique: (optional) Unique label indices as computed by 624 ctc_unique_labels(labels). If supplied, enable a faster, memory 625 efficient implementation on TPU. 626 blank_index: (optional) Set the class index to use for the blank label. 627 Negative values will start from num_classes, ie, -1 will reproduce the 628 ctc_loss behavior of using num_classes - 1 for the blank symbol. 629 There is some memory/performance overhead to switching from the default 630 of 0 as an additional shifted copy of the logits may be created. 631 name: A name for this `Op`. Defaults to "ctc_loss_dense". 632 633 Returns: 634 loss: tensor of shape [batch_size], negative log probabilities. 635 """ 636 if isinstance(labels, sparse_tensor.SparseTensor): 637 if blank_index is None: 638 raise ValueError( 639 "blank_index must be given when using SparseTensor labels.") 640 641 if blank_index < 0: 642 blank_index += _get_dim(logits, 2) 643 644 if blank_index != _get_dim(logits, 2) - 1: 645 logits = array_ops.concat([ 646 logits[:, :, :blank_index], 647 logits[:, :, blank_index+1:], 648 logits[:, :, blank_index:blank_index+1], 649 ], axis=2) 650 labels = sparse_tensor.SparseTensor( 651 labels.indices, 652 array_ops.where(labels.values < blank_index, 653 labels.values, 654 labels.values - 1), 655 labels.dense_shape) 656 657 return ctc_loss(labels=labels, 658 inputs=logits, 659 sequence_length=logit_length, 660 time_major=logits_time_major) 661 662 if blank_index is None: 663 blank_index = 0 664 665 return ctc_loss_dense(labels=labels, 666 logits=logits, 667 label_length=label_length, 668 logit_length=logit_length, 669 logits_time_major=logits_time_major, 670 unique=unique, 671 blank_index=blank_index, 672 name=name) 673 674 675def ctc_loss_dense(labels, logits, label_length, logit_length, 676 logits_time_major=True, unique=None, 677 blank_index=0, name=None): 678 """Computes CTC (Connectionist Temporal Classification) loss. 679 680 This op implements the CTC loss as presented in the article: 681 682 [A. Graves, S. Fernandez, F. Gomez, J. Schmidhuber. 683 Connectionist Temporal Classification: Labeling Unsegmented Sequence Data 684 with Recurrent Neural Networks. ICML 2006, Pittsburgh, USA, 685 pp. 369-376.](http://www.cs.toronto.edu/~graves/icml_2006.pdf) 686 687 Using the batched forward backward algorithm described in: 688 689 [Sim, K. C., Narayanan, A., Bagby, T., Sainath, T. N., & Bacchiani, M. 690 Improving the efficiency of forward-backward algorithm using batched 691 computation in TensorFlow. 692 Automatic Speech Recognition and Understanding Workshop (ASRU), 693 2017 IEEE (pp. 258-264). 694 ](https://ieeexplore.ieee.org/iel7/8260578/8268903/08268944.pdf) 695 696 Notes: 697 Significant differences from tf.nn.ctc_loss: 698 Supports GPU and TPU (tf.nn.ctc_loss supports CPU only): 699 For batched operations, GPU and TPU are significantly faster than using 700 ctc_loss on CPU. 701 This implementation runs on CPU, but significantly slower than ctc_loss. 702 Blank label is 0 rather num_classes - 1, unless overridden by blank_index. 703 Logits and labels are dense arrays with padding rather than SparseTensor. 704 The only mode supported is the same as: 705 preprocess_collapse_repeated=False, ctc_merge_repeated=True 706 To collapse labels, the caller can preprocess label sequence first. 707 708 The dense implementation supports both CPU, GPU and TPU. A fast path is 709 provided that significantly improves memory use for large vocabulary if the 710 caller preprocesses label sequences to get unique label indices on the CPU 711 (eg. in the data input pipeline) using ctc_ops.unique and simplies this in 712 the optional "unique" kwarg. This is especially useful for TPU and GPU but 713 also works with if used on CPU. 714 715 Args: 716 labels: tensor of shape [batch_size, max_label_seq_length] 717 logits: tensor of shape [frames, batch_size, num_labels], 718 if logits_time_major == False, shape is [batch_size, frames, num_labels]. 719 label_length: tensor of shape [batch_size] 720 Length of reference label sequence in labels. 721 logit_length: tensor of shape [batch_size] 722 Length of input sequence in logits. 723 logits_time_major: (optional) If True (default), logits is shaped 724 [time, batch, logits]. If False, shape is [batch, time, logits] 725 unique: (optional) Unique label indices as computed by unique(labels). 726 If supplied, enable a faster, memory efficient implementation on TPU. 727 blank_index: (optional) Set the class index to use for the blank label. 728 Negative values will start from num_classes, ie, -1 will reproduce the 729 ctc_loss behavior of using num_classes - 1 for the blank symbol. 730 There is some memory/performance overhead to switching from the default 731 of 0 as an additional shifted copy of the logits may be created. 732 name: A name for this `Op`. Defaults to "ctc_loss_dense". 733 734 Returns: 735 loss: tensor of shape [batch_size], negative log probabilities. 736 """ 737 738 with ops.name_scope(name, "ctc_loss_dense", 739 [logits, labels, label_length, logit_length]): 740 logits = ops.convert_to_tensor(logits, name="logits") 741 labels = ops.convert_to_tensor(labels, name="labels") 742 label_length = ops.convert_to_tensor(label_length, name="label_length") 743 logit_length = ops.convert_to_tensor(logit_length, name="logit_length") 744 745 if not logits_time_major: 746 logits = array_ops.transpose(logits, perm=[1, 0, 2]) 747 748 if blank_index != 0: 749 if blank_index < 0: 750 blank_index += _get_dim(logits, 2) 751 logits = array_ops.concat([ 752 logits[:, :, blank_index:blank_index+1], 753 logits[:, :, :blank_index], 754 logits[:, :, blank_index+1:], 755 ], axis=2) 756 labels = array_ops.where(labels < blank_index, labels + 1, labels) 757 758 args = [logits, labels, label_length, logit_length] 759 760 if unique: 761 unique_y, unique_idx = unique 762 args.extend([unique_y, unique_idx]) 763 764 # TODO(tombagby): Update to tfe.defun 765 @function.Defun(*[x.dtype for x in args], 766 python_grad_func=_ctc_loss_grad, 767 shape_func=_ctc_loss_shape) 768 def compute_ctc_loss(logits_t, labels_t, label_length_t, logit_length_t, 769 *unique_t): 770 """Compute CTC loss.""" 771 logits_t.set_shape(logits.shape) 772 labels_t.set_shape(labels.shape) 773 label_length_t.set_shape(label_length.shape) 774 logit_length_t.set_shape(logit_length.shape) 775 kwargs = dict( 776 logits=logits_t, 777 labels=labels_t, 778 label_length=label_length_t, 779 logit_length=logit_length_t) 780 if unique_t: 781 kwargs["unique"] = unique_t 782 return ctc_loss_and_grad(**kwargs) 783 784 return compute_ctc_loss(*args)[0] 785 786 787@tf_export("nn.collapse_repeated") 788def collapse_repeated(labels, seq_length, name=None): 789 """Merge repeated labels into single labels. 790 791 Args: 792 labels: Tensor of shape (batch, max value in seq_length) 793 seq_length: Tensor of shape (batch), sequence length of each batch element. 794 name: A name for this `Op`. Defaults to "collapse_repeated_labels". 795 796 Returns: 797 tuple of Tensor of shape (batch, max_seq_length) with repeated labels 798 collapsed and padded to max_seq_length, eg: 799 [[A, A, B, B, A], 800 [A, B, C, D, E]] => [[A, B, A, 0, 0], 801 [A, B, C, D, E]] 802 and int tensor of shape [batch] with new sequence lengths. 803 """ 804 805 with ops.name_scope(name, "collapse_repeated_labels", 806 [labels, seq_length]): 807 labels = ops.convert_to_tensor(labels, name="labels") 808 seq_length = ops.convert_to_tensor(seq_length, name="seq_length") 809 810 # Mask labels that don't equal previous label. 811 label_mask = array_ops.concat( 812 [array_ops.ones_like(labels[:, :1], dtypes.bool), 813 math_ops.not_equal(labels[:, 1:], labels[:, :-1])], 814 axis=1) 815 816 # Filter labels that aren't in the original sequence. 817 maxlen = _get_dim(labels, 1) 818 seq_mask = array_ops.sequence_mask(seq_length, maxlen=maxlen) 819 label_mask = math_ops.logical_and(label_mask, seq_mask) 820 821 # Count masks for new sequence lengths. 822 new_seq_len = math_ops.reduce_sum( 823 math_ops.cast(label_mask, dtypes.int32), axis=1) 824 825 # Mask indexes based on sequence length mask. 826 new_maxlen = math_ops.reduce_max(new_seq_len) 827 idx_mask = array_ops.sequence_mask(new_seq_len, maxlen=new_maxlen) 828 829 # Flatten everything and mask out labels to keep and sparse indices. 830 flat_labels = array_ops.reshape(labels, [-1]) 831 flat_label_mask = array_ops.reshape(label_mask, [-1]) 832 flat_idx_mask = array_ops.reshape(idx_mask, [-1]) 833 idx = math_ops.range(_get_dim(flat_idx_mask, 0)) 834 835 # Scatter to flat shape. 836 flat = array_ops.scatter_nd( 837 indices=array_ops.expand_dims( 838 array_ops.boolean_mask(idx, flat_idx_mask), axis=1), 839 updates=array_ops.boolean_mask(flat_labels, flat_label_mask), 840 shape=array_ops.shape(flat_idx_mask)) 841 842 # Reshape back to square batch. 843 batch_size = _get_dim(labels, 0) 844 new_shape = [batch_size, new_maxlen] 845 return (array_ops.reshape(flat, new_shape), 846 math_ops.cast(new_seq_len, seq_length.dtype)) 847 848 849def dense_labels_to_sparse(dense, length): 850 """Convert dense labels with sequence lengths to sparse tensor. 851 852 Args: 853 dense: tensor of shape [batch, max_length] 854 length: int tensor of shape [batch] 855 The length of each sequence in dense. 856 857 Returns: 858 tf.SparseTensor with values only for the valid elements of sequences. 859 """ 860 861 flat_values = array_ops.reshape(dense, [-1]) 862 flat_indices = math_ops.range( 863 array_ops.shape(flat_values, out_type=dtypes.int64)[0]) 864 mask = array_ops.sequence_mask(length, maxlen=array_ops.shape(dense)[1]) 865 flat_mask = array_ops.reshape(mask, [-1]) 866 indices = array_ops.expand_dims( 867 array_ops.boolean_mask(flat_indices, flat_mask), 1) 868 values = array_ops.boolean_mask(flat_values, flat_mask) 869 sparse = sparse_tensor.SparseTensor( 870 indices=indices, values=math_ops.cast(values, dtypes.int32), 871 dense_shape=array_ops.shape(flat_values, out_type=dtypes.int64)) 872 reshaped = sparse_ops.sparse_reshape(sparse, array_ops.shape(dense)) 873 max_length = math_ops.reduce_max(length) 874 return sparse_tensor.SparseTensor( 875 indices=reshaped.indices, 876 values=reshaped.values, 877 dense_shape=[ 878 math_ops.cast(reshaped.dense_shape[0], dtypes.int64), 879 math_ops.cast(max_length, dtypes.int64)]) 880 881 882@tf_export("nn.ctc_unique_labels") 883def ctc_unique_labels(labels, name=None): 884 """Get unique labels and indices for batched labels for tf.nn.ctc_loss. 885 886 For use with tf.nn.ctc_loss_v2 optional argument `unique`: This op can be 887 used to preprocess labels in input pipeline to for better speed/memory use 888 computing the ctc loss on TPU. 889 890 Example: 891 ctc_unique_labels([[3, 4, 4, 3]]) -> 892 unique labels padded with 0: [[3, 4, 0, 0]] 893 indices of original labels in unique: [0, 1, 1, 0] 894 895 Args: 896 labels: tensor of shape [batch_size, max_label_length] padded with 0. 897 name: A name for this `Op`. Defaults to "ctc_unique_labels". 898 899 Returns: 900 tuple of 901 - unique labels, tensor of shape `[batch_size, max_label_length]` 902 - indices into unique labels, shape `[batch_size, max_label_length]` 903 """ 904 905 with ops.name_scope(name, "ctc_unique_labels", [labels]): 906 labels = ops.convert_to_tensor(labels, name="labels") 907 def _unique(x): 908 u = array_ops.unique(x) 909 y = array_ops.pad( 910 u.y, [[0, _get_dim(u.idx, 0) - _get_dim(u.y, 0)]]) 911 y = math_ops.cast(y, dtypes.int64) 912 return [y, u.idx] 913 return map_fn.map_fn( 914 _unique, labels, dtype=[dtypes.int64, dtypes.int32]) 915 916 917def _sum_states(idx, states): 918 """Take logsumexp for each unique state out of all label states. 919 920 Args: 921 idx: tensor of shape [batch, label_length] 922 For each sequence, indices into a set of unique labels as computed by 923 calling unique. 924 states: tensor of shape [frames, batch, label_length] 925 Log probabilities for each label state. 926 927 Returns: 928 tensor of shape [frames, batch_size, label_length], log probabilites summed 929 for each unique label of the sequence. 930 """ 931 932 with ops.name_scope("sum_states"): 933 idx = ops.convert_to_tensor(idx, name="idx") 934 num_states = _get_dim(states, 2) 935 states = array_ops.expand_dims(states, axis=2) 936 one_hot = array_ops.one_hot( 937 idx, depth=num_states, on_value=0.0, off_value=math_ops.log(0.0), 938 axis=1) 939 return math_ops.reduce_logsumexp(states + one_hot, axis=-1) 940 941 942def _forward_backward_log(state_trans_log_probs, initial_state_log_probs, 943 final_state_log_probs, observed_log_probs, 944 sequence_length): 945 """Forward-backward algorithm computed in log domain. 946 947 Args: 948 state_trans_log_probs: tensor of shape [states, states] or 949 if different transition matrix per batch [batch_size, states, states] 950 initial_state_log_probs: tensor of shape [batch_size, states] 951 final_state_log_probs: tensor of shape [batch_size, states] 952 observed_log_probs: tensor of shape [frames, batch_size, states] 953 sequence_length: tensor of shape [batch_size] 954 955 Returns: 956 forward backward log probabilites: tensor of shape [frames, batch, states] 957 log_likelihood: tensor of shape [batch_size] 958 959 Raises: 960 ValueError: If state_trans_log_probs has unknown or incorrect rank. 961 """ 962 963 if state_trans_log_probs.shape.ndims == 2: 964 perm = [1, 0] 965 elif state_trans_log_probs.shape.ndims == 3: 966 perm = [0, 2, 1] 967 else: 968 raise ValueError( 969 "state_trans_log_probs rank must be known and == 2 or 3, is: %s" % 970 state_trans_log_probs.shape.ndims) 971 972 bwd_state_trans_log_probs = array_ops.transpose(state_trans_log_probs, perm) 973 batch_size = _get_dim(observed_log_probs, 1) 974 975 def _forward(state_log_prob, obs_log_prob): 976 state_log_prob = array_ops.expand_dims(state_log_prob, axis=1) # Broadcast. 977 state_log_prob += state_trans_log_probs 978 state_log_prob = math_ops.reduce_logsumexp(state_log_prob, axis=-1) 979 state_log_prob += obs_log_prob 980 log_prob_sum = math_ops.reduce_logsumexp( 981 state_log_prob, axis=-1, keepdims=True) 982 state_log_prob -= log_prob_sum 983 return state_log_prob 984 985 fwd = _scan(_forward, observed_log_probs, initial_state_log_probs, 986 inclusive=True) 987 988 def _backward(accs, elems): 989 """Calculate log probs and cumulative sum masked for sequence length.""" 990 state_log_prob, cum_log_sum = accs 991 obs_log_prob, mask = elems 992 state_log_prob += obs_log_prob 993 state_log_prob = array_ops.expand_dims(state_log_prob, axis=1) # Broadcast. 994 state_log_prob += bwd_state_trans_log_probs 995 state_log_prob = math_ops.reduce_logsumexp(state_log_prob, axis=-1) 996 997 log_prob_sum = math_ops.reduce_logsumexp( 998 state_log_prob, axis=-1, keepdims=True) 999 state_log_prob -= log_prob_sum 1000 1001 cum_log_sum += array_ops.squeeze(log_prob_sum) * mask 1002 batched_mask = array_ops.expand_dims(mask, axis=1) 1003 out = state_log_prob * batched_mask 1004 out += final_state_log_probs * (1.0 - batched_mask) 1005 return out, cum_log_sum 1006 1007 zero_log_sum = array_ops.zeros([batch_size]) 1008 maxlen = _get_dim(observed_log_probs, 0) 1009 mask = array_ops.sequence_mask(sequence_length, maxlen, dtypes.float32) 1010 mask = array_ops.transpose(mask, perm=[1, 0]) 1011 1012 bwd, cum_log_sum = _scan(_backward, (observed_log_probs, mask), 1013 (final_state_log_probs, zero_log_sum), 1014 reverse=True, inclusive=True) 1015 1016 fwd_bwd_log_probs = fwd[1:] + bwd[1:] 1017 fwd_bwd_log_probs_sum = math_ops.reduce_logsumexp( 1018 fwd_bwd_log_probs, axis=2, keepdims=True) 1019 fwd_bwd_log_probs -= fwd_bwd_log_probs_sum 1020 fwd_bwd_log_probs += math_ops.log(array_ops.expand_dims(mask, axis=2)) 1021 1022 log_likelihood = bwd[0, :, 0] + cum_log_sum[0] 1023 1024 return fwd_bwd_log_probs, log_likelihood 1025 1026 1027# TODO(tombagby): This is currently faster for the ctc implementation than using 1028# functional_ops.scan, but could be replaced by that or something similar if 1029# things change. 1030def _scan(fn, elems, initial, reverse=False, inclusive=False, final_only=False): 1031 """Repeatedly applies callable `fn` to a sequence of elements. 1032 1033 Implemented by functional_ops.While, tpu friendly, no gradient. 1034 1035 This is similar to functional_ops.scan but significantly faster on tpu/gpu 1036 for the forward backward use case. 1037 1038 Examples: 1039 scan(lambda a, e: a + e, [1.0, 2.0, 3.0], 1.0) => [2.0, 4.0, 7.0] 1040 1041 Multiple accumulators: 1042 scan(lambda a, e: (a[0] + e, a[1] * e), [1.0, 2.0, 3.0], (0.0, 1.0)) 1043 1044 Multiple inputs: 1045 scan(lambda a, e: a + (e[0] * e[1]), (elems1, elems2), 0.0) 1046 1047 Args: 1048 fn: callable, fn(accumulators, element) return new accumulator values. 1049 The (possibly nested) sequence of accumulators is the same as `initial` 1050 and the return value must have the same structure. 1051 elems: A (possibly nested) tensor which will be unpacked along the first 1052 dimension. The resulting slices will be the second argument to fn. The 1053 first dimension of all nested input tensors must be the same. 1054 initial: A tensor or (possibly nested) sequence of tensors with initial 1055 values for the accumulators. 1056 reverse: (optional) True enables scan and output elems in reverse order. 1057 inclusive: (optional) True includes the initial accumulator values in the 1058 output. Length of output will be len(elem sequence) + 1. Not meaningful 1059 if final_only is True. 1060 final_only: (optional) When True, return only the final accumulated values, 1061 not the concatenation of accumulated values for each input. 1062 1063 Returns: 1064 A (possibly nested) sequence of tensors with the results of applying fn 1065 to tensors unpacked from elems and previous accumulator values. 1066 """ 1067 1068 flat_elems = [ops.convert_to_tensor(x) for x in nest.flatten(elems)] 1069 num_elems = array_ops.shape(flat_elems[0])[0] 1070 pack_elems = lambda x: nest.pack_sequence_as(structure=elems, flat_sequence=x) 1071 flat_initial = [ops.convert_to_tensor(x) for x in nest.flatten(initial)] 1072 pack = lambda x: nest.pack_sequence_as(structure=initial, flat_sequence=x) 1073 accum_dtypes = [x.dtype for x in flat_initial] 1074 num_accums = len(flat_initial) 1075 1076 # Types for counter, [outputs], [accumulators] loop arguments. 1077 if final_only: 1078 loop_dtypes = [dtypes.int32, dtypes.int32] + accum_dtypes 1079 else: 1080 loop_dtypes = [dtypes.int32, dtypes.int32] + accum_dtypes + accum_dtypes 1081 1082 # TODO(tombagby): Update to tfe.defun 1083 @function.Defun(*loop_dtypes) 1084 def cond(i, num_elems, *args): 1085 del args 1086 return i >= 0 if reverse else i < num_elems 1087 1088 # The loop *args are [output tensors] + [accumulator tensors] which must 1089 # be paired. Each output corresponds to one accumulator. 1090 @function.Defun(*loop_dtypes) 1091 def body(i, num_elems, *args): 1092 """Loop body.""" 1093 i.set_shape([]) 1094 if final_only: 1095 accum = args 1096 else: 1097 out, accum = args[:num_accums], args[num_accums:] 1098 slices = [array_ops.gather(e, i) for e in flat_elems] 1099 accum = fn(pack(accum), pack_elems(slices)) 1100 flat_accum = nest.flatten(accum) 1101 if final_only: 1102 new_out = [] 1103 else: 1104 update_i = i + 1 if inclusive and not reverse else i 1105 new_out = [inplace_ops.alias_inplace_update(x, update_i, y) 1106 for x, y in zip(out, flat_accum)] 1107 i = i - 1 if reverse else i + 1 1108 return [i, num_elems] + new_out + flat_accum 1109 1110 init_i = (array_ops.shape(flat_elems[0])[0] - 1 if reverse 1111 else constant_op.constant(0, dtype=dtypes.int32)) 1112 outputs = [] 1113 if not final_only: 1114 num_outputs = array_ops.shape(flat_elems[0])[0] + (1 if inclusive else 0) 1115 for initial_accum in flat_initial: 1116 out_shape = array_ops.concat( 1117 [[num_outputs], array_ops.shape(initial_accum)], 0) 1118 out = inplace_ops.empty(out_shape, dtype=initial_accum.dtype, init=True) 1119 if inclusive: 1120 out = inplace_ops.alias_inplace_add( 1121 out, init_i + (1 if reverse else 0), initial_accum) 1122 outputs.append(out) 1123 loop_in = [init_i, num_elems] + outputs + flat_initial 1124 hostmem = [ 1125 i for i, x in enumerate(loop_in) 1126 if x.dtype.base_dtype in (dtypes.int32, dtypes.int64) 1127 ] 1128 1129 # TODO(tombagby): Update to while_v2. 1130 loop_results = functional_ops.While(loop_in, cond, body, hostmem=hostmem) 1131 out = loop_results[2:num_accums + 2] 1132 return pack(out) 1133 1134 1135def _get_dim(tensor, i): 1136 """Get value of tensor shape[i] preferring static value if available.""" 1137 return tensor_shape.dimension_value( 1138 tensor.shape[i]) or array_ops.shape(tensor)[i] 1139