2024-04-28 18:57:42 +02:00
|
|
|
import torch
|
2024-05-14 20:31:53 +02:00
|
|
|
import mlflow
|
2024-04-28 18:57:42 +02:00
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
from torch import nn
|
|
|
|
from torchvision import transforms, datasets
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
|
|
|
|
|
|
class DeviceDataLoader():
|
|
|
|
def __init__(self, dl, device):
|
|
|
|
self.dl = dl
|
|
|
|
self.device = device
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
for b in self.dl:
|
|
|
|
yield to_device(b, self.device)
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.dl)
|
|
|
|
|
|
|
|
|
|
|
|
class Model(nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.conv1 = nn.Conv2d(3, 32, 3, 1)
|
|
|
|
self.batchnorm1 = nn.BatchNorm2d(32)
|
|
|
|
self.conv2 = nn.Conv2d(32, 64, 3, 1)
|
|
|
|
self.batchnorm2 = nn.BatchNorm2d(64)
|
|
|
|
self.conv3 = nn.Conv2d(64, 128, 3, 1)
|
|
|
|
self.fc1 = nn.Linear(128*26*26, 128)
|
|
|
|
self.fc2 = nn.Linear(128, 2)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = F.relu(self.batchnorm1(self.conv1(x)))
|
|
|
|
x = F.max_pool2d(x, 2, 2)
|
|
|
|
x = F.relu(self.batchnorm2(self.conv2(x)))
|
|
|
|
x = F.max_pool2d(x, 2, 2)
|
|
|
|
x = F.relu(self.conv3(x))
|
|
|
|
x = F.max_pool2d(x, 2, 2)
|
|
|
|
x = x.view(-1, 128*26*26)
|
|
|
|
x = F.relu(self.fc1(x))
|
|
|
|
x = self.fc2(x)
|
|
|
|
|
|
|
|
return F.log_softmax(x, dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
def get_data(IMG_SIZE: int, BATCH_SIZE: int, device: torch.device):
|
|
|
|
transformer = transforms.Compose([
|
|
|
|
transforms.RandomRotation(20),
|
|
|
|
transforms.RandomHorizontalFlip(p=0.3),
|
|
|
|
transforms.RandomVerticalFlip(p=0.3),
|
|
|
|
|
|
|
|
transforms.Resize(size = (IMG_SIZE, IMG_SIZE), antialias = True),
|
|
|
|
transforms.CenterCrop(IMG_SIZE),
|
|
|
|
|
|
|
|
transforms.ToTensor(),
|
|
|
|
transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]),
|
|
|
|
])
|
|
|
|
|
|
|
|
trainData = datasets.ImageFolder(root = "./train", transform = transformer)
|
|
|
|
|
|
|
|
trainSet, valSet = torch.utils.data.random_split(trainData, \
|
|
|
|
[int(0.8 * len(trainData)), len(trainData) - int(0.8 * len(trainData))])
|
|
|
|
|
|
|
|
trainLoader = DataLoader(trainSet, batch_size=BATCH_SIZE, shuffle=True)
|
|
|
|
valLoader = DataLoader(valSet, batch_size=BATCH_SIZE, shuffle=True)
|
|
|
|
|
|
|
|
train_loader = DeviceDataLoader(trainLoader, device)
|
|
|
|
val_loader = DeviceDataLoader(valLoader, device)
|
|
|
|
|
|
|
|
return train_loader, val_loader, int(len(trainSet))
|
|
|
|
|
|
|
|
|
|
|
|
def train(EPOCHS: int, BATCH_SIZE: int,
|
|
|
|
model: nn.Module, train_loader: DataLoader,
|
|
|
|
val_loader: DataLoader, criterion: nn.Module,
|
|
|
|
optimizer: torch.optim.Optimizer, train_size: int):
|
|
|
|
|
|
|
|
for epoch in range(EPOCHS):
|
|
|
|
print()
|
|
|
|
print(f'EPOCH {epoch+1}')
|
|
|
|
print()
|
|
|
|
|
|
|
|
model.train(True)
|
|
|
|
|
|
|
|
running_loss, last_loss, avg_loss = 0., 0., 0.
|
|
|
|
train_correct, train_total = 0, 0
|
|
|
|
|
|
|
|
for i, data in enumerate(train_loader):
|
|
|
|
input, label = data
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
output = model(input)
|
|
|
|
loss = criterion(output, label)
|
|
|
|
|
|
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
running_loss += loss.item()
|
|
|
|
avg_loss += loss.item()
|
|
|
|
|
|
|
|
if i % 10 == 0:
|
|
|
|
last_loss = running_loss / 5
|
|
|
|
print(f'Batch {i} Loss train: {last_loss:.3f}')
|
|
|
|
running_loss = 0.
|
|
|
|
|
|
|
|
_, predicted = torch.max(output.data, 1)
|
|
|
|
train_total += label.size(0)
|
|
|
|
train_correct += (predicted == label).sum().item()
|
|
|
|
|
|
|
|
avg_loss /= train_size/BATCH_SIZE
|
|
|
|
|
|
|
|
running_vloss = 0.
|
|
|
|
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
val_correct, val_total = 0, 0
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
for i, val_data in enumerate(val_loader):
|
|
|
|
val_input, val_label = val_data
|
|
|
|
val_output = model(val_input)
|
|
|
|
val_loss = criterion(val_output, val_label)
|
|
|
|
running_vloss += val_loss.item()
|
|
|
|
|
|
|
|
_, vpredicted = torch.max(val_output.data, 1)
|
|
|
|
val_total += val_label.size(0)
|
|
|
|
val_correct += (vpredicted == val_label).sum().item()
|
|
|
|
|
|
|
|
avg_vloss = running_vloss / (i + 1)
|
|
|
|
|
|
|
|
train_accuracy = train_correct / train_total
|
|
|
|
val_accuracy = val_correct / val_total
|
|
|
|
|
2024-05-14 20:31:53 +02:00
|
|
|
mlflow.log_metric("train_loss", avg_loss, step=epoch)
|
|
|
|
mlflow.log_metric("val_loss", avg_vloss, step=epoch)
|
|
|
|
|
2024-04-28 18:57:42 +02:00
|
|
|
print(f'Loss train {avg_loss:.3f}, loss valid {avg_vloss:.3f}')
|
|
|
|
print(f'Accuracy train {train_accuracy:.2%}, accuracy valid {val_accuracy:.2%}')
|
|
|
|
|
|
|
|
print('Finished Training')
|
|
|
|
|
|
|
|
|
|
|
|
def to_device(data, device: torch.device):
|
|
|
|
if isinstance(data, (list,tuple)):
|
|
|
|
return [to_device(x, device) for x in data]
|
|
|
|
return data.to(device, non_blocking=True)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
IMG_SIZE = 224
|
|
|
|
BATCH_SIZE = 32
|
|
|
|
EPOCHS = 3
|
|
|
|
LEARNING_RATE = 0.001
|
|
|
|
|
2024-05-14 20:31:53 +02:00
|
|
|
mlflow.set_experiment("Skin cancer classification, custom CNN model")
|
|
|
|
mlflow.start_run()
|
|
|
|
|
|
|
|
mlflow.log_param("IMG_SIZE", IMG_SIZE)
|
|
|
|
mlflow.log_param("BATCH_SIZE", BATCH_SIZE)
|
|
|
|
mlflow.log_param("EPOCHS", EPOCHS)
|
|
|
|
mlflow.log_param("LEARNING_RATE", LEARNING_RATE)
|
|
|
|
|
2024-04-28 18:57:42 +02:00
|
|
|
train_loader, val_loader, train_size = get_data(IMG_SIZE, BATCH_SIZE, device)
|
|
|
|
|
|
|
|
model = Model()
|
|
|
|
to_device(model, device)
|
|
|
|
|
|
|
|
criterion= nn.CrossEntropyLoss()
|
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
|
|
|
|
|
|
|
train(EPOCHS, BATCH_SIZE, model, train_loader,
|
|
|
|
val_loader, criterion, optimizer, train_size)
|
|
|
|
|
2024-05-14 20:31:53 +02:00
|
|
|
torch.save(model.state_dict(), "model.pth")
|
|
|
|
mlflow.log_artifact("model.pth")
|