import random from sklearn.model_selection import train_test_split from torch.utils.data import DataLoader from data import ( FootballSegDataset, SegmentationStatistics, calculate_mean_std, get_data, plot_random_images, ) def main(): all_images, all_paths = get_data() image_train_paths, image_test_paths, label_train_paths, label_test_paths = ( train_test_split(all_images, all_paths, test_size=0.2, random_state=42) ) mean, std = calculate_mean_std(image_train_paths) train_dataset = FootballSegDataset(image_train_paths, label_train_paths, mean, std) test_dataset = FootballSegDataset(image_test_paths, label_test_paths, mean, std) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) train_indices = random.sample(range(len(train_loader.dataset)), 5) test_indices = random.sample(range(len(test_loader.dataset)), 5) plot_random_images(train_indices, train_loader, 'train') plot_random_images(test_indices, test_loader, 'test') statistics = SegmentationStatistics(all_paths, train_loader, test_loader, mean, std) statistics.count_colors() statistics.print_statistics() if __name__ == "__main__": main()