1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/stream_executor/tpu/tpu_executable_interface.h"
17
18 #include <utility>
19
20 #include "absl/algorithm/container.h"
21 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
22 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
23 #include "tensorflow/compiler/xla/service/transfer_manager.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/compiler/xla/util.h"
27
28 namespace xla {
29
30 namespace {
31
32 // Write the tuple index buffers (arrays of pointers).
PopulateResultTupleBuffers(const ShapedBuffer & result,se::Stream * stream,se::Stream * transfer_stream)33 static Status PopulateResultTupleBuffers(const ShapedBuffer& result,
34 se::Stream* stream,
35 se::Stream* transfer_stream) {
36 TF_ASSIGN_OR_RETURN(auto transfer_manager, TransferManager::GetForPlatform(
37 stream->parent()->platform()));
38 if (transfer_manager->CanShapedBufferBeAccessedNow(stream->parent(),
39 result)) {
40 TF_RETURN_IF_ERROR(transfer_manager->WriteTupleIndexTablesAsync(
41 transfer_stream ? transfer_stream : stream, result));
42 if (transfer_stream && transfer_stream != stream) {
43 stream->ThenWaitFor(transfer_stream);
44 }
45 return Status::OK();
46 } else {
47 return transfer_manager->WriteTupleIndexTablesAsync(stream, result);
48 }
49 }
50
51 } // namespace
52
53 StatusOr<ExecutionOutput>
AllocateOutputMemoryWithInputReuse(const Shape & host_shape,const HloInputOutputAliasConfig & alias_config,se::DeviceMemoryAllocator * allocator,std::vector<ExecutionInput> * arguments,se::Stream * stream,se::Stream * transfer_stream)54 TpuExecutableInterface::AllocateOutputMemoryWithInputReuse(
55 const Shape& host_shape, const HloInputOutputAliasConfig& alias_config,
56 se::DeviceMemoryAllocator* allocator,
57 std::vector<ExecutionInput>* arguments, se::Stream* stream,
58 se::Stream* transfer_stream) {
59 auto stream_exec = stream->parent();
60 auto device_ordinal = stream_exec->device_ordinal();
61 VLOG(3) << "AllocateOutputMemoryWithInputReuse, device = " << device_ordinal
62 << " host_shape = " << ShapeUtil::HumanStringWithLayout(host_shape);
63 Shape device_shape = HostShapeToDeviceShape(host_shape);
64
65 TF_RETURN_IF_ERROR(alias_config.ForEachAliasWithStatus(
66 [&](const ShapeIndex& output_index,
67 absl::optional<HloInputOutputAliasConfig::Alias> alias) {
68 if (alias && alias->must_alias()) {
69 VLOG(1) << alias->ToString();
70 const MaybeOwningDeviceMemory& original_input =
71 (*arguments)[alias->parameter_number].Buffers().element(
72 alias->parameter_index);
73 if (!original_input.HasOwnership()) {
74 return InvalidArgument(
75 "An input was configured to be must-alias at "
76 "compile time but not donated at runtime: %s",
77 alias->ToString());
78 }
79 }
80 return Status::OK();
81 }));
82
83 if (VLOG_IS_ON(3)) {
84 VLOG(3) << "AllocateOutputMemoryWithInputReuse, device = " << device_ordinal
85 << " host_shape = " << ShapeUtil::HumanStringWithLayout(host_shape);
86 if (!Shape::Equal().MinorToMajorOnlyInLayout()(host_shape, device_shape)) {
87 VLOG(3) << "Rewrote host_shape to device_shape: "
88 << ShapeUtil::HumanStringWithLayout(host_shape) << " -> "
89 << ShapeUtil::HumanStringWithLayout(device_shape);
90 }
91 }
92
93 ExecutionOutput result(std::move(device_shape), allocator, device_ordinal);
94 // Iterate through and allocate a buffer for each shape index, checking for
95 // possible input buffer reuse.
96 int64 reused_buffer_bytes = 0;
97 int64 total_result_buffer_bytes = 0;
98 for (auto& pair : result.MutableResult()->buffers()) {
99 const ShapeIndex& result_index = pair.first;
100 se::DeviceMemoryBase& result_buffer = pair.second;
101 int64 allocation_bytes = ShapeSize(ShapeUtil::GetSubshape(
102 result.Result().on_device_shape(), result_index));
103 total_result_buffer_bytes += allocation_bytes;
104
105 // Return an InternalError if result_index is invalid. This avoids failing
106 // the CHECK when calling GetAliasedParameter
107 if (!ShapeUtil::IndexIsValid(alias_config.shape(), result_index)) {
108 return InternalError("result_index is invalid: %s",
109 result_index.ToString());
110 }
111
112 absl::optional<HloInputOutputAliasConfig::Alias> alias =
113 alias_config.GetAliasedParameter(result_index);
114 if (alias) {
115 TF_RET_CHECK(alias->parameter_number < arguments->size());
116 ExecutionInput& input = (*arguments)[alias->parameter_number];
117 MaybeOwningDeviceMemory* device_memory =
118 input.MutableBuffer(alias->parameter_index);
119 if (auto owning = device_memory->Release()) {
120 // If the caller passes the ownership of the device memory, reuse it
121 // as the output buffer. It is up to the caller whether or not to
122 // donate a buffer; the aliasing information describes which buffers
123 // may alias, not buffers that must alias.
124 se::DeviceMemoryBase device_memory_base = owning->Release();
125 *device_memory = device_memory_base;
126 result_buffer = device_memory_base;
127 reused_buffer_bytes += allocation_bytes;
128 // The caller is giving us the input buffer, but in case of error of the
129 // execute call, we should not be releasing it as it contains valid data
130 // (for example, it is a parameter which the user wants us to alias, in
131 // a gradient update computation). So we store the index into the result
132 // in the aliased vactor, which will be fed to the ExecutionOutput,
133 // which will be using the indices to drop the addresses from its own
134 // ScopedShapedBuffer result, if the ExecutionOutput is not committed.
135 result.AddAliasedIndex(result_index);
136 }
137 }
138
139 // We need to allocate a new output buffer for two cases:
140 // - There is no alias between this output and any input.
141 // - There is an alias, but the xla doesn't own the input memory so it can't
142 // donate buffer to the computation.
143 if (result_buffer.is_null()) {
144 const Shape& on_device_shape = result.Result().on_device_shape();
145 const Shape& on_device_subshape =
146 ShapeUtil::GetSubshape(on_device_shape, result_index);
147 TF_ASSIGN_OR_RETURN(
148 auto allocated_buffer,
149 allocator->Allocate(device_ordinal, allocation_bytes,
150 /*retry_on_failure=*/true,
151 on_device_subshape.layout().memory_space()));
152 // Store the allocated buffer in our ScopedShapedBuffer, which takes
153 // ownership.
154 result_buffer = allocated_buffer.Release();
155 }
156 TF_RET_CHECK(allocation_bytes == 0 || result_buffer != nullptr);
157 }
158
159 VLOG(1) << "Reused " << reused_buffer_bytes
160 << " parameter buffers (total result buffer size: "
161 << total_result_buffer_bytes << ")";
162
163 TF_RETURN_IF_ERROR(
164 PopulateResultTupleBuffers(result.Result(), stream, transfer_stream));
165 return std::move(result);
166 }
167
ExecuteAsyncOnStream(const ServiceExecutableRunOptions * run_options,std::vector<ExecutionInput> arguments,HloExecutionProfile *)168 StatusOr<ExecutionOutput> TpuExecutableInterface::ExecuteAsyncOnStream(
169 const ServiceExecutableRunOptions* run_options,
170 std::vector<ExecutionInput> arguments,
171 HloExecutionProfile* /*hlo_execution_profile*/) {
172 std::vector<se::DeviceMemoryBase> memory_bases;
173 memory_bases.reserve(arguments.size());
174 for (auto& argument : arguments) {
175 memory_bases.push_back(argument.Buffer({}).AsDeviceMemoryBase());
176 }
177 se::Stream* stream = run_options->stream();
178
179 CHECK_NE(run_options->allocator(), nullptr);
180 const Shape& shape =
181 hlo_module_ == nullptr ? ShapeUtil::MakeNil() : result_shape();
182 const HloInputOutputAliasConfig& alias_config =
183 hlo_module_ == nullptr ? HloInputOutputAliasConfig()
184 : hlo_module_->input_output_alias_config();
185 TF_ASSIGN_OR_RETURN(
186 ExecutionOutput result,
187 AllocateOutputMemoryWithInputReuse(
188 shape, alias_config, run_options->allocator(), &arguments, stream,
189 run_options->run_options().host_to_device_stream()));
190
191 // Address of the buffer in TPU memory that is being speculated.
192 absl::optional<se::DeviceMemoryBase> cross_program_prefetch_addr;
193 if (hlo_module_) {
194 for (const auto& prefetch : hlo_module_->CrossProgramPrefetches()) {
195 const auto& parameter = prefetch.first;
196 const auto& index = prefetch.second;
197 CHECK_LT(parameter, arguments.size());
198 // Ensure the cross program prefetched buffer doesn't alias with any
199 // program outputs. If the input and output aliased, the buffer could be
200 // invalidated during program execution and the program could read stale
201 // data from fast memory instead of fresh data in large memory.
202 auto it = arguments[parameter].MutableBuffers()->find({index});
203 CHECK(it != arguments[parameter].MutableBuffers()->end());
204 CHECK(!it->second.AsDeviceMemoryBase().is_null());
205 if (absl::c_none_of(result.Result().buffers(), [&](auto index_addr_pair) {
206 return index_addr_pair.second.IsSameAs(
207 it->second.AsDeviceMemoryBase());
208 })) {
209 // Supports only one cross-program prefetch address.
210 cross_program_prefetch_addr = it->second.AsDeviceMemoryBase();
211 }
212 }
213 }
214
215 // MarkToBeReleasedArguments may std::move some elements of arguments, so it
216 // must run after the cross program prefetch address is calculated from the
217 // arguments.
218 MarkToBeReleasedArguments(absl::MakeSpan(arguments), result);
219
220 TF_RETURN_IF_ERROR(LoadProgramAndEnqueueToStream(
221 *run_options, memory_bases, result.Result().root_buffer(),
222 cross_program_prefetch_addr));
223 return std::move(result);
224 }
225
226 } // namespace xla
227