mlflow + conda
This commit is contained in:
parent
9d91b561e8
commit
1d1c419a33
1
.gitignore
vendored
1
.gitignore
vendored
@ -176,3 +176,4 @@ ipython_config.py
|
||||
|
||||
football_dataset/
|
||||
venvium/
|
||||
lightning_logs/
|
11
MLProject
Normal file
11
MLProject
Normal file
@ -0,0 +1,11 @@
|
||||
name: iumKC
|
||||
|
||||
conda_env: environment.yaml
|
||||
|
||||
entry_points:
|
||||
main:
|
||||
command: "python train.py --train=true --save_model=true"
|
||||
test:
|
||||
parameters:
|
||||
model_filepath: "model"
|
||||
command: "python test.py --test=true --load_model=$model_filepath"
|
BIN
environment.yml
Normal file
BIN
environment.yml
Normal file
Binary file not shown.
1
id2label.json
Normal file
1
id2label.json
Normal file
@ -0,0 +1 @@
|
||||
{"0": "Goal Bar", "1": "Referee", "2": "Advertisement", "3": "Ground", "4": "Ball", "5": "Coaches & Officials", "6": "Team A", "7": "Team B", "8": "Goalkeeper A", "9": "Goalkeeper B", "10": "Audience"}
|
37
main.py
37
main.py
@ -1,6 +1,9 @@
|
||||
import argparse
|
||||
import random
|
||||
|
||||
from lightning import Trainer
|
||||
import lightning as L
|
||||
from lightning.pytorch.loggers import MLFlowLogger
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
@ -12,7 +15,7 @@ from data import (
|
||||
from model import UNetLightning
|
||||
|
||||
|
||||
def main():
|
||||
def main(train=True, test=True, save_model=False, load_model=None):
|
||||
all_images, all_paths = get_data()
|
||||
image_train_paths, image_test_paths, label_train_paths, label_test_paths = (
|
||||
train_test_split(all_images, all_paths, test_size=0.2, random_state=42)
|
||||
@ -28,11 +31,13 @@ def main():
|
||||
test_dataset = FootballSegDataset(
|
||||
image_test_paths, label_test_paths, test_mean, test_std, test_label_mean, test_label_std)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
||||
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
|
||||
train_loader = DataLoader(
|
||||
train_dataset, batch_size=32, shuffle=True, num_workers=7, persistent_workers=True)
|
||||
test_loader = DataLoader(test_dataset, batch_size=32,
|
||||
shuffle=False, num_workers=7, persistent_workers=True)
|
||||
|
||||
train_indices = random.sample(range(len(train_loader.dataset)), 5)
|
||||
test_indices = random.sample(range(len(test_loader.dataset)), 5)
|
||||
# train_indices = random.sample(range(len(train_loader.dataset)), 5)
|
||||
# test_indices = random.sample(range(len(test_loader.dataset)), 5)
|
||||
|
||||
# plot_random_images(train_indices, train_loader, "train")
|
||||
# plot_random_images(test_indices, test_loader, "test")
|
||||
@ -40,14 +45,32 @@ def main():
|
||||
# statistics.count_colors()
|
||||
# statistics.print_statistics()
|
||||
|
||||
mlFlowLogger = MLFlowLogger(
|
||||
experiment_name="football-segmentation", tracking_uri="http://127.0.0.1:8080")
|
||||
model = UNetLightning(3, 3, learning_rate=1e-3)
|
||||
trainer = Trainer(max_epochs=5, logger=True, log_every_n_steps=1)
|
||||
trainer = Trainer(max_epochs=2, logger=mlFlowLogger, log_every_n_steps=1)
|
||||
if train:
|
||||
trainer.fit(model, train_loader, test_loader)
|
||||
|
||||
if save_model:
|
||||
model.save_hyperparameters()
|
||||
|
||||
if load_model:
|
||||
model = UNetLightning.load_from_checkpoint(load_model)
|
||||
model.eval()
|
||||
model.freeze()
|
||||
|
||||
if test:
|
||||
trainer.test(model, test_loader)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
args = argparse.ArgumentParser()
|
||||
args.add_argument("--seed", type=int, default=42)
|
||||
args.add_argument("--train", type=bool, default=False)
|
||||
args.add_argument("--test", type=bool, default=False)
|
||||
args.add_argument("--save_model", type=bool, default=False)
|
||||
args.add_argument("--load_model", type=str, default=None)
|
||||
args = args.parse_args()
|
||||
L.seed_everything(args.seed)
|
||||
main(args.train, args.test, args.save_model, args.load_model)
|
||||
|
43
model.py
43
model.py
@ -47,14 +47,17 @@ class UNetLightning(L.LightningModule):
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
# print(max(x.flatten()), min(x.flatten()),
|
||||
# max(y.flatten()), min(y.flatten()))
|
||||
y_hat = self(x)
|
||||
loss = nn.CrossEntropyLoss()(y_hat, y)
|
||||
self.log("train_loss", loss)
|
||||
return loss
|
||||
|
||||
def log_hyperparameters(self):
|
||||
for key, value in self.hparams.items():
|
||||
self.logger.experiment.log_param(key, value)
|
||||
|
||||
def configure_optimizers(self):
|
||||
self.log_hyperparameters()
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
|
||||
return optimizer
|
||||
|
||||
@ -74,39 +77,3 @@ class UNetLightning(L.LightningModule):
|
||||
y_hat = self(x)
|
||||
loss = nn.CrossEntropyLoss()(y_hat, y)
|
||||
self.log("test_loss", loss)
|
||||
# visualize
|
||||
# if batch_idx == 0:
|
||||
# import matplotlib.pyplot as plt
|
||||
# import numpy as np
|
||||
# import torchvision.transforms as transforms
|
||||
# x = transforms.ToPILImage()(x[0])
|
||||
# y = transforms.ToPILImage()(y[0])
|
||||
# y_hat = transforms.ToPILImage()(y_hat[0])
|
||||
# plt.figure(figsize=(15, 15))
|
||||
# plt.subplot(131)
|
||||
# plt.imshow(np.array(x))
|
||||
# plt.title("Input")
|
||||
# plt.axis("off")
|
||||
# plt.subplot(132)
|
||||
# plt.imshow(np.array(y))
|
||||
# plt.title("Ground Truth")
|
||||
# plt.axis("off")
|
||||
# plt.subplot(133)
|
||||
# plt.imshow(np.array(y_hat))
|
||||
# plt.title("Prediction")
|
||||
# plt.axis("off")
|
||||
# plt.show()
|
||||
|
||||
# def on_test_epoch_end(self):
|
||||
# all_preds, all_labels = [], []
|
||||
# for output in self.trainer.predictions:
|
||||
# # predicted values
|
||||
# probs = list(output['logits'].cpu().detach().numpy())
|
||||
# labels = list(output['labels'].flatten().cpu().detach().numpy())
|
||||
# all_preds.extend(probs)
|
||||
# all_labels.extend(labels)
|
||||
|
||||
# # save predictions and labels
|
||||
# import numpy as np
|
||||
# np.save('predictions.npy', all_preds)
|
||||
# np.save('labels.npy', all_labels)
|
||||
|
@ -2,5 +2,5 @@ kaggle==1.6.6
|
||||
matplotlib==3.6.3
|
||||
Pillow==10.2.0
|
||||
scikit_learn==1.2.2
|
||||
torch==2.0.0+cu117
|
||||
torchvision==0.15.1+cu117
|
||||
# torch==2.0.0+cu117
|
||||
# torchvision==0.15.1+cu117
|
||||
|
Loading…
Reference in New Issue
Block a user