1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/python/lib/core/bfloat16.h"
17 
18 #include <array>
19 #include <locale>
20 // Place `<locale>` before <Python.h> to avoid a build failure in macOS.
21 #include <Python.h>
22 
23 #include "absl/strings/str_cat.h"
24 #include "third_party/eigen3/Eigen/Core"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/python/lib/core/numpy.h"
27 
28 namespace tensorflow {
29 namespace {
30 
31 using bfloat16 = Eigen::bfloat16;
32 
33 struct PyDecrefDeleter {
operator ()tensorflow::__anonbd6dfcf70111::PyDecrefDeleter34   void operator()(PyObject* p) const { Py_DECREF(p); }
35 };
36 
37 // Safe container for an owned PyObject. On destruction, the reference count of
38 // the contained object will be decremented.
39 using Safe_PyObjectPtr = std::unique_ptr<PyObject, PyDecrefDeleter>;
make_safe(PyObject * object)40 Safe_PyObjectPtr make_safe(PyObject* object) {
41   return Safe_PyObjectPtr(object);
42 }
43 
PyLong_CheckNoOverflow(PyObject * object)44 bool PyLong_CheckNoOverflow(PyObject* object) {
45   if (!PyLong_Check(object)) {
46     return false;
47   }
48   int overflow = 0;
49   PyLong_AsLongAndOverflow(object, &overflow);
50   return (overflow == 0);
51 }
52 
53 // Registered numpy type ID. Global variable populated by the registration code.
54 // Protected by the GIL.
55 int npy_bfloat16 = NPY_NOTYPE;
56 
57 // Forward declaration.
58 extern PyTypeObject bfloat16_type;
59 
60 // Pointer to the bfloat16 type object we are using. This is either a pointer
61 // to bfloat16_type, if we choose to register it, or to the bfloat16 type
62 // registered by another system into NumPy.
63 PyTypeObject* bfloat16_type_ptr = nullptr;
64 
65 // Representation of a Python bfloat16 object.
66 struct PyBfloat16 {
67   PyObject_HEAD;  // Python object header
68   bfloat16 value;
69 };
70 
71 // Returns true if 'object' is a PyBfloat16.
PyBfloat16_Check(PyObject * object)72 bool PyBfloat16_Check(PyObject* object) {
73   return PyObject_IsInstance(object,
74                              reinterpret_cast<PyObject*>(&bfloat16_type));
75 }
76 
77 // Extracts the value of a PyBfloat16 object.
PyBfloat16_Bfloat16(PyObject * object)78 bfloat16 PyBfloat16_Bfloat16(PyObject* object) {
79   return reinterpret_cast<PyBfloat16*>(object)->value;
80 }
81 
82 // Constructs a PyBfloat16 object from a bfloat16.
PyBfloat16_FromBfloat16(bfloat16 x)83 Safe_PyObjectPtr PyBfloat16_FromBfloat16(bfloat16 x) {
84   Safe_PyObjectPtr ref = make_safe(bfloat16_type.tp_alloc(&bfloat16_type, 0));
85   PyBfloat16* p = reinterpret_cast<PyBfloat16*>(ref.get());
86   if (p) {
87     p->value = x;
88   }
89   return ref;
90 }
91 
92 // Converts a Python object to a bfloat16 value. Returns true on success,
93 // returns false and reports a Python error on failure.
CastToBfloat16(PyObject * arg,bfloat16 * output)94 bool CastToBfloat16(PyObject* arg, bfloat16* output) {
95   if (PyBfloat16_Check(arg)) {
96     *output = PyBfloat16_Bfloat16(arg);
97     return true;
98   }
99   if (PyFloat_Check(arg)) {
100     double d = PyFloat_AsDouble(arg);
101     if (PyErr_Occurred()) {
102       return false;
103     }
104     // TODO(phawkins): check for overflow
105     *output = bfloat16(d);
106     return true;
107   }
108   if (PyLong_CheckNoOverflow(arg)) {
109     long l = PyLong_AsLong(arg);  // NOLINT
110     if (PyErr_Occurred()) {
111       return false;
112     }
113     // TODO(phawkins): check for overflow
114     *output = bfloat16(static_cast<float>(l));
115     return true;
116   }
117   if (PyArray_IsScalar(arg, Half)) {
118     Eigen::half f;
119     PyArray_ScalarAsCtype(arg, &f);
120     *output = bfloat16(f);
121     return true;
122   }
123   if (PyArray_IsScalar(arg, Float)) {
124     float f;
125     PyArray_ScalarAsCtype(arg, &f);
126     *output = bfloat16(f);
127     return true;
128   }
129   if (PyArray_IsScalar(arg, Double)) {
130     double f;
131     PyArray_ScalarAsCtype(arg, &f);
132     *output = bfloat16(f);
133     return true;
134   }
135   if (PyArray_IsZeroDim(arg)) {
136     Safe_PyObjectPtr ref;
137     PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(arg);
138     if (PyArray_TYPE(arr) != npy_bfloat16) {
139       ref = make_safe(PyArray_Cast(arr, npy_bfloat16));
140       if (PyErr_Occurred()) {
141         return false;
142       }
143       arg = ref.get();
144       arr = reinterpret_cast<PyArrayObject*>(arg);
145     }
146     *output = *reinterpret_cast<bfloat16*>(PyArray_DATA(arr));
147     return true;
148   }
149   return false;
150 }
151 
SafeCastToBfloat16(PyObject * arg,bfloat16 * output)152 bool SafeCastToBfloat16(PyObject* arg, bfloat16* output) {
153   if (PyBfloat16_Check(arg)) {
154     *output = PyBfloat16_Bfloat16(arg);
155     return true;
156   }
157   return false;
158 }
159 
160 // Converts a PyBfloat16 into a PyFloat.
PyBfloat16_Float(PyObject * self)161 PyObject* PyBfloat16_Float(PyObject* self) {
162   bfloat16 x = PyBfloat16_Bfloat16(self);
163   return PyFloat_FromDouble(static_cast<double>(x));
164 }
165 
166 // Converts a PyBfloat16 into a PyInt.
PyBfloat16_Int(PyObject * self)167 PyObject* PyBfloat16_Int(PyObject* self) {
168   bfloat16 x = PyBfloat16_Bfloat16(self);
169   long y = static_cast<long>(x);  // NOLINT
170   return PyLong_FromLong(y);
171 }
172 
173 // Negates a PyBfloat16.
PyBfloat16_Negative(PyObject * self)174 PyObject* PyBfloat16_Negative(PyObject* self) {
175   bfloat16 x = PyBfloat16_Bfloat16(self);
176   return PyBfloat16_FromBfloat16(-x).release();
177 }
178 
PyBfloat16_Add(PyObject * a,PyObject * b)179 PyObject* PyBfloat16_Add(PyObject* a, PyObject* b) {
180   bfloat16 x, y;
181   if (SafeCastToBfloat16(a, &x) && SafeCastToBfloat16(b, &y)) {
182     return PyBfloat16_FromBfloat16(x + y).release();
183   }
184   return PyArray_Type.tp_as_number->nb_add(a, b);
185 }
186 
PyBfloat16_Subtract(PyObject * a,PyObject * b)187 PyObject* PyBfloat16_Subtract(PyObject* a, PyObject* b) {
188   bfloat16 x, y;
189   if (SafeCastToBfloat16(a, &x) && SafeCastToBfloat16(b, &y)) {
190     return PyBfloat16_FromBfloat16(x - y).release();
191   }
192   return PyArray_Type.tp_as_number->nb_subtract(a, b);
193 }
194 
PyBfloat16_Multiply(PyObject * a,PyObject * b)195 PyObject* PyBfloat16_Multiply(PyObject* a, PyObject* b) {
196   bfloat16 x, y;
197   if (SafeCastToBfloat16(a, &x) && SafeCastToBfloat16(b, &y)) {
198     return PyBfloat16_FromBfloat16(x * y).release();
199   }
200   return PyArray_Type.tp_as_number->nb_multiply(a, b);
201 }
202 
PyBfloat16_TrueDivide(PyObject * a,PyObject * b)203 PyObject* PyBfloat16_TrueDivide(PyObject* a, PyObject* b) {
204   bfloat16 x, y;
205   if (SafeCastToBfloat16(a, &x) && SafeCastToBfloat16(b, &y)) {
206     return PyBfloat16_FromBfloat16(x / y).release();
207   }
208   return PyArray_Type.tp_as_number->nb_true_divide(a, b);
209 }
210 
211 // Python number methods for PyBfloat16 objects.
212 PyNumberMethods PyBfloat16_AsNumber = {
213     PyBfloat16_Add,       // nb_add
214     PyBfloat16_Subtract,  // nb_subtract
215     PyBfloat16_Multiply,  // nb_multiply
216     nullptr,              // nb_remainder
217     nullptr,              // nb_divmod
218     nullptr,              // nb_power
219     PyBfloat16_Negative,  // nb_negative
220     nullptr,              // nb_positive
221     nullptr,              // nb_absolute
222     nullptr,              // nb_nonzero
223     nullptr,              // nb_invert
224     nullptr,              // nb_lshift
225     nullptr,              // nb_rshift
226     nullptr,              // nb_and
227     nullptr,              // nb_xor
228     nullptr,              // nb_or
229     PyBfloat16_Int,       // nb_int
230     nullptr,              // reserved
231     PyBfloat16_Float,     // nb_float
232 
233     nullptr,  // nb_inplace_add
234     nullptr,  // nb_inplace_subtract
235     nullptr,  // nb_inplace_multiply
236     nullptr,  // nb_inplace_remainder
237     nullptr,  // nb_inplace_power
238     nullptr,  // nb_inplace_lshift
239     nullptr,  // nb_inplace_rshift
240     nullptr,  // nb_inplace_and
241     nullptr,  // nb_inplace_xor
242     nullptr,  // nb_inplace_or
243 
244     nullptr,                // nb_floor_divide
245     PyBfloat16_TrueDivide,  // nb_true_divide
246     nullptr,                // nb_inplace_floor_divide
247     nullptr,                // nb_inplace_true_divide
248     nullptr,                // nb_index
249 };
250 
251 // Constructs a new PyBfloat16.
PyBfloat16_New(PyTypeObject * type,PyObject * args,PyObject * kwds)252 PyObject* PyBfloat16_New(PyTypeObject* type, PyObject* args, PyObject* kwds) {
253   if (kwds && PyDict_Size(kwds)) {
254     PyErr_SetString(PyExc_TypeError, "constructor takes no keyword arguments");
255     return nullptr;
256   }
257   Py_ssize_t size = PyTuple_Size(args);
258   if (size != 1) {
259     PyErr_SetString(PyExc_TypeError,
260                     "expected number as argument to bfloat16 constructor");
261     return nullptr;
262   }
263   PyObject* arg = PyTuple_GetItem(args, 0);
264 
265   bfloat16 value;
266   if (PyBfloat16_Check(arg)) {
267     Py_INCREF(arg);
268     return arg;
269   } else if (CastToBfloat16(arg, &value)) {
270     return PyBfloat16_FromBfloat16(value).release();
271   } else if (PyArray_Check(arg)) {
272     PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(arg);
273     if (PyArray_TYPE(arr) != npy_bfloat16) {
274       return PyArray_Cast(arr, npy_bfloat16);
275     } else {
276       Py_INCREF(arg);
277       return arg;
278     }
279   }
280   PyErr_Format(PyExc_TypeError, "expected number, got %s",
281                arg->ob_type->tp_name);
282   return nullptr;
283 }
284 
285 // Comparisons on PyBfloat16s.
PyBfloat16_RichCompare(PyObject * a,PyObject * b,int op)286 PyObject* PyBfloat16_RichCompare(PyObject* a, PyObject* b, int op) {
287   bfloat16 x, y;
288   if (!SafeCastToBfloat16(a, &x) || !SafeCastToBfloat16(b, &y)) {
289     return PyGenericArrType_Type.tp_richcompare(a, b, op);
290   }
291   bool result;
292   switch (op) {
293     case Py_LT:
294       result = x < y;
295       break;
296     case Py_LE:
297       result = x <= y;
298       break;
299     case Py_EQ:
300       result = x == y;
301       break;
302     case Py_NE:
303       result = x != y;
304       break;
305     case Py_GT:
306       result = x > y;
307       break;
308     case Py_GE:
309       result = x >= y;
310       break;
311     default:
312       LOG(FATAL) << "Invalid op type " << op;
313   }
314   return PyBool_FromLong(result);
315 }
316 
317 // Implementation of repr() for PyBfloat16.
PyBfloat16_Repr(PyObject * self)318 PyObject* PyBfloat16_Repr(PyObject* self) {
319   bfloat16 x = reinterpret_cast<PyBfloat16*>(self)->value;
320   std::string v = absl::StrCat(static_cast<float>(x));
321   return PyUnicode_FromString(v.c_str());
322 }
323 
324 // Implementation of str() for PyBfloat16.
PyBfloat16_Str(PyObject * self)325 PyObject* PyBfloat16_Str(PyObject* self) {
326   bfloat16 x = reinterpret_cast<PyBfloat16*>(self)->value;
327   std::string v = absl::StrCat(static_cast<float>(x));
328   return PyUnicode_FromString(v.c_str());
329 }
330 
331 // Hash function for PyBfloat16. We use the identity function, which is a weak
332 // hash function.
PyBfloat16_Hash(PyObject * self)333 Py_hash_t PyBfloat16_Hash(PyObject* self) {
334   bfloat16 x = reinterpret_cast<PyBfloat16*>(self)->value;
335   return x.value;
336 }
337 
338 // Python type for PyBfloat16 objects.
339 PyTypeObject bfloat16_type = {
340     PyVarObject_HEAD_INIT(nullptr, 0) "bfloat16",  // tp_name
341     sizeof(PyBfloat16),                            // tp_basicsize
342     0,                                             // tp_itemsize
343     nullptr,                                       // tp_dealloc
344 #if PY_VERSION_HEX < 0x03080000
345     nullptr,  // tp_print
346 #else
347     0,  // tp_vectorcall_offset
348 #endif
349     nullptr,               // tp_getattr
350     nullptr,               // tp_setattr
351     nullptr,               // tp_compare / tp_reserved
352     PyBfloat16_Repr,       // tp_repr
353     &PyBfloat16_AsNumber,  // tp_as_number
354     nullptr,               // tp_as_sequence
355     nullptr,               // tp_as_mapping
356     PyBfloat16_Hash,       // tp_hash
357     nullptr,               // tp_call
358     PyBfloat16_Str,        // tp_str
359     nullptr,               // tp_getattro
360     nullptr,               // tp_setattro
361     nullptr,               // tp_as_buffer
362                            // tp_flags
363     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,
364     "bfloat16 floating-point values",  // tp_doc
365     nullptr,                           // tp_traverse
366     nullptr,                           // tp_clear
367     PyBfloat16_RichCompare,            // tp_richcompare
368     0,                                 // tp_weaklistoffset
369     nullptr,                           // tp_iter
370     nullptr,                           // tp_iternext
371     nullptr,                           // tp_methods
372     nullptr,                           // tp_members
373     nullptr,                           // tp_getset
374     nullptr,                           // tp_base
375     nullptr,                           // tp_dict
376     nullptr,                           // tp_descr_get
377     nullptr,                           // tp_descr_set
378     0,                                 // tp_dictoffset
379     nullptr,                           // tp_init
380     nullptr,                           // tp_alloc
381     PyBfloat16_New,                    // tp_new
382     nullptr,                           // tp_free
383     nullptr,                           // tp_is_gc
384     nullptr,                           // tp_bases
385     nullptr,                           // tp_mro
386     nullptr,                           // tp_cache
387     nullptr,                           // tp_subclasses
388     nullptr,                           // tp_weaklist
389     nullptr,                           // tp_del
390     0,                                 // tp_version_tag
391 };
392 
393 // Numpy support
394 
395 PyArray_ArrFuncs NPyBfloat16_ArrFuncs;
396 
397 PyArray_Descr NPyBfloat16_Descr = {
398     PyObject_HEAD_INIT(nullptr)  //
399                                  /*typeobj=*/
400     (&bfloat16_type),
401     // We must register bfloat16 with a kind other than "f", because numpy
402     // considers two types with the same kind and size to be equal, but
403     // float16 != bfloat16.
404     // The downside of this is that NumPy scalar promotion does not work with
405     // bfloat16 values.
406     /*kind=*/'V',
407     // TODO(phawkins): there doesn't seem to be a way of guaranteeing a type
408     // character is unique.
409     /*type=*/'E',
410     /*byteorder=*/'=',
411     /*flags=*/NPY_NEEDS_PYAPI | NPY_USE_GETITEM | NPY_USE_SETITEM,
412     /*type_num=*/0,
413     /*elsize=*/sizeof(bfloat16),
414     /*alignment=*/alignof(bfloat16),
415     /*subarray=*/nullptr,
416     /*fields=*/nullptr,
417     /*names=*/nullptr,
418     /*f=*/&NPyBfloat16_ArrFuncs,
419     /*metadata=*/nullptr,
420     /*c_metadata=*/nullptr,
421     /*hash=*/-1,  // -1 means "not computed yet".
422 };
423 
424 // Implementations of NumPy array methods.
425 
NPyBfloat16_GetItem(void * data,void * arr)426 PyObject* NPyBfloat16_GetItem(void* data, void* arr) {
427   bfloat16 x;
428   memcpy(&x, data, sizeof(bfloat16));
429   return PyBfloat16_FromBfloat16(x).release();
430 }
431 
NPyBfloat16_SetItem(PyObject * item,void * data,void * arr)432 int NPyBfloat16_SetItem(PyObject* item, void* data, void* arr) {
433   bfloat16 x;
434   if (!CastToBfloat16(item, &x)) {
435     PyErr_Format(PyExc_TypeError, "expected number, got %s",
436                  item->ob_type->tp_name);
437     return -1;
438   }
439   memcpy(data, &x, sizeof(bfloat16));
440   return 0;
441 }
442 
ByteSwap16(void * value)443 void ByteSwap16(void* value) {
444   char* p = reinterpret_cast<char*>(value);
445   std::swap(p[0], p[1]);
446 }
447 
NPyBfloat16_Compare(const void * a,const void * b,void * arr)448 int NPyBfloat16_Compare(const void* a, const void* b, void* arr) {
449   bfloat16 x;
450   memcpy(&x, a, sizeof(bfloat16));
451 
452   bfloat16 y;
453   memcpy(&y, b, sizeof(bfloat16));
454 
455   if (x < y) {
456     return -1;
457   }
458   if (y < x) {
459     return 1;
460   }
461   // NaNs sort to the end.
462   if (!Eigen::numext::isnan(x) && Eigen::numext::isnan(y)) {
463     return -1;
464   }
465   if (Eigen::numext::isnan(x) && !Eigen::numext::isnan(y)) {
466     return 1;
467   }
468   return 0;
469 }
470 
NPyBfloat16_CopySwapN(void * dstv,npy_intp dstride,void * srcv,npy_intp sstride,npy_intp n,int swap,void * arr)471 void NPyBfloat16_CopySwapN(void* dstv, npy_intp dstride, void* srcv,
472                            npy_intp sstride, npy_intp n, int swap, void* arr) {
473   char* dst = reinterpret_cast<char*>(dstv);
474   char* src = reinterpret_cast<char*>(srcv);
475   if (!src) {
476     return;
477   }
478   if (swap) {
479     for (npy_intp i = 0; i < n; i++) {
480       char* r = dst + dstride * i;
481       memcpy(r, src + sstride * i, sizeof(uint16_t));
482       ByteSwap16(r);
483     }
484   } else if (dstride == sizeof(uint16_t) && sstride == sizeof(uint16_t)) {
485     memcpy(dst, src, n * sizeof(uint16_t));
486   } else {
487     for (npy_intp i = 0; i < n; i++) {
488       memcpy(dst + dstride * i, src + sstride * i, sizeof(uint16_t));
489     }
490   }
491 }
492 
NPyBfloat16_CopySwap(void * dst,void * src,int swap,void * arr)493 void NPyBfloat16_CopySwap(void* dst, void* src, int swap, void* arr) {
494   if (!src) {
495     return;
496   }
497   memcpy(dst, src, sizeof(uint16_t));
498   if (swap) {
499     ByteSwap16(dst);
500   }
501 }
502 
NPyBfloat16_NonZero(void * data,void * arr)503 npy_bool NPyBfloat16_NonZero(void* data, void* arr) {
504   bfloat16 x;
505   memcpy(&x, data, sizeof(x));
506   return x != static_cast<bfloat16>(0);
507 }
508 
NPyBfloat16_Fill(void * buffer_raw,npy_intp length,void * ignored)509 int NPyBfloat16_Fill(void* buffer_raw, npy_intp length, void* ignored) {
510   bfloat16* const buffer = reinterpret_cast<bfloat16*>(buffer_raw);
511   const float start(buffer[0]);
512   const float delta = static_cast<float>(buffer[1]) - start;
513   for (npy_intp i = 2; i < length; ++i) {
514     buffer[i] = static_cast<bfloat16>(start + i * delta);
515   }
516   return 0;
517 }
518 
NPyBfloat16_DotFunc(void * ip1,npy_intp is1,void * ip2,npy_intp is2,void * op,npy_intp n,void * arr)519 void NPyBfloat16_DotFunc(void* ip1, npy_intp is1, void* ip2, npy_intp is2,
520                          void* op, npy_intp n, void* arr) {
521   char* c1 = reinterpret_cast<char*>(ip1);
522   char* c2 = reinterpret_cast<char*>(ip2);
523   float acc = 0.0f;
524   for (npy_intp i = 0; i < n; ++i) {
525     bfloat16* const b1 = reinterpret_cast<bfloat16*>(c1);
526     bfloat16* const b2 = reinterpret_cast<bfloat16*>(c2);
527     acc += static_cast<float>(*b1) * static_cast<float>(*b2);
528     c1 += is1;
529     c2 += is2;
530   }
531   bfloat16* out = reinterpret_cast<bfloat16*>(op);
532   *out = static_cast<bfloat16>(acc);
533 }
534 
NPyBfloat16_CompareFunc(const void * v1,const void * v2,void * arr)535 int NPyBfloat16_CompareFunc(const void* v1, const void* v2, void* arr) {
536   bfloat16 b1 = *reinterpret_cast<const bfloat16*>(v1);
537   bfloat16 b2 = *reinterpret_cast<const bfloat16*>(v2);
538   if (b1 < b2) {
539     return -1;
540   }
541   if (b1 > b2) {
542     return 1;
543   }
544   return 0;
545 }
546 
NPyBfloat16_ArgMaxFunc(void * data,npy_intp n,npy_intp * max_ind,void * arr)547 int NPyBfloat16_ArgMaxFunc(void* data, npy_intp n, npy_intp* max_ind,
548                            void* arr) {
549   const bfloat16* bdata = reinterpret_cast<const bfloat16*>(data);
550   float max_val = -std::numeric_limits<float>::infinity();
551   for (npy_intp i = 0; i < n; ++i) {
552     if (static_cast<float>(bdata[i]) > max_val) {
553       max_val = static_cast<float>(bdata[i]);
554       *max_ind = i;
555     }
556   }
557   return 0;
558 }
559 
NPyBfloat16_ArgMinFunc(void * data,npy_intp n,npy_intp * min_ind,void * arr)560 int NPyBfloat16_ArgMinFunc(void* data, npy_intp n, npy_intp* min_ind,
561                            void* arr) {
562   const bfloat16* bdata = reinterpret_cast<const bfloat16*>(data);
563   float min_val = std::numeric_limits<float>::infinity();
564   for (npy_intp i = 0; i < n; ++i) {
565     if (static_cast<float>(bdata[i]) < min_val) {
566       min_val = static_cast<float>(bdata[i]);
567       *min_ind = i;
568     }
569   }
570   return 0;
571 }
572 
573 // NumPy casts
574 
575 template <typename T, typename Enable = void>
576 struct TypeDescriptor {
577   // typedef ... T;  // Representation type in memory for NumPy values of type
578   // static int Dtype() { return NPY_...; }  // Numpy type number for T.
579 };
580 
581 template <>
582 struct TypeDescriptor<bfloat16> {
583   typedef bfloat16 T;
Dtypetensorflow::__anonbd6dfcf70111::TypeDescriptor584   static int Dtype() { return npy_bfloat16; }
585 };
586 
587 template <>
588 struct TypeDescriptor<uint8> {
589   typedef uint8 T;
Dtypetensorflow::__anonbd6dfcf70111::TypeDescriptor590   static int Dtype() { return NPY_UINT8; }
591 };
592 
593 template <>
594 struct TypeDescriptor<uint16> {
595   typedef uint16 T;
Dtypetensorflow::__anonbd6dfcf70111::TypeDescriptor596   static int Dtype() { return NPY_UINT16; }
597 };
598 
599 // We register "int", "long", and "long long" types for portability across
600 // Linux, where "int" and "long" are the same type, and Windows, where "long"
601 // and "longlong" are the same type.
602 template <>
603 struct TypeDescriptor<unsigned int> {
604   typedef unsigned int T;
Dtypetensorflow::__anonbd6dfcf70111::TypeDescriptor605   static int Dtype() { return NPY_UINT; }
606 };
607 
608 template <>
609 struct TypeDescriptor<unsigned long> {  // NOLINT
610   typedef unsigned long T;              // NOLINT
Dtypetensorflow::__anonbd6dfcf70111::TypeDescriptor611   static int Dtype() { return NPY_ULONG; }
612 };
613 
614 template <>
615 struct TypeDescriptor<unsigned long long> {  // NOLINT
616   typedef unsigned long long T;              // NOLINT
Dtypetensorflow::__anonbd6dfcf70111::TypeDescriptor617   static int Dtype() { return NPY_ULONGLONG; }
618 };
619 
620 template <>
621 struct TypeDescriptor<int8> {
622   typedef int8 T;
Dtypetensorflow::__anonbd6dfcf70111::TypeDescriptor623   static int Dtype() { return NPY_INT8; }
624 };
625 
626 template <>
627 struct TypeDescriptor<int16> {
628   typedef int16 T;
Dtypetensorflow::__anonbd6dfcf70111::TypeDescriptor629   static int Dtype() { return NPY_INT16; }
630 };
631 
632 template <>
633 struct TypeDescriptor<int> {
634   typedef int T;
Dtypetensorflow::__anonbd6dfcf70111::TypeDescriptor635   static int Dtype() { return NPY_INT; }
636 };
637 
638 template <>
639 struct TypeDescriptor<long> {  // NOLINT
640   typedef long T;              // NOLINT
Dtypetensorflow::__anonbd6dfcf70111::TypeDescriptor641   static int Dtype() { return NPY_LONG; }
642 };
643 
644 template <>
645 struct TypeDescriptor<long long> {  // NOLINT
646   typedef long long T;              // NOLINT
Dtypetensorflow::__anonbd6dfcf70111::TypeDescriptor647   static int Dtype() { return NPY_LONGLONG; }
648 };
649 
650 template <>
651 struct TypeDescriptor<bool> {
652   typedef int8 T;
Dtypetensorflow::__anonbd6dfcf70111::TypeDescriptor653   static int Dtype() { return NPY_BOOL; }
654 };
655 
656 template <>
657 struct TypeDescriptor<Eigen::half> {
658   typedef Eigen::half T;
Dtypetensorflow::__anonbd6dfcf70111::TypeDescriptor659   static int Dtype() { return NPY_HALF; }
660 };
661 
662 template <>
663 struct TypeDescriptor<float> {
664   typedef float T;
Dtypetensorflow::__anonbd6dfcf70111::TypeDescriptor665   static int Dtype() { return NPY_FLOAT; }
666 };
667 
668 template <>
669 struct TypeDescriptor<double> {
670   typedef double T;
Dtypetensorflow::__anonbd6dfcf70111::TypeDescriptor671   static int Dtype() { return NPY_DOUBLE; }
672 };
673 
674 template <>
675 struct TypeDescriptor<std::complex<float>> {
676   typedef std::complex<float> T;
Dtypetensorflow::__anonbd6dfcf70111::TypeDescriptor677   static int Dtype() { return NPY_COMPLEX64; }
678 };
679 
680 template <>
681 struct TypeDescriptor<std::complex<double>> {
682   typedef std::complex<double> T;
Dtypetensorflow::__anonbd6dfcf70111::TypeDescriptor683   static int Dtype() { return NPY_COMPLEX128; }
684 };
685 
686 // Performs a NumPy array cast from type 'From' to 'To'.
687 template <typename From, typename To>
NPyCast(void * from_void,void * to_void,npy_intp n,void * fromarr,void * toarr)688 void NPyCast(void* from_void, void* to_void, npy_intp n, void* fromarr,
689              void* toarr) {
690   const auto* from =
691       reinterpret_cast<typename TypeDescriptor<From>::T*>(from_void);
692   auto* to = reinterpret_cast<typename TypeDescriptor<To>::T*>(to_void);
693   for (npy_intp i = 0; i < n; ++i) {
694     to[i] =
695         static_cast<typename TypeDescriptor<To>::T>(static_cast<To>(from[i]));
696   }
697 }
698 
699 // Registers a cast between bfloat16 and type 'T'. 'numpy_type' is the NumPy
700 // type corresponding to 'T'.
701 template <typename T>
RegisterBfloat16Cast(int numpy_type)702 bool RegisterBfloat16Cast(int numpy_type) {
703   PyArray_Descr* descr = PyArray_DescrFromType(numpy_type);
704   if (PyArray_RegisterCastFunc(descr, npy_bfloat16, NPyCast<T, bfloat16>) < 0) {
705     return false;
706   }
707   if (PyArray_RegisterCastFunc(&NPyBfloat16_Descr, numpy_type,
708                                NPyCast<bfloat16, T>) < 0) {
709     return false;
710   }
711   return true;
712 }
713 
714 template <typename InType, typename OutType, typename Functor>
715 struct UnaryUFunc {
Typestensorflow::__anonbd6dfcf70111::UnaryUFunc716   static std::vector<int> Types() {
717     return {TypeDescriptor<InType>::Dtype(), TypeDescriptor<OutType>::Dtype()};
718   }
Calltensorflow::__anonbd6dfcf70111::UnaryUFunc719   static void Call(char** args, const npy_intp* dimensions,
720                    const npy_intp* steps, void* data) {
721     const char* i0 = args[0];
722     char* o = args[1];
723     for (npy_intp k = 0; k < *dimensions; k++) {
724       auto x = *reinterpret_cast<const typename TypeDescriptor<InType>::T*>(i0);
725       *reinterpret_cast<typename TypeDescriptor<OutType>::T*>(o) = Functor()(x);
726       i0 += steps[0];
727       o += steps[1];
728     }
729   }
730 };
731 
732 template <typename InType, typename OutType, typename OutType2,
733           typename Functor>
734 struct UnaryUFunc2 {
Typestensorflow::__anonbd6dfcf70111::UnaryUFunc2735   static std::vector<int> Types() {
736     return {TypeDescriptor<InType>::Dtype(), TypeDescriptor<OutType>::Dtype(),
737             TypeDescriptor<OutType2>::Dtype()};
738   }
Calltensorflow::__anonbd6dfcf70111::UnaryUFunc2739   static void Call(char** args, const npy_intp* dimensions,
740                    const npy_intp* steps, void* data) {
741     const char* i0 = args[0];
742     char* o0 = args[1];
743     char* o1 = args[2];
744     for (npy_intp k = 0; k < *dimensions; k++) {
745       auto x = *reinterpret_cast<const typename TypeDescriptor<InType>::T*>(i0);
746       std::tie(*reinterpret_cast<typename TypeDescriptor<OutType>::T*>(o0),
747                *reinterpret_cast<typename TypeDescriptor<OutType2>::T*>(o1)) =
748           Functor()(x);
749       i0 += steps[0];
750       o0 += steps[1];
751       o1 += steps[2];
752     }
753   }
754 };
755 
756 template <typename InType, typename OutType, typename Functor>
757 struct BinaryUFunc {
Typestensorflow::__anonbd6dfcf70111::BinaryUFunc758   static std::vector<int> Types() {
759     return {TypeDescriptor<InType>::Dtype(), TypeDescriptor<InType>::Dtype(),
760             TypeDescriptor<OutType>::Dtype()};
761   }
Calltensorflow::__anonbd6dfcf70111::BinaryUFunc762   static void Call(char** args, const npy_intp* dimensions,
763                    const npy_intp* steps, void* data) {
764     const char* i0 = args[0];
765     const char* i1 = args[1];
766     char* o = args[2];
767     for (npy_intp k = 0; k < *dimensions; k++) {
768       auto x = *reinterpret_cast<const typename TypeDescriptor<InType>::T*>(i0);
769       auto y = *reinterpret_cast<const typename TypeDescriptor<InType>::T*>(i1);
770       *reinterpret_cast<typename TypeDescriptor<OutType>::T*>(o) =
771           Functor()(x, y);
772       i0 += steps[0];
773       i1 += steps[1];
774       o += steps[2];
775     }
776   }
777 };
778 
779 template <typename InType, typename InType2, typename OutType, typename Functor>
780 struct BinaryUFunc2 {
Typestensorflow::__anonbd6dfcf70111::BinaryUFunc2781   static std::vector<int> Types() {
782     return {TypeDescriptor<InType>::Dtype(), TypeDescriptor<InType2>::Dtype(),
783             TypeDescriptor<OutType>::Dtype()};
784   }
Calltensorflow::__anonbd6dfcf70111::BinaryUFunc2785   static void Call(char** args, const npy_intp* dimensions,
786                    const npy_intp* steps, void* data) {
787     const char* i0 = args[0];
788     const char* i1 = args[1];
789     char* o = args[2];
790     for (npy_intp k = 0; k < *dimensions; k++) {
791       auto x = *reinterpret_cast<const typename TypeDescriptor<InType>::T*>(i0);
792       auto y =
793           *reinterpret_cast<const typename TypeDescriptor<InType2>::T*>(i1);
794       *reinterpret_cast<typename TypeDescriptor<OutType>::T*>(o) =
795           Functor()(x, y);
796       i0 += steps[0];
797       i1 += steps[1];
798       o += steps[2];
799     }
800   }
801 };
802 
803 template <typename UFunc>
RegisterUFunc(PyObject * numpy,const char * name)804 bool RegisterUFunc(PyObject* numpy, const char* name) {
805   std::vector<int> types = UFunc::Types();
806   PyUFuncGenericFunction fn =
807       reinterpret_cast<PyUFuncGenericFunction>(UFunc::Call);
808   Safe_PyObjectPtr ufunc_obj = make_safe(PyObject_GetAttrString(numpy, name));
809   if (!ufunc_obj) {
810     return false;
811   }
812   PyUFuncObject* ufunc = reinterpret_cast<PyUFuncObject*>(ufunc_obj.get());
813   if (static_cast<int>(types.size()) != ufunc->nargs) {
814     PyErr_Format(PyExc_AssertionError,
815                  "ufunc %s takes %d arguments, loop takes %lu", name,
816                  ufunc->nargs, types.size());
817     return false;
818   }
819   if (PyUFunc_RegisterLoopForType(ufunc, npy_bfloat16, fn,
820                                   const_cast<int*>(types.data()),
821                                   nullptr) < 0) {
822     return false;
823   }
824   return true;
825 }
826 
827 namespace ufuncs {
828 
829 struct Add {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Add830   bfloat16 operator()(bfloat16 a, bfloat16 b) { return a + b; }
831 };
832 struct Subtract {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Subtract833   bfloat16 operator()(bfloat16 a, bfloat16 b) { return a - b; }
834 };
835 struct Multiply {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Multiply836   bfloat16 operator()(bfloat16 a, bfloat16 b) { return a * b; }
837 };
838 struct TrueDivide {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::TrueDivide839   bfloat16 operator()(bfloat16 a, bfloat16 b) { return a / b; }
840 };
841 
divmod(float a,float b)842 std::pair<float, float> divmod(float a, float b) {
843   if (b == 0.0f) {
844     float nan = std::numeric_limits<float>::quiet_NaN();
845     return {nan, nan};
846   }
847   float mod = std::fmod(a, b);
848   float div = (a - mod) / b;
849   if (mod != 0.0f) {
850     if ((b < 0.0f) != (mod < 0.0f)) {
851       mod += b;
852       div -= 1.0f;
853     }
854   } else {
855     mod = std::copysign(0.0f, b);
856   }
857 
858   float floordiv;
859   if (div != 0.0f) {
860     floordiv = std::floor(div);
861     if (div - floordiv > 0.5f) {
862       floordiv += 1.0f;
863     }
864   } else {
865     floordiv = std::copysign(0.0f, a / b);
866   }
867   return {floordiv, mod};
868 }
869 
870 struct FloorDivide {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::FloorDivide871   bfloat16 operator()(bfloat16 a, bfloat16 b) {
872     return bfloat16(divmod(static_cast<float>(a), static_cast<float>(b)).first);
873   }
874 };
875 struct Remainder {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Remainder876   bfloat16 operator()(bfloat16 a, bfloat16 b) {
877     return bfloat16(
878         divmod(static_cast<float>(a), static_cast<float>(b)).second);
879   }
880 };
881 struct DivmodUFunc {
Typestensorflow::__anonbd6dfcf70111::ufuncs::DivmodUFunc882   static std::vector<int> Types() {
883     return {npy_bfloat16, npy_bfloat16, npy_bfloat16, npy_bfloat16};
884   }
Calltensorflow::__anonbd6dfcf70111::ufuncs::DivmodUFunc885   static void Call(char** args, npy_intp* dimensions, npy_intp* steps,
886                    void* data) {
887     const char* i0 = args[0];
888     const char* i1 = args[1];
889     char* o0 = args[2];
890     char* o1 = args[3];
891     for (npy_intp k = 0; k < *dimensions; k++) {
892       bfloat16 x = *reinterpret_cast<const bfloat16*>(i0);
893       bfloat16 y = *reinterpret_cast<const bfloat16*>(i1);
894       float floordiv, mod;
895       std::tie(floordiv, mod) =
896           divmod(static_cast<float>(x), static_cast<float>(y));
897       *reinterpret_cast<bfloat16*>(o0) = bfloat16(floordiv);
898       *reinterpret_cast<bfloat16*>(o1) = bfloat16(mod);
899       i0 += steps[0];
900       i1 += steps[1];
901       o0 += steps[2];
902       o1 += steps[3];
903     }
904   }
905 };
906 struct Fmod {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Fmod907   bfloat16 operator()(bfloat16 a, bfloat16 b) {
908     return bfloat16(std::fmod(static_cast<float>(a), static_cast<float>(b)));
909   }
910 };
911 struct Negative {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Negative912   bfloat16 operator()(bfloat16 a) { return -a; }
913 };
914 struct Positive {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Positive915   bfloat16 operator()(bfloat16 a) { return a; }
916 };
917 struct Power {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Power918   bfloat16 operator()(bfloat16 a, bfloat16 b) {
919     return bfloat16(std::pow(static_cast<float>(a), static_cast<float>(b)));
920   }
921 };
922 struct Abs {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Abs923   bfloat16 operator()(bfloat16 a) {
924     return bfloat16(std::abs(static_cast<float>(a)));
925   }
926 };
927 struct Cbrt {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Cbrt928   bfloat16 operator()(bfloat16 a) {
929     return bfloat16(std::cbrt(static_cast<float>(a)));
930   }
931 };
932 struct Ceil {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Ceil933   bfloat16 operator()(bfloat16 a) {
934     return bfloat16(std::ceil(static_cast<float>(a)));
935   }
936 };
937 struct CopySign {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::CopySign938   bfloat16 operator()(bfloat16 a, bfloat16 b) {
939     return bfloat16(
940         std::copysign(static_cast<float>(a), static_cast<float>(b)));
941   }
942 };
943 struct Exp {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Exp944   bfloat16 operator()(bfloat16 a) {
945     return bfloat16(std::exp(static_cast<float>(a)));
946   }
947 };
948 struct Exp2 {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Exp2949   bfloat16 operator()(bfloat16 a) {
950     return bfloat16(std::exp2(static_cast<float>(a)));
951   }
952 };
953 struct Expm1 {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Expm1954   bfloat16 operator()(bfloat16 a) {
955     return bfloat16(std::expm1(static_cast<float>(a)));
956   }
957 };
958 struct Floor {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Floor959   bfloat16 operator()(bfloat16 a) {
960     return bfloat16(std::floor(static_cast<float>(a)));
961   }
962 };
963 struct Frexp {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Frexp964   std::pair<bfloat16, int> operator()(bfloat16 a) {
965     int exp;
966     float f = std::frexp(static_cast<float>(a), &exp);
967     return {bfloat16(f), exp};
968   }
969 };
970 struct Heaviside {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Heaviside971   bfloat16 operator()(bfloat16 bx, bfloat16 h0) {
972     float x = static_cast<float>(bx);
973     if (Eigen::numext::isnan(x)) {
974       return bx;
975     }
976     if (x < 0) {
977       return bfloat16(0.0f);
978     }
979     if (x > 0) {
980       return bfloat16(1.0f);
981     }
982     return h0;  // x == 0
983   }
984 };
985 struct Conjugate {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Conjugate986   bfloat16 operator()(bfloat16 a) { return a; }
987 };
988 struct IsFinite {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::IsFinite989   bool operator()(bfloat16 a) { return std::isfinite(static_cast<float>(a)); }
990 };
991 struct IsInf {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::IsInf992   bool operator()(bfloat16 a) { return std::isinf(static_cast<float>(a)); }
993 };
994 struct IsNan {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::IsNan995   bool operator()(bfloat16 a) {
996     return Eigen::numext::isnan(static_cast<float>(a));
997   }
998 };
999 struct Ldexp {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Ldexp1000   bfloat16 operator()(bfloat16 a, int exp) {
1001     return bfloat16(std::ldexp(static_cast<float>(a), exp));
1002   }
1003 };
1004 struct Log {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Log1005   bfloat16 operator()(bfloat16 a) {
1006     return bfloat16(std::log(static_cast<float>(a)));
1007   }
1008 };
1009 struct Log2 {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Log21010   bfloat16 operator()(bfloat16 a) {
1011     return bfloat16(std::log2(static_cast<float>(a)));
1012   }
1013 };
1014 struct Log10 {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Log101015   bfloat16 operator()(bfloat16 a) {
1016     return bfloat16(std::log10(static_cast<float>(a)));
1017   }
1018 };
1019 struct Log1p {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Log1p1020   bfloat16 operator()(bfloat16 a) {
1021     return bfloat16(std::log1p(static_cast<float>(a)));
1022   }
1023 };
1024 struct LogAddExp {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::LogAddExp1025   bfloat16 operator()(bfloat16 bx, bfloat16 by) {
1026     float x = static_cast<float>(bx);
1027     float y = static_cast<float>(by);
1028     if (x == y) {
1029       // Handles infinities of the same sign.
1030       return bfloat16(x + std::log(2.0f));
1031     }
1032     float out = std::numeric_limits<float>::quiet_NaN();
1033     if (x > y) {
1034       out = x + std::log1p(std::exp(y - x));
1035     } else if (x < y) {
1036       out = y + std::log1p(std::exp(x - y));
1037     }
1038     return bfloat16(out);
1039   }
1040 };
1041 struct LogAddExp2 {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::LogAddExp21042   bfloat16 operator()(bfloat16 bx, bfloat16 by) {
1043     float x = static_cast<float>(bx);
1044     float y = static_cast<float>(by);
1045     if (x == y) {
1046       // Handles infinities of the same sign.
1047       return bfloat16(x + 1.0f);
1048     }
1049     float out = std::numeric_limits<float>::quiet_NaN();
1050     if (x > y) {
1051       out = x + std::log1p(std::exp2(y - x)) / std::log(2.0f);
1052     } else if (x < y) {
1053       out = y + std::log1p(std::exp2(x - y)) / std::log(2.0f);
1054     }
1055     return bfloat16(out);
1056   }
1057 };
1058 struct Modf {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Modf1059   std::pair<bfloat16, bfloat16> operator()(bfloat16 a) {
1060     float integral;
1061     float f = std::modf(static_cast<float>(a), &integral);
1062     return {bfloat16(f), bfloat16(integral)};
1063   }
1064 };
1065 
1066 struct Reciprocal {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Reciprocal1067   bfloat16 operator()(bfloat16 a) {
1068     return bfloat16(1.f / static_cast<float>(a));
1069   }
1070 };
1071 struct Rint {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Rint1072   bfloat16 operator()(bfloat16 a) {
1073     return bfloat16(std::rint(static_cast<float>(a)));
1074   }
1075 };
1076 struct Sign {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Sign1077   bfloat16 operator()(bfloat16 a) {
1078     float f(a);
1079     if (f < 0) {
1080       return bfloat16(-1);
1081     }
1082     if (f > 0) {
1083       return bfloat16(1);
1084     }
1085     return a;
1086   }
1087 };
1088 struct SignBit {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::SignBit1089   bool operator()(bfloat16 a) { return std::signbit(static_cast<float>(a)); }
1090 };
1091 struct Sqrt {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Sqrt1092   bfloat16 operator()(bfloat16 a) {
1093     return bfloat16(std::sqrt(static_cast<float>(a)));
1094   }
1095 };
1096 struct Square {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Square1097   bfloat16 operator()(bfloat16 a) {
1098     float f(a);
1099     return bfloat16(f * f);
1100   }
1101 };
1102 struct Trunc {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Trunc1103   bfloat16 operator()(bfloat16 a) {
1104     return bfloat16(std::trunc(static_cast<float>(a)));
1105   }
1106 };
1107 
1108 // Trigonometric functions
1109 struct Sin {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Sin1110   bfloat16 operator()(bfloat16 a) {
1111     return bfloat16(std::sin(static_cast<float>(a)));
1112   }
1113 };
1114 struct Cos {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Cos1115   bfloat16 operator()(bfloat16 a) {
1116     return bfloat16(std::cos(static_cast<float>(a)));
1117   }
1118 };
1119 struct Tan {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Tan1120   bfloat16 operator()(bfloat16 a) {
1121     return bfloat16(std::tan(static_cast<float>(a)));
1122   }
1123 };
1124 struct Arcsin {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Arcsin1125   bfloat16 operator()(bfloat16 a) {
1126     return bfloat16(std::asin(static_cast<float>(a)));
1127   }
1128 };
1129 struct Arccos {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Arccos1130   bfloat16 operator()(bfloat16 a) {
1131     return bfloat16(std::acos(static_cast<float>(a)));
1132   }
1133 };
1134 struct Arctan {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Arctan1135   bfloat16 operator()(bfloat16 a) {
1136     return bfloat16(std::atan(static_cast<float>(a)));
1137   }
1138 };
1139 struct Arctan2 {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Arctan21140   bfloat16 operator()(bfloat16 a, bfloat16 b) {
1141     return bfloat16(std::atan2(static_cast<float>(a), static_cast<float>(b)));
1142   }
1143 };
1144 struct Hypot {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Hypot1145   bfloat16 operator()(bfloat16 a, bfloat16 b) {
1146     return bfloat16(std::hypot(static_cast<float>(a), static_cast<float>(b)));
1147   }
1148 };
1149 struct Sinh {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Sinh1150   bfloat16 operator()(bfloat16 a) {
1151     return bfloat16(std::sinh(static_cast<float>(a)));
1152   }
1153 };
1154 struct Cosh {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Cosh1155   bfloat16 operator()(bfloat16 a) {
1156     return bfloat16(std::cosh(static_cast<float>(a)));
1157   }
1158 };
1159 struct Tanh {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Tanh1160   bfloat16 operator()(bfloat16 a) {
1161     return bfloat16(std::tanh(static_cast<float>(a)));
1162   }
1163 };
1164 struct Arcsinh {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Arcsinh1165   bfloat16 operator()(bfloat16 a) {
1166     return bfloat16(std::asinh(static_cast<float>(a)));
1167   }
1168 };
1169 struct Arccosh {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Arccosh1170   bfloat16 operator()(bfloat16 a) {
1171     return bfloat16(std::acosh(static_cast<float>(a)));
1172   }
1173 };
1174 struct Arctanh {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Arctanh1175   bfloat16 operator()(bfloat16 a) {
1176     return bfloat16(std::atanh(static_cast<float>(a)));
1177   }
1178 };
1179 struct Deg2rad {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Deg2rad1180   bfloat16 operator()(bfloat16 a) {
1181     static constexpr float radians_per_degree = M_PI / 180.0f;
1182     return bfloat16(static_cast<float>(a) * radians_per_degree);
1183   }
1184 };
1185 struct Rad2deg {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Rad2deg1186   bfloat16 operator()(bfloat16 a) {
1187     static constexpr float degrees_per_radian = 180.0f / M_PI;
1188     return bfloat16(static_cast<float>(a) * degrees_per_radian);
1189   }
1190 };
1191 
1192 struct Eq {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Eq1193   npy_bool operator()(bfloat16 a, bfloat16 b) { return a == b; }
1194 };
1195 struct Ne {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Ne1196   npy_bool operator()(bfloat16 a, bfloat16 b) { return a != b; }
1197 };
1198 struct Lt {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Lt1199   npy_bool operator()(bfloat16 a, bfloat16 b) { return a < b; }
1200 };
1201 struct Gt {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Gt1202   npy_bool operator()(bfloat16 a, bfloat16 b) { return a > b; }
1203 };
1204 struct Le {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Le1205   npy_bool operator()(bfloat16 a, bfloat16 b) { return a <= b; }
1206 };
1207 struct Ge {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Ge1208   npy_bool operator()(bfloat16 a, bfloat16 b) { return a >= b; }
1209 };
1210 struct Maximum {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Maximum1211   bfloat16 operator()(bfloat16 a, bfloat16 b) {
1212     float fa(a), fb(b);
1213     return Eigen::numext::isnan(fa) || fa > fb ? a : b;
1214   }
1215 };
1216 struct Minimum {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Minimum1217   bfloat16 operator()(bfloat16 a, bfloat16 b) {
1218     float fa(a), fb(b);
1219     return Eigen::numext::isnan(fa) || fa < fb ? a : b;
1220   }
1221 };
1222 struct Fmax {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Fmax1223   bfloat16 operator()(bfloat16 a, bfloat16 b) {
1224     float fa(a), fb(b);
1225     return Eigen::numext::isnan(fb) || fa > fb ? a : b;
1226   }
1227 };
1228 struct Fmin {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::Fmin1229   bfloat16 operator()(bfloat16 a, bfloat16 b) {
1230     float fa(a), fb(b);
1231     return Eigen::numext::isnan(fb) || fa < fb ? a : b;
1232   }
1233 };
1234 
1235 struct LogicalNot {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::LogicalNot1236   npy_bool operator()(bfloat16 a) { return !a; }
1237 };
1238 struct LogicalAnd {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::LogicalAnd1239   npy_bool operator()(bfloat16 a, bfloat16 b) { return a && b; }
1240 };
1241 struct LogicalOr {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::LogicalOr1242   npy_bool operator()(bfloat16 a, bfloat16 b) { return a || b; }
1243 };
1244 struct LogicalXor {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::LogicalXor1245   npy_bool operator()(bfloat16 a, bfloat16 b) {
1246     return static_cast<bool>(a) ^ static_cast<bool>(b);
1247   }
1248 };
1249 
1250 struct NextAfter {
operator ()tensorflow::__anonbd6dfcf70111::ufuncs::NextAfter1251   bfloat16 operator()(bfloat16 from, bfloat16 to) {
1252     uint16_t from_as_int, to_as_int;
1253     const uint16_t sign_mask = 1 << 15;
1254     float from_as_float(from), to_as_float(to);
1255     memcpy(&from_as_int, &from, sizeof(bfloat16));
1256     memcpy(&to_as_int, &to, sizeof(bfloat16));
1257     if (Eigen::numext::isnan(from_as_float) ||
1258         Eigen::numext::isnan(to_as_float)) {
1259       return bfloat16(std::numeric_limits<float>::quiet_NaN());
1260     }
1261     if (from_as_int == to_as_int) {
1262       return to;
1263     }
1264     if (from_as_float == 0) {
1265       if (to_as_float == 0) {
1266         return to;
1267       } else {
1268         // Smallest subnormal signed like `to`.
1269         uint16_t out_int = (to_as_int & sign_mask) | 1;
1270         bfloat16 out;
1271         memcpy(&out, &out_int, sizeof(bfloat16));
1272         return out;
1273       }
1274     }
1275     uint16_t from_sign = from_as_int & sign_mask;
1276     uint16_t to_sign = to_as_int & sign_mask;
1277     uint16_t from_abs = from_as_int & ~sign_mask;
1278     uint16_t to_abs = to_as_int & ~sign_mask;
1279     uint16_t magnitude_adjustment =
1280         (from_abs > to_abs || from_sign != to_sign) ? 0xFFFF : 0x0001;
1281     uint16_t out_int = from_as_int + magnitude_adjustment;
1282     bfloat16 out;
1283     memcpy(&out, &out_int, sizeof(bfloat16));
1284     return out;
1285   }
1286 };
1287 
1288 // TODO(phawkins): implement spacing
1289 
1290 }  // namespace ufuncs
1291 
1292 }  // namespace
1293 
1294 // Initializes the module.
Initialize()1295 bool Initialize() {
1296   ImportNumpy();
1297   import_umath1(false);
1298 
1299   Safe_PyObjectPtr numpy_str = make_safe(PyUnicode_FromString("numpy"));
1300   if (!numpy_str) {
1301     return false;
1302   }
1303   Safe_PyObjectPtr numpy = make_safe(PyImport_Import(numpy_str.get()));
1304   if (!numpy) {
1305     return false;
1306   }
1307 
1308   // If another module (presumably either TF or JAX) has registered a bfloat16
1309   // type, use it. We don't want two bfloat16 types if we can avoid it since it
1310   // leads to confusion if we have two different types with the same name. This
1311   // assumes that the other module has a sufficiently complete bfloat16
1312   // implementation. The only known NumPy bfloat16 extension at the time of
1313   // writing is this one (distributed in TF and JAX).
1314   // TODO(phawkins): distribute the bfloat16 extension as its own pip package,
1315   // so we can unambiguously refer to a single canonical definition of bfloat16.
1316   int typenum = PyArray_TypeNumFromName(const_cast<char*>("bfloat16"));
1317   if (typenum != NPY_NOTYPE) {
1318     PyArray_Descr* descr = PyArray_DescrFromType(typenum);
1319     // The test for an argmax function here is to verify that the
1320     // bfloat16 implementation is sufficiently new, and, say, not from
1321     // an older version of TF or JAX.
1322     if (descr && descr->f && descr->f->argmax) {
1323       npy_bfloat16 = typenum;
1324       bfloat16_type_ptr = descr->typeobj;
1325       return true;
1326     }
1327   }
1328 
1329   bfloat16_type.tp_base = &PyGenericArrType_Type;
1330 
1331   if (PyType_Ready(&bfloat16_type) < 0) {
1332     return false;
1333   }
1334 
1335   // Initializes the NumPy descriptor.
1336   PyArray_InitArrFuncs(&NPyBfloat16_ArrFuncs);
1337   NPyBfloat16_ArrFuncs.getitem = NPyBfloat16_GetItem;
1338   NPyBfloat16_ArrFuncs.setitem = NPyBfloat16_SetItem;
1339   NPyBfloat16_ArrFuncs.compare = NPyBfloat16_Compare;
1340   NPyBfloat16_ArrFuncs.copyswapn = NPyBfloat16_CopySwapN;
1341   NPyBfloat16_ArrFuncs.copyswap = NPyBfloat16_CopySwap;
1342   NPyBfloat16_ArrFuncs.nonzero = NPyBfloat16_NonZero;
1343   NPyBfloat16_ArrFuncs.fill = NPyBfloat16_Fill;
1344   NPyBfloat16_ArrFuncs.dotfunc = NPyBfloat16_DotFunc;
1345   NPyBfloat16_ArrFuncs.compare = NPyBfloat16_CompareFunc;
1346   NPyBfloat16_ArrFuncs.argmax = NPyBfloat16_ArgMaxFunc;
1347   NPyBfloat16_ArrFuncs.argmin = NPyBfloat16_ArgMinFunc;
1348 
1349   Py_TYPE(&NPyBfloat16_Descr) = &PyArrayDescr_Type;
1350   npy_bfloat16 = PyArray_RegisterDataType(&NPyBfloat16_Descr);
1351   bfloat16_type_ptr = &bfloat16_type;
1352   if (npy_bfloat16 < 0) {
1353     return false;
1354   }
1355 
1356   // Support dtype(bfloat16)
1357   if (PyDict_SetItemString(bfloat16_type.tp_dict, "dtype",
1358                            reinterpret_cast<PyObject*>(&NPyBfloat16_Descr)) <
1359       0) {
1360     return false;
1361   }
1362 
1363   // Register casts
1364   if (!RegisterBfloat16Cast<Eigen::half>(NPY_HALF)) {
1365     return false;
1366   }
1367 
1368   if (!RegisterBfloat16Cast<float>(NPY_FLOAT)) {
1369     return false;
1370   }
1371   if (!RegisterBfloat16Cast<double>(NPY_DOUBLE)) {
1372     return false;
1373   }
1374   if (!RegisterBfloat16Cast<bool>(NPY_BOOL)) {
1375     return false;
1376   }
1377   if (!RegisterBfloat16Cast<uint8>(NPY_UINT8)) {
1378     return false;
1379   }
1380   if (!RegisterBfloat16Cast<uint16>(NPY_UINT16)) {
1381     return false;
1382   }
1383   if (!RegisterBfloat16Cast<unsigned int>(NPY_UINT)) {
1384     return false;
1385   }
1386   if (!RegisterBfloat16Cast<unsigned long>(NPY_ULONG)) {  // NOLINT
1387     return false;
1388   }
1389   if (!RegisterBfloat16Cast<unsigned long long>(NPY_ULONGLONG)) {  // NOLINT
1390     return false;
1391   }
1392   if (!RegisterBfloat16Cast<uint64>(NPY_UINT64)) {
1393     return false;
1394   }
1395   if (!RegisterBfloat16Cast<int8>(NPY_INT8)) {
1396     return false;
1397   }
1398   if (!RegisterBfloat16Cast<int16>(NPY_INT16)) {
1399     return false;
1400   }
1401   if (!RegisterBfloat16Cast<int>(NPY_INT)) {
1402     return false;
1403   }
1404   if (!RegisterBfloat16Cast<long>(NPY_LONG)) {  // NOLINT
1405     return false;
1406   }
1407   if (!RegisterBfloat16Cast<long long>(NPY_LONGLONG)) {  // NOLINT
1408     return false;
1409   }
1410   // Following the numpy convention. imag part is dropped when converting to
1411   // float.
1412   if (!RegisterBfloat16Cast<std::complex<float>>(NPY_COMPLEX64)) {
1413     return false;
1414   }
1415   if (!RegisterBfloat16Cast<std::complex<double>>(NPY_COMPLEX128)) {
1416     return false;
1417   }
1418 
1419   // Safe casts from bfloat16 to other types
1420   if (PyArray_RegisterCanCast(&NPyBfloat16_Descr, NPY_FLOAT, NPY_NOSCALAR) <
1421       0) {
1422     return false;
1423   }
1424   if (PyArray_RegisterCanCast(&NPyBfloat16_Descr, NPY_DOUBLE, NPY_NOSCALAR) <
1425       0) {
1426     return false;
1427   }
1428   if (PyArray_RegisterCanCast(&NPyBfloat16_Descr, NPY_COMPLEX64, NPY_NOSCALAR) <
1429       0) {
1430     return false;
1431   }
1432   if (PyArray_RegisterCanCast(&NPyBfloat16_Descr, NPY_COMPLEX128,
1433                               NPY_NOSCALAR) < 0) {
1434     return false;
1435   }
1436 
1437   // Safe casts to bfloat16 from other types
1438   if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_BOOL), npy_bfloat16,
1439                               NPY_NOSCALAR) < 0) {
1440     return false;
1441   }
1442   if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_UINT8), npy_bfloat16,
1443                               NPY_NOSCALAR) < 0) {
1444     return false;
1445   }
1446   if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_INT8), npy_bfloat16,
1447                               NPY_NOSCALAR) < 0) {
1448     return false;
1449   }
1450 
1451   bool ok =
1452       RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Add>>(numpy.get(),
1453                                                                   "add") &&
1454       RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Subtract>>(
1455           numpy.get(), "subtract") &&
1456       RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Multiply>>(
1457           numpy.get(), "multiply") &&
1458       RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::TrueDivide>>(
1459           numpy.get(), "divide") &&
1460       RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::LogAddExp>>(
1461           numpy.get(), "logaddexp") &&
1462       RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::LogAddExp2>>(
1463           numpy.get(), "logaddexp2") &&
1464       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Negative>>(
1465           numpy.get(), "negative") &&
1466       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Positive>>(
1467           numpy.get(), "positive") &&
1468       RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::TrueDivide>>(
1469           numpy.get(), "true_divide") &&
1470       RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::FloorDivide>>(
1471           numpy.get(), "floor_divide") &&
1472       RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Power>>(numpy.get(),
1473                                                                     "power") &&
1474       RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Remainder>>(
1475           numpy.get(), "remainder") &&
1476       RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Remainder>>(
1477           numpy.get(), "mod") &&
1478       RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Fmod>>(numpy.get(),
1479                                                                    "fmod") &&
1480       RegisterUFunc<ufuncs::DivmodUFunc>(numpy.get(), "divmod") &&
1481       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Abs>>(numpy.get(),
1482                                                                  "absolute") &&
1483       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Abs>>(numpy.get(),
1484                                                                  "fabs") &&
1485       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Rint>>(numpy.get(),
1486                                                                   "rint") &&
1487       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Sign>>(numpy.get(),
1488                                                                   "sign") &&
1489       RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Heaviside>>(
1490           numpy.get(), "heaviside") &&
1491       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Conjugate>>(
1492           numpy.get(), "conjugate") &&
1493       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Exp>>(numpy.get(),
1494                                                                  "exp") &&
1495       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Exp2>>(numpy.get(),
1496                                                                   "exp2") &&
1497       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Expm1>>(numpy.get(),
1498                                                                    "expm1") &&
1499       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Log>>(numpy.get(),
1500                                                                  "log") &&
1501       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Log2>>(numpy.get(),
1502                                                                   "log2") &&
1503       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Log10>>(numpy.get(),
1504                                                                    "log10") &&
1505       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Log1p>>(numpy.get(),
1506                                                                    "log1p") &&
1507       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Sqrt>>(numpy.get(),
1508                                                                   "sqrt") &&
1509       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Square>>(numpy.get(),
1510                                                                     "square") &&
1511       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Cbrt>>(numpy.get(),
1512                                                                   "cbrt") &&
1513       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Reciprocal>>(
1514           numpy.get(), "reciprocal") &&
1515 
1516       // Trigonometric functions
1517       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Sin>>(numpy.get(),
1518                                                                  "sin") &&
1519       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Cos>>(numpy.get(),
1520                                                                  "cos") &&
1521       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Tan>>(numpy.get(),
1522                                                                  "tan") &&
1523       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Arcsin>>(numpy.get(),
1524                                                                     "arcsin") &&
1525       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Arccos>>(numpy.get(),
1526                                                                     "arccos") &&
1527       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Arctan>>(numpy.get(),
1528                                                                     "arctan") &&
1529       RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Arctan2>>(
1530           numpy.get(), "arctan2") &&
1531       RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Hypot>>(numpy.get(),
1532                                                                     "hypot") &&
1533       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Sinh>>(numpy.get(),
1534                                                                   "sinh") &&
1535       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Cosh>>(numpy.get(),
1536                                                                   "cosh") &&
1537       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Tanh>>(numpy.get(),
1538                                                                   "tanh") &&
1539       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Arcsinh>>(
1540           numpy.get(), "arcsinh") &&
1541       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Arccosh>>(
1542           numpy.get(), "arccosh") &&
1543       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Arctanh>>(
1544           numpy.get(), "arctanh") &&
1545       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Deg2rad>>(
1546           numpy.get(), "deg2rad") &&
1547       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Rad2deg>>(
1548           numpy.get(), "rad2deg") &&
1549 
1550       // Comparison functions
1551       RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::Eq>>(numpy.get(),
1552                                                              "equal") &&
1553       RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::Ne>>(numpy.get(),
1554                                                              "not_equal") &&
1555       RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::Lt>>(numpy.get(),
1556                                                              "less") &&
1557       RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::Gt>>(numpy.get(),
1558                                                              "greater") &&
1559       RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::Le>>(numpy.get(),
1560                                                              "less_equal") &&
1561       RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::Ge>>(numpy.get(),
1562                                                              "greater_equal") &&
1563       RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Maximum>>(
1564           numpy.get(), "maximum") &&
1565       RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Minimum>>(
1566           numpy.get(), "minimum") &&
1567       RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Fmax>>(numpy.get(),
1568                                                                    "fmax") &&
1569       RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Fmin>>(numpy.get(),
1570                                                                    "fmin") &&
1571       RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::LogicalAnd>>(
1572           numpy.get(), "logical_and") &&
1573       RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::LogicalOr>>(
1574           numpy.get(), "logical_or") &&
1575       RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::LogicalXor>>(
1576           numpy.get(), "logical_xor") &&
1577       RegisterUFunc<UnaryUFunc<bfloat16, bool, ufuncs::LogicalNot>>(
1578           numpy.get(), "logical_not") &&
1579 
1580       // Floating point functions
1581       RegisterUFunc<UnaryUFunc<bfloat16, bool, ufuncs::IsFinite>>(numpy.get(),
1582                                                                   "isfinite") &&
1583       RegisterUFunc<UnaryUFunc<bfloat16, bool, ufuncs::IsInf>>(numpy.get(),
1584                                                                "isinf") &&
1585       RegisterUFunc<UnaryUFunc<bfloat16, bool, ufuncs::IsNan>>(numpy.get(),
1586                                                                "isnan") &&
1587       RegisterUFunc<UnaryUFunc<bfloat16, bool, ufuncs::SignBit>>(numpy.get(),
1588                                                                  "signbit") &&
1589       RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::CopySign>>(
1590           numpy.get(), "copysign") &&
1591       RegisterUFunc<UnaryUFunc2<bfloat16, bfloat16, bfloat16, ufuncs::Modf>>(
1592           numpy.get(), "modf") &&
1593       RegisterUFunc<BinaryUFunc2<bfloat16, int, bfloat16, ufuncs::Ldexp>>(
1594           numpy.get(), "ldexp") &&
1595       RegisterUFunc<UnaryUFunc2<bfloat16, bfloat16, int, ufuncs::Frexp>>(
1596           numpy.get(), "frexp") &&
1597       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Floor>>(numpy.get(),
1598                                                                    "floor") &&
1599       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Ceil>>(numpy.get(),
1600                                                                   "ceil") &&
1601       RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Trunc>>(numpy.get(),
1602                                                                    "trunc") &&
1603       RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::NextAfter>>(
1604           numpy.get(), "nextafter");
1605 
1606   return ok;
1607 }
1608 
RegisterNumpyBfloat16()1609 bool RegisterNumpyBfloat16() {
1610   if (npy_bfloat16 != NPY_NOTYPE) {
1611     // Already initialized.
1612     return true;
1613   }
1614   if (!Initialize()) {
1615     if (!PyErr_Occurred()) {
1616       PyErr_SetString(PyExc_RuntimeError, "cannot load bfloat16 module.");
1617     }
1618     PyErr_Print();
1619     return false;
1620   }
1621   return true;
1622 }
1623 
Bfloat16Dtype()1624 PyObject* Bfloat16Dtype() {
1625   return reinterpret_cast<PyObject*>(bfloat16_type_ptr);
1626 }
1627 
Bfloat16NumpyType()1628 int Bfloat16NumpyType() { return npy_bfloat16; }
1629 
1630 }  // namespace tensorflow
1631