add model training

This commit is contained in:
Karol Cyganik 2024-04-14 22:55:07 +02:00
parent 094e93bd2d
commit 0fad11aa5c
4 changed files with 183 additions and 19 deletions

29
data.py
View File

@ -14,19 +14,28 @@ class FootballSegDataset(Dataset):
self, self,
image_paths, image_paths,
label_paths, label_paths,
mean=[0.3468, 0.3885, 0.3321], img_mean=[0.3468, 0.3885, 0.3321],
std=[0.2081, 0.2054, 0.2093], img_std=[0.2081, 0.2054, 0.2093],
label_mean=[0.3468, 0.3885, 0.3321],
label_std=[0.2081, 0.2054, 0.2093],
): ):
self.image_paths = image_paths self.image_paths = image_paths
self.label_paths = label_paths self.label_paths = label_paths
self.mean = mean self.transform_img = transforms.Compose(
self.std = std
self.transform = transforms.Compose(
[ [
transforms.Resize((256, 256)), transforms.Resize((256, 256)),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize( transforms.Normalize(
mean=[0.3468, 0.3885, 0.3321], std=[0.2081, 0.2054, 0.2093] mean=img_mean, std=img_std
),
]
)
self.transform_label = transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(
mean=label_mean, std=label_std
), ),
] ]
) )
@ -52,9 +61,8 @@ class FootballSegDataset(Dataset):
image = Image.open(self.image_paths[idx]).convert("RGB") image = Image.open(self.image_paths[idx]).convert("RGB")
label_map = Image.open(self.label_paths[idx]).convert("RGB") label_map = Image.open(self.label_paths[idx]).convert("RGB")
if self.transform: image = self.transform_img(image)
image = self.transform(image) label_map = self.transform_label(label_map)
label_map = self.transform(label_map)
return image, label_map return image, label_map
@ -112,7 +120,8 @@ def get_data():
api.dataset_download_files(dataset_slug, path=download_dir, unzip=True) api.dataset_download_files(dataset_slug, path=download_dir, unzip=True)
all_images = glob(download_dir + "/images/*.jpg") all_images = glob(download_dir + "/images/*.jpg")
all_paths = [path.replace(".jpg", ".jpg___fuse.png") for path in all_images] all_paths = [path.replace(".jpg", ".jpg___fuse.png")
for path in all_images]
return all_images, all_paths return all_images, all_paths

31
main.py
View File

@ -1,15 +1,15 @@
import random import random
from lightning import Trainer
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from data import ( from data import (
FootballSegDataset, FootballSegDataset,
SegmentationStatistics,
calculate_mean_std, calculate_mean_std,
get_data, get_data,
plot_random_images,
) )
from model import UNetLightning
def main(): def main():
@ -19,9 +19,14 @@ def main():
) )
mean, std = calculate_mean_std(image_train_paths) mean, std = calculate_mean_std(image_train_paths)
label_mean, label_std = calculate_mean_std(label_train_paths)
test_mean, test_std = calculate_mean_std(image_test_paths)
test_label_mean, test_label_std = calculate_mean_std(label_test_paths)
train_dataset = FootballSegDataset(image_train_paths, label_train_paths, mean, std) train_dataset = FootballSegDataset(
test_dataset = FootballSegDataset(image_test_paths, label_test_paths, mean, std) image_train_paths, label_train_paths, mean, std, label_mean, label_std)
test_dataset = FootballSegDataset(
image_test_paths, label_test_paths, test_mean, test_std, test_label_mean, test_label_std)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
@ -29,11 +34,19 @@ def main():
train_indices = random.sample(range(len(train_loader.dataset)), 5) train_indices = random.sample(range(len(train_loader.dataset)), 5)
test_indices = random.sample(range(len(test_loader.dataset)), 5) test_indices = random.sample(range(len(test_loader.dataset)), 5)
plot_random_images(train_indices, train_loader, 'train') # plot_random_images(train_indices, train_loader, "train")
plot_random_images(test_indices, test_loader, 'test') # plot_random_images(test_indices, test_loader, "test")
statistics = SegmentationStatistics(all_paths, train_loader, test_loader, mean, std) # statistics = SegmentationStatistics(all_paths, train_loader, test_loader, mean, std)
statistics.count_colors() # statistics.count_colors()
statistics.print_statistics() # statistics.print_statistics()
model = UNetLightning(3, 3, learning_rate=1e-3)
trainer = Trainer(max_epochs=5, logger=True, log_every_n_steps=1)
trainer.fit(model, train_loader, test_loader)
model.eval()
model.freeze()
trainer.test(model, test_loader)
if __name__ == "__main__": if __name__ == "__main__":

112
model.py Normal file
View File

@ -0,0 +1,112 @@
import lightning as L
import torch
import torch.nn as nn
class UNet(nn.Module):
def __init__(self, in_channels, out_channels):
super(UNet, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.middle = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, out_channels, kernel_size=2, stride=2),
nn.Sigmoid(),
)
def forward(self, x):
x1 = self.encoder(x)
x2 = self.middle(x1)
x3 = self.decoder(x2)
return x3
class UNetLightning(L.LightningModule):
def __init__(self, in_channels=3, out_channels=3, learning_rate=1e-3):
super(UNetLightning, self).__init__()
self.model = UNet(in_channels, out_channels)
self.learning_rate = learning_rate
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
# print(max(x.flatten()), min(x.flatten()),
# max(y.flatten()), min(y.flatten()))
y_hat = self(x)
loss = nn.CrossEntropyLoss()(y_hat, y)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
return optimizer
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.CrossEntropyLoss()(y_hat, y)
self.log("val_loss", loss)
def predict_step(self, batch, batch_idx, dataloader_idx=0):
x, y = batch
y_hat = self(x)
return y_hat
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.CrossEntropyLoss()(y_hat, y)
self.log("test_loss", loss)
# visualize
# if batch_idx == 0:
# import matplotlib.pyplot as plt
# import numpy as np
# import torchvision.transforms as transforms
# x = transforms.ToPILImage()(x[0])
# y = transforms.ToPILImage()(y[0])
# y_hat = transforms.ToPILImage()(y_hat[0])
# plt.figure(figsize=(15, 15))
# plt.subplot(131)
# plt.imshow(np.array(x))
# plt.title("Input")
# plt.axis("off")
# plt.subplot(132)
# plt.imshow(np.array(y))
# plt.title("Ground Truth")
# plt.axis("off")
# plt.subplot(133)
# plt.imshow(np.array(y_hat))
# plt.title("Prediction")
# plt.axis("off")
# plt.show()
# def on_test_epoch_end(self):
# all_preds, all_labels = [], []
# for output in self.trainer.predictions:
# # predicted values
# probs = list(output['logits'].cpu().detach().numpy())
# labels = list(output['labels'].flatten().cpu().detach().numpy())
# all_preds.extend(probs)
# all_labels.extend(labels)
# # save predictions and labels
# import numpy as np
# np.save('predictions.npy', all_preds)
# np.save('labels.npy', all_labels)

30
test.py Normal file
View File

@ -0,0 +1,30 @@
from model import UNetLightning
from data import get_data, calculate_mean_std, FootballSegDataset
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from lightning import Trainer
import numpy as np
ckpt_path = "lightning_logs/version_0/checkpoints/epoch=4-step=15.ckpt"
hparams_path = "lightning_logs/version_0/hparams.yaml"
model = UNetLightning.load_from_checkpoint(
ckpt_path, hparams_path=hparams_path)
model.eval()
model.freeze()
all_images, all_paths = get_data()
image_train_paths, image_test_paths, label_train_paths, label_test_paths = (
train_test_split(all_images, all_paths, test_size=0.2, random_state=42)
)
mean, std = calculate_mean_std(image_test_paths)
test_dataset = FootballSegDataset(
image_test_paths, label_test_paths, mean, std)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
trainer = Trainer()
predictions = trainer.predict(model, test_loader)
arrays = []
for batch in predictions:
arrays.append(batch.cpu().numpy())
np.save("predictions.npy", np.concatenate(arrays))