from model import UNetLightning from data import get_data, calculate_mean_std, FootballSegDataset from sklearn.model_selection import train_test_split from torch.utils.data import DataLoader from lightning import Trainer import numpy as np ckpt_path = "lightning_logs/version_0/checkpoints/epoch=4-step=15.ckpt" hparams_path = "lightning_logs/version_0/hparams.yaml" model = UNetLightning.load_from_checkpoint( ckpt_path, hparams_path=hparams_path) model.eval() model.freeze() all_images, all_paths = get_data() image_train_paths, image_test_paths, label_train_paths, label_test_paths = ( train_test_split(all_images, all_paths, test_size=0.2, random_state=42) ) mean, std = calculate_mean_std(image_test_paths) test_dataset = FootballSegDataset( image_test_paths, label_test_paths, mean, std) test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) trainer = Trainer() predictions = trainer.predict(model, test_loader) arrays = [] for batch in predictions: arrays.append(batch.cpu().numpy()) np.save("predictions.npy", np.concatenate(arrays))