1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
17 
18 #include <iterator>
19 #include <memory>
20 
21 #include "absl/synchronization/mutex.h"
22 #include "tensorflow/compiler/xla/pjrt/local_device_state.h"
23 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
24 #include "tensorflow/compiler/xla/types.h"
25 #include "tensorflow/stream_executor/device_memory.h"
26 #include "tensorflow/stream_executor/device_memory_allocator.h"
27 #include "tensorflow/stream_executor/event.h"
28 #include "tensorflow/stream_executor/stream.h"
29 
30 namespace xla {
31 
SetSequencingEvent(EventPool::Handle event,se::Stream * stream)32 void BufferSequencingEvent::SetSequencingEvent(EventPool::Handle event,
33                                                se::Stream* stream) {
34   absl::MutexLock lock(&mu_);
35   CHECK(!event_.event());
36   event_ = std::move(event);
37   CHECK(streams_defined_on_.empty());
38   streams_defined_on_.push_back(stream);
39 }
40 
EventHasBeenRecorded() const41 bool BufferSequencingEvent::EventHasBeenRecorded() const {
42   return event_.event() != nullptr;
43 }
44 
sequence_number() const45 uint64 BufferSequencingEvent::sequence_number() const {
46   absl::MutexLock lock(&mu_);
47   CHECK(EventHasBeenRecorded());
48   return event_.sequence_number();
49 }
50 
WaitForEventOnStream(se::Stream * stream)51 void BufferSequencingEvent::WaitForEventOnStream(se::Stream* stream) {
52   absl::MutexLock lock(&mu_);
53 
54   // We cannot wait for an event until ThenRecordEvent has been called; on GPU
55   // newly created events are deemed to have already happened past.
56   mu_.Await(
57       absl::Condition(this, &BufferSequencingEvent::EventHasBeenRecorded));
58 
59   // The set of defined streams is expected to be very small indeed (usually
60   // 1-2), so a simple linear scan should be fast enough.
61   if (std::find(streams_defined_on_.begin(), streams_defined_on_.end(),
62                 stream) != streams_defined_on_.end()) {
63     // stream is in streams_defined_on_; it doesn't need to be waited on.
64     return;
65   }
66 
67   stream->ThenWaitFor(event_.event());
68   streams_defined_on_.push_back(stream);
69 }
70 
DefinedOn(se::Stream * stream)71 bool BufferSequencingEvent::DefinedOn(se::Stream* stream) {
72   absl::MutexLock lock(&mu_);
73 
74   // We cannot wait for an event until ThenRecordEvent has been called; on GPU
75   // newly created events are deemed to have already happened past.
76   mu_.Await(
77       absl::Condition(this, &BufferSequencingEvent::EventHasBeenRecorded));
78 
79   // The set of defined streams is expected to be very small indeed (usually
80   // 1-2), so a simple linear scan should be fast enough.
81   return std::find(streams_defined_on_.begin(), streams_defined_on_.end(),
82                    stream) != streams_defined_on_.end();
83 }
84 
IsComplete()85 bool BufferSequencingEvent::IsComplete() {
86   absl::MutexLock lock(&mu_);
87 
88   // We cannot wait for an event until ThenRecordEvent has been called; on
89   // GPU newly created events are deemed to have already happened past.
90   mu_.Await(
91       absl::Condition(this, &BufferSequencingEvent::EventHasBeenRecorded));
92 
93   return event_.event()->PollForStatus() == se::Event::Status::kComplete;
94 }
95 
96 /* static */ std::shared_ptr<TrackedDeviceBuffer>
FromScopedShapedBuffer(ScopedShapedBuffer * shaped_buffer,absl::Span<const std::shared_ptr<BufferSequencingEvent>> definition_events)97 TrackedDeviceBuffer::FromScopedShapedBuffer(
98     ScopedShapedBuffer* shaped_buffer,
99     absl::Span<const std::shared_ptr<BufferSequencingEvent>>
100         definition_events) {
101   ShapeTree<se::DeviceMemoryBase>::iterator iterator =
102       shaped_buffer->buffers().begin();
103   std::vector<se::DeviceMemoryBase> buffers;
104   buffers.reserve(1);
105 
106   ShapeUtil::ForEachSubshape(
107       shaped_buffer->on_device_shape(), [&](const Shape&, const ShapeIndex&) {
108         CHECK(iterator != shaped_buffer->buffers().end());
109         buffers.push_back(iterator->second);
110         iterator->second = se::DeviceMemoryBase();
111         ++iterator;
112       });
113   CHECK(iterator == shaped_buffer->buffers().end());
114   return std::make_shared<TrackedDeviceBuffer>(
115       shaped_buffer->memory_allocator(), shaped_buffer->device_ordinal(),
116       absl::Span<se::DeviceMemoryBase>(buffers), definition_events,
117       /*on_delete_callback=*/nullptr);
118 }
119 
AsShapedBuffer(const Shape & on_device_shape) const120 ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer(
121     const Shape& on_device_shape) const {
122   ShapedBuffer shaped_buffer(on_device_shape, device_ordinal_);
123   ShapeTree<se::DeviceMemoryBase>::iterator iterator =
124       shaped_buffer.buffers().begin();
125   for (const se::DeviceMemoryBase& buf : device_memory_) {
126     CHECK(iterator != shaped_buffer.buffers().end());
127     iterator->second = buf;
128     ++iterator;
129   }
130   CHECK(iterator == shaped_buffer.buffers().end());
131   return shaped_buffer;
132 }
133 
134 // See comment on ExecutionInput in xla/service/executable.h to understand
135 // the meaning of owned/unowned in that class.
136 
AddToInputAsImmutable(ShapeTree<MaybeOwningDeviceMemory>::iterator * iterator,const ShapeTree<MaybeOwningDeviceMemory>::iterator & end) const137 void TrackedDeviceBuffer::AddToInputAsImmutable(
138     ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
139     const ShapeTree<MaybeOwningDeviceMemory>::iterator& end) const {
140   for (const se::DeviceMemoryBase& buf : device_memory_) {
141     CHECK(*iterator != end);
142     // Set buffers to be case (1) in the comment on ExecutionInput.
143     (*iterator)->second = MaybeOwningDeviceMemory(buf);
144     ++(*iterator);
145   }
146 }
147 
AddToInputAsDonated(ShapeTree<MaybeOwningDeviceMemory>::iterator * iterator,const ShapeTree<MaybeOwningDeviceMemory>::iterator & end,ExecutionInput * execution_input,se::DeviceMemoryAllocator * allocator) const148 void TrackedDeviceBuffer::AddToInputAsDonated(
149     ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
150     const ShapeTree<MaybeOwningDeviceMemory>::iterator& end,
151     ExecutionInput* execution_input,
152     se::DeviceMemoryAllocator* allocator) const {
153   for (const se::DeviceMemoryBase& buf : device_memory_) {
154     CHECK(*iterator != end);
155     // Set buffers to be case (2) in the comment on ExecutionInput.
156     (*iterator)->second = MaybeOwningDeviceMemory(
157         se::OwningDeviceMemory(buf, device_ordinal_, allocator));
158     execution_input->SetUnownedIndex((*iterator)->first);
159     ++(*iterator);
160   }
161 }
162 
TrackedDeviceBuffer(se::DeviceMemoryAllocator * allocator,int device_ordinal,absl::Span<se::DeviceMemoryBase const> device_memory,absl::Span<const std::shared_ptr<BufferSequencingEvent>> definition_events,std::function<void ()> on_delete_callback)163 TrackedDeviceBuffer::TrackedDeviceBuffer(
164     se::DeviceMemoryAllocator* allocator, int device_ordinal,
165     absl::Span<se::DeviceMemoryBase const> device_memory,
166     absl::Span<const std::shared_ptr<BufferSequencingEvent>> definition_events,
167     std::function<void()> on_delete_callback)
168     : allocator_(allocator),
169       device_ordinal_(device_ordinal),
170       device_memory_(device_memory.begin(), device_memory.end()),
171       definition_events_(std::make_move_iterator(definition_events.begin()),
172                          std::make_move_iterator(definition_events.end())),
173       in_use_(true),
174       on_delete_callback_(std::move(on_delete_callback)) {}
175 
~TrackedDeviceBuffer()176 TrackedDeviceBuffer::~TrackedDeviceBuffer() {
177   if (allocator_) {
178     for (const se::DeviceMemoryBase& buffer : device_memory_) {
179       Status status = allocator_->Deallocate(device_ordinal_, buffer);
180       if (!status.ok()) {
181         LOG(ERROR) << "Buffer deallocation failed: " << status;
182       }
183     }
184   }
185   if (on_delete_callback_) {
186     on_delete_callback_();
187   }
188 }
189 
AddUsageEvent(se::Stream * usage_stream,std::shared_ptr<BufferSequencingEvent> event,bool reference_held)190 void TrackedDeviceBuffer::AddUsageEvent(
191     se::Stream* usage_stream, std::shared_ptr<BufferSequencingEvent> event,
192     bool reference_held) {
193   CHECK(in_use_);
194 
195   for (auto& existing : usage_events_) {
196     if (existing.stream == usage_stream) {
197       if (*existing.event < *event) {
198         existing.event = event;
199         existing.reference_held = reference_held;
200       }
201       return;
202     }
203   }
204   usage_events_.push_back({usage_stream, event, reference_held});
205 }
206 
207 TrackedDeviceBuffer::StreamAndEventContainer
LockUseAndTransferUsageEvents()208 TrackedDeviceBuffer::LockUseAndTransferUsageEvents() {
209   CHECK(in_use_);
210   in_use_ = false;
211   return std::move(usage_events_);
212 }
213 
GetDeviceBufferEvents(const TrackedDeviceBuffer & buffer,bool get_usage_events,absl::flat_hash_set<BufferSequencingEvent * > * events)214 void GetDeviceBufferEvents(
215     const TrackedDeviceBuffer& buffer, bool get_usage_events,
216     absl::flat_hash_set<BufferSequencingEvent*>* events) {
217   if (get_usage_events) {
218     for (const auto& e : buffer.usage_events()) {
219       events->insert(e.event.get());
220     }
221   } else {
222     for (const auto& e : buffer.definition_events()) {
223       events->insert(e.get());
224     }
225   }
226 }
227 
WaitForBufferDefinitionEventsOnStream(const TrackedDeviceBuffer & buffer,se::Stream * stream)228 void WaitForBufferDefinitionEventsOnStream(const TrackedDeviceBuffer& buffer,
229                                            se::Stream* stream) {
230   absl::flat_hash_set<BufferSequencingEvent*> events;
231   GetDeviceBufferEvents(buffer, /*get_usage_events=*/false, &events);
232   for (BufferSequencingEvent* event : events) {
233     event->WaitForEventOnStream(stream);
234   }
235 }
236 
237 }  // namespace xla
238