• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Module for constructing a linear-chain CRF.
16
17The following snippet is an example of a CRF layer on top of a batched sequence
18of unary scores (logits for every word). This example also decodes the most
19likely sequence at test time. There are two ways to do decoding. One
20is using crf_decode to do decoding in Tensorflow , and the other one is using
21viterbi_decode in Numpy.
22
23log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(
24    unary_scores, gold_tags, sequence_lengths)
25
26loss = tf.reduce_mean(-log_likelihood)
27train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
28
29# Decoding in Tensorflow.
30viterbi_sequence, viterbi_score = tf.contrib.crf.crf_decode(
31    unary_scores, transition_params, sequence_lengths)
32
33tf_viterbi_sequence, tf_viterbi_score, _ = session.run(
34    [viterbi_sequence, viterbi_score, train_op])
35
36# Decoding in Numpy.
37tf_unary_scores, tf_sequence_lengths, tf_transition_params, _ = session.run(
38    [unary_scores, sequence_lengths, transition_params, train_op])
39for tf_unary_scores_, tf_sequence_length_ in zip(tf_unary_scores,
40                                                 tf_sequence_lengths):
41    # Remove padding.
42    tf_unary_scores_ = tf_unary_scores_[:tf_sequence_length_]
43
44    # Compute the highest score and its tag sequence.
45    tf_viterbi_sequence, tf_viterbi_score = tf.contrib.crf.viterbi_decode(
46        tf_unary_scores_, tf_transition_params)
47"""
48
49from __future__ import absolute_import
50from __future__ import division
51from __future__ import print_function
52
53import numpy as np
54
55from tensorflow.python.framework import constant_op
56from tensorflow.python.framework import dtypes
57from tensorflow.python.framework import tensor_shape
58from tensorflow.python.layers import utils
59from tensorflow.python.ops import array_ops
60from tensorflow.python.ops import gen_array_ops
61from tensorflow.python.ops import math_ops
62from tensorflow.python.ops import rnn
63from tensorflow.python.ops import rnn_cell
64from tensorflow.python.ops import variable_scope as vs
65
66__all__ = [
67    "crf_sequence_score", "crf_log_norm", "crf_log_likelihood",
68    "crf_unary_score", "crf_binary_score", "CrfForwardRnnCell",
69    "viterbi_decode", "crf_decode", "CrfDecodeForwardRnnCell",
70    "CrfDecodeBackwardRnnCell", "crf_multitag_sequence_score"
71]
72
73
74def crf_sequence_score(inputs, tag_indices, sequence_lengths,
75                       transition_params):
76  """Computes the unnormalized score for a tag sequence.
77
78  Args:
79    inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
80        to use as input to the CRF layer.
81    tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we
82        compute the unnormalized score.
83    sequence_lengths: A [batch_size] vector of true sequence lengths.
84    transition_params: A [num_tags, num_tags] transition matrix.
85  Returns:
86    sequence_scores: A [batch_size] vector of unnormalized sequence scores.
87  """
88  # If max_seq_len is 1, we skip the score calculation and simply gather the
89  # unary potentials of the single tag.
90  def _single_seq_fn():
91    batch_size = array_ops.shape(inputs, out_type=tag_indices.dtype)[0]
92    example_inds = array_ops.reshape(
93        math_ops.range(batch_size, dtype=tag_indices.dtype), [-1, 1])
94    sequence_scores = array_ops.gather_nd(
95        array_ops.squeeze(inputs, [1]),
96        array_ops.concat([example_inds, tag_indices], axis=1))
97    sequence_scores = array_ops.where(math_ops.less_equal(sequence_lengths, 0),
98                                      array_ops.zeros_like(sequence_scores),
99                                      sequence_scores)
100    return sequence_scores
101
102  def _multi_seq_fn():
103    # Compute the scores of the given tag sequence.
104    unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs)
105    binary_scores = crf_binary_score(tag_indices, sequence_lengths,
106                                     transition_params)
107    sequence_scores = unary_scores + binary_scores
108    return sequence_scores
109
110  return utils.smart_cond(
111      pred=math_ops.equal(
112          tensor_shape.dimension_value(
113              inputs.shape[1]) or array_ops.shape(inputs)[1],
114          1),
115      true_fn=_single_seq_fn,
116      false_fn=_multi_seq_fn)
117
118
119def crf_multitag_sequence_score(inputs, tag_bitmap, sequence_lengths,
120                                transition_params):
121  """Computes the unnormalized score of all tag sequences matching tag_bitmap.
122
123  tag_bitmap enables more than one tag to be considered correct at each time
124  step. This is useful when an observed output at a given time step is
125  consistent with more than one tag, and thus the log likelihood of that
126  observation must take into account all possible consistent tags.
127
128  Using one-hot vectors in tag_bitmap gives results identical to
129  crf_sequence_score.
130
131  Args:
132    inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
133        to use as input to the CRF layer.
134    tag_bitmap: A [batch_size, max_seq_len, num_tags] boolean tensor
135        representing all active tags at each index for which to calculate the
136        unnormalized score.
137    sequence_lengths: A [batch_size] vector of true sequence lengths.
138    transition_params: A [num_tags, num_tags] transition matrix.
139  Returns:
140    sequence_scores: A [batch_size] vector of unnormalized sequence scores.
141  """
142
143  # If max_seq_len is 1, we skip the score calculation and simply gather the
144  # unary potentials of all active tags.
145  def _single_seq_fn():
146    filtered_inputs = array_ops.where(
147        tag_bitmap, inputs,
148        array_ops.fill(array_ops.shape(inputs), float("-inf")))
149    return math_ops.reduce_logsumexp(
150        filtered_inputs, axis=[1, 2], keepdims=False)
151
152  def _multi_seq_fn():
153    # Compute the logsumexp of all scores of sequences matching the given tags.
154    filtered_inputs = array_ops.where(
155        tag_bitmap, inputs,
156        array_ops.fill(array_ops.shape(inputs), float("-inf")))
157    return crf_log_norm(
158        inputs=filtered_inputs,
159        sequence_lengths=sequence_lengths,
160        transition_params=transition_params)
161
162  return utils.smart_cond(
163      pred=math_ops.equal(
164          tensor_shape.dimension_value(
165              inputs.shape[1]) or array_ops.shape(inputs)[1],
166          1),
167      true_fn=_single_seq_fn,
168      false_fn=_multi_seq_fn)
169
170
171def crf_log_norm(inputs, sequence_lengths, transition_params):
172  """Computes the normalization for a CRF.
173
174  Args:
175    inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
176        to use as input to the CRF layer.
177    sequence_lengths: A [batch_size] vector of true sequence lengths.
178    transition_params: A [num_tags, num_tags] transition matrix.
179  Returns:
180    log_norm: A [batch_size] vector of normalizers for a CRF.
181  """
182  # Split up the first and rest of the inputs in preparation for the forward
183  # algorithm.
184  first_input = array_ops.slice(inputs, [0, 0, 0], [-1, 1, -1])
185  first_input = array_ops.squeeze(first_input, [1])
186
187  # If max_seq_len is 1, we skip the algorithm and simply reduce_logsumexp over
188  # the "initial state" (the unary potentials).
189  def _single_seq_fn():
190    log_norm = math_ops.reduce_logsumexp(first_input, [1])
191    # Mask `log_norm` of the sequences with length <= zero.
192    log_norm = array_ops.where(math_ops.less_equal(sequence_lengths, 0),
193                               array_ops.zeros_like(log_norm),
194                               log_norm)
195    return log_norm
196
197  def _multi_seq_fn():
198    """Forward computation of alpha values."""
199    rest_of_input = array_ops.slice(inputs, [0, 1, 0], [-1, -1, -1])
200
201    # Compute the alpha values in the forward algorithm in order to get the
202    # partition function.
203    forward_cell = CrfForwardRnnCell(transition_params)
204    # Sequence length is not allowed to be less than zero.
205    sequence_lengths_less_one = math_ops.maximum(
206        constant_op.constant(0, dtype=sequence_lengths.dtype),
207        sequence_lengths - 1)
208    _, alphas = rnn.dynamic_rnn(
209        cell=forward_cell,
210        inputs=rest_of_input,
211        sequence_length=sequence_lengths_less_one,
212        initial_state=first_input,
213        dtype=dtypes.float32)
214    log_norm = math_ops.reduce_logsumexp(alphas, [1])
215    # Mask `log_norm` of the sequences with length <= zero.
216    log_norm = array_ops.where(math_ops.less_equal(sequence_lengths, 0),
217                               array_ops.zeros_like(log_norm),
218                               log_norm)
219    return log_norm
220
221  return utils.smart_cond(
222      pred=math_ops.equal(
223          tensor_shape.dimension_value(
224              inputs.shape[1]) or array_ops.shape(inputs)[1],
225          1),
226      true_fn=_single_seq_fn,
227      false_fn=_multi_seq_fn)
228
229
230def crf_log_likelihood(inputs,
231                       tag_indices,
232                       sequence_lengths,
233                       transition_params=None):
234  """Computes the log-likelihood of tag sequences in a CRF.
235
236  Args:
237    inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
238        to use as input to the CRF layer.
239    tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we
240        compute the log-likelihood.
241    sequence_lengths: A [batch_size] vector of true sequence lengths.
242    transition_params: A [num_tags, num_tags] transition matrix, if available.
243  Returns:
244    log_likelihood: A [batch_size] `Tensor` containing the log-likelihood of
245      each example, given the sequence of tag indices.
246    transition_params: A [num_tags, num_tags] transition matrix. This is either
247        provided by the caller or created in this function.
248  """
249  # Get shape information.
250  num_tags = tensor_shape.dimension_value(inputs.shape[2])
251
252  # Get the transition matrix if not provided.
253  if transition_params is None:
254    transition_params = vs.get_variable("transitions", [num_tags, num_tags])
255
256  sequence_scores = crf_sequence_score(inputs, tag_indices, sequence_lengths,
257                                       transition_params)
258  log_norm = crf_log_norm(inputs, sequence_lengths, transition_params)
259
260  # Normalize the scores to get the log-likelihood per example.
261  log_likelihood = sequence_scores - log_norm
262  return log_likelihood, transition_params
263
264
265def crf_unary_score(tag_indices, sequence_lengths, inputs):
266  """Computes the unary scores of tag sequences.
267
268  Args:
269    tag_indices: A [batch_size, max_seq_len] matrix of tag indices.
270    sequence_lengths: A [batch_size] vector of true sequence lengths.
271    inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials.
272  Returns:
273    unary_scores: A [batch_size] vector of unary scores.
274  """
275  batch_size = array_ops.shape(inputs)[0]
276  max_seq_len = array_ops.shape(inputs)[1]
277  num_tags = array_ops.shape(inputs)[2]
278
279  flattened_inputs = array_ops.reshape(inputs, [-1])
280
281  offsets = array_ops.expand_dims(
282      math_ops.range(batch_size) * max_seq_len * num_tags, 1)
283  offsets += array_ops.expand_dims(math_ops.range(max_seq_len) * num_tags, 0)
284  # Use int32 or int64 based on tag_indices' dtype.
285  if tag_indices.dtype == dtypes.int64:
286    offsets = math_ops.cast(offsets, dtypes.int64)
287  flattened_tag_indices = array_ops.reshape(offsets + tag_indices, [-1])
288
289  unary_scores = array_ops.reshape(
290      array_ops.gather(flattened_inputs, flattened_tag_indices),
291      [batch_size, max_seq_len])
292
293  masks = array_ops.sequence_mask(sequence_lengths,
294                                  maxlen=array_ops.shape(tag_indices)[1],
295                                  dtype=dtypes.float32)
296
297  unary_scores = math_ops.reduce_sum(unary_scores * masks, 1)
298  return unary_scores
299
300
301def crf_binary_score(tag_indices, sequence_lengths, transition_params):
302  """Computes the binary scores of tag sequences.
303
304  Args:
305    tag_indices: A [batch_size, max_seq_len] matrix of tag indices.
306    sequence_lengths: A [batch_size] vector of true sequence lengths.
307    transition_params: A [num_tags, num_tags] matrix of binary potentials.
308  Returns:
309    binary_scores: A [batch_size] vector of binary scores.
310  """
311  # Get shape information.
312  num_tags = transition_params.get_shape()[0]
313  num_transitions = array_ops.shape(tag_indices)[1] - 1
314
315  # Truncate by one on each side of the sequence to get the start and end
316  # indices of each transition.
317  start_tag_indices = array_ops.slice(tag_indices, [0, 0],
318                                      [-1, num_transitions])
319  end_tag_indices = array_ops.slice(tag_indices, [0, 1], [-1, num_transitions])
320
321  # Encode the indices in a flattened representation.
322  flattened_transition_indices = start_tag_indices * num_tags + end_tag_indices
323  flattened_transition_params = array_ops.reshape(transition_params, [-1])
324
325  # Get the binary scores based on the flattened representation.
326  binary_scores = array_ops.gather(flattened_transition_params,
327                                   flattened_transition_indices)
328
329  masks = array_ops.sequence_mask(sequence_lengths,
330                                  maxlen=array_ops.shape(tag_indices)[1],
331                                  dtype=dtypes.float32)
332  truncated_masks = array_ops.slice(masks, [0, 1], [-1, -1])
333  binary_scores = math_ops.reduce_sum(binary_scores * truncated_masks, 1)
334  return binary_scores
335
336
337class CrfForwardRnnCell(rnn_cell.RNNCell):
338  """Computes the alpha values in a linear-chain CRF.
339
340  See http://www.cs.columbia.edu/~mcollins/fb.pdf for reference.
341  """
342
343  def __init__(self, transition_params):
344    """Initialize the CrfForwardRnnCell.
345
346    Args:
347      transition_params: A [num_tags, num_tags] matrix of binary potentials.
348          This matrix is expanded into a [1, num_tags, num_tags] in preparation
349          for the broadcast summation occurring within the cell.
350    """
351    self._transition_params = array_ops.expand_dims(transition_params, 0)
352    self._num_tags = tensor_shape.dimension_value(transition_params.shape[0])
353
354  @property
355  def state_size(self):
356    return self._num_tags
357
358  @property
359  def output_size(self):
360    return self._num_tags
361
362  def __call__(self, inputs, state, scope=None):
363    """Build the CrfForwardRnnCell.
364
365    Args:
366      inputs: A [batch_size, num_tags] matrix of unary potentials.
367      state: A [batch_size, num_tags] matrix containing the previous alpha
368          values.
369      scope: Unused variable scope of this cell.
370
371    Returns:
372      new_alphas, new_alphas: A pair of [batch_size, num_tags] matrices
373          values containing the new alpha values.
374    """
375    state = array_ops.expand_dims(state, 2)
376
377    # This addition op broadcasts self._transitions_params along the zeroth
378    # dimension and state along the second dimension. This performs the
379    # multiplication of previous alpha values and the current binary potentials
380    # in log space.
381    transition_scores = state + self._transition_params
382    new_alphas = inputs + math_ops.reduce_logsumexp(transition_scores, [1])
383
384    # Both the state and the output of this RNN cell contain the alphas values.
385    # The output value is currently unused and simply satisfies the RNN API.
386    # This could be useful in the future if we need to compute marginal
387    # probabilities, which would require the accumulated alpha values at every
388    # time step.
389    return new_alphas, new_alphas
390
391
392def viterbi_decode(score, transition_params):
393  """Decode the highest scoring sequence of tags outside of TensorFlow.
394
395  This should only be used at test time.
396
397  Args:
398    score: A [seq_len, num_tags] matrix of unary potentials.
399    transition_params: A [num_tags, num_tags] matrix of binary potentials.
400
401  Returns:
402    viterbi: A [seq_len] list of integers containing the highest scoring tag
403        indices.
404    viterbi_score: A float containing the score for the Viterbi sequence.
405  """
406  trellis = np.zeros_like(score)
407  backpointers = np.zeros_like(score, dtype=np.int32)
408  trellis[0] = score[0]
409
410  for t in range(1, score.shape[0]):
411    v = np.expand_dims(trellis[t - 1], 1) + transition_params
412    trellis[t] = score[t] + np.max(v, 0)
413    backpointers[t] = np.argmax(v, 0)
414
415  viterbi = [np.argmax(trellis[-1])]
416  for bp in reversed(backpointers[1:]):
417    viterbi.append(bp[viterbi[-1]])
418  viterbi.reverse()
419
420  viterbi_score = np.max(trellis[-1])
421  return viterbi, viterbi_score
422
423
424class CrfDecodeForwardRnnCell(rnn_cell.RNNCell):
425  """Computes the forward decoding in a linear-chain CRF.
426  """
427
428  def __init__(self, transition_params):
429    """Initialize the CrfDecodeForwardRnnCell.
430
431    Args:
432      transition_params: A [num_tags, num_tags] matrix of binary
433        potentials. This matrix is expanded into a
434        [1, num_tags, num_tags] in preparation for the broadcast
435        summation occurring within the cell.
436    """
437    self._transition_params = array_ops.expand_dims(transition_params, 0)
438    self._num_tags = tensor_shape.dimension_value(transition_params.shape[0])
439
440  @property
441  def state_size(self):
442    return self._num_tags
443
444  @property
445  def output_size(self):
446    return self._num_tags
447
448  def __call__(self, inputs, state, scope=None):
449    """Build the CrfDecodeForwardRnnCell.
450
451    Args:
452      inputs: A [batch_size, num_tags] matrix of unary potentials.
453      state: A [batch_size, num_tags] matrix containing the previous step's
454            score values.
455      scope: Unused variable scope of this cell.
456
457    Returns:
458      backpointers: A [batch_size, num_tags] matrix of backpointers.
459      new_state: A [batch_size, num_tags] matrix of new score values.
460    """
461    # For simplicity, in shape comments, denote:
462    # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
463    state = array_ops.expand_dims(state, 2)                         # [B, O, 1]
464
465    # This addition op broadcasts self._transitions_params along the zeroth
466    # dimension and state along the second dimension.
467    # [B, O, 1] + [1, O, O] -> [B, O, O]
468    transition_scores = state + self._transition_params             # [B, O, O]
469    new_state = inputs + math_ops.reduce_max(transition_scores, [1])  # [B, O]
470    backpointers = math_ops.argmax(transition_scores, 1)
471    backpointers = math_ops.cast(backpointers, dtype=dtypes.int32)    # [B, O]
472    return backpointers, new_state
473
474
475class CrfDecodeBackwardRnnCell(rnn_cell.RNNCell):
476  """Computes backward decoding in a linear-chain CRF.
477  """
478
479  def __init__(self, num_tags):
480    """Initialize the CrfDecodeBackwardRnnCell.
481
482    Args:
483      num_tags: An integer. The number of tags.
484    """
485    self._num_tags = num_tags
486
487  @property
488  def state_size(self):
489    return 1
490
491  @property
492  def output_size(self):
493    return 1
494
495  def __call__(self, inputs, state, scope=None):
496    """Build the CrfDecodeBackwardRnnCell.
497
498    Args:
499      inputs: A [batch_size, num_tags] matrix of
500            backpointer of next step (in time order).
501      state: A [batch_size, 1] matrix of tag index of next step.
502      scope: Unused variable scope of this cell.
503
504    Returns:
505      new_tags, new_tags: A pair of [batch_size, num_tags]
506        tensors containing the new tag indices.
507    """
508    state = array_ops.squeeze(state, axis=[1])                # [B]
509    batch_size = array_ops.shape(inputs)[0]
510    b_indices = math_ops.range(batch_size)                    # [B]
511    indices = array_ops.stack([b_indices, state], axis=1)     # [B, 2]
512    new_tags = array_ops.expand_dims(
513        gen_array_ops.gather_nd(inputs, indices),             # [B]
514        axis=-1)                                              # [B, 1]
515
516    return new_tags, new_tags
517
518
519def crf_decode(potentials, transition_params, sequence_length):
520  """Decode the highest scoring sequence of tags in TensorFlow.
521
522  This is a function for tensor.
523
524  Args:
525    potentials: A [batch_size, max_seq_len, num_tags] tensor of
526              unary potentials.
527    transition_params: A [num_tags, num_tags] matrix of
528              binary potentials.
529    sequence_length: A [batch_size] vector of true sequence lengths.
530
531  Returns:
532    decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`.
533                Contains the highest scoring tag indices.
534    best_score: A [batch_size] vector, containing the score of `decode_tags`.
535  """
536  # If max_seq_len is 1, we skip the algorithm and simply return the argmax tag
537  # and the max activation.
538  def _single_seq_fn():
539    squeezed_potentials = array_ops.squeeze(potentials, [1])
540    decode_tags = array_ops.expand_dims(
541        math_ops.argmax(squeezed_potentials, axis=1), 1)
542    best_score = math_ops.reduce_max(squeezed_potentials, axis=1)
543    return math_ops.cast(decode_tags, dtype=dtypes.int32), best_score
544
545  def _multi_seq_fn():
546    """Decoding of highest scoring sequence."""
547
548    # For simplicity, in shape comments, denote:
549    # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
550    num_tags = tensor_shape.dimension_value(potentials.shape[2])
551
552    # Computes forward decoding. Get last score and backpointers.
553    crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params)
554    initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1])
555    initial_state = array_ops.squeeze(initial_state, axis=[1])  # [B, O]
556    inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1])  # [B, T-1, O]
557    # Sequence length is not allowed to be less than zero.
558    sequence_length_less_one = math_ops.maximum(
559        constant_op.constant(0, dtype=sequence_length.dtype),
560        sequence_length - 1)
561    backpointers, last_score = rnn.dynamic_rnn(  # [B, T - 1, O], [B, O]
562        crf_fwd_cell,
563        inputs=inputs,
564        sequence_length=sequence_length_less_one,
565        initial_state=initial_state,
566        time_major=False,
567        dtype=dtypes.int32)
568    backpointers = gen_array_ops.reverse_sequence(  # [B, T - 1, O]
569        backpointers, sequence_length_less_one, seq_dim=1)
570
571    # Computes backward decoding. Extract tag indices from backpointers.
572    crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags)
573    initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1),  # [B]
574                                  dtype=dtypes.int32)
575    initial_state = array_ops.expand_dims(initial_state, axis=-1)  # [B, 1]
576    decode_tags, _ = rnn.dynamic_rnn(  # [B, T - 1, 1]
577        crf_bwd_cell,
578        inputs=backpointers,
579        sequence_length=sequence_length_less_one,
580        initial_state=initial_state,
581        time_major=False,
582        dtype=dtypes.int32)
583    decode_tags = array_ops.squeeze(decode_tags, axis=[2])  # [B, T - 1]
584    decode_tags = array_ops.concat([initial_state, decode_tags],   # [B, T]
585                                   axis=1)
586    decode_tags = gen_array_ops.reverse_sequence(  # [B, T]
587        decode_tags, sequence_length, seq_dim=1)
588
589    best_score = math_ops.reduce_max(last_score, axis=1)  # [B]
590    return decode_tags, best_score
591
592  return utils.smart_cond(
593      pred=math_ops.equal(tensor_shape.dimension_value(potentials.shape[1]) or
594                          array_ops.shape(potentials)[1], 1),
595      true_fn=_single_seq_fn,
596      false_fn=_multi_seq_fn)
597