1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ 17 #define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ 18 19 #include <string> 20 #include <vector> 21 22 #include <Python.h> 23 24 #include "absl/types/span.h" 25 #include "tensorflow/compiler/xla/client/executable_build_options.h" 26 #include "tensorflow/compiler/xla/client/local_client.h" 27 #include "tensorflow/compiler/xla/client/xla_builder.h" 28 #include "tensorflow/compiler/xla/client/xla_computation.h" 29 #include "tensorflow/compiler/xla/service/shaped_buffer.h" 30 #include "tensorflow/compiler/xla/xla_data.pb.h" 31 32 namespace xla { 33 namespace swig { 34 35 // Registers a 'fn_capsule' as a CPU custom call target. 36 // 'fn_capsule' is a void* pointer encapsulated in a PyCapsule object, with name 37 // "xla._CPU_CUSTOM_CALL_TARGET". 38 Status RegisterCpuCustomCallTarget(const string& name, PyObject* fn_capsule); 39 40 // Wrapper around an xla::LocalClient. 41 class LocalClient { 42 public: 43 // Initializes a local XLA client for `platform_name`. Returns an error if no 44 /// such platform exists, or if the platform has no visible devices. 45 static StatusOr<LocalClient> Get(const string& platform_name); 46 47 // Copyable and moveable; the class is just a wrapper around a 48 // xla::LocalClient pointer for convenient SWIG wrapping. 49 50 // Returns the number of devices known to the XLA client. 51 int DeviceCount() const; 52 53 // Wraps the local client's infeed-transfer function. 54 // 55 // The default device ordinal (0) is used. 56 Status TransferToInfeed(const Literal& literal, int device_ordinal); 57 58 // Transfers a literal of the given shape from the outfeed of the given 59 // replica. 60 StatusOr<Literal> TransferFromOutfeed(const Shape& shape, int device_ordinal); 61 client()62 xla::LocalClient* client() const { return client_; } 63 64 private: 65 LocalClient(xla::LocalClient* client); 66 67 xla::LocalClient* client_; 68 }; 69 70 class LocalShapedBufferTuple; 71 72 // Represents a reference to literals that live in a device-allocated buffer via 73 // XLA. Specifically, wraps a ScopedShapedBuffer produced by transferring a 74 // literal to device via the local client. 75 class LocalShapedBuffer { 76 public: 77 static StatusOr<LocalShapedBuffer*> FromLiteral( 78 const Literal& argument, const absl::optional<Shape>& shape_with_layout, 79 const LocalClient& client, int device_ordinal); 80 81 LocalShapedBuffer(ScopedShapedBuffer shaped_buffer, xla::LocalClient* client); 82 StatusOr<Literal> ToLiteral() const; 83 const Shape& shape() const; 84 const ScopedShapedBuffer* shaped_buffer() const; 85 86 // Transfers ownership of the encapsulated ShapedBuffer to the caller, 87 // analogous to std::unique_ptr::release(). 88 ShapedBuffer Release(); 89 90 // Destructures a tuple-valued LocalShapedBuffer into its constituent 91 // elements in LocalShapedBufferTuple form. 92 StatusOr<LocalShapedBufferTuple*> DestructureTuple(); 93 94 private: 95 ScopedShapedBuffer shaped_buffer_; 96 xla::LocalClient* client_; 97 }; 98 99 // Result of a tuple destructuring operation on a LocalShapedBuffer -- this 100 // appears to be a simpler mechanism for the time being than an alternative like 101 // using SWIG to transform std::vectors into Python lists of SWIG objects 102 // directly. 103 class LocalShapedBufferTuple { 104 public: 105 // Note: any LocalShapedBuffer elements that are not Release()'d will be 106 // deallocated in the destructor. 107 explicit LocalShapedBufferTuple(std::vector<LocalShapedBuffer*> elements); 108 109 ~LocalShapedBufferTuple(); 110 111 // Releases the ith element to the caller. Further attempts to release the ith 112 // element will return an invalid argument error. 113 StatusOr<LocalShapedBuffer*> Release(int i); 114 115 // Returns the number of elements in the destructured tuple. 116 int64 size() const; 117 118 private: 119 std::vector<LocalShapedBuffer*> elements_; 120 }; 121 122 // Represents a compiled computation that can be executed given handles to 123 // device-allocated literals. Specifically, wraps an XLA LocalExecutable. 124 class LocalExecutable { 125 public: 126 LocalExecutable(std::unique_ptr<xla::LocalExecutable> executable, 127 xla::DeviceAssignment device_assignment, 128 xla::LocalClient* client); 129 num_replicas()130 int num_replicas() const { 131 return executable_->build_options().num_replicas(); 132 } 133 134 // Returns the device ordinals to which each replica is assigned. 135 std::vector<int> DeviceOrdinals() const; 136 137 StatusOr<LocalShapedBuffer*> Execute( 138 absl::Span<LocalShapedBuffer* const> argument_handles); 139 140 // Execute on many replicas. Takes a sequence of argument lists (one argument 141 // list per replica) and returns a tuple of results (one result per replica). 142 // The number of argument lists must be equal to the replica count. 143 StatusOr<LocalShapedBufferTuple*> ExecutePerReplica( 144 absl::Span<const std::vector<LocalShapedBuffer*> > argument_handles); 145 146 private: 147 const std::unique_ptr<xla::LocalExecutable> executable_; 148 const xla::DeviceAssignment device_assignment_; 149 xla::LocalClient* const client_; 150 }; 151 152 // Wraps a XlaComputation produced by a ComputationBuilder. The 153 // Compile method compiles the computation to a (local) executable via 154 // the client library's local client. This class is intended to be 155 // made available to Python via SWIG. 156 class Computation { 157 public: 158 Computation(XlaComputation computation); 159 160 StatusOr<LocalExecutable*> Compile( 161 const std::vector<Shape>& argument_shapes, 162 const ExecutableBuildOptions* build_options, const LocalClient& client); 163 164 const XlaComputation& computation() const; 165 166 // Returns the HloModuleProto contained in the XlaComputation in the 167 // serialized binary format. Logs an internal error and returns an empty 168 // string on failure. 169 string GetSerializedProto() const; 170 171 // Returns the computation in human-readable HLO text format. 172 StatusOr<string> GetHloText() const; 173 174 // Returns the computation in graphviz dot format. 175 StatusOr<string> GetHloDotGraph() const; 176 177 // Returns the program shape for this computation. 178 StatusOr<ProgramShape> GetProgramShape() const; 179 180 // Returns the return-value shape for this computation. 181 StatusOr<Shape> GetReturnValueShape() const; 182 183 private: 184 XlaComputation computation_; 185 }; 186 187 // Wraps a XlaOp produced by a ComputationBuilder. This class is intended 188 // to be made available to Python via SWIG. 189 class LocalOp { 190 public: 191 LocalOp(const XlaOp& op); 192 193 const XlaOp& op() const; 194 195 private: 196 XlaOp op_; 197 }; 198 199 // Wraps the ComputationBuilder API in order to: 200 // - Support consumption by SWIG in order to be made available to 201 // Python. 202 // - Set up the underlying builder to use the client library's 203 // LocalClient. 204 // - Wrap Computations in Computations for Python access. 205 // - Correspondingly unwrap incoming Computations. 206 class ComputationBuilder { 207 public: 208 ComputationBuilder(const string& computation_name); 209 210 void SetOpMetadata(const OpMetadata& metadata); 211 void ClearOpMetadata(); 212 213 // Returns an owned Computation to the caller on success. 214 StatusOr<Computation*> Build(); 215 216 // Returns an owned Computation to the caller on success with given root. 217 StatusOr<Computation*> BuildWithRoot(const LocalOp& root); 218 219 LocalOp Parameter(int64 parameter_number, const Shape& shape, 220 const string& name); 221 222 StatusOr<Shape> GetShape(const LocalOp& operand); 223 224 // Returns the shape of the current return value for the computation. 225 StatusOr<Shape> GetReturnValueShape(); 226 227 LocalOp ReplicaId(); 228 229 LocalOp Infeed(const Shape& shape); 230 231 void Outfeed(const LocalOp& operand, const Shape& shape, 232 const string& outfeed_config); 233 234 LocalOp ConstantLiteral(const Literal& literal); 235 236 LocalOp Iota(PrimitiveType element_type, int64 size); 237 238 LocalOp BroadcastedIota(const Shape& shape, int64 dimension); 239 240 LocalOp Broadcast(const LocalOp& operand, 241 absl::Span<const int64> broadcast_sizes); 242 243 LocalOp BroadcastInDim(const LocalOp& operand, 244 absl::Span<const int64> out_dim_sizes, 245 absl::Span<const int64> broadcast_dimensions); 246 247 LocalOp Pad(const LocalOp& operand, const LocalOp& padding_value, 248 const PaddingConfig& padding_config); 249 250 LocalOp Reshape(const LocalOp& operand, absl::Span<const int64> dimensions, 251 absl::Span<const int64> new_sizes); 252 253 LocalOp Collapse(const LocalOp& operand, absl::Span<const int64> dimensions); 254 255 LocalOp AllToAll(const LocalOp& operand, int64 split_dimension, 256 int64 concat_dimension, int64 split_count, 257 absl::Span<const ReplicaGroup> replica_groups); 258 259 LocalOp CrossReplicaSum(const LocalOp& operand, 260 absl::Span<const ReplicaGroup> replica_groups); 261 262 LocalOp Slice(const LocalOp& operand, absl::Span<const int64> start_indices, 263 absl::Span<const int64> limit_indices, 264 absl::Span<const int64> strides); 265 266 LocalOp SliceInDim(const LocalOp& operand, int64 start_index, 267 int64 limit_index, int64 stride, int64 dimno); 268 269 LocalOp DynamicSlice(const LocalOp& operand, const LocalOp& start_indices, 270 absl::Span<const int64> slice_sizes); 271 272 LocalOp DynamicUpdateSlice(const LocalOp& operand, const LocalOp& update, 273 const LocalOp& start_indices); 274 275 LocalOp ConcatInDim(absl::Span<const LocalOp> operands, int64 dimension); 276 277 LocalOp SelectAndScatterWithGeneralPadding( 278 const LocalOp& operand, const Computation& select, 279 absl::Span<const int64> window_dimensions, 280 absl::Span<const int64> window_strides, 281 absl::Span<const std::pair<int64, int64> > padding, const LocalOp& source, 282 const LocalOp& init_value, const Computation& scatter); 283 284 LocalOp Tuple(absl::Span<const LocalOp> elements); 285 286 LocalOp GetTupleElement(const LocalOp& tuple_data, int64 index); 287 288 LocalOp Dot(const LocalOp& lhs, const LocalOp& rhs); 289 290 LocalOp DotGeneral(const LocalOp& lhs, const LocalOp& rhs, 291 const DotDimensionNumbers& dimension_numbers); 292 293 LocalOp ConvGeneralDilated( 294 const LocalOp& lhs, const LocalOp& rhs, 295 absl::Span<const int64> window_strides, 296 absl::Span<const std::pair<int64, int64> > padding, 297 absl::Span<const int64> lhs_dilation, 298 absl::Span<const int64> rhs_dilation, 299 const ConvolutionDimensionNumbers& dimension_numbers, 300 int64 feature_group_count); 301 302 LocalOp ConvertElementType(const LocalOp& operand, 303 PrimitiveType new_element_type); 304 305 LocalOp BitcastConvertType(const LocalOp& operand, 306 PrimitiveType new_element_type); 307 308 LocalOp Call(const Computation& local_computation, 309 absl::Span<const LocalOp> operands); 310 311 LocalOp CustomCall(const string& call_target_name, 312 absl::Span<const LocalOp> operands, 313 const Shape& shape_with_layout, 314 const std::vector<Shape>& operand_shapes_with_layout, 315 const string& opaque); 316 317 LocalOp Transpose(const LocalOp& operand, 318 absl::Span<const int64> permutation); 319 320 LocalOp Rev(const LocalOp& operand, absl::Span<const int64> dimensions); 321 322 LocalOp Map(absl::Span<const LocalOp> operands, 323 const Computation& local_computation, 324 absl::Span<const int64> dimensions); 325 326 LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value, 327 const Computation& local_computation, 328 absl::Span<const int64> dimensions_to_reduce); 329 330 LocalOp ReduceWindowWithGeneralPadding( 331 const LocalOp& operand, const LocalOp& init_value, 332 const Computation& local_computation, 333 absl::Span<const int64> window_dimensions, 334 absl::Span<const int64> window_strides, 335 absl::Span<const int64> base_dilations, 336 absl::Span<const int64> window_dilations, 337 absl::Span<const std::pair<int64, int64> > padding); 338 339 LocalOp RngNormal(const LocalOp& mu, const LocalOp& sigma, 340 const Shape& shape); 341 342 LocalOp RngUniform(const LocalOp& a, const LocalOp& b, const Shape& shape); 343 344 LocalOp While(const Computation& condition, const Computation& body, 345 const LocalOp& init); 346 347 LocalOp Conditional(const LocalOp& predicate, const LocalOp& true_operand, 348 const Computation& true_computation, 349 const LocalOp& false_operand, 350 const Computation& false_computation); 351 352 StatusOr<bool> IsConstant(const LocalOp& operand); 353 354 LocalOp Sort(const LocalOp& operand, int64 dimension); 355 356 LocalOp SortKeyVal(const LocalOp& keys, const LocalOp& values, 357 int64 dimension); 358 359 LocalOp QR(const LocalOp& a, bool full_matrices); 360 361 LocalOp Cholesky(const LocalOp& a, bool lower); 362 363 LocalOp Eigh(const LocalOp& a, bool lower); 364 365 LocalOp SVD(const LocalOp& a); 366 367 // `transpose_a` is the integer value of a TriangularSolveOptions::Transpose 368 // enum. We use an integer here so we don't have to teach SWIG about the 369 // enum. 370 LocalOp TriangularSolve(const LocalOp& a, const LocalOp& b, bool left_side, 371 bool lower, bool unit_diagonal, int transpose_a); 372 373 LocalOp Gather(const LocalOp& input, const LocalOp& start_indices, 374 const GatherDimensionNumbers& dimension_numbers, 375 absl::Span<const int64> slice_sizes); 376 377 LocalOp Scatter(const LocalOp& input, const LocalOp& scatter_indices, 378 const LocalOp& updates, const Computation& update_computation, 379 const ScatterDimensionNumbers& dimension_numbers); 380 381 StatusOr<Computation*> BuildConstantSubGraph(const LocalOp& operand); 382 383 #define _FORWARD(method_name, return_sig, args_sig) \ 384 return_sig method_name args_sig; 385 386 #define _FORWARD_UNOP(method_name) \ 387 _FORWARD(method_name, LocalOp, (const LocalOp& operand)) 388 389 #define _FORWARD_BINOP(method_name) \ 390 _FORWARD(method_name, LocalOp, \ 391 (const LocalOp& lhs, const LocalOp& rhs, \ 392 absl::Span<const int64> broadcast_dimensions)) 393 394 #define _FORWARD_TRIOP(method_name) \ 395 _FORWARD(method_name, LocalOp, \ 396 (const LocalOp& lhs, const LocalOp& rhs, const LocalOp& ehs)) 397 398 _FORWARD_TRIOP(Select) 399 _FORWARD_TRIOP(Clamp) 400 _FORWARD_BINOP(Eq) 401 _FORWARD_BINOP(Ne) 402 _FORWARD_BINOP(Ge) 403 _FORWARD_BINOP(Gt) 404 _FORWARD_BINOP(Lt) 405 _FORWARD_BINOP(Le) 406 _FORWARD_BINOP(Add) 407 _FORWARD_BINOP(Sub) 408 _FORWARD_BINOP(Mul) 409 _FORWARD_BINOP(Div) 410 _FORWARD_BINOP(Rem) 411 _FORWARD_BINOP(Max) 412 _FORWARD_BINOP(Min) 413 _FORWARD_BINOP(And) 414 _FORWARD_BINOP(Or) 415 _FORWARD_BINOP(Xor) 416 _FORWARD_BINOP(ShiftLeft) 417 _FORWARD_BINOP(ShiftRightArithmetic) 418 _FORWARD_BINOP(ShiftRightLogical) 419 _FORWARD_BINOP(Atan2) 420 _FORWARD_BINOP(Pow) 421 _FORWARD_BINOP(Complex) 422 _FORWARD_UNOP(Not) 423 _FORWARD_UNOP(Clz) 424 _FORWARD_UNOP(Abs) 425 _FORWARD_UNOP(Exp) 426 _FORWARD_UNOP(Expm1) 427 _FORWARD_UNOP(Floor) 428 _FORWARD_UNOP(Ceil) 429 _FORWARD_UNOP(Round) 430 _FORWARD_UNOP(Log) 431 _FORWARD_UNOP(Log1p) 432 _FORWARD_UNOP(Sign) 433 _FORWARD_UNOP(Cos) 434 _FORWARD_UNOP(Sin) 435 _FORWARD_UNOP(Tanh) 436 _FORWARD_UNOP(IsFinite) 437 _FORWARD_UNOP(Neg) 438 _FORWARD_UNOP(Sqrt) 439 _FORWARD_UNOP(Rsqrt) 440 _FORWARD_UNOP(Square) 441 _FORWARD_UNOP(Reciprocal) 442 _FORWARD_UNOP(Erfc) 443 _FORWARD_UNOP(Erf) 444 _FORWARD_UNOP(ErfInv) 445 _FORWARD_UNOP(Lgamma) 446 _FORWARD_UNOP(Digamma) 447 _FORWARD_UNOP(Acos) 448 _FORWARD_UNOP(Asin) 449 _FORWARD_UNOP(Atan) 450 _FORWARD_UNOP(Tan) 451 _FORWARD_UNOP(Acosh) 452 _FORWARD_UNOP(Asinh) 453 _FORWARD_UNOP(Atanh) 454 _FORWARD_UNOP(Cosh) 455 _FORWARD_UNOP(Sinh) 456 _FORWARD_UNOP(Real) 457 _FORWARD_UNOP(Imag) 458 _FORWARD_UNOP(Conj) 459 460 #undef _FORWARD 461 #undef _FORWARD_UNOP 462 #undef _FORWARD_BINOP 463 #undef _FORWARD_TRIOP 464 465 private: 466 XlaBuilder builder_; 467 }; 468 469 // Functions for freeing resources from the Python side. 470 void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer); 471 void DeleteLocalExecutable(LocalExecutable* computation); 472 void DeleteComputation(Computation* computation); 473 474 } // namespace swig 475 } // namespace xla 476 477 #endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ 478