
164 lines
5.1 KiB
Raw Normal View History

2021-06-06 22:13:05 +02:00
import sys
import numpy as np
from numpy.testing import assert_array_almost_equal
import pytest
import matplotlib.pyplot as plt
from matplotlib.testing.decorators import image_comparison
import matplotlib.transforms as mtransforms
on_win = (sys.platform == 'win32')
on_mac = (sys.platform == 'darwin')
def velocity_field():
Y, X = np.mgrid[-3:3:100j, -3:3:100j]
U = -1 - X**2 + Y
V = 1 + X - Y**2
return X, Y, U, V
def swirl_velocity_field():
x = np.linspace(-3., 3., 100)
y = np.linspace(-3., 3., 100)
X, Y = np.meshgrid(x, y)
a = 0.1
U = np.cos(a) * (-Y) - np.sin(a) * X
V = np.sin(a) * (-Y) + np.cos(a) * X
return x, y, U, V
@image_comparison(['streamplot_startpoints'], remove_text=True, style='mpl20')
def test_startpoints():
X, Y, U, V = velocity_field()
start_x = np.linspace(X.min(), X.max(), 10)
start_y = np.linspace(Y.min(), Y.max(), 10)
start_points = np.column_stack([start_x, start_y])
plt.streamplot(X, Y, U, V, start_points=start_points)
plt.plot(start_x, start_y, 'ok')
tol=.04, remove_text=True, style='mpl20')
def test_colormap():
# Remove this line when this test image is regenerated.
plt.rcParams['pcolormesh.snap'] = False
X, Y, U, V = velocity_field()
plt.streamplot(X, Y, U, V, color=U, density=0.6, linewidth=2,
@image_comparison(['streamplot_linewidth'], remove_text=True, style='mpl20')
def test_linewidth():
X, Y, U, V = velocity_field()
speed = np.hypot(U, V)
lw = 5 * speed / speed.max()
# Compatibility for old test image
df = 25 / 30
ax = plt.figure().subplots()
ax.set(xlim=(-3.0, 2.9999999999999947),
ylim=(-3.0000000000000004, 2.9999999999999947))
ax.streamplot(X, Y, U, V, density=[0.5 * df, 1. * df], color='k',
remove_text=True, style='mpl20', tol=0.04 if on_win else 0)
def test_masks_and_nans():
X, Y, U, V = velocity_field()
mask = np.zeros(U.shape, dtype=bool)
mask[40:60, 40:60] = 1
U[:20, :20] = np.nan
U =, mask=mask)
# Compatibility for old test image
ax = plt.figure().subplots()
ax.set(xlim=(-3.0, 2.9999999999999947),
ylim=(-3.0000000000000004, 2.9999999999999947))
with np.errstate(invalid='ignore'):
ax.streamplot(X, Y, U, V, color=U,
remove_text=True, style='mpl20',
tol=0.002 if on_mac else 0)
def test_maxlength():
x, y, U, V = swirl_velocity_field()
ax = plt.figure().subplots()
ax.streamplot(x, y, U, V, maxlength=10., start_points=[[0., 1.5]],
linewidth=2, density=2)
assert ax.get_xlim()[-1] == ax.get_ylim()[-1] == 3
# Compatibility for old test image
ax.set(xlim=(None, 3.2555988021882305), ylim=(None, 3.078326760195413))
remove_text=True, style='mpl20')
def test_direction():
x, y, U, V = swirl_velocity_field()
plt.streamplot(x, y, U, V, integration_direction='backward',
maxlength=1.5, start_points=[[1.5, 0.]],
linewidth=2, density=2)
def test_streamplot_limits():
ax = plt.axes()
x = np.linspace(-5, 10, 20)
y = np.linspace(-2, 4, 10)
y, x = np.meshgrid(y, x)
trans = mtransforms.Affine2D().translate(25, 32) + ax.transData
plt.barbs(x, y, np.sin(x), np.cos(y), transform=trans)
# The calculated bounds are approximately the bounds of the original data,
# this is because the entire path is taken into account when updating the
# datalim.
assert_array_almost_equal(ax.dataLim.bounds, (20, 30, 15, 6),
def test_streamplot_grid():
u = np.ones((2, 2))
v = np.zeros((2, 2))
# Test for same rows and columns
x = np.array([[10, 20], [10, 30]])
y = np.array([[10, 10], [20, 20]])
with pytest.raises(ValueError, match="The rows of 'x' must be equal"):
plt.streamplot(x, y, u, v)
x = np.array([[10, 20], [10, 20]])
y = np.array([[10, 10], [20, 30]])
with pytest.raises(ValueError, match="The columns of 'y' must be equal"):
plt.streamplot(x, y, u, v)
x = np.array([[10, 20], [10, 20]])
y = np.array([[10, 10], [20, 20]])
plt.streamplot(x, y, u, v)
# Test for maximum dimensions
x = np.array([0, 10])
y = np.array([[[0, 10]]])
with pytest.raises(ValueError, match="'y' can have at maximum "
"2 dimensions"):
plt.streamplot(x, y, u, v)
# Test for equal spacing
u = np.ones((3, 3))
v = np.zeros((3, 3))
x = np.array([0, 10, 20])
y = np.array([0, 10, 30])
with pytest.raises(ValueError, match="'y' values must be equally spaced"):
plt.streamplot(x, y, u, v)
# Test for strictly increasing
x = np.array([0, 20, 40])
y = np.array([0, 20, 10])
with pytest.raises(ValueError, match="'y' must be strictly increasing"):
plt.streamplot(x, y, u, v)