iumKC/test.py

31 lines
1.1 KiB
Python
Raw Normal View History

2024-04-14 22:55:07 +02:00
from model import UNetLightning
from data import get_data, calculate_mean_std, FootballSegDataset
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from lightning import Trainer
import numpy as np
ckpt_path = "lightning_logs/version_0/checkpoints/epoch=4-step=15.ckpt"
hparams_path = "lightning_logs/version_0/hparams.yaml"
model = UNetLightning.load_from_checkpoint(
ckpt_path, hparams_path=hparams_path)
model.eval()
model.freeze()
all_images, all_paths = get_data()
image_train_paths, image_test_paths, label_train_paths, label_test_paths = (
train_test_split(all_images, all_paths, test_size=0.2, random_state=42)
)
mean, std = calculate_mean_std(image_test_paths)
test_dataset = FootballSegDataset(
image_test_paths, label_test_paths, mean, std)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
trainer = Trainer()
predictions = trainer.predict(model, test_loader)
arrays = []
for batch in predictions:
arrays.append(batch.cpu().numpy())
np.save("predictions.npy", np.concatenate(arrays))