1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
17 
18 #include "absl/memory/memory.h"
19 #include "tensorflow/compiler/xla/layout_util.h"
20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
21 #include "tensorflow/compiler/xla/util.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/gtl/cleanup.h"
24 #include "tensorflow/core/lib/io/path.h"
25 #include "tensorflow/core/platform/cuda_libdevice_path.h"
26 #include "tensorflow/core/platform/regexp.h"
27 #include "tensorflow/core/platform/subprocess.h"
28 #include "tensorflow/core/platform/tracing.h"
29 #include "tensorflow/core/profiler/lib/traceme.h"
30 #include "tensorflow/stream_executor/gpu/gpu_asm_opts.h"
31 #include "tensorflow/stream_executor/kernel_spec.h"
32 
33 namespace xla {
34 namespace gpu {
35 
36 using se::dnn::DataLayout;
37 using se::dnn::DataLayoutString;
38 using se::dnn::FilterLayout;
39 using se::dnn::FilterLayoutString;
40 
IsVoltaOrLater(const se::StreamExecutor & stream_executor)41 bool IsVoltaOrLater(const se::StreamExecutor& stream_executor) {
42   int major, minor;
43   CHECK(stream_executor.GetDeviceDescription().cuda_compute_capability(&major,
44                                                                        &minor));
45   return major >= 7;
46 }
47 
48 StatusOr<std::tuple<Layout, Layout, Layout>>
StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers & dnums,DataLayout input,FilterLayout filter,DataLayout output)49 StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums,
50                                       DataLayout input, FilterLayout filter,
51                                       DataLayout output) {
52   std::vector<int64> input_layout;
53   switch (input) {
54     case DataLayout::kBatchDepthYX:
55       input_layout.push_back(dnums.input_batch_dimension());
56       input_layout.push_back(dnums.input_feature_dimension());
57       input_layout.insert(input_layout.end(),
58                           dnums.input_spatial_dimensions().begin(),
59                           dnums.input_spatial_dimensions().end());
60       break;
61     case DataLayout::kBatchYXDepth:
62       input_layout.push_back(dnums.input_batch_dimension());
63       input_layout.insert(input_layout.end(),
64                           dnums.input_spatial_dimensions().begin(),
65                           dnums.input_spatial_dimensions().end());
66       input_layout.push_back(dnums.input_feature_dimension());
67       break;
68     default:
69       return InternalError("Invalid input layout %s for conv with dnums %s",
70                            DataLayoutString(input),
71                            ConvolutionDimensionNumbersToString(dnums));
72   }
73 
74   std::vector<int64> filter_layout;
75   switch (filter) {
76     case FilterLayout::kOutputInputYX:
77       filter_layout.push_back(dnums.kernel_output_feature_dimension());
78       filter_layout.push_back(dnums.kernel_input_feature_dimension());
79       filter_layout.insert(filter_layout.end(),
80                            dnums.kernel_spatial_dimensions().begin(),
81                            dnums.kernel_spatial_dimensions().end());
82       break;
83     case FilterLayout::kOutputYXInput:
84       filter_layout.push_back(dnums.kernel_output_feature_dimension());
85       filter_layout.insert(filter_layout.end(),
86                            dnums.kernel_spatial_dimensions().begin(),
87                            dnums.kernel_spatial_dimensions().end());
88       filter_layout.push_back(dnums.kernel_input_feature_dimension());
89       break;
90     default:
91       return InternalError("Invalid filter layout %s for conv with dnums %s",
92                            FilterLayoutString(filter),
93                            ConvolutionDimensionNumbersToString(dnums));
94   }
95 
96   std::vector<int64> output_layout;
97   switch (output) {
98     case DataLayout::kBatchDepthYX:
99       output_layout.push_back(dnums.output_batch_dimension());
100       output_layout.push_back(dnums.output_feature_dimension());
101       output_layout.insert(output_layout.end(),
102                            dnums.output_spatial_dimensions().begin(),
103                            dnums.output_spatial_dimensions().end());
104       break;
105     case DataLayout::kBatchYXDepth:
106       output_layout.push_back(dnums.output_batch_dimension());
107       output_layout.insert(output_layout.end(),
108                            dnums.output_spatial_dimensions().begin(),
109                            dnums.output_spatial_dimensions().end());
110       output_layout.push_back(dnums.output_feature_dimension());
111       break;
112     default:
113       return InternalError("Invalid output layout %s for conv with dnums %s",
114                            DataLayoutString(output),
115                            ConvolutionDimensionNumbersToString(dnums));
116   }
117 
118   return std::make_tuple(LayoutUtil::MakeLayoutFromMajorToMinor(input_layout),
119                          LayoutUtil::MakeLayoutFromMajorToMinor(filter_layout),
120                          LayoutUtil::MakeLayoutFromMajorToMinor(output_layout));
121 }
122 
123 StatusOr<std::tuple<DataLayout, FilterLayout, DataLayout>>
XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers & dnums,const Layout & input,const Layout & filter,const Layout & output)124 XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums,
125                                       const Layout& input, const Layout& filter,
126                                       const Layout& output) {
127   Layout nchw_input, nchw_filter, nchw_output;
128   std::tie(nchw_input, nchw_filter, nchw_output) =
129       StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchDepthYX,
130                                             FilterLayout::kOutputInputYX,
131                                             DataLayout::kBatchDepthYX)
132           .ConsumeValueOrDie();
133 
134   Layout nhwc_input, nhwc_filter, nhwc_output;
135   std::tie(nhwc_input, nhwc_filter, nhwc_output) =
136       StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchYXDepth,
137                                             FilterLayout::kOutputYXInput,
138                                             DataLayout::kBatchYXDepth)
139           .ConsumeValueOrDie();
140 
141   DataLayout input_layout;
142   if (LayoutUtil::Equal(input, nchw_input)) {
143     input_layout = DataLayout::kBatchDepthYX;
144   } else if (LayoutUtil::Equal(input, nhwc_input)) {
145     input_layout = DataLayout::kBatchYXDepth;
146   } else {
147     return InternalError("Invalid input layout %s for conv with dnums %s",
148                          LayoutUtil::HumanString(input),
149                          ConvolutionDimensionNumbersToString(dnums));
150   }
151 
152   FilterLayout filter_layout;
153   if (LayoutUtil::Equal(filter, nchw_filter)) {
154     filter_layout = FilterLayout::kOutputInputYX;
155   } else if (LayoutUtil::Equal(filter, nhwc_filter)) {
156     filter_layout = FilterLayout::kOutputYXInput;
157   } else {
158     return InternalError("Invalid filter layout %s for conv with dnums %s",
159                          LayoutUtil::HumanString(filter),
160                          ConvolutionDimensionNumbersToString(dnums));
161   }
162 
163   DataLayout output_layout;
164   if (LayoutUtil::Equal(output, nchw_output)) {
165     output_layout = DataLayout::kBatchDepthYX;
166   } else if (LayoutUtil::Equal(output, nhwc_output)) {
167     output_layout = DataLayout::kBatchYXDepth;
168   } else {
169     return InternalError("Invalid output layout %s for conv with dnums %s",
170                          LayoutUtil::HumanString(output),
171                          ConvolutionDimensionNumbersToString(dnums));
172   }
173 
174   return std::make_tuple(input_layout, filter_layout, output_layout);
175 }
176 
LockGpu(const se::StreamExecutor * stream_exec)177 tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) {
178   static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
179   // se::Platform*s are global singletons guaranteed to live forever.
180   static auto* mutexes =
181       new std::map<std::pair<const se::Platform*, /*device_ordinal*/ int64>,
182                    tensorflow::mutex>();
183 
184   tensorflow::mutex_lock global_lock(mu);
185   auto it = mutexes
186                 ->emplace(std::piecewise_construct,
187                           std::make_tuple(stream_exec->platform(),
188                                           stream_exec->device_ordinal()),
189                           std::make_tuple())
190                 .first;
191   return tensorflow::mutex_lock{it->second};
192 }
193 
CreateKernel(absl::string_view kernel_name,uint64 num_args,absl::string_view ptx,absl::Span<const uint8> cubin_data,se::StreamExecutor * stream_exec)194 StatusOr<std::unique_ptr<se::KernelBase>> CreateKernel(
195     absl::string_view kernel_name, uint64 num_args, absl::string_view ptx,
196     absl::Span<const uint8> cubin_data, se::StreamExecutor* stream_exec) {
197   se::MultiKernelLoaderSpec loader_spec(num_args);
198   loader_spec.AddCudaPtxInMemory(ptx, kernel_name);
199 
200   if (!cubin_data.empty()) {
201     loader_spec.AddCudaCubinInMemory(
202         reinterpret_cast<const char*>(cubin_data.data()), kernel_name);
203   }
204 
205   auto kernel_base = absl::make_unique<se::KernelBase>(stream_exec);
206   TF_RETURN_IF_ERROR(stream_exec->GetKernel(loader_spec, kernel_base.get()));
207   return std::move(kernel_base);
208 }
209 
ExecuteKernelOnStream(const se::KernelBase & kernel,absl::Span<const se::DeviceMemoryBase> args,const LaunchDimensions & dims,se::Stream * stream)210 Status ExecuteKernelOnStream(const se::KernelBase& kernel,
211                              absl::Span<const se::DeviceMemoryBase> args,
212                              const LaunchDimensions& dims, se::Stream* stream) {
213   static constexpr int kKernelArgsLimit = 1024;
214   auto kernel_args = absl::make_unique<se::KernelArgsArray<kKernelArgsLimit>>();
215   for (const se::DeviceMemoryBase& buf : args) {
216     kernel_args->add_device_memory_argument(buf);
217   }
218   LaunchDimensions::Dim3D thread_counts = dims.thread_counts_per_block();
219   LaunchDimensions::Dim3D block_counts = dims.block_counts();
220   return stream->parent()->Launch(
221       stream, se::ThreadDim(thread_counts.x, thread_counts.y, thread_counts.z),
222       se::BlockDim(block_counts.x, block_counts.y, block_counts.z), kernel,
223       *kernel_args);
224 }
225 
PtxOptsFromConfig(const HloModuleConfig & hlo_module_config)226 se::GpuAsmOpts PtxOptsFromConfig(const HloModuleConfig& hlo_module_config) {
227   string extra_string =
228       hlo_module_config.debug_options().xla_gpu_asm_extra_flags();
229   std::vector<std::string> extra_flags;
230   extra_flags = absl::StrSplit(extra_string, ",", absl::SkipEmpty());
231   return se::GpuAsmOpts(
232       hlo_module_config.debug_options().xla_gpu_disable_gpuasm_optimizations(),
233       hlo_module_config.debug_options().xla_gpu_cuda_data_dir(), extra_flags);
234 }
235 
236 // Unimplemented for integers yet.
237 template <typename T, typename Generator>
238 typename std::enable_if<std::is_integral<T>::value,
239                         T>::type static UniformDistribution(T lhs, T rhs,
240                                                             Generator* gen) =
241     delete;
242 
243 template <typename T, typename Generator>
244 typename std::enable_if<std::is_floating_point<T>::value,
UniformDistribution(T lhs,T rhs,Generator * gen)245                         T>::type static UniformDistribution(T lhs, T rhs,
246                                                             Generator* gen) {
247   return std::uniform_real_distribution<T>(lhs, rhs)(*gen);
248 }
249 
250 template <typename T>
InitializeTypedBuffer(se::Stream * stream,se::DeviceMemoryBase buffer,int64 * rng_state)251 static void InitializeTypedBuffer(se::Stream* stream,
252                                   se::DeviceMemoryBase buffer,
253                                   int64* rng_state) {
254   // Accesses to static variables are not locked, since the caller is already
255   // in a critical section.
256   static std::vector<T>* host_buffer = [] {
257     // Use a large prime number to fragment the accesses.
258     auto* ret = new std::vector<T>(10069);
259     // Default-seeded random numbers.
260     std::mt19937 gen;
261     for (auto& element : *ret) {
262       // Only double gets random values in double.  Other data types get random
263       // values in float then cast them to the target data types.
264       using RandomFloatingPointType =
265           typename std::conditional<std::is_same<T, Eigen::half>::value, float,
266                                     T>::type;
267       using RandomType =
268           typename std::conditional<std::is_integral<T>::value, float,
269                                     RandomFloatingPointType>::type;
270       // Scale down the values for fp16 to have less overflows.
271       auto upper_bound =
272           RandomType(std::is_same<T, Eigen::half>::value ? 0.1 : 1.0);
273       auto rand_val = UniformDistribution(RandomType(0), upper_bound, &gen);
274       // For float or double, it is between [0,1].
275       // For fp16, it ranges between [0, 0.1].
276       // For integer types, element is either 0 or 1 for less overflows
277       // especially for int8.
278       element = T(std::is_integral<T>::value ? rand_val + 0.5 : rand_val);
279     }
280     return ret;
281   }();
282 
283   int64& host_index = *rng_state;
284 
285   char* current_addr = static_cast<char*>(buffer.opaque());
286   CHECK_EQ(0, buffer.size() % sizeof(T));
287   int64 elements_left = buffer.size() / sizeof(T);
288   while (elements_left > 0) {
289     CHECK_LE(host_index, host_buffer->size());
290     if (host_buffer->size() == host_index) {
291       host_index = 0;
292     }
293     int64 elements_copied =
294         std::min<int64>(host_buffer->size() - host_index, elements_left);
295     se::DeviceMemoryBase mem(current_addr, elements_copied * sizeof(T));
296     stream->ThenMemcpy(&mem, host_buffer->data() + host_index,
297                        elements_copied * sizeof(T));
298     current_addr += elements_copied * sizeof(T);
299     elements_left -= elements_copied;
300     host_index += elements_copied;
301   }
302 }
303 
InitializeBuffer(se::Stream * stream,PrimitiveType buffer_type,int64 * rng_state,se::DeviceMemoryBase buffer)304 void InitializeBuffer(se::Stream* stream, PrimitiveType buffer_type,
305                       int64* rng_state, se::DeviceMemoryBase buffer) {
306   switch (buffer_type) {
307     case xla::F16:
308       return InitializeTypedBuffer<Eigen::half>(stream, buffer, rng_state);
309     case xla::F32:
310     case xla::C64:
311       return InitializeTypedBuffer<float>(stream, buffer, rng_state);
312     case xla::F64:
313     case xla::C128:
314       return InitializeTypedBuffer<double>(stream, buffer, rng_state);
315     case xla::S8:
316       return InitializeTypedBuffer<int8>(stream, buffer, rng_state);
317     default:
318       LOG(FATAL) << "Unexpected type";
319   }
320 }
321 
GetDNNConvKindFromCudnnConvKind(CudnnConvKind kind)322 StatusOr<se::dnn::ConvolutionKind> GetDNNConvKindFromCudnnConvKind(
323     CudnnConvKind kind) {
324   switch (kind) {
325     case CudnnConvKind::kBackwardFilter:
326       return se::dnn::BACKWARD_FILTER;
327     case CudnnConvKind::kBackwardInput:
328       return se::dnn::BACKWARD_DATA;
329     case CudnnConvKind::kForward:
330       return se::dnn::FORWARD;
331     default:
332       break;
333   }
334   return InternalError("Unexpected convolution kind");
335 }
336 
GetDNNDataTypeFromPrimitiveType(PrimitiveType type)337 StatusOr<se::dnn::DataType> GetDNNDataTypeFromPrimitiveType(
338     PrimitiveType type) {
339   switch (type) {
340     case F16:
341       return se::dnn::ToDataType<Eigen::half>::value;
342     case F32:
343       return se::dnn::ToDataType<float>::value;
344     case F64:
345       return se::dnn::ToDataType<double>::value;
346     default:
347       break;
348   }
349   return InternalError("Unsupported convolution datatype");
350 }
351 
352 }  // namespace gpu
353 }  // namespace xla
354