1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 
17 #include "tensorflow/lite/delegates/hexagon/builders/op_builder.h"
18 
19 namespace tflite {
20 namespace delegates {
21 namespace hexagon {
22 
23 // Adds Rsqrt op to the Hexagon graph by constructing
24 // 1/Sqrt(input).
25 class RsqrtOpBuilder : public OpBuilder {
26  public:
RsqrtOpBuilder(GraphBuilder * graph_builder,int op_type)27   explicit RsqrtOpBuilder(GraphBuilder* graph_builder, int op_type)
28       : OpBuilder(graph_builder, op_type) {}
29   TfLiteStatus PopulateSubGraph(const TfLiteIntArray* inputs,
30                                 const TfLiteIntArray* outputs,
31                                 TfLiteContext* context) override;
32 
33   TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs,
34                                TfLiteContext* context) override;
35 
36  private:
37   void AddNumerator();
38 
39   TensorID node_output_;
40   TensorID numerator_out_;
41   TensorID numerator_min_;
42   TensorID numerator_max_;
43   // Total number of elements in the input tensor.
44   int num_elements_;
45 };
46 
AddNumerator()47 void RsqrtOpBuilder::AddNumerator() {
48   // Numerator is a constant with value 1. We add it as float and quantize it.
49   std::vector<uint8_t> numerator;
50   // Hexagon NN Div implementation assumes output to be of shape as first
51   // input, so it doesn't broadcast.
52   // So here we create the constant numerator with value 1 to be of same
53   // flattened shape as the denominator.
54   numerator.resize(num_elements_);
55   int flat_shape[] = {1, 1, 1, num_elements_};
56   std::fill(numerator.begin(), numerator.end(), 0);
57   float kNumeratorMin = 1.0, kNumeratorMax = 1.0;
58   auto* const_node = graph_builder_->AddConstNodeWithData(
59       flat_shape, reinterpret_cast<char*>(numerator.data()),
60       sizeof(numerator[0]) * numerator.size());
61   auto* numerator_min_const = graph_builder_->AddConstNodeWithData(
62       kScalarShape, reinterpret_cast<char*>(&kNumeratorMin),
63       sizeof(kNumeratorMin));
64   auto* numerator_max_const = graph_builder_->AddConstNodeWithData(
65       kScalarShape, reinterpret_cast<char*>(&kNumeratorMax),
66       sizeof(kNumeratorMax));
67   numerator_out_ = TensorID(const_node->GetID(), 0);
68   numerator_min_ = TensorID(numerator_min_const->GetID(), 0);
69   numerator_max_ = TensorID(numerator_max_const->GetID(), 0);
70 }
71 
RegisterOutputs(const TfLiteIntArray * outputs,TfLiteContext * context)72 TfLiteStatus RsqrtOpBuilder::RegisterOutputs(const TfLiteIntArray* outputs,
73                                              TfLiteContext* context) {
74   graph_builder_->AddTensorWithID(outputs->data[0], node_output_.first,
75                                   node_output_.second);
76   return kTfLiteOk;
77 }
78 
PopulateSubGraph(const TfLiteIntArray * inputs,const TfLiteIntArray * outputs,TfLiteContext * context)79 TfLiteStatus RsqrtOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
80                                               const TfLiteIntArray* outputs,
81                                               TfLiteContext* context) {
82   const int tensor_id = inputs->data[0];
83   const auto& tensor = context->tensors[tensor_id];
84   float min_value = 0;
85   float max_value = 0;
86   int batch_size, height_size, width_size, depth_size;
87   GetDims(&batch_size, &height_size, &width_size, &depth_size, tensor.dims);
88   TF_LITE_ENSURE_STATUS(
89       ComputeMinAndMaxQuantValues(tensor, &min_value, &max_value));
90   num_elements_ = batch_size * height_size * width_size * depth_size;
91   int flat_shape[] = {1, 1, 1, num_elements_};
92 
93   auto* min_const = graph_builder_->AddConstNodeWithData(
94       kScalarShape, reinterpret_cast<char*>(&min_value), sizeof(min_value));
95   auto* max_const = graph_builder_->AddConstNodeWithData(
96       kScalarShape, reinterpret_cast<char*>(&max_value), sizeof(max_value));
97   // Create SQRT op as denominator.
98   AddInput(graph_builder_->GetHexagonTensorId(tensor_id));
99   AddInput(TensorID(min_const->GetID(), 0));
100   AddInput(TensorID(max_const->GetID(), 0));
101   auto sqrt_output = AddOutput(
102       sizeof(uint8_t), 4, {batch_size, height_size, width_size, depth_size});
103   auto sqrt_output_min = AddOutput(sizeof(float), 4, kScalarShape);
104   auto sqrt_output_max = AddOutput(sizeof(float), 4, kScalarShape);
105 
106   // Reshape result of Sqrt to be [1,1,1,NumElements] since Hexagon Div
107   // has limitation on the shape of the tensor.
108   const int reshape_shape[] = {1, 1, 1, 4};
109   auto* target_shape_node = graph_builder_->AddConstNodeWithData(
110       reshape_shape, reinterpret_cast<char*>(flat_shape),
111       sizeof(flat_shape[0]) * 4);
112   auto* reshape_op = graph_builder_->AddNode(GetTFLiteNodeID());
113   reshape_op->SetOpType(OP_Reshape);
114   reshape_op->AddInput(sqrt_output);
115   reshape_op->AddInput(TensorID(target_shape_node->GetID(), 0));
116   auto reshape_out = reshape_op->AddOutput(sizeof(uint8_t), 4, flat_shape);
117 
118   // Create the numerator and add to the graph.
119   AddNumerator();
120 
121   // Fetch output details
122   float output_min = -1, output_max = 1;
123   // Output details.
124   TF_LITE_ENSURE_STATUS(ComputeMinAndMaxQuantValues(
125       context->tensors[outputs->data[0]], &output_min, &output_max));
126   auto* output_min_const = graph_builder_->AddConstNodeWithData(
127       kScalarShape, reinterpret_cast<char*>(&output_min), sizeof(output_min));
128   auto* output_max_const = graph_builder_->AddConstNodeWithData(
129       kScalarShape, reinterpret_cast<char*>(&output_max), sizeof(output_max));
130   int output_batch_size, output_height_size, output_width_size,
131       output_depth_size;
132   GetDims(&output_batch_size, &output_height_size, &output_width_size,
133           &output_depth_size, context->tensors[outputs->data[0]].dims);
134 
135   // Add Div op to compute 1/Sqrt
136   auto* div_op = graph_builder_->AddNode(GetTFLiteNodeID());
137   div_op->SetOpType(OP_QuantizedDiv_8);
138   div_op->AddInput(numerator_out_);
139   div_op->AddInput(reshape_out);
140   div_op->AddInput(numerator_min_);
141   div_op->AddInput(numerator_max_);
142   div_op->AddInput(sqrt_output_min);
143   div_op->AddInput(sqrt_output_max);
144   div_op->AddInput(TensorID(output_min_const->GetID(), 0));
145   div_op->AddInput(TensorID(output_max_const->GetID(), 0));
146 
147   auto div_output = div_op->AddOutput(sizeof(uint8_t), 4, flat_shape);
148   div_op->AddOutput(sizeof(float), 4, kScalarShape);
149   div_op->AddOutput(sizeof(float), 4, kScalarShape);
150 
151   // Reshape output back to the expected shape.
152   int output_shape[] = {output_batch_size, output_height_size,
153                         output_width_size, output_depth_size};
154   target_shape_node = graph_builder_->AddConstNodeWithData(
155       reshape_shape, reinterpret_cast<char*>(output_shape),
156       sizeof(output_shape[0]) * 4);
157 
158   reshape_op = graph_builder_->AddNode(GetTFLiteNodeID());
159   reshape_op->SetOpType(OP_Reshape);
160   reshape_op->AddInput(div_output);
161   reshape_op->AddInput(TensorID(target_shape_node->GetID(), 0));
162   node_output_ = reshape_op->AddOutput(sizeof(uint8_t), 4, output_shape);
163   return kTfLiteOk;
164 }
165 
CreateRSqrtOpBuilder(GraphBuilder * graph_builder,int op_type)166 OpBuilder* CreateRSqrtOpBuilder(GraphBuilder* graph_builder, int op_type) {
167   return new RsqrtOpBuilder(graph_builder, op_type);
168 }
169 
170 }  // namespace hexagon
171 }  // namespace delegates
172 }  // namespace tflite
173