459 lines
14 KiB
Python
459 lines
14 KiB
Python
|
"""
|
||
|
Functions for changing global ufunc configuration
|
||
|
|
||
|
This provides helpers which wrap `umath.geterrobj` and `umath.seterrobj`
|
||
|
"""
|
||
|
from __future__ import division, absolute_import, print_function
|
||
|
|
||
|
try:
|
||
|
# Accessing collections abstract classes from collections
|
||
|
# has been deprecated since Python 3.3
|
||
|
import collections.abc as collections_abc
|
||
|
except ImportError:
|
||
|
import collections as collections_abc
|
||
|
import contextlib
|
||
|
|
||
|
from .overrides import set_module
|
||
|
from .umath import (
|
||
|
UFUNC_BUFSIZE_DEFAULT,
|
||
|
ERR_IGNORE, ERR_WARN, ERR_RAISE, ERR_CALL, ERR_PRINT, ERR_LOG, ERR_DEFAULT,
|
||
|
SHIFT_DIVIDEBYZERO, SHIFT_OVERFLOW, SHIFT_UNDERFLOW, SHIFT_INVALID,
|
||
|
)
|
||
|
from . import umath
|
||
|
|
||
|
__all__ = [
|
||
|
"seterr", "geterr", "setbufsize", "getbufsize", "seterrcall", "geterrcall",
|
||
|
"errstate",
|
||
|
]
|
||
|
|
||
|
_errdict = {"ignore": ERR_IGNORE,
|
||
|
"warn": ERR_WARN,
|
||
|
"raise": ERR_RAISE,
|
||
|
"call": ERR_CALL,
|
||
|
"print": ERR_PRINT,
|
||
|
"log": ERR_LOG}
|
||
|
|
||
|
_errdict_rev = {value: key for key, value in _errdict.items()}
|
||
|
|
||
|
|
||
|
@set_module('numpy')
|
||
|
def seterr(all=None, divide=None, over=None, under=None, invalid=None):
|
||
|
"""
|
||
|
Set how floating-point errors are handled.
|
||
|
|
||
|
Note that operations on integer scalar types (such as `int16`) are
|
||
|
handled like floating point, and are affected by these settings.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
all : {'ignore', 'warn', 'raise', 'call', 'print', 'log'}, optional
|
||
|
Set treatment for all types of floating-point errors at once:
|
||
|
|
||
|
- ignore: Take no action when the exception occurs.
|
||
|
- warn: Print a `RuntimeWarning` (via the Python `warnings` module).
|
||
|
- raise: Raise a `FloatingPointError`.
|
||
|
- call: Call a function specified using the `seterrcall` function.
|
||
|
- print: Print a warning directly to ``stdout``.
|
||
|
- log: Record error in a Log object specified by `seterrcall`.
|
||
|
|
||
|
The default is not to change the current behavior.
|
||
|
divide : {'ignore', 'warn', 'raise', 'call', 'print', 'log'}, optional
|
||
|
Treatment for division by zero.
|
||
|
over : {'ignore', 'warn', 'raise', 'call', 'print', 'log'}, optional
|
||
|
Treatment for floating-point overflow.
|
||
|
under : {'ignore', 'warn', 'raise', 'call', 'print', 'log'}, optional
|
||
|
Treatment for floating-point underflow.
|
||
|
invalid : {'ignore', 'warn', 'raise', 'call', 'print', 'log'}, optional
|
||
|
Treatment for invalid floating-point operation.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
old_settings : dict
|
||
|
Dictionary containing the old settings.
|
||
|
|
||
|
See also
|
||
|
--------
|
||
|
seterrcall : Set a callback function for the 'call' mode.
|
||
|
geterr, geterrcall, errstate
|
||
|
|
||
|
Notes
|
||
|
-----
|
||
|
The floating-point exceptions are defined in the IEEE 754 standard [1]_:
|
||
|
|
||
|
- Division by zero: infinite result obtained from finite numbers.
|
||
|
- Overflow: result too large to be expressed.
|
||
|
- Underflow: result so close to zero that some precision
|
||
|
was lost.
|
||
|
- Invalid operation: result is not an expressible number, typically
|
||
|
indicates that a NaN was produced.
|
||
|
|
||
|
.. [1] https://en.wikipedia.org/wiki/IEEE_754
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
>>> old_settings = np.seterr(all='ignore') #seterr to known value
|
||
|
>>> np.seterr(over='raise')
|
||
|
{'divide': 'ignore', 'over': 'ignore', 'under': 'ignore', 'invalid': 'ignore'}
|
||
|
>>> np.seterr(**old_settings) # reset to default
|
||
|
{'divide': 'ignore', 'over': 'raise', 'under': 'ignore', 'invalid': 'ignore'}
|
||
|
|
||
|
>>> np.int16(32000) * np.int16(3)
|
||
|
30464
|
||
|
>>> old_settings = np.seterr(all='warn', over='raise')
|
||
|
>>> np.int16(32000) * np.int16(3)
|
||
|
Traceback (most recent call last):
|
||
|
File "<stdin>", line 1, in <module>
|
||
|
FloatingPointError: overflow encountered in short_scalars
|
||
|
|
||
|
>>> from collections import OrderedDict
|
||
|
>>> old_settings = np.seterr(all='print')
|
||
|
>>> OrderedDict(np.geterr())
|
||
|
OrderedDict([('divide', 'print'), ('over', 'print'), ('under', 'print'), ('invalid', 'print')])
|
||
|
>>> np.int16(32000) * np.int16(3)
|
||
|
30464
|
||
|
|
||
|
"""
|
||
|
|
||
|
pyvals = umath.geterrobj()
|
||
|
old = geterr()
|
||
|
|
||
|
if divide is None:
|
||
|
divide = all or old['divide']
|
||
|
if over is None:
|
||
|
over = all or old['over']
|
||
|
if under is None:
|
||
|
under = all or old['under']
|
||
|
if invalid is None:
|
||
|
invalid = all or old['invalid']
|
||
|
|
||
|
maskvalue = ((_errdict[divide] << SHIFT_DIVIDEBYZERO) +
|
||
|
(_errdict[over] << SHIFT_OVERFLOW) +
|
||
|
(_errdict[under] << SHIFT_UNDERFLOW) +
|
||
|
(_errdict[invalid] << SHIFT_INVALID))
|
||
|
|
||
|
pyvals[1] = maskvalue
|
||
|
umath.seterrobj(pyvals)
|
||
|
return old
|
||
|
|
||
|
|
||
|
@set_module('numpy')
|
||
|
def geterr():
|
||
|
"""
|
||
|
Get the current way of handling floating-point errors.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
res : dict
|
||
|
A dictionary with keys "divide", "over", "under", and "invalid",
|
||
|
whose values are from the strings "ignore", "print", "log", "warn",
|
||
|
"raise", and "call". The keys represent possible floating-point
|
||
|
exceptions, and the values define how these exceptions are handled.
|
||
|
|
||
|
See Also
|
||
|
--------
|
||
|
geterrcall, seterr, seterrcall
|
||
|
|
||
|
Notes
|
||
|
-----
|
||
|
For complete documentation of the types of floating-point exceptions and
|
||
|
treatment options, see `seterr`.
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
>>> from collections import OrderedDict
|
||
|
>>> sorted(np.geterr().items())
|
||
|
[('divide', 'warn'), ('invalid', 'warn'), ('over', 'warn'), ('under', 'ignore')]
|
||
|
>>> np.arange(3.) / np.arange(3.)
|
||
|
array([nan, 1., 1.])
|
||
|
|
||
|
>>> oldsettings = np.seterr(all='warn', over='raise')
|
||
|
>>> OrderedDict(sorted(np.geterr().items()))
|
||
|
OrderedDict([('divide', 'warn'), ('invalid', 'warn'), ('over', 'raise'), ('under', 'warn')])
|
||
|
>>> np.arange(3.) / np.arange(3.)
|
||
|
array([nan, 1., 1.])
|
||
|
|
||
|
"""
|
||
|
maskvalue = umath.geterrobj()[1]
|
||
|
mask = 7
|
||
|
res = {}
|
||
|
val = (maskvalue >> SHIFT_DIVIDEBYZERO) & mask
|
||
|
res['divide'] = _errdict_rev[val]
|
||
|
val = (maskvalue >> SHIFT_OVERFLOW) & mask
|
||
|
res['over'] = _errdict_rev[val]
|
||
|
val = (maskvalue >> SHIFT_UNDERFLOW) & mask
|
||
|
res['under'] = _errdict_rev[val]
|
||
|
val = (maskvalue >> SHIFT_INVALID) & mask
|
||
|
res['invalid'] = _errdict_rev[val]
|
||
|
return res
|
||
|
|
||
|
|
||
|
@set_module('numpy')
|
||
|
def setbufsize(size):
|
||
|
"""
|
||
|
Set the size of the buffer used in ufuncs.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
size : int
|
||
|
Size of buffer.
|
||
|
|
||
|
"""
|
||
|
if size > 10e6:
|
||
|
raise ValueError("Buffer size, %s, is too big." % size)
|
||
|
if size < 5:
|
||
|
raise ValueError("Buffer size, %s, is too small." % size)
|
||
|
if size % 16 != 0:
|
||
|
raise ValueError("Buffer size, %s, is not a multiple of 16." % size)
|
||
|
|
||
|
pyvals = umath.geterrobj()
|
||
|
old = getbufsize()
|
||
|
pyvals[0] = size
|
||
|
umath.seterrobj(pyvals)
|
||
|
return old
|
||
|
|
||
|
|
||
|
@set_module('numpy')
|
||
|
def getbufsize():
|
||
|
"""
|
||
|
Return the size of the buffer used in ufuncs.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
getbufsize : int
|
||
|
Size of ufunc buffer in bytes.
|
||
|
|
||
|
"""
|
||
|
return umath.geterrobj()[0]
|
||
|
|
||
|
|
||
|
@set_module('numpy')
|
||
|
def seterrcall(func):
|
||
|
"""
|
||
|
Set the floating-point error callback function or log object.
|
||
|
|
||
|
There are two ways to capture floating-point error messages. The first
|
||
|
is to set the error-handler to 'call', using `seterr`. Then, set
|
||
|
the function to call using this function.
|
||
|
|
||
|
The second is to set the error-handler to 'log', using `seterr`.
|
||
|
Floating-point errors then trigger a call to the 'write' method of
|
||
|
the provided object.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
func : callable f(err, flag) or object with write method
|
||
|
Function to call upon floating-point errors ('call'-mode) or
|
||
|
object whose 'write' method is used to log such message ('log'-mode).
|
||
|
|
||
|
The call function takes two arguments. The first is a string describing
|
||
|
the type of error (such as "divide by zero", "overflow", "underflow",
|
||
|
or "invalid value"), and the second is the status flag. The flag is a
|
||
|
byte, whose four least-significant bits indicate the type of error, one
|
||
|
of "divide", "over", "under", "invalid"::
|
||
|
|
||
|
[0 0 0 0 divide over under invalid]
|
||
|
|
||
|
In other words, ``flags = divide + 2*over + 4*under + 8*invalid``.
|
||
|
|
||
|
If an object is provided, its write method should take one argument,
|
||
|
a string.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
h : callable, log instance or None
|
||
|
The old error handler.
|
||
|
|
||
|
See Also
|
||
|
--------
|
||
|
seterr, geterr, geterrcall
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
Callback upon error:
|
||
|
|
||
|
>>> def err_handler(type, flag):
|
||
|
... print("Floating point error (%s), with flag %s" % (type, flag))
|
||
|
...
|
||
|
|
||
|
>>> saved_handler = np.seterrcall(err_handler)
|
||
|
>>> save_err = np.seterr(all='call')
|
||
|
>>> from collections import OrderedDict
|
||
|
|
||
|
>>> np.array([1, 2, 3]) / 0.0
|
||
|
Floating point error (divide by zero), with flag 1
|
||
|
array([inf, inf, inf])
|
||
|
|
||
|
>>> np.seterrcall(saved_handler)
|
||
|
<function err_handler at 0x...>
|
||
|
>>> OrderedDict(sorted(np.seterr(**save_err).items()))
|
||
|
OrderedDict([('divide', 'call'), ('invalid', 'call'), ('over', 'call'), ('under', 'call')])
|
||
|
|
||
|
Log error message:
|
||
|
|
||
|
>>> class Log(object):
|
||
|
... def write(self, msg):
|
||
|
... print("LOG: %s" % msg)
|
||
|
...
|
||
|
|
||
|
>>> log = Log()
|
||
|
>>> saved_handler = np.seterrcall(log)
|
||
|
>>> save_err = np.seterr(all='log')
|
||
|
|
||
|
>>> np.array([1, 2, 3]) / 0.0
|
||
|
LOG: Warning: divide by zero encountered in true_divide
|
||
|
array([inf, inf, inf])
|
||
|
|
||
|
>>> np.seterrcall(saved_handler)
|
||
|
<numpy.core.numeric.Log object at 0x...>
|
||
|
>>> OrderedDict(sorted(np.seterr(**save_err).items()))
|
||
|
OrderedDict([('divide', 'log'), ('invalid', 'log'), ('over', 'log'), ('under', 'log')])
|
||
|
|
||
|
"""
|
||
|
if func is not None and not isinstance(func, collections_abc.Callable):
|
||
|
if not hasattr(func, 'write') or not isinstance(func.write, collections_abc.Callable):
|
||
|
raise ValueError("Only callable can be used as callback")
|
||
|
pyvals = umath.geterrobj()
|
||
|
old = geterrcall()
|
||
|
pyvals[2] = func
|
||
|
umath.seterrobj(pyvals)
|
||
|
return old
|
||
|
|
||
|
|
||
|
@set_module('numpy')
|
||
|
def geterrcall():
|
||
|
"""
|
||
|
Return the current callback function used on floating-point errors.
|
||
|
|
||
|
When the error handling for a floating-point error (one of "divide",
|
||
|
"over", "under", or "invalid") is set to 'call' or 'log', the function
|
||
|
that is called or the log instance that is written to is returned by
|
||
|
`geterrcall`. This function or log instance has been set with
|
||
|
`seterrcall`.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
errobj : callable, log instance or None
|
||
|
The current error handler. If no handler was set through `seterrcall`,
|
||
|
``None`` is returned.
|
||
|
|
||
|
See Also
|
||
|
--------
|
||
|
seterrcall, seterr, geterr
|
||
|
|
||
|
Notes
|
||
|
-----
|
||
|
For complete documentation of the types of floating-point exceptions and
|
||
|
treatment options, see `seterr`.
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
>>> np.geterrcall() # we did not yet set a handler, returns None
|
||
|
|
||
|
>>> oldsettings = np.seterr(all='call')
|
||
|
>>> def err_handler(type, flag):
|
||
|
... print("Floating point error (%s), with flag %s" % (type, flag))
|
||
|
>>> oldhandler = np.seterrcall(err_handler)
|
||
|
>>> np.array([1, 2, 3]) / 0.0
|
||
|
Floating point error (divide by zero), with flag 1
|
||
|
array([inf, inf, inf])
|
||
|
|
||
|
>>> cur_handler = np.geterrcall()
|
||
|
>>> cur_handler is err_handler
|
||
|
True
|
||
|
|
||
|
"""
|
||
|
return umath.geterrobj()[2]
|
||
|
|
||
|
|
||
|
class _unspecified(object):
|
||
|
pass
|
||
|
|
||
|
|
||
|
_Unspecified = _unspecified()
|
||
|
|
||
|
|
||
|
@set_module('numpy')
|
||
|
class errstate(contextlib.ContextDecorator):
|
||
|
"""
|
||
|
errstate(**kwargs)
|
||
|
|
||
|
Context manager for floating-point error handling.
|
||
|
|
||
|
Using an instance of `errstate` as a context manager allows statements in
|
||
|
that context to execute with a known error handling behavior. Upon entering
|
||
|
the context the error handling is set with `seterr` and `seterrcall`, and
|
||
|
upon exiting it is reset to what it was before.
|
||
|
|
||
|
.. versionchanged:: 1.17.0
|
||
|
`errstate` is also usable as a function decorator, saving
|
||
|
a level of indentation if an entire function is wrapped.
|
||
|
See :py:class:`contextlib.ContextDecorator` for more information.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
kwargs : {divide, over, under, invalid}
|
||
|
Keyword arguments. The valid keywords are the possible floating-point
|
||
|
exceptions. Each keyword should have a string value that defines the
|
||
|
treatment for the particular error. Possible values are
|
||
|
{'ignore', 'warn', 'raise', 'call', 'print', 'log'}.
|
||
|
|
||
|
See Also
|
||
|
--------
|
||
|
seterr, geterr, seterrcall, geterrcall
|
||
|
|
||
|
Notes
|
||
|
-----
|
||
|
For complete documentation of the types of floating-point exceptions and
|
||
|
treatment options, see `seterr`.
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
>>> from collections import OrderedDict
|
||
|
>>> olderr = np.seterr(all='ignore') # Set error handling to known state.
|
||
|
|
||
|
>>> np.arange(3) / 0.
|
||
|
array([nan, inf, inf])
|
||
|
>>> with np.errstate(divide='warn'):
|
||
|
... np.arange(3) / 0.
|
||
|
array([nan, inf, inf])
|
||
|
|
||
|
>>> np.sqrt(-1)
|
||
|
nan
|
||
|
>>> with np.errstate(invalid='raise'):
|
||
|
... np.sqrt(-1)
|
||
|
Traceback (most recent call last):
|
||
|
File "<stdin>", line 2, in <module>
|
||
|
FloatingPointError: invalid value encountered in sqrt
|
||
|
|
||
|
Outside the context the error handling behavior has not changed:
|
||
|
|
||
|
>>> OrderedDict(sorted(np.geterr().items()))
|
||
|
OrderedDict([('divide', 'ignore'), ('invalid', 'ignore'), ('over', 'ignore'), ('under', 'ignore')])
|
||
|
|
||
|
"""
|
||
|
# Note that we don't want to run the above doctests because they will fail
|
||
|
# without a from __future__ import with_statement
|
||
|
|
||
|
def __init__(self, **kwargs):
|
||
|
self.call = kwargs.pop('call', _Unspecified)
|
||
|
self.kwargs = kwargs
|
||
|
|
||
|
def __enter__(self):
|
||
|
self.oldstate = seterr(**self.kwargs)
|
||
|
if self.call is not _Unspecified:
|
||
|
self.oldcall = seterrcall(self.call)
|
||
|
|
||
|
def __exit__(self, *exc_info):
|
||
|
seterr(**self.oldstate)
|
||
|
if self.call is not _Unspecified:
|
||
|
seterrcall(self.oldcall)
|
||
|
|
||
|
|
||
|
def _setdef():
|
||
|
defval = [UFUNC_BUFSIZE_DEFAULT, ERR_DEFAULT, None]
|
||
|
umath.seterrobj(defval)
|
||
|
|
||
|
|
||
|
# set the default values
|
||
|
_setdef()
|