1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
19 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
20 #include "tensorflow/lite/kernels/internal/tensor.h"
21 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 
24 namespace tflite {
25 namespace ops {
26 namespace builtin {
27 namespace matrix_diag {
28 
29 constexpr int kInputTensor = 0;
30 constexpr int kOutputTensor = 0;
31 
Prepare(TfLiteContext * context,TfLiteNode * node)32 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
33   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
34   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
35   const TfLiteTensor* input;
36   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
37   TfLiteIntArray* input_dims = input->dims;
38   int input_dims_size = input_dims->size;
39   TF_LITE_ENSURE(context, input_dims_size >= 1);
40 
41   TfLiteTensor* output;
42   TF_LITE_ENSURE_OK(context,
43                     GetOutputSafe(context, node, kOutputTensor, &output));
44   // Resize the output tensor.
45   TfLiteIntArray* output_shape = TfLiteIntArrayCreate(input_dims_size + 1);
46   for (int i = 0; i < input_dims_size; i++) {
47     output_shape->data[i] = input_dims->data[i];
48   }
49   // Last dimension in the output is the same as the last dimension in the
50   // input.
51   output_shape->data[input_dims_size] = input_dims->data[input_dims_size - 1];
52   output->type = input->type;
53   TF_LITE_ENSURE_OK(context,
54                     context->ResizeTensor(context, output, output_shape));
55 
56   return kTfLiteOk;
57 }
58 
59 // Fill the tensor to make a diagonal matrix in each batch, i.e., when
60 // row index and column index are the same, fill with the next input value.
61 // All other entries get zero.
62 // TODO(b/128636574) Move to reference_ops.
63 template <typename T>
FillDiagImpl(const T * in,T * out,const int batch_size,const int row_size,const int col_size)64 void FillDiagImpl(const T* in, T* out, const int batch_size, const int row_size,
65                   const int col_size) {
66   int idx = 0;
67   for (int b = 0; b < batch_size; b++) {
68     for (int i = 0; i < row_size; i++) {
69       for (int j = 0; j < col_size; ++j) {
70         // input values go on the diagonal, 0 elsewhere
71         if (i == j) {
72           out[i * col_size + j] = in[idx];
73           idx++;
74         } else {
75           out[i * col_size + j] = 0;
76         }
77       }
78     }
79     out += row_size * col_size;
80   }
81 }
82 
83 template <typename T>
FillDiag(const TfLiteTensor * input,TfLiteTensor * output,const int batch_size,const int row_size,const int col_size)84 void FillDiag(const TfLiteTensor* input, TfLiteTensor* output,
85               const int batch_size, const int row_size, const int col_size) {
86   FillDiagImpl<T>(GetTensorData<T>(input), GetTensorData<T>(output), batch_size,
87                   row_size, col_size);
88 }
89 
90 // Fill a tensor with given input on the diagonal, zero elsewhere
FillDiagHelper(const TfLiteTensor * input,TfLiteTensor * output)91 void FillDiagHelper(const TfLiteTensor* input, TfLiteTensor* output) {
92   const int num_output_dims = output->dims->size;
93   int batch_size = 1;
94   for (int i = 0; i < num_output_dims - 2; ++i) {
95     batch_size *= output->dims->data[i];
96   }
97 
98   const int row_size = output->dims->data[num_output_dims - 2];
99   const int col_size = output->dims->data[num_output_dims - 1];
100   switch (output->type) {
101     case kTfLiteInt64: {
102       return FillDiag<int64_t>(input, output, batch_size, row_size, col_size);
103     }
104     case kTfLiteInt32: {
105       return FillDiag<int32_t>(input, output, batch_size, row_size, col_size);
106     }
107     case kTfLiteInt16: {
108       return FillDiag<int16_t>(input, output, batch_size, row_size, col_size);
109     }
110     case kTfLiteInt8: {
111       return FillDiag<int8_t>(input, output, batch_size, row_size, col_size);
112     }
113     case kTfLiteUInt8: {
114       return FillDiag<uint8_t>(input, output, batch_size, row_size, col_size);
115     }
116     default:
117       return FillDiag<float>(input, output, batch_size, row_size, col_size);
118   }
119 }
120 
Eval(TfLiteContext * context,TfLiteNode * node)121 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
122   TfLiteTensor* output;
123   TF_LITE_ENSURE_OK(context,
124                     GetOutputSafe(context, node, kOutputTensor, &output));
125   const TfLiteTensor* input;
126   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
127   FillDiagHelper(input, output);
128   return kTfLiteOk;
129 }
130 
131 }  // namespace matrix_diag
132 
Register_MATRIX_DIAG()133 TfLiteRegistration* Register_MATRIX_DIAG() {
134   static TfLiteRegistration r = {nullptr, nullptr, matrix_diag::Prepare,
135                                  matrix_diag::Eval};
136   return &r;
137 }
138 
139 }  // namespace builtin
140 }  // namespace ops
141 }  // namespace tflite
142