1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <array>
17
18 #include "tensorflow/python/lib/core/bfloat16.h"
19
20 #include "tensorflow/core/framework/numeric_types.h"
21 #include "tensorflow/core/lib/strings/strcat.h"
22 #include "tensorflow/core/platform/logging.h"
23 #include "tensorflow/python/lib/core/numpy.h"
24 #include "tensorflow/python/lib/core/safe_ptr.h"
25
26 namespace tensorflow {
27 namespace {
28
29 // Workarounds for Python 2 vs 3 API differences.
30 #if PY_MAJOR_VERSION < 3
31
MakePyString(const string & s)32 PyObject* MakePyString(const string& s) {
33 return PyString_FromString(s.c_str());
34 }
35
36 typedef long HashType; // NOLINT
37
TfPyInt_Check(PyObject * object)38 bool TfPyInt_Check(PyObject* object) { return PyInt_Check(object); }
39
TfPyInt_FromLong(long x)40 PyObject* TfPyInt_FromLong(long x) { // NOLINT
41 return PyInt_FromLong(x);
42 }
43
TfPyInt_AsLong(PyObject * x)44 long TfPyInt_AsLong(PyObject* x) { // NOLINT
45 return PyInt_AsLong(x);
46 }
47
48 #else // PY_MAJOR_VERSION < 3
49
50 PyObject* MakePyString(const string& s) {
51 return PyUnicode_FromString(s.c_str());
52 }
53
54 bool TfPyInt_Check(PyObject* object) {
55 if (!PyLong_Check(object)) {
56 return 0;
57 }
58 int overflow = 0;
59 PyLong_AsLongAndOverflow(object, &overflow);
60 return (overflow == 0);
61 }
62
63 PyObject* TfPyInt_FromLong(long x) { // NOLINT
64 return PyLong_FromLong(x);
65 }
66
67 long TfPyInt_AsLong(PyObject* x) { // NOLINT
68 return PyLong_AsLong(x);
69 }
70
71 typedef Py_hash_t HashType;
72
73 #endif // PY_MAJOR_VERSION < 3
74
75 // Forward declaration.
76 extern PyTypeObject PyBfloat16_Type;
77
78 // Representation of a Python bfloat16 object.
79 struct PyBfloat16 {
80 PyObject_HEAD; // Python object header
81 bfloat16 value;
82 };
83
84 // Returns true if 'object' is a PyBfloat16.
PyBfloat16_Check(PyObject * object)85 bool PyBfloat16_Check(PyObject* object) {
86 return PyObject_IsInstance(object,
87 reinterpret_cast<PyObject*>(&PyBfloat16_Type));
88 }
89
90 // Extracts the value of a PyBfloat16 object.
PyBfloat16_Bfloat16(PyObject * object)91 bfloat16 PyBfloat16_Bfloat16(PyObject* object) {
92 return reinterpret_cast<PyBfloat16*>(object)->value;
93 }
94
95 // Constructs a PyBfloat16 object from a bfloat16.
PyBfloat16_FromBfloat16(bfloat16 x)96 Safe_PyObjectPtr PyBfloat16_FromBfloat16(bfloat16 x) {
97 Safe_PyObjectPtr ref =
98 make_safe(PyBfloat16_Type.tp_alloc(&PyBfloat16_Type, 0));
99 PyBfloat16* p = reinterpret_cast<PyBfloat16*>(ref.get());
100 if (p) {
101 p->value = x;
102 }
103 return ref;
104 }
105
106 // Converts a Python object to a bfloat16 value. Returns true on success,
107 // returns false and reports a Python error on failure.
AsBfloat16(PyObject * arg,bfloat16 * output)108 bool AsBfloat16(PyObject* arg, bfloat16* output) {
109 if (PyBfloat16_Check(arg)) {
110 *output = PyBfloat16_Bfloat16(arg);
111 return true;
112 }
113 if (PyFloat_Check(arg)) {
114 double d = PyFloat_AsDouble(arg);
115 if (PyErr_Occurred()) {
116 return false;
117 }
118 // TODO(phawkins): check for overflow
119 *output = bfloat16(d);
120 return true;
121 }
122 if (TfPyInt_Check(arg)) {
123 long l = TfPyInt_AsLong(arg); // NOLINT
124 if (PyErr_Occurred()) {
125 return false;
126 }
127 // TODO(phawkins): check for overflow
128 *output = bfloat16(static_cast<float>(l));
129 return true;
130 }
131 if (PyArray_IsScalar(arg, Float)) {
132 float f;
133 PyArray_ScalarAsCtype(arg, &f);
134 *output = bfloat16(f);
135 return true;
136 }
137 PyErr_Format(PyExc_TypeError, "expected number, got %s",
138 arg->ob_type->tp_name);
139 return false;
140 }
141
142 // Converts a PyBfloat16 into a PyFloat.
PyBfloat16_Float(PyObject * self)143 PyObject* PyBfloat16_Float(PyObject* self) {
144 bfloat16 x = PyBfloat16_Bfloat16(self);
145 return PyFloat_FromDouble(static_cast<double>(x));
146 }
147
148 // Converts a PyBfloat16 into a PyInt.
PyBfloat16_Int(PyObject * self)149 PyObject* PyBfloat16_Int(PyObject* self) {
150 bfloat16 x = PyBfloat16_Bfloat16(self);
151 long y = static_cast<long>(x); // NOLINT
152 return TfPyInt_FromLong(y);
153 }
154
155 // Negates a PyBfloat16.
PyBfloat16_Negative(PyObject * self)156 PyObject* PyBfloat16_Negative(PyObject* self) {
157 bfloat16 x = PyBfloat16_Bfloat16(self);
158 return PyBfloat16_FromBfloat16(-x).release();
159 }
160
161 // Binary arithmetic operators on PyBfloat16 values.
162 #define BFLOAT16_BINOP(name, op) \
163 PyObject* PyBfloat16_##name(PyObject* a, PyObject* b) { \
164 bfloat16 x, y; \
165 if (!AsBfloat16(a, &x) || !AsBfloat16(b, &y)) return nullptr; \
166 bfloat16 z = x op y; \
167 return PyBfloat16_FromBfloat16(z).release(); \
168 }
169 BFLOAT16_BINOP(Add, +)
170 BFLOAT16_BINOP(Subtract, -)
171 BFLOAT16_BINOP(Multiply, *)
172 BFLOAT16_BINOP(Divide, /)
173 #undef BFLOAT16_BINOP
174
175 // Python number methods for PyBfloat16 objects.
176 PyNumberMethods PyBfloat16_AsNumber = {
177 PyBfloat16_Add, // nb_add
178 PyBfloat16_Subtract, // nb_subtract
179 PyBfloat16_Multiply, // nb_multiply
180 #if PY_MAJOR_VERSION < 3
181 PyBfloat16_Divide, // nb_divide
182 #endif
183 nullptr, // nb_remainder
184 nullptr, // nb_divmod
185 nullptr, // nb_power
186 PyBfloat16_Negative, // nb_negative
187 nullptr, // nb_positive
188 nullptr, // nb_absolute
189 nullptr, // nb_nonzero
190 nullptr, // nb_invert
191 nullptr, // nb_lshift
192 nullptr, // nb_rshift
193 nullptr, // nb_and
194 nullptr, // nb_xor
195 nullptr, // nb_or
196 #if PY_MAJOR_VERSION < 3
197 nullptr, // nb_coerce
198 #endif
199 PyBfloat16_Int, // nb_int
200 #if PY_MAJOR_VERSION < 3
201 PyBfloat16_Int, // nb_long
202 #else
203 nullptr, // reserved
204 #endif
205 PyBfloat16_Float, // nb_float
206 #if PY_MAJOR_VERSION < 3
207 nullptr, // nb_oct
208 nullptr, // nb_hex
209 #endif
210
211 nullptr, // nb_inplace_add
212 nullptr, // nb_inplace_subtract
213 nullptr, // nb_inplace_multiply
214 #if PY_MAJOR_VERSION < 3
215 nullptr, // nb_inplace_divide
216 #endif
217 nullptr, // nb_inplace_remainder
218 nullptr, // nb_inplace_power
219 nullptr, // nb_inplace_lshift
220 nullptr, // nb_inplace_rshift
221 nullptr, // nb_inplace_and
222 nullptr, // nb_inplace_xor
223 nullptr, // nb_inplace_or
224
225 nullptr, // nb_floor_divide
226 PyBfloat16_Divide, // nb_true_divide
227 nullptr, // nb_inplace_floor_divide
228 nullptr, // nb_inplace_true_divide
229 nullptr, // nb_index
230 };
231
232 // Constructs a new PyBfloat16.
PyBfloat16_New(PyTypeObject * type,PyObject * args,PyObject * kwds)233 PyObject* PyBfloat16_New(PyTypeObject* type, PyObject* args, PyObject* kwds) {
234 if (kwds && PyDict_Size(kwds)) {
235 PyErr_SetString(PyExc_TypeError, "constructor takes no keyword arguments");
236 return nullptr;
237 }
238 Py_ssize_t size = PyTuple_Size(args);
239 if (size != 1) {
240 PyErr_SetString(PyExc_TypeError,
241 "expected number as argument to bfloat16 constructor");
242 return nullptr;
243 }
244 PyObject* arg = PyTuple_GetItem(args, 0);
245
246 if (PyBfloat16_Check(arg)) {
247 Py_INCREF(arg);
248 return arg;
249 } else {
250 bfloat16 value;
251 if (!AsBfloat16(arg, &value)) {
252 return nullptr;
253 }
254 return PyBfloat16_FromBfloat16(value).release();
255 }
256 }
257
258 // Comparisons on PyBfloat16s.
PyBfloat16_RichCompare(PyObject * a,PyObject * b,int op)259 PyObject* PyBfloat16_RichCompare(PyObject* a, PyObject* b, int op) {
260 bfloat16 x, y;
261 if (!AsBfloat16(a, &x) || !AsBfloat16(b, &y)) return nullptr;
262 bool result;
263 switch (op) {
264 case Py_LT:
265 result = x < y;
266 break;
267 case Py_LE:
268 result = x <= y;
269 break;
270 case Py_EQ:
271 result = x == y;
272 break;
273 case Py_NE:
274 result = x != y;
275 break;
276 case Py_GT:
277 result = x > y;
278 break;
279 case Py_GE:
280 result = x >= y;
281 break;
282 default:
283 LOG(FATAL) << "Invalid op type " << op;
284 }
285 return PyBool_FromLong(result);
286 }
287
288 // Implementation of repr() for PyBfloat16.
PyBfloat16_Repr(PyObject * self)289 PyObject* PyBfloat16_Repr(PyObject* self) {
290 bfloat16 x = reinterpret_cast<PyBfloat16*>(self)->value;
291 string v = strings::StrCat("bfloat16(", static_cast<float>(x), ")");
292 return MakePyString(v);
293 }
294
295 // Implementation of str() for PyBfloat16.
PyBfloat16_Str(PyObject * self)296 PyObject* PyBfloat16_Str(PyObject* self) {
297 bfloat16 x = reinterpret_cast<PyBfloat16*>(self)->value;
298 string v = strings::StrCat(static_cast<float>(x));
299 return MakePyString(v);
300 }
301
302 // Hash function for PyBfloat16. We use the identity function, which is a weak
303 // hash function.
PyBfloat16_Hash(PyObject * self)304 HashType PyBfloat16_Hash(PyObject* self) {
305 bfloat16 x = reinterpret_cast<PyBfloat16*>(self)->value;
306 return x.value;
307 }
308
309 // Python type for PyBfloat16 objects.
310 PyTypeObject PyBfloat16_Type = {
311 #if PY_MAJOR_VERSION < 3
312 PyObject_HEAD_INIT(nullptr) 0, // ob_size
313 #else
314 PyVarObject_HEAD_INIT(nullptr, 0)
315 #endif
316 "bfloat16", // tp_name
317 sizeof(PyBfloat16), // tp_basicsize
318 0, // tp_itemsize
319 nullptr, // tp_dealloc
320 nullptr, // tp_print
321 nullptr, // tp_getattr
322 nullptr, // tp_setattr
323 nullptr, // tp_compare / tp_reserved
324 PyBfloat16_Repr, // tp_repr
325 &PyBfloat16_AsNumber, // tp_as_number
326 nullptr, // tp_as_sequence
327 nullptr, // tp_as_mapping
328 PyBfloat16_Hash, // tp_hash
329 nullptr, // tp_call
330 PyBfloat16_Str, // tp_str
331 nullptr, // tp_getattro
332 nullptr, // tp_setattro
333 nullptr, // tp_as_buffer
334 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, // tp_flags
335 "bfloat16 floating-point values", // tp_doc
336 nullptr, // tp_traverse
337 nullptr, // tp_clear
338 PyBfloat16_RichCompare, // tp_richcompare
339 0, // tp_weaklistoffset
340 nullptr, // tp_iter
341 nullptr, // tp_iternext
342 nullptr, // tp_methods
343 nullptr, // tp_members
344 nullptr, // tp_getset
345 nullptr, // tp_base
346 nullptr, // tp_dict
347 nullptr, // tp_descr_get
348 nullptr, // tp_descr_set
349 0, // tp_dictoffset
350 nullptr, // tp_init
351 nullptr, // tp_alloc
352 PyBfloat16_New, // tp_new
353 nullptr, // tp_free
354 nullptr, // tp_is_gc
355 nullptr, // tp_bases
356 nullptr, // tp_mro
357 nullptr, // tp_cache
358 nullptr, // tp_subclasses
359 nullptr, // tp_weaklist
360 nullptr, // tp_del
361 0, // tp_version_tag
362 };
363
364 // Numpy support
365
366 PyArray_ArrFuncs NPyBfloat16_ArrFuncs;
367
368 PyArray_Descr NPyBfloat16_Descr = {
369 PyObject_HEAD_INIT(nullptr) & PyBfloat16_Type, // typeobj
370 // We must register bfloat16 with a kind other than "f", because numpy
371 // considers two types with the same kind and size to be equal, but
372 // float16 != bfloat16.
373 'V', // kind
374 // TODO(phawkins): there doesn't seem to be a way of guaranteeing a type
375 // character is unique.
376 'E', // type
377 '=', // byteorder
378 NPY_NEEDS_PYAPI | NPY_USE_GETITEM | NPY_USE_SETITEM, // hasobject
379 0, // type_num
380 sizeof(bfloat16), // elsize
381 alignof(bfloat16), // alignment
382 nullptr, // subarray
383 nullptr, // fields
384 nullptr, // names
385 &NPyBfloat16_ArrFuncs, // f
386 };
387
388 // Registered numpy type ID. Global variable populated by the registration code.
389 int npy_bfloat16_ = -1;
390
391 // Implementations of NumPy array methods.
392
NPyBfloat16_GetItem(void * data,void * arr)393 PyObject* NPyBfloat16_GetItem(void* data, void* arr) {
394 bfloat16 x;
395 memcpy(&x, data, sizeof(bfloat16));
396 return PyBfloat16_FromBfloat16(x).release();
397 }
398
NPyBfloat16_SetItem(PyObject * item,void * data,void * arr)399 int NPyBfloat16_SetItem(PyObject* item, void* data, void* arr) {
400 bfloat16 x;
401 if (!AsBfloat16(item, &x)) return -1;
402 memcpy(data, &x, sizeof(bfloat16));
403 return 0;
404 }
405
ByteSwap16(void * value)406 void ByteSwap16(void* value) {
407 char* p = reinterpret_cast<char*>(value);
408 std::swap(p[0], p[1]);
409 }
410
NPyBfloat16_CopySwapN(void * dstv,npy_intp dstride,void * srcv,npy_intp sstride,npy_intp n,int swap,void * arr)411 void NPyBfloat16_CopySwapN(void* dstv, npy_intp dstride, void* srcv,
412 npy_intp sstride, npy_intp n, int swap, void* arr) {
413 char* dst = reinterpret_cast<char*>(dstv);
414 char* src = reinterpret_cast<char*>(srcv);
415 if (!src) {
416 return;
417 }
418 if (swap) {
419 for (npy_intp i = 0; i < n; i++) {
420 char* r = dst + dstride * i;
421 memcpy(r, src + sstride * i, sizeof(uint16_t));
422 ByteSwap16(r);
423 }
424 } else if (dstride == sizeof(uint16_t) && sstride == sizeof(uint16_t)) {
425 memcpy(dst, src, n * sizeof(uint16_t));
426 } else {
427 for (npy_intp i = 0; i < n; i++) {
428 memcpy(dst + dstride * i, src + sstride * i, sizeof(uint16_t));
429 }
430 }
431 }
432
NPyBfloat16_CopySwap(void * dst,void * src,int swap,void * arr)433 void NPyBfloat16_CopySwap(void* dst, void* src, int swap, void* arr) {
434 if (!src) {
435 return;
436 }
437 memcpy(dst, src, sizeof(uint16_t));
438 if (swap) {
439 ByteSwap16(dst);
440 }
441 }
442
NPyBfloat16_NonZero(void * data,void * arr)443 npy_bool NPyBfloat16_NonZero(void* data, void* arr) {
444 bfloat16 x;
445 memcpy(&x, data, sizeof(x));
446 return x != static_cast<bfloat16>(0);
447 }
448
NPyBfloat16_Fill(void * buffer_raw,npy_intp length,void * ignored)449 int NPyBfloat16_Fill(void* buffer_raw, npy_intp length, void* ignored) {
450 bfloat16* const buffer = reinterpret_cast<bfloat16*>(buffer_raw);
451 const float start(buffer[0]);
452 const float delta = static_cast<float>(buffer[1]) - start;
453 for (npy_intp i = 2; i < length; ++i) {
454 buffer[i] = static_cast<bfloat16>(start + i * delta);
455 }
456 return 0;
457 }
458
459 // NumPy casts
460
461 // Performs a NumPy array cast from type 'From' to 'To'.
462 template <typename From, typename To>
NPyCast(void * from_void,void * to_void,npy_intp n,void * fromarr,void * toarr)463 void NPyCast(void* from_void, void* to_void, npy_intp n, void* fromarr,
464 void* toarr) {
465 const From* from = reinterpret_cast<From*>(from_void);
466 To* to = reinterpret_cast<To*>(to_void);
467 for (npy_intp i = 0; i < n; ++i) {
468 to[i] = static_cast<To>(from[i]);
469 }
470 }
471
472 // Registers a cast between bfloat16 and type 'T'. 'numpy_type' is the NumPy
473 // type corresponding to 'T'. If 'cast_is_safe', registers that bfloat16 can be
474 // safely coerced to T.
475 template <typename T>
RegisterBfloat16Cast(int numpy_type,bool cast_is_safe)476 bool RegisterBfloat16Cast(int numpy_type, bool cast_is_safe) {
477 if (PyArray_RegisterCastFunc(PyArray_DescrFromType(numpy_type), npy_bfloat16_,
478 NPyCast<T, bfloat16>) < 0) {
479 return false;
480 }
481 if (PyArray_RegisterCastFunc(&NPyBfloat16_Descr, numpy_type,
482 NPyCast<bfloat16, T>) < 0) {
483 return false;
484 }
485 if (cast_is_safe && PyArray_RegisterCanCast(&NPyBfloat16_Descr, numpy_type,
486 NPY_NOSCALAR) < 0) {
487 return false;
488 }
489 return true;
490 }
491
492 template <typename InType, typename OutType, typename Functor>
BinaryUFunc(char ** args,npy_intp * dimensions,npy_intp * steps,void * data)493 void BinaryUFunc(char** args, npy_intp* dimensions, npy_intp* steps,
494 void* data) {
495 const char* i0 = args[0];
496 const char* i1 = args[1];
497 char* o = args[2];
498 for (npy_intp k = 0; k < *dimensions; k++) {
499 InType x = *reinterpret_cast<const InType*>(i0);
500 InType y = *reinterpret_cast<const InType*>(i1);
501 *reinterpret_cast<OutType*>(o) = Functor()(x, y);
502 i0 += steps[0];
503 i1 += steps[1];
504 o += steps[2];
505 }
506 }
507
508 template <typename Functor>
CompareUFunc(char ** args,npy_intp * dimensions,npy_intp * steps,void * data)509 void CompareUFunc(char** args, npy_intp* dimensions, npy_intp* steps,
510 void* data) {
511 BinaryUFunc<bfloat16, npy_bool, Functor>(args, dimensions, steps, data);
512 }
513
514 struct Bfloat16EqFunctor {
operator ()tensorflow::__anonbd6dfcf70111::Bfloat16EqFunctor515 npy_bool operator()(bfloat16 a, bfloat16 b) { return a == b; }
516 };
517 struct Bfloat16NeFunctor {
operator ()tensorflow::__anonbd6dfcf70111::Bfloat16NeFunctor518 npy_bool operator()(bfloat16 a, bfloat16 b) { return a != b; }
519 };
520 struct Bfloat16LtFunctor {
operator ()tensorflow::__anonbd6dfcf70111::Bfloat16LtFunctor521 npy_bool operator()(bfloat16 a, bfloat16 b) { return a < b; }
522 };
523 struct Bfloat16GtFunctor {
operator ()tensorflow::__anonbd6dfcf70111::Bfloat16GtFunctor524 npy_bool operator()(bfloat16 a, bfloat16 b) { return a > b; }
525 };
526 struct Bfloat16LeFunctor {
operator ()tensorflow::__anonbd6dfcf70111::Bfloat16LeFunctor527 npy_bool operator()(bfloat16 a, bfloat16 b) { return a <= b; }
528 };
529 struct Bfloat16GeFunctor {
operator ()tensorflow::__anonbd6dfcf70111::Bfloat16GeFunctor530 npy_bool operator()(bfloat16 a, bfloat16 b) { return a >= b; }
531 };
532
533 // Initializes the module.
Initialize()534 bool Initialize() {
535 // It's critical to import umath to avoid crash in open source build.
536 import_umath1(false);
537
538 Safe_PyObjectPtr numpy_str = make_safe(MakePyString("numpy"));
539 if (!numpy_str) {
540 return false;
541 }
542 Safe_PyObjectPtr numpy = make_safe(PyImport_Import(numpy_str.get()));
543 if (!numpy) {
544 return false;
545 }
546
547 // We hit a mysterious crash if we haven't initialized numpy before this:
548 PyBfloat16_Type.tp_base = &PyGenericArrType_Type;
549
550 if (PyType_Ready(&PyBfloat16_Type) < 0) {
551 return false;
552 }
553
554 // Initializes the NumPy descriptor.
555 PyArray_InitArrFuncs(&NPyBfloat16_ArrFuncs);
556 NPyBfloat16_ArrFuncs.getitem = NPyBfloat16_GetItem;
557 NPyBfloat16_ArrFuncs.setitem = NPyBfloat16_SetItem;
558 NPyBfloat16_ArrFuncs.copyswapn = NPyBfloat16_CopySwapN;
559 NPyBfloat16_ArrFuncs.copyswap = NPyBfloat16_CopySwap;
560 NPyBfloat16_ArrFuncs.nonzero = NPyBfloat16_NonZero;
561 NPyBfloat16_ArrFuncs.fill = NPyBfloat16_Fill;
562
563 Py_TYPE(&NPyBfloat16_Descr) = &PyArrayDescr_Type;
564 npy_bfloat16_ = PyArray_RegisterDataType(&NPyBfloat16_Descr);
565 if (npy_bfloat16_ < 0) return false;
566
567 // Support dtype(bfloat16)
568 if (PyDict_SetItemString(PyBfloat16_Type.tp_dict, "dtype",
569 reinterpret_cast<PyObject*>(&NPyBfloat16_Descr)) <
570 0) {
571 return false;
572 }
573
574 // Register casts
575
576 // We lie shamelessly and say that a cast from half to bfloat16 is safe.
577 // Numpy frequently uses the smallest legal representation type for small
578 // float constants (e.g., 1.0), which is often float16. Things break if these
579 // cannot be converted transparently to bfloat16.
580 if (!RegisterBfloat16Cast<Eigen::half>(NPY_HALF, /*cast_is_safe=*/true)) {
581 return false;
582 }
583
584 if (!RegisterBfloat16Cast<float>(NPY_FLOAT, /*cast_is_safe=*/true)) {
585 return false;
586 }
587 if (!RegisterBfloat16Cast<double>(NPY_DOUBLE, /*cast_is_safe=*/true)) {
588 return false;
589 }
590 if (!RegisterBfloat16Cast<int32>(NPY_INT32, /*cast_is_safe=*/false)) {
591 return false;
592 }
593 if (!RegisterBfloat16Cast<int64>(NPY_INT64, /*cast_is_safe=*/false)) {
594 return false;
595 }
596 // Following the numpy convention. imag part is dropped when converting to
597 // float.
598 if (!RegisterBfloat16Cast<complex64>(NPY_COMPLEX64, /*cast_is_safe=*/true)) {
599 return false;
600 }
601 if (!RegisterBfloat16Cast<complex128>(NPY_COMPLEX128,
602 /*cast_is_safe=*/true)) {
603 return false;
604 }
605
606 // Register ufuncs
607 auto register_ufunc = [&](const char* name, PyUFuncGenericFunction fn,
608 const std::array<int, 3>& types) {
609 Safe_PyObjectPtr ufunc_obj =
610 make_safe(PyObject_GetAttrString(numpy.get(), name));
611 if (!ufunc_obj) {
612 return false;
613 }
614 PyUFuncObject* ufunc = reinterpret_cast<PyUFuncObject*>(ufunc_obj.get());
615 if (types.size() != ufunc->nargs) {
616 PyErr_Format(PyExc_AssertionError,
617 "ufunc %s takes %d arguments, loop takes %lu", name,
618 ufunc->nargs, types.size());
619 return false;
620 }
621 if (PyUFunc_RegisterLoopForType(ufunc, npy_bfloat16_, fn,
622 const_cast<int*>(types.data()),
623 nullptr) < 0) {
624 return false;
625 }
626 return true;
627 };
628
629 // Comparisons
630 const std::array<int, 3> compare_types = {
631 {npy_bfloat16_, npy_bfloat16_, NPY_BOOL}};
632
633 if (!register_ufunc("equal", CompareUFunc<Bfloat16EqFunctor>,
634 compare_types)) {
635 return false;
636 }
637 if (!register_ufunc("not_equal", CompareUFunc<Bfloat16NeFunctor>,
638 compare_types)) {
639 return false;
640 }
641 if (!register_ufunc("less", CompareUFunc<Bfloat16LtFunctor>, compare_types)) {
642 return false;
643 }
644 if (!register_ufunc("greater", CompareUFunc<Bfloat16GtFunctor>,
645 compare_types)) {
646 return false;
647 }
648 if (!register_ufunc("less_equal", CompareUFunc<Bfloat16LeFunctor>,
649 compare_types)) {
650 return false;
651 }
652 if (!register_ufunc("greater_equal", CompareUFunc<Bfloat16GeFunctor>,
653 compare_types)) {
654 return false;
655 }
656 return true;
657 }
658
659 } // namespace
660
RegisterNumpyBfloat16()661 void RegisterNumpyBfloat16() {
662 if (npy_bfloat16_ >= 0) {
663 // Already initialized.
664 return;
665 }
666 if (!Initialize()) {
667 if (!PyErr_Occurred()) {
668 PyErr_SetString(PyExc_RuntimeError, "cannot load bfloat16 module.");
669 }
670 PyErr_Print();
671 }
672 }
673
Bfloat16PyType()674 PyObject* Bfloat16PyType() {
675 CHECK(PyBfloat16_Type.tp_base != nullptr);
676 Py_INCREF(&PyBfloat16_Type);
677 return reinterpret_cast<PyObject*>(&PyBfloat16_Type);
678 }
679
Bfloat16NumpyType()680 int Bfloat16NumpyType() {
681 CHECK_GE(npy_bfloat16_, 0);
682 return npy_bfloat16_;
683 }
684
685 } // namespace tensorflow
686