1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Module for constructing a linear-chain CRF. 16 17The following snippet is an example of a CRF layer on top of a batched sequence 18of unary scores (logits for every word). This example also decodes the most 19likely sequence at test time. There are two ways to do decoding. One 20is using crf_decode to do decoding in Tensorflow , and the other one is using 21viterbi_decode in Numpy. 22 23log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood( 24 unary_scores, gold_tags, sequence_lengths) 25 26loss = tf.reduce_mean(-log_likelihood) 27train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss) 28 29# Decoding in Tensorflow. 30viterbi_sequence, viterbi_score = tf.contrib.crf.crf_decode( 31 unary_scores, transition_params, sequence_lengths) 32 33tf_viterbi_sequence, tf_viterbi_score, _ = session.run( 34 [viterbi_sequence, viterbi_score, train_op]) 35 36# Decoding in Numpy. 37tf_unary_scores, tf_sequence_lengths, tf_transition_params, _ = session.run( 38 [unary_scores, sequence_lengths, transition_params, train_op]) 39for tf_unary_scores_, tf_sequence_length_ in zip(tf_unary_scores, 40 tf_sequence_lengths): 41 # Remove padding. 42 tf_unary_scores_ = tf_unary_scores_[:tf_sequence_length_] 43 44 # Compute the highest score and its tag sequence. 45 tf_viterbi_sequence, tf_viterbi_score = tf.contrib.crf.viterbi_decode( 46 tf_unary_scores_, tf_transition_params) 47""" 48 49from __future__ import absolute_import 50from __future__ import division 51from __future__ import print_function 52 53import numpy as np 54 55from tensorflow.python.framework import constant_op 56from tensorflow.python.framework import dtypes 57from tensorflow.python.framework import tensor_shape 58from tensorflow.python.layers import utils 59from tensorflow.python.ops import array_ops 60from tensorflow.python.ops import gen_array_ops 61from tensorflow.python.ops import math_ops 62from tensorflow.python.ops import rnn 63from tensorflow.python.ops import rnn_cell 64from tensorflow.python.ops import variable_scope as vs 65 66__all__ = [ 67 "crf_sequence_score", "crf_log_norm", "crf_log_likelihood", 68 "crf_unary_score", "crf_binary_score", "CrfForwardRnnCell", 69 "viterbi_decode", "crf_decode", "CrfDecodeForwardRnnCell", 70 "CrfDecodeBackwardRnnCell", "crf_multitag_sequence_score" 71] 72 73 74def crf_sequence_score(inputs, tag_indices, sequence_lengths, 75 transition_params): 76 """Computes the unnormalized score for a tag sequence. 77 78 Args: 79 inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials 80 to use as input to the CRF layer. 81 tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we 82 compute the unnormalized score. 83 sequence_lengths: A [batch_size] vector of true sequence lengths. 84 transition_params: A [num_tags, num_tags] transition matrix. 85 Returns: 86 sequence_scores: A [batch_size] vector of unnormalized sequence scores. 87 """ 88 # If max_seq_len is 1, we skip the score calculation and simply gather the 89 # unary potentials of the single tag. 90 def _single_seq_fn(): 91 batch_size = array_ops.shape(inputs, out_type=tag_indices.dtype)[0] 92 example_inds = array_ops.reshape( 93 math_ops.range(batch_size, dtype=tag_indices.dtype), [-1, 1]) 94 sequence_scores = array_ops.gather_nd( 95 array_ops.squeeze(inputs, [1]), 96 array_ops.concat([example_inds, tag_indices], axis=1)) 97 sequence_scores = array_ops.where(math_ops.less_equal(sequence_lengths, 0), 98 array_ops.zeros_like(sequence_scores), 99 sequence_scores) 100 return sequence_scores 101 102 def _multi_seq_fn(): 103 # Compute the scores of the given tag sequence. 104 unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs) 105 binary_scores = crf_binary_score(tag_indices, sequence_lengths, 106 transition_params) 107 sequence_scores = unary_scores + binary_scores 108 return sequence_scores 109 110 return utils.smart_cond( 111 pred=math_ops.equal( 112 tensor_shape.dimension_value( 113 inputs.shape[1]) or array_ops.shape(inputs)[1], 114 1), 115 true_fn=_single_seq_fn, 116 false_fn=_multi_seq_fn) 117 118 119def crf_multitag_sequence_score(inputs, tag_bitmap, sequence_lengths, 120 transition_params): 121 """Computes the unnormalized score of all tag sequences matching tag_bitmap. 122 123 tag_bitmap enables more than one tag to be considered correct at each time 124 step. This is useful when an observed output at a given time step is 125 consistent with more than one tag, and thus the log likelihood of that 126 observation must take into account all possible consistent tags. 127 128 Using one-hot vectors in tag_bitmap gives results identical to 129 crf_sequence_score. 130 131 Args: 132 inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials 133 to use as input to the CRF layer. 134 tag_bitmap: A [batch_size, max_seq_len, num_tags] boolean tensor 135 representing all active tags at each index for which to calculate the 136 unnormalized score. 137 sequence_lengths: A [batch_size] vector of true sequence lengths. 138 transition_params: A [num_tags, num_tags] transition matrix. 139 Returns: 140 sequence_scores: A [batch_size] vector of unnormalized sequence scores. 141 """ 142 143 # If max_seq_len is 1, we skip the score calculation and simply gather the 144 # unary potentials of all active tags. 145 def _single_seq_fn(): 146 filtered_inputs = array_ops.where( 147 tag_bitmap, inputs, 148 array_ops.fill(array_ops.shape(inputs), float("-inf"))) 149 return math_ops.reduce_logsumexp( 150 filtered_inputs, axis=[1, 2], keepdims=False) 151 152 def _multi_seq_fn(): 153 # Compute the logsumexp of all scores of sequences matching the given tags. 154 filtered_inputs = array_ops.where( 155 tag_bitmap, inputs, 156 array_ops.fill(array_ops.shape(inputs), float("-inf"))) 157 return crf_log_norm( 158 inputs=filtered_inputs, 159 sequence_lengths=sequence_lengths, 160 transition_params=transition_params) 161 162 return utils.smart_cond( 163 pred=math_ops.equal( 164 tensor_shape.dimension_value( 165 inputs.shape[1]) or array_ops.shape(inputs)[1], 166 1), 167 true_fn=_single_seq_fn, 168 false_fn=_multi_seq_fn) 169 170 171def crf_log_norm(inputs, sequence_lengths, transition_params): 172 """Computes the normalization for a CRF. 173 174 Args: 175 inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials 176 to use as input to the CRF layer. 177 sequence_lengths: A [batch_size] vector of true sequence lengths. 178 transition_params: A [num_tags, num_tags] transition matrix. 179 Returns: 180 log_norm: A [batch_size] vector of normalizers for a CRF. 181 """ 182 # Split up the first and rest of the inputs in preparation for the forward 183 # algorithm. 184 first_input = array_ops.slice(inputs, [0, 0, 0], [-1, 1, -1]) 185 first_input = array_ops.squeeze(first_input, [1]) 186 187 # If max_seq_len is 1, we skip the algorithm and simply reduce_logsumexp over 188 # the "initial state" (the unary potentials). 189 def _single_seq_fn(): 190 log_norm = math_ops.reduce_logsumexp(first_input, [1]) 191 # Mask `log_norm` of the sequences with length <= zero. 192 log_norm = array_ops.where(math_ops.less_equal(sequence_lengths, 0), 193 array_ops.zeros_like(log_norm), 194 log_norm) 195 return log_norm 196 197 def _multi_seq_fn(): 198 """Forward computation of alpha values.""" 199 rest_of_input = array_ops.slice(inputs, [0, 1, 0], [-1, -1, -1]) 200 201 # Compute the alpha values in the forward algorithm in order to get the 202 # partition function. 203 forward_cell = CrfForwardRnnCell(transition_params) 204 # Sequence length is not allowed to be less than zero. 205 sequence_lengths_less_one = math_ops.maximum( 206 constant_op.constant(0, dtype=sequence_lengths.dtype), 207 sequence_lengths - 1) 208 _, alphas = rnn.dynamic_rnn( 209 cell=forward_cell, 210 inputs=rest_of_input, 211 sequence_length=sequence_lengths_less_one, 212 initial_state=first_input, 213 dtype=dtypes.float32) 214 log_norm = math_ops.reduce_logsumexp(alphas, [1]) 215 # Mask `log_norm` of the sequences with length <= zero. 216 log_norm = array_ops.where(math_ops.less_equal(sequence_lengths, 0), 217 array_ops.zeros_like(log_norm), 218 log_norm) 219 return log_norm 220 221 return utils.smart_cond( 222 pred=math_ops.equal( 223 tensor_shape.dimension_value( 224 inputs.shape[1]) or array_ops.shape(inputs)[1], 225 1), 226 true_fn=_single_seq_fn, 227 false_fn=_multi_seq_fn) 228 229 230def crf_log_likelihood(inputs, 231 tag_indices, 232 sequence_lengths, 233 transition_params=None): 234 """Computes the log-likelihood of tag sequences in a CRF. 235 236 Args: 237 inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials 238 to use as input to the CRF layer. 239 tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we 240 compute the log-likelihood. 241 sequence_lengths: A [batch_size] vector of true sequence lengths. 242 transition_params: A [num_tags, num_tags] transition matrix, if available. 243 Returns: 244 log_likelihood: A [batch_size] `Tensor` containing the log-likelihood of 245 each example, given the sequence of tag indices. 246 transition_params: A [num_tags, num_tags] transition matrix. This is either 247 provided by the caller or created in this function. 248 """ 249 # Get shape information. 250 num_tags = tensor_shape.dimension_value(inputs.shape[2]) 251 252 # Get the transition matrix if not provided. 253 if transition_params is None: 254 transition_params = vs.get_variable("transitions", [num_tags, num_tags]) 255 256 sequence_scores = crf_sequence_score(inputs, tag_indices, sequence_lengths, 257 transition_params) 258 log_norm = crf_log_norm(inputs, sequence_lengths, transition_params) 259 260 # Normalize the scores to get the log-likelihood per example. 261 log_likelihood = sequence_scores - log_norm 262 return log_likelihood, transition_params 263 264 265def crf_unary_score(tag_indices, sequence_lengths, inputs): 266 """Computes the unary scores of tag sequences. 267 268 Args: 269 tag_indices: A [batch_size, max_seq_len] matrix of tag indices. 270 sequence_lengths: A [batch_size] vector of true sequence lengths. 271 inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials. 272 Returns: 273 unary_scores: A [batch_size] vector of unary scores. 274 """ 275 batch_size = array_ops.shape(inputs)[0] 276 max_seq_len = array_ops.shape(inputs)[1] 277 num_tags = array_ops.shape(inputs)[2] 278 279 flattened_inputs = array_ops.reshape(inputs, [-1]) 280 281 offsets = array_ops.expand_dims( 282 math_ops.range(batch_size) * max_seq_len * num_tags, 1) 283 offsets += array_ops.expand_dims(math_ops.range(max_seq_len) * num_tags, 0) 284 # Use int32 or int64 based on tag_indices' dtype. 285 if tag_indices.dtype == dtypes.int64: 286 offsets = math_ops.cast(offsets, dtypes.int64) 287 flattened_tag_indices = array_ops.reshape(offsets + tag_indices, [-1]) 288 289 unary_scores = array_ops.reshape( 290 array_ops.gather(flattened_inputs, flattened_tag_indices), 291 [batch_size, max_seq_len]) 292 293 masks = array_ops.sequence_mask(sequence_lengths, 294 maxlen=array_ops.shape(tag_indices)[1], 295 dtype=dtypes.float32) 296 297 unary_scores = math_ops.reduce_sum(unary_scores * masks, 1) 298 return unary_scores 299 300 301def crf_binary_score(tag_indices, sequence_lengths, transition_params): 302 """Computes the binary scores of tag sequences. 303 304 Args: 305 tag_indices: A [batch_size, max_seq_len] matrix of tag indices. 306 sequence_lengths: A [batch_size] vector of true sequence lengths. 307 transition_params: A [num_tags, num_tags] matrix of binary potentials. 308 Returns: 309 binary_scores: A [batch_size] vector of binary scores. 310 """ 311 # Get shape information. 312 num_tags = transition_params.get_shape()[0] 313 num_transitions = array_ops.shape(tag_indices)[1] - 1 314 315 # Truncate by one on each side of the sequence to get the start and end 316 # indices of each transition. 317 start_tag_indices = array_ops.slice(tag_indices, [0, 0], 318 [-1, num_transitions]) 319 end_tag_indices = array_ops.slice(tag_indices, [0, 1], [-1, num_transitions]) 320 321 # Encode the indices in a flattened representation. 322 flattened_transition_indices = start_tag_indices * num_tags + end_tag_indices 323 flattened_transition_params = array_ops.reshape(transition_params, [-1]) 324 325 # Get the binary scores based on the flattened representation. 326 binary_scores = array_ops.gather(flattened_transition_params, 327 flattened_transition_indices) 328 329 masks = array_ops.sequence_mask(sequence_lengths, 330 maxlen=array_ops.shape(tag_indices)[1], 331 dtype=dtypes.float32) 332 truncated_masks = array_ops.slice(masks, [0, 1], [-1, -1]) 333 binary_scores = math_ops.reduce_sum(binary_scores * truncated_masks, 1) 334 return binary_scores 335 336 337class CrfForwardRnnCell(rnn_cell.RNNCell): 338 """Computes the alpha values in a linear-chain CRF. 339 340 See http://www.cs.columbia.edu/~mcollins/fb.pdf for reference. 341 """ 342 343 def __init__(self, transition_params): 344 """Initialize the CrfForwardRnnCell. 345 346 Args: 347 transition_params: A [num_tags, num_tags] matrix of binary potentials. 348 This matrix is expanded into a [1, num_tags, num_tags] in preparation 349 for the broadcast summation occurring within the cell. 350 """ 351 self._transition_params = array_ops.expand_dims(transition_params, 0) 352 self._num_tags = tensor_shape.dimension_value(transition_params.shape[0]) 353 354 @property 355 def state_size(self): 356 return self._num_tags 357 358 @property 359 def output_size(self): 360 return self._num_tags 361 362 def __call__(self, inputs, state, scope=None): 363 """Build the CrfForwardRnnCell. 364 365 Args: 366 inputs: A [batch_size, num_tags] matrix of unary potentials. 367 state: A [batch_size, num_tags] matrix containing the previous alpha 368 values. 369 scope: Unused variable scope of this cell. 370 371 Returns: 372 new_alphas, new_alphas: A pair of [batch_size, num_tags] matrices 373 values containing the new alpha values. 374 """ 375 state = array_ops.expand_dims(state, 2) 376 377 # This addition op broadcasts self._transitions_params along the zeroth 378 # dimension and state along the second dimension. This performs the 379 # multiplication of previous alpha values and the current binary potentials 380 # in log space. 381 transition_scores = state + self._transition_params 382 new_alphas = inputs + math_ops.reduce_logsumexp(transition_scores, [1]) 383 384 # Both the state and the output of this RNN cell contain the alphas values. 385 # The output value is currently unused and simply satisfies the RNN API. 386 # This could be useful in the future if we need to compute marginal 387 # probabilities, which would require the accumulated alpha values at every 388 # time step. 389 return new_alphas, new_alphas 390 391 392def viterbi_decode(score, transition_params): 393 """Decode the highest scoring sequence of tags outside of TensorFlow. 394 395 This should only be used at test time. 396 397 Args: 398 score: A [seq_len, num_tags] matrix of unary potentials. 399 transition_params: A [num_tags, num_tags] matrix of binary potentials. 400 401 Returns: 402 viterbi: A [seq_len] list of integers containing the highest scoring tag 403 indices. 404 viterbi_score: A float containing the score for the Viterbi sequence. 405 """ 406 trellis = np.zeros_like(score) 407 backpointers = np.zeros_like(score, dtype=np.int32) 408 trellis[0] = score[0] 409 410 for t in range(1, score.shape[0]): 411 v = np.expand_dims(trellis[t - 1], 1) + transition_params 412 trellis[t] = score[t] + np.max(v, 0) 413 backpointers[t] = np.argmax(v, 0) 414 415 viterbi = [np.argmax(trellis[-1])] 416 for bp in reversed(backpointers[1:]): 417 viterbi.append(bp[viterbi[-1]]) 418 viterbi.reverse() 419 420 viterbi_score = np.max(trellis[-1]) 421 return viterbi, viterbi_score 422 423 424class CrfDecodeForwardRnnCell(rnn_cell.RNNCell): 425 """Computes the forward decoding in a linear-chain CRF. 426 """ 427 428 def __init__(self, transition_params): 429 """Initialize the CrfDecodeForwardRnnCell. 430 431 Args: 432 transition_params: A [num_tags, num_tags] matrix of binary 433 potentials. This matrix is expanded into a 434 [1, num_tags, num_tags] in preparation for the broadcast 435 summation occurring within the cell. 436 """ 437 self._transition_params = array_ops.expand_dims(transition_params, 0) 438 self._num_tags = tensor_shape.dimension_value(transition_params.shape[0]) 439 440 @property 441 def state_size(self): 442 return self._num_tags 443 444 @property 445 def output_size(self): 446 return self._num_tags 447 448 def __call__(self, inputs, state, scope=None): 449 """Build the CrfDecodeForwardRnnCell. 450 451 Args: 452 inputs: A [batch_size, num_tags] matrix of unary potentials. 453 state: A [batch_size, num_tags] matrix containing the previous step's 454 score values. 455 scope: Unused variable scope of this cell. 456 457 Returns: 458 backpointers: A [batch_size, num_tags] matrix of backpointers. 459 new_state: A [batch_size, num_tags] matrix of new score values. 460 """ 461 # For simplicity, in shape comments, denote: 462 # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output). 463 state = array_ops.expand_dims(state, 2) # [B, O, 1] 464 465 # This addition op broadcasts self._transitions_params along the zeroth 466 # dimension and state along the second dimension. 467 # [B, O, 1] + [1, O, O] -> [B, O, O] 468 transition_scores = state + self._transition_params # [B, O, O] 469 new_state = inputs + math_ops.reduce_max(transition_scores, [1]) # [B, O] 470 backpointers = math_ops.argmax(transition_scores, 1) 471 backpointers = math_ops.cast(backpointers, dtype=dtypes.int32) # [B, O] 472 return backpointers, new_state 473 474 475class CrfDecodeBackwardRnnCell(rnn_cell.RNNCell): 476 """Computes backward decoding in a linear-chain CRF. 477 """ 478 479 def __init__(self, num_tags): 480 """Initialize the CrfDecodeBackwardRnnCell. 481 482 Args: 483 num_tags: An integer. The number of tags. 484 """ 485 self._num_tags = num_tags 486 487 @property 488 def state_size(self): 489 return 1 490 491 @property 492 def output_size(self): 493 return 1 494 495 def __call__(self, inputs, state, scope=None): 496 """Build the CrfDecodeBackwardRnnCell. 497 498 Args: 499 inputs: A [batch_size, num_tags] matrix of 500 backpointer of next step (in time order). 501 state: A [batch_size, 1] matrix of tag index of next step. 502 scope: Unused variable scope of this cell. 503 504 Returns: 505 new_tags, new_tags: A pair of [batch_size, num_tags] 506 tensors containing the new tag indices. 507 """ 508 state = array_ops.squeeze(state, axis=[1]) # [B] 509 batch_size = array_ops.shape(inputs)[0] 510 b_indices = math_ops.range(batch_size) # [B] 511 indices = array_ops.stack([b_indices, state], axis=1) # [B, 2] 512 new_tags = array_ops.expand_dims( 513 gen_array_ops.gather_nd(inputs, indices), # [B] 514 axis=-1) # [B, 1] 515 516 return new_tags, new_tags 517 518 519def crf_decode(potentials, transition_params, sequence_length): 520 """Decode the highest scoring sequence of tags in TensorFlow. 521 522 This is a function for tensor. 523 524 Args: 525 potentials: A [batch_size, max_seq_len, num_tags] tensor of 526 unary potentials. 527 transition_params: A [num_tags, num_tags] matrix of 528 binary potentials. 529 sequence_length: A [batch_size] vector of true sequence lengths. 530 531 Returns: 532 decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`. 533 Contains the highest scoring tag indices. 534 best_score: A [batch_size] vector, containing the score of `decode_tags`. 535 """ 536 # If max_seq_len is 1, we skip the algorithm and simply return the argmax tag 537 # and the max activation. 538 def _single_seq_fn(): 539 squeezed_potentials = array_ops.squeeze(potentials, [1]) 540 decode_tags = array_ops.expand_dims( 541 math_ops.argmax(squeezed_potentials, axis=1), 1) 542 best_score = math_ops.reduce_max(squeezed_potentials, axis=1) 543 return math_ops.cast(decode_tags, dtype=dtypes.int32), best_score 544 545 def _multi_seq_fn(): 546 """Decoding of highest scoring sequence.""" 547 548 # For simplicity, in shape comments, denote: 549 # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output). 550 num_tags = tensor_shape.dimension_value(potentials.shape[2]) 551 552 # Computes forward decoding. Get last score and backpointers. 553 crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params) 554 initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1]) 555 initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O] 556 inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O] 557 # Sequence length is not allowed to be less than zero. 558 sequence_length_less_one = math_ops.maximum( 559 constant_op.constant(0, dtype=sequence_length.dtype), 560 sequence_length - 1) 561 backpointers, last_score = rnn.dynamic_rnn( # [B, T - 1, O], [B, O] 562 crf_fwd_cell, 563 inputs=inputs, 564 sequence_length=sequence_length_less_one, 565 initial_state=initial_state, 566 time_major=False, 567 dtype=dtypes.int32) 568 backpointers = gen_array_ops.reverse_sequence( # [B, T - 1, O] 569 backpointers, sequence_length_less_one, seq_dim=1) 570 571 # Computes backward decoding. Extract tag indices from backpointers. 572 crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags) 573 initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1), # [B] 574 dtype=dtypes.int32) 575 initial_state = array_ops.expand_dims(initial_state, axis=-1) # [B, 1] 576 decode_tags, _ = rnn.dynamic_rnn( # [B, T - 1, 1] 577 crf_bwd_cell, 578 inputs=backpointers, 579 sequence_length=sequence_length_less_one, 580 initial_state=initial_state, 581 time_major=False, 582 dtype=dtypes.int32) 583 decode_tags = array_ops.squeeze(decode_tags, axis=[2]) # [B, T - 1] 584 decode_tags = array_ops.concat([initial_state, decode_tags], # [B, T] 585 axis=1) 586 decode_tags = gen_array_ops.reverse_sequence( # [B, T] 587 decode_tags, sequence_length, seq_dim=1) 588 589 best_score = math_ops.reduce_max(last_score, axis=1) # [B] 590 return decode_tags, best_score 591 592 return utils.smart_cond( 593 pred=math_ops.equal(tensor_shape.dimension_value(potentials.shape[1]) or 594 array_ops.shape(potentials)[1], 1), 595 true_fn=_single_seq_fn, 596 false_fn=_multi_seq_fn) 597