1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implements the StreamExecutor interface by passing through to its
17 // implementation_ value (in pointer-to-implementation style), which
18 // implements StreamExecutorInterface.
19 
20 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
21 
22 #include <atomic>
23 #include <memory>
24 #include <utility>
25 
26 #include "absl/base/const_init.h"
27 #include "absl/strings/ascii.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/str_format.h"
30 #include "absl/synchronization/notification.h"
31 #include "tensorflow/core/util/env_var.h"
32 #include "tensorflow/stream_executor/blas.h"
33 #include "tensorflow/stream_executor/fft.h"
34 #include "tensorflow/stream_executor/lib/env.h"
35 #include "tensorflow/stream_executor/lib/error.h"
36 #include "tensorflow/stream_executor/lib/stacktrace.h"
37 #include "tensorflow/stream_executor/lib/statusor.h"
38 #include "tensorflow/stream_executor/lib/threadpool.h"
39 #include "tensorflow/stream_executor/platform/port.h"
40 #include "tensorflow/stream_executor/rng.h"
41 #include "tensorflow/stream_executor/stream.h"
42 #include "tensorflow/stream_executor/stream_executor_internal.h"
43 
44 namespace {
45 bool FLAGS_check_device_leaks = false;
46 }  // namespace
47 
48 namespace stream_executor {
49 namespace {
50 
StackTraceIfVLOG10()51 std::string StackTraceIfVLOG10() {
52   if (VLOG_IS_ON(10)) {
53     return absl::StrCat(" ", port::CurrentStackTrace(), "\n");
54   } else {
55     return "";
56   }
57 }
58 
59 // Make sure the executor is done with its work; we know (because this isn't
60 // publicly visible) that all enqueued work is quick.
BlockOnThreadExecutor(port::ThreadPool * executor)61 void BlockOnThreadExecutor(port::ThreadPool *executor) {
62   absl::Notification n;
63   executor->Schedule([&n]() { n.Notify(); });
64   n.WaitForNotification();
65 }
66 
67 std::atomic_int_fast64_t correlation_id_generator(0);
68 
69 }  // namespace
70 
71 template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
72           typename... BeginArgsT>
73 class ScopedTracer {
74  public:
ScopedTracer(StreamExecutor * stream_exec,BeginCallT begin_call,CompleteCallT complete_call,const ReturnT * result,BeginArgsT...begin_args)75   ScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call,
76                CompleteCallT complete_call, const ReturnT *result,
77                BeginArgsT... begin_args)
78       : stream_exec_(stream_exec),
79         complete_call_(complete_call),
80         result_(result) {
81     if (stream_exec_->tracing_enabled_) {
82       correlation_id_ =
83           correlation_id_generator.fetch_add(1, std::memory_order_relaxed) - 1;
84       Trace(begin_call, begin_args...);
85     }
86   }
87 
~ScopedTracer()88   ~ScopedTracer() {
89     if (stream_exec_->tracing_enabled_) {
90       Trace(complete_call_, result_);
91     }
92   }
93 
94  private:
95   template <typename CallbackT, typename... TraceArgsT>
Trace(CallbackT callback,TraceArgsT...args)96   void Trace(CallbackT callback, TraceArgsT... args) {
97     {
98       // Instance tracers held in a block to limit the lock lifetime.
99       absl::ReaderMutexLock lock{&stream_exec_->mu_};
100       for (TraceListener *listener : stream_exec_->listeners_) {
101         (listener->*callback)(correlation_id_,
102                               std::forward<TraceArgsT>(args)...);
103       }
104     }
105   }
106 
107   StreamExecutor *stream_exec_;
108   CompleteCallT complete_call_;
109   const ReturnT *result_;
110   int64 correlation_id_;
111 };
112 
113 template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
114           typename... BeginArgsT>
115 ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>
MakeScopedTracer(StreamExecutor * stream_exec,BeginCallT begin_call,CompleteCallT complete_call,ReturnT * result,BeginArgsT...begin_args)116 MakeScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call,
117                  CompleteCallT complete_call, ReturnT *result,
118                  BeginArgsT... begin_args) {
119   return ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>(
120       stream_exec, begin_call, complete_call, result,
121       std::forward<BeginArgsT>(begin_args)...);
122 }
123 
124 #define SCOPED_TRACE(LOC, ...) \
125   auto tracer =                \
126       MakeScopedTracer(this, &LOC##Begin, &LOC##Complete, ##__VA_ARGS__);
127 
128 /* static */ absl::Mutex StreamExecutor::static_mu_{absl::kConstInit};
129 
130 // Get per-device memory limit in bytes. Returns 0 if
131 // TF_PER_DEVICE_MEMORY_LIMIT_MB environment variable is not set.
GetMemoryLimitBytes()132 static int64 GetMemoryLimitBytes() {
133   int64 value;
134   SE_CHECK_OK(tensorflow::ReadInt64FromEnvVar("TF_PER_DEVICE_MEMORY_LIMIT_MB",
135                                               0, &value));
136   return value * (1ll << 20);
137 }
138 
StreamExecutor(const Platform * platform,std::unique_ptr<internal::StreamExecutorInterface> implementation,int device_ordinal)139 StreamExecutor::StreamExecutor(
140     const Platform *platform,
141     std::unique_ptr<internal::StreamExecutorInterface> implementation,
142     int device_ordinal)
143     : platform_(platform),
144       implementation_(std::move(implementation)),
145       device_ordinal_(device_ordinal),
146       background_threads_(new port::ThreadPool(
147           port::Env::Default(), "stream_executor", kNumBackgroundThreads)),
148       live_stream_count_(0),
149       tracing_enabled_(false),
150       mem_alloc_bytes_(0),
151       memory_limit_bytes_(GetMemoryLimitBytes()),
152       allocator_(this) {
153   std::string name = absl::AsciiStrToLower(platform_->Name());
154   if (name == "cuda") {
155     platform_kind_ = PlatformKind::kCuda;
156   } else if (name == "rocm") {
157     platform_kind_ = PlatformKind::kROCm;
158   } else if (name == "opencl") {
159     platform_kind_ = PlatformKind::kOpenCL;
160   } else if (name == "host") {
161     platform_kind_ = PlatformKind::kHost;
162   } else {
163     platform_kind_ = PlatformKind::kInvalid;
164   }
165 }
166 
~StreamExecutor()167 StreamExecutor::~StreamExecutor() {
168   BlockOnThreadExecutor(background_threads_.get());
169 
170   if (live_stream_count_.load() != 0) {
171     LOG(WARNING) << "Not all streams were deallocated at executor destruction "
172                  << "time. This may lead to unexpected/bad behavior - "
173                  << "especially if any stream is still active!";
174   }
175 
176   if (FLAGS_check_device_leaks) {
177     for (const auto &it : mem_allocs_) {
178       LOG(INFO) << "Memory alloced at executor exit: addr: "
179                 << absl::StrFormat("%p", it.first)
180                 << ", bytes: " << it.second.bytes << ", trace: \n"
181                 << it.second.stack_trace;
182     }
183   }
184 }
185 
Init(DeviceOptions device_options)186 port::Status StreamExecutor::Init(DeviceOptions device_options) {
187   return implementation_->Init(device_ordinal_, std::move(device_options));
188 }
189 
Init()190 port::Status StreamExecutor::Init() { return Init(DeviceOptions::Default()); }
191 
GetKernel(const MultiKernelLoaderSpec & spec,KernelBase * kernel)192 port::Status StreamExecutor::GetKernel(const MultiKernelLoaderSpec &spec,
193                                        KernelBase *kernel) {
194   return implementation_->GetKernel(spec, kernel);
195 }
196 
UnloadKernel(const KernelBase * kernel)197 void StreamExecutor::UnloadKernel(const KernelBase *kernel) {
198   implementation_->UnloadKernel(kernel);
199 }
200 
LoadModule(const MultiModuleLoaderSpec & spec,ModuleHandle * module_handle)201 port::Status StreamExecutor::LoadModule(const MultiModuleLoaderSpec &spec,
202                                         ModuleHandle *module_handle) {
203   return implementation_->LoadModule(spec, module_handle);
204 }
205 
UnloadModule(ModuleHandle module_handle)206 bool StreamExecutor::UnloadModule(ModuleHandle module_handle) {
207   return implementation_->UnloadModule(module_handle);
208 }
209 
Deallocate(DeviceMemoryBase * mem)210 void StreamExecutor::Deallocate(DeviceMemoryBase *mem) {
211   VLOG(1) << "Called StreamExecutor::Deallocate(mem=" << mem->opaque()
212           << ") mem->size()=" << mem->size() << StackTraceIfVLOG10();
213 
214   if (mem->opaque() != nullptr) {
215     EraseAllocRecord(mem->opaque());
216   }
217   implementation_->Deallocate(mem);
218   mem->Reset(nullptr, 0);
219 }
220 
GetMemAllocs(std::map<void *,AllocRecord> * records_out)221 void StreamExecutor::GetMemAllocs(std::map<void *, AllocRecord> *records_out) {
222   absl::ReaderMutexLock lock(&mu_);
223   *records_out = mem_allocs_;
224 }
225 
CanEnablePeerAccessTo(StreamExecutor * other)226 bool StreamExecutor::CanEnablePeerAccessTo(StreamExecutor *other) {
227   return implementation_->CanEnablePeerAccessTo(other->implementation_.get());
228 }
229 
EnablePeerAccessTo(StreamExecutor * other)230 port::Status StreamExecutor::EnablePeerAccessTo(StreamExecutor *other) {
231   return implementation_->EnablePeerAccessTo(other->implementation_.get());
232 }
233 
GetDeviceDescription() const234 const DeviceDescription &StreamExecutor::GetDeviceDescription() const {
235   absl::MutexLock lock(&mu_);
236   if (device_description_ != nullptr) {
237     return *device_description_;
238   }
239 
240   device_description_ = CreateDeviceDescription();
241   return *device_description_;
242 }
243 
GetDeviceLoad() const244 int64 StreamExecutor::GetDeviceLoad() const {
245   return implementation_->GetDeviceLoad();
246 }
247 
PlatformDeviceCount() const248 int StreamExecutor::PlatformDeviceCount() const {
249   return implementation_->PlatformDeviceCount();
250 }
251 
SupportsBlas() const252 bool StreamExecutor::SupportsBlas() const {
253   return implementation_->SupportsBlas();
254 }
255 
SupportsRng() const256 bool StreamExecutor::SupportsRng() const {
257   return implementation_->SupportsRng();
258 }
259 
SupportsDnn() const260 bool StreamExecutor::SupportsDnn() const {
261   return implementation_->SupportsDnn();
262 }
263 
GetConvolveAlgorithms(bool with_winograd_nonfused,std::vector<dnn::AlgorithmDesc> * out_algorithms)264 bool StreamExecutor::GetConvolveAlgorithms(
265     bool with_winograd_nonfused,
266     std::vector<dnn::AlgorithmDesc> *out_algorithms) {
267   dnn::DnnSupport *dnn_support = AsDnn();
268   if (!dnn_support) {
269     return false;
270   }
271   int cc_major, cc_minor;
272   GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
273   return dnn_support->GetConvolveAlgorithms(with_winograd_nonfused, cc_major,
274                                             cc_minor, out_algorithms);
275 }
276 
GetMIOpenConvolveAlgorithms(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,ScratchAllocator * scratch_allocator,std::vector<dnn::ProfileResult> * out_algorithms)277 bool StreamExecutor::GetMIOpenConvolveAlgorithms(
278     dnn::ConvolutionKind kind, dnn::DataType element_type, Stream *stream,
279     const dnn::BatchDescriptor &input_descriptor, DeviceMemoryBase input_data,
280     const dnn::FilterDescriptor &filter_descriptor,
281     DeviceMemoryBase filter_data, const dnn::BatchDescriptor &output_descriptor,
282     DeviceMemoryBase output_data,
283     const dnn::ConvolutionDescriptor &convolution_descriptor,
284     ScratchAllocator *scratch_allocator,
285     std::vector<dnn::ProfileResult> *out_algorithms) {
286   dnn::DnnSupport *dnn_support = AsDnn();
287   if (!dnn_support) {
288     return false;
289   }
290   return dnn_support->GetMIOpenConvolveAlgorithms(
291       kind, element_type, stream, input_descriptor, input_data,
292       filter_descriptor, filter_data, output_descriptor, output_data,
293       convolution_descriptor, scratch_allocator, out_algorithms);
294 }
295 
GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> * out_algorithms)296 bool StreamExecutor::GetRnnAlgorithms(
297     std::vector<dnn::AlgorithmDesc> *out_algorithms) {
298   dnn::DnnSupport *dnn_support = AsDnn();
299   if (!dnn_support) {
300     return false;
301   }
302   return dnn_support->GetRnnAlgorithms(out_algorithms);
303 }
304 
GetConvolveBackwardDataAlgorithms(bool with_winograd_nonfused,std::vector<dnn::AlgorithmDesc> * out_algorithms)305 bool StreamExecutor::GetConvolveBackwardDataAlgorithms(
306     bool with_winograd_nonfused,
307     std::vector<dnn::AlgorithmDesc> *out_algorithms) {
308   dnn::DnnSupport *dnn_support = AsDnn();
309   if (!dnn_support) {
310     return false;
311   }
312   int cc_major, cc_minor;
313   GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
314   return dnn_support->GetConvolveBackwardDataAlgorithms(
315       with_winograd_nonfused, cc_major, cc_minor, out_algorithms);
316 }
317 
GetConvolveBackwardFilterAlgorithms(bool with_winograd_nonfused,std::vector<dnn::AlgorithmDesc> * out_algorithms)318 bool StreamExecutor::GetConvolveBackwardFilterAlgorithms(
319     bool with_winograd_nonfused,
320     std::vector<dnn::AlgorithmDesc> *out_algorithms) {
321   dnn::DnnSupport *dnn_support = AsDnn();
322   if (!dnn_support) {
323     return false;
324   }
325   int cc_major, cc_minor;
326   GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
327   return dnn_support->GetConvolveBackwardFilterAlgorithms(
328       with_winograd_nonfused, cc_major, cc_minor, out_algorithms);
329 }
330 
GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> * out_algorithms)331 bool StreamExecutor::GetBlasGemmAlgorithms(
332     std::vector<blas::AlgorithmType> *out_algorithms) {
333   blas::BlasSupport *blas_support = AsBlas();
334   if (!blas_support) {
335     return false;
336   }
337   return blas_support->GetBlasGemmAlgorithms(out_algorithms);
338 }
339 
340 port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams & params)341 StreamExecutor::CreateBlasLtMatmulPlan(
342     const blas::BlasLtMatmulPlanParams &params) {
343   blas::BlasSupport *blas_support = AsBlas();
344   if (!blas_support) {
345     return port::Status(port::error::UNKNOWN,
346                         "Fail to find the blas implementation.");
347   }
348   return blas_support->CreateBlasLtMatmulPlan(params);
349 }
350 
351 port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan * plan,size_t max_workspace_size,int max_algorithm_count)352 StreamExecutor::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
353                                           size_t max_workspace_size,
354                                           int max_algorithm_count) {
355   blas::BlasSupport *blas_support = AsBlas();
356   if (!blas_support) {
357     return port::Status(port::error::UNKNOWN,
358                         "Fail to find the blas implementation.");
359   }
360   return blas_support->GetBlasLtMatmulAlgorithms(plan, max_workspace_size,
361                                                  max_algorithm_count);
362 }
363 
364 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
createRnnDescriptor(int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,dnn::RnnInputMode input_mode,dnn::RnnDirectionMode direction_mode,dnn::RnnMode rnn_mode,dnn::DataType data_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64 seed,ScratchAllocator * state_allocator,bool use_padded_io)365 StreamExecutor::createRnnDescriptor(
366     int num_layers, int hidden_size, int input_size, int cell_size,
367     int batch_size, dnn::RnnInputMode input_mode,
368     dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
369     dnn::DataType data_type, const dnn::AlgorithmConfig &algorithm_config,
370     float dropout, uint64 seed, ScratchAllocator *state_allocator,
371     bool use_padded_io) {
372   dnn::DnnSupport *dnn_support = AsDnn();
373   if (!dnn_support) {
374     return port::Status(port::error::UNKNOWN,
375                         "Fail to find the dnn implementation.");
376   }
377   return dnn_support->createRnnDescriptor(
378       num_layers, hidden_size, input_size, cell_size, batch_size, input_mode,
379       direction_mode, rnn_mode, data_type, algorithm_config, dropout, seed,
380       state_allocator, use_padded_io);
381 }
382 
383 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,dnn::DataType data_type)384 StreamExecutor::createRnnSequenceTensorDescriptor(int max_seq_length,
385                                                   int batch_size, int data_size,
386                                                   dnn::DataType data_type) {
387   dnn::DnnSupport *dnn_support = AsDnn();
388   if (!dnn_support) {
389     return port::Status(port::error::UNKNOWN,
390                         "Fail to find the dnn implementation.");
391   }
392   return dnn_support->createRnnSequenceTensorDescriptor(
393       max_seq_length, batch_size, data_size, data_type);
394 }
395 
396 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,dnn::DataType data_type)397 StreamExecutor::createRnnSequenceTensorDescriptor(
398     int max_seq_length, int batch_size, int data_size,
399     const absl::Span<const int> &seq_lengths, bool time_major,
400     dnn::DataType data_type) {
401   dnn::DnnSupport *dnn_support = AsDnn();
402   if (!dnn_support) {
403     return port::Status(port::error::UNKNOWN,
404                         "Fail to find the dnn implementation.");
405   }
406   return dnn_support->createRnnSequenceTensorDescriptor(
407       max_seq_length, batch_size, data_size, seq_lengths, time_major,
408       data_type);
409 }
410 
411 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
createRnnStateTensorDescriptor(int num_layer,int batch_size,int data_size,dnn::DataType data_type)412 StreamExecutor::createRnnStateTensorDescriptor(int num_layer, int batch_size,
413                                                int data_size,
414                                                dnn::DataType data_type) {
415   dnn::DnnSupport *dnn_support = AsDnn();
416   if (!dnn_support) {
417     return port::Status(port::error::UNKNOWN,
418                         "Fail to find the dnn implementation.");
419   }
420   return dnn_support->createRnnStateTensorDescriptor(num_layer, batch_size,
421                                                      data_size, data_type);
422 }
423 
AsDnn()424 dnn::DnnSupport *StreamExecutor::AsDnn() {
425   absl::MutexLock lock(&mu_);
426   if (dnn_ != nullptr) {
427     return dnn_.get();
428   }
429 
430   dnn_.reset(implementation_->CreateDnn());
431   return dnn_.get();
432 }
433 
AsBlas()434 blas::BlasSupport *StreamExecutor::AsBlas() {
435   absl::MutexLock lock(&mu_);
436   if (blas_ != nullptr) {
437     return blas_.get();
438   }
439 
440   blas_.reset(implementation_->CreateBlas());
441   return blas_.get();
442 }
443 
AsFft()444 fft::FftSupport *StreamExecutor::AsFft() {
445   absl::MutexLock lock(&mu_);
446   if (fft_ != nullptr) {
447     return fft_.get();
448   }
449 
450   fft_.reset(implementation_->CreateFft());
451   return fft_.get();
452 }
453 
AsRng()454 rng::RngSupport *StreamExecutor::AsRng() {
455   absl::MutexLock lock(&mu_);
456   if (rng_ != nullptr) {
457     return rng_.get();
458   }
459 
460   rng_.reset(implementation_->CreateRng());
461   return rng_.get();
462 }
463 
Launch(Stream * stream,const ThreadDim & thread_dims,const BlockDim & block_dims,const KernelBase & kernel,const KernelArgsArrayBase & args)464 port::Status StreamExecutor::Launch(Stream *stream,
465                                     const ThreadDim &thread_dims,
466                                     const BlockDim &block_dims,
467                                     const KernelBase &kernel,
468                                     const KernelArgsArrayBase &args) {
469   SubmitTrace(&TraceListener::LaunchSubmit, stream, thread_dims, block_dims,
470               kernel, args);
471 
472   return implementation_->Launch(stream, thread_dims, block_dims, kernel, args);
473 }
474 
BlockHostUntilDone(Stream * stream)475 port::Status StreamExecutor::BlockHostUntilDone(Stream *stream) {
476   port::Status result;
477   SCOPED_TRACE(TraceListener::BlockHostUntilDone, &result, stream);
478 
479   result = implementation_->BlockHostUntilDone(stream);
480   return result;
481 }
482 
GetStatus(Stream * stream)483 port::Status StreamExecutor::GetStatus(Stream *stream) {
484   return implementation_->GetStatus(stream);
485 }
486 
Allocate(uint64 size,int64 memory_space)487 DeviceMemoryBase StreamExecutor::Allocate(uint64 size, int64 memory_space) {
488   if (memory_limit_bytes_ > 0 &&
489       static_cast<int64>(mem_alloc_bytes_ + size) > memory_limit_bytes_) {
490     LOG(WARNING) << "Not enough memory to allocate " << size << " on device "
491                  << device_ordinal_
492                  << " within provided limit. [used=" << mem_alloc_bytes_
493                  << ", limit=" << memory_limit_bytes_ << "]";
494     return DeviceMemoryBase();
495   }
496   DeviceMemoryBase buf = implementation_->Allocate(size, memory_space);
497   VLOG(1) << "Called StreamExecutor::Allocate(size=" << size
498           << ", memory_space=" << memory_space << ") returns " << buf.opaque()
499           << StackTraceIfVLOG10();
500   CreateAllocRecord(buf.opaque(), size);
501 
502   return buf;
503 }
504 
GetUntypedSymbol(const std::string & symbol_name,ModuleHandle module_handle)505 port::StatusOr<DeviceMemoryBase> StreamExecutor::GetUntypedSymbol(
506     const std::string &symbol_name, ModuleHandle module_handle) {
507   // If failed to get the symbol, opaque/bytes are unchanged. Initialize them to
508   // be nullptr/0 for consistency with DeviceMemory semantics.
509   void *opaque = nullptr;
510   size_t bytes = 0;
511   if (GetSymbol(symbol_name, module_handle, &opaque, &bytes)) {
512     return DeviceMemoryBase(opaque, bytes);
513   }
514 
515   if (static_cast<bool>(module_handle)) {
516     return port::Status(
517         port::error::NOT_FOUND,
518         absl::StrCat("Check if module containing symbol ", symbol_name,
519                      " is loaded (module_handle = ",
520                      reinterpret_cast<uintptr_t>(module_handle.id()), ")"));
521   } else {
522     return port::Status(
523         port::error::NOT_FOUND,
524         absl::StrCat("Check if kernel using the symbol is loaded: ",
525                      symbol_name));
526   }
527 }
528 
GetSymbol(const std::string & symbol_name,ModuleHandle module_handle,void ** mem,size_t * bytes)529 bool StreamExecutor::GetSymbol(const std::string &symbol_name,
530                                ModuleHandle module_handle, void **mem,
531                                size_t *bytes) {
532   return implementation_->GetSymbol(symbol_name, module_handle, mem, bytes);
533 }
534 
UnifiedMemoryAllocate(uint64 bytes)535 void *StreamExecutor::UnifiedMemoryAllocate(uint64 bytes) {
536   void *buffer = implementation_->UnifiedMemoryAllocate(bytes);
537   VLOG(1) << "Called StreamExecutor::UnifiedMemoryAllocate(size=" << bytes
538           << ") returns " << buffer << StackTraceIfVLOG10();
539   return buffer;
540 }
541 
UnifiedMemoryDeallocate(void * location)542 void StreamExecutor::UnifiedMemoryDeallocate(void *location) {
543   VLOG(1) << "Called StreamExecutor::UnifiedMemoryDeallocate(location="
544           << location << ")" << StackTraceIfVLOG10();
545 
546   return implementation_->UnifiedMemoryDeallocate(location);
547 }
548 
HostMemoryAllocate(uint64 size)549 void *StreamExecutor::HostMemoryAllocate(uint64 size) {
550   void *buffer = implementation_->HostMemoryAllocate(size);
551   VLOG(1) << "Called StreamExecutor::HostMemoryAllocate(size=" << size
552           << ") returns " << buffer << StackTraceIfVLOG10();
553   return buffer;
554 }
555 
HostMemoryDeallocate(void * location)556 void StreamExecutor::HostMemoryDeallocate(void *location) {
557   VLOG(1) << "Called StreamExecutor::HostMemoryDeallocate(location=" << location
558           << ")" << StackTraceIfVLOG10();
559 
560   return implementation_->HostMemoryDeallocate(location);
561 }
562 
HostMemoryRegister(void * location,uint64 size)563 bool StreamExecutor::HostMemoryRegister(void *location, uint64 size) {
564   VLOG(1) << "Called StreamExecutor::HostMemoryRegister(location=" << location
565           << ", size=" << size << ")" << StackTraceIfVLOG10();
566   if (location == nullptr || size == 0) {
567     LOG(WARNING) << "attempting to register null or zero-sized memory: "
568                  << location << "; size " << size;
569   }
570   return implementation_->HostMemoryRegister(location, size);
571 }
572 
HostMemoryUnregister(void * location)573 bool StreamExecutor::HostMemoryUnregister(void *location) {
574   VLOG(1) << "Called StreamExecutor::HostMemoryUnregister(location=" << location
575           << ")" << StackTraceIfVLOG10();
576   return implementation_->HostMemoryUnregister(location);
577 }
578 
SynchronizeAllActivity()579 bool StreamExecutor::SynchronizeAllActivity() {
580   VLOG(1) << "Called StreamExecutor::SynchronizeAllActivity()"
581           << StackTraceIfVLOG10();
582   bool ok = implementation_->SynchronizeAllActivity();
583 
584   // This should all be quick and infallible work, so we can perform the
585   // synchronization even in the case of failure.
586   BlockOnThreadExecutor(background_threads_.get());
587 
588   return ok;
589 }
590 
SynchronousMemZero(DeviceMemoryBase * location,uint64 size)591 port::Status StreamExecutor::SynchronousMemZero(DeviceMemoryBase *location,
592                                                 uint64 size) {
593   VLOG(1) << "Called StreamExecutor::SynchronousMemZero(location=" << location
594           << ", size=" << size << ")" << StackTraceIfVLOG10();
595 
596   return implementation_->SynchronousMemZero(location, size);
597 }
598 
SynchronousMemSet(DeviceMemoryBase * location,int value,uint64 size)599 port::Status StreamExecutor::SynchronousMemSet(DeviceMemoryBase *location,
600                                                int value, uint64 size) {
601   VLOG(1) << "Called StreamExecutor::SynchronousMemSet(location=" << location
602           << ", value=" << value << ", size=" << size << ")"
603           << StackTraceIfVLOG10();
604 
605   return implementation_->SynchronousMemSet(location, value, size);
606 }
607 
SynchronousMemcpy(DeviceMemoryBase * device_dst,const void * host_src,uint64 size)608 bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *device_dst,
609                                        const void *host_src, uint64 size) {
610   VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst="
611           << device_dst->opaque() << ", host_src=" << host_src
612           << ", size=" << size << ") H2D" << StackTraceIfVLOG10();
613 
614   // Tracing overloaded methods is very difficult due to issues with type
615   // inference on template args. Since use of these overloaded methods is
616   // discouraged anyway, this isn't a huge deal.
617   port::Status status =
618       implementation_->SynchronousMemcpy(device_dst, host_src, size);
619   if (!status.ok()) {
620     LOG(ERROR) << "synchronous memcpy: " << status;
621   }
622   return status.ok();
623 }
624 
SynchronousMemcpy(void * host_dst,const DeviceMemoryBase & device_src,uint64 size)625 bool StreamExecutor::SynchronousMemcpy(void *host_dst,
626                                        const DeviceMemoryBase &device_src,
627                                        uint64 size) {
628   VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(host_dst=" << host_dst
629           << ", device_src=" << device_src.opaque() << ", size=" << size
630           << ") D2H" << StackTraceIfVLOG10();
631 
632   port::Status status =
633       implementation_->SynchronousMemcpy(host_dst, device_src, size);
634   if (!status.ok()) {
635     LOG(ERROR) << "synchronous memcpy: " << status;
636   }
637   return status.ok();
638 }
639 
SynchronousMemcpy(DeviceMemoryBase * device_dst,const DeviceMemoryBase & device_src,uint64 size)640 bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *device_dst,
641                                        const DeviceMemoryBase &device_src,
642                                        uint64 size) {
643   VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst="
644           << device_dst->opaque() << ", device_src=" << device_src.opaque()
645           << ", size=" << size << ") D2D" << StackTraceIfVLOG10();
646 
647   port::Status status = implementation_->SynchronousMemcpyDeviceToDevice(
648       device_dst, device_src, size);
649   if (!status.ok()) {
650     LOG(ERROR) << "synchronous memcpy: " << status;
651   }
652   return status.ok();
653 }
654 
SynchronousMemcpyD2H(const DeviceMemoryBase & device_src,int64 size,void * host_dst)655 port::Status StreamExecutor::SynchronousMemcpyD2H(
656     const DeviceMemoryBase &device_src, int64 size, void *host_dst) {
657   VLOG(1) << "Called StreamExecutor::SynchronousMemcpyD2H(device_src="
658           << device_src.opaque() << ", size=" << size
659           << ", host_dst=" << host_dst << ")" << StackTraceIfVLOG10();
660 
661   port::Status result;
662   SCOPED_TRACE(TraceListener::SynchronousMemcpyD2H, &result, device_src, size,
663                host_dst);
664 
665   result = implementation_->SynchronousMemcpy(host_dst, device_src, size);
666   if (!result.ok()) {
667     result = port::Status(
668         port::error::INTERNAL,
669         absl::StrFormat("failed to synchronously memcpy device-to-host: device "
670                         "%p to host %p size %d: %s",
671                         device_src.opaque(), host_dst, size,
672                         result.ToString()));
673   }
674 
675   return result;
676 }
677 
SynchronousMemcpyH2D(const void * host_src,int64 size,DeviceMemoryBase * device_dst)678 port::Status StreamExecutor::SynchronousMemcpyH2D(
679     const void *host_src, int64 size, DeviceMemoryBase *device_dst) {
680   VLOG(1) << "Called StreamExecutor::SynchronousMemcpyH2D(host_src=" << host_src
681           << ", size=" << size << ", device_dst=" << device_dst->opaque() << ")"
682           << StackTraceIfVLOG10();
683 
684   port::Status result;
685   SCOPED_TRACE(TraceListener::SynchronousMemcpyH2D, &result, host_src, size,
686                device_dst);
687 
688   result = implementation_->SynchronousMemcpy(device_dst, host_src, size);
689   if (!result.ok()) {
690     result = port::Status(
691         port::error::INTERNAL,
692         absl::StrFormat("failed to synchronously memcpy host-to-device: host "
693                         "%p to device %p size %d: %s",
694                         host_src, device_dst->opaque(), size,
695                         result.ToString()));
696   }
697 
698   return result;
699 }
700 
Memcpy(Stream * stream,void * host_dst,const DeviceMemoryBase & device_src,uint64 size)701 bool StreamExecutor::Memcpy(Stream *stream, void *host_dst,
702                             const DeviceMemoryBase &device_src, uint64 size) {
703   return implementation_->Memcpy(stream, host_dst, device_src, size);
704 }
705 
Memcpy(Stream * stream,DeviceMemoryBase * device_dst,const void * host_src,uint64 size)706 bool StreamExecutor::Memcpy(Stream *stream, DeviceMemoryBase *device_dst,
707                             const void *host_src, uint64 size) {
708   return implementation_->Memcpy(stream, device_dst, host_src, size);
709 }
710 
MemcpyDeviceToDevice(Stream * stream,DeviceMemoryBase * device_dst,const DeviceMemoryBase & device_src,uint64 size)711 bool StreamExecutor::MemcpyDeviceToDevice(Stream *stream,
712                                           DeviceMemoryBase *device_dst,
713                                           const DeviceMemoryBase &device_src,
714                                           uint64 size) {
715   return implementation_->MemcpyDeviceToDevice(stream, device_dst, device_src,
716                                                size);
717 }
718 
MemZero(Stream * stream,DeviceMemoryBase * location,uint64 size)719 port::Status StreamExecutor::MemZero(Stream *stream, DeviceMemoryBase *location,
720                                      uint64 size) {
721   return implementation_->MemZero(stream, location, size);
722 }
723 
Memset32(Stream * stream,DeviceMemoryBase * location,uint32 pattern,uint64 size)724 port::Status StreamExecutor::Memset32(Stream *stream,
725                                       DeviceMemoryBase *location,
726                                       uint32 pattern, uint64 size) {
727   CHECK_EQ(0, size % 4)
728       << "need 32-bit multiple size to fill with 32-bit pattern";
729   return implementation_->Memset32(stream, location, pattern, size);
730 }
731 
HostCallback(Stream * stream,std::function<void ()> callback)732 bool StreamExecutor::HostCallback(Stream *stream,
733                                   std::function<void()> callback) {
734   return implementation_->HostCallback(stream, std::move(callback));
735 }
736 
HostCallback(Stream * stream,std::function<port::Status ()> callback)737 bool StreamExecutor::HostCallback(Stream *stream,
738                                   std::function<port::Status()> callback) {
739   return implementation_->HostCallback(stream, std::move(callback));
740 }
741 
AllocateEvent(Event * event)742 port::Status StreamExecutor::AllocateEvent(Event *event) {
743   return implementation_->AllocateEvent(event);
744 }
745 
DeallocateEvent(Event * event)746 port::Status StreamExecutor::DeallocateEvent(Event *event) {
747   return implementation_->DeallocateEvent(event);
748 }
749 
RecordEvent(Stream * stream,Event * event)750 port::Status StreamExecutor::RecordEvent(Stream *stream, Event *event) {
751   return implementation_->RecordEvent(stream, event);
752 }
753 
WaitForEvent(Stream * stream,Event * event)754 port::Status StreamExecutor::WaitForEvent(Stream *stream, Event *event) {
755   return implementation_->WaitForEvent(stream, event);
756 }
757 
PollForEventStatus(Event * event)758 Event::Status StreamExecutor::PollForEventStatus(Event *event) {
759   return implementation_->PollForEventStatus(event);
760 }
761 
AllocateStream(Stream * stream)762 bool StreamExecutor::AllocateStream(Stream *stream) {
763   live_stream_count_.fetch_add(1, std::memory_order_relaxed);
764   if (!implementation_->AllocateStream(stream)) {
765     auto count = live_stream_count_.fetch_sub(1);
766     CHECK_GE(count, 0) << "live stream count should not dip below zero";
767     LOG(INFO) << "failed to allocate stream; live stream count: " << count;
768     return false;
769   }
770 
771   return true;
772 }
773 
DeallocateStream(Stream * stream)774 void StreamExecutor::DeallocateStream(Stream *stream) {
775   implementation_->DeallocateStream(stream);
776   CHECK_GE(live_stream_count_.fetch_sub(1), 0)
777       << "live stream count should not dip below zero";
778 }
779 
CreateStreamDependency(Stream * dependent,Stream * other)780 bool StreamExecutor::CreateStreamDependency(Stream *dependent, Stream *other) {
781   return implementation_->CreateStreamDependency(dependent, other);
782 }
783 
AllocateTimer(Timer * timer)784 bool StreamExecutor::AllocateTimer(Timer *timer) {
785   return implementation_->AllocateTimer(timer);
786 }
787 
DeallocateTimer(Timer * timer)788 void StreamExecutor::DeallocateTimer(Timer *timer) {
789   return implementation_->DeallocateTimer(timer);
790 }
791 
StartTimer(Stream * stream,Timer * timer)792 bool StreamExecutor::StartTimer(Stream *stream, Timer *timer) {
793   return implementation_->StartTimer(stream, timer);
794 }
795 
StopTimer(Stream * stream,Timer * timer)796 bool StreamExecutor::StopTimer(Stream *stream, Timer *timer) {
797   return implementation_->StopTimer(stream, timer);
798 }
799 
CreateDeviceDescription() const800 std::unique_ptr<DeviceDescription> StreamExecutor::CreateDeviceDescription()
801     const {
802   auto desc_status = implementation_->CreateDeviceDescription();
803   return desc_status.ConsumeValueOrDie();
804 }
805 
DeviceMemoryUsage(int64 * free,int64 * total) const806 bool StreamExecutor::DeviceMemoryUsage(int64 *free, int64 *total) const {
807   return implementation_->DeviceMemoryUsage(free, total);
808 }
809 
EnqueueOnBackgroundThread(std::function<void ()> task)810 void StreamExecutor::EnqueueOnBackgroundThread(std::function<void()> task) {
811   background_threads_->Schedule(std::move(task));
812 }
813 
CreateAllocRecord(void * opaque,uint64 bytes)814 void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) {
815   if (FLAGS_check_device_leaks && opaque != nullptr && bytes != 0) {
816     absl::MutexLock lock(&mu_);
817     mem_allocs_[opaque] = AllocRecord{bytes, ""};
818     mem_alloc_bytes_ += bytes;
819   }
820 }
821 
EraseAllocRecord(void * opaque)822 void StreamExecutor::EraseAllocRecord(void *opaque) {
823   if (FLAGS_check_device_leaks && opaque != nullptr) {
824     absl::MutexLock lock(&mu_);
825     if (mem_allocs_.find(opaque) == mem_allocs_.end()) {
826       LOG(ERROR) << "Deallocating unknown pointer: " << opaque;
827     } else {
828       mem_alloc_bytes_ -= mem_allocs_[opaque].bytes;
829       mem_allocs_.erase(opaque);
830     }
831   }
832 }
833 
EnableTracing(bool enabled)834 void StreamExecutor::EnableTracing(bool enabled) { tracing_enabled_ = enabled; }
835 
RegisterTraceListener(TraceListener * listener)836 void StreamExecutor::RegisterTraceListener(TraceListener *listener) {
837   {
838     absl::MutexLock lock(&mu_);
839     if (listeners_.find(listener) != listeners_.end()) {
840       LOG(INFO) << "Attempt to register already-registered listener, "
841                 << listener;
842     } else {
843       listeners_.insert(listener);
844     }
845   }
846 
847   implementation_->RegisterTraceListener(listener);
848 }
849 
UnregisterTraceListener(TraceListener * listener)850 bool StreamExecutor::UnregisterTraceListener(TraceListener *listener) {
851   {
852     absl::MutexLock lock(&mu_);
853     if (listeners_.find(listener) == listeners_.end()) {
854       LOG(INFO) << "Attempt to unregister unknown listener, " << listener;
855       return false;
856     }
857     listeners_.erase(listener);
858   }
859 
860   implementation_->UnregisterTraceListener(listener);
861   return true;
862 }
863 
GetAllocatorStats()864 absl::optional<AllocatorStats> StreamExecutor::GetAllocatorStats() {
865   return implementation_->GetAllocatorStats();
866 }
867 
868 template <typename TraceCallT, typename... ArgsT>
SubmitTrace(TraceCallT trace_call,ArgsT &&...args)869 void StreamExecutor::SubmitTrace(TraceCallT trace_call, ArgsT &&...args) {
870   if (tracing_enabled_) {
871     {
872       // instance tracers held in a block to limit the lock lifetime.
873       absl::ReaderMutexLock lock(&mu_);
874       for (TraceListener *listener : listeners_) {
875         (listener->*trace_call)(std::forward<ArgsT>(args)...);
876       }
877     }
878   }
879 }
880 
implementation()881 internal::StreamExecutorInterface *StreamExecutor::implementation() {
882   return implementation_->GetUnderlyingExecutor();
883 }
884 
StreamExecutorMemoryAllocator(StreamExecutor * executor)885 StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
886     StreamExecutor *executor)
887     : DeviceMemoryAllocator(executor->platform()) {
888   stream_executors_ = {executor};
889 }
890 
StreamExecutorMemoryAllocator(const Platform * platform,absl::Span<StreamExecutor * const> stream_executors)891 StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
892     const Platform *platform,
893     absl::Span<StreamExecutor *const> stream_executors)
894     : DeviceMemoryAllocator(platform),
895       stream_executors_(stream_executors.begin(), stream_executors.end()) {}
896 
Allocate(int device_ordinal,uint64 size,bool retry_on_failure,int64 memory_space)897 port::StatusOr<OwningDeviceMemory> StreamExecutorMemoryAllocator::Allocate(
898     int device_ordinal, uint64 size, bool retry_on_failure,
899     int64 memory_space) {
900   TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
901                       GetStreamExecutor(device_ordinal));
902   DeviceMemoryBase result = executor->AllocateArray<uint8>(size, memory_space);
903   if (size > 0 && result == nullptr) {
904     return tensorflow::errors::ResourceExhausted(absl::StrFormat(
905         "Failed to allocate request for %s (%uB) on device ordinal %d",
906         tensorflow::strings::HumanReadableNumBytes(size), size,
907         device_ordinal));
908   }
909   VLOG(3) << absl::StreamFormat(
910       "Allocated %s (%uB) on device ordinal %d: %p",
911       tensorflow::strings::HumanReadableNumBytes(size), size, device_ordinal,
912       result.opaque());
913   return OwningDeviceMemory(result, device_ordinal, this);
914 }
915 
Deallocate(int device_ordinal,DeviceMemoryBase mem)916 port::Status StreamExecutorMemoryAllocator::Deallocate(int device_ordinal,
917                                                        DeviceMemoryBase mem) {
918   if (!mem.is_null()) {
919     TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
920                         GetStreamExecutor(device_ordinal));
921     VLOG(3) << absl::StreamFormat("Freeing %p on device ordinal %d",
922                                   mem.opaque(), device_ordinal);
923     executor->Deallocate(&mem);
924   }
925   return port::Status::OK();
926 }
927 
928 port::StatusOr<StreamExecutor *>
GetStreamExecutor(int device_ordinal) const929 StreamExecutorMemoryAllocator::GetStreamExecutor(int device_ordinal) const {
930   if (device_ordinal < 0) {
931     return tensorflow::errors::InvalidArgument(absl::StrFormat(
932         "device ordinal value (%d) must be non-negative", device_ordinal));
933   }
934   for (StreamExecutor *se : stream_executors_) {
935     if (se->device_ordinal() == device_ordinal) {
936       return se;
937     }
938   }
939   return tensorflow::errors::NotFound(
940       absl::StrFormat("Device %s:%d present but not supported",
941                       platform()->Name(), device_ordinal));
942 }
943 
AllowsAsynchronousDeallocation() const944 bool StreamExecutorMemoryAllocator::AllowsAsynchronousDeallocation() const {
945   return false;
946 }
947 
GetStream(int device_ordinal)948 port::StatusOr<Stream *> StreamExecutorMemoryAllocator::GetStream(
949     int device_ordinal) {
950   CHECK(!AllowsAsynchronousDeallocation())
951       << "The logic below only works for synchronous allocators";
952   TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
953                       GetStreamExecutor(device_ordinal));
954   Stream *out = [&] {
955     absl::MutexLock lock(&mutex_);
956     if (!streams_.count(device_ordinal)) {
957       auto p = streams_.emplace(std::piecewise_construct,
958                                 std::forward_as_tuple(device_ordinal),
959                                 std::forward_as_tuple(executor));
960       p.first->second.Init();
961       return &p.first->second;
962     }
963     return &streams_.at(device_ordinal);
964   }();
965   return out;
966 }
967 
968 }  // namespace stream_executor
969