1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Classes for keeping track of on-device state.
17 
18 #ifndef TENSORFLOW_COMPILER_XRT_XRT_STATE_H_
19 #define TENSORFLOW_COMPILER_XRT_XRT_STATE_H_
20 
21 #include <functional>
22 #include <memory>
23 #include <string>
24 #include <vector>
25 
26 #include "tensorflow/compiler/xla/literal.h"
27 #include "tensorflow/compiler/xla/service/backend.h"
28 #include "tensorflow/compiler/xla/service/device_memory_allocator.h"
29 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 #include "tensorflow/core/framework/resource_mgr.h"
33 #include "tensorflow/core/lib/core/refcount.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/lib/gtl/array_slice.h"
36 #include "tensorflow/core/platform/types.h"
37 #include "tensorflow/stream_executor/stream_executor.h"
38 
39 namespace tensorflow {
40 
41 // TODO(misard) make this a Tensor if and when that makes sense.
42 // A reference-counted wrapper around a buffer allocation. This maps an XLA
43 // tuple index or a non-tuple XLA shape to a region of device memory. The device
44 // memory buffer is freed when the reference count drops to zero.
45 class XRTBufferAllocation : public core::RefCounted {
46  public:
47   XRTBufferAllocation(const se::DeviceMemoryBase& allocation,
48                       int device_ordinal,
49                       xla::DeviceMemoryAllocator* allocator);
50   ~XRTBufferAllocation() override;
51 
52   // The region of device memory being wrapped.
53   const se::DeviceMemoryBase& allocation();
54 
55   // Sets the DeviceMemoryBase to be null. DiscardAllocation should be called
56   // when ownership of the underlying buffer has been transferred, e.g., to an
57   // output buffer when input and output buffers are aliased during
58   // execution. The call to DiscardAllocation prevents any device buffer being
59   // freed when the reference count drops to zero.
60   void DiscardAllocation();
61 
62   // Returns the expected size of the allocation. Since DiscardAllocation() will
63   // set allocation_ to {null,0}, and since later we might want to replace the
64   // discarded buffer with a new one, we need to be able to verify the size
65   // compatibility.
size()66   uint64 size() const { return size_; }
67 
68  private:
69   uint64 size_ = 0;
70   se::DeviceMemoryBase allocation_;
71   int device_ordinal_;
72   xla::DeviceMemoryAllocator* allocator_;
73 };
74 
75 // Entry in the resource manager corresponding to an allocation handle returned
76 // to a client. The handle identifies an immutable tuple of data in device
77 // memory. New handles can be created in three ways: by passing a literal in
78 // which case device memory is allocated and the literal is transferred to that
79 // memory; by aliasing a sub-shape of an existing tuple-shaped handle; or by
80 // aliasing a vector of existing handles to create a new tuple. The underlying
81 // storage is reference-counted. When a handle is released, the reference count
82 // of each storage buffer is decremented, and buffers with no outstanding
83 // references are freed.
84 class XRTTupleAllocation : public ResourceBase {
85  public:
86   ~XRTTupleAllocation() override;
87 
88   // Allocates new device memory buffers sufficient to store literal, transfers
89   // literal to that memory, and returns a XRTTupleAllocation handle to the
90   // allocated buffers.
91   static Status CreateAndTransfer(const xla::LiteralBase& literal,
92                                   xla::Backend* backend, int device_ordinal,
93                                   XRTTupleAllocation** allocation);
94 
95   // Wraps an existing ShapeBuffer in a new XRTTupleAllocation handle.
96   static Status CreateFromBuffer(const xla::ShapedBuffer& shaped_buffer,
97                                  xla::Backend* backend, int device_ordinal,
98                                  XRTTupleAllocation** allocation);
99 
100   // Aliases a sub-shape of parent and returns a XRTTupleAllocation handle
101   // to the sub-shape. If alias_base_allocation is true, the buffers in the
102   // sub-shape will be shared between parent and the returned allocation,
103   // otherwise the overlapping buffers in parent will be replaced by
104   // nullptr.
105   static Status MakeSubBuffer(XRTTupleAllocation* parent,
106                               const xla::ShapeIndex& subshape,
107                               XRTTupleAllocation** allocation,
108                               bool alias_parent_allocation);
109 
110   // A structure describing a leaf of a tree of tuples to expand. Each leaf
111   // contains an allocation and indicates whether or not the allocation's handle
112   // should be freed after incorporating its buffers into the expanded tree.
113   struct ExpandedTupleInput {
114     XRTTupleAllocation* allocation;
115     bool release_allocation_after_use;
116   };
117 
118   // Returns a handle to a new tuple where the subtree of the new tuple at an
119   // index corresponding to a leaf of 'elements' is constructed from the
120   // allocation (i.e., a tuple or array) pointed to by that leaf. If
121   // release_allocation_after_use is false at a leaf, the new tuple will alias
122   // the input allocation at that leaf, otherwise the input allocation will be
123   // released. Input allocations may be repeated (appear in more than one leaf)
124   // in which case the corresponding buffers in the output tuple will alias. If
125   // an input is repeated, release_input_handle must be false for every leaf
126   // where that input appears. The latter property is not validated by MakeTuple
127   // and must be enforced by the caller.
128   static Status MakeTuple(xla::Backend* backend, int device_ordinal,
129                           const xla::ShapeTree<ExpandedTupleInput>& elements,
130                           XRTTupleAllocation** allocation);
131 
132   // Retrieves the allocation interned under key from rm. The caller owns a
133   // reference to allocation after looking it up.
134   static Status Lookup(ResourceMgr* rm, int64 key,
135                        XRTTupleAllocation** allocation);
136 
137   // Deletes the reference in the rm to an allocation interned under key.
138   static Status DeleteFromResourceManager(ResourceMgr* rm, int64 key);
139 
140   // Releases all the device memory allocated by XRT within the resource
141   // manager.
142   static Status ReleaseAllAllocations(ResourceMgr* rm);
143 
144   // Adds the allocation to a ResourceMgr and returns the key that will be used
145   // to retrieve it. Transfers a reference on *this to rm.
146   Status Intern(ResourceMgr* rm, int64* key);
147 
148   // Copies the allocation from device to host and returns it in literal.
149   Status ToLiteral(xla::Backend* backend, int device_ordinal,
150                    xla::MutableLiteralBase* literal);
151 
152   // Write a new literal value to the allocation.
153   Status WriteLiteral(xla::Backend* backend, const xla::Literal& literal);
154 
155   // True if none of the buffers in the allocation are aliased by any other live
156   // handle.
157   bool IsExclusiveOwner();
158 
159   // The ordinal of the device holding this tuple.
160   int device_ordinal();
161 
162   // Returns the shape of the tuple as seen by the host.
163   const xla::Shape& on_host_shape();
164 
165   // Returns the shape of the tuple as stored on the device.
166   const xla::Shape& on_device_shape();
167 
168   // Returns the buffer pointed to by the root of the tuple.
169   const se::DeviceMemoryBase& root_allocation();
170 
171   // Stops managing the storage for the allocation at buffer_index, e.g.,
172   // because it has been aliased to the output buffer of a computation.
173   void DiscardAllocation(const xla::ShapeIndex& buffer_index);
174 
175   // Returns the tree of allocations as a ShapedBuffer. This tree may not have
176   // the same shape as on_host_shape.
177   xla::ShapedBuffer ToShapedBuffer();
178 
179   // Aliases the source buffer at source_index into the current tuple allocation
180   // dest_index.
181   Status AliasBufferFrom(const XRTTupleAllocation& source,
182                          const xla::ShapeIndex& source_index,
183                          const xla::ShapeIndex& dest_index);
184 
185   // Returns the device memory tree of this allocation. If the release_checker
186   // function returns true for a given index, the ownership of the device memory
187   // at that index is transferred to the result. Every attempt to read the value
188   // at that index will fail.
189   xla::ShapeTree<xla::MaybeOwningDeviceMemory> ToDeviceMemoryTree(
190       const std::function<bool(const xla::ShapeIndex&)>& release_checker);
191 
DebugString()192   string DebugString() const override { return "XLA allocation handle"; }
193 
194  private:
195   // Creates a new handle with (tuple) shape.
196   XRTTupleAllocation(int device_ordinal, xla::DeviceMemoryAllocator* allocator,
197                      const xla::Shape& on_host_shape,
198                      const xla::Shape& on_device_shape);
199 
200   // Inherits the allocations represented in buffer, which must have the same
201   // shape as buffers_.
202   void InitializeFromShapedBuffer(const xla::ShapedBuffer& shaped_buffer,
203                                   xla::DeviceMemoryAllocator* allocator,
204                                   int device_ordinal);
205 
206   // Takes a tree 'elements' where each leaf is an allocation, validates that
207   // they are all on device_ordinal managed by allocator, and returns in
208   // host_shape and device_shape the host/device shapes of the expanded tree,
209   // where at each leaf of elements the shape of the allocation at elements is
210   // grafted on.
211   static Status ExpandTreeOfTuples(
212       const xla::ShapeTree<ExpandedTupleInput>& elements, int device_ordinal,
213       xla::DeviceMemoryAllocator* allocator, xla::Shape* host_shape,
214       xla::Shape* device_shape);
215 
216   // Location of the memory that is being managed.
217   int device_ordinal_;
218   xla::DeviceMemoryAllocator* allocator_;
219 
220   // The shape that the caller thinks the tuple has.
221   const xla::Shape on_host_shape_;
222   // The shape that the tuple has on device. Store this explicitly instead of
223   // using a shape stored in ShapeTree because ShapeTree discards the layout.
224   const xla::Shape on_device_shape_;
225   // The tree of reference-counted buffers, which uses on_device_shape_ as its
226   // shape.
227   xla::ShapeTree<XRTBufferAllocation*> buffers_;
228 };
229 
230 }  // namespace tensorflow
231 
232 #endif  // TENSORFLOW_COMPILER_XRT_XRT_STATE_H_
233