1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_
18 
19 #include <map>
20 #include <set>
21 #include <vector>
22 
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/service/executable.h"
26 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
32 #include "tensorflow/core/platform/thread_annotations.h"
33 #include "tensorflow/core/platform/types.h"
34 #include "tensorflow/stream_executor/device_memory.h"
35 
36 namespace xla {
37 
38 // The TransferManager interface lets backends provide platform-specific
39 // mechanisms for constructing literals from given device memory handles.
40 // This lets each platform customize how literals are transferred to/from the
41 // device in terms of padding, leading dimension, etc.
42 class TransferManager {
43  public:
~TransferManager()44   virtual ~TransferManager() {}
45 
46   // Returns the ID of the platform that this transfer manager acts on.
47   virtual se::Platform::Id PlatformId() const = 0;
48 
49   // Returns the shape of the on-device representation for the given shape on
50   // the host. This is intended for use with ShapedBuffer where buffers are
51   // pre-allocated by the host, e.g. TransferLiteralToDevice, without the user
52   // needing to consider device-specific behaviors.
HostShapeToDeviceShape(const Shape & host_shape)53   virtual Shape HostShapeToDeviceShape(const Shape& host_shape) const {
54     // Strips off any preexisting tiling or memory space information.
55     // TODO(phawkins): fix clients not to including tiling or memory space
56     // information in shapes passed to this function and turn this into an
57     // assertion.
58     return ShapeUtil::DeviceShapeToHostShape(host_shape);
59   }
60 
61   // Base class for specifying platform specific transfer metadata that can be
62   // used to tell the underlying implementation to perform specific optimization
63   // to a transfer. Actual metadata passed to supported transfer methods should
64   // subclass this class.
65   class TransferMetadata {
66    public:
67     virtual ~TransferMetadata() = 0;
68   };
69   // Returns a literal containing the data held in the given ShapedBuffer
70   // using the provided executor. This operation is performed synchronously
71   // without waiting for any other operation on a stream to complete.
72   //
73   // This function should be avoided in favor of the asynchronous version below.
74   //
75   // Optionally caller can specify platform-specific transfer metadata that
76   // tells the actual implementation to do something special.
77   virtual StatusOr<Literal> TransferLiteralFromDevice(
78       se::Stream* stream, const ShapedBuffer& device_buffer,
79       const TransferMetadata* transfer_metadata);
TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer)80   StatusOr<Literal> TransferLiteralFromDevice(
81       se::Stream* stream, const ShapedBuffer& device_buffer) {
82     return TransferLiteralFromDevice(stream, device_buffer, nullptr);
83   }
84   virtual Status TransferLiteralFromDevice(
85       se::Stream* stream, const ShapedBuffer& device_buffer,
86       const MutableBorrowingLiteral& literal,
87       const TransferMetadata* transfer_metadata);
TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,const MutableBorrowingLiteral & literal)88   Status TransferLiteralFromDevice(se::Stream* stream,
89                                    const ShapedBuffer& device_buffer,
90                                    const MutableBorrowingLiteral& literal) {
91     return TransferLiteralFromDevice(stream, device_buffer, literal, nullptr);
92   }
93 
94   // Begins transferring a literal containing the data held in the given
95   // ShapedBuffer using the provided executor.
96   //
97   // This operation is performed asynchronously on the given stream. It returns
98   // once the transfer is enqueued. 'done' is invoked with the result when
99   // complete.
100   //
101   // device_buffer is copied by reference and must live at least until done() is
102   // invoked.
103   //
104   // Optionally caller can specify platform-specific transfer metadata that
105   // tells the actual implementation to do something special.
106   virtual void TransferLiteralFromDevice(
107       se::Stream* stream, const ShapedBuffer& device_buffer,
108       MutableBorrowingLiteral literal, std::function<void(Status)> done,
109       const TransferMetadata* transfer_metadata) = 0;
TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,MutableBorrowingLiteral literal,std::function<void (Status)> done)110   void TransferLiteralFromDevice(se::Stream* stream,
111                                  const ShapedBuffer& device_buffer,
112                                  MutableBorrowingLiteral literal,
113                                  std::function<void(Status)> done) {
114     return TransferLiteralFromDevice(stream, device_buffer, literal, done,
115                                      nullptr);
116   }
117 
118   // Transfers the given literal into the previously allocated device memory
119   // represented by the given ShapedBuffer using the given executor. The shape
120   // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible,
121   // but need not have the same layout.
122   //
123   // This operation is performed synchronously without waiting for any other
124   // operation on a stream to complete. This function should be avoided in favor
125   // of the asynchronous version below.
126   //
127   // Optionally caller can specify platform-specific transfer metadata that
128   // tells the actual implementation to do something special.
129   virtual Status TransferLiteralToDevice(
130       se::Stream* stream, const LiteralSlice& literal,
131       const ShapedBuffer& device_buffer,
132       const TransferMetadata* transfer_metadata);
TransferLiteralToDevice(se::Stream * stream,const LiteralSlice & literal,const ShapedBuffer & device_buffer)133   Status TransferLiteralToDevice(se::Stream* stream,
134                                  const LiteralSlice& literal,
135                                  const ShapedBuffer& device_buffer) {
136     return TransferLiteralToDevice(stream, literal, device_buffer, nullptr);
137   }
138 
139   // Transfers the given literal into the previously allocated device memory
140   // represented by the given ShapedBuffer using the given executor. The shape
141   // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible,
142   // but need not have the same layout.
143   //
144   // This operation is performed asynchronously on the given stream. It returns
145   // once the transfer is enqueued, and may return before the transfer has
146   // completed.
147   //
148   // The caller may free the data structures 'literal' and 'device_buffer'
149   // immediately after this function returns, however their constituent buffers
150   // on both host and device must remain valid until the enqueued transfer has
151   // completed on 'stream'.
152   //
153   // Optionally caller can specify platform-specific transfer metadata that
154   // tells the actual implementation to do something special.
155   virtual Status TransferLiteralToDeviceAsync(
156       se::Stream* stream, const LiteralSlice& literal,
157       const ShapedBuffer& device_buffer,
158       const TransferMetadata* transfer_metadata) = 0;
TransferLiteralToDeviceAsync(se::Stream * stream,const LiteralSlice & literal,const ShapedBuffer & device_buffer)159   Status TransferLiteralToDeviceAsync(se::Stream* stream,
160                                       const LiteralSlice& literal,
161                                       const ShapedBuffer& device_buffer) {
162     return TransferLiteralToDeviceAsync(stream, literal, device_buffer,
163                                         nullptr);
164   }
165 
166   // Convenience methods for transferring an array to or from the device at a
167   // known address. This avoids having to construct a ShapedBuffer just to
168   // transfer an array at a known address.
169   //
170   // Optionally caller can specify platform-specific transfer metadata that
171   // tells the actual implementation to do something special.
172   Status TransferArrayToDevice(
173       se::Stream* stream, const LiteralSlice& literal,
174       const se::DeviceMemoryBase& dest,
175       const TransferMetadata* transfer_metadata = nullptr);
176   void TransferArrayFromDevice(
177       se::Stream* stream, const Shape& shape,
178       const se::DeviceMemoryBase& source,
179       const MutableBorrowingLiteral& literal, std::function<void(Status)> done,
180       const TransferMetadata* transfer_metadata = nullptr);
181 
182   Status TransferArrayToDeviceAsync(
183       se::Stream* stream, const LiteralSlice& literal,
184       const se::DeviceMemoryBase& dest,
185       const TransferMetadata* transfer_metadata = nullptr);
186   StatusOr<Literal> TransferArrayFromDevice(
187       se::Stream* stream, const Shape& shape,
188       const se::DeviceMemoryBase& source,
189       const TransferMetadata* transfer_metadata = nullptr);
190 
191   // Read from a device buffer and update the dynamic dimension sizes of
192   // `host_shape` and `device_shape`. The function takes in bounded dynamic
193   // shapes, and returns static shapes with dynamic shapes updated.
194   // The shape of the buffer also have to be compatible with the host shape and
195   // device shape.
196   virtual Status ReadDynamicShapes(se::Stream* stream,
197                                    ShapedBuffer* device_buffer,
198                                    Shape* device_shape);
199 
200   // Transfers the given literal into the Infeed interface of the device,
201   // using the given executor.
202   virtual Status TransferLiteralToInfeed(se::StreamExecutor* executor,
203                                          const LiteralSlice& literal) = 0;
204 
205   // Transfers the given literal from the Outfeed interface of the device,
206   // using the given executor. The shape and layout are determined by the
207   // shape and layout of `literal`.
208   virtual Status TransferLiteralFromOutfeed(
209       se::StreamExecutor* executor, MutableBorrowingLiteral literal) = 0;
210 
211   // Resets the devices associated with this transfer manager.
212   virtual Status ResetDevices(
213       absl::Span<se::StreamExecutor* const> executor) = 0;
214 
215   // Given an allocated ShapedBuffer, constructs the tuple index table(s) in
216   // each buffer of the given ShapedBuffer corresponding to tuple shapes. If the
217   // ShapedBuffer is array-shaped this method does nothing.
218   Status WriteTupleIndexTables(se::Stream* stream,
219                                const ShapedBuffer& device_buffer);
220   Status WriteTupleIndexTablesAsync(se::Stream* stream,
221                                     const ShapedBuffer& device_buffer);
222 
223   // Writes a tuple index buffer for the root of 'device_buffer', which must
224   // be a tuple. Unlike WriteTupleIndexTables, only writes the root buffer,
225   // rather than writing all subbuffers. This method is always asynchronous.
226   Status WriteRootTupleIndexTable(se::Stream* stream,
227                                   const ShapedBuffer& device_buffer);
228   Status WriteRootTupleIndexTable(
229       se::Stream* stream,
230       const ShapeTree<MaybeOwningDeviceMemory>& buffer_tree);
231 
232   // Determines the byte size requirement for the given shape on the underlying
233   // architecture. This will be used to allocate an appropriately sized memory
234   // region for a host-to-device transfer.
235   virtual int64 GetByteSizeRequirement(const Shape& shape) const = 0;
236 
237   // Chooses a compact layout for 'shape', ignoring any existing layout on
238   // 'shape'. What "reasonable" means is left up to the backend. The
239   // intended use case is to choose a layout that avoids excessive padding on
240   // devices that have tiled memory architectures.
241   // The default implementation always picks a default (major-to-minor) layout.
242   // Fails if 'shape' cannot be represented by the device.
243   virtual StatusOr<Shape> ChooseCompactLayoutForShape(
244       const Shape& host_shape) const;
245 
246   // Allocates a ScopedShapedBuffer which can hold data with the given on-host
247   // shape. The on-device shape may be different as indicated by
248   // HostShapeToDeviceShape.
249   StatusOr<ScopedShapedBuffer> AllocateScopedShapedBuffer(
250       const Shape& on_host_shape, se::DeviceMemoryAllocator* allocator,
251       int device_ordinal);
252 
253   // The given ShapedBuffer holds a handle to allocated memory, but it is not
254   // in the general case legal to immediately copy or access that allocated
255   // memory because queued operations on the device may alias that memory.
256   // Memory ordering is enforced by the Stream's happens-before relationship
257   // which allows eager deallocation and reallocation of buffers host-side even
258   // if the device hasn't finished with them.
259   //
260   // In certain cases, it can be known that a ShapedBuffer does not have any
261   // conflicting accesses on the device and thus is eligible to be accessed at
262   // any time from the host.
263   //
264   // This function returns true if device_buffer can be accessed immediately
265   // without waiting for the Stream's previously enqueued items. This only
266   // returns true if all subbuffers in device_buffer can be accessed
267   // immediately.
CanShapedBufferBeAccessedNow(se::StreamExecutor * executor,const ShapedBuffer & device_buffer)268   virtual bool CanShapedBufferBeAccessedNow(
269       se::StreamExecutor* executor, const ShapedBuffer& device_buffer) const {
270     return false;
271   }
272 
273   // Equivalent to CanShapedBufferBeAccessedNow but for a single device buffer.
CanBufferBeAccessedNow(se::StreamExecutor * executor,const se::DeviceMemoryBase & device_buffer)274   virtual bool CanBufferBeAccessedNow(
275       se::StreamExecutor* executor,
276       const se::DeviceMemoryBase& device_buffer) const {
277     return false;
278   }
279 
280   /////
281   // The TransferManager class also serves as a point to register objects for
282   // the various platforms.
283 
284   // Registers the TransferManager singleton for the platform kind. This is
285   // assumed to be a singleton, so no ownership is transferred.
286   //
287   // Precondition: a platform kind must not be registered more than once.
288   typedef std::unique_ptr<TransferManager> (*TransferManagerCreationFunction)();
289   static void RegisterTransferManager(
290       se::Platform::Id platform_id,
291       TransferManagerCreationFunction transfer_manager);
292 
293   // Returns the transfer manager singleton pointer if it is available for the
294   // given platform, or an error status if it is not.
295   static StatusOr<TransferManager*> GetForPlatform(
296       const se::Platform* platform);
297 
298   // Writes the given device-memory pointers in 'elements' to the given region
299   // to construct a tuple index table in the platform-specific tuple
300   // representation.
301   virtual Status WriteSingleTupleIndexTable(
302       se::Stream* stream, absl::Span<const se::DeviceMemoryBase> elements,
303       const Shape& shape, se::DeviceMemoryBase* region) = 0;
304 
305  protected:
306   // Transfer a memory block of the given size from the device source into the
307   // 'destination' buffer.
308   //
309   // size is the size to transfer to destination in bytes.
310   virtual Status TransferBufferFromDevice(se::Stream* stream,
311                                           const se::DeviceMemoryBase& source,
312                                           int64 size, void* destination);
313 
314   // Transfer a memory block of the given size from 'source' buffer to the given
315   // destination of the device.
316   //
317   // size is the size to transfer from source in bytes.
318   virtual Status TransferBufferToDevice(se::Stream* stream, int64 size,
319                                         const void* source,
320                                         se::DeviceMemoryBase* destination);
321 
322  private:
323   // The mutex that guards the platform-to-transfer manager map.
324   static tensorflow::mutex platform_transfer_manager_mutex_;
325 
326   // State kept for each kind of TransferManager.  Registration functions
327   // set up creation_function, and then we use that to lazily create
328   // "manager" the first time GetForPlatform is invoked for a particular id.
329   struct State {
330     std::unique_ptr<TransferManager> manager;
331     TransferManagerCreationFunction creation_function = nullptr;
332   };
333 
334   // Map from platform kind to transfer manager singleton.
335   static std::map<se::Platform::Id, State>* GetPlatformTransferManagers();
336 };
337 
338 }  // namespace xla
339 
340 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_
341