Intelegentny_Pszczelarz/.venv/Lib/site-packages/ml_dtypes/tests/finfo_test.py
2023-06-19 00:49:18 +02:00

113 lines
3.1 KiB
Python

# Copyright 2022 The ml_dtypes Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
from absl.testing import parameterized
import ml_dtypes
import numpy as np
ALL_DTYPES = [
ml_dtypes.bfloat16,
ml_dtypes.float8_e4m3b11fnuz,
ml_dtypes.float8_e4m3fn,
ml_dtypes.float8_e4m3fnuz,
ml_dtypes.float8_e5m2,
ml_dtypes.float8_e5m2fnuz,
]
DTYPES_WITH_NO_INFINITY = [
ml_dtypes.float8_e4m3b11fnuz,
ml_dtypes.float8_e4m3fn,
ml_dtypes.float8_e4m3fnuz,
ml_dtypes.float8_e5m2fnuz,
]
UINT_TYPES = {
8: np.uint8,
16: np.uint16,
}
class FinfoTest(parameterized.TestCase):
def assertNanEqual(self, x, y):
if np.isnan(x) and np.isnan(y):
return
self.assertEqual(x, y)
@parameterized.named_parameters(
{"testcase_name": f"_{dtype.__name__}", "dtype": np.dtype(dtype)}
for dtype in ALL_DTYPES
)
def testFInfo(self, dtype):
info = ml_dtypes.finfo(dtype)
assert ml_dtypes.finfo(dtype.name) is info
assert ml_dtypes.finfo(dtype.type) is info
_ = str(info) # doesn't crash
def make_val(val):
return np.array(val, dtype=dtype)
def assert_representable(val):
self.assertEqual(make_val(val).item(), val)
def assert_infinite(val):
val = make_val(val)
if dtype in DTYPES_WITH_NO_INFINITY:
self.assertTrue(np.isnan(val), f"expected NaN, got {val}")
else:
self.assertTrue(np.isposinf(val), f"expected inf, got {val}")
def assert_zero(val):
self.assertEqual(make_val(val), make_val(0))
self.assertEqual(np.array(0, dtype).dtype, dtype)
self.assertIs(info.dtype, dtype)
self.assertEqual(info.bits, np.array(0, dtype).itemsize * 8)
self.assertEqual(info.nmant + info.nexp + 1, info.bits)
assert_representable(info.tiny)
assert_representable(info.max)
assert_infinite(np.spacing(info.max))
assert_representable(info.min)
assert_infinite(-np.spacing(info.min))
assert_representable(2.0 ** (info.maxexp - 1))
assert_infinite(2.0**info.maxexp)
assert_representable(info.smallest_subnormal)
assert_zero(info.smallest_subnormal * 0.5)
self.assertEqual(info.tiny, info.smallest_normal)
# Identities according to the documentation:
np.testing.assert_allclose(info.resolution, make_val(10**-info.precision))
self.assertEqual(info.epsneg, make_val(2**info.negep))
self.assertEqual(info.eps, make_val(2**info.machep))
self.assertEqual(info.iexp, info.nexp)
# Check that minexp is consistent with nmant
self.assertEqual(
make_val(2**info.minexp).view(UINT_TYPES[info.bits]),
2**info.nmant,
)
if __name__ == "__main__":
absltest.main()