AI-Tech-WKO-Projekt/experiments/plots.py

108 lines
3.4 KiB
Python

from datetime import timedelta
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from torch.utils.data import Dataset
from torchvision.transforms import ToPILImage
import metrics
def plot_metrics(
title: str,
test_metrics: metrics.Metrics,
train_metrics: metrics.Metrics,
n_epochs: int,
time: int,
image_size: int,
device: str = 'cpu'
):
"""
Shows a plot from collected metrics
:param title: Plot title
:param test_metrics: A dict of keyed metric scores with arrays as values.
Each metric should have the same # of items.
Keys:
- l - losses
- a - accuracy scores
- p - precision scores
- r - recall scores
- f - f scores
:param train_metrics: A dict of keyed metric scores with arrays as values.
See `test_metrics` for details
:param n_epochs:
:param time: Time taken to train the model in seconds
:param image_size: A number corresponding to the size of images used to train the model
:param device: What was used to train the model
:return:
"""
plt.style.use('classic')
sns.set()
fig, axis = plt.subplot_mosaic([['l', 'l'],
['a', 'p'],
['r', 'f']],
constrained_layout=True, figsize=(10, 10))
axis['l'].plot(test_metrics.loss)
axis['l'].plot(train_metrics.loss)
axis['l'].set_yscale('log')
axis['l'].set_title("Loss")
axis['a'].plot(test_metrics.accuracy)
axis['a'].plot(train_metrics.accuracy)
axis['a'].set_title("Accuracy")
axis['p'].plot(test_metrics.precision)
axis['p'].plot(train_metrics.precision)
axis['p'].set_title("Precision")
axis['r'].plot(test_metrics.recall)
axis['r'].plot(train_metrics.recall)
axis['r'].set_title("Recall")
axis['f'].plot(test_metrics.f_score)
axis['f'].plot(train_metrics.f_score)
axis['f'].set_title("F-score")
fig.tight_layout()
fig.subplots_adjust(top=0.90, bottom=0.05)
fig.suptitle(title, fontsize=24)
fig.legend(axis, labels=['test', 'train'], loc="lower center")
plt.text(0.30, 0.93,
f'{device}, {image_size}x{image_size}, {n_epochs} iteracji, czas treningu: '
f'{str(timedelta(seconds=time)).split(".", maxsplit=1)[0]}',
fontsize=14,
transform=plt.gcf().transFigure)
plt.show()
def show_missclassified(
dataset: Dataset,
preds: np.ndarray,
label_names: dict,
count_per_class: int = 5
):
results = {}
for i in label_names.keys():
results[i] = []
indexes = np.random.permutation(len(preds))
for i in indexes:
pred = preds[i]
image_tensor, true = dataset[i]
if len(results[pred]) < count_per_class and pred != int(true):
results[pred].append({
"image": ToPILImage()(image_tensor),
"actual": int(true)
})
sns.reset_orig()
plt.figure(figsize=[20, 30])
for row, (label, images) in enumerate(results.items()):
for i, image in enumerate(images):
plt.subplot(len(label_names.keys()), count_per_class, row * count_per_class + i + 1)
plt.imshow(image["image"], interpolation="bicubic")
plt.title(f'{label_names[label]}, expected {label_names[image["actual"]]}')
plt.axis('off')
plt.show()