Intelegentny_Pszczelarz/.venv/Lib/site-packages/ml_dtypes/_src/custom_float.h
2023-06-19 00:49:18 +02:00

998 lines
35 KiB
C++

/* Copyright 2022 The ml_dtypes Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef ML_DTYPES_CUSTOM_FLOAT_H_
#define ML_DTYPES_CUSTOM_FLOAT_H_
// Must be included first
// clang-format off
#include "_src/numpy.h" // NOLINT
// clang-format on
// Support utilities for adding custom floating-point dtypes to TensorFlow,
// such as bfloat16, and float8_*.
#include <array> // NOLINT
#include <cmath> // NOLINT
#include <limits> // NOLINT
#include <locale> // NOLINT
#include <memory> // NOLINT
#include <sstream> // NOLINT
#include <vector> // NOLINT
// Place `<locale>` before <Python.h> to avoid a build failure in macOS.
#include <Python.h>
#include "Eigen/Core"
#include "_src/common.h" // NOLINT
#include "_src/ufuncs.h" // NOLINT
#undef copysign // TODO(ddunleavy): temporary fix for Windows bazel build
// Possible this has to do with numpy.h being included before
// system headers and in bfloat16.{cc,h}?
namespace ml_dtypes {
template <typename T>
struct CustomFloatType {
static int Dtype() { return npy_type; }
// Registered numpy type ID. Global variable populated by the registration
// code. Protected by the GIL.
static int npy_type;
// Pointer to the python type object we are using. This is either a pointer
// to type, if we choose to register it, or to the python type
// registered by another system into NumPy.
static PyObject* type_ptr;
static PyNumberMethods number_methods;
static PyArray_ArrFuncs arr_funcs;
static PyArray_Descr npy_descr;
};
template <typename T>
int CustomFloatType<T>::npy_type = NPY_NOTYPE;
template <typename T>
PyObject* CustomFloatType<T>::type_ptr = nullptr;
// Representation of a Python custom float object.
template <typename T>
struct PyCustomFloat {
PyObject_HEAD; // Python object header
T value;
};
// Returns true if 'object' is a PyCustomFloat.
template <typename T>
bool PyCustomFloat_Check(PyObject* object) {
return PyObject_IsInstance(object, TypeDescriptor<T>::type_ptr);
}
// Extracts the value of a PyCustomFloat object.
template <typename T>
T PyCustomFloat_CustomFloat(PyObject* object) {
return reinterpret_cast<PyCustomFloat<T>*>(object)->value;
}
// Constructs a PyCustomFloat object from PyCustomFloat<T>::T.
template <typename T>
Safe_PyObjectPtr PyCustomFloat_FromT(T x) {
PyTypeObject* type =
reinterpret_cast<PyTypeObject*>(TypeDescriptor<T>::type_ptr);
Safe_PyObjectPtr ref = make_safe(type->tp_alloc(type, 0));
PyCustomFloat<T>* p = reinterpret_cast<PyCustomFloat<T>*>(ref.get());
if (p) {
p->value = x;
}
return ref;
}
// Converts a Python object to a reduced float value. Returns true on success,
// returns false and reports a Python error on failure.
template <typename T>
bool CastToCustomFloat(PyObject* arg, T* output) {
if (PyCustomFloat_Check<T>(arg)) {
*output = PyCustomFloat_CustomFloat<T>(arg);
return true;
}
if (PyFloat_Check(arg)) {
double d = PyFloat_AsDouble(arg);
if (PyErr_Occurred()) {
return false;
}
// TODO(phawkins): check for overflow
*output = T(d);
return true;
}
if (PyLong_Check(arg)) {
long l = PyLong_AsLong(arg); // NOLINT
if (PyErr_Occurred()) {
return false;
}
// TODO(phawkins): check for overflow
*output = T(static_cast<float>(l));
return true;
}
if (PyArray_IsScalar(arg, Half)) {
Eigen::half f;
PyArray_ScalarAsCtype(arg, &f);
*output = T(f);
return true;
}
if (PyArray_IsScalar(arg, Float)) {
float f;
PyArray_ScalarAsCtype(arg, &f);
*output = T(f);
return true;
}
if (PyArray_IsScalar(arg, Double)) {
double f;
PyArray_ScalarAsCtype(arg, &f);
*output = T(f);
return true;
}
if (PyArray_IsScalar(arg, LongDouble)) {
long double f;
PyArray_ScalarAsCtype(arg, &f);
*output = T(f);
return true;
}
if (PyArray_IsZeroDim(arg)) {
Safe_PyObjectPtr ref;
PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(arg);
if (PyArray_TYPE(arr) != TypeDescriptor<T>::Dtype()) {
ref = make_safe(PyArray_Cast(arr, TypeDescriptor<T>::Dtype()));
if (PyErr_Occurred()) {
return false;
}
arg = ref.get();
arr = reinterpret_cast<PyArrayObject*>(arg);
}
*output = *reinterpret_cast<T*>(PyArray_DATA(arr));
return true;
}
return false;
}
template <typename T>
bool SafeCastToCustomFloat(PyObject* arg, T* output) {
if (PyCustomFloat_Check<T>(arg)) {
*output = PyCustomFloat_CustomFloat<T>(arg);
return true;
}
return false;
}
// Converts a PyReduceFloat into a PyFloat.
template <typename T>
PyObject* PyCustomFloat_Float(PyObject* self) {
T x = PyCustomFloat_CustomFloat<T>(self);
return PyFloat_FromDouble(static_cast<double>(static_cast<float>(x)));
}
// Converts a PyReduceFloat into a PyInt.
template <typename T>
PyObject* PyCustomFloat_Int(PyObject* self) {
T x = PyCustomFloat_CustomFloat<T>(self);
long y = static_cast<long>(static_cast<float>(x)); // NOLINT
return PyLong_FromLong(y);
}
// Negates a PyCustomFloat.
template <typename T>
PyObject* PyCustomFloat_Negative(PyObject* self) {
T x = PyCustomFloat_CustomFloat<T>(self);
return PyCustomFloat_FromT<T>(-x).release();
}
template <typename T>
PyObject* PyCustomFloat_Add(PyObject* a, PyObject* b) {
T x, y;
if (SafeCastToCustomFloat<T>(a, &x) && SafeCastToCustomFloat<T>(b, &y)) {
return PyCustomFloat_FromT<T>(x + y).release();
}
return PyArray_Type.tp_as_number->nb_add(a, b);
}
template <typename T>
PyObject* PyCustomFloat_Subtract(PyObject* a, PyObject* b) {
T x, y;
if (SafeCastToCustomFloat<T>(a, &x) && SafeCastToCustomFloat<T>(b, &y)) {
return PyCustomFloat_FromT<T>(x - y).release();
}
return PyArray_Type.tp_as_number->nb_subtract(a, b);
}
template <typename T>
PyObject* PyCustomFloat_Multiply(PyObject* a, PyObject* b) {
T x, y;
if (SafeCastToCustomFloat<T>(a, &x) && SafeCastToCustomFloat<T>(b, &y)) {
return PyCustomFloat_FromT<T>(x * y).release();
}
return PyArray_Type.tp_as_number->nb_multiply(a, b);
}
template <typename T>
PyObject* PyCustomFloat_TrueDivide(PyObject* a, PyObject* b) {
T x, y;
if (SafeCastToCustomFloat<T>(a, &x) && SafeCastToCustomFloat<T>(b, &y)) {
return PyCustomFloat_FromT<T>(x / y).release();
}
return PyArray_Type.tp_as_number->nb_true_divide(a, b);
}
// Python number methods for PyCustomFloat objects.
template <typename T>
PyNumberMethods CustomFloatType<T>::number_methods = {
PyCustomFloat_Add<T>, // nb_add
PyCustomFloat_Subtract<T>, // nb_subtract
PyCustomFloat_Multiply<T>, // nb_multiply
nullptr, // nb_remainder
nullptr, // nb_divmod
nullptr, // nb_power
PyCustomFloat_Negative<T>, // nb_negative
nullptr, // nb_positive
nullptr, // nb_absolute
nullptr, // nb_nonzero
nullptr, // nb_invert
nullptr, // nb_lshift
nullptr, // nb_rshift
nullptr, // nb_and
nullptr, // nb_xor
nullptr, // nb_or
PyCustomFloat_Int<T>, // nb_int
nullptr, // reserved
PyCustomFloat_Float<T>, // nb_float
nullptr, // nb_inplace_add
nullptr, // nb_inplace_subtract
nullptr, // nb_inplace_multiply
nullptr, // nb_inplace_remainder
nullptr, // nb_inplace_power
nullptr, // nb_inplace_lshift
nullptr, // nb_inplace_rshift
nullptr, // nb_inplace_and
nullptr, // nb_inplace_xor
nullptr, // nb_inplace_or
nullptr, // nb_floor_divide
PyCustomFloat_TrueDivide<T>, // nb_true_divide
nullptr, // nb_inplace_floor_divide
nullptr, // nb_inplace_true_divide
nullptr, // nb_index
};
// Constructs a new PyCustomFloat.
template <typename T>
PyObject* PyCustomFloat_New(PyTypeObject* type, PyObject* args,
PyObject* kwds) {
if (kwds && PyDict_Size(kwds)) {
PyErr_SetString(PyExc_TypeError, "constructor takes no keyword arguments");
return nullptr;
}
Py_ssize_t size = PyTuple_Size(args);
if (size != 1) {
PyErr_Format(PyExc_TypeError,
"expected number as argument to %s constructor",
TypeDescriptor<T>::kTypeName);
return nullptr;
}
PyObject* arg = PyTuple_GetItem(args, 0);
T value;
if (PyCustomFloat_Check<T>(arg)) {
Py_INCREF(arg);
return arg;
} else if (CastToCustomFloat<T>(arg, &value)) {
return PyCustomFloat_FromT<T>(value).release();
} else if (PyArray_Check(arg)) {
PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(arg);
if (PyArray_TYPE(arr) != TypeDescriptor<T>::Dtype()) {
return PyArray_Cast(arr, TypeDescriptor<T>::Dtype());
} else {
Py_INCREF(arg);
return arg;
}
} else if (PyUnicode_Check(arg) || PyBytes_Check(arg)) {
// Parse float from string, then cast to T.
PyObject* f = PyFloat_FromString(arg);
if (CastToCustomFloat<T>(f, &value)) {
return PyCustomFloat_FromT<T>(value).release();
}
}
PyErr_Format(PyExc_TypeError, "expected number, got %s",
Py_TYPE(arg)->tp_name);
return nullptr;
}
// Comparisons on PyCustomFloats.
template <typename T>
PyObject* PyCustomFloat_RichCompare(PyObject* a, PyObject* b, int op) {
T x, y;
if (!SafeCastToCustomFloat<T>(a, &x) || !SafeCastToCustomFloat<T>(b, &y)) {
return PyGenericArrType_Type.tp_richcompare(a, b, op);
}
bool result;
switch (op) {
case Py_LT:
result = x < y;
break;
case Py_LE:
result = x <= y;
break;
case Py_EQ:
result = x == y;
break;
case Py_NE:
result = x != y;
break;
case Py_GT:
result = x > y;
break;
case Py_GE:
result = x >= y;
break;
default:
PyErr_SetString(PyExc_ValueError, "Invalid op type");
return nullptr;
}
return PyBool_FromLong(result);
}
// Implementation of repr() for PyCustomFloat.
template <typename T>
PyObject* PyCustomFloat_Repr(PyObject* self) {
T x = reinterpret_cast<PyCustomFloat<T>*>(self)->value;
float f = static_cast<float>(x);
std::ostringstream s;
s << (std::isnan(f) ? std::abs(f) : f);
return PyUnicode_FromString(s.str().c_str());
}
// Implementation of str() for PyCustomFloat.
template <typename T>
PyObject* PyCustomFloat_Str(PyObject* self) {
T x = reinterpret_cast<PyCustomFloat<T>*>(self)->value;
float f = static_cast<float>(x);
std::ostringstream s;
s << (std::isnan(f) ? std::abs(f) : f);
return PyUnicode_FromString(s.str().c_str());
}
// _Py_HashDouble changed its prototype for Python 3.10 so we use an overload to
// handle the two possibilities.
// NOLINTNEXTLINE(clang-diagnostic-unused-function)
inline Py_hash_t HashImpl(Py_hash_t (*hash_double)(PyObject*, double),
PyObject* self, double value) {
return hash_double(self, value);
}
// NOLINTNEXTLINE(clang-diagnostic-unused-function)
inline Py_hash_t HashImpl(Py_hash_t (*hash_double)(double), PyObject* self,
double value) {
return hash_double(value);
}
// Hash function for PyCustomFloat.
template <typename T>
Py_hash_t PyCustomFloat_Hash(PyObject* self) {
T x = reinterpret_cast<PyCustomFloat<T>*>(self)->value;
return HashImpl(&_Py_HashDouble, self, static_cast<double>(x));
}
// Numpy support
template <typename T>
PyArray_ArrFuncs CustomFloatType<T>::arr_funcs;
template <typename T>
PyArray_Descr CustomFloatType<T>::npy_descr = {
PyObject_HEAD_INIT(nullptr)
/*typeobj=*/nullptr, // Filled in later
/*kind=*/TypeDescriptor<T>::kNpyDescrKind,
/*type=*/TypeDescriptor<T>::kNpyDescrType,
/*byteorder=*/TypeDescriptor<T>::kNpyDescrByteorder,
/*flags=*/NPY_NEEDS_PYAPI | NPY_USE_SETITEM,
/*type_num=*/0,
/*elsize=*/sizeof(T),
/*alignment=*/alignof(T),
/*subarray=*/nullptr,
/*fields=*/nullptr,
/*names=*/nullptr,
/*f=*/&CustomFloatType<T>::arr_funcs,
/*metadata=*/nullptr,
/*c_metadata=*/nullptr,
/*hash=*/-1, // -1 means "not computed yet".
};
// Implementations of NumPy array methods.
template <typename T>
PyObject* NPyCustomFloat_GetItem(void* data, void* arr) {
T x;
memcpy(&x, data, sizeof(T));
return PyFloat_FromDouble(static_cast<float>(x));
}
template <typename T>
int NPyCustomFloat_SetItem(PyObject* item, void* data, void* arr) {
T x;
if (!CastToCustomFloat<T>(item, &x)) {
PyErr_Format(PyExc_TypeError, "expected number, got %s",
Py_TYPE(item)->tp_name);
return -1;
}
memcpy(data, &x, sizeof(T));
return 0;
}
inline void ByteSwap16(void* value) {
char* p = reinterpret_cast<char*>(value);
std::swap(p[0], p[1]);
}
template <typename T>
int NPyCustomFloat_Compare(const void* a, const void* b, void* arr) {
T x;
memcpy(&x, a, sizeof(T));
T y;
memcpy(&y, b, sizeof(T));
float fy(y);
float fx(x);
if (fx < fy) {
return -1;
}
if (fy < fx) {
return 1;
}
// NaNs sort to the end.
if (!Eigen::numext::isnan(fx) && Eigen::numext::isnan(fy)) {
return -1;
}
if (Eigen::numext::isnan(fx) && !Eigen::numext::isnan(fy)) {
return 1;
}
return 0;
}
template <typename T>
void NPyCustomFloat_CopySwapN(void* dstv, npy_intp dstride, void* srcv,
npy_intp sstride, npy_intp n, int swap,
void* arr) {
static_assert(sizeof(T) == sizeof(int16_t) || sizeof(T) == sizeof(int8_t),
"Not supported");
char* dst = reinterpret_cast<char*>(dstv);
char* src = reinterpret_cast<char*>(srcv);
if (!src) {
return;
}
if (swap && sizeof(T) == sizeof(int16_t)) {
for (npy_intp i = 0; i < n; i++) {
char* r = dst + dstride * i;
memcpy(r, src + sstride * i, sizeof(T));
ByteSwap16(r);
}
} else if (dstride == sizeof(T) && sstride == sizeof(T)) {
memcpy(dst, src, n * sizeof(T));
} else {
for (npy_intp i = 0; i < n; i++) {
memcpy(dst + dstride * i, src + sstride * i, sizeof(T));
}
}
}
template <typename T>
void NPyCustomFloat_CopySwap(void* dst, void* src, int swap, void* arr) {
if (!src) {
return;
}
memcpy(dst, src, sizeof(T));
static_assert(sizeof(T) == sizeof(int16_t) || sizeof(T) == sizeof(int8_t),
"Not supported");
if (swap && sizeof(T) == sizeof(int16_t)) {
ByteSwap16(dst);
}
}
template <typename T>
npy_bool NPyCustomFloat_NonZero(void* data, void* arr) {
T x;
memcpy(&x, data, sizeof(x));
return x != static_cast<T>(0);
}
template <typename T>
int NPyCustomFloat_Fill(void* buffer_raw, npy_intp length, void* ignored) {
T* const buffer = reinterpret_cast<T*>(buffer_raw);
const float start(buffer[0]);
const float delta = static_cast<float>(buffer[1]) - start;
for (npy_intp i = 2; i < length; ++i) {
buffer[i] = static_cast<T>(start + i * delta);
}
return 0;
}
template <typename T>
void NPyCustomFloat_DotFunc(void* ip1, npy_intp is1, void* ip2, npy_intp is2,
void* op, npy_intp n, void* arr) {
char* c1 = reinterpret_cast<char*>(ip1);
char* c2 = reinterpret_cast<char*>(ip2);
float acc = 0.0f;
for (npy_intp i = 0; i < n; ++i) {
T* const b1 = reinterpret_cast<T*>(c1);
T* const b2 = reinterpret_cast<T*>(c2);
acc += static_cast<float>(*b1) * static_cast<float>(*b2);
c1 += is1;
c2 += is2;
}
T* out = reinterpret_cast<T*>(op);
*out = static_cast<T>(acc);
}
template <typename T>
int NPyCustomFloat_CompareFunc(const void* v1, const void* v2, void* arr) {
T b1 = *reinterpret_cast<const T*>(v1);
T b2 = *reinterpret_cast<const T*>(v2);
if (b1 < b2) {
return -1;
}
if (b1 > b2) {
return 1;
}
return 0;
}
template <typename T>
int NPyCustomFloat_ArgMaxFunc(void* data, npy_intp n, npy_intp* max_ind,
void* arr) {
const T* bdata = reinterpret_cast<const T*>(data);
// Start with a max_val of NaN, this results in the first iteration preferring
// bdata[0].
float max_val = std::numeric_limits<float>::quiet_NaN();
for (npy_intp i = 0; i < n; ++i) {
// This condition is chosen so that NaNs are always considered "max".
if (!(static_cast<float>(bdata[i]) <= max_val)) {
max_val = static_cast<float>(bdata[i]);
*max_ind = i;
// NumPy stops at the first NaN.
if (Eigen::numext::isnan(max_val)) {
break;
}
}
}
return 0;
}
template <typename T>
int NPyCustomFloat_ArgMinFunc(void* data, npy_intp n, npy_intp* min_ind,
void* arr) {
const T* bdata = reinterpret_cast<const T*>(data);
float min_val = std::numeric_limits<float>::quiet_NaN();
// Start with a min_val of NaN, this results in the first iteration preferring
// bdata[0].
for (npy_intp i = 0; i < n; ++i) {
// This condition is chosen so that NaNs are always considered "min".
if (!(static_cast<float>(bdata[i]) >= min_val)) {
min_val = static_cast<float>(bdata[i]);
*min_ind = i;
// NumPy stops at the first NaN.
if (Eigen::numext::isnan(min_val)) {
break;
}
}
}
return 0;
}
template <typename T>
float CastToFloat(T value) {
return static_cast<float>(value);
}
template <typename T>
float CastToFloat(std::complex<T> value) {
return CastToFloat(value.real());
}
// Performs a NumPy array cast from type 'From' to 'To'.
template <typename From, typename To>
void NPyCast(void* from_void, void* to_void, npy_intp n, void* fromarr,
void* toarr) {
const auto* from =
reinterpret_cast<typename TypeDescriptor<From>::T*>(from_void);
auto* to = reinterpret_cast<typename TypeDescriptor<To>::T*>(to_void);
for (npy_intp i = 0; i < n; ++i) {
to[i] = static_cast<typename TypeDescriptor<To>::T>(
static_cast<To>(CastToFloat(from[i])));
}
}
// Registers a cast between T (a reduced float) and type 'OtherT'. 'numpy_type'
// is the NumPy type corresponding to 'OtherT'.
template <typename T, typename OtherT>
bool RegisterCustomFloatCast(int numpy_type = TypeDescriptor<OtherT>::Dtype()) {
PyArray_Descr* descr = PyArray_DescrFromType(numpy_type);
if (PyArray_RegisterCastFunc(descr, TypeDescriptor<T>::Dtype(),
NPyCast<OtherT, T>) < 0) {
return false;
}
if (PyArray_RegisterCastFunc(&CustomFloatType<T>::npy_descr,
numpy_type, NPyCast<T, OtherT>) < 0) {
return false;
}
return true;
}
template <typename T>
bool RegisterFloatCasts() {
if (!RegisterCustomFloatCast<T, Eigen::half>(NPY_HALF)) {
return false;
}
if (!RegisterCustomFloatCast<T, float>(NPY_FLOAT)) {
return false;
}
if (!RegisterCustomFloatCast<T, double>(NPY_DOUBLE)) {
return false;
}
if (!RegisterCustomFloatCast<T, long double>(NPY_LONGDOUBLE)) {
return false;
}
if (!RegisterCustomFloatCast<T, bool>(NPY_BOOL)) {
return false;
}
if (!RegisterCustomFloatCast<T, unsigned char>(NPY_UBYTE)) {
return false;
}
if (!RegisterCustomFloatCast<T, unsigned short>(NPY_USHORT)) { // NOLINT
return false;
}
if (!RegisterCustomFloatCast<T, unsigned int>(NPY_UINT)) {
return false;
}
if (!RegisterCustomFloatCast<T, unsigned long>(NPY_ULONG)) { // NOLINT
return false;
}
if (!RegisterCustomFloatCast<T, unsigned long long>( // NOLINT
NPY_ULONGLONG)) {
return false;
}
if (!RegisterCustomFloatCast<T, signed char>(NPY_BYTE)) {
return false;
}
if (!RegisterCustomFloatCast<T, short>(NPY_SHORT)) { // NOLINT
return false;
}
if (!RegisterCustomFloatCast<T, int>(NPY_INT)) {
return false;
}
if (!RegisterCustomFloatCast<T, long>(NPY_LONG)) { // NOLINT
return false;
}
if (!RegisterCustomFloatCast<T, long long>(NPY_LONGLONG)) { // NOLINT
return false;
}
// Following the numpy convention. imag part is dropped when converting to
// float.
if (!RegisterCustomFloatCast<T, std::complex<float>>(NPY_CFLOAT)) {
return false;
}
if (!RegisterCustomFloatCast<T, std::complex<double>>(NPY_CDOUBLE)) {
return false;
}
if (!RegisterCustomFloatCast<T, std::complex<long double>>(NPY_CLONGDOUBLE)) {
return false;
}
// Safe casts from T to other types
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_FLOAT,
NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_DOUBLE,
NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_LONGDOUBLE,
NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_CFLOAT,
NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_CDOUBLE,
NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_CLONGDOUBLE,
NPY_NOSCALAR) < 0) {
return false;
}
// Safe casts to T from other types
if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_BOOL),
TypeDescriptor<T>::Dtype(), NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_UBYTE),
TypeDescriptor<T>::Dtype(), NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_BYTE),
TypeDescriptor<T>::Dtype(), NPY_NOSCALAR) < 0) {
return false;
}
return true;
}
template <typename T>
bool RegisterFloatUFuncs(PyObject* numpy) {
bool ok =
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Add<T>>, T>(numpy, "add") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Subtract<T>>, T>(numpy,
"subtract") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Multiply<T>>, T>(numpy,
"multiply") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::TrueDivide<T>>, T>(numpy,
"divide") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::LogAddExp<T>>, T>(numpy,
"logaddexp") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::LogAddExp2<T>>, T>(
numpy, "logaddexp2") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Negative<T>>, T>(numpy,
"negative") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Positive<T>>, T>(numpy,
"positive") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::TrueDivide<T>>, T>(
numpy, "true_divide") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::FloorDivide<T>>, T>(
numpy, "floor_divide") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Power<T>>, T>(numpy, "power") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Remainder<T>>, T>(numpy,
"remainder") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Remainder<T>>, T>(numpy, "mod") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Fmod<T>>, T>(numpy, "fmod") &&
RegisterUFunc<ufuncs::DivmodUFunc<T>, T>(numpy, "divmod") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Abs<T>>, T>(numpy, "absolute") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Abs<T>>, T>(numpy, "fabs") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Rint<T>>, T>(numpy, "rint") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Sign<T>>, T>(numpy, "sign") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Heaviside<T>>, T>(numpy,
"heaviside") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Conjugate<T>>, T>(numpy,
"conjugate") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Exp<T>>, T>(numpy, "exp") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Exp2<T>>, T>(numpy, "exp2") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Expm1<T>>, T>(numpy, "expm1") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Log<T>>, T>(numpy, "log") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Log2<T>>, T>(numpy, "log2") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Log10<T>>, T>(numpy, "log10") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Log1p<T>>, T>(numpy, "log1p") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Sqrt<T>>, T>(numpy, "sqrt") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Square<T>>, T>(numpy, "square") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Cbrt<T>>, T>(numpy, "cbrt") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Reciprocal<T>>, T>(numpy,
"reciprocal") &&
// Trigonometric functions
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Sin<T>>, T>(numpy, "sin") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Cos<T>>, T>(numpy, "cos") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Tan<T>>, T>(numpy, "tan") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Arcsin<T>>, T>(numpy, "arcsin") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Arccos<T>>, T>(numpy, "arccos") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Arctan<T>>, T>(numpy, "arctan") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Arctan2<T>>, T>(numpy,
"arctan2") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Hypot<T>>, T>(numpy, "hypot") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Sinh<T>>, T>(numpy, "sinh") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Cosh<T>>, T>(numpy, "cosh") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Tanh<T>>, T>(numpy, "tanh") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Arcsinh<T>>, T>(numpy,
"arcsinh") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Arccosh<T>>, T>(numpy,
"arccosh") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Arctanh<T>>, T>(numpy,
"arctanh") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Deg2rad<T>>, T>(numpy,
"deg2rad") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Rad2deg<T>>, T>(numpy,
"rad2deg") &&
// Comparison functions
RegisterUFunc<BinaryUFunc<T, bool, ufuncs::Eq<T>>, T>(numpy, "equal") &&
RegisterUFunc<BinaryUFunc<T, bool, ufuncs::Ne<T>>, T>(numpy,
"not_equal") &&
RegisterUFunc<BinaryUFunc<T, bool, ufuncs::Lt<T>>, T>(numpy, "less") &&
RegisterUFunc<BinaryUFunc<T, bool, ufuncs::Gt<T>>, T>(numpy, "greater") &&
RegisterUFunc<BinaryUFunc<T, bool, ufuncs::Le<T>>, T>(numpy,
"less_equal") &&
RegisterUFunc<BinaryUFunc<T, bool, ufuncs::Ge<T>>, T>(numpy,
"greater_equal") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Maximum<T>>, T>(numpy,
"maximum") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Minimum<T>>, T>(numpy,
"minimum") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Fmax<T>>, T>(numpy, "fmax") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Fmin<T>>, T>(numpy, "fmin") &&
RegisterUFunc<BinaryUFunc<T, bool, ufuncs::LogicalAnd<T>>, T>(
numpy, "logical_and") &&
RegisterUFunc<BinaryUFunc<T, bool, ufuncs::LogicalOr<T>>, T>(
numpy, "logical_or") &&
RegisterUFunc<BinaryUFunc<T, bool, ufuncs::LogicalXor<T>>, T>(
numpy, "logical_xor") &&
RegisterUFunc<UnaryUFunc<T, bool, ufuncs::LogicalNot<T>>, T>(
numpy, "logical_not") &&
// Floating point functions
RegisterUFunc<UnaryUFunc<T, bool, ufuncs::IsFinite<T>>, T>(numpy,
"isfinite") &&
RegisterUFunc<UnaryUFunc<T, bool, ufuncs::IsInf<T>>, T>(numpy, "isinf") &&
RegisterUFunc<UnaryUFunc<T, bool, ufuncs::IsNan<T>>, T>(numpy, "isnan") &&
RegisterUFunc<UnaryUFunc<T, bool, ufuncs::SignBit<T>>, T>(numpy,
"signbit") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::CopySign<T>>, T>(numpy,
"copysign") &&
RegisterUFunc<UnaryUFunc2<T, T, T, ufuncs::Modf<T>>, T>(numpy, "modf") &&
RegisterUFunc<BinaryUFunc2<T, int, T, ufuncs::Ldexp<T>>, T>(numpy,
"ldexp") &&
RegisterUFunc<UnaryUFunc2<T, T, int, ufuncs::Frexp<T>>, T>(numpy,
"frexp") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Floor<T>>, T>(numpy, "floor") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Ceil<T>>, T>(numpy, "ceil") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Trunc<T>>, T>(numpy, "trunc") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::NextAfter<T>>, T>(numpy,
"nextafter") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Spacing<T>>, T>(numpy, "spacing");
return ok;
}
// Returns true if the numpy type for T is successfully registered, including if
// it was already registered (e.g. by a different library). If
// `already_registered` is non-null, it's set to true if the type was already
// registered and false otherwise.
template <typename T>
bool RegisterFloatDtype(PyObject* numpy, bool* already_registered = nullptr) {
if (already_registered != nullptr) {
*already_registered = false;
}
// If another module (presumably either TF or JAX) has registered a bfloat16
// type, use it. We don't want two bfloat16 types if we can avoid it since it
// leads to confusion if we have two different types with the same name. This
// assumes that the other module has a sufficiently complete bfloat16
// implementation. The only known NumPy bfloat16 extension at the time of
// writing is this one (distributed in TF and JAX).
// TODO(phawkins): distribute the bfloat16 extension as its own pip package,
// so we can unambiguously refer to a single canonical definition of bfloat16.
int typenum =
PyArray_TypeNumFromName(const_cast<char*>(TypeDescriptor<T>::kTypeName));
if (typenum != NPY_NOTYPE) {
PyArray_Descr* descr = PyArray_DescrFromType(typenum);
// The test for an argmax function here is to verify that the
// bfloat16 implementation is sufficiently new, and, say, not from
// an older version of TF or JAX.
if (descr && descr->f && descr->f->argmax) {
TypeDescriptor<T>::npy_type = typenum;
TypeDescriptor<T>::type_ptr = reinterpret_cast<PyObject*>(descr->typeobj);
if (already_registered != nullptr) {
*already_registered = true;
}
return true;
}
}
// It's important that we heap-allocate our type. This is because tp_name
// is not a fully-qualified name for a heap-allocated type, and
// PyArray_TypeNumFromName() (above) looks at the tp_name field to find
// types. Existing implementations in JAX and TensorFlow look for "bfloat16",
// not "ml_dtypes.bfloat16" when searching for an implementation.
Safe_PyObjectPtr name =
make_safe(PyUnicode_FromString(TypeDescriptor<T>::kTypeName));
Safe_PyObjectPtr qualname =
make_safe(PyUnicode_FromString(TypeDescriptor<T>::kTypeName));
PyHeapTypeObject* heap_type = reinterpret_cast<PyHeapTypeObject*>(
PyType_Type.tp_alloc(&PyType_Type, 0));
if (!heap_type) {
return false;
}
// Caution: we must not call any functions that might invoke the GC until
// PyType_Ready() is called. Otherwise the GC might see a half-constructed
// type object.
heap_type->ht_name = name.release();
heap_type->ht_qualname = qualname.release();
PyTypeObject* type = &heap_type->ht_type;
type->tp_name = TypeDescriptor<T>::kTypeName;
type->tp_basicsize = sizeof(PyCustomFloat<T>);
type->tp_flags =
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
type->tp_base = &PyGenericArrType_Type;
type->tp_new = PyCustomFloat_New<T>;
type->tp_repr = PyCustomFloat_Repr<T>;
type->tp_hash = PyCustomFloat_Hash<T>;
type->tp_str = PyCustomFloat_Str<T>;
type->tp_doc = const_cast<char*>(TypeDescriptor<T>::kTpDoc);
type->tp_richcompare = PyCustomFloat_RichCompare<T>;
type->tp_as_number = &CustomFloatType<T>::number_methods;
if (PyType_Ready(type) < 0) {
return false;
}
TypeDescriptor<T>::type_ptr = reinterpret_cast<PyObject*>(type);
Safe_PyObjectPtr module = make_safe(PyUnicode_FromString("ml_dtypes"));
if (!module) {
return false;
}
if (PyObject_SetAttrString(TypeDescriptor<T>::type_ptr, "__module__",
module.get()) < 0) {
return false;
}
// Initializes the NumPy descriptor.
PyArray_ArrFuncs& arr_funcs = CustomFloatType<T>::arr_funcs;
PyArray_InitArrFuncs(&arr_funcs);
arr_funcs.getitem = NPyCustomFloat_GetItem<T>;
arr_funcs.setitem = NPyCustomFloat_SetItem<T>;
arr_funcs.compare = NPyCustomFloat_Compare<T>;
arr_funcs.copyswapn = NPyCustomFloat_CopySwapN<T>;
arr_funcs.copyswap = NPyCustomFloat_CopySwap<T>;
arr_funcs.nonzero = NPyCustomFloat_NonZero<T>;
arr_funcs.fill = NPyCustomFloat_Fill<T>;
arr_funcs.dotfunc = NPyCustomFloat_DotFunc<T>;
arr_funcs.compare = NPyCustomFloat_CompareFunc<T>;
arr_funcs.argmax = NPyCustomFloat_ArgMaxFunc<T>;
arr_funcs.argmin = NPyCustomFloat_ArgMinFunc<T>;
#if PY_VERSION_HEX < 0x030900A4 && !defined(Py_SET_TYPE)
Py_TYPE(&CustomFloatType<T>::npy_descr) = &PyArrayDescr_Type;
#else
Py_SET_TYPE(&CustomFloatType<T>::npy_descr, &PyArrayDescr_Type);
#endif
TypeDescriptor<T>::npy_descr.typeobj = type;
TypeDescriptor<T>::npy_type =
PyArray_RegisterDataType(&CustomFloatType<T>::npy_descr);
if (TypeDescriptor<T>::Dtype() < 0) {
return false;
}
Safe_PyObjectPtr typeDict_obj =
make_safe(PyObject_GetAttrString(numpy, "sctypeDict"));
if (!typeDict_obj) return false;
// Add the type object to `numpy.typeDict`: that makes
// `numpy.dtype(type_name)` work.
if (PyDict_SetItemString(typeDict_obj.get(), TypeDescriptor<T>::kTypeName,
TypeDescriptor<T>::type_ptr) < 0) {
return false;
}
// Support dtype(type_name)
if (PyObject_SetAttrString(TypeDescriptor<T>::type_ptr, "dtype",
reinterpret_cast<PyObject*>(
&CustomFloatType<T>::npy_descr)) <
0) {
return false;
}
return RegisterFloatCasts<T>() && RegisterFloatUFuncs<T>(numpy);
}
} // namespace ml_dtypes
#endif // ML_DTYPES_CUSTOM_FLOAT_H_