1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for contrib.seq2seq.python.seq2seq.beam_search_decoder.""" 16# pylint: disable=unused-import,g-bad-import-order 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20# pylint: enable=unused-import 21 22import numpy as np 23 24from tensorflow.contrib.seq2seq.python.ops import attention_wrapper 25from tensorflow.contrib.seq2seq.python.ops import beam_search_decoder 26from tensorflow.contrib.seq2seq.python.ops import beam_search_ops 27from tensorflow.contrib.seq2seq.python.ops import decoder 28from tensorflow.python.eager import context 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import errors 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import test_util 34from tensorflow.python.keras import layers 35from tensorflow.python.layers import core as layers_core 36from tensorflow.python.ops import array_ops 37from tensorflow.python.ops import nn_ops 38from tensorflow.python.ops import rnn_cell 39from tensorflow.python.ops import variables 40from tensorflow.python.platform import test 41 42# pylint: enable=g-import-not-at-top 43 44 45class TestGatherTree(test.TestCase): 46 """Tests the gather_tree function.""" 47 48 def test_gather_tree(self): 49 # (max_time = 3, batch_size = 2, beam_width = 3) 50 51 # create (batch_size, max_time, beam_width) matrix and transpose it 52 predicted_ids = np.array( 53 [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[2, 3, 4], [5, 6, 7], [8, 9, 10]]], 54 dtype=np.int32).transpose([1, 0, 2]) 55 parent_ids = np.array( 56 [[[0, 0, 0], [0, 1, 1], [2, 1, 2]], [[0, 0, 0], [1, 2, 0], [2, 1, 1]]], 57 dtype=np.int32).transpose([1, 0, 2]) 58 59 # sequence_lengths is shaped (batch_size = 3) 60 max_sequence_lengths = [3, 3] 61 62 expected_result = np.array([[[2, 2, 2], [6, 5, 6], [7, 8, 9]], 63 [[2, 4, 4], [7, 6, 6], 64 [8, 9, 10]]]).transpose([1, 0, 2]) 65 66 res = beam_search_ops.gather_tree( 67 predicted_ids, 68 parent_ids, 69 max_sequence_lengths=max_sequence_lengths, 70 end_token=11) 71 72 with self.cached_session() as sess: 73 res_ = sess.run(res) 74 75 self.assertAllEqual(expected_result, res_) 76 77 def _test_gather_tree_from_array(self, 78 depth_ndims=0, 79 merged_batch_beam=False): 80 array = np.array( 81 [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [0, 0, 0]], 82 [[2, 3, 4], [5, 6, 7], [8, 9, 10], [11, 12, 0]]]).transpose([1, 0, 2]) 83 parent_ids = np.array( 84 [[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]], 85 [[0, 0, 0], [1, 1, 0], [2, 0, 1], [0, 1, 0]]]).transpose([1, 0, 2]) 86 expected_array = np.array( 87 [[[2, 2, 2], [6, 5, 6], [7, 8, 9], [0, 0, 0]], 88 [[2, 3, 2], [7, 5, 7], [8, 9, 8], [11, 12, 0]]]).transpose([1, 0, 2]) 89 sequence_length = [[3, 3, 3], [4, 4, 3]] 90 91 array = ops.convert_to_tensor( 92 array, dtype=dtypes.float32) 93 parent_ids = ops.convert_to_tensor( 94 parent_ids, dtype=dtypes.int32) 95 expected_array = ops.convert_to_tensor( 96 expected_array, dtype=dtypes.float32) 97 98 max_time = array_ops.shape(array)[0] 99 batch_size = array_ops.shape(array)[1] 100 beam_width = array_ops.shape(array)[2] 101 102 def _tile_in_depth(tensor): 103 # Generate higher rank tensors by concatenating tensor and tensor + 1. 104 for _ in range(depth_ndims): 105 tensor = array_ops.stack([tensor, tensor + 1], -1) 106 return tensor 107 108 if merged_batch_beam: 109 array = array_ops.reshape( 110 array, [max_time, batch_size * beam_width]) 111 expected_array = array_ops.reshape( 112 expected_array, [max_time, batch_size * beam_width]) 113 114 if depth_ndims > 0: 115 array = _tile_in_depth(array) 116 expected_array = _tile_in_depth(expected_array) 117 118 sorted_array = beam_search_decoder.gather_tree_from_array( 119 array, parent_ids, sequence_length) 120 121 with self.cached_session() as sess: 122 sorted_array = sess.run(sorted_array) 123 expected_array = sess.run(expected_array) 124 self.assertAllEqual(expected_array, sorted_array) 125 126 def test_gather_tree_from_array_scalar(self): 127 self._test_gather_tree_from_array() 128 129 def test_gather_tree_from_array_1d(self): 130 self._test_gather_tree_from_array(depth_ndims=1) 131 132 def test_gather_tree_from_array_1d_with_merged_batch_beam(self): 133 self._test_gather_tree_from_array(depth_ndims=1, merged_batch_beam=True) 134 135 def test_gather_tree_from_array_2d(self): 136 self._test_gather_tree_from_array(depth_ndims=2) 137 138 def test_gather_tree_from_array_complex_trajectory(self): 139 # Max. time = 7, batch = 1, beam = 5. 140 array = np.expand_dims(np.array( 141 [[[25, 12, 114, 89, 97]], 142 [[9, 91, 64, 11, 162]], 143 [[34, 34, 34, 34, 34]], 144 [[2, 4, 2, 2, 4]], 145 [[2, 3, 6, 2, 2]], 146 [[2, 2, 2, 3, 2]], 147 [[2, 2, 2, 2, 2]]]), -1) 148 parent_ids = np.array( 149 [[[0, 0, 0, 0, 0]], 150 [[0, 0, 0, 0, 0]], 151 [[0, 1, 2, 3, 4]], 152 [[0, 0, 1, 2, 1]], 153 [[0, 1, 1, 2, 3]], 154 [[0, 1, 3, 1, 2]], 155 [[0, 1, 2, 3, 4]]]) 156 expected_array = np.expand_dims(np.array( 157 [[[25, 25, 25, 25, 25]], 158 [[9, 9, 91, 9, 9]], 159 [[34, 34, 34, 34, 34]], 160 [[2, 4, 2, 4, 4]], 161 [[2, 3, 6, 3, 6]], 162 [[2, 2, 2, 3, 2]], 163 [[2, 2, 2, 2, 2]]]), -1) 164 sequence_length = [[4, 6, 4, 7, 6]] 165 166 array = ops.convert_to_tensor( 167 array, dtype=dtypes.float32) 168 parent_ids = ops.convert_to_tensor( 169 parent_ids, dtype=dtypes.int32) 170 expected_array = ops.convert_to_tensor( 171 expected_array, dtype=dtypes.float32) 172 173 sorted_array = beam_search_decoder.gather_tree_from_array( 174 array, parent_ids, sequence_length) 175 176 with self.cached_session() as sess: 177 sorted_array, expected_array = sess.run([sorted_array, expected_array]) 178 self.assertAllEqual(expected_array, sorted_array) 179 180 181class TestArrayShapeChecks(test.TestCase): 182 183 def _test_array_shape_dynamic_checks(self, static_shape, dynamic_shape, 184 batch_size, beam_width, is_valid=True): 185 t = array_ops.placeholder_with_default( 186 np.random.randn(*static_shape).astype(np.float32), 187 shape=dynamic_shape) 188 189 batch_size = array_ops.constant(batch_size) 190 191 def _test_body(): 192 # pylint: disable=protected-access 193 if context.executing_eagerly(): 194 beam_search_decoder._check_batch_beam(t, batch_size, beam_width) 195 else: 196 with self.cached_session(): 197 check_op = beam_search_decoder._check_batch_beam( 198 t, batch_size, beam_width) 199 self.evaluate(check_op) 200 # pylint: enable=protected-access 201 202 if is_valid: 203 _test_body() 204 else: 205 with self.assertRaises(errors.InvalidArgumentError): 206 _test_body() 207 208 def test_array_shape_dynamic_checks(self): 209 self._test_array_shape_dynamic_checks( 210 (8, 4, 5, 10), (None, None, 5, 10), 4, 5, is_valid=True) 211 self._test_array_shape_dynamic_checks( 212 (8, 20, 10), (None, None, 10), 4, 5, is_valid=True) 213 self._test_array_shape_dynamic_checks( 214 (8, 21, 10), (None, None, 10), 4, 5, is_valid=False) 215 self._test_array_shape_dynamic_checks( 216 (8, 4, 6, 10), (None, None, None, 10), 4, 5, is_valid=False) 217 self._test_array_shape_dynamic_checks( 218 (8, 4), (None, None), 4, 5, is_valid=False) 219 220 221class TestEosMasking(test.TestCase): 222 """Tests EOS masking used in beam search.""" 223 224 def test_eos_masking(self): 225 probs = constant_op.constant([ 226 [[-.2, -.2, -.2, -.2, -.2], [-.3, -.3, -.3, 3, 0], [5, 6, 0, 0, 0]], 227 [[-.2, -.2, -.2, -.2, 0], [-.3, -.3, -.1, 3, 0], [5, 6, 3, 0, 0]], 228 ]) 229 230 eos_token = 0 231 previously_finished = np.array([[0, 1, 0], [0, 1, 1]], dtype=bool) 232 masked = beam_search_decoder._mask_probs(probs, eos_token, 233 previously_finished) 234 235 with self.cached_session() as sess: 236 probs = sess.run(probs) 237 masked = sess.run(masked) 238 239 self.assertAllEqual(probs[0][0], masked[0][0]) 240 self.assertAllEqual(probs[0][2], masked[0][2]) 241 self.assertAllEqual(probs[1][0], masked[1][0]) 242 243 self.assertEqual(masked[0][1][0], 0) 244 self.assertEqual(masked[1][1][0], 0) 245 self.assertEqual(masked[1][2][0], 0) 246 247 for i in range(1, 5): 248 self.assertAllClose(masked[0][1][i], np.finfo('float32').min) 249 self.assertAllClose(masked[1][1][i], np.finfo('float32').min) 250 self.assertAllClose(masked[1][2][i], np.finfo('float32').min) 251 252 253class TestBeamStep(test.TestCase): 254 """Tests a single step of beam search.""" 255 256 def setUp(self): 257 super(TestBeamStep, self).setUp() 258 self.batch_size = 2 259 self.beam_width = 3 260 self.vocab_size = 5 261 self.end_token = 0 262 self.length_penalty_weight = 0.6 263 self.coverage_penalty_weight = 0.0 264 265 def test_step(self): 266 dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width]) 267 beam_state = beam_search_decoder.BeamSearchDecoderState( 268 cell_state=dummy_cell_state, 269 log_probs=nn_ops.log_softmax( 270 array_ops.ones([self.batch_size, self.beam_width])), 271 lengths=constant_op.constant( 272 2, shape=[self.batch_size, self.beam_width], dtype=dtypes.int64), 273 finished=array_ops.zeros( 274 [self.batch_size, self.beam_width], dtype=dtypes.bool), 275 accumulated_attention_probs=()) 276 277 logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size], 278 0.0001) 279 logits_[0, 0, 2] = 1.9 280 logits_[0, 0, 3] = 2.1 281 logits_[0, 1, 3] = 3.1 282 logits_[0, 1, 4] = 0.9 283 logits_[1, 0, 1] = 0.5 284 logits_[1, 1, 2] = 2.7 285 logits_[1, 2, 2] = 10.0 286 logits_[1, 2, 3] = 0.2 287 logits = ops.convert_to_tensor(logits_, dtype=dtypes.float32) 288 log_probs = nn_ops.log_softmax(logits) 289 290 outputs, next_beam_state = beam_search_decoder._beam_search_step( 291 time=2, 292 logits=logits, 293 next_cell_state=dummy_cell_state, 294 beam_state=beam_state, 295 batch_size=ops.convert_to_tensor(self.batch_size), 296 beam_width=self.beam_width, 297 end_token=self.end_token, 298 length_penalty_weight=self.length_penalty_weight, 299 coverage_penalty_weight=self.coverage_penalty_weight) 300 301 with self.cached_session() as sess: 302 outputs_, next_state_, state_, log_probs_ = sess.run( 303 [outputs, next_beam_state, beam_state, log_probs]) 304 305 self.assertAllEqual(outputs_.predicted_ids, [[3, 3, 2], [2, 2, 1]]) 306 self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [2, 1, 0]]) 307 self.assertAllEqual(next_state_.lengths, [[3, 3, 3], [3, 3, 3]]) 308 self.assertAllEqual(next_state_.finished, 309 [[False, False, False], [False, False, False]]) 310 311 expected_log_probs = [] 312 expected_log_probs.append(state_.log_probs[0][[1, 0, 0]]) 313 expected_log_probs.append(state_.log_probs[1][[2, 1, 0]]) # 0 --> 1 314 expected_log_probs[0][0] += log_probs_[0, 1, 3] 315 expected_log_probs[0][1] += log_probs_[0, 0, 3] 316 expected_log_probs[0][2] += log_probs_[0, 0, 2] 317 expected_log_probs[1][0] += log_probs_[1, 2, 2] 318 expected_log_probs[1][1] += log_probs_[1, 1, 2] 319 expected_log_probs[1][2] += log_probs_[1, 0, 1] 320 self.assertAllEqual(next_state_.log_probs, expected_log_probs) 321 322 def test_step_with_eos(self): 323 dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width]) 324 beam_state = beam_search_decoder.BeamSearchDecoderState( 325 cell_state=dummy_cell_state, 326 log_probs=nn_ops.log_softmax( 327 array_ops.ones([self.batch_size, self.beam_width])), 328 lengths=ops.convert_to_tensor( 329 [[2, 1, 2], [2, 2, 1]], dtype=dtypes.int64), 330 finished=ops.convert_to_tensor( 331 [[False, True, False], [False, False, True]], dtype=dtypes.bool), 332 accumulated_attention_probs=()) 333 334 logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size], 335 0.0001) 336 logits_[0, 0, 2] = 1.9 337 logits_[0, 0, 3] = 2.1 338 logits_[0, 1, 3] = 3.1 339 logits_[0, 1, 4] = 0.9 340 logits_[1, 0, 1] = 0.5 341 logits_[1, 1, 2] = 5.7 # why does this not work when it's 2.7? 342 logits_[1, 2, 2] = 1.0 343 logits_[1, 2, 3] = 0.2 344 logits = ops.convert_to_tensor(logits_, dtype=dtypes.float32) 345 log_probs = nn_ops.log_softmax(logits) 346 347 outputs, next_beam_state = beam_search_decoder._beam_search_step( 348 time=2, 349 logits=logits, 350 next_cell_state=dummy_cell_state, 351 beam_state=beam_state, 352 batch_size=ops.convert_to_tensor(self.batch_size), 353 beam_width=self.beam_width, 354 end_token=self.end_token, 355 length_penalty_weight=self.length_penalty_weight, 356 coverage_penalty_weight=self.coverage_penalty_weight) 357 358 with self.cached_session() as sess: 359 outputs_, next_state_, state_, log_probs_ = sess.run( 360 [outputs, next_beam_state, beam_state, log_probs]) 361 362 self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [1, 2, 0]]) 363 self.assertAllEqual(outputs_.predicted_ids, [[0, 3, 2], [2, 0, 1]]) 364 self.assertAllEqual(next_state_.lengths, [[1, 3, 3], [3, 1, 3]]) 365 self.assertAllEqual(next_state_.finished, 366 [[True, False, False], [False, True, False]]) 367 368 expected_log_probs = [] 369 expected_log_probs.append(state_.log_probs[0][[1, 0, 0]]) 370 expected_log_probs.append(state_.log_probs[1][[1, 2, 0]]) 371 expected_log_probs[0][1] += log_probs_[0, 0, 3] 372 expected_log_probs[0][2] += log_probs_[0, 0, 2] 373 expected_log_probs[1][0] += log_probs_[1, 1, 2] 374 expected_log_probs[1][2] += log_probs_[1, 0, 1] 375 self.assertAllEqual(next_state_.log_probs, expected_log_probs) 376 377 378class TestLargeBeamStep(test.TestCase): 379 """Tests large beam step. 380 381 Tests a single step of beam search in such case that beam size is larger than 382 vocabulary size. 383 """ 384 385 def setUp(self): 386 super(TestLargeBeamStep, self).setUp() 387 self.batch_size = 2 388 self.beam_width = 8 389 self.vocab_size = 5 390 self.end_token = 0 391 self.length_penalty_weight = 0.6 392 self.coverage_penalty_weight = 0.0 393 394 def test_step(self): 395 396 def get_probs(): 397 """this simulates the initialize method in BeamSearchDecoder.""" 398 log_prob_mask = array_ops.one_hot( 399 array_ops.zeros([self.batch_size], dtype=dtypes.int32), 400 depth=self.beam_width, 401 on_value=True, 402 off_value=False, 403 dtype=dtypes.bool) 404 405 log_prob_zeros = array_ops.zeros( 406 [self.batch_size, self.beam_width], dtype=dtypes.float32) 407 log_prob_neg_inf = array_ops.ones( 408 [self.batch_size, self.beam_width], dtype=dtypes.float32) * -np.Inf 409 410 log_probs = array_ops.where(log_prob_mask, log_prob_zeros, 411 log_prob_neg_inf) 412 return log_probs 413 414 log_probs = get_probs() 415 dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width]) 416 417 # pylint: disable=invalid-name 418 _finished = array_ops.one_hot( 419 array_ops.zeros([self.batch_size], dtype=dtypes.int32), 420 depth=self.beam_width, 421 on_value=False, 422 off_value=True, 423 dtype=dtypes.bool) 424 _lengths = np.zeros([self.batch_size, self.beam_width], dtype=np.int64) 425 _lengths[:, 0] = 2 426 _lengths = constant_op.constant(_lengths, dtype=dtypes.int64) 427 428 beam_state = beam_search_decoder.BeamSearchDecoderState( 429 cell_state=dummy_cell_state, 430 log_probs=log_probs, 431 lengths=_lengths, 432 finished=_finished, 433 accumulated_attention_probs=()) 434 435 logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size], 436 0.0001) 437 logits_[0, 0, 2] = 1.9 438 logits_[0, 0, 3] = 2.1 439 logits_[0, 1, 3] = 3.1 440 logits_[0, 1, 4] = 0.9 441 logits_[1, 0, 1] = 0.5 442 logits_[1, 1, 2] = 2.7 443 logits_[1, 2, 2] = 10.0 444 logits_[1, 2, 3] = 0.2 445 logits = constant_op.constant(logits_, dtype=dtypes.float32) 446 log_probs = nn_ops.log_softmax(logits) 447 448 outputs, next_beam_state = beam_search_decoder._beam_search_step( 449 time=2, 450 logits=logits, 451 next_cell_state=dummy_cell_state, 452 beam_state=beam_state, 453 batch_size=ops.convert_to_tensor(self.batch_size), 454 beam_width=self.beam_width, 455 end_token=self.end_token, 456 length_penalty_weight=self.length_penalty_weight, 457 coverage_penalty_weight=self.coverage_penalty_weight) 458 459 with self.cached_session() as sess: 460 outputs_, next_state_, _, _ = sess.run( 461 [outputs, next_beam_state, beam_state, log_probs]) 462 463 self.assertEqual(outputs_.predicted_ids[0, 0], 3) 464 self.assertEqual(outputs_.predicted_ids[0, 1], 2) 465 self.assertEqual(outputs_.predicted_ids[1, 0], 1) 466 neg_inf = -np.Inf 467 self.assertAllEqual( 468 next_state_.log_probs[:, -3:], 469 [[neg_inf, neg_inf, neg_inf], [neg_inf, neg_inf, neg_inf]]) 470 self.assertEqual((next_state_.log_probs[:, :-3] > neg_inf).all(), True) 471 self.assertEqual((next_state_.lengths[:, :-3] > 0).all(), True) 472 self.assertAllEqual(next_state_.lengths[:, -3:], [[0, 0, 0], [0, 0, 0]]) 473 474 475@test_util.run_v1_only 476class BeamSearchDecoderTest(test.TestCase): 477 478 def _testDynamicDecodeRNN(self, time_major, has_attention, 479 with_alignment_history=False): 480 encoder_sequence_length = np.array([3, 2, 3, 1, 1]) 481 decoder_sequence_length = np.array([2, 0, 1, 2, 3]) 482 batch_size = 5 483 decoder_max_time = 4 484 input_depth = 7 485 cell_depth = 9 486 attention_depth = 6 487 vocab_size = 20 488 end_token = vocab_size - 1 489 start_token = 0 490 embedding_dim = 50 491 max_out = max(decoder_sequence_length) 492 output_layer = layers_core.Dense(vocab_size, use_bias=True, activation=None) 493 beam_width = 3 494 495 with self.cached_session() as sess: 496 batch_size_tensor = constant_op.constant(batch_size) 497 embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32) 498 cell = rnn_cell.LSTMCell(cell_depth) 499 initial_state = cell.zero_state(batch_size, dtypes.float32) 500 coverage_penalty_weight = 0.0 501 if has_attention: 502 coverage_penalty_weight = 0.2 503 inputs = array_ops.placeholder_with_default( 504 np.random.randn(batch_size, decoder_max_time, input_depth).astype( 505 np.float32), 506 shape=(None, None, input_depth)) 507 tiled_inputs = beam_search_decoder.tile_batch( 508 inputs, multiplier=beam_width) 509 tiled_sequence_length = beam_search_decoder.tile_batch( 510 encoder_sequence_length, multiplier=beam_width) 511 attention_mechanism = attention_wrapper.BahdanauAttention( 512 num_units=attention_depth, 513 memory=tiled_inputs, 514 memory_sequence_length=tiled_sequence_length) 515 initial_state = beam_search_decoder.tile_batch( 516 initial_state, multiplier=beam_width) 517 cell = attention_wrapper.AttentionWrapper( 518 cell=cell, 519 attention_mechanism=attention_mechanism, 520 attention_layer_size=attention_depth, 521 alignment_history=with_alignment_history) 522 cell_state = cell.zero_state( 523 dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width) 524 if has_attention: 525 cell_state = cell_state.clone(cell_state=initial_state) 526 bsd = beam_search_decoder.BeamSearchDecoder( 527 cell=cell, 528 embedding=embedding, 529 start_tokens=array_ops.fill([batch_size_tensor], start_token), 530 end_token=end_token, 531 initial_state=cell_state, 532 beam_width=beam_width, 533 output_layer=output_layer, 534 length_penalty_weight=0.0, 535 coverage_penalty_weight=coverage_penalty_weight) 536 537 final_outputs, final_state, final_sequence_lengths = ( 538 decoder.dynamic_decode( 539 bsd, output_time_major=time_major, maximum_iterations=max_out)) 540 541 def _t(shape): 542 if time_major: 543 return (shape[1], shape[0]) + shape[2:] 544 return shape 545 546 self.assertIsInstance( 547 final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput) 548 self.assertIsInstance( 549 final_state, beam_search_decoder.BeamSearchDecoderState) 550 551 beam_search_decoder_output = final_outputs.beam_search_decoder_output 552 self.assertEqual( 553 _t((batch_size, None, beam_width)), 554 tuple(beam_search_decoder_output.scores.get_shape().as_list())) 555 self.assertEqual( 556 _t((batch_size, None, beam_width)), 557 tuple(final_outputs.predicted_ids.get_shape().as_list())) 558 559 sess.run(variables.global_variables_initializer()) 560 sess_results = sess.run({ 561 'final_outputs': final_outputs, 562 'final_state': final_state, 563 'final_sequence_lengths': final_sequence_lengths 564 }) 565 566 max_sequence_length = np.max(sess_results['final_sequence_lengths']) 567 568 # A smoke test 569 self.assertEqual( 570 _t((batch_size, max_sequence_length, beam_width)), 571 sess_results['final_outputs'].beam_search_decoder_output.scores.shape) 572 self.assertEqual( 573 _t((batch_size, max_sequence_length, beam_width)), sess_results[ 574 'final_outputs'].beam_search_decoder_output.predicted_ids.shape) 575 576 def testDynamicDecodeRNNBatchMajorNoAttention(self): 577 self._testDynamicDecodeRNN(time_major=False, has_attention=False) 578 579 def testDynamicDecodeRNNBatchMajorYesAttention(self): 580 self._testDynamicDecodeRNN(time_major=False, has_attention=True) 581 582 def testDynamicDecodeRNNBatchMajorYesAttentionWithAlignmentHistory(self): 583 self._testDynamicDecodeRNN( 584 time_major=False, 585 has_attention=True, 586 with_alignment_history=True) 587 588 589@test_util.run_all_in_graph_and_eager_modes 590class BeamSearchDecoderV2Test(test.TestCase): 591 592 def _testDynamicDecodeRNN(self, time_major, has_attention, 593 with_alignment_history=False): 594 encoder_sequence_length = np.array([3, 2, 3, 1, 1]) 595 decoder_sequence_length = np.array([2, 0, 1, 2, 3]) 596 batch_size = 5 597 decoder_max_time = 4 598 input_depth = 7 599 cell_depth = 9 600 attention_depth = 6 601 vocab_size = 20 602 end_token = vocab_size - 1 603 start_token = 0 604 embedding_dim = 50 605 max_out = max(decoder_sequence_length) 606 output_layer = layers.Dense(vocab_size, use_bias=True, activation=None) 607 beam_width = 3 608 609 with self.cached_session(): 610 batch_size_tensor = constant_op.constant(batch_size) 611 embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32) 612 cell = rnn_cell.LSTMCell(cell_depth) 613 initial_state = cell.zero_state(batch_size, dtypes.float32) 614 coverage_penalty_weight = 0.0 615 if has_attention: 616 coverage_penalty_weight = 0.2 617 inputs = array_ops.placeholder_with_default( 618 np.random.randn(batch_size, decoder_max_time, input_depth).astype( 619 np.float32), 620 shape=(None, None, input_depth)) 621 tiled_inputs = beam_search_decoder.tile_batch( 622 inputs, multiplier=beam_width) 623 tiled_sequence_length = beam_search_decoder.tile_batch( 624 encoder_sequence_length, multiplier=beam_width) 625 attention_mechanism = attention_wrapper.BahdanauAttention( 626 num_units=attention_depth, 627 memory=tiled_inputs, 628 memory_sequence_length=tiled_sequence_length) 629 initial_state = beam_search_decoder.tile_batch( 630 initial_state, multiplier=beam_width) 631 cell = attention_wrapper.AttentionWrapper( 632 cell=cell, 633 attention_mechanism=attention_mechanism, 634 attention_layer_size=attention_depth, 635 alignment_history=with_alignment_history) 636 cell_state = cell.zero_state( 637 dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width) 638 if has_attention: 639 cell_state = cell_state.clone(cell_state=initial_state) 640 bsd = beam_search_decoder.BeamSearchDecoderV2( 641 cell=cell, 642 beam_width=beam_width, 643 output_layer=output_layer, 644 length_penalty_weight=0.0, 645 coverage_penalty_weight=coverage_penalty_weight, 646 output_time_major=time_major, 647 maximum_iterations=max_out) 648 649 final_outputs, final_state, final_sequence_lengths = bsd( 650 embedding, 651 start_tokens=array_ops.fill([batch_size_tensor], start_token), 652 end_token=end_token, 653 initial_state=cell_state) 654 655 def _t(shape): 656 if time_major: 657 return (shape[1], shape[0]) + shape[2:] 658 return shape 659 660 self.assertIsInstance( 661 final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput) 662 self.assertIsInstance( 663 final_state, beam_search_decoder.BeamSearchDecoderState) 664 665 beam_search_decoder_output = final_outputs.beam_search_decoder_output 666 expected_seq_length = 3 if context.executing_eagerly() else None 667 self.assertEqual( 668 _t((batch_size, expected_seq_length, beam_width)), 669 tuple(beam_search_decoder_output.scores.get_shape().as_list())) 670 self.assertEqual( 671 _t((batch_size, expected_seq_length, beam_width)), 672 tuple(final_outputs.predicted_ids.get_shape().as_list())) 673 674 self.evaluate(variables.global_variables_initializer()) 675 eval_results = self.evaluate({ 676 'final_outputs': final_outputs, 677 'final_sequence_lengths': final_sequence_lengths 678 }) 679 680 max_sequence_length = np.max(eval_results['final_sequence_lengths']) 681 682 # A smoke test 683 self.assertEqual( 684 _t((batch_size, max_sequence_length, beam_width)), 685 eval_results['final_outputs'].beam_search_decoder_output.scores.shape) 686 self.assertEqual( 687 _t((batch_size, max_sequence_length, beam_width)), eval_results[ 688 'final_outputs'].beam_search_decoder_output.predicted_ids.shape) 689 690 def testDynamicDecodeRNNBatchMajorNoAttention(self): 691 self._testDynamicDecodeRNN(time_major=False, has_attention=False) 692 693 def testDynamicDecodeRNNBatchMajorYesAttention(self): 694 self._testDynamicDecodeRNN(time_major=False, has_attention=True) 695 696 def testDynamicDecodeRNNBatchMajorYesAttentionWithAlignmentHistory(self): 697 self._testDynamicDecodeRNN( 698 time_major=False, 699 has_attention=True, 700 with_alignment_history=True) 701 702 703if __name__ == '__main__': 704 test.main() 705