1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_XRT_H_ 17 #define TENSORFLOW_COMPILER_XLA_PYTHON_XRT_H_ 18 19 #include <string> 20 #include <vector> 21 22 #include "absl/types/span.h" 23 #include "tensorflow/compiler/xla/literal.h" 24 #include "tensorflow/compiler/xla/shape.h" 25 26 namespace xla { 27 namespace swig { 28 29 // Represents a reference to literals that live in a device-allocated buffer via 30 // XRT. Specifically, wraps an int64 handle produced by running the allocation 31 // graph, and an XLA shape to track the referent's shape. 32 class XrtAllocation { 33 public: 34 // Accepts a `session_target` argument, used in constructing the 35 // `tensorflow::ClientSession` instance in which allocation and deallocation 36 // graphs are run. 37 static StatusOr<XrtAllocation*> FromLiteral(const Literal& argument, 38 const string& session_target); 39 40 XrtAllocation(int64 handle, Shape shape, const string& session_target); 41 ~XrtAllocation(); 42 StatusOr<Literal> ToLiteral() const; 43 const Shape& shape() const; 44 const int64 handle() const; 45 46 private: 47 const int64 handle_; 48 const Shape shape_; 49 const string session_target_; 50 }; 51 52 // Result of a tuple destructuring operation on an XrtAllocation. 53 class XrtAllocationTuple { 54 public: 55 // Note: any XrtAllocation elements that are not Release()'d will be 56 // deallocated in the destructor. 57 explicit XrtAllocationTuple(std::vector<XrtAllocation*> elements); 58 59 ~XrtAllocationTuple(); 60 61 // Releases the ith element to the caller. Further attempts to release the ith 62 // element will return an invalid argument error. 63 StatusOr<XrtAllocation*> Release(int i); 64 65 // Returns the number of elements in the destructured tuple. 66 int64 size() const; 67 68 private: 69 std::vector<XrtAllocation*> elements_; 70 }; 71 72 // Destructures a tuple-valued XrtAllocation into its constituent elements 73 // in XrtAllocationTuple form. 74 // 75 // Accepts a `session_target` argument, used in constructing the 76 // `tensorflow::ClientSession` instance in which the sub-tupling graph is run, 77 // and passed along in constructing each constituent XrtAllocation. 78 StatusOr<XrtAllocationTuple*> DestructureXrtAllocationTuple( 79 XrtAllocation* allocation, const string& session_target); 80 81 // Represents a compiled computation that can be executed given handles to 82 // device-allocated literals. Specifically, wraps an XRT computation handle. 83 class XrtExecutable { 84 public: 85 // Accepts a `session_target` argument, used in constructing the 86 // `tensorflow::ClientSession` instance in which the compilation graph is run. 87 static StatusOr<XrtExecutable*> CompileForXrt( 88 const string& hlo_module_proto, const std::vector<Shape>& argument_shapes, 89 const Shape& result_shape, const string& session_target); 90 91 // Accepts a `session_target` argument, used in constructing the 92 // `tensorflow::ClientSession` instance in which the execution graph is run. 93 XrtExecutable(const ProgramShape& program_shape, int64 handle, 94 const string& session_target); 95 ~XrtExecutable(); 96 DeviceOrdinals()97 std::vector<int> DeviceOrdinals() const { return {0}; } 98 99 StatusOr<XrtAllocation*> Execute( 100 absl::Span<XrtAllocation* const> argument_handles); 101 102 const ProgramShape& program_shape() const; 103 int64 handle() const; 104 105 private: 106 const ProgramShape program_shape_; 107 const int64 handle_; 108 const string session_target_; 109 }; 110 111 // Functions for freeing resources from the Python side. 112 void DeleteXrtAllocation(XrtAllocation* allocation); 113 void DeleteXrtExecutable(XrtExecutable* computation); 114 115 } // namespace swig 116 } // namespace xla 117 118 #endif // TENSORFLOW_COMPILER_XLA_PYTHON_XRT_H_ 119