1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/split.h"
17 
18 #include <string>
19 
20 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
21 
22 namespace tflite {
23 namespace gpu {
24 
Split(const OperationDef & definition,const SplitAttributes & attr)25 Split::Split(const OperationDef& definition, const SplitAttributes& attr)
26     : GPUOperation(definition), attr_(attr) {
27   work_group_size_ = int3(8, 4, 1);
28   code_ = attr.axis == Axis::CHANNELS ? GetSplitChannelsCode() : GetSplitCode();
29 }
30 
GetSplitCode()31 std::string Split::GetSplitCode() {
32   AddSrcTensor("src_tensor", definition_.src_tensors[0]);
33   for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
34     AddDstTensor("dst_tensor_" + std::to_string(i), definition_.dst_tensors[i]);
35   }
36   const std::string task_width =
37       attr_.axis == Axis::WIDTH ? "1" : "args.src_tensor.Width()";
38   const std::string task_height =
39       attr_.axis == Axis::HEIGHT ? "1" : "args.src_tensor.Height()";
40   const std::string task_depth =
41       attr_.axis == Axis::DEPTH ? "1" : "args.src_tensor.Depth()";
42   const std::string task_batch =
43       attr_.axis == Axis::BATCH ? "1" : "args.src_tensor.Batch()";
44   const std::string task_slices =
45       attr_.axis == Axis::CHANNELS ? "1" : "args.src_tensor.Slices()";
46 
47   std::string c;
48   c += "MAIN_FUNCTION($0) {\n";
49   c += "  int task_width = "
50        ";\n";
51   if (definition_.src_tensors[0].HasAxis(Axis::BATCH)) {
52     c += "  int linear_id = GLOBAL_ID_0;\n";
53     c += "  int X = linear_id / " + task_batch + ";\n";
54     c += "  int B = linear_id % " + task_batch + ";\n";
55   } else {
56     c += "  int X = GLOBAL_ID_0;\n";
57   }
58   if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) {
59     c += "  int linear_id = GLOBAL_ID_1;\n";
60     c += "  int Y = linear_id % " + task_height + ";\n";
61     c += "  int B = linear_id / " + task_height + ";\n";
62   } else {
63     c += "  int Y = GLOBAL_ID_1;\n";
64   }
65   c += "  int S = GLOBAL_ID_2;\n";
66   c += "  if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
67        "S >= args.dst_tensor.Slices()) { \n";
68   c += "    return; \n";
69   c += "  } \n";
70   c += "  int src_counter = 0;\n";
71   for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
72     const std::string dst_name = "args.dst_tensor_" + std::to_string(i);
73     c += "  for (int i = 0; i < " + dst_name +
74          ".Slices(); ++i, src_counter++) {\n";
75     c += "    FLT4 result = args.src_tensor.Read(s_x, s_y, src_counter);\n";
76     c += "    " + dst_name + ".Write(result, X, Y, i);\n";
77     c += "  }\n";
78   }
79   c += "}\n";
80   return c;
81 }
82 
GetSplitChannelsCode()83 std::string Split::GetSplitChannelsCode() {
84   AddSrcTensor("src_tensor", definition_.src_tensors[0]);
85   for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
86     AddDstTensor("dst_tensor_" + std::to_string(i), definition_.dst_tensors[i]);
87   }
88 
89   const std::string batch_coord =
90       definition_.src_tensors[0].HasAxis(Axis::BATCH) ? ", B" : "";
91   std::string coords = "X, Y";
92   std::string c;
93   c += "MAIN_FUNCTION($0) {\n";
94   if (definition_.src_tensors[0].HasAxis(Axis::BATCH)) {
95     c += "  int linear_id = GLOBAL_ID_0;\n";
96     c += "  int X = linear_id / args.src_tensor.Batch();\n";
97     c += "  int B = linear_id % args.src_tensor.Batch();\n";
98     c += "  if (X >= args.src_tensor.Width()) return;\n";
99   } else {
100     c += "  int X = GLOBAL_ID_0;\n";
101     c += "  if (X >= args.src_tensor.Width()) return;\n";
102   }
103   if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) {
104     c += "  int linear_id = GLOBAL_ID_1;\n";
105     c += "  int Y = linear_id % args.src_tensor.Height();\n";
106     c += "  int Z = linear_id / args.src_tensor.Height();\n";
107     c += "  if (Z >= args.src_tensor.Depth()) return;\n";
108     coords += ", Z";
109   } else {
110     c += "  int Y = GLOBAL_ID_1;\n";
111     c += "  if (Y >= args.src_tensor.Height()) return;\n";
112   }
113   c += "  int src_channel = 0;\n";
114   const std::string postfixes[] = {"x", "y", "z", "w"};
115   for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
116     const std::string dst_name = "args.dst_tensor_" + std::to_string(i);
117     c += "  for (int i = 0; i < " + dst_name + ".Slices(); ++i) {\n";
118     c += "    FLT4 result = INIT_FLT4(0.0f);\n";
119     for (int j = 0; j < 4; ++j) {
120       c += "    if (i * 4 + " + std::to_string(j) + " < " + dst_name +
121            ".Channels()) {\n";
122       c += "      int src_slice = src_channel >> 2;\n";
123       c += "      int src_sub_ch = src_channel & 3;\n";
124       c += "      FLT4 t = args.src_tensor.Read(" + coords + ", src_slice" +
125            batch_coord + ");\n";
126       c += "      FLT t_ar[4] = {t.x, t.y, t.z, t.w};\n";
127       c += "      result." + postfixes[j] + " = t_ar[src_sub_ch];\n";
128       c += "      src_channel++;\n";
129       c += "    }\n";
130     }
131     c += "    " + dst_name + ".Write(result, " + coords + ", i" + batch_coord +
132          ");\n";
133     c += "  }\n";
134   }
135   c += "}\n";
136   return c;
137 }
138 
GetGridSize() const139 int3 Split::GetGridSize() const {
140   const int width = attr_.axis == Axis::WIDTH ? 1 : src_[0]->Width();
141   const int height = attr_.axis == Axis::HEIGHT ? 1 : src_[0]->Height();
142   const int depth = attr_.axis == Axis::DEPTH ? 1 : src_[0]->Depth();
143   const int batch = attr_.axis == Axis::BATCH ? 1 : src_[0]->Batch();
144   const int slices = attr_.axis == Axis::CHANNELS ? 1 : src_[0]->Slices();
145   const int grid_x = width * batch;
146   const int grid_y = height * depth;
147   const int grid_z = slices;
148   return int3(grid_x, grid_y, grid_z);
149 }
150 
CreateSplit(const OperationDef & definition,const SplitAttributes & attr)151 Split CreateSplit(const OperationDef& definition, const SplitAttributes& attr) {
152   return Split(definition, attr);
153 }
154 
155 }  // namespace gpu
156 }  // namespace tflite
157