1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
7#     http://www.apache.org/licenses/LICENSE-2.0
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""LSTM Block Cell ops."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
21from absl.testing import parameterized
22import numpy as np
24from tensorflow.contrib.rnn.python.kernel_tests import benchmarking
25from tensorflow.contrib.rnn.python.ops import lstm_ops
26from tensorflow.python.client import session
27from tensorflow.python.framework import constant_op
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import gen_array_ops
32from tensorflow.python.ops import gen_bitwise_ops
33from tensorflow.python.ops import gradients_impl
34from tensorflow.python.ops import init_ops
35from tensorflow.python.ops import rnn
36from tensorflow.python.ops import rnn_cell
37from tensorflow.python.ops import variable_scope
38from tensorflow.python.ops import variables
39from tensorflow.python.platform import test
41block_lstm = lstm_ops._block_lstm  # pylint: disable=protected-access
44class _MaskedRandomUniformInitializer(init_ops.RandomUniform):
45  """Initializer for uniform dist tensors with trailing bits zeroed-out.
47  Allow returning tensors with last few mantissa bits set to 0. This potentially
48  helps avoid getting into precision issues when testing low precision (float16)
49  computation.
50  """
52  def __init__(self,
53               minval=0,
54               maxval=None,
55               seed=None,
56               dtype=dtypes.float16,
57               num_valid_mantissa_bits=4):
58    """Constructor.
60    Args:
61      minval: A python scalar or a scalar tensor. Lower bound of the range of
62        random values to generate.
63      maxval: A python scalar or a scalar tensor. Upper bound of the range of
64        random values to generate.  Defaults to 1 for float types.
65      seed: A Python integer. Used to create random seeds. See
66        `tf.set_random_seed` for behavior.
67      dtype: The data type. Only supports tf.float16 for now.
68      num_valid_mantissa_bits: number of non-zero mantissa bits, default to 4.
70    Raises:
71      ValueError: An error if `dtype` is not tf.float16.
72    """
73    if dtype not in (dtypes.float16,):
74      raise ValueError("dtype: %s not supported" % dtype.name)
76    super(_MaskedRandomUniformInitializer, self).__init__(
77        minval=minval, maxval=maxval, seed=seed, dtype=dtype)
78    self._num_mantissa_bits = 10
79    self._num_valid_mantissa_bits = num_valid_mantissa_bits
81  def __call__(self, shape, dtype=dtypes.float16, partition_info=None):
82    if dtype and dtype != dtypes.float16:
83      raise ValueError("dtype: %s not supported" % dtype.name)
84    res = super(_MaskedRandomUniformInitializer, self).__call__(
85        shape, dtype, partition_info)
86    # get uint16 view of the underlying buffer.
87    res = gen_array_ops.bitcast(res, dtypes.uint16)
89    # mask the last `shift` mantissa bits.
90    shift = self._num_mantissa_bits - self._num_valid_mantissa_bits
91    mask = (0xffff >> shift) << shift
92    res = gen_bitwise_ops.bitwise_and(res, mask)
94    # restore float16 view.
95    return gen_array_ops.bitcast(res, dtype)
98def _get_initializer(init_bound, dtype, seed):
99  if dtype == dtypes.float16:
100    return _MaskedRandomUniformInitializer(
101        -init_bound, init_bound, dtype=dtype, seed=seed)
102  else:
103    return init_ops.random_uniform_initializer(
104        -init_bound, init_bound, dtype=dtype, seed=seed)
107def blocks_match(sess, use_peephole, dtype=dtypes.float32, cell_clip=None):
108  batch_size = 2
109  input_size = 3
110  cell_size = 4
111  sequence_length = 4
113  inputs = []
114  for _ in range(sequence_length):
115    inp = ops.convert_to_tensor(
116        np.random.randn(batch_size, input_size), dtype=dtype)
117    inputs.append(inp)
118  stacked_inputs = array_ops.stack(inputs)
120  init_bound = 1e-1 if dtype == dtypes.float16 else 1e-2
121  initializer = _get_initializer(init_bound, dtype=dtype, seed=19890212)
123  with variable_scope.variable_scope("test", initializer=initializer):
124    # magic naming so that the cells pick up these variables and reuse them
125    if use_peephole:
126      wci = variable_scope.get_variable(
127          "rnn/lstm_cell/w_i_diag", shape=[cell_size], dtype=dtype)
128      wcf = variable_scope.get_variable(
129          "rnn/lstm_cell/w_f_diag", shape=[cell_size], dtype=dtype)
130      wco = variable_scope.get_variable(
131          "rnn/lstm_cell/w_o_diag", shape=[cell_size], dtype=dtype)
133    w = variable_scope.get_variable(
134        "rnn/lstm_cell/kernel",
135        shape=[input_size + cell_size, cell_size * 4],
136        dtype=dtype)
137    b = variable_scope.get_variable(
138        "rnn/lstm_cell/bias",
139        shape=[cell_size * 4],
140        dtype=dtype,
141        initializer=init_ops.zeros_initializer())
143    basic_cell = rnn_cell.LSTMCell(
144        cell_size,
145        use_peepholes=use_peephole,
146        cell_clip=cell_clip,
147        dtype=dtype,
148        state_is_tuple=True,
149        reuse=True)
150    basic_outputs_op, basic_state_op = rnn.static_rnn(
151        basic_cell, inputs, dtype=dtype)
153    if use_peephole:
154      _, _, _, _, _, _, block_outputs_op = block_lstm(
155          ops.convert_to_tensor(sequence_length, dtype=dtypes.int64),
156          inputs,
157          w,
158          b,
159          wci=wci,
160          wcf=wcf,
161          wco=wco,
162          cell_clip=cell_clip,
163          use_peephole=True)
164    else:
165      _, _, _, _, _, _, block_outputs_op = block_lstm(
166          ops.convert_to_tensor(sequence_length, dtype=dtypes.int64),
167          inputs,
168          w,
169          b,
170          cell_clip=cell_clip)
172    fused_cell = lstm_ops.LSTMBlockFusedCell(
173        cell_size,
174        cell_clip=cell_clip,
175        use_peephole=use_peephole,
176        reuse=True,
177        name="rnn/lstm_cell")
178    fused_outputs_op, fused_state_op = fused_cell(stacked_inputs, dtype=dtype)
180    sess.run([variables.global_variables_initializer()])
181    basic_outputs, basic_state = sess.run([basic_outputs_op, basic_state_op[0]])
182    basic_grads = sess.run(gradients_impl.gradients(basic_outputs_op, inputs))
183    xs = [w, b]
184    if use_peephole:
185      xs += [wci, wcf, wco]
186    basic_wgrads = sess.run(gradients_impl.gradients(basic_outputs_op, xs))
188    block_outputs = sess.run(block_outputs_op)
189    block_grads = sess.run(gradients_impl.gradients(block_outputs_op, inputs))
190    block_wgrads = sess.run(gradients_impl.gradients(block_outputs_op, xs))
192    xs = [w, b]
193    if use_peephole:
194      xs += [wci, wcf, wco]
195    fused_outputs, fused_state = sess.run([fused_outputs_op, fused_state_op[0]])
196    fused_grads = sess.run(gradients_impl.gradients(fused_outputs_op, inputs))
197    fused_wgrads = sess.run(gradients_impl.gradients(fused_outputs_op, xs))
199    return (basic_state, fused_state, basic_outputs, block_outputs,
200            fused_outputs, basic_grads, block_grads, fused_grads, basic_wgrads,
201            block_wgrads, fused_wgrads)
204class LSTMBlockCellTest(test.TestCase, parameterized.TestCase):
206  TEST_CASES = ({
207      "testcase_name": "Fp32",
208      "dtype": dtypes.float32,
209      "rtol": 1e-6,
210      "atol": 1e-6
211  }, {
212      "testcase_name": "Fp16",
213      "dtype": dtypes.float16,
214      "rtol": 8e-3,
215      "atol": 8e-4
216  })
218  def testNoneDimsWithDynamicRNN(self):
219    with self.session(use_gpu=True, graph=ops.Graph()) as sess:
220      batch_size = 4
221      num_steps = 5
222      input_dim = 6
223      cell_size = 7
225      cell = lstm_ops.LSTMBlockCell(cell_size)
226      x = array_ops.placeholder(dtypes.float32, shape=(None, None, input_dim))
228      output, _ = rnn.dynamic_rnn(
229          cell, x, time_major=True, dtype=dtypes.float32)
230      sess.run(variables.global_variables_initializer())
231      feed = {}
232      feed[x] = np.random.randn(num_steps, batch_size, input_dim)
233      sess.run(output, feed)
235  def testLSTMBlockCell(self):
236    with self.session(use_gpu=True, graph=ops.Graph()) as sess:
237      with variable_scope.variable_scope(
238          "root", initializer=init_ops.constant_initializer(0.5)):
239        x = array_ops.zeros([1, 2])
240        m0 = array_ops.zeros([1, 2])
241        m1 = array_ops.zeros([1, 2])
242        m2 = array_ops.zeros([1, 2])
243        m3 = array_ops.zeros([1, 2])
244        g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
245            [lstm_ops.LSTMBlockCell(2)
246             for _ in range(2)], state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
247        sess.run([variables.global_variables_initializer()])
248        res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
249            x.name: np.array([[1., 1.]]),
250            m0.name: 0.1 * np.ones([1, 2]),
251            m1.name: 0.1 * np.ones([1, 2]),
252            m2.name: 0.1 * np.ones([1, 2]),
253            m3.name: 0.1 * np.ones([1, 2])
254        })
255        self.assertEqual(len(res), 5)
256        self.assertAllClose(res[0], [[0.24024698, 0.24024698]])
257        # These numbers are from testBasicLSTMCell and only test c/h.
258        self.assertAllClose(res[1], [[0.68967271, 0.68967271]])
259        self.assertAllClose(res[2], [[0.44848421, 0.44848421]])
260        self.assertAllClose(res[3], [[0.39897051, 0.39897051]])
261        self.assertAllClose(res[4], [[0.24024698, 0.24024698]])
263  def testCompatibleNames(self):
264    with self.session(use_gpu=True, graph=ops.Graph()):
265      cell = rnn_cell.LSTMCell(10)
266      pcell = rnn_cell.LSTMCell(10, use_peepholes=True)
267      inputs = [array_ops.zeros([4, 5])] * 6
268      rnn.static_rnn(cell, inputs, dtype=dtypes.float32, scope="basic")
269      rnn.static_rnn(pcell, inputs, dtype=dtypes.float32, scope="peephole")
270      basic_names = {
271          v.name: v.get_shape()
272          for v in variables.trainable_variables()
273      }
275    with self.session(use_gpu=True, graph=ops.Graph()):
276      cell = lstm_ops.LSTMBlockCell(10)
277      pcell = lstm_ops.LSTMBlockCell(10, use_peephole=True)
278      inputs = [array_ops.zeros([4, 5])] * 6
279      rnn.static_rnn(cell, inputs, dtype=dtypes.float32, scope="basic")
280      rnn.static_rnn(pcell, inputs, dtype=dtypes.float32, scope="peephole")
281      block_names = {
282          v.name: v.get_shape()
283          for v in variables.trainable_variables()
284      }
286    with self.session(use_gpu=True, graph=ops.Graph()):
287      cell = lstm_ops.LSTMBlockFusedCell(10)
288      pcell = lstm_ops.LSTMBlockFusedCell(10, use_peephole=True)
289      inputs = array_ops.stack([array_ops.zeros([4, 5])] * 6)
290      cell(inputs, dtype=dtypes.float32, scope="basic/lstm_cell")
291      pcell(inputs, dtype=dtypes.float32, scope="peephole/lstm_cell")
292      fused_names = {
293          v.name: v.get_shape()
294          for v in variables.trainable_variables()
295      }
297    self.assertEqual(basic_names, block_names)
298    self.assertEqual(basic_names, fused_names)
300  def testLSTMBasicToBlockCell(self):
301    with self.session(use_gpu=True) as sess:
302      x = array_ops.zeros([1, 2])
303      x_values = np.random.randn(1, 2)
305      m0_val = 0.1 * np.ones([1, 2])
306      m1_val = -0.1 * np.ones([1, 2])
307      m2_val = -0.2 * np.ones([1, 2])
308      m3_val = 0.2 * np.ones([1, 2])
310      initializer = init_ops.random_uniform_initializer(
311          -0.01, 0.01, seed=19890212)
312      with variable_scope.variable_scope("basic", initializer=initializer):
313        m0 = array_ops.zeros([1, 2])
314        m1 = array_ops.zeros([1, 2])
315        m2 = array_ops.zeros([1, 2])
316        m3 = array_ops.zeros([1, 2])
317        g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
318            [rnn_cell.BasicLSTMCell(2, state_is_tuple=True) for _ in range(2)],
319            state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
320        sess.run([variables.global_variables_initializer()])
321        basic_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
322            x.name: x_values,
323            m0.name: m0_val,
324            m1.name: m1_val,
325            m2.name: m2_val,
326            m3.name: m3_val
327        })
329      with variable_scope.variable_scope("block", initializer=initializer):
330        m0 = array_ops.zeros([1, 2])
331        m1 = array_ops.zeros([1, 2])
332        m2 = array_ops.zeros([1, 2])
333        m3 = array_ops.zeros([1, 2])
334        g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
335            [lstm_ops.LSTMBlockCell(2)
336             for _ in range(2)], state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
337        sess.run([variables.global_variables_initializer()])
338        block_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
339            x.name: x_values,
340            m0.name: m0_val,
341            m1.name: m1_val,
342            m2.name: m2_val,
343            m3.name: m3_val
344        })
346      self.assertEqual(len(basic_res), len(block_res))
347      for basic, block in zip(basic_res, block_res):
348        self.assertAllClose(basic, block)
350  def testLSTMBasicToBlockCellPeeping(self):
351    with self.session(use_gpu=True) as sess:
352      x = array_ops.zeros([1, 2])
353      x_values = np.random.randn(1, 2)
355      m0_val = 0.1 * np.ones([1, 2])
356      m1_val = -0.1 * np.ones([1, 2])
357      m2_val = -0.2 * np.ones([1, 2])
358      m3_val = 0.2 * np.ones([1, 2])
360      initializer = init_ops.random_uniform_initializer(
361          -0.01, 0.01, seed=19890212)
362      with variable_scope.variable_scope("basic", initializer=initializer):
363        m0 = array_ops.zeros([1, 2])
364        m1 = array_ops.zeros([1, 2])
365        m2 = array_ops.zeros([1, 2])
366        m3 = array_ops.zeros([1, 2])
367        g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
368            [
369                rnn_cell.LSTMCell(2, use_peepholes=True, state_is_tuple=True)
370                for _ in range(2)
371            ],
372            state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
373        sess.run([variables.global_variables_initializer()])
374        basic_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
375            x.name: x_values,
376            m0.name: m0_val,
377            m1.name: m1_val,
378            m2.name: m2_val,
379            m3.name: m3_val
380        })
382      with variable_scope.variable_scope("block", initializer=initializer):
383        m0 = array_ops.zeros([1, 2])
384        m1 = array_ops.zeros([1, 2])
385        m2 = array_ops.zeros([1, 2])
386        m3 = array_ops.zeros([1, 2])
387        g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
388            [lstm_ops.LSTMBlockCell(2, use_peephole=True) for _ in range(2)],
389            state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
390        sess.run([variables.global_variables_initializer()])
391        block_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
392            x.name: x_values,
393            m0.name: m0_val,
394            m1.name: m1_val,
395            m2.name: m2_val,
396            m3.name: m3_val
397        })
399      self.assertEqual(len(basic_res), len(block_res))
400      for basic, block in zip(basic_res, block_res):
401        self.assertAllClose(basic, block)
403  def LSTMBasicToBlockTestHelper(self,
404                                 dtype=dtypes.float32,
405                                 use_peephole=False,
406                                 cell_clip=None,
407                                 rtol=1e-6,
408                                 atol=1e-6):
409    with self.session(use_gpu=True, graph=ops.Graph()) as sess:
410      (basic_state, fused_state, basic_outputs, block_outputs, fused_outputs,
411       basic_grads, block_grads, fused_grads, basic_wgrads, block_wgrads,
412       fused_wgrads) = blocks_match(
413           sess, use_peephole=use_peephole, dtype=dtype, cell_clip=cell_clip)
415      self.assertAllClose(basic_outputs, block_outputs, rtol=rtol, atol=atol)
416      self.assertAllClose(basic_grads, block_grads, rtol=rtol, atol=atol)
417      for basic, block in zip(basic_wgrads, block_wgrads):
418        self.assertAllClose(basic, block, rtol=rtol, atol=atol)
420      self.assertAllClose(basic_outputs, fused_outputs, rtol=rtol, atol=atol)
421      self.assertAllClose(basic_state, fused_state, rtol=rtol, atol=atol)
422      self.assertAllClose(basic_grads, fused_grads, rtol=rtol, atol=atol)
423      for basic, fused in zip(basic_wgrads, fused_wgrads):
424        self.assertAllClose(basic, fused, rtol=rtol, atol=atol)
426  @parameterized.named_parameters(*TEST_CASES)
427  def testLSTMBasicToBlock(self, dtype, rtol, atol):
428    self.LSTMBasicToBlockTestHelper(
429        dtype, use_peephole=False, rtol=rtol, atol=atol)
431  @parameterized.named_parameters(*TEST_CASES)
432  def testLSTMBasicToBlockPeeping(self, dtype, rtol, atol):
433    self.LSTMBasicToBlockTestHelper(
434        dtype, use_peephole=True, rtol=rtol, atol=atol)
436  @parameterized.named_parameters(*TEST_CASES)
437  def testLSTMBasicToBlockCellClip(self, dtype, rtol, atol):
438    self.LSTMBasicToBlockTestHelper(
439        dtype, use_peephole=True, cell_clip=0.5, rtol=rtol, atol=atol)
441  def testLSTMFusedSequenceLengths(self):
442    """Verify proper support for sequence lengths in LSTMBlockFusedCell."""
443    with self.session(use_gpu=True) as sess:
444      batch_size = 3
445      input_size = 4
446      cell_size = 5
447      max_sequence_length = 6
449      inputs = []
450      for _ in range(max_sequence_length):
451        inp = ops.convert_to_tensor(
452            np.random.randn(batch_size, input_size), dtype=dtypes.float32)
453        inputs.append(inp)
454      seq_lengths = constant_op.constant([3, 4, 5])
455      cell_inputs = array_ops.stack(inputs)
457      initializer = init_ops.random_uniform_initializer(
458          -0.01, 0.01, seed=19890213)
460      with variable_scope.variable_scope("lstm_cell", initializer=initializer):
461        # magic naming so that the cells pick up these variables and reuse them
462        variable_scope.get_variable(
463            "kernel",
464            shape=[input_size + cell_size, cell_size * 4],
465            dtype=dtypes.float32)
467        variable_scope.get_variable(
468            "bias",
469            shape=[cell_size * 4],
470            dtype=dtypes.float32,
471            initializer=init_ops.zeros_initializer())
473      cell = lstm_ops.LSTMBlockFusedCell(
474          cell_size, cell_clip=0, use_peephole=False, reuse=True,
475          name="lstm_cell")
477      fused_outputs_op, fused_state_op = cell(
478          cell_inputs, dtype=dtypes.float32, sequence_length=seq_lengths)
480      cell_vars = [
481          v for v in variables.trainable_variables()
482          if v.name.endswith("kernel") or v.name.endswith("bias")
483      ]
485      # Verify that state propagation works if we turn our sequence into
486      # tiny (single-time) subsequences, i.e. unfuse the cell
487      unfused_outputs_op = []
488      state = None
489      with variable_scope.variable_scope(
490          variable_scope.get_variable_scope(), reuse=True):
491        for i, inp in enumerate(inputs):
492          lengths = [int(i < l) for l in seq_lengths.eval()]
493          output, state = cell(
494              array_ops.expand_dims(inp, 0),
495              initial_state=state,
496              dtype=dtypes.float32,
497              sequence_length=lengths)
498          unfused_outputs_op.append(output[0])
499      unfused_outputs_op = array_ops.stack(unfused_outputs_op)
501      sess.run([variables.global_variables_initializer()])
502      unfused_outputs, unfused_state = sess.run([unfused_outputs_op, state[0]])
503      unfused_grads = sess.run(
504          gradients_impl.gradients(unfused_outputs_op, inputs))
505      unfused_wgrads = sess.run(
506          gradients_impl.gradients(unfused_outputs_op, cell_vars))
508      fused_outputs, fused_state = sess.run(
509          [fused_outputs_op, fused_state_op[0]])
510      fused_grads = sess.run(gradients_impl.gradients(fused_outputs_op, inputs))
511      fused_wgrads = sess.run(
512          gradients_impl.gradients(fused_outputs_op, cell_vars))
514      self.assertAllClose(fused_outputs, unfused_outputs)
515      self.assertAllClose(fused_state, unfused_state)
516      self.assertAllClose(fused_grads, unfused_grads)
517      for fused, unfused in zip(fused_wgrads, unfused_wgrads):
518        self.assertAllClose(fused, unfused, rtol=1e-6, atol=1e-6)
520#### Benchmarking.
523class BenchmarkLSTMBlock(test.Benchmark):
525  def benchmarkLSTMBlockCellFpropWithDynamicRNN(self):
526    print("BlockLSTMCell forward propagation via dynamic_rnn().")
527    print("--------------------------------------------------------------")
528    print("LSTMBlockCell Seconds per inference.")
529    print("batch_size,cell_size,input_size,time_steps,use_gpu,wall_time")
530    iters = 10
531    for config in benchmarking.dict_product({
532        "batch_size": [1, 8, 13, 32, 67, 128],
533        "cell_size": [128, 250, 512, 650, 1024, 1350],
534        "time_steps": [40],
535        "use_gpu": [True, False],
536        "dtype": ["float32", "float16"],
537    }):
538      dtype = dtypes.float32 if config["dtype"] == "float32" else dtypes.float16
539      with ops.Graph().as_default():
540        with benchmarking.device(use_gpu=config["use_gpu"]):
541          inputs = variable_scope.get_variable(
542              "x",
543              dtype=dtype,
544              shape=[
545                  config["time_steps"], config["batch_size"],
546                  config["cell_size"]
547              ])
548          cell = lstm_ops.LSTMBlockCell(config["cell_size"], dtype=dtype)
549          outputs = rnn.dynamic_rnn(cell, inputs, time_major=True, dtype=dtype)
550          init_op = variables.global_variables_initializer()
552        with session.Session() as sess:
553          sess.run(init_op)
554          wall_time = benchmarking.seconds_per_run(outputs, sess, iters)
556        # Print to stdout. If the TEST_REPORT_FILE_PREFIX environment variable
557        # is set, this will produce a copy-paste-able CSV file.
558        print(",".join(
559            map(str, [
560                config["dtype"], config["batch_size"], config["cell_size"],
561                config["cell_size"], config["time_steps"], config["use_gpu"],
562                wall_time
563            ])))
564        benchmark_name_template = "_".join([
565            "LSTMBlockCell_fprop", "DT_%(dtype)s", "BS%(batch_size)i",
566            "CS%(cell_size)i", "IS%(cell_size)i", "TS%(time_steps)i",
567            "gpu_%(use_gpu)s"
568        ])
570        self.report_benchmark(
571            name=benchmark_name_template % config,
572            iters=iters,
573            wall_time=wall_time,
574            extras=config)
576  def benchmarkLSTMBlockCellBpropWithDynamicRNN(self):
577    print("BlockLSTMCell backward propagation via dynamic_rnn().")
578    print("--------------------------------------------------------------")
579    print("LSTMBlockCell Seconds per inference.")
580    print("batch_size,cell_size,input_size,time_steps,use_gpu,wall_time")
581    iters = 10
582    for config in benchmarking.dict_product({
583        "batch_size": [1, 8, 13, 32, 67, 128],
584        "cell_size": [128, 250, 512, 650, 1024, 1350],
585        "time_steps": [40],
586        "use_gpu": [True, False],
587        "dtype": ["float32", "float16"],
588    }):
589      dtype = dtypes.float32 if config["dtype"] == "float32" else dtypes.float16
590      with ops.Graph().as_default():
591        with benchmarking.device(use_gpu=config["use_gpu"]):
592          time_steps = config["time_steps"]
593          batch_size = config["batch_size"]
594          cell_size = input_size = config["cell_size"]
595          inputs = variable_scope.get_variable(
596              "x", [time_steps, batch_size, cell_size],
597              trainable=False,
598              dtype=dtype)
599          with variable_scope.variable_scope(
600              "rnn", reuse=variable_scope.AUTO_REUSE):
601            w = variable_scope.get_variable(
602                "rnn/lstm_cell/kernel",
603                shape=[input_size + cell_size, cell_size * 4],
604                dtype=dtype)
605            b = variable_scope.get_variable(
606                "rnn/lstm_cell/bias",
607                shape=[cell_size * 4],
608                dtype=dtype,
609                initializer=init_ops.zeros_initializer())
610            cell = lstm_ops.LSTMBlockCell(cell_size, dtype=dtype)
611            outputs = rnn.dynamic_rnn(
612                cell, inputs, time_major=True, dtype=dtype)
613          grads = gradients_impl.gradients(outputs, [inputs, w, b])
614          init_op = variables.global_variables_initializer()
616        with session.Session() as sess:
617          sess.run(init_op)
618          wall_time = benchmarking.seconds_per_run(grads, sess, iters)
620        # Print to stdout. If the TEST_REPORT_FILE_PREFIX environment variable
621        # is set, this will produce a copy-paste-able CSV file.
622        print(",".join(
623            map(str, [
624                config["dtype"], batch_size, cell_size, cell_size, time_steps,
625                config["use_gpu"], wall_time
626            ])))
627        benchmark_name_template = "_".join([
628            "LSTMBlockCell_bprop", "DT_%(dtype)s", "BS%(batch_size)i",
629            "CS%(cell_size)i", "IS%(cell_size)i", "TS%(time_steps)i",
630            "gpu_%(use_gpu)s"
631        ])
633        self.report_benchmark(
634            name=benchmark_name_template % config,
635            iters=iters,
636            wall_time=wall_time,
637            extras=config)
640if __name__ == "__main__":
641  test.main()