1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TfLite BasicRnnCell wrapper. 16 17TODO(renjieliu): Find a better home for this one. 18""" 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22import itertools 23 24import tensorflow.lite.python.op_hint as op_hint 25from tensorflow.python.keras import activations 26from tensorflow.python.keras import initializers 27from tensorflow.python.layers import base as base_layer 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import clip_ops 30from tensorflow.python.ops import init_ops 31from tensorflow.python.ops import math_ops 32from tensorflow.python.ops import nn_ops 33from tensorflow.python.ops import partitioned_variables 34from tensorflow.python.ops import rnn_cell_impl 35from tensorflow.python.platform import tf_logging as logging 36from tensorflow.python.util.tf_export import tf_export 37 38 39@tf_export("lite.experimental.nn.TfLiteRNNCell") 40class TfLiteRNNCell(rnn_cell_impl.LayerRNNCell): 41 """The most basic RNN cell. 42 43 This is used only for TfLite, it provides hints and it also makes the 44 variables in the desired for the tflite ops. 45 """ 46 47 def __init__(self, 48 num_units, 49 activation=None, 50 reuse=None, 51 name=None, 52 dtype=None, 53 **kwargs): 54 """Initializes the parameters for an RNN cell. 55 56 Args: 57 num_units: int, The number of units in the RNN cell. 58 activation: Nonlinearity to use. Default: `tanh`. It could also be string 59 that is within Keras activation function names. 60 reuse: (optional) Python boolean describing whether to reuse variables in 61 an existing scope. Raises an error if not `True` and the existing scope 62 already has the given variables. 63 name: String, the name of the layer. Layers with the same name will share 64 weights, but to avoid mistakes we require reuse=True in such cases. 65 dtype: Default dtype of the layer (default of `None` means use the type of 66 the first input). Required when `build` is called before `call`. 67 **kwargs: Dict, keyword named properties for common layer attributes, like 68 `trainable` etc when constructing the cell from configs of get_config(). 69 70 Raises: 71 ValueError: If the existing scope already has the given variables. 72 """ 73 super(TfLiteRNNCell, self).__init__( 74 _reuse=reuse, name=name, dtype=dtype, **kwargs) 75 76 # Inputs must be Rank-2. 77 self.input_spec = base_layer.InputSpec(ndim=2) 78 79 self._tflite_wrapper = op_hint.OpHint("UnidirectionalSequenceRnn") 80 self._num_units = num_units 81 if activation: 82 self._activation = activations.get(activation) 83 else: 84 self._activation = math_ops.tanh 85 86 @property 87 def state_size(self): 88 return self._num_units 89 90 @property 91 def output_size(self): 92 return self._num_units 93 94 def build(self, inputs_shape): 95 """Builds the RNN cell. 96 97 Args: 98 inputs_shape: Rnn input tensor shape. 99 100 Raises: 101 ValueError: If last dimension of the input shape is not known. 102 """ 103 if inputs_shape[-1] is None: 104 raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % 105 (inputs_shape,)) 106 107 input_depth = inputs_shape[-1] 108 109 def add_variable_wrapped(name, shape, initializer, index): 110 var = self.add_weight(name, shape=shape, initializer=initializer) 111 return self._tflite_wrapper.add_input( 112 var, name=name, index_override=index) 113 114 self._input_weights = add_variable_wrapped( 115 "input_weights", [self._num_units, input_depth], None, 1) 116 self._recurrent_weights = add_variable_wrapped( 117 "recurrent_weights", [self._num_units, self._num_units], None, 2) 118 self._bias = add_variable_wrapped( 119 "bias", 120 shape=[self._num_units], 121 initializer=init_ops.zeros_initializer(dtype=self.dtype), 122 index=3) 123 124 self.built = True 125 126 def call(self, inputs, state): 127 """Most basic RNN: output = new_state = act(W * input + U * state + B).""" 128 inputs = self._tflite_wrapper.add_input( 129 inputs, tag="input", name="input", aggregate="stack", index_override=0) 130 state = self._tflite_wrapper.add_input( 131 state, 132 tag="hidden_state", 133 name="hidden_state", 134 aggregate="first", 135 index_override=4) 136 weights = array_ops.transpose( 137 array_ops.concat([self._input_weights, self._recurrent_weights], 1)) 138 gate_inputs = math_ops.matmul(array_ops.concat([inputs, state], 1), weights) 139 gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) 140 output = self._activation(gate_inputs) 141 output = self._tflite_wrapper.add_output( 142 output, 143 tag="output", 144 name="output", 145 index_override=1, 146 aggregate="stack") 147 return output, output 148 149 def get_config(self): 150 config = { 151 "num_units": self._num_units, 152 "activation": activations.serialize(self._activation), 153 "reuse": self._reuse, 154 } 155 base_config = super(TfLiteRNNCell, self).get_config() 156 return dict(itertools.chain(base_config.items(), config.items())) 157 158 159@tf_export("lite.experimental.nn.TFLiteLSTMCell") 160class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell): 161 """Long short-term memory unit (LSTM) recurrent network cell. 162 163 This is used only for TfLite, it provides hints and it also makes the 164 variables in the desired for the tflite ops (transposed and seaparated). 165 166 The default non-peephole implementation is based on: 167 168 https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf 169 170 Felix Gers, Jurgen Schmidhuber, and Fred Cummins. 171 "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999. 172 173 The peephole implementation is based on: 174 175 https://research.google.com/pubs/archive/43905.pdf 176 177 Hasim Sak, Andrew Senior, and Francoise Beaufays. 178 "Long short-term memory recurrent neural network architectures for 179 large scale acoustic modeling." INTERSPEECH, 2014. 180 181 The class uses optional peep-hole connections, optional cell clipping, and 182 an optional projection layer. 183 184 Note that this cell is not optimized for performance. Please use 185 `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or 186 `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for 187 better performance on CPU. 188 """ 189 190 def __init__(self, 191 num_units, 192 use_peepholes=False, 193 cell_clip=None, 194 initializer=None, 195 num_proj=None, 196 proj_clip=None, 197 num_unit_shards=None, 198 num_proj_shards=None, 199 forget_bias=1.0, 200 state_is_tuple=True, 201 activation=None, 202 reuse=None, 203 name=None, 204 dtype=None): 205 """Initialize the parameters for an LSTM cell. 206 207 Args: 208 num_units: int, The number of units in the LSTM cell. 209 use_peepholes: bool, set True to enable diagonal/peephole connections. 210 cell_clip: (optional) A float value, if provided the cell state is clipped 211 by this value prior to the cell output activation. 212 initializer: (optional) The initializer to use for the weight and 213 projection matrices. 214 num_proj: (optional) int, The output dimensionality for the projection 215 matrices. If None, no projection is performed. 216 proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is 217 provided, then the projected values are clipped elementwise to within 218 `[-proj_clip, proj_clip]`. 219 num_unit_shards: Deprecated, will be removed by Jan. 2017. Use a 220 variable_scope partitioner instead. 221 num_proj_shards: Deprecated, will be removed by Jan. 2017. Use a 222 variable_scope partitioner instead. 223 forget_bias: Biases of the forget gate are initialized by default to 1 in 224 order to reduce the scale of forgetting at the beginning of the 225 training. Must set it manually to `0.0` when restoring from CudnnLSTM 226 trained checkpoints. 227 state_is_tuple: If True, accepted and returned states are 2-tuples of the 228 `c_state` and `m_state`. If False, they are concatenated along the 229 column axis. This latter behavior will soon be deprecated. 230 activation: Activation function of the inner states. Default: `tanh`. 231 reuse: (optional) Python boolean describing whether to reuse variables in 232 an existing scope. If not `True`, and the existing scope already has 233 the given variables, an error is raised. 234 name: String, the name of the layer. Layers with the same name will share 235 weights, but to avoid mistakes we require reuse=True in such cases. 236 dtype: Default dtype of the layer (default of `None` means use the type of 237 the first input). Required when `build` is called before `call`. When 238 restoring from CudnnLSTM-trained checkpoints, use 239 `CudnnCompatibleLSTMCell` instead. 240 """ 241 super(TFLiteLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype) 242 # TODO(raziel): decide if we want to just support tuples (yes please!). 243 if not state_is_tuple: 244 logging.warn( 245 "%s: Using a concatenated state is slower and will soon be " 246 "deprecated. Use state_is_tuple=True.", self) 247 if num_unit_shards is not None or num_proj_shards is not None: 248 logging.warn( 249 "%s: The num_unit_shards and proj_unit_shards parameters are " 250 "deprecated and will be removed in Jan 2017. " 251 "Use a variable scope with a partitioner instead.", self) 252 253 # Inputs must be 2-dimensional. 254 # TODO(raziel): layers stuff -- chop if un-layerizing Op. 255 self.input_spec = base_layer.InputSpec(ndim=2) 256 257 self._tflite_wrapper = op_hint.OpHint("UnidirectionalSequenceLstm") 258 259 self._num_units = num_units 260 self._use_peepholes = use_peepholes 261 self._cell_clip = cell_clip 262 self._initializer = initializer 263 self._num_proj = num_proj 264 self._proj_clip = proj_clip 265 self._num_unit_shards = num_unit_shards 266 self._num_proj_shards = num_proj_shards 267 self._forget_bias = forget_bias 268 self._state_is_tuple = state_is_tuple 269 self._activation = activation or math_ops.tanh 270 271 self._output_size = num_proj if num_proj else num_units 272 self._state_size = ( 273 rnn_cell_impl.LSTMStateTuple(num_units, self._output_size) 274 if state_is_tuple else num_units + self._output_size) 275 276 @property 277 def state_size(self): 278 return self._state_size 279 280 @property 281 def output_size(self): 282 return self._output_size 283 284 def build(self, inputs_shape): 285 """Build TfLite LSTM cell graph. 286 287 Args: 288 inputs_shape: The inputs_shape must be known, and is [batch_size, 289 input_size] shape. 290 291 Raises: 292 ValueError: if the inputs_shape is invalid. 293 """ 294 if len(inputs_shape) != 2: 295 raise ValueError( 296 "inputs_shape must be 2-dimensional, saw shape: %s" % inputs_shape) 297 input_depth = ( 298 inputs_shape[1] 299 if isinstance(inputs_shape[1], int) else inputs_shape[1].value) 300 if input_depth is None: 301 raise ValueError("Invalid inputs_shape, saw shape: %s" % inputs_shape) 302 303 maybe_partitioner = ( 304 partitioned_variables.fixed_size_partitioner(self._num_unit_shards) 305 if self._num_unit_shards is not None else None) 306 input_weight_shape = [self._num_units, input_depth] 307 cell_weight_shape = [self._num_units, self._output_size] 308 bias_shape = [self._num_units] 309 310 def add_variable_wrapped(name, shape, initializer, index, partitioner): 311 var = self.add_weight( 312 name, shape=shape, initializer=initializer, partitioner=partitioner) 313 return self._tflite_wrapper.add_input( 314 var, name=name, index_override=index) 315 316 weight_initializer = self._initializer 317 if self.dtype is None: 318 bias_initializer = init_ops.zeros_initializer 319 else: 320 bias_initializer = init_ops.zeros_initializer(dtype=self.dtype) 321 322 forget_bias_initializer = init_ops.constant_initializer(self._forget_bias) 323 324 self.input_to_input_w = add_variable_wrapped( 325 "input_to_input_w", input_weight_shape, weight_initializer, 1, 326 maybe_partitioner) 327 self.input_to_forget_w = add_variable_wrapped( 328 "input_to_forget_w", input_weight_shape, weight_initializer, 2, 329 maybe_partitioner) 330 self.input_to_cell_w = add_variable_wrapped( 331 "input_to_cell_w", input_weight_shape, weight_initializer, 3, 332 maybe_partitioner) 333 self.input_to_output_w = add_variable_wrapped( 334 "input_to_output_w", input_weight_shape, weight_initializer, 4, 335 maybe_partitioner) 336 self.cell_to_input_w = add_variable_wrapped( 337 "cell_to_input_w", cell_weight_shape, weight_initializer, 5, 338 maybe_partitioner) 339 self.cell_to_forget_w = add_variable_wrapped( 340 "cell_to_forget_w", cell_weight_shape, weight_initializer, 6, 341 maybe_partitioner) 342 self.cell_to_cell_w = add_variable_wrapped( 343 "cell_to_cell_w", cell_weight_shape, weight_initializer, 7, 344 maybe_partitioner) 345 self.cell_to_output_w = add_variable_wrapped( 346 "cell_to_output_w", cell_weight_shape, weight_initializer, 8, 347 maybe_partitioner) 348 349 self.input_bias = add_variable_wrapped( 350 "input_bias", bias_shape, bias_initializer, 12, maybe_partitioner) 351 self.forget_bias = add_variable_wrapped("forget_bias", bias_shape, 352 forget_bias_initializer, 13, 353 maybe_partitioner) 354 self.cell_bias = add_variable_wrapped( 355 "cell_bias", bias_shape, bias_initializer, 14, maybe_partitioner) 356 self.output_bias = add_variable_wrapped( 357 "output_bias", bias_shape, bias_initializer, 15, maybe_partitioner) 358 359 # index 9, 10, 11. 360 # f stands for forget, i stands for input and o stands for output. 361 if self._use_peepholes: 362 self._w_f_diag = add_variable_wrapped("w_f_diag", [self._num_units], 363 self._initializer, 10, 364 maybe_partitioner) 365 self._w_i_diag = add_variable_wrapped("w_i_diag", [self._num_units], 366 self._initializer, 9, 367 maybe_partitioner) 368 self._w_o_diag = add_variable_wrapped("w_o_diag", [self._num_units], 369 self._initializer, 11, 370 maybe_partitioner) 371 372 # index 16 for proj kernel. 373 if self._num_proj is not None: 374 maybe_proj_partitioner = ( 375 partitioned_variables.fixed_size_partitioner(self._num_proj_shards) 376 if self._num_proj_shards is not None else None) 377 self._proj_kernel = add_variable_wrapped( 378 "projection/kernel", [self._num_proj, self._num_units], 379 self._initializer, 380 16, 381 partitioner=maybe_proj_partitioner) 382 383 self.built = True 384 385 def call(self, inputs, state): 386 """Run one step of LSTM. 387 388 Args: 389 inputs: input Tensor, 2D, `[batch, num_units]`. 390 state: if `state_is_tuple` is False, this must be a state Tensor, `2-D, 391 [batch, state_size]`. If `state_is_tuple` is True, this must be a tuple 392 of state Tensors, both `2-D`, with column sizes `c_state` and `m_state`. 393 394 Returns: 395 A tuple containing: 396 397 - A `2-D, [batch, output_dim]`, Tensor representing the output of the 398 LSTM after reading `inputs` when previous state was `state`. 399 Here output_dim is: 400 num_proj if num_proj was set, 401 num_units otherwise. 402 - Tensor(s) representing the new state of LSTM after reading `inputs` when 403 the previous state was `state`. Same type and shape(s) as `state`. 404 405 Raises: 406 ValueError: If input size cannot be inferred from inputs via 407 static shape inference. 408 """ 409 inputs = self._tflite_wrapper.add_input( 410 inputs, tag="input", name="input", aggregate="stack", index_override=0) 411 412 # Make sure inputs and bias_initializer has the same type. 413 assert inputs.dtype == self.input_to_input_w.dtype 414 415 num_proj = self._num_units if self._num_proj is None else self._num_proj 416 sigmoid = math_ops.sigmoid 417 418 if self._state_is_tuple: 419 (c_prev, m_prev) = state 420 else: 421 c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) 422 m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) 423 424 # Note: For TfLite, cell_state is at index 19 while activation state at 425 # index 18. 426 c_prev = self._tflite_wrapper.add_input( 427 c_prev, 428 tag="c_prev", 429 name="c_prev", 430 aggregate="first", 431 index_override=19) 432 m_prev = self._tflite_wrapper.add_input( 433 m_prev, 434 tag="m_prev", 435 name="m_prev", 436 aggregate="first", 437 index_override=18) 438 439 input_size = inputs.shape.with_rank(2)[1] 440 if input_size.value is None: 441 raise ValueError("Could not infer input size from inputs.shape[-1]") 442 443 inputs_and_m_prev = array_ops.concat([inputs, m_prev], axis=1) 444 445 # i stands for input gate. 446 # f stands for forget gate activation. 447 # o outputs. 448 # j output of LSTM unit. 449 # c is the final state. 450 # m is the output. 451 i = nn_ops.bias_add( 452 math_ops.matmul( 453 inputs_and_m_prev, 454 array_ops.concat([self.input_to_input_w, self.cell_to_input_w], 455 axis=1), 456 transpose_b=True), self.input_bias) 457 f = nn_ops.bias_add( 458 math_ops.matmul( 459 inputs_and_m_prev, 460 array_ops.concat([self.input_to_forget_w, self.cell_to_forget_w], 461 axis=1), 462 transpose_b=True), self.forget_bias) 463 o = nn_ops.bias_add( 464 math_ops.matmul( 465 inputs_and_m_prev, 466 array_ops.concat([self.input_to_output_w, self.cell_to_output_w], 467 axis=1), 468 transpose_b=True), self.output_bias) 469 j = nn_ops.bias_add( 470 math_ops.matmul( 471 inputs_and_m_prev, 472 array_ops.concat([self.input_to_cell_w, self.cell_to_cell_w], 473 axis=1), 474 transpose_b=True), self.cell_bias) 475 476 # Diagonal connections 477 if self._use_peepholes: 478 c = ( 479 sigmoid(f + self._w_f_diag * c_prev) * c_prev + 480 sigmoid(i + self._w_i_diag * c_prev) * self._activation(j)) 481 else: 482 c = (sigmoid(f) * c_prev + sigmoid(i) * self._activation(j)) 483 484 if self._cell_clip is not None: 485 # pylint: disable=invalid-unary-operand-type 486 c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) 487 # pylint: enable=invalid-unary-operand-type 488 if self._use_peepholes: 489 m = sigmoid(o + self._w_o_diag * c) * self._activation(c) 490 else: 491 m = sigmoid(o) * self._activation(c) 492 493 if self._num_proj is not None: 494 transposed_proj_kernel = array_ops.transpose(self._proj_kernel) 495 m = math_ops.matmul(m, transposed_proj_kernel) 496 497 if self._proj_clip is not None: 498 # pylint: disable=invalid-unary-operand-type 499 m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) 500 # pylint: enable=invalid-unary-operand-type 501 502 c = self._tflite_wrapper.add_output( 503 c, tag="c", name="c", aggregate="last", index_override=1) 504 m = self._tflite_wrapper.add_output( 505 m, tag="m", name="m", index_override=2, aggregate="stack") 506 507 new_state = ( 508 rnn_cell_impl.LSTMStateTuple(c, m) 509 if self._state_is_tuple else array_ops.concat([c, m], 1)) 510 return m, new_state 511 512 def get_config(self): 513 config = { 514 "num_units": self._num_units, 515 "use_peepholes": self._use_peepholes, 516 "cell_clip": self._cell_clip, 517 "initializer": initializers.serialize(self._initializer), 518 "num_proj": self._num_proj, 519 "proj_clip": self._proj_clip, 520 "num_unit_shards": self._num_unit_shards, 521 "num_proj_shards": self._num_proj_shards, 522 "forget_bias": self._forget_bias, 523 "state_is_tuple": self._state_is_tuple, 524 "activation": activations.serialize(self._activation), 525 "reuse": self._reuse, 526 } 527 base_config = super(TFLiteLSTMCell, self).get_config() 528 return dict(list(base_config.items()) + list(config.items())) 529