ium_464915/train.py

179 lines
5.4 KiB
Python
Raw Normal View History

2024-04-28 18:57:42 +02:00
import torch
2024-05-14 20:31:53 +02:00
import mlflow
2024-04-28 18:57:42 +02:00
import torch.nn.functional as F
from torch import nn
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
class DeviceDataLoader():
def __init__(self, dl, device):
self.dl = dl
self.device = device
def __iter__(self):
for b in self.dl:
yield to_device(b, self.device)
def __len__(self):
return len(self.dl)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, 3, 1)
self.batchnorm1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.batchnorm2 = nn.BatchNorm2d(64)
self.conv3 = nn.Conv2d(64, 128, 3, 1)
self.fc1 = nn.Linear(128*26*26, 128)
self.fc2 = nn.Linear(128, 2)
def forward(self, x):
x = F.relu(self.batchnorm1(self.conv1(x)))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.batchnorm2(self.conv2(x)))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv3(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 128*26*26)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def get_data(IMG_SIZE: int, BATCH_SIZE: int, device: torch.device):
transformer = transforms.Compose([
transforms.RandomRotation(20),
transforms.RandomHorizontalFlip(p=0.3),
transforms.RandomVerticalFlip(p=0.3),
transforms.Resize(size = (IMG_SIZE, IMG_SIZE), antialias = True),
transforms.CenterCrop(IMG_SIZE),
transforms.ToTensor(),
transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]),
])
trainData = datasets.ImageFolder(root = "./train", transform = transformer)
trainSet, valSet = torch.utils.data.random_split(trainData, \
[int(0.8 * len(trainData)), len(trainData) - int(0.8 * len(trainData))])
trainLoader = DataLoader(trainSet, batch_size=BATCH_SIZE, shuffle=True)
valLoader = DataLoader(valSet, batch_size=BATCH_SIZE, shuffle=True)
train_loader = DeviceDataLoader(trainLoader, device)
val_loader = DeviceDataLoader(valLoader, device)
return train_loader, val_loader, int(len(trainSet))
def train(EPOCHS: int, BATCH_SIZE: int,
model: nn.Module, train_loader: DataLoader,
val_loader: DataLoader, criterion: nn.Module,
optimizer: torch.optim.Optimizer, train_size: int):
for epoch in range(EPOCHS):
print()
print(f'EPOCH {epoch+1}')
print()
model.train(True)
running_loss, last_loss, avg_loss = 0., 0., 0.
train_correct, train_total = 0, 0
for i, data in enumerate(train_loader):
input, label = data
optimizer.zero_grad()
output = model(input)
loss = criterion(output, label)
loss.backward()
optimizer.step()
running_loss += loss.item()
avg_loss += loss.item()
if i % 10 == 0:
last_loss = running_loss / 5
print(f'Batch {i} Loss train: {last_loss:.3f}')
running_loss = 0.
_, predicted = torch.max(output.data, 1)
train_total += label.size(0)
train_correct += (predicted == label).sum().item()
avg_loss /= train_size/BATCH_SIZE
running_vloss = 0.
model.eval()
val_correct, val_total = 0, 0
with torch.no_grad():
for i, val_data in enumerate(val_loader):
val_input, val_label = val_data
val_output = model(val_input)
val_loss = criterion(val_output, val_label)
running_vloss += val_loss.item()
_, vpredicted = torch.max(val_output.data, 1)
val_total += val_label.size(0)
val_correct += (vpredicted == val_label).sum().item()
avg_vloss = running_vloss / (i + 1)
train_accuracy = train_correct / train_total
val_accuracy = val_correct / val_total
2024-05-14 20:31:53 +02:00
mlflow.log_metric("train_loss", avg_loss, step=epoch)
mlflow.log_metric("val_loss", avg_vloss, step=epoch)
2024-04-28 18:57:42 +02:00
print(f'Loss train {avg_loss:.3f}, loss valid {avg_vloss:.3f}')
print(f'Accuracy train {train_accuracy:.2%}, accuracy valid {val_accuracy:.2%}')
print('Finished Training')
def to_device(data, device: torch.device):
if isinstance(data, (list,tuple)):
return [to_device(x, device) for x in data]
return data.to(device, non_blocking=True)
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 3
LEARNING_RATE = 0.001
2024-05-14 20:31:53 +02:00
mlflow.set_experiment("Skin cancer classification, custom CNN model")
mlflow.start_run()
mlflow.log_param("IMG_SIZE", IMG_SIZE)
mlflow.log_param("BATCH_SIZE", BATCH_SIZE)
mlflow.log_param("EPOCHS", EPOCHS)
mlflow.log_param("LEARNING_RATE", LEARNING_RATE)
2024-04-28 18:57:42 +02:00
train_loader, val_loader, train_size = get_data(IMG_SIZE, BATCH_SIZE, device)
model = Model()
to_device(model, device)
criterion= nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
train(EPOCHS, BATCH_SIZE, model, train_loader,
val_loader, criterion, optimizer, train_size)
2024-05-14 20:31:53 +02:00
torch.save(model.state_dict(), "model.pth")
mlflow.log_artifact("model.pth")