1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <stddef.h>
17
18 #include <cstring>
19 #include <memory>
20 #include <vector>
21
22 #include "tensorflow/lite/c/builtin_op_data.h"
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/core/subgraph.h"
25 #include "tensorflow/lite/kernels/internal/compatibility.h"
26 #include "tensorflow/lite/kernels/kernel_util.h"
27
28 namespace tflite {
29 namespace ops {
30 namespace builtin {
31 namespace if_kernel {
32
33 struct OpData {
34 int then_subgraph_index;
35 int else_subgraph_index;
36 };
37
Init(TfLiteContext * context,const char * buffer,size_t length)38 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
39 auto* op_data = new OpData;
40 const auto* params = reinterpret_cast<const TfLiteIfParams*>(buffer);
41 op_data->then_subgraph_index = params->then_subgraph_index;
42 op_data->else_subgraph_index = params->else_subgraph_index;
43 return op_data;
44 }
45
Free(TfLiteContext * context,void * buffer)46 void Free(TfLiteContext* context, void* buffer) {
47 delete reinterpret_cast<OpData*>(buffer);
48 }
49
Prepare(TfLiteContext * context,TfLiteNode * node)50 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
51 const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
52
53 TF_LITE_ENSURE(context, node->inputs->size > 0);
54
55 // The first input is the condition.
56 const TfLiteTensor* cond;
57 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &cond));
58 // Currently only bool is supported.
59 // TODO(ycling): Support other types since TensorFlow also support
60 // non-bool types as condition.
61 TF_LITE_ENSURE_EQ(context, cond->type, kTfLiteBool);
62 TF_LITE_ENSURE_EQ(context, NumElements(cond), 1);
63
64 // The first input of the node is the condition. The rest of inputs are
65 // passed to the branch subgraphs. Therefore, the number of subgraph inputs
66 // will be the number of node inputs - 1.
67 int num_inputs = node->inputs->size - 1;
68 int num_outputs = node->outputs->size;
69
70 Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
71 auto* subgraphs = this_subgraph->GetSubgraphs();
72 TF_LITE_ENSURE(context, op_data->then_subgraph_index < subgraphs->size());
73 TF_LITE_ENSURE(context, op_data->else_subgraph_index < subgraphs->size());
74
75 Subgraph* then_subgraph = (*subgraphs)[op_data->then_subgraph_index].get();
76 Subgraph* else_subgraph = (*subgraphs)[op_data->else_subgraph_index].get();
77
78 for (auto* subgraph : {then_subgraph, else_subgraph}) {
79 TF_LITE_ENSURE_EQ(context, num_inputs, subgraph->inputs().size());
80 TF_LITE_ENSURE_EQ(context, num_outputs, subgraph->outputs().size());
81 }
82
83 bool has_dynamic_output_tensors = false;
84 for (auto* subgraph : {then_subgraph, else_subgraph}) {
85 for (int i = 0; i < num_inputs; ++i) {
86 // The first input of the node is the condition. The indices of the inputs
87 // passed to the subgraphs are offset by 1.
88 const TfLiteTensor* input;
89 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i + 1, &input));
90 std::vector<int> dims(input->dims->data,
91 input->dims->data + input->dims->size);
92 subgraph->ResizeInputTensor(i, dims);
93 TfLiteTensor* subgraph_input = subgraph->tensor(subgraph->inputs()[i]);
94 TF_LITE_ENSURE_TYPES_EQ(context, input->type, subgraph_input->type);
95 }
96 // Note: The `Prepare` function is responsible to run `AllocateTensors` on
97 // both subgraphs. It's intentionally not to break out of the loop when
98 // finding a dynamic output tensor.
99 TF_LITE_ENSURE_OK(context, subgraph->AllocateTensors());
100 has_dynamic_output_tensors |= subgraph->HasDynamicTensors();
101 }
102
103 if (!has_dynamic_output_tensors) {
104 for (int i = 0; i < num_outputs; ++i) {
105 TfLiteTensor* then_output =
106 then_subgraph->tensor(then_subgraph->outputs()[i]);
107 TfLiteTensor* else_output =
108 else_subgraph->tensor(else_subgraph->outputs()[i]);
109 // If the 2 subgraphs have static but different output shapes, the output
110 // tensors of the IF op have dynamic sizes.
111 if (!TfLiteIntArrayEqual(then_output->dims, else_output->dims)) {
112 has_dynamic_output_tensors = true;
113 break;
114 }
115 }
116 }
117
118 for (int i = 0; i < num_outputs; ++i) {
119 TfLiteTensor* output;
120 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
121 if (has_dynamic_output_tensors) {
122 SetTensorToDynamic(output);
123 } else {
124 // When there's no dynamic output tensors, the 2 subgraph has exactly
125 // the same static sized outputs.
126 TfLiteTensor* then_output =
127 then_subgraph->tensor(then_subgraph->outputs()[i]);
128 TfLiteIntArray* output_size = TfLiteIntArrayCopy(then_output->dims);
129 TF_LITE_ENSURE_OK(context,
130 context->ResizeTensor(context, output, output_size));
131 }
132 }
133
134 return kTfLiteOk;
135 }
136
Eval(TfLiteContext * context,TfLiteNode * node)137 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
138 const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
139
140 const TfLiteTensor* cond;
141 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &cond));
142 bool cond_value = cond->data.b[0];
143
144 Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
145 auto* subgraphs = this_subgraph->GetSubgraphs();
146
147 // Currently we copy the input / output between the subgraphs. This isn't
148 // optimized yet.
149 // TODO(b/120234921): Optimize and avoid copying tensors between subgraphs.
150 int active_branch_subgraph_index =
151 cond_value ? op_data->then_subgraph_index : op_data->else_subgraph_index;
152 Subgraph& active_branch_subgraph =
153 *(*subgraphs)[active_branch_subgraph_index];
154 for (int i = 0; i < active_branch_subgraph.inputs().size(); ++i) {
155 const TfLiteTensor* input;
156 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i + 1, &input));
157 TfLiteTensor* subgraph_input =
158 active_branch_subgraph.tensor(active_branch_subgraph.inputs()[i]);
159
160 if (IsDynamicTensor(subgraph_input)) {
161 TfLiteTensorRealloc(input->bytes, subgraph_input);
162 }
163
164 TF_LITE_ENSURE_EQ(context, input->bytes, subgraph_input->bytes);
165 memcpy(subgraph_input->data.raw, input->data.raw, input->bytes);
166 }
167
168 // Note: It's guaranteed that the subgraphs' `AllocateTensors` are called
169 // in `Prepare`, so we don't need to do it here again.
170 TF_LITE_ENSURE_OK(context, active_branch_subgraph.Invoke());
171
172 for (int tensor_index : active_branch_subgraph.outputs()) {
173 active_branch_subgraph.EnsureTensorDataIsReadable(tensor_index);
174 }
175
176 bool has_dynamic_output_tensors = false;
177 for (int i = 0; i < node->outputs->size; ++i) {
178 TfLiteTensor* output;
179 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
180 if (IsDynamicTensor(output)) {
181 has_dynamic_output_tensors = true;
182 break;
183 }
184 }
185
186 if (has_dynamic_output_tensors) {
187 for (int i = 0; i < node->outputs->size; ++i) {
188 TfLiteTensor* output;
189 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
190 TfLiteTensor* subgraph_output =
191 active_branch_subgraph.tensor(active_branch_subgraph.outputs()[i]);
192 TfLiteIntArray* output_size = TfLiteIntArrayCopy(subgraph_output->dims);
193 TF_LITE_ENSURE_OK(context,
194 context->ResizeTensor(context, output, output_size));
195 }
196 }
197
198 for (int i = 0; i < active_branch_subgraph.outputs().size(); ++i) {
199 const TfLiteTensor* subgraph_output =
200 active_branch_subgraph.tensor(active_branch_subgraph.outputs()[i]);
201 TfLiteTensor* output;
202 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
203
204 if (IsDynamicTensor(output)) {
205 TfLiteTensorRealloc(subgraph_output->bytes, output);
206 }
207
208 TF_LITE_ENSURE_EQ(context, output->bytes, subgraph_output->bytes);
209 memcpy(output->data.raw, subgraph_output->data.raw, output->bytes);
210 }
211 return kTfLiteOk;
212 }
213
214 } // namespace if_kernel
215
Register_IF()216 TfLiteRegistration* Register_IF() {
217 static TfLiteRegistration r = {if_kernel::Init, if_kernel::Free,
218 if_kernel::Prepare, if_kernel::Eval};
219 return &r;
220 }
221
222 } // namespace builtin
223 } // namespace ops
224 } // namespace tflite
225