1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for contrib.seq2seq.python.seq2seq.beam_search_decoder."""
16# pylint: disable=unused-import,g-bad-import-order
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20# pylint: enable=unused-import
21
22import numpy as np
23
24from tensorflow.contrib.seq2seq.python.ops import attention_wrapper
25from tensorflow.contrib.seq2seq.python.ops import beam_search_decoder
26from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
27from tensorflow.contrib.seq2seq.python.ops import decoder
28from tensorflow.python.eager import context
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import errors
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import test_util
34from tensorflow.python.keras import layers
35from tensorflow.python.layers import core as layers_core
36from tensorflow.python.ops import array_ops
37from tensorflow.python.ops import nn_ops
38from tensorflow.python.ops import rnn_cell
39from tensorflow.python.ops import variables
40from tensorflow.python.platform import test
41
42# pylint: enable=g-import-not-at-top
43
44
45class TestGatherTree(test.TestCase):
46  """Tests the gather_tree function."""
47
48  def test_gather_tree(self):
49    # (max_time = 3, batch_size = 2, beam_width = 3)
50
51    # create (batch_size, max_time, beam_width) matrix and transpose it
52    predicted_ids = np.array(
53        [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[2, 3, 4], [5, 6, 7], [8, 9, 10]]],
54        dtype=np.int32).transpose([1, 0, 2])
55    parent_ids = np.array(
56        [[[0, 0, 0], [0, 1, 1], [2, 1, 2]], [[0, 0, 0], [1, 2, 0], [2, 1, 1]]],
57        dtype=np.int32).transpose([1, 0, 2])
58
59    # sequence_lengths is shaped (batch_size = 3)
60    max_sequence_lengths = [3, 3]
61
62    expected_result = np.array([[[2, 2, 2], [6, 5, 6], [7, 8, 9]],
63                                [[2, 4, 4], [7, 6, 6],
64                                 [8, 9, 10]]]).transpose([1, 0, 2])
65
66    res = beam_search_ops.gather_tree(
67        predicted_ids,
68        parent_ids,
69        max_sequence_lengths=max_sequence_lengths,
70        end_token=11)
71
72    with self.cached_session() as sess:
73      res_ = sess.run(res)
74
75    self.assertAllEqual(expected_result, res_)
76
77  def _test_gather_tree_from_array(self,
78                                   depth_ndims=0,
79                                   merged_batch_beam=False):
80    array = np.array(
81        [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [0, 0, 0]],
82         [[2, 3, 4], [5, 6, 7], [8, 9, 10], [11, 12, 0]]]).transpose([1, 0, 2])
83    parent_ids = np.array(
84        [[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]],
85         [[0, 0, 0], [1, 1, 0], [2, 0, 1], [0, 1, 0]]]).transpose([1, 0, 2])
86    expected_array = np.array(
87        [[[2, 2, 2], [6, 5, 6], [7, 8, 9], [0, 0, 0]],
88         [[2, 3, 2], [7, 5, 7], [8, 9, 8], [11, 12, 0]]]).transpose([1, 0, 2])
89    sequence_length = [[3, 3, 3], [4, 4, 3]]
90
91    array = ops.convert_to_tensor(
92        array, dtype=dtypes.float32)
93    parent_ids = ops.convert_to_tensor(
94        parent_ids, dtype=dtypes.int32)
95    expected_array = ops.convert_to_tensor(
96        expected_array, dtype=dtypes.float32)
97
98    max_time = array_ops.shape(array)[0]
99    batch_size = array_ops.shape(array)[1]
100    beam_width = array_ops.shape(array)[2]
101
102    def _tile_in_depth(tensor):
103      # Generate higher rank tensors by concatenating tensor and tensor + 1.
104      for _ in range(depth_ndims):
105        tensor = array_ops.stack([tensor, tensor + 1], -1)
106      return tensor
107
108    if merged_batch_beam:
109      array = array_ops.reshape(
110          array, [max_time, batch_size * beam_width])
111      expected_array = array_ops.reshape(
112          expected_array, [max_time, batch_size * beam_width])
113
114    if depth_ndims > 0:
115      array = _tile_in_depth(array)
116      expected_array = _tile_in_depth(expected_array)
117
118    sorted_array = beam_search_decoder.gather_tree_from_array(
119        array, parent_ids, sequence_length)
120
121    with self.cached_session() as sess:
122      sorted_array = sess.run(sorted_array)
123      expected_array = sess.run(expected_array)
124      self.assertAllEqual(expected_array, sorted_array)
125
126  def test_gather_tree_from_array_scalar(self):
127    self._test_gather_tree_from_array()
128
129  def test_gather_tree_from_array_1d(self):
130    self._test_gather_tree_from_array(depth_ndims=1)
131
132  def test_gather_tree_from_array_1d_with_merged_batch_beam(self):
133    self._test_gather_tree_from_array(depth_ndims=1, merged_batch_beam=True)
134
135  def test_gather_tree_from_array_2d(self):
136    self._test_gather_tree_from_array(depth_ndims=2)
137
138  def test_gather_tree_from_array_complex_trajectory(self):
139    # Max. time = 7, batch = 1, beam = 5.
140    array = np.expand_dims(np.array(
141        [[[25, 12, 114, 89, 97]],
142         [[9, 91, 64, 11, 162]],
143         [[34, 34, 34, 34, 34]],
144         [[2, 4, 2, 2, 4]],
145         [[2, 3, 6, 2, 2]],
146         [[2, 2, 2, 3, 2]],
147         [[2, 2, 2, 2, 2]]]), -1)
148    parent_ids = np.array(
149        [[[0, 0, 0, 0, 0]],
150         [[0, 0, 0, 0, 0]],
151         [[0, 1, 2, 3, 4]],
152         [[0, 0, 1, 2, 1]],
153         [[0, 1, 1, 2, 3]],
154         [[0, 1, 3, 1, 2]],
155         [[0, 1, 2, 3, 4]]])
156    expected_array = np.expand_dims(np.array(
157        [[[25, 25, 25, 25, 25]],
158         [[9, 9, 91, 9, 9]],
159         [[34, 34, 34, 34, 34]],
160         [[2, 4, 2, 4, 4]],
161         [[2, 3, 6, 3, 6]],
162         [[2, 2, 2, 3, 2]],
163         [[2, 2, 2, 2, 2]]]), -1)
164    sequence_length = [[4, 6, 4, 7, 6]]
165
166    array = ops.convert_to_tensor(
167        array, dtype=dtypes.float32)
168    parent_ids = ops.convert_to_tensor(
169        parent_ids, dtype=dtypes.int32)
170    expected_array = ops.convert_to_tensor(
171        expected_array, dtype=dtypes.float32)
172
173    sorted_array = beam_search_decoder.gather_tree_from_array(
174        array, parent_ids, sequence_length)
175
176    with self.cached_session() as sess:
177      sorted_array, expected_array = sess.run([sorted_array, expected_array])
178      self.assertAllEqual(expected_array, sorted_array)
179
180
181class TestArrayShapeChecks(test.TestCase):
182
183  def _test_array_shape_dynamic_checks(self, static_shape, dynamic_shape,
184                                       batch_size, beam_width, is_valid=True):
185    t = array_ops.placeholder_with_default(
186        np.random.randn(*static_shape).astype(np.float32),
187        shape=dynamic_shape)
188
189    batch_size = array_ops.constant(batch_size)
190
191    def _test_body():
192      # pylint: disable=protected-access
193      if context.executing_eagerly():
194        beam_search_decoder._check_batch_beam(t, batch_size, beam_width)
195      else:
196        with self.cached_session():
197          check_op = beam_search_decoder._check_batch_beam(
198              t, batch_size, beam_width)
199          self.evaluate(check_op)
200      # pylint: enable=protected-access
201
202    if is_valid:
203      _test_body()
204    else:
205      with self.assertRaises(errors.InvalidArgumentError):
206        _test_body()
207
208  def test_array_shape_dynamic_checks(self):
209    self._test_array_shape_dynamic_checks(
210        (8, 4, 5, 10), (None, None, 5, 10), 4, 5, is_valid=True)
211    self._test_array_shape_dynamic_checks(
212        (8, 20, 10), (None, None, 10), 4, 5, is_valid=True)
213    self._test_array_shape_dynamic_checks(
214        (8, 21, 10), (None, None, 10), 4, 5, is_valid=False)
215    self._test_array_shape_dynamic_checks(
216        (8, 4, 6, 10), (None, None, None, 10), 4, 5, is_valid=False)
217    self._test_array_shape_dynamic_checks(
218        (8, 4), (None, None), 4, 5, is_valid=False)
219
220
221class TestEosMasking(test.TestCase):
222  """Tests EOS masking used in beam search."""
223
224  def test_eos_masking(self):
225    probs = constant_op.constant([
226        [[-.2, -.2, -.2, -.2, -.2], [-.3, -.3, -.3, 3, 0], [5, 6, 0, 0, 0]],
227        [[-.2, -.2, -.2, -.2, 0], [-.3, -.3, -.1, 3, 0], [5, 6, 3, 0, 0]],
228    ])
229
230    eos_token = 0
231    previously_finished = np.array([[0, 1, 0], [0, 1, 1]], dtype=bool)
232    masked = beam_search_decoder._mask_probs(probs, eos_token,
233                                             previously_finished)
234
235    with self.cached_session() as sess:
236      probs = sess.run(probs)
237      masked = sess.run(masked)
238
239      self.assertAllEqual(probs[0][0], masked[0][0])
240      self.assertAllEqual(probs[0][2], masked[0][2])
241      self.assertAllEqual(probs[1][0], masked[1][0])
242
243      self.assertEqual(masked[0][1][0], 0)
244      self.assertEqual(masked[1][1][0], 0)
245      self.assertEqual(masked[1][2][0], 0)
246
247      for i in range(1, 5):
248        self.assertAllClose(masked[0][1][i], np.finfo('float32').min)
249        self.assertAllClose(masked[1][1][i], np.finfo('float32').min)
250        self.assertAllClose(masked[1][2][i], np.finfo('float32').min)
251
252
253class TestBeamStep(test.TestCase):
254  """Tests a single step of beam search."""
255
256  def setUp(self):
257    super(TestBeamStep, self).setUp()
258    self.batch_size = 2
259    self.beam_width = 3
260    self.vocab_size = 5
261    self.end_token = 0
262    self.length_penalty_weight = 0.6
263    self.coverage_penalty_weight = 0.0
264
265  def test_step(self):
266    dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width])
267    beam_state = beam_search_decoder.BeamSearchDecoderState(
268        cell_state=dummy_cell_state,
269        log_probs=nn_ops.log_softmax(
270            array_ops.ones([self.batch_size, self.beam_width])),
271        lengths=constant_op.constant(
272            2, shape=[self.batch_size, self.beam_width], dtype=dtypes.int64),
273        finished=array_ops.zeros(
274            [self.batch_size, self.beam_width], dtype=dtypes.bool),
275        accumulated_attention_probs=())
276
277    logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
278                      0.0001)
279    logits_[0, 0, 2] = 1.9
280    logits_[0, 0, 3] = 2.1
281    logits_[0, 1, 3] = 3.1
282    logits_[0, 1, 4] = 0.9
283    logits_[1, 0, 1] = 0.5
284    logits_[1, 1, 2] = 2.7
285    logits_[1, 2, 2] = 10.0
286    logits_[1, 2, 3] = 0.2
287    logits = ops.convert_to_tensor(logits_, dtype=dtypes.float32)
288    log_probs = nn_ops.log_softmax(logits)
289
290    outputs, next_beam_state = beam_search_decoder._beam_search_step(
291        time=2,
292        logits=logits,
293        next_cell_state=dummy_cell_state,
294        beam_state=beam_state,
295        batch_size=ops.convert_to_tensor(self.batch_size),
296        beam_width=self.beam_width,
297        end_token=self.end_token,
298        length_penalty_weight=self.length_penalty_weight,
299        coverage_penalty_weight=self.coverage_penalty_weight)
300
301    with self.cached_session() as sess:
302      outputs_, next_state_, state_, log_probs_ = sess.run(
303          [outputs, next_beam_state, beam_state, log_probs])
304
305    self.assertAllEqual(outputs_.predicted_ids, [[3, 3, 2], [2, 2, 1]])
306    self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [2, 1, 0]])
307    self.assertAllEqual(next_state_.lengths, [[3, 3, 3], [3, 3, 3]])
308    self.assertAllEqual(next_state_.finished,
309                        [[False, False, False], [False, False, False]])
310
311    expected_log_probs = []
312    expected_log_probs.append(state_.log_probs[0][[1, 0, 0]])
313    expected_log_probs.append(state_.log_probs[1][[2, 1, 0]])  # 0 --> 1
314    expected_log_probs[0][0] += log_probs_[0, 1, 3]
315    expected_log_probs[0][1] += log_probs_[0, 0, 3]
316    expected_log_probs[0][2] += log_probs_[0, 0, 2]
317    expected_log_probs[1][0] += log_probs_[1, 2, 2]
318    expected_log_probs[1][1] += log_probs_[1, 1, 2]
319    expected_log_probs[1][2] += log_probs_[1, 0, 1]
320    self.assertAllEqual(next_state_.log_probs, expected_log_probs)
321
322  def test_step_with_eos(self):
323    dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width])
324    beam_state = beam_search_decoder.BeamSearchDecoderState(
325        cell_state=dummy_cell_state,
326        log_probs=nn_ops.log_softmax(
327            array_ops.ones([self.batch_size, self.beam_width])),
328        lengths=ops.convert_to_tensor(
329            [[2, 1, 2], [2, 2, 1]], dtype=dtypes.int64),
330        finished=ops.convert_to_tensor(
331            [[False, True, False], [False, False, True]], dtype=dtypes.bool),
332        accumulated_attention_probs=())
333
334    logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
335                      0.0001)
336    logits_[0, 0, 2] = 1.9
337    logits_[0, 0, 3] = 2.1
338    logits_[0, 1, 3] = 3.1
339    logits_[0, 1, 4] = 0.9
340    logits_[1, 0, 1] = 0.5
341    logits_[1, 1, 2] = 5.7  # why does this not work when it's 2.7?
342    logits_[1, 2, 2] = 1.0
343    logits_[1, 2, 3] = 0.2
344    logits = ops.convert_to_tensor(logits_, dtype=dtypes.float32)
345    log_probs = nn_ops.log_softmax(logits)
346
347    outputs, next_beam_state = beam_search_decoder._beam_search_step(
348        time=2,
349        logits=logits,
350        next_cell_state=dummy_cell_state,
351        beam_state=beam_state,
352        batch_size=ops.convert_to_tensor(self.batch_size),
353        beam_width=self.beam_width,
354        end_token=self.end_token,
355        length_penalty_weight=self.length_penalty_weight,
356        coverage_penalty_weight=self.coverage_penalty_weight)
357
358    with self.cached_session() as sess:
359      outputs_, next_state_, state_, log_probs_ = sess.run(
360          [outputs, next_beam_state, beam_state, log_probs])
361
362    self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [1, 2, 0]])
363    self.assertAllEqual(outputs_.predicted_ids, [[0, 3, 2], [2, 0, 1]])
364    self.assertAllEqual(next_state_.lengths, [[1, 3, 3], [3, 1, 3]])
365    self.assertAllEqual(next_state_.finished,
366                        [[True, False, False], [False, True, False]])
367
368    expected_log_probs = []
369    expected_log_probs.append(state_.log_probs[0][[1, 0, 0]])
370    expected_log_probs.append(state_.log_probs[1][[1, 2, 0]])
371    expected_log_probs[0][1] += log_probs_[0, 0, 3]
372    expected_log_probs[0][2] += log_probs_[0, 0, 2]
373    expected_log_probs[1][0] += log_probs_[1, 1, 2]
374    expected_log_probs[1][2] += log_probs_[1, 0, 1]
375    self.assertAllEqual(next_state_.log_probs, expected_log_probs)
376
377
378class TestLargeBeamStep(test.TestCase):
379  """Tests large beam step.
380
381  Tests a single step of beam search in such case that beam size is larger than
382  vocabulary size.
383  """
384
385  def setUp(self):
386    super(TestLargeBeamStep, self).setUp()
387    self.batch_size = 2
388    self.beam_width = 8
389    self.vocab_size = 5
390    self.end_token = 0
391    self.length_penalty_weight = 0.6
392    self.coverage_penalty_weight = 0.0
393
394  def test_step(self):
395
396    def get_probs():
397      """this simulates the initialize method in BeamSearchDecoder."""
398      log_prob_mask = array_ops.one_hot(
399          array_ops.zeros([self.batch_size], dtype=dtypes.int32),
400          depth=self.beam_width,
401          on_value=True,
402          off_value=False,
403          dtype=dtypes.bool)
404
405      log_prob_zeros = array_ops.zeros(
406          [self.batch_size, self.beam_width], dtype=dtypes.float32)
407      log_prob_neg_inf = array_ops.ones(
408          [self.batch_size, self.beam_width], dtype=dtypes.float32) * -np.Inf
409
410      log_probs = array_ops.where(log_prob_mask, log_prob_zeros,
411                                  log_prob_neg_inf)
412      return log_probs
413
414    log_probs = get_probs()
415    dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width])
416
417    # pylint: disable=invalid-name
418    _finished = array_ops.one_hot(
419        array_ops.zeros([self.batch_size], dtype=dtypes.int32),
420        depth=self.beam_width,
421        on_value=False,
422        off_value=True,
423        dtype=dtypes.bool)
424    _lengths = np.zeros([self.batch_size, self.beam_width], dtype=np.int64)
425    _lengths[:, 0] = 2
426    _lengths = constant_op.constant(_lengths, dtype=dtypes.int64)
427
428    beam_state = beam_search_decoder.BeamSearchDecoderState(
429        cell_state=dummy_cell_state,
430        log_probs=log_probs,
431        lengths=_lengths,
432        finished=_finished,
433        accumulated_attention_probs=())
434
435    logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
436                      0.0001)
437    logits_[0, 0, 2] = 1.9
438    logits_[0, 0, 3] = 2.1
439    logits_[0, 1, 3] = 3.1
440    logits_[0, 1, 4] = 0.9
441    logits_[1, 0, 1] = 0.5
442    logits_[1, 1, 2] = 2.7
443    logits_[1, 2, 2] = 10.0
444    logits_[1, 2, 3] = 0.2
445    logits = constant_op.constant(logits_, dtype=dtypes.float32)
446    log_probs = nn_ops.log_softmax(logits)
447
448    outputs, next_beam_state = beam_search_decoder._beam_search_step(
449        time=2,
450        logits=logits,
451        next_cell_state=dummy_cell_state,
452        beam_state=beam_state,
453        batch_size=ops.convert_to_tensor(self.batch_size),
454        beam_width=self.beam_width,
455        end_token=self.end_token,
456        length_penalty_weight=self.length_penalty_weight,
457        coverage_penalty_weight=self.coverage_penalty_weight)
458
459    with self.cached_session() as sess:
460      outputs_, next_state_, _, _ = sess.run(
461          [outputs, next_beam_state, beam_state, log_probs])
462
463    self.assertEqual(outputs_.predicted_ids[0, 0], 3)
464    self.assertEqual(outputs_.predicted_ids[0, 1], 2)
465    self.assertEqual(outputs_.predicted_ids[1, 0], 1)
466    neg_inf = -np.Inf
467    self.assertAllEqual(
468        next_state_.log_probs[:, -3:],
469        [[neg_inf, neg_inf, neg_inf], [neg_inf, neg_inf, neg_inf]])
470    self.assertEqual((next_state_.log_probs[:, :-3] > neg_inf).all(), True)
471    self.assertEqual((next_state_.lengths[:, :-3] > 0).all(), True)
472    self.assertAllEqual(next_state_.lengths[:, -3:], [[0, 0, 0], [0, 0, 0]])
473
474
475@test_util.run_v1_only
476class BeamSearchDecoderTest(test.TestCase):
477
478  def _testDynamicDecodeRNN(self, time_major, has_attention,
479                            with_alignment_history=False):
480    encoder_sequence_length = np.array([3, 2, 3, 1, 1])
481    decoder_sequence_length = np.array([2, 0, 1, 2, 3])
482    batch_size = 5
483    decoder_max_time = 4
484    input_depth = 7
485    cell_depth = 9
486    attention_depth = 6
487    vocab_size = 20
488    end_token = vocab_size - 1
489    start_token = 0
490    embedding_dim = 50
491    max_out = max(decoder_sequence_length)
492    output_layer = layers_core.Dense(vocab_size, use_bias=True, activation=None)
493    beam_width = 3
494
495    with self.cached_session() as sess:
496      batch_size_tensor = constant_op.constant(batch_size)
497      embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
498      cell = rnn_cell.LSTMCell(cell_depth)
499      initial_state = cell.zero_state(batch_size, dtypes.float32)
500      coverage_penalty_weight = 0.0
501      if has_attention:
502        coverage_penalty_weight = 0.2
503        inputs = array_ops.placeholder_with_default(
504            np.random.randn(batch_size, decoder_max_time, input_depth).astype(
505                np.float32),
506            shape=(None, None, input_depth))
507        tiled_inputs = beam_search_decoder.tile_batch(
508            inputs, multiplier=beam_width)
509        tiled_sequence_length = beam_search_decoder.tile_batch(
510            encoder_sequence_length, multiplier=beam_width)
511        attention_mechanism = attention_wrapper.BahdanauAttention(
512            num_units=attention_depth,
513            memory=tiled_inputs,
514            memory_sequence_length=tiled_sequence_length)
515        initial_state = beam_search_decoder.tile_batch(
516            initial_state, multiplier=beam_width)
517        cell = attention_wrapper.AttentionWrapper(
518            cell=cell,
519            attention_mechanism=attention_mechanism,
520            attention_layer_size=attention_depth,
521            alignment_history=with_alignment_history)
522      cell_state = cell.zero_state(
523          dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width)
524      if has_attention:
525        cell_state = cell_state.clone(cell_state=initial_state)
526      bsd = beam_search_decoder.BeamSearchDecoder(
527          cell=cell,
528          embedding=embedding,
529          start_tokens=array_ops.fill([batch_size_tensor], start_token),
530          end_token=end_token,
531          initial_state=cell_state,
532          beam_width=beam_width,
533          output_layer=output_layer,
534          length_penalty_weight=0.0,
535          coverage_penalty_weight=coverage_penalty_weight)
536
537      final_outputs, final_state, final_sequence_lengths = (
538          decoder.dynamic_decode(
539              bsd, output_time_major=time_major, maximum_iterations=max_out))
540
541      def _t(shape):
542        if time_major:
543          return (shape[1], shape[0]) + shape[2:]
544        return shape
545
546      self.assertIsInstance(
547          final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput)
548      self.assertIsInstance(
549          final_state, beam_search_decoder.BeamSearchDecoderState)
550
551      beam_search_decoder_output = final_outputs.beam_search_decoder_output
552      self.assertEqual(
553          _t((batch_size, None, beam_width)),
554          tuple(beam_search_decoder_output.scores.get_shape().as_list()))
555      self.assertEqual(
556          _t((batch_size, None, beam_width)),
557          tuple(final_outputs.predicted_ids.get_shape().as_list()))
558
559      sess.run(variables.global_variables_initializer())
560      sess_results = sess.run({
561          'final_outputs': final_outputs,
562          'final_state': final_state,
563          'final_sequence_lengths': final_sequence_lengths
564      })
565
566      max_sequence_length = np.max(sess_results['final_sequence_lengths'])
567
568      # A smoke test
569      self.assertEqual(
570          _t((batch_size, max_sequence_length, beam_width)),
571          sess_results['final_outputs'].beam_search_decoder_output.scores.shape)
572      self.assertEqual(
573          _t((batch_size, max_sequence_length, beam_width)), sess_results[
574              'final_outputs'].beam_search_decoder_output.predicted_ids.shape)
575
576  def testDynamicDecodeRNNBatchMajorNoAttention(self):
577    self._testDynamicDecodeRNN(time_major=False, has_attention=False)
578
579  def testDynamicDecodeRNNBatchMajorYesAttention(self):
580    self._testDynamicDecodeRNN(time_major=False, has_attention=True)
581
582  def testDynamicDecodeRNNBatchMajorYesAttentionWithAlignmentHistory(self):
583    self._testDynamicDecodeRNN(
584        time_major=False,
585        has_attention=True,
586        with_alignment_history=True)
587
588
589@test_util.run_all_in_graph_and_eager_modes
590class BeamSearchDecoderV2Test(test.TestCase):
591
592  def _testDynamicDecodeRNN(self, time_major, has_attention,
593                            with_alignment_history=False):
594    encoder_sequence_length = np.array([3, 2, 3, 1, 1])
595    decoder_sequence_length = np.array([2, 0, 1, 2, 3])
596    batch_size = 5
597    decoder_max_time = 4
598    input_depth = 7
599    cell_depth = 9
600    attention_depth = 6
601    vocab_size = 20
602    end_token = vocab_size - 1
603    start_token = 0
604    embedding_dim = 50
605    max_out = max(decoder_sequence_length)
606    output_layer = layers.Dense(vocab_size, use_bias=True, activation=None)
607    beam_width = 3
608
609    with self.cached_session():
610      batch_size_tensor = constant_op.constant(batch_size)
611      embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
612      cell = rnn_cell.LSTMCell(cell_depth)
613      initial_state = cell.zero_state(batch_size, dtypes.float32)
614      coverage_penalty_weight = 0.0
615      if has_attention:
616        coverage_penalty_weight = 0.2
617        inputs = array_ops.placeholder_with_default(
618            np.random.randn(batch_size, decoder_max_time, input_depth).astype(
619                np.float32),
620            shape=(None, None, input_depth))
621        tiled_inputs = beam_search_decoder.tile_batch(
622            inputs, multiplier=beam_width)
623        tiled_sequence_length = beam_search_decoder.tile_batch(
624            encoder_sequence_length, multiplier=beam_width)
625        attention_mechanism = attention_wrapper.BahdanauAttention(
626            num_units=attention_depth,
627            memory=tiled_inputs,
628            memory_sequence_length=tiled_sequence_length)
629        initial_state = beam_search_decoder.tile_batch(
630            initial_state, multiplier=beam_width)
631        cell = attention_wrapper.AttentionWrapper(
632            cell=cell,
633            attention_mechanism=attention_mechanism,
634            attention_layer_size=attention_depth,
635            alignment_history=with_alignment_history)
636      cell_state = cell.zero_state(
637          dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width)
638      if has_attention:
639        cell_state = cell_state.clone(cell_state=initial_state)
640      bsd = beam_search_decoder.BeamSearchDecoderV2(
641          cell=cell,
642          beam_width=beam_width,
643          output_layer=output_layer,
644          length_penalty_weight=0.0,
645          coverage_penalty_weight=coverage_penalty_weight,
646          output_time_major=time_major,
647          maximum_iterations=max_out)
648
649      final_outputs, final_state, final_sequence_lengths = bsd(
650          embedding,
651          start_tokens=array_ops.fill([batch_size_tensor], start_token),
652          end_token=end_token,
653          initial_state=cell_state)
654
655      def _t(shape):
656        if time_major:
657          return (shape[1], shape[0]) + shape[2:]
658        return shape
659
660      self.assertIsInstance(
661          final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput)
662      self.assertIsInstance(
663          final_state, beam_search_decoder.BeamSearchDecoderState)
664
665      beam_search_decoder_output = final_outputs.beam_search_decoder_output
666      expected_seq_length = 3 if context.executing_eagerly() else None
667      self.assertEqual(
668          _t((batch_size, expected_seq_length, beam_width)),
669          tuple(beam_search_decoder_output.scores.get_shape().as_list()))
670      self.assertEqual(
671          _t((batch_size, expected_seq_length, beam_width)),
672          tuple(final_outputs.predicted_ids.get_shape().as_list()))
673
674      self.evaluate(variables.global_variables_initializer())
675      eval_results = self.evaluate({
676          'final_outputs': final_outputs,
677          'final_sequence_lengths': final_sequence_lengths
678      })
679
680      max_sequence_length = np.max(eval_results['final_sequence_lengths'])
681
682      # A smoke test
683      self.assertEqual(
684          _t((batch_size, max_sequence_length, beam_width)),
685          eval_results['final_outputs'].beam_search_decoder_output.scores.shape)
686      self.assertEqual(
687          _t((batch_size, max_sequence_length, beam_width)), eval_results[
688              'final_outputs'].beam_search_decoder_output.predicted_ids.shape)
689
690  def testDynamicDecodeRNNBatchMajorNoAttention(self):
691    self._testDynamicDecodeRNN(time_major=False, has_attention=False)
692
693  def testDynamicDecodeRNNBatchMajorYesAttention(self):
694    self._testDynamicDecodeRNN(time_major=False, has_attention=True)
695
696  def testDynamicDecodeRNNBatchMajorYesAttentionWithAlignmentHistory(self):
697    self._testDynamicDecodeRNN(
698        time_major=False,
699        has_attention=True,
700        with_alignment_history=True)
701
702
703if __name__ == '__main__':
704  test.main()
705