1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/numpy_bridge.h"
17 #include "absl/strings/str_cat.h"
18 #include "absl/strings/str_format.h"
19 #include "tensorflow/compiler/xla/literal_util.h"
20 #include "tensorflow/compiler/xla/shape_util.h"
21 #include "tensorflow/core/platform/logging.h"
22
23 namespace xla {
24
25 namespace swig {
26
27 namespace numpy {
28
make_safe(PyObject * object)29 Safe_PyObjectPtr make_safe(PyObject* object) {
30 return Safe_PyObjectPtr(object);
31 }
32
PrimitiveTypeToNumpyType(PrimitiveType primitive_type)33 int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) {
34 switch (primitive_type) {
35 case PRED:
36 return NPY_BOOL;
37 case S8:
38 return NPY_INT8;
39 case S16:
40 return NPY_INT16;
41 case S32:
42 return NPY_INT32;
43 case S64:
44 return NPY_INT64;
45 case U8:
46 return NPY_UINT8;
47 case U16:
48 return NPY_UINT16;
49 case U32:
50 return NPY_UINT32;
51 case U64:
52 return NPY_UINT64;
53 case F16:
54 return NPY_FLOAT16;
55 case F32:
56 return NPY_FLOAT32;
57 case F64:
58 return NPY_FLOAT64;
59 case C64:
60 return NPY_COMPLEX64;
61 case C128:
62 return NPY_COMPLEX128;
63 case TUPLE:
64 return NPY_OBJECT;
65 default:
66 LOG(FATAL) << "No Numpy type for XLA primitive type " << primitive_type;
67 }
68 }
69
NumpyTypeToPrimitiveType(int np_type)70 PrimitiveType NumpyTypeToPrimitiveType(int np_type) {
71 switch (np_type) {
72 case NPY_BOOL:
73 return PRED;
74 case NPY_INT8:
75 return S8;
76 case NPY_INT16:
77 return S16;
78 case NPY_INT32:
79 return S32;
80 case NPY_INT64:
81 return S64;
82 case NPY_UINT8:
83 return U8;
84 case NPY_UINT16:
85 return U16;
86 case NPY_UINT32:
87 return U32;
88 case NPY_UINT64:
89 return U64;
90 case NPY_FLOAT16:
91 return F16;
92 case NPY_FLOAT32:
93 return F32;
94 case NPY_FLOAT64:
95 return F64;
96 case NPY_COMPLEX64:
97 return C64;
98 case NPY_COMPLEX128:
99 return C128;
100 case NPY_OBJECT:
101 return TUPLE;
102 default:
103 LOG(FATAL) << "No XLA primitive type for Numpy type " << np_type;
104 }
105 }
106
NumpyTypeIsValid(int np_type)107 bool NumpyTypeIsValid(int np_type) {
108 switch (np_type) {
109 case NPY_BOOL:
110 case NPY_INT8:
111 case NPY_INT16:
112 case NPY_INT32:
113 case NPY_INT64:
114 case NPY_UINT8:
115 case NPY_UINT16:
116 case NPY_UINT32:
117 case NPY_UINT64:
118 case NPY_FLOAT16:
119 case NPY_FLOAT32:
120 case NPY_FLOAT64:
121 case NPY_COMPLEX64:
122 case NPY_COMPLEX128:
123 case NPY_OBJECT:
124 return true;
125 default:
126 return false;
127 }
128 }
129
PyShapeInfoFromXlaShape(const Shape & shape)130 Safe_PyObjectPtr PyShapeInfoFromXlaShape(const Shape& shape) {
131 int np_typenum = PrimitiveTypeToNumpyType(shape.element_type());
132 PyArray_Descr* np_dtype = PyArray_DescrFromType(np_typenum);
133
134 Safe_PyObjectPtr dimensions;
135 if (shape.IsTuple()) {
136 int num_elements = ShapeUtil::TupleElementCount(shape);
137 dimensions = make_safe(PyTuple_New(ShapeUtil::TupleElementCount(shape)));
138 for (int i = 0; i < num_elements; ++i) {
139 PyTuple_SET_ITEM(
140 dimensions.get(), i,
141 PyShapeInfoFromXlaShape(ShapeUtil::GetTupleElementShape(shape, i))
142 .release());
143 }
144 } else {
145 int rank = shape.rank();
146 dimensions = make_safe(PyTuple_New(rank));
147 for (int i = 0; i < rank; ++i) {
148 PyTuple_SET_ITEM(dimensions.get(), i,
149 LongToPyIntOrPyLong(ShapeUtil::GetDimension(shape, i)));
150 }
151 }
152 return make_safe(PyTuple_Pack(2, np_dtype, dimensions.release()));
153 }
154
PyProgramShapeInfoFromXlaProgramShape(const ProgramShape & shape)155 Safe_PyObjectPtr PyProgramShapeInfoFromXlaProgramShape(
156 const ProgramShape& shape) {
157 Safe_PyObjectPtr arg_shapes = make_safe(PyTuple_New(shape.parameters_size()));
158 for (int i = 0; i < shape.parameters_size(); ++i) {
159 PyTuple_SET_ITEM(arg_shapes.get(), i,
160 PyShapeInfoFromXlaShape(shape.parameters(i)).release());
161 }
162
163 Safe_PyObjectPtr result_shape = PyShapeInfoFromXlaShape(shape.result());
164 return make_safe(
165 PyTuple_Pack(2, arg_shapes.release(), result_shape.release()));
166 }
167
168 // Precondition: o->ob_type == &PyArrayDescr_Type
NumpyTypenum(PyObject * o)169 static int NumpyTypenum(PyObject* o) {
170 return reinterpret_cast<PyArray_Descr*>(o)->type_num;
171 }
172
173 // Extracts the string held inside r and returns it as a C++ string.
174 //
175 // NOTE: this is an internal helper for conversion to a C++, and so decrefs r.
ExtractStringAndDecref(PyObject * r)176 static string ExtractStringAndDecref(PyObject* r) {
177 auto error = [r] { return absl::StrFormat("<failed conversion of %p>", r); };
178 if (r == nullptr) {
179 return error();
180 }
181 #if PY_MAJOR_VERSION < 3
182 string result = PyString_AsString(r);
183 #else
184 PyObject* bytes = PyUnicode_AsEncodedString(r, 0, 0);
185 if (bytes == nullptr) {
186 return error();
187 }
188 CHECK(PyBytes_Check(bytes));
189 string result = PyBytes_AsString(bytes);
190 Py_DECREF(bytes);
191 #endif
192 Py_DECREF(r);
193 return result;
194 }
195
196 // Safely returns a str of the given Python object o as a C++ string.
PyObjectCppStr(PyObject * o)197 static string PyObjectCppStr(PyObject* o) {
198 PyObject* s = PyObject_Str(o);
199 return ExtractStringAndDecref(s);
200 }
201
PyObjectCppRepr(PyObject * o)202 string PyObjectCppRepr(PyObject* o) {
203 PyObject* r = PyObject_Repr(o);
204 return ExtractStringAndDecref(r);
205 }
206
XlaShapeFromPyShape(PyObject * o)207 StatusOr<Shape> XlaShapeFromPyShape(PyObject* o) {
208 auto error = [o](const string& prefix) {
209 return InvalidArgument("%s; got %s", prefix.c_str(),
210 PyObjectCppRepr(o).c_str());
211 };
212
213 auto call_method = [o, &error](const string& method) -> StatusOr<PyObject*> {
214 PyObject* result =
215 PyObject_CallMethod(o, const_cast<char*>(method.c_str()), nullptr);
216 if (result == nullptr) {
217 return error(
218 absl::StrCat("Failed to call method of shape object:", method));
219 }
220 return result;
221 };
222
223 PyObject* np_type;
224 TF_ASSIGN_OR_RETURN(np_type, call_method("numpy_dtype"));
225 if (np_type->ob_type != &PyArrayDescr_Type) {
226 return error(
227 "Return value of shape method numpy_dtype "
228 "is not an integer numpy dtype");
229 }
230 if (!NumpyTypeIsValid(NumpyTypenum(np_type))) {
231 return error(
232 "Return value of shape method numpy_dtype "
233 "is not a valid integer numpy dtype");
234 }
235 const PrimitiveType element_type =
236 NumpyTypeToPrimitiveType(NumpyTypenum(np_type));
237 Py_DECREF(np_type);
238
239 if (element_type == TUPLE) {
240 PyObject* py_subshapes;
241 TF_ASSIGN_OR_RETURN(py_subshapes, call_method("tuple_shapes"));
242 if (!PyTuple_Check(py_subshapes)) {
243 return error(
244 "Return value of Shape method tuple_shapes() is not a tuple");
245 }
246 const int length = PyTuple_Size(py_subshapes);
247 std::vector<Shape> subshapes;
248 subshapes.reserve(length);
249 for (int i = 0; i < length; i++) {
250 TF_ASSIGN_OR_RETURN(
251 const Shape& subshape,
252 XlaShapeFromPyShape(PyTuple_GetItem(py_subshapes, i)));
253 subshapes.push_back(subshape);
254 }
255 Py_DECREF(py_subshapes);
256 return ShapeUtil::MakeTupleShape(subshapes);
257 } else {
258 PyObject* py_dimensions;
259 PyObject* py_minor_to_major;
260 TF_ASSIGN_OR_RETURN(py_dimensions, call_method("dimensions"));
261 TF_ASSIGN_OR_RETURN(py_minor_to_major, call_method("minor_to_major"));
262 if (!PyTuple_Check(py_dimensions)) {
263 return error("Return value of Shape method dimensions() is not a tuple");
264 }
265 if (py_minor_to_major != Py_None && !PyTuple_Check(py_minor_to_major)) {
266 return error(
267 "Return value of Shape method minor_to_major() is neither a tuple "
268 "nor None");
269 }
270 const int length = PyTuple_Size(py_dimensions);
271 if (py_minor_to_major != Py_None &&
272 length != PyTuple_Size(py_minor_to_major)) {
273 return error(
274 "Shape methods dimensions() and minor_to_major() return "
275 "different-length tuples");
276 }
277 std::vector<int64> dimensions(length);
278 std::vector<int64> minor_to_major(length);
279 for (int i = 0; i < length; i++) {
280 dimensions[i] = PyIntOrPyLongToLong(PyTuple_GetItem(py_dimensions, i));
281 if (dimensions[i] == -1 && PyErr_Occurred()) {
282 return error("Dimension is not an int");
283 }
284
285 if (py_minor_to_major != Py_None) {
286 minor_to_major[i] =
287 PyIntOrPyLongToLong(PyTuple_GetItem(py_minor_to_major, i));
288 if (minor_to_major[i] == -1 && PyErr_Occurred()) {
289 return error("Minor-to-major value is not an int");
290 }
291 }
292 }
293 bool with_layout = py_minor_to_major != Py_None;
294 Py_DECREF(py_dimensions);
295 Py_DECREF(py_minor_to_major);
296 if (with_layout) {
297 return ShapeUtil::MakeShapeWithLayout(element_type, dimensions,
298 minor_to_major);
299 } else {
300 return ShapeUtil::MakeShape(element_type, dimensions);
301 }
302 }
303 }
304
305 // Helper that retrieves the member with attr_name, stringifies it if is not
306 // None, and returns it as a C++ string.
GetAttrAsString(PyObject * o,const string & attr_name)307 static absl::optional<string> GetAttrAsString(PyObject* o,
308 const string& attr_name) {
309 if (!PyObject_HasAttrString(o, attr_name.c_str())) {
310 return absl::nullopt;
311 }
312 PyObject* attr = PyObject_GetAttrString(o, attr_name.c_str());
313 if (attr == Py_None) {
314 Py_DECREF(attr);
315 return absl::nullopt;
316 }
317 string result = PyObjectCppStr(attr);
318 Py_DECREF(attr);
319 return result;
320 }
321
322 // Helper that retrieves the member with attr_name, checks that it is an integer
323 // if it is not None, and returns it as an int32 value.
GetAttrAsInt32(PyObject * o,const string & attr_name)324 static absl::optional<int32> GetAttrAsInt32(PyObject* o,
325 const string& attr_name) {
326 if (!PyObject_HasAttrString(o, attr_name.c_str())) {
327 return absl::nullopt;
328 }
329 PyObject* attr = PyObject_GetAttrString(o, attr_name.c_str());
330 if (attr == Py_None) {
331 Py_DECREF(attr);
332 return absl::nullopt;
333 }
334 if (!CheckPyIntOrLong(attr)) {
335 Py_DECREF(attr);
336 return absl::nullopt;
337 }
338 long value = PyIntOrPyLongToLong(attr); // NOLINT
339 Py_DECREF(attr);
340 if (value == -1 && PyErr_Occurred() != nullptr) {
341 return absl::nullopt;
342 }
343 if (static_cast<int32>(value) != value) {
344 return absl::nullopt;
345 }
346 return value;
347 }
348
OpMetadataFromPyObject(PyObject * o)349 StatusOr<OpMetadata> OpMetadataFromPyObject(PyObject* o) {
350 OpMetadata result;
351 absl::optional<string> op_type = GetAttrAsString(o, "op_type");
352 if (op_type.has_value()) {
353 result.set_op_type(op_type.value());
354 }
355 absl::optional<string> op_name = GetAttrAsString(o, "op_name");
356 if (op_name.has_value()) {
357 result.set_op_name(op_name.value());
358 }
359 absl::optional<string> source_file = GetAttrAsString(o, "source_file");
360 if (source_file.has_value()) {
361 result.set_source_file(source_file.value());
362 }
363 absl::optional<int32> source_line = GetAttrAsInt32(o, "source_line");
364 if (source_line.has_value()) {
365 result.set_source_line(source_line.value());
366 }
367 return result;
368 }
369
PyObjectFromXlaLiteral(const LiteralSlice & literal)370 StatusOr<Safe_PyObjectPtr> PyObjectFromXlaLiteral(const LiteralSlice& literal) {
371 if (literal.shape().IsTuple()) {
372 int num_elements = ShapeUtil::TupleElementCount(literal.shape());
373 std::vector<Safe_PyObjectPtr> elems(num_elements);
374 for (int i = 0; i < num_elements; i++) {
375 TF_ASSIGN_OR_RETURN(elems[i],
376 PyObjectFromXlaLiteral(LiteralSlice(literal, {i})));
377 }
378 Safe_PyObjectPtr tuple = make_safe(PyTuple_New(num_elements));
379 for (int i = 0; i < num_elements; i++) {
380 PyTuple_SET_ITEM(tuple.get(), i, elems[i].release());
381 }
382 return tuple;
383 } else {
384 int rank = literal.shape().rank();
385 std::vector<long> dimensions(rank); // NOLINT - PyArray requires a long*
386 for (int i = 0; i < rank; i++) {
387 dimensions[i] = ShapeUtil::GetDimension(literal.shape(), i);
388 }
389 int np_type = PrimitiveTypeToNumpyType(literal.shape().element_type());
390 Safe_PyObjectPtr array = make_safe(
391 PyArray_EMPTY(rank, dimensions.data(), np_type, /*fortran=*/0));
392 TF_RETURN_IF_ERROR(CopyLiteralToNumpyArray(
393 np_type, literal, reinterpret_cast<PyArrayObject*>(array.get())));
394 return array;
395 }
396 }
397
XlaLiteralFromPyObject(PyObject * o)398 StatusOr<Literal> XlaLiteralFromPyObject(PyObject* o) {
399 if (PyTuple_Check(o)) {
400 int num_elements = PyTuple_Size(o);
401 std::vector<Literal> elements;
402 elements.reserve(num_elements);
403 for (int i = 0; i < num_elements; i++) {
404 PyObject* element = PyTuple_GetItem(o, i);
405 TF_ASSIGN_OR_RETURN(auto literal, XlaLiteralFromPyObject(element));
406 elements.push_back(std::move(literal));
407 }
408 return LiteralUtil::MakeTupleOwned(std::move(elements));
409 } else if (PyArray_Check(o)) {
410 PyArrayObject* py_array = reinterpret_cast<PyArrayObject*>(o);
411 int rank = PyArray_NDIM(py_array);
412 std::vector<int64> dimensions(rank);
413 for (int i = 0; i < rank; i++) {
414 dimensions[i] = PyArray_DIM(py_array, i);
415 }
416 int np_type = PyArray_TYPE(py_array);
417 auto literal = LiteralUtil::CreateFromDimensions(
418 NumpyTypeToPrimitiveType(np_type), dimensions);
419 TF_RETURN_IF_ERROR(CopyNumpyArrayToLiteral(np_type, py_array, &literal));
420 return std::move(literal);
421 } else {
422 return InvalidArgument(
423 "Non-tuple or Numpy array encountered in conversion to XLA literal.");
424 }
425 }
426
CopyNumpyArrayToLiteral(int np_type,PyArrayObject * py_array,Literal * literal)427 Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
428 Literal* literal) {
429 switch (np_type) {
430 case NPY_BOOL:
431 CopyNumpyArrayToLiteral<bool>(py_array, literal);
432 break;
433 case NPY_INT8:
434 CopyNumpyArrayToLiteral<int8>(py_array, literal);
435 break;
436 case NPY_INT16:
437 CopyNumpyArrayToLiteral<int16>(py_array, literal);
438 break;
439 case NPY_INT32:
440 CopyNumpyArrayToLiteral<int32>(py_array, literal);
441 break;
442 case NPY_INT64:
443 CopyNumpyArrayToLiteral<int64>(py_array, literal);
444 break;
445 case NPY_UINT8:
446 CopyNumpyArrayToLiteral<uint8>(py_array, literal);
447 break;
448 case NPY_UINT16:
449 CopyNumpyArrayToLiteral<uint16>(py_array, literal);
450 break;
451 case NPY_UINT32:
452 CopyNumpyArrayToLiteral<uint32>(py_array, literal);
453 break;
454 case NPY_UINT64:
455 CopyNumpyArrayToLiteral<uint64>(py_array, literal);
456 break;
457 case NPY_FLOAT16:
458 CopyNumpyArrayToLiteral<half>(py_array, literal);
459 break;
460 case NPY_FLOAT32:
461 CopyNumpyArrayToLiteral<float>(py_array, literal);
462 break;
463 case NPY_FLOAT64:
464 CopyNumpyArrayToLiteral<double>(py_array, literal);
465 break;
466 case NPY_COMPLEX64:
467 CopyNumpyArrayToLiteral<complex64>(py_array, literal);
468 break;
469 case NPY_COMPLEX128:
470 CopyNumpyArrayToLiteral<complex128>(py_array, literal);
471 break;
472 default:
473 return InvalidArgument(
474 "No XLA literal container for Numpy type number: %d", np_type);
475 }
476 return Status::OK();
477 }
478
CopyLiteralToNumpyArray(int np_type,const LiteralSlice & literal,PyArrayObject * py_array)479 Status CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
480 PyArrayObject* py_array) {
481 switch (np_type) {
482 case NPY_BOOL:
483 CopyLiteralToNumpyArray<bool>(literal, py_array);
484 break;
485 case NPY_INT8:
486 CopyLiteralToNumpyArray<int8>(literal, py_array);
487 break;
488 case NPY_INT16:
489 CopyLiteralToNumpyArray<int16>(literal, py_array);
490 break;
491 case NPY_INT32:
492 CopyLiteralToNumpyArray<int32>(literal, py_array);
493 break;
494 case NPY_INT64:
495 CopyLiteralToNumpyArray<int64>(literal, py_array);
496 break;
497 case NPY_UINT8:
498 CopyLiteralToNumpyArray<uint8>(literal, py_array);
499 break;
500 case NPY_UINT16:
501 CopyLiteralToNumpyArray<uint16>(literal, py_array);
502 break;
503 case NPY_UINT32:
504 CopyLiteralToNumpyArray<uint32>(literal, py_array);
505 break;
506 case NPY_UINT64:
507 CopyLiteralToNumpyArray<uint64>(literal, py_array);
508 break;
509 case NPY_FLOAT16:
510 CopyLiteralToNumpyArray<half>(literal, py_array);
511 break;
512 case NPY_FLOAT32:
513 CopyLiteralToNumpyArray<float>(literal, py_array);
514 break;
515 case NPY_FLOAT64:
516 CopyLiteralToNumpyArray<double>(literal, py_array);
517 break;
518 case NPY_COMPLEX64:
519 CopyLiteralToNumpyArray<complex64>(literal, py_array);
520 break;
521 case NPY_COMPLEX128:
522 CopyLiteralToNumpyArray<complex128>(literal, py_array);
523 break;
524 default:
525 return InvalidArgument(
526 "No XLA literal container for Numpy type number: %d", np_type);
527 }
528 return Status::OK();
529 }
530
LongToPyIntOrPyLong(long x)531 PyObject* LongToPyIntOrPyLong(long x) { // NOLINT
532 #if PY_MAJOR_VERSION < 3
533 return PyInt_FromLong(x);
534 #else
535 return PyLong_FromLong(x);
536 #endif
537 }
538
PyIntOrPyLongToLong(PyObject * o)539 long PyIntOrPyLongToLong(PyObject* o) { // NOLINT
540 #if PY_MAJOR_VERSION < 3
541 return PyInt_AsLong(o);
542 #else
543 return PyLong_AsLong(o);
544 #endif
545 }
546
CheckPyIntOrLong(PyObject * o)547 bool CheckPyIntOrLong(PyObject* o) {
548 #if PY_MAJOR_VERSION < 3
549 return PyInt_Check(o);
550 #else
551 if (!PyLong_Check(o)) {
552 return false;
553 }
554 int overflow = 0;
555 PyLong_AsLongAndOverflow(o, &overflow);
556 return (overflow == 0);
557 #endif
558 }
559
PyNumberToPyInt(PyObject * o)560 PyObject* PyNumberToPyInt(PyObject* o) {
561 #if PY_MAJOR_VERSION < 3
562 return PyNumber_Int(o);
563 #else
564 return PyNumber_Long(o);
565 #endif
566 }
567
568 } // namespace numpy
569
GetIntAttr(PyObject * o,const char * field,int64 * result)570 bool GetIntAttr(PyObject* o, const char* field, int64* result) {
571 PyObject* fo = PyObject_GetAttrString(o, field);
572 if (!fo) {
573 return false;
574 }
575 const int64 value = numpy::PyIntOrPyLongToLong(fo);
576 if (value == -1 && PyErr_Occurred()) {
577 Py_DECREF(fo);
578 return false;
579 }
580 Py_DECREF(fo);
581 *result = value;
582 return true;
583 }
584
585 // Returns "ok"; true if there is no error, false if there was an error.
HandleStringAttribute(PyObject * o,const char * attr_name,std::function<void (string s)> f)586 bool HandleStringAttribute(PyObject* o, const char* attr_name,
587 std::function<void(string s)> f) {
588 if (!PyObject_HasAttrString(o, attr_name)) {
589 return true; // It's ok for the object to not have the attribute.
590 }
591 PyObject* attr = PyObject_GetAttrString(o, attr_name);
592 if (attr == nullptr) {
593 return false; // An error occurred getting the attribute.
594 }
595 if (attr == Py_None) {
596 Py_DECREF(attr);
597 return true; // The attribute is None, which we consider ok.
598 }
599 #if PY_MAJOR_VERSION < 3
600 if (!PyString_Check(attr)) {
601 string message = absl::StrFormat("%s must be a string or none; got %s",
602 attr_name, numpy::PyObjectCppRepr(attr));
603 PyErr_SetString(PyExc_TypeError, message.c_str());
604 Py_DECREF(attr);
605 return false; // Type error, not ok.
606 }
607 f(PyString_AsString(attr));
608 #else
609 if (!PyBytes_Check(attr)) {
610 string message = absl::StrFormat("%s must be a string or none; got %s",
611 attr_name, numpy::PyObjectCppRepr(attr));
612 PyErr_SetString(PyExc_TypeError, message.c_str());
613 Py_DECREF(attr);
614 return false; // Type error, not ok.
615 }
616 f(PyBytes_AsString(attr));
617 #endif
618
619 Py_DECREF(attr);
620 return true; // Handled string attribute, ok!
621 }
622
623 // Returns "ok"; true if there is no error, false if there was an error.
HandleBoolAttribute(PyObject * o,const char * attr_name,std::function<void (bool b)> f)624 bool HandleBoolAttribute(PyObject* o, const char* attr_name,
625 std::function<void(bool b)> f) {
626 if (!PyObject_HasAttrString(o, attr_name)) {
627 return true; // It's ok for the object to not have the attribute.
628 }
629 PyObject* attr = PyObject_GetAttrString(o, attr_name);
630 if (attr == nullptr) {
631 return false; // An error occurred getting the attribute.
632 }
633 if (attr == Py_None) {
634 Py_DECREF(attr);
635 return true; // The attribute is None, which we consider ok.
636 }
637 if (!PyBool_Check(attr)) {
638 string message = absl::StrFormat("%s must be a boolean or none; got %s",
639 attr_name, numpy::PyObjectCppRepr(attr));
640 PyErr_SetString(PyExc_TypeError, message.c_str());
641 Py_DECREF(attr);
642 return false; // Type error, not ok.
643 }
644 f(PyObject_IsTrue(attr));
645 Py_DECREF(attr);
646 return true; // Handled boolean attribute, ok!
647 }
648
HandleRepeatedInt64Attribute(PyObject * o,const char * attr_name,tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64> * field)649 bool HandleRepeatedInt64Attribute(
650 PyObject* o, const char* attr_name,
651 tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>* field) {
652 PyObject* seq = PyObject_GetAttrString(o, attr_name);
653 if (!seq) {
654 return false;
655 }
656
657 int length = PySequence_Size(seq);
658 if (length == -1) {
659 Py_DECREF(seq);
660 return false;
661 }
662
663 for (int i = 0; i < length; ++i) {
664 PyObject* item = PySequence_GetItem(seq, i);
665 if (!item) {
666 Py_DECREF(seq);
667 return false;
668 }
669 const int64 dimension = numpy::PyIntOrPyLongToLong(item);
670 if (dimension == -1 && PyErr_Occurred()) {
671 Py_DECREF(item);
672 Py_DECREF(seq);
673 return false;
674 }
675 *field->Add() = dimension;
676 Py_DECREF(item);
677 }
678 Py_DECREF(seq);
679 return true;
680 }
681
682 } // namespace swig
683
684 } // namespace xla
685