/* Copyright 2022 The ml_dtypes Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #ifndef ML_DTYPES_CUSTOM_FLOAT_H_ #define ML_DTYPES_CUSTOM_FLOAT_H_ // Must be included first // clang-format off #include "_src/numpy.h" // NOLINT // clang-format on // Support utilities for adding custom floating-point dtypes to TensorFlow, // such as bfloat16, and float8_*. #include // NOLINT #include // NOLINT #include // NOLINT #include // NOLINT #include // NOLINT #include // NOLINT #include // NOLINT // Place `` before to avoid a build failure in macOS. #include #include "Eigen/Core" #include "_src/common.h" // NOLINT #include "_src/ufuncs.h" // NOLINT #undef copysign // TODO(ddunleavy): temporary fix for Windows bazel build // Possible this has to do with numpy.h being included before // system headers and in bfloat16.{cc,h}? namespace ml_dtypes { template struct CustomFloatType { static int Dtype() { return npy_type; } // Registered numpy type ID. Global variable populated by the registration // code. Protected by the GIL. static int npy_type; // Pointer to the python type object we are using. This is either a pointer // to type, if we choose to register it, or to the python type // registered by another system into NumPy. static PyObject* type_ptr; static PyNumberMethods number_methods; static PyArray_ArrFuncs arr_funcs; static PyArray_Descr npy_descr; }; template int CustomFloatType::npy_type = NPY_NOTYPE; template PyObject* CustomFloatType::type_ptr = nullptr; // Representation of a Python custom float object. template struct PyCustomFloat { PyObject_HEAD; // Python object header T value; }; // Returns true if 'object' is a PyCustomFloat. template bool PyCustomFloat_Check(PyObject* object) { return PyObject_IsInstance(object, TypeDescriptor::type_ptr); } // Extracts the value of a PyCustomFloat object. template T PyCustomFloat_CustomFloat(PyObject* object) { return reinterpret_cast*>(object)->value; } // Constructs a PyCustomFloat object from PyCustomFloat::T. template Safe_PyObjectPtr PyCustomFloat_FromT(T x) { PyTypeObject* type = reinterpret_cast(TypeDescriptor::type_ptr); Safe_PyObjectPtr ref = make_safe(type->tp_alloc(type, 0)); PyCustomFloat* p = reinterpret_cast*>(ref.get()); if (p) { p->value = x; } return ref; } // Converts a Python object to a reduced float value. Returns true on success, // returns false and reports a Python error on failure. template bool CastToCustomFloat(PyObject* arg, T* output) { if (PyCustomFloat_Check(arg)) { *output = PyCustomFloat_CustomFloat(arg); return true; } if (PyFloat_Check(arg)) { double d = PyFloat_AsDouble(arg); if (PyErr_Occurred()) { return false; } // TODO(phawkins): check for overflow *output = T(d); return true; } if (PyLong_Check(arg)) { long l = PyLong_AsLong(arg); // NOLINT if (PyErr_Occurred()) { return false; } // TODO(phawkins): check for overflow *output = T(static_cast(l)); return true; } if (PyArray_IsScalar(arg, Half)) { Eigen::half f; PyArray_ScalarAsCtype(arg, &f); *output = T(f); return true; } if (PyArray_IsScalar(arg, Float)) { float f; PyArray_ScalarAsCtype(arg, &f); *output = T(f); return true; } if (PyArray_IsScalar(arg, Double)) { double f; PyArray_ScalarAsCtype(arg, &f); *output = T(f); return true; } if (PyArray_IsScalar(arg, LongDouble)) { long double f; PyArray_ScalarAsCtype(arg, &f); *output = T(f); return true; } if (PyArray_IsZeroDim(arg)) { Safe_PyObjectPtr ref; PyArrayObject* arr = reinterpret_cast(arg); if (PyArray_TYPE(arr) != TypeDescriptor::Dtype()) { ref = make_safe(PyArray_Cast(arr, TypeDescriptor::Dtype())); if (PyErr_Occurred()) { return false; } arg = ref.get(); arr = reinterpret_cast(arg); } *output = *reinterpret_cast(PyArray_DATA(arr)); return true; } return false; } template bool SafeCastToCustomFloat(PyObject* arg, T* output) { if (PyCustomFloat_Check(arg)) { *output = PyCustomFloat_CustomFloat(arg); return true; } return false; } // Converts a PyReduceFloat into a PyFloat. template PyObject* PyCustomFloat_Float(PyObject* self) { T x = PyCustomFloat_CustomFloat(self); return PyFloat_FromDouble(static_cast(static_cast(x))); } // Converts a PyReduceFloat into a PyInt. template PyObject* PyCustomFloat_Int(PyObject* self) { T x = PyCustomFloat_CustomFloat(self); long y = static_cast(static_cast(x)); // NOLINT return PyLong_FromLong(y); } // Negates a PyCustomFloat. template PyObject* PyCustomFloat_Negative(PyObject* self) { T x = PyCustomFloat_CustomFloat(self); return PyCustomFloat_FromT(-x).release(); } template PyObject* PyCustomFloat_Add(PyObject* a, PyObject* b) { T x, y; if (SafeCastToCustomFloat(a, &x) && SafeCastToCustomFloat(b, &y)) { return PyCustomFloat_FromT(x + y).release(); } return PyArray_Type.tp_as_number->nb_add(a, b); } template PyObject* PyCustomFloat_Subtract(PyObject* a, PyObject* b) { T x, y; if (SafeCastToCustomFloat(a, &x) && SafeCastToCustomFloat(b, &y)) { return PyCustomFloat_FromT(x - y).release(); } return PyArray_Type.tp_as_number->nb_subtract(a, b); } template PyObject* PyCustomFloat_Multiply(PyObject* a, PyObject* b) { T x, y; if (SafeCastToCustomFloat(a, &x) && SafeCastToCustomFloat(b, &y)) { return PyCustomFloat_FromT(x * y).release(); } return PyArray_Type.tp_as_number->nb_multiply(a, b); } template PyObject* PyCustomFloat_TrueDivide(PyObject* a, PyObject* b) { T x, y; if (SafeCastToCustomFloat(a, &x) && SafeCastToCustomFloat(b, &y)) { return PyCustomFloat_FromT(x / y).release(); } return PyArray_Type.tp_as_number->nb_true_divide(a, b); } // Python number methods for PyCustomFloat objects. template PyNumberMethods CustomFloatType::number_methods = { PyCustomFloat_Add, // nb_add PyCustomFloat_Subtract, // nb_subtract PyCustomFloat_Multiply, // nb_multiply nullptr, // nb_remainder nullptr, // nb_divmod nullptr, // nb_power PyCustomFloat_Negative, // nb_negative nullptr, // nb_positive nullptr, // nb_absolute nullptr, // nb_nonzero nullptr, // nb_invert nullptr, // nb_lshift nullptr, // nb_rshift nullptr, // nb_and nullptr, // nb_xor nullptr, // nb_or PyCustomFloat_Int, // nb_int nullptr, // reserved PyCustomFloat_Float, // nb_float nullptr, // nb_inplace_add nullptr, // nb_inplace_subtract nullptr, // nb_inplace_multiply nullptr, // nb_inplace_remainder nullptr, // nb_inplace_power nullptr, // nb_inplace_lshift nullptr, // nb_inplace_rshift nullptr, // nb_inplace_and nullptr, // nb_inplace_xor nullptr, // nb_inplace_or nullptr, // nb_floor_divide PyCustomFloat_TrueDivide, // nb_true_divide nullptr, // nb_inplace_floor_divide nullptr, // nb_inplace_true_divide nullptr, // nb_index }; // Constructs a new PyCustomFloat. template PyObject* PyCustomFloat_New(PyTypeObject* type, PyObject* args, PyObject* kwds) { if (kwds && PyDict_Size(kwds)) { PyErr_SetString(PyExc_TypeError, "constructor takes no keyword arguments"); return nullptr; } Py_ssize_t size = PyTuple_Size(args); if (size != 1) { PyErr_Format(PyExc_TypeError, "expected number as argument to %s constructor", TypeDescriptor::kTypeName); return nullptr; } PyObject* arg = PyTuple_GetItem(args, 0); T value; if (PyCustomFloat_Check(arg)) { Py_INCREF(arg); return arg; } else if (CastToCustomFloat(arg, &value)) { return PyCustomFloat_FromT(value).release(); } else if (PyArray_Check(arg)) { PyArrayObject* arr = reinterpret_cast(arg); if (PyArray_TYPE(arr) != TypeDescriptor::Dtype()) { return PyArray_Cast(arr, TypeDescriptor::Dtype()); } else { Py_INCREF(arg); return arg; } } else if (PyUnicode_Check(arg) || PyBytes_Check(arg)) { // Parse float from string, then cast to T. PyObject* f = PyFloat_FromString(arg); if (CastToCustomFloat(f, &value)) { return PyCustomFloat_FromT(value).release(); } } PyErr_Format(PyExc_TypeError, "expected number, got %s", Py_TYPE(arg)->tp_name); return nullptr; } // Comparisons on PyCustomFloats. template PyObject* PyCustomFloat_RichCompare(PyObject* a, PyObject* b, int op) { T x, y; if (!SafeCastToCustomFloat(a, &x) || !SafeCastToCustomFloat(b, &y)) { return PyGenericArrType_Type.tp_richcompare(a, b, op); } bool result; switch (op) { case Py_LT: result = x < y; break; case Py_LE: result = x <= y; break; case Py_EQ: result = x == y; break; case Py_NE: result = x != y; break; case Py_GT: result = x > y; break; case Py_GE: result = x >= y; break; default: PyErr_SetString(PyExc_ValueError, "Invalid op type"); return nullptr; } return PyBool_FromLong(result); } // Implementation of repr() for PyCustomFloat. template PyObject* PyCustomFloat_Repr(PyObject* self) { T x = reinterpret_cast*>(self)->value; float f = static_cast(x); std::ostringstream s; s << (std::isnan(f) ? std::abs(f) : f); return PyUnicode_FromString(s.str().c_str()); } // Implementation of str() for PyCustomFloat. template PyObject* PyCustomFloat_Str(PyObject* self) { T x = reinterpret_cast*>(self)->value; float f = static_cast(x); std::ostringstream s; s << (std::isnan(f) ? std::abs(f) : f); return PyUnicode_FromString(s.str().c_str()); } // _Py_HashDouble changed its prototype for Python 3.10 so we use an overload to // handle the two possibilities. // NOLINTNEXTLINE(clang-diagnostic-unused-function) inline Py_hash_t HashImpl(Py_hash_t (*hash_double)(PyObject*, double), PyObject* self, double value) { return hash_double(self, value); } // NOLINTNEXTLINE(clang-diagnostic-unused-function) inline Py_hash_t HashImpl(Py_hash_t (*hash_double)(double), PyObject* self, double value) { return hash_double(value); } // Hash function for PyCustomFloat. template Py_hash_t PyCustomFloat_Hash(PyObject* self) { T x = reinterpret_cast*>(self)->value; return HashImpl(&_Py_HashDouble, self, static_cast(x)); } // Numpy support template PyArray_ArrFuncs CustomFloatType::arr_funcs; template PyArray_Descr CustomFloatType::npy_descr = { PyObject_HEAD_INIT(nullptr) /*typeobj=*/nullptr, // Filled in later /*kind=*/TypeDescriptor::kNpyDescrKind, /*type=*/TypeDescriptor::kNpyDescrType, /*byteorder=*/TypeDescriptor::kNpyDescrByteorder, /*flags=*/NPY_NEEDS_PYAPI | NPY_USE_SETITEM, /*type_num=*/0, /*elsize=*/sizeof(T), /*alignment=*/alignof(T), /*subarray=*/nullptr, /*fields=*/nullptr, /*names=*/nullptr, /*f=*/&CustomFloatType::arr_funcs, /*metadata=*/nullptr, /*c_metadata=*/nullptr, /*hash=*/-1, // -1 means "not computed yet". }; // Implementations of NumPy array methods. template PyObject* NPyCustomFloat_GetItem(void* data, void* arr) { T x; memcpy(&x, data, sizeof(T)); return PyFloat_FromDouble(static_cast(x)); } template int NPyCustomFloat_SetItem(PyObject* item, void* data, void* arr) { T x; if (!CastToCustomFloat(item, &x)) { PyErr_Format(PyExc_TypeError, "expected number, got %s", Py_TYPE(item)->tp_name); return -1; } memcpy(data, &x, sizeof(T)); return 0; } inline void ByteSwap16(void* value) { char* p = reinterpret_cast(value); std::swap(p[0], p[1]); } template int NPyCustomFloat_Compare(const void* a, const void* b, void* arr) { T x; memcpy(&x, a, sizeof(T)); T y; memcpy(&y, b, sizeof(T)); float fy(y); float fx(x); if (fx < fy) { return -1; } if (fy < fx) { return 1; } // NaNs sort to the end. if (!Eigen::numext::isnan(fx) && Eigen::numext::isnan(fy)) { return -1; } if (Eigen::numext::isnan(fx) && !Eigen::numext::isnan(fy)) { return 1; } return 0; } template void NPyCustomFloat_CopySwapN(void* dstv, npy_intp dstride, void* srcv, npy_intp sstride, npy_intp n, int swap, void* arr) { static_assert(sizeof(T) == sizeof(int16_t) || sizeof(T) == sizeof(int8_t), "Not supported"); char* dst = reinterpret_cast(dstv); char* src = reinterpret_cast(srcv); if (!src) { return; } if (swap && sizeof(T) == sizeof(int16_t)) { for (npy_intp i = 0; i < n; i++) { char* r = dst + dstride * i; memcpy(r, src + sstride * i, sizeof(T)); ByteSwap16(r); } } else if (dstride == sizeof(T) && sstride == sizeof(T)) { memcpy(dst, src, n * sizeof(T)); } else { for (npy_intp i = 0; i < n; i++) { memcpy(dst + dstride * i, src + sstride * i, sizeof(T)); } } } template void NPyCustomFloat_CopySwap(void* dst, void* src, int swap, void* arr) { if (!src) { return; } memcpy(dst, src, sizeof(T)); static_assert(sizeof(T) == sizeof(int16_t) || sizeof(T) == sizeof(int8_t), "Not supported"); if (swap && sizeof(T) == sizeof(int16_t)) { ByteSwap16(dst); } } template npy_bool NPyCustomFloat_NonZero(void* data, void* arr) { T x; memcpy(&x, data, sizeof(x)); return x != static_cast(0); } template int NPyCustomFloat_Fill(void* buffer_raw, npy_intp length, void* ignored) { T* const buffer = reinterpret_cast(buffer_raw); const float start(buffer[0]); const float delta = static_cast(buffer[1]) - start; for (npy_intp i = 2; i < length; ++i) { buffer[i] = static_cast(start + i * delta); } return 0; } template void NPyCustomFloat_DotFunc(void* ip1, npy_intp is1, void* ip2, npy_intp is2, void* op, npy_intp n, void* arr) { char* c1 = reinterpret_cast(ip1); char* c2 = reinterpret_cast(ip2); float acc = 0.0f; for (npy_intp i = 0; i < n; ++i) { T* const b1 = reinterpret_cast(c1); T* const b2 = reinterpret_cast(c2); acc += static_cast(*b1) * static_cast(*b2); c1 += is1; c2 += is2; } T* out = reinterpret_cast(op); *out = static_cast(acc); } template int NPyCustomFloat_CompareFunc(const void* v1, const void* v2, void* arr) { T b1 = *reinterpret_cast(v1); T b2 = *reinterpret_cast(v2); if (b1 < b2) { return -1; } if (b1 > b2) { return 1; } return 0; } template int NPyCustomFloat_ArgMaxFunc(void* data, npy_intp n, npy_intp* max_ind, void* arr) { const T* bdata = reinterpret_cast(data); // Start with a max_val of NaN, this results in the first iteration preferring // bdata[0]. float max_val = std::numeric_limits::quiet_NaN(); for (npy_intp i = 0; i < n; ++i) { // This condition is chosen so that NaNs are always considered "max". if (!(static_cast(bdata[i]) <= max_val)) { max_val = static_cast(bdata[i]); *max_ind = i; // NumPy stops at the first NaN. if (Eigen::numext::isnan(max_val)) { break; } } } return 0; } template int NPyCustomFloat_ArgMinFunc(void* data, npy_intp n, npy_intp* min_ind, void* arr) { const T* bdata = reinterpret_cast(data); float min_val = std::numeric_limits::quiet_NaN(); // Start with a min_val of NaN, this results in the first iteration preferring // bdata[0]. for (npy_intp i = 0; i < n; ++i) { // This condition is chosen so that NaNs are always considered "min". if (!(static_cast(bdata[i]) >= min_val)) { min_val = static_cast(bdata[i]); *min_ind = i; // NumPy stops at the first NaN. if (Eigen::numext::isnan(min_val)) { break; } } } return 0; } template float CastToFloat(T value) { return static_cast(value); } template float CastToFloat(std::complex value) { return CastToFloat(value.real()); } // Performs a NumPy array cast from type 'From' to 'To'. template void NPyCast(void* from_void, void* to_void, npy_intp n, void* fromarr, void* toarr) { const auto* from = reinterpret_cast::T*>(from_void); auto* to = reinterpret_cast::T*>(to_void); for (npy_intp i = 0; i < n; ++i) { to[i] = static_cast::T>( static_cast(CastToFloat(from[i]))); } } // Registers a cast between T (a reduced float) and type 'OtherT'. 'numpy_type' // is the NumPy type corresponding to 'OtherT'. template bool RegisterCustomFloatCast(int numpy_type = TypeDescriptor::Dtype()) { PyArray_Descr* descr = PyArray_DescrFromType(numpy_type); if (PyArray_RegisterCastFunc(descr, TypeDescriptor::Dtype(), NPyCast) < 0) { return false; } if (PyArray_RegisterCastFunc(&CustomFloatType::npy_descr, numpy_type, NPyCast) < 0) { return false; } return true; } template bool RegisterFloatCasts() { if (!RegisterCustomFloatCast(NPY_HALF)) { return false; } if (!RegisterCustomFloatCast(NPY_FLOAT)) { return false; } if (!RegisterCustomFloatCast(NPY_DOUBLE)) { return false; } if (!RegisterCustomFloatCast(NPY_LONGDOUBLE)) { return false; } if (!RegisterCustomFloatCast(NPY_BOOL)) { return false; } if (!RegisterCustomFloatCast(NPY_UBYTE)) { return false; } if (!RegisterCustomFloatCast(NPY_USHORT)) { // NOLINT return false; } if (!RegisterCustomFloatCast(NPY_UINT)) { return false; } if (!RegisterCustomFloatCast(NPY_ULONG)) { // NOLINT return false; } if (!RegisterCustomFloatCast( // NOLINT NPY_ULONGLONG)) { return false; } if (!RegisterCustomFloatCast(NPY_BYTE)) { return false; } if (!RegisterCustomFloatCast(NPY_SHORT)) { // NOLINT return false; } if (!RegisterCustomFloatCast(NPY_INT)) { return false; } if (!RegisterCustomFloatCast(NPY_LONG)) { // NOLINT return false; } if (!RegisterCustomFloatCast(NPY_LONGLONG)) { // NOLINT return false; } // Following the numpy convention. imag part is dropped when converting to // float. if (!RegisterCustomFloatCast>(NPY_CFLOAT)) { return false; } if (!RegisterCustomFloatCast>(NPY_CDOUBLE)) { return false; } if (!RegisterCustomFloatCast>(NPY_CLONGDOUBLE)) { return false; } // Safe casts from T to other types if (PyArray_RegisterCanCast(&TypeDescriptor::npy_descr, NPY_FLOAT, NPY_NOSCALAR) < 0) { return false; } if (PyArray_RegisterCanCast(&TypeDescriptor::npy_descr, NPY_DOUBLE, NPY_NOSCALAR) < 0) { return false; } if (PyArray_RegisterCanCast(&TypeDescriptor::npy_descr, NPY_LONGDOUBLE, NPY_NOSCALAR) < 0) { return false; } if (PyArray_RegisterCanCast(&TypeDescriptor::npy_descr, NPY_CFLOAT, NPY_NOSCALAR) < 0) { return false; } if (PyArray_RegisterCanCast(&TypeDescriptor::npy_descr, NPY_CDOUBLE, NPY_NOSCALAR) < 0) { return false; } if (PyArray_RegisterCanCast(&TypeDescriptor::npy_descr, NPY_CLONGDOUBLE, NPY_NOSCALAR) < 0) { return false; } // Safe casts to T from other types if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_BOOL), TypeDescriptor::Dtype(), NPY_NOSCALAR) < 0) { return false; } if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_UBYTE), TypeDescriptor::Dtype(), NPY_NOSCALAR) < 0) { return false; } if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_BYTE), TypeDescriptor::Dtype(), NPY_NOSCALAR) < 0) { return false; } return true; } template bool RegisterFloatUFuncs(PyObject* numpy) { bool ok = RegisterUFunc>, T>(numpy, "add") && RegisterUFunc>, T>(numpy, "subtract") && RegisterUFunc>, T>(numpy, "multiply") && RegisterUFunc>, T>(numpy, "divide") && RegisterUFunc>, T>(numpy, "logaddexp") && RegisterUFunc>, T>( numpy, "logaddexp2") && RegisterUFunc>, T>(numpy, "negative") && RegisterUFunc>, T>(numpy, "positive") && RegisterUFunc>, T>( numpy, "true_divide") && RegisterUFunc>, T>( numpy, "floor_divide") && RegisterUFunc>, T>(numpy, "power") && RegisterUFunc>, T>(numpy, "remainder") && RegisterUFunc>, T>(numpy, "mod") && RegisterUFunc>, T>(numpy, "fmod") && RegisterUFunc, T>(numpy, "divmod") && RegisterUFunc>, T>(numpy, "absolute") && RegisterUFunc>, T>(numpy, "fabs") && RegisterUFunc>, T>(numpy, "rint") && RegisterUFunc>, T>(numpy, "sign") && RegisterUFunc>, T>(numpy, "heaviside") && RegisterUFunc>, T>(numpy, "conjugate") && RegisterUFunc>, T>(numpy, "exp") && RegisterUFunc>, T>(numpy, "exp2") && RegisterUFunc>, T>(numpy, "expm1") && RegisterUFunc>, T>(numpy, "log") && RegisterUFunc>, T>(numpy, "log2") && RegisterUFunc>, T>(numpy, "log10") && RegisterUFunc>, T>(numpy, "log1p") && RegisterUFunc>, T>(numpy, "sqrt") && RegisterUFunc>, T>(numpy, "square") && RegisterUFunc>, T>(numpy, "cbrt") && RegisterUFunc>, T>(numpy, "reciprocal") && // Trigonometric functions RegisterUFunc>, T>(numpy, "sin") && RegisterUFunc>, T>(numpy, "cos") && RegisterUFunc>, T>(numpy, "tan") && RegisterUFunc>, T>(numpy, "arcsin") && RegisterUFunc>, T>(numpy, "arccos") && RegisterUFunc>, T>(numpy, "arctan") && RegisterUFunc>, T>(numpy, "arctan2") && RegisterUFunc>, T>(numpy, "hypot") && RegisterUFunc>, T>(numpy, "sinh") && RegisterUFunc>, T>(numpy, "cosh") && RegisterUFunc>, T>(numpy, "tanh") && RegisterUFunc>, T>(numpy, "arcsinh") && RegisterUFunc>, T>(numpy, "arccosh") && RegisterUFunc>, T>(numpy, "arctanh") && RegisterUFunc>, T>(numpy, "deg2rad") && RegisterUFunc>, T>(numpy, "rad2deg") && // Comparison functions RegisterUFunc>, T>(numpy, "equal") && RegisterUFunc>, T>(numpy, "not_equal") && RegisterUFunc>, T>(numpy, "less") && RegisterUFunc>, T>(numpy, "greater") && RegisterUFunc>, T>(numpy, "less_equal") && RegisterUFunc>, T>(numpy, "greater_equal") && RegisterUFunc>, T>(numpy, "maximum") && RegisterUFunc>, T>(numpy, "minimum") && RegisterUFunc>, T>(numpy, "fmax") && RegisterUFunc>, T>(numpy, "fmin") && RegisterUFunc>, T>( numpy, "logical_and") && RegisterUFunc>, T>( numpy, "logical_or") && RegisterUFunc>, T>( numpy, "logical_xor") && RegisterUFunc>, T>( numpy, "logical_not") && // Floating point functions RegisterUFunc>, T>(numpy, "isfinite") && RegisterUFunc>, T>(numpy, "isinf") && RegisterUFunc>, T>(numpy, "isnan") && RegisterUFunc>, T>(numpy, "signbit") && RegisterUFunc>, T>(numpy, "copysign") && RegisterUFunc>, T>(numpy, "modf") && RegisterUFunc>, T>(numpy, "ldexp") && RegisterUFunc>, T>(numpy, "frexp") && RegisterUFunc>, T>(numpy, "floor") && RegisterUFunc>, T>(numpy, "ceil") && RegisterUFunc>, T>(numpy, "trunc") && RegisterUFunc>, T>(numpy, "nextafter") && RegisterUFunc>, T>(numpy, "spacing"); return ok; } // Returns true if the numpy type for T is successfully registered, including if // it was already registered (e.g. by a different library). If // `already_registered` is non-null, it's set to true if the type was already // registered and false otherwise. template bool RegisterFloatDtype(PyObject* numpy, bool* already_registered = nullptr) { if (already_registered != nullptr) { *already_registered = false; } // If another module (presumably either TF or JAX) has registered a bfloat16 // type, use it. We don't want two bfloat16 types if we can avoid it since it // leads to confusion if we have two different types with the same name. This // assumes that the other module has a sufficiently complete bfloat16 // implementation. The only known NumPy bfloat16 extension at the time of // writing is this one (distributed in TF and JAX). // TODO(phawkins): distribute the bfloat16 extension as its own pip package, // so we can unambiguously refer to a single canonical definition of bfloat16. int typenum = PyArray_TypeNumFromName(const_cast(TypeDescriptor::kTypeName)); if (typenum != NPY_NOTYPE) { PyArray_Descr* descr = PyArray_DescrFromType(typenum); // The test for an argmax function here is to verify that the // bfloat16 implementation is sufficiently new, and, say, not from // an older version of TF or JAX. if (descr && descr->f && descr->f->argmax) { TypeDescriptor::npy_type = typenum; TypeDescriptor::type_ptr = reinterpret_cast(descr->typeobj); if (already_registered != nullptr) { *already_registered = true; } return true; } } // It's important that we heap-allocate our type. This is because tp_name // is not a fully-qualified name for a heap-allocated type, and // PyArray_TypeNumFromName() (above) looks at the tp_name field to find // types. Existing implementations in JAX and TensorFlow look for "bfloat16", // not "ml_dtypes.bfloat16" when searching for an implementation. Safe_PyObjectPtr name = make_safe(PyUnicode_FromString(TypeDescriptor::kTypeName)); Safe_PyObjectPtr qualname = make_safe(PyUnicode_FromString(TypeDescriptor::kTypeName)); PyHeapTypeObject* heap_type = reinterpret_cast( PyType_Type.tp_alloc(&PyType_Type, 0)); if (!heap_type) { return false; } // Caution: we must not call any functions that might invoke the GC until // PyType_Ready() is called. Otherwise the GC might see a half-constructed // type object. heap_type->ht_name = name.release(); heap_type->ht_qualname = qualname.release(); PyTypeObject* type = &heap_type->ht_type; type->tp_name = TypeDescriptor::kTypeName; type->tp_basicsize = sizeof(PyCustomFloat); type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; type->tp_base = &PyGenericArrType_Type; type->tp_new = PyCustomFloat_New; type->tp_repr = PyCustomFloat_Repr; type->tp_hash = PyCustomFloat_Hash; type->tp_str = PyCustomFloat_Str; type->tp_doc = const_cast(TypeDescriptor::kTpDoc); type->tp_richcompare = PyCustomFloat_RichCompare; type->tp_as_number = &CustomFloatType::number_methods; if (PyType_Ready(type) < 0) { return false; } TypeDescriptor::type_ptr = reinterpret_cast(type); Safe_PyObjectPtr module = make_safe(PyUnicode_FromString("ml_dtypes")); if (!module) { return false; } if (PyObject_SetAttrString(TypeDescriptor::type_ptr, "__module__", module.get()) < 0) { return false; } // Initializes the NumPy descriptor. PyArray_ArrFuncs& arr_funcs = CustomFloatType::arr_funcs; PyArray_InitArrFuncs(&arr_funcs); arr_funcs.getitem = NPyCustomFloat_GetItem; arr_funcs.setitem = NPyCustomFloat_SetItem; arr_funcs.compare = NPyCustomFloat_Compare; arr_funcs.copyswapn = NPyCustomFloat_CopySwapN; arr_funcs.copyswap = NPyCustomFloat_CopySwap; arr_funcs.nonzero = NPyCustomFloat_NonZero; arr_funcs.fill = NPyCustomFloat_Fill; arr_funcs.dotfunc = NPyCustomFloat_DotFunc; arr_funcs.compare = NPyCustomFloat_CompareFunc; arr_funcs.argmax = NPyCustomFloat_ArgMaxFunc; arr_funcs.argmin = NPyCustomFloat_ArgMinFunc; #if PY_VERSION_HEX < 0x030900A4 && !defined(Py_SET_TYPE) Py_TYPE(&CustomFloatType::npy_descr) = &PyArrayDescr_Type; #else Py_SET_TYPE(&CustomFloatType::npy_descr, &PyArrayDescr_Type); #endif TypeDescriptor::npy_descr.typeobj = type; TypeDescriptor::npy_type = PyArray_RegisterDataType(&CustomFloatType::npy_descr); if (TypeDescriptor::Dtype() < 0) { return false; } Safe_PyObjectPtr typeDict_obj = make_safe(PyObject_GetAttrString(numpy, "sctypeDict")); if (!typeDict_obj) return false; // Add the type object to `numpy.typeDict`: that makes // `numpy.dtype(type_name)` work. if (PyDict_SetItemString(typeDict_obj.get(), TypeDescriptor::kTypeName, TypeDescriptor::type_ptr) < 0) { return false; } // Support dtype(type_name) if (PyObject_SetAttrString(TypeDescriptor::type_ptr, "dtype", reinterpret_cast( &CustomFloatType::npy_descr)) < 0) { return false; } return RegisterFloatCasts() && RegisterFloatUFuncs(numpy); } } // namespace ml_dtypes #endif // ML_DTYPES_CUSTOM_FLOAT_H_