882 lines
26 KiB
C++
882 lines
26 KiB
C++
/* Copyright 2023 The ml_dtypes Authors
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#ifndef ML_DTYPES_INT4_H_
|
|
#define ML_DTYPES_INT4_H_
|
|
|
|
// Must be included first
|
|
// clang-format off
|
|
#include "_src/numpy.h"
|
|
// clang-format on
|
|
|
|
#include <cstdint> //NOLINT
|
|
#include <optional> //NOLINT
|
|
#include <ostream> //NOLINT
|
|
#include <sstream> //NOLINT
|
|
|
|
#include "Eigen/Core"
|
|
#include "_src/common.h" // NOLINT
|
|
#include "_src/ufuncs.h" // NOLINT
|
|
|
|
namespace ml_dtypes {
|
|
|
|
template <typename UnderlyingTy>
|
|
struct i4 {
|
|
private:
|
|
UnderlyingTy v : 4;
|
|
|
|
public:
|
|
i4() : v(0) {}
|
|
explicit i4(UnderlyingTy val) : v(val & 0x0F) {}
|
|
template <typename T>
|
|
explicit i4(T t) : i4(static_cast<UnderlyingTy>(t)) {}
|
|
i4(const i4& other) = default;
|
|
|
|
static constexpr i4 lowest() {
|
|
return std::is_signed<UnderlyingTy>::value ? i4(-8) : i4(0);
|
|
}
|
|
static constexpr i4 highest() {
|
|
return std::is_signed<UnderlyingTy>::value ? i4(7) : i4(15);
|
|
}
|
|
|
|
template <typename T, typename = std::enable_if_t<std::is_arithmetic_v<T>>>
|
|
explicit operator T() const {
|
|
return static_cast<T>(v);
|
|
}
|
|
// NOLINTNEXTLINE(google-explicit-constructor)
|
|
operator std::optional<int64_t>() const { return static_cast<int64_t>(v); }
|
|
|
|
i4 operator-() const { return i4(-v); }
|
|
i4 operator+(const i4& other) const { return i4((v + other.v)); }
|
|
i4 operator-(const i4& other) const { return i4((v - other.v)); }
|
|
i4 operator*(const i4& other) const { return i4((v * other.v)); }
|
|
i4 operator/(const i4& other) const { return i4((v / other.v)); }
|
|
i4 operator%(const i4& other) const { return i4((v % other.v)); }
|
|
|
|
i4 operator>>(const int amount) const { return i4((v >> amount)); }
|
|
i4 operator<<(const int amount) const { return i4((v << amount)); }
|
|
|
|
bool operator==(const i4& other) const { return v == other.v; }
|
|
bool operator!=(const i4& other) const { return v != other.v; }
|
|
bool operator<(const i4& other) const { return v < other.v; }
|
|
bool operator>(const i4& other) const { return v > other.v; }
|
|
bool operator<=(const i4& other) const { return v <= other.v; }
|
|
bool operator>=(const i4& other) const { return v >= other.v; }
|
|
|
|
bool operator==(const int64_t other) const { return v == other; }
|
|
bool operator!=(const int64_t other) const { return v != other; }
|
|
bool operator<(const int64_t other) const { return v < other; }
|
|
bool operator>(const int64_t other) const { return v > other; }
|
|
bool operator<=(const int64_t other) const { return v <= other; }
|
|
bool operator>=(const int64_t other) const { return v >= other; }
|
|
|
|
i4& operator++() {
|
|
v = (v + 1) & 0x0F;
|
|
return *this;
|
|
}
|
|
|
|
friend ::std::ostream& operator<<(::std::ostream& os, const i4& num) {
|
|
os << static_cast<int16_t>(num.v);
|
|
return os;
|
|
}
|
|
|
|
std::string ToString() const {
|
|
std::ostringstream os;
|
|
os << static_cast<int16_t>(v);
|
|
return os.str();
|
|
}
|
|
};
|
|
|
|
using int4 = i4<int8_t>;
|
|
using uint4 = i4<uint8_t>;
|
|
|
|
template <typename T>
|
|
struct Int4TypeDescriptor {
|
|
static int Dtype() { return npy_type; }
|
|
|
|
// Registered numpy type ID. Global variable populated by the registration
|
|
// code. Protected by the GIL.
|
|
static int npy_type;
|
|
|
|
// Pointer to the python type object we are using. This is either a pointer
|
|
// to type, if we choose to register it, or to the python type
|
|
// registered by another system into NumPy.
|
|
static PyObject* type_ptr;
|
|
|
|
static PyNumberMethods number_methods;
|
|
static PyArray_ArrFuncs arr_funcs;
|
|
static PyArray_Descr npy_descr;
|
|
};
|
|
|
|
template <typename T>
|
|
int Int4TypeDescriptor<T>::npy_type = NPY_NOTYPE;
|
|
template <typename T>
|
|
PyObject* Int4TypeDescriptor<T>::type_ptr = nullptr;
|
|
|
|
// Representation of a Python custom float object.
|
|
template <typename T>
|
|
struct PyInt4 {
|
|
PyObject_HEAD; // Python object header
|
|
T value;
|
|
};
|
|
|
|
// Returns true if 'object' is a PyInt4.
|
|
template <typename T>
|
|
bool PyInt4_Check(PyObject* object) {
|
|
return PyObject_IsInstance(object, TypeDescriptor<T>::type_ptr);
|
|
}
|
|
|
|
// Extracts the value of a PyInt4 object.
|
|
template <typename T>
|
|
T PyInt4_Value_Unchecked(PyObject* object) {
|
|
return reinterpret_cast<PyInt4<T>*>(object)->value;
|
|
}
|
|
|
|
template <typename T>
|
|
bool PyInt4_Value(PyObject* arg, T* output) {
|
|
if (PyInt4_Check<T>(arg)) {
|
|
*output = PyInt4_Value_Unchecked<T>(arg);
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
// Constructs a PyInt4 object from PyInt4<T>::T.
|
|
template <typename T>
|
|
Safe_PyObjectPtr PyInt4_FromValue(T x) {
|
|
PyTypeObject* type =
|
|
reinterpret_cast<PyTypeObject*>(TypeDescriptor<T>::type_ptr);
|
|
Safe_PyObjectPtr ref = make_safe(type->tp_alloc(type, 0));
|
|
PyInt4<T>* p = reinterpret_cast<PyInt4<T>*>(ref.get());
|
|
if (p) {
|
|
p->value = x;
|
|
}
|
|
return ref;
|
|
}
|
|
|
|
// Converts a Python object to a reduced float value. Returns true on success,
|
|
// returns false and reports a Python error on failure.
|
|
template <typename T>
|
|
bool CastToInt4(PyObject* arg, T* output) {
|
|
if (PyInt4_Check<T>(arg)) {
|
|
*output = PyInt4_Value_Unchecked<T>(arg);
|
|
return true;
|
|
}
|
|
if (PyFloat_Check(arg)) {
|
|
double d = PyFloat_AsDouble(arg);
|
|
if (PyErr_Occurred()) {
|
|
return false;
|
|
}
|
|
if (std::isnan(d)) {
|
|
PyErr_SetString(PyExc_ValueError, "cannot convert float NaN to integer");
|
|
}
|
|
if (std::isinf(d)) {
|
|
PyErr_SetString(PyExc_OverflowError,
|
|
"cannot convert float infinity to integer");
|
|
}
|
|
if (d < static_cast<double>(T::lowest()) ||
|
|
d > static_cast<double>(T::highest())) {
|
|
PyErr_SetString(PyExc_OverflowError,
|
|
"out of range value cannot be converted to int4");
|
|
}
|
|
*output = T(d);
|
|
return true;
|
|
}
|
|
if (PyLong_Check(arg)) {
|
|
long l = PyLong_AsLong(arg); // NOLINT
|
|
if (PyErr_Occurred()) {
|
|
return false;
|
|
}
|
|
*output = T(l);
|
|
return true;
|
|
}
|
|
if (PyArray_IsScalar(arg, Integer)) {
|
|
int64_t v;
|
|
PyArray_CastScalarToCtype(arg, &v, PyArray_DescrFromType(NPY_INT64));
|
|
*output = T(v);
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
// Constructs a new PyInt4.
|
|
template <typename T>
|
|
PyObject* PyInt4_tp_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
|
|
if (kwds && PyDict_Size(kwds)) {
|
|
PyErr_SetString(PyExc_TypeError, "constructor takes no keyword arguments");
|
|
return nullptr;
|
|
}
|
|
Py_ssize_t size = PyTuple_Size(args);
|
|
if (size != 1) {
|
|
PyErr_Format(PyExc_TypeError,
|
|
"expected number as argument to %s constructor",
|
|
TypeDescriptor<T>::kTypeName);
|
|
return nullptr;
|
|
}
|
|
PyObject* arg = PyTuple_GetItem(args, 0);
|
|
|
|
T value;
|
|
if (PyInt4_Check<T>(arg)) {
|
|
Py_INCREF(arg);
|
|
return arg;
|
|
} else if (CastToInt4<T>(arg, &value)) {
|
|
return PyInt4_FromValue<T>(value).release();
|
|
} else if (PyArray_Check(arg)) {
|
|
PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(arg);
|
|
if (PyArray_TYPE(arr) != TypeDescriptor<T>::Dtype()) {
|
|
return PyArray_Cast(arr, TypeDescriptor<T>::Dtype());
|
|
} else {
|
|
Py_INCREF(arg);
|
|
return arg;
|
|
}
|
|
} else if (PyUnicode_Check(arg) || PyBytes_Check(arg)) {
|
|
// Parse float from string, then cast to T.
|
|
PyObject* f = PyLong_FromUnicodeObject(arg, /*base=*/0);
|
|
if (PyErr_Occurred()) {
|
|
return nullptr;
|
|
}
|
|
if (CastToInt4<T>(f, &value)) {
|
|
return PyInt4_FromValue<T>(value).release();
|
|
}
|
|
}
|
|
PyErr_Format(PyExc_TypeError, "expected number, got %s",
|
|
Py_TYPE(arg)->tp_name);
|
|
return nullptr;
|
|
}
|
|
|
|
template <typename T>
|
|
PyObject* PyInt4_nb_float(PyObject* self) {
|
|
T x = PyInt4_Value_Unchecked<T>(self);
|
|
return PyFloat_FromDouble(static_cast<double>(static_cast<float>(x)));
|
|
}
|
|
|
|
template <typename T>
|
|
PyObject* PyInt4_nb_int(PyObject* self) {
|
|
T x = PyInt4_Value_Unchecked<T>(self);
|
|
long y = static_cast<long>(static_cast<float>(x)); // NOLINT
|
|
return PyLong_FromLong(y);
|
|
}
|
|
|
|
template <typename T>
|
|
PyObject* PyInt4_nb_negative(PyObject* self) {
|
|
T x = PyInt4_Value_Unchecked<T>(self);
|
|
return PyInt4_FromValue<T>(-x).release();
|
|
}
|
|
|
|
template <typename T>
|
|
PyObject* PyInt4_nb_positive(PyObject* self) {
|
|
T x = PyInt4_Value_Unchecked<T>(self);
|
|
return PyInt4_FromValue<T>(x).release();
|
|
}
|
|
|
|
template <typename T>
|
|
PyObject* PyInt4_nb_add(PyObject* a, PyObject* b) {
|
|
T x, y;
|
|
if (PyInt4_Value<T>(a, &x) && PyInt4_Value<T>(b, &y)) {
|
|
return PyInt4_FromValue<T>(x + y).release();
|
|
}
|
|
return PyArray_Type.tp_as_number->nb_add(a, b);
|
|
}
|
|
|
|
template <typename T>
|
|
PyObject* PyInt4_nb_subtract(PyObject* a, PyObject* b) {
|
|
T x, y;
|
|
if (PyInt4_Value<T>(a, &x) && PyInt4_Value<T>(b, &y)) {
|
|
return PyInt4_FromValue<T>(x - y).release();
|
|
}
|
|
return PyArray_Type.tp_as_number->nb_subtract(a, b);
|
|
}
|
|
|
|
template <typename T>
|
|
PyObject* PyInt4_nb_multiply(PyObject* a, PyObject* b) {
|
|
T x, y;
|
|
if (PyInt4_Value<T>(a, &x) && PyInt4_Value<T>(b, &y)) {
|
|
return PyInt4_FromValue<T>(x * y).release();
|
|
}
|
|
return PyArray_Type.tp_as_number->nb_multiply(a, b);
|
|
}
|
|
|
|
template <typename T>
|
|
PyObject* PyInt4_nb_remainder(PyObject* a, PyObject* b) {
|
|
T x, y;
|
|
if (PyInt4_Value<T>(a, &x) && PyInt4_Value<T>(b, &y)) {
|
|
if (y == 0) {
|
|
PyErr_SetString(PyExc_ZeroDivisionError, "division by zero");
|
|
return nullptr;
|
|
}
|
|
T v = x % y;
|
|
if (v != 0 && ((v < 0) != (y < 0))) {
|
|
v = v + y;
|
|
}
|
|
return PyInt4_FromValue<T>(v).release();
|
|
}
|
|
return PyArray_Type.tp_as_number->nb_remainder(a, b);
|
|
}
|
|
|
|
template <typename T>
|
|
PyObject* PyInt4_nb_floor_divide(PyObject* a, PyObject* b) {
|
|
T x, y;
|
|
if (PyInt4_Value<T>(a, &x) && PyInt4_Value<T>(b, &y)) {
|
|
if (y == 0) {
|
|
PyErr_SetString(PyExc_ZeroDivisionError, "division by zero");
|
|
return nullptr;
|
|
}
|
|
T v = x / y;
|
|
if (((x > 0) != (y > 0)) && x % y != 0) {
|
|
v = v - T(1);
|
|
}
|
|
return PyInt4_FromValue<T>(v).release();
|
|
}
|
|
return PyArray_Type.tp_as_number->nb_floor_divide(a, b);
|
|
}
|
|
|
|
// Python number methods for PyInt4 objects.
|
|
template <typename T>
|
|
PyNumberMethods Int4TypeDescriptor<T>::number_methods = {
|
|
PyInt4_nb_add<T>, // nb_add
|
|
PyInt4_nb_subtract<T>, // nb_subtract
|
|
PyInt4_nb_multiply<T>, // nb_multiply
|
|
PyInt4_nb_remainder<T>, // nb_remainder
|
|
nullptr, // nb_divmod
|
|
nullptr, // nb_power
|
|
PyInt4_nb_negative<T>, // nb_negative
|
|
PyInt4_nb_positive<T>, // nb_positive
|
|
nullptr, // nb_absolute
|
|
nullptr, // nb_nonzero
|
|
nullptr, // nb_invert
|
|
nullptr, // nb_lshift
|
|
nullptr, // nb_rshift
|
|
nullptr, // nb_and
|
|
nullptr, // nb_xor
|
|
nullptr, // nb_or
|
|
PyInt4_nb_int<T>, // nb_int
|
|
nullptr, // reserved
|
|
PyInt4_nb_float<T>, // nb_float
|
|
|
|
nullptr, // nb_inplace_add
|
|
nullptr, // nb_inplace_subtract
|
|
nullptr, // nb_inplace_multiply
|
|
nullptr, // nb_inplace_remainder
|
|
nullptr, // nb_inplace_power
|
|
nullptr, // nb_inplace_lshift
|
|
nullptr, // nb_inplace_rshift
|
|
nullptr, // nb_inplace_and
|
|
nullptr, // nb_inplace_xor
|
|
nullptr, // nb_inplace_or
|
|
|
|
PyInt4_nb_floor_divide<T>, // nb_floor_divide
|
|
nullptr, // nb_true_divide
|
|
nullptr, // nb_inplace_floor_divide
|
|
nullptr, // nb_inplace_true_divide
|
|
nullptr, // nb_index
|
|
};
|
|
|
|
// Implementation of repr() for PyInt4.
|
|
template <typename T>
|
|
PyObject* PyInt4_Repr(PyObject* self) {
|
|
T x = PyInt4_Value_Unchecked<T>(self);
|
|
std::string s = x.ToString();
|
|
return PyUnicode_FromString(s.c_str());
|
|
}
|
|
|
|
// Implementation of str() for PyInt4.
|
|
template <typename T>
|
|
PyObject* PyInt4_Str(PyObject* self) {
|
|
T x = PyInt4_Value_Unchecked<T>(self);
|
|
std::string s = x.ToString();
|
|
return PyUnicode_FromString(s.c_str());
|
|
}
|
|
|
|
// Hash function for PyInt4.
|
|
template <typename T>
|
|
Py_hash_t PyInt4_Hash(PyObject* self) {
|
|
T x = PyInt4_Value_Unchecked<T>(self);
|
|
// Hash functions must not return -1.
|
|
return static_cast<int>(x) == -1 ? static_cast<Py_hash_t>(-2)
|
|
: static_cast<Py_hash_t>(x);
|
|
}
|
|
|
|
// Comparisons on PyInt4s.
|
|
template <typename T>
|
|
PyObject* PyInt4_RichCompare(PyObject* a, PyObject* b, int op) {
|
|
T x, y;
|
|
if (!PyInt4_Value<T>(a, &x) || !PyInt4_Value<T>(b, &y)) {
|
|
return PyGenericArrType_Type.tp_richcompare(a, b, op);
|
|
}
|
|
bool result;
|
|
switch (op) {
|
|
case Py_LT:
|
|
result = x < y;
|
|
break;
|
|
case Py_LE:
|
|
result = x <= y;
|
|
break;
|
|
case Py_EQ:
|
|
result = x == y;
|
|
break;
|
|
case Py_NE:
|
|
result = x != y;
|
|
break;
|
|
case Py_GT:
|
|
result = x > y;
|
|
break;
|
|
case Py_GE:
|
|
result = x >= y;
|
|
break;
|
|
default:
|
|
PyErr_SetString(PyExc_ValueError, "Invalid op type");
|
|
return nullptr;
|
|
}
|
|
return PyBool_FromLong(result);
|
|
}
|
|
|
|
// Numpy support
|
|
template <typename T>
|
|
PyArray_ArrFuncs Int4TypeDescriptor<T>::arr_funcs;
|
|
|
|
template <typename T>
|
|
PyArray_Descr Int4TypeDescriptor<T>::npy_descr = {
|
|
PyObject_HEAD_INIT(nullptr)
|
|
/*typeobj=*/nullptr, // Filled in later
|
|
/*kind=*/TypeDescriptor<T>::kNpyDescrKind,
|
|
/*type=*/TypeDescriptor<T>::kNpyDescrType,
|
|
/*byteorder=*/TypeDescriptor<T>::kNpyDescrByteorder,
|
|
/*flags=*/NPY_NEEDS_PYAPI | NPY_USE_SETITEM,
|
|
/*type_num=*/0,
|
|
/*elsize=*/sizeof(T),
|
|
/*alignment=*/alignof(T),
|
|
/*subarray=*/nullptr,
|
|
/*fields=*/nullptr,
|
|
/*names=*/nullptr,
|
|
/*f=*/&Int4TypeDescriptor<T>::arr_funcs,
|
|
/*metadata=*/nullptr,
|
|
/*c_metadata=*/nullptr,
|
|
/*hash=*/-1, // -1 means "not computed yet".
|
|
};
|
|
|
|
// Implementations of NumPy array methods.
|
|
|
|
template <typename T>
|
|
PyObject* NPyInt4_GetItem(void* data, void* arr) {
|
|
T x;
|
|
memcpy(&x, data, sizeof(T));
|
|
return PyLong_FromLong(static_cast<int>(x));
|
|
}
|
|
|
|
template <typename T>
|
|
int NPyInt4_SetItem(PyObject* item, void* data, void* arr) {
|
|
T x;
|
|
if (!CastToInt4<T>(item, &x)) {
|
|
PyErr_Format(PyExc_TypeError, "expected number, got %s",
|
|
Py_TYPE(item)->tp_name);
|
|
return -1;
|
|
}
|
|
memcpy(data, &x, sizeof(T));
|
|
return 0;
|
|
}
|
|
|
|
template <typename T>
|
|
int NPyInt4_Compare(const void* a, const void* b, void* arr) {
|
|
T x;
|
|
memcpy(&x, a, sizeof(T));
|
|
|
|
T y;
|
|
memcpy(&y, b, sizeof(T));
|
|
int fy(y);
|
|
int fx(x);
|
|
if (fx < fy) {
|
|
return -1;
|
|
}
|
|
if (fy < fx) {
|
|
return 1;
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
template <typename T>
|
|
void NPyInt4_CopySwapN(void* dstv, npy_intp dstride, void* srcv,
|
|
npy_intp sstride, npy_intp n, int swap, void* arr) {
|
|
char* dst = reinterpret_cast<char*>(dstv);
|
|
char* src = reinterpret_cast<char*>(srcv);
|
|
if (!src) {
|
|
return;
|
|
}
|
|
if (dstride == sizeof(T) && sstride == sizeof(T)) {
|
|
memcpy(dst, src, n * sizeof(T));
|
|
} else {
|
|
for (npy_intp i = 0; i < n; i++) {
|
|
memcpy(dst + dstride * i, src + sstride * i, sizeof(T));
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
void NPyInt4_CopySwap(void* dst, void* src, int swap, void* arr) {
|
|
if (!src) {
|
|
return;
|
|
}
|
|
memcpy(dst, src, sizeof(T));
|
|
}
|
|
|
|
template <typename T>
|
|
npy_bool NPyInt4_NonZero(void* data, void* arr) {
|
|
T x;
|
|
memcpy(&x, data, sizeof(x));
|
|
return x != static_cast<T>(0);
|
|
}
|
|
|
|
template <typename T>
|
|
int NPyInt4_Fill(void* buffer_raw, npy_intp length, void* ignored) {
|
|
T* const buffer = reinterpret_cast<T*>(buffer_raw);
|
|
const int start(buffer[0]);
|
|
const int delta = static_cast<int>(buffer[1]) - start;
|
|
for (npy_intp i = 2; i < length; ++i) {
|
|
buffer[i] = static_cast<T>(start + i * delta);
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
template <typename T>
|
|
void NPyInt4_DotFunc(void* ip1, npy_intp is1, void* ip2, npy_intp is2, void* op,
|
|
npy_intp n, void* arr) {
|
|
char* c1 = reinterpret_cast<char*>(ip1);
|
|
char* c2 = reinterpret_cast<char*>(ip2);
|
|
int acc = 0;
|
|
for (npy_intp i = 0; i < n; ++i) {
|
|
T* const b1 = reinterpret_cast<T*>(c1);
|
|
T* const b2 = reinterpret_cast<T*>(c2);
|
|
acc += static_cast<int>(*b1) * static_cast<int>(*b2);
|
|
c1 += is1;
|
|
c2 += is2;
|
|
}
|
|
T* out = reinterpret_cast<T*>(op);
|
|
*out = static_cast<T>(acc);
|
|
}
|
|
|
|
template <typename T>
|
|
int NPyInt4_CompareFunc(const void* v1, const void* v2, void* arr) {
|
|
T b1 = *reinterpret_cast<const T*>(v1);
|
|
T b2 = *reinterpret_cast<const T*>(v2);
|
|
if (b1 < b2) {
|
|
return -1;
|
|
}
|
|
if (b1 > b2) {
|
|
return 1;
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
template <typename T>
|
|
int NPyInt4_ArgMaxFunc(void* data, npy_intp n, npy_intp* max_ind, void* arr) {
|
|
const T* bdata = reinterpret_cast<const T*>(data);
|
|
// Start with a max_val of NaN, this results in the first iteration preferring
|
|
// bdata[0].
|
|
int max_val = std::numeric_limits<int>::max();
|
|
for (npy_intp i = 0; i < n; ++i) {
|
|
// This condition is chosen so that NaNs are always considered "max".
|
|
if (!(static_cast<int>(bdata[i]) <= max_val)) {
|
|
max_val = static_cast<int>(bdata[i]);
|
|
*max_ind = i;
|
|
}
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
template <typename T>
|
|
int NPyInt4_ArgMinFunc(void* data, npy_intp n, npy_intp* min_ind, void* arr) {
|
|
const T* bdata = reinterpret_cast<const T*>(data);
|
|
int min_val = std::numeric_limits<int>::lowest();
|
|
// Start with a min_val of NaN, this results in the first iteration preferring
|
|
// bdata[0].
|
|
for (npy_intp i = 0; i < n; ++i) {
|
|
// This condition is chosen so that NaNs are always considered "min".
|
|
if (!(static_cast<int>(bdata[i]) >= min_val)) {
|
|
min_val = static_cast<int>(bdata[i]);
|
|
*min_ind = i;
|
|
}
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
template <typename T, std::enable_if_t<(std::is_floating_point<T>::value ||
|
|
std::is_same<T, Eigen::half>::value),
|
|
bool> = true>
|
|
int CastToInt(T value) {
|
|
if (std::isnan(value) || std::isinf(value) ||
|
|
value < std::numeric_limits<int>::lowest() ||
|
|
value > std::numeric_limits<int>::max()) {
|
|
return 0;
|
|
}
|
|
return static_cast<int>(value);
|
|
}
|
|
|
|
template <typename T, std::enable_if_t<std::is_integral<T>::value, bool> = true>
|
|
int CastToInt(T value) {
|
|
return static_cast<int>(value);
|
|
}
|
|
|
|
int CastToInt(int4 value) { return static_cast<int>(value); }
|
|
|
|
int CastToInt(uint4 value) { return static_cast<int>(value); }
|
|
|
|
template <typename T>
|
|
int CastToInt(std::complex<T> value) {
|
|
return CastToInt(value.real());
|
|
}
|
|
|
|
// Performs a NumPy array cast from type 'From' to 'To'.
|
|
template <typename From, typename To>
|
|
void IntegerCast(void* from_void, void* to_void, npy_intp n, void* fromarr,
|
|
void* toarr) {
|
|
const auto* from =
|
|
reinterpret_cast<typename TypeDescriptor<From>::T*>(from_void);
|
|
auto* to = reinterpret_cast<typename TypeDescriptor<To>::T*>(to_void);
|
|
for (npy_intp i = 0; i < n; ++i) {
|
|
to[i] = static_cast<typename TypeDescriptor<To>::T>(
|
|
static_cast<To>(CastToInt(from[i])));
|
|
}
|
|
}
|
|
|
|
// Registers a cast between T (a reduced float) and type 'OtherT'. 'numpy_type'
|
|
// is the NumPy type corresponding to 'OtherT'.
|
|
template <typename T, typename OtherT>
|
|
bool RegisterCustomIntCast(int numpy_type = TypeDescriptor<OtherT>::Dtype()) {
|
|
PyArray_Descr* descr = PyArray_DescrFromType(numpy_type);
|
|
if (PyArray_RegisterCastFunc(descr, TypeDescriptor<T>::Dtype(),
|
|
IntegerCast<OtherT, T>) < 0) {
|
|
return false;
|
|
}
|
|
if (PyArray_RegisterCastFunc(&Int4TypeDescriptor<T>::npy_descr, numpy_type,
|
|
IntegerCast<T, OtherT>) < 0) {
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
template <typename T>
|
|
bool RegisterInt4Casts() {
|
|
if (!RegisterCustomIntCast<T, Eigen::half>(NPY_HALF)) {
|
|
return false;
|
|
}
|
|
if (!RegisterCustomIntCast<T, float>(NPY_FLOAT)) {
|
|
return false;
|
|
}
|
|
if (!RegisterCustomIntCast<T, double>(NPY_DOUBLE)) {
|
|
return false;
|
|
}
|
|
if (!RegisterCustomIntCast<T, long double>(NPY_LONGDOUBLE)) {
|
|
return false;
|
|
}
|
|
if (!RegisterCustomIntCast<T, bool>(NPY_BOOL)) {
|
|
return false;
|
|
}
|
|
if (!RegisterCustomIntCast<T, unsigned char>(NPY_UBYTE)) {
|
|
return false;
|
|
}
|
|
if (!RegisterCustomIntCast<T, unsigned short>(NPY_USHORT)) { // NOLINT
|
|
return false;
|
|
}
|
|
if (!RegisterCustomIntCast<T, unsigned int>(NPY_UINT)) {
|
|
return false;
|
|
}
|
|
if (!RegisterCustomIntCast<T, unsigned long>(NPY_ULONG)) { // NOLINT
|
|
return false;
|
|
}
|
|
if (!RegisterCustomIntCast<T, unsigned long long>( // NOLINT
|
|
NPY_ULONGLONG)) {
|
|
return false;
|
|
}
|
|
if (!RegisterCustomIntCast<T, signed char>(NPY_BYTE)) {
|
|
return false;
|
|
}
|
|
if (!RegisterCustomIntCast<T, short>(NPY_SHORT)) { // NOLINT
|
|
return false;
|
|
}
|
|
if (!RegisterCustomIntCast<T, int>(NPY_INT)) {
|
|
return false;
|
|
}
|
|
if (!RegisterCustomIntCast<T, long>(NPY_LONG)) { // NOLINT
|
|
return false;
|
|
}
|
|
if (!RegisterCustomIntCast<T, long long>(NPY_LONGLONG)) { // NOLINT
|
|
return false;
|
|
}
|
|
// Following the numpy convention. imag part is dropped when converting to
|
|
// float.
|
|
if (!RegisterCustomIntCast<T, std::complex<float>>(NPY_CFLOAT)) {
|
|
return false;
|
|
}
|
|
if (!RegisterCustomIntCast<T, std::complex<double>>(NPY_CDOUBLE)) {
|
|
return false;
|
|
}
|
|
if (!RegisterCustomIntCast<T, std::complex<long double>>(NPY_CLONGDOUBLE)) {
|
|
return false;
|
|
}
|
|
|
|
// Safe casts from T to other types
|
|
// TODO(phawkins): add integer types
|
|
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_FLOAT,
|
|
NPY_NOSCALAR) < 0) {
|
|
return false;
|
|
}
|
|
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_DOUBLE,
|
|
NPY_NOSCALAR) < 0) {
|
|
return false;
|
|
}
|
|
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_LONGDOUBLE,
|
|
NPY_NOSCALAR) < 0) {
|
|
return false;
|
|
}
|
|
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_CFLOAT,
|
|
NPY_NOSCALAR) < 0) {
|
|
return false;
|
|
}
|
|
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_CDOUBLE,
|
|
NPY_NOSCALAR) < 0) {
|
|
return false;
|
|
}
|
|
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_CLONGDOUBLE,
|
|
NPY_NOSCALAR) < 0) {
|
|
return false;
|
|
}
|
|
|
|
// Safe casts to T from other types
|
|
if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_BOOL),
|
|
TypeDescriptor<T>::Dtype(), NPY_NOSCALAR) < 0) {
|
|
return false;
|
|
}
|
|
if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_UBYTE),
|
|
TypeDescriptor<T>::Dtype(), NPY_NOSCALAR) < 0) {
|
|
return false;
|
|
}
|
|
if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_BYTE),
|
|
TypeDescriptor<T>::Dtype(), NPY_NOSCALAR) < 0) {
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
template <typename T>
|
|
bool RegisterInt4UFuncs(PyObject* numpy) {
|
|
bool ok = RegisterUFunc<BinaryUFunc<T, T, ufuncs::Add<T>>, T>(numpy, "add") &&
|
|
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Subtract<T>>, T>(
|
|
numpy, "subtract") &&
|
|
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Multiply<T>>, T>(
|
|
numpy, "multiply") &&
|
|
RegisterUFunc<BinaryUFunc<T, T, ufuncs::FloorDivide<T>>, T>(
|
|
numpy, "floor_divide") &&
|
|
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Remainder<T>>, T>(
|
|
numpy, "remainder");
|
|
|
|
return ok;
|
|
}
|
|
|
|
template <typename T>
|
|
bool RegisterInt4Dtype(PyObject* numpy) {
|
|
Safe_PyObjectPtr name =
|
|
make_safe(PyUnicode_FromString(TypeDescriptor<T>::kTypeName));
|
|
Safe_PyObjectPtr qualname =
|
|
make_safe(PyUnicode_FromString(TypeDescriptor<T>::kTypeName));
|
|
|
|
PyHeapTypeObject* heap_type = reinterpret_cast<PyHeapTypeObject*>(
|
|
PyType_Type.tp_alloc(&PyType_Type, 0));
|
|
if (!heap_type) {
|
|
return false;
|
|
}
|
|
// Caution: we must not call any functions that might invoke the GC until
|
|
// PyType_Ready() is called. Otherwise the GC might see a half-constructed
|
|
// type object.
|
|
heap_type->ht_name = name.release();
|
|
heap_type->ht_qualname = qualname.release();
|
|
PyTypeObject* type = &heap_type->ht_type;
|
|
type->tp_name = TypeDescriptor<T>::kTypeName;
|
|
type->tp_basicsize = sizeof(PyInt4<T>);
|
|
type->tp_flags =
|
|
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
|
|
type->tp_base = &PyGenericArrType_Type;
|
|
type->tp_new = PyInt4_tp_new<T>;
|
|
type->tp_repr = PyInt4_Repr<T>;
|
|
type->tp_hash = PyInt4_Hash<T>;
|
|
type->tp_str = PyInt4_Str<T>;
|
|
type->tp_doc = const_cast<char*>(TypeDescriptor<T>::kTpDoc);
|
|
type->tp_richcompare = PyInt4_RichCompare<T>;
|
|
type->tp_as_number = &Int4TypeDescriptor<T>::number_methods;
|
|
if (PyType_Ready(type) < 0) {
|
|
return false;
|
|
}
|
|
TypeDescriptor<T>::type_ptr = reinterpret_cast<PyObject*>(type);
|
|
|
|
Safe_PyObjectPtr module = make_safe(PyUnicode_FromString("ml_dtypes"));
|
|
if (!module) {
|
|
return false;
|
|
}
|
|
if (PyObject_SetAttrString(TypeDescriptor<T>::type_ptr, "__module__",
|
|
module.get()) < 0) {
|
|
return false;
|
|
}
|
|
|
|
// Initializes the NumPy descriptor.
|
|
PyArray_ArrFuncs& arr_funcs = Int4TypeDescriptor<T>::arr_funcs;
|
|
PyArray_InitArrFuncs(&arr_funcs);
|
|
arr_funcs.getitem = NPyInt4_GetItem<T>;
|
|
arr_funcs.setitem = NPyInt4_SetItem<T>;
|
|
arr_funcs.compare = NPyInt4_Compare<T>;
|
|
arr_funcs.copyswapn = NPyInt4_CopySwapN<T>;
|
|
arr_funcs.copyswap = NPyInt4_CopySwap<T>;
|
|
arr_funcs.nonzero = NPyInt4_NonZero<T>;
|
|
arr_funcs.fill = NPyInt4_Fill<T>;
|
|
arr_funcs.dotfunc = NPyInt4_DotFunc<T>;
|
|
arr_funcs.compare = NPyInt4_CompareFunc<T>;
|
|
arr_funcs.argmax = NPyInt4_ArgMaxFunc<T>;
|
|
arr_funcs.argmin = NPyInt4_ArgMinFunc<T>;
|
|
|
|
#if PY_VERSION_HEX < 0x030900A4 && !defined(Py_SET_TYPE)
|
|
Py_TYPE(&Int4TypeDescriptor<T>::npy_descr) = &PyArrayDescr_Type;
|
|
#else
|
|
Py_SET_TYPE(&Int4TypeDescriptor<T>::npy_descr, &PyArrayDescr_Type);
|
|
#endif
|
|
TypeDescriptor<T>::npy_descr.typeobj = type;
|
|
|
|
TypeDescriptor<T>::npy_type =
|
|
PyArray_RegisterDataType(&Int4TypeDescriptor<T>::npy_descr);
|
|
if (TypeDescriptor<T>::Dtype() < 0) {
|
|
return false;
|
|
}
|
|
|
|
Safe_PyObjectPtr typeDict_obj =
|
|
make_safe(PyObject_GetAttrString(numpy, "sctypeDict"));
|
|
if (!typeDict_obj) return false;
|
|
// Add the type object to `numpy.typeDict`: that makes
|
|
// `numpy.dtype(type_name)` work.
|
|
if (PyDict_SetItemString(typeDict_obj.get(), TypeDescriptor<T>::kTypeName,
|
|
TypeDescriptor<T>::type_ptr) < 0) {
|
|
return false;
|
|
}
|
|
|
|
// Support dtype(type_name)
|
|
if (PyObject_SetAttrString(
|
|
TypeDescriptor<T>::type_ptr, "dtype",
|
|
reinterpret_cast<PyObject*>(&Int4TypeDescriptor<T>::npy_descr)) < 0) {
|
|
return false;
|
|
}
|
|
|
|
return RegisterInt4Casts<T>() && RegisterInt4UFuncs<T>(numpy);
|
|
}
|
|
|
|
} // namespace ml_dtypes
|
|
|
|
#endif // ML_DTYPES_INT4_H_
|