projektAI/venv/Lib/site-packages/mlxtend/plotting/scatterplotmatrix.py
2021-06-06 22:13:05 +02:00

85 lines
2.5 KiB
Python

# Sebastian Raschka 2014-2020
# mlxtend Machine Learning Library Extensions
#
# Matplotlib wrapper for generating stacked barplots.
# Author: Sebastian Raschka <sebastianraschka.com>
#
# License: BSD 3 clause
import matplotlib.pyplot as plt
import numpy as np
def scatterplotmatrix(X, fig_axes=None, names=None,
figsize=(8, 8), alpha=1.0, **kwargs):
"""
Lower triangular of a scatterplot matrix
Parameters
-----------
X : array-like, shape={num_examples, num_features}
Design matrix containing data instances (examples)
with multiple exploratory variables (features).
fix_axes : tuple (default: None)
A `(fig, axes)` tuple, where fig is an figure object
and axes is an axes object created via matplotlib,
for example, by calling the pyplot `subplot` function
`fig, axes = plt.subplots(...)`
names : list (default: None)
A list of string names, which should have the same number
of elements as there are features (columns) in `X`.
figsize : tuple (default: (8, 8))
Height and width of the subplot grid. Ignored if
fig_axes is not `None`.
alpha : float (default: 1.0)
Transparency for both the scatter plots and the
histograms along the diagonal.
**kwargs : kwargs
Keyword arguments for the scatterplots.
Returns
--------
fix_axes : tuple
A `(fig, axes)` tuple, where fig is an figure object
and axes is an axes object created via matplotlib,
for example, by calling the pyplot `subplot` function
`fig, axes = plt.subplots(...)`
Examples
----------
For more usage examples, please see
http://rasbt.github.io/mlxtend/user_guide/plotting/scatterplotmatrix/
"""
num_examples, num_features = X.shape
if fig_axes is None:
fig, axes = plt.subplots(nrows=num_features,
ncols=num_features,
figsize=figsize)
else:
fig, axes = fig_axes
if names is None:
names = ['X%d' % (i+1) for i in range(num_features)]
for i, j in zip(*np.triu_indices_from(axes, k=1)):
axes[j, i].scatter(X[:, j], X[:, i], alpha=alpha, **kwargs)
axes[j, i].set_xlabel(names[j])
axes[j, i].set_ylabel(names[i])
axes[i, j].set_axis_off()
for i in range(num_features):
axes[i, i].hist(X[:, i], alpha=alpha)
axes[i, i].set_ylabel('Count')
axes[i, i].set_xlabel(names[i])
return fig, axes