diff --git a/.gitignore b/.gitignore index 1cdc336..90187e5 100644 --- a/.gitignore +++ b/.gitignore @@ -175,4 +175,5 @@ ipython_config.py # git rm -r .ipynb_checkpoints/ football_dataset/ -venvium/ \ No newline at end of file +venvium/ +lightning_logs/ \ No newline at end of file diff --git a/MLProject b/MLProject new file mode 100644 index 0000000..b48a269 --- /dev/null +++ b/MLProject @@ -0,0 +1,11 @@ +name: iumKC + +conda_env: environment.yaml + +entry_points: + main: + command: "python train.py --train=true --save_model=true" + test: + parameters: + model_filepath: "model" + command: "python test.py --test=true --load_model=$model_filepath" \ No newline at end of file diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..3a7ae17 Binary files /dev/null and b/environment.yml differ diff --git a/id2label.json b/id2label.json new file mode 100644 index 0000000..1c9f213 --- /dev/null +++ b/id2label.json @@ -0,0 +1 @@ +{"0": "Goal Bar", "1": "Referee", "2": "Advertisement", "3": "Ground", "4": "Ball", "5": "Coaches & Officials", "6": "Team A", "7": "Team B", "8": "Goalkeeper A", "9": "Goalkeeper B", "10": "Audience"} \ No newline at end of file diff --git a/main.py b/main.py index 2c5eef8..f3200c3 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,9 @@ +import argparse import random from lightning import Trainer +import lightning as L +from lightning.pytorch.loggers import MLFlowLogger from sklearn.model_selection import train_test_split from torch.utils.data import DataLoader @@ -12,7 +15,7 @@ from data import ( from model import UNetLightning -def main(): +def main(train=True, test=True, save_model=False, load_model=None): all_images, all_paths = get_data() image_train_paths, image_test_paths, label_train_paths, label_test_paths = ( train_test_split(all_images, all_paths, test_size=0.2, random_state=42) @@ -28,11 +31,13 @@ def main(): test_dataset = FootballSegDataset( image_test_paths, label_test_paths, test_mean, test_std, test_label_mean, test_label_std) - train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) - test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) + train_loader = DataLoader( + train_dataset, batch_size=32, shuffle=True, num_workers=7, persistent_workers=True) + test_loader = DataLoader(test_dataset, batch_size=32, + shuffle=False, num_workers=7, persistent_workers=True) - train_indices = random.sample(range(len(train_loader.dataset)), 5) - test_indices = random.sample(range(len(test_loader.dataset)), 5) + # train_indices = random.sample(range(len(train_loader.dataset)), 5) + # test_indices = random.sample(range(len(test_loader.dataset)), 5) # plot_random_images(train_indices, train_loader, "train") # plot_random_images(test_indices, test_loader, "test") @@ -40,14 +45,32 @@ def main(): # statistics.count_colors() # statistics.print_statistics() + mlFlowLogger = MLFlowLogger( + experiment_name="football-segmentation", tracking_uri="http://127.0.0.1:8080") model = UNetLightning(3, 3, learning_rate=1e-3) - trainer = Trainer(max_epochs=5, logger=True, log_every_n_steps=1) - trainer.fit(model, train_loader, test_loader) + trainer = Trainer(max_epochs=2, logger=mlFlowLogger, log_every_n_steps=1) + if train: + trainer.fit(model, train_loader, test_loader) - model.eval() - model.freeze() - trainer.test(model, test_loader) + if save_model: + model.save_hyperparameters() + + if load_model: + model = UNetLightning.load_from_checkpoint(load_model) + model.eval() + model.freeze() + + if test: + trainer.test(model, test_loader) if __name__ == "__main__": - main() + args = argparse.ArgumentParser() + args.add_argument("--seed", type=int, default=42) + args.add_argument("--train", type=bool, default=False) + args.add_argument("--test", type=bool, default=False) + args.add_argument("--save_model", type=bool, default=False) + args.add_argument("--load_model", type=str, default=None) + args = args.parse_args() + L.seed_everything(args.seed) + main(args.train, args.test, args.save_model, args.load_model) diff --git a/model.py b/model.py index d873fc9..02cadbd 100644 --- a/model.py +++ b/model.py @@ -47,14 +47,17 @@ class UNetLightning(L.LightningModule): def training_step(self, batch, batch_idx): x, y = batch - # print(max(x.flatten()), min(x.flatten()), - # max(y.flatten()), min(y.flatten())) y_hat = self(x) loss = nn.CrossEntropyLoss()(y_hat, y) self.log("train_loss", loss) return loss + def log_hyperparameters(self): + for key, value in self.hparams.items(): + self.logger.experiment.log_param(key, value) + def configure_optimizers(self): + self.log_hyperparameters() optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) return optimizer @@ -74,39 +77,3 @@ class UNetLightning(L.LightningModule): y_hat = self(x) loss = nn.CrossEntropyLoss()(y_hat, y) self.log("test_loss", loss) - # visualize - # if batch_idx == 0: - # import matplotlib.pyplot as plt - # import numpy as np - # import torchvision.transforms as transforms - # x = transforms.ToPILImage()(x[0]) - # y = transforms.ToPILImage()(y[0]) - # y_hat = transforms.ToPILImage()(y_hat[0]) - # plt.figure(figsize=(15, 15)) - # plt.subplot(131) - # plt.imshow(np.array(x)) - # plt.title("Input") - # plt.axis("off") - # plt.subplot(132) - # plt.imshow(np.array(y)) - # plt.title("Ground Truth") - # plt.axis("off") - # plt.subplot(133) - # plt.imshow(np.array(y_hat)) - # plt.title("Prediction") - # plt.axis("off") - # plt.show() - - # def on_test_epoch_end(self): - # all_preds, all_labels = [], [] - # for output in self.trainer.predictions: - # # predicted values - # probs = list(output['logits'].cpu().detach().numpy()) - # labels = list(output['labels'].flatten().cpu().detach().numpy()) - # all_preds.extend(probs) - # all_labels.extend(labels) - - # # save predictions and labels - # import numpy as np - # np.save('predictions.npy', all_preds) - # np.save('labels.npy', all_labels) diff --git a/requirements.txt b/requirements.txt index 1ed36d5..8a57100 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,5 +2,5 @@ kaggle==1.6.6 matplotlib==3.6.3 Pillow==10.2.0 scikit_learn==1.2.2 -torch==2.0.0+cu117 -torchvision==0.15.1+cu117 +# torch==2.0.0+cu117 +# torchvision==0.15.1+cu117