mlflow + conda

This commit is contained in:
Karol Cyganik 2024-05-13 18:51:43 +02:00
parent 9d91b561e8
commit 1d1c419a33
7 changed files with 55 additions and 52 deletions

1
.gitignore vendored
View File

@ -176,3 +176,4 @@ ipython_config.py
football_dataset/ football_dataset/
venvium/ venvium/
lightning_logs/

11
MLProject Normal file
View File

@ -0,0 +1,11 @@
name: iumKC
conda_env: environment.yaml
entry_points:
main:
command: "python train.py --train=true --save_model=true"
test:
parameters:
model_filepath: "model"
command: "python test.py --test=true --load_model=$model_filepath"

BIN
environment.yml Normal file

Binary file not shown.

1
id2label.json Normal file
View File

@ -0,0 +1 @@
{"0": "Goal Bar", "1": "Referee", "2": "Advertisement", "3": "Ground", "4": "Ball", "5": "Coaches & Officials", "6": "Team A", "7": "Team B", "8": "Goalkeeper A", "9": "Goalkeeper B", "10": "Audience"}

37
main.py
View File

@ -1,6 +1,9 @@
import argparse
import random import random
from lightning import Trainer from lightning import Trainer
import lightning as L
from lightning.pytorch.loggers import MLFlowLogger
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -12,7 +15,7 @@ from data import (
from model import UNetLightning from model import UNetLightning
def main(): def main(train=True, test=True, save_model=False, load_model=None):
all_images, all_paths = get_data() all_images, all_paths = get_data()
image_train_paths, image_test_paths, label_train_paths, label_test_paths = ( image_train_paths, image_test_paths, label_train_paths, label_test_paths = (
train_test_split(all_images, all_paths, test_size=0.2, random_state=42) train_test_split(all_images, all_paths, test_size=0.2, random_state=42)
@ -28,11 +31,13 @@ def main():
test_dataset = FootballSegDataset( test_dataset = FootballSegDataset(
image_test_paths, label_test_paths, test_mean, test_std, test_label_mean, test_label_std) image_test_paths, label_test_paths, test_mean, test_std, test_label_mean, test_label_std)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) train_loader = DataLoader(
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) train_dataset, batch_size=32, shuffle=True, num_workers=7, persistent_workers=True)
test_loader = DataLoader(test_dataset, batch_size=32,
shuffle=False, num_workers=7, persistent_workers=True)
train_indices = random.sample(range(len(train_loader.dataset)), 5) # train_indices = random.sample(range(len(train_loader.dataset)), 5)
test_indices = random.sample(range(len(test_loader.dataset)), 5) # test_indices = random.sample(range(len(test_loader.dataset)), 5)
# plot_random_images(train_indices, train_loader, "train") # plot_random_images(train_indices, train_loader, "train")
# plot_random_images(test_indices, test_loader, "test") # plot_random_images(test_indices, test_loader, "test")
@ -40,14 +45,32 @@ def main():
# statistics.count_colors() # statistics.count_colors()
# statistics.print_statistics() # statistics.print_statistics()
mlFlowLogger = MLFlowLogger(
experiment_name="football-segmentation", tracking_uri="http://127.0.0.1:8080")
model = UNetLightning(3, 3, learning_rate=1e-3) model = UNetLightning(3, 3, learning_rate=1e-3)
trainer = Trainer(max_epochs=5, logger=True, log_every_n_steps=1) trainer = Trainer(max_epochs=2, logger=mlFlowLogger, log_every_n_steps=1)
if train:
trainer.fit(model, train_loader, test_loader) trainer.fit(model, train_loader, test_loader)
if save_model:
model.save_hyperparameters()
if load_model:
model = UNetLightning.load_from_checkpoint(load_model)
model.eval() model.eval()
model.freeze() model.freeze()
if test:
trainer.test(model, test_loader) trainer.test(model, test_loader)
if __name__ == "__main__": if __name__ == "__main__":
main() args = argparse.ArgumentParser()
args.add_argument("--seed", type=int, default=42)
args.add_argument("--train", type=bool, default=False)
args.add_argument("--test", type=bool, default=False)
args.add_argument("--save_model", type=bool, default=False)
args.add_argument("--load_model", type=str, default=None)
args = args.parse_args()
L.seed_everything(args.seed)
main(args.train, args.test, args.save_model, args.load_model)

View File

@ -47,14 +47,17 @@ class UNetLightning(L.LightningModule):
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
x, y = batch x, y = batch
# print(max(x.flatten()), min(x.flatten()),
# max(y.flatten()), min(y.flatten()))
y_hat = self(x) y_hat = self(x)
loss = nn.CrossEntropyLoss()(y_hat, y) loss = nn.CrossEntropyLoss()(y_hat, y)
self.log("train_loss", loss) self.log("train_loss", loss)
return loss return loss
def log_hyperparameters(self):
for key, value in self.hparams.items():
self.logger.experiment.log_param(key, value)
def configure_optimizers(self): def configure_optimizers(self):
self.log_hyperparameters()
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
return optimizer return optimizer
@ -74,39 +77,3 @@ class UNetLightning(L.LightningModule):
y_hat = self(x) y_hat = self(x)
loss = nn.CrossEntropyLoss()(y_hat, y) loss = nn.CrossEntropyLoss()(y_hat, y)
self.log("test_loss", loss) self.log("test_loss", loss)
# visualize
# if batch_idx == 0:
# import matplotlib.pyplot as plt
# import numpy as np
# import torchvision.transforms as transforms
# x = transforms.ToPILImage()(x[0])
# y = transforms.ToPILImage()(y[0])
# y_hat = transforms.ToPILImage()(y_hat[0])
# plt.figure(figsize=(15, 15))
# plt.subplot(131)
# plt.imshow(np.array(x))
# plt.title("Input")
# plt.axis("off")
# plt.subplot(132)
# plt.imshow(np.array(y))
# plt.title("Ground Truth")
# plt.axis("off")
# plt.subplot(133)
# plt.imshow(np.array(y_hat))
# plt.title("Prediction")
# plt.axis("off")
# plt.show()
# def on_test_epoch_end(self):
# all_preds, all_labels = [], []
# for output in self.trainer.predictions:
# # predicted values
# probs = list(output['logits'].cpu().detach().numpy())
# labels = list(output['labels'].flatten().cpu().detach().numpy())
# all_preds.extend(probs)
# all_labels.extend(labels)
# # save predictions and labels
# import numpy as np
# np.save('predictions.npy', all_preds)
# np.save('labels.npy', all_labels)

View File

@ -2,5 +2,5 @@ kaggle==1.6.6
matplotlib==3.6.3 matplotlib==3.6.3
Pillow==10.2.0 Pillow==10.2.0
scikit_learn==1.2.2 scikit_learn==1.2.2
torch==2.0.0+cu117 # torch==2.0.0+cu117
torchvision==0.15.1+cu117 # torchvision==0.15.1+cu117