import random from lightning import Trainer from sklearn.model_selection import train_test_split from torch.utils.data import DataLoader from data import ( FootballSegDataset, calculate_mean_std, get_data, ) from model import UNetLightning def main(): all_images, all_paths = get_data() image_train_paths, image_test_paths, label_train_paths, label_test_paths = ( train_test_split(all_images, all_paths, test_size=0.2, random_state=42) ) mean, std = calculate_mean_std(image_train_paths) label_mean, label_std = calculate_mean_std(label_train_paths) test_mean, test_std = calculate_mean_std(image_test_paths) test_label_mean, test_label_std = calculate_mean_std(label_test_paths) train_dataset = FootballSegDataset( image_train_paths, label_train_paths, mean, std, label_mean, label_std) test_dataset = FootballSegDataset( image_test_paths, label_test_paths, test_mean, test_std, test_label_mean, test_label_std) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) train_indices = random.sample(range(len(train_loader.dataset)), 5) test_indices = random.sample(range(len(test_loader.dataset)), 5) # plot_random_images(train_indices, train_loader, "train") # plot_random_images(test_indices, test_loader, "test") # statistics = SegmentationStatistics(all_paths, train_loader, test_loader, mean, std) # statistics.count_colors() # statistics.print_statistics() model = UNetLightning(3, 3, learning_rate=1e-3) trainer = Trainer(max_epochs=5, logger=True, log_every_n_steps=1) trainer.fit(model, train_loader, test_loader) model.eval() model.freeze() trainer.test(model, test_loader) if __name__ == "__main__": main()