1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <array>
17 
18 #include "tensorflow/python/lib/core/bfloat16.h"
19 
20 #include "tensorflow/core/framework/numeric_types.h"
21 #include "tensorflow/core/lib/strings/strcat.h"
22 #include "tensorflow/core/platform/logging.h"
23 #include "tensorflow/python/lib/core/numpy.h"
24 #include "tensorflow/python/lib/core/safe_ptr.h"
25 
26 namespace tensorflow {
27 namespace {
28 
29 // Workarounds for Python 2 vs 3 API differences.
30 #if PY_MAJOR_VERSION < 3
31 
MakePyString(const string & s)32 PyObject* MakePyString(const string& s) {
33   return PyString_FromString(s.c_str());
34 }
35 
36 typedef long HashType;  // NOLINT
37 
TfPyInt_Check(PyObject * object)38 bool TfPyInt_Check(PyObject* object) { return PyInt_Check(object); }
39 
TfPyInt_FromLong(long x)40 PyObject* TfPyInt_FromLong(long x) {  // NOLINT
41   return PyInt_FromLong(x);
42 }
43 
TfPyInt_AsLong(PyObject * x)44 long TfPyInt_AsLong(PyObject* x) {  // NOLINT
45   return PyInt_AsLong(x);
46 }
47 
48 #else  // PY_MAJOR_VERSION < 3
49 
50 PyObject* MakePyString(const string& s) {
51   return PyUnicode_FromString(s.c_str());
52 }
53 
54 bool TfPyInt_Check(PyObject* object) {
55   if (!PyLong_Check(object)) {
56     return 0;
57   }
58   int overflow = 0;
59   PyLong_AsLongAndOverflow(object, &overflow);
60   return (overflow == 0);
61 }
62 
63 PyObject* TfPyInt_FromLong(long x) {  // NOLINT
64   return PyLong_FromLong(x);
65 }
66 
67 long TfPyInt_AsLong(PyObject* x) {  // NOLINT
68   return PyLong_AsLong(x);
69 }
70 
71 typedef Py_hash_t HashType;
72 
73 #endif  // PY_MAJOR_VERSION < 3
74 
75 // Forward declaration.
76 extern PyTypeObject PyBfloat16_Type;
77 
78 // Representation of a Python bfloat16 object.
79 struct PyBfloat16 {
80   PyObject_HEAD;  // Python object header
81   bfloat16 value;
82 };
83 
84 // Returns true if 'object' is a PyBfloat16.
PyBfloat16_Check(PyObject * object)85 bool PyBfloat16_Check(PyObject* object) {
86   return PyObject_IsInstance(object,
87                              reinterpret_cast<PyObject*>(&PyBfloat16_Type));
88 }
89 
90 // Extracts the value of a PyBfloat16 object.
PyBfloat16_Bfloat16(PyObject * object)91 bfloat16 PyBfloat16_Bfloat16(PyObject* object) {
92   return reinterpret_cast<PyBfloat16*>(object)->value;
93 }
94 
95 // Constructs a PyBfloat16 object from a bfloat16.
PyBfloat16_FromBfloat16(bfloat16 x)96 Safe_PyObjectPtr PyBfloat16_FromBfloat16(bfloat16 x) {
97   Safe_PyObjectPtr ref =
98       make_safe(PyBfloat16_Type.tp_alloc(&PyBfloat16_Type, 0));
99   PyBfloat16* p = reinterpret_cast<PyBfloat16*>(ref.get());
100   if (p) {
101     p->value = x;
102   }
103   return ref;
104 }
105 
106 // Converts a Python object to a bfloat16 value. Returns true on success,
107 // returns false and reports a Python error on failure.
AsBfloat16(PyObject * arg,bfloat16 * output)108 bool AsBfloat16(PyObject* arg, bfloat16* output) {
109   if (PyBfloat16_Check(arg)) {
110     *output = PyBfloat16_Bfloat16(arg);
111     return true;
112   }
113   if (PyFloat_Check(arg)) {
114     double d = PyFloat_AsDouble(arg);
115     if (PyErr_Occurred()) {
116       return false;
117     }
118     // TODO(phawkins): check for overflow
119     *output = bfloat16(d);
120     return true;
121   }
122   if (TfPyInt_Check(arg)) {
123     long l = TfPyInt_AsLong(arg);  // NOLINT
124     if (PyErr_Occurred()) {
125       return false;
126     }
127     // TODO(phawkins): check for overflow
128     *output = bfloat16(static_cast<float>(l));
129     return true;
130   }
131   if (PyArray_IsScalar(arg, Float)) {
132     float f;
133     PyArray_ScalarAsCtype(arg, &f);
134     *output = bfloat16(f);
135     return true;
136   }
137   PyErr_Format(PyExc_TypeError, "expected number, got %s",
138                arg->ob_type->tp_name);
139   return false;
140 }
141 
142 // Converts a PyBfloat16 into a PyFloat.
PyBfloat16_Float(PyObject * self)143 PyObject* PyBfloat16_Float(PyObject* self) {
144   bfloat16 x = PyBfloat16_Bfloat16(self);
145   return PyFloat_FromDouble(static_cast<double>(x));
146 }
147 
148 // Converts a PyBfloat16 into a PyInt.
PyBfloat16_Int(PyObject * self)149 PyObject* PyBfloat16_Int(PyObject* self) {
150   bfloat16 x = PyBfloat16_Bfloat16(self);
151   long y = static_cast<long>(x);  // NOLINT
152   return TfPyInt_FromLong(y);
153 }
154 
155 // Negates a PyBfloat16.
PyBfloat16_Negative(PyObject * self)156 PyObject* PyBfloat16_Negative(PyObject* self) {
157   bfloat16 x = PyBfloat16_Bfloat16(self);
158   return PyBfloat16_FromBfloat16(-x).release();
159 }
160 
161 // Binary arithmetic operators on PyBfloat16 values.
162 #define BFLOAT16_BINOP(name, op)                                  \
163   PyObject* PyBfloat16_##name(PyObject* a, PyObject* b) {         \
164     bfloat16 x, y;                                                \
165     if (!AsBfloat16(a, &x) || !AsBfloat16(b, &y)) return nullptr; \
166     bfloat16 z = x op y;                                          \
167     return PyBfloat16_FromBfloat16(z).release();                  \
168   }
169 BFLOAT16_BINOP(Add, +)
170 BFLOAT16_BINOP(Subtract, -)
171 BFLOAT16_BINOP(Multiply, *)
172 BFLOAT16_BINOP(Divide, /)
173 #undef BFLOAT16_BINOP
174 
175 // Python number methods for PyBfloat16 objects.
176 PyNumberMethods PyBfloat16_AsNumber = {
177     PyBfloat16_Add,       // nb_add
178     PyBfloat16_Subtract,  // nb_subtract
179     PyBfloat16_Multiply,  // nb_multiply
180 #if PY_MAJOR_VERSION < 3
181     PyBfloat16_Divide,  // nb_divide
182 #endif
183     nullptr,              // nb_remainder
184     nullptr,              // nb_divmod
185     nullptr,              // nb_power
186     PyBfloat16_Negative,  // nb_negative
187     nullptr,              // nb_positive
188     nullptr,              // nb_absolute
189     nullptr,              // nb_nonzero
190     nullptr,              // nb_invert
191     nullptr,              // nb_lshift
192     nullptr,              // nb_rshift
193     nullptr,              // nb_and
194     nullptr,              // nb_xor
195     nullptr,              // nb_or
196 #if PY_MAJOR_VERSION < 3
197     nullptr,  // nb_coerce
198 #endif
199     PyBfloat16_Int,  // nb_int
200 #if PY_MAJOR_VERSION < 3
201     PyBfloat16_Int,  // nb_long
202 #else
203     nullptr,  // reserved
204 #endif
205     PyBfloat16_Float,  // nb_float
206 #if PY_MAJOR_VERSION < 3
207     nullptr,  // nb_oct
208     nullptr,  // nb_hex
209 #endif
210 
211     nullptr,  // nb_inplace_add
212     nullptr,  // nb_inplace_subtract
213     nullptr,  // nb_inplace_multiply
214 #if PY_MAJOR_VERSION < 3
215     nullptr,  // nb_inplace_divide
216 #endif
217     nullptr,  // nb_inplace_remainder
218     nullptr,  // nb_inplace_power
219     nullptr,  // nb_inplace_lshift
220     nullptr,  // nb_inplace_rshift
221     nullptr,  // nb_inplace_and
222     nullptr,  // nb_inplace_xor
223     nullptr,  // nb_inplace_or
224 
225     nullptr,            // nb_floor_divide
226     PyBfloat16_Divide,  // nb_true_divide
227     nullptr,            // nb_inplace_floor_divide
228     nullptr,            // nb_inplace_true_divide
229     nullptr,            // nb_index
230 };
231 
232 // Constructs a new PyBfloat16.
PyBfloat16_New(PyTypeObject * type,PyObject * args,PyObject * kwds)233 PyObject* PyBfloat16_New(PyTypeObject* type, PyObject* args, PyObject* kwds) {
234   if (kwds && PyDict_Size(kwds)) {
235     PyErr_SetString(PyExc_TypeError, "constructor takes no keyword arguments");
236     return nullptr;
237   }
238   Py_ssize_t size = PyTuple_Size(args);
239   if (size != 1) {
240     PyErr_SetString(PyExc_TypeError,
241                     "expected number as argument to bfloat16 constructor");
242     return nullptr;
243   }
244   PyObject* arg = PyTuple_GetItem(args, 0);
245 
246   if (PyBfloat16_Check(arg)) {
247     Py_INCREF(arg);
248     return arg;
249   } else {
250     bfloat16 value;
251     if (!AsBfloat16(arg, &value)) {
252       return nullptr;
253     }
254     return PyBfloat16_FromBfloat16(value).release();
255   }
256 }
257 
258 // Comparisons on PyBfloat16s.
PyBfloat16_RichCompare(PyObject * a,PyObject * b,int op)259 PyObject* PyBfloat16_RichCompare(PyObject* a, PyObject* b, int op) {
260   bfloat16 x, y;
261   if (!AsBfloat16(a, &x) || !AsBfloat16(b, &y)) return nullptr;
262   bool result;
263   switch (op) {
264     case Py_LT:
265       result = x < y;
266       break;
267     case Py_LE:
268       result = x <= y;
269       break;
270     case Py_EQ:
271       result = x == y;
272       break;
273     case Py_NE:
274       result = x != y;
275       break;
276     case Py_GT:
277       result = x > y;
278       break;
279     case Py_GE:
280       result = x >= y;
281       break;
282     default:
283       LOG(FATAL) << "Invalid op type " << op;
284   }
285   return PyBool_FromLong(result);
286 }
287 
288 // Implementation of repr() for PyBfloat16.
PyBfloat16_Repr(PyObject * self)289 PyObject* PyBfloat16_Repr(PyObject* self) {
290   bfloat16 x = reinterpret_cast<PyBfloat16*>(self)->value;
291   string v = strings::StrCat("bfloat16(", static_cast<float>(x), ")");
292   return MakePyString(v);
293 }
294 
295 // Implementation of str() for PyBfloat16.
PyBfloat16_Str(PyObject * self)296 PyObject* PyBfloat16_Str(PyObject* self) {
297   bfloat16 x = reinterpret_cast<PyBfloat16*>(self)->value;
298   string v = strings::StrCat(static_cast<float>(x));
299   return MakePyString(v);
300 }
301 
302 // Hash function for PyBfloat16. We use the identity function, which is a weak
303 // hash function.
PyBfloat16_Hash(PyObject * self)304 HashType PyBfloat16_Hash(PyObject* self) {
305   bfloat16 x = reinterpret_cast<PyBfloat16*>(self)->value;
306   return x.value;
307 }
308 
309 // Python type for PyBfloat16 objects.
310 PyTypeObject PyBfloat16_Type = {
311 #if PY_MAJOR_VERSION < 3
312     PyObject_HEAD_INIT(nullptr) 0,  // ob_size
313 #else
314     PyVarObject_HEAD_INIT(nullptr, 0)
315 #endif
316     "bfloat16",                                // tp_name
317     sizeof(PyBfloat16),                        // tp_basicsize
318     0,                                         // tp_itemsize
319     nullptr,                                   // tp_dealloc
320     nullptr,                                   // tp_print
321     nullptr,                                   // tp_getattr
322     nullptr,                                   // tp_setattr
323     nullptr,                                   // tp_compare / tp_reserved
324     PyBfloat16_Repr,                           // tp_repr
325     &PyBfloat16_AsNumber,                      // tp_as_number
326     nullptr,                                   // tp_as_sequence
327     nullptr,                                   // tp_as_mapping
328     PyBfloat16_Hash,                           // tp_hash
329     nullptr,                                   // tp_call
330     PyBfloat16_Str,                            // tp_str
331     nullptr,                                   // tp_getattro
332     nullptr,                                   // tp_setattro
333     nullptr,                                   // tp_as_buffer
334     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,  // tp_flags
335     "bfloat16 floating-point values",          // tp_doc
336     nullptr,                                   // tp_traverse
337     nullptr,                                   // tp_clear
338     PyBfloat16_RichCompare,                    // tp_richcompare
339     0,                                         // tp_weaklistoffset
340     nullptr,                                   // tp_iter
341     nullptr,                                   // tp_iternext
342     nullptr,                                   // tp_methods
343     nullptr,                                   // tp_members
344     nullptr,                                   // tp_getset
345     nullptr,                                   // tp_base
346     nullptr,                                   // tp_dict
347     nullptr,                                   // tp_descr_get
348     nullptr,                                   // tp_descr_set
349     0,                                         // tp_dictoffset
350     nullptr,                                   // tp_init
351     nullptr,                                   // tp_alloc
352     PyBfloat16_New,                            // tp_new
353     nullptr,                                   // tp_free
354     nullptr,                                   // tp_is_gc
355     nullptr,                                   // tp_bases
356     nullptr,                                   // tp_mro
357     nullptr,                                   // tp_cache
358     nullptr,                                   // tp_subclasses
359     nullptr,                                   // tp_weaklist
360     nullptr,                                   // tp_del
361     0,                                         // tp_version_tag
362 };
363 
364 // Numpy support
365 
366 PyArray_ArrFuncs NPyBfloat16_ArrFuncs;
367 
368 PyArray_Descr NPyBfloat16_Descr = {
369     PyObject_HEAD_INIT(nullptr) & PyBfloat16_Type,  // typeobj
370     // We must register bfloat16 with a kind other than "f", because numpy
371     // considers two types with the same kind and size to be equal, but
372     // float16 != bfloat16.
373     'V',  // kind
374     // TODO(phawkins): there doesn't seem to be a way of guaranteeing a type
375     // character is unique.
376     'E',                                                  // type
377     '=',                                                  // byteorder
378     NPY_NEEDS_PYAPI | NPY_USE_GETITEM | NPY_USE_SETITEM,  // hasobject
379     0,                                                    // type_num
380     sizeof(bfloat16),                                     // elsize
381     alignof(bfloat16),                                    // alignment
382     nullptr,                                              // subarray
383     nullptr,                                              // fields
384     nullptr,                                              // names
385     &NPyBfloat16_ArrFuncs,                                // f
386 };
387 
388 // Registered numpy type ID. Global variable populated by the registration code.
389 int npy_bfloat16_ = -1;
390 
391 // Implementations of NumPy array methods.
392 
NPyBfloat16_GetItem(void * data,void * arr)393 PyObject* NPyBfloat16_GetItem(void* data, void* arr) {
394   bfloat16 x;
395   memcpy(&x, data, sizeof(bfloat16));
396   return PyBfloat16_FromBfloat16(x).release();
397 }
398 
NPyBfloat16_SetItem(PyObject * item,void * data,void * arr)399 int NPyBfloat16_SetItem(PyObject* item, void* data, void* arr) {
400   bfloat16 x;
401   if (!AsBfloat16(item, &x)) return -1;
402   memcpy(data, &x, sizeof(bfloat16));
403   return 0;
404 }
405 
ByteSwap16(void * value)406 void ByteSwap16(void* value) {
407   char* p = reinterpret_cast<char*>(value);
408   std::swap(p[0], p[1]);
409 }
410 
NPyBfloat16_CopySwapN(void * dstv,npy_intp dstride,void * srcv,npy_intp sstride,npy_intp n,int swap,void * arr)411 void NPyBfloat16_CopySwapN(void* dstv, npy_intp dstride, void* srcv,
412                            npy_intp sstride, npy_intp n, int swap, void* arr) {
413   char* dst = reinterpret_cast<char*>(dstv);
414   char* src = reinterpret_cast<char*>(srcv);
415   if (!src) {
416     return;
417   }
418   if (swap) {
419     for (npy_intp i = 0; i < n; i++) {
420       char* r = dst + dstride * i;
421       memcpy(r, src + sstride * i, sizeof(uint16_t));
422       ByteSwap16(r);
423     }
424   } else if (dstride == sizeof(uint16_t) && sstride == sizeof(uint16_t)) {
425     memcpy(dst, src, n * sizeof(uint16_t));
426   } else {
427     for (npy_intp i = 0; i < n; i++) {
428       memcpy(dst + dstride * i, src + sstride * i, sizeof(uint16_t));
429     }
430   }
431 }
432 
NPyBfloat16_CopySwap(void * dst,void * src,int swap,void * arr)433 void NPyBfloat16_CopySwap(void* dst, void* src, int swap, void* arr) {
434   if (!src) {
435     return;
436   }
437   memcpy(dst, src, sizeof(uint16_t));
438   if (swap) {
439     ByteSwap16(dst);
440   }
441 }
442 
NPyBfloat16_NonZero(void * data,void * arr)443 npy_bool NPyBfloat16_NonZero(void* data, void* arr) {
444   bfloat16 x;
445   memcpy(&x, data, sizeof(x));
446   return x != static_cast<bfloat16>(0);
447 }
448 
NPyBfloat16_Fill(void * buffer_raw,npy_intp length,void * ignored)449 int NPyBfloat16_Fill(void* buffer_raw, npy_intp length, void* ignored) {
450   bfloat16* const buffer = reinterpret_cast<bfloat16*>(buffer_raw);
451   const float start(buffer[0]);
452   const float delta = static_cast<float>(buffer[1]) - start;
453   for (npy_intp i = 2; i < length; ++i) {
454     buffer[i] = static_cast<bfloat16>(start + i * delta);
455   }
456   return 0;
457 }
458 
459 // NumPy casts
460 
461 // Performs a NumPy array cast from type 'From' to 'To'.
462 template <typename From, typename To>
NPyCast(void * from_void,void * to_void,npy_intp n,void * fromarr,void * toarr)463 void NPyCast(void* from_void, void* to_void, npy_intp n, void* fromarr,
464              void* toarr) {
465   const From* from = reinterpret_cast<From*>(from_void);
466   To* to = reinterpret_cast<To*>(to_void);
467   for (npy_intp i = 0; i < n; ++i) {
468     to[i] = static_cast<To>(from[i]);
469   }
470 }
471 
472 // Registers a cast between bfloat16 and type 'T'. 'numpy_type' is the NumPy
473 // type corresponding to 'T'. If 'cast_is_safe', registers that bfloat16 can be
474 // safely coerced to T.
475 template <typename T>
RegisterBfloat16Cast(int numpy_type,bool cast_is_safe)476 bool RegisterBfloat16Cast(int numpy_type, bool cast_is_safe) {
477   if (PyArray_RegisterCastFunc(PyArray_DescrFromType(numpy_type), npy_bfloat16_,
478                                NPyCast<T, bfloat16>) < 0) {
479     return false;
480   }
481   if (PyArray_RegisterCastFunc(&NPyBfloat16_Descr, numpy_type,
482                                NPyCast<bfloat16, T>) < 0) {
483     return false;
484   }
485   if (cast_is_safe && PyArray_RegisterCanCast(&NPyBfloat16_Descr, numpy_type,
486                                               NPY_NOSCALAR) < 0) {
487     return false;
488   }
489   return true;
490 }
491 
492 template <typename InType, typename OutType, typename Functor>
BinaryUFunc(char ** args,npy_intp * dimensions,npy_intp * steps,void * data)493 void BinaryUFunc(char** args, npy_intp* dimensions, npy_intp* steps,
494                  void* data) {
495   const char* i0 = args[0];
496   const char* i1 = args[1];
497   char* o = args[2];
498   for (npy_intp k = 0; k < *dimensions; k++) {
499     InType x = *reinterpret_cast<const InType*>(i0);
500     InType y = *reinterpret_cast<const InType*>(i1);
501     *reinterpret_cast<OutType*>(o) = Functor()(x, y);
502     i0 += steps[0];
503     i1 += steps[1];
504     o += steps[2];
505   }
506 }
507 
508 template <typename Functor>
CompareUFunc(char ** args,npy_intp * dimensions,npy_intp * steps,void * data)509 void CompareUFunc(char** args, npy_intp* dimensions, npy_intp* steps,
510                   void* data) {
511   BinaryUFunc<bfloat16, npy_bool, Functor>(args, dimensions, steps, data);
512 }
513 
514 struct Bfloat16EqFunctor {
operator ()tensorflow::__anonbd6dfcf70111::Bfloat16EqFunctor515   npy_bool operator()(bfloat16 a, bfloat16 b) { return a == b; }
516 };
517 struct Bfloat16NeFunctor {
operator ()tensorflow::__anonbd6dfcf70111::Bfloat16NeFunctor518   npy_bool operator()(bfloat16 a, bfloat16 b) { return a != b; }
519 };
520 struct Bfloat16LtFunctor {
operator ()tensorflow::__anonbd6dfcf70111::Bfloat16LtFunctor521   npy_bool operator()(bfloat16 a, bfloat16 b) { return a < b; }
522 };
523 struct Bfloat16GtFunctor {
operator ()tensorflow::__anonbd6dfcf70111::Bfloat16GtFunctor524   npy_bool operator()(bfloat16 a, bfloat16 b) { return a > b; }
525 };
526 struct Bfloat16LeFunctor {
operator ()tensorflow::__anonbd6dfcf70111::Bfloat16LeFunctor527   npy_bool operator()(bfloat16 a, bfloat16 b) { return a <= b; }
528 };
529 struct Bfloat16GeFunctor {
operator ()tensorflow::__anonbd6dfcf70111::Bfloat16GeFunctor530   npy_bool operator()(bfloat16 a, bfloat16 b) { return a >= b; }
531 };
532 
533 // Initializes the module.
Initialize()534 bool Initialize() {
535   // It's critical to import umath to avoid crash in open source build.
536   import_umath1(false);
537 
538   Safe_PyObjectPtr numpy_str = make_safe(MakePyString("numpy"));
539   if (!numpy_str) {
540     return false;
541   }
542   Safe_PyObjectPtr numpy = make_safe(PyImport_Import(numpy_str.get()));
543   if (!numpy) {
544     return false;
545   }
546 
547   // We hit a mysterious crash if we haven't initialized numpy before this:
548   PyBfloat16_Type.tp_base = &PyGenericArrType_Type;
549 
550   if (PyType_Ready(&PyBfloat16_Type) < 0) {
551     return false;
552   }
553 
554   // Initializes the NumPy descriptor.
555   PyArray_InitArrFuncs(&NPyBfloat16_ArrFuncs);
556   NPyBfloat16_ArrFuncs.getitem = NPyBfloat16_GetItem;
557   NPyBfloat16_ArrFuncs.setitem = NPyBfloat16_SetItem;
558   NPyBfloat16_ArrFuncs.copyswapn = NPyBfloat16_CopySwapN;
559   NPyBfloat16_ArrFuncs.copyswap = NPyBfloat16_CopySwap;
560   NPyBfloat16_ArrFuncs.nonzero = NPyBfloat16_NonZero;
561   NPyBfloat16_ArrFuncs.fill = NPyBfloat16_Fill;
562 
563   Py_TYPE(&NPyBfloat16_Descr) = &PyArrayDescr_Type;
564   npy_bfloat16_ = PyArray_RegisterDataType(&NPyBfloat16_Descr);
565   if (npy_bfloat16_ < 0) return false;
566 
567   // Support dtype(bfloat16)
568   if (PyDict_SetItemString(PyBfloat16_Type.tp_dict, "dtype",
569                            reinterpret_cast<PyObject*>(&NPyBfloat16_Descr)) <
570       0) {
571     return false;
572   }
573 
574   // Register casts
575 
576   // We lie shamelessly and say that a cast from half to bfloat16 is safe.
577   // Numpy frequently uses the smallest legal representation type for small
578   // float constants (e.g., 1.0), which is often float16. Things break if these
579   // cannot be converted transparently to bfloat16.
580   if (!RegisterBfloat16Cast<Eigen::half>(NPY_HALF, /*cast_is_safe=*/true)) {
581     return false;
582   }
583 
584   if (!RegisterBfloat16Cast<float>(NPY_FLOAT, /*cast_is_safe=*/true)) {
585     return false;
586   }
587   if (!RegisterBfloat16Cast<double>(NPY_DOUBLE, /*cast_is_safe=*/true)) {
588     return false;
589   }
590   if (!RegisterBfloat16Cast<int32>(NPY_INT32, /*cast_is_safe=*/false)) {
591     return false;
592   }
593   if (!RegisterBfloat16Cast<int64>(NPY_INT64, /*cast_is_safe=*/false)) {
594     return false;
595   }
596   // Following the numpy convention. imag part is dropped when converting to
597   // float.
598   if (!RegisterBfloat16Cast<complex64>(NPY_COMPLEX64, /*cast_is_safe=*/true)) {
599     return false;
600   }
601   if (!RegisterBfloat16Cast<complex128>(NPY_COMPLEX128,
602                                         /*cast_is_safe=*/true)) {
603     return false;
604   }
605 
606   // Register ufuncs
607   auto register_ufunc = [&](const char* name, PyUFuncGenericFunction fn,
608                             const std::array<int, 3>& types) {
609     Safe_PyObjectPtr ufunc_obj =
610         make_safe(PyObject_GetAttrString(numpy.get(), name));
611     if (!ufunc_obj) {
612       return false;
613     }
614     PyUFuncObject* ufunc = reinterpret_cast<PyUFuncObject*>(ufunc_obj.get());
615     if (types.size() != ufunc->nargs) {
616       PyErr_Format(PyExc_AssertionError,
617                    "ufunc %s takes %d arguments, loop takes %lu", name,
618                    ufunc->nargs, types.size());
619       return false;
620     }
621     if (PyUFunc_RegisterLoopForType(ufunc, npy_bfloat16_, fn,
622                                     const_cast<int*>(types.data()),
623                                     nullptr) < 0) {
624       return false;
625     }
626     return true;
627   };
628 
629   // Comparisons
630   const std::array<int, 3> compare_types = {
631       {npy_bfloat16_, npy_bfloat16_, NPY_BOOL}};
632 
633   if (!register_ufunc("equal", CompareUFunc<Bfloat16EqFunctor>,
634                       compare_types)) {
635     return false;
636   }
637   if (!register_ufunc("not_equal", CompareUFunc<Bfloat16NeFunctor>,
638                       compare_types)) {
639     return false;
640   }
641   if (!register_ufunc("less", CompareUFunc<Bfloat16LtFunctor>, compare_types)) {
642     return false;
643   }
644   if (!register_ufunc("greater", CompareUFunc<Bfloat16GtFunctor>,
645                       compare_types)) {
646     return false;
647   }
648   if (!register_ufunc("less_equal", CompareUFunc<Bfloat16LeFunctor>,
649                       compare_types)) {
650     return false;
651   }
652   if (!register_ufunc("greater_equal", CompareUFunc<Bfloat16GeFunctor>,
653                       compare_types)) {
654     return false;
655   }
656   return true;
657 }
658 
659 }  // namespace
660 
RegisterNumpyBfloat16()661 void RegisterNumpyBfloat16() {
662   if (npy_bfloat16_ >= 0) {
663     // Already initialized.
664     return;
665   }
666   if (!Initialize()) {
667     if (!PyErr_Occurred()) {
668       PyErr_SetString(PyExc_RuntimeError, "cannot load bfloat16 module.");
669     }
670     PyErr_Print();
671   }
672 }
673 
Bfloat16PyType()674 PyObject* Bfloat16PyType() {
675   CHECK(PyBfloat16_Type.tp_base != nullptr);
676   Py_INCREF(&PyBfloat16_Type);
677   return reinterpret_cast<PyObject*>(&PyBfloat16_Type);
678 }
679 
Bfloat16NumpyType()680 int Bfloat16NumpyType() {
681   CHECK_GE(npy_bfloat16_, 0);
682   return npy_bfloat16_;
683 }
684 
685 }  // namespace tensorflow
686