113 lines
3.1 KiB
Python
113 lines
3.1 KiB
Python
# Copyright 2022 The ml_dtypes Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
import ml_dtypes
|
|
import numpy as np
|
|
|
|
ALL_DTYPES = [
|
|
ml_dtypes.bfloat16,
|
|
ml_dtypes.float8_e4m3b11fnuz,
|
|
ml_dtypes.float8_e4m3fn,
|
|
ml_dtypes.float8_e4m3fnuz,
|
|
ml_dtypes.float8_e5m2,
|
|
ml_dtypes.float8_e5m2fnuz,
|
|
]
|
|
|
|
DTYPES_WITH_NO_INFINITY = [
|
|
ml_dtypes.float8_e4m3b11fnuz,
|
|
ml_dtypes.float8_e4m3fn,
|
|
ml_dtypes.float8_e4m3fnuz,
|
|
ml_dtypes.float8_e5m2fnuz,
|
|
]
|
|
|
|
UINT_TYPES = {
|
|
8: np.uint8,
|
|
16: np.uint16,
|
|
}
|
|
|
|
|
|
class FinfoTest(parameterized.TestCase):
|
|
|
|
def assertNanEqual(self, x, y):
|
|
if np.isnan(x) and np.isnan(y):
|
|
return
|
|
self.assertEqual(x, y)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": f"_{dtype.__name__}", "dtype": np.dtype(dtype)}
|
|
for dtype in ALL_DTYPES
|
|
)
|
|
def testFInfo(self, dtype):
|
|
info = ml_dtypes.finfo(dtype)
|
|
|
|
assert ml_dtypes.finfo(dtype.name) is info
|
|
assert ml_dtypes.finfo(dtype.type) is info
|
|
|
|
_ = str(info) # doesn't crash
|
|
|
|
def make_val(val):
|
|
return np.array(val, dtype=dtype)
|
|
|
|
def assert_representable(val):
|
|
self.assertEqual(make_val(val).item(), val)
|
|
|
|
def assert_infinite(val):
|
|
val = make_val(val)
|
|
if dtype in DTYPES_WITH_NO_INFINITY:
|
|
self.assertTrue(np.isnan(val), f"expected NaN, got {val}")
|
|
else:
|
|
self.assertTrue(np.isposinf(val), f"expected inf, got {val}")
|
|
|
|
def assert_zero(val):
|
|
self.assertEqual(make_val(val), make_val(0))
|
|
|
|
self.assertEqual(np.array(0, dtype).dtype, dtype)
|
|
self.assertIs(info.dtype, dtype)
|
|
|
|
self.assertEqual(info.bits, np.array(0, dtype).itemsize * 8)
|
|
self.assertEqual(info.nmant + info.nexp + 1, info.bits)
|
|
|
|
assert_representable(info.tiny)
|
|
|
|
assert_representable(info.max)
|
|
assert_infinite(np.spacing(info.max))
|
|
|
|
assert_representable(info.min)
|
|
assert_infinite(-np.spacing(info.min))
|
|
|
|
assert_representable(2.0 ** (info.maxexp - 1))
|
|
assert_infinite(2.0**info.maxexp)
|
|
|
|
assert_representable(info.smallest_subnormal)
|
|
assert_zero(info.smallest_subnormal * 0.5)
|
|
self.assertEqual(info.tiny, info.smallest_normal)
|
|
|
|
# Identities according to the documentation:
|
|
np.testing.assert_allclose(info.resolution, make_val(10**-info.precision))
|
|
self.assertEqual(info.epsneg, make_val(2**info.negep))
|
|
self.assertEqual(info.eps, make_val(2**info.machep))
|
|
self.assertEqual(info.iexp, info.nexp)
|
|
|
|
# Check that minexp is consistent with nmant
|
|
self.assertEqual(
|
|
make_val(2**info.minexp).view(UINT_TYPES[info.bits]),
|
|
2**info.nmant,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main()
|