1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
17 
18 #include "absl/strings/str_cat.h"
19 #include "tensorflow/lite/delegates/gpu/cl/cl_program.h"
20 #include "tensorflow/lite/delegates/gpu/cl/util.h"
21 #include "tensorflow/lite/delegates/gpu/common/status.h"
22 
23 namespace tflite {
24 namespace gpu {
25 namespace cl {
26 namespace {
27 
GetKernelMaxWorkGroupSize(cl_kernel kernel,cl_device_id device_id,int * result)28 absl::Status GetKernelMaxWorkGroupSize(cl_kernel kernel, cl_device_id device_id,
29                                        int* result) {
30   size_t max_work_group_size;
31   cl_int error_code =
32       clGetKernelWorkGroupInfo(kernel, device_id, CL_KERNEL_WORK_GROUP_SIZE,
33                                sizeof(size_t), &max_work_group_size, nullptr);
34   if (error_code != CL_SUCCESS) {
35     return absl::UnknownError(
36         absl::StrCat("Failed to get info CL_KERNEL_WORK_GROUP_SIZE ",
37                      CLErrorCodeToString(error_code)));
38   }
39   *result = static_cast<int>(max_work_group_size);
40   return absl::OkStatus();
41 }
42 
GetKernelPrivateMemorySize(cl_kernel kernel,cl_device_id device_id,int * result)43 absl::Status GetKernelPrivateMemorySize(cl_kernel kernel,
44                                         cl_device_id device_id, int* result) {
45   cl_ulong private_mem_size;
46   cl_int error_code =
47       clGetKernelWorkGroupInfo(kernel, device_id, CL_KERNEL_PRIVATE_MEM_SIZE,
48                                sizeof(cl_ulong), &private_mem_size, nullptr);
49   if (error_code != CL_SUCCESS) {
50     return absl::UnknownError(
51         absl::StrCat("Failed to get info CL_KERNEL_PRIVATE_MEM_SIZE ",
52                      CLErrorCodeToString(error_code)));
53   }
54   *result = static_cast<int>(private_mem_size);
55   return absl::OkStatus();
56 }
57 
58 }  // namespace
59 
CLKernel(CLKernel && kernel)60 CLKernel::CLKernel(CLKernel&& kernel)
61     : info_(kernel.info_),
62       binding_counter_(kernel.binding_counter_),
63       function_name_(std::move(kernel.function_name_)),
64       program_(kernel.program_),
65       kernel_(kernel.kernel_) {
66   kernel.kernel_ = nullptr;
67 }
68 
operator =(CLKernel && kernel)69 CLKernel& CLKernel::operator=(CLKernel&& kernel) {
70   if (this != &kernel) {
71     Release();
72     std::swap(info_, kernel.info_);
73     std::swap(binding_counter_, kernel.binding_counter_);
74     function_name_ = std::move(kernel.function_name_);
75     std::swap(program_, kernel.program_);
76     std::swap(kernel_, kernel.kernel_);
77   }
78   return *this;
79 }
80 
~CLKernel()81 CLKernel::~CLKernel() { Release(); }
82 
ReInit() const83 absl::Status CLKernel::ReInit() const {
84   clReleaseKernel(kernel_);
85   cl_kernel* kern_ptr = const_cast<cl_kernel*>(&kernel_);
86   int error_code;
87   *kern_ptr = clCreateKernel(program_, function_name_.c_str(), &error_code);
88   if (!kernel_ || error_code != CL_SUCCESS) {
89     *kern_ptr = nullptr;
90     return absl::UnknownError(absl::StrCat("Failed to create ", function_name_,
91                                            CLErrorCodeToString(error_code)));
92   }
93   return absl::OkStatus();
94 }
95 
Release()96 void CLKernel::Release() {
97   if (kernel_) {
98     clReleaseKernel(kernel_);
99     clReleaseProgram(program_);
100     kernel_ = nullptr;
101   }
102 }
103 
CreateFromProgram(const CLProgram & program,const std::string & function_name)104 absl::Status CLKernel::CreateFromProgram(const CLProgram& program,
105                                          const std::string& function_name) {
106   int error_code;
107   function_name_ = function_name;
108   kernel_ =
109       clCreateKernel(program.program(), function_name.c_str(), &error_code);
110   if (!kernel_ || error_code != CL_SUCCESS) {
111     kernel_ = nullptr;
112     return absl::UnknownError(absl::StrCat("Failed to create ", function_name,
113                                            CLErrorCodeToString(error_code)));
114   }
115 
116   program_ = program.program();
117   clRetainProgram(program_);
118 
119   RETURN_IF_ERROR(GetKernelPrivateMemorySize(kernel_, program.GetDeviceId(),
120                                              &info_.private_memory_size));
121   RETURN_IF_ERROR(GetKernelMaxWorkGroupSize(kernel_, program.GetDeviceId(),
122                                             &info_.max_work_group_size));
123   return absl::OkStatus();
124 }
125 
SetMemory(int index,cl_mem memory)126 absl::Status CLKernel::SetMemory(int index, cl_mem memory) {
127   return SetBytes(index, &memory, sizeof(cl_mem));
128 }
129 
SetMemoryAuto(cl_mem memory)130 absl::Status CLKernel::SetMemoryAuto(cl_mem memory) {
131   return SetBytesAuto(&memory, sizeof(cl_mem));
132 }
133 
SetBytes(int index,const void * ptr,int length) const134 absl::Status CLKernel::SetBytes(int index, const void* ptr, int length) const {
135   const int error_code = clSetKernelArg(kernel_, index, length, ptr);
136   if (error_code != CL_SUCCESS) {
137     return absl::UnknownError(absl::StrCat("Failed to set kernel arguments - ",
138                                            CLErrorCodeToString(error_code)));
139   }
140   return absl::OkStatus();
141 }
142 
SetBytesAuto(const void * ptr,int length)143 absl::Status CLKernel::SetBytesAuto(const void* ptr, int length) {
144   const int error_code = clSetKernelArg(kernel_, binding_counter_, length, ptr);
145   if (error_code != CL_SUCCESS) {
146     return absl::UnknownError(absl::StrCat(
147         "Failed to set kernel arguments - ", CLErrorCodeToString(error_code),
148         "(at index - ", binding_counter_, ")"));
149   }
150   binding_counter_++;
151   return absl::OkStatus();
152 }
153 
154 }  // namespace cl
155 }  // namespace gpu
156 }  // namespace tflite
157