From 0fad11aa5c72dd393d942bd9625f2aa08f17bdfe Mon Sep 17 00:00:00 2001 From: Karol Cyganik Date: Sun, 14 Apr 2024 22:55:07 +0200 Subject: [PATCH] add model training --- data.py | 29 +++++++++----- main.py | 31 ++++++++++----- model.py | 112 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ test.py | 30 +++++++++++++++ 4 files changed, 183 insertions(+), 19 deletions(-) create mode 100644 model.py create mode 100644 test.py diff --git a/data.py b/data.py index e180905..350ac95 100644 --- a/data.py +++ b/data.py @@ -14,19 +14,28 @@ class FootballSegDataset(Dataset): self, image_paths, label_paths, - mean=[0.3468, 0.3885, 0.3321], - std=[0.2081, 0.2054, 0.2093], + img_mean=[0.3468, 0.3885, 0.3321], + img_std=[0.2081, 0.2054, 0.2093], + label_mean=[0.3468, 0.3885, 0.3321], + label_std=[0.2081, 0.2054, 0.2093], ): self.image_paths = image_paths self.label_paths = label_paths - self.mean = mean - self.std = std - self.transform = transforms.Compose( + self.transform_img = transforms.Compose( [ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize( - mean=[0.3468, 0.3885, 0.3321], std=[0.2081, 0.2054, 0.2093] + mean=img_mean, std=img_std + ), + ] + ) + self.transform_label = transforms.Compose( + [ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Normalize( + mean=label_mean, std=label_std ), ] ) @@ -52,9 +61,8 @@ class FootballSegDataset(Dataset): image = Image.open(self.image_paths[idx]).convert("RGB") label_map = Image.open(self.label_paths[idx]).convert("RGB") - if self.transform: - image = self.transform(image) - label_map = self.transform(label_map) + image = self.transform_img(image) + label_map = self.transform_label(label_map) return image, label_map @@ -112,7 +120,8 @@ def get_data(): api.dataset_download_files(dataset_slug, path=download_dir, unzip=True) all_images = glob(download_dir + "/images/*.jpg") - all_paths = [path.replace(".jpg", ".jpg___fuse.png") for path in all_images] + all_paths = [path.replace(".jpg", ".jpg___fuse.png") + for path in all_images] return all_images, all_paths diff --git a/main.py b/main.py index 40c9ba3..2c5eef8 100644 --- a/main.py +++ b/main.py @@ -1,15 +1,15 @@ import random +from lightning import Trainer from sklearn.model_selection import train_test_split from torch.utils.data import DataLoader from data import ( FootballSegDataset, - SegmentationStatistics, calculate_mean_std, get_data, - plot_random_images, ) +from model import UNetLightning def main(): @@ -19,9 +19,14 @@ def main(): ) mean, std = calculate_mean_std(image_train_paths) + label_mean, label_std = calculate_mean_std(label_train_paths) + test_mean, test_std = calculate_mean_std(image_test_paths) + test_label_mean, test_label_std = calculate_mean_std(label_test_paths) - train_dataset = FootballSegDataset(image_train_paths, label_train_paths, mean, std) - test_dataset = FootballSegDataset(image_test_paths, label_test_paths, mean, std) + train_dataset = FootballSegDataset( + image_train_paths, label_train_paths, mean, std, label_mean, label_std) + test_dataset = FootballSegDataset( + image_test_paths, label_test_paths, test_mean, test_std, test_label_mean, test_label_std) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) @@ -29,11 +34,19 @@ def main(): train_indices = random.sample(range(len(train_loader.dataset)), 5) test_indices = random.sample(range(len(test_loader.dataset)), 5) - plot_random_images(train_indices, train_loader, 'train') - plot_random_images(test_indices, test_loader, 'test') - statistics = SegmentationStatistics(all_paths, train_loader, test_loader, mean, std) - statistics.count_colors() - statistics.print_statistics() + # plot_random_images(train_indices, train_loader, "train") + # plot_random_images(test_indices, test_loader, "test") + # statistics = SegmentationStatistics(all_paths, train_loader, test_loader, mean, std) + # statistics.count_colors() + # statistics.print_statistics() + + model = UNetLightning(3, 3, learning_rate=1e-3) + trainer = Trainer(max_epochs=5, logger=True, log_every_n_steps=1) + trainer.fit(model, train_loader, test_loader) + + model.eval() + model.freeze() + trainer.test(model, test_loader) if __name__ == "__main__": diff --git a/model.py b/model.py new file mode 100644 index 0000000..d873fc9 --- /dev/null +++ b/model.py @@ -0,0 +1,112 @@ +import lightning as L +import torch +import torch.nn as nn + + +class UNet(nn.Module): + def __init__(self, in_channels, out_channels): + super(UNet, self).__init__() + self.encoder = nn.Sequential( + nn.Conv2d(in_channels, 64, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(64, 64, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2, stride=2), + ) + self.middle = nn.Sequential( + nn.Conv2d(64, 128, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(128, 128, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2, stride=2), + ) + self.decoder = nn.Sequential( + nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2), + nn.ReLU(inplace=True), + nn.Conv2d(64, 64, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(64, out_channels, kernel_size=2, stride=2), + nn.Sigmoid(), + ) + + def forward(self, x): + x1 = self.encoder(x) + x2 = self.middle(x1) + x3 = self.decoder(x2) + return x3 + + +class UNetLightning(L.LightningModule): + def __init__(self, in_channels=3, out_channels=3, learning_rate=1e-3): + super(UNetLightning, self).__init__() + self.model = UNet(in_channels, out_channels) + self.learning_rate = learning_rate + + def forward(self, x): + return self.model(x) + + def training_step(self, batch, batch_idx): + x, y = batch + # print(max(x.flatten()), min(x.flatten()), + # max(y.flatten()), min(y.flatten())) + y_hat = self(x) + loss = nn.CrossEntropyLoss()(y_hat, y) + self.log("train_loss", loss) + return loss + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) + return optimizer + + def validation_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = nn.CrossEntropyLoss()(y_hat, y) + self.log("val_loss", loss) + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + x, y = batch + y_hat = self(x) + return y_hat + + def test_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = nn.CrossEntropyLoss()(y_hat, y) + self.log("test_loss", loss) + # visualize + # if batch_idx == 0: + # import matplotlib.pyplot as plt + # import numpy as np + # import torchvision.transforms as transforms + # x = transforms.ToPILImage()(x[0]) + # y = transforms.ToPILImage()(y[0]) + # y_hat = transforms.ToPILImage()(y_hat[0]) + # plt.figure(figsize=(15, 15)) + # plt.subplot(131) + # plt.imshow(np.array(x)) + # plt.title("Input") + # plt.axis("off") + # plt.subplot(132) + # plt.imshow(np.array(y)) + # plt.title("Ground Truth") + # plt.axis("off") + # plt.subplot(133) + # plt.imshow(np.array(y_hat)) + # plt.title("Prediction") + # plt.axis("off") + # plt.show() + + # def on_test_epoch_end(self): + # all_preds, all_labels = [], [] + # for output in self.trainer.predictions: + # # predicted values + # probs = list(output['logits'].cpu().detach().numpy()) + # labels = list(output['labels'].flatten().cpu().detach().numpy()) + # all_preds.extend(probs) + # all_labels.extend(labels) + + # # save predictions and labels + # import numpy as np + # np.save('predictions.npy', all_preds) + # np.save('labels.npy', all_labels) diff --git a/test.py b/test.py new file mode 100644 index 0000000..92b5789 --- /dev/null +++ b/test.py @@ -0,0 +1,30 @@ +from model import UNetLightning +from data import get_data, calculate_mean_std, FootballSegDataset +from sklearn.model_selection import train_test_split +from torch.utils.data import DataLoader +from lightning import Trainer +import numpy as np + +ckpt_path = "lightning_logs/version_0/checkpoints/epoch=4-step=15.ckpt" +hparams_path = "lightning_logs/version_0/hparams.yaml" +model = UNetLightning.load_from_checkpoint( + ckpt_path, hparams_path=hparams_path) +model.eval() +model.freeze() + +all_images, all_paths = get_data() +image_train_paths, image_test_paths, label_train_paths, label_test_paths = ( + train_test_split(all_images, all_paths, test_size=0.2, random_state=42) +) + +mean, std = calculate_mean_std(image_test_paths) +test_dataset = FootballSegDataset( + image_test_paths, label_test_paths, mean, std) +test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) + +trainer = Trainer() +predictions = trainer.predict(model, test_loader) +arrays = [] +for batch in predictions: + arrays.append(batch.cpu().numpy()) +np.save("predictions.npy", np.concatenate(arrays))