1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/reference/broadcast_to.h"
16
17 #include <string.h>
18
19 #include <cstdint>
20 #include <memory>
21
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/kernels/internal/tensor.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25
26 namespace tflite {
27 namespace ops {
28 namespace builtin {
29 namespace broadcastto {
30
31 constexpr int kInputTensor = 0;
32 constexpr int kShapeTensor = 1;
33 constexpr int kOutputTensor = 0;
34 constexpr int kMaxDims = 8;
35
36 struct BroadcastToContext {
BroadcastToContexttflite::ops::builtin::broadcastto::BroadcastToContext37 BroadcastToContext(TfLiteContext* context, TfLiteNode* node) {
38 input = GetInput(context, node, kInputTensor);
39 shape = GetInput(context, node, kShapeTensor);
40 output = GetOutput(context, node, kOutputTensor);
41 }
42 const TfLiteTensor* input;
43 const TfLiteTensor* shape;
44 TfLiteTensor* output;
45 };
46
ResizeOutputTensor(TfLiteContext * context,BroadcastToContext * op_context)47 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
48 BroadcastToContext* op_context) {
49 // Ensures the shape is 1D tensor.
50 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->shape), 1);
51
52 // Ensure output dims is not less than input dims.
53 int input_num_dims = NumDimensions(op_context->input);
54 int output_num_dims = SizeOfDimension(op_context->shape, 0);
55 TF_LITE_ENSURE_MSG(context, input_num_dims <= output_num_dims,
56 "Output shape must be broadcastable from input shape.");
57 TF_LITE_ENSURE_MSG(context, output_num_dims <= kMaxDims,
58 "BroadcastTo only supports 1-8D tensor.");
59
60 // Check if output shape is broadcastable from input shape.
61 auto get_shape_data = [op_context](int i) -> int32_t {
62 if (op_context->shape->type == kTfLiteInt32) {
63 return GetTensorData<int32_t>(op_context->shape)[i];
64 } else {
65 return GetTensorData<int64_t>(op_context->shape)[i];
66 }
67 };
68
69 int extending_dims = output_num_dims - input_num_dims;
70 for (int idx = 0; idx < input_num_dims; ++idx) {
71 TF_LITE_ENSURE_MSG(context,
72 (SizeOfDimension(op_context->input, idx) == 1 ||
73 SizeOfDimension(op_context->input, idx) ==
74 get_shape_data(extending_dims + idx)),
75 "Output shape must be broadcastable from input shape.");
76 }
77 // Resizing the shape of the output tensor.
78 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_num_dims);
79 std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)>
80 scoped_output_shape(output_shape, TfLiteIntArrayFree);
81 for (int idx = 0; idx < output_num_dims; ++idx) {
82 output_shape->data[idx] = get_shape_data(idx);
83 }
84
85 return context->ResizeTensor(context, op_context->output,
86 scoped_output_shape.release());
87 }
88
Prepare(TfLiteContext * context,TfLiteNode * node)89 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
90 TF_LITE_ENSURE(context, NumInputs(node) == 2);
91 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
92 TF_LITE_ENSURE_MSG(context,
93 (NumDimensions(GetInput(context, node, 0)) <= kMaxDims),
94 "BroadcastTo only supports 1-8D tensor.");
95
96 BroadcastToContext op_context(context, node);
97 TF_LITE_ENSURE(context, op_context.shape->type == kTfLiteInt32 ||
98 op_context.shape->type == kTfLiteInt64);
99 TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
100
101 // Not yet support string type due to the use of memcopy with fixed size.
102 TF_LITE_ENSURE(context, op_context.input->type != kTfLiteString);
103
104 if (IsConstantTensor(op_context.shape)) {
105 return ResizeOutputTensor(context, &op_context);
106 }
107
108 SetTensorToDynamic(op_context.output);
109 return kTfLiteOk;
110 }
111
Eval(TfLiteContext * context,TfLiteNode * node)112 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
113 BroadcastToContext op_context(context, node);
114 if (IsDynamicTensor(op_context.output)) {
115 TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
116 }
117
118 // BroadcastTo op support upto 8 dims, matching the support of Tensorflow.
119 reference_ops::BroadcastTo<kMaxDims>(
120 GetTensorShape(op_context.input), op_context.input->data.raw,
121 GetTensorShape(op_context.output), op_context.output->data.raw,
122 op_context.input->type);
123 return kTfLiteOk;
124 }
125
126 } // namespace broadcastto
127
Register_BROADCAST_TO()128 TfLiteRegistration* Register_BROADCAST_TO() {
129 static TfLiteRegistration r = {nullptr, nullptr, broadcastto::Prepare,
130 broadcastto::Eval};
131 return &r;
132 }
133
134 } // namespace builtin
135 } // namespace ops
136 } // namespace tflite
137