Intelegentny_Pszczelarz/.venv/Lib/site-packages/ml_dtypes/_src/dtypes.cc
2023-06-19 00:49:18 +02:00

370 lines
14 KiB
C++

/* Copyright 2017 The ml_dtypes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Enable cmath defines on Windows
#define _USE_MATH_DEFINES
// Must be included first
// clang-format off
#include "_src/numpy.h" //NOLINT
// clang-format on
#include <array> // NOLINT
#include <cmath> // NOLINT
#include <limits> // NOLINT
#include <locale> // NOLINT
// Place `<locale>` before <Python.h> to avoid a build failure in macOS.
#include <Python.h>
#include "Eigen/Core"
#include "_src/custom_float.h"
#include "_src/int4.h"
#include "include/float8.h"
namespace ml_dtypes {
using bfloat16 = Eigen::bfloat16;
template <>
struct TypeDescriptor<bfloat16> : CustomFloatType<bfloat16> {
typedef bfloat16 T;
static constexpr bool is_floating = true;
static constexpr bool is_integral = false;
static constexpr const char* kTypeName = "bfloat16";
static constexpr const char* kQualifiedTypeName = "ml_dtypes.bfloat16";
static constexpr const char* kTpDoc = "bfloat16 floating-point values";
// We must register bfloat16 with a kind other than "f", because numpy
// considers two types with the same kind and size to be equal, but
// float16 != bfloat16.
// The downside of this is that NumPy scalar promotion does not work with
// bfloat16 values.
static constexpr char kNpyDescrKind = 'V';
// TODO(phawkins): there doesn't seem to be a way of guaranteeing a type
// character is unique.
static constexpr char kNpyDescrType = 'E';
static constexpr char kNpyDescrByteorder = '=';
};
template <>
struct TypeDescriptor<float8_e4m3b11fnuz>
: CustomFloatType<float8_e4m3b11fnuz> {
typedef float8_e4m3b11fnuz T;
static constexpr bool is_floating = true;
static constexpr bool is_integral = false;
static constexpr const char* kTypeName = "float8_e4m3b11fnuz";
static constexpr const char* kQualifiedTypeName =
"ml_dtypes.float8_e4m3b11fnuz";
static constexpr const char* kTpDoc =
"float8_e4m3b11fnuz floating-point values";
// We must register float8_e4m3b11fnuz with a kind other than "f", because
// numpy considers two types with the same kind and size to be equal, and we
// expect multiple 1 byte floating point types.
// The downside of this is that NumPy scalar promotion does not work with
// float8_e4m3b11fnuz values.
static constexpr char kNpyDescrKind = 'V';
// TODO(phawkins): there doesn't seem to be a way of guaranteeing a type
// character is unique.
static constexpr char kNpyDescrType = 'L';
static constexpr char kNpyDescrByteorder = '=';
};
template <>
struct TypeDescriptor<float8_e4m3fn> : CustomFloatType<float8_e4m3fn> {
typedef float8_e4m3fn T;
static constexpr bool is_floating = true;
static constexpr bool is_integral = false;
static constexpr const char* kTypeName = "float8_e4m3fn";
static constexpr const char* kQualifiedTypeName = "ml_dtypes.float8_e4m3fn";
static constexpr const char* kTpDoc = "float8_e4m3fn floating-point values";
// We must register float8_e4m3fn with a unique kind, because numpy
// considers two types with the same kind and size to be equal.
// The downside of this is that NumPy scalar promotion does not work with
// float8 values. Using 'V' to mirror bfloat16 vs float16.
static constexpr char kNpyDescrKind = 'V';
// TODO(phawkins): there doesn't seem to be a way of guaranteeing a type
// character is unique.
static constexpr char kNpyDescrType = '4';
static constexpr char kNpyDescrByteorder = '=';
};
template <>
struct TypeDescriptor<float8_e4m3fnuz> : CustomFloatType<float8_e4m3fnuz> {
typedef float8_e4m3fnuz T;
static constexpr bool is_floating = true;
static constexpr bool is_integral = false;
static constexpr const char* kTypeName = "float8_e4m3fnuz";
static constexpr const char* kQualifiedTypeName = "ml_dtypes.float8_e4m3fnuz";
static constexpr const char* kTpDoc = "float8_e4m3fnuz floating-point values";
static constexpr char kNpyDescrKind = 'V';
// TODO(phawkins): there doesn't seem to be a way of guaranteeing a type
// character is unique.
static constexpr char kNpyDescrType = 'G';
static constexpr char kNpyDescrByteorder = '=';
};
template <>
struct TypeDescriptor<float8_e5m2> : CustomFloatType<float8_e5m2> {
typedef float8_e5m2 T;
static constexpr bool is_floating = true;
static constexpr bool is_integral = false;
static constexpr const char* kTypeName = "float8_e5m2";
static constexpr const char* kQualifiedTypeName = "ml_dtypes.float8_e5m2";
static constexpr const char* kTpDoc = "float8_e5m2 floating-point values";
// Treating e5m2 as the natural "float" type since it is IEEE-754 compliant.
static constexpr char kNpyDescrKind = 'f';
// TODO(phawkins): there doesn't seem to be a way of guaranteeing a type
// character is unique.
static constexpr char kNpyDescrType = '5';
static constexpr char kNpyDescrByteorder = '=';
};
template <>
struct TypeDescriptor<float8_e5m2fnuz> : CustomFloatType<float8_e5m2fnuz> {
typedef float8_e5m2fnuz T;
static constexpr bool is_floating = true;
static constexpr bool is_integral = false;
static constexpr const char* kTypeName = "float8_e5m2fnuz";
static constexpr const char* kQualifiedTypeName = "ml_dtypes.float8_e5m2fnuz";
static constexpr const char* kTpDoc = "float8_e5m2fnuz floating-point values";
static constexpr char kNpyDescrKind = 'V';
// TODO(phawkins): there doesn't seem to be a way of guaranteeing a type
// character is unique.
static constexpr char kNpyDescrType = 'C';
static constexpr char kNpyDescrByteorder = '=';
};
template <>
struct TypeDescriptor<int4> : Int4TypeDescriptor<int4> {
typedef int4 T;
static constexpr bool is_floating = false;
static constexpr bool is_integral = true;
static constexpr const char* kTypeName = "int4";
static constexpr const char* kQualifiedTypeName = "ml_dtypes.int4";
static constexpr const char* kTpDoc = "int4 integer values";
static constexpr char kNpyDescrKind = 'V';
// TODO(phawkins): there doesn't seem to be a way of guaranteeing a type
// character is unique.
static constexpr char kNpyDescrType = 'a';
static constexpr char kNpyDescrByteorder = '=';
};
template <>
struct TypeDescriptor<uint4> : Int4TypeDescriptor<uint4> {
typedef uint4 T;
static constexpr bool is_floating = false;
static constexpr bool is_integral = true;
static constexpr const char* kTypeName = "uint4";
static constexpr const char* kQualifiedTypeName = "ml_dtypes.uint4";
static constexpr const char* kTpDoc = "uint4 integer values";
static constexpr char kNpyDescrKind = 'V';
// TODO(phawkins): there doesn't seem to be a way of guaranteeing a type
// character is unique.
static constexpr char kNpyDescrType = 'A';
static constexpr char kNpyDescrByteorder = '=';
};
namespace {
// Performs a NumPy array cast from type 'From' to 'To' via float.
template <typename From, typename To>
void FloatPyCast(void* from_void, void* to_void, npy_intp n, void* fromarr,
void* toarr) {
const auto* from = static_cast<From*>(from_void);
auto* to = static_cast<To*>(to_void);
for (npy_intp i = 0; i < n; ++i) {
to[i] = static_cast<To>(static_cast<float>(from[i]));
}
}
template <typename Type1, typename Type2>
bool RegisterTwoWayCustomCast() {
int nptype1 = TypeDescriptor<Type1>::npy_type;
int nptype2 = TypeDescriptor<Type2>::npy_type;
PyArray_Descr* descr1 = PyArray_DescrFromType(nptype1);
if (PyArray_RegisterCastFunc(descr1, nptype2, FloatPyCast<Type1, Type2>) <
0) {
return false;
}
PyArray_Descr* descr2 = PyArray_DescrFromType(nptype2);
if (PyArray_RegisterCastFunc(descr2, nptype1, FloatPyCast<Type2, Type1>) <
0) {
return false;
}
return true;
}
} // namespace
// Initializes the module.
bool Initialize() {
ml_dtypes::ImportNumpy();
import_umath1(false);
Safe_PyObjectPtr numpy_str = make_safe(PyUnicode_FromString("numpy"));
if (!numpy_str) {
return false;
}
Safe_PyObjectPtr numpy = make_safe(PyImport_Import(numpy_str.get()));
if (!numpy) {
return false;
}
if (!RegisterFloatDtype<bfloat16>(numpy.get())) {
return false;
}
bool float8_e4m3b11fnuz_already_registered;
if (!RegisterFloatDtype<float8_e4m3b11fnuz>(
numpy.get(), &float8_e4m3b11fnuz_already_registered)) {
return false;
}
bool float8_e4m3fn_already_registered;
if (!ml_dtypes::RegisterFloatDtype<float8_e4m3fn>(
numpy.get(), &float8_e4m3fn_already_registered)) {
return false;
}
bool float8_e4m3fnuz_already_registered;
if (!ml_dtypes::RegisterFloatDtype<float8_e4m3fnuz>(
numpy.get(), &float8_e4m3fnuz_already_registered)) {
return false;
}
bool float8_e5m2_already_registered;
if (!ml_dtypes::RegisterFloatDtype<float8_e5m2>(
numpy.get(), &float8_e5m2_already_registered)) {
return false;
}
bool float8_e5m2fnuz_already_registered;
if (!ml_dtypes::RegisterFloatDtype<float8_e5m2fnuz>(
numpy.get(), &float8_e5m2fnuz_already_registered)) {
return false;
}
if (!ml_dtypes::RegisterInt4Dtype<int4>(numpy.get())) {
return false;
}
if (!ml_dtypes::RegisterInt4Dtype<uint4>(numpy.get())) {
return false;
}
// Casts between bfloat16 and float8_e4m3b11nuz. Only perform the cast if
// float8_e4m3b11nuz hasn't been previously registered, presumably by a
// different library. In this case, we assume the cast has also already been
// registered, and registering it again can cause segfaults due to accessing
// an uninitialized type descriptor in this library.
if (!float8_e4m3b11fnuz_already_registered &&
!RegisterCustomFloatCast<float8_e4m3b11fnuz, bfloat16>()) {
return false;
}
if (!float8_e4m3fnuz_already_registered &&
!float8_e5m2fnuz_already_registered &&
!RegisterTwoWayCustomCast<float8_e4m3fnuz, float8_e5m2fnuz>()) {
return false;
}
if (!float8_e4m3fn_already_registered && !float8_e5m2_already_registered &&
!RegisterCustomFloatCast<float8_e4m3fn, float8_e5m2>()) {
return false;
}
bool success = true;
// Continue trying to register casts, just in case some types are not
// registered (i.e. float8_e4m3b11fnuz)
success &= RegisterTwoWayCustomCast<float8_e4m3b11fnuz, float8_e4m3fn>();
success &= RegisterTwoWayCustomCast<float8_e4m3b11fnuz, float8_e5m2>();
success &= RegisterTwoWayCustomCast<bfloat16, float8_e4m3fn>();
success &= RegisterTwoWayCustomCast<bfloat16, float8_e5m2>();
success &= RegisterTwoWayCustomCast<float8_e4m3fnuz, bfloat16>();
success &= RegisterTwoWayCustomCast<float8_e5m2fnuz, bfloat16>();
success &= RegisterTwoWayCustomCast<float8_e4m3fnuz, float8_e4m3b11fnuz>();
success &= RegisterTwoWayCustomCast<float8_e5m2fnuz, float8_e4m3b11fnuz>();
success &= RegisterTwoWayCustomCast<float8_e4m3fnuz, float8_e4m3fn>();
success &= RegisterTwoWayCustomCast<float8_e5m2fnuz, float8_e4m3fn>();
success &= RegisterTwoWayCustomCast<float8_e4m3fnuz, float8_e5m2>();
success &= RegisterTwoWayCustomCast<float8_e5m2fnuz, float8_e5m2>();
return success;
}
static PyModuleDef module_def = {
PyModuleDef_HEAD_INIT,
"_custom_floats",
};
// TODO(phawkins): PyMODINIT_FUNC handles visibility correctly in Python 3.9+.
// Just use PyMODINIT_FUNC after dropping Python 3.8 support.
#if defined(WIN32) || defined(_WIN32)
#define EXPORT_SYMBOL __declspec(dllexport)
#else
#define EXPORT_SYMBOL __attribute__((visibility("default")))
#endif
extern "C" EXPORT_SYMBOL PyObject* PyInit__custom_floats() {
Safe_PyObjectPtr m = make_safe(PyModule_Create(&module_def));
if (!m) {
return nullptr;
}
if (!Initialize()) {
if (!PyErr_Occurred()) {
PyErr_SetString(PyExc_RuntimeError, "cannot load _custom_floats module.");
}
return nullptr;
}
if (PyObject_SetAttrString(
m.get(), "float8_e4m3b11fnuz",
reinterpret_cast<PyObject*>(
TypeDescriptor<float8_e4m3b11fnuz>::type_ptr)) < 0) {
return nullptr;
}
if (PyObject_SetAttrString(m.get(), "float8_e4m3fn",
reinterpret_cast<PyObject*>(
TypeDescriptor<float8_e4m3fn>::type_ptr)) <
0) {
return nullptr;
}
if (PyObject_SetAttrString(m.get(), "float8_e4m3fnuz",
reinterpret_cast<PyObject*>(
TypeDescriptor<float8_e4m3fnuz>::type_ptr)) <
0) {
return nullptr;
}
if (PyObject_SetAttrString(m.get(), "float8_e5m2",
reinterpret_cast<PyObject*>(
TypeDescriptor<float8_e5m2>::type_ptr)) < 0) {
return nullptr;
}
if (PyObject_SetAttrString(m.get(), "float8_e5m2fnuz",
reinterpret_cast<PyObject*>(
TypeDescriptor<float8_e5m2fnuz>::type_ptr)) <
0) {
return nullptr;
}
if (PyObject_SetAttrString(m.get(), "bfloat16",
reinterpret_cast<PyObject*>(
TypeDescriptor<bfloat16>::type_ptr)) < 0) {
return nullptr;
}
if (PyObject_SetAttrString(
m.get(), "int4",
reinterpret_cast<PyObject*>(TypeDescriptor<int4>::type_ptr)) < 0) {
return nullptr;
}
if (PyObject_SetAttrString(
m.get(), "uint4",
reinterpret_cast<PyObject*>(TypeDescriptor<uint4>::type_ptr)) < 0) {
return nullptr;
}
return m.release();
}
} // namespace ml_dtypes