From 1d1c419a337f555db6886a37f8ae149e5061ec48 Mon Sep 17 00:00:00 2001 From: Karol Cyganik Date: Mon, 13 May 2024 18:51:43 +0200 Subject: [PATCH] mlflow + conda --- .gitignore | 3 ++- MLProject | 11 +++++++++++ environment.yml | Bin 0 -> 10288 bytes id2label.json | 1 + main.py | 45 ++++++++++++++++++++++++++++++++++----------- model.py | 43 +++++-------------------------------------- requirements.txt | 4 ++-- 7 files changed, 55 insertions(+), 52 deletions(-) create mode 100644 MLProject create mode 100644 environment.yml create mode 100644 id2label.json diff --git a/.gitignore b/.gitignore index 1cdc336..90187e5 100644 --- a/.gitignore +++ b/.gitignore @@ -175,4 +175,5 @@ ipython_config.py # git rm -r .ipynb_checkpoints/ football_dataset/ -venvium/ \ No newline at end of file +venvium/ +lightning_logs/ \ No newline at end of file diff --git a/MLProject b/MLProject new file mode 100644 index 0000000..b48a269 --- /dev/null +++ b/MLProject @@ -0,0 +1,11 @@ +name: iumKC + +conda_env: environment.yaml + +entry_points: + main: + command: "python train.py --train=true --save_model=true" + test: + parameters: + model_filepath: "model" + command: "python test.py --test=true --load_model=$model_filepath" \ No newline at end of file diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..3a7ae17f615c9c741eb6d63ab5cdbd6f1678503a GIT binary patch literal 10288 zcmb`N-A*c77KQ7&lYWQu0_=z=Vkcg8ze3;bgaCrNRpbb6w?2G#k8cg;u2sdV0#7;w zL8$pR=I@%tfB&|djI$8Z`Go@*XL_Jmeo_W)wg{5haLkxH`PbA zt6r+tbY1?4eUFk7T z|IMpuHED7D6t9i0tT{~d{$8KV^;qdK)a$ewH|Q1#gN5vPl|B!;>OpqR^d8G@M17{` zg)ABAIn)E5Gd-@WOFc&VcB*S~>zdqh9_QkLhS-Hg6G^!h&sox}%>_BH(qbuX2ErQ; zKkGA4`%D_c8;rk6!IxXERZm`kQd6IrQuGCY_0#MS47m4iBLN z2%XDIaBU(|)A$`|2OV)T6vgFDzd*yzeyhc3Kd8btN6jOo>LRsU3MN)mbC8U zK<4njC)DHR_=}Uyp}#wX*5pNwu$Z!p}VerNF#g^97bQ# z>xZt0%H$bbz~K z9KT^Bb>mpv*`B%^LwCKWB{q0#h5WSC zlTVKEdds*)1iO=+%g^n)YFie{*085tF`TnxWFIwCfcs8;y?bXfs%m+r$fcc1N!MUf zBBJQAV0NkJP?3DtA!~_2m&^PL{Z`TZq&m(-gY1ZB>LC+LH?}!tF8wPUsF0pnqMD*R zaehp8O_|Gv@;C=itIVCuZO-9gLiu<$65HLiotgtDqI*`zN~nnu`VJljWoEI%>}*#7 zD5G`^r1zU}o6C2u3fSx{A1X`7iX&#S!(N0K39y*6+`W=ZHu?npQ%GRz8z^)6FS|!s8{Jpzr5(+;M2`Y*Siv+X(>)i z9)1(v&Z>(u90vYP9s}hnXz~rOJ}HtXuC~~98s|@G!8yv0i z<8!At;B2?klkRh)m~!VK2b(9~ZZ4dAbGoYrrhRnSf_%JuDhE%1Ow;bHcql(R=JK2i z@4}p#vy*3$HB)s23A|ysuAZPt{U+YW%C#$@LnWvCwq?+@9_81KVsxhK>gn1!mQT5D z?=<7y+YThKIUvr%u6t=QR~`51nmo0B%zRCY*_RAbw!A9y*hqvv*n%(n^l3*6+@`WA zW*pGA*lc`n7Sraim2d2Q0}0sol?D7CHXKT9w{f9QFh@~7Tr zv7zFe1!p`k6y|5bc_@sLmE-Ig^*u%GV>ZN$5vTOn(`qx?G7mFR z1&1qPLC?&y=9o@wW@EGI=)xQBU3m9Ac>-_XiK$-U9yEEh=ATbhqvpaYPP}v+>Iu|F z>yQ#GA}?y6+7|3$kI4y(-CoQ?@d=yfNz$;JFSXy);Ek0rUo2%O73@lXr}97_kHc>3 z_=Vr8ebq>Xy#wVZ{hArykwB#hUE8{R$ol6L`DC7Zdeh_m)t?yrQ}UPz*qd8>d~&?P zHLKHBI+6`hS%$y!&^`m>d+}Y#^PCLMmZw_f964@~_nbFUqrsQ?*SasYfvTlW zfd|s#mo#B(yVYlOPxL_-Pq2ZnA-pagS0AVQ202E-Rb4^i!HpKR}7tWwSu`k$;zZ*Umkz=W1rA_LdnQqZGmH z-^gqbXCp@zUFk1pLpU}k!=Lh#ryAnybiVBF*yxU>DN)_#`!3l;FF66q7m|M*pJtwU zDxanpBhIJ$s>3M+v`<02$?=!&a(SKyFZOyM)I2LUeNQCf9pgItnJD*TBJ<#ChA67nRk7L zY+w75rUp=f2Fa|RdfKV5nZaw%KEIy zzUfJ}QUzA~tn!Ra43#q9m4Y4}b2+hAt}R^a34veC*vhkx`0@FPT8ie(%Vf{sH&RiN zjZv3&K0lEe{CdV6?{y!$L0LwanV7R;bYGtYO9*UL6NvUry4oj+E(CmfqQld7>GG~z zjzMCo>?6}fl(@bUQI=NH%4c1yX71#U{ZcYt(rZ{%&U513TLOOh#qY*k9njGoxhyZJ z*j2`f0!i*Rj-oQ<=RSg6O?|rSj-q#AjX$a8Y^WSZeSO*YAN>0i5O6G+GEgi>ER@td zerZKk`4?!81<+#3j%e_9iu;7W(r^3zslVT=f6E4>Z1w(MdPjx&uE;=3=AqbY Rj`S%uF#mtoH}Si`{{bZfQ-}Zn literal 0 HcmV?d00001 diff --git a/id2label.json b/id2label.json new file mode 100644 index 0000000..1c9f213 --- /dev/null +++ b/id2label.json @@ -0,0 +1 @@ +{"0": "Goal Bar", "1": "Referee", "2": "Advertisement", "3": "Ground", "4": "Ball", "5": "Coaches & Officials", "6": "Team A", "7": "Team B", "8": "Goalkeeper A", "9": "Goalkeeper B", "10": "Audience"} \ No newline at end of file diff --git a/main.py b/main.py index 2c5eef8..f3200c3 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,9 @@ +import argparse import random from lightning import Trainer +import lightning as L +from lightning.pytorch.loggers import MLFlowLogger from sklearn.model_selection import train_test_split from torch.utils.data import DataLoader @@ -12,7 +15,7 @@ from data import ( from model import UNetLightning -def main(): +def main(train=True, test=True, save_model=False, load_model=None): all_images, all_paths = get_data() image_train_paths, image_test_paths, label_train_paths, label_test_paths = ( train_test_split(all_images, all_paths, test_size=0.2, random_state=42) @@ -28,11 +31,13 @@ def main(): test_dataset = FootballSegDataset( image_test_paths, label_test_paths, test_mean, test_std, test_label_mean, test_label_std) - train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) - test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) + train_loader = DataLoader( + train_dataset, batch_size=32, shuffle=True, num_workers=7, persistent_workers=True) + test_loader = DataLoader(test_dataset, batch_size=32, + shuffle=False, num_workers=7, persistent_workers=True) - train_indices = random.sample(range(len(train_loader.dataset)), 5) - test_indices = random.sample(range(len(test_loader.dataset)), 5) + # train_indices = random.sample(range(len(train_loader.dataset)), 5) + # test_indices = random.sample(range(len(test_loader.dataset)), 5) # plot_random_images(train_indices, train_loader, "train") # plot_random_images(test_indices, test_loader, "test") @@ -40,14 +45,32 @@ def main(): # statistics.count_colors() # statistics.print_statistics() + mlFlowLogger = MLFlowLogger( + experiment_name="football-segmentation", tracking_uri="http://127.0.0.1:8080") model = UNetLightning(3, 3, learning_rate=1e-3) - trainer = Trainer(max_epochs=5, logger=True, log_every_n_steps=1) - trainer.fit(model, train_loader, test_loader) + trainer = Trainer(max_epochs=2, logger=mlFlowLogger, log_every_n_steps=1) + if train: + trainer.fit(model, train_loader, test_loader) - model.eval() - model.freeze() - trainer.test(model, test_loader) + if save_model: + model.save_hyperparameters() + + if load_model: + model = UNetLightning.load_from_checkpoint(load_model) + model.eval() + model.freeze() + + if test: + trainer.test(model, test_loader) if __name__ == "__main__": - main() + args = argparse.ArgumentParser() + args.add_argument("--seed", type=int, default=42) + args.add_argument("--train", type=bool, default=False) + args.add_argument("--test", type=bool, default=False) + args.add_argument("--save_model", type=bool, default=False) + args.add_argument("--load_model", type=str, default=None) + args = args.parse_args() + L.seed_everything(args.seed) + main(args.train, args.test, args.save_model, args.load_model) diff --git a/model.py b/model.py index d873fc9..02cadbd 100644 --- a/model.py +++ b/model.py @@ -47,14 +47,17 @@ class UNetLightning(L.LightningModule): def training_step(self, batch, batch_idx): x, y = batch - # print(max(x.flatten()), min(x.flatten()), - # max(y.flatten()), min(y.flatten())) y_hat = self(x) loss = nn.CrossEntropyLoss()(y_hat, y) self.log("train_loss", loss) return loss + def log_hyperparameters(self): + for key, value in self.hparams.items(): + self.logger.experiment.log_param(key, value) + def configure_optimizers(self): + self.log_hyperparameters() optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) return optimizer @@ -74,39 +77,3 @@ class UNetLightning(L.LightningModule): y_hat = self(x) loss = nn.CrossEntropyLoss()(y_hat, y) self.log("test_loss", loss) - # visualize - # if batch_idx == 0: - # import matplotlib.pyplot as plt - # import numpy as np - # import torchvision.transforms as transforms - # x = transforms.ToPILImage()(x[0]) - # y = transforms.ToPILImage()(y[0]) - # y_hat = transforms.ToPILImage()(y_hat[0]) - # plt.figure(figsize=(15, 15)) - # plt.subplot(131) - # plt.imshow(np.array(x)) - # plt.title("Input") - # plt.axis("off") - # plt.subplot(132) - # plt.imshow(np.array(y)) - # plt.title("Ground Truth") - # plt.axis("off") - # plt.subplot(133) - # plt.imshow(np.array(y_hat)) - # plt.title("Prediction") - # plt.axis("off") - # plt.show() - - # def on_test_epoch_end(self): - # all_preds, all_labels = [], [] - # for output in self.trainer.predictions: - # # predicted values - # probs = list(output['logits'].cpu().detach().numpy()) - # labels = list(output['labels'].flatten().cpu().detach().numpy()) - # all_preds.extend(probs) - # all_labels.extend(labels) - - # # save predictions and labels - # import numpy as np - # np.save('predictions.npy', all_preds) - # np.save('labels.npy', all_labels) diff --git a/requirements.txt b/requirements.txt index 1ed36d5..8a57100 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,5 +2,5 @@ kaggle==1.6.6 matplotlib==3.6.3 Pillow==10.2.0 scikit_learn==1.2.2 -torch==2.0.0+cu117 -torchvision==0.15.1+cu117 +# torch==2.0.0+cu117 +# torchvision==0.15.1+cu117