1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_XRT_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_XRT_H_
18 
19 #include <string>
20 #include <vector>
21 
22 #include "absl/types/span.h"
23 #include "tensorflow/compiler/xla/literal.h"
24 #include "tensorflow/compiler/xla/shape.h"
25 
26 namespace xla {
27 namespace swig {
28 
29 // Represents a reference to literals that live in a device-allocated buffer via
30 // XRT. Specifically, wraps an int64 handle produced by running the allocation
31 // graph, and an XLA shape to track the referent's shape.
32 class XrtAllocation {
33  public:
34   // Accepts a `session_target` argument, used in constructing the
35   // `tensorflow::ClientSession` instance in which allocation and deallocation
36   // graphs are run.
37   static StatusOr<XrtAllocation*> FromLiteral(const Literal& argument,
38                                               const string& session_target);
39 
40   XrtAllocation(int64 handle, Shape shape, const string& session_target);
41   ~XrtAllocation();
42   StatusOr<Literal> ToLiteral() const;
43   const Shape& shape() const;
44   const int64 handle() const;
45 
46  private:
47   const int64 handle_;
48   const Shape shape_;
49   const string session_target_;
50 };
51 
52 // Result of a tuple destructuring operation on an XrtAllocation.
53 class XrtAllocationTuple {
54  public:
55   // Note: any XrtAllocation elements that are not Release()'d will be
56   // deallocated in the destructor.
57   explicit XrtAllocationTuple(std::vector<XrtAllocation*> elements);
58 
59   ~XrtAllocationTuple();
60 
61   // Releases the ith element to the caller. Further attempts to release the ith
62   // element will return an invalid argument error.
63   StatusOr<XrtAllocation*> Release(int i);
64 
65   // Returns the number of elements in the destructured tuple.
66   int64 size() const;
67 
68  private:
69   std::vector<XrtAllocation*> elements_;
70 };
71 
72 // Destructures a tuple-valued XrtAllocation into its constituent elements
73 // in XrtAllocationTuple form.
74 //
75 // Accepts a `session_target` argument, used in constructing the
76 // `tensorflow::ClientSession` instance in which the sub-tupling graph is run,
77 // and passed along in constructing each constituent XrtAllocation.
78 StatusOr<XrtAllocationTuple*> DestructureXrtAllocationTuple(
79     XrtAllocation* allocation, const string& session_target);
80 
81 // Represents a compiled computation that can be executed given handles to
82 // device-allocated literals. Specifically, wraps an XRT computation handle.
83 class XrtExecutable {
84  public:
85   // Accepts a `session_target` argument, used in constructing the
86   // `tensorflow::ClientSession` instance in which the compilation graph is run.
87   static StatusOr<XrtExecutable*> CompileForXrt(
88       const string& hlo_module_proto, const std::vector<Shape>& argument_shapes,
89       const Shape& result_shape, const string& session_target);
90 
91   // Accepts a `session_target` argument, used in constructing the
92   // `tensorflow::ClientSession` instance in which the execution graph is run.
93   XrtExecutable(const ProgramShape& program_shape, int64 handle,
94                 const string& session_target);
95   ~XrtExecutable();
96 
DeviceOrdinals()97   std::vector<int> DeviceOrdinals() const { return {0}; }
98 
99   StatusOr<XrtAllocation*> Execute(
100       absl::Span<XrtAllocation* const> argument_handles);
101 
102   const ProgramShape& program_shape() const;
103   int64 handle() const;
104 
105  private:
106   const ProgramShape program_shape_;
107   const int64 handle_;
108   const string session_target_;
109 };
110 
111 // Functions for freeing resources from the Python side.
112 void DeleteXrtAllocation(XrtAllocation* allocation);
113 void DeleteXrtExecutable(XrtExecutable* computation);
114 
115 }  // namespace swig
116 }  // namespace xla
117 
118 #endif  // TENSORFLOW_COMPILER_XLA_PYTHON_XRT_H_
119