iumKC/main.py

77 lines
2.8 KiB
Python
Raw Normal View History

2024-05-13 18:51:43 +02:00
import argparse
2024-03-13 23:33:36 +01:00
import random
2024-04-14 22:55:07 +02:00
from lightning import Trainer
2024-05-13 18:51:43 +02:00
import lightning as L
from lightning.pytorch.loggers import MLFlowLogger
2024-03-13 23:33:36 +01:00
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from data import (
FootballSegDataset,
calculate_mean_std,
get_data,
)
2024-04-14 22:55:07 +02:00
from model import UNetLightning
2024-03-13 23:33:36 +01:00
2024-05-13 18:51:43 +02:00
def main(train=True, test=True, save_model=False, load_model=None):
2024-03-13 23:33:36 +01:00
all_images, all_paths = get_data()
image_train_paths, image_test_paths, label_train_paths, label_test_paths = (
train_test_split(all_images, all_paths, test_size=0.2, random_state=42)
)
mean, std = calculate_mean_std(image_train_paths)
2024-04-14 22:55:07 +02:00
label_mean, label_std = calculate_mean_std(label_train_paths)
test_mean, test_std = calculate_mean_std(image_test_paths)
test_label_mean, test_label_std = calculate_mean_std(label_test_paths)
2024-03-13 23:33:36 +01:00
2024-04-14 22:55:07 +02:00
train_dataset = FootballSegDataset(
image_train_paths, label_train_paths, mean, std, label_mean, label_std)
test_dataset = FootballSegDataset(
image_test_paths, label_test_paths, test_mean, test_std, test_label_mean, test_label_std)
2024-03-13 23:33:36 +01:00
2024-05-13 18:51:43 +02:00
train_loader = DataLoader(
train_dataset, batch_size=32, shuffle=True, num_workers=7, persistent_workers=True)
test_loader = DataLoader(test_dataset, batch_size=32,
shuffle=False, num_workers=7, persistent_workers=True)
2024-03-13 23:33:36 +01:00
2024-05-13 18:51:43 +02:00
# train_indices = random.sample(range(len(train_loader.dataset)), 5)
# test_indices = random.sample(range(len(test_loader.dataset)), 5)
2024-03-13 23:33:36 +01:00
2024-04-14 22:55:07 +02:00
# plot_random_images(train_indices, train_loader, "train")
# plot_random_images(test_indices, test_loader, "test")
# statistics = SegmentationStatistics(all_paths, train_loader, test_loader, mean, std)
# statistics.count_colors()
# statistics.print_statistics()
2024-05-13 18:51:43 +02:00
mlFlowLogger = MLFlowLogger(
experiment_name="football-segmentation", tracking_uri="http://127.0.0.1:8080")
2024-04-14 22:55:07 +02:00
model = UNetLightning(3, 3, learning_rate=1e-3)
2024-05-13 18:51:43 +02:00
trainer = Trainer(max_epochs=2, logger=mlFlowLogger, log_every_n_steps=1)
if train:
trainer.fit(model, train_loader, test_loader)
2024-04-14 22:55:07 +02:00
2024-05-13 18:51:43 +02:00
if save_model:
model.save_hyperparameters()
if load_model:
model = UNetLightning.load_from_checkpoint(load_model)
model.eval()
model.freeze()
if test:
trainer.test(model, test_loader)
2024-03-13 23:33:36 +01:00
if __name__ == "__main__":
2024-05-13 18:51:43 +02:00
args = argparse.ArgumentParser()
args.add_argument("--seed", type=int, default=42)
args.add_argument("--train", type=bool, default=False)
args.add_argument("--test", type=bool, default=False)
args.add_argument("--save_model", type=bool, default=False)
args.add_argument("--load_model", type=str, default=None)
args = args.parse_args()
L.seed_everything(args.seed)
main(args.train, args.test, args.save_model, args.load_model)