mlflow + conda
This commit is contained in:
parent
9d91b561e8
commit
1d1c419a33
1
.gitignore
vendored
1
.gitignore
vendored
@ -176,3 +176,4 @@ ipython_config.py
|
|||||||
|
|
||||||
football_dataset/
|
football_dataset/
|
||||||
venvium/
|
venvium/
|
||||||
|
lightning_logs/
|
11
MLProject
Normal file
11
MLProject
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
name: iumKC
|
||||||
|
|
||||||
|
conda_env: environment.yaml
|
||||||
|
|
||||||
|
entry_points:
|
||||||
|
main:
|
||||||
|
command: "python train.py --train=true --save_model=true"
|
||||||
|
test:
|
||||||
|
parameters:
|
||||||
|
model_filepath: "model"
|
||||||
|
command: "python test.py --test=true --load_model=$model_filepath"
|
BIN
environment.yml
Normal file
BIN
environment.yml
Normal file
Binary file not shown.
1
id2label.json
Normal file
1
id2label.json
Normal file
@ -0,0 +1 @@
|
|||||||
|
{"0": "Goal Bar", "1": "Referee", "2": "Advertisement", "3": "Ground", "4": "Ball", "5": "Coaches & Officials", "6": "Team A", "7": "Team B", "8": "Goalkeeper A", "9": "Goalkeeper B", "10": "Audience"}
|
45
main.py
45
main.py
@ -1,6 +1,9 @@
|
|||||||
|
import argparse
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from lightning import Trainer
|
from lightning import Trainer
|
||||||
|
import lightning as L
|
||||||
|
from lightning.pytorch.loggers import MLFlowLogger
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
@ -12,7 +15,7 @@ from data import (
|
|||||||
from model import UNetLightning
|
from model import UNetLightning
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main(train=True, test=True, save_model=False, load_model=None):
|
||||||
all_images, all_paths = get_data()
|
all_images, all_paths = get_data()
|
||||||
image_train_paths, image_test_paths, label_train_paths, label_test_paths = (
|
image_train_paths, image_test_paths, label_train_paths, label_test_paths = (
|
||||||
train_test_split(all_images, all_paths, test_size=0.2, random_state=42)
|
train_test_split(all_images, all_paths, test_size=0.2, random_state=42)
|
||||||
@ -28,11 +31,13 @@ def main():
|
|||||||
test_dataset = FootballSegDataset(
|
test_dataset = FootballSegDataset(
|
||||||
image_test_paths, label_test_paths, test_mean, test_std, test_label_mean, test_label_std)
|
image_test_paths, label_test_paths, test_mean, test_std, test_label_mean, test_label_std)
|
||||||
|
|
||||||
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
train_loader = DataLoader(
|
||||||
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
|
train_dataset, batch_size=32, shuffle=True, num_workers=7, persistent_workers=True)
|
||||||
|
test_loader = DataLoader(test_dataset, batch_size=32,
|
||||||
|
shuffle=False, num_workers=7, persistent_workers=True)
|
||||||
|
|
||||||
train_indices = random.sample(range(len(train_loader.dataset)), 5)
|
# train_indices = random.sample(range(len(train_loader.dataset)), 5)
|
||||||
test_indices = random.sample(range(len(test_loader.dataset)), 5)
|
# test_indices = random.sample(range(len(test_loader.dataset)), 5)
|
||||||
|
|
||||||
# plot_random_images(train_indices, train_loader, "train")
|
# plot_random_images(train_indices, train_loader, "train")
|
||||||
# plot_random_images(test_indices, test_loader, "test")
|
# plot_random_images(test_indices, test_loader, "test")
|
||||||
@ -40,14 +45,32 @@ def main():
|
|||||||
# statistics.count_colors()
|
# statistics.count_colors()
|
||||||
# statistics.print_statistics()
|
# statistics.print_statistics()
|
||||||
|
|
||||||
|
mlFlowLogger = MLFlowLogger(
|
||||||
|
experiment_name="football-segmentation", tracking_uri="http://127.0.0.1:8080")
|
||||||
model = UNetLightning(3, 3, learning_rate=1e-3)
|
model = UNetLightning(3, 3, learning_rate=1e-3)
|
||||||
trainer = Trainer(max_epochs=5, logger=True, log_every_n_steps=1)
|
trainer = Trainer(max_epochs=2, logger=mlFlowLogger, log_every_n_steps=1)
|
||||||
trainer.fit(model, train_loader, test_loader)
|
if train:
|
||||||
|
trainer.fit(model, train_loader, test_loader)
|
||||||
|
|
||||||
model.eval()
|
if save_model:
|
||||||
model.freeze()
|
model.save_hyperparameters()
|
||||||
trainer.test(model, test_loader)
|
|
||||||
|
if load_model:
|
||||||
|
model = UNetLightning.load_from_checkpoint(load_model)
|
||||||
|
model.eval()
|
||||||
|
model.freeze()
|
||||||
|
|
||||||
|
if test:
|
||||||
|
trainer.test(model, test_loader)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
args = argparse.ArgumentParser()
|
||||||
|
args.add_argument("--seed", type=int, default=42)
|
||||||
|
args.add_argument("--train", type=bool, default=False)
|
||||||
|
args.add_argument("--test", type=bool, default=False)
|
||||||
|
args.add_argument("--save_model", type=bool, default=False)
|
||||||
|
args.add_argument("--load_model", type=str, default=None)
|
||||||
|
args = args.parse_args()
|
||||||
|
L.seed_everything(args.seed)
|
||||||
|
main(args.train, args.test, args.save_model, args.load_model)
|
||||||
|
43
model.py
43
model.py
@ -47,14 +47,17 @@ class UNetLightning(L.LightningModule):
|
|||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def training_step(self, batch, batch_idx):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
# print(max(x.flatten()), min(x.flatten()),
|
|
||||||
# max(y.flatten()), min(y.flatten()))
|
|
||||||
y_hat = self(x)
|
y_hat = self(x)
|
||||||
loss = nn.CrossEntropyLoss()(y_hat, y)
|
loss = nn.CrossEntropyLoss()(y_hat, y)
|
||||||
self.log("train_loss", loss)
|
self.log("train_loss", loss)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
def log_hyperparameters(self):
|
||||||
|
for key, value in self.hparams.items():
|
||||||
|
self.logger.experiment.log_param(key, value)
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
|
self.log_hyperparameters()
|
||||||
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
|
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
@ -74,39 +77,3 @@ class UNetLightning(L.LightningModule):
|
|||||||
y_hat = self(x)
|
y_hat = self(x)
|
||||||
loss = nn.CrossEntropyLoss()(y_hat, y)
|
loss = nn.CrossEntropyLoss()(y_hat, y)
|
||||||
self.log("test_loss", loss)
|
self.log("test_loss", loss)
|
||||||
# visualize
|
|
||||||
# if batch_idx == 0:
|
|
||||||
# import matplotlib.pyplot as plt
|
|
||||||
# import numpy as np
|
|
||||||
# import torchvision.transforms as transforms
|
|
||||||
# x = transforms.ToPILImage()(x[0])
|
|
||||||
# y = transforms.ToPILImage()(y[0])
|
|
||||||
# y_hat = transforms.ToPILImage()(y_hat[0])
|
|
||||||
# plt.figure(figsize=(15, 15))
|
|
||||||
# plt.subplot(131)
|
|
||||||
# plt.imshow(np.array(x))
|
|
||||||
# plt.title("Input")
|
|
||||||
# plt.axis("off")
|
|
||||||
# plt.subplot(132)
|
|
||||||
# plt.imshow(np.array(y))
|
|
||||||
# plt.title("Ground Truth")
|
|
||||||
# plt.axis("off")
|
|
||||||
# plt.subplot(133)
|
|
||||||
# plt.imshow(np.array(y_hat))
|
|
||||||
# plt.title("Prediction")
|
|
||||||
# plt.axis("off")
|
|
||||||
# plt.show()
|
|
||||||
|
|
||||||
# def on_test_epoch_end(self):
|
|
||||||
# all_preds, all_labels = [], []
|
|
||||||
# for output in self.trainer.predictions:
|
|
||||||
# # predicted values
|
|
||||||
# probs = list(output['logits'].cpu().detach().numpy())
|
|
||||||
# labels = list(output['labels'].flatten().cpu().detach().numpy())
|
|
||||||
# all_preds.extend(probs)
|
|
||||||
# all_labels.extend(labels)
|
|
||||||
|
|
||||||
# # save predictions and labels
|
|
||||||
# import numpy as np
|
|
||||||
# np.save('predictions.npy', all_preds)
|
|
||||||
# np.save('labels.npy', all_labels)
|
|
||||||
|
@ -2,5 +2,5 @@ kaggle==1.6.6
|
|||||||
matplotlib==3.6.3
|
matplotlib==3.6.3
|
||||||
Pillow==10.2.0
|
Pillow==10.2.0
|
||||||
scikit_learn==1.2.2
|
scikit_learn==1.2.2
|
||||||
torch==2.0.0+cu117
|
# torch==2.0.0+cu117
|
||||||
torchvision==0.15.1+cu117
|
# torchvision==0.15.1+cu117
|
||||||
|
Loading…
Reference in New Issue
Block a user