1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LOCAL_CLIENT_H_
17 #define TENSORFLOW_COMPILER_XLA_CLIENT_LOCAL_CLIENT_H_
18 
19 #include <memory>
20 
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/xla/client/client.h"
23 #include "tensorflow/compiler/xla/client/executable_build_options.h"
24 #include "tensorflow/compiler/xla/client/xla_computation.h"
25 #include "tensorflow/compiler/xla/executable_run_options.h"
26 #include "tensorflow/compiler/xla/service/compiler.h"
27 #include "tensorflow/compiler/xla/service/device_memory_allocator.h"
28 #include "tensorflow/compiler/xla/service/executable.h"
29 #include "tensorflow/compiler/xla/service/hlo.pb.h"
30 #include "tensorflow/compiler/xla/service/local_service.h"
31 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
35 
36 namespace xla {
37 
38 class LocalExecutable {
39  public:
40   // Run the compiled computation with the given arguments and options and
41   // return the result.
42   StatusOr<ScopedShapedBuffer> Run(
43       const absl::Span<const ShapedBuffer* const> arguments,
44       ExecutableRunOptions run_options);
45 
46   // Return the options used to build the executable.
build_options()47   const ExecutableBuildOptions& build_options() const { return build_options_; }
48 
49   // Return the built executable.
executable()50   Executable* executable() const { return executable_.get(); }
51 
52  private:
53   // Only a local client can construct these objects.
54   friend class LocalClient;
55 
56   // Constructor invoked by LocalClient.
57   LocalExecutable(std::unique_ptr<Executable> executable, Backend* backend,
58                   ExecutableBuildOptions build_options);
59 
60   // Validates that the given arguments and options satisfy various constraints
61   // of the computation.
62   //
63   // The given ExecutableRunOptions override any values from TF_XLA_FLAGS
64   // environment variable.
65   Status ValidateExecutionOptions(
66       const absl::Span<const ShapedBuffer* const> arguments,
67       const ExecutableRunOptions& run_options, const Backend& backend);
68 
69   // Records the computation in a SessionModule proto with the arguments used to
70   // invoke it, and the result. Enabled by flag: --xla_dump_hlo_snapshots.
71   //
72   // The given ServiceExecutableRunOptions override any values from the
73   // XLA_FLAGS environment variable.
74   StatusOr<ScopedShapedBuffer> ExecuteAndDump(
75       const ServiceExecutableRunOptions* run_options,
76       const absl::Span<const ShapedBuffer* const> arguments);
77 
78   // Records the arguments used to invoke the computation in a SessionModule
79   // proto.
80   Status RecordArguments(const absl::Span<const ShapedBuffer* const> arguments,
81                          HloSnapshot* hlo_snapshot);
82 
83   // Records the result of the computation in a SessionModule proto.
84   Status RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot);
85 
86   // Returns a literal containing the contents of the given ShapedBuffer.
87   StatusOr<Literal> LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer);
88 
89   // The ordinal of the device which this executable was compiled for. The
90   // executable can run on all equivalent devices (as determined by
91   // Backend::devices_equivalent).
build_device_ordinal()92   int build_device_ordinal() const { return build_options_.device_ordinal(); }
93 
94   // Compiled computation.
95   std::unique_ptr<Executable> executable_;
96 
97   // Execution backend.
98   Backend* backend_ = nullptr;
99 
100   // Options used to build the executable.
101   const ExecutableBuildOptions build_options_;
102 };
103 
104 // An XLA Client specialization for use when the client and service run in
105 // the same process.
106 class LocalClient : public Client {
107  public:
LocalClient(LocalService * service)108   explicit LocalClient(LocalService* service)
109       : Client(service), local_service_(service) {}
110 
111   LocalClient(const LocalClient&) = delete;
112   void operator=(const LocalClient&) = delete;
113 
114   // Build and return a LocalExecutable object. The executable is compiled using
115   // the given XlaComputation, argument layouts and options.
116   //
117   // The given ExecutableBuildOptions overrides any values from XLA_FLAGS
118   // environment variable.
119   StatusOr<std::unique_ptr<LocalExecutable>> Compile(
120       const XlaComputation& computation,
121       const absl::Span<const Shape* const> argument_layouts,
122       const ExecutableBuildOptions& options);
123 
124   // Copy the literal data to the device with the given ordinal and return as a
125   // ScopedShapedBuffer. If non-null the given memory allocator is used for
126   // device memory allocation. If null, the default memory allocator for the
127   // device is used.
128   StatusOr<ScopedShapedBuffer> LiteralToShapedBuffer(
129       const Literal& literal, int device_ordinal,
130       DeviceMemoryAllocator* allocator = nullptr);
131 
132   // Transfer the BorrowingLiteral to the device with the given ordinal.
133   StatusOr<TransferToServerResponse> TransferToLocalServer(
134       const ::xla::BorrowingLiteral& literal, int device_oridinal);
135 
136   // Copy the data from the device contained in the given ShapedBuffer and
137   // return as a Literal.
138   StatusOr<Literal> ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer);
139 
140   // Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid
141   // as long as the handle is valid.
142   StatusOr<const ShapedBuffer*> GlobalDataToShapedBuffer(
143       const GlobalDataHandle& data, int replica_number);
144 
145   // Transfer the given literal to the infeed queue of the given device.
146   // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
147   // not inherit from Client and there is no possibility of confusion with
148   // Client::TransferToInfeed.
149   Status TransferToInfeedLocal(const Literal& literal, int device_ordinal);
150 
151   // Transfer and return a value of the given shape from the outfeed of the
152   // given device.
153   // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
154   // not inherit from Client and there is no possibility of confusion with
155   // Client::TransferFromOutfeed.
156   StatusOr<Literal> TransferFromOutfeedLocal(const Shape& shape,
157                                              int device_ordinal);
158 
159   // Returns the device ordinal that corresponds to the given replica number.
160   //
161   // This returns an error if there is not a one-to-one correspondence of
162   // replicas to device ordinals, but is useful as a short term mechanism for
163   // the "easy" case where a single replica is a single device.
164   StatusOr<int> ReplicaNumberToDeviceOrdinal(int replica_number);
165 
166   // Returns the platform that the underlying service targets.
167   se::Platform* platform() const;
168 
169   // Returns the number of devices on the system of the service platform
170   // type. Not all devices may be supported by the service (see
171   // device_ordinal_supported method).
172   int device_count() const;
173 
174   // Returns the default device ordinal that the service will run computations
175   // on if no device ordinal is specified in execute options.
176   int default_device_ordinal() const;
177 
178   // Returns whether the device with the given ordinal can be used by the
179   // service to execute computations. Not all devices of a particular platform
180   // may be usable by the service (eg, a GPU with insufficient CUDA compute
181   // capability).
182   bool device_ordinal_supported(int device_ordinal) const;
183 
184   // Returns the backend used to execute computations.
185   const Backend& backend() const;
186   Backend* mutable_backend();
187 
188  private:
189   LocalService* local_service_;
190 };
191 
192 }  // namespace xla
193 
194 #endif  // TENSORFLOW_COMPILER_XLA_CLIENT_LOCAL_CLIENT_H_
195