1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A decoder that performs beam search.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import numpy as np 23 24from tensorflow.contrib.seq2seq.python.ops import attention_wrapper 25from tensorflow.contrib.seq2seq.python.ops import beam_search_ops 26from tensorflow.contrib.seq2seq.python.ops import decoder 27from tensorflow.python.eager import context 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import tensor_shape 31from tensorflow.python.framework import tensor_util 32from tensorflow.python.keras import layers 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import control_flow_ops 35from tensorflow.python.ops import embedding_ops 36from tensorflow.python.ops import math_ops 37from tensorflow.python.ops import nn_ops 38from tensorflow.python.ops import rnn_cell_impl 39from tensorflow.python.ops import tensor_array_ops 40from tensorflow.python.platform import tf_logging 41from tensorflow.python.util import nest 42 43__all__ = [ 44 "BeamSearchDecoderOutput", 45 "BeamSearchDecoderState", 46 "BeamSearchDecoder", 47 "FinalBeamSearchDecoderOutput", 48 "tile_batch", 49] 50 51 52class BeamSearchDecoderState( 53 collections.namedtuple("BeamSearchDecoderState", 54 ("cell_state", "log_probs", "finished", "lengths", 55 "accumulated_attention_probs"))): 56 pass 57 58 59class BeamSearchDecoderOutput( 60 collections.namedtuple("BeamSearchDecoderOutput", 61 ("scores", "predicted_ids", "parent_ids"))): 62 pass 63 64 65class FinalBeamSearchDecoderOutput( 66 collections.namedtuple("FinalBeamDecoderOutput", 67 ["predicted_ids", "beam_search_decoder_output"])): 68 """Final outputs returned by the beam search after all decoding is finished. 69 70 Args: 71 predicted_ids: The final prediction. A tensor of shape 72 `[batch_size, T, beam_width]` (or `[T, batch_size, beam_width]` if 73 `output_time_major` is True). Beams are ordered from best to worst. 74 beam_search_decoder_output: An instance of `BeamSearchDecoderOutput` that 75 describes the state of the beam search. 76 """ 77 pass 78 79 80def _tile_batch(t, multiplier): 81 """Core single-tensor implementation of tile_batch.""" 82 t = ops.convert_to_tensor(t, name="t") 83 shape_t = array_ops.shape(t) 84 if t.shape.ndims is None or t.shape.ndims < 1: 85 raise ValueError("t must have statically known rank") 86 tiling = [1] * (t.shape.ndims + 1) 87 tiling[1] = multiplier 88 tiled_static_batch_size = ( 89 t.shape.dims[0].value * multiplier 90 if t.shape.dims[0].value is not None else None) 91 tiled = array_ops.tile(array_ops.expand_dims(t, 1), tiling) 92 tiled = array_ops.reshape(tiled, 93 array_ops.concat( 94 ([shape_t[0] * multiplier], shape_t[1:]), 0)) 95 tiled.set_shape( 96 tensor_shape.TensorShape([tiled_static_batch_size]).concatenate( 97 t.shape[1:])) 98 return tiled 99 100 101def tile_batch(t, multiplier, name=None): 102 """Tile the batch dimension of a (possibly nested structure of) tensor(s) t. 103 104 For each tensor t in a (possibly nested structure) of tensors, 105 this function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of 106 minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a shape 107 `[batch_size * multiplier, s0, s1, ...]` composed of minibatch entries 108 `t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is repeated 109 `multiplier` times. 110 111 Args: 112 t: `Tensor` shaped `[batch_size, ...]`. 113 multiplier: Python int. 114 name: Name scope for any created operations. 115 116 Returns: 117 A (possibly nested structure of) `Tensor` shaped 118 `[batch_size * multiplier, ...]`. 119 120 Raises: 121 ValueError: if tensor(s) `t` do not have a statically known rank or 122 the rank is < 1. 123 """ 124 flat_t = nest.flatten(t) 125 with ops.name_scope(name, "tile_batch", flat_t + [multiplier]): 126 return nest.map_structure(lambda t_: _tile_batch(t_, multiplier), t) 127 128 129def gather_tree_from_array(t, parent_ids, sequence_length): 130 """Calculates the full beams for `TensorArray`s. 131 132 Args: 133 t: A stacked `TensorArray` of size `max_time` that contains `Tensor`s of 134 shape `[batch_size, beam_width, s]` or `[batch_size * beam_width, s]` 135 where `s` is the depth shape. 136 parent_ids: The parent ids of shape `[max_time, batch_size, beam_width]`. 137 sequence_length: The sequence length of shape `[batch_size, beam_width]`. 138 139 Returns: 140 A `Tensor` which is a stacked `TensorArray` of the same size and type as 141 `t` and where beams are sorted in each `Tensor` according to `parent_ids`. 142 """ 143 max_time = parent_ids.shape.dims[0].value or array_ops.shape(parent_ids)[0] 144 batch_size = parent_ids.shape.dims[1].value or array_ops.shape(parent_ids)[1] 145 beam_width = parent_ids.shape.dims[2].value or array_ops.shape(parent_ids)[2] 146 147 # Generate beam ids that will be reordered by gather_tree. 148 beam_ids = array_ops.expand_dims( 149 array_ops.expand_dims(math_ops.range(beam_width), 0), 0) 150 beam_ids = array_ops.tile(beam_ids, [max_time, batch_size, 1]) 151 152 max_sequence_lengths = math_ops.cast( 153 math_ops.reduce_max(sequence_length, axis=1), dtypes.int32) 154 sorted_beam_ids = beam_search_ops.gather_tree( 155 step_ids=beam_ids, 156 parent_ids=parent_ids, 157 max_sequence_lengths=max_sequence_lengths, 158 end_token=beam_width + 1) 159 160 # For out of range steps, simply copy the same beam. 161 in_bound_steps = array_ops.transpose( 162 array_ops.sequence_mask(sequence_length, maxlen=max_time), 163 perm=[2, 0, 1]) 164 sorted_beam_ids = array_ops.where( 165 in_bound_steps, x=sorted_beam_ids, y=beam_ids) 166 167 # Generate indices for gather_nd. 168 time_ind = array_ops.tile(array_ops.reshape( 169 math_ops.range(max_time), [-1, 1, 1]), [1, batch_size, beam_width]) 170 batch_ind = array_ops.tile(array_ops.reshape( 171 math_ops.range(batch_size), [-1, 1, 1]), [1, max_time, beam_width]) 172 batch_ind = array_ops.transpose(batch_ind, perm=[1, 0, 2]) 173 indices = array_ops.stack([time_ind, batch_ind, sorted_beam_ids], -1) 174 175 # Gather from a tensor with collapsed additional dimensions. 176 gather_from = t 177 final_shape = array_ops.shape(gather_from) 178 gather_from = array_ops.reshape( 179 gather_from, [max_time, batch_size, beam_width, -1]) 180 ordered = array_ops.gather_nd(gather_from, indices) 181 ordered = array_ops.reshape(ordered, final_shape) 182 183 return ordered 184 185 186def _check_ndims(t): 187 if t.shape.ndims is None: 188 raise ValueError( 189 "Expected tensor (%s) to have known rank, but ndims == None." % t) 190 191 192def _check_static_batch_beam_maybe(shape, batch_size, beam_width): 193 """Raises an exception if dimensions are known statically and can not be 194 reshaped to [batch_size, beam_size, -1]. 195 """ 196 reshaped_shape = tensor_shape.TensorShape([batch_size, beam_width, None]) 197 if (batch_size is not None and shape.dims[0].value is not None 198 and (shape[0] != batch_size * beam_width 199 or (shape.ndims >= 2 and shape.dims[1].value is not None 200 and (shape[0] != batch_size or shape[1] != beam_width)))): 201 tf_logging.warn("TensorArray reordering expects elements to be " 202 "reshapable to %s which is incompatible with the " 203 "current shape %s. Consider setting " 204 "reorder_tensor_arrays to False to disable TensorArray " 205 "reordering during the beam search." 206 % (reshaped_shape, shape)) 207 return False 208 return True 209 210 211def _check_batch_beam(t, batch_size, beam_width): 212 """Returns an Assert operation checking that the elements of the stacked 213 TensorArray can be reshaped to [batch_size, beam_size, -1]. At this point, 214 the TensorArray elements have a known rank of at least 1. 215 """ 216 error_message = ("TensorArray reordering expects elements to be " 217 "reshapable to [batch_size, beam_size, -1] which is " 218 "incompatible with the dynamic shape of %s elements. " 219 "Consider setting reorder_tensor_arrays to False to disable " 220 "TensorArray reordering during the beam search." 221 % (t if context.executing_eagerly() else t.name)) 222 rank = t.shape.ndims 223 shape = array_ops.shape(t) 224 if rank == 2: 225 condition = math_ops.equal(shape[1], batch_size * beam_width) 226 else: 227 condition = math_ops.logical_or( 228 math_ops.equal(shape[1], batch_size * beam_width), 229 math_ops.logical_and( 230 math_ops.equal(shape[1], batch_size), 231 math_ops.equal(shape[2], beam_width))) 232 return control_flow_ops.Assert(condition, [error_message]) 233 234 235class BeamSearchDecoderMixin(object): 236 """BeamSearchDecoderMixin contains the common methods for BeamSearchDecoder. 237 238 It is expected to be used a base class for concrete BeamSearchDecoder. Since 239 this is a mixin class, it is expected to be used together with other class as 240 base. 241 """ 242 243 def __init__(self, 244 cell, 245 beam_width, 246 output_layer=None, 247 length_penalty_weight=0.0, 248 coverage_penalty_weight=0.0, 249 reorder_tensor_arrays=True, 250 **kwargs): 251 """Initialize the BeamSearchDecoderMixin. 252 253 Args: 254 cell: An `RNNCell` instance. 255 beam_width: Python integer, the number of beams. 256 output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e., 257 `tf.keras.layers.Dense`. Optional layer to apply to the RNN output 258 prior to storing the result or sampling. 259 length_penalty_weight: Float weight to penalize length. Disabled with 0.0. 260 coverage_penalty_weight: Float weight to penalize the coverage of source 261 sentence. Disabled with 0.0. 262 reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell 263 state will be reordered according to the beam search path. If the 264 `TensorArray` can be reordered, the stacked form will be returned. 265 Otherwise, the `TensorArray` will be returned as is. Set this flag to 266 `False` if the cell state contains `TensorArray`s that are not amenable 267 to reordering. 268 **kwargs: Dict, other keyword arguments for parent class. 269 270 Raises: 271 TypeError: if `cell` is not an instance of `RNNCell`, 272 or `output_layer` is not an instance of `tf.keras.layers.Layer`. 273 """ 274 rnn_cell_impl.assert_like_rnncell("cell", cell) # pylint: disable=protected-access 275 if (output_layer is not None and 276 not isinstance(output_layer, layers.Layer)): 277 raise TypeError( 278 "output_layer must be a Layer, received: %s" % type(output_layer)) 279 self._cell = cell 280 self._output_layer = output_layer 281 self._reorder_tensor_arrays = reorder_tensor_arrays 282 283 self._start_tokens = None 284 self._end_token = None 285 self._batch_size = None 286 self._beam_width = beam_width 287 self._length_penalty_weight = length_penalty_weight 288 self._coverage_penalty_weight = coverage_penalty_weight 289 super(BeamSearchDecoderMixin, self).__init__(**kwargs) 290 291 @property 292 def batch_size(self): 293 return self._batch_size 294 295 def _rnn_output_size(self): 296 """Get the output shape from the RNN layer.""" 297 size = self._cell.output_size 298 if self._output_layer is None: 299 return size 300 else: 301 # To use layer's compute_output_shape, we need to convert the 302 # RNNCell's output_size entries into shapes with an unknown 303 # batch size. We then pass this through the layer's 304 # compute_output_shape and read off all but the first (batch) 305 # dimensions to get the output size of the rnn with the layer 306 # applied to the top. 307 output_shape_with_unknown_batch = nest.map_structure( 308 lambda s: tensor_shape.TensorShape([None]).concatenate(s), size) 309 layer_output_shape = self._output_layer.compute_output_shape( 310 output_shape_with_unknown_batch) 311 return nest.map_structure(lambda s: s[1:], layer_output_shape) 312 313 @property 314 def tracks_own_finished(self): 315 """The BeamSearchDecoder shuffles its beams and their finished state. 316 317 For this reason, it conflicts with the `dynamic_decode` function's 318 tracking of finished states. Setting this property to true avoids 319 early stopping of decoding due to mismanagement of the finished state 320 in `dynamic_decode`. 321 322 Returns: 323 `True`. 324 """ 325 return True 326 327 @property 328 def output_size(self): 329 # Return the cell output and the id 330 return BeamSearchDecoderOutput( 331 scores=tensor_shape.TensorShape([self._beam_width]), 332 predicted_ids=tensor_shape.TensorShape([self._beam_width]), 333 parent_ids=tensor_shape.TensorShape([self._beam_width])) 334 335 def finalize(self, outputs, final_state, sequence_lengths): 336 """Finalize and return the predicted_ids. 337 338 Args: 339 outputs: An instance of BeamSearchDecoderOutput. 340 final_state: An instance of BeamSearchDecoderState. Passed through to the 341 output. 342 sequence_lengths: An `int64` tensor shaped `[batch_size, beam_width]`. 343 The sequence lengths determined for each beam during decode. 344 **NOTE** These are ignored; the updated sequence lengths are stored in 345 `final_state.lengths`. 346 347 Returns: 348 outputs: An instance of `FinalBeamSearchDecoderOutput` where the 349 predicted_ids are the result of calling _gather_tree. 350 final_state: The same input instance of `BeamSearchDecoderState`. 351 """ 352 del sequence_lengths 353 # Get max_sequence_length across all beams for each batch. 354 max_sequence_lengths = math_ops.cast( 355 math_ops.reduce_max(final_state.lengths, axis=1), dtypes.int32) 356 predicted_ids = beam_search_ops.gather_tree( 357 outputs.predicted_ids, 358 outputs.parent_ids, 359 max_sequence_lengths=max_sequence_lengths, 360 end_token=self._end_token) 361 if self._reorder_tensor_arrays: 362 final_state = final_state._replace(cell_state=nest.map_structure( 363 lambda t: self._maybe_sort_array_beams( 364 t, outputs.parent_ids, final_state.lengths), 365 final_state.cell_state)) 366 outputs = FinalBeamSearchDecoderOutput( 367 beam_search_decoder_output=outputs, predicted_ids=predicted_ids) 368 return outputs, final_state 369 370 def _merge_batch_beams(self, t, s=None): 371 """Merges the tensor from a batch of beams into a batch by beams. 372 373 More exactly, t is a tensor of dimension [batch_size, beam_width, s]. We 374 reshape this into [batch_size*beam_width, s] 375 376 Args: 377 t: Tensor of dimension [batch_size, beam_width, s] 378 s: (Possibly known) depth shape. 379 380 Returns: 381 A reshaped version of t with dimension [batch_size * beam_width, s]. 382 """ 383 if isinstance(s, ops.Tensor): 384 s = tensor_shape.as_shape(tensor_util.constant_value(s)) 385 else: 386 s = tensor_shape.TensorShape(s) 387 t_shape = array_ops.shape(t) 388 static_batch_size = tensor_util.constant_value(self._batch_size) 389 batch_size_beam_width = ( 390 None 391 if static_batch_size is None else static_batch_size * self._beam_width) 392 reshaped_t = array_ops.reshape( 393 t, 394 array_ops.concat(([self._batch_size * self._beam_width], t_shape[2:]), 395 0)) 396 reshaped_t.set_shape( 397 (tensor_shape.TensorShape([batch_size_beam_width]).concatenate(s))) 398 return reshaped_t 399 400 def _split_batch_beams(self, t, s=None): 401 """Splits the tensor from a batch by beams into a batch of beams. 402 403 More exactly, t is a tensor of dimension [batch_size*beam_width, s]. We 404 reshape this into [batch_size, beam_width, s] 405 406 Args: 407 t: Tensor of dimension [batch_size*beam_width, s]. 408 s: (Possibly known) depth shape. 409 410 Returns: 411 A reshaped version of t with dimension [batch_size, beam_width, s]. 412 413 Raises: 414 ValueError: If, after reshaping, the new tensor is not shaped 415 `[batch_size, beam_width, s]` (assuming batch_size and beam_width 416 are known statically). 417 """ 418 if isinstance(s, ops.Tensor): 419 s = tensor_shape.TensorShape(tensor_util.constant_value(s)) 420 else: 421 s = tensor_shape.TensorShape(s) 422 t_shape = array_ops.shape(t) 423 reshaped_t = array_ops.reshape( 424 t, 425 array_ops.concat(([self._batch_size, self._beam_width], t_shape[1:]), 426 0)) 427 static_batch_size = tensor_util.constant_value(self._batch_size) 428 expected_reshaped_shape = tensor_shape.TensorShape( 429 [static_batch_size, self._beam_width]).concatenate(s) 430 if not reshaped_t.shape.is_compatible_with(expected_reshaped_shape): 431 raise ValueError("Unexpected behavior when reshaping between beam width " 432 "and batch size. The reshaped tensor has shape: %s. " 433 "We expected it to have shape " 434 "(batch_size, beam_width, depth) == %s. Perhaps you " 435 "forgot to create a zero_state with " 436 "batch_size=encoder_batch_size * beam_width?" % 437 (reshaped_t.shape, expected_reshaped_shape)) 438 reshaped_t.set_shape(expected_reshaped_shape) 439 return reshaped_t 440 441 def _maybe_split_batch_beams(self, t, s): 442 """Maybe splits the tensor from a batch by beams into a batch of beams. 443 444 We do this so that we can use nest and not run into problems with shapes. 445 446 Args: 447 t: `Tensor`, either scalar or shaped `[batch_size * beam_width] + s`. 448 s: `Tensor`, Python int, or `TensorShape`. 449 450 Returns: 451 If `t` is a matrix or higher order tensor, then the return value is 452 `t` reshaped to `[batch_size, beam_width] + s`. Otherwise `t` is 453 returned unchanged. 454 455 Raises: 456 ValueError: If the rank of `t` is not statically known. 457 """ 458 if isinstance(t, tensor_array_ops.TensorArray): 459 return t 460 _check_ndims(t) 461 if t.shape.ndims >= 1: 462 return self._split_batch_beams(t, s) 463 else: 464 return t 465 466 def _maybe_merge_batch_beams(self, t, s): 467 """Splits the tensor from a batch by beams into a batch of beams. 468 469 More exactly, `t` is a tensor of dimension `[batch_size * beam_width] + s`, 470 then we reshape it to `[batch_size, beam_width] + s`. 471 472 Args: 473 t: `Tensor` of dimension `[batch_size * beam_width] + s`. 474 s: `Tensor`, Python int, or `TensorShape`. 475 476 Returns: 477 A reshaped version of t with shape `[batch_size, beam_width] + s`. 478 479 Raises: 480 ValueError: If the rank of `t` is not statically known. 481 """ 482 if isinstance(t, tensor_array_ops.TensorArray): 483 return t 484 _check_ndims(t) 485 if t.shape.ndims >= 2: 486 return self._merge_batch_beams(t, s) 487 else: 488 return t 489 490 def _maybe_sort_array_beams(self, t, parent_ids, sequence_length): 491 """Maybe sorts beams within a `TensorArray`. 492 493 Args: 494 t: A `TensorArray` of size `max_time` that contains `Tensor`s of shape 495 `[batch_size, beam_width, s]` or `[batch_size * beam_width, s]` where 496 `s` is the depth shape. 497 parent_ids: The parent ids of shape `[max_time, batch_size, beam_width]`. 498 sequence_length: The sequence length of shape `[batch_size, beam_width]`. 499 500 Returns: 501 A `TensorArray` where beams are sorted in each `Tensor` or `t` itself if 502 it is not a `TensorArray` or does not meet shape requirements. 503 """ 504 if not isinstance(t, tensor_array_ops.TensorArray): 505 return t 506 # pylint: disable=protected-access 507 # This is a bad hack due to the implementation detail of eager/graph TA. 508 # TODO(b/124374427): Update this to use public property of TensorArray. 509 if context.executing_eagerly(): 510 element_shape = t._element_shape 511 else: 512 element_shape = t._element_shape[0] 513 if (not t._infer_shape 514 or not t._element_shape 515 or element_shape.ndims is None 516 or element_shape.ndims < 1): 517 shape = ( 518 element_shape if t._infer_shape and t._element_shape 519 else tensor_shape.TensorShape(None)) 520 tf_logging.warn("The TensorArray %s in the cell state is not amenable to " 521 "sorting based on the beam search result. For a " 522 "TensorArray to be sorted, its elements shape must be " 523 "defined and have at least a rank of 1, but saw shape: %s" 524 % (t.handle.name, shape)) 525 return t 526 # pylint: enable=protected-access 527 if not _check_static_batch_beam_maybe( 528 element_shape, tensor_util.constant_value(self._batch_size), 529 self._beam_width): 530 return t 531 t = t.stack() 532 with ops.control_dependencies( 533 [_check_batch_beam(t, self._batch_size, self._beam_width)]): 534 return gather_tree_from_array(t, parent_ids, sequence_length) 535 536 def step(self, time, inputs, state, name=None): 537 """Perform a decoding step. 538 539 Args: 540 time: scalar `int32` tensor. 541 inputs: A (structure of) input tensors. 542 state: A (structure of) state tensors and TensorArrays. 543 name: Name scope for any created operations. 544 545 Returns: 546 `(outputs, next_state, next_inputs, finished)`. 547 """ 548 batch_size = self._batch_size 549 beam_width = self._beam_width 550 end_token = self._end_token 551 length_penalty_weight = self._length_penalty_weight 552 coverage_penalty_weight = self._coverage_penalty_weight 553 554 with ops.name_scope(name, "BeamSearchDecoderStep", (time, inputs, state)): 555 cell_state = state.cell_state 556 inputs = nest.map_structure( 557 lambda inp: self._merge_batch_beams(inp, s=inp.shape[2:]), inputs) 558 cell_state = nest.map_structure(self._maybe_merge_batch_beams, cell_state, 559 self._cell.state_size) 560 cell_outputs, next_cell_state = self._cell(inputs, cell_state) 561 cell_outputs = nest.map_structure( 562 lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs) 563 next_cell_state = nest.map_structure( 564 self._maybe_split_batch_beams, next_cell_state, self._cell.state_size) 565 566 if self._output_layer is not None: 567 cell_outputs = self._output_layer(cell_outputs) 568 569 beam_search_output, beam_search_state = _beam_search_step( 570 time=time, 571 logits=cell_outputs, 572 next_cell_state=next_cell_state, 573 beam_state=state, 574 batch_size=batch_size, 575 beam_width=beam_width, 576 end_token=end_token, 577 length_penalty_weight=length_penalty_weight, 578 coverage_penalty_weight=coverage_penalty_weight) 579 580 finished = beam_search_state.finished 581 sample_ids = beam_search_output.predicted_ids 582 next_inputs = control_flow_ops.cond( 583 math_ops.reduce_all(finished), lambda: self._start_inputs, 584 lambda: self._embedding_fn(sample_ids)) 585 586 return (beam_search_output, beam_search_state, next_inputs, finished) 587 588 589class BeamSearchDecoder(BeamSearchDecoderMixin, decoder.Decoder): 590 # Note that the inheritance hierarchy is important here. The Mixin has to be 591 # the first parent class since we will use super().__init__(), and Mixin which 592 # is a object will properly invoke the __init__ method of other parent class. 593 """BeamSearch sampling decoder. 594 595 **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in 596 `AttentionWrapper`, then you must ensure that: 597 598 - The encoder output has been tiled to `beam_width` via 599 `tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`). 600 - The `batch_size` argument passed to the `zero_state` method of this 601 wrapper is equal to `true_batch_size * beam_width`. 602 - The initial state created with `zero_state` above contains a 603 `cell_state` value containing properly tiled final state from the 604 encoder. 605 606 An example: 607 608 ``` 609 tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch( 610 encoder_outputs, multiplier=beam_width) 611 tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch( 612 encoder_final_state, multiplier=beam_width) 613 tiled_sequence_length = tf.contrib.seq2seq.tile_batch( 614 sequence_length, multiplier=beam_width) 615 attention_mechanism = MyFavoriteAttentionMechanism( 616 num_units=attention_depth, 617 memory=tiled_inputs, 618 memory_sequence_length=tiled_sequence_length) 619 attention_cell = AttentionWrapper(cell, attention_mechanism, ...) 620 decoder_initial_state = attention_cell.zero_state( 621 dtype, batch_size=true_batch_size * beam_width) 622 decoder_initial_state = decoder_initial_state.clone( 623 cell_state=tiled_encoder_final_state) 624 ``` 625 626 Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use 627 when computing scores (https://arxiv.org/pdf/1609.08144.pdf). It encourages 628 the decoder to cover all inputs. 629 """ 630 631 def __init__(self, 632 cell, 633 embedding, 634 start_tokens, 635 end_token, 636 initial_state, 637 beam_width, 638 output_layer=None, 639 length_penalty_weight=0.0, 640 coverage_penalty_weight=0.0, 641 reorder_tensor_arrays=True): 642 """Initialize the BeamSearchDecoder. 643 644 Args: 645 cell: An `RNNCell` instance. 646 embedding: A callable that takes a vector tensor of `ids` (argmax ids), 647 or the `params` argument for `embedding_lookup`. 648 start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. 649 end_token: `int32` scalar, the token that marks end of decoding. 650 initial_state: A (possibly nested tuple of...) tensors and TensorArrays. 651 beam_width: Python integer, the number of beams. 652 output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e., 653 `tf.keras.layers.Dense`. Optional layer to apply to the RNN output 654 prior to storing the result or sampling. 655 length_penalty_weight: Float weight to penalize length. Disabled with 0.0. 656 coverage_penalty_weight: Float weight to penalize the coverage of source 657 sentence. Disabled with 0.0. 658 reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell 659 state will be reordered according to the beam search path. If the 660 `TensorArray` can be reordered, the stacked form will be returned. 661 Otherwise, the `TensorArray` will be returned as is. Set this flag to 662 `False` if the cell state contains `TensorArray`s that are not amenable 663 to reordering. 664 665 Raises: 666 TypeError: if `cell` is not an instance of `RNNCell`, 667 or `output_layer` is not an instance of `tf.keras.layers.Layer`. 668 ValueError: If `start_tokens` is not a vector or 669 `end_token` is not a scalar. 670 """ 671 super(BeamSearchDecoder, self).__init__( 672 cell, 673 beam_width, 674 output_layer=output_layer, 675 length_penalty_weight=length_penalty_weight, 676 coverage_penalty_weight=coverage_penalty_weight, 677 reorder_tensor_arrays=reorder_tensor_arrays) 678 679 if callable(embedding): 680 self._embedding_fn = embedding 681 else: 682 self._embedding_fn = ( 683 lambda ids: embedding_ops.embedding_lookup(embedding, ids)) 684 685 self._start_tokens = ops.convert_to_tensor( 686 start_tokens, dtype=dtypes.int32, name="start_tokens") 687 if self._start_tokens.get_shape().ndims != 1: 688 raise ValueError("start_tokens must be a vector") 689 self._end_token = ops.convert_to_tensor( 690 end_token, dtype=dtypes.int32, name="end_token") 691 if self._end_token.get_shape().ndims != 0: 692 raise ValueError("end_token must be a scalar") 693 694 self._batch_size = array_ops.size(start_tokens) 695 self._initial_cell_state = nest.map_structure( 696 self._maybe_split_batch_beams, initial_state, self._cell.state_size) 697 self._start_tokens = array_ops.tile( 698 array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width]) 699 self._start_inputs = self._embedding_fn(self._start_tokens) 700 701 self._finished = array_ops.one_hot( 702 array_ops.zeros([self._batch_size], dtype=dtypes.int32), 703 depth=self._beam_width, 704 on_value=False, 705 off_value=True, 706 dtype=dtypes.bool) 707 708 def initialize(self, name=None): 709 """Initialize the decoder. 710 711 Args: 712 name: Name scope for any created operations. 713 714 Returns: 715 `(finished, start_inputs, initial_state)`. 716 """ 717 finished, start_inputs = self._finished, self._start_inputs 718 719 dtype = nest.flatten(self._initial_cell_state)[0].dtype 720 log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz) 721 array_ops.zeros([self._batch_size], dtype=dtypes.int32), 722 depth=self._beam_width, 723 on_value=ops.convert_to_tensor(0.0, dtype=dtype), 724 off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype), 725 dtype=dtype) 726 init_attention_probs = get_attention_probs( 727 self._initial_cell_state, self._coverage_penalty_weight) 728 if init_attention_probs is None: 729 init_attention_probs = () 730 731 initial_state = BeamSearchDecoderState( 732 cell_state=self._initial_cell_state, 733 log_probs=log_probs, 734 finished=finished, 735 lengths=array_ops.zeros( 736 [self._batch_size, self._beam_width], dtype=dtypes.int64), 737 accumulated_attention_probs=init_attention_probs) 738 739 return (finished, start_inputs, initial_state) 740 741 @property 742 def output_dtype(self): 743 # Assume the dtype of the cell is the output_size structure 744 # containing the input_state's first component's dtype. 745 # Return that structure and int32 (the id) 746 dtype = nest.flatten(self._initial_cell_state)[0].dtype 747 return BeamSearchDecoderOutput( 748 scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()), 749 predicted_ids=dtypes.int32, 750 parent_ids=dtypes.int32) 751 752 753class BeamSearchDecoderV2(BeamSearchDecoderMixin, decoder.BaseDecoder): 754 # Note that the inheritance hierarchy is important here. The Mixin has to be 755 # the first parent class since we will use super().__init__(), and Mixin which 756 # is a object will properly invoke the __init__ method of other parent class. 757 """BeamSearch sampling decoder. 758 759 **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in 760 `AttentionWrapper`, then you must ensure that: 761 762 - The encoder output has been tiled to `beam_width` via 763 `tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`). 764 - The `batch_size` argument passed to the `zero_state` method of this 765 wrapper is equal to `true_batch_size * beam_width`. 766 - The initial state created with `zero_state` above contains a 767 `cell_state` value containing properly tiled final state from the 768 encoder. 769 770 An example: 771 772 ``` 773 tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch( 774 encoder_outputs, multiplier=beam_width) 775 tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch( 776 encoder_final_state, multiplier=beam_width) 777 tiled_sequence_length = tf.contrib.seq2seq.tile_batch( 778 sequence_length, multiplier=beam_width) 779 attention_mechanism = MyFavoriteAttentionMechanism( 780 num_units=attention_depth, 781 memory=tiled_inputs, 782 memory_sequence_length=tiled_sequence_length) 783 attention_cell = AttentionWrapper(cell, attention_mechanism, ...) 784 decoder_initial_state = attention_cell.zero_state( 785 dtype, batch_size=true_batch_size * beam_width) 786 decoder_initial_state = decoder_initial_state.clone( 787 cell_state=tiled_encoder_final_state) 788 ``` 789 790 Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use 791 when computing scores (https://arxiv.org/pdf/1609.08144.pdf). It encourages 792 the decoding to cover all inputs. 793 """ 794 795 def __init__(self, 796 cell, 797 beam_width, 798 embedding_fn=None, 799 output_layer=None, 800 length_penalty_weight=0.0, 801 coverage_penalty_weight=0.0, 802 reorder_tensor_arrays=True, 803 **kwargs): 804 """Initialize the BeamSearchDecoderV2. 805 806 Args: 807 cell: An `RNNCell` instance. 808 beam_width: Python integer, the number of beams. 809 embedding_fn: A callable that takes a vector tensor of `ids` (argmax ids). 810 output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e., 811 `tf.keras.layers.Dense`. Optional layer to apply to the RNN output 812 prior to storing the result or sampling. 813 length_penalty_weight: Float weight to penalize length. Disabled with 0.0. 814 coverage_penalty_weight: Float weight to penalize the coverage of source 815 sentence. Disabled with 0.0. 816 reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell 817 state will be reordered according to the beam search path. If the 818 `TensorArray` can be reordered, the stacked form will be returned. 819 Otherwise, the `TensorArray` will be returned as is. Set this flag to 820 `False` if the cell state contains `TensorArray`s that are not amenable 821 to reordering. 822 **kwargs: Dict, other keyword arguments for initialization. 823 824 Raises: 825 TypeError: if `cell` is not an instance of `RNNCell`, 826 or `output_layer` is not an instance of `tf.keras.layers.Layer`. 827 """ 828 super(BeamSearchDecoderV2, self).__init__( 829 cell, 830 beam_width, 831 output_layer=output_layer, 832 length_penalty_weight=length_penalty_weight, 833 coverage_penalty_weight=coverage_penalty_weight, 834 reorder_tensor_arrays=reorder_tensor_arrays, 835 **kwargs) 836 837 if embedding_fn is None or callable(embedding_fn): 838 self._embedding_fn = embedding_fn 839 else: 840 raise ValueError("embedding_fn is expected to be a callable, got %s" % 841 type(embedding_fn)) 842 843 def initialize(self, 844 embedding, 845 start_tokens, 846 end_token, 847 initial_state): 848 """Initialize the decoder. 849 850 Args: 851 embedding: A tensor from the embedding layer output, which is the 852 `params` argument for `embedding_lookup`. 853 start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. 854 end_token: `int32` scalar, the token that marks end of decoding. 855 initial_state: A (possibly nested tuple of...) tensors and TensorArrays. 856 Returns: 857 `(finished, start_inputs, initial_state)`. 858 Raises: 859 ValueError: If `start_tokens` is not a vector or `end_token` is not a 860 scalar. 861 """ 862 if embedding is not None and self._embedding_fn is not None: 863 raise ValueError( 864 "embedding and embedding_fn cannot be provided at same time") 865 elif embedding is not None: 866 self._embedding_fn = ( 867 lambda ids: embedding_ops.embedding_lookup(embedding, ids)) 868 869 self._start_tokens = ops.convert_to_tensor( 870 start_tokens, dtype=dtypes.int32, name="start_tokens") 871 if self._start_tokens.get_shape().ndims != 1: 872 raise ValueError("start_tokens must be a vector") 873 self._end_token = ops.convert_to_tensor( 874 end_token, dtype=dtypes.int32, name="end_token") 875 if self._end_token.get_shape().ndims != 0: 876 raise ValueError("end_token must be a scalar") 877 878 self._batch_size = array_ops.size(start_tokens) 879 self._initial_cell_state = nest.map_structure( 880 self._maybe_split_batch_beams, initial_state, self._cell.state_size) 881 self._start_tokens = array_ops.tile( 882 array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width]) 883 self._start_inputs = self._embedding_fn(self._start_tokens) 884 885 self._finished = array_ops.one_hot( 886 array_ops.zeros([self._batch_size], dtype=dtypes.int32), 887 depth=self._beam_width, 888 on_value=False, 889 off_value=True, 890 dtype=dtypes.bool) 891 892 finished, start_inputs = self._finished, self._start_inputs 893 894 dtype = nest.flatten(self._initial_cell_state)[0].dtype 895 log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz) 896 array_ops.zeros([self._batch_size], dtype=dtypes.int32), 897 depth=self._beam_width, 898 on_value=ops.convert_to_tensor(0.0, dtype=dtype), 899 off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype), 900 dtype=dtype) 901 init_attention_probs = get_attention_probs( 902 self._initial_cell_state, self._coverage_penalty_weight) 903 if init_attention_probs is None: 904 init_attention_probs = () 905 906 initial_state = BeamSearchDecoderState( 907 cell_state=self._initial_cell_state, 908 log_probs=log_probs, 909 finished=finished, 910 lengths=array_ops.zeros( 911 [self._batch_size, self._beam_width], dtype=dtypes.int64), 912 accumulated_attention_probs=init_attention_probs) 913 914 return (finished, start_inputs, initial_state) 915 916 @property 917 def output_dtype(self): 918 # Assume the dtype of the cell is the output_size structure 919 # containing the input_state's first component's dtype. 920 # Return that structure and int32 (the id) 921 dtype = nest.flatten(self._initial_cell_state)[0].dtype 922 return BeamSearchDecoderOutput( 923 scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()), 924 predicted_ids=dtypes.int32, 925 parent_ids=dtypes.int32) 926 927 def call(self, embeddning, start_tokens, end_token, initial_state, **kwargs): 928 init_kwargs = kwargs 929 init_kwargs["start_tokens"] = start_tokens 930 init_kwargs["end_token"] = end_token 931 init_kwargs["initial_state"] = initial_state 932 return decoder.dynamic_decode(self, 933 output_time_major=self.output_time_major, 934 impute_finished=self.impute_finished, 935 maximum_iterations=self.maximum_iterations, 936 parallel_iterations=self.parallel_iterations, 937 swap_memory=self.swap_memory, 938 decoder_init_input=embeddning, 939 decoder_init_kwargs=init_kwargs) 940 941 942def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, 943 beam_width, end_token, length_penalty_weight, 944 coverage_penalty_weight): 945 """Performs a single step of Beam Search Decoding. 946 947 Args: 948 time: Beam search time step, should start at 0. At time 0 we assume 949 that all beams are equal and consider only the first beam for 950 continuations. 951 logits: Logits at the current time step. A tensor of shape 952 `[batch_size, beam_width, vocab_size]` 953 next_cell_state: The next state from the cell, e.g. an instance of 954 AttentionWrapperState if the cell is attentional. 955 beam_state: Current state of the beam search. 956 An instance of `BeamSearchDecoderState`. 957 batch_size: The batch size for this input. 958 beam_width: Python int. The size of the beams. 959 end_token: The int32 end token. 960 length_penalty_weight: Float weight to penalize length. Disabled with 0.0. 961 coverage_penalty_weight: Float weight to penalize the coverage of source 962 sentence. Disabled with 0.0. 963 964 Returns: 965 A new beam state. 966 """ 967 static_batch_size = tensor_util.constant_value(batch_size) 968 969 # Calculate the current lengths of the predictions 970 prediction_lengths = beam_state.lengths 971 previously_finished = beam_state.finished 972 not_finished = math_ops.logical_not(previously_finished) 973 974 # Calculate the total log probs for the new hypotheses 975 # Final Shape: [batch_size, beam_width, vocab_size] 976 step_log_probs = nn_ops.log_softmax(logits) 977 step_log_probs = _mask_probs(step_log_probs, end_token, previously_finished) 978 total_probs = array_ops.expand_dims(beam_state.log_probs, 2) + step_log_probs 979 980 # Calculate the continuation lengths by adding to all continuing beams. 981 vocab_size = logits.shape.dims[-1].value or array_ops.shape(logits)[-1] 982 lengths_to_add = array_ops.one_hot( 983 indices=array_ops.fill([batch_size, beam_width], end_token), 984 depth=vocab_size, 985 on_value=np.int64(0), 986 off_value=np.int64(1), 987 dtype=dtypes.int64) 988 add_mask = math_ops.cast(not_finished, dtypes.int64) 989 lengths_to_add *= array_ops.expand_dims(add_mask, 2) 990 new_prediction_lengths = ( 991 lengths_to_add + array_ops.expand_dims(prediction_lengths, 2)) 992 993 # Calculate the accumulated attention probabilities if coverage penalty is 994 # enabled. 995 accumulated_attention_probs = None 996 attention_probs = get_attention_probs( 997 next_cell_state, coverage_penalty_weight) 998 if attention_probs is not None: 999 attention_probs *= array_ops.expand_dims( 1000 math_ops.cast(not_finished, dtypes.float32), 2) 1001 accumulated_attention_probs = ( 1002 beam_state.accumulated_attention_probs + attention_probs) 1003 1004 # Calculate the scores for each beam 1005 scores = _get_scores( 1006 log_probs=total_probs, 1007 sequence_lengths=new_prediction_lengths, 1008 length_penalty_weight=length_penalty_weight, 1009 coverage_penalty_weight=coverage_penalty_weight, 1010 finished=previously_finished, 1011 accumulated_attention_probs=accumulated_attention_probs) 1012 1013 time = ops.convert_to_tensor(time, name="time") 1014 # During the first time step we only consider the initial beam 1015 scores_flat = array_ops.reshape(scores, [batch_size, -1]) 1016 1017 # Pick the next beams according to the specified successors function 1018 next_beam_size = ops.convert_to_tensor( 1019 beam_width, dtype=dtypes.int32, name="beam_width") 1020 next_beam_scores, word_indices = nn_ops.top_k(scores_flat, k=next_beam_size) 1021 1022 next_beam_scores.set_shape([static_batch_size, beam_width]) 1023 word_indices.set_shape([static_batch_size, beam_width]) 1024 1025 # Pick out the probs, beam_ids, and states according to the chosen predictions 1026 next_beam_probs = _tensor_gather_helper( 1027 gather_indices=word_indices, 1028 gather_from=total_probs, 1029 batch_size=batch_size, 1030 range_size=beam_width * vocab_size, 1031 gather_shape=[-1], 1032 name="next_beam_probs") 1033 # Note: just doing the following 1034 # math_ops.cast( 1035 # word_indices % vocab_size, 1036 # dtypes.int32, 1037 # name="next_beam_word_ids") 1038 # would be a lot cleaner but for reasons unclear, that hides the results of 1039 # the op which prevents capturing it with tfdbg debug ops. 1040 raw_next_word_ids = math_ops.mod( 1041 word_indices, vocab_size, name="next_beam_word_ids") 1042 next_word_ids = math_ops.cast(raw_next_word_ids, dtypes.int32) 1043 next_beam_ids = math_ops.cast( 1044 word_indices / vocab_size, dtypes.int32, name="next_beam_parent_ids") 1045 1046 # Append new ids to current predictions 1047 previously_finished = _tensor_gather_helper( 1048 gather_indices=next_beam_ids, 1049 gather_from=previously_finished, 1050 batch_size=batch_size, 1051 range_size=beam_width, 1052 gather_shape=[-1]) 1053 next_finished = math_ops.logical_or( 1054 previously_finished, 1055 math_ops.equal(next_word_ids, end_token), 1056 name="next_beam_finished") 1057 1058 # Calculate the length of the next predictions. 1059 # 1. Finished beams remain unchanged. 1060 # 2. Beams that are now finished (EOS predicted) have their length 1061 # increased by 1. 1062 # 3. Beams that are not yet finished have their length increased by 1. 1063 lengths_to_add = math_ops.cast( 1064 math_ops.logical_not(previously_finished), dtypes.int64) 1065 next_prediction_len = _tensor_gather_helper( 1066 gather_indices=next_beam_ids, 1067 gather_from=beam_state.lengths, 1068 batch_size=batch_size, 1069 range_size=beam_width, 1070 gather_shape=[-1]) 1071 next_prediction_len += lengths_to_add 1072 next_accumulated_attention_probs = () 1073 if accumulated_attention_probs is not None: 1074 next_accumulated_attention_probs = _tensor_gather_helper( 1075 gather_indices=next_beam_ids, 1076 gather_from=accumulated_attention_probs, 1077 batch_size=batch_size, 1078 range_size=beam_width, 1079 gather_shape=[batch_size * beam_width, -1], 1080 name="next_accumulated_attention_probs") 1081 1082 # Pick out the cell_states according to the next_beam_ids. We use a 1083 # different gather_shape here because the cell_state tensors, i.e. 1084 # the tensors that would be gathered from, all have dimension 1085 # greater than two and we need to preserve those dimensions. 1086 # pylint: disable=g-long-lambda 1087 next_cell_state = nest.map_structure( 1088 lambda gather_from: _maybe_tensor_gather_helper( 1089 gather_indices=next_beam_ids, 1090 gather_from=gather_from, 1091 batch_size=batch_size, 1092 range_size=beam_width, 1093 gather_shape=[batch_size * beam_width, -1]), 1094 next_cell_state) 1095 # pylint: enable=g-long-lambda 1096 1097 next_state = BeamSearchDecoderState( 1098 cell_state=next_cell_state, 1099 log_probs=next_beam_probs, 1100 lengths=next_prediction_len, 1101 finished=next_finished, 1102 accumulated_attention_probs=next_accumulated_attention_probs) 1103 1104 output = BeamSearchDecoderOutput( 1105 scores=next_beam_scores, 1106 predicted_ids=next_word_ids, 1107 parent_ids=next_beam_ids) 1108 1109 return output, next_state 1110 1111 1112def get_attention_probs(next_cell_state, coverage_penalty_weight): 1113 """Get attention probabilities from the cell state. 1114 1115 Args: 1116 next_cell_state: The next state from the cell, e.g. an instance of 1117 AttentionWrapperState if the cell is attentional. 1118 coverage_penalty_weight: Float weight to penalize the coverage of source 1119 sentence. Disabled with 0.0. 1120 1121 Returns: 1122 The attention probabilities with shape `[batch_size, beam_width, max_time]` 1123 if coverage penalty is enabled. Otherwise, returns None. 1124 1125 Raises: 1126 ValueError: If no cell is attentional but coverage penalty is enabled. 1127 """ 1128 if coverage_penalty_weight == 0.0: 1129 return None 1130 1131 # Attention probabilities of each attention layer. Each with shape 1132 # `[batch_size, beam_width, max_time]`. 1133 probs_per_attn_layer = [] 1134 if isinstance(next_cell_state, attention_wrapper.AttentionWrapperState): 1135 probs_per_attn_layer = [attention_probs_from_attn_state(next_cell_state)] 1136 elif isinstance(next_cell_state, tuple): 1137 for state in next_cell_state: 1138 if isinstance(state, attention_wrapper.AttentionWrapperState): 1139 probs_per_attn_layer.append(attention_probs_from_attn_state(state)) 1140 1141 if not probs_per_attn_layer: 1142 raise ValueError( 1143 "coverage_penalty_weight must be 0.0 if no cell is attentional.") 1144 1145 if len(probs_per_attn_layer) == 1: 1146 attention_probs = probs_per_attn_layer[0] 1147 else: 1148 # Calculate the average attention probabilities from all attention layers. 1149 attention_probs = [ 1150 array_ops.expand_dims(prob, -1) for prob in probs_per_attn_layer] 1151 attention_probs = array_ops.concat(attention_probs, -1) 1152 attention_probs = math_ops.reduce_mean(attention_probs, -1) 1153 1154 return attention_probs 1155 1156 1157def _get_scores(log_probs, sequence_lengths, length_penalty_weight, 1158 coverage_penalty_weight, finished, accumulated_attention_probs): 1159 """Calculates scores for beam search hypotheses. 1160 1161 Args: 1162 log_probs: The log probabilities with shape 1163 `[batch_size, beam_width, vocab_size]`. 1164 sequence_lengths: The array of sequence lengths. 1165 length_penalty_weight: Float weight to penalize length. Disabled with 0.0. 1166 coverage_penalty_weight: Float weight to penalize the coverage of source 1167 sentence. Disabled with 0.0. 1168 finished: A boolean tensor of shape `[batch_size, beam_width]` that 1169 specifies which elements in the beam are finished already. 1170 accumulated_attention_probs: Accumulated attention probabilities up to the 1171 current time step, with shape `[batch_size, beam_width, max_time]` if 1172 coverage_penalty_weight is not 0.0. 1173 1174 Returns: 1175 The scores normalized by the length_penalty and coverage_penalty. 1176 1177 Raises: 1178 ValueError: accumulated_attention_probs is None when coverage penalty is 1179 enabled. 1180 """ 1181 length_penalty_ = _length_penalty( 1182 sequence_lengths=sequence_lengths, penalty_factor=length_penalty_weight) 1183 length_penalty_ = math_ops.cast(length_penalty_, dtype=log_probs.dtype) 1184 scores = log_probs / length_penalty_ 1185 1186 coverage_penalty_weight = ops.convert_to_tensor( 1187 coverage_penalty_weight, name="coverage_penalty_weight") 1188 if coverage_penalty_weight.shape.ndims != 0: 1189 raise ValueError("coverage_penalty_weight should be a scalar, " 1190 "but saw shape: %s" % coverage_penalty_weight.shape) 1191 1192 if tensor_util.constant_value(coverage_penalty_weight) == 0.0: 1193 return scores 1194 1195 if accumulated_attention_probs is None: 1196 raise ValueError( 1197 "accumulated_attention_probs can be None only if coverage penalty is " 1198 "disabled.") 1199 1200 # Add source sequence length mask before computing coverage penalty. 1201 accumulated_attention_probs = array_ops.where( 1202 math_ops.equal(accumulated_attention_probs, 0.0), 1203 array_ops.ones_like(accumulated_attention_probs), 1204 accumulated_attention_probs) 1205 1206 # coverage penalty = 1207 # sum over `max_time` {log(min(accumulated_attention_probs, 1.0))} 1208 coverage_penalty = math_ops.reduce_sum( 1209 math_ops.log(math_ops.minimum(accumulated_attention_probs, 1.0)), 2) 1210 # Apply coverage penalty to finished predictions. 1211 coverage_penalty *= math_ops.cast(finished, dtypes.float32) 1212 weighted_coverage_penalty = coverage_penalty * coverage_penalty_weight 1213 # Reshape from [batch_size, beam_width] to [batch_size, beam_width, 1] 1214 weighted_coverage_penalty = array_ops.expand_dims( 1215 weighted_coverage_penalty, 2) 1216 return scores + weighted_coverage_penalty 1217 1218 1219def attention_probs_from_attn_state(attention_state): 1220 """Calculates the average attention probabilities. 1221 1222 Args: 1223 attention_state: An instance of `AttentionWrapperState`. 1224 1225 Returns: 1226 The attention probabilities in the given AttentionWrapperState. 1227 If there're multiple attention mechanisms, return the average value from 1228 all attention mechanisms. 1229 """ 1230 # Attention probabilities over time steps, with shape 1231 # `[batch_size, beam_width, max_time]`. 1232 attention_probs = attention_state.alignments 1233 if isinstance(attention_probs, tuple): 1234 attention_probs = [ 1235 array_ops.expand_dims(prob, -1) for prob in attention_probs] 1236 attention_probs = array_ops.concat(attention_probs, -1) 1237 attention_probs = math_ops.reduce_mean(attention_probs, -1) 1238 return attention_probs 1239 1240 1241def _length_penalty(sequence_lengths, penalty_factor): 1242 """Calculates the length penalty. See https://arxiv.org/abs/1609.08144. 1243 1244 Returns the length penalty tensor: 1245 ``` 1246 [(5+sequence_lengths)/6]**penalty_factor 1247 ``` 1248 where all operations are performed element-wise. 1249 1250 Args: 1251 sequence_lengths: `Tensor`, the sequence lengths of each hypotheses. 1252 penalty_factor: A scalar that weights the length penalty. 1253 1254 Returns: 1255 If the penalty is `0`, returns the scalar `1.0`. Otherwise returns 1256 the length penalty factor, a tensor with the same shape as 1257 `sequence_lengths`. 1258 """ 1259 penalty_factor = ops.convert_to_tensor(penalty_factor, name="penalty_factor") 1260 penalty_factor.set_shape(()) # penalty should be a scalar. 1261 static_penalty = tensor_util.constant_value(penalty_factor) 1262 if static_penalty is not None and static_penalty == 0: 1263 return 1.0 1264 return math_ops.div( 1265 (5. + math_ops.cast(sequence_lengths, dtypes.float32))**penalty_factor, 1266 (5. + 1.)**penalty_factor) 1267 1268 1269def _mask_probs(probs, eos_token, finished): 1270 """Masks log probabilities. 1271 1272 The result is that finished beams allocate all probability mass to eos and 1273 unfinished beams remain unchanged. 1274 1275 Args: 1276 probs: Log probabilities of shape `[batch_size, beam_width, vocab_size]` 1277 eos_token: An int32 id corresponding to the EOS token to allocate 1278 probability to. 1279 finished: A boolean tensor of shape `[batch_size, beam_width]` that 1280 specifies which elements in the beam are finished already. 1281 1282 Returns: 1283 A tensor of shape `[batch_size, beam_width, vocab_size]`, where unfinished 1284 beams stay unchanged and finished beams are replaced with a tensor with all 1285 probability on the EOS token. 1286 """ 1287 vocab_size = array_ops.shape(probs)[2] 1288 # All finished examples are replaced with a vector that has all 1289 # probability on EOS 1290 finished_row = array_ops.one_hot( 1291 eos_token, 1292 vocab_size, 1293 dtype=probs.dtype, 1294 on_value=ops.convert_to_tensor(0., dtype=probs.dtype), 1295 off_value=probs.dtype.min) 1296 finished_probs = array_ops.tile( 1297 array_ops.reshape(finished_row, [1, 1, -1]), 1298 array_ops.concat([array_ops.shape(finished), [1]], 0)) 1299 finished_mask = array_ops.tile( 1300 array_ops.expand_dims(finished, 2), [1, 1, vocab_size]) 1301 1302 return array_ops.where(finished_mask, finished_probs, probs) 1303 1304 1305def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size, 1306 range_size, gather_shape): 1307 """Maybe applies _tensor_gather_helper. 1308 1309 This applies _tensor_gather_helper when the gather_from dims is at least as 1310 big as the length of gather_shape. This is used in conjunction with nest so 1311 that we don't apply _tensor_gather_helper to inapplicable values like scalars. 1312 1313 Args: 1314 gather_indices: The tensor indices that we use to gather. 1315 gather_from: The tensor that we are gathering from. 1316 batch_size: The batch size. 1317 range_size: The number of values in each range. Likely equal to beam_width. 1318 gather_shape: What we should reshape gather_from to in order to preserve the 1319 correct values. An example is when gather_from is the attention from an 1320 AttentionWrapperState with shape [batch_size, beam_width, attention_size]. 1321 There, we want to preserve the attention_size elements, so gather_shape is 1322 [batch_size * beam_width, -1]. Then, upon reshape, we still have the 1323 attention_size as desired. 1324 1325 Returns: 1326 output: Gathered tensor of shape tf.shape(gather_from)[:1+len(gather_shape)] 1327 or the original tensor if its dimensions are too small. 1328 """ 1329 if isinstance(gather_from, tensor_array_ops.TensorArray): 1330 return gather_from 1331 _check_ndims(gather_from) 1332 if gather_from.shape.ndims >= len(gather_shape): 1333 return _tensor_gather_helper( 1334 gather_indices=gather_indices, 1335 gather_from=gather_from, 1336 batch_size=batch_size, 1337 range_size=range_size, 1338 gather_shape=gather_shape) 1339 else: 1340 return gather_from 1341 1342 1343def _tensor_gather_helper(gather_indices, 1344 gather_from, 1345 batch_size, 1346 range_size, 1347 gather_shape, 1348 name=None): 1349 """Helper for gathering the right indices from the tensor. 1350 1351 This works by reshaping gather_from to gather_shape (e.g. [-1]) and then 1352 gathering from that according to the gather_indices, which are offset by 1353 the right amounts in order to preserve the batch order. 1354 1355 Args: 1356 gather_indices: The tensor indices that we use to gather. 1357 gather_from: The tensor that we are gathering from. 1358 batch_size: The input batch size. 1359 range_size: The number of values in each range. Likely equal to beam_width. 1360 gather_shape: What we should reshape gather_from to in order to preserve the 1361 correct values. An example is when gather_from is the attention from an 1362 AttentionWrapperState with shape [batch_size, beam_width, attention_size]. 1363 There, we want to preserve the attention_size elements, so gather_shape is 1364 [batch_size * beam_width, -1]. Then, upon reshape, we still have the 1365 attention_size as desired. 1366 name: The tensor name for set of operations. By default this is 1367 'tensor_gather_helper'. The final output is named 'output'. 1368 1369 Returns: 1370 output: Gathered tensor of shape tf.shape(gather_from)[:1+len(gather_shape)] 1371 """ 1372 with ops.name_scope(name, "tensor_gather_helper"): 1373 range_ = array_ops.expand_dims(math_ops.range(batch_size) * range_size, 1) 1374 gather_indices = array_ops.reshape(gather_indices + range_, [-1]) 1375 output = array_ops.gather( 1376 array_ops.reshape(gather_from, gather_shape), gather_indices) 1377 final_shape = array_ops.shape(gather_from)[:1 + len(gather_shape)] 1378 static_batch_size = tensor_util.constant_value(batch_size) 1379 final_static_shape = ( 1380 tensor_shape.TensorShape([static_batch_size]).concatenate( 1381 gather_from.shape[1:1 + len(gather_shape)])) 1382 output = array_ops.reshape(output, final_shape, name="output") 1383 output.set_shape(final_static_shape) 1384 return output 1385