1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""CTC (Connectionist Temporal Classification) Operations."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import uuid
22
23from tensorflow.python.eager import context
24from tensorflow.python.eager import function as function_eager
25
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import device
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import function
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import sparse_tensor
32from tensorflow.python.framework import tensor_shape
33
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops import custom_gradient
36from tensorflow.python.ops import functional_ops
37from tensorflow.python.ops import gen_ctc_ops
38from tensorflow.python.ops import inplace_ops
39from tensorflow.python.ops import linalg_ops
40from tensorflow.python.ops import map_fn
41from tensorflow.python.ops import math_ops
42from tensorflow.python.ops import nn_ops
43from tensorflow.python.ops import sparse_ops
44from tensorflow.python.ops.nn_grad import _BroadcastMul
45from tensorflow.python.util import deprecation
46from tensorflow.python.util import dispatch
47from tensorflow.python.util import nest
48from tensorflow.python.util.tf_export import tf_export
49
50_DEFUN_API_NAME_ATTRIBUTE = "api_implements"
51_DEFUN_DEVICE_ATTRIBUTE = "api_preferred_device"
52_CPU_DEVICE_NAME = "CPU"
53_GPU_DEVICE_NAME = "GPU"
54
55
56def _get_context_device_type():
57  """Parse the current context and return the device type, eg CPU/GPU."""
58  current_device = context.context().device_name
59  if current_device is None:
60    return None
61  return device.DeviceSpec.from_string(current_device).device_type
62
63
64def _generate_defun_backend(unique_api_name, preferred_device, func):
65  function_attributes = {
66      _DEFUN_API_NAME_ATTRIBUTE: unique_api_name,
67      _DEFUN_DEVICE_ATTRIBUTE: preferred_device,
68  }
69  return function_eager.defun_with_attributes(
70      func=func, attributes=function_attributes, autograph=False)
71
72# pylint: disable=protected-access, invalid-name
73@tf_export(v1=["nn.ctc_loss"])
74@dispatch.add_dispatch_support
75def ctc_loss(labels,
76             inputs=None,
77             sequence_length=None,
78             preprocess_collapse_repeated=False,
79             ctc_merge_repeated=True,
80             ignore_longer_outputs_than_inputs=False,
81             time_major=True,
82             logits=None):
83  """Computes the CTC (Connectionist Temporal Classification) Loss.
84
85  This op implements the CTC loss as presented in (Graves et al., 2006).
86
87  Input requirements:
88
89  ```
90  sequence_length(b) <= time for all b
91
92  max(labels.indices(labels.indices[:, 1] == b, 2))
93    <= sequence_length(b) for all b.
94  ```
95
96  Notes:
97
98  This class performs the softmax operation for you, so inputs should
99  be e.g. linear projections of outputs by an LSTM.
100
101  The `inputs` Tensor's innermost dimension size, `num_classes`, represents
102  `num_labels + 1` classes, where num_labels is the number of true labels, and
103  the largest value `(num_classes - 1)` is reserved for the blank label.
104
105  For example, for a vocabulary containing 3 labels `[a, b, c]`,
106  `num_classes = 4` and the labels indexing is `{a: 0, b: 1, c: 2, blank: 3}`.
107
108  Regarding the arguments `preprocess_collapse_repeated` and
109  `ctc_merge_repeated`:
110
111  If `preprocess_collapse_repeated` is True, then a preprocessing step runs
112  before loss calculation, wherein repeated labels passed to the loss
113  are merged into single labels.  This is useful if the training labels come
114  from, e.g., forced alignments and therefore have unnecessary repetitions.
115
116  If `ctc_merge_repeated` is set False, then deep within the CTC calculation,
117  repeated non-blank labels will not be merged and are interpreted
118  as individual labels.  This is a simplified (non-standard) version of CTC.
119
120  Here is a table of the (roughly) expected first order behavior:
121
122  * `preprocess_collapse_repeated=False`, `ctc_merge_repeated=True`
123
124    Classical CTC behavior: Outputs true repeated classes with blanks in
125    between, and can also output repeated classes with no blanks in
126    between that need to be collapsed by the decoder.
127
128  * `preprocess_collapse_repeated=True`, `ctc_merge_repeated=False`
129
130    Never learns to output repeated classes, as they are collapsed
131    in the input labels before training.
132
133  * `preprocess_collapse_repeated=False`, `ctc_merge_repeated=False`
134
135    Outputs repeated classes with blanks in between, but generally does not
136    require the decoder to collapse/merge repeated classes.
137
138  * `preprocess_collapse_repeated=True`, `ctc_merge_repeated=True`
139
140    Untested.  Very likely will not learn to output repeated classes.
141
142  The `ignore_longer_outputs_than_inputs` option allows to specify the behavior
143  of the CTCLoss when dealing with sequences that have longer outputs than
144  inputs. If true, the CTCLoss will simply return zero gradient for those
145  items, otherwise an InvalidArgument error is returned, stopping training.
146
147  Args:
148    labels: An `int32` `SparseTensor`.
149      `labels.indices[i, :] == [b, t]` means `labels.values[i]` stores the id
150        for (batch b, time t). `labels.values[i]` must take on values in `[0,
151        num_labels)`. See `core/ops/ctc_ops.cc` for more details.
152    inputs: 3-D `float` `Tensor`.
153      If time_major == False, this will be a `Tensor` shaped: `[batch_size,
154        max_time, num_classes]`.
155      If time_major == True (default), this will be a `Tensor` shaped:
156        `[max_time, batch_size, num_classes]`. The logits.
157    sequence_length: 1-D `int32` vector, size `[batch_size]`. The sequence
158      lengths.
159    preprocess_collapse_repeated: Boolean.  Default: False. If True, repeated
160      labels are collapsed prior to the CTC calculation.
161    ctc_merge_repeated: Boolean.  Default: True.
162    ignore_longer_outputs_than_inputs: Boolean. Default: False. If True,
163      sequences with longer outputs than inputs will be ignored.
164    time_major: The shape format of the `inputs` Tensors. If True, these
165      `Tensors` must be shaped `[max_time, batch_size, num_classes]`. If False,
166      these `Tensors` must be shaped `[batch_size, max_time, num_classes]`.
167      Using `time_major = True` (default) is a bit more efficient because it
168      avoids transposes at the beginning of the ctc_loss calculation.  However,
169      most TensorFlow data is batch-major, so by this function also accepts
170      inputs in batch-major form.
171    logits: Alias for inputs.
172
173  Returns:
174    A 1-D `float` `Tensor`, size `[batch]`, containing the negative log
175      probabilities.
176
177  Raises:
178    TypeError: if labels is not a `SparseTensor`.
179
180  References:
181      Connectionist Temporal Classification - Labeling Unsegmented Sequence Data
182      with Recurrent Neural Networks:
183        [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891)
184        ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf))
185  """
186  return _ctc_loss_impl(
187      labels,
188      inputs,
189      sequence_length,
190      preprocess_collapse_repeated,
191      ctc_merge_repeated,
192      ignore_longer_outputs_than_inputs,
193      time_major,
194      logits,
195      use_cudnn=False)
196
197
198def _ctc_loss_impl(labels,
199                   inputs=None,
200                   sequence_length=None,
201                   preprocess_collapse_repeated=False,
202                   ctc_merge_repeated=True,
203                   ignore_longer_outputs_than_inputs=False,
204                   time_major=True,
205                   logits=None,
206                   use_cudnn=False):
207  # Helper function of ctc_loss with one additional param:
208  # use_cudnn: A bool to enable cuDNN CTC loss operation. If true, the blank
209  #   index has to be 0.
210
211  # The second, third, etc output tensors contain the gradients.  We use it in
212  # _CTCLossGrad() below.
213  if not isinstance(labels, sparse_tensor.SparseTensor):
214    raise TypeError("Expected labels (first argument) to be a SparseTensor")
215
216  # For internal calculations, we transpose to [time, batch, num_classes]
217  inputs = deprecation.deprecated_argument_lookup("logits", logits, "inputs",
218                                                  inputs)
219  if not time_major:
220    inputs = array_ops.transpose(inputs, [1, 0, 2])  # (B,T,N) => (T,B,N)
221
222  # gen_ctc_ops.ctc_loss_v2 differs from gen_ctc_ops.ctc_loss. v2 assumes the
223  # blank index to be 0, but v1 views it as the last index.
224  if use_cudnn:
225    ctc_loss_func = gen_ctc_ops.ctc_loss_v2
226  else:
227    ctc_loss_func = gen_ctc_ops.ctc_loss
228
229  loss, _ = ctc_loss_func(
230      inputs,
231      labels.indices,
232      labels.values,
233      sequence_length,
234      preprocess_collapse_repeated=preprocess_collapse_repeated,
235      ctc_merge_repeated=ctc_merge_repeated,
236      ignore_longer_outputs_than_inputs=ignore_longer_outputs_than_inputs)
237
238  return loss
239
240# pylint: disable=unused-argument
241def _CTCLossGradImpl(op, grad_loss, _):
242  # Outputs are: loss, grad
243  #
244  # Currently there is no way to take the second derivative of this op
245  # due to the fused implementation's interaction with tf.gradients(),
246  # so we make sure we prevent silently incorrect results by raising
247  # an error if the second derivative is requested via prevent_gradient.
248  grad_without_gradient = array_ops.prevent_gradient(
249      op.outputs[1],
250      message="Currently there is no way to take the second "
251      " derivative of ctc_loss due to the fused implementation's interaction "
252      " with tf.gradients()")
253  # Return gradient for inputs and None for
254  # labels_indices, labels_values and sequence_length
255  return [_BroadcastMul(grad_loss, grad_without_gradient), None, None, None]
256
257
258# pylint: disable=unused-argument
259@ops.RegisterGradient("CTCLoss")
260def _CTCLossGrad(op, grad_loss, _):
261  """The derivative provided by CTC Loss.
262
263  Args:
264     op: the CTCLoss op.
265     grad_loss: The backprop for cost.
266
267  Returns:
268     The CTC Loss gradient.
269  """
270  return _CTCLossGradImpl(op, grad_loss, _)
271
272
273# pylint: disable=unused-argument
274@ops.RegisterGradient("CTCLossV2")
275def _CTCLossV2Grad(op, grad_loss, _):
276  """The derivative provided by CTC Loss V2.
277
278  Args:
279     op: the CTCLossV2 op.
280     grad_loss: The backprop for cost.
281
282  Returns:
283     The CTC Loss V2 gradient.
284  """
285  return _CTCLossGradImpl(op, grad_loss, _)
286
287
288@tf_export("nn.ctc_greedy_decoder")
289@dispatch.add_dispatch_support
290def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
291  """Performs greedy decoding on the logits given in input (best path).
292
293  Note: Regardless of the value of merge_repeated, if the maximum index of a
294  given time and batch corresponds to the blank index `(num_classes - 1)`, no
295  new element is emitted.
296
297  If `merge_repeated` is `True`, merge repeated classes in output.
298  This means that if consecutive logits' maximum indices are the same,
299  only the first of these is emitted.  The sequence `A B B * B * B` (where '*'
300  is the blank label) becomes
301
302    * `A B B B` if `merge_repeated=True`.
303    * `A B B B B` if `merge_repeated=False`.
304
305  Args:
306    inputs: 3-D `float` `Tensor` sized `[max_time, batch_size, num_classes]`.
307      The logits.
308    sequence_length: 1-D `int32` vector containing sequence lengths, having size
309      `[batch_size]`.
310    merge_repeated: Boolean.  Default: True.
311
312  Returns:
313    A tuple `(decoded, neg_sum_logits)` where
314
315    decoded: A single-element list. `decoded[0]`
316      is an `SparseTensor` containing the decoded outputs s.t.:
317
318      `decoded.indices`: Indices matrix `(total_decoded_outputs, 2)`.
319        The rows store: `[batch, time]`.
320
321      `decoded.values`: Values vector, size `(total_decoded_outputs)`.
322        The vector stores the decoded classes.
323
324      `decoded.dense_shape`: Shape vector, size `(2)`.
325        The shape values are: `[batch_size, max_decoded_length]`
326
327    neg_sum_logits: A `float` matrix `(batch_size x 1)` containing, for the
328        sequence found, the negative of the sum of the greatest logit at each
329        timeframe.
330  """
331  outputs = gen_ctc_ops.ctc_greedy_decoder(
332      inputs, sequence_length, merge_repeated=merge_repeated)
333  (decoded_ix, decoded_val, decoded_shape, log_probabilities) = outputs
334  return ([sparse_tensor.SparseTensor(decoded_ix, decoded_val,
335                                      decoded_shape)], log_probabilities)
336
337
338@tf_export(v1=["nn.ctc_beam_search_decoder"])
339@dispatch.add_dispatch_support
340def ctc_beam_search_decoder(inputs,
341                            sequence_length,
342                            beam_width=100,
343                            top_paths=1,
344                            merge_repeated=True):
345  """Performs beam search decoding on the logits given in input.
346
347  **Note** The `ctc_greedy_decoder` is a special case of the
348  `ctc_beam_search_decoder` with `top_paths=1` and `beam_width=1` (but
349  that decoder is faster for this special case).
350
351  If `merge_repeated` is `True`, merge repeated classes in the output beams.
352  This means that if consecutive entries in a beam are the same,
353  only the first of these is emitted.  That is, when the sequence is
354  `A B B * B * B` (where '*' is the blank label), the return value is:
355
356    * `A B` if `merge_repeated = True`.
357    * `A B B B` if `merge_repeated = False`.
358
359  Args:
360    inputs: 3-D `float` `Tensor`, size `[max_time x batch_size x num_classes]`.
361      The logits.
362    sequence_length: 1-D `int32` vector containing sequence lengths, having size
363      `[batch_size]`.
364    beam_width: An int scalar >= 0 (beam search beam width).
365    top_paths: An int scalar >= 0, <= beam_width (controls output size).
366    merge_repeated: Boolean.  Default: True.
367
368  Returns:
369    A tuple `(decoded, log_probabilities)` where
370
371    decoded: A list of length top_paths, where `decoded[j]`
372      is a `SparseTensor` containing the decoded outputs:
373
374      `decoded[j].indices`: Indices matrix `(total_decoded_outputs[j] x 2)`
375        The rows store: [batch, time].
376
377      `decoded[j].values`: Values vector, size `(total_decoded_outputs[j])`.
378        The vector stores the decoded classes for beam j.
379
380      `decoded[j].dense_shape`: Shape vector, size `(2)`.
381        The shape values are: `[batch_size, max_decoded_length[j]]`.
382
383    log_probability: A `float` matrix `(batch_size x top_paths)` containing
384        sequence log-probabilities.
385  """
386
387  decoded_ixs, decoded_vals, decoded_shapes, log_probabilities = (
388      gen_ctc_ops.ctc_beam_search_decoder(
389          inputs,
390          sequence_length,
391          beam_width=beam_width,
392          top_paths=top_paths,
393          merge_repeated=merge_repeated))
394
395  return ([
396      sparse_tensor.SparseTensor(ix, val, shape)
397      for (ix, val, shape) in zip(decoded_ixs, decoded_vals, decoded_shapes)
398  ], log_probabilities)
399
400
401@tf_export("nn.ctc_beam_search_decoder", v1=["nn.ctc_beam_search_decoder_v2"])
402@dispatch.add_dispatch_support
403def ctc_beam_search_decoder_v2(inputs,
404                               sequence_length,
405                               beam_width=100,
406                               top_paths=1):
407  """Performs beam search decoding on the logits given in input.
408
409  **Note** The `ctc_greedy_decoder` is a special case of the
410  `ctc_beam_search_decoder` with `top_paths=1` and `beam_width=1` (but
411  that decoder is faster for this special case).
412
413  Args:
414    inputs: 3-D `float` `Tensor`, size `[max_time, batch_size, num_classes]`.
415      The logits.
416    sequence_length: 1-D `int32` vector containing sequence lengths, having size
417      `[batch_size]`.
418    beam_width: An int scalar >= 0 (beam search beam width).
419    top_paths: An int scalar >= 0, <= beam_width (controls output size).
420
421  Returns:
422    A tuple `(decoded, log_probabilities)` where
423
424    decoded: A list of length top_paths, where `decoded[j]`
425      is a `SparseTensor` containing the decoded outputs:
426
427      `decoded[j].indices`: Indices matrix `[total_decoded_outputs[j], 2]`;
428        The rows store: `[batch, time]`.
429
430      `decoded[j].values`: Values vector, size `[total_decoded_outputs[j]]`.
431        The vector stores the decoded classes for beam `j`.
432
433      `decoded[j].dense_shape`: Shape vector, size `(2)`.
434        The shape values are: `[batch_size, max_decoded_length[j]]`.
435
436    log_probability: A `float` matrix `[batch_size, top_paths]` containing
437        sequence log-probabilities.
438  """
439
440  # Note, merge_repeated is an invalid optimization that is removed from the
441  # public API: it returns low probability paths.
442  return ctc_beam_search_decoder(
443      inputs,
444      sequence_length=sequence_length,
445      beam_width=beam_width,
446      top_paths=top_paths,
447      merge_repeated=False)
448
449
450ops.NotDifferentiable("CTCGreedyDecoder")
451ops.NotDifferentiable("CTCBeamSearchDecoder")
452
453
454def _ctc_state_trans(label_seq):
455  """Compute CTC alignment model transition matrix.
456
457  Args:
458    label_seq: tensor of shape [batch_size, max_seq_length]
459
460  Returns:
461    tensor of shape [batch_size, states, states] with a state transition matrix
462    computed for each sequence of the batch.
463  """
464
465  with ops.name_scope("ctc_state_trans"):
466    label_seq = ops.convert_to_tensor(label_seq, name="label_seq")
467    batch_size = _get_dim(label_seq, 0)
468    num_labels = _get_dim(label_seq, 1)
469
470    num_label_states = num_labels + 1
471    num_states = 2 * num_label_states
472
473    label_states = math_ops.range(num_label_states)
474    blank_states = label_states + num_label_states
475
476    # Start state to first label.
477    start_to_label = [[1, 0]]
478
479    # Blank to label transitions.
480    blank_to_label = array_ops.stack([label_states[1:], blank_states[:-1]], 1)
481
482    # Label to blank transitions.
483    label_to_blank = array_ops.stack([blank_states, label_states], 1)
484
485    # Scatter transitions that don't depend on sequence.
486    indices = array_ops.concat([start_to_label, blank_to_label, label_to_blank],
487                               0)
488    values = array_ops.ones([_get_dim(indices, 0)])
489    trans = array_ops.scatter_nd(
490        indices, values, shape=[num_states, num_states])
491    trans += linalg_ops.eye(num_states)  # Self-loops.
492
493    # Label to label transitions. Disallow transitions between repeated labels
494    # with no blank state in between.
495    batch_idx = array_ops.zeros_like(label_states[2:])
496    indices = array_ops.stack([batch_idx, label_states[2:], label_states[1:-1]],
497                              1)
498    indices = array_ops.tile(
499        array_ops.expand_dims(indices, 0), [batch_size, 1, 1])
500    batch_idx = array_ops.expand_dims(math_ops.range(batch_size), 1) * [1, 0, 0]
501    indices += array_ops.expand_dims(batch_idx, 1)
502    repeats = math_ops.equal(label_seq[:, :-1], label_seq[:, 1:])
503    values = 1.0 - math_ops.cast(repeats, dtypes.float32)
504    batched_shape = [batch_size, num_states, num_states]
505    label_to_label = array_ops.scatter_nd(indices, values, batched_shape)
506
507    return array_ops.expand_dims(trans, 0) + label_to_label
508
509
510def ctc_state_log_probs(seq_lengths, max_seq_length):
511  """Computes CTC alignment initial and final state log probabilities.
512
513  Create the initial/final state values directly as log values to avoid
514  having to take a float64 log on tpu (which does not exist).
515
516  Args:
517    seq_lengths: int tensor of shape [batch_size], seq lengths in the batch.
518    max_seq_length: int, max sequence length possible.
519
520  Returns:
521    initial_state_log_probs, final_state_log_probs
522  """
523
524  batch_size = _get_dim(seq_lengths, 0)
525  num_label_states = max_seq_length + 1
526  num_duration_states = 2
527  num_states = num_duration_states * num_label_states
528  log_0 = math_ops.cast(
529      math_ops.log(math_ops.cast(0, dtypes.float64) + 1e-307), dtypes.float32)
530
531  initial_state_log_probs = array_ops.one_hot(
532      indices=array_ops.zeros([batch_size], dtype=dtypes.int32),
533      depth=num_states,
534      on_value=0.0,
535      off_value=log_0,
536      axis=1)
537
538  label_final_state_mask = array_ops.one_hot(
539      seq_lengths, depth=num_label_states, axis=0)
540  duration_final_state_mask = array_ops.ones(
541      [num_duration_states, 1, batch_size])
542  final_state_mask = duration_final_state_mask * label_final_state_mask
543  final_state_log_probs = (1.0 - final_state_mask) * log_0
544  final_state_log_probs = array_ops.reshape(final_state_log_probs,
545                                            [num_states, batch_size])
546
547  return initial_state_log_probs, array_ops.transpose(final_state_log_probs)
548
549
550def _ilabel_to_state(labels, num_labels, ilabel_log_probs):
551  """Project ilabel log probs to state log probs."""
552
553  num_label_states = _get_dim(labels, 1)
554  blank = ilabel_log_probs[:, :, :1]
555  blank = array_ops.tile(blank, [1, 1, num_label_states + 1])
556  one_hot = array_ops.one_hot(labels, depth=num_labels)
557  one_hot = array_ops.expand_dims(one_hot, axis=0)
558  ilabel_log_probs = array_ops.expand_dims(ilabel_log_probs, axis=2)
559  state_log_probs = math_ops.reduce_sum(ilabel_log_probs * one_hot, axis=3)
560  state_log_probs = array_ops.concat([state_log_probs, blank], axis=2)
561  return array_ops.pad(
562      state_log_probs, [[0, 0], [0, 0], [1, 0]],
563      constant_values=math_ops.log(0.0))
564
565
566def _state_to_olabel(labels, num_labels, states):
567  """Sum state log probs to ilabel log probs."""
568
569  num_label_states = _get_dim(labels, 1) + 1
570  label_states = states[:, :, 1:num_label_states]
571  blank_states = states[:, :, num_label_states:]
572  one_hot = array_ops.one_hot(
573      labels - 1,
574      depth=(num_labels - 1),
575      on_value=0.0,
576      off_value=math_ops.log(0.0))
577  one_hot = array_ops.expand_dims(one_hot, axis=0)
578  label_states = array_ops.expand_dims(label_states, axis=3)
579  label_olabels = math_ops.reduce_logsumexp(label_states + one_hot, axis=2)
580  blank_olabels = math_ops.reduce_logsumexp(blank_states, axis=2, keepdims=True)
581  return array_ops.concat([blank_olabels, label_olabels], axis=-1)
582
583
584# pylint: disable=redefined-outer-name
585def _state_to_olabel_unique(labels, num_labels, states, unique):
586  """Sum state log probs to ilabel log probs using unique label indices."""
587
588  num_label_states = _get_dim(labels, 1) + 1
589  label_states = states[:, :, 1:num_label_states]
590  blank_states = states[:, :, num_label_states:]
591
592  unique_y, unique_idx = unique
593  mul_reduce = _sum_states(unique_idx, label_states)
594
595  num_frames = states.shape[0]
596  batch_size = states.shape[1]
597  num_states = num_label_states - 1
598  batch_state_major = array_ops.transpose(mul_reduce, perm=[1, 2, 0])
599  batch_state_major = array_ops.reshape(batch_state_major,
600                                        [batch_size * num_states, num_frames])
601  batch_offset = math_ops.range(batch_size, dtype=unique_y.dtype) * num_labels
602  indices = unique_y + array_ops.expand_dims(batch_offset, axis=-1)
603  indices = array_ops.reshape(indices, [-1, 1])
604  scatter = array_ops.scatter_nd(
605      indices=indices,
606      updates=batch_state_major,
607      shape=[batch_size * num_labels, num_frames])
608  scatter = array_ops.reshape(scatter, [batch_size, num_labels, num_frames])
609
610  mask = array_ops.ones_like(batch_state_major, dtype=dtypes.bool)
611  mask = array_ops.scatter_nd(
612      indices=indices,
613      updates=mask,
614      shape=[batch_size * num_labels, num_frames])
615  mask = array_ops.reshape(mask, [batch_size, num_labels, num_frames])
616
617  scatter = array_ops.where(
618      mask, scatter,
619      array_ops.fill(array_ops.shape(scatter), math_ops.log(0.0)))
620
621  label_olabels = array_ops.transpose(scatter, [2, 0, 1])
622  label_olabels = label_olabels[:, :, 1:]
623
624  blank_olabels = math_ops.reduce_logsumexp(blank_states, axis=2, keepdims=True)
625
626  return array_ops.concat([blank_olabels, label_olabels], axis=-1)
627
628
629def ctc_loss_and_grad(logits, labels, label_length, logit_length, unique=None):
630  """Computes the CTC loss and gradients.
631
632  Most users will want fwd_bwd.ctc_loss
633
634  This function returns the computed gradient, it does not have a gradient
635  of its own defined.
636
637  Args:
638    logits: tensor of shape [frames, batch_size, num_labels]
639    labels: tensor of shape [batch_size, max_label_seq_length]
640    label_length: tensor of shape [batch_size] Length of reference label
641      sequence in labels.
642    logit_length: tensor of shape [batch_size] Length of input sequence in
643      logits.
644    unique: (optional) unique label indices as computed by unique(labels) If
645      supplied, enables an implementation that is faster and more memory
646      efficient on TPU.
647
648  Returns:
649    loss: tensor of shape [batch_size]
650    gradient: tensor of shape [frames, batch_size, num_labels]
651  """
652
653  num_labels = _get_dim(logits, 2)
654  max_label_seq_length = _get_dim(labels, 1)
655
656  ilabel_log_probs = nn_ops.log_softmax(logits)
657  state_log_probs = _ilabel_to_state(labels, num_labels, ilabel_log_probs)
658  state_trans_probs = _ctc_state_trans(labels)
659  initial_state_log_probs, final_state_log_probs = ctc_state_log_probs(
660      label_length, max_label_seq_length)
661  fwd_bwd_log_probs, log_likelihood = _forward_backward_log(
662      state_trans_log_probs=math_ops.log(state_trans_probs),
663      initial_state_log_probs=initial_state_log_probs,
664      final_state_log_probs=final_state_log_probs,
665      observed_log_probs=state_log_probs,
666      sequence_length=logit_length)
667
668  if unique:
669    olabel_log_probs = _state_to_olabel_unique(labels, num_labels,
670                                               fwd_bwd_log_probs, unique)
671  else:
672    olabel_log_probs = _state_to_olabel(labels, num_labels, fwd_bwd_log_probs)
673
674  grad = math_ops.exp(ilabel_log_probs) - math_ops.exp(olabel_log_probs)
675
676  # Applies the sequence mask for the gradient. It is enough to appply the mask
677  # only for ilabel_log_probs because olabel_log_probs already consider the
678  # mask. However, it is just safe and clean to apply it for the gradient.
679  max_logit_length = _get_dim(logits, 0)
680  logit_mask = array_ops.sequence_mask(logit_length, max_logit_length,
681                                       dtypes.float32)
682  logit_mask = array_ops.transpose(logit_mask, perm=[1, 0])
683  logit_mask = array_ops.expand_dims(logit_mask, axis=2)
684  grad *= logit_mask
685
686  loss = -log_likelihood
687  return loss, grad
688
689
690def _ctc_loss_grad(op, grad_loss, _):
691  grad = op.outputs[1]
692  grad = [array_ops.reshape(grad_loss, [1, -1, 1]) * grad]
693  grad += [None] * (len(op.inputs) - len(grad))
694  return grad
695
696
697def _ctc_loss_op_standard(labels, logits, logit_length, logits_time_major,
698                          blank_index):
699  part_before = logits[:, :, :blank_index]
700  part_after = logits[:, :, blank_index + 1:]
701  part_blank = logits[:, :, blank_index:blank_index + 1]
702  logits = array_ops.concat([part_before, part_after, part_blank], axis=2)
703  labels = sparse_tensor.SparseTensor(
704      labels.indices,
705      array_ops.where(labels.values < blank_index, labels.values,
706                      labels.values - 1), labels.dense_shape)
707  return _ctc_loss_impl(
708      labels=labels,
709      inputs=logits,
710      sequence_length=logit_length,
711      time_major=logits_time_major,
712      use_cudnn=False)
713
714
715def _ctc_loss_op_cudnn(labels, logits, logit_length, logits_time_major,
716                       blank_index):
717  part_before = logits[:, :, :blank_index]
718  part_after = logits[:, :, blank_index + 1:]
719  part_blank = logits[:, :, blank_index:blank_index + 1]
720  logits = array_ops.concat([part_blank, part_before, part_after], axis=2)
721  labels = sparse_tensor.SparseTensor(
722      labels.indices,
723      array_ops.where(labels.values < blank_index, labels.values + 1,
724                      labels.values), labels.dense_shape)
725  return _ctc_loss_impl(
726      labels=labels,
727      inputs=logits,
728      sequence_length=logit_length,
729      time_major=logits_time_major,
730      use_cudnn=True)
731
732
733def _ctc_loss_shape(op):
734  return [op.inputs[2].get_shape(), op.inputs[0].get_shape()]
735
736
737# pylint: disable=protected-access, invalid-name
738@tf_export(v1=["nn.ctc_loss_v2"])
739@dispatch.add_dispatch_support
740def ctc_loss_v2(labels,
741                logits,
742                label_length,
743                logit_length,
744                logits_time_major=True,
745                unique=None,
746                blank_index=None,
747                name=None):
748  """Computes CTC (Connectionist Temporal Classification) loss.
749
750  This op implements the CTC loss as presented in (Graves et al., 2006).
751
752  Notes:
753
754  - Same as the "Classic CTC" in TensorFlow 1.x's tf.compat.v1.nn.ctc_loss
755    setting of preprocess_collapse_repeated=False, ctc_merge_repeated=True
756  - Labels may be supplied as either a dense, zero-padded tensor with a
757    vector of label sequence lengths OR as a SparseTensor.
758  - On TPU and GPU: Only dense padded labels are supported.
759  - On CPU: Caller may use SparseTensor or dense padded labels but calling with
760    a SparseTensor will be significantly faster.
761  - Default blank label is 0 rather num_classes - 1, unless overridden by
762    blank_index.
763
764  Args:
765    labels: tensor of shape [batch_size, max_label_seq_length] or SparseTensor
766    logits: tensor of shape [frames, batch_size, num_labels], if
767      logits_time_major == False, shape is [batch_size, frames, num_labels].
768    label_length: tensor of shape [batch_size], None if labels is SparseTensor
769      Length of reference label sequence in labels.
770    logit_length: tensor of shape [batch_size] Length of input sequence in
771      logits.
772    logits_time_major: (optional) If True (default), logits is shaped [time,
773      batch, logits]. If False, shape is [batch, time, logits]
774    unique: (optional) Unique label indices as computed by
775      ctc_unique_labels(labels).  If supplied, enable a faster, memory efficient
776      implementation on TPU.
777    blank_index: (optional) Set the class index to use for the blank label.
778      Negative values will start from num_classes, ie, -1 will reproduce the
779      ctc_loss behavior of using num_classes - 1 for the blank symbol. There is
780      some memory/performance overhead to switching from the default of 0 as an
781      additional shifted copy of the logits may be created.
782    name: A name for this `Op`. Defaults to "ctc_loss_dense".
783
784  Returns:
785    loss: tensor of shape [batch_size], negative log probabilities.
786
787  References:
788      Connectionist Temporal Classification - Labeling Unsegmented Sequence Data
789      with Recurrent Neural Networks:
790        [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891)
791        ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf))
792  """
793  if isinstance(labels, sparse_tensor.SparseTensor):
794    if blank_index is None:
795      raise ValueError(
796          "blank_index must be given when using SparseTensor labels.")
797
798    if blank_index < 0:
799      blank_index += _get_dim(logits, 2)
800
801    if blank_index != _get_dim(logits, 2) - 1:
802      logits = array_ops.concat([
803          logits[:, :, :blank_index],
804          logits[:, :, blank_index + 1:],
805          logits[:, :, blank_index:blank_index + 1],
806      ],
807                                axis=2)
808      labels = sparse_tensor.SparseTensor(
809          labels.indices,
810          array_ops.where(labels.values < blank_index, labels.values,
811                          labels.values - 1), labels.dense_shape)
812
813    return ctc_loss(
814        labels=labels,
815        inputs=logits,
816        sequence_length=logit_length,
817        time_major=logits_time_major)
818
819  if blank_index is None:
820    blank_index = 0
821
822  return ctc_loss_dense(
823      labels=labels,
824      logits=logits,
825      label_length=label_length,
826      logit_length=logit_length,
827      logits_time_major=logits_time_major,
828      unique=unique,
829      blank_index=blank_index,
830      name=name)
831
832
833@tf_export("nn.ctc_loss", v1=[])
834@dispatch.add_dispatch_support
835def ctc_loss_v3(labels,
836                logits,
837                label_length,
838                logit_length,
839                logits_time_major=True,
840                unique=None,
841                blank_index=None,
842                name=None):
843  """Computes CTC (Connectionist Temporal Classification) loss.
844
845  This op implements the CTC loss as presented in (Graves et al., 2006).
846
847  Notes:
848
849  - Same as the "Classic CTC" in TensorFlow 1.x's tf.compat.v1.nn.ctc_loss
850    setting of preprocess_collapse_repeated=False, ctc_merge_repeated=True
851  - Labels may be supplied as either a dense, zero-padded tensor with a
852    vector of label sequence lengths OR as a SparseTensor.
853  - On TPU and GPU: Only dense padded labels are supported.
854  - On CPU: Caller may use SparseTensor or dense padded labels but calling with
855    a SparseTensor will be significantly faster.
856  - Default blank label is 0 rather num_classes - 1, unless overridden by
857    blank_index.
858
859  Args:
860    labels: tensor of shape [batch_size, max_label_seq_length] or SparseTensor
861    logits: tensor of shape [frames, batch_size, num_labels], if
862      logits_time_major == False, shape is [batch_size, frames, num_labels].
863    label_length: tensor of shape [batch_size], None if labels is SparseTensor
864      Length of reference label sequence in labels.
865    logit_length: tensor of shape [batch_size] Length of input sequence in
866      logits.
867    logits_time_major: (optional) If True (default), logits is shaped [time,
868      batch, logits]. If False, shape is [batch, time, logits]
869    unique: (optional) Unique label indices as computed by
870      ctc_unique_labels(labels).  If supplied, enable a faster, memory efficient
871      implementation on TPU.
872    blank_index: (optional) Set the class index to use for the blank label.
873      Negative values will start from num_classes, ie, -1 will reproduce the
874      ctc_loss behavior of using num_classes - 1 for the blank symbol. There is
875      some memory/performance overhead to switching from the default of 0 as an
876      additional shifted copy of the logits may be created.
877    name: A name for this `Op`. Defaults to "ctc_loss_dense".
878
879  Returns:
880    loss: tensor of shape [batch_size], negative log probabilities.
881
882  References:
883      Connectionist Temporal Classification - Labeling Unsegmented Sequence Data
884      with Recurrent Neural Networks:
885        [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891)
886        ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf))
887  """
888  if isinstance(labels, sparse_tensor.SparseTensor):
889    if blank_index is None:
890      raise ValueError(
891          "blank_index must be given when using SparseTensor labels.")
892
893    if blank_index < 0:
894      blank_index += _get_dim(logits, 2)
895
896    params = {
897        "labels": labels,
898        "logits": logits,
899        "logit_length": logit_length,
900        "logits_time_major": logits_time_major,
901        "blank_index": blank_index
902    }
903
904    if context.executing_eagerly():
905      device_type = _get_context_device_type()
906      can_use_gpu = (
907          # Either user specified GPU or unspecified but GPU is available.
908          (device_type == _GPU_DEVICE_NAME or
909           (device_type is None and context.num_gpus() > 0)))
910      # Under eager context, check the device placement and prefer the
911      if can_use_gpu:
912        res = _ctc_loss_op_cudnn(**params)
913      else:
914        res = _ctc_loss_op_standard(**params)
915    else:
916      api_name = "ctc_loss_" + str(uuid.uuid4())
917      ctc_loss_op_standard = _generate_defun_backend(api_name, _CPU_DEVICE_NAME,
918                                                     _ctc_loss_op_standard)
919      ctc_loss_op_cudnn = _generate_defun_backend(api_name, _GPU_DEVICE_NAME,
920                                                  _ctc_loss_op_cudnn)
921      res = ctc_loss_op_standard(**params)
922      function_eager.register(ctc_loss_op_cudnn, **params)
923    return res
924
925  if blank_index is None:
926    blank_index = 0
927
928  return ctc_loss_dense(
929      labels=labels,
930      logits=logits,
931      label_length=label_length,
932      logit_length=logit_length,
933      logits_time_major=logits_time_major,
934      unique=unique,
935      blank_index=blank_index,
936      name=name)
937
938
939def ctc_loss_dense(labels,
940                   logits,
941                   label_length,
942                   logit_length,
943                   logits_time_major=True,
944                   unique=None,
945                   blank_index=0,
946                   name=None):
947  """Computes CTC (Connectionist Temporal Classification) loss.
948
949  This op implements the CTC loss as presented in (Graves et al., 2006),
950  using the batched forward backward algorithm described in (Sim et al., 2017).
951
952  Notes:
953    Significant differences from tf.compat.v1.nn.ctc_loss:
954      Supports GPU and TPU (tf.compat.v1.nn.ctc_loss supports CPU only):
955        For batched operations, GPU and TPU are significantly faster than using
956        ctc_loss on CPU.
957        This implementation runs on CPU, but significantly slower than ctc_loss.
958      Blank label is 0 rather num_classes - 1, unless overridden by blank_index.
959      Logits and labels are dense arrays with padding rather than SparseTensor.
960      The only mode supported is the same as:
961        preprocess_collapse_repeated=False, ctc_merge_repeated=True
962        To collapse labels, the caller can preprocess label sequence first.
963
964    The dense implementation supports both CPU, GPU and TPU. A fast path is
965    provided that significantly improves memory use for large vocabulary if the
966    caller preprocesses label sequences to get unique label indices on the CPU
967    (eg. in the data input pipeline) using ctc_ops.unique and simplifies this in
968    the optional "unique" kwarg. This is especially useful for TPU and GPU but
969    also works with if used on CPU.
970
971  Args:
972    labels: tensor of shape [batch_size, max_label_seq_length]
973    logits: tensor of shape [frames, batch_size, num_labels], if
974      logits_time_major == False, shape is [batch_size, frames, num_labels].
975    label_length: tensor of shape [batch_size] Length of reference label
976      sequence in labels.
977    logit_length: tensor of shape [batch_size] Length of input sequence in
978      logits.
979    logits_time_major: (optional) If True (default), logits is shaped [time,
980      batch, logits]. If False, shape is [batch, time, logits]
981    unique: (optional) Unique label indices as computed by unique(labels). If
982      supplied, enable a faster, memory efficient implementation on TPU.
983    blank_index: (optional) Set the class index to use for the blank label.
984      Negative values will start from num_classes, ie, -1 will reproduce the
985      ctc_loss behavior of using num_classes - 1 for the blank symbol. There is
986      some memory/performance overhead to switching from the default of 0 as an
987      additional shifted copy of the logits may be created.
988    name: A name for this `Op`. Defaults to "ctc_loss_dense".
989
990  Returns:
991    loss: tensor of shape [batch_size], negative log probabilities.
992
993  References:
994      Connectionist Temporal Classification - Labeling Unsegmented Sequence Data
995      with Recurrent Neural Networks:
996        [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891)
997        ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf))
998      Improving the efficiency of forward-backward algorithm using batched
999      computation in TensorFlow:
1000        [Sim et al., 2017](https://ieeexplore.ieee.org/document/8268944)
1001        ([pdf](http://bacchiani.net/resume/papers/ASRU2017.pdf))
1002  """
1003
1004  with ops.name_scope(name, "ctc_loss_dense",
1005                      [logits, labels, label_length, logit_length]):
1006    logits = ops.convert_to_tensor(logits, name="logits")
1007    labels = ops.convert_to_tensor(labels, name="labels")
1008    label_length = ops.convert_to_tensor(label_length, name="label_length")
1009    logit_length = ops.convert_to_tensor(logit_length, name="logit_length")
1010
1011    if not logits_time_major:
1012      logits = array_ops.transpose(logits, perm=[1, 0, 2])
1013
1014    if blank_index != 0:
1015      if blank_index < 0:
1016        blank_index += _get_dim(logits, 2)
1017      logits = array_ops.concat([
1018          logits[:, :, blank_index:blank_index + 1],
1019          logits[:, :, :blank_index],
1020          logits[:, :, blank_index + 1:],
1021      ],
1022                                axis=2)
1023      labels = array_ops.where(labels < blank_index, labels + 1, labels)
1024
1025    args = [logits, labels, label_length, logit_length]
1026
1027    if unique:
1028      unique_y, unique_idx = unique
1029      if blank_index != 0:
1030        unique_y = array_ops.where(unique_y < blank_index, unique_y + 1,
1031                                   unique_y)
1032        label_mask_len = math_ops.reduce_max(unique_idx, axis=1) + 1
1033        max_label_length = _get_dim(unique_y, 1)
1034        label_mask = array_ops.sequence_mask(label_mask_len, max_label_length)
1035        unique_y = array_ops.where(label_mask, unique_y,
1036                                   array_ops.zeros_like(unique_y))
1037      args.extend([unique_y, unique_idx])
1038
1039    @custom_gradient.custom_gradient
1040    def compute_ctc_loss(logits_t, labels_t, label_length_t, logit_length_t,
1041                         *unique_t):
1042      """Compute CTC loss."""
1043      logits_t.set_shape(logits.shape)
1044      labels_t.set_shape(labels.shape)
1045      label_length_t.set_shape(label_length.shape)
1046      logit_length_t.set_shape(logit_length.shape)
1047      kwargs = dict(
1048          logits=logits_t,
1049          labels=labels_t,
1050          label_length=label_length_t,
1051          logit_length=logit_length_t)
1052      if unique_t:
1053        kwargs["unique"] = unique_t
1054      result = ctc_loss_and_grad(**kwargs)
1055      def grad(grad_loss):
1056        grad = [array_ops.reshape(grad_loss, [1, -1, 1]) * result[1]]
1057        grad += [None] * (len(args) - len(grad))
1058        return grad
1059
1060      return result[0], grad
1061
1062    return compute_ctc_loss(*args)
1063
1064
1065@tf_export("nn.collapse_repeated")
1066@dispatch.add_dispatch_support
1067def collapse_repeated(labels, seq_length, name=None):
1068  """Merge repeated labels into single labels.
1069
1070  Args:
1071    labels: Tensor of shape [batch, max value in seq_length]
1072    seq_length: Tensor of shape [batch], sequence length of each batch element.
1073    name: A name for this `Op`. Defaults to "collapse_repeated_labels".
1074
1075  Returns:
1076    A tuple `(collapsed_labels, new_seq_length)` where
1077
1078    collapsed_labels: Tensor of shape [batch, max_seq_length] with repeated
1079    labels collapsed and padded to max_seq_length, eg:
1080    `[[A, A, B, B, A], [A, B, C, D, E]] => [[A, B, A, 0, 0], [A, B, C, D, E]]`
1081
1082    new_seq_length: int tensor of shape [batch] with new sequence lengths.
1083  """
1084
1085  with ops.name_scope(name, "collapse_repeated_labels", [labels, seq_length]):
1086    labels = ops.convert_to_tensor(labels, name="labels")
1087    seq_length = ops.convert_to_tensor(seq_length, name="seq_length")
1088
1089    # Mask labels that don't equal previous label.
1090    label_mask = array_ops.concat([
1091        array_ops.ones_like(labels[:, :1], dtypes.bool),
1092        math_ops.not_equal(labels[:, 1:], labels[:, :-1])
1093    ],
1094                                  axis=1)
1095
1096    # Filter labels that aren't in the original sequence.
1097    maxlen = _get_dim(labels, 1)
1098    seq_mask = array_ops.sequence_mask(seq_length, maxlen=maxlen)
1099    label_mask = math_ops.logical_and(label_mask, seq_mask)
1100
1101    # Count masks for new sequence lengths.
1102    new_seq_len = math_ops.reduce_sum(
1103        math_ops.cast(label_mask, dtypes.int32), axis=1)
1104
1105    # Mask indexes based on sequence length mask.
1106    new_maxlen = math_ops.reduce_max(new_seq_len)
1107    idx_mask = array_ops.sequence_mask(new_seq_len, maxlen=new_maxlen)
1108
1109    # Flatten everything and mask out labels to keep and sparse indices.
1110    flat_labels = array_ops.reshape(labels, [-1])
1111    flat_label_mask = array_ops.reshape(label_mask, [-1])
1112    flat_idx_mask = array_ops.reshape(idx_mask, [-1])
1113    idx = math_ops.range(_get_dim(flat_idx_mask, 0))
1114
1115    # Scatter to flat shape.
1116    flat = array_ops.scatter_nd(
1117        indices=array_ops.expand_dims(
1118            array_ops.boolean_mask(idx, flat_idx_mask), axis=1),
1119        updates=array_ops.boolean_mask(flat_labels, flat_label_mask),
1120        shape=array_ops.shape(flat_idx_mask))
1121
1122    # Reshape back to square batch.
1123    batch_size = _get_dim(labels, 0)
1124    new_shape = [batch_size, new_maxlen]
1125    return (array_ops.reshape(flat, new_shape),
1126            math_ops.cast(new_seq_len, seq_length.dtype))
1127
1128
1129def dense_labels_to_sparse(dense, length):
1130  """Convert dense labels with sequence lengths to sparse tensor.
1131
1132  Args:
1133    dense: tensor of shape [batch, max_length]
1134    length: int tensor of shape [batch] The length of each sequence in dense.
1135
1136  Returns:
1137    tf.sparse.SparseTensor with values only for the valid elements of sequences.
1138  """
1139
1140  flat_values = array_ops.reshape(dense, [-1])
1141  flat_indices = math_ops.range(
1142      array_ops.shape(flat_values, out_type=dtypes.int64)[0])
1143  mask = array_ops.sequence_mask(length, maxlen=array_ops.shape(dense)[1])
1144  flat_mask = array_ops.reshape(mask, [-1])
1145  indices = array_ops.expand_dims(
1146      array_ops.boolean_mask(flat_indices, flat_mask), 1)
1147  values = array_ops.boolean_mask(flat_values, flat_mask)
1148  sparse = sparse_tensor.SparseTensor(
1149      indices=indices,
1150      values=math_ops.cast(values, dtypes.int32),
1151      dense_shape=array_ops.shape(flat_values, out_type=dtypes.int64))
1152  reshaped = sparse_ops.sparse_reshape(sparse, array_ops.shape(dense))
1153  max_length = math_ops.reduce_max(length)
1154  return sparse_tensor.SparseTensor(
1155      indices=reshaped.indices,
1156      values=reshaped.values,
1157      dense_shape=[
1158          math_ops.cast(reshaped.dense_shape[0], dtypes.int64),
1159          math_ops.cast(max_length, dtypes.int64)
1160      ])
1161
1162
1163@tf_export("nn.ctc_unique_labels")
1164@dispatch.add_dispatch_support
1165def ctc_unique_labels(labels, name=None):
1166  """Get unique labels and indices for batched labels for `tf.nn.ctc_loss`.
1167
1168  For use with `tf.nn.ctc_loss` optional argument `unique`: This op can be
1169  used to preprocess labels in input pipeline to for better speed/memory use
1170  computing the ctc loss on TPU.
1171
1172  Example:
1173    ctc_unique_labels([[3, 4, 4, 3]]) ->
1174      unique labels padded with 0: [[3, 4, 0, 0]]
1175      indices of original labels in unique: [0, 1, 1, 0]
1176
1177  Args:
1178    labels: tensor of shape [batch_size, max_label_length] padded with 0.
1179    name: A name for this `Op`. Defaults to "ctc_unique_labels".
1180
1181  Returns:
1182    tuple of
1183      - unique labels, tensor of shape `[batch_size, max_label_length]`
1184      - indices into unique labels, shape `[batch_size, max_label_length]`
1185  """
1186
1187  with ops.name_scope(name, "ctc_unique_labels", [labels]):
1188    labels = ops.convert_to_tensor(labels, name="labels")
1189
1190    def _unique(x):
1191      u = array_ops.unique(x)
1192      y = array_ops.pad(u.y, [[0, _get_dim(u.idx, 0) - _get_dim(u.y, 0)]])
1193      y = math_ops.cast(y, dtypes.int64)
1194      return [y, u.idx]
1195
1196    return map_fn.map_fn(_unique, labels, dtype=[dtypes.int64, dtypes.int32])
1197
1198
1199def _sum_states(idx, states):
1200  """Take logsumexp for each unique state out of all label states.
1201
1202  Args:
1203    idx: tensor of shape [batch, label_length] For each sequence, indices into a
1204      set of unique labels as computed by calling unique.
1205    states: tensor of shape [frames, batch, label_length] Log probabilities for
1206      each label state.
1207
1208  Returns:
1209    tensor of shape [frames, batch_size, label_length], log probabilites summed
1210      for each unique label of the sequence.
1211  """
1212
1213  with ops.name_scope("sum_states"):
1214    idx = ops.convert_to_tensor(idx, name="idx")
1215    num_states = _get_dim(states, 2)
1216    states = array_ops.expand_dims(states, axis=2)
1217    one_hot = array_ops.one_hot(
1218        idx,
1219        depth=num_states,
1220        on_value=0.0,
1221        off_value=math_ops.log(0.0),
1222        axis=1)
1223    return math_ops.reduce_logsumexp(states + one_hot, axis=-1)
1224
1225
1226def _forward_backward_log(state_trans_log_probs, initial_state_log_probs,
1227                          final_state_log_probs, observed_log_probs,
1228                          sequence_length):
1229  """Forward-backward algorithm computed in log domain.
1230
1231  Args:
1232    state_trans_log_probs: tensor of shape [states, states] or if different
1233      transition matrix per batch [batch_size, states, states]
1234    initial_state_log_probs: tensor of shape [batch_size, states]
1235    final_state_log_probs: tensor of shape [batch_size, states]
1236    observed_log_probs: tensor of shape [frames, batch_size, states]
1237    sequence_length: tensor of shape [batch_size]
1238
1239  Returns:
1240    forward backward log probabilites: tensor of shape [frames, batch, states]
1241    log_likelihood: tensor of shape [batch_size]
1242
1243  Raises:
1244    ValueError: If state_trans_log_probs has unknown or incorrect rank.
1245  """
1246
1247  if state_trans_log_probs.shape.ndims == 2:
1248    perm = [1, 0]
1249  elif state_trans_log_probs.shape.ndims == 3:
1250    perm = [0, 2, 1]
1251  else:
1252    raise ValueError(
1253        "state_trans_log_probs rank must be known and == 2 or 3, is: %s" %
1254        state_trans_log_probs.shape.ndims)
1255
1256  bwd_state_trans_log_probs = array_ops.transpose(state_trans_log_probs, perm)
1257  batch_size = _get_dim(observed_log_probs, 1)
1258
1259  def _forward(state_log_prob, obs_log_prob):
1260    state_log_prob = array_ops.expand_dims(state_log_prob, axis=1)  # Broadcast.
1261    state_log_prob += state_trans_log_probs
1262    state_log_prob = math_ops.reduce_logsumexp(state_log_prob, axis=-1)
1263    state_log_prob += obs_log_prob
1264    log_prob_sum = math_ops.reduce_logsumexp(
1265        state_log_prob, axis=-1, keepdims=True)
1266    state_log_prob -= log_prob_sum
1267    return state_log_prob
1268
1269  fwd = _scan(
1270      _forward, observed_log_probs, initial_state_log_probs, inclusive=True)
1271
1272  def _backward(accs, elems):
1273    """Calculate log probs and cumulative sum masked for sequence length."""
1274    state_log_prob, cum_log_sum = accs
1275    obs_log_prob, mask = elems
1276    state_log_prob += obs_log_prob
1277    state_log_prob = array_ops.expand_dims(state_log_prob, axis=1)  # Broadcast.
1278    state_log_prob += bwd_state_trans_log_probs
1279    state_log_prob = math_ops.reduce_logsumexp(state_log_prob, axis=-1)
1280
1281    log_prob_sum = math_ops.reduce_logsumexp(
1282        state_log_prob, axis=-1, keepdims=True)
1283    state_log_prob -= log_prob_sum
1284
1285    cum_log_sum += array_ops.squeeze(log_prob_sum) * mask
1286    batched_mask = array_ops.expand_dims(mask, axis=1)
1287    out = state_log_prob * batched_mask
1288    out += final_state_log_probs * (1.0 - batched_mask)
1289    return out, cum_log_sum
1290
1291  zero_log_sum = array_ops.zeros([batch_size])
1292  maxlen = _get_dim(observed_log_probs, 0)
1293  mask = array_ops.sequence_mask(sequence_length, maxlen, dtypes.float32)
1294  mask = array_ops.transpose(mask, perm=[1, 0])
1295
1296  bwd, cum_log_sum = _scan(
1297      _backward, (observed_log_probs, mask),
1298      (final_state_log_probs, zero_log_sum),
1299      reverse=True,
1300      inclusive=True)
1301
1302  fwd_bwd_log_probs = fwd[1:] + bwd[1:]
1303  fwd_bwd_log_probs_sum = math_ops.reduce_logsumexp(
1304      fwd_bwd_log_probs, axis=2, keepdims=True)
1305  fwd_bwd_log_probs -= fwd_bwd_log_probs_sum
1306  fwd_bwd_log_probs += math_ops.log(array_ops.expand_dims(mask, axis=2))
1307
1308  log_likelihood = bwd[0, :, 0] + cum_log_sum[0]
1309
1310  return fwd_bwd_log_probs, log_likelihood
1311
1312
1313# TODO(tombagby): This is currently faster for the ctc implementation than using
1314# functional_ops.scan, but could be replaced by that or something similar if
1315# things change.
1316def _scan(fn, elems, initial, reverse=False, inclusive=False, final_only=False):
1317  """Repeatedly applies callable `fn` to a sequence of elements.
1318
1319  Implemented by functional_ops.While, tpu friendly, no gradient.
1320
1321  This is similar to functional_ops.scan but significantly faster on tpu/gpu
1322  for the forward backward use case.
1323
1324  Examples:
1325    scan(lambda a, e: a + e, [1.0, 2.0, 3.0], 1.0) => [2.0, 4.0, 7.0]
1326
1327    Multiple accumulators:
1328      scan(lambda a, e: (a[0] + e, a[1] * e), [1.0, 2.0, 3.0], (0.0, 1.0))
1329
1330    Multiple inputs:
1331      scan(lambda a, e: a + (e[0] * e[1]), (elems1, elems2), 0.0)
1332
1333  Args:
1334    fn: callable, fn(accumulators, element) return new accumulator values. The
1335      (possibly nested) sequence of accumulators is the same as `initial` and
1336      the return value must have the same structure.
1337    elems: A (possibly nested) tensor which will be unpacked along the first
1338      dimension. The resulting slices will be the second argument to fn. The
1339      first dimension of all nested input tensors must be the same.
1340    initial: A tensor or (possibly nested) sequence of tensors with initial
1341      values for the accumulators.
1342    reverse: (optional) True enables scan and output elems in reverse order.
1343    inclusive: (optional) True includes the initial accumulator values in the
1344      output. Length of output will be len(elem sequence) + 1. Not meaningful if
1345      final_only is True.
1346    final_only: (optional) When True, return only the final accumulated values,
1347      not the concatenation of accumulated values for each input.
1348
1349  Returns:
1350    A (possibly nested) sequence of tensors with the results of applying fn
1351    to tensors unpacked from elems and previous accumulator values.
1352  """
1353
1354  flat_elems = [ops.convert_to_tensor(x) for x in nest.flatten(elems)]
1355  num_elems = array_ops.shape(flat_elems[0])[0]
1356  pack_elems = lambda x: nest.pack_sequence_as(structure=elems, flat_sequence=x)
1357  flat_initial = [ops.convert_to_tensor(x) for x in nest.flatten(initial)]
1358  pack = lambda x: nest.pack_sequence_as(structure=initial, flat_sequence=x)
1359  accum_dtypes = [x.dtype for x in flat_initial]
1360  num_accums = len(flat_initial)
1361
1362  # Types for counter, [outputs], [accumulators] loop arguments.
1363  if final_only:
1364    loop_dtypes = [dtypes.int32, dtypes.int32] + accum_dtypes
1365  else:
1366    loop_dtypes = [dtypes.int32, dtypes.int32] + accum_dtypes + accum_dtypes
1367
1368  # TODO(tombagby): Update to tfe.defun
1369  def cond(i, num_elems, *args):
1370    del args
1371    return i >= 0 if reverse else i < num_elems
1372
1373  # The loop *args are [output tensors] + [accumulator tensors] which must
1374  # be paired. Each output corresponds to one accumulator.
1375  def body(i, num_elems, *args):
1376    """Loop body."""
1377    i.set_shape([])
1378    if final_only:
1379      accum = args
1380    else:
1381      out, accum = args[:num_accums], args[num_accums:]
1382    slices = [array_ops.gather(e, i) for e in flat_elems]
1383    accum = fn(pack(accum), pack_elems(slices))
1384    flat_accum = nest.flatten(accum)
1385    if final_only:
1386      new_out = []
1387    else:
1388      update_i = i + 1 if inclusive and not reverse else i
1389      new_out = [
1390          inplace_ops.alias_inplace_update(x, update_i, y)
1391          for x, y in zip(out, flat_accum)
1392      ]
1393    i = i - 1 if reverse else i + 1
1394    return [i, num_elems] + new_out + flat_accum
1395
1396  init_i = (
1397      array_ops.shape(flat_elems[0])[0] -
1398      1 if reverse else constant_op.constant(0, dtype=dtypes.int32))
1399  outputs = []
1400  if not final_only:
1401    num_outputs = array_ops.shape(flat_elems[0])[0] + (1 if inclusive else 0)
1402    for initial_accum in flat_initial:
1403      out_shape = array_ops.concat(
1404          [[num_outputs], array_ops.shape(initial_accum)], 0)
1405      out = inplace_ops.empty(out_shape, dtype=initial_accum.dtype, init=True)
1406      if inclusive:
1407        out = inplace_ops.alias_inplace_add(out, init_i + (1 if reverse else 0),
1408                                            initial_accum)
1409      outputs.append(out)
1410  loop_in = [init_i, num_elems] + outputs + flat_initial
1411  hostmem = [
1412      i for i, x in enumerate(loop_in)
1413      if x.dtype.base_dtype in (dtypes.int32, dtypes.int64)
1414  ]
1415
1416  if context.executing_eagerly():
1417    loop_results = loop_in
1418    while cond(*loop_results):
1419      loop_results = body(*loop_results)
1420  else:
1421    # TODO(tombagby): Update to while_v2.
1422    cond = function.Defun(*loop_dtypes)(cond)
1423    body = function.Defun(*loop_dtypes)(body)
1424    loop_results = functional_ops.While(loop_in, cond, body, hostmem=hostmem)
1425  out = loop_results[2:num_accums + 2]
1426  return pack(out)
1427
1428
1429def _get_dim(tensor, i):
1430  """Get value of tensor shape[i] preferring static value if available."""
1431  return tensor_shape.dimension_value(
1432      tensor.shape[i]) or array_ops.shape(tensor)[i]
1433