2024-03-13 23:33:36 +01:00
|
|
|
import random
|
|
|
|
|
2024-04-14 22:55:07 +02:00
|
|
|
from lightning import Trainer
|
2024-03-13 23:33:36 +01:00
|
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
|
|
from data import (
|
|
|
|
FootballSegDataset,
|
|
|
|
calculate_mean_std,
|
|
|
|
get_data,
|
|
|
|
)
|
2024-04-14 22:55:07 +02:00
|
|
|
from model import UNetLightning
|
2024-03-13 23:33:36 +01:00
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
all_images, all_paths = get_data()
|
|
|
|
image_train_paths, image_test_paths, label_train_paths, label_test_paths = (
|
|
|
|
train_test_split(all_images, all_paths, test_size=0.2, random_state=42)
|
|
|
|
)
|
|
|
|
|
|
|
|
mean, std = calculate_mean_std(image_train_paths)
|
2024-04-14 22:55:07 +02:00
|
|
|
label_mean, label_std = calculate_mean_std(label_train_paths)
|
|
|
|
test_mean, test_std = calculate_mean_std(image_test_paths)
|
|
|
|
test_label_mean, test_label_std = calculate_mean_std(label_test_paths)
|
2024-03-13 23:33:36 +01:00
|
|
|
|
2024-04-14 22:55:07 +02:00
|
|
|
train_dataset = FootballSegDataset(
|
|
|
|
image_train_paths, label_train_paths, mean, std, label_mean, label_std)
|
|
|
|
test_dataset = FootballSegDataset(
|
|
|
|
image_test_paths, label_test_paths, test_mean, test_std, test_label_mean, test_label_std)
|
2024-03-13 23:33:36 +01:00
|
|
|
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
|
|
|
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
|
|
|
|
|
|
|
|
train_indices = random.sample(range(len(train_loader.dataset)), 5)
|
|
|
|
test_indices = random.sample(range(len(test_loader.dataset)), 5)
|
|
|
|
|
2024-04-14 22:55:07 +02:00
|
|
|
# plot_random_images(train_indices, train_loader, "train")
|
|
|
|
# plot_random_images(test_indices, test_loader, "test")
|
|
|
|
# statistics = SegmentationStatistics(all_paths, train_loader, test_loader, mean, std)
|
|
|
|
# statistics.count_colors()
|
|
|
|
# statistics.print_statistics()
|
|
|
|
|
|
|
|
model = UNetLightning(3, 3, learning_rate=1e-3)
|
|
|
|
trainer = Trainer(max_epochs=5, logger=True, log_every_n_steps=1)
|
|
|
|
trainer.fit(model, train_loader, test_loader)
|
|
|
|
|
|
|
|
model.eval()
|
|
|
|
model.freeze()
|
|
|
|
trainer.test(model, test_loader)
|
2024-03-13 23:33:36 +01:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|