1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/conv_weights_converter.h"
17 
18 #include <cstring>
19 #include <string>
20 
21 #include "tensorflow/lite/delegates/gpu/common/task/util.h"
22 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
23 
24 namespace tflite {
25 namespace gpu {
26 
ConverterToConvWeights(const OperationDef & definition,const WeightsDescription & weights_desc)27 ConverterToConvWeights::ConverterToConvWeights(
28     const OperationDef& definition, const WeightsDescription& weights_desc)
29     : GPUOperation(definition), weights_desc_(weights_desc) {
30   code_ = GetConverterToConvWeightsCode(definition_, weights_desc_);
31 }
32 
ConverterToConvWeights(ConverterToConvWeights && operation)33 ConverterToConvWeights::ConverterToConvWeights(
34     ConverterToConvWeights&& operation)
35     : GPUOperation(std::move(operation)),
36       weights_desc_(std::move(operation.weights_desc_)) {}
37 
operator =(ConverterToConvWeights && operation)38 ConverterToConvWeights& ConverterToConvWeights::operator=(
39     ConverterToConvWeights&& operation) {
40   if (this != &operation) {
41     weights_desc_ = std::move(operation.weights_desc_);
42     GPUOperation::operator=(std::move(operation));
43   }
44   return *this;
45 }
46 
GetConverterToConvWeightsCode(const OperationDef & op_def,const WeightsDescription & conv_weights_desc)47 std::string ConverterToConvWeights::GetConverterToConvWeightsCode(
48     const OperationDef& op_def, const WeightsDescription& conv_weights_desc) {
49   AddSrcTensor("src_tensor", op_def.src_tensors[0]);
50   args_.AddFloat("mask_x");
51   args_.AddFloat("mask_y");
52   args_.AddFloat("mask_z");
53   args_.AddFloat("mask_w");
54 
55   if (conv_weights_desc.layout == WeightsLayout::kOICustomSpatialI4O4 ||
56       conv_weights_desc.layout == WeightsLayout::kOICustomSpatialO4I4) {
57     std::vector<int32_t> remap(conv_weights_desc.spatial_remap.size());
58     for (int i = 0; i < remap.size(); ++i) {
59       remap[i] = conv_weights_desc.spatial_remap[i];
60     }
61     BufferDescriptor desc;
62     desc.element_type = DataType::INT32;
63     desc.element_size = 1;
64     desc.memory_type = MemoryType::GLOBAL;
65     desc.size = remap.size() * sizeof(int32_t);
66     desc.data.resize(desc.size);
67     std::memcpy(desc.data.data(), remap.data(), desc.size);
68     args_.AddObject("spatial_remap",
69                     absl::make_unique<BufferDescriptor>(std::move(desc)));
70   }
71 
72   std::string c;
73   c += "MAIN_FUNCTION($0) {\n";
74   c += "  int O = GLOBAL_ID_0 * 4;\n";
75   c += "  int I = GLOBAL_ID_1;\n";
76   c += "  int Z = GLOBAL_ID_2;\n";
77   c += "  int W = Z % args.src_tensor.Width();\n";
78   c += "  int H = Z / args.src_tensor.Width();\n";
79   c += "  if (O >= args.src_tensor.Batch() || I >= args.src_tensor.Slices() || "
80        "H >= args.src_tensor.Height()) return;\n";
81   std::string x_kern = "W";
82   std::string y_kern = "H";
83   if (conv_weights_desc.layout == WeightsLayout::kOICustomSpatialI4O4 ||
84       conv_weights_desc.layout == WeightsLayout::kOICustomSpatialO4I4) {
85     c += "  int spatial_linear = H * args.src_tensor.Width() + W;\n";
86     c += "  int linear_remap = args.spatial_remap.Read(spatial_linear);\n";
87     c += "  int w_remap = linear_remap % args.src_tensor.Width();\n";
88     c += "  int h_remap = linear_remap / args.src_tensor.Width();\n";
89     x_kern = "w_remap";
90     y_kern = "h_remap";
91   }
92   const std::string coords = x_kern + ", " + y_kern;
93   c += "  FLT4 v0 = args.src_tensor.Read(" + coords + ", I, O + 0);\n";
94   c += "  FLT4 v1 = INIT_FLT4(0.0f);\n";
95   c += "  FLT4 v2 = INIT_FLT4(0.0f);\n";
96   c += "  FLT4 v3 = INIT_FLT4(0.0f);\n";
97   c += "  if (O + 1 < args.src_tensor.Batch()) {\n";
98   c += "    v1 = args.src_tensor.Read(" + coords + ", I, O + 1);\n";
99   c += "  }\n";
100   c += "  if (O + 2 < args.src_tensor.Batch()) {\n";
101   c += "    v2 = args.src_tensor.Read(" + coords + ", I, O + 2);\n";
102   c += "  }\n";
103   c += "  if (O + 3 < args.src_tensor.Batch()) {\n";
104   c += "    v3 = args.src_tensor.Read(" + coords + ", I, O + 3);\n";
105   c += "  }\n";
106   c += "  if (I == args.src_tensor.Slices() - 1) {\n";
107   c += "    FLT4 mask = INIT_FLT4v4(args.mask_x, args.mask_y, args.mask_z, "
108        "args.mask_w);\n";
109   c += "    v0 *= mask;\n";
110   c += "    v1 *= mask;\n";
111   c += "    v2 *= mask;\n";
112   c += "    v3 *= mask;\n";
113   c += "  }\n";
114   if (conv_weights_desc.IsI4O4()) {
115     c += "  FLT4 r0 = INIT_FLT4v4(v0.x, v1.x, v2.x, v3.x);\n";
116     c += "  FLT4 r1 = INIT_FLT4v4(v0.y, v1.y, v2.y, v3.y);\n";
117     c += "  FLT4 r2 = INIT_FLT4v4(v0.z, v1.z, v2.z, v3.z);\n";
118     c += "  FLT4 r3 = INIT_FLT4v4(v0.w, v1.w, v2.w, v3.w);\n";
119   } else if (conv_weights_desc.IsO4I4()) {
120     c += "  FLT4 r0 = v0;\n";
121     c += "  FLT4 r1 = v1;\n";
122     c += "  FLT4 r2 = v2;\n";
123     c += "  FLT4 r3 = v3;\n";
124   }
125   if (conv_weights_desc.layout == WeightsLayout::k2DX4I4YIsHWIAndXIsOOGroupO4 ||
126       conv_weights_desc.layout == WeightsLayout::k2DX4O4YIsHWIAndXIsOOGroupI4) {
127     // Writing to 4X Textures 2D
128     AddDstTensor("dst_tensor0", op_def.dst_tensors[0]);
129     AddDstTensor("dst_tensor1", op_def.dst_tensors[1]);
130     AddDstTensor("dst_tensor2", op_def.dst_tensors[2]);
131     AddDstTensor("dst_tensor3", op_def.dst_tensors[3]);
132     c += "  int yc = (H * args.src_tensor.Width() + W) * "
133          "args.src_tensor.Slices() + I\n;";
134     c += "  args.dst_tensor0.Write2D(r0, O / 4, yc)\n;";
135     c += "  args.dst_tensor1.Write2D(r1, O / 4, yc)\n;";
136     c += "  args.dst_tensor2.Write2D(r2, O / 4, yc)\n;";
137     c += "  args.dst_tensor3.Write2D(r3, O / 4, yc)\n;";
138     c += "}\n";
139   } else {
140     // Writing to linear buffer
141     AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
142     c += "  int GROUP_SIZE = " +
143          std::to_string(conv_weights_desc.GetOutputGroupSize()) + ";\n";
144     c += "  int d_index = O / (GROUP_SIZE * 4);\n";
145     c += "  int k_index = (O % (GROUP_SIZE * 4)) / 4;\n";
146     std::string index;
147     if (conv_weights_desc.layout == WeightsLayout::kOICustomSpatialI4O4 ||
148         conv_weights_desc.layout == WeightsLayout::kOICustomSpatialO4I4) {
149       index =
150           "((d_index * args.src_tensor.Slices() + I) * "
151           "args.src_tensor.Height() "
152           "+ H) * args.src_tensor.Width() + W";
153     } else if (conv_weights_desc.layout == WeightsLayout::kOHWIOGroupI4O4 ||
154                conv_weights_desc.layout == WeightsLayout::kOHWIOGroupO4I4) {
155       index =
156           "((d_index * args.src_tensor.Height() + H) * args.src_tensor.Width() "
157           "+ "
158           "W) * args.src_tensor.Slices() + I";
159     }
160     c += "  int dst_offset = (" + index + ") * GROUP_SIZE + k_index;\n";
161     c += "  args.dst_tensor.WriteLinear(r0, dst_offset * 4 + 0)\n;";
162     c += "  args.dst_tensor.WriteLinear(r1, dst_offset * 4 + 1)\n;";
163     c += "  args.dst_tensor.WriteLinear(r2, dst_offset * 4 + 2)\n;";
164     c += "  args.dst_tensor.WriteLinear(r3, dst_offset * 4 + 3)\n;";
165     c += "}\n";
166   }
167   return c;
168 }
169 
BindArguments(ArgumentsBinder * args)170 absl::Status ConverterToConvWeights::BindArguments(ArgumentsBinder* args) {
171   float4 mask = GetMaskForLastPlane(src_[0]->Channels());
172   RETURN_IF_ERROR(args->SetFloat("mask_x", mask.x));
173   RETURN_IF_ERROR(args->SetFloat("mask_y", mask.y));
174   RETURN_IF_ERROR(args->SetFloat("mask_z", mask.z));
175   return args->SetFloat("mask_w", mask.w);
176 }
177 
GetGridSize() const178 int3 ConverterToConvWeights::GetGridSize() const {
179   const int out_group_size = weights_desc_.GetOutputGroupSize();
180   const int grid_x =
181       DivideRoundUp(AlignByN(src_[0]->Batch(), 4 * out_group_size), 4);
182   const int grid_y = src_[0]->Slices();
183   const int grid_z = src_[0]->Width() * src_[0]->Height();
184   return int3(grid_x, grid_y, grid_z);
185 }
186 
CreateConverterToConvWeights(const OperationDef & definition,const WeightsDescription & weights_desc)187 ConverterToConvWeights CreateConverterToConvWeights(
188     const OperationDef& definition, const WeightsDescription& weights_desc) {
189   return ConverterToConvWeights(definition, weights_desc);
190 }
191 
192 }  // namespace gpu
193 }  // namespace tflite
194