1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DEVICE_MEMORY_ALLOCATOR_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DEVICE_MEMORY_ALLOCATOR_H_
18 
19 #include <vector>
20 
21 #include "absl/synchronization/mutex.h"
22 #include "absl/types/span.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/platform/thread_annotations.h"
26 #include "tensorflow/core/platform/types.h"
27 #include "tensorflow/stream_executor/device_memory.h"
28 #include "tensorflow/stream_executor/lib/statusor.h"
29 #include "tensorflow/stream_executor/platform.h"
30 
31 namespace stream_executor {
32 
33 class DeviceMemoryAllocator;
34 
35 // Owning pointer for memory on a device.
36 //
37 // ScopedDeviceMemory is an owning pointer like std::unique_ptr, but it can
38 // point to memory that resides on a "device" (e.g. a GPU).  When a
39 // ScopedDeviceMemory goes out of scope, it frees the memory it owns.
40 //
41 // We say that an instance of ScopedDeviceMemory is "active" if it currently
42 // owns a (possibly empty) slice of memory on the device.  Moving,
43 // Release()'ing, Free()'ing, and other actions can deactive an active object.
44 template <typename ElemT>
45 class ScopedDeviceMemory {
46  public:
47   // Default construction initializes the internal state to nullptr.  This
48   // mirrors the std::unique_ptr<> functionality, where default construction
49   // produces a nullptr unique_ptr, which can be assigned later.
ScopedDeviceMemory()50   ScopedDeviceMemory() : device_ordinal_(-1), allocator_(nullptr) {}
51 
52   // Construct a ScopedDeviceMemory from a custom allocator.
53   //
54   // Parameters:
55   //  mem: Already-allocated device memory value for this scoped mechanism to
56   //       deallocate. This memory must have been allocated by parent.
57   //  device_ordinal: Device on which the memory was allocated.
58   //  allocator: Allocator used to deallocate memory when this instance goes
59   //             out of scope.
ScopedDeviceMemory(DeviceMemoryBase mem,int device_ordinal,DeviceMemoryAllocator * allocator)60   ScopedDeviceMemory(DeviceMemoryBase mem, int device_ordinal,
61                      DeviceMemoryAllocator *allocator)
62       : wrapped_(mem), device_ordinal_(device_ordinal), allocator_(allocator) {
63     DCHECK_GE(device_ordinal_, 0);
64   }
65 
66   // A helper constructor to generate a scoped device memory given an already
67   // allocated memory and a stream executor.
68   //
69   // Precondition: memory was allocated by the stream executor `parent`.
70   ScopedDeviceMemory(StreamExecutor *parent, DeviceMemoryBase value);
71 
72   // Constructor overload that places a literal array into device memory.
73   //
74   // Relies on the allocation function exposed by the stream executor `parent`,
75   // which will be also used for deallocating the memory
76   ScopedDeviceMemory(StreamExecutor *parent,
77                      std::initializer_list<ElemT> values);
78 
79   // Moves ownership of the memory from other to the constructed
80   // object.
81   //
82   // Postcondition: other == nullptr.
ScopedDeviceMemory(ScopedDeviceMemory && other)83   ScopedDeviceMemory(ScopedDeviceMemory &&other)
84       : wrapped_(other.Release()),
85         device_ordinal_(other.device_ordinal_),
86         allocator_(other.allocator_) {}
87 
88   // Releases the memory that was provided in the constructor, through the
89   // "parent" StreamExecutor.
~ScopedDeviceMemory()90   ~ScopedDeviceMemory() { TF_CHECK_OK(Free()); }
91 
92   // Moves ownership of the memory from other to this object.
93   //
94   // Postcondition: other == nullptr.
95   ScopedDeviceMemory &operator=(ScopedDeviceMemory &&other) {
96     TF_CHECK_OK(Free());
97     wrapped_ = other.Release();
98     allocator_ = other.allocator_;
99     device_ordinal_ = other.device_ordinal_;
100     return *this;
101   }
102 
103   // Returns the memory that backs this scoped allocation converted to
104   // DeviceMemory<T> apparent type. This is useful for cases where the
105   // DeviceMemory must be passed by const-ref, as the ScopedDeviceMemory doesn't
106   // allow copying, for scoped-object-lifetime reasons.
cref()107   const DeviceMemory<ElemT> &cref() const { return wrapped_; }
108 
109   // Returns a pointer to the DeviceMemory<T> apparent type for use in mutable
110   // operations. The value returned should not be used outside the scope of this
111   // ScopedDeviceMemory object's lifetime.
ptr()112   DeviceMemory<ElemT> *ptr() { return &wrapped_; }
ptr()113   const DeviceMemory<ElemT> *ptr() const { return &wrapped_; }
114 
115   // Smart-pointer-like operators for the wrapped DeviceMemory.
116   // This reference must not be used outside the lifetime of this
117   // ScopedDeviceMemory.
118   const DeviceMemory<ElemT> &operator*() const { return cref(); }
119   DeviceMemory<ElemT> *operator->() { return ptr(); }
120   const DeviceMemory<ElemT> *operator->() const { return ptr(); }
121 
is_null()122   bool is_null() const { return wrapped_.is_null(); }
123   bool operator==(std::nullptr_t other) const { return is_null(); }
124   bool operator!=(std::nullptr_t other) const { return !is_null(); }
125 
126   // Analogous to std::unique_ptr::release, releases ownership of the held
127   // memory and transfers it to the caller.
128   //
129   // Postcondition: *this == nullptr
Release()130   DeviceMemory<ElemT> Release() {
131     DeviceMemory<ElemT> tmp = wrapped_;
132     wrapped_ = DeviceMemory<ElemT>{};
133     return tmp;
134   }
135 
136   // The returned allocator is nonnull iff this object is active.
allocator()137   DeviceMemoryAllocator *allocator() const { return allocator_; }
138 
device_ordinal()139   int device_ordinal() const { return device_ordinal_; }
140 
141   // Frees the existing memory, resets the wrapped memory to null.
142   port::Status Free();
143 
144  private:
145   DeviceMemory<ElemT> wrapped_;       // Value we wrap with scoped-release.
146   int device_ordinal_;                // Negative one for inactive object.
147   DeviceMemoryAllocator *allocator_;  // Null if this object is inactive.
148 
149   SE_DISALLOW_COPY_AND_ASSIGN(ScopedDeviceMemory);
150 };
151 
152 // Type alias for compatibility with the previous managed memory implementation.
153 using OwningDeviceMemory = ScopedDeviceMemory<uint8>;
154 
155 // Memory allocator interface for the device.
156 //
157 // Intended usage is through Allocate() functions which return an owning smart
158 // pointer.
159 class DeviceMemoryAllocator {
160  public:
161   // Parameter platform indicates which platform the allocator allocates memory
162   // on. Must be non-null.
DeviceMemoryAllocator(const Platform * platform)163   explicit DeviceMemoryAllocator(const Platform* platform)
164       : platform_(platform) {}
~DeviceMemoryAllocator()165   virtual ~DeviceMemoryAllocator() {}
166 
167   // Allocates memory on the device.
168   //
169   // If size > 0 and the returned StatusOr is OK, the wrapped OwningDeviceMemory
170   // must not be null.  If size == 0, must return a null OwningDeviceMemory.
171   //
172   // 'retry_on_failure': If false, and the first attempt to allocate the memory
173   // fails, the allocation should return immediately without retrying.  An
174   // example use case is optional scratch spaces where a failure has only
175   // performance impact.
176   virtual port::StatusOr<OwningDeviceMemory> Allocate(int device_ordinal,
177                                                       uint64 size,
178                                                       bool retry_on_failure,
179                                                       int64 memory_space) = 0;
180 
181   // Two-arg version of Allocate(), which sets retry-on-failure to true and
182   // memory_space to default (0).
183   //
184   // (We don't simply use a default argument on the virtual Allocate function
185   // because default args on virtual functions are disallowed by the Google
186   // style guide.)
Allocate(int device_ordinal,uint64 size)187   port::StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64 size) {
188     return Allocate(device_ordinal, size, /*retry_on_failure=*/true,
189                     /*memory_space=*/0);
190   }
191 
192   // Three-arg version of Allocate(), which sets memory_space to default (0).
Allocate(int device_ordinal,uint64 size,bool retry_on_failure)193   port::StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64 size,
194                                               bool retry_on_failure) {
195     return Allocate(device_ordinal, size, retry_on_failure,
196                     /*memory_space=*/0);
197   }
198 
199   // Typed version of the allocation, returning typed memory.
200   template <typename ElemT>
201   port::StatusOr<ScopedDeviceMemory<ElemT>> Allocate(
202       int device_ordinal, uint64 size, bool retry_on_failure = true,
203       int64 memory_space = 0) {
204     return Allocate(device_ordinal, size, retry_on_failure, memory_space);
205   }
206 
207   // Must be a nop for null pointers. Should not be used.
208   //
209   // TODO(cheshire): Add deprecation notice.
210   virtual port::Status Deallocate(int device_ordinal, DeviceMemoryBase mem) = 0;
211 
212   // Return the platform that the allocator allocates memory on.
platform()213   const Platform* platform() const { return platform_; }
214 
215   // Can we call Deallocate() as soon as a computation has been scheduled on
216   // a stream, or do we have to wait for the computation to complete first?
AllowsAsynchronousDeallocation()217   virtual bool AllowsAsynchronousDeallocation() const { return false; }
218 
219   // Returns a stream pointer on which it is always safe to access memory
220   // allocated by this allocator. It is not necessary to use the returned stream
221   // though, as clients may have additional information letting them safely use
222   // a different stream.
223   virtual port::StatusOr<Stream *> GetStream(int device_ordinal) = 0;
224 
225  protected:
226   const Platform* platform_;
227 };
228 
229 // Default memory allocator for a platform which uses
230 // StreamExecutor::Allocate/Deallocate.
231 class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator {
232  public:
233   // Create an allocator supporting a single device, corresponding to the passed
234   // executor.
235   explicit StreamExecutorMemoryAllocator(StreamExecutor *executor);
236 
237   // Create an allocator supporting multiple stream executors.
238   //
239   // Precondition: all stream_executors have different device ordinals.
240   StreamExecutorMemoryAllocator(
241       const Platform *platform,
242       absl::Span<StreamExecutor *const> stream_executors);
243 
244   port::StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64 size,
245                                               bool retry_on_failure,
246                                               int64 memory_space) override;
247 
248   // Pull in two-arg overload that sets retry_on_failure to true.
249   using DeviceMemoryAllocator::Allocate;
250 
251   port::Status Deallocate(int device_ordinal, DeviceMemoryBase mem) override;
252 
253   bool AllowsAsynchronousDeallocation() const override;
254 
255   // Gets-or-creates a stream for a given `device_ordinal` from an appropriate
256   // stream executor.
257   port::StatusOr<Stream *> GetStream(int device_ordinal) override;
258 
259   // Gets the stream executor for given device ordinal.
260   port::StatusOr<StreamExecutor *> GetStreamExecutor(int device_ordinal) const;
261 
262  private:
263   // Available stream executors. Each stream executor has a different device
264   // ordinal.
265   std::vector<StreamExecutor *> stream_executors_;
266 
267   absl::Mutex mutex_;
268 
269   // Cache of streams for GetStream.
270   std::map<int, Stream> streams_ TF_GUARDED_BY(mutex_);
271 };
272 
273 template <typename ElemT>
Free()274 port::Status ScopedDeviceMemory<ElemT>::Free() {
275   if (!wrapped_.is_null()) {
276     CHECK(allocator_ != nullptr) << "Owning pointer in inconsistent state";
277     TF_RETURN_IF_ERROR(allocator_->Deallocate(device_ordinal_, wrapped_));
278   }
279   wrapped_ = DeviceMemory<ElemT>{};
280   return port::Status::OK();
281 }
282 
283 }  // namespace stream_executor
284 
285 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_DEVICE_MEMORY_ALLOCATOR_H_
286