1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_LITE_TOCO_GRAPH_TRANSFORMATIONS_LSTM_UTILS_H_ 16 #define TENSORFLOW_LITE_TOCO_GRAPH_TRANSFORMATIONS_LSTM_UTILS_H_ 17 18 #include <iostream> 19 #include <string> 20 #include <vector> 21 22 #include "tensorflow/lite/toco/model.h" 23 #include "tensorflow/lite/toco/tooling_util.h" 24 25 namespace toco { 26 27 // For consistency with the parameters defined in extended LstmCell's kernel 28 // (tensorflow/lite/kernels/lstm.cc), 29 // use lowercase for these constants. 30 31 enum ExtendedLstmCellInputs { 32 kInputTensor = 0, 33 kInputToInputWeightsTensor = 1, // Optional 34 kInputToForgetWeightsTensor = 2, 35 kInputToCellWeightsTensor = 3, 36 kInputToOutputWeightsTensor = 4, 37 kRecurrentToInputWeightsTensor = 5, // Optional 38 kRecurrentToForgetWeightsTensor = 6, 39 kRecurrentToCellWeightsTensor = 7, 40 kRecurrentToOutputWeightsTensor = 8, 41 kCellToInputWeightsTensor = 9, // Optional 42 kCellToForgetWeightsTensor = 10, // Optional 43 kCellToOutputWeightsTensor = 11, // Optional 44 kInputGateBiasTensor = 12, // Optional 45 kForgetGateBiasTensor = 13, 46 kCellGateBiasTensor = 14, 47 kOutputGateBiasTensor = 15, 48 kProjectionWeightsTensor = 16, // Optional 49 kProjectionBiasTensor = 17, // Optional 50 kInputActivationStateTensor = 18, 51 // The op can handle 18 inputs or 20 inputs. 52 kInputCellStateTensor = 19, 53 kExtendedLstmInputCount = 20, 54 }; 55 56 enum ExtendedLstmCellOutputs { 57 // TODO(ycling): Make the 2 output state tensors optional. 58 kOutputStateTensor = 0, 59 kCellStateTensor = 1, 60 kOutputTensor = 2, 61 kExtendedLstmOutputCount = 3 62 }; 63 64 // Create optional array used for optional tensor in ExtendedLstmCell inputs. 65 void CreateOptionalArray(Model* model, string* input_array_buffer, 66 const string& array_name); 67 68 // Create float array and get its buffer. 69 Buffer<ArrayDataType::kFloat>* CreateFloatArrayBuffer(Model* model, 70 string* array_name, 71 const Shape& shape); 72 73 // Copy data from one array to the other one (supports 1D and 2D array), 74 // for 1D array, the 2nd dim's size is 1. 75 // Arguments: 76 // src_buffer: the source buffer 77 // src_stride: the stride of source buffer, i.e., 2nd dim's size 78 // src_start_idx1: the 1st dim index of start point in src matrix 79 // src_start_idx2: the 2nd dim index of start point in src matrix 80 // dst_buffer: the destination buffer 81 // dst_stride: the stride of destination buffer, i.e., 2nd dim's size 82 // dst_start_idx1: the 1st dim index of start point in dst matrix 83 // dst_start_idx2: the 2nd dim index of start point in dst matrix 84 // dim1_copy_size: 1st dim size of copy data 85 // dim2_copy_size: 2nd dim size of copy data 86 void CopyArrayData(const Buffer<ArrayDataType::kFloat>& src_buffer, 87 int src_stride, int src_start_idx1, int src_start_idx2, 88 Buffer<ArrayDataType::kFloat>* dst_buffer, int dst_stride, 89 int dst_start_idx1, int dst_start_idx2, int dim1_copy_size, 90 int dim2_copy_size); 91 92 // Copy a subset of array data and create a smaller array, 93 // mostly used for spliting weights and bias for Lstm cell. 94 void CopySubArrayToArray(Model* model, string* array_name, 95 const string& tensor_name, int dim1_size, 96 int dim2_size, const Array& original_array, 97 int start_idx1, int start_idx2); 98 99 // Copy array data to a large array's submatrix, 100 // mostly used for merging weights and bias for Lstm cell. 101 void CopyArrayToSubArray(Buffer<ArrayDataType::kFloat>& tensor_buffer, 102 int tensor_stride, const Array& sub_array, 103 int start_idx1, int start_idx2); 104 105 // Get mating rnn array inputs using rnn_states flag. 106 bool GetMatchingRnnArray(Model* model, const string& back_edge_source_array, 107 string* rnn_array); 108 109 } // namespace toco 110 111 #endif // TENSORFLOW_LITE_TOCO_GRAPH_TRANSFORMATIONS_LSTM_UTILS_H_ 112