iumKC/main.py

41 lines
1.3 KiB
Python
Raw Normal View History

2024-03-13 23:33:36 +01:00
import random
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from data import (
FootballSegDataset,
SegmentationStatistics,
calculate_mean_std,
get_data,
plot_random_images,
)
def main():
all_images, all_paths = get_data()
image_train_paths, image_test_paths, label_train_paths, label_test_paths = (
train_test_split(all_images, all_paths, test_size=0.2, random_state=42)
)
mean, std = calculate_mean_std(image_train_paths)
train_dataset = FootballSegDataset(image_train_paths, label_train_paths, mean, std)
test_dataset = FootballSegDataset(image_test_paths, label_test_paths, mean, std)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
train_indices = random.sample(range(len(train_loader.dataset)), 5)
test_indices = random.sample(range(len(test_loader.dataset)), 5)
plot_random_images(train_indices, train_loader, 'train')
plot_random_images(test_indices, test_loader, 'test')
2024-03-13 23:33:36 +01:00
statistics = SegmentationStatistics(all_paths, train_loader, test_loader, mean, std)
statistics.count_colors()
statistics.print_statistics()
if __name__ == "__main__":
main()