1562 lines
52 KiB
Python
1562 lines
52 KiB
Python
from __future__ import division, absolute_import, print_function
|
|
|
|
import warnings
|
|
import sys
|
|
import os
|
|
import itertools
|
|
import textwrap
|
|
import pytest
|
|
import weakref
|
|
|
|
import numpy as np
|
|
from numpy.testing import (
|
|
assert_equal, assert_array_equal, assert_almost_equal,
|
|
assert_array_almost_equal, assert_array_less, build_err_msg, raises,
|
|
assert_raises, assert_warns, assert_no_warnings, assert_allclose,
|
|
assert_approx_equal, assert_array_almost_equal_nulp, assert_array_max_ulp,
|
|
clear_and_catch_warnings, suppress_warnings, assert_string_equal, assert_,
|
|
tempdir, temppath, assert_no_gc_cycles, HAS_REFCOUNT
|
|
)
|
|
from numpy.core.overrides import ARRAY_FUNCTION_ENABLED
|
|
|
|
|
|
class _GenericTest(object):
|
|
|
|
def _test_equal(self, a, b):
|
|
self._assert_func(a, b)
|
|
|
|
def _test_not_equal(self, a, b):
|
|
with assert_raises(AssertionError):
|
|
self._assert_func(a, b)
|
|
|
|
def test_array_rank1_eq(self):
|
|
"""Test two equal array of rank 1 are found equal."""
|
|
a = np.array([1, 2])
|
|
b = np.array([1, 2])
|
|
|
|
self._test_equal(a, b)
|
|
|
|
def test_array_rank1_noteq(self):
|
|
"""Test two different array of rank 1 are found not equal."""
|
|
a = np.array([1, 2])
|
|
b = np.array([2, 2])
|
|
|
|
self._test_not_equal(a, b)
|
|
|
|
def test_array_rank2_eq(self):
|
|
"""Test two equal array of rank 2 are found equal."""
|
|
a = np.array([[1, 2], [3, 4]])
|
|
b = np.array([[1, 2], [3, 4]])
|
|
|
|
self._test_equal(a, b)
|
|
|
|
def test_array_diffshape(self):
|
|
"""Test two arrays with different shapes are found not equal."""
|
|
a = np.array([1, 2])
|
|
b = np.array([[1, 2], [1, 2]])
|
|
|
|
self._test_not_equal(a, b)
|
|
|
|
def test_objarray(self):
|
|
"""Test object arrays."""
|
|
a = np.array([1, 1], dtype=object)
|
|
self._test_equal(a, 1)
|
|
|
|
def test_array_likes(self):
|
|
self._test_equal([1, 2, 3], (1, 2, 3))
|
|
|
|
|
|
class TestArrayEqual(_GenericTest):
|
|
|
|
def setup(self):
|
|
self._assert_func = assert_array_equal
|
|
|
|
def test_generic_rank1(self):
|
|
"""Test rank 1 array for all dtypes."""
|
|
def foo(t):
|
|
a = np.empty(2, t)
|
|
a.fill(1)
|
|
b = a.copy()
|
|
c = a.copy()
|
|
c.fill(0)
|
|
self._test_equal(a, b)
|
|
self._test_not_equal(c, b)
|
|
|
|
# Test numeric types and object
|
|
for t in '?bhilqpBHILQPfdgFDG':
|
|
foo(t)
|
|
|
|
# Test strings
|
|
for t in ['S1', 'U1']:
|
|
foo(t)
|
|
|
|
def test_generic_rank3(self):
|
|
"""Test rank 3 array for all dtypes."""
|
|
def foo(t):
|
|
a = np.empty((4, 2, 3), t)
|
|
a.fill(1)
|
|
b = a.copy()
|
|
c = a.copy()
|
|
c.fill(0)
|
|
self._test_equal(a, b)
|
|
self._test_not_equal(c, b)
|
|
|
|
# Test numeric types and object
|
|
for t in '?bhilqpBHILQPfdgFDG':
|
|
foo(t)
|
|
|
|
# Test strings
|
|
for t in ['S1', 'U1']:
|
|
foo(t)
|
|
|
|
def test_nan_array(self):
|
|
"""Test arrays with nan values in them."""
|
|
a = np.array([1, 2, np.nan])
|
|
b = np.array([1, 2, np.nan])
|
|
|
|
self._test_equal(a, b)
|
|
|
|
c = np.array([1, 2, 3])
|
|
self._test_not_equal(c, b)
|
|
|
|
def test_string_arrays(self):
|
|
"""Test two arrays with different shapes are found not equal."""
|
|
a = np.array(['floupi', 'floupa'])
|
|
b = np.array(['floupi', 'floupa'])
|
|
|
|
self._test_equal(a, b)
|
|
|
|
c = np.array(['floupipi', 'floupa'])
|
|
|
|
self._test_not_equal(c, b)
|
|
|
|
def test_recarrays(self):
|
|
"""Test record arrays."""
|
|
a = np.empty(2, [('floupi', float), ('floupa', float)])
|
|
a['floupi'] = [1, 2]
|
|
a['floupa'] = [1, 2]
|
|
b = a.copy()
|
|
|
|
self._test_equal(a, b)
|
|
|
|
c = np.empty(2, [('floupipi', float), ('floupa', float)])
|
|
c['floupipi'] = a['floupi'].copy()
|
|
c['floupa'] = a['floupa'].copy()
|
|
|
|
with suppress_warnings() as sup:
|
|
l = sup.record(FutureWarning, message="elementwise == ")
|
|
self._test_not_equal(c, b)
|
|
assert_equal(len(l), 1)
|
|
|
|
def test_masked_nan_inf(self):
|
|
# Regression test for gh-11121
|
|
a = np.ma.MaskedArray([3., 4., 6.5], mask=[False, True, False])
|
|
b = np.array([3., np.nan, 6.5])
|
|
self._test_equal(a, b)
|
|
self._test_equal(b, a)
|
|
a = np.ma.MaskedArray([3., 4., 6.5], mask=[True, False, False])
|
|
b = np.array([np.inf, 4., 6.5])
|
|
self._test_equal(a, b)
|
|
self._test_equal(b, a)
|
|
|
|
def test_subclass_that_overrides_eq(self):
|
|
# While we cannot guarantee testing functions will always work for
|
|
# subclasses, the tests should ideally rely only on subclasses having
|
|
# comparison operators, not on them being able to store booleans
|
|
# (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
|
|
class MyArray(np.ndarray):
|
|
def __eq__(self, other):
|
|
return bool(np.equal(self, other).all())
|
|
|
|
def __ne__(self, other):
|
|
return not self == other
|
|
|
|
a = np.array([1., 2.]).view(MyArray)
|
|
b = np.array([2., 3.]).view(MyArray)
|
|
assert_(type(a == a), bool)
|
|
assert_(a == a)
|
|
assert_(a != b)
|
|
self._test_equal(a, a)
|
|
self._test_not_equal(a, b)
|
|
self._test_not_equal(b, a)
|
|
|
|
@pytest.mark.skipif(
|
|
not ARRAY_FUNCTION_ENABLED, reason='requires __array_function__')
|
|
def test_subclass_that_does_not_implement_npall(self):
|
|
class MyArray(np.ndarray):
|
|
def __array_function__(self, *args, **kwargs):
|
|
return NotImplemented
|
|
|
|
a = np.array([1., 2.]).view(MyArray)
|
|
b = np.array([2., 3.]).view(MyArray)
|
|
with assert_raises(TypeError):
|
|
np.all(a)
|
|
self._test_equal(a, a)
|
|
self._test_not_equal(a, b)
|
|
self._test_not_equal(b, a)
|
|
|
|
|
|
class TestBuildErrorMessage(object):
|
|
|
|
def test_build_err_msg_defaults(self):
|
|
x = np.array([1.00001, 2.00002, 3.00003])
|
|
y = np.array([1.00002, 2.00003, 3.00004])
|
|
err_msg = 'There is a mismatch'
|
|
|
|
a = build_err_msg([x, y], err_msg)
|
|
b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array(['
|
|
'1.00001, 2.00002, 3.00003])\n DESIRED: array([1.00002, '
|
|
'2.00003, 3.00004])')
|
|
assert_equal(a, b)
|
|
|
|
def test_build_err_msg_no_verbose(self):
|
|
x = np.array([1.00001, 2.00002, 3.00003])
|
|
y = np.array([1.00002, 2.00003, 3.00004])
|
|
err_msg = 'There is a mismatch'
|
|
|
|
a = build_err_msg([x, y], err_msg, verbose=False)
|
|
b = '\nItems are not equal: There is a mismatch'
|
|
assert_equal(a, b)
|
|
|
|
def test_build_err_msg_custom_names(self):
|
|
x = np.array([1.00001, 2.00002, 3.00003])
|
|
y = np.array([1.00002, 2.00003, 3.00004])
|
|
err_msg = 'There is a mismatch'
|
|
|
|
a = build_err_msg([x, y], err_msg, names=('FOO', 'BAR'))
|
|
b = ('\nItems are not equal: There is a mismatch\n FOO: array(['
|
|
'1.00001, 2.00002, 3.00003])\n BAR: array([1.00002, 2.00003, '
|
|
'3.00004])')
|
|
assert_equal(a, b)
|
|
|
|
def test_build_err_msg_custom_precision(self):
|
|
x = np.array([1.000000001, 2.00002, 3.00003])
|
|
y = np.array([1.000000002, 2.00003, 3.00004])
|
|
err_msg = 'There is a mismatch'
|
|
|
|
a = build_err_msg([x, y], err_msg, precision=10)
|
|
b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array(['
|
|
'1.000000001, 2.00002 , 3.00003 ])\n DESIRED: array(['
|
|
'1.000000002, 2.00003 , 3.00004 ])')
|
|
assert_equal(a, b)
|
|
|
|
|
|
class TestEqual(TestArrayEqual):
|
|
|
|
def setup(self):
|
|
self._assert_func = assert_equal
|
|
|
|
def test_nan_items(self):
|
|
self._assert_func(np.nan, np.nan)
|
|
self._assert_func([np.nan], [np.nan])
|
|
self._test_not_equal(np.nan, [np.nan])
|
|
self._test_not_equal(np.nan, 1)
|
|
|
|
def test_inf_items(self):
|
|
self._assert_func(np.inf, np.inf)
|
|
self._assert_func([np.inf], [np.inf])
|
|
self._test_not_equal(np.inf, [np.inf])
|
|
|
|
def test_datetime(self):
|
|
self._test_equal(
|
|
np.datetime64("2017-01-01", "s"),
|
|
np.datetime64("2017-01-01", "s")
|
|
)
|
|
self._test_equal(
|
|
np.datetime64("2017-01-01", "s"),
|
|
np.datetime64("2017-01-01", "m")
|
|
)
|
|
|
|
# gh-10081
|
|
self._test_not_equal(
|
|
np.datetime64("2017-01-01", "s"),
|
|
np.datetime64("2017-01-02", "s")
|
|
)
|
|
self._test_not_equal(
|
|
np.datetime64("2017-01-01", "s"),
|
|
np.datetime64("2017-01-02", "m")
|
|
)
|
|
|
|
def test_nat_items(self):
|
|
# not a datetime
|
|
nadt_no_unit = np.datetime64("NaT")
|
|
nadt_s = np.datetime64("NaT", "s")
|
|
nadt_d = np.datetime64("NaT", "ns")
|
|
# not a timedelta
|
|
natd_no_unit = np.timedelta64("NaT")
|
|
natd_s = np.timedelta64("NaT", "s")
|
|
natd_d = np.timedelta64("NaT", "ns")
|
|
|
|
dts = [nadt_no_unit, nadt_s, nadt_d]
|
|
tds = [natd_no_unit, natd_s, natd_d]
|
|
for a, b in itertools.product(dts, dts):
|
|
self._assert_func(a, b)
|
|
self._assert_func([a], [b])
|
|
self._test_not_equal([a], b)
|
|
|
|
for a, b in itertools.product(tds, tds):
|
|
self._assert_func(a, b)
|
|
self._assert_func([a], [b])
|
|
self._test_not_equal([a], b)
|
|
|
|
for a, b in itertools.product(tds, dts):
|
|
self._test_not_equal(a, b)
|
|
self._test_not_equal(a, [b])
|
|
self._test_not_equal([a], [b])
|
|
self._test_not_equal([a], np.datetime64("2017-01-01", "s"))
|
|
self._test_not_equal([b], np.datetime64("2017-01-01", "s"))
|
|
self._test_not_equal([a], np.timedelta64(123, "s"))
|
|
self._test_not_equal([b], np.timedelta64(123, "s"))
|
|
|
|
def test_non_numeric(self):
|
|
self._assert_func('ab', 'ab')
|
|
self._test_not_equal('ab', 'abb')
|
|
|
|
def test_complex_item(self):
|
|
self._assert_func(complex(1, 2), complex(1, 2))
|
|
self._assert_func(complex(1, np.nan), complex(1, np.nan))
|
|
self._test_not_equal(complex(1, np.nan), complex(1, 2))
|
|
self._test_not_equal(complex(np.nan, 1), complex(1, np.nan))
|
|
self._test_not_equal(complex(np.nan, np.inf), complex(np.nan, 2))
|
|
|
|
def test_negative_zero(self):
|
|
self._test_not_equal(np.PZERO, np.NZERO)
|
|
|
|
def test_complex(self):
|
|
x = np.array([complex(1, 2), complex(1, np.nan)])
|
|
y = np.array([complex(1, 2), complex(1, 2)])
|
|
self._assert_func(x, x)
|
|
self._test_not_equal(x, y)
|
|
|
|
def test_error_message(self):
|
|
with pytest.raises(AssertionError) as exc_info:
|
|
self._assert_func(np.array([1, 2]), np.array([[1, 2]]))
|
|
msg = str(exc_info.value)
|
|
msg2 = msg.replace("shapes (2L,), (1L, 2L)", "shapes (2,), (1, 2)")
|
|
msg_reference = textwrap.dedent("""\
|
|
|
|
Arrays are not equal
|
|
|
|
(shapes (2,), (1, 2) mismatch)
|
|
x: array([1, 2])
|
|
y: array([[1, 2]])""")
|
|
|
|
try:
|
|
assert_equal(msg, msg_reference)
|
|
except AssertionError:
|
|
assert_equal(msg2, msg_reference)
|
|
|
|
def test_object(self):
|
|
#gh-12942
|
|
import datetime
|
|
a = np.array([datetime.datetime(2000, 1, 1),
|
|
datetime.datetime(2000, 1, 2)])
|
|
self._test_not_equal(a, a[::-1])
|
|
|
|
|
|
class TestArrayAlmostEqual(_GenericTest):
|
|
|
|
def setup(self):
|
|
self._assert_func = assert_array_almost_equal
|
|
|
|
def test_closeness(self):
|
|
# Note that in the course of time we ended up with
|
|
# `abs(x - y) < 1.5 * 10**(-decimal)`
|
|
# instead of the previously documented
|
|
# `abs(x - y) < 0.5 * 10**(-decimal)`
|
|
# so this check serves to preserve the wrongness.
|
|
|
|
# test scalars
|
|
self._assert_func(1.499999, 0.0, decimal=0)
|
|
assert_raises(AssertionError,
|
|
lambda: self._assert_func(1.5, 0.0, decimal=0))
|
|
|
|
# test arrays
|
|
self._assert_func([1.499999], [0.0], decimal=0)
|
|
assert_raises(AssertionError,
|
|
lambda: self._assert_func([1.5], [0.0], decimal=0))
|
|
|
|
def test_simple(self):
|
|
x = np.array([1234.2222])
|
|
y = np.array([1234.2223])
|
|
|
|
self._assert_func(x, y, decimal=3)
|
|
self._assert_func(x, y, decimal=4)
|
|
assert_raises(AssertionError,
|
|
lambda: self._assert_func(x, y, decimal=5))
|
|
|
|
def test_nan(self):
|
|
anan = np.array([np.nan])
|
|
aone = np.array([1])
|
|
ainf = np.array([np.inf])
|
|
self._assert_func(anan, anan)
|
|
assert_raises(AssertionError,
|
|
lambda: self._assert_func(anan, aone))
|
|
assert_raises(AssertionError,
|
|
lambda: self._assert_func(anan, ainf))
|
|
assert_raises(AssertionError,
|
|
lambda: self._assert_func(ainf, anan))
|
|
|
|
def test_inf(self):
|
|
a = np.array([[1., 2.], [3., 4.]])
|
|
b = a.copy()
|
|
a[0, 0] = np.inf
|
|
assert_raises(AssertionError,
|
|
lambda: self._assert_func(a, b))
|
|
b[0, 0] = -np.inf
|
|
assert_raises(AssertionError,
|
|
lambda: self._assert_func(a, b))
|
|
|
|
def test_subclass(self):
|
|
a = np.array([[1., 2.], [3., 4.]])
|
|
b = np.ma.masked_array([[1., 2.], [0., 4.]],
|
|
[[False, False], [True, False]])
|
|
self._assert_func(a, b)
|
|
self._assert_func(b, a)
|
|
self._assert_func(b, b)
|
|
|
|
# Test fully masked as well (see gh-11123).
|
|
a = np.ma.MaskedArray(3.5, mask=True)
|
|
b = np.array([3., 4., 6.5])
|
|
self._test_equal(a, b)
|
|
self._test_equal(b, a)
|
|
a = np.ma.masked
|
|
b = np.array([3., 4., 6.5])
|
|
self._test_equal(a, b)
|
|
self._test_equal(b, a)
|
|
a = np.ma.MaskedArray([3., 4., 6.5], mask=[True, True, True])
|
|
b = np.array([1., 2., 3.])
|
|
self._test_equal(a, b)
|
|
self._test_equal(b, a)
|
|
a = np.ma.MaskedArray([3., 4., 6.5], mask=[True, True, True])
|
|
b = np.array(1.)
|
|
self._test_equal(a, b)
|
|
self._test_equal(b, a)
|
|
|
|
def test_subclass_that_cannot_be_bool(self):
|
|
# While we cannot guarantee testing functions will always work for
|
|
# subclasses, the tests should ideally rely only on subclasses having
|
|
# comparison operators, not on them being able to store booleans
|
|
# (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
|
|
class MyArray(np.ndarray):
|
|
def __eq__(self, other):
|
|
return super(MyArray, self).__eq__(other).view(np.ndarray)
|
|
|
|
def __lt__(self, other):
|
|
return super(MyArray, self).__lt__(other).view(np.ndarray)
|
|
|
|
def all(self, *args, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
a = np.array([1., 2.]).view(MyArray)
|
|
self._assert_func(a, a)
|
|
|
|
|
|
class TestAlmostEqual(_GenericTest):
|
|
|
|
def setup(self):
|
|
self._assert_func = assert_almost_equal
|
|
|
|
def test_closeness(self):
|
|
# Note that in the course of time we ended up with
|
|
# `abs(x - y) < 1.5 * 10**(-decimal)`
|
|
# instead of the previously documented
|
|
# `abs(x - y) < 0.5 * 10**(-decimal)`
|
|
# so this check serves to preserve the wrongness.
|
|
|
|
# test scalars
|
|
self._assert_func(1.499999, 0.0, decimal=0)
|
|
assert_raises(AssertionError,
|
|
lambda: self._assert_func(1.5, 0.0, decimal=0))
|
|
|
|
# test arrays
|
|
self._assert_func([1.499999], [0.0], decimal=0)
|
|
assert_raises(AssertionError,
|
|
lambda: self._assert_func([1.5], [0.0], decimal=0))
|
|
|
|
def test_nan_item(self):
|
|
self._assert_func(np.nan, np.nan)
|
|
assert_raises(AssertionError,
|
|
lambda: self._assert_func(np.nan, 1))
|
|
assert_raises(AssertionError,
|
|
lambda: self._assert_func(np.nan, np.inf))
|
|
assert_raises(AssertionError,
|
|
lambda: self._assert_func(np.inf, np.nan))
|
|
|
|
def test_inf_item(self):
|
|
self._assert_func(np.inf, np.inf)
|
|
self._assert_func(-np.inf, -np.inf)
|
|
assert_raises(AssertionError,
|
|
lambda: self._assert_func(np.inf, 1))
|
|
assert_raises(AssertionError,
|
|
lambda: self._assert_func(-np.inf, np.inf))
|
|
|
|
def test_simple_item(self):
|
|
self._test_not_equal(1, 2)
|
|
|
|
def test_complex_item(self):
|
|
self._assert_func(complex(1, 2), complex(1, 2))
|
|
self._assert_func(complex(1, np.nan), complex(1, np.nan))
|
|
self._assert_func(complex(np.inf, np.nan), complex(np.inf, np.nan))
|
|
self._test_not_equal(complex(1, np.nan), complex(1, 2))
|
|
self._test_not_equal(complex(np.nan, 1), complex(1, np.nan))
|
|
self._test_not_equal(complex(np.nan, np.inf), complex(np.nan, 2))
|
|
|
|
def test_complex(self):
|
|
x = np.array([complex(1, 2), complex(1, np.nan)])
|
|
z = np.array([complex(1, 2), complex(np.nan, 1)])
|
|
y = np.array([complex(1, 2), complex(1, 2)])
|
|
self._assert_func(x, x)
|
|
self._test_not_equal(x, y)
|
|
self._test_not_equal(x, z)
|
|
|
|
def test_error_message(self):
|
|
"""Check the message is formatted correctly for the decimal value.
|
|
Also check the message when input includes inf or nan (gh12200)"""
|
|
x = np.array([1.00000000001, 2.00000000002, 3.00003])
|
|
y = np.array([1.00000000002, 2.00000000003, 3.00004])
|
|
|
|
# Test with a different amount of decimal digits
|
|
with pytest.raises(AssertionError) as exc_info:
|
|
self._assert_func(x, y, decimal=12)
|
|
msgs = str(exc_info.value).split('\n')
|
|
assert_equal(msgs[3], 'Mismatch: 100%')
|
|
assert_equal(msgs[4], 'Max absolute difference: 1.e-05')
|
|
assert_equal(msgs[5], 'Max relative difference: 3.33328889e-06')
|
|
assert_equal(
|
|
msgs[6],
|
|
' x: array([1.00000000001, 2.00000000002, 3.00003 ])')
|
|
assert_equal(
|
|
msgs[7],
|
|
' y: array([1.00000000002, 2.00000000003, 3.00004 ])')
|
|
|
|
# With the default value of decimal digits, only the 3rd element
|
|
# differs. Note that we only check for the formatting of the arrays
|
|
# themselves.
|
|
with pytest.raises(AssertionError) as exc_info:
|
|
self._assert_func(x, y)
|
|
msgs = str(exc_info.value).split('\n')
|
|
assert_equal(msgs[3], 'Mismatch: 33.3%')
|
|
assert_equal(msgs[4], 'Max absolute difference: 1.e-05')
|
|
assert_equal(msgs[5], 'Max relative difference: 3.33328889e-06')
|
|
assert_equal(msgs[6], ' x: array([1. , 2. , 3.00003])')
|
|
assert_equal(msgs[7], ' y: array([1. , 2. , 3.00004])')
|
|
|
|
# Check the error message when input includes inf
|
|
x = np.array([np.inf, 0])
|
|
y = np.array([np.inf, 1])
|
|
with pytest.raises(AssertionError) as exc_info:
|
|
self._assert_func(x, y)
|
|
msgs = str(exc_info.value).split('\n')
|
|
assert_equal(msgs[3], 'Mismatch: 50%')
|
|
assert_equal(msgs[4], 'Max absolute difference: 1.')
|
|
assert_equal(msgs[5], 'Max relative difference: 1.')
|
|
assert_equal(msgs[6], ' x: array([inf, 0.])')
|
|
assert_equal(msgs[7], ' y: array([inf, 1.])')
|
|
|
|
# Check the error message when dividing by zero
|
|
x = np.array([1, 2])
|
|
y = np.array([0, 0])
|
|
with pytest.raises(AssertionError) as exc_info:
|
|
self._assert_func(x, y)
|
|
msgs = str(exc_info.value).split('\n')
|
|
assert_equal(msgs[3], 'Mismatch: 100%')
|
|
assert_equal(msgs[4], 'Max absolute difference: 2')
|
|
assert_equal(msgs[5], 'Max relative difference: inf')
|
|
|
|
def test_subclass_that_cannot_be_bool(self):
|
|
# While we cannot guarantee testing functions will always work for
|
|
# subclasses, the tests should ideally rely only on subclasses having
|
|
# comparison operators, not on them being able to store booleans
|
|
# (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
|
|
class MyArray(np.ndarray):
|
|
def __eq__(self, other):
|
|
return super(MyArray, self).__eq__(other).view(np.ndarray)
|
|
|
|
def __lt__(self, other):
|
|
return super(MyArray, self).__lt__(other).view(np.ndarray)
|
|
|
|
def all(self, *args, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
a = np.array([1., 2.]).view(MyArray)
|
|
self._assert_func(a, a)
|
|
|
|
|
|
class TestApproxEqual(object):
|
|
|
|
def setup(self):
|
|
self._assert_func = assert_approx_equal
|
|
|
|
def test_simple_arrays(self):
|
|
x = np.array([1234.22])
|
|
y = np.array([1234.23])
|
|
|
|
self._assert_func(x, y, significant=5)
|
|
self._assert_func(x, y, significant=6)
|
|
assert_raises(AssertionError,
|
|
lambda: self._assert_func(x, y, significant=7))
|
|
|
|
def test_simple_items(self):
|
|
x = 1234.22
|
|
y = 1234.23
|
|
|
|
self._assert_func(x, y, significant=4)
|
|
self._assert_func(x, y, significant=5)
|
|
self._assert_func(x, y, significant=6)
|
|
assert_raises(AssertionError,
|
|
lambda: self._assert_func(x, y, significant=7))
|
|
|
|
def test_nan_array(self):
|
|
anan = np.array(np.nan)
|
|
aone = np.array(1)
|
|
ainf = np.array(np.inf)
|
|
self._assert_func(anan, anan)
|
|
assert_raises(AssertionError, lambda: self._assert_func(anan, aone))
|
|
assert_raises(AssertionError, lambda: self._assert_func(anan, ainf))
|
|
assert_raises(AssertionError, lambda: self._assert_func(ainf, anan))
|
|
|
|
def test_nan_items(self):
|
|
anan = np.array(np.nan)
|
|
aone = np.array(1)
|
|
ainf = np.array(np.inf)
|
|
self._assert_func(anan, anan)
|
|
assert_raises(AssertionError, lambda: self._assert_func(anan, aone))
|
|
assert_raises(AssertionError, lambda: self._assert_func(anan, ainf))
|
|
assert_raises(AssertionError, lambda: self._assert_func(ainf, anan))
|
|
|
|
|
|
class TestArrayAssertLess(object):
|
|
|
|
def setup(self):
|
|
self._assert_func = assert_array_less
|
|
|
|
def test_simple_arrays(self):
|
|
x = np.array([1.1, 2.2])
|
|
y = np.array([1.2, 2.3])
|
|
|
|
self._assert_func(x, y)
|
|
assert_raises(AssertionError, lambda: self._assert_func(y, x))
|
|
|
|
y = np.array([1.0, 2.3])
|
|
|
|
assert_raises(AssertionError, lambda: self._assert_func(x, y))
|
|
assert_raises(AssertionError, lambda: self._assert_func(y, x))
|
|
|
|
def test_rank2(self):
|
|
x = np.array([[1.1, 2.2], [3.3, 4.4]])
|
|
y = np.array([[1.2, 2.3], [3.4, 4.5]])
|
|
|
|
self._assert_func(x, y)
|
|
assert_raises(AssertionError, lambda: self._assert_func(y, x))
|
|
|
|
y = np.array([[1.0, 2.3], [3.4, 4.5]])
|
|
|
|
assert_raises(AssertionError, lambda: self._assert_func(x, y))
|
|
assert_raises(AssertionError, lambda: self._assert_func(y, x))
|
|
|
|
def test_rank3(self):
|
|
x = np.ones(shape=(2, 2, 2))
|
|
y = np.ones(shape=(2, 2, 2))+1
|
|
|
|
self._assert_func(x, y)
|
|
assert_raises(AssertionError, lambda: self._assert_func(y, x))
|
|
|
|
y[0, 0, 0] = 0
|
|
|
|
assert_raises(AssertionError, lambda: self._assert_func(x, y))
|
|
assert_raises(AssertionError, lambda: self._assert_func(y, x))
|
|
|
|
def test_simple_items(self):
|
|
x = 1.1
|
|
y = 2.2
|
|
|
|
self._assert_func(x, y)
|
|
assert_raises(AssertionError, lambda: self._assert_func(y, x))
|
|
|
|
y = np.array([2.2, 3.3])
|
|
|
|
self._assert_func(x, y)
|
|
assert_raises(AssertionError, lambda: self._assert_func(y, x))
|
|
|
|
y = np.array([1.0, 3.3])
|
|
|
|
assert_raises(AssertionError, lambda: self._assert_func(x, y))
|
|
|
|
def test_nan_noncompare(self):
|
|
anan = np.array(np.nan)
|
|
aone = np.array(1)
|
|
ainf = np.array(np.inf)
|
|
self._assert_func(anan, anan)
|
|
assert_raises(AssertionError, lambda: self._assert_func(aone, anan))
|
|
assert_raises(AssertionError, lambda: self._assert_func(anan, aone))
|
|
assert_raises(AssertionError, lambda: self._assert_func(anan, ainf))
|
|
assert_raises(AssertionError, lambda: self._assert_func(ainf, anan))
|
|
|
|
def test_nan_noncompare_array(self):
|
|
x = np.array([1.1, 2.2, 3.3])
|
|
anan = np.array(np.nan)
|
|
|
|
assert_raises(AssertionError, lambda: self._assert_func(x, anan))
|
|
assert_raises(AssertionError, lambda: self._assert_func(anan, x))
|
|
|
|
x = np.array([1.1, 2.2, np.nan])
|
|
|
|
assert_raises(AssertionError, lambda: self._assert_func(x, anan))
|
|
assert_raises(AssertionError, lambda: self._assert_func(anan, x))
|
|
|
|
y = np.array([1.0, 2.0, np.nan])
|
|
|
|
self._assert_func(y, x)
|
|
assert_raises(AssertionError, lambda: self._assert_func(x, y))
|
|
|
|
def test_inf_compare(self):
|
|
aone = np.array(1)
|
|
ainf = np.array(np.inf)
|
|
|
|
self._assert_func(aone, ainf)
|
|
self._assert_func(-ainf, aone)
|
|
self._assert_func(-ainf, ainf)
|
|
assert_raises(AssertionError, lambda: self._assert_func(ainf, aone))
|
|
assert_raises(AssertionError, lambda: self._assert_func(aone, -ainf))
|
|
assert_raises(AssertionError, lambda: self._assert_func(ainf, ainf))
|
|
assert_raises(AssertionError, lambda: self._assert_func(ainf, -ainf))
|
|
assert_raises(AssertionError, lambda: self._assert_func(-ainf, -ainf))
|
|
|
|
def test_inf_compare_array(self):
|
|
x = np.array([1.1, 2.2, np.inf])
|
|
ainf = np.array(np.inf)
|
|
|
|
assert_raises(AssertionError, lambda: self._assert_func(x, ainf))
|
|
assert_raises(AssertionError, lambda: self._assert_func(ainf, x))
|
|
assert_raises(AssertionError, lambda: self._assert_func(x, -ainf))
|
|
assert_raises(AssertionError, lambda: self._assert_func(-x, -ainf))
|
|
assert_raises(AssertionError, lambda: self._assert_func(-ainf, -x))
|
|
self._assert_func(-ainf, x)
|
|
|
|
|
|
@pytest.mark.skip(reason="The raises decorator depends on Nose")
|
|
class TestRaises(object):
|
|
|
|
def setup(self):
|
|
class MyException(Exception):
|
|
pass
|
|
|
|
self.e = MyException
|
|
|
|
def raises_exception(self, e):
|
|
raise e
|
|
|
|
def does_not_raise_exception(self):
|
|
pass
|
|
|
|
def test_correct_catch(self):
|
|
raises(self.e)(self.raises_exception)(self.e) # raises?
|
|
|
|
def test_wrong_exception(self):
|
|
try:
|
|
raises(self.e)(self.raises_exception)(RuntimeError) # raises?
|
|
except RuntimeError:
|
|
return
|
|
else:
|
|
raise AssertionError("should have caught RuntimeError")
|
|
|
|
def test_catch_no_raise(self):
|
|
try:
|
|
raises(self.e)(self.does_not_raise_exception)() # raises?
|
|
except AssertionError:
|
|
return
|
|
else:
|
|
raise AssertionError("should have raised an AssertionError")
|
|
|
|
|
|
class TestWarns(object):
|
|
|
|
def test_warn(self):
|
|
def f():
|
|
warnings.warn("yo")
|
|
return 3
|
|
|
|
before_filters = sys.modules['warnings'].filters[:]
|
|
assert_equal(assert_warns(UserWarning, f), 3)
|
|
after_filters = sys.modules['warnings'].filters
|
|
|
|
assert_raises(AssertionError, assert_no_warnings, f)
|
|
assert_equal(assert_no_warnings(lambda x: x, 1), 1)
|
|
|
|
# Check that the warnings state is unchanged
|
|
assert_equal(before_filters, after_filters,
|
|
"assert_warns does not preserver warnings state")
|
|
|
|
def test_context_manager(self):
|
|
|
|
before_filters = sys.modules['warnings'].filters[:]
|
|
with assert_warns(UserWarning):
|
|
warnings.warn("yo")
|
|
after_filters = sys.modules['warnings'].filters
|
|
|
|
def no_warnings():
|
|
with assert_no_warnings():
|
|
warnings.warn("yo")
|
|
|
|
assert_raises(AssertionError, no_warnings)
|
|
assert_equal(before_filters, after_filters,
|
|
"assert_warns does not preserver warnings state")
|
|
|
|
def test_warn_wrong_warning(self):
|
|
def f():
|
|
warnings.warn("yo", DeprecationWarning)
|
|
|
|
failed = False
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter("error", DeprecationWarning)
|
|
try:
|
|
# Should raise a DeprecationWarning
|
|
assert_warns(UserWarning, f)
|
|
failed = True
|
|
except DeprecationWarning:
|
|
pass
|
|
|
|
if failed:
|
|
raise AssertionError("wrong warning caught by assert_warn")
|
|
|
|
|
|
class TestAssertAllclose(object):
|
|
|
|
def test_simple(self):
|
|
x = 1e-3
|
|
y = 1e-9
|
|
|
|
assert_allclose(x, y, atol=1)
|
|
assert_raises(AssertionError, assert_allclose, x, y)
|
|
|
|
a = np.array([x, y, x, y])
|
|
b = np.array([x, y, x, x])
|
|
|
|
assert_allclose(a, b, atol=1)
|
|
assert_raises(AssertionError, assert_allclose, a, b)
|
|
|
|
b[-1] = y * (1 + 1e-8)
|
|
assert_allclose(a, b)
|
|
assert_raises(AssertionError, assert_allclose, a, b, rtol=1e-9)
|
|
|
|
assert_allclose(6, 10, rtol=0.5)
|
|
assert_raises(AssertionError, assert_allclose, 10, 6, rtol=0.5)
|
|
|
|
def test_min_int(self):
|
|
a = np.array([np.iinfo(np.int_).min], dtype=np.int_)
|
|
# Should not raise:
|
|
assert_allclose(a, a)
|
|
|
|
def test_report_fail_percentage(self):
|
|
a = np.array([1, 1, 1, 1])
|
|
b = np.array([1, 1, 1, 2])
|
|
|
|
with pytest.raises(AssertionError) as exc_info:
|
|
assert_allclose(a, b)
|
|
msg = str(exc_info.value)
|
|
assert_('Mismatch: 25%\nMax absolute difference: 1\n'
|
|
'Max relative difference: 0.5' in msg)
|
|
|
|
def test_equal_nan(self):
|
|
a = np.array([np.nan])
|
|
b = np.array([np.nan])
|
|
# Should not raise:
|
|
assert_allclose(a, b, equal_nan=True)
|
|
|
|
def test_not_equal_nan(self):
|
|
a = np.array([np.nan])
|
|
b = np.array([np.nan])
|
|
assert_raises(AssertionError, assert_allclose, a, b, equal_nan=False)
|
|
|
|
def test_equal_nan_default(self):
|
|
# Make sure equal_nan default behavior remains unchanged. (All
|
|
# of these functions use assert_array_compare under the hood.)
|
|
# None of these should raise.
|
|
a = np.array([np.nan])
|
|
b = np.array([np.nan])
|
|
assert_array_equal(a, b)
|
|
assert_array_almost_equal(a, b)
|
|
assert_array_less(a, b)
|
|
assert_allclose(a, b)
|
|
|
|
def test_report_max_relative_error(self):
|
|
a = np.array([0, 1])
|
|
b = np.array([0, 2])
|
|
|
|
with pytest.raises(AssertionError) as exc_info:
|
|
assert_allclose(a, b)
|
|
msg = str(exc_info.value)
|
|
assert_('Max relative difference: 0.5' in msg)
|
|
|
|
|
|
class TestArrayAlmostEqualNulp(object):
|
|
|
|
def test_float64_pass(self):
|
|
# The number of units of least precision
|
|
# In this case, use a few places above the lowest level (ie nulp=1)
|
|
nulp = 5
|
|
x = np.linspace(-20, 20, 50, dtype=np.float64)
|
|
x = 10**x
|
|
x = np.r_[-x, x]
|
|
|
|
# Addition
|
|
eps = np.finfo(x.dtype).eps
|
|
y = x + x*eps*nulp/2.
|
|
assert_array_almost_equal_nulp(x, y, nulp)
|
|
|
|
# Subtraction
|
|
epsneg = np.finfo(x.dtype).epsneg
|
|
y = x - x*epsneg*nulp/2.
|
|
assert_array_almost_equal_nulp(x, y, nulp)
|
|
|
|
def test_float64_fail(self):
|
|
nulp = 5
|
|
x = np.linspace(-20, 20, 50, dtype=np.float64)
|
|
x = 10**x
|
|
x = np.r_[-x, x]
|
|
|
|
eps = np.finfo(x.dtype).eps
|
|
y = x + x*eps*nulp*2.
|
|
assert_raises(AssertionError, assert_array_almost_equal_nulp,
|
|
x, y, nulp)
|
|
|
|
epsneg = np.finfo(x.dtype).epsneg
|
|
y = x - x*epsneg*nulp*2.
|
|
assert_raises(AssertionError, assert_array_almost_equal_nulp,
|
|
x, y, nulp)
|
|
|
|
def test_float32_pass(self):
|
|
nulp = 5
|
|
x = np.linspace(-20, 20, 50, dtype=np.float32)
|
|
x = 10**x
|
|
x = np.r_[-x, x]
|
|
|
|
eps = np.finfo(x.dtype).eps
|
|
y = x + x*eps*nulp/2.
|
|
assert_array_almost_equal_nulp(x, y, nulp)
|
|
|
|
epsneg = np.finfo(x.dtype).epsneg
|
|
y = x - x*epsneg*nulp/2.
|
|
assert_array_almost_equal_nulp(x, y, nulp)
|
|
|
|
def test_float32_fail(self):
|
|
nulp = 5
|
|
x = np.linspace(-20, 20, 50, dtype=np.float32)
|
|
x = 10**x
|
|
x = np.r_[-x, x]
|
|
|
|
eps = np.finfo(x.dtype).eps
|
|
y = x + x*eps*nulp*2.
|
|
assert_raises(AssertionError, assert_array_almost_equal_nulp,
|
|
x, y, nulp)
|
|
|
|
epsneg = np.finfo(x.dtype).epsneg
|
|
y = x - x*epsneg*nulp*2.
|
|
assert_raises(AssertionError, assert_array_almost_equal_nulp,
|
|
x, y, nulp)
|
|
|
|
def test_float16_pass(self):
|
|
nulp = 5
|
|
x = np.linspace(-4, 4, 10, dtype=np.float16)
|
|
x = 10**x
|
|
x = np.r_[-x, x]
|
|
|
|
eps = np.finfo(x.dtype).eps
|
|
y = x + x*eps*nulp/2.
|
|
assert_array_almost_equal_nulp(x, y, nulp)
|
|
|
|
epsneg = np.finfo(x.dtype).epsneg
|
|
y = x - x*epsneg*nulp/2.
|
|
assert_array_almost_equal_nulp(x, y, nulp)
|
|
|
|
def test_float16_fail(self):
|
|
nulp = 5
|
|
x = np.linspace(-4, 4, 10, dtype=np.float16)
|
|
x = 10**x
|
|
x = np.r_[-x, x]
|
|
|
|
eps = np.finfo(x.dtype).eps
|
|
y = x + x*eps*nulp*2.
|
|
assert_raises(AssertionError, assert_array_almost_equal_nulp,
|
|
x, y, nulp)
|
|
|
|
epsneg = np.finfo(x.dtype).epsneg
|
|
y = x - x*epsneg*nulp*2.
|
|
assert_raises(AssertionError, assert_array_almost_equal_nulp,
|
|
x, y, nulp)
|
|
|
|
def test_complex128_pass(self):
|
|
nulp = 5
|
|
x = np.linspace(-20, 20, 50, dtype=np.float64)
|
|
x = 10**x
|
|
x = np.r_[-x, x]
|
|
xi = x + x*1j
|
|
|
|
eps = np.finfo(x.dtype).eps
|
|
y = x + x*eps*nulp/2.
|
|
assert_array_almost_equal_nulp(xi, x + y*1j, nulp)
|
|
assert_array_almost_equal_nulp(xi, y + x*1j, nulp)
|
|
# The test condition needs to be at least a factor of sqrt(2) smaller
|
|
# because the real and imaginary parts both change
|
|
y = x + x*eps*nulp/4.
|
|
assert_array_almost_equal_nulp(xi, y + y*1j, nulp)
|
|
|
|
epsneg = np.finfo(x.dtype).epsneg
|
|
y = x - x*epsneg*nulp/2.
|
|
assert_array_almost_equal_nulp(xi, x + y*1j, nulp)
|
|
assert_array_almost_equal_nulp(xi, y + x*1j, nulp)
|
|
y = x - x*epsneg*nulp/4.
|
|
assert_array_almost_equal_nulp(xi, y + y*1j, nulp)
|
|
|
|
def test_complex128_fail(self):
|
|
nulp = 5
|
|
x = np.linspace(-20, 20, 50, dtype=np.float64)
|
|
x = 10**x
|
|
x = np.r_[-x, x]
|
|
xi = x + x*1j
|
|
|
|
eps = np.finfo(x.dtype).eps
|
|
y = x + x*eps*nulp*2.
|
|
assert_raises(AssertionError, assert_array_almost_equal_nulp,
|
|
xi, x + y*1j, nulp)
|
|
assert_raises(AssertionError, assert_array_almost_equal_nulp,
|
|
xi, y + x*1j, nulp)
|
|
# The test condition needs to be at least a factor of sqrt(2) smaller
|
|
# because the real and imaginary parts both change
|
|
y = x + x*eps*nulp
|
|
assert_raises(AssertionError, assert_array_almost_equal_nulp,
|
|
xi, y + y*1j, nulp)
|
|
|
|
epsneg = np.finfo(x.dtype).epsneg
|
|
y = x - x*epsneg*nulp*2.
|
|
assert_raises(AssertionError, assert_array_almost_equal_nulp,
|
|
xi, x + y*1j, nulp)
|
|
assert_raises(AssertionError, assert_array_almost_equal_nulp,
|
|
xi, y + x*1j, nulp)
|
|
y = x - x*epsneg*nulp
|
|
assert_raises(AssertionError, assert_array_almost_equal_nulp,
|
|
xi, y + y*1j, nulp)
|
|
|
|
def test_complex64_pass(self):
|
|
nulp = 5
|
|
x = np.linspace(-20, 20, 50, dtype=np.float32)
|
|
x = 10**x
|
|
x = np.r_[-x, x]
|
|
xi = x + x*1j
|
|
|
|
eps = np.finfo(x.dtype).eps
|
|
y = x + x*eps*nulp/2.
|
|
assert_array_almost_equal_nulp(xi, x + y*1j, nulp)
|
|
assert_array_almost_equal_nulp(xi, y + x*1j, nulp)
|
|
y = x + x*eps*nulp/4.
|
|
assert_array_almost_equal_nulp(xi, y + y*1j, nulp)
|
|
|
|
epsneg = np.finfo(x.dtype).epsneg
|
|
y = x - x*epsneg*nulp/2.
|
|
assert_array_almost_equal_nulp(xi, x + y*1j, nulp)
|
|
assert_array_almost_equal_nulp(xi, y + x*1j, nulp)
|
|
y = x - x*epsneg*nulp/4.
|
|
assert_array_almost_equal_nulp(xi, y + y*1j, nulp)
|
|
|
|
def test_complex64_fail(self):
|
|
nulp = 5
|
|
x = np.linspace(-20, 20, 50, dtype=np.float32)
|
|
x = 10**x
|
|
x = np.r_[-x, x]
|
|
xi = x + x*1j
|
|
|
|
eps = np.finfo(x.dtype).eps
|
|
y = x + x*eps*nulp*2.
|
|
assert_raises(AssertionError, assert_array_almost_equal_nulp,
|
|
xi, x + y*1j, nulp)
|
|
assert_raises(AssertionError, assert_array_almost_equal_nulp,
|
|
xi, y + x*1j, nulp)
|
|
y = x + x*eps*nulp
|
|
assert_raises(AssertionError, assert_array_almost_equal_nulp,
|
|
xi, y + y*1j, nulp)
|
|
|
|
epsneg = np.finfo(x.dtype).epsneg
|
|
y = x - x*epsneg*nulp*2.
|
|
assert_raises(AssertionError, assert_array_almost_equal_nulp,
|
|
xi, x + y*1j, nulp)
|
|
assert_raises(AssertionError, assert_array_almost_equal_nulp,
|
|
xi, y + x*1j, nulp)
|
|
y = x - x*epsneg*nulp
|
|
assert_raises(AssertionError, assert_array_almost_equal_nulp,
|
|
xi, y + y*1j, nulp)
|
|
|
|
|
|
class TestULP(object):
|
|
|
|
def test_equal(self):
|
|
x = np.random.randn(10)
|
|
assert_array_max_ulp(x, x, maxulp=0)
|
|
|
|
def test_single(self):
|
|
# Generate 1 + small deviation, check that adding eps gives a few UNL
|
|
x = np.ones(10).astype(np.float32)
|
|
x += 0.01 * np.random.randn(10).astype(np.float32)
|
|
eps = np.finfo(np.float32).eps
|
|
assert_array_max_ulp(x, x+eps, maxulp=20)
|
|
|
|
def test_double(self):
|
|
# Generate 1 + small deviation, check that adding eps gives a few UNL
|
|
x = np.ones(10).astype(np.float64)
|
|
x += 0.01 * np.random.randn(10).astype(np.float64)
|
|
eps = np.finfo(np.float64).eps
|
|
assert_array_max_ulp(x, x+eps, maxulp=200)
|
|
|
|
def test_inf(self):
|
|
for dt in [np.float32, np.float64]:
|
|
inf = np.array([np.inf]).astype(dt)
|
|
big = np.array([np.finfo(dt).max])
|
|
assert_array_max_ulp(inf, big, maxulp=200)
|
|
|
|
def test_nan(self):
|
|
# Test that nan is 'far' from small, tiny, inf, max and min
|
|
for dt in [np.float32, np.float64]:
|
|
if dt == np.float32:
|
|
maxulp = 1e6
|
|
else:
|
|
maxulp = 1e12
|
|
inf = np.array([np.inf]).astype(dt)
|
|
nan = np.array([np.nan]).astype(dt)
|
|
big = np.array([np.finfo(dt).max])
|
|
tiny = np.array([np.finfo(dt).tiny])
|
|
zero = np.array([np.PZERO]).astype(dt)
|
|
nzero = np.array([np.NZERO]).astype(dt)
|
|
assert_raises(AssertionError,
|
|
lambda: assert_array_max_ulp(nan, inf,
|
|
maxulp=maxulp))
|
|
assert_raises(AssertionError,
|
|
lambda: assert_array_max_ulp(nan, big,
|
|
maxulp=maxulp))
|
|
assert_raises(AssertionError,
|
|
lambda: assert_array_max_ulp(nan, tiny,
|
|
maxulp=maxulp))
|
|
assert_raises(AssertionError,
|
|
lambda: assert_array_max_ulp(nan, zero,
|
|
maxulp=maxulp))
|
|
assert_raises(AssertionError,
|
|
lambda: assert_array_max_ulp(nan, nzero,
|
|
maxulp=maxulp))
|
|
|
|
|
|
class TestStringEqual(object):
|
|
def test_simple(self):
|
|
assert_string_equal("hello", "hello")
|
|
assert_string_equal("hello\nmultiline", "hello\nmultiline")
|
|
|
|
with pytest.raises(AssertionError) as exc_info:
|
|
assert_string_equal("foo\nbar", "hello\nbar")
|
|
msg = str(exc_info.value)
|
|
assert_equal(msg, "Differences in strings:\n- foo\n+ hello")
|
|
|
|
assert_raises(AssertionError,
|
|
lambda: assert_string_equal("foo", "hello"))
|
|
|
|
def test_regex(self):
|
|
assert_string_equal("a+*b", "a+*b")
|
|
|
|
assert_raises(AssertionError,
|
|
lambda: assert_string_equal("aaa", "a+b"))
|
|
|
|
|
|
def assert_warn_len_equal(mod, n_in_context, py34=None, py37=None):
|
|
try:
|
|
mod_warns = mod.__warningregistry__
|
|
except AttributeError:
|
|
# the lack of a __warningregistry__
|
|
# attribute means that no warning has
|
|
# occurred; this can be triggered in
|
|
# a parallel test scenario, while in
|
|
# a serial test scenario an initial
|
|
# warning (and therefore the attribute)
|
|
# are always created first
|
|
mod_warns = {}
|
|
|
|
num_warns = len(mod_warns)
|
|
# Python 3.4 appears to clear any pre-existing warnings of the same type,
|
|
# when raising warnings inside a catch_warnings block. So, there is a
|
|
# warning generated by the tests within the context manager, but no
|
|
# previous warnings.
|
|
if 'version' in mod_warns:
|
|
# Python 3 adds a 'version' entry to the registry,
|
|
# do not count it.
|
|
num_warns -= 1
|
|
|
|
# Behavior of warnings is Python version dependent. Adjust the
|
|
# expected result to compensate. In particular, Python 3.7 does
|
|
# not make an entry for ignored warnings.
|
|
if sys.version_info[:2] >= (3, 7):
|
|
if py37 is not None:
|
|
n_in_context = py37
|
|
elif sys.version_info[:2] >= (3, 4):
|
|
if py34 is not None:
|
|
n_in_context = py34
|
|
assert_equal(num_warns, n_in_context)
|
|
|
|
def test_warn_len_equal_call_scenarios():
|
|
# assert_warn_len_equal is called under
|
|
# varying circumstances depending on serial
|
|
# vs. parallel test scenarios; this test
|
|
# simply aims to probe both code paths and
|
|
# check that no assertion is uncaught
|
|
|
|
# parallel scenario -- no warning issued yet
|
|
class mod(object):
|
|
pass
|
|
|
|
mod_inst = mod()
|
|
|
|
assert_warn_len_equal(mod=mod_inst,
|
|
n_in_context=0)
|
|
|
|
# serial test scenario -- the __warningregistry__
|
|
# attribute should be present
|
|
class mod(object):
|
|
def __init__(self):
|
|
self.__warningregistry__ = {'warning1':1,
|
|
'warning2':2}
|
|
|
|
mod_inst = mod()
|
|
assert_warn_len_equal(mod=mod_inst,
|
|
n_in_context=2)
|
|
|
|
|
|
def _get_fresh_mod():
|
|
# Get this module, with warning registry empty
|
|
my_mod = sys.modules[__name__]
|
|
try:
|
|
my_mod.__warningregistry__.clear()
|
|
except AttributeError:
|
|
# will not have a __warningregistry__ unless warning has been
|
|
# raised in the module at some point
|
|
pass
|
|
return my_mod
|
|
|
|
|
|
def test_clear_and_catch_warnings():
|
|
# Initial state of module, no warnings
|
|
my_mod = _get_fresh_mod()
|
|
assert_equal(getattr(my_mod, '__warningregistry__', {}), {})
|
|
with clear_and_catch_warnings(modules=[my_mod]):
|
|
warnings.simplefilter('ignore')
|
|
warnings.warn('Some warning')
|
|
assert_equal(my_mod.__warningregistry__, {})
|
|
# Without specified modules, don't clear warnings during context
|
|
# Python 3.7 catch_warnings doesn't make an entry for 'ignore'.
|
|
with clear_and_catch_warnings():
|
|
warnings.simplefilter('ignore')
|
|
warnings.warn('Some warning')
|
|
assert_warn_len_equal(my_mod, 1, py37=0)
|
|
# Confirm that specifying module keeps old warning, does not add new
|
|
with clear_and_catch_warnings(modules=[my_mod]):
|
|
warnings.simplefilter('ignore')
|
|
warnings.warn('Another warning')
|
|
assert_warn_len_equal(my_mod, 1, py37=0)
|
|
# Another warning, no module spec does add to warnings dict, except on
|
|
# Python 3.4 (see comments in `assert_warn_len_equal`)
|
|
# Python 3.7 catch_warnings doesn't make an entry for 'ignore'.
|
|
with clear_and_catch_warnings():
|
|
warnings.simplefilter('ignore')
|
|
warnings.warn('Another warning')
|
|
assert_warn_len_equal(my_mod, 2, py34=1, py37=0)
|
|
|
|
|
|
def test_suppress_warnings_module():
|
|
# Initial state of module, no warnings
|
|
my_mod = _get_fresh_mod()
|
|
assert_equal(getattr(my_mod, '__warningregistry__', {}), {})
|
|
|
|
def warn_other_module():
|
|
# Apply along axis is implemented in python; stacklevel=2 means
|
|
# we end up inside its module, not ours.
|
|
def warn(arr):
|
|
warnings.warn("Some warning 2", stacklevel=2)
|
|
return arr
|
|
np.apply_along_axis(warn, 0, [0])
|
|
|
|
# Test module based warning suppression:
|
|
assert_warn_len_equal(my_mod, 0)
|
|
with suppress_warnings() as sup:
|
|
sup.record(UserWarning)
|
|
# suppress warning from other module (may have .pyc ending),
|
|
# if apply_along_axis is moved, had to be changed.
|
|
sup.filter(module=np.lib.shape_base)
|
|
warnings.warn("Some warning")
|
|
warn_other_module()
|
|
# Check that the suppression did test the file correctly (this module
|
|
# got filtered)
|
|
assert_equal(len(sup.log), 1)
|
|
assert_equal(sup.log[0].message.args[0], "Some warning")
|
|
assert_warn_len_equal(my_mod, 0, py37=0)
|
|
sup = suppress_warnings()
|
|
# Will have to be changed if apply_along_axis is moved:
|
|
sup.filter(module=my_mod)
|
|
with sup:
|
|
warnings.warn('Some warning')
|
|
assert_warn_len_equal(my_mod, 0)
|
|
# And test repeat works:
|
|
sup.filter(module=my_mod)
|
|
with sup:
|
|
warnings.warn('Some warning')
|
|
assert_warn_len_equal(my_mod, 0)
|
|
|
|
# Without specified modules, don't clear warnings during context
|
|
# Python 3.7 does not add ignored warnings.
|
|
with suppress_warnings():
|
|
warnings.simplefilter('ignore')
|
|
warnings.warn('Some warning')
|
|
assert_warn_len_equal(my_mod, 1, py37=0)
|
|
|
|
def test_suppress_warnings_type():
|
|
# Initial state of module, no warnings
|
|
my_mod = _get_fresh_mod()
|
|
assert_equal(getattr(my_mod, '__warningregistry__', {}), {})
|
|
|
|
# Test module based warning suppression:
|
|
with suppress_warnings() as sup:
|
|
sup.filter(UserWarning)
|
|
warnings.warn('Some warning')
|
|
assert_warn_len_equal(my_mod, 0)
|
|
sup = suppress_warnings()
|
|
sup.filter(UserWarning)
|
|
with sup:
|
|
warnings.warn('Some warning')
|
|
assert_warn_len_equal(my_mod, 0)
|
|
# And test repeat works:
|
|
sup.filter(module=my_mod)
|
|
with sup:
|
|
warnings.warn('Some warning')
|
|
assert_warn_len_equal(my_mod, 0)
|
|
|
|
# Without specified modules, don't clear warnings during context
|
|
# Python 3.7 does not add ignored warnings.
|
|
with suppress_warnings():
|
|
warnings.simplefilter('ignore')
|
|
warnings.warn('Some warning')
|
|
assert_warn_len_equal(my_mod, 1, py37=0)
|
|
|
|
|
|
def test_suppress_warnings_decorate_no_record():
|
|
sup = suppress_warnings()
|
|
sup.filter(UserWarning)
|
|
|
|
@sup
|
|
def warn(category):
|
|
warnings.warn('Some warning', category)
|
|
|
|
with warnings.catch_warnings(record=True) as w:
|
|
warnings.simplefilter("always")
|
|
warn(UserWarning) # should be supppressed
|
|
warn(RuntimeWarning)
|
|
assert_equal(len(w), 1)
|
|
|
|
|
|
def test_suppress_warnings_record():
|
|
sup = suppress_warnings()
|
|
log1 = sup.record()
|
|
|
|
with sup:
|
|
log2 = sup.record(message='Some other warning 2')
|
|
sup.filter(message='Some warning')
|
|
warnings.warn('Some warning')
|
|
warnings.warn('Some other warning')
|
|
warnings.warn('Some other warning 2')
|
|
|
|
assert_equal(len(sup.log), 2)
|
|
assert_equal(len(log1), 1)
|
|
assert_equal(len(log2),1)
|
|
assert_equal(log2[0].message.args[0], 'Some other warning 2')
|
|
|
|
# Do it again, with the same context to see if some warnings survived:
|
|
with sup:
|
|
log2 = sup.record(message='Some other warning 2')
|
|
sup.filter(message='Some warning')
|
|
warnings.warn('Some warning')
|
|
warnings.warn('Some other warning')
|
|
warnings.warn('Some other warning 2')
|
|
|
|
assert_equal(len(sup.log), 2)
|
|
assert_equal(len(log1), 1)
|
|
assert_equal(len(log2), 1)
|
|
assert_equal(log2[0].message.args[0], 'Some other warning 2')
|
|
|
|
# Test nested:
|
|
with suppress_warnings() as sup:
|
|
sup.record()
|
|
with suppress_warnings() as sup2:
|
|
sup2.record(message='Some warning')
|
|
warnings.warn('Some warning')
|
|
warnings.warn('Some other warning')
|
|
assert_equal(len(sup2.log), 1)
|
|
assert_equal(len(sup.log), 1)
|
|
|
|
|
|
def test_suppress_warnings_forwarding():
|
|
def warn_other_module():
|
|
# Apply along axis is implemented in python; stacklevel=2 means
|
|
# we end up inside its module, not ours.
|
|
def warn(arr):
|
|
warnings.warn("Some warning", stacklevel=2)
|
|
return arr
|
|
np.apply_along_axis(warn, 0, [0])
|
|
|
|
with suppress_warnings() as sup:
|
|
sup.record()
|
|
with suppress_warnings("always"):
|
|
for i in range(2):
|
|
warnings.warn("Some warning")
|
|
|
|
assert_equal(len(sup.log), 2)
|
|
|
|
with suppress_warnings() as sup:
|
|
sup.record()
|
|
with suppress_warnings("location"):
|
|
for i in range(2):
|
|
warnings.warn("Some warning")
|
|
warnings.warn("Some warning")
|
|
|
|
assert_equal(len(sup.log), 2)
|
|
|
|
with suppress_warnings() as sup:
|
|
sup.record()
|
|
with suppress_warnings("module"):
|
|
for i in range(2):
|
|
warnings.warn("Some warning")
|
|
warnings.warn("Some warning")
|
|
warn_other_module()
|
|
|
|
assert_equal(len(sup.log), 2)
|
|
|
|
with suppress_warnings() as sup:
|
|
sup.record()
|
|
with suppress_warnings("once"):
|
|
for i in range(2):
|
|
warnings.warn("Some warning")
|
|
warnings.warn("Some other warning")
|
|
warn_other_module()
|
|
|
|
assert_equal(len(sup.log), 2)
|
|
|
|
|
|
def test_tempdir():
|
|
with tempdir() as tdir:
|
|
fpath = os.path.join(tdir, 'tmp')
|
|
with open(fpath, 'w'):
|
|
pass
|
|
assert_(not os.path.isdir(tdir))
|
|
|
|
raised = False
|
|
try:
|
|
with tempdir() as tdir:
|
|
raise ValueError()
|
|
except ValueError:
|
|
raised = True
|
|
assert_(raised)
|
|
assert_(not os.path.isdir(tdir))
|
|
|
|
|
|
def test_temppath():
|
|
with temppath() as fpath:
|
|
with open(fpath, 'w'):
|
|
pass
|
|
assert_(not os.path.isfile(fpath))
|
|
|
|
raised = False
|
|
try:
|
|
with temppath() as fpath:
|
|
raise ValueError()
|
|
except ValueError:
|
|
raised = True
|
|
assert_(raised)
|
|
assert_(not os.path.isfile(fpath))
|
|
|
|
|
|
class my_cacw(clear_and_catch_warnings):
|
|
|
|
class_modules = (sys.modules[__name__],)
|
|
|
|
|
|
def test_clear_and_catch_warnings_inherit():
|
|
# Test can subclass and add default modules
|
|
my_mod = _get_fresh_mod()
|
|
with my_cacw():
|
|
warnings.simplefilter('ignore')
|
|
warnings.warn('Some warning')
|
|
assert_equal(my_mod.__warningregistry__, {})
|
|
|
|
|
|
@pytest.mark.skipif(not HAS_REFCOUNT, reason="Python lacks refcounts")
|
|
class TestAssertNoGcCycles(object):
|
|
""" Test assert_no_gc_cycles """
|
|
def test_passes(self):
|
|
def no_cycle():
|
|
b = []
|
|
b.append([])
|
|
return b
|
|
|
|
with assert_no_gc_cycles():
|
|
no_cycle()
|
|
|
|
assert_no_gc_cycles(no_cycle)
|
|
|
|
def test_asserts(self):
|
|
def make_cycle():
|
|
a = []
|
|
a.append(a)
|
|
a.append(a)
|
|
return a
|
|
|
|
with assert_raises(AssertionError):
|
|
with assert_no_gc_cycles():
|
|
make_cycle()
|
|
|
|
with assert_raises(AssertionError):
|
|
assert_no_gc_cycles(make_cycle)
|
|
|
|
@pytest.mark.slow
|
|
def test_fails(self):
|
|
"""
|
|
Test that in cases where the garbage cannot be collected, we raise an
|
|
error, instead of hanging forever trying to clear it.
|
|
"""
|
|
|
|
class ReferenceCycleInDel(object):
|
|
"""
|
|
An object that not only contains a reference cycle, but creates new
|
|
cycles whenever it's garbage-collected and its __del__ runs
|
|
"""
|
|
make_cycle = True
|
|
|
|
def __init__(self):
|
|
self.cycle = self
|
|
|
|
def __del__(self):
|
|
# break the current cycle so that `self` can be freed
|
|
self.cycle = None
|
|
|
|
if ReferenceCycleInDel.make_cycle:
|
|
# but create a new one so that the garbage collector has more
|
|
# work to do.
|
|
ReferenceCycleInDel()
|
|
|
|
try:
|
|
w = weakref.ref(ReferenceCycleInDel())
|
|
try:
|
|
with assert_raises(RuntimeError):
|
|
# this will be unable to get a baseline empty garbage
|
|
assert_no_gc_cycles(lambda: None)
|
|
except AssertionError:
|
|
# the above test is only necessary if the GC actually tried to free
|
|
# our object anyway, which python 2.7 does not.
|
|
if w() is not None:
|
|
pytest.skip("GC does not call __del__ on cyclic objects")
|
|
raise
|
|
|
|
finally:
|
|
# make sure that we stop creating reference cycles
|
|
ReferenceCycleInDel.make_cycle = False
|