import argparse import random from lightning import Trainer import lightning as L from lightning.pytorch.loggers import MLFlowLogger from sklearn.model_selection import train_test_split from torch.utils.data import DataLoader from data import ( FootballSegDataset, calculate_mean_std, get_data, ) from model import UNetLightning def main(train=True, test=True, save_model=False, load_model=None): all_images, all_paths = get_data() image_train_paths, image_test_paths, label_train_paths, label_test_paths = ( train_test_split(all_images, all_paths, test_size=0.2, random_state=42) ) mean, std = calculate_mean_std(image_train_paths) label_mean, label_std = calculate_mean_std(label_train_paths) test_mean, test_std = calculate_mean_std(image_test_paths) test_label_mean, test_label_std = calculate_mean_std(label_test_paths) train_dataset = FootballSegDataset( image_train_paths, label_train_paths, mean, std, label_mean, label_std) test_dataset = FootballSegDataset( image_test_paths, label_test_paths, test_mean, test_std, test_label_mean, test_label_std) train_loader = DataLoader( train_dataset, batch_size=32, shuffle=True, num_workers=7, persistent_workers=True) test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=7, persistent_workers=True) # train_indices = random.sample(range(len(train_loader.dataset)), 5) # test_indices = random.sample(range(len(test_loader.dataset)), 5) # plot_random_images(train_indices, train_loader, "train") # plot_random_images(test_indices, test_loader, "test") # statistics = SegmentationStatistics(all_paths, train_loader, test_loader, mean, std) # statistics.count_colors() # statistics.print_statistics() mlFlowLogger = MLFlowLogger( experiment_name="football-segmentation", tracking_uri="http://127.0.0.1:8080") model = UNetLightning(3, 3, learning_rate=1e-3) trainer = Trainer(max_epochs=2, logger=mlFlowLogger, log_every_n_steps=1) if train: trainer.fit(model, train_loader, test_loader) if save_model: model.save_hyperparameters() if load_model: model = UNetLightning.load_from_checkpoint(load_model) model.eval() model.freeze() if test: trainer.test(model, test_loader) if __name__ == "__main__": args = argparse.ArgumentParser() args.add_argument("--seed", type=int, default=42) args.add_argument("--train", type=bool, default=False) args.add_argument("--test", type=bool, default=False) args.add_argument("--save_model", type=bool, default=False) args.add_argument("--load_model", type=str, default=None) args = args.parse_args() L.seed_everything(args.seed) main(args.train, args.test, args.save_model, args.load_model)