add model training
This commit is contained in:
parent
094e93bd2d
commit
0fad11aa5c
29
data.py
29
data.py
@ -14,19 +14,28 @@ class FootballSegDataset(Dataset):
|
||||
self,
|
||||
image_paths,
|
||||
label_paths,
|
||||
mean=[0.3468, 0.3885, 0.3321],
|
||||
std=[0.2081, 0.2054, 0.2093],
|
||||
img_mean=[0.3468, 0.3885, 0.3321],
|
||||
img_std=[0.2081, 0.2054, 0.2093],
|
||||
label_mean=[0.3468, 0.3885, 0.3321],
|
||||
label_std=[0.2081, 0.2054, 0.2093],
|
||||
):
|
||||
self.image_paths = image_paths
|
||||
self.label_paths = label_paths
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.transform = transforms.Compose(
|
||||
self.transform_img = transforms.Compose(
|
||||
[
|
||||
transforms.Resize((256, 256)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.3468, 0.3885, 0.3321], std=[0.2081, 0.2054, 0.2093]
|
||||
mean=img_mean, std=img_std
|
||||
),
|
||||
]
|
||||
)
|
||||
self.transform_label = transforms.Compose(
|
||||
[
|
||||
transforms.Resize((256, 256)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=label_mean, std=label_std
|
||||
),
|
||||
]
|
||||
)
|
||||
@ -52,9 +61,8 @@ class FootballSegDataset(Dataset):
|
||||
image = Image.open(self.image_paths[idx]).convert("RGB")
|
||||
label_map = Image.open(self.label_paths[idx]).convert("RGB")
|
||||
|
||||
if self.transform:
|
||||
image = self.transform(image)
|
||||
label_map = self.transform(label_map)
|
||||
image = self.transform_img(image)
|
||||
label_map = self.transform_label(label_map)
|
||||
|
||||
return image, label_map
|
||||
|
||||
@ -112,7 +120,8 @@ def get_data():
|
||||
api.dataset_download_files(dataset_slug, path=download_dir, unzip=True)
|
||||
|
||||
all_images = glob(download_dir + "/images/*.jpg")
|
||||
all_paths = [path.replace(".jpg", ".jpg___fuse.png") for path in all_images]
|
||||
all_paths = [path.replace(".jpg", ".jpg___fuse.png")
|
||||
for path in all_images]
|
||||
return all_images, all_paths
|
||||
|
||||
|
||||
|
31
main.py
31
main.py
@ -1,15 +1,15 @@
|
||||
import random
|
||||
|
||||
from lightning import Trainer
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from data import (
|
||||
FootballSegDataset,
|
||||
SegmentationStatistics,
|
||||
calculate_mean_std,
|
||||
get_data,
|
||||
plot_random_images,
|
||||
)
|
||||
from model import UNetLightning
|
||||
|
||||
|
||||
def main():
|
||||
@ -19,9 +19,14 @@ def main():
|
||||
)
|
||||
|
||||
mean, std = calculate_mean_std(image_train_paths)
|
||||
label_mean, label_std = calculate_mean_std(label_train_paths)
|
||||
test_mean, test_std = calculate_mean_std(image_test_paths)
|
||||
test_label_mean, test_label_std = calculate_mean_std(label_test_paths)
|
||||
|
||||
train_dataset = FootballSegDataset(image_train_paths, label_train_paths, mean, std)
|
||||
test_dataset = FootballSegDataset(image_test_paths, label_test_paths, mean, std)
|
||||
train_dataset = FootballSegDataset(
|
||||
image_train_paths, label_train_paths, mean, std, label_mean, label_std)
|
||||
test_dataset = FootballSegDataset(
|
||||
image_test_paths, label_test_paths, test_mean, test_std, test_label_mean, test_label_std)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
||||
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
|
||||
@ -29,11 +34,19 @@ def main():
|
||||
train_indices = random.sample(range(len(train_loader.dataset)), 5)
|
||||
test_indices = random.sample(range(len(test_loader.dataset)), 5)
|
||||
|
||||
plot_random_images(train_indices, train_loader, 'train')
|
||||
plot_random_images(test_indices, test_loader, 'test')
|
||||
statistics = SegmentationStatistics(all_paths, train_loader, test_loader, mean, std)
|
||||
statistics.count_colors()
|
||||
statistics.print_statistics()
|
||||
# plot_random_images(train_indices, train_loader, "train")
|
||||
# plot_random_images(test_indices, test_loader, "test")
|
||||
# statistics = SegmentationStatistics(all_paths, train_loader, test_loader, mean, std)
|
||||
# statistics.count_colors()
|
||||
# statistics.print_statistics()
|
||||
|
||||
model = UNetLightning(3, 3, learning_rate=1e-3)
|
||||
trainer = Trainer(max_epochs=5, logger=True, log_every_n_steps=1)
|
||||
trainer.fit(model, train_loader, test_loader)
|
||||
|
||||
model.eval()
|
||||
model.freeze()
|
||||
trainer.test(model, test_loader)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
112
model.py
Normal file
112
model.py
Normal file
@ -0,0 +1,112 @@
|
||||
import lightning as L
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(UNet, self).__init__()
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
)
|
||||
self.middle = nn.Sequential(
|
||||
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 128, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
)
|
||||
self.decoder = nn.Sequential(
|
||||
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.ConvTranspose2d(64, out_channels, kernel_size=2, stride=2),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.encoder(x)
|
||||
x2 = self.middle(x1)
|
||||
x3 = self.decoder(x2)
|
||||
return x3
|
||||
|
||||
|
||||
class UNetLightning(L.LightningModule):
|
||||
def __init__(self, in_channels=3, out_channels=3, learning_rate=1e-3):
|
||||
super(UNetLightning, self).__init__()
|
||||
self.model = UNet(in_channels, out_channels)
|
||||
self.learning_rate = learning_rate
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
# print(max(x.flatten()), min(x.flatten()),
|
||||
# max(y.flatten()), min(y.flatten()))
|
||||
y_hat = self(x)
|
||||
loss = nn.CrossEntropyLoss()(y_hat, y)
|
||||
self.log("train_loss", loss)
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
|
||||
return optimizer
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self(x)
|
||||
loss = nn.CrossEntropyLoss()(y_hat, y)
|
||||
self.log("val_loss", loss)
|
||||
|
||||
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
||||
x, y = batch
|
||||
y_hat = self(x)
|
||||
return y_hat
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self(x)
|
||||
loss = nn.CrossEntropyLoss()(y_hat, y)
|
||||
self.log("test_loss", loss)
|
||||
# visualize
|
||||
# if batch_idx == 0:
|
||||
# import matplotlib.pyplot as plt
|
||||
# import numpy as np
|
||||
# import torchvision.transforms as transforms
|
||||
# x = transforms.ToPILImage()(x[0])
|
||||
# y = transforms.ToPILImage()(y[0])
|
||||
# y_hat = transforms.ToPILImage()(y_hat[0])
|
||||
# plt.figure(figsize=(15, 15))
|
||||
# plt.subplot(131)
|
||||
# plt.imshow(np.array(x))
|
||||
# plt.title("Input")
|
||||
# plt.axis("off")
|
||||
# plt.subplot(132)
|
||||
# plt.imshow(np.array(y))
|
||||
# plt.title("Ground Truth")
|
||||
# plt.axis("off")
|
||||
# plt.subplot(133)
|
||||
# plt.imshow(np.array(y_hat))
|
||||
# plt.title("Prediction")
|
||||
# plt.axis("off")
|
||||
# plt.show()
|
||||
|
||||
# def on_test_epoch_end(self):
|
||||
# all_preds, all_labels = [], []
|
||||
# for output in self.trainer.predictions:
|
||||
# # predicted values
|
||||
# probs = list(output['logits'].cpu().detach().numpy())
|
||||
# labels = list(output['labels'].flatten().cpu().detach().numpy())
|
||||
# all_preds.extend(probs)
|
||||
# all_labels.extend(labels)
|
||||
|
||||
# # save predictions and labels
|
||||
# import numpy as np
|
||||
# np.save('predictions.npy', all_preds)
|
||||
# np.save('labels.npy', all_labels)
|
30
test.py
Normal file
30
test.py
Normal file
@ -0,0 +1,30 @@
|
||||
from model import UNetLightning
|
||||
from data import get_data, calculate_mean_std, FootballSegDataset
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import DataLoader
|
||||
from lightning import Trainer
|
||||
import numpy as np
|
||||
|
||||
ckpt_path = "lightning_logs/version_0/checkpoints/epoch=4-step=15.ckpt"
|
||||
hparams_path = "lightning_logs/version_0/hparams.yaml"
|
||||
model = UNetLightning.load_from_checkpoint(
|
||||
ckpt_path, hparams_path=hparams_path)
|
||||
model.eval()
|
||||
model.freeze()
|
||||
|
||||
all_images, all_paths = get_data()
|
||||
image_train_paths, image_test_paths, label_train_paths, label_test_paths = (
|
||||
train_test_split(all_images, all_paths, test_size=0.2, random_state=42)
|
||||
)
|
||||
|
||||
mean, std = calculate_mean_std(image_test_paths)
|
||||
test_dataset = FootballSegDataset(
|
||||
image_test_paths, label_test_paths, mean, std)
|
||||
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
|
||||
|
||||
trainer = Trainer()
|
||||
predictions = trainer.predict(model, test_loader)
|
||||
arrays = []
|
||||
for batch in predictions:
|
||||
arrays.append(batch.cpu().numpy())
|
||||
np.save("predictions.npy", np.concatenate(arrays))
|
Loading…
Reference in New Issue
Block a user