1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_
18 
19 #include <string>
20 #include <vector>
21 
22 #include <Python.h>
23 
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/client/executable_build_options.h"
26 #include "tensorflow/compiler/xla/client/local_client.h"
27 #include "tensorflow/compiler/xla/client/xla_builder.h"
28 #include "tensorflow/compiler/xla/client/xla_computation.h"
29 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 
32 namespace xla {
33 namespace swig {
34 
35 // Registers a 'fn_capsule' as a CPU custom call target.
36 // 'fn_capsule' is a void* pointer encapsulated in a PyCapsule object, with name
37 // "xla._CPU_CUSTOM_CALL_TARGET".
38 Status RegisterCpuCustomCallTarget(const string& name, PyObject* fn_capsule);
39 
40 // Wrapper around an xla::LocalClient.
41 class LocalClient {
42  public:
43   // Initializes a local XLA client for `platform_name`. Returns an error if no
44   /// such platform exists, or if the platform has no visible devices.
45   static StatusOr<LocalClient> Get(const string& platform_name);
46 
47   // Copyable and moveable; the class is just a wrapper around a
48   // xla::LocalClient pointer for convenient SWIG wrapping.
49 
50   // Returns the number of devices known to the XLA client.
51   int DeviceCount() const;
52 
53   // Wraps the local client's infeed-transfer function.
54   //
55   // The default device ordinal (0) is used.
56   Status TransferToInfeed(const Literal& literal, int device_ordinal);
57 
58   // Transfers a literal of the given shape from the outfeed of the given
59   // replica.
60   StatusOr<Literal> TransferFromOutfeed(const Shape& shape, int device_ordinal);
61 
client()62   xla::LocalClient* client() const { return client_; }
63 
64  private:
65   LocalClient(xla::LocalClient* client);
66 
67   xla::LocalClient* client_;
68 };
69 
70 class LocalShapedBufferTuple;
71 
72 // Represents a reference to literals that live in a device-allocated buffer via
73 // XLA. Specifically, wraps a ScopedShapedBuffer produced by transferring a
74 // literal to device via the local client.
75 class LocalShapedBuffer {
76  public:
77   static StatusOr<LocalShapedBuffer*> FromLiteral(
78       const Literal& argument, const absl::optional<Shape>& shape_with_layout,
79       const LocalClient& client, int device_ordinal);
80 
81   LocalShapedBuffer(ScopedShapedBuffer shaped_buffer, xla::LocalClient* client);
82   StatusOr<Literal> ToLiteral() const;
83   const Shape& shape() const;
84   const ScopedShapedBuffer* shaped_buffer() const;
85 
86   // Transfers ownership of the encapsulated ShapedBuffer to the caller,
87   // analogous to std::unique_ptr::release().
88   ShapedBuffer Release();
89 
90   // Destructures a tuple-valued LocalShapedBuffer into its constituent
91   // elements in LocalShapedBufferTuple form.
92   StatusOr<LocalShapedBufferTuple*> DestructureTuple();
93 
94  private:
95   ScopedShapedBuffer shaped_buffer_;
96   xla::LocalClient* client_;
97 };
98 
99 // Result of a tuple destructuring operation on a LocalShapedBuffer -- this
100 // appears to be a simpler mechanism for the time being than an alternative like
101 // using SWIG to transform std::vectors into Python lists of SWIG objects
102 // directly.
103 class LocalShapedBufferTuple {
104  public:
105   // Note: any LocalShapedBuffer elements that are not Release()'d will be
106   // deallocated in the destructor.
107   explicit LocalShapedBufferTuple(std::vector<LocalShapedBuffer*> elements);
108 
109   ~LocalShapedBufferTuple();
110 
111   // Releases the ith element to the caller. Further attempts to release the ith
112   // element will return an invalid argument error.
113   StatusOr<LocalShapedBuffer*> Release(int i);
114 
115   // Returns the number of elements in the destructured tuple.
116   int64 size() const;
117 
118  private:
119   std::vector<LocalShapedBuffer*> elements_;
120 };
121 
122 // Represents a compiled computation that can be executed given handles to
123 // device-allocated literals. Specifically, wraps an XLA LocalExecutable.
124 class LocalExecutable {
125  public:
126   LocalExecutable(std::unique_ptr<xla::LocalExecutable> executable,
127                   xla::DeviceAssignment device_assignment,
128                   xla::LocalClient* client);
129 
num_replicas()130   int num_replicas() const {
131     return executable_->build_options().num_replicas();
132   }
133 
134   // Returns the device ordinals to which each replica is assigned.
135   std::vector<int> DeviceOrdinals() const;
136 
137   StatusOr<LocalShapedBuffer*> Execute(
138       absl::Span<LocalShapedBuffer* const> argument_handles);
139 
140   // Execute on many replicas. Takes a sequence of argument lists (one argument
141   // list per replica) and returns a tuple of results (one result per replica).
142   // The number of argument lists must be equal to the replica count.
143   StatusOr<LocalShapedBufferTuple*> ExecutePerReplica(
144       absl::Span<const std::vector<LocalShapedBuffer*> > argument_handles);
145 
146  private:
147   const std::unique_ptr<xla::LocalExecutable> executable_;
148   const xla::DeviceAssignment device_assignment_;
149   xla::LocalClient* const client_;
150 };
151 
152 // Wraps a XlaComputation produced by a ComputationBuilder. The
153 // Compile method compiles the computation to a (local) executable via
154 // the client library's local client. This class is intended to be
155 // made available to Python via SWIG.
156 class Computation {
157  public:
158   Computation(XlaComputation computation);
159 
160   StatusOr<LocalExecutable*> Compile(
161       const std::vector<Shape>& argument_shapes,
162       const ExecutableBuildOptions* build_options, const LocalClient& client);
163 
164   const XlaComputation& computation() const;
165 
166   // Returns the HloModuleProto contained in the XlaComputation in the
167   // serialized binary format. Logs an internal error and returns an empty
168   // string on failure.
169   string GetSerializedProto() const;
170 
171   // Returns the computation in human-readable HLO text format.
172   StatusOr<string> GetHloText() const;
173 
174   // Returns the computation in graphviz dot format.
175   StatusOr<string> GetHloDotGraph() const;
176 
177   // Returns the program shape for this computation.
178   StatusOr<ProgramShape> GetProgramShape() const;
179 
180   // Returns the return-value shape for this computation.
181   StatusOr<Shape> GetReturnValueShape() const;
182 
183  private:
184   XlaComputation computation_;
185 };
186 
187 // Wraps a XlaOp produced by a ComputationBuilder. This class is intended
188 // to be made available to Python via SWIG.
189 class LocalOp {
190  public:
191   LocalOp(const XlaOp& op);
192 
193   const XlaOp& op() const;
194 
195  private:
196   XlaOp op_;
197 };
198 
199 // Wraps the ComputationBuilder API in order to:
200 // - Support consumption by SWIG in order to be made available to
201 //   Python.
202 // - Set up the underlying builder to use the client library's
203 //   LocalClient.
204 // - Wrap Computations in Computations for Python access.
205 // - Correspondingly unwrap incoming Computations.
206 class ComputationBuilder {
207  public:
208   ComputationBuilder(const string& computation_name);
209 
210   void SetOpMetadata(const OpMetadata& metadata);
211   void ClearOpMetadata();
212 
213   // Returns an owned Computation to the caller on success.
214   StatusOr<Computation*> Build();
215 
216   // Returns an owned Computation to the caller on success with given root.
217   StatusOr<Computation*> BuildWithRoot(const LocalOp& root);
218 
219   LocalOp Parameter(int64 parameter_number, const Shape& shape,
220                     const string& name);
221 
222   StatusOr<Shape> GetShape(const LocalOp& operand);
223 
224   // Returns the shape of the current return value for the computation.
225   StatusOr<Shape> GetReturnValueShape();
226 
227   LocalOp ReplicaId();
228 
229   LocalOp Infeed(const Shape& shape);
230 
231   void Outfeed(const LocalOp& operand, const Shape& shape,
232                const string& outfeed_config);
233 
234   LocalOp ConstantLiteral(const Literal& literal);
235 
236   LocalOp Iota(PrimitiveType element_type, int64 size);
237 
238   LocalOp BroadcastedIota(const Shape& shape, int64 dimension);
239 
240   LocalOp Broadcast(const LocalOp& operand,
241                     absl::Span<const int64> broadcast_sizes);
242 
243   LocalOp BroadcastInDim(const LocalOp& operand,
244                          absl::Span<const int64> out_dim_sizes,
245                          absl::Span<const int64> broadcast_dimensions);
246 
247   LocalOp Pad(const LocalOp& operand, const LocalOp& padding_value,
248               const PaddingConfig& padding_config);
249 
250   LocalOp Reshape(const LocalOp& operand, absl::Span<const int64> dimensions,
251                   absl::Span<const int64> new_sizes);
252 
253   LocalOp Collapse(const LocalOp& operand, absl::Span<const int64> dimensions);
254 
255   LocalOp AllToAll(const LocalOp& operand, int64 split_dimension,
256                    int64 concat_dimension, int64 split_count,
257                    absl::Span<const ReplicaGroup> replica_groups);
258 
259   LocalOp CrossReplicaSum(const LocalOp& operand,
260                           absl::Span<const ReplicaGroup> replica_groups);
261 
262   LocalOp Slice(const LocalOp& operand, absl::Span<const int64> start_indices,
263                 absl::Span<const int64> limit_indices,
264                 absl::Span<const int64> strides);
265 
266   LocalOp SliceInDim(const LocalOp& operand, int64 start_index,
267                      int64 limit_index, int64 stride, int64 dimno);
268 
269   LocalOp DynamicSlice(const LocalOp& operand, const LocalOp& start_indices,
270                        absl::Span<const int64> slice_sizes);
271 
272   LocalOp DynamicUpdateSlice(const LocalOp& operand, const LocalOp& update,
273                              const LocalOp& start_indices);
274 
275   LocalOp ConcatInDim(absl::Span<const LocalOp> operands, int64 dimension);
276 
277   LocalOp SelectAndScatterWithGeneralPadding(
278       const LocalOp& operand, const Computation& select,
279       absl::Span<const int64> window_dimensions,
280       absl::Span<const int64> window_strides,
281       absl::Span<const std::pair<int64, int64> > padding, const LocalOp& source,
282       const LocalOp& init_value, const Computation& scatter);
283 
284   LocalOp Tuple(absl::Span<const LocalOp> elements);
285 
286   LocalOp GetTupleElement(const LocalOp& tuple_data, int64 index);
287 
288   LocalOp Dot(const LocalOp& lhs, const LocalOp& rhs);
289 
290   LocalOp DotGeneral(const LocalOp& lhs, const LocalOp& rhs,
291                      const DotDimensionNumbers& dimension_numbers);
292 
293   LocalOp ConvGeneralDilated(
294       const LocalOp& lhs, const LocalOp& rhs,
295       absl::Span<const int64> window_strides,
296       absl::Span<const std::pair<int64, int64> > padding,
297       absl::Span<const int64> lhs_dilation,
298       absl::Span<const int64> rhs_dilation,
299       const ConvolutionDimensionNumbers& dimension_numbers,
300       int64 feature_group_count);
301 
302   LocalOp ConvertElementType(const LocalOp& operand,
303                              PrimitiveType new_element_type);
304 
305   LocalOp BitcastConvertType(const LocalOp& operand,
306                              PrimitiveType new_element_type);
307 
308   LocalOp Call(const Computation& local_computation,
309                absl::Span<const LocalOp> operands);
310 
311   LocalOp CustomCall(const string& call_target_name,
312                      absl::Span<const LocalOp> operands,
313                      const Shape& shape_with_layout,
314                      const std::vector<Shape>& operand_shapes_with_layout,
315                      const string& opaque);
316 
317   LocalOp Transpose(const LocalOp& operand,
318                     absl::Span<const int64> permutation);
319 
320   LocalOp Rev(const LocalOp& operand, absl::Span<const int64> dimensions);
321 
322   LocalOp Map(absl::Span<const LocalOp> operands,
323               const Computation& local_computation,
324               absl::Span<const int64> dimensions);
325 
326   LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value,
327                  const Computation& local_computation,
328                  absl::Span<const int64> dimensions_to_reduce);
329 
330   LocalOp ReduceWindowWithGeneralPadding(
331       const LocalOp& operand, const LocalOp& init_value,
332       const Computation& local_computation,
333       absl::Span<const int64> window_dimensions,
334       absl::Span<const int64> window_strides,
335       absl::Span<const int64> base_dilations,
336       absl::Span<const int64> window_dilations,
337       absl::Span<const std::pair<int64, int64> > padding);
338 
339   LocalOp RngNormal(const LocalOp& mu, const LocalOp& sigma,
340                     const Shape& shape);
341 
342   LocalOp RngUniform(const LocalOp& a, const LocalOp& b, const Shape& shape);
343 
344   LocalOp While(const Computation& condition, const Computation& body,
345                 const LocalOp& init);
346 
347   LocalOp Conditional(const LocalOp& predicate, const LocalOp& true_operand,
348                       const Computation& true_computation,
349                       const LocalOp& false_operand,
350                       const Computation& false_computation);
351 
352   StatusOr<bool> IsConstant(const LocalOp& operand);
353 
354   LocalOp Sort(const LocalOp& operand, int64 dimension);
355 
356   LocalOp SortKeyVal(const LocalOp& keys, const LocalOp& values,
357                      int64 dimension);
358 
359   LocalOp QR(const LocalOp& a, bool full_matrices);
360 
361   LocalOp Cholesky(const LocalOp& a, bool lower);
362 
363   LocalOp Eigh(const LocalOp& a, bool lower);
364 
365   LocalOp SVD(const LocalOp& a);
366 
367   // `transpose_a` is the integer value of a TriangularSolveOptions::Transpose
368   // enum. We use an integer here so we don't have to teach SWIG about the
369   // enum.
370   LocalOp TriangularSolve(const LocalOp& a, const LocalOp& b, bool left_side,
371                           bool lower, bool unit_diagonal, int transpose_a);
372 
373   LocalOp Gather(const LocalOp& input, const LocalOp& start_indices,
374                  const GatherDimensionNumbers& dimension_numbers,
375                  absl::Span<const int64> slice_sizes);
376 
377   LocalOp Scatter(const LocalOp& input, const LocalOp& scatter_indices,
378                   const LocalOp& updates, const Computation& update_computation,
379                   const ScatterDimensionNumbers& dimension_numbers);
380 
381   StatusOr<Computation*> BuildConstantSubGraph(const LocalOp& operand);
382 
383 #define _FORWARD(method_name, return_sig, args_sig) \
384   return_sig method_name args_sig;
385 
386 #define _FORWARD_UNOP(method_name) \
387   _FORWARD(method_name, LocalOp, (const LocalOp& operand))
388 
389 #define _FORWARD_BINOP(method_name)                 \
390   _FORWARD(method_name, LocalOp,                    \
391            (const LocalOp& lhs, const LocalOp& rhs, \
392             absl::Span<const int64> broadcast_dimensions))
393 
394 #define _FORWARD_TRIOP(method_name) \
395   _FORWARD(method_name, LocalOp,    \
396            (const LocalOp& lhs, const LocalOp& rhs, const LocalOp& ehs))
397 
398   _FORWARD_TRIOP(Select)
399   _FORWARD_TRIOP(Clamp)
400   _FORWARD_BINOP(Eq)
401   _FORWARD_BINOP(Ne)
402   _FORWARD_BINOP(Ge)
403   _FORWARD_BINOP(Gt)
404   _FORWARD_BINOP(Lt)
405   _FORWARD_BINOP(Le)
406   _FORWARD_BINOP(Add)
407   _FORWARD_BINOP(Sub)
408   _FORWARD_BINOP(Mul)
409   _FORWARD_BINOP(Div)
410   _FORWARD_BINOP(Rem)
411   _FORWARD_BINOP(Max)
412   _FORWARD_BINOP(Min)
413   _FORWARD_BINOP(And)
414   _FORWARD_BINOP(Or)
415   _FORWARD_BINOP(Xor)
416   _FORWARD_BINOP(ShiftLeft)
417   _FORWARD_BINOP(ShiftRightArithmetic)
418   _FORWARD_BINOP(ShiftRightLogical)
419   _FORWARD_BINOP(Atan2)
420   _FORWARD_BINOP(Pow)
421   _FORWARD_BINOP(Complex)
422   _FORWARD_UNOP(Not)
423   _FORWARD_UNOP(Clz)
424   _FORWARD_UNOP(Abs)
425   _FORWARD_UNOP(Exp)
426   _FORWARD_UNOP(Expm1)
427   _FORWARD_UNOP(Floor)
428   _FORWARD_UNOP(Ceil)
429   _FORWARD_UNOP(Round)
430   _FORWARD_UNOP(Log)
431   _FORWARD_UNOP(Log1p)
432   _FORWARD_UNOP(Sign)
433   _FORWARD_UNOP(Cos)
434   _FORWARD_UNOP(Sin)
435   _FORWARD_UNOP(Tanh)
436   _FORWARD_UNOP(IsFinite)
437   _FORWARD_UNOP(Neg)
438   _FORWARD_UNOP(Sqrt)
439   _FORWARD_UNOP(Rsqrt)
440   _FORWARD_UNOP(Square)
441   _FORWARD_UNOP(Reciprocal)
442   _FORWARD_UNOP(Erfc)
443   _FORWARD_UNOP(Erf)
444   _FORWARD_UNOP(ErfInv)
445   _FORWARD_UNOP(Lgamma)
446   _FORWARD_UNOP(Digamma)
447   _FORWARD_UNOP(Acos)
448   _FORWARD_UNOP(Asin)
449   _FORWARD_UNOP(Atan)
450   _FORWARD_UNOP(Tan)
451   _FORWARD_UNOP(Acosh)
452   _FORWARD_UNOP(Asinh)
453   _FORWARD_UNOP(Atanh)
454   _FORWARD_UNOP(Cosh)
455   _FORWARD_UNOP(Sinh)
456   _FORWARD_UNOP(Real)
457   _FORWARD_UNOP(Imag)
458   _FORWARD_UNOP(Conj)
459 
460 #undef _FORWARD
461 #undef _FORWARD_UNOP
462 #undef _FORWARD_BINOP
463 #undef _FORWARD_TRIOP
464 
465  private:
466   XlaBuilder builder_;
467 };
468 
469 // Functions for freeing resources from the Python side.
470 void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer);
471 void DeleteLocalExecutable(LocalExecutable* computation);
472 void DeleteComputation(Computation* computation);
473 
474 }  // namespace swig
475 }  // namespace xla
476 
477 #endif  // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_
478