1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/xrt.h"
17 
18 #include <memory>
19 #include <string>
20 #include <vector>
21 
22 #include "absl/memory/memory.h"
23 #include "tensorflow/cc/client/client_session.h"
24 #include "tensorflow/cc/framework/ops.h"
25 #include "tensorflow/cc/framework/scope.h"
26 #include "tensorflow/cc/ops/standard_ops.h"
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/literal_util.h"
29 #include "tensorflow/compiler/xla/service/hlo.pb.h"
30 #include "tensorflow/compiler/xla/service/platform_util.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/compiler/xrt/cc/ops/xrt_compile_ops.h"
35 #include "tensorflow/compiler/xrt/cc/ops/xrt_execute_op.h"
36 #include "tensorflow/compiler/xrt/cc/ops/xrt_state_ops.h"
37 #include "tensorflow/compiler/xrt/xrt.pb.h"
38 #include "tensorflow/core/framework/tensor.h"
39 #include "tensorflow/core/lib/core/status.h"
40 #include "tensorflow/core/lib/gtl/array_slice.h"
41 #include "tensorflow/core/platform/thread_annotations.h"
42 #include "tensorflow/core/platform/types.h"
43 
44 namespace xla {
45 namespace swig {
46 
XrtAllocation(int64 handle,Shape shape,const string & session_target)47 XrtAllocation::XrtAllocation(int64 handle, Shape shape,
48                              const string& session_target)
49     : handle_(handle), shape_(shape), session_target_(session_target) {}
50 
~XrtAllocation()51 XrtAllocation::~XrtAllocation() {
52   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
53   auto allocation_handle =
54       tensorflow::ops::Placeholder(root, tensorflow::DT_INT64);
55   auto release =
56       tensorflow::ops::XRTReleaseAllocationHandle(root, allocation_handle);
57   if (!root.status().ok()) {
58     LOG(ERROR) << root.status();
59     return;
60   }
61 
62   tensorflow::ClientSession session(root, session_target_);
63   tensorflow::ClientSession::FeedType inputs;
64   inputs.insert({allocation_handle, handle()});
65   std::vector<tensorflow::Tensor> outputs;
66   auto status = session.Run(inputs, {}, {release}, &outputs);
67   if (!status.ok()) {
68     LOG(ERROR) << status;
69     return;
70   }
71 }
72 
73 /* static */
FromLiteral(const Literal & argument,const string & session_target)74 StatusOr<XrtAllocation*> XrtAllocation::FromLiteral(
75     const Literal& argument, const string& session_target) {
76   xrt::XLAAllocation alloc;
77   *alloc.mutable_value() = argument.ToProto();
78 
79   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
80   auto literal_string =
81       tensorflow::ops::Placeholder(root, tensorflow::DT_STRING);
82   auto literal_handle = tensorflow::ops::XRTAllocate(root, literal_string);
83   TF_RETURN_IF_ERROR(root.status());
84 
85   tensorflow::ClientSession session(root, session_target);
86   tensorflow::ClientSession::FeedType inputs;
87   inputs.insert({literal_string, alloc.SerializeAsString()});
88   std::vector<tensorflow::Tensor> outputs;
89   TF_RETURN_IF_ERROR(session.Run(inputs, {literal_handle}, &outputs));
90 
91   int64 handle = outputs[0].scalar<int64>()();
92   return new XrtAllocation(handle, argument.shape(), session_target);
93 }
94 
handle() const95 const int64 XrtAllocation::handle() const { return handle_; }
96 
shape() const97 const Shape& XrtAllocation::shape() const { return shape_; }
98 
ToLiteral() const99 StatusOr<Literal> XrtAllocation::ToLiteral() const {
100   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
101   auto allocation_handle =
102       tensorflow::ops::Placeholder(root, tensorflow::DT_INT64);
103   auto read_literal = tensorflow::ops::XRTReadLiteral(root, allocation_handle);
104   TF_RETURN_IF_ERROR(root.status());
105 
106   tensorflow::ClientSession session(root, session_target_);
107   tensorflow::ClientSession::FeedType inputs;
108   inputs.insert({allocation_handle, handle()});
109   std::vector<tensorflow::Tensor> outputs;
110   TF_RETURN_IF_ERROR(session.Run(inputs, {read_literal}, &outputs));
111 
112   xla::LiteralProto response;
113   TF_RET_CHECK(response.ParseFromString(outputs[0].scalar<string>()()));
114   return Literal::CreateFromProto(response);
115 }
116 
XrtAllocationTuple(std::vector<XrtAllocation * > elements)117 XrtAllocationTuple::XrtAllocationTuple(std::vector<XrtAllocation*> elements)
118     : elements_(std::move(elements)) {
119   for (auto* element : elements_) {
120     CHECK(element != nullptr);
121   }
122 }
123 
~XrtAllocationTuple()124 XrtAllocationTuple::~XrtAllocationTuple() {
125   for (XrtAllocation* element : elements_) {
126     if (element != nullptr) {
127       delete element;
128     }
129   }
130 }
131 
Release(int i)132 StatusOr<XrtAllocation*> XrtAllocationTuple::Release(int i) {
133   XrtAllocation* element = elements_[i];
134   if (element == nullptr) {
135     return InvalidArgument("Attempted to release already-released element %d.",
136                            i);
137   }
138   elements_[i] = nullptr;
139   return element;
140 }
141 
size() const142 int64 XrtAllocationTuple::size() const { return elements_.size(); }
143 
CompileForXrt(const string & hlo_module_proto,const std::vector<Shape> & argument_shapes,const Shape & result_shape,const string & session_target)144 StatusOr<XrtExecutable*> XrtExecutable::CompileForXrt(
145     const string& hlo_module_proto, const std::vector<Shape>& argument_shapes,
146     const Shape& result_shape, const string& session_target) {
147   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
148   auto program = tensorflow::ops::Placeholder(root, tensorflow::DT_STRING);
149   auto compile = tensorflow::ops::XRTCompile(root, program);
150   TF_RETURN_IF_ERROR(root.status());
151 
152   xrt::XLAComputation c;
153   auto config = c.mutable_config();
154   ProgramShape program_shape;
155   for (auto& shape : argument_shapes) {
156     *program_shape.add_parameters() = shape;
157   }
158   *program_shape.mutable_result() = result_shape;
159 
160   LayoutUtil::SetToDefaultLayout(&program_shape);
161   *config->mutable_program_shape() = program_shape.ToProto();
162   c.mutable_hlo_snapshot()
163       ->mutable_hlo()
164       ->mutable_hlo_module()
165       ->ParsePartialFromString(hlo_module_proto);
166 
167   tensorflow::ClientSession session(root, session_target);
168   tensorflow::ClientSession::FeedType inputs;
169   inputs.insert({program, c.SerializeAsString()});
170   std::vector<tensorflow::Tensor> outputs;
171   TF_RETURN_IF_ERROR(session.Run(inputs, {compile.handle}, &outputs));
172 
173   int64 handle = outputs[0].scalar<int64>()();
174   return new XrtExecutable(program_shape, handle, session_target);
175 }
176 
XrtExecutable(const ProgramShape & program_shape,int64 handle,const string & session_target)177 XrtExecutable::XrtExecutable(const ProgramShape& program_shape, int64 handle,
178                              const string& session_target)
179     : program_shape_(program_shape),
180       handle_(handle),
181       session_target_(session_target) {}
182 
~XrtExecutable()183 XrtExecutable::~XrtExecutable() {
184   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
185   auto computation_handle =
186       tensorflow::ops::Placeholder(root, tensorflow::DT_INT64);
187   auto release =
188       tensorflow::ops::XRTReleaseCompilationHandle(root, computation_handle);
189   if (!root.status().ok()) {
190     LOG(ERROR) << root.status();
191     return;
192   }
193 
194   tensorflow::ClientSession session(root, session_target_);
195   tensorflow::ClientSession::FeedType inputs;
196   inputs.insert({computation_handle, handle()});
197   std::vector<tensorflow::Tensor> outputs;
198   auto status = session.Run(inputs, {}, {release}, &outputs);
199   if (!status.ok()) {
200     LOG(ERROR) << status;
201     return;
202   }
203 }
204 
Execute(absl::Span<XrtAllocation * const> argument_handles)205 StatusOr<XrtAllocation*> XrtExecutable::Execute(
206     absl::Span<XrtAllocation* const> argument_handles) {
207   const int num_expected_arguments = program_shape().parameters().size();
208 
209   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
210   std::vector<tensorflow::Output> arguments;
211   arguments.reserve(num_expected_arguments);
212   for (int i = 0; i < num_expected_arguments; ++i) {
213     arguments.push_back(
214         tensorflow::ops::Placeholder(root, tensorflow::DT_INT64));
215   }
216   auto computation_handle =
217       tensorflow::ops::Placeholder(root, tensorflow::DT_INT64);
218   auto execution_config =
219       tensorflow::ops::Placeholder(root, tensorflow::DT_STRING);
220   auto execute = tensorflow::ops::XRTExecute(root, computation_handle,
221                                              execution_config, arguments);
222   TF_RETURN_IF_ERROR(root.status());
223 
224   TF_RET_CHECK(argument_handles.size() == arguments.size());
225 
226   xrt::XRTExecutionConfig e;
227   e.set_release_input_handles(false);
228   e.set_release_compilation_handle(false);
229 
230   tensorflow::ClientSession session(root, session_target_);
231   tensorflow::ClientSession::FeedType inputs;
232   for (int i = 0; i < arguments.size(); ++i) {
233     inputs.insert({arguments[i], argument_handles[i]->handle()});
234   }
235   inputs.insert({computation_handle, handle()});
236   inputs.insert({execution_config, e.SerializeAsString()});
237   std::vector<tensorflow::Tensor> outputs;
238   TF_RETURN_IF_ERROR(session.Run(inputs, {execute}, &outputs));
239 
240   int64 output = outputs[0].scalar<int64>()();
241   return new XrtAllocation(output, program_shape().result(), session_target_);
242 }
243 
program_shape() const244 const ProgramShape& XrtExecutable::program_shape() const {
245   return program_shape_;
246 }
247 
handle() const248 int64 XrtExecutable::handle() const { return handle_; }
249 
DeleteXrtAllocation(XrtAllocation * allocation)250 void DeleteXrtAllocation(XrtAllocation* allocation) { delete allocation; }
251 
DeleteXrtExecutable(XrtExecutable * computation)252 void DeleteXrtExecutable(XrtExecutable* computation) { delete computation; }
253 
DestructureXrtAllocationTuple(XrtAllocation * allocation,const string & session_target)254 StatusOr<XrtAllocationTuple*> DestructureXrtAllocationTuple(
255     XrtAllocation* allocation, const string& session_target) {
256   const Shape& tuple_shape = allocation->shape();
257 
258   if (!tuple_shape.IsTuple()) {
259     return InvalidArgument(
260         "Attemped to destructure a LocalShapedBuffer that did not have a tuple "
261         "shape; shape: %s",
262         ShapeUtil::HumanString(tuple_shape));
263   }
264 
265   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
266   auto base_handle = tensorflow::ops::Placeholder(root, tensorflow::DT_INT64);
267   auto shape_index = tensorflow::ops::Placeholder(root, tensorflow::DT_INT32);
268   auto subtuple = tensorflow::ops::XRTSubTuple(root, base_handle, shape_index);
269   TF_RETURN_IF_ERROR(root.status());
270 
271   tensorflow::ClientSession session(root, session_target);
272   tensorflow::ClientSession::FeedType inputs;
273   std::vector<XrtAllocation*> results;
274   for (int32 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) {
275     inputs.clear();
276     inputs.insert({base_handle, allocation->handle()});
277     inputs.insert({shape_index, {i}});
278     std::vector<tensorflow::Tensor> outputs;
279     auto status = session.Run(inputs, {subtuple}, &outputs);
280     if (!status.ok()) {
281       // Clean up before returning non-ok status.
282       for (int j = 0; j < results.size(); ++j) {
283         delete results[j];
284       }
285       return status;
286     }
287     const int64 subtuple_handle = outputs[0].scalar<int64>()();
288     const Shape& subtuple_shape =
289         ShapeUtil::GetTupleElementShape(tuple_shape, i);
290     results.push_back(
291         new XrtAllocation(subtuple_handle, subtuple_shape, session_target));
292   }
293   return new XrtAllocationTuple(std::move(results));
294 }
295 
296 }  // namespace swig
297 }  // namespace xla
298