• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A decoder that performs beam search."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import numpy as np
23
24from tensorflow.contrib.seq2seq.python.ops import attention_wrapper
25from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
26from tensorflow.contrib.seq2seq.python.ops import decoder
27from tensorflow.python.eager import context
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import tensor_shape
31from tensorflow.python.framework import tensor_util
32from tensorflow.python.keras import layers
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import control_flow_ops
35from tensorflow.python.ops import embedding_ops
36from tensorflow.python.ops import math_ops
37from tensorflow.python.ops import nn_ops
38from tensorflow.python.ops import rnn_cell_impl
39from tensorflow.python.ops import tensor_array_ops
40from tensorflow.python.platform import tf_logging
41from tensorflow.python.util import nest
42
43__all__ = [
44    "BeamSearchDecoderOutput",
45    "BeamSearchDecoderState",
46    "BeamSearchDecoder",
47    "FinalBeamSearchDecoderOutput",
48    "tile_batch",
49]
50
51
52class BeamSearchDecoderState(
53    collections.namedtuple("BeamSearchDecoderState",
54                           ("cell_state", "log_probs", "finished", "lengths",
55                            "accumulated_attention_probs"))):
56  pass
57
58
59class BeamSearchDecoderOutput(
60    collections.namedtuple("BeamSearchDecoderOutput",
61                           ("scores", "predicted_ids", "parent_ids"))):
62  pass
63
64
65class FinalBeamSearchDecoderOutput(
66    collections.namedtuple("FinalBeamDecoderOutput",
67                           ["predicted_ids", "beam_search_decoder_output"])):
68  """Final outputs returned by the beam search after all decoding is finished.
69
70  Args:
71    predicted_ids: The final prediction. A tensor of shape
72      `[batch_size, T, beam_width]` (or `[T, batch_size, beam_width]` if
73      `output_time_major` is True). Beams are ordered from best to worst.
74    beam_search_decoder_output: An instance of `BeamSearchDecoderOutput` that
75      describes the state of the beam search.
76  """
77  pass
78
79
80def _tile_batch(t, multiplier):
81  """Core single-tensor implementation of tile_batch."""
82  t = ops.convert_to_tensor(t, name="t")
83  shape_t = array_ops.shape(t)
84  if t.shape.ndims is None or t.shape.ndims < 1:
85    raise ValueError("t must have statically known rank")
86  tiling = [1] * (t.shape.ndims + 1)
87  tiling[1] = multiplier
88  tiled_static_batch_size = (
89      t.shape.dims[0].value * multiplier
90      if t.shape.dims[0].value is not None else None)
91  tiled = array_ops.tile(array_ops.expand_dims(t, 1), tiling)
92  tiled = array_ops.reshape(tiled,
93                            array_ops.concat(
94                                ([shape_t[0] * multiplier], shape_t[1:]), 0))
95  tiled.set_shape(
96      tensor_shape.TensorShape([tiled_static_batch_size]).concatenate(
97          t.shape[1:]))
98  return tiled
99
100
101def tile_batch(t, multiplier, name=None):
102  """Tile the batch dimension of a (possibly nested structure of) tensor(s) t.
103
104  For each tensor t in a (possibly nested structure) of tensors,
105  this function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of
106  minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a shape
107  `[batch_size * multiplier, s0, s1, ...]` composed of minibatch entries
108  `t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is repeated
109  `multiplier` times.
110
111  Args:
112    t: `Tensor` shaped `[batch_size, ...]`.
113    multiplier: Python int.
114    name: Name scope for any created operations.
115
116  Returns:
117    A (possibly nested structure of) `Tensor` shaped
118    `[batch_size * multiplier, ...]`.
119
120  Raises:
121    ValueError: if tensor(s) `t` do not have a statically known rank or
122    the rank is < 1.
123  """
124  flat_t = nest.flatten(t)
125  with ops.name_scope(name, "tile_batch", flat_t + [multiplier]):
126    return nest.map_structure(lambda t_: _tile_batch(t_, multiplier), t)
127
128
129def gather_tree_from_array(t, parent_ids, sequence_length):
130  """Calculates the full beams for `TensorArray`s.
131
132  Args:
133    t: A stacked `TensorArray` of size `max_time` that contains `Tensor`s of
134      shape `[batch_size, beam_width, s]` or `[batch_size * beam_width, s]`
135      where `s` is the depth shape.
136    parent_ids: The parent ids of shape `[max_time, batch_size, beam_width]`.
137    sequence_length: The sequence length of shape `[batch_size, beam_width]`.
138
139  Returns:
140    A `Tensor` which is a stacked `TensorArray` of the same size and type as
141    `t` and where beams are sorted in each `Tensor` according to `parent_ids`.
142  """
143  max_time = parent_ids.shape.dims[0].value or array_ops.shape(parent_ids)[0]
144  batch_size = parent_ids.shape.dims[1].value or array_ops.shape(parent_ids)[1]
145  beam_width = parent_ids.shape.dims[2].value or array_ops.shape(parent_ids)[2]
146
147  # Generate beam ids that will be reordered by gather_tree.
148  beam_ids = array_ops.expand_dims(
149      array_ops.expand_dims(math_ops.range(beam_width), 0), 0)
150  beam_ids = array_ops.tile(beam_ids, [max_time, batch_size, 1])
151
152  max_sequence_lengths = math_ops.cast(
153      math_ops.reduce_max(sequence_length, axis=1), dtypes.int32)
154  sorted_beam_ids = beam_search_ops.gather_tree(
155      step_ids=beam_ids,
156      parent_ids=parent_ids,
157      max_sequence_lengths=max_sequence_lengths,
158      end_token=beam_width + 1)
159
160  # For out of range steps, simply copy the same beam.
161  in_bound_steps = array_ops.transpose(
162      array_ops.sequence_mask(sequence_length, maxlen=max_time),
163      perm=[2, 0, 1])
164  sorted_beam_ids = array_ops.where(
165      in_bound_steps, x=sorted_beam_ids, y=beam_ids)
166
167  # Generate indices for gather_nd.
168  time_ind = array_ops.tile(array_ops.reshape(
169      math_ops.range(max_time), [-1, 1, 1]), [1, batch_size, beam_width])
170  batch_ind = array_ops.tile(array_ops.reshape(
171      math_ops.range(batch_size), [-1, 1, 1]), [1, max_time, beam_width])
172  batch_ind = array_ops.transpose(batch_ind, perm=[1, 0, 2])
173  indices = array_ops.stack([time_ind, batch_ind, sorted_beam_ids], -1)
174
175  # Gather from a tensor with collapsed additional dimensions.
176  gather_from = t
177  final_shape = array_ops.shape(gather_from)
178  gather_from = array_ops.reshape(
179      gather_from, [max_time, batch_size, beam_width, -1])
180  ordered = array_ops.gather_nd(gather_from, indices)
181  ordered = array_ops.reshape(ordered, final_shape)
182
183  return ordered
184
185
186def _check_ndims(t):
187  if t.shape.ndims is None:
188    raise ValueError(
189        "Expected tensor (%s) to have known rank, but ndims == None." % t)
190
191
192def _check_static_batch_beam_maybe(shape, batch_size, beam_width):
193  """Raises an exception if dimensions are known statically and can not be
194  reshaped to [batch_size, beam_size, -1].
195  """
196  reshaped_shape = tensor_shape.TensorShape([batch_size, beam_width, None])
197  if (batch_size is not None and shape.dims[0].value is not None
198      and (shape[0] != batch_size * beam_width
199           or (shape.ndims >= 2 and shape.dims[1].value is not None
200               and (shape[0] != batch_size or shape[1] != beam_width)))):
201    tf_logging.warn("TensorArray reordering expects elements to be "
202                    "reshapable to %s which is incompatible with the "
203                    "current shape %s. Consider setting "
204                    "reorder_tensor_arrays to False to disable TensorArray "
205                    "reordering during the beam search."
206                    % (reshaped_shape, shape))
207    return False
208  return True
209
210
211def _check_batch_beam(t, batch_size, beam_width):
212  """Returns an Assert operation checking that the elements of the stacked
213  TensorArray can be reshaped to [batch_size, beam_size, -1]. At this point,
214  the TensorArray elements have a known rank of at least 1.
215  """
216  error_message = ("TensorArray reordering expects elements to be "
217                   "reshapable to [batch_size, beam_size, -1] which is "
218                   "incompatible with the dynamic shape of %s elements. "
219                   "Consider setting reorder_tensor_arrays to False to disable "
220                   "TensorArray reordering during the beam search."
221                   % (t if context.executing_eagerly() else t.name))
222  rank = t.shape.ndims
223  shape = array_ops.shape(t)
224  if rank == 2:
225    condition = math_ops.equal(shape[1], batch_size * beam_width)
226  else:
227    condition = math_ops.logical_or(
228        math_ops.equal(shape[1], batch_size * beam_width),
229        math_ops.logical_and(
230            math_ops.equal(shape[1], batch_size),
231            math_ops.equal(shape[2], beam_width)))
232  return control_flow_ops.Assert(condition, [error_message])
233
234
235class BeamSearchDecoderMixin(object):
236  """BeamSearchDecoderMixin contains the common methods for BeamSearchDecoder.
237
238  It is expected to be used a base class for concrete BeamSearchDecoder. Since
239  this is a mixin class, it is expected to be used together with other class as
240  base.
241  """
242
243  def __init__(self,
244               cell,
245               beam_width,
246               output_layer=None,
247               length_penalty_weight=0.0,
248               coverage_penalty_weight=0.0,
249               reorder_tensor_arrays=True,
250               **kwargs):
251    """Initialize the BeamSearchDecoderMixin.
252
253    Args:
254      cell: An `RNNCell` instance.
255      beam_width:  Python integer, the number of beams.
256      output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e.,
257        `tf.keras.layers.Dense`.  Optional layer to apply to the RNN output
258        prior to storing the result or sampling.
259      length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
260      coverage_penalty_weight: Float weight to penalize the coverage of source
261        sentence. Disabled with 0.0.
262      reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell
263        state will be reordered according to the beam search path. If the
264        `TensorArray` can be reordered, the stacked form will be returned.
265        Otherwise, the `TensorArray` will be returned as is. Set this flag to
266        `False` if the cell state contains `TensorArray`s that are not amenable
267        to reordering.
268      **kwargs: Dict, other keyword arguments for parent class.
269
270    Raises:
271      TypeError: if `cell` is not an instance of `RNNCell`,
272        or `output_layer` is not an instance of `tf.keras.layers.Layer`.
273    """
274    rnn_cell_impl.assert_like_rnncell("cell", cell)  # pylint: disable=protected-access
275    if (output_layer is not None and
276        not isinstance(output_layer, layers.Layer)):
277      raise TypeError(
278          "output_layer must be a Layer, received: %s" % type(output_layer))
279    self._cell = cell
280    self._output_layer = output_layer
281    self._reorder_tensor_arrays = reorder_tensor_arrays
282
283    self._start_tokens = None
284    self._end_token = None
285    self._batch_size = None
286    self._beam_width = beam_width
287    self._length_penalty_weight = length_penalty_weight
288    self._coverage_penalty_weight = coverage_penalty_weight
289    super(BeamSearchDecoderMixin, self).__init__(**kwargs)
290
291  @property
292  def batch_size(self):
293    return self._batch_size
294
295  def _rnn_output_size(self):
296    """Get the output shape from the RNN layer."""
297    size = self._cell.output_size
298    if self._output_layer is None:
299      return size
300    else:
301      # To use layer's compute_output_shape, we need to convert the
302      # RNNCell's output_size entries into shapes with an unknown
303      # batch size.  We then pass this through the layer's
304      # compute_output_shape and read off all but the first (batch)
305      # dimensions to get the output size of the rnn with the layer
306      # applied to the top.
307      output_shape_with_unknown_batch = nest.map_structure(
308          lambda s: tensor_shape.TensorShape([None]).concatenate(s), size)
309      layer_output_shape = self._output_layer.compute_output_shape(
310          output_shape_with_unknown_batch)
311      return nest.map_structure(lambda s: s[1:], layer_output_shape)
312
313  @property
314  def tracks_own_finished(self):
315    """The BeamSearchDecoder shuffles its beams and their finished state.
316
317    For this reason, it conflicts with the `dynamic_decode` function's
318    tracking of finished states.  Setting this property to true avoids
319    early stopping of decoding due to mismanagement of the finished state
320    in `dynamic_decode`.
321
322    Returns:
323      `True`.
324    """
325    return True
326
327  @property
328  def output_size(self):
329    # Return the cell output and the id
330    return BeamSearchDecoderOutput(
331        scores=tensor_shape.TensorShape([self._beam_width]),
332        predicted_ids=tensor_shape.TensorShape([self._beam_width]),
333        parent_ids=tensor_shape.TensorShape([self._beam_width]))
334
335  def finalize(self, outputs, final_state, sequence_lengths):
336    """Finalize and return the predicted_ids.
337
338    Args:
339      outputs: An instance of BeamSearchDecoderOutput.
340      final_state: An instance of BeamSearchDecoderState. Passed through to the
341        output.
342      sequence_lengths: An `int64` tensor shaped `[batch_size, beam_width]`.
343        The sequence lengths determined for each beam during decode.
344        **NOTE** These are ignored; the updated sequence lengths are stored in
345        `final_state.lengths`.
346
347    Returns:
348      outputs: An instance of `FinalBeamSearchDecoderOutput` where the
349        predicted_ids are the result of calling _gather_tree.
350      final_state: The same input instance of `BeamSearchDecoderState`.
351    """
352    del sequence_lengths
353    # Get max_sequence_length across all beams for each batch.
354    max_sequence_lengths = math_ops.cast(
355        math_ops.reduce_max(final_state.lengths, axis=1), dtypes.int32)
356    predicted_ids = beam_search_ops.gather_tree(
357        outputs.predicted_ids,
358        outputs.parent_ids,
359        max_sequence_lengths=max_sequence_lengths,
360        end_token=self._end_token)
361    if self._reorder_tensor_arrays:
362      final_state = final_state._replace(cell_state=nest.map_structure(
363          lambda t: self._maybe_sort_array_beams(
364              t, outputs.parent_ids, final_state.lengths),
365          final_state.cell_state))
366    outputs = FinalBeamSearchDecoderOutput(
367        beam_search_decoder_output=outputs, predicted_ids=predicted_ids)
368    return outputs, final_state
369
370  def _merge_batch_beams(self, t, s=None):
371    """Merges the tensor from a batch of beams into a batch by beams.
372
373    More exactly, t is a tensor of dimension [batch_size, beam_width, s]. We
374    reshape this into [batch_size*beam_width, s]
375
376    Args:
377      t: Tensor of dimension [batch_size, beam_width, s]
378      s: (Possibly known) depth shape.
379
380    Returns:
381      A reshaped version of t with dimension [batch_size * beam_width, s].
382    """
383    if isinstance(s, ops.Tensor):
384      s = tensor_shape.as_shape(tensor_util.constant_value(s))
385    else:
386      s = tensor_shape.TensorShape(s)
387    t_shape = array_ops.shape(t)
388    static_batch_size = tensor_util.constant_value(self._batch_size)
389    batch_size_beam_width = (
390        None
391        if static_batch_size is None else static_batch_size * self._beam_width)
392    reshaped_t = array_ops.reshape(
393        t,
394        array_ops.concat(([self._batch_size * self._beam_width], t_shape[2:]),
395                         0))
396    reshaped_t.set_shape(
397        (tensor_shape.TensorShape([batch_size_beam_width]).concatenate(s)))
398    return reshaped_t
399
400  def _split_batch_beams(self, t, s=None):
401    """Splits the tensor from a batch by beams into a batch of beams.
402
403    More exactly, t is a tensor of dimension [batch_size*beam_width, s]. We
404    reshape this into [batch_size, beam_width, s]
405
406    Args:
407      t: Tensor of dimension [batch_size*beam_width, s].
408      s: (Possibly known) depth shape.
409
410    Returns:
411      A reshaped version of t with dimension [batch_size, beam_width, s].
412
413    Raises:
414      ValueError: If, after reshaping, the new tensor is not shaped
415        `[batch_size, beam_width, s]` (assuming batch_size and beam_width
416        are known statically).
417    """
418    if isinstance(s, ops.Tensor):
419      s = tensor_shape.TensorShape(tensor_util.constant_value(s))
420    else:
421      s = tensor_shape.TensorShape(s)
422    t_shape = array_ops.shape(t)
423    reshaped_t = array_ops.reshape(
424        t,
425        array_ops.concat(([self._batch_size, self._beam_width], t_shape[1:]),
426                         0))
427    static_batch_size = tensor_util.constant_value(self._batch_size)
428    expected_reshaped_shape = tensor_shape.TensorShape(
429        [static_batch_size, self._beam_width]).concatenate(s)
430    if not reshaped_t.shape.is_compatible_with(expected_reshaped_shape):
431      raise ValueError("Unexpected behavior when reshaping between beam width "
432                       "and batch size.  The reshaped tensor has shape: %s.  "
433                       "We expected it to have shape "
434                       "(batch_size, beam_width, depth) == %s.  Perhaps you "
435                       "forgot to create a zero_state with "
436                       "batch_size=encoder_batch_size * beam_width?" %
437                       (reshaped_t.shape, expected_reshaped_shape))
438    reshaped_t.set_shape(expected_reshaped_shape)
439    return reshaped_t
440
441  def _maybe_split_batch_beams(self, t, s):
442    """Maybe splits the tensor from a batch by beams into a batch of beams.
443
444    We do this so that we can use nest and not run into problems with shapes.
445
446    Args:
447      t: `Tensor`, either scalar or shaped `[batch_size * beam_width] + s`.
448      s: `Tensor`, Python int, or `TensorShape`.
449
450    Returns:
451      If `t` is a matrix or higher order tensor, then the return value is
452      `t` reshaped to `[batch_size, beam_width] + s`.  Otherwise `t` is
453      returned unchanged.
454
455    Raises:
456      ValueError: If the rank of `t` is not statically known.
457    """
458    if isinstance(t, tensor_array_ops.TensorArray):
459      return t
460    _check_ndims(t)
461    if t.shape.ndims >= 1:
462      return self._split_batch_beams(t, s)
463    else:
464      return t
465
466  def _maybe_merge_batch_beams(self, t, s):
467    """Splits the tensor from a batch by beams into a batch of beams.
468
469    More exactly, `t` is a tensor of dimension `[batch_size * beam_width] + s`,
470    then we reshape it to `[batch_size, beam_width] + s`.
471
472    Args:
473      t: `Tensor` of dimension `[batch_size * beam_width] + s`.
474      s: `Tensor`, Python int, or `TensorShape`.
475
476    Returns:
477      A reshaped version of t with shape `[batch_size, beam_width] + s`.
478
479    Raises:
480      ValueError:  If the rank of `t` is not statically known.
481    """
482    if isinstance(t, tensor_array_ops.TensorArray):
483      return t
484    _check_ndims(t)
485    if t.shape.ndims >= 2:
486      return self._merge_batch_beams(t, s)
487    else:
488      return t
489
490  def _maybe_sort_array_beams(self, t, parent_ids, sequence_length):
491    """Maybe sorts beams within a `TensorArray`.
492
493    Args:
494      t: A `TensorArray` of size `max_time` that contains `Tensor`s of shape
495        `[batch_size, beam_width, s]` or `[batch_size * beam_width, s]` where
496        `s` is the depth shape.
497      parent_ids: The parent ids of shape `[max_time, batch_size, beam_width]`.
498      sequence_length: The sequence length of shape `[batch_size, beam_width]`.
499
500    Returns:
501      A `TensorArray` where beams are sorted in each `Tensor` or `t` itself if
502      it is not a `TensorArray` or does not meet shape requirements.
503    """
504    if not isinstance(t, tensor_array_ops.TensorArray):
505      return t
506    # pylint: disable=protected-access
507    # This is a bad hack due to the implementation detail of eager/graph TA.
508    # TODO(b/124374427): Update this to use public property of TensorArray.
509    if context.executing_eagerly():
510      element_shape = t._element_shape
511    else:
512      element_shape = t._element_shape[0]
513    if (not t._infer_shape
514        or not t._element_shape
515        or element_shape.ndims is None
516        or element_shape.ndims < 1):
517      shape = (
518          element_shape if t._infer_shape and t._element_shape
519          else tensor_shape.TensorShape(None))
520      tf_logging.warn("The TensorArray %s in the cell state is not amenable to "
521                      "sorting based on the beam search result. For a "
522                      "TensorArray to be sorted, its elements shape must be "
523                      "defined and have at least a rank of 1, but saw shape: %s"
524                      % (t.handle.name, shape))
525      return t
526    # pylint: enable=protected-access
527    if not _check_static_batch_beam_maybe(
528        element_shape, tensor_util.constant_value(self._batch_size),
529        self._beam_width):
530      return t
531    t = t.stack()
532    with ops.control_dependencies(
533        [_check_batch_beam(t, self._batch_size, self._beam_width)]):
534      return gather_tree_from_array(t, parent_ids, sequence_length)
535
536  def step(self, time, inputs, state, name=None):
537    """Perform a decoding step.
538
539    Args:
540      time: scalar `int32` tensor.
541      inputs: A (structure of) input tensors.
542      state: A (structure of) state tensors and TensorArrays.
543      name: Name scope for any created operations.
544
545    Returns:
546      `(outputs, next_state, next_inputs, finished)`.
547    """
548    batch_size = self._batch_size
549    beam_width = self._beam_width
550    end_token = self._end_token
551    length_penalty_weight = self._length_penalty_weight
552    coverage_penalty_weight = self._coverage_penalty_weight
553
554    with ops.name_scope(name, "BeamSearchDecoderStep", (time, inputs, state)):
555      cell_state = state.cell_state
556      inputs = nest.map_structure(
557          lambda inp: self._merge_batch_beams(inp, s=inp.shape[2:]), inputs)
558      cell_state = nest.map_structure(self._maybe_merge_batch_beams, cell_state,
559                                      self._cell.state_size)
560      cell_outputs, next_cell_state = self._cell(inputs, cell_state)
561      cell_outputs = nest.map_structure(
562          lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs)
563      next_cell_state = nest.map_structure(
564          self._maybe_split_batch_beams, next_cell_state, self._cell.state_size)
565
566      if self._output_layer is not None:
567        cell_outputs = self._output_layer(cell_outputs)
568
569      beam_search_output, beam_search_state = _beam_search_step(
570          time=time,
571          logits=cell_outputs,
572          next_cell_state=next_cell_state,
573          beam_state=state,
574          batch_size=batch_size,
575          beam_width=beam_width,
576          end_token=end_token,
577          length_penalty_weight=length_penalty_weight,
578          coverage_penalty_weight=coverage_penalty_weight)
579
580      finished = beam_search_state.finished
581      sample_ids = beam_search_output.predicted_ids
582      next_inputs = control_flow_ops.cond(
583          math_ops.reduce_all(finished), lambda: self._start_inputs,
584          lambda: self._embedding_fn(sample_ids))
585
586    return (beam_search_output, beam_search_state, next_inputs, finished)
587
588
589class BeamSearchDecoder(BeamSearchDecoderMixin, decoder.Decoder):
590  # Note that the inheritance hierarchy is important here. The Mixin has to be
591  # the first parent class since we will use super().__init__(), and Mixin which
592  # is a object will properly invoke the __init__ method of other parent class.
593  """BeamSearch sampling decoder.
594
595    **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in
596    `AttentionWrapper`, then you must ensure that:
597
598    - The encoder output has been tiled to `beam_width` via
599      `tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`).
600    - The `batch_size` argument passed to the `zero_state` method of this
601      wrapper is equal to `true_batch_size * beam_width`.
602    - The initial state created with `zero_state` above contains a
603      `cell_state` value containing properly tiled final state from the
604      encoder.
605
606    An example:
607
608    ```
609    tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(
610        encoder_outputs, multiplier=beam_width)
611    tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch(
612        encoder_final_state, multiplier=beam_width)
613    tiled_sequence_length = tf.contrib.seq2seq.tile_batch(
614        sequence_length, multiplier=beam_width)
615    attention_mechanism = MyFavoriteAttentionMechanism(
616        num_units=attention_depth,
617        memory=tiled_inputs,
618        memory_sequence_length=tiled_sequence_length)
619    attention_cell = AttentionWrapper(cell, attention_mechanism, ...)
620    decoder_initial_state = attention_cell.zero_state(
621        dtype, batch_size=true_batch_size * beam_width)
622    decoder_initial_state = decoder_initial_state.clone(
623        cell_state=tiled_encoder_final_state)
624    ```
625
626    Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use
627    when computing scores (https://arxiv.org/pdf/1609.08144.pdf). It encourages
628    the decoder to cover all inputs.
629  """
630
631  def __init__(self,
632               cell,
633               embedding,
634               start_tokens,
635               end_token,
636               initial_state,
637               beam_width,
638               output_layer=None,
639               length_penalty_weight=0.0,
640               coverage_penalty_weight=0.0,
641               reorder_tensor_arrays=True):
642    """Initialize the BeamSearchDecoder.
643
644    Args:
645      cell: An `RNNCell` instance.
646      embedding: A callable that takes a vector tensor of `ids` (argmax ids),
647        or the `params` argument for `embedding_lookup`.
648      start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
649      end_token: `int32` scalar, the token that marks end of decoding.
650      initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
651      beam_width:  Python integer, the number of beams.
652      output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e.,
653        `tf.keras.layers.Dense`.  Optional layer to apply to the RNN output
654        prior to storing the result or sampling.
655      length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
656      coverage_penalty_weight: Float weight to penalize the coverage of source
657        sentence. Disabled with 0.0.
658      reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell
659        state will be reordered according to the beam search path. If the
660        `TensorArray` can be reordered, the stacked form will be returned.
661        Otherwise, the `TensorArray` will be returned as is. Set this flag to
662        `False` if the cell state contains `TensorArray`s that are not amenable
663        to reordering.
664
665    Raises:
666      TypeError: if `cell` is not an instance of `RNNCell`,
667        or `output_layer` is not an instance of `tf.keras.layers.Layer`.
668      ValueError: If `start_tokens` is not a vector or
669        `end_token` is not a scalar.
670    """
671    super(BeamSearchDecoder, self).__init__(
672        cell,
673        beam_width,
674        output_layer=output_layer,
675        length_penalty_weight=length_penalty_weight,
676        coverage_penalty_weight=coverage_penalty_weight,
677        reorder_tensor_arrays=reorder_tensor_arrays)
678
679    if callable(embedding):
680      self._embedding_fn = embedding
681    else:
682      self._embedding_fn = (
683          lambda ids: embedding_ops.embedding_lookup(embedding, ids))
684
685    self._start_tokens = ops.convert_to_tensor(
686        start_tokens, dtype=dtypes.int32, name="start_tokens")
687    if self._start_tokens.get_shape().ndims != 1:
688      raise ValueError("start_tokens must be a vector")
689    self._end_token = ops.convert_to_tensor(
690        end_token, dtype=dtypes.int32, name="end_token")
691    if self._end_token.get_shape().ndims != 0:
692      raise ValueError("end_token must be a scalar")
693
694    self._batch_size = array_ops.size(start_tokens)
695    self._initial_cell_state = nest.map_structure(
696        self._maybe_split_batch_beams, initial_state, self._cell.state_size)
697    self._start_tokens = array_ops.tile(
698        array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width])
699    self._start_inputs = self._embedding_fn(self._start_tokens)
700
701    self._finished = array_ops.one_hot(
702        array_ops.zeros([self._batch_size], dtype=dtypes.int32),
703        depth=self._beam_width,
704        on_value=False,
705        off_value=True,
706        dtype=dtypes.bool)
707
708  def initialize(self, name=None):
709    """Initialize the decoder.
710
711    Args:
712      name: Name scope for any created operations.
713
714    Returns:
715      `(finished, start_inputs, initial_state)`.
716    """
717    finished, start_inputs = self._finished, self._start_inputs
718
719    dtype = nest.flatten(self._initial_cell_state)[0].dtype
720    log_probs = array_ops.one_hot(  # shape(batch_sz, beam_sz)
721        array_ops.zeros([self._batch_size], dtype=dtypes.int32),
722        depth=self._beam_width,
723        on_value=ops.convert_to_tensor(0.0, dtype=dtype),
724        off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype),
725        dtype=dtype)
726    init_attention_probs = get_attention_probs(
727        self._initial_cell_state, self._coverage_penalty_weight)
728    if init_attention_probs is None:
729      init_attention_probs = ()
730
731    initial_state = BeamSearchDecoderState(
732        cell_state=self._initial_cell_state,
733        log_probs=log_probs,
734        finished=finished,
735        lengths=array_ops.zeros(
736            [self._batch_size, self._beam_width], dtype=dtypes.int64),
737        accumulated_attention_probs=init_attention_probs)
738
739    return (finished, start_inputs, initial_state)
740
741  @property
742  def output_dtype(self):
743    # Assume the dtype of the cell is the output_size structure
744    # containing the input_state's first component's dtype.
745    # Return that structure and int32 (the id)
746    dtype = nest.flatten(self._initial_cell_state)[0].dtype
747    return BeamSearchDecoderOutput(
748        scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()),
749        predicted_ids=dtypes.int32,
750        parent_ids=dtypes.int32)
751
752
753class BeamSearchDecoderV2(BeamSearchDecoderMixin, decoder.BaseDecoder):
754  # Note that the inheritance hierarchy is important here. The Mixin has to be
755  # the first parent class since we will use super().__init__(), and Mixin which
756  # is a object will properly invoke the __init__ method of other parent class.
757  """BeamSearch sampling decoder.
758
759    **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in
760    `AttentionWrapper`, then you must ensure that:
761
762    - The encoder output has been tiled to `beam_width` via
763      `tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`).
764    - The `batch_size` argument passed to the `zero_state` method of this
765      wrapper is equal to `true_batch_size * beam_width`.
766    - The initial state created with `zero_state` above contains a
767      `cell_state` value containing properly tiled final state from the
768      encoder.
769
770    An example:
771
772    ```
773    tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(
774        encoder_outputs, multiplier=beam_width)
775    tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch(
776        encoder_final_state, multiplier=beam_width)
777    tiled_sequence_length = tf.contrib.seq2seq.tile_batch(
778        sequence_length, multiplier=beam_width)
779    attention_mechanism = MyFavoriteAttentionMechanism(
780        num_units=attention_depth,
781        memory=tiled_inputs,
782        memory_sequence_length=tiled_sequence_length)
783    attention_cell = AttentionWrapper(cell, attention_mechanism, ...)
784    decoder_initial_state = attention_cell.zero_state(
785        dtype, batch_size=true_batch_size * beam_width)
786    decoder_initial_state = decoder_initial_state.clone(
787        cell_state=tiled_encoder_final_state)
788    ```
789
790    Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use
791    when computing scores (https://arxiv.org/pdf/1609.08144.pdf). It encourages
792    the decoding to cover all inputs.
793  """
794
795  def __init__(self,
796               cell,
797               beam_width,
798               embedding_fn=None,
799               output_layer=None,
800               length_penalty_weight=0.0,
801               coverage_penalty_weight=0.0,
802               reorder_tensor_arrays=True,
803               **kwargs):
804    """Initialize the BeamSearchDecoderV2.
805
806    Args:
807      cell: An `RNNCell` instance.
808      beam_width:  Python integer, the number of beams.
809      embedding_fn: A callable that takes a vector tensor of `ids` (argmax ids).
810      output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e.,
811        `tf.keras.layers.Dense`.  Optional layer to apply to the RNN output
812        prior to storing the result or sampling.
813      length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
814      coverage_penalty_weight: Float weight to penalize the coverage of source
815        sentence. Disabled with 0.0.
816      reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell
817        state will be reordered according to the beam search path. If the
818        `TensorArray` can be reordered, the stacked form will be returned.
819        Otherwise, the `TensorArray` will be returned as is. Set this flag to
820        `False` if the cell state contains `TensorArray`s that are not amenable
821        to reordering.
822      **kwargs: Dict, other keyword arguments for initialization.
823
824    Raises:
825      TypeError: if `cell` is not an instance of `RNNCell`,
826        or `output_layer` is not an instance of `tf.keras.layers.Layer`.
827    """
828    super(BeamSearchDecoderV2, self).__init__(
829        cell,
830        beam_width,
831        output_layer=output_layer,
832        length_penalty_weight=length_penalty_weight,
833        coverage_penalty_weight=coverage_penalty_weight,
834        reorder_tensor_arrays=reorder_tensor_arrays,
835        **kwargs)
836
837    if embedding_fn is None or callable(embedding_fn):
838      self._embedding_fn = embedding_fn
839    else:
840      raise ValueError("embedding_fn is expected to be a callable, got %s" %
841                       type(embedding_fn))
842
843  def initialize(self,
844                 embedding,
845                 start_tokens,
846                 end_token,
847                 initial_state):
848    """Initialize the decoder.
849
850    Args:
851      embedding: A tensor from the embedding layer output, which is the
852        `params` argument for `embedding_lookup`.
853      start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
854      end_token: `int32` scalar, the token that marks end of decoding.
855      initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
856    Returns:
857      `(finished, start_inputs, initial_state)`.
858    Raises:
859      ValueError: If `start_tokens` is not a vector or `end_token` is not a
860        scalar.
861    """
862    if embedding is not None and self._embedding_fn is not None:
863      raise ValueError(
864          "embedding and embedding_fn cannot be provided at same time")
865    elif embedding is not None:
866      self._embedding_fn = (
867          lambda ids: embedding_ops.embedding_lookup(embedding, ids))
868
869    self._start_tokens = ops.convert_to_tensor(
870        start_tokens, dtype=dtypes.int32, name="start_tokens")
871    if self._start_tokens.get_shape().ndims != 1:
872      raise ValueError("start_tokens must be a vector")
873    self._end_token = ops.convert_to_tensor(
874        end_token, dtype=dtypes.int32, name="end_token")
875    if self._end_token.get_shape().ndims != 0:
876      raise ValueError("end_token must be a scalar")
877
878    self._batch_size = array_ops.size(start_tokens)
879    self._initial_cell_state = nest.map_structure(
880        self._maybe_split_batch_beams, initial_state, self._cell.state_size)
881    self._start_tokens = array_ops.tile(
882        array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width])
883    self._start_inputs = self._embedding_fn(self._start_tokens)
884
885    self._finished = array_ops.one_hot(
886        array_ops.zeros([self._batch_size], dtype=dtypes.int32),
887        depth=self._beam_width,
888        on_value=False,
889        off_value=True,
890        dtype=dtypes.bool)
891
892    finished, start_inputs = self._finished, self._start_inputs
893
894    dtype = nest.flatten(self._initial_cell_state)[0].dtype
895    log_probs = array_ops.one_hot(  # shape(batch_sz, beam_sz)
896        array_ops.zeros([self._batch_size], dtype=dtypes.int32),
897        depth=self._beam_width,
898        on_value=ops.convert_to_tensor(0.0, dtype=dtype),
899        off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype),
900        dtype=dtype)
901    init_attention_probs = get_attention_probs(
902        self._initial_cell_state, self._coverage_penalty_weight)
903    if init_attention_probs is None:
904      init_attention_probs = ()
905
906    initial_state = BeamSearchDecoderState(
907        cell_state=self._initial_cell_state,
908        log_probs=log_probs,
909        finished=finished,
910        lengths=array_ops.zeros(
911            [self._batch_size, self._beam_width], dtype=dtypes.int64),
912        accumulated_attention_probs=init_attention_probs)
913
914    return (finished, start_inputs, initial_state)
915
916  @property
917  def output_dtype(self):
918    # Assume the dtype of the cell is the output_size structure
919    # containing the input_state's first component's dtype.
920    # Return that structure and int32 (the id)
921    dtype = nest.flatten(self._initial_cell_state)[0].dtype
922    return BeamSearchDecoderOutput(
923        scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()),
924        predicted_ids=dtypes.int32,
925        parent_ids=dtypes.int32)
926
927  def call(self, embeddning, start_tokens, end_token, initial_state, **kwargs):
928    init_kwargs = kwargs
929    init_kwargs["start_tokens"] = start_tokens
930    init_kwargs["end_token"] = end_token
931    init_kwargs["initial_state"] = initial_state
932    return decoder.dynamic_decode(self,
933                                  output_time_major=self.output_time_major,
934                                  impute_finished=self.impute_finished,
935                                  maximum_iterations=self.maximum_iterations,
936                                  parallel_iterations=self.parallel_iterations,
937                                  swap_memory=self.swap_memory,
938                                  decoder_init_input=embeddning,
939                                  decoder_init_kwargs=init_kwargs)
940
941
942def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
943                      beam_width, end_token, length_penalty_weight,
944                      coverage_penalty_weight):
945  """Performs a single step of Beam Search Decoding.
946
947  Args:
948    time: Beam search time step, should start at 0. At time 0 we assume
949      that all beams are equal and consider only the first beam for
950      continuations.
951    logits: Logits at the current time step. A tensor of shape
952      `[batch_size, beam_width, vocab_size]`
953    next_cell_state: The next state from the cell, e.g. an instance of
954      AttentionWrapperState if the cell is attentional.
955    beam_state: Current state of the beam search.
956      An instance of `BeamSearchDecoderState`.
957    batch_size: The batch size for this input.
958    beam_width: Python int.  The size of the beams.
959    end_token: The int32 end token.
960    length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
961    coverage_penalty_weight: Float weight to penalize the coverage of source
962      sentence. Disabled with 0.0.
963
964  Returns:
965    A new beam state.
966  """
967  static_batch_size = tensor_util.constant_value(batch_size)
968
969  # Calculate the current lengths of the predictions
970  prediction_lengths = beam_state.lengths
971  previously_finished = beam_state.finished
972  not_finished = math_ops.logical_not(previously_finished)
973
974  # Calculate the total log probs for the new hypotheses
975  # Final Shape: [batch_size, beam_width, vocab_size]
976  step_log_probs = nn_ops.log_softmax(logits)
977  step_log_probs = _mask_probs(step_log_probs, end_token, previously_finished)
978  total_probs = array_ops.expand_dims(beam_state.log_probs, 2) + step_log_probs
979
980  # Calculate the continuation lengths by adding to all continuing beams.
981  vocab_size = logits.shape.dims[-1].value or array_ops.shape(logits)[-1]
982  lengths_to_add = array_ops.one_hot(
983      indices=array_ops.fill([batch_size, beam_width], end_token),
984      depth=vocab_size,
985      on_value=np.int64(0),
986      off_value=np.int64(1),
987      dtype=dtypes.int64)
988  add_mask = math_ops.cast(not_finished, dtypes.int64)
989  lengths_to_add *= array_ops.expand_dims(add_mask, 2)
990  new_prediction_lengths = (
991      lengths_to_add + array_ops.expand_dims(prediction_lengths, 2))
992
993  # Calculate the accumulated attention probabilities if coverage penalty is
994  # enabled.
995  accumulated_attention_probs = None
996  attention_probs = get_attention_probs(
997      next_cell_state, coverage_penalty_weight)
998  if attention_probs is not None:
999    attention_probs *= array_ops.expand_dims(
1000        math_ops.cast(not_finished, dtypes.float32), 2)
1001    accumulated_attention_probs = (
1002        beam_state.accumulated_attention_probs + attention_probs)
1003
1004  # Calculate the scores for each beam
1005  scores = _get_scores(
1006      log_probs=total_probs,
1007      sequence_lengths=new_prediction_lengths,
1008      length_penalty_weight=length_penalty_weight,
1009      coverage_penalty_weight=coverage_penalty_weight,
1010      finished=previously_finished,
1011      accumulated_attention_probs=accumulated_attention_probs)
1012
1013  time = ops.convert_to_tensor(time, name="time")
1014  # During the first time step we only consider the initial beam
1015  scores_flat = array_ops.reshape(scores, [batch_size, -1])
1016
1017  # Pick the next beams according to the specified successors function
1018  next_beam_size = ops.convert_to_tensor(
1019      beam_width, dtype=dtypes.int32, name="beam_width")
1020  next_beam_scores, word_indices = nn_ops.top_k(scores_flat, k=next_beam_size)
1021
1022  next_beam_scores.set_shape([static_batch_size, beam_width])
1023  word_indices.set_shape([static_batch_size, beam_width])
1024
1025  # Pick out the probs, beam_ids, and states according to the chosen predictions
1026  next_beam_probs = _tensor_gather_helper(
1027      gather_indices=word_indices,
1028      gather_from=total_probs,
1029      batch_size=batch_size,
1030      range_size=beam_width * vocab_size,
1031      gather_shape=[-1],
1032      name="next_beam_probs")
1033  # Note: just doing the following
1034  #   math_ops.cast(
1035  #       word_indices % vocab_size,
1036  #       dtypes.int32,
1037  #       name="next_beam_word_ids")
1038  # would be a lot cleaner but for reasons unclear, that hides the results of
1039  # the op which prevents capturing it with tfdbg debug ops.
1040  raw_next_word_ids = math_ops.mod(
1041      word_indices, vocab_size, name="next_beam_word_ids")
1042  next_word_ids = math_ops.cast(raw_next_word_ids, dtypes.int32)
1043  next_beam_ids = math_ops.cast(
1044      word_indices / vocab_size, dtypes.int32, name="next_beam_parent_ids")
1045
1046  # Append new ids to current predictions
1047  previously_finished = _tensor_gather_helper(
1048      gather_indices=next_beam_ids,
1049      gather_from=previously_finished,
1050      batch_size=batch_size,
1051      range_size=beam_width,
1052      gather_shape=[-1])
1053  next_finished = math_ops.logical_or(
1054      previously_finished,
1055      math_ops.equal(next_word_ids, end_token),
1056      name="next_beam_finished")
1057
1058  # Calculate the length of the next predictions.
1059  # 1. Finished beams remain unchanged.
1060  # 2. Beams that are now finished (EOS predicted) have their length
1061  #    increased by 1.
1062  # 3. Beams that are not yet finished have their length increased by 1.
1063  lengths_to_add = math_ops.cast(
1064      math_ops.logical_not(previously_finished), dtypes.int64)
1065  next_prediction_len = _tensor_gather_helper(
1066      gather_indices=next_beam_ids,
1067      gather_from=beam_state.lengths,
1068      batch_size=batch_size,
1069      range_size=beam_width,
1070      gather_shape=[-1])
1071  next_prediction_len += lengths_to_add
1072  next_accumulated_attention_probs = ()
1073  if accumulated_attention_probs is not None:
1074    next_accumulated_attention_probs = _tensor_gather_helper(
1075        gather_indices=next_beam_ids,
1076        gather_from=accumulated_attention_probs,
1077        batch_size=batch_size,
1078        range_size=beam_width,
1079        gather_shape=[batch_size * beam_width, -1],
1080        name="next_accumulated_attention_probs")
1081
1082  # Pick out the cell_states according to the next_beam_ids. We use a
1083  # different gather_shape here because the cell_state tensors, i.e.
1084  # the tensors that would be gathered from, all have dimension
1085  # greater than two and we need to preserve those dimensions.
1086  # pylint: disable=g-long-lambda
1087  next_cell_state = nest.map_structure(
1088      lambda gather_from: _maybe_tensor_gather_helper(
1089          gather_indices=next_beam_ids,
1090          gather_from=gather_from,
1091          batch_size=batch_size,
1092          range_size=beam_width,
1093          gather_shape=[batch_size * beam_width, -1]),
1094      next_cell_state)
1095  # pylint: enable=g-long-lambda
1096
1097  next_state = BeamSearchDecoderState(
1098      cell_state=next_cell_state,
1099      log_probs=next_beam_probs,
1100      lengths=next_prediction_len,
1101      finished=next_finished,
1102      accumulated_attention_probs=next_accumulated_attention_probs)
1103
1104  output = BeamSearchDecoderOutput(
1105      scores=next_beam_scores,
1106      predicted_ids=next_word_ids,
1107      parent_ids=next_beam_ids)
1108
1109  return output, next_state
1110
1111
1112def get_attention_probs(next_cell_state, coverage_penalty_weight):
1113  """Get attention probabilities from the cell state.
1114
1115  Args:
1116    next_cell_state: The next state from the cell, e.g. an instance of
1117      AttentionWrapperState if the cell is attentional.
1118    coverage_penalty_weight: Float weight to penalize the coverage of source
1119      sentence. Disabled with 0.0.
1120
1121  Returns:
1122    The attention probabilities with shape `[batch_size, beam_width, max_time]`
1123    if coverage penalty is enabled. Otherwise, returns None.
1124
1125  Raises:
1126    ValueError: If no cell is attentional but coverage penalty is enabled.
1127  """
1128  if coverage_penalty_weight == 0.0:
1129    return None
1130
1131  # Attention probabilities of each attention layer. Each with shape
1132  # `[batch_size, beam_width, max_time]`.
1133  probs_per_attn_layer = []
1134  if isinstance(next_cell_state, attention_wrapper.AttentionWrapperState):
1135    probs_per_attn_layer = [attention_probs_from_attn_state(next_cell_state)]
1136  elif isinstance(next_cell_state, tuple):
1137    for state in next_cell_state:
1138      if isinstance(state, attention_wrapper.AttentionWrapperState):
1139        probs_per_attn_layer.append(attention_probs_from_attn_state(state))
1140
1141  if not probs_per_attn_layer:
1142    raise ValueError(
1143        "coverage_penalty_weight must be 0.0 if no cell is attentional.")
1144
1145  if len(probs_per_attn_layer) == 1:
1146    attention_probs = probs_per_attn_layer[0]
1147  else:
1148    # Calculate the average attention probabilities from all attention layers.
1149    attention_probs = [
1150        array_ops.expand_dims(prob, -1) for prob in probs_per_attn_layer]
1151    attention_probs = array_ops.concat(attention_probs, -1)
1152    attention_probs = math_ops.reduce_mean(attention_probs, -1)
1153
1154  return attention_probs
1155
1156
1157def _get_scores(log_probs, sequence_lengths, length_penalty_weight,
1158                coverage_penalty_weight, finished, accumulated_attention_probs):
1159  """Calculates scores for beam search hypotheses.
1160
1161  Args:
1162    log_probs: The log probabilities with shape
1163      `[batch_size, beam_width, vocab_size]`.
1164    sequence_lengths: The array of sequence lengths.
1165    length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
1166    coverage_penalty_weight: Float weight to penalize the coverage of source
1167      sentence. Disabled with 0.0.
1168    finished: A boolean tensor of shape `[batch_size, beam_width]` that
1169      specifies which elements in the beam are finished already.
1170    accumulated_attention_probs: Accumulated attention probabilities up to the
1171      current time step, with shape `[batch_size, beam_width, max_time]` if
1172      coverage_penalty_weight is not 0.0.
1173
1174  Returns:
1175    The scores normalized by the length_penalty and coverage_penalty.
1176
1177  Raises:
1178    ValueError: accumulated_attention_probs is None when coverage penalty is
1179      enabled.
1180  """
1181  length_penalty_ = _length_penalty(
1182      sequence_lengths=sequence_lengths, penalty_factor=length_penalty_weight)
1183  length_penalty_ = math_ops.cast(length_penalty_, dtype=log_probs.dtype)
1184  scores = log_probs / length_penalty_
1185
1186  coverage_penalty_weight = ops.convert_to_tensor(
1187      coverage_penalty_weight, name="coverage_penalty_weight")
1188  if coverage_penalty_weight.shape.ndims != 0:
1189    raise ValueError("coverage_penalty_weight should be a scalar, "
1190                     "but saw shape: %s" % coverage_penalty_weight.shape)
1191
1192  if tensor_util.constant_value(coverage_penalty_weight) == 0.0:
1193    return scores
1194
1195  if accumulated_attention_probs is None:
1196    raise ValueError(
1197        "accumulated_attention_probs can be None only if coverage penalty is "
1198        "disabled.")
1199
1200  # Add source sequence length mask before computing coverage penalty.
1201  accumulated_attention_probs = array_ops.where(
1202      math_ops.equal(accumulated_attention_probs, 0.0),
1203      array_ops.ones_like(accumulated_attention_probs),
1204      accumulated_attention_probs)
1205
1206  # coverage penalty =
1207  #     sum over `max_time` {log(min(accumulated_attention_probs, 1.0))}
1208  coverage_penalty = math_ops.reduce_sum(
1209      math_ops.log(math_ops.minimum(accumulated_attention_probs, 1.0)), 2)
1210  # Apply coverage penalty to finished predictions.
1211  coverage_penalty *= math_ops.cast(finished, dtypes.float32)
1212  weighted_coverage_penalty = coverage_penalty * coverage_penalty_weight
1213  # Reshape from [batch_size, beam_width] to [batch_size, beam_width, 1]
1214  weighted_coverage_penalty = array_ops.expand_dims(
1215      weighted_coverage_penalty, 2)
1216  return scores + weighted_coverage_penalty
1217
1218
1219def attention_probs_from_attn_state(attention_state):
1220  """Calculates the average attention probabilities.
1221
1222  Args:
1223    attention_state: An instance of `AttentionWrapperState`.
1224
1225  Returns:
1226    The attention probabilities in the given AttentionWrapperState.
1227    If there're multiple attention mechanisms, return the average value from
1228    all attention mechanisms.
1229  """
1230  # Attention probabilities over time steps, with shape
1231  # `[batch_size, beam_width, max_time]`.
1232  attention_probs = attention_state.alignments
1233  if isinstance(attention_probs, tuple):
1234    attention_probs = [
1235        array_ops.expand_dims(prob, -1) for prob in attention_probs]
1236    attention_probs = array_ops.concat(attention_probs, -1)
1237    attention_probs = math_ops.reduce_mean(attention_probs, -1)
1238  return attention_probs
1239
1240
1241def _length_penalty(sequence_lengths, penalty_factor):
1242  """Calculates the length penalty. See https://arxiv.org/abs/1609.08144.
1243
1244  Returns the length penalty tensor:
1245  ```
1246  [(5+sequence_lengths)/6]**penalty_factor
1247  ```
1248  where all operations are performed element-wise.
1249
1250  Args:
1251    sequence_lengths: `Tensor`, the sequence lengths of each hypotheses.
1252    penalty_factor: A scalar that weights the length penalty.
1253
1254  Returns:
1255    If the penalty is `0`, returns the scalar `1.0`.  Otherwise returns
1256    the length penalty factor, a tensor with the same shape as
1257    `sequence_lengths`.
1258  """
1259  penalty_factor = ops.convert_to_tensor(penalty_factor, name="penalty_factor")
1260  penalty_factor.set_shape(())  # penalty should be a scalar.
1261  static_penalty = tensor_util.constant_value(penalty_factor)
1262  if static_penalty is not None and static_penalty == 0:
1263    return 1.0
1264  return math_ops.div(
1265      (5. + math_ops.cast(sequence_lengths, dtypes.float32))**penalty_factor,
1266      (5. + 1.)**penalty_factor)
1267
1268
1269def _mask_probs(probs, eos_token, finished):
1270  """Masks log probabilities.
1271
1272  The result is that finished beams allocate all probability mass to eos and
1273  unfinished beams remain unchanged.
1274
1275  Args:
1276    probs: Log probabilities of shape `[batch_size, beam_width, vocab_size]`
1277    eos_token: An int32 id corresponding to the EOS token to allocate
1278      probability to.
1279    finished: A boolean tensor of shape `[batch_size, beam_width]` that
1280      specifies which elements in the beam are finished already.
1281
1282  Returns:
1283    A tensor of shape `[batch_size, beam_width, vocab_size]`, where unfinished
1284    beams stay unchanged and finished beams are replaced with a tensor with all
1285    probability on the EOS token.
1286  """
1287  vocab_size = array_ops.shape(probs)[2]
1288  # All finished examples are replaced with a vector that has all
1289  # probability on EOS
1290  finished_row = array_ops.one_hot(
1291      eos_token,
1292      vocab_size,
1293      dtype=probs.dtype,
1294      on_value=ops.convert_to_tensor(0., dtype=probs.dtype),
1295      off_value=probs.dtype.min)
1296  finished_probs = array_ops.tile(
1297      array_ops.reshape(finished_row, [1, 1, -1]),
1298      array_ops.concat([array_ops.shape(finished), [1]], 0))
1299  finished_mask = array_ops.tile(
1300      array_ops.expand_dims(finished, 2), [1, 1, vocab_size])
1301
1302  return array_ops.where(finished_mask, finished_probs, probs)
1303
1304
1305def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size,
1306                                range_size, gather_shape):
1307  """Maybe applies _tensor_gather_helper.
1308
1309  This applies _tensor_gather_helper when the gather_from dims is at least as
1310  big as the length of gather_shape. This is used in conjunction with nest so
1311  that we don't apply _tensor_gather_helper to inapplicable values like scalars.
1312
1313  Args:
1314    gather_indices: The tensor indices that we use to gather.
1315    gather_from: The tensor that we are gathering from.
1316    batch_size: The batch size.
1317    range_size: The number of values in each range. Likely equal to beam_width.
1318    gather_shape: What we should reshape gather_from to in order to preserve the
1319      correct values. An example is when gather_from is the attention from an
1320      AttentionWrapperState with shape [batch_size, beam_width, attention_size].
1321      There, we want to preserve the attention_size elements, so gather_shape is
1322      [batch_size * beam_width, -1]. Then, upon reshape, we still have the
1323      attention_size as desired.
1324
1325  Returns:
1326    output: Gathered tensor of shape tf.shape(gather_from)[:1+len(gather_shape)]
1327      or the original tensor if its dimensions are too small.
1328  """
1329  if isinstance(gather_from, tensor_array_ops.TensorArray):
1330    return gather_from
1331  _check_ndims(gather_from)
1332  if gather_from.shape.ndims >= len(gather_shape):
1333    return _tensor_gather_helper(
1334        gather_indices=gather_indices,
1335        gather_from=gather_from,
1336        batch_size=batch_size,
1337        range_size=range_size,
1338        gather_shape=gather_shape)
1339  else:
1340    return gather_from
1341
1342
1343def _tensor_gather_helper(gather_indices,
1344                          gather_from,
1345                          batch_size,
1346                          range_size,
1347                          gather_shape,
1348                          name=None):
1349  """Helper for gathering the right indices from the tensor.
1350
1351  This works by reshaping gather_from to gather_shape (e.g. [-1]) and then
1352  gathering from that according to the gather_indices, which are offset by
1353  the right amounts in order to preserve the batch order.
1354
1355  Args:
1356    gather_indices: The tensor indices that we use to gather.
1357    gather_from: The tensor that we are gathering from.
1358    batch_size: The input batch size.
1359    range_size: The number of values in each range. Likely equal to beam_width.
1360    gather_shape: What we should reshape gather_from to in order to preserve the
1361      correct values. An example is when gather_from is the attention from an
1362      AttentionWrapperState with shape [batch_size, beam_width, attention_size].
1363      There, we want to preserve the attention_size elements, so gather_shape is
1364      [batch_size * beam_width, -1]. Then, upon reshape, we still have the
1365      attention_size as desired.
1366    name: The tensor name for set of operations. By default this is
1367      'tensor_gather_helper'. The final output is named 'output'.
1368
1369  Returns:
1370    output: Gathered tensor of shape tf.shape(gather_from)[:1+len(gather_shape)]
1371  """
1372  with ops.name_scope(name, "tensor_gather_helper"):
1373    range_ = array_ops.expand_dims(math_ops.range(batch_size) * range_size, 1)
1374    gather_indices = array_ops.reshape(gather_indices + range_, [-1])
1375    output = array_ops.gather(
1376        array_ops.reshape(gather_from, gather_shape), gather_indices)
1377    final_shape = array_ops.shape(gather_from)[:1 + len(gather_shape)]
1378    static_batch_size = tensor_util.constant_value(batch_size)
1379    final_static_shape = (
1380        tensor_shape.TensorShape([static_batch_size]).concatenate(
1381            gather_from.shape[1:1 + len(gather_shape)]))
1382    output = array_ops.reshape(output, final_shape, name="output")
1383    output.set_shape(final_static_shape)
1384    return output
1385