1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // SWIG typemaps and declarations for building, compiling, and 17 // executing XLA computations, wrapping most of what is declared in 18 // local_computation_builder.h. 19 20 %module(threads="1") local_computation_builder 21 22 // Keep the GIL except where explicitly specified. 23 %nothread; 24 25 %include "tensorflow/python/platform/base.i" 26 %include "tensorflow/compiler/xla/python/xla_data.i" 27 28 %{ 29 // Must be included first 30 #include "tensorflow/python/lib/core/numpy.h" 31 32 #include "absl/strings/str_cat.h" 33 #include "absl/strings/str_format.h" 34 #include "tensorflow/compiler/xla/literal.h" 35 #include "tensorflow/compiler/xla/shape_util.h" 36 #include "tensorflow/compiler/xla/xla_data.pb.h" 37 #include "absl/types/span.h" 38 #include "tensorflow/compiler/xla/python/numpy_bridge.h" 39 #include "tensorflow/compiler/xla/python/local_computation_builder.h" 40 41 using namespace xla; 42 using namespace xla::swig; 43 44 %} 45 46 // Required to use PyArray_* functions. 47 %init %{ 48 tensorflow::ImportNumpy(); 49 %} 50 51 // Computation builder types 52 53 %typemap(in) absl::Span<const xla::swig::LocalOp>( 54 std::vector<LocalOp> temps) { 55 if (!PySequence_Check($input)) { 56 PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); 57 SWIG_fail; 58 } 59 const int size = PySequence_Size($input); 60 for (int i = 0; i < size; ++i) { 61 PyObject* o = PySequence_GetItem($input, i); 62 LocalOp* op; 63 if ((SWIG_ConvertPtr(o, (void**)&op, $descriptor(xla::swig::LocalOp*), 64 SWIG_POINTER_EXCEPTION)) == -1) { 65 SWIG_fail; 66 } 67 temps.push_back(*op); 68 Py_DECREF(o); 69 } 70 $1 = temps; 71 } 72 73 // Computation and buffer/allocation types 74 75 %typemap(out) StatusOr<xla::swig::LocalClient> { 76 if ($1.ok()) { 77 xla::swig::LocalClient value = $1.ValueOrDie(); 78 { 79 auto $1 = value; 80 $typemap(out, xla::swig::LocalClient) 81 } 82 } else { 83 PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); 84 SWIG_fail; 85 } 86 } 87 88 %typemap(out) StatusOr<xla::swig::LocalExecutable*> { 89 if ($1.ok()) { 90 auto* value = $1.ValueOrDie(); 91 { 92 auto* $1 = value; 93 $typemap(out, xla::swig::LocalExecutable*) 94 } 95 } else { 96 PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); 97 SWIG_fail; 98 } 99 } 100 101 %typemap(out) StatusOr<xla::swig::LocalShapedBuffer*> { 102 if ($1.ok()) { 103 auto* value = $1.ValueOrDie(); 104 { 105 auto* $1 = value; 106 $typemap(out, xla::swig::LocalShapedBuffer*) 107 } 108 } else { 109 PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); 110 SWIG_fail; 111 } 112 } 113 114 %typemap(out) StatusOr<xla::swig::LocalShapedBufferTuple*> { 115 if ($1.ok()) { 116 auto* value = $1.ValueOrDie(); 117 { 118 auto* $1 = value; 119 $typemap(out, xla::swig::LocalShapedBufferTuple*) 120 } 121 } else { 122 PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); 123 SWIG_fail; 124 } 125 } 126 127 %typemap(out) StatusOr<xla::swig::Computation*> { 128 if ($1.ok()) { 129 auto* value = $1.ValueOrDie(); 130 { 131 auto* $1 = value; 132 $typemap(out, xla::swig::Computation*) 133 } 134 } else { 135 PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); 136 SWIG_fail; 137 } 138 } 139 140 %typemap(in) absl::Span<xla::swig::LocalShapedBuffer* const> 141 (std::vector<LocalShapedBuffer*> temps) { 142 if (!PySequence_Check($input)) { 143 PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); 144 SWIG_fail; 145 } 146 const int size = PySequence_Size($input); 147 temps.reserve(size); 148 for (int i = 0; i < size; ++i) { 149 PyObject* o = PySequence_GetItem($input, i); 150 LocalShapedBuffer* lsbp; 151 if ((SWIG_ConvertPtr(o, (void**) &lsbp, $descriptor(xla::swig::LocalShapedBuffer*), 152 SWIG_POINTER_EXCEPTION)) == -1) { 153 SWIG_fail; 154 } 155 temps.push_back(lsbp); 156 Py_DECREF(o); 157 } 158 $1 = temps; 159 } 160 161 %typemap(in) absl::Span<const std::vector<xla::swig::LocalShapedBuffer*> > 162 (std::vector<std::vector<LocalShapedBuffer*> > temps) { 163 if (!PySequence_Check($input)) { 164 PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); 165 SWIG_fail; 166 } 167 const int size = PySequence_Size($input); 168 temps.reserve(size); 169 for (int i = 0; i < size; ++i) { 170 PyObject* o = PySequence_GetItem($input, i); 171 std::vector<LocalShapedBuffer*> vec; 172 const int vec_size = PySequence_Size(o); 173 vec.reserve(vec_size); 174 for (int j = 0; j < vec_size; ++j) { 175 PyObject* vec_elt = PySequence_GetItem(o, j); 176 LocalShapedBuffer* lsbp; 177 if ((SWIG_ConvertPtr(vec_elt, (void**) &lsbp, $descriptor(xla::swig::LocalShapedBuffer*), 178 SWIG_POINTER_EXCEPTION)) == -1) { 179 Py_DECREF(vec_elt); 180 Py_DECREF(o); 181 SWIG_fail; 182 } 183 vec.push_back(lsbp); 184 Py_DECREF(vec_elt); 185 } 186 temps.push_back(vec); 187 Py_DECREF(o); 188 } 189 $1 = temps; 190 } 191 192 // ExecutableBuildOptions 193 194 %typemap(in) const ExecutableBuildOptions* 195 (ExecutableBuildOptions build_options) { 196 if ($input == Py_None) { 197 $1 = NULL; 198 } else { 199 if (!HandleStringAttribute($input, "dump_to", [&](string s) { 200 build_options.mutable_debug_options()->set_xla_dump_to(std::move(s)); 201 })) { 202 return nullptr; 203 } 204 if (!HandleStringAttribute($input, "dump_hlo_pass_re", [&](string s) { 205 build_options.mutable_debug_options()->set_xla_dump_hlo_pass_re(std::move(s)); 206 })) { 207 return nullptr; 208 } 209 if (!HandleStringAttribute($input, "dump_hlo_module_re", [&](string s) { 210 build_options.mutable_debug_options()->set_xla_dump_hlo_module_re(std::move(s)); 211 })) { 212 return nullptr; 213 } 214 if (!HandleBoolAttribute($input, "dump_hlo_as_text", [&](bool b) { 215 build_options.mutable_debug_options()->set_xla_dump_hlo_as_text(b); 216 })) { 217 return nullptr; 218 } 219 if (!HandleBoolAttribute($input, "dump_hlo_as_proto", [&](bool b) { 220 build_options.mutable_debug_options()->set_xla_dump_hlo_as_proto(b); 221 })) { 222 return nullptr; 223 } 224 if (!HandleBoolAttribute($input, "hlo_profile", [&](bool b) { 225 build_options.mutable_debug_options()->set_xla_hlo_profile(b); 226 })) { 227 return nullptr; 228 } 229 230 PyObject* o = PyObject_GetAttrString($input, "result_shape"); 231 if (o == nullptr) { 232 return nullptr; 233 } 234 if (o != Py_None) { 235 StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape(o); 236 if (!statusor.ok()) { 237 PyErr_SetString(PyExc_TypeError, absl::StrCat("ExecutableBuildOptions.result_shape could not be created from Python shape value: ", statusor.status().ToString()).c_str()); 238 Py_DECREF(o); 239 SWIG_fail; 240 } 241 build_options.set_result_layout(statusor.ValueOrDie()); 242 } 243 Py_DECREF(o); 244 245 int64 num_replicas; 246 if (!GetIntAttr($input, "num_replicas", &num_replicas)) { 247 SWIG_fail; 248 } 249 build_options.set_num_replicas(num_replicas); 250 251 $1 = &build_options; 252 } 253 } 254 255 %ignoreall 256 %unignore xla; 257 %unignore xla::swig; 258 %unignore xla::swig::RegisterCpuCustomCallTarget; 259 %unignore xla::swig::LocalClient; 260 %unignore xla::swig::LocalClient::Get; 261 %unignore xla::swig::LocalClient::DeviceCount; 262 %unignore xla::swig::LocalClient::TransferToInfeed; 263 %unignore xla::swig::LocalClient::TransferFromOutfeed; 264 %unignore xla::swig::LocalShapedBuffer; 265 %unignore xla::swig::LocalShapedBuffer::FromLiteral; 266 %unignore xla::swig::LocalShapedBuffer::ToLiteral; 267 %unignore xla::swig::LocalShapedBuffer::shape; 268 %unignore xla::swig::LocalShapedBuffer::DestructureTuple; 269 %unignore xla::swig::LocalShapedBufferTuple; 270 %unignore xla::swig::LocalShapedBufferTuple::Release; 271 %unignore xla::swig::LocalShapedBufferTuple::size; 272 %unignore xla::swig::LocalExecutable; 273 %unignore xla::swig::LocalExecutable::DeviceOrdinals; 274 %unignore xla::swig::LocalExecutable::Execute; 275 %unignore xla::swig::LocalExecutable::ExecutePerReplica; 276 %unignore xla::swig::Computation; 277 %unignore xla::swig::Computation::Compile; 278 %unignore xla::swig::Computation::GetProgramShape; 279 %unignore xla::swig::Computation::GetReturnValueShape; 280 %unignore xla::swig::Computation::GetSerializedProto; 281 %unignore xla::swig::Computation::GetHloText; 282 %unignore xla::swig::Computation::GetHloDotGraph; 283 %unignore xla::swig::LocalOp; 284 %unignore xla::swig::ComputationBuilder; 285 %unignore xla::swig::ComputationBuilder::ComputationBuilder; 286 %unignore xla::swig::ComputationBuilder::Build; 287 %unignore xla::swig::ComputationBuilder::BuildWithRoot; 288 %unignore xla::swig::ComputationBuilder::SetOpMetadata; 289 %unignore xla::swig::ComputationBuilder::ClearOpMetadata; 290 %unignore xla::swig::ComputationBuilder::Parameter; 291 %unignore xla::swig::ComputationBuilder::GetShape; 292 %unignore xla::swig::ComputationBuilder::GetReturnValueShape; 293 %unignore xla::swig::ComputationBuilder::ReplicaId; 294 %unignore xla::swig::ComputationBuilder::Infeed; 295 %unignore xla::swig::ComputationBuilder::Outfeed; 296 %unignore xla::swig::ComputationBuilder::ConstantLiteral; 297 %unignore xla::swig::ComputationBuilder::ConstantR0; 298 %unignore xla::swig::ComputationBuilder::Iota; 299 %unignore xla::swig::ComputationBuilder::BroadcastedIota; 300 %unignore xla::swig::ComputationBuilder::Broadcast; 301 %unignore xla::swig::ComputationBuilder::BroadcastInDim; 302 %unignore xla::swig::ComputationBuilder::Pad; 303 %unignore xla::swig::ComputationBuilder::Reshape; 304 %unignore xla::swig::ComputationBuilder::Collapse; 305 %unignore xla::swig::ComputationBuilder::AllToAll; 306 %unignore xla::swig::ComputationBuilder::CrossReplicaSum; 307 %unignore xla::swig::ComputationBuilder::Slice; 308 %unignore xla::swig::ComputationBuilder::SliceInDim; 309 %unignore xla::swig::ComputationBuilder::DynamicSlice; 310 %unignore xla::swig::ComputationBuilder::DynamicUpdateSlice; 311 %unignore xla::swig::ComputationBuilder::ConcatInDim; 312 %unignore xla::swig::ComputationBuilder::SelectAndScatterWithGeneralPadding; 313 %unignore xla::swig::ComputationBuilder::Select; 314 %unignore xla::swig::ComputationBuilder::Tuple; 315 %unignore xla::swig::ComputationBuilder::GetTupleElement; 316 %unignore xla::swig::ComputationBuilder::ConvertElementType; 317 %unignore xla::swig::ComputationBuilder::BitcastConvertType; 318 %unignore xla::swig::ComputationBuilder::Call; 319 %unignore xla::swig::ComputationBuilder::Transpose; 320 %unignore xla::swig::ComputationBuilder::Rev; 321 %unignore xla::swig::ComputationBuilder::Clamp; 322 %unignore xla::swig::ComputationBuilder::Map; 323 %unignore xla::swig::ComputationBuilder::Reduce; 324 %unignore xla::swig::ComputationBuilder::ReduceWindowWithGeneralPadding; 325 %unignore xla::swig::ComputationBuilder::RngNormal; 326 %unignore xla::swig::ComputationBuilder::RngUniform; 327 %unignore xla::swig::ComputationBuilder::RngBernoulli; 328 %unignore xla::swig::ComputationBuilder::While; 329 %unignore xla::swig::ComputationBuilder::Conditional; 330 %unignore xla::swig::ComputationBuilder::IsConstant; 331 %unignore xla::swig::ComputationBuilder::Eq; 332 %unignore xla::swig::ComputationBuilder::Ne; 333 %unignore xla::swig::ComputationBuilder::Ge; 334 %unignore xla::swig::ComputationBuilder::Gt; 335 %unignore xla::swig::ComputationBuilder::Lt; 336 %unignore xla::swig::ComputationBuilder::Le; 337 %unignore xla::swig::ComputationBuilder::Dot; 338 %unignore xla::swig::ComputationBuilder::DotGeneral; 339 %unignore xla::swig::ComputationBuilder::ConvGeneralDilated; 340 %unignore xla::swig::ComputationBuilder::Add; 341 %unignore xla::swig::ComputationBuilder::Sub; 342 %unignore xla::swig::ComputationBuilder::Mul; 343 %unignore xla::swig::ComputationBuilder::Div; 344 %unignore xla::swig::ComputationBuilder::Rem; 345 %unignore xla::swig::ComputationBuilder::Max; 346 %unignore xla::swig::ComputationBuilder::Min; 347 %unignore xla::swig::ComputationBuilder::And; 348 %unignore xla::swig::ComputationBuilder::Or; 349 %unignore xla::swig::ComputationBuilder::Xor; 350 %unignore xla::swig::ComputationBuilder::ShiftLeft; 351 %unignore xla::swig::ComputationBuilder::ShiftRightArithmetic; 352 %unignore xla::swig::ComputationBuilder::ShiftRightLogical; 353 %unignore xla::swig::ComputationBuilder::Not; 354 %unignore xla::swig::ComputationBuilder::Clz; 355 %unignore xla::swig::ComputationBuilder::Abs; 356 %unignore xla::swig::ComputationBuilder::Exp; 357 %unignore xla::swig::ComputationBuilder::Expm1; 358 %unignore xla::swig::ComputationBuilder::Floor; 359 %unignore xla::swig::ComputationBuilder::Ceil; 360 %unignore xla::swig::ComputationBuilder::Round; 361 %unignore xla::swig::ComputationBuilder::Log; 362 %unignore xla::swig::ComputationBuilder::Log1p; 363 %unignore xla::swig::ComputationBuilder::Sign; 364 %unignore xla::swig::ComputationBuilder::Cos; 365 %unignore xla::swig::ComputationBuilder::Sin; 366 %unignore xla::swig::ComputationBuilder::Tanh; 367 %unignore xla::swig::ComputationBuilder::Atan2; 368 %unignore xla::swig::ComputationBuilder::IsFinite; 369 %unignore xla::swig::ComputationBuilder::Pow; 370 %unignore xla::swig::ComputationBuilder::Neg; 371 %unignore xla::swig::ComputationBuilder::Sort; 372 %unignore xla::swig::ComputationBuilder::SortKeyVal; 373 %unignore xla::swig::ComputationBuilder::Sqrt; 374 %unignore xla::swig::ComputationBuilder::Rsqrt; 375 %unignore xla::swig::ComputationBuilder::Square; 376 %unignore xla::swig::ComputationBuilder::Reciprocal; 377 %unignore xla::swig::ComputationBuilder::Erfc; 378 %unignore xla::swig::ComputationBuilder::Erf; 379 %unignore xla::swig::ComputationBuilder::ErfInv; 380 %unignore xla::swig::ComputationBuilder::Lgamma; 381 %unignore xla::swig::ComputationBuilder::Digamma; 382 %unignore xla::swig::ComputationBuilder::Acos; 383 %unignore xla::swig::ComputationBuilder::Asin; 384 %unignore xla::swig::ComputationBuilder::Atan; 385 %unignore xla::swig::ComputationBuilder::Tan; 386 %unignore xla::swig::ComputationBuilder::Acosh; 387 %unignore xla::swig::ComputationBuilder::Asinh; 388 %unignore xla::swig::ComputationBuilder::Atanh; 389 %unignore xla::swig::ComputationBuilder::Cosh; 390 %unignore xla::swig::ComputationBuilder::Sinh; 391 %unignore xla::swig::ComputationBuilder::Real; 392 %unignore xla::swig::ComputationBuilder::Imag; 393 %unignore xla::swig::ComputationBuilder::Conj; 394 %unignore xla::swig::ComputationBuilder::Complex; 395 %unignore xla::swig::ComputationBuilder::Cholesky; 396 %unignore xla::swig::ComputationBuilder::QR; 397 %unignore xla::swig::ComputationBuilder::Eigh; 398 %unignore xla::swig::ComputationBuilder::SVD; 399 %unignore xla::swig::ComputationBuilder::TriangularSolve; 400 %unignore xla::swig::ComputationBuilder::CustomCall; 401 %unignore xla::swig::ComputationBuilder::Gather; 402 %unignore xla::swig::ComputationBuilder::Scatter; 403 %unignore xla::swig::DeleteComputation; 404 %unignore xla::swig::DeleteLocalShapedBuffer; 405 %unignore xla::swig::DeleteLocalExecutable; 406 407 %thread; 408 %include "tensorflow/compiler/xla/python/local_computation_builder.h" 409 %nothread; 410 411 %unignoreall 412