1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
17 #define TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
18 
19 #include <atomic>
20 #include <memory>
21 #include <set>
22 #include <tuple>
23 #include <vector>
24 
25 #include "absl/base/macros.h"
26 #include "absl/memory/memory.h"
27 #include "absl/synchronization/mutex.h"
28 #include "absl/types/optional.h"
29 #include "tensorflow/core/platform/thread_annotations.h"
30 #include "tensorflow/stream_executor/device_memory_allocator.h"
31 #include "tensorflow/stream_executor/lib/status.h"
32 #include "tensorflow/stream_executor/lib/statusor.h"
33 #include "tensorflow/stream_executor/lib/threadpool.h"
34 #include "tensorflow/stream_executor/platform.h"
35 #include "tensorflow/stream_executor/platform/logging.h"
36 #include "tensorflow/stream_executor/platform/port.h"
37 #include "tensorflow/stream_executor/rng.h"
38 #include "tensorflow/stream_executor/stream_executor_internal.h"
39 #include "tensorflow/stream_executor/trace_listener.h"
40 
41 namespace stream_executor {
42 
43 class Stream;
44 
45 // Structure used for device memory leak checking.
46 struct AllocRecord {
47   // The requested allocation size of the buffer.
48   uint64 bytes;
49 
50   // Holds a representation of the stack at the time the associated buffer was
51   // allocated. Produced in a form described in
52   // //util/symbolize/symbolized_stacktrace.h.
53   std::string stack_trace;
54 };
55 
56 // Forward declaration of private friend class.
57 template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
58           typename... BeginArgsT>
59 class ScopedTracer;
60 
61 // A StreamExecutor manages a single device, in terms of executing work (kernel
62 // launches) and memory management (allocation/deallocation, memory copies to
63 // and from the device). It is conceptually the "handle" for a device -- Stream
64 // objects, which are used to enqueue work to run on the
65 // coprocessor have a StreamExecutor instance as their "parent" object.
66 //
67 // StreamExecutor objects have an underlying platform that is specified up
68 // front;
69 // e.g. either it is a CUDA or OpenCL executor.
70 //
71 // Thread-safe after initialization.
72 // StreamExecutor interface should not be invoked from a signal handler.
73 class StreamExecutor {
74  public:
75   StreamExecutor(
76       const Platform *platform,
77       std::unique_ptr<internal::StreamExecutorInterface> implementation,
78       int device_ordinal);
79 
80   ~StreamExecutor();
81 
82   port::Status Init();
83   port::Status Init(DeviceOptions device_options);
84 
85   // Returns the platform that this StreamExecutor is acting upon.
86   ABSL_DEPRECATED("Use platform() instead.")
platform_kind()87   PlatformKind platform_kind() const { return platform_kind_; }
88 
89   // Returns a reference to the platform that created this executor.
platform()90   const Platform *platform() const { return platform_; }
91 
92   // Retrieves (loads) a kernel for the platform this StreamExecutor is acting
93   // upon, if one exists.
94   //
95   // Parameters:
96   //   spec: The MultiKernelLoaderSpec is usually generated as a compile-time
97   //    constant into an appropriate namespace. For example, see
98   //    stream_executor::executor_sample::kKernelLoaderSpecs, from which a
99   //    MultiKernelLoaderSpec is selected.
100   //   kernel: Outparam that the kernel is loaded into. A given Kernel
101   //    instantiation should not be loaded into more than once.
102   //
103   // If an error occurs, or there is no kernel available for the StreamExecutor
104   // platform, error status is returned.
105   port::Status GetKernel(const MultiKernelLoaderSpec &spec, KernelBase *kernel);
106 
107   // Releases any state associated with the previously loaded kernel.
108   void UnloadKernel(const KernelBase *kernel);
109 
110   // Loads a module for the platform this StreamExecutor is acting upon.
111   //
112   // `spec` describes the module to be loaded.  On success writes the handle for
113   // the loaded module to `module_handle` and returns Status::OK.
114   // Otherwise, returns the error which has occurred.
115   port::Status LoadModule(const MultiModuleLoaderSpec &spec,
116                           ModuleHandle *module_handle);
117 
118   // Unloads the module with handle `module_handle`.
119   bool UnloadModule(ModuleHandle module_handle);
120 
121   // Synchronously allocates an array on the device of type T with element_count
122   // elements.
123   template <typename T>
124   DeviceMemory<T> AllocateArray(uint64 element_count, int64 memory_space = 0);
125 
126   // As AllocateArray(), but returns a ScopedDeviceMemory<T>.
127   template <typename T>
AllocateOwnedArray(uint64 element_count)128   ScopedDeviceMemory<T> AllocateOwnedArray(uint64 element_count) {
129     return ScopedDeviceMemory<T>(this, AllocateArray<T>(element_count));
130   }
131 
132   // Convenience wrapper that allocates space for a single element of type T in
133   // device memory.
134   template <typename T>
AllocateScalar()135   DeviceMemory<T> AllocateScalar() {
136     return AllocateArray<T>(1);
137   }
138 
139   // As AllocateScalar(), but returns a ScopedDeviceMemory<T>.
140   template <typename T>
AllocateOwnedScalar()141   ScopedDeviceMemory<T> AllocateOwnedScalar() {
142     return AllocateOwnedArray<T>(1);
143   }
144 
145   // Synchronously allocates a scalar of type T on the device that is (POD)
146   // zero-byte initialized.
147   template <typename T>
148   DeviceMemory<T> AllocateZeroed();
149 
150   // As AllocateZeroed(), but returns a ScopedDeviceMemory<T>.
151   template <typename T>
AllocateOwnedZeroed()152   ScopedDeviceMemory<T> AllocateOwnedZeroed() {
153     return ScopedDeviceMemory<T>(this, AllocateZeroed<T>());
154   }
155 
156   // Allocate a memory region inside another allocated memory region.
157   // Offset and size are specified in terms of T elements.
158   // Warning: Do not free a parent buffer before its sub-buffers; this may cause
159   // use-after-free issues (the specific behavior is not consistent across
160   // platforms).
161   //  - Note: OpenCL uses refcounting to manage buffer lifetimes, so use of a
162   //    sub-buffer after parent deallocation is expected to be safe. This will
163   //    render your code non-platform-portable, however.
164   template <typename T>
165   DeviceMemory<T> GetSubBuffer(DeviceMemory<T> *parent, uint64 element_offset,
166                                uint64 element_count);
167 
168   // Finds a symbol and returns device memory allocated to the symbol. The
169   // symbol is searched in any kernels that were previously loaded through
170   // GetKernel() before the GetSymbol() call. The user has to make sure that the
171   // type of symbol and T match.
172   // - Note: symbol_name should include its namespace as well. For example,
173   //         pass "nms0::symbol" if referring to nms0::symbol.
174   //
175   // If `module_handle` is set then searches only within the module
176   // corresponding to `module_handle`.
177   template <typename T>
178   port::StatusOr<DeviceMemory<T>> GetSymbol(const std::string &symbol_name,
179                                             ModuleHandle module_handle = {});
180 
181   // An untyped version of GetSymbol.
182   port::StatusOr<DeviceMemoryBase> GetUntypedSymbol(
183       const std::string &symbol_name, ModuleHandle module_handle = {});
184 
185   // Deallocate the DeviceMemory previously allocated via this interface.
186   // Deallocation of a nullptr-representative value is permitted.
187   //
188   // Resets the internal contents of mem to be null-representative, but this
189   // null-out effect should not be relied upon in client code.
190   void Deallocate(DeviceMemoryBase *mem);
191 
192   // Retrieves a mapping of active opaque device memory pointer to a string
193   // representation of the [allocating thread's] stack at the time the pointer
194   // was allocated. Useful for tracking device memory leaks.
195   //
196   // Note: this will only be populated if --check_device_leaks flag is
197   // activated.
198   void GetMemAllocs(std::map<void *, AllocRecord> *records_out);
199 
200   // Allocates unified memory space of the given size, if supported.
201   // See
202   // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#um-unified-memory-programming-hd
203   // for more details on unified memory.
204   void *UnifiedMemoryAllocate(uint64 bytes);
205 
206   // Deallocates unified memory space previously allocated with
207   // UnifiedMemoryAllocate.
208   void UnifiedMemoryDeallocate(void *location);
209 
210   // Allocates a region of host memory and registers it with the platform API.
211   // Memory allocated in this manner (or allocated and registered with
212   // HostMemoryRegister() is required for use in asynchronous memcpy operations,
213   // such as Stream::ThenMemcpy.
214   void *HostMemoryAllocate(uint64 size);
215 
216   // Deallocates a region of host memory allocated by HostMemoryAllocate().
217   void HostMemoryDeallocate(void *location);
218 
219   // Registers a region of host memory with the platform API. Registered memory
220   // (or memory allocated with HostMemoryAllocate) is required for use with
221   // asynchronous memcpy operations, such as Stream::ThenMemcpy. This method
222   // is used to register memory allocated outside the StreamExecutor;
223   // HostMemoryAllocate implicitly registers its allocations and
224   // HostMemoryDeallocate implicitly deregisters on deallocation.
225   bool HostMemoryRegister(void *location, uint64 size) SE_MUST_USE_RESULT;
226 
227   // Unregisters a region of host memory registered with HostMemoryRegister.
228   // This should be done before deallocating the region with delete[]/free/etc.
229   bool HostMemoryUnregister(void *location) SE_MUST_USE_RESULT;
230 
231   // Synchronizes all activity occurring in the StreamExecutor's context (most
232   // likely a whole device).
233   bool SynchronizeAllActivity() SE_MUST_USE_RESULT;
234 
235   // Blocks the caller while "size" bytes are zeroed out (in POD fashion) at the
236   // given location in device memory.
237   port::Status SynchronousMemZero(DeviceMemoryBase *location,
238                                   uint64 size) SE_MUST_USE_RESULT;
239 
240   // Blocks the caller while "size" bytes are initialized to "value" (in POD
241   // fashion) at the given location in device memory.
242   port::Status SynchronousMemSet(DeviceMemoryBase *location, int value,
243                                  uint64 size) SE_MUST_USE_RESULT;
244 
245   // [deprecated] Blocks the caller while a data segment of the given size is
246   // copied from the host source to the device destination.
247   ABSL_DEPRECATED(
248       "Prefer SynchronousMemcpyH2D, to avoid error-prone API usage.")
249   bool SynchronousMemcpy(DeviceMemoryBase *device_dst, const void *host_src,
250                          uint64 size) SE_MUST_USE_RESULT;
251 
252   // [deprecated] Blocks the caller while a data segment of the given size is
253   // copied from the device source to the host destination.
254   ABSL_DEPRECATED(
255       "Prefer SynchronousMemcpyD2H, to avoid error-prone API usage.")
256   bool SynchronousMemcpy(void *host_dst, const DeviceMemoryBase &device_src,
257                          uint64 size) SE_MUST_USE_RESULT;
258 
259   // Same as SynchronousMemcpy(DeviceMemoryBase*, ...) above.
260   port::Status SynchronousMemcpyH2D(const void *host_src, int64 size,
261                                     DeviceMemoryBase *device_dst);
262 
263   // Alternative interface for memcpying from host to device that takes an
264   // array slice. Checks that the destination size can accommodate the host
265   // slice size.
266   template <class T>
SynchronousMemcpyH2D(port::ArraySlice<T> host_src,DeviceMemoryBase * device_dst)267   port::Status SynchronousMemcpyH2D(port::ArraySlice<T> host_src,
268                                     DeviceMemoryBase *device_dst) {
269     auto host_size = host_src.size() * sizeof(T);
270     CHECK(device_dst->size() == 0 || device_dst->size() >= host_size);
271     return SynchronousMemcpyH2D(host_src.begin(), host_size, device_dst);
272   }
273 
274   // Same as SynchronousMemcpy(void*, ...) above.
275   port::Status SynchronousMemcpyD2H(const DeviceMemoryBase &device_src,
276                                     int64 size, void *host_dst);
277 
278   // Alternative interface for memcpying from device to host that takes an
279   // array slice. Checks that the destination size can accommodate the host
280   // slice size.
281   template <typename T>
SynchronousMemcpyD2H(const DeviceMemory<T> & device_src,port::MutableArraySlice<T> host_dst)282   port::Status SynchronousMemcpyD2H(const DeviceMemory<T> &device_src,
283                                     port::MutableArraySlice<T> host_dst) {
284     auto host_size = host_dst.size() * sizeof(T);
285     CHECK(device_src.size() == 0 || host_size >= device_src.size());
286     return SynchronousMemcpyD2H(device_src, host_size, host_dst.begin());
287   }
288 
289   // Blocks the caller while a data segment of the given size is copied from the
290   // device source to the device destination.
291   bool SynchronousMemcpy(DeviceMemoryBase *device_dst,
292                          const DeviceMemoryBase &device_src,
293                          uint64 size) SE_MUST_USE_RESULT;
294 
295   // Enqueues an operation onto stream to zero out size bytes at the given
296   // device memory location. Neither stream nor location may be null. Returns
297   // whether the operation was successfully enqueued onto the stream.
298   port::Status MemZero(Stream *stream, DeviceMemoryBase *location,
299                        uint64 size) SE_MUST_USE_RESULT;
300 
301   // Enqueues an operation onto stream to set 32-bit patterns starting at
302   // location, for byte count given by size. size must be 32-bit quantified
303   // (i.e. evently divisible by 4). Returns whether the operation was
304   // successfully enqueued onto the stream.
305   port::Status Memset32(Stream *stream, DeviceMemoryBase *location,
306                         uint32 pattern, uint64 size);
307 
308   // Enables peer access from this StreamExecutor to memory
309   // allocated by other, such that launched device code, memcpies, etc may
310   // access it directly.
311   //
312   // Both this StreamExecutor and other must be backed by the same platform (as
313   // in
314   // CUDA vs OpenCL) implementation.
315   port::Status EnablePeerAccessTo(StreamExecutor *other);
316 
317   // Returns whether it's possible to enable peer access from this
318   // StreamExecutor
319   // to memory allocated by another.
320   //
321   // Even when this returns true, EnablePeerAccessTo may fail for other reasons;
322   // this is more an up-front test as to whether it's expressly forbidden.
323   bool CanEnablePeerAccessTo(StreamExecutor *other);
324 
325   // Obtains metadata about the underlying device.
326   // The value is cached on first use.
327   const DeviceDescription &GetDeviceDescription() const;
328 
329   // If implemented, returns device specific measurement of load
330   // (e.g. pending requests).
331   int64 GetDeviceLoad() const;
332 
333   // Returns the underlying device memory usage information, if it is available.
334   // If it is not available (false is returned), free/total may not be
335   // initialized.
336   //
337   // Note: "Free" reflects the amount of free memory on the underlying device,
338   // so allocations via other StreamExecutors that have the same underlying
339   // device
340   // will be reflected in "free".
341   bool DeviceMemoryUsage(int64 *free, int64 *total) const;
342 
343   // The device count reported by this StreamExecutor's platform.
344   // Note: on OpenCL we implicitly select platform zero at the moment.
345   int PlatformDeviceCount() const;
346 
347   // Returns whether the StreamExecutor supports BLAS routines for the platform
348   // that underlies this interface.
349   bool SupportsBlas() const;
350 
351   // Returns whether the StreamExecutor supports FFT routines for the platform
352   // that underlies this interface.
353   bool SupportsFft() const;
354 
355   // Returns whether the StreamExecutor supports RNG routines for the platform
356   // that underlies this interface.
357   bool SupportsRng() const;
358 
359   // Returns whether the StreamExecutor support neural net routines for the
360   // platform that underlies this interface.
361   bool SupportsDnn() const;
362 
363   // Returns the list of supported algorithms for the forward convolution
364   // operation.
365   bool GetConvolveAlgorithms(bool with_winograd_nonfused,
366                              std::vector<dnn::AlgorithmDesc> *out_algorithms);
367 
368   // Returns the list of supported algorithms for the forward convolution
369   // operation.
370   bool GetMIOpenConvolveAlgorithms(
371       dnn::ConvolutionKind kind, dnn::DataType element_type, Stream *stream,
372       const dnn::BatchDescriptor &input_descriptor, DeviceMemoryBase input_data,
373       const dnn::FilterDescriptor &filter_descriptor,
374       DeviceMemoryBase filter_data,
375       const dnn::BatchDescriptor &output_descriptor,
376       DeviceMemoryBase output_data,
377       const dnn::ConvolutionDescriptor &convolution_descriptor,
378       ScratchAllocator *scratch_allocator,
379       std::vector<dnn::ProfileResult> *out_algorithms);
380 
381   // Returns the list of supported algorithms for rnn operation.
382   bool GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> *out_algorithms);
383 
384   // Get the list of supported algorithms for the backward convolution on data.
385   bool GetConvolveBackwardDataAlgorithms(
386       bool with_winograd_nonfused,
387       std::vector<dnn::AlgorithmDesc> *out_algorithms);
388 
389   // Get the list of supported algorithms for the backward convolution on the
390   // filter.
391   bool GetConvolveBackwardFilterAlgorithms(
392       bool with_winograd_nonfused,
393       std::vector<dnn::AlgorithmDesc> *out_algorithms);
394 
395   // Get the list of supported algorithms for BLAS gemm.
396   bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms);
397 
398   // Creates a backend-specific plan object for a blaslt matmul operation, which
399   // can then be passed to DoBlasLtMatmul(). When possible, plans should be
400   // created once and reused for multiple calls to DoBlasLtMatmul().
401   // Returns a null pointer on failure.
402   port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
403   CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams &params);
404 
405   // Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
406   // returned in the order of increasing estimated compute time according to an
407   // internal heuristic. The first returned algorithm can be used as the default
408   // algorithm if no autotuning is to be performed.
409   port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
410   GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
411                             size_t max_workspace_size, int max_algorithm_count);
412 
413   // Create an RNN descriptor based on model shapes and configurations.
414   // The caller retains the ownership of the descriptor.
415   port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
416       int num_layers, int hidden_size, int input_size, int cell_size,
417       int batch_size, dnn::RnnInputMode input_mode,
418       dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
419       dnn::DataType data_type, const dnn::AlgorithmConfig &algorithm_config,
420       float dropout, uint64 seed, ScratchAllocator *state_allocator,
421       bool use_padded_io);
422 
423   // Create a RNN sequence descriptor that specifies either the input or output
424   // sequence. The caller retains the ownership of the returned descriptor.
425   port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
426   createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
427                                     int data_size, dnn::DataType data_type);
428 
429   port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
430   createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
431                                     int data_size,
432                                     const absl::Span<const int> &seq_lengths,
433                                     bool time_major, dnn::DataType data_type);
434 
435   // Create an RNN state descriptor that specifies the input or hidden state.
436   // The caller retains the ownership of the returned descriptor.
437   port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
438   createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size,
439                                  dnn::DataType data_type);
440 
441   // Returns the device ordinal that this StreamExecutor was initialized with.
442   // Meaningless before initialization.
device_ordinal()443   int device_ordinal() const { return device_ordinal_; }
444 
445   // Returns a borrowed pointer to the underlying StreamExecutor implementation.
446   internal::StreamExecutorInterface *implementation();
447 
448   // Creates a kernel which can be launched with stream.ThenLaunch, such that
449   // the types of the arguments provided for launch would have to match
450   // types of the arguments provided at creation time.
451   //
452   // The kernel has a name kernel_name, and is based from provided PTX in ptx,
453   // and (optional) compiled PTX in cubin_data.
454   // The canonical storage for both ptx and cubin_data should outlive the
455   // lifetime of the kernel.
456   template <typename... Args>
457   port::StatusOr<std::unique_ptr<TypedKernel<Args...>>> CreateTypedKernel(
458       absl::string_view kernel_name, absl::string_view ptx,
459       absl::Span<const uint8> cubin_data);
460 
461   // Warning: use Stream::ThenLaunch instead, this method is not for general
462   // consumption. However, this is the only way to launch a kernel for which
463   // the type signature is only known at runtime; say, if an application
464   // supports loading/launching kernels with arbitrary type signatures.
465   // In this case, the application is expected to know how to do parameter
466   // packing that obeys the contract of the underlying platform implementation.
467   //
468   // Launches a data parallel kernel with the given thread/block
469   // dimensionality and already-packed args/sizes to pass to the underlying
470   // platform driver.
471   //
472   // This is called by Stream::Launch() to delegate to the platform's launch
473   // implementation in StreamExecutorInterface::Launch().
474   port::Status Launch(Stream *stream, const ThreadDim &thread_dims,
475                       const BlockDim &block_dims, const KernelBase &kernel,
476                       const KernelArgsArrayBase &args);
477 
478   // Gets-or-creates (creates with memoization) a FftSupport datatype that can
479   // be used to execute FFT routines on the current platform.
480   //
481   // Ownership and user-facing is the same as AsBlas() below.
482   //
483   // Returns null if there was an error initializing the FFT support for the
484   // underlying platform.
485   fft::FftSupport *AsFft();
486 
487   // Gets-or-creates (creates with memoization) a DnnSupport datatype that can
488   // be used for neural network routines on the current platform.
489   //
490   // Ownership and user-facing is the same as AsBlas() below.
491   //
492   // Returns null if there was an error initializing the DNN support for the
493   // underlying platform.
494   dnn::DnnSupport *AsDnn();
495 
496   // Gets-or-creates (creates with memoization) a BlasSupport datatype that can
497   // be used to execute BLAS routines on the current platform. This is typically
498   // not user-facing, as users will use the Stream::ThenBlas* family of routines
499   // to entrain BLAS operations. See blas.h for additional details.
500   //
501   // Ownership is not transferred to the caller -- ownership is retained by this
502   // object for memoization. This BLAS interface is also only expected to be
503   // used by a Stream for entraining calls to BLAS functionality.
504   //
505   // Returns null if there was an error initializing the BLAS support for the
506   // underlying platform.
507   blas::BlasSupport *AsBlas();
508 
509   // Turns StreamExecutor operation tracing on or off.
510   void EnableTracing(bool enable);
511 
512   // Registers a trace listener to receive callbacks for only a single
513   // StreamExecutor instance.
514   // To register a listener for all executors for a given platform, see
515   // Platform::RegisterTraceListener().
516   // Does not take ownership of listener.
517   void RegisterTraceListener(TraceListener *listener);
518 
519   // Removes a TraceListener from this StreamExecutor instance.
520   // Returns false (and logs) in cases where the argument listener was not
521   // previously registered.
522   bool UnregisterTraceListener(TraceListener *listener);
523 
524   // Return allocator statistics.
525   absl::optional<AllocatorStats> GetAllocatorStats();
526 
527   // Return an allocator which delegates to this stream executor for memory
528   // allocation.
GetAllocator()529   StreamExecutorMemoryAllocator *GetAllocator() { return &allocator_; }
530 
531  private:
532   template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
533             typename... BeginArgsT>
534   friend class ScopedTracer;
535   friend class Event;
536   friend class Stream;
537   friend class Timer;
538   template <typename... Params>
539   friend class TypedKernel;
540   template <typename... Args>
541   friend struct ThenBlasImpl;
542 
543   // Synchronously allocates size bytes on the underlying platform and returns
544   // a DeviceMemoryBase representing that allocation. In the case of failure,
545   // nullptr is returned.
546   DeviceMemoryBase Allocate(uint64 size, int64 memory_space);
547 
548   // Gets-or-creates (creates with memoization) an RngSupport datatype that can
549   // be used for random-number-generation routines on the current platform.
550   //
551   // Ownership and user-facing is the same as AsBlas() above.
552   //
553   // Returns null if there was an error initializing the RNG support for the
554   // underlying platform.
555   rng::RngSupport *AsRng();
556 
557   // Causes the host code to synchronously wait for operations entrained onto
558   // stream to complete. Effectively a join on the asynchronous device
559   // operations enqueued on the stream before this program point.
560   port::Status BlockHostUntilDone(Stream *stream);
561 
562   // Without blocking the device, retrieve the current stream status.
563   port::Status GetStatus(Stream *stream);
564 
565   // Finds and retrieves device memory for the symbol on the underlying
566   // platform.
567   bool GetSymbol(const std::string &symbol_name, ModuleHandle module_handle,
568                  void **mem, size_t *bytes);
569 
570   // Entrains a memcpy operation onto stream, with a host destination location
571   // host_dst and a device memory source, with target size size.
572   bool Memcpy(Stream *stream, void *host_dst,
573               const DeviceMemoryBase &device_src, uint64 size);
574 
575   // Entrains a memcpy operation onto stream, with a device destination location
576   // and a host memory source, with target size size.
577   bool Memcpy(Stream *stream, DeviceMemoryBase *device_dst,
578               const void *host_src, uint64 size);
579 
580   // Entrains a memcpy operation onto stream, with a device destination location
581   // and a device source location, with target size size. Peer access should
582   // have been enabled between the StreamExecutors owning the device memory
583   // regions.
584   bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *device_dst,
585                             const DeviceMemoryBase &device_src, uint64 size);
586 
587   // Entrains on a stream a user-specified function to be run on the host.
588   // See Stream::ThenDoHostCallback for full details.
589   bool HostCallback(Stream *stream, std::function<void()> callback);
590 
591   // Entrains on a stream a user-specified function to be run on the host.
592   // See Stream::ThenDoHostCallback for full details.
593   // This is the preferred form for a callback that may return an error.
594   bool HostCallback(Stream *stream, std::function<port::Status()> callback);
595 
596   // Performs platform-specific allocation and initialization of an event.
597   port::Status AllocateEvent(Event *event);
598 
599   // Performs platform-specific deallocation and cleanup of an event.
600   port::Status DeallocateEvent(Event *event);
601 
602   // Inserts the specified event at the end of the specified stream.
603   port::Status RecordEvent(Stream *stream, Event *event);
604 
605   // Wait for the specified event at the end of the specified stream.
606   port::Status WaitForEvent(Stream *stream, Event *event);
607 
608   // Requests the current status of the event from the underlying platform.
609   Event::Status PollForEventStatus(Event *event);
610 
611   // Allocates stream resources on the underlying platform and initializes its
612   // internals.
613   bool AllocateStream(Stream *stream);
614 
615   // Deallocates stream resources on the underlying platform.
616   void DeallocateStream(Stream *stream);
617 
618   // Causes dependent to not begin execution until other has finished its
619   // last-enqueued work.
620   bool CreateStreamDependency(Stream *dependent, Stream *other);
621 
622   // Allocates timer resources on the underlying platform and initializes its
623   // internals.
624   bool AllocateTimer(Timer *timer);
625 
626   // Deallocates timer resources on the underlying platform.
627   void DeallocateTimer(Timer *timer);
628 
629   // Records a start event for an interval timer.
630   bool StartTimer(Stream *stream, Timer *timer);
631 
632   // Records a stop event for an interval timer.
633   bool StopTimer(Stream *stream, Timer *timer);
634 
635   // Allocates a new metadata object, appropriately populated, on the heap, with
636   // ownership transfer to caller.
637   std::unique_ptr<DeviceDescription> CreateDeviceDescription() const;
638 
639   // Adds a task to the port::ThreadPool work queue. These tasks must be
640   // fire-and-forget and have no external data or timing dependencies; their
641   // execution order and completion time have no guarantees.
642   // For an example of an appropriate task, see HostBlas::DoBlasGemmInternal;
643   // there, temporary internal buffers are freed using this method.
644   void EnqueueOnBackgroundThread(std::function<void()> task);
645 
646   // Adds an AllocRecord for 'opaque' of size 'bytes' to the record map, for
647   // leak checking. NULL buffer pointers and buffer sizes of 0 will not be
648   // tracked.
649   void CreateAllocRecord(void *opaque, uint64 bytes);
650 
651   // Removes the AllocRecord keyed by 'opaque' from the record map. NULL
652   // pointers will not be erased (as they're not tracked, per above).
653   void EraseAllocRecord(void *opaque);
654 
655   // Calls the relevant TraceListener routine to begin tracing for the specified
656   // asynchronous method.
657   template <typename TraceCallT, typename... ArgsT>
658   void SubmitTrace(TraceCallT trace_call, ArgsT &&...args);
659 
660   // Reader/writer lock for class-static StreamExecutor members.
661   static absl::Mutex static_mu_;
662 
663   // Reader/writer lock for mutable data structures on this StreamExecutor.
664   //
665   // Mutable so that caching functions (like DeviceDescription, AsBlas, etc.)
666   // can acquire the lock on their first (mutating) call as well.
667   mutable absl::Mutex mu_;
668 
669   // Reference to the platform that created this executor.
670   const Platform *platform_;
671 
672   // Pointer to the platform-specific-interface implementation. This is
673   // delegated to by the interface routines in pointer-to-implementation
674   // fashion.
675   std::unique_ptr<internal::StreamExecutorInterface> implementation_;
676 
677   // A mapping of pointer (to device memory) to string representation of the
678   // stack (of the allocating thread) at the time at which the pointer was
679   // allocated.
680   std::map<void *, AllocRecord> mem_allocs_ TF_GUARDED_BY(mu_);
681 
682   // Memoized BLAS support object -- we only want to create this once when asked
683   // for a BLAS interface.
684   std::unique_ptr<blas::BlasSupport> blas_ TF_GUARDED_BY(mu_);
685 
686   // Memoized DNN support object -- we only want to create this once when asked
687   // for an DNN interface.
688   std::unique_ptr<dnn::DnnSupport> dnn_ TF_GUARDED_BY(mu_);
689 
690   // Memoized FFT support object -- we only want to create this once when asked
691   // for a FFT interface.
692   std::unique_ptr<fft::FftSupport> fft_;
693 
694   // Memoized RNG support object -- we only want to create this once when asked
695   // for an RNG interface.
696   std::unique_ptr<rng::RngSupport> rng_ TF_GUARDED_BY(mu_);
697 
698   // Slot to cache the owned DeviceDescription for the underlying device
699   // once it has been queried from DeviceDescription().
700   mutable std::unique_ptr<DeviceDescription> device_description_
701       TF_GUARDED_BY(mu_);
702 
703   // The kind of the underlying platform that is being targeted, as passed
704   // during construction.
705   //
706   // Immutable post-initialization.
707   PlatformKind platform_kind_;
708 
709   // The device ordinal that this object was initialized with.
710   //
711   // Immutable post-initialization.
712   int device_ordinal_;
713 
714   // Executor for handling host callback work that cannot be performed
715   // by a host callback thread - for example, cleanup after a host BLAS routine
716   // (which may make device API calls). This work cannot block the host
717   // callback thread, will be completed asynchronously, and should be treated
718   // as fire-and-forget. Assume no ordering guarantees WRT the tasks enqueued
719   // here.
720   //
721   // Immutable post-initialization. Object is thread-safe.
722   std::unique_ptr<port::ThreadPool> background_threads_;
723 
724   // Counter for the current number of live streams. This is used to check
725   // for accidentally-outstanding streams at StreamExecutor teardown time, as
726   // well
727   // as to indicate leaks (via a large outstanding count being logged) in the
728   // case we can't allocate more streams.
729   std::atomic_int_fast32_t live_stream_count_;
730 
731   // Only one worker thread is needed; little work will be done by the
732   // executor.
733   static constexpr int kNumBackgroundThreads = 1;
734 
735   // Indicates if StreamExecutor operation tracing should be performed.
736   bool tracing_enabled_;
737 
738   // The set of TraceListeners registered for this StreamExecutor.
739   std::set<TraceListener *> listeners_ TF_GUARDED_BY(mu_);
740 
741   // Allocated memory in bytes.
742   int64 mem_alloc_bytes_;
743 
744   // Memory limit in bytes. Value less or equal to 0 indicates there is no
745   // limit.
746   int64 memory_limit_bytes_;
747 
748   StreamExecutorMemoryAllocator allocator_;
749 
750   SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutor);
751 };
752 
753 // A wrapper around ModuleHandle that uses RAII to manage its lifetime.
754 class ScopedModuleHandle {
755  public:
ScopedModuleHandle(StreamExecutor * executor,ModuleHandle module_handle)756   explicit ScopedModuleHandle(StreamExecutor *executor,
757                               ModuleHandle module_handle)
758       : executor_(executor), module_handle_(module_handle) {}
759 
ScopedModuleHandle(ScopedModuleHandle && other)760   ScopedModuleHandle(ScopedModuleHandle &&other) {
761     executor_ = other.executor_;
762     module_handle_ = other.module_handle_;
763     other.executor_ = nullptr;
764     other.module_handle_ = ModuleHandle();
765   }
766 
767   ScopedModuleHandle &operator=(ScopedModuleHandle &&other) {
768     executor_ = other.executor_;
769     module_handle_ = other.module_handle_;
770     other.executor_ = nullptr;
771     other.module_handle_ = ModuleHandle();
772     return *this;
773   }
774 
~ScopedModuleHandle()775   ~ScopedModuleHandle() {
776     if (static_cast<bool>(module_handle_)) {
777       CHECK(executor_->UnloadModule(module_handle_));
778     }
779   }
780 
781  private:
782   StreamExecutor *executor_;
783   ModuleHandle module_handle_;
784 
785   TF_DISALLOW_COPY_AND_ASSIGN(ScopedModuleHandle);
786 };
787 
788 ////////////
789 // Inlines
790 
791 template <typename... Args>
792 inline port::StatusOr<std::unique_ptr<TypedKernel<Args...>>>
CreateTypedKernel(absl::string_view kernel_name,absl::string_view ptx,absl::Span<const uint8> cubin_data)793 StreamExecutor::CreateTypedKernel(absl::string_view kernel_name,
794                                   absl::string_view ptx,
795                                   absl::Span<const uint8> cubin_data) {
796   auto kernel_base = absl::make_unique<TypedKernel<Args...>>(this);
797   MultiKernelLoaderSpec loader_spec(kernel_base->kNumberOfParameters);
798   loader_spec.AddCudaPtxInMemory(ptx, kernel_name);
799 
800   if (!cubin_data.empty()) {
801     loader_spec.AddCudaCubinInMemory(
802         reinterpret_cast<const char *>(cubin_data.data()), kernel_name);
803   }
804 
805   TF_RETURN_IF_ERROR(GetKernel(loader_spec, kernel_base.get()));
806   return std::move(kernel_base);
807 }
808 
809 template <typename T>
AllocateArray(uint64 element_count,int64 memory_space)810 inline DeviceMemory<T> StreamExecutor::AllocateArray(uint64 element_count,
811                                                      int64 memory_space) {
812   uint64 bytes = sizeof(T) * element_count;
813   return DeviceMemory<T>(Allocate(bytes, memory_space));
814 }
815 
816 template <typename T>
GetSymbol(const std::string & symbol_name,ModuleHandle module_handle)817 inline port::StatusOr<DeviceMemory<T>> StreamExecutor::GetSymbol(
818     const std::string &symbol_name, ModuleHandle module_handle) {
819   port::StatusOr<DeviceMemoryBase> untyped_symbol =
820       GetUntypedSymbol(symbol_name, module_handle);
821   if (!untyped_symbol.ok()) {
822     return untyped_symbol.status();
823   }
824   return DeviceMemory<T>(untyped_symbol.ValueOrDie());
825 }
826 
827 template <typename ElemT>
ScopedDeviceMemory(StreamExecutor * parent,DeviceMemoryBase value)828 ScopedDeviceMemory<ElemT>::ScopedDeviceMemory(StreamExecutor *parent,
829                                               DeviceMemoryBase value)
830     : wrapped_(value),
831       device_ordinal_(parent->device_ordinal()),
832       allocator_(parent->GetAllocator()) {}
833 
834 template <typename ElemT>
ScopedDeviceMemory(StreamExecutor * parent,std::initializer_list<ElemT> values)835 ScopedDeviceMemory<ElemT>::ScopedDeviceMemory(
836     StreamExecutor *parent, std::initializer_list<ElemT> values)
837     : ScopedDeviceMemory(parent, parent->AllocateArray<ElemT>(values.size())) {
838   if (ptr() != nullptr) {
839     std::vector<ElemT> local(values);
840     if (!parent->SynchronousMemcpy(ptr(), const_cast<const ElemT *>(&local[0]),
841                                    ptr()->size())) {
842       TF_CHECK_OK(Free());
843     }
844   }
845 }
846 
847 template <typename T>
AllocateZeroed()848 DeviceMemory<T> StreamExecutor::AllocateZeroed() {
849   DeviceMemoryBase buf = Allocate(sizeof(T), /*memory_space=*/0);
850   if (buf.is_null()) {
851     return DeviceMemory<T>{};
852   }
853 
854   DeviceMemory<T> result(buf);
855   bool ok = SynchronousMemZero(&result, sizeof(T)).ok();
856   if (!ok) {
857     Deallocate(&result);
858     return DeviceMemory<T>{};
859   }
860 
861   return result;
862 }
863 
864 template <typename T>
GetSubBuffer(DeviceMemory<T> * parent,uint64 element_offset,uint64 element_count)865 DeviceMemory<T> StreamExecutor::GetSubBuffer(DeviceMemory<T> *parent,
866                                              uint64 element_offset,
867                                              uint64 element_count) {
868   if (element_offset + element_count > parent->ElementCount()) {
869     LOG(ERROR) << "requested sub-buffer allocation (offset + size) is greater "
870                << "than parent allocation size: (" << element_offset << " + "
871                << element_count << ") vs. (" << parent->ElementCount() << ")";
872     return DeviceMemory<T>{};
873   }
874 
875   void *opaque = implementation_->GetSubBuffer(
876       parent, sizeof(T) * element_offset, sizeof(T) * element_count);
877   if (opaque == nullptr) {
878     return DeviceMemory<T>{};
879   }
880   return DeviceMemory<T>(DeviceMemoryBase(opaque, sizeof(T) * element_count));
881 }
882 
883 }  // namespace stream_executor
884 
885 #endif  // TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
886