1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <limits>
17 
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/experimental/kernels/gru_cell.h"
20 #include "tensorflow/lite/kernels/cpu_backend_context.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 
24 namespace tflite {
25 namespace ops {
26 namespace experimental {
27 namespace unidirectional_sequence_gru {
28 namespace {
29 
GruImpl(const TfLiteTensor * input,const TfLiteTensor * input_state,const TfLiteTensor * gate_weight,const TfLiteTensor * gate_bias,const TfLiteTensor * candidate_weight,const TfLiteTensor * candidate_bias,TfLiteTensor * output,TfLiteTensor * output_state,TfLiteTensor * activation,TfLiteTensor * concat,tflite::CpuBackendContext * cpu_backend_context)30 void GruImpl(const TfLiteTensor* input, const TfLiteTensor* input_state,
31              const TfLiteTensor* gate_weight, const TfLiteTensor* gate_bias,
32              const TfLiteTensor* candidate_weight,
33              const TfLiteTensor* candidate_bias, TfLiteTensor* output,
34              TfLiteTensor* output_state, TfLiteTensor* activation,
35              TfLiteTensor* concat,
36              tflite::CpuBackendContext* cpu_backend_context) {
37   const int n_time = input->dims->data[0];
38   const int n_batch = input->dims->data[1];
39   const int n_input = input->dims->data[2];
40   const int n_output = output->dims->data[2];
41   const int n_batch_input = n_batch * n_input;
42   const int n_batch_output = n_batch * n_output;
43   const RuntimeShape input_shape({n_batch, n_input});
44   const float* input_data = GetTensorData<float>(input);
45   const RuntimeShape state_shape = GetTensorShape(input_state);
46   const float* input_state_data = GetTensorData<float>(input_state);
47   const RuntimeShape gate_weight_shape = GetTensorShape(gate_weight);
48   const float* gate_weight_data = GetTensorData<float>(gate_weight);
49   const RuntimeShape gate_bias_shape = GetTensorShape(gate_bias);
50   const float* gate_bias_data = GetTensorData<float>(gate_bias);
51   const RuntimeShape candidate_weight_shape = GetTensorShape(candidate_weight);
52   const float* candidate_weight_data = GetTensorData<float>(candidate_weight);
53   const RuntimeShape candidate_bias_shape = GetTensorShape(candidate_bias);
54   const float* candidate_bias_data = GetTensorData<float>(candidate_bias);
55   const RuntimeShape activation_shape = GetTensorShape(activation);
56   const RuntimeShape output_shape = RuntimeShape({n_batch, n_output});
57   float* output_data = GetTensorData<float>(output);
58   float* output_state_data = GetTensorData<float>(output_state);
59   float* activation_data = GetTensorData<float>(activation);
60   const RuntimeShape concat_shape = GetTensorShape(concat);
61   float* concat_data = GetTensorData<float>(concat);
62   tflite::FullyConnectedParams fc_params;
63   fc_params.float_activation_min = std::numeric_limits<float>::lowest();
64   fc_params.float_activation_max = std::numeric_limits<float>::max();
65   for (int i = 0; i < n_time; ++i) {
66     gru_cell::GruCell(
67         input_shape, input_data, state_shape, input_state_data,
68         gate_weight_shape, gate_weight_data, gate_bias_shape, gate_bias_data,
69         candidate_weight_shape, candidate_weight_data, candidate_bias_shape,
70         candidate_bias_data, output_shape, output_data, output_state_data,
71         activation_shape, activation_data, concat_shape, concat_data, fc_params,
72         cpu_backend_context);
73     input_data += n_batch_input;
74     output_data += n_batch_output;
75     input_state_data = output_state_data;
76   }
77 }
78 
79 }  // namespace
80 
81 enum InputTensor {
82   // Input tensor of size [n_time, n_batch, n_input]
83   kInput = 0,
84   // Input state tensor of size [n_batch, n_output]
85   kInputState = 1,
86   // Gate weight tensor of size [2*n_output, n_input+n_output]
87   kGateWeight = 2,
88   // Gate bias tensor of size [2*n_output]
89   kGateBias = 3,
90   // Candidate weight tensor of size [n_output, n_input+n_output]
91   kCandidateWeight = 4,
92   // Candidate bias tensor of size [n_output]
93   kCandidateBias = 5,
94   kInputNum = 6
95 };
96 
97 enum OutputTensor {
98   // Input tensor of size [n_time, n_batch, n_output]
99   kOutput = 0,
100   // Output state tensor of size [n_batch, n_output]
101   kOutputState = 1,
102   kOutputNum = 2
103 };
104 
105 enum TemporaryTensor {
106   // Scratch buffer for activation of size [n_batch, 2*n_output]
107   kActivation = 0,
108   // Scratch buffer for activation of size [n_batch, n_input+n_output]
109   kConcat = 1,
110   kTemporaryNum = 2
111 };
112 
Init(TfLiteContext * context,const char * buffer,size_t length)113 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
114   auto* scratch_tensor_index = new int;
115   context->AddTensors(context, kTemporaryNum, scratch_tensor_index);
116   return scratch_tensor_index;
117 }
118 
Free(TfLiteContext * context,void * buffer)119 void Free(TfLiteContext* context, void* buffer) {
120   delete reinterpret_cast<int*>(buffer);
121 }
122 
Prepare(TfLiteContext * context,TfLiteNode * node)123 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
124   int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
125 
126   TF_LITE_ENSURE_EQ(context, node->inputs->size, kInputNum);
127   TF_LITE_ENSURE_EQ(context, node->outputs->size, kOutputNum);
128 
129   // input's dim = [n_time, n_batch, n_input]
130   const TfLiteTensor* input;
131   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInput, &input));
132   TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
133   const int n_time = input->dims->data[0];
134   const int n_batch = input->dims->data[1];
135   const int n_input = input->dims->data[2];
136 
137   // input_state's dim = [n_batch, n_output]
138   const TfLiteTensor* input_state;
139   TF_LITE_ENSURE_OK(context,
140                     GetInputSafe(context, node, kInputState, &input_state));
141   TF_LITE_ENSURE_EQ(context, input_state->dims->size, 2);
142   TF_LITE_ENSURE_EQ(context, input_state->dims->data[0], n_batch);
143   const int n_output = input_state->dims->data[1];
144 
145   // gate_weight' dim = [2 * n_output, n_input + n_output]
146   const TfLiteTensor* gate_weight;
147   TF_LITE_ENSURE_OK(context,
148                     GetInputSafe(context, node, kGateWeight, &gate_weight));
149   TF_LITE_ENSURE_EQ(context, gate_weight->dims->size, 2);
150   TF_LITE_ENSURE_EQ(context, gate_weight->dims->data[0], 2 * n_output);
151   TF_LITE_ENSURE_EQ(context, gate_weight->dims->data[1], n_input + n_output);
152 
153   // gate_bias' dim = [2 * n_output]
154   const TfLiteTensor* gate_bias;
155   TF_LITE_ENSURE_OK(context,
156                     GetInputSafe(context, node, kGateBias, &gate_bias));
157   TF_LITE_ENSURE_EQ(context, gate_bias->dims->size, 1);
158   TF_LITE_ENSURE_EQ(context, gate_bias->dims->data[0], 2 * n_output);
159 
160   // candidate_weight' dim = [n_output, n_input + n_output]
161   const TfLiteTensor* candidate_weight;
162   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kCandidateWeight,
163                                           &candidate_weight));
164   TF_LITE_ENSURE_EQ(context, candidate_weight->dims->size, 2);
165   TF_LITE_ENSURE_EQ(context, candidate_weight->dims->data[0], n_output);
166   TF_LITE_ENSURE_EQ(context, candidate_weight->dims->data[1],
167                     n_input + n_output);
168 
169   // candidate_bias' dim = [n_output]
170   const TfLiteTensor* candidate_bias;
171   TF_LITE_ENSURE_OK(
172       context, GetInputSafe(context, node, kCandidateBias, &candidate_bias));
173   TF_LITE_ENSURE_EQ(context, candidate_bias->dims->size, 1);
174   TF_LITE_ENSURE_EQ(context, candidate_bias->dims->data[0], n_output);
175 
176   // output's dim = [n_time, n_batch, n_output]
177   TfLiteTensor* output;
178   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutput, &output));
179   TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
180   output_size->data[0] = n_time;
181   output_size->data[1] = n_batch;
182   output_size->data[2] = n_output;
183   TF_LITE_ENSURE_OK(context,
184                     context->ResizeTensor(context, output, output_size));
185 
186   // output_state's dim = [n_batch, n_output]
187   TfLiteTensor* output_state;
188   TF_LITE_ENSURE_OK(context,
189                     GetOutputSafe(context, node, kOutputState, &output_state));
190   TF_LITE_ENSURE_OK(
191       context, context->ResizeTensor(context, output_state,
192                                      TfLiteIntArrayCopy(input_state->dims)));
193 
194   TfLiteIntArrayFree(node->temporaries);
195   node->temporaries = TfLiteIntArrayCreate(kTemporaryNum);
196 
197   // activation's dim = [n_batch, 2 * n_output]
198   node->temporaries->data[kActivation] = *scratch_tensor_index;
199   TfLiteTensor* activation;
200   TF_LITE_ENSURE_OK(context,
201                     GetTemporarySafe(context, node, kActivation, &activation));
202   activation->type = input->type;
203   activation->allocation_type = kTfLiteArenaRw;
204   TfLiteIntArray* activation_size = TfLiteIntArrayCreate(2);
205   activation_size->data[0] = n_batch;
206   activation_size->data[1] = 2 * n_output;
207   TF_LITE_ENSURE_OK(
208       context, context->ResizeTensor(context, activation, activation_size));
209 
210   // concat's dim  = [n_batch, n_input + n_output]
211   node->temporaries->data[kConcat] = (*scratch_tensor_index) + kConcat;
212   TfLiteTensor* concat;
213   TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kConcat, &concat));
214   concat->type = input->type;
215   concat->allocation_type = kTfLiteArenaRw;
216   TfLiteIntArray* concat_size = TfLiteIntArrayCreate(2);
217   concat_size->data[0] = n_batch;
218   concat_size->data[1] = n_input + n_output;
219   TF_LITE_ENSURE_OK(context,
220                     context->ResizeTensor(context, concat, concat_size));
221 
222   return kTfLiteOk;
223 }
224 
Eval(TfLiteContext * context,TfLiteNode * node)225 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
226   const TfLiteTensor* input;
227   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInput, &input));
228   const TfLiteTensor* input_state;
229   TF_LITE_ENSURE_OK(context,
230                     GetInputSafe(context, node, kInputState, &input_state));
231   const TfLiteTensor* gate_weight;
232   TF_LITE_ENSURE_OK(context,
233                     GetInputSafe(context, node, kGateWeight, &gate_weight));
234   const TfLiteTensor* gate_bias;
235   TF_LITE_ENSURE_OK(context,
236                     GetInputSafe(context, node, kGateBias, &gate_bias));
237   const TfLiteTensor* candidate_weight;
238   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kCandidateWeight,
239                                           &candidate_weight));
240   const TfLiteTensor* candidate_bias;
241   TF_LITE_ENSURE_OK(
242       context, GetInputSafe(context, node, kCandidateBias, &candidate_bias));
243   TfLiteTensor* output;
244   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutput, &output));
245   TfLiteTensor* output_state;
246   TF_LITE_ENSURE_OK(context,
247                     GetOutputSafe(context, node, kOutputState, &output_state));
248   TfLiteTensor* activation;
249   TF_LITE_ENSURE_OK(context,
250                     GetTemporarySafe(context, node, kActivation, &activation));
251   TfLiteTensor* concat;
252   TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kConcat, &concat));
253   auto cpu_backend_context = CpuBackendContext::GetFromContext(context);
254 
255   if (gate_weight->type == kTfLiteFloat32) {
256     GruImpl(input, input_state, gate_weight, gate_bias, candidate_weight,
257             candidate_bias, output, output_state, activation, concat,
258             cpu_backend_context);
259   } else {
260     context->ReportError(context,
261                          "Unsupported combination of data types for GruCell");
262     return kTfLiteError;
263   }
264 
265   return kTfLiteOk;
266 }
267 
268 }  // namespace unidirectional_sequence_gru
269 
Register_UNIDIRECTIONAL_SEQUENCE_GRU()270 TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_GRU() {
271   static TfLiteRegistration r = {
272       unidirectional_sequence_gru::Init, unidirectional_sequence_gru::Free,
273       unidirectional_sequence_gru::Prepare, unidirectional_sequence_gru::Eval};
274   return &r;
275 }
276 
277 }  // namespace experimental
278 }  // namespace ops
279 }  // namespace tflite
280