1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_
17 #define TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_
18 
19 #include <memory>
20 
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/xla/client/local_client.h"
23 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
24 #include "tensorflow/core/framework/allocator.h"
25 #include "tensorflow/core/framework/device_base.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/core/platform/mutex.h"
28 
29 namespace tensorflow {
30 
31 // The implementation of a Tensor for an XlaDevice. All device tensors are
32 // actually one of these.
33 //
34 // To distinguish between "normal" device tensors and XlaTensors, the raw
35 // pointer data stored in the TensorBuffer is a tagged pointer.
36 class XlaTensor {
37  public:
38   // Downcast from a Tensor to an XlaTensor. Return nullptr if the downcast
39   // fails.
40   static XlaTensor* FromTensor(const Tensor* tensor);
41 
42   static bool RefCountIsOne(const Tensor& tensor);
43 
44   // Create a DeviceMemoryBase from a Tensor. The Tensor can be an XlaTensor, in
45   // which case the returned value is shaped_buffer()->root_buffer(), or a
46   // normal Tensor in which case the returned value is
47   // {tensor.tensor_data().data(), tensor.tensor_data().size}.
48   static se::DeviceMemoryBase DeviceMemoryFromTensor(const Tensor& tensor);
49 
50   // Assign the internal ShapedBuffer to new memory for the given dtype and
51   // shape. If a ShapedBuffer exists already (has_shaped_buffer() == true), it
52   // is replaced and the managed memory deallocated.
53   Status AllocateShapedBuffer(DataType dtype, const xla::Shape& on_host_shape,
54                               xla::LocalClient* client, int device_ordinal);
55 
56   // Some Tensors can have complex on-device shapes, including tuple shapes. To
57   // manage the memory for these tensors a ShapedBuffer may be required.
58 
59   // Return true if this XlaTensor contains a ShapedBuffer.
has_shaped_buffer()60   bool has_shaped_buffer() const { return shaped_buffer_ != nullptr; }
61   // Return the contained ShapedBuffer.
62   // REQUIRES: has_shaped_buffer()
shaped_buffer()63   const xla::ShapedBuffer& shaped_buffer() const {
64     CHECK(has_shaped_buffer());
65     return *shaped_buffer_;
66   }
shaped_buffer()67   xla::ShapedBuffer& shaped_buffer() {
68     CHECK(has_shaped_buffer());
69     return *shaped_buffer_;
70   }
71   // Mutates the XlaTensor to set the ShapedBuffer.
set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer)72   void set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer) {
73     shaped_buffer_ =
74         absl::make_unique<xla::ScopedShapedBuffer>(std::move(shaped_buffer));
75   }
76 
77   // Some tensors on the device may have known values on the host. We use these
78   // in on-demand mode to avoid re-copying values from the device if we know the
79   // host value already.
80 
81   // Return true if this XlaTensor contains a host tensor.
has_host_tensor()82   bool has_host_tensor() const { return host_tensor_ != nullptr; }
83   // Return the contained host tensor.
84   // REQUIRES: has_host_tensor()
host_tensor()85   const Tensor& host_tensor() const { return *host_tensor_; }
86   // Sets the contained host tensor.
set_host_tensor(const Tensor & tensor)87   void set_host_tensor(const Tensor& tensor) {
88     host_tensor_.reset(new Tensor(tensor));
89   }
90 
91   // Adds synchronization events to 'stream' that wait for this tensor to be
92   // defined on 'stream'. Does nothing if the tensor is already defined on that
93   // stream.
94   void WaitForDefinitionEventOnStream(se::Stream* stream);
95 
96   // (Re)sets the definition event of the tensor to 'event', and promises that
97   // the tensor has already been defined on stream. Removes any previous
98   // definition event or any previous promises about the tensor being defined on
99   // streams.
100   // It is legal to reset the definition event of a tensor when overwriting the
101   // tensor's value (at which point, it is effectively a new tensor once again.)
102   void ResetDefinitionEvent(std::shared_ptr<se::Event> event,
103                             se::Stream* stream);
104 
105   // Convert from a raw pointer to an XlaTensor, removing the pointer tag.
106   static XlaTensor* FromOpaquePointer(void* ptr);
107   // Convert to a raw pointer from an XlaTensor, adding the pointer tag.
108   static void* ToOpaquePointer(XlaTensor* tensor);
109 
110  private:
111   // The optional contained ShapedBuffer.
112   std::unique_ptr<xla::ScopedShapedBuffer> shaped_buffer_;
113   // An optional host tensor value.
114   std::unique_ptr<Tensor> host_tensor_;
115   // An optional event that is triggered when the tensor's content has been
116   // defined. If this event is nullptr, it is assumed that the tensor's content
117   // is always defined.
118   std::shared_ptr<se::Event> definition_event_;
119   // A list of all streams for which the tensor's content is defined for any
120   // newly enqueued command.
121   absl::InlinedVector<se::Stream*, 2> streams_defined_on_ GUARDED_BY(mu_);
122   mutex mu_;
123 };
124 
125 }  // namespace tensorflow
126 
127 #endif  // TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_
128