1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/metal/compute_task.h"
17 
18 #include <Availability.h>
19 
20 #include <map>
21 #include <string>
22 #include <tuple>
23 
24 #include "absl/strings/match.h"
25 #include "absl/strings/substitute.h"
26 #include "tensorflow/lite/delegates/gpu/common/kernel_info.h"
27 #include "tensorflow/lite/delegates/gpu/common/shape.h"
28 #include "tensorflow/lite/delegates/gpu/common/status.h"
29 #include "tensorflow/lite/delegates/gpu/common/types.h"
30 #include "tensorflow/lite/delegates/gpu/common/util.h"
31 #include "tensorflow/lite/delegates/gpu/metal/common.h"
32 
33 namespace tflite {
34 namespace gpu {
35 namespace metal {
36 namespace {
GetWorkGroupsCount(int grid_dimension,const int3 & grid_size,const int3 & work_group_size,const int3 & work_group_launch_order)37 int3 GetWorkGroupsCount(int grid_dimension, const int3& grid_size,
38                         const int3& work_group_size,
39                         const int3& work_group_launch_order) {
40   int3 work_groups_count;
41   if (grid_dimension == 1) {
42     work_groups_count.x = DivideRoundUp(grid_size.x, work_group_size.x);
43     work_groups_count.y = 1;
44     work_groups_count.z = 1;
45   } else if (grid_dimension == 2) {
46     int3 wgs;
47     wgs.x = DivideRoundUp(grid_size.x, work_group_size.x);
48     wgs.y = DivideRoundUp(grid_size.y, work_group_size.y);
49     work_groups_count.x = wgs[work_group_launch_order[0]];
50     work_groups_count.y = wgs[work_group_launch_order[1]];
51     work_groups_count.z = 1;
52   } else {  // grid_dimension == 3
53     int3 wgs;
54     wgs.x = DivideRoundUp(grid_size.x, work_group_size.x);
55     wgs.y = DivideRoundUp(grid_size.y, work_group_size.y);
56     wgs.z = DivideRoundUp(grid_size.z, work_group_size.z);
57     work_groups_count.x = wgs[work_group_launch_order[0]];
58     work_groups_count.y = wgs[work_group_launch_order[1]];
59     work_groups_count.z = wgs[work_group_launch_order[2]];
60   }
61   return work_groups_count;
62 }
63 }  // namespace
64 
Init(std::unique_ptr<GPUOperation> && operation)65 void ComputeTask::Init(std::unique_ptr<GPUOperation>&& operation) {
66   operation_ = std::move(operation);
67 }
68 
GetDefinition() const69 const OperationDef& ComputeTask::GetDefinition() const {
70   return operation_->definition_;
71 }
72 
IsLinkable() const73 bool ComputeTask::IsLinkable() const { return operation_->IsLinkable(); }
74 
AddTask(ComputeTask * task)75 absl::Status ComputeTask::AddTask(ComputeTask* task) {
76   return operation_->AddOperation(task->operation_.get());
77 }
78 
Compile(MetalDevice * device)79 absl::Status ComputeTask::Compile(MetalDevice* device) {
80   operation_->AssembleCode(device->GetInfo());
81   const std::map<std::string, std::string> linkables = {
82       {operation_->dst_tensors_names_[0], operation_->elementwise_code_}};
83   RETURN_IF_ERROR(metal_args_.Init(linkables, device, &operation_->args_,
84                                    &operation_->code_));
85 
86   operation_->args_.ReleaseCPURepresentation();
87 
88   return CompileProgram(device, operation_->definition_.precision,
89                         operation_->code_);
90 }
91 
CompileProgram(MetalDevice * device,CalculationsPrecision precision,const std::string & kernel_code)92 absl::Status ComputeTask::CompileProgram(MetalDevice* device,
93                                          CalculationsPrecision precision,
94                                          const std::string& kernel_code) {
95   NSString* barrier;
96   // simdgroup_barrier is supported since Metal shading language version 2.0
97   if (device->IsLanguageVersion2orHigher()) {
98     barrier = @"simdgroup_barrier";
99   } else {
100     barrier = @"threadgroup_barrier";
101   }
102   NSString* storageType;
103   NSString* accumulatorType;
104   NSString* toAccumulatorType4 = @"";
105   if (precision == CalculationsPrecision::F32) {
106     storageType = @"float";
107     accumulatorType = @"float";
108   } else {
109     // FP16
110     storageType = @"half";
111     if (precision == CalculationsPrecision::F32_F16) {
112       accumulatorType = @"float";
113       toAccumulatorType4 = @"float4";
114     } else {
115       accumulatorType = @"half";
116     }
117   }
118   NSDictionary<NSString*, NSString*>* macros = @{
119     @"float16" : @"float4x4",
120     @"half16" : @"half4x4",
121     @"FLT16_0123(V)" : @"V[0]",
122     @"FLT16_4567(V)" : @"V[1]",
123     @"FLT16_89ab(V)" : @"V[2]",
124     @"FLT16_cdef(V)" : @"V[3]",
125     @"FLT" : storageType,
126     @"FLT2" : [NSString stringWithFormat:@"%@2", storageType],
127     @"FLT3" : [NSString stringWithFormat:@"%@3", storageType],
128     @"FLT4" : [NSString stringWithFormat:@"%@4", storageType],
129     @"ACCUM_FLT" : accumulatorType,
130     @"ACCUM_FLT2" : [NSString stringWithFormat:@"%@2", accumulatorType],
131     @"ACCUM_FLT3" : [NSString stringWithFormat:@"%@3", accumulatorType],
132     @"ACCUM_FLT4" : [NSString stringWithFormat:@"%@4", accumulatorType],
133     @"INIT_ACCUM_FLT4(value)" :
134         [NSString stringWithFormat:@"%@4(value)", accumulatorType],
135     @"TO_ACCUM_TYPE" : toAccumulatorType4,
136     @"TO_ACCUM_FLT" : accumulatorType,
137     @"TO_FLT4" : [NSString stringWithFormat:@"%@4", storageType],
138     @"SIMDGROUP_BARRIER" : barrier,
139     @"SIMD_LOCAL_MEM_BARRIER" : barrier,
140     @"MAIN_FUNCTION" : @"\"kernel void ComputeFunction\"",
141     @"GLOBAL_ID_0" : @"static_cast<int>(reserved_gid.x)",
142     @"GLOBAL_ID_1" : @"static_cast<int>(reserved_gid.y)",
143     @"GLOBAL_ID_2" : @"static_cast<int>(reserved_gid.z)",
144     @"LOCAL_ID_0" : @"static_cast<int>(reserved_lid.x)",
145     @"LOCAL_ID_1" : @"static_cast<int>(reserved_lid.y)",
146     @"LOCAL_ID_2" : @"static_cast<int>(reserved_lid.z)",
147     @"GROUP_ID_0" : @"static_cast<int>(reserved_group_id.x)",
148     @"GROUP_ID_1" : @"static_cast<int>(reserved_group_id.y)",
149     @"GROUP_ID_2" : @"static_cast<int>(reserved_group_id.z)",
150     @"GROUP_SIZE_0" : @"static_cast<int>(reserved_group_size.x)",
151     @"GROUP_SIZE_1" : @"static_cast<int>(reserved_group_size.y)",
152     @"GROUP_SIZE_2" : @"static_cast<int>(reserved_group_size.z)",
153     @"SUB_GROUP_LOCAL_ID" : @"static_cast<int>(reserved_simd_id)",
154     @"\"SUB_GROUP_BROADCAST(V, ID)\"" : @"\"simd_broadcast(V, ID)\"",
155     @"__local" : @"threadgroup",
156     @"__global" : @"device",
157     @"__constant" : @"constant",
158     @"LOCAL_MEM_BARRIER" : @"threadgroup_barrier(mem_flags::mem_threadgroup)",
159     @"INIT_FLT(value)" : [NSString stringWithFormat:@"%@(value)", storageType],
160     @"INIT_FLT4(value)" :
161         [NSString stringWithFormat:@"%@4(value)", storageType],
162     @"\"INIT_FLT4v4(v0, v1, v2, v3)\"" :
163         [NSString stringWithFormat:@"\"%@4(v0, v1, v2, v3)\"", storageType],
164     @"INIT_FLOAT(value)" : @"float(value)",
165     @"INIT_FLOAT2(value)" : @"float2(value)",
166     @"\"INIT_FLOAT2v2(v0, v1)\"" : @"\"float2(v0, v1)\"",
167     @"INIT_FLOAT3(value)" : @"float3(value)",
168     @"\"INIT_FLOAT3v3(v0, v1, v2)\"" : @"\"float3(v0, v1, v2)\"",
169     @"INIT_FLOAT4(value)" : @"float4(value)",
170     @"\"INIT_FLOAT4v4(v0, v1, v2, v3)\"" : @"\"float4(v0, v1, v2, v3)\"",
171     @"INIT_INT(value)" : @"int(value)",
172     @"\"INIT_INT2v2(v0, v1)\"" : @"\"int2(v0, v1)\"",
173     @"\"INIT_INT4v4(v0, v1, v2, v3)\"" : @"\"int4(v0, v1, v2, v3)\"",
174     @"CONVERT_TO_INT4(value)" : @"int4(value)",
175   };
176 
177   NSString* code =
178       [NSString stringWithCString:kernel_code.c_str()
179                          encoding:[NSString defaultCStringEncoding]];
180   id<MTLComputePipelineState> program;
181   RETURN_IF_ERROR(CreateComputeProgram(device->device(), code,
182                                        @"ComputeFunction", macros, &program));
183   if (!program) {
184     return absl::InternalError("Unknown shader compilation error");
185   }
186   program_ = program;
187   return absl::OkStatus();
188 }
189 
UpdateParams()190 absl::Status ComputeTask::UpdateParams() {
191   for (int i = 0; i < operation_->src_tensors_names_.size(); ++i) {
192     const auto* metal_spatial_tensor =
193         dynamic_cast<const MetalSpatialTensor*>(operation_->src_[i]);
194     if (!metal_spatial_tensor) {
195       return absl::InvalidArgumentError("Expected MetalSpatialTensor.");
196     }
197     RETURN_IF_ERROR(metal_args_.SetObjectRef(operation_->src_tensors_names_[i],
198                                              *metal_spatial_tensor));
199   }
200   for (int i = 0; i < operation_->dst_tensors_names_.size(); ++i) {
201     const auto* metal_spatial_tensor =
202         dynamic_cast<const MetalSpatialTensor*>(operation_->dst_[i]);
203     if (!metal_spatial_tensor) {
204       return absl::InvalidArgumentError("Expected MetalSpatialTensor.");
205     }
206     RETURN_IF_ERROR(metal_args_.SetObjectRef(operation_->dst_tensors_names_[i],
207                                              *metal_spatial_tensor));
208   }
209   RETURN_IF_ERROR(operation_->BindArguments(&metal_args_));
210   operation_->grid_size_ = operation_->GetGridSize();
211   operation_->work_groups_count_ = GetWorkGroupsCount(
212       operation_->grid_dimension_, operation_->grid_size_,
213       operation_->work_group_size_, operation_->work_group_launch_order_);
214   return absl::OkStatus();
215 }
216 
Encode(id<MTLComputeCommandEncoder> encoder)217 void ComputeTask::Encode(id<MTLComputeCommandEncoder> encoder) {
218   [encoder setComputePipelineState:program_];
219   metal_args_.Encode(encoder, 0);
220   MTLSize groupsCount, groupsSize;
221   groupsCount.width = operation_->work_groups_count_.x;
222   groupsCount.height = operation_->work_groups_count_.y;
223   groupsCount.depth = operation_->work_groups_count_.z;
224   groupsSize.width = operation_->work_group_size_.x;
225   groupsSize.height = operation_->work_group_size_.y;
226   groupsSize.depth = operation_->work_group_size_.z;
227   [encoder dispatchThreadgroups:groupsCount threadsPerThreadgroup:groupsSize];
228 }
229 
SetSrcTensor(MetalSpatialTensor * tensor,int index)230 void ComputeTask::SetSrcTensor(MetalSpatialTensor* tensor, int index) {
231   operation_->SetSrc(tensor, index);
232   auto status =
233       metal_args_.SetObjectRef(operation_->src_tensors_names_[index], *tensor);
234 }
235 
SetDstTensor(MetalSpatialTensor * tensor,int index)236 void ComputeTask::SetDstTensor(MetalSpatialTensor* tensor, int index) {
237   operation_->SetDst(tensor, index);
238   auto status =
239       metal_args_.SetObjectRef(operation_->dst_tensors_names_[index], *tensor);
240 }
241 
Tune(TuningType tuning_type,MetalDevice * device)242 absl::Status ComputeTask::Tune(TuningType tuning_type, MetalDevice* device) {
243   std::vector<int3> possible_work_groups;
244   KernelInfo kernel_info;
245   kernel_info.max_work_group_size = [program_ maxTotalThreadsPerThreadgroup];
246   kernel_info.private_memory_size = 0;
247   operation_->GetPossibleKernelWorkGroups(tuning_type, device->GetInfo(),
248                                           kernel_info, &possible_work_groups);
249   if (possible_work_groups.empty()) {
250     return absl::NotFoundError(
251         "Can not found work_group size to launch kernel");
252   }
253   operation_->work_group_size_ = possible_work_groups[0];
254   operation_->work_groups_count_ = GetWorkGroupsCount(
255       operation_->grid_dimension_, operation_->grid_size_,
256       operation_->work_group_size_, operation_->work_group_launch_order_);
257   return absl::OkStatus();
258 }
259 
260 }  // namespace metal
261 }  // namespace gpu
262 }  // namespace tflite
263