1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/reference/broadcast_to.h"
16 
17 #include <string.h>
18 
19 #include <cstdint>
20 #include <memory>
21 
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/kernels/internal/tensor.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25 
26 namespace tflite {
27 namespace ops {
28 namespace builtin {
29 namespace broadcastto {
30 
31 constexpr int kInputTensor = 0;
32 constexpr int kShapeTensor = 1;
33 constexpr int kOutputTensor = 0;
34 constexpr int kMaxDims = 8;
35 
36 struct BroadcastToContext {
BroadcastToContexttflite::ops::builtin::broadcastto::BroadcastToContext37   BroadcastToContext(TfLiteContext* context, TfLiteNode* node) {
38     input = GetInput(context, node, kInputTensor);
39     shape = GetInput(context, node, kShapeTensor);
40     output = GetOutput(context, node, kOutputTensor);
41   }
42   const TfLiteTensor* input;
43   const TfLiteTensor* shape;
44   TfLiteTensor* output;
45 };
46 
ResizeOutputTensor(TfLiteContext * context,BroadcastToContext * op_context)47 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
48                                 BroadcastToContext* op_context) {
49   // Ensures the shape is 1D tensor.
50   TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->shape), 1);
51 
52   // Ensure output dims is not less than input dims.
53   int input_num_dims = NumDimensions(op_context->input);
54   int output_num_dims = SizeOfDimension(op_context->shape, 0);
55   TF_LITE_ENSURE_MSG(context, input_num_dims <= output_num_dims,
56                      "Output shape must be broadcastable from input shape.");
57   TF_LITE_ENSURE_MSG(context, output_num_dims <= kMaxDims,
58                      "BroadcastTo only supports 1-8D tensor.");
59 
60   // Check if output shape is broadcastable from input shape.
61   auto get_shape_data = [op_context](int i) -> int32_t {
62     if (op_context->shape->type == kTfLiteInt32) {
63       return GetTensorData<int32_t>(op_context->shape)[i];
64     } else {
65       return GetTensorData<int64_t>(op_context->shape)[i];
66     }
67   };
68 
69   int extending_dims = output_num_dims - input_num_dims;
70   for (int idx = 0; idx < input_num_dims; ++idx) {
71     TF_LITE_ENSURE_MSG(context,
72                        (SizeOfDimension(op_context->input, idx) == 1 ||
73                         SizeOfDimension(op_context->input, idx) ==
74                             get_shape_data(extending_dims + idx)),
75                        "Output shape must be broadcastable from input shape.");
76   }
77   // Resizing the shape of the output tensor.
78   TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_num_dims);
79   std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)>
80       scoped_output_shape(output_shape, TfLiteIntArrayFree);
81   for (int idx = 0; idx < output_num_dims; ++idx) {
82     output_shape->data[idx] = get_shape_data(idx);
83   }
84 
85   return context->ResizeTensor(context, op_context->output,
86                                scoped_output_shape.release());
87 }
88 
Prepare(TfLiteContext * context,TfLiteNode * node)89 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
90   TF_LITE_ENSURE(context, NumInputs(node) == 2);
91   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
92   TF_LITE_ENSURE_MSG(context,
93                      (NumDimensions(GetInput(context, node, 0)) <= kMaxDims),
94                      "BroadcastTo only supports 1-8D tensor.");
95 
96   BroadcastToContext op_context(context, node);
97   TF_LITE_ENSURE(context, op_context.shape->type == kTfLiteInt32 ||
98                               op_context.shape->type == kTfLiteInt64);
99   TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
100 
101   // Not yet support string type due to the use of memcopy with fixed size.
102   TF_LITE_ENSURE(context, op_context.input->type != kTfLiteString);
103 
104   if (IsConstantTensor(op_context.shape)) {
105     return ResizeOutputTensor(context, &op_context);
106   }
107 
108   SetTensorToDynamic(op_context.output);
109   return kTfLiteOk;
110 }
111 
Eval(TfLiteContext * context,TfLiteNode * node)112 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
113   BroadcastToContext op_context(context, node);
114   if (IsDynamicTensor(op_context.output)) {
115     TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
116   }
117 
118   // BroadcastTo op support upto 8 dims, matching the support of Tensorflow.
119   reference_ops::BroadcastTo<kMaxDims>(
120       GetTensorShape(op_context.input), op_context.input->data.raw,
121       GetTensorShape(op_context.output), op_context.output->data.raw,
122       op_context.input->type);
123   return kTfLiteOk;
124 }
125 
126 }  // namespace broadcastto
127 
Register_BROADCAST_TO()128 TfLiteRegistration* Register_BROADCAST_TO() {
129   static TfLiteRegistration r = {nullptr, nullptr, broadcastto::Prepare,
130                                  broadcastto::Eval};
131   return &r;
132 }
133 
134 }  // namespace builtin
135 }  // namespace ops
136 }  // namespace tflite
137