1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Op that looks up items from a sparse tensor in an embedding matrix.
17 // The sparse lookup tensor is represented by three individual tensors: lookup,
18 // indices, and dense_shape. The representation assume that the corresponding
19 // dense tensor would satisfy:
20 //   * dense.shape = dense_shape
21 //   * dense[tuple(indices[i])] = lookup[i]
22 //
23 // By convention, indices should be sorted.
24 //
25 // Options:
26 //   combiner: The reduction op (SUM, MEAN, SQRTN).
27 //     * SUM computes the weighted sum of the embedding results.
28 //     * MEAN is the weighted sum divided by the total weight.
29 //     * SQRTN is the weighted sum divided by the square root of the sum of the
30 //       squares of the weights.
31 //
32 // Input:
33 //     Tensor[0]: Ids to lookup, dim.size == 1, int32.
34 //     Tensor[1]: Indices, int32.
35 //     Tensor[2]: Dense shape, int32.
36 //     Tensor[3]: Weights to use for aggregation, float.
37 //     Tensor[4]: Params, a matrix of multi-dimensional items,
38 //                dim.size >= 2, float.
39 //
40 // Output:
41 //   A (dense) tensor representing the combined embeddings for the sparse ids.
42 //   For each row in the sparse tensor represented by (lookup, indices, shape)
43 //   the op looks up the embeddings for all ids in that row, multiplies them by
44 //   the corresponding weight, and combines these embeddings as specified in the
45 //   last dimension.
46 //
47 //   Output.dim = [l0, ... , ln-1, e1, ..., em]
48 //   Where dense_shape == [l0, ..., ln] and Tensor[4].dim == [e0, e1, ..., em]
49 //
50 //   For instance, if params is a 10x20 matrix and ids, weights are:
51 //
52 //   [0, 0]: id 1, weight 2.0
53 //   [0, 1]: id 3, weight 0.5
54 //   [1, 0]: id 0, weight 1.0
55 //   [2, 3]: id 1, weight 3.0
56 //
57 //   with combiner=MEAN, then the output will be a (3, 20) tensor where:
58 //
59 //   output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
60 //   output[1, :] = (params[0, :] * 1.0) / 1.0
61 //   output[2, :] = (params[1, :] * 3.0) / 3.0
62 //
63 //   When indices are out of bound, the op will not succeed.
64 
65 #include <stdint.h>
66 
67 #include <algorithm>
68 #include <cmath>
69 
70 #include "tensorflow/lite/c/builtin_op_data.h"
71 #include "tensorflow/lite/c/common.h"
72 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
73 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
74 #include "tensorflow/lite/kernels/kernel_util.h"
75 
76 namespace tflite {
77 namespace ops {
78 namespace builtin {
79 
80 namespace {
81 
Prepare(TfLiteContext * context,TfLiteNode * node)82 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
83   TF_LITE_ENSURE_EQ(context, NumInputs(node), 5);
84   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
85 
86   const TfLiteTensor* ids;
87   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &ids));
88   TF_LITE_ENSURE_EQ(context, NumDimensions(ids), 1);
89   TF_LITE_ENSURE_EQ(context, ids->type, kTfLiteInt32);
90 
91   const TfLiteTensor* indices;
92   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &indices));
93   TF_LITE_ENSURE_EQ(context, NumDimensions(indices), 2);
94   TF_LITE_ENSURE_EQ(context, indices->type, kTfLiteInt32);
95 
96   const TfLiteTensor* shape;
97   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &shape));
98   TF_LITE_ENSURE_EQ(context, NumDimensions(shape), 1);
99   TF_LITE_ENSURE_EQ(context, shape->type, kTfLiteInt32);
100 
101   const TfLiteTensor* weights;
102   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 3, &weights));
103   TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 1);
104   TF_LITE_ENSURE_EQ(context, weights->type, kTfLiteFloat32);
105 
106   TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0),
107                     SizeOfDimension(ids, 0));
108   TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0),
109                     SizeOfDimension(weights, 0));
110 
111   const TfLiteTensor* value;
112   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 4, &value));
113   TF_LITE_ENSURE(context, NumDimensions(value) >= 2);
114 
115   // Mark the output as a dynamic tensor.
116   TfLiteTensor* output;
117   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
118   TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
119   output->allocation_type = kTfLiteDynamic;
120 
121   return kTfLiteOk;
122 }
123 
FinalizeAggregation(TfLiteCombinerType combiner,int num_elements,float current_total_weight,float current_squares_weight,int embedding_size,float * output)124 void FinalizeAggregation(TfLiteCombinerType combiner, int num_elements,
125                          float current_total_weight,
126                          float current_squares_weight, int embedding_size,
127                          float* output) {
128   if (combiner != kTfLiteCombinerTypeSum && num_elements > 0) {
129     float multiplier = 1.0;
130     switch (combiner) {
131       case kTfLiteCombinerTypeMean:
132         multiplier = current_total_weight;
133         break;
134       case kTfLiteCombinerTypeSqrtn:
135         multiplier = std::sqrt(current_squares_weight);
136         break;
137       default:
138         break;
139     }
140     for (int k = 0; k < embedding_size; k++) {
141       output[k] /= multiplier;
142     }
143   }
144 }
145 
Eval(TfLiteContext * context,TfLiteNode * node)146 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
147   auto* params =
148       reinterpret_cast<TfLiteEmbeddingLookupSparseParams*>(node->builtin_data);
149   TfLiteTensor* output;
150   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
151   const TfLiteTensor* ids;
152   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &ids));
153   const TfLiteTensor* indices;
154   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &indices));
155   const TfLiteTensor* dense_shape;
156   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &dense_shape));
157   const TfLiteTensor* weights;
158   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 3, &weights));
159   const TfLiteTensor* value;
160   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 4, &value));
161 
162   const int lookup_rank = SizeOfDimension(indices, 1);
163   const int embedding_rank = NumDimensions(value);
164   const int num_lookups = SizeOfDimension(ids, 0);
165   const int num_rows = SizeOfDimension(value, 0);
166 
167   // The last dimension gets replaced by the embedding.
168   const int output_rank = (lookup_rank - 1) + (embedding_rank - 1);
169 
170   // Make sure that the actual dense shape of the sparse tensor represented by
171   // (loopkup, indices, dense_shape) is consistent.
172   TF_LITE_ENSURE_EQ(context, SizeOfDimension(dense_shape, 0), lookup_rank);
173 
174   // Resize output tensor.
175   TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank);
176   int k = 0;
177   int embedding_size = 1;
178   int lookup_size = 1;
179   for (int i = 0; i < lookup_rank - 1; i++, k++) {
180     const int dim = dense_shape->data.i32[i];
181     lookup_size *= dim;
182     output_shape->data[k] = dim;
183   }
184   for (int i = 1; i < embedding_rank; i++, k++) {
185     const int dim = SizeOfDimension(value, i);
186     embedding_size *= dim;
187     output_shape->data[k] = dim;
188   }
189   TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_shape));
190   const int output_size = lookup_size * embedding_size;
191   TfLiteTensorRealloc(output_size * sizeof(float), output);
192 
193   float* output_ptr = GetTensorData<float>(output);
194   const float* weights_ptr = GetTensorData<float>(weights);
195   const float* value_ptr = GetTensorData<float>(value);
196 
197   std::fill_n(output_ptr, output_size, 0.0f);
198 
199   // Keep track of the current bucket for aggregation/combination.
200   int current_output_offset = 0;
201   float current_total_weight = 0.0;
202   float current_squares_weight = 0.0;
203   int num_elements = 0;
204 
205   for (int i = 0; i < num_lookups; i++) {
206     int idx = ids->data.i32[i];
207     if (idx >= num_rows || idx < 0) {
208       context->ReportError(context,
209                            "Embedding Lookup Sparse: index out of bounds. "
210                            "Got %d, and bounds are [0, %d]",
211                            idx, num_rows - 1);
212       return kTfLiteError;
213     }
214 
215     // Check where we need to aggregate.
216     const int example_indices_offset = i * lookup_rank;
217     int output_bucket = 0;
218     int stride = 1;
219     for (int k = (lookup_rank - 1) - 1; k >= 0; k--) {
220       output_bucket += indices->data.i32[example_indices_offset + k] * stride;
221       stride *= dense_shape->data.i32[k];
222     }
223     const int output_offset = output_bucket * embedding_size;
224 
225     // If we are in a new aggregation bucket and the combiner is not the sum,
226     // go back and finalize the result of the previous bucket.
227     if (output_offset != current_output_offset) {
228       FinalizeAggregation(params->combiner, num_elements, current_total_weight,
229                           current_squares_weight, embedding_size,
230                           &output_ptr[current_output_offset]);
231 
232       // Track next bucket.
233       num_elements = 0;
234       current_total_weight = 0.0;
235       current_squares_weight = 0.0;
236       current_output_offset = output_offset;
237     }
238 
239     // Add element to aggregation.
240     ++num_elements;
241     const int example_embedding_offset = idx * embedding_size;
242     const float w = weights_ptr[i];
243     current_squares_weight += w * w;
244     current_total_weight += w;
245     for (int k = 0; k < embedding_size; k++) {
246       output_ptr[current_output_offset + k] +=
247           value_ptr[example_embedding_offset + k] * w;
248     }
249   }
250 
251   // Finalize last bucket.
252   FinalizeAggregation(params->combiner, num_elements, current_total_weight,
253                       current_squares_weight, embedding_size,
254                       &GetTensorData<float>(output)[current_output_offset]);
255 
256   return kTfLiteOk;
257 }
258 
259 }  // namespace
260 
Register_EMBEDDING_LOOKUP_SPARSE()261 TfLiteRegistration* Register_EMBEDDING_LOOKUP_SPARSE() {
262   static TfLiteRegistration r = {nullptr, nullptr, Prepare, Eval};
263   return &r;
264 }
265 
266 }  // namespace builtin
267 }  // namespace ops
268 }  // namespace tflite
269