1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/xrt.h"
17
18 #include <memory>
19 #include <string>
20 #include <vector>
21
22 #include "absl/memory/memory.h"
23 #include "tensorflow/cc/client/client_session.h"
24 #include "tensorflow/cc/framework/ops.h"
25 #include "tensorflow/cc/framework/scope.h"
26 #include "tensorflow/cc/ops/standard_ops.h"
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/literal_util.h"
29 #include "tensorflow/compiler/xla/service/hlo.pb.h"
30 #include "tensorflow/compiler/xla/service/platform_util.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/compiler/xrt/cc/ops/xrt_compile_ops.h"
35 #include "tensorflow/compiler/xrt/cc/ops/xrt_execute_op.h"
36 #include "tensorflow/compiler/xrt/cc/ops/xrt_state_ops.h"
37 #include "tensorflow/compiler/xrt/xrt.pb.h"
38 #include "tensorflow/core/framework/tensor.h"
39 #include "tensorflow/core/lib/core/status.h"
40 #include "tensorflow/core/lib/gtl/array_slice.h"
41 #include "tensorflow/core/platform/thread_annotations.h"
42 #include "tensorflow/core/platform/types.h"
43
44 namespace xla {
45 namespace swig {
46
XrtAllocation(int64 handle,Shape shape,const string & session_target)47 XrtAllocation::XrtAllocation(int64 handle, Shape shape,
48 const string& session_target)
49 : handle_(handle), shape_(shape), session_target_(session_target) {}
50
~XrtAllocation()51 XrtAllocation::~XrtAllocation() {
52 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
53 auto allocation_handle =
54 tensorflow::ops::Placeholder(root, tensorflow::DT_INT64);
55 auto release =
56 tensorflow::ops::XRTReleaseAllocationHandle(root, allocation_handle);
57 if (!root.status().ok()) {
58 LOG(ERROR) << root.status();
59 return;
60 }
61
62 tensorflow::ClientSession session(root, session_target_);
63 tensorflow::ClientSession::FeedType inputs;
64 inputs.insert({allocation_handle, handle()});
65 std::vector<tensorflow::Tensor> outputs;
66 auto status = session.Run(inputs, {}, {release}, &outputs);
67 if (!status.ok()) {
68 LOG(ERROR) << status;
69 return;
70 }
71 }
72
73 /* static */
FromLiteral(const Literal & argument,const string & session_target)74 StatusOr<XrtAllocation*> XrtAllocation::FromLiteral(
75 const Literal& argument, const string& session_target) {
76 xrt::XLAAllocation alloc;
77 *alloc.mutable_value() = argument.ToProto();
78
79 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
80 auto literal_string =
81 tensorflow::ops::Placeholder(root, tensorflow::DT_STRING);
82 auto literal_handle = tensorflow::ops::XRTAllocate(root, literal_string);
83 TF_RETURN_IF_ERROR(root.status());
84
85 tensorflow::ClientSession session(root, session_target);
86 tensorflow::ClientSession::FeedType inputs;
87 inputs.insert({literal_string, alloc.SerializeAsString()});
88 std::vector<tensorflow::Tensor> outputs;
89 TF_RETURN_IF_ERROR(session.Run(inputs, {literal_handle}, &outputs));
90
91 int64 handle = outputs[0].scalar<int64>()();
92 return new XrtAllocation(handle, argument.shape(), session_target);
93 }
94
handle() const95 const int64 XrtAllocation::handle() const { return handle_; }
96
shape() const97 const Shape& XrtAllocation::shape() const { return shape_; }
98
ToLiteral() const99 StatusOr<Literal> XrtAllocation::ToLiteral() const {
100 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
101 auto allocation_handle =
102 tensorflow::ops::Placeholder(root, tensorflow::DT_INT64);
103 auto read_literal = tensorflow::ops::XRTReadLiteral(root, allocation_handle);
104 TF_RETURN_IF_ERROR(root.status());
105
106 tensorflow::ClientSession session(root, session_target_);
107 tensorflow::ClientSession::FeedType inputs;
108 inputs.insert({allocation_handle, handle()});
109 std::vector<tensorflow::Tensor> outputs;
110 TF_RETURN_IF_ERROR(session.Run(inputs, {read_literal}, &outputs));
111
112 xla::LiteralProto response;
113 TF_RET_CHECK(response.ParseFromString(outputs[0].scalar<string>()()));
114 return Literal::CreateFromProto(response);
115 }
116
XrtAllocationTuple(std::vector<XrtAllocation * > elements)117 XrtAllocationTuple::XrtAllocationTuple(std::vector<XrtAllocation*> elements)
118 : elements_(std::move(elements)) {
119 for (auto* element : elements_) {
120 CHECK(element != nullptr);
121 }
122 }
123
~XrtAllocationTuple()124 XrtAllocationTuple::~XrtAllocationTuple() {
125 for (XrtAllocation* element : elements_) {
126 if (element != nullptr) {
127 delete element;
128 }
129 }
130 }
131
Release(int i)132 StatusOr<XrtAllocation*> XrtAllocationTuple::Release(int i) {
133 XrtAllocation* element = elements_[i];
134 if (element == nullptr) {
135 return InvalidArgument("Attempted to release already-released element %d.",
136 i);
137 }
138 elements_[i] = nullptr;
139 return element;
140 }
141
size() const142 int64 XrtAllocationTuple::size() const { return elements_.size(); }
143
CompileForXrt(const string & hlo_module_proto,const std::vector<Shape> & argument_shapes,const Shape & result_shape,const string & session_target)144 StatusOr<XrtExecutable*> XrtExecutable::CompileForXrt(
145 const string& hlo_module_proto, const std::vector<Shape>& argument_shapes,
146 const Shape& result_shape, const string& session_target) {
147 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
148 auto program = tensorflow::ops::Placeholder(root, tensorflow::DT_STRING);
149 auto compile = tensorflow::ops::XRTCompile(root, program);
150 TF_RETURN_IF_ERROR(root.status());
151
152 xrt::XLAComputation c;
153 auto config = c.mutable_config();
154 ProgramShape program_shape;
155 for (auto& shape : argument_shapes) {
156 *program_shape.add_parameters() = shape;
157 }
158 *program_shape.mutable_result() = result_shape;
159
160 LayoutUtil::SetToDefaultLayout(&program_shape);
161 *config->mutable_program_shape() = program_shape.ToProto();
162 c.mutable_hlo_snapshot()
163 ->mutable_hlo()
164 ->mutable_hlo_module()
165 ->ParsePartialFromString(hlo_module_proto);
166
167 tensorflow::ClientSession session(root, session_target);
168 tensorflow::ClientSession::FeedType inputs;
169 inputs.insert({program, c.SerializeAsString()});
170 std::vector<tensorflow::Tensor> outputs;
171 TF_RETURN_IF_ERROR(session.Run(inputs, {compile.handle}, &outputs));
172
173 int64 handle = outputs[0].scalar<int64>()();
174 return new XrtExecutable(program_shape, handle, session_target);
175 }
176
XrtExecutable(const ProgramShape & program_shape,int64 handle,const string & session_target)177 XrtExecutable::XrtExecutable(const ProgramShape& program_shape, int64 handle,
178 const string& session_target)
179 : program_shape_(program_shape),
180 handle_(handle),
181 session_target_(session_target) {}
182
~XrtExecutable()183 XrtExecutable::~XrtExecutable() {
184 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
185 auto computation_handle =
186 tensorflow::ops::Placeholder(root, tensorflow::DT_INT64);
187 auto release =
188 tensorflow::ops::XRTReleaseCompilationHandle(root, computation_handle);
189 if (!root.status().ok()) {
190 LOG(ERROR) << root.status();
191 return;
192 }
193
194 tensorflow::ClientSession session(root, session_target_);
195 tensorflow::ClientSession::FeedType inputs;
196 inputs.insert({computation_handle, handle()});
197 std::vector<tensorflow::Tensor> outputs;
198 auto status = session.Run(inputs, {}, {release}, &outputs);
199 if (!status.ok()) {
200 LOG(ERROR) << status;
201 return;
202 }
203 }
204
Execute(absl::Span<XrtAllocation * const> argument_handles)205 StatusOr<XrtAllocation*> XrtExecutable::Execute(
206 absl::Span<XrtAllocation* const> argument_handles) {
207 const int num_expected_arguments = program_shape().parameters().size();
208
209 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
210 std::vector<tensorflow::Output> arguments;
211 arguments.reserve(num_expected_arguments);
212 for (int i = 0; i < num_expected_arguments; ++i) {
213 arguments.push_back(
214 tensorflow::ops::Placeholder(root, tensorflow::DT_INT64));
215 }
216 auto computation_handle =
217 tensorflow::ops::Placeholder(root, tensorflow::DT_INT64);
218 auto execution_config =
219 tensorflow::ops::Placeholder(root, tensorflow::DT_STRING);
220 auto execute = tensorflow::ops::XRTExecute(root, computation_handle,
221 execution_config, arguments);
222 TF_RETURN_IF_ERROR(root.status());
223
224 TF_RET_CHECK(argument_handles.size() == arguments.size());
225
226 xrt::XRTExecutionConfig e;
227 e.set_release_input_handles(false);
228 e.set_release_compilation_handle(false);
229
230 tensorflow::ClientSession session(root, session_target_);
231 tensorflow::ClientSession::FeedType inputs;
232 for (int i = 0; i < arguments.size(); ++i) {
233 inputs.insert({arguments[i], argument_handles[i]->handle()});
234 }
235 inputs.insert({computation_handle, handle()});
236 inputs.insert({execution_config, e.SerializeAsString()});
237 std::vector<tensorflow::Tensor> outputs;
238 TF_RETURN_IF_ERROR(session.Run(inputs, {execute}, &outputs));
239
240 int64 output = outputs[0].scalar<int64>()();
241 return new XrtAllocation(output, program_shape().result(), session_target_);
242 }
243
program_shape() const244 const ProgramShape& XrtExecutable::program_shape() const {
245 return program_shape_;
246 }
247
handle() const248 int64 XrtExecutable::handle() const { return handle_; }
249
DeleteXrtAllocation(XrtAllocation * allocation)250 void DeleteXrtAllocation(XrtAllocation* allocation) { delete allocation; }
251
DeleteXrtExecutable(XrtExecutable * computation)252 void DeleteXrtExecutable(XrtExecutable* computation) { delete computation; }
253
DestructureXrtAllocationTuple(XrtAllocation * allocation,const string & session_target)254 StatusOr<XrtAllocationTuple*> DestructureXrtAllocationTuple(
255 XrtAllocation* allocation, const string& session_target) {
256 const Shape& tuple_shape = allocation->shape();
257
258 if (!tuple_shape.IsTuple()) {
259 return InvalidArgument(
260 "Attemped to destructure a LocalShapedBuffer that did not have a tuple "
261 "shape; shape: %s",
262 ShapeUtil::HumanString(tuple_shape));
263 }
264
265 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
266 auto base_handle = tensorflow::ops::Placeholder(root, tensorflow::DT_INT64);
267 auto shape_index = tensorflow::ops::Placeholder(root, tensorflow::DT_INT32);
268 auto subtuple = tensorflow::ops::XRTSubTuple(root, base_handle, shape_index);
269 TF_RETURN_IF_ERROR(root.status());
270
271 tensorflow::ClientSession session(root, session_target);
272 tensorflow::ClientSession::FeedType inputs;
273 std::vector<XrtAllocation*> results;
274 for (int32 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) {
275 inputs.clear();
276 inputs.insert({base_handle, allocation->handle()});
277 inputs.insert({shape_index, {i}});
278 std::vector<tensorflow::Tensor> outputs;
279 auto status = session.Run(inputs, {subtuple}, &outputs);
280 if (!status.ok()) {
281 // Clean up before returning non-ok status.
282 for (int j = 0; j < results.size(); ++j) {
283 delete results[j];
284 }
285 return status;
286 }
287 const int64 subtuple_handle = outputs[0].scalar<int64>()();
288 const Shape& subtuple_shape =
289 ShapeUtil::GetTupleElementShape(tuple_shape, i);
290 results.push_back(
291 new XrtAllocation(subtuple_handle, subtuple_shape, session_target));
292 }
293 return new XrtAllocationTuple(std::move(results));
294 }
295
296 } // namespace swig
297 } // namespace xla
298