31 lines
1.1 KiB
Python
31 lines
1.1 KiB
Python
|
from model import UNetLightning
|
||
|
from data import get_data, calculate_mean_std, FootballSegDataset
|
||
|
from sklearn.model_selection import train_test_split
|
||
|
from torch.utils.data import DataLoader
|
||
|
from lightning import Trainer
|
||
|
import numpy as np
|
||
|
|
||
|
ckpt_path = "lightning_logs/version_0/checkpoints/epoch=4-step=15.ckpt"
|
||
|
hparams_path = "lightning_logs/version_0/hparams.yaml"
|
||
|
model = UNetLightning.load_from_checkpoint(
|
||
|
ckpt_path, hparams_path=hparams_path)
|
||
|
model.eval()
|
||
|
model.freeze()
|
||
|
|
||
|
all_images, all_paths = get_data()
|
||
|
image_train_paths, image_test_paths, label_train_paths, label_test_paths = (
|
||
|
train_test_split(all_images, all_paths, test_size=0.2, random_state=42)
|
||
|
)
|
||
|
|
||
|
mean, std = calculate_mean_std(image_test_paths)
|
||
|
test_dataset = FootballSegDataset(
|
||
|
image_test_paths, label_test_paths, mean, std)
|
||
|
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
|
||
|
|
||
|
trainer = Trainer()
|
||
|
predictions = trainer.predict(model, test_loader)
|
||
|
arrays = []
|
||
|
for batch in predictions:
|
||
|
arrays.append(batch.cpu().numpy())
|
||
|
np.save("predictions.npy", np.concatenate(arrays))
|