1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // SWIG typemaps for building, compiling, and executing XLA computations. 17 // 18 // The typemaps below implement/assert the following correspondences 19 // (with elaborations below): 20 // 21 // C++ Python 22 // -------------------------------------+--------------------------------------- 23 // Span<int64> <- sequence of int 24 // vector<int> -> sequence of int 25 // Span<LocalOp> <- sequence of LocalOp 26 // Literal <-> (nested tuple of) numpy ndarray 27 // std::vector<Literal> <- sequence of (nested tuple of) ndarray 28 // Shape -> pair holding (dtype, dimensions) 29 // <- object duck-typed as xla_client.Shape 30 // ProgramShape -> pair of ([arg_shapes], ret_shape) 31 // std::vector<Shape> <- sequence of xla_client.Shape objects 32 // PrimitiveType <- int 33 // Span<pair<int64, in64>> <- sequence of int pairs 34 // PaddingConfig proto <- ducktyped Python proto 35 // ConvolutionDimensionNumbers proto <- ducktyped Python proto 36 // DotDimensionNumbers proto <- ducktyped Python proto 37 // GatherDimensionNumbers proto <- ducktyped Python proto 38 // ScatterDimensionNumbers proto <- ducktyped Python proto 39 // Span<ReplicaGroup proto> <- sequence of ReplicaGroup Python proto 40 // 41 // Arrows indicate whether a conversion only ever occurs in one 42 // direction, or whether it is maintained bidirectionally. 43 // 44 // The Python objects corresponding to C++ Literals have the type: 45 // 46 // T = ndarray | (T, ...) 47 // 48 // where a terminal numpy ndarray translates to a Literal with a 49 // non-tuple Shape, an XLA primitive element type corresponding to the 50 // ndarray's dtype. Meanwhile, a non-terminal "tuple of T" translates 51 // to a tuple-shaped Literal whose tuple components are translated 52 // recursively. For example, if x is a numpy ndarray in Python, with 53 // shape (2, 3) and dtype of dtype('float32'), then x translates to a 54 // Literal with rank 2, dimension 2 and 3, and XLA primitive type 55 // F32. Meanwhile, 56 // 57 // (x, (x, x), (x,)), 58 // 59 // translates to a tuple-shaped XLA Literal, whose component subshapes 60 // are a 2x3 F32-shaped literal followed by two tuple-shaped literals. 61 // 62 // Shapes output by C++ become Python objects with the type: 63 // 64 // T = (dtype, S) 65 // S = DIMENSIONS | TUPLE_SHAPES 66 // DIMENSIONS = (int, ...) 67 // TUPLE_SHAPES = (T, ...) 68 // 69 // In the pair described by the T rule, the terminal dtype determines 70 // whether S expands as DIMENSIONS or TUPLE_SHAPES. Namely if it is 71 // dtype('O'), numpy's object dtype, the structure represents a tuple 72 // shape and the expansion of the non-terminal S is 73 // TUPLE_SHAPES. Otherwise, dtype describes a primitive element type 74 // and S expands into DIMENSIONS giving dimension sizes. For example: 75 // 76 // (dtype('float32'), (3, 5, 7)) 77 // 78 // describes a 3x5x7 array of F32s, and 79 // 80 // (dtype('O'), ((dtype('float32'), (2, 3)), 81 // (dtype('float64'), (4, 5)))) 82 // 83 // describes a tuple shape with two subshapes: the first a 2x3 F32, 84 // and the other a 4x5 F64. 85 // 86 // The Python int corresponding to a PrimitiveType enum must be valid 87 // per xla_data.proto (e.g. xla_data.PRED, xla_data.F32). 88 // 89 // The SWIG object wrappers generated by this file are not intended 90 // for end use, but rather for internal use in the Python XLA client, 91 // xla_client.py. 92 // 93 // One central reason for the Python-side indirection is that the 94 // Python-side objects produced by the typemaps in this file are 95 // further packaged up by xla_client before being passed on. For 96 // instance, the Python pair produced for a C++ Shape is further 97 // wrapped in a Python class (xla_client.Shape) so as not to expose 98 // the raw pair externally. 99 // 100 // Other SWIG object wrappers (e.g. of Computation) are further 101 // wrapped by xla_client in order to set up a custom destructor that 102 // triggers memory deallocation on the C++ side. 103 // 104 105 106 %module(threads="1") xla_data 107 108 // Keep the GIL except where explicitly specified. 109 %nothread; 110 111 %include "tensorflow/python/platform/base.i" 112 113 %{ 114 // Must be included first 115 #include "tensorflow/python/lib/core/numpy.h" 116 117 #include "absl/strings/str_cat.h" 118 #include "absl/strings/str_format.h" 119 #include "tensorflow/compiler/xla/literal.h" 120 #include "tensorflow/compiler/xla/shape_util.h" 121 #include "tensorflow/compiler/xla/xla_data.pb.h" 122 #include "absl/types/span.h" 123 #include "tensorflow/compiler/xla/python/numpy_bridge.h" 124 125 using namespace xla; 126 using namespace xla::swig; 127 128 %} 129 130 // Basic types 131 132 133 %typemap(out) std::vector<int> { 134 PyObject* out = PyList_New($1.size()); 135 for (int i = 0; i < $1.size(); ++i) { 136 PyList_SET_ITEM(out, i, PyInt_FromLong($1[i])); 137 } 138 $result = out; 139 } 140 141 %typemap(out) StatusOr<bool> { 142 if ($1.ok()) { 143 $result = PyBool_FromLong($1.ConsumeValueOrDie()); 144 } else { 145 PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); 146 SWIG_fail; 147 } 148 } 149 150 %typemap(out) StatusOr<string> { 151 if ($1.ok()) { 152 $result = PyString_FromString($1.ConsumeValueOrDie().c_str()); 153 } else { 154 PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); 155 SWIG_fail; 156 } 157 } 158 159 %typemap(out) Status { 160 if (!$1.ok()) { 161 PyErr_SetString( 162 PyExc_RuntimeError, $1.ToString().c_str()); 163 SWIG_fail; 164 } 165 Py_INCREF(Py_None); 166 $result = Py_None; 167 } 168 169 %typemap(in) absl::Span<const int64> 170 (std::vector<int64> temps) { 171 if (!PySequence_Check($input)) { 172 PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); 173 SWIG_fail; 174 } 175 const int size = PySequence_Size($input); 176 temps.resize(size); 177 for (int i = 0; i < size; ++i) { 178 PyObject* o = PySequence_GetItem($input, i); 179 PyObject* py_int = numpy::PyNumberToPyInt(o); 180 if (!py_int) { 181 PyErr_SetString( 182 PyExc_TypeError, 183 "Argument sequence element cannot be converted to int"); 184 Py_DECREF(o); 185 SWIG_fail; 186 } 187 temps[i] = numpy::PyIntOrPyLongToLong(py_int); 188 if (temps[i] == -1 && PyErr_Occurred()) { 189 Py_DECREF(py_int); 190 Py_DECREF(o); 191 SWIG_fail; 192 } 193 Py_DECREF(py_int); 194 Py_DECREF(o); 195 } 196 $1 = temps; 197 } 198 199 // Literal 200 201 %typemap(in) const Literal& (StatusOr<Literal> literal_status) { 202 literal_status = numpy::XlaLiteralFromPyObject($input); 203 if (!literal_status.ok()) { 204 PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); 205 SWIG_fail; 206 } 207 $1 = &literal_status.ValueOrDie(); 208 } 209 210 %typemap(out) Literal (StatusOr<numpy::Safe_PyObjectPtr> obj_status) { 211 obj_status = numpy::PyObjectFromXlaLiteral(*$1); 212 if (!obj_status.ok()) { 213 PyErr_SetString(PyExc_RuntimeError, obj_status.status().ToString().c_str()); 214 SWIG_fail; 215 } 216 $result = obj_status.ValueOrDie().release(); 217 } 218 219 %typemap(out) StatusOr<Literal> (StatusOr<numpy::Safe_PyObjectPtr> obj_status) { 220 if (!$1.ok()) { 221 PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); 222 SWIG_fail; 223 } 224 obj_status = numpy::PyObjectFromXlaLiteral($1.ValueOrDie()); 225 if (!obj_status.ok()) { 226 PyErr_SetString(PyExc_RuntimeError, obj_status.status().ToString().c_str()); 227 SWIG_fail; 228 } 229 $result = obj_status.ValueOrDie().release(); 230 } 231 232 %typemap(in) const std::vector<Literal>& (std::vector<Literal> temps) { 233 if (!PySequence_Check($input)) { 234 PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); 235 SWIG_fail; 236 } 237 const int size = PySequence_Size($input); 238 for (int i = 0; i < size; ++i) { 239 PyObject* o = PySequence_GetItem($input, i); 240 StatusOr<Literal> literal_status = numpy::XlaLiteralFromPyObject(o); 241 if (!literal_status.ok()) { 242 PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); 243 Py_DECREF(o); 244 SWIG_fail; 245 } 246 temps.push_back(literal_status.ConsumeValueOrDie()); 247 Py_DECREF(o); 248 } 249 $1 = &temps; 250 } 251 252 // OpMetadata 253 254 %typemap(in) const OpMetadata& (OpMetadata temp) { 255 StatusOr<OpMetadata> statusor = numpy::OpMetadataFromPyObject($input); 256 if (!statusor.ok()) { 257 PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); 258 SWIG_fail; 259 } 260 temp = std::move(statusor).ValueOrDie(); 261 $1 = &temp; 262 } 263 264 // Shape 265 266 %typemap(out) const Shape& { 267 $result = numpy::PyShapeInfoFromXlaShape(*$1).release(); 268 } 269 270 %typemap(out) StatusOr<Shape> { 271 if ($1.ok()) { 272 $result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie()).release(); 273 } else { 274 PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); 275 SWIG_fail; 276 } 277 } 278 279 280 %typemap(out) StatusOr<ProgramShape> { 281 if ($1.ok()) { 282 $result = numpy::PyProgramShapeInfoFromXlaProgramShape( 283 $1.ConsumeValueOrDie()).release(); 284 } else { 285 PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); 286 SWIG_fail; 287 } 288 } 289 290 291 %typemap(in) const Shape& (Shape temp) { 292 StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape($input); 293 if (!statusor.ok()) { 294 PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); 295 SWIG_fail; 296 } 297 temp = std::move(statusor).ValueOrDie(); 298 $1 = &temp; 299 } 300 301 %typemap(in) const absl::optional<Shape>& ( 302 absl::optional<Shape> temp) { 303 if ($input == Py_None) { 304 temp = absl::nullopt; 305 $1 = &temp; 306 } else { 307 StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape($input); 308 if (!statusor.ok()) { 309 PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); 310 SWIG_fail; 311 } 312 temp = std::move(statusor).ValueOrDie(); 313 $1 = &temp; 314 } 315 } 316 317 %typemap(out) std::unique_ptr<Shape> { 318 $result = numpy::PyShapeInfoFromXlaShape(*$1).release(); 319 } 320 321 %typemap(in) const std::vector<Shape>& (std::vector<Shape> temps) { 322 if (!PySequence_Check($input)) { 323 PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); 324 SWIG_fail; 325 } 326 const int size = PySequence_Size($input); 327 for (int i = 0; i < size; ++i) { 328 PyObject* o = PySequence_GetItem($input, i); 329 StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape(o); 330 Py_DECREF(o); 331 if (!statusor.ok()) { 332 PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); 333 SWIG_fail; 334 } 335 temps.push_back(statusor.ConsumeValueOrDie()); 336 } 337 $1 = &temps; 338 } 339 340 %typemap(in) const std::vector<absl::optional<Shape> >& ( 341 std::vector<absl::optional<Shape> > temps) { 342 if (!PySequence_Check($input)) { 343 PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); 344 SWIG_fail; 345 } 346 const int size = PySequence_Size($input); 347 for (int i = 0; i < size; ++i) { 348 PyObject* o = PySequence_GetItem($input, i); 349 if (o == Py_None) { 350 temps.push_back(absl::nullopt); 351 } else { 352 StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape(o); 353 Py_DECREF(o); 354 if (!statusor.ok()) { 355 PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); 356 SWIG_fail; 357 } 358 temps.push_back(statusor.ConsumeValueOrDie()); 359 } 360 } 361 $1 = &temps; 362 } 363 364 // PrimitiveType 365 366 %typemap(in) PrimitiveType { 367 PyObject* py_int = numpy::PyNumberToPyInt($input); 368 if (!py_int) { 369 PyErr_SetString(PyExc_TypeError, "Argument cannot be converted to int"); 370 SWIG_fail; 371 } 372 const long value = numpy::PyIntOrPyLongToLong(py_int); 373 if (value == -1 && PyErr_Occurred()) { 374 Py_DECREF(py_int); 375 SWIG_fail; 376 } 377 if (!PrimitiveType_IsValid(value)) { 378 PyErr_SetString( 379 PyExc_TypeError, "Argument not valid for PrimitiveType enum"); 380 Py_DECREF(py_int); 381 SWIG_fail; 382 } 383 $1 = static_cast<PrimitiveType>(value); 384 } 385 386 // Span<pair<int64, in64>> 387 388 %typemap(in) absl::Span<const std::pair<int64, int64> > 389 (std::vector<std::pair<int64, int64> > temps) { 390 if (!PySequence_Check($input)) { 391 PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); 392 SWIG_fail; 393 } 394 const int size = PySequence_Size($input); 395 temps.reserve(size); 396 for (int i = 0; i < size; ++i) { 397 PyObject* o = PySequence_GetItem($input, i); 398 if (!o) { 399 SWIG_fail; 400 } 401 PyObject* first = PyTuple_GetItem(o, 0); 402 if (!first) { 403 Py_DECREF(o); 404 SWIG_fail; 405 } 406 PyObject* first_pyint = numpy::PyNumberToPyInt(first); 407 if (!first_pyint) { 408 PyErr_SetString( 409 PyExc_TypeError, 410 "First pair item cannot be converted to int"); 411 Py_DECREF(o); 412 SWIG_fail; 413 } 414 PyObject* second = PyTuple_GetItem(o, 1); 415 if (!second) { 416 Py_DECREF(o); 417 Py_DECREF(first_pyint); 418 SWIG_fail; 419 } 420 PyObject* second_pyint = numpy::PyNumberToPyInt(second); 421 if (!second_pyint) { 422 PyErr_SetString( 423 PyExc_TypeError, 424 "Second pair item cannot be converted to int"); 425 Py_DECREF(o); 426 Py_DECREF(first_pyint); 427 SWIG_fail; 428 } 429 const int64 first_value = numpy::PyIntOrPyLongToLong(first_pyint); 430 if (first_value == -1 && PyErr_Occurred()) { 431 Py_DECREF(o); 432 Py_DECREF(first_pyint); 433 Py_DECREF(second_pyint); 434 SWIG_fail; 435 } 436 const int64 second_value = numpy::PyIntOrPyLongToLong(second_pyint); 437 if (second_value == -1 && PyErr_Occurred()) { 438 Py_DECREF(o); 439 Py_DECREF(first_pyint); 440 Py_DECREF(second_pyint); 441 SWIG_fail; 442 } 443 temps.push_back(std::make_pair(first_value, second_value)); 444 Py_DECREF(o); 445 } 446 $1 = temps; 447 } 448 449 // DotDimensionNumbers 450 451 %typemap(in) const DotDimensionNumbers& 452 (DotDimensionNumbers dimension_numbers) { 453 if (!HandleRepeatedInt64Attribute( 454 $input, "lhs_contracting_dimensions", 455 dimension_numbers.mutable_lhs_contracting_dimensions())) { 456 SWIG_fail; 457 } 458 if (!HandleRepeatedInt64Attribute( 459 $input, "rhs_contracting_dimensions", 460 dimension_numbers.mutable_rhs_contracting_dimensions())) { 461 SWIG_fail; 462 } 463 if (!HandleRepeatedInt64Attribute( 464 $input, "lhs_batch_dimensions", 465 dimension_numbers.mutable_lhs_batch_dimensions())) { 466 SWIG_fail; 467 } 468 if (!HandleRepeatedInt64Attribute( 469 $input, "rhs_batch_dimensions", 470 dimension_numbers.mutable_rhs_batch_dimensions())) { 471 SWIG_fail; 472 } 473 474 $1 = &dimension_numbers; 475 } 476 477 // PaddingConfig 478 479 %typemap(in) const PaddingConfig& 480 (PaddingConfig padding_config) { 481 PyObject* dimensions = PyObject_GetAttrString($input, "dimensions"); 482 if (!dimensions) { 483 SWIG_fail; 484 } 485 486 int length = PySequence_Size(dimensions); 487 if (length == -1) { 488 Py_DECREF(dimensions); 489 SWIG_fail; 490 } 491 492 for (int i = 0; i < length; ++i) { 493 PyObject* item = PySequence_GetItem(dimensions, i); 494 if (!item) { 495 Py_DECREF(dimensions); 496 SWIG_fail; 497 } 498 int64 edge_padding_low, edge_padding_high, interior_padding; 499 if (!GetIntAttr(item, "edge_padding_low", &edge_padding_low) 500 || !GetIntAttr(item, "edge_padding_high", &edge_padding_high) 501 || !GetIntAttr(item, "interior_padding", &interior_padding)) { 502 Py_DECREF(item); 503 Py_DECREF(dimensions); 504 SWIG_fail; 505 } 506 Py_DECREF(item); 507 508 PaddingConfig::PaddingConfigDimension* dimension = 509 padding_config.add_dimensions(); 510 dimension->set_edge_padding_low(edge_padding_low); 511 dimension->set_edge_padding_high(edge_padding_high); 512 dimension->set_interior_padding(interior_padding); 513 } 514 Py_DECREF(dimensions); 515 516 $1 = &padding_config; 517 } 518 519 // ConvolutionDimensionNumbers 520 521 %typemap(in) const ConvolutionDimensionNumbers& 522 (ConvolutionDimensionNumbers dimension_numbers) { 523 int64 value; 524 525 if (!GetIntAttr($input, "input_batch_dimension", &value)) { 526 SWIG_fail; 527 } 528 dimension_numbers.set_input_batch_dimension(value); 529 530 if (!GetIntAttr($input, "input_feature_dimension", &value)) { 531 SWIG_fail; 532 } 533 dimension_numbers.set_input_feature_dimension(value); 534 535 if (!GetIntAttr($input, "output_batch_dimension", &value)) { 536 SWIG_fail; 537 } 538 dimension_numbers.set_output_batch_dimension(value); 539 540 if (!GetIntAttr($input, "output_feature_dimension", &value)) { 541 SWIG_fail; 542 } 543 dimension_numbers.set_output_feature_dimension(value); 544 545 if (!GetIntAttr($input, "kernel_output_feature_dimension", &value)) { 546 SWIG_fail; 547 } 548 dimension_numbers.set_kernel_output_feature_dimension(value); 549 550 if (!GetIntAttr($input, "kernel_input_feature_dimension", &value)) { 551 SWIG_fail; 552 } 553 dimension_numbers.set_kernel_input_feature_dimension(value); 554 555 if (!HandleRepeatedInt64Attribute( 556 $input, "input_spatial_dimensions", 557 dimension_numbers.mutable_input_spatial_dimensions())) { 558 SWIG_fail; 559 } 560 if (!HandleRepeatedInt64Attribute( 561 $input, "kernel_spatial_dimensions", 562 dimension_numbers.mutable_kernel_spatial_dimensions())) { 563 SWIG_fail; 564 } 565 if (!HandleRepeatedInt64Attribute( 566 $input, "output_spatial_dimensions", 567 dimension_numbers.mutable_output_spatial_dimensions())) { 568 SWIG_fail; 569 } 570 571 $1 = &dimension_numbers; 572 } 573 574 // GatherDimensionNumbers 575 576 %typemap(in) const GatherDimensionNumbers& 577 (GatherDimensionNumbers dimension_numbers) { 578 if (!HandleRepeatedInt64Attribute( 579 $input, "offset_dims", 580 dimension_numbers.mutable_offset_dims())) { 581 SWIG_fail; 582 } 583 if (!HandleRepeatedInt64Attribute( 584 $input, "collapsed_slice_dims", 585 dimension_numbers.mutable_collapsed_slice_dims())) { 586 SWIG_fail; 587 } 588 if (!HandleRepeatedInt64Attribute( 589 $input, "start_index_map", 590 dimension_numbers.mutable_start_index_map())) { 591 SWIG_fail; 592 } 593 594 int64 value; 595 if (!GetIntAttr($input, "index_vector_dim", &value)) { 596 SWIG_fail; 597 } 598 dimension_numbers.set_index_vector_dim(value); 599 600 $1 = &dimension_numbers; 601 } 602 603 // ScatterDimensionNumbers 604 605 %typemap(in) const ScatterDimensionNumbers& 606 (ScatterDimensionNumbers dimension_numbers) { 607 if (!HandleRepeatedInt64Attribute( 608 $input, "update_window_dims", 609 dimension_numbers.mutable_update_window_dims())) { 610 SWIG_fail; 611 } 612 if (!HandleRepeatedInt64Attribute( 613 $input, "inserted_window_dims", 614 dimension_numbers.mutable_inserted_window_dims())) { 615 SWIG_fail; 616 } 617 if (!HandleRepeatedInt64Attribute( 618 $input, "scatter_dims_to_operand_dims", 619 dimension_numbers.mutable_scatter_dims_to_operand_dims())) { 620 SWIG_fail; 621 } 622 623 int64 value; 624 if (!GetIntAttr($input, "index_vector_dim", &value)) { 625 SWIG_fail; 626 } 627 dimension_numbers.set_index_vector_dim(value); 628 629 $1 = &dimension_numbers; 630 } 631 632 // Span<const ReplicaGroup> 633 634 %typemap(in) absl::Span<const ReplicaGroup > 635 (std::vector<ReplicaGroup > temps) { 636 if (!PySequence_Check($input)) { 637 PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); 638 SWIG_fail; 639 } 640 const int size = PySequence_Size($input); 641 temps.reserve(size); 642 for (int i = 0; i < size; ++i) { 643 PyObject* o = PySequence_GetItem($input, i); 644 ReplicaGroup rgrp; 645 if (!HandleRepeatedInt64Attribute( 646 o, "replica_ids", 647 rgrp.mutable_replica_ids())) { 648 SWIG_fail; 649 } 650 temps.push_back(rgrp); 651 Py_DECREF(o); 652 } 653 $1 = temps; 654 } 655