78 lines
2.4 KiB
Python
78 lines
2.4 KiB
Python
import pytest
|
|
|
|
import numpy as np
|
|
from .._ni_support import _get_output
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
'dtype',
|
|
[
|
|
# String specifiers
|
|
'f4', 'float32', 'complex64', 'complex128',
|
|
# Type and dtype specifiers
|
|
np.float32, float, np.dtype('f4'),
|
|
# Derive from input
|
|
None,
|
|
],
|
|
)
|
|
def test_get_output_basic(dtype):
|
|
shape = (2, 3)
|
|
|
|
input_ = np.zeros(shape, 'float32')
|
|
|
|
# For None, derive dtype from input
|
|
expected_dtype = 'float32' if dtype is None else dtype
|
|
|
|
# Output is dtype-specifier, retrieve shape from input
|
|
result = _get_output(dtype, input_)
|
|
assert result.shape == shape
|
|
assert result.dtype == np.dtype(expected_dtype)
|
|
|
|
# Output is dtype specifier, with explicit shape, overriding input
|
|
result = _get_output(dtype, input_, shape=(3, 2))
|
|
assert result.shape == (3, 2)
|
|
assert result.dtype == np.dtype(expected_dtype)
|
|
|
|
# Output is pre-allocated array, return directly
|
|
output = np.zeros(shape, dtype)
|
|
result = _get_output(output, input_)
|
|
assert result is output
|
|
|
|
|
|
def test_get_output_complex():
|
|
shape = (2, 3)
|
|
|
|
input_ = np.zeros(shape)
|
|
|
|
# None, promote input type to complex
|
|
result = _get_output(None, input_, complex_output=True)
|
|
assert result.shape == shape
|
|
assert result.dtype == np.dtype('complex128')
|
|
|
|
# Explicit type, promote type to complex
|
|
with pytest.warns(UserWarning, match='promoting specified output dtype to complex'):
|
|
result = _get_output(float, input_, complex_output=True)
|
|
assert result.shape == shape
|
|
assert result.dtype == np.dtype('complex128')
|
|
|
|
# String specifier, simply verify complex output
|
|
result = _get_output('complex64', input_, complex_output=True)
|
|
assert result.shape == shape
|
|
assert result.dtype == np.dtype('complex64')
|
|
|
|
|
|
def test_get_output_error_cases():
|
|
input_ = np.zeros((2, 3), 'float32')
|
|
|
|
# Two separate paths can raise the same error
|
|
with pytest.raises(RuntimeError, match='output must have complex dtype'):
|
|
_get_output('float32', input_, complex_output=True)
|
|
with pytest.raises(RuntimeError, match='output must have complex dtype'):
|
|
_get_output(np.zeros((2, 3)), input_, complex_output=True)
|
|
|
|
with pytest.raises(RuntimeError, match='output must have numeric dtype'):
|
|
_get_output('void', input_)
|
|
|
|
with pytest.raises(RuntimeError, match='shape not correct'):
|
|
_get_output(np.zeros((3, 2)), input_)
|