72 lines
2.1 KiB
Python
72 lines
2.1 KiB
Python
![]() |
# Copyright 2022 The ml_dtypes Authors.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
__version__ = '0.2.0' # Keep in sync with pyproject.toml:version
|
||
|
__all__ = [
|
||
|
'__version__',
|
||
|
'bfloat16',
|
||
|
'finfo',
|
||
|
'float8_e4m3b11fnuz',
|
||
|
'float8_e4m3fn',
|
||
|
'float8_e4m3fnuz',
|
||
|
'float8_e5m2',
|
||
|
'float8_e5m2fnuz',
|
||
|
'iinfo',
|
||
|
'int4',
|
||
|
'uint4',
|
||
|
]
|
||
|
|
||
|
from typing import Type
|
||
|
|
||
|
from ml_dtypes._custom_floats import bfloat16
|
||
|
from ml_dtypes._custom_floats import float8_e4m3b11fnuz
|
||
|
from ml_dtypes._custom_floats import float8_e4m3fn
|
||
|
from ml_dtypes._custom_floats import float8_e4m3fnuz
|
||
|
from ml_dtypes._custom_floats import float8_e5m2
|
||
|
from ml_dtypes._custom_floats import float8_e5m2fnuz
|
||
|
from ml_dtypes._custom_floats import int4
|
||
|
from ml_dtypes._custom_floats import uint4
|
||
|
from ml_dtypes._finfo import finfo
|
||
|
from ml_dtypes._iinfo import iinfo
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
bfloat16: Type[np.generic]
|
||
|
float8_e4m3b11fnuz: Type[np.generic]
|
||
|
float8_e4m3fn: Type[np.generic]
|
||
|
float8_e4m3fnuz: Type[np.generic]
|
||
|
float8_e5m2: Type[np.generic]
|
||
|
float8_e5m2fnuz: Type[np.generic]
|
||
|
int4: Type[np.generic]
|
||
|
uint4: Type[np.generic]
|
||
|
|
||
|
del np, Type
|
||
|
|
||
|
|
||
|
# TODO(jakevdp) remove this deprecated name.
|
||
|
def __getattr__(name): # pylint: disable=invalid-name
|
||
|
if name == 'float8_e4m3b11':
|
||
|
import warnings # pylint: disable=g-import-not-at-top
|
||
|
|
||
|
warnings.warn(
|
||
|
(
|
||
|
'ml_dtypes.float8_e4m3b11 is deprecated. Use'
|
||
|
' ml_dtypes.float8_e4m3b11fnuz'
|
||
|
),
|
||
|
DeprecationWarning,
|
||
|
stacklevel=2,
|
||
|
)
|
||
|
return float8_e4m3b11fnuz
|
||
|
raise AttributeError(f'cannot import name {name!r} from {__name__!r}')
|