1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/numpy_bridge.h"
17 #include "absl/strings/str_cat.h"
18 #include "absl/strings/str_format.h"
19 #include "tensorflow/compiler/xla/literal_util.h"
20 #include "tensorflow/compiler/xla/shape_util.h"
21 #include "tensorflow/core/platform/logging.h"
22 
23 namespace xla {
24 
25 namespace swig {
26 
27 namespace numpy {
28 
make_safe(PyObject * object)29 Safe_PyObjectPtr make_safe(PyObject* object) {
30   return Safe_PyObjectPtr(object);
31 }
32 
PrimitiveTypeToNumpyType(PrimitiveType primitive_type)33 int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) {
34   switch (primitive_type) {
35     case PRED:
36       return NPY_BOOL;
37     case S8:
38       return NPY_INT8;
39     case S16:
40       return NPY_INT16;
41     case S32:
42       return NPY_INT32;
43     case S64:
44       return NPY_INT64;
45     case U8:
46       return NPY_UINT8;
47     case U16:
48       return NPY_UINT16;
49     case U32:
50       return NPY_UINT32;
51     case U64:
52       return NPY_UINT64;
53     case F16:
54       return NPY_FLOAT16;
55     case F32:
56       return NPY_FLOAT32;
57     case F64:
58       return NPY_FLOAT64;
59     case C64:
60       return NPY_COMPLEX64;
61     case C128:
62       return NPY_COMPLEX128;
63     case TUPLE:
64       return NPY_OBJECT;
65     default:
66       LOG(FATAL) << "No Numpy type for XLA primitive type " << primitive_type;
67   }
68 }
69 
NumpyTypeToPrimitiveType(int np_type)70 PrimitiveType NumpyTypeToPrimitiveType(int np_type) {
71   switch (np_type) {
72     case NPY_BOOL:
73       return PRED;
74     case NPY_INT8:
75       return S8;
76     case NPY_INT16:
77       return S16;
78     case NPY_INT32:
79       return S32;
80     case NPY_INT64:
81       return S64;
82     case NPY_UINT8:
83       return U8;
84     case NPY_UINT16:
85       return U16;
86     case NPY_UINT32:
87       return U32;
88     case NPY_UINT64:
89       return U64;
90     case NPY_FLOAT16:
91       return F16;
92     case NPY_FLOAT32:
93       return F32;
94     case NPY_FLOAT64:
95       return F64;
96     case NPY_COMPLEX64:
97       return C64;
98     case NPY_COMPLEX128:
99       return C128;
100     case NPY_OBJECT:
101       return TUPLE;
102     default:
103       LOG(FATAL) << "No XLA primitive type for Numpy type " << np_type;
104   }
105 }
106 
NumpyTypeIsValid(int np_type)107 bool NumpyTypeIsValid(int np_type) {
108   switch (np_type) {
109     case NPY_BOOL:
110     case NPY_INT8:
111     case NPY_INT16:
112     case NPY_INT32:
113     case NPY_INT64:
114     case NPY_UINT8:
115     case NPY_UINT16:
116     case NPY_UINT32:
117     case NPY_UINT64:
118     case NPY_FLOAT16:
119     case NPY_FLOAT32:
120     case NPY_FLOAT64:
121     case NPY_COMPLEX64:
122     case NPY_COMPLEX128:
123     case NPY_OBJECT:
124       return true;
125     default:
126       return false;
127   }
128 }
129 
PyShapeInfoFromXlaShape(const Shape & shape)130 Safe_PyObjectPtr PyShapeInfoFromXlaShape(const Shape& shape) {
131   int np_typenum = PrimitiveTypeToNumpyType(shape.element_type());
132   PyArray_Descr* np_dtype = PyArray_DescrFromType(np_typenum);
133 
134   Safe_PyObjectPtr dimensions;
135   if (shape.IsTuple()) {
136     int num_elements = ShapeUtil::TupleElementCount(shape);
137     dimensions = make_safe(PyTuple_New(ShapeUtil::TupleElementCount(shape)));
138     for (int i = 0; i < num_elements; ++i) {
139       PyTuple_SET_ITEM(
140           dimensions.get(), i,
141           PyShapeInfoFromXlaShape(ShapeUtil::GetTupleElementShape(shape, i))
142               .release());
143     }
144   } else {
145     int rank = shape.rank();
146     dimensions = make_safe(PyTuple_New(rank));
147     for (int i = 0; i < rank; ++i) {
148       PyTuple_SET_ITEM(dimensions.get(), i,
149                        LongToPyIntOrPyLong(ShapeUtil::GetDimension(shape, i)));
150     }
151   }
152   return make_safe(PyTuple_Pack(2, np_dtype, dimensions.release()));
153 }
154 
PyProgramShapeInfoFromXlaProgramShape(const ProgramShape & shape)155 Safe_PyObjectPtr PyProgramShapeInfoFromXlaProgramShape(
156     const ProgramShape& shape) {
157   Safe_PyObjectPtr arg_shapes = make_safe(PyTuple_New(shape.parameters_size()));
158   for (int i = 0; i < shape.parameters_size(); ++i) {
159     PyTuple_SET_ITEM(arg_shapes.get(), i,
160                      PyShapeInfoFromXlaShape(shape.parameters(i)).release());
161   }
162 
163   Safe_PyObjectPtr result_shape = PyShapeInfoFromXlaShape(shape.result());
164   return make_safe(
165       PyTuple_Pack(2, arg_shapes.release(), result_shape.release()));
166 }
167 
168 // Precondition: o->ob_type == &PyArrayDescr_Type
NumpyTypenum(PyObject * o)169 static int NumpyTypenum(PyObject* o) {
170   return reinterpret_cast<PyArray_Descr*>(o)->type_num;
171 }
172 
173 // Extracts the string held inside r and returns it as a C++ string.
174 //
175 // NOTE: this is an internal helper for conversion to a C++, and so decrefs r.
ExtractStringAndDecref(PyObject * r)176 static string ExtractStringAndDecref(PyObject* r) {
177   auto error = [r] { return absl::StrFormat("<failed conversion of %p>", r); };
178   if (r == nullptr) {
179     return error();
180   }
181 #if PY_MAJOR_VERSION < 3
182   string result = PyString_AsString(r);
183 #else
184   PyObject* bytes = PyUnicode_AsEncodedString(r, 0, 0);
185   if (bytes == nullptr) {
186     return error();
187   }
188   CHECK(PyBytes_Check(bytes));
189   string result = PyBytes_AsString(bytes);
190   Py_DECREF(bytes);
191 #endif
192   Py_DECREF(r);
193   return result;
194 }
195 
196 // Safely returns a str of the given Python object o as a C++ string.
PyObjectCppStr(PyObject * o)197 static string PyObjectCppStr(PyObject* o) {
198   PyObject* s = PyObject_Str(o);
199   return ExtractStringAndDecref(s);
200 }
201 
PyObjectCppRepr(PyObject * o)202 string PyObjectCppRepr(PyObject* o) {
203   PyObject* r = PyObject_Repr(o);
204   return ExtractStringAndDecref(r);
205 }
206 
XlaShapeFromPyShape(PyObject * o)207 StatusOr<Shape> XlaShapeFromPyShape(PyObject* o) {
208   auto error = [o](const string& prefix) {
209     return InvalidArgument("%s; got %s", prefix.c_str(),
210                            PyObjectCppRepr(o).c_str());
211   };
212 
213   auto call_method = [o, &error](const string& method) -> StatusOr<PyObject*> {
214     PyObject* result =
215         PyObject_CallMethod(o, const_cast<char*>(method.c_str()), nullptr);
216     if (result == nullptr) {
217       return error(
218           absl::StrCat("Failed to call method of shape object:", method));
219     }
220     return result;
221   };
222 
223   PyObject* np_type;
224   TF_ASSIGN_OR_RETURN(np_type, call_method("numpy_dtype"));
225   if (np_type->ob_type != &PyArrayDescr_Type) {
226     return error(
227         "Return value of shape method numpy_dtype "
228         "is not an integer numpy dtype");
229   }
230   if (!NumpyTypeIsValid(NumpyTypenum(np_type))) {
231     return error(
232         "Return value of shape method numpy_dtype "
233         "is not a valid integer numpy dtype");
234   }
235   const PrimitiveType element_type =
236       NumpyTypeToPrimitiveType(NumpyTypenum(np_type));
237   Py_DECREF(np_type);
238 
239   if (element_type == TUPLE) {
240     PyObject* py_subshapes;
241     TF_ASSIGN_OR_RETURN(py_subshapes, call_method("tuple_shapes"));
242     if (!PyTuple_Check(py_subshapes)) {
243       return error(
244           "Return value of Shape method tuple_shapes() is not a tuple");
245     }
246     const int length = PyTuple_Size(py_subshapes);
247     std::vector<Shape> subshapes;
248     subshapes.reserve(length);
249     for (int i = 0; i < length; i++) {
250       TF_ASSIGN_OR_RETURN(
251           const Shape& subshape,
252           XlaShapeFromPyShape(PyTuple_GetItem(py_subshapes, i)));
253       subshapes.push_back(subshape);
254     }
255     Py_DECREF(py_subshapes);
256     return ShapeUtil::MakeTupleShape(subshapes);
257   } else {
258     PyObject* py_dimensions;
259     PyObject* py_minor_to_major;
260     TF_ASSIGN_OR_RETURN(py_dimensions, call_method("dimensions"));
261     TF_ASSIGN_OR_RETURN(py_minor_to_major, call_method("minor_to_major"));
262     if (!PyTuple_Check(py_dimensions)) {
263       return error("Return value of Shape method dimensions() is not a tuple");
264     }
265     if (py_minor_to_major != Py_None && !PyTuple_Check(py_minor_to_major)) {
266       return error(
267           "Return value of Shape method minor_to_major() is neither a tuple "
268           "nor None");
269     }
270     const int length = PyTuple_Size(py_dimensions);
271     if (py_minor_to_major != Py_None &&
272         length != PyTuple_Size(py_minor_to_major)) {
273       return error(
274           "Shape methods dimensions() and minor_to_major() return "
275           "different-length tuples");
276     }
277     std::vector<int64> dimensions(length);
278     std::vector<int64> minor_to_major(length);
279     for (int i = 0; i < length; i++) {
280       dimensions[i] = PyIntOrPyLongToLong(PyTuple_GetItem(py_dimensions, i));
281       if (dimensions[i] == -1 && PyErr_Occurred()) {
282         return error("Dimension is not an int");
283       }
284 
285       if (py_minor_to_major != Py_None) {
286         minor_to_major[i] =
287             PyIntOrPyLongToLong(PyTuple_GetItem(py_minor_to_major, i));
288         if (minor_to_major[i] == -1 && PyErr_Occurred()) {
289           return error("Minor-to-major value is not an int");
290         }
291       }
292     }
293     bool with_layout = py_minor_to_major != Py_None;
294     Py_DECREF(py_dimensions);
295     Py_DECREF(py_minor_to_major);
296     if (with_layout) {
297       return ShapeUtil::MakeShapeWithLayout(element_type, dimensions,
298                                             minor_to_major);
299     } else {
300       return ShapeUtil::MakeShape(element_type, dimensions);
301     }
302   }
303 }
304 
305 // Helper that retrieves the member with attr_name, stringifies it if is not
306 // None, and returns it as a C++ string.
GetAttrAsString(PyObject * o,const string & attr_name)307 static absl::optional<string> GetAttrAsString(PyObject* o,
308                                               const string& attr_name) {
309   if (!PyObject_HasAttrString(o, attr_name.c_str())) {
310     return absl::nullopt;
311   }
312   PyObject* attr = PyObject_GetAttrString(o, attr_name.c_str());
313   if (attr == Py_None) {
314     Py_DECREF(attr);
315     return absl::nullopt;
316   }
317   string result = PyObjectCppStr(attr);
318   Py_DECREF(attr);
319   return result;
320 }
321 
322 // Helper that retrieves the member with attr_name, checks that it is an integer
323 // if it is not None, and returns it as an int32 value.
GetAttrAsInt32(PyObject * o,const string & attr_name)324 static absl::optional<int32> GetAttrAsInt32(PyObject* o,
325                                             const string& attr_name) {
326   if (!PyObject_HasAttrString(o, attr_name.c_str())) {
327     return absl::nullopt;
328   }
329   PyObject* attr = PyObject_GetAttrString(o, attr_name.c_str());
330   if (attr == Py_None) {
331     Py_DECREF(attr);
332     return absl::nullopt;
333   }
334   if (!CheckPyIntOrLong(attr)) {
335     Py_DECREF(attr);
336     return absl::nullopt;
337   }
338   long value = PyIntOrPyLongToLong(attr);  // NOLINT
339   Py_DECREF(attr);
340   if (value == -1 && PyErr_Occurred() != nullptr) {
341     return absl::nullopt;
342   }
343   if (static_cast<int32>(value) != value) {
344     return absl::nullopt;
345   }
346   return value;
347 }
348 
OpMetadataFromPyObject(PyObject * o)349 StatusOr<OpMetadata> OpMetadataFromPyObject(PyObject* o) {
350   OpMetadata result;
351   absl::optional<string> op_type = GetAttrAsString(o, "op_type");
352   if (op_type.has_value()) {
353     result.set_op_type(op_type.value());
354   }
355   absl::optional<string> op_name = GetAttrAsString(o, "op_name");
356   if (op_name.has_value()) {
357     result.set_op_name(op_name.value());
358   }
359   absl::optional<string> source_file = GetAttrAsString(o, "source_file");
360   if (source_file.has_value()) {
361     result.set_source_file(source_file.value());
362   }
363   absl::optional<int32> source_line = GetAttrAsInt32(o, "source_line");
364   if (source_line.has_value()) {
365     result.set_source_line(source_line.value());
366   }
367   return result;
368 }
369 
PyObjectFromXlaLiteral(const LiteralSlice & literal)370 StatusOr<Safe_PyObjectPtr> PyObjectFromXlaLiteral(const LiteralSlice& literal) {
371   if (literal.shape().IsTuple()) {
372     int num_elements = ShapeUtil::TupleElementCount(literal.shape());
373     std::vector<Safe_PyObjectPtr> elems(num_elements);
374     for (int i = 0; i < num_elements; i++) {
375       TF_ASSIGN_OR_RETURN(elems[i],
376                           PyObjectFromXlaLiteral(LiteralSlice(literal, {i})));
377     }
378     Safe_PyObjectPtr tuple = make_safe(PyTuple_New(num_elements));
379     for (int i = 0; i < num_elements; i++) {
380       PyTuple_SET_ITEM(tuple.get(), i, elems[i].release());
381     }
382     return tuple;
383   } else {
384     int rank = literal.shape().rank();
385     std::vector<long> dimensions(rank);  // NOLINT - PyArray requires a long*
386     for (int i = 0; i < rank; i++) {
387       dimensions[i] = ShapeUtil::GetDimension(literal.shape(), i);
388     }
389     int np_type = PrimitiveTypeToNumpyType(literal.shape().element_type());
390     Safe_PyObjectPtr array = make_safe(
391         PyArray_EMPTY(rank, dimensions.data(), np_type, /*fortran=*/0));
392     TF_RETURN_IF_ERROR(CopyLiteralToNumpyArray(
393         np_type, literal, reinterpret_cast<PyArrayObject*>(array.get())));
394     return array;
395   }
396 }
397 
XlaLiteralFromPyObject(PyObject * o)398 StatusOr<Literal> XlaLiteralFromPyObject(PyObject* o) {
399   if (PyTuple_Check(o)) {
400     int num_elements = PyTuple_Size(o);
401     std::vector<Literal> elements;
402     elements.reserve(num_elements);
403     for (int i = 0; i < num_elements; i++) {
404       PyObject* element = PyTuple_GetItem(o, i);
405       TF_ASSIGN_OR_RETURN(auto literal, XlaLiteralFromPyObject(element));
406       elements.push_back(std::move(literal));
407     }
408     return LiteralUtil::MakeTupleOwned(std::move(elements));
409   } else if (PyArray_Check(o)) {
410     PyArrayObject* py_array = reinterpret_cast<PyArrayObject*>(o);
411     int rank = PyArray_NDIM(py_array);
412     std::vector<int64> dimensions(rank);
413     for (int i = 0; i < rank; i++) {
414       dimensions[i] = PyArray_DIM(py_array, i);
415     }
416     int np_type = PyArray_TYPE(py_array);
417     auto literal = LiteralUtil::CreateFromDimensions(
418         NumpyTypeToPrimitiveType(np_type), dimensions);
419     TF_RETURN_IF_ERROR(CopyNumpyArrayToLiteral(np_type, py_array, &literal));
420     return std::move(literal);
421   } else {
422     return InvalidArgument(
423         "Non-tuple or Numpy array encountered in conversion to XLA literal.");
424   }
425 }
426 
CopyNumpyArrayToLiteral(int np_type,PyArrayObject * py_array,Literal * literal)427 Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
428                                Literal* literal) {
429   switch (np_type) {
430     case NPY_BOOL:
431       CopyNumpyArrayToLiteral<bool>(py_array, literal);
432       break;
433     case NPY_INT8:
434       CopyNumpyArrayToLiteral<int8>(py_array, literal);
435       break;
436     case NPY_INT16:
437       CopyNumpyArrayToLiteral<int16>(py_array, literal);
438       break;
439     case NPY_INT32:
440       CopyNumpyArrayToLiteral<int32>(py_array, literal);
441       break;
442     case NPY_INT64:
443       CopyNumpyArrayToLiteral<int64>(py_array, literal);
444       break;
445     case NPY_UINT8:
446       CopyNumpyArrayToLiteral<uint8>(py_array, literal);
447       break;
448     case NPY_UINT16:
449       CopyNumpyArrayToLiteral<uint16>(py_array, literal);
450       break;
451     case NPY_UINT32:
452       CopyNumpyArrayToLiteral<uint32>(py_array, literal);
453       break;
454     case NPY_UINT64:
455       CopyNumpyArrayToLiteral<uint64>(py_array, literal);
456       break;
457     case NPY_FLOAT16:
458       CopyNumpyArrayToLiteral<half>(py_array, literal);
459       break;
460     case NPY_FLOAT32:
461       CopyNumpyArrayToLiteral<float>(py_array, literal);
462       break;
463     case NPY_FLOAT64:
464       CopyNumpyArrayToLiteral<double>(py_array, literal);
465       break;
466     case NPY_COMPLEX64:
467       CopyNumpyArrayToLiteral<complex64>(py_array, literal);
468       break;
469     case NPY_COMPLEX128:
470       CopyNumpyArrayToLiteral<complex128>(py_array, literal);
471       break;
472     default:
473       return InvalidArgument(
474           "No XLA literal container for Numpy type number: %d", np_type);
475   }
476   return Status::OK();
477 }
478 
CopyLiteralToNumpyArray(int np_type,const LiteralSlice & literal,PyArrayObject * py_array)479 Status CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
480                                PyArrayObject* py_array) {
481   switch (np_type) {
482     case NPY_BOOL:
483       CopyLiteralToNumpyArray<bool>(literal, py_array);
484       break;
485     case NPY_INT8:
486       CopyLiteralToNumpyArray<int8>(literal, py_array);
487       break;
488     case NPY_INT16:
489       CopyLiteralToNumpyArray<int16>(literal, py_array);
490       break;
491     case NPY_INT32:
492       CopyLiteralToNumpyArray<int32>(literal, py_array);
493       break;
494     case NPY_INT64:
495       CopyLiteralToNumpyArray<int64>(literal, py_array);
496       break;
497     case NPY_UINT8:
498       CopyLiteralToNumpyArray<uint8>(literal, py_array);
499       break;
500     case NPY_UINT16:
501       CopyLiteralToNumpyArray<uint16>(literal, py_array);
502       break;
503     case NPY_UINT32:
504       CopyLiteralToNumpyArray<uint32>(literal, py_array);
505       break;
506     case NPY_UINT64:
507       CopyLiteralToNumpyArray<uint64>(literal, py_array);
508       break;
509     case NPY_FLOAT16:
510       CopyLiteralToNumpyArray<half>(literal, py_array);
511       break;
512     case NPY_FLOAT32:
513       CopyLiteralToNumpyArray<float>(literal, py_array);
514       break;
515     case NPY_FLOAT64:
516       CopyLiteralToNumpyArray<double>(literal, py_array);
517       break;
518     case NPY_COMPLEX64:
519       CopyLiteralToNumpyArray<complex64>(literal, py_array);
520       break;
521     case NPY_COMPLEX128:
522       CopyLiteralToNumpyArray<complex128>(literal, py_array);
523       break;
524     default:
525       return InvalidArgument(
526           "No XLA literal container for Numpy type number: %d", np_type);
527   }
528   return Status::OK();
529 }
530 
LongToPyIntOrPyLong(long x)531 PyObject* LongToPyIntOrPyLong(long x) {  // NOLINT
532 #if PY_MAJOR_VERSION < 3
533   return PyInt_FromLong(x);
534 #else
535   return PyLong_FromLong(x);
536 #endif
537 }
538 
PyIntOrPyLongToLong(PyObject * o)539 long PyIntOrPyLongToLong(PyObject* o) {  // NOLINT
540 #if PY_MAJOR_VERSION < 3
541   return PyInt_AsLong(o);
542 #else
543   return PyLong_AsLong(o);
544 #endif
545 }
546 
CheckPyIntOrLong(PyObject * o)547 bool CheckPyIntOrLong(PyObject* o) {
548 #if PY_MAJOR_VERSION < 3
549   return PyInt_Check(o);
550 #else
551   if (!PyLong_Check(o)) {
552     return false;
553   }
554   int overflow = 0;
555   PyLong_AsLongAndOverflow(o, &overflow);
556   return (overflow == 0);
557 #endif
558 }
559 
PyNumberToPyInt(PyObject * o)560 PyObject* PyNumberToPyInt(PyObject* o) {
561 #if PY_MAJOR_VERSION < 3
562   return PyNumber_Int(o);
563 #else
564   return PyNumber_Long(o);
565 #endif
566 }
567 
568 }  // namespace numpy
569 
GetIntAttr(PyObject * o,const char * field,int64 * result)570 bool GetIntAttr(PyObject* o, const char* field, int64* result) {
571   PyObject* fo = PyObject_GetAttrString(o, field);
572   if (!fo) {
573     return false;
574   }
575   const int64 value = numpy::PyIntOrPyLongToLong(fo);
576   if (value == -1 && PyErr_Occurred()) {
577     Py_DECREF(fo);
578     return false;
579   }
580   Py_DECREF(fo);
581   *result = value;
582   return true;
583 }
584 
585 // Returns "ok"; true if there is no error, false if there was an error.
HandleStringAttribute(PyObject * o,const char * attr_name,std::function<void (string s)> f)586 bool HandleStringAttribute(PyObject* o, const char* attr_name,
587                            std::function<void(string s)> f) {
588   if (!PyObject_HasAttrString(o, attr_name)) {
589     return true;  // It's ok for the object to not have the attribute.
590   }
591   PyObject* attr = PyObject_GetAttrString(o, attr_name);
592   if (attr == nullptr) {
593     return false;  // An error occurred getting the attribute.
594   }
595   if (attr == Py_None) {
596     Py_DECREF(attr);
597     return true;  // The attribute is None, which we consider ok.
598   }
599 #if PY_MAJOR_VERSION < 3
600   if (!PyString_Check(attr)) {
601     string message = absl::StrFormat("%s must be a string or none; got %s",
602                                      attr_name, numpy::PyObjectCppRepr(attr));
603     PyErr_SetString(PyExc_TypeError, message.c_str());
604     Py_DECREF(attr);
605     return false;  // Type error, not ok.
606   }
607   f(PyString_AsString(attr));
608 #else
609   if (!PyBytes_Check(attr)) {
610     string message = absl::StrFormat("%s must be a string or none; got %s",
611                                      attr_name, numpy::PyObjectCppRepr(attr));
612     PyErr_SetString(PyExc_TypeError, message.c_str());
613     Py_DECREF(attr);
614     return false;  // Type error, not ok.
615   }
616   f(PyBytes_AsString(attr));
617 #endif
618 
619   Py_DECREF(attr);
620   return true;  // Handled string attribute, ok!
621 }
622 
623 // Returns "ok"; true if there is no error, false if there was an error.
HandleBoolAttribute(PyObject * o,const char * attr_name,std::function<void (bool b)> f)624 bool HandleBoolAttribute(PyObject* o, const char* attr_name,
625                          std::function<void(bool b)> f) {
626   if (!PyObject_HasAttrString(o, attr_name)) {
627     return true;  // It's ok for the object to not have the attribute.
628   }
629   PyObject* attr = PyObject_GetAttrString(o, attr_name);
630   if (attr == nullptr) {
631     return false;  // An error occurred getting the attribute.
632   }
633   if (attr == Py_None) {
634     Py_DECREF(attr);
635     return true;  // The attribute is None, which we consider ok.
636   }
637   if (!PyBool_Check(attr)) {
638     string message = absl::StrFormat("%s must be a boolean or none; got %s",
639                                      attr_name, numpy::PyObjectCppRepr(attr));
640     PyErr_SetString(PyExc_TypeError, message.c_str());
641     Py_DECREF(attr);
642     return false;  // Type error, not ok.
643   }
644   f(PyObject_IsTrue(attr));
645   Py_DECREF(attr);
646   return true;  // Handled boolean attribute, ok!
647 }
648 
HandleRepeatedInt64Attribute(PyObject * o,const char * attr_name,tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64> * field)649 bool HandleRepeatedInt64Attribute(
650     PyObject* o, const char* attr_name,
651     tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>* field) {
652   PyObject* seq = PyObject_GetAttrString(o, attr_name);
653   if (!seq) {
654     return false;
655   }
656 
657   int length = PySequence_Size(seq);
658   if (length == -1) {
659     Py_DECREF(seq);
660     return false;
661   }
662 
663   for (int i = 0; i < length; ++i) {
664     PyObject* item = PySequence_GetItem(seq, i);
665     if (!item) {
666       Py_DECREF(seq);
667       return false;
668     }
669     const int64 dimension = numpy::PyIntOrPyLongToLong(item);
670     if (dimension == -1 && PyErr_Occurred()) {
671       Py_DECREF(item);
672       Py_DECREF(seq);
673       return false;
674     }
675     *field->Add() = dimension;
676     Py_DECREF(item);
677   }
678   Py_DECREF(seq);
679   return true;
680 }
681 
682 }  // namespace swig
683 
684 }  // namespace xla
685