1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_ 18 19 #include <map> 20 #include <set> 21 #include <vector> 22 23 #include "absl/types/span.h" 24 #include "tensorflow/compiler/xla/literal.h" 25 #include "tensorflow/compiler/xla/service/executable.h" 26 #include "tensorflow/compiler/xla/service/shaped_buffer.h" 27 #include "tensorflow/compiler/xla/statusor.h" 28 #include "tensorflow/compiler/xla/types.h" 29 #include "tensorflow/compiler/xla/xla_data.pb.h" 30 #include "tensorflow/core/platform/mutex.h" 31 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 32 #include "tensorflow/core/platform/thread_annotations.h" 33 #include "tensorflow/core/platform/types.h" 34 #include "tensorflow/stream_executor/device_memory.h" 35 36 namespace xla { 37 38 // The TransferManager interface lets backends provide platform-specific 39 // mechanisms for constructing literals from given device memory handles. 40 // This lets each platform customize how literals are transferred to/from the 41 // device in terms of padding, leading dimension, etc. 42 class TransferManager { 43 public: ~TransferManager()44 virtual ~TransferManager() {} 45 46 // Returns the ID of the platform that this transfer manager acts on. 47 virtual se::Platform::Id PlatformId() const = 0; 48 49 // Returns the shape of the on-device representation for the given shape on 50 // the host. This is intended for use with ShapedBuffer where buffers are 51 // pre-allocated by the host, e.g. TransferLiteralToDevice, without the user 52 // needing to consider device-specific behaviors. HostShapeToDeviceShape(const Shape & host_shape)53 virtual Shape HostShapeToDeviceShape(const Shape& host_shape) const { 54 // Strips off any preexisting tiling or memory space information. 55 // TODO(phawkins): fix clients not to including tiling or memory space 56 // information in shapes passed to this function and turn this into an 57 // assertion. 58 return ShapeUtil::DeviceShapeToHostShape(host_shape); 59 } 60 61 // Base class for specifying platform specific transfer metadata that can be 62 // used to tell the underlying implementation to perform specific optimization 63 // to a transfer. Actual metadata passed to supported transfer methods should 64 // subclass this class. 65 class TransferMetadata { 66 public: 67 virtual ~TransferMetadata() = 0; 68 }; 69 // Returns a literal containing the data held in the given ShapedBuffer 70 // using the provided executor. This operation is performed synchronously 71 // without waiting for any other operation on a stream to complete. 72 // 73 // This function should be avoided in favor of the asynchronous version below. 74 // 75 // Optionally caller can specify platform-specific transfer metadata that 76 // tells the actual implementation to do something special. 77 virtual StatusOr<Literal> TransferLiteralFromDevice( 78 se::Stream* stream, const ShapedBuffer& device_buffer, 79 const TransferMetadata* transfer_metadata); TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer)80 StatusOr<Literal> TransferLiteralFromDevice( 81 se::Stream* stream, const ShapedBuffer& device_buffer) { 82 return TransferLiteralFromDevice(stream, device_buffer, nullptr); 83 } 84 virtual Status TransferLiteralFromDevice( 85 se::Stream* stream, const ShapedBuffer& device_buffer, 86 const MutableBorrowingLiteral& literal, 87 const TransferMetadata* transfer_metadata); TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,const MutableBorrowingLiteral & literal)88 Status TransferLiteralFromDevice(se::Stream* stream, 89 const ShapedBuffer& device_buffer, 90 const MutableBorrowingLiteral& literal) { 91 return TransferLiteralFromDevice(stream, device_buffer, literal, nullptr); 92 } 93 94 // Begins transferring a literal containing the data held in the given 95 // ShapedBuffer using the provided executor. 96 // 97 // This operation is performed asynchronously on the given stream. It returns 98 // once the transfer is enqueued. 'done' is invoked with the result when 99 // complete. 100 // 101 // device_buffer is copied by reference and must live at least until done() is 102 // invoked. 103 // 104 // Optionally caller can specify platform-specific transfer metadata that 105 // tells the actual implementation to do something special. 106 virtual void TransferLiteralFromDevice( 107 se::Stream* stream, const ShapedBuffer& device_buffer, 108 MutableBorrowingLiteral literal, std::function<void(Status)> done, 109 const TransferMetadata* transfer_metadata) = 0; TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,MutableBorrowingLiteral literal,std::function<void (Status)> done)110 void TransferLiteralFromDevice(se::Stream* stream, 111 const ShapedBuffer& device_buffer, 112 MutableBorrowingLiteral literal, 113 std::function<void(Status)> done) { 114 return TransferLiteralFromDevice(stream, device_buffer, literal, done, 115 nullptr); 116 } 117 118 // Transfers the given literal into the previously allocated device memory 119 // represented by the given ShapedBuffer using the given executor. The shape 120 // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible, 121 // but need not have the same layout. 122 // 123 // This operation is performed synchronously without waiting for any other 124 // operation on a stream to complete. This function should be avoided in favor 125 // of the asynchronous version below. 126 // 127 // Optionally caller can specify platform-specific transfer metadata that 128 // tells the actual implementation to do something special. 129 virtual Status TransferLiteralToDevice( 130 se::Stream* stream, const LiteralSlice& literal, 131 const ShapedBuffer& device_buffer, 132 const TransferMetadata* transfer_metadata); TransferLiteralToDevice(se::Stream * stream,const LiteralSlice & literal,const ShapedBuffer & device_buffer)133 Status TransferLiteralToDevice(se::Stream* stream, 134 const LiteralSlice& literal, 135 const ShapedBuffer& device_buffer) { 136 return TransferLiteralToDevice(stream, literal, device_buffer, nullptr); 137 } 138 139 // Transfers the given literal into the previously allocated device memory 140 // represented by the given ShapedBuffer using the given executor. The shape 141 // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible, 142 // but need not have the same layout. 143 // 144 // This operation is performed asynchronously on the given stream. It returns 145 // once the transfer is enqueued, and may return before the transfer has 146 // completed. 147 // 148 // The caller may free the data structures 'literal' and 'device_buffer' 149 // immediately after this function returns, however their constituent buffers 150 // on both host and device must remain valid until the enqueued transfer has 151 // completed on 'stream'. 152 // 153 // Optionally caller can specify platform-specific transfer metadata that 154 // tells the actual implementation to do something special. 155 virtual Status TransferLiteralToDeviceAsync( 156 se::Stream* stream, const LiteralSlice& literal, 157 const ShapedBuffer& device_buffer, 158 const TransferMetadata* transfer_metadata) = 0; TransferLiteralToDeviceAsync(se::Stream * stream,const LiteralSlice & literal,const ShapedBuffer & device_buffer)159 Status TransferLiteralToDeviceAsync(se::Stream* stream, 160 const LiteralSlice& literal, 161 const ShapedBuffer& device_buffer) { 162 return TransferLiteralToDeviceAsync(stream, literal, device_buffer, 163 nullptr); 164 } 165 166 // Convenience methods for transferring an array to or from the device at a 167 // known address. This avoids having to construct a ShapedBuffer just to 168 // transfer an array at a known address. 169 // 170 // Optionally caller can specify platform-specific transfer metadata that 171 // tells the actual implementation to do something special. 172 Status TransferArrayToDevice( 173 se::Stream* stream, const LiteralSlice& literal, 174 const se::DeviceMemoryBase& dest, 175 const TransferMetadata* transfer_metadata = nullptr); 176 void TransferArrayFromDevice( 177 se::Stream* stream, const Shape& shape, 178 const se::DeviceMemoryBase& source, 179 const MutableBorrowingLiteral& literal, std::function<void(Status)> done, 180 const TransferMetadata* transfer_metadata = nullptr); 181 182 Status TransferArrayToDeviceAsync( 183 se::Stream* stream, const LiteralSlice& literal, 184 const se::DeviceMemoryBase& dest, 185 const TransferMetadata* transfer_metadata = nullptr); 186 StatusOr<Literal> TransferArrayFromDevice( 187 se::Stream* stream, const Shape& shape, 188 const se::DeviceMemoryBase& source, 189 const TransferMetadata* transfer_metadata = nullptr); 190 191 // Read from a device buffer and update the dynamic dimension sizes of 192 // `host_shape` and `device_shape`. The function takes in bounded dynamic 193 // shapes, and returns static shapes with dynamic shapes updated. 194 // The shape of the buffer also have to be compatible with the host shape and 195 // device shape. 196 virtual Status ReadDynamicShapes(se::Stream* stream, 197 ShapedBuffer* device_buffer, 198 Shape* device_shape); 199 200 // Transfers the given literal into the Infeed interface of the device, 201 // using the given executor. 202 virtual Status TransferLiteralToInfeed(se::StreamExecutor* executor, 203 const LiteralSlice& literal) = 0; 204 205 // Transfers the given literal from the Outfeed interface of the device, 206 // using the given executor. The shape and layout are determined by the 207 // shape and layout of `literal`. 208 virtual Status TransferLiteralFromOutfeed( 209 se::StreamExecutor* executor, MutableBorrowingLiteral literal) = 0; 210 211 // Resets the devices associated with this transfer manager. 212 virtual Status ResetDevices( 213 absl::Span<se::StreamExecutor* const> executor) = 0; 214 215 // Given an allocated ShapedBuffer, constructs the tuple index table(s) in 216 // each buffer of the given ShapedBuffer corresponding to tuple shapes. If the 217 // ShapedBuffer is array-shaped this method does nothing. 218 Status WriteTupleIndexTables(se::Stream* stream, 219 const ShapedBuffer& device_buffer); 220 Status WriteTupleIndexTablesAsync(se::Stream* stream, 221 const ShapedBuffer& device_buffer); 222 223 // Writes a tuple index buffer for the root of 'device_buffer', which must 224 // be a tuple. Unlike WriteTupleIndexTables, only writes the root buffer, 225 // rather than writing all subbuffers. This method is always asynchronous. 226 Status WriteRootTupleIndexTable(se::Stream* stream, 227 const ShapedBuffer& device_buffer); 228 Status WriteRootTupleIndexTable( 229 se::Stream* stream, 230 const ShapeTree<MaybeOwningDeviceMemory>& buffer_tree); 231 232 // Determines the byte size requirement for the given shape on the underlying 233 // architecture. This will be used to allocate an appropriately sized memory 234 // region for a host-to-device transfer. 235 virtual int64 GetByteSizeRequirement(const Shape& shape) const = 0; 236 237 // Chooses a compact layout for 'shape', ignoring any existing layout on 238 // 'shape'. What "reasonable" means is left up to the backend. The 239 // intended use case is to choose a layout that avoids excessive padding on 240 // devices that have tiled memory architectures. 241 // The default implementation always picks a default (major-to-minor) layout. 242 // Fails if 'shape' cannot be represented by the device. 243 virtual StatusOr<Shape> ChooseCompactLayoutForShape( 244 const Shape& host_shape) const; 245 246 // Allocates a ScopedShapedBuffer which can hold data with the given on-host 247 // shape. The on-device shape may be different as indicated by 248 // HostShapeToDeviceShape. 249 StatusOr<ScopedShapedBuffer> AllocateScopedShapedBuffer( 250 const Shape& on_host_shape, se::DeviceMemoryAllocator* allocator, 251 int device_ordinal); 252 253 // The given ShapedBuffer holds a handle to allocated memory, but it is not 254 // in the general case legal to immediately copy or access that allocated 255 // memory because queued operations on the device may alias that memory. 256 // Memory ordering is enforced by the Stream's happens-before relationship 257 // which allows eager deallocation and reallocation of buffers host-side even 258 // if the device hasn't finished with them. 259 // 260 // In certain cases, it can be known that a ShapedBuffer does not have any 261 // conflicting accesses on the device and thus is eligible to be accessed at 262 // any time from the host. 263 // 264 // This function returns true if device_buffer can be accessed immediately 265 // without waiting for the Stream's previously enqueued items. This only 266 // returns true if all subbuffers in device_buffer can be accessed 267 // immediately. CanShapedBufferBeAccessedNow(se::StreamExecutor * executor,const ShapedBuffer & device_buffer)268 virtual bool CanShapedBufferBeAccessedNow( 269 se::StreamExecutor* executor, const ShapedBuffer& device_buffer) const { 270 return false; 271 } 272 273 // Equivalent to CanShapedBufferBeAccessedNow but for a single device buffer. CanBufferBeAccessedNow(se::StreamExecutor * executor,const se::DeviceMemoryBase & device_buffer)274 virtual bool CanBufferBeAccessedNow( 275 se::StreamExecutor* executor, 276 const se::DeviceMemoryBase& device_buffer) const { 277 return false; 278 } 279 280 ///// 281 // The TransferManager class also serves as a point to register objects for 282 // the various platforms. 283 284 // Registers the TransferManager singleton for the platform kind. This is 285 // assumed to be a singleton, so no ownership is transferred. 286 // 287 // Precondition: a platform kind must not be registered more than once. 288 typedef std::unique_ptr<TransferManager> (*TransferManagerCreationFunction)(); 289 static void RegisterTransferManager( 290 se::Platform::Id platform_id, 291 TransferManagerCreationFunction transfer_manager); 292 293 // Returns the transfer manager singleton pointer if it is available for the 294 // given platform, or an error status if it is not. 295 static StatusOr<TransferManager*> GetForPlatform( 296 const se::Platform* platform); 297 298 // Writes the given device-memory pointers in 'elements' to the given region 299 // to construct a tuple index table in the platform-specific tuple 300 // representation. 301 virtual Status WriteSingleTupleIndexTable( 302 se::Stream* stream, absl::Span<const se::DeviceMemoryBase> elements, 303 const Shape& shape, se::DeviceMemoryBase* region) = 0; 304 305 protected: 306 // Transfer a memory block of the given size from the device source into the 307 // 'destination' buffer. 308 // 309 // size is the size to transfer to destination in bytes. 310 virtual Status TransferBufferFromDevice(se::Stream* stream, 311 const se::DeviceMemoryBase& source, 312 int64 size, void* destination); 313 314 // Transfer a memory block of the given size from 'source' buffer to the given 315 // destination of the device. 316 // 317 // size is the size to transfer from source in bytes. 318 virtual Status TransferBufferToDevice(se::Stream* stream, int64 size, 319 const void* source, 320 se::DeviceMemoryBase* destination); 321 322 private: 323 // The mutex that guards the platform-to-transfer manager map. 324 static tensorflow::mutex platform_transfer_manager_mutex_; 325 326 // State kept for each kind of TransferManager. Registration functions 327 // set up creation_function, and then we use that to lazily create 328 // "manager" the first time GetForPlatform is invoked for a particular id. 329 struct State { 330 std::unique_ptr<TransferManager> manager; 331 TransferManagerCreationFunction creation_function = nullptr; 332 }; 333 334 // Map from platform kind to transfer manager singleton. 335 static std::map<se::Platform::Id, State>* GetPlatformTransferManagers(); 336 }; 337 338 } // namespace xla 339 340 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_ 341