1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
19 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
20 #include "tensorflow/lite/kernels/internal/tensor.h"
21 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23
24 namespace tflite {
25 namespace ops {
26 namespace builtin {
27 namespace matrix_diag {
28
29 constexpr int kInputTensor = 0;
30 constexpr int kOutputTensor = 0;
31
Prepare(TfLiteContext * context,TfLiteNode * node)32 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
33 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
34 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
35 const TfLiteTensor* input;
36 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
37 TfLiteIntArray* input_dims = input->dims;
38 int input_dims_size = input_dims->size;
39 TF_LITE_ENSURE(context, input_dims_size >= 1);
40
41 TfLiteTensor* output;
42 TF_LITE_ENSURE_OK(context,
43 GetOutputSafe(context, node, kOutputTensor, &output));
44 // Resize the output tensor.
45 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(input_dims_size + 1);
46 for (int i = 0; i < input_dims_size; i++) {
47 output_shape->data[i] = input_dims->data[i];
48 }
49 // Last dimension in the output is the same as the last dimension in the
50 // input.
51 output_shape->data[input_dims_size] = input_dims->data[input_dims_size - 1];
52 output->type = input->type;
53 TF_LITE_ENSURE_OK(context,
54 context->ResizeTensor(context, output, output_shape));
55
56 return kTfLiteOk;
57 }
58
59 // Fill the tensor to make a diagonal matrix in each batch, i.e., when
60 // row index and column index are the same, fill with the next input value.
61 // All other entries get zero.
62 // TODO(b/128636574) Move to reference_ops.
63 template <typename T>
FillDiagImpl(const T * in,T * out,const int batch_size,const int row_size,const int col_size)64 void FillDiagImpl(const T* in, T* out, const int batch_size, const int row_size,
65 const int col_size) {
66 int idx = 0;
67 for (int b = 0; b < batch_size; b++) {
68 for (int i = 0; i < row_size; i++) {
69 for (int j = 0; j < col_size; ++j) {
70 // input values go on the diagonal, 0 elsewhere
71 if (i == j) {
72 out[i * col_size + j] = in[idx];
73 idx++;
74 } else {
75 out[i * col_size + j] = 0;
76 }
77 }
78 }
79 out += row_size * col_size;
80 }
81 }
82
83 template <typename T>
FillDiag(const TfLiteTensor * input,TfLiteTensor * output,const int batch_size,const int row_size,const int col_size)84 void FillDiag(const TfLiteTensor* input, TfLiteTensor* output,
85 const int batch_size, const int row_size, const int col_size) {
86 FillDiagImpl<T>(GetTensorData<T>(input), GetTensorData<T>(output), batch_size,
87 row_size, col_size);
88 }
89
90 // Fill a tensor with given input on the diagonal, zero elsewhere
FillDiagHelper(const TfLiteTensor * input,TfLiteTensor * output)91 void FillDiagHelper(const TfLiteTensor* input, TfLiteTensor* output) {
92 const int num_output_dims = output->dims->size;
93 int batch_size = 1;
94 for (int i = 0; i < num_output_dims - 2; ++i) {
95 batch_size *= output->dims->data[i];
96 }
97
98 const int row_size = output->dims->data[num_output_dims - 2];
99 const int col_size = output->dims->data[num_output_dims - 1];
100 switch (output->type) {
101 case kTfLiteInt64: {
102 return FillDiag<int64_t>(input, output, batch_size, row_size, col_size);
103 }
104 case kTfLiteInt32: {
105 return FillDiag<int32_t>(input, output, batch_size, row_size, col_size);
106 }
107 case kTfLiteInt16: {
108 return FillDiag<int16_t>(input, output, batch_size, row_size, col_size);
109 }
110 case kTfLiteInt8: {
111 return FillDiag<int8_t>(input, output, batch_size, row_size, col_size);
112 }
113 case kTfLiteUInt8: {
114 return FillDiag<uint8_t>(input, output, batch_size, row_size, col_size);
115 }
116 default:
117 return FillDiag<float>(input, output, batch_size, row_size, col_size);
118 }
119 }
120
Eval(TfLiteContext * context,TfLiteNode * node)121 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
122 TfLiteTensor* output;
123 TF_LITE_ENSURE_OK(context,
124 GetOutputSafe(context, node, kOutputTensor, &output));
125 const TfLiteTensor* input;
126 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
127 FillDiagHelper(input, output);
128 return kTfLiteOk;
129 }
130
131 } // namespace matrix_diag
132
Register_MATRIX_DIAG()133 TfLiteRegistration* Register_MATRIX_DIAG() {
134 static TfLiteRegistration r = {nullptr, nullptr, matrix_diag::Prepare,
135 matrix_diag::Eval};
136 return &r;
137 }
138
139 } // namespace builtin
140 } // namespace ops
141 } // namespace tflite
142