add model training
This commit is contained in:
parent
094e93bd2d
commit
0fad11aa5c
29
data.py
29
data.py
@ -14,19 +14,28 @@ class FootballSegDataset(Dataset):
|
|||||||
self,
|
self,
|
||||||
image_paths,
|
image_paths,
|
||||||
label_paths,
|
label_paths,
|
||||||
mean=[0.3468, 0.3885, 0.3321],
|
img_mean=[0.3468, 0.3885, 0.3321],
|
||||||
std=[0.2081, 0.2054, 0.2093],
|
img_std=[0.2081, 0.2054, 0.2093],
|
||||||
|
label_mean=[0.3468, 0.3885, 0.3321],
|
||||||
|
label_std=[0.2081, 0.2054, 0.2093],
|
||||||
):
|
):
|
||||||
self.image_paths = image_paths
|
self.image_paths = image_paths
|
||||||
self.label_paths = label_paths
|
self.label_paths = label_paths
|
||||||
self.mean = mean
|
self.transform_img = transforms.Compose(
|
||||||
self.std = std
|
|
||||||
self.transform = transforms.Compose(
|
|
||||||
[
|
[
|
||||||
transforms.Resize((256, 256)),
|
transforms.Resize((256, 256)),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize(
|
transforms.Normalize(
|
||||||
mean=[0.3468, 0.3885, 0.3321], std=[0.2081, 0.2054, 0.2093]
|
mean=img_mean, std=img_std
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.transform_label = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Resize((256, 256)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(
|
||||||
|
mean=label_mean, std=label_std
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -52,9 +61,8 @@ class FootballSegDataset(Dataset):
|
|||||||
image = Image.open(self.image_paths[idx]).convert("RGB")
|
image = Image.open(self.image_paths[idx]).convert("RGB")
|
||||||
label_map = Image.open(self.label_paths[idx]).convert("RGB")
|
label_map = Image.open(self.label_paths[idx]).convert("RGB")
|
||||||
|
|
||||||
if self.transform:
|
image = self.transform_img(image)
|
||||||
image = self.transform(image)
|
label_map = self.transform_label(label_map)
|
||||||
label_map = self.transform(label_map)
|
|
||||||
|
|
||||||
return image, label_map
|
return image, label_map
|
||||||
|
|
||||||
@ -112,7 +120,8 @@ def get_data():
|
|||||||
api.dataset_download_files(dataset_slug, path=download_dir, unzip=True)
|
api.dataset_download_files(dataset_slug, path=download_dir, unzip=True)
|
||||||
|
|
||||||
all_images = glob(download_dir + "/images/*.jpg")
|
all_images = glob(download_dir + "/images/*.jpg")
|
||||||
all_paths = [path.replace(".jpg", ".jpg___fuse.png") for path in all_images]
|
all_paths = [path.replace(".jpg", ".jpg___fuse.png")
|
||||||
|
for path in all_images]
|
||||||
return all_images, all_paths
|
return all_images, all_paths
|
||||||
|
|
||||||
|
|
||||||
|
31
main.py
31
main.py
@ -1,15 +1,15 @@
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
|
from lightning import Trainer
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from data import (
|
from data import (
|
||||||
FootballSegDataset,
|
FootballSegDataset,
|
||||||
SegmentationStatistics,
|
|
||||||
calculate_mean_std,
|
calculate_mean_std,
|
||||||
get_data,
|
get_data,
|
||||||
plot_random_images,
|
|
||||||
)
|
)
|
||||||
|
from model import UNetLightning
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -19,9 +19,14 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
mean, std = calculate_mean_std(image_train_paths)
|
mean, std = calculate_mean_std(image_train_paths)
|
||||||
|
label_mean, label_std = calculate_mean_std(label_train_paths)
|
||||||
|
test_mean, test_std = calculate_mean_std(image_test_paths)
|
||||||
|
test_label_mean, test_label_std = calculate_mean_std(label_test_paths)
|
||||||
|
|
||||||
train_dataset = FootballSegDataset(image_train_paths, label_train_paths, mean, std)
|
train_dataset = FootballSegDataset(
|
||||||
test_dataset = FootballSegDataset(image_test_paths, label_test_paths, mean, std)
|
image_train_paths, label_train_paths, mean, std, label_mean, label_std)
|
||||||
|
test_dataset = FootballSegDataset(
|
||||||
|
image_test_paths, label_test_paths, test_mean, test_std, test_label_mean, test_label_std)
|
||||||
|
|
||||||
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
||||||
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
|
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
|
||||||
@ -29,11 +34,19 @@ def main():
|
|||||||
train_indices = random.sample(range(len(train_loader.dataset)), 5)
|
train_indices = random.sample(range(len(train_loader.dataset)), 5)
|
||||||
test_indices = random.sample(range(len(test_loader.dataset)), 5)
|
test_indices = random.sample(range(len(test_loader.dataset)), 5)
|
||||||
|
|
||||||
plot_random_images(train_indices, train_loader, 'train')
|
# plot_random_images(train_indices, train_loader, "train")
|
||||||
plot_random_images(test_indices, test_loader, 'test')
|
# plot_random_images(test_indices, test_loader, "test")
|
||||||
statistics = SegmentationStatistics(all_paths, train_loader, test_loader, mean, std)
|
# statistics = SegmentationStatistics(all_paths, train_loader, test_loader, mean, std)
|
||||||
statistics.count_colors()
|
# statistics.count_colors()
|
||||||
statistics.print_statistics()
|
# statistics.print_statistics()
|
||||||
|
|
||||||
|
model = UNetLightning(3, 3, learning_rate=1e-3)
|
||||||
|
trainer = Trainer(max_epochs=5, logger=True, log_every_n_steps=1)
|
||||||
|
trainer.fit(model, train_loader, test_loader)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
model.freeze()
|
||||||
|
trainer.test(model, test_loader)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
112
model.py
Normal file
112
model.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
import lightning as L
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class UNet(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels):
|
||||||
|
super(UNet, self).__init__()
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||||
|
)
|
||||||
|
self.middle = nn.Sequential(
|
||||||
|
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(128, 128, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||||
|
)
|
||||||
|
self.decoder = nn.Sequential(
|
||||||
|
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.ConvTranspose2d(64, out_channels, kernel_size=2, stride=2),
|
||||||
|
nn.Sigmoid(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x1 = self.encoder(x)
|
||||||
|
x2 = self.middle(x1)
|
||||||
|
x3 = self.decoder(x2)
|
||||||
|
return x3
|
||||||
|
|
||||||
|
|
||||||
|
class UNetLightning(L.LightningModule):
|
||||||
|
def __init__(self, in_channels=3, out_channels=3, learning_rate=1e-3):
|
||||||
|
super(UNetLightning, self).__init__()
|
||||||
|
self.model = UNet(in_channels, out_channels)
|
||||||
|
self.learning_rate = learning_rate
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.model(x)
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
x, y = batch
|
||||||
|
# print(max(x.flatten()), min(x.flatten()),
|
||||||
|
# max(y.flatten()), min(y.flatten()))
|
||||||
|
y_hat = self(x)
|
||||||
|
loss = nn.CrossEntropyLoss()(y_hat, y)
|
||||||
|
self.log("train_loss", loss)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
x, y = batch
|
||||||
|
y_hat = self(x)
|
||||||
|
loss = nn.CrossEntropyLoss()(y_hat, y)
|
||||||
|
self.log("val_loss", loss)
|
||||||
|
|
||||||
|
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
||||||
|
x, y = batch
|
||||||
|
y_hat = self(x)
|
||||||
|
return y_hat
|
||||||
|
|
||||||
|
def test_step(self, batch, batch_idx):
|
||||||
|
x, y = batch
|
||||||
|
y_hat = self(x)
|
||||||
|
loss = nn.CrossEntropyLoss()(y_hat, y)
|
||||||
|
self.log("test_loss", loss)
|
||||||
|
# visualize
|
||||||
|
# if batch_idx == 0:
|
||||||
|
# import matplotlib.pyplot as plt
|
||||||
|
# import numpy as np
|
||||||
|
# import torchvision.transforms as transforms
|
||||||
|
# x = transforms.ToPILImage()(x[0])
|
||||||
|
# y = transforms.ToPILImage()(y[0])
|
||||||
|
# y_hat = transforms.ToPILImage()(y_hat[0])
|
||||||
|
# plt.figure(figsize=(15, 15))
|
||||||
|
# plt.subplot(131)
|
||||||
|
# plt.imshow(np.array(x))
|
||||||
|
# plt.title("Input")
|
||||||
|
# plt.axis("off")
|
||||||
|
# plt.subplot(132)
|
||||||
|
# plt.imshow(np.array(y))
|
||||||
|
# plt.title("Ground Truth")
|
||||||
|
# plt.axis("off")
|
||||||
|
# plt.subplot(133)
|
||||||
|
# plt.imshow(np.array(y_hat))
|
||||||
|
# plt.title("Prediction")
|
||||||
|
# plt.axis("off")
|
||||||
|
# plt.show()
|
||||||
|
|
||||||
|
# def on_test_epoch_end(self):
|
||||||
|
# all_preds, all_labels = [], []
|
||||||
|
# for output in self.trainer.predictions:
|
||||||
|
# # predicted values
|
||||||
|
# probs = list(output['logits'].cpu().detach().numpy())
|
||||||
|
# labels = list(output['labels'].flatten().cpu().detach().numpy())
|
||||||
|
# all_preds.extend(probs)
|
||||||
|
# all_labels.extend(labels)
|
||||||
|
|
||||||
|
# # save predictions and labels
|
||||||
|
# import numpy as np
|
||||||
|
# np.save('predictions.npy', all_preds)
|
||||||
|
# np.save('labels.npy', all_labels)
|
30
test.py
Normal file
30
test.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
from model import UNetLightning
|
||||||
|
from data import get_data, calculate_mean_std, FootballSegDataset
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from lightning import Trainer
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
ckpt_path = "lightning_logs/version_0/checkpoints/epoch=4-step=15.ckpt"
|
||||||
|
hparams_path = "lightning_logs/version_0/hparams.yaml"
|
||||||
|
model = UNetLightning.load_from_checkpoint(
|
||||||
|
ckpt_path, hparams_path=hparams_path)
|
||||||
|
model.eval()
|
||||||
|
model.freeze()
|
||||||
|
|
||||||
|
all_images, all_paths = get_data()
|
||||||
|
image_train_paths, image_test_paths, label_train_paths, label_test_paths = (
|
||||||
|
train_test_split(all_images, all_paths, test_size=0.2, random_state=42)
|
||||||
|
)
|
||||||
|
|
||||||
|
mean, std = calculate_mean_std(image_test_paths)
|
||||||
|
test_dataset = FootballSegDataset(
|
||||||
|
image_test_paths, label_test_paths, mean, std)
|
||||||
|
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
|
||||||
|
|
||||||
|
trainer = Trainer()
|
||||||
|
predictions = trainer.predict(model, test_loader)
|
||||||
|
arrays = []
|
||||||
|
for batch in predictions:
|
||||||
|
arrays.append(batch.cpu().numpy())
|
||||||
|
np.save("predictions.npy", np.concatenate(arrays))
|
Loading…
Reference in New Issue
Block a user