1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/hexagon/builders/concat_builder.h"
16
17 #include <stdint.h>
18
19 #include <limits>
20
21 #include "tensorflow/lite/c/builtin_op_data.h"
22 #include "tensorflow/lite/delegates/hexagon/hexagon_nn/hexagon_nn.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24
25 namespace tflite {
26 namespace delegates {
27 namespace hexagon {
PopulateSubGraph(const TfLiteIntArray * inputs,const TfLiteIntArray * outputs,TfLiteContext * context)28 TfLiteStatus ConcatOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
29 const TfLiteIntArray* outputs,
30 TfLiteContext* context) {
31 const TfLiteConcatenationParams* concat_params =
32 reinterpret_cast<const TfLiteConcatenationParams*>(builtin_data_);
33 int concat_axis = concat_params->axis;
34 const int output_dim_size = context->tensors[outputs->data[0]].dims->size;
35 // Axis value is incremented if tensor dims are < 4 and/or axis < 0.
36 concat_axis =
37 concat_axis < 0 ? concat_axis + 4 : concat_axis + 4 - output_dim_size;
38 auto* axis_const = graph_builder_->AddConstNodeWithData(
39 kScalarShape, reinterpret_cast<char*>(&concat_axis), sizeof(concat_axis));
40 AddInput(TensorID(axis_const->GetID(), 0));
41
42 int tensor_id;
43
44 // Input data tensors.
45 // input_bound_minimum & input_bound_maximum track the minimum & maximum
46 // min/max bounds across all inputs.
47 float input_bound_minimum = std::numeric_limits<float>::max();
48 float input_bound_maximum = std::numeric_limits<float>::min();
49 input_minima_.reserve(inputs->size);
50 input_maxima_.reserve(inputs->size);
51 for (int i = 0; i < inputs->size; ++i) {
52 tensor_id = inputs->data[i];
53 float data_min, data_max;
54 const auto& data_tensor = context->tensors[tensor_id];
55 AddInput(graph_builder_->GetHexagonTensorId(tensor_id));
56 TF_LITE_ENSURE_STATUS(
57 ComputeMinAndMaxQuantValues(data_tensor, &data_min, &data_max));
58 input_minima_.push_back(data_min);
59 input_maxima_.push_back(data_max);
60 if (data_min < input_bound_minimum) input_bound_minimum = data_min;
61 if (data_max > input_bound_maximum) input_bound_maximum = data_max;
62 }
63
64 // Minima tensors.
65 for (int i = 0; i < input_minima_.size(); ++i) {
66 auto* data_min_const = graph_builder_->AddConstNodeWithData(
67 kScalarShape, reinterpret_cast<char*>(&input_minima_[i]),
68 sizeof(input_minima_[i]));
69 AddInput(TensorID(data_min_const->GetID(), 0));
70 }
71
72 // Maxima tensors.
73 for (int i = 0; i < input_minima_.size(); ++i) {
74 auto* data_max_const = graph_builder_->AddConstNodeWithData(
75 kScalarShape, reinterpret_cast<char*>(&input_maxima_[i]),
76 sizeof(input_maxima_[i]));
77 AddInput(TensorID(data_max_const->GetID(), 0));
78 }
79
80 // Hexagon outputs for this node.
81 int output_batch_size, output_height_size, output_width_size,
82 output_depth_size;
83 GetDims(&output_batch_size, &output_height_size, &output_width_size,
84 &output_depth_size, context->tensors[outputs->data[0]].dims);
85
86 // We requantize the output from concat to the range expected by TFLite.
87 // Otherwise, we see accuracy issues for cases where the inputs have different
88 // min/max bounds.
89 TensorID concat_out = AddOutput(sizeof(uint8_t), 4,
90 {output_batch_size, output_height_size,
91 output_width_size, output_depth_size});
92 const auto& concat_out_min = AddOutput(sizeof(float), 4, {1, 1, 1, 1});
93 const auto& concat_out_max = AddOutput(sizeof(float), 4, {1, 1, 1, 1});
94
95 // Output min/max for requantization.
96 TF_LITE_ENSURE_STATUS(ComputeMinAndMaxQuantValues(
97 context->tensors[outputs->data[0]], &output_min_, &output_max_));
98 auto* output_min_const = graph_builder_->AddConstNodeWithData(
99 kScalarShape, (char*)&output_min_, sizeof(output_min_));
100 auto* output_max_const = graph_builder_->AddConstNodeWithData(
101 kScalarShape, (char*)&output_max_, sizeof(output_max_));
102
103 if (output_min_ == input_bound_minimum &&
104 output_max_ == input_bound_maximum) {
105 // If the input min/max (across all tensors) is same as the output min/max,
106 // Hexagon's Requantize causes errors in InceptionV3.
107 // TODO(b/150137234): Figure out why this is.
108 node_output_ = concat_out;
109 } else {
110 auto* requantize_op = graph_builder_->AddNode(GetTFLiteNodeID());
111 requantize_op->SetOpType(OP_Requantize_8to8);
112 requantize_op->AddInput(concat_out);
113 requantize_op->AddInput(concat_out_min);
114 requantize_op->AddInput(concat_out_max);
115 requantize_op->AddInput(TensorID(output_min_const->GetID(), 0));
116 requantize_op->AddInput(TensorID(output_max_const->GetID(), 0));
117 node_output_ =
118 requantize_op->AddOutput(sizeof(uint8_t), 4,
119 {output_batch_size, output_height_size,
120 output_width_size, output_depth_size});
121 requantize_op->AddOutput(sizeof(float), 4, kScalarShape);
122 requantize_op->AddOutput(sizeof(float), 4, kScalarShape);
123 }
124
125 return kTfLiteOk;
126 }
127
RegisterOutputs(const TfLiteIntArray * outputs,TfLiteContext * context)128 TfLiteStatus ConcatOpBuilder::RegisterOutputs(const TfLiteIntArray* outputs,
129 TfLiteContext* context) {
130 // Should be only 1 output.
131 graph_builder_->AddTensorWithID(outputs->data[0], node_output_.first,
132 node_output_.second);
133 return kTfLiteOk;
134 }
135
~ConcatOpBuilder()136 ConcatOpBuilder::~ConcatOpBuilder() {}
137
CreateConcatBuilder(GraphBuilder * graph_builder,int op_type)138 OpBuilder* CreateConcatBuilder(GraphBuilder* graph_builder, int op_type) {
139 return new ConcatOpBuilder(graph_builder, op_type);
140 }
141
142 } // namespace hexagon
143 } // namespace delegates
144 } // namespace tflite
145