1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TfLite BasicRnnCell wrapper.
16
17TODO(renjieliu): Find a better home for this one.
18"""
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22import itertools
23
24import tensorflow.lite.python.op_hint as op_hint
25from tensorflow.python.keras import activations
26from tensorflow.python.keras import initializers
27from tensorflow.python.layers import base as base_layer
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import clip_ops
30from tensorflow.python.ops import init_ops
31from tensorflow.python.ops import math_ops
32from tensorflow.python.ops import nn_ops
33from tensorflow.python.ops import partitioned_variables
34from tensorflow.python.ops import rnn_cell_impl
35from tensorflow.python.platform import tf_logging as logging
36from tensorflow.python.util.tf_export import tf_export
37
38
39@tf_export("lite.experimental.nn.TfLiteRNNCell")
40class TfLiteRNNCell(rnn_cell_impl.LayerRNNCell):
41  """The most basic RNN cell.
42
43  This is used only for TfLite, it provides hints and it also makes the
44  variables in the desired for the tflite ops.
45  """
46
47  def __init__(self,
48               num_units,
49               activation=None,
50               reuse=None,
51               name=None,
52               dtype=None,
53               **kwargs):
54    """Initializes the parameters for an RNN cell.
55
56    Args:
57      num_units: int, The number of units in the RNN cell.
58      activation: Nonlinearity to use.  Default: `tanh`. It could also be string
59        that is within Keras activation function names.
60      reuse: (optional) Python boolean describing whether to reuse variables in
61        an existing scope. Raises an error if not `True` and the existing scope
62        already has the given variables.
63      name: String, the name of the layer. Layers with the same name will share
64        weights, but to avoid mistakes we require reuse=True in such cases.
65      dtype: Default dtype of the layer (default of `None` means use the type of
66        the first input). Required when `build` is called before `call`.
67      **kwargs: Dict, keyword named properties for common layer attributes, like
68        `trainable` etc when constructing the cell from configs of get_config().
69
70    Raises:
71      ValueError: If the existing scope already has the given variables.
72    """
73    super(TfLiteRNNCell, self).__init__(
74        _reuse=reuse, name=name, dtype=dtype, **kwargs)
75
76    # Inputs must be Rank-2.
77    self.input_spec = base_layer.InputSpec(ndim=2)
78
79    self._tflite_wrapper = op_hint.OpHint("UnidirectionalSequenceRnn")
80    self._num_units = num_units
81    if activation:
82      self._activation = activations.get(activation)
83    else:
84      self._activation = math_ops.tanh
85
86  @property
87  def state_size(self):
88    return self._num_units
89
90  @property
91  def output_size(self):
92    return self._num_units
93
94  def build(self, inputs_shape):
95    """Builds the RNN cell.
96
97    Args:
98      inputs_shape: Rnn input tensor shape.
99
100    Raises:
101      ValueError: If last dimension of the input shape is not known.
102    """
103    if inputs_shape[-1] is None:
104      raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" %
105                       (inputs_shape,))
106
107    input_depth = inputs_shape[-1]
108
109    def add_variable_wrapped(name, shape, initializer, index):
110      var = self.add_weight(name, shape=shape, initializer=initializer)
111      return self._tflite_wrapper.add_input(
112          var, name=name, index_override=index)
113
114    self._input_weights = add_variable_wrapped(
115        "input_weights", [self._num_units, input_depth], None, 1)
116    self._recurrent_weights = add_variable_wrapped(
117        "recurrent_weights", [self._num_units, self._num_units], None, 2)
118    self._bias = add_variable_wrapped(
119        "bias",
120        shape=[self._num_units],
121        initializer=init_ops.zeros_initializer(dtype=self.dtype),
122        index=3)
123
124    self.built = True
125
126  def call(self, inputs, state):
127    """Most basic RNN: output = new_state = act(W * input + U * state + B)."""
128    inputs = self._tflite_wrapper.add_input(
129        inputs, tag="input", name="input", aggregate="stack", index_override=0)
130    state = self._tflite_wrapper.add_input(
131        state,
132        tag="hidden_state",
133        name="hidden_state",
134        aggregate="first",
135        index_override=4)
136    weights = array_ops.transpose(
137        array_ops.concat([self._input_weights, self._recurrent_weights], 1))
138    gate_inputs = math_ops.matmul(array_ops.concat([inputs, state], 1), weights)
139    gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
140    output = self._activation(gate_inputs)
141    output = self._tflite_wrapper.add_output(
142        output,
143        tag="output",
144        name="output",
145        index_override=1,
146        aggregate="stack")
147    return output, output
148
149  def get_config(self):
150    config = {
151        "num_units": self._num_units,
152        "activation": activations.serialize(self._activation),
153        "reuse": self._reuse,
154    }
155    base_config = super(TfLiteRNNCell, self).get_config()
156    return dict(itertools.chain(base_config.items(), config.items()))
157
158
159@tf_export("lite.experimental.nn.TFLiteLSTMCell")
160class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell):
161  """Long short-term memory unit (LSTM) recurrent network cell.
162
163  This is used only for TfLite, it provides hints and it also makes the
164  variables in the desired for the tflite ops  (transposed and seaparated).
165
166  The default non-peephole implementation is based on:
167
168    https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf
169
170  Felix Gers, Jurgen Schmidhuber, and Fred Cummins.
171  "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999.
172
173  The peephole implementation is based on:
174
175    https://research.google.com/pubs/archive/43905.pdf
176
177  Hasim Sak, Andrew Senior, and Francoise Beaufays.
178  "Long short-term memory recurrent neural network architectures for
179   large scale acoustic modeling." INTERSPEECH, 2014.
180
181  The class uses optional peep-hole connections, optional cell clipping, and
182  an optional projection layer.
183
184  Note that this cell is not optimized for performance. Please use
185  `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or
186  `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for
187  better performance on CPU.
188  """
189
190  def __init__(self,
191               num_units,
192               use_peepholes=False,
193               cell_clip=None,
194               initializer=None,
195               num_proj=None,
196               proj_clip=None,
197               num_unit_shards=None,
198               num_proj_shards=None,
199               forget_bias=1.0,
200               state_is_tuple=True,
201               activation=None,
202               reuse=None,
203               name=None,
204               dtype=None):
205    """Initialize the parameters for an LSTM cell.
206
207    Args:
208      num_units: int, The number of units in the LSTM cell.
209      use_peepholes: bool, set True to enable diagonal/peephole connections.
210      cell_clip: (optional) A float value, if provided the cell state is clipped
211        by this value prior to the cell output activation.
212      initializer: (optional) The initializer to use for the weight and
213        projection matrices.
214      num_proj: (optional) int, The output dimensionality for the projection
215        matrices.  If None, no projection is performed.
216      proj_clip: (optional) A float value.  If `num_proj > 0` and `proj_clip` is
217        provided, then the projected values are clipped elementwise to within
218        `[-proj_clip, proj_clip]`.
219      num_unit_shards: Deprecated, will be removed by Jan. 2017. Use a
220        variable_scope partitioner instead.
221      num_proj_shards: Deprecated, will be removed by Jan. 2017. Use a
222        variable_scope partitioner instead.
223      forget_bias: Biases of the forget gate are initialized by default to 1 in
224        order to reduce the scale of forgetting at the beginning of the
225        training. Must set it manually to `0.0` when restoring from CudnnLSTM
226        trained checkpoints.
227      state_is_tuple: If True, accepted and returned states are 2-tuples of the
228        `c_state` and `m_state`.  If False, they are concatenated along the
229        column axis.  This latter behavior will soon be deprecated.
230      activation: Activation function of the inner states.  Default: `tanh`.
231      reuse: (optional) Python boolean describing whether to reuse variables in
232        an existing scope.  If not `True`, and the existing scope already has
233        the given variables, an error is raised.
234      name: String, the name of the layer. Layers with the same name will share
235        weights, but to avoid mistakes we require reuse=True in such cases.
236      dtype: Default dtype of the layer (default of `None` means use the type of
237        the first input). Required when `build` is called before `call`.  When
238        restoring from CudnnLSTM-trained checkpoints, use
239        `CudnnCompatibleLSTMCell` instead.
240    """
241    super(TFLiteLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
242    # TODO(raziel): decide if we want to just support tuples (yes please!).
243    if not state_is_tuple:
244      logging.warn(
245          "%s: Using a concatenated state is slower and will soon be "
246          "deprecated.  Use state_is_tuple=True.", self)
247    if num_unit_shards is not None or num_proj_shards is not None:
248      logging.warn(
249          "%s: The num_unit_shards and proj_unit_shards parameters are "
250          "deprecated and will be removed in Jan 2017.  "
251          "Use a variable scope with a partitioner instead.", self)
252
253    # Inputs must be 2-dimensional.
254    # TODO(raziel): layers stuff -- chop if un-layerizing Op.
255    self.input_spec = base_layer.InputSpec(ndim=2)
256
257    self._tflite_wrapper = op_hint.OpHint("UnidirectionalSequenceLstm")
258
259    self._num_units = num_units
260    self._use_peepholes = use_peepholes
261    self._cell_clip = cell_clip
262    self._initializer = initializer
263    self._num_proj = num_proj
264    self._proj_clip = proj_clip
265    self._num_unit_shards = num_unit_shards
266    self._num_proj_shards = num_proj_shards
267    self._forget_bias = forget_bias
268    self._state_is_tuple = state_is_tuple
269    self._activation = activation or math_ops.tanh
270
271    self._output_size = num_proj if num_proj else num_units
272    self._state_size = (
273        rnn_cell_impl.LSTMStateTuple(num_units, self._output_size)
274        if state_is_tuple else num_units + self._output_size)
275
276  @property
277  def state_size(self):
278    return self._state_size
279
280  @property
281  def output_size(self):
282    return self._output_size
283
284  def build(self, inputs_shape):
285    """Build TfLite LSTM cell graph.
286
287    Args:
288      inputs_shape: The inputs_shape must be known, and is [batch_size,
289        input_size] shape.
290
291    Raises:
292      ValueError: if the inputs_shape is invalid.
293    """
294    if len(inputs_shape) != 2:
295      raise ValueError(
296          "inputs_shape must be 2-dimensional, saw shape: %s" % inputs_shape)
297    input_depth = (
298        inputs_shape[1]
299        if isinstance(inputs_shape[1], int) else inputs_shape[1].value)
300    if input_depth is None:
301      raise ValueError("Invalid inputs_shape, saw shape: %s" % inputs_shape)
302
303    maybe_partitioner = (
304        partitioned_variables.fixed_size_partitioner(self._num_unit_shards)
305        if self._num_unit_shards is not None else None)
306    input_weight_shape = [self._num_units, input_depth]
307    cell_weight_shape = [self._num_units, self._output_size]
308    bias_shape = [self._num_units]
309
310    def add_variable_wrapped(name, shape, initializer, index, partitioner):
311      var = self.add_weight(
312          name, shape=shape, initializer=initializer, partitioner=partitioner)
313      return self._tflite_wrapper.add_input(
314          var, name=name, index_override=index)
315
316    weight_initializer = self._initializer
317    if self.dtype is None:
318      bias_initializer = init_ops.zeros_initializer
319    else:
320      bias_initializer = init_ops.zeros_initializer(dtype=self.dtype)
321
322    forget_bias_initializer = init_ops.constant_initializer(self._forget_bias)
323
324    self.input_to_input_w = add_variable_wrapped(
325        "input_to_input_w", input_weight_shape, weight_initializer, 1,
326        maybe_partitioner)
327    self.input_to_forget_w = add_variable_wrapped(
328        "input_to_forget_w", input_weight_shape, weight_initializer, 2,
329        maybe_partitioner)
330    self.input_to_cell_w = add_variable_wrapped(
331        "input_to_cell_w", input_weight_shape, weight_initializer, 3,
332        maybe_partitioner)
333    self.input_to_output_w = add_variable_wrapped(
334        "input_to_output_w", input_weight_shape, weight_initializer, 4,
335        maybe_partitioner)
336    self.cell_to_input_w = add_variable_wrapped(
337        "cell_to_input_w", cell_weight_shape, weight_initializer, 5,
338        maybe_partitioner)
339    self.cell_to_forget_w = add_variable_wrapped(
340        "cell_to_forget_w", cell_weight_shape, weight_initializer, 6,
341        maybe_partitioner)
342    self.cell_to_cell_w = add_variable_wrapped(
343        "cell_to_cell_w", cell_weight_shape, weight_initializer, 7,
344        maybe_partitioner)
345    self.cell_to_output_w = add_variable_wrapped(
346        "cell_to_output_w", cell_weight_shape, weight_initializer, 8,
347        maybe_partitioner)
348
349    self.input_bias = add_variable_wrapped(
350        "input_bias", bias_shape, bias_initializer, 12, maybe_partitioner)
351    self.forget_bias = add_variable_wrapped("forget_bias", bias_shape,
352                                            forget_bias_initializer, 13,
353                                            maybe_partitioner)
354    self.cell_bias = add_variable_wrapped(
355        "cell_bias", bias_shape, bias_initializer, 14, maybe_partitioner)
356    self.output_bias = add_variable_wrapped(
357        "output_bias", bias_shape, bias_initializer, 15, maybe_partitioner)
358
359    # index 9, 10, 11.
360    # f stands for forget, i stands for input and o stands for output.
361    if self._use_peepholes:
362      self._w_f_diag = add_variable_wrapped("w_f_diag", [self._num_units],
363                                            self._initializer, 10,
364                                            maybe_partitioner)
365      self._w_i_diag = add_variable_wrapped("w_i_diag", [self._num_units],
366                                            self._initializer, 9,
367                                            maybe_partitioner)
368      self._w_o_diag = add_variable_wrapped("w_o_diag", [self._num_units],
369                                            self._initializer, 11,
370                                            maybe_partitioner)
371
372    # index 16 for proj kernel.
373    if self._num_proj is not None:
374      maybe_proj_partitioner = (
375          partitioned_variables.fixed_size_partitioner(self._num_proj_shards)
376          if self._num_proj_shards is not None else None)
377      self._proj_kernel = add_variable_wrapped(
378          "projection/kernel", [self._num_proj, self._num_units],
379          self._initializer,
380          16,
381          partitioner=maybe_proj_partitioner)
382
383    self.built = True
384
385  def call(self, inputs, state):
386    """Run one step of LSTM.
387
388    Args:
389      inputs: input Tensor, 2D, `[batch, num_units]`.
390      state: if `state_is_tuple` is False, this must be a state Tensor, `2-D,
391        [batch, state_size]`.  If `state_is_tuple` is True, this must be a tuple
392        of state Tensors, both `2-D`, with column sizes `c_state` and `m_state`.
393
394    Returns:
395      A tuple containing:
396
397      - A `2-D, [batch, output_dim]`, Tensor representing the output of the
398        LSTM after reading `inputs` when previous state was `state`.
399        Here output_dim is:
400           num_proj if num_proj was set,
401           num_units otherwise.
402      - Tensor(s) representing the new state of LSTM after reading `inputs` when
403        the previous state was `state`.  Same type and shape(s) as `state`.
404
405    Raises:
406      ValueError: If input size cannot be inferred from inputs via
407        static shape inference.
408    """
409    inputs = self._tflite_wrapper.add_input(
410        inputs, tag="input", name="input", aggregate="stack", index_override=0)
411
412    # Make sure inputs and bias_initializer has the same type.
413    assert inputs.dtype == self.input_to_input_w.dtype
414
415    num_proj = self._num_units if self._num_proj is None else self._num_proj
416    sigmoid = math_ops.sigmoid
417
418    if self._state_is_tuple:
419      (c_prev, m_prev) = state
420    else:
421      c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
422      m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])
423
424    # Note: For TfLite, cell_state is at index 19 while activation state at
425    # index 18.
426    c_prev = self._tflite_wrapper.add_input(
427        c_prev,
428        tag="c_prev",
429        name="c_prev",
430        aggregate="first",
431        index_override=19)
432    m_prev = self._tflite_wrapper.add_input(
433        m_prev,
434        tag="m_prev",
435        name="m_prev",
436        aggregate="first",
437        index_override=18)
438
439    input_size = inputs.shape.with_rank(2)[1]
440    if input_size.value is None:
441      raise ValueError("Could not infer input size from inputs.shape[-1]")
442
443    inputs_and_m_prev = array_ops.concat([inputs, m_prev], axis=1)
444
445    # i stands for input gate.
446    # f stands for forget gate activation.
447    # o outputs.
448    # j output of LSTM unit.
449    # c is the final state.
450    # m is the output.
451    i = nn_ops.bias_add(
452        math_ops.matmul(
453            inputs_and_m_prev,
454            array_ops.concat([self.input_to_input_w, self.cell_to_input_w],
455                             axis=1),
456            transpose_b=True), self.input_bias)
457    f = nn_ops.bias_add(
458        math_ops.matmul(
459            inputs_and_m_prev,
460            array_ops.concat([self.input_to_forget_w, self.cell_to_forget_w],
461                             axis=1),
462            transpose_b=True), self.forget_bias)
463    o = nn_ops.bias_add(
464        math_ops.matmul(
465            inputs_and_m_prev,
466            array_ops.concat([self.input_to_output_w, self.cell_to_output_w],
467                             axis=1),
468            transpose_b=True), self.output_bias)
469    j = nn_ops.bias_add(
470        math_ops.matmul(
471            inputs_and_m_prev,
472            array_ops.concat([self.input_to_cell_w, self.cell_to_cell_w],
473                             axis=1),
474            transpose_b=True), self.cell_bias)
475
476    # Diagonal connections
477    if self._use_peepholes:
478      c = (
479          sigmoid(f + self._w_f_diag * c_prev) * c_prev +
480          sigmoid(i + self._w_i_diag * c_prev) * self._activation(j))
481    else:
482      c = (sigmoid(f) * c_prev + sigmoid(i) * self._activation(j))
483
484    if self._cell_clip is not None:
485      # pylint: disable=invalid-unary-operand-type
486      c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
487      # pylint: enable=invalid-unary-operand-type
488    if self._use_peepholes:
489      m = sigmoid(o + self._w_o_diag * c) * self._activation(c)
490    else:
491      m = sigmoid(o) * self._activation(c)
492
493    if self._num_proj is not None:
494      transposed_proj_kernel = array_ops.transpose(self._proj_kernel)
495      m = math_ops.matmul(m, transposed_proj_kernel)
496
497      if self._proj_clip is not None:
498        # pylint: disable=invalid-unary-operand-type
499        m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
500        # pylint: enable=invalid-unary-operand-type
501
502    c = self._tflite_wrapper.add_output(
503        c, tag="c", name="c", aggregate="last", index_override=1)
504    m = self._tflite_wrapper.add_output(
505        m, tag="m", name="m", index_override=2, aggregate="stack")
506
507    new_state = (
508        rnn_cell_impl.LSTMStateTuple(c, m)
509        if self._state_is_tuple else array_ops.concat([c, m], 1))
510    return m, new_state
511
512  def get_config(self):
513    config = {
514        "num_units": self._num_units,
515        "use_peepholes": self._use_peepholes,
516        "cell_clip": self._cell_clip,
517        "initializer": initializers.serialize(self._initializer),
518        "num_proj": self._num_proj,
519        "proj_clip": self._proj_clip,
520        "num_unit_shards": self._num_unit_shards,
521        "num_proj_shards": self._num_proj_shards,
522        "forget_bias": self._forget_bias,
523        "state_is_tuple": self._state_is_tuple,
524        "activation": activations.serialize(self._activation),
525        "reuse": self._reuse,
526    }
527    base_config = super(TFLiteLSTMCell, self).get_config()
528    return dict(list(base_config.items()) + list(config.items()))
529