2023-06-05 03:35:16 +02:00
|
|
|
import torch
|
|
|
|
import argparse
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.optim as optim
|
|
|
|
import time
|
|
|
|
from tqdm.auto import tqdm
|
2023-06-05 05:25:13 +02:00
|
|
|
from neural_network.model import CNNModel
|
|
|
|
from neural_network.datasets import train_loader, valid_loader
|
|
|
|
from neural_network.utils import save_model, save_plots
|
2023-06-05 03:35:16 +02:00
|
|
|
|
|
|
|
# construct the argument parser
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument('-e', '--epochs', type=int, default=20,
|
|
|
|
help='number of epochs to train our network for')
|
|
|
|
args = vars(parser.parse_args())
|
|
|
|
|
|
|
|
|
|
|
|
lr = 1e-3
|
|
|
|
epochs = args['epochs']
|
|
|
|
device = ('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
print(f"Computation device: {device}\n")
|
|
|
|
|
|
|
|
model = CNNModel().to(device)
|
|
|
|
print(model)
|
|
|
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters())
|
|
|
|
print(f"{total_params:,} total parameters.")
|
|
|
|
total_trainable_params = sum(
|
|
|
|
p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
|
print(f"{total_trainable_params:,} training parameters.")
|
|
|
|
# optimizer
|
|
|
|
optimizer = optim.Adam(model.parameters(), lr=lr)
|
|
|
|
# loss function
|
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
|
|
|
|
|
|
|
|
|
|
# training
|
|
|
|
def train(model, trainloader, optimizer, criterion):
|
|
|
|
model.train()
|
|
|
|
print('Training')
|
|
|
|
train_running_loss = 0.0
|
|
|
|
train_running_correct = 0
|
|
|
|
counter = 0
|
|
|
|
for i, data in tqdm(enumerate(trainloader), total=len(trainloader)):
|
|
|
|
counter += 1
|
|
|
|
image, labels = data
|
|
|
|
image = image.to(device)
|
|
|
|
labels = labels.to(device)
|
|
|
|
optimizer.zero_grad()
|
|
|
|
# forward pass
|
|
|
|
outputs = model(image)
|
|
|
|
# calculate the loss
|
|
|
|
loss = criterion(outputs, labels)
|
|
|
|
train_running_loss += loss.item()
|
|
|
|
# calculate the accuracy
|
|
|
|
_, preds = torch.max(outputs.data, 1)
|
|
|
|
train_running_correct += (preds == labels).sum().item()
|
|
|
|
# backpropagation
|
|
|
|
loss.backward()
|
|
|
|
# update the optimizer parameters
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
# loss and accuracy for the complete epoch
|
|
|
|
epoch_loss = train_running_loss / counter
|
|
|
|
epoch_acc = 100. * (train_running_correct / len(trainloader.dataset))
|
|
|
|
return epoch_loss, epoch_acc
|
|
|
|
|
|
|
|
# validation
|
|
|
|
def validate(model, testloader, criterion):
|
|
|
|
model.eval()
|
|
|
|
print('Validation')
|
|
|
|
valid_running_loss = 0.0
|
|
|
|
valid_running_correct = 0
|
|
|
|
counter = 0
|
|
|
|
with torch.no_grad():
|
|
|
|
for i, data in tqdm(enumerate(testloader), total=len(testloader)):
|
|
|
|
counter += 1
|
|
|
|
|
|
|
|
image, labels = data
|
|
|
|
image = image.to(device)
|
|
|
|
labels = labels.to(device)
|
|
|
|
# forward pass
|
|
|
|
outputs = model(image)
|
|
|
|
# calculate the loss
|
|
|
|
loss = criterion(outputs, labels)
|
|
|
|
valid_running_loss += loss.item()
|
|
|
|
# calculate the accuracy
|
|
|
|
_, preds = torch.max(outputs.data, 1)
|
|
|
|
valid_running_correct += (preds == labels).sum().item()
|
|
|
|
|
|
|
|
# loss and accuracy for the complete epoch
|
|
|
|
epoch_loss = valid_running_loss / counter
|
|
|
|
epoch_acc = 100. * (valid_running_correct / len(testloader.dataset))
|
|
|
|
return epoch_loss, epoch_acc
|
|
|
|
|
|
|
|
# lists to keep track of losses and accuracies
|
|
|
|
train_loss, valid_loss = [], []
|
|
|
|
train_acc, valid_acc = [], []
|
|
|
|
# start the training
|
|
|
|
for epoch in range(epochs):
|
|
|
|
print(f"[INFO]: Epoch {epoch+1} of {epochs}")
|
|
|
|
train_epoch_loss, train_epoch_acc = train(model, train_loader,
|
|
|
|
optimizer, criterion)
|
|
|
|
valid_epoch_loss, valid_epoch_acc = validate(model, valid_loader,
|
|
|
|
criterion)
|
|
|
|
train_loss.append(train_epoch_loss)
|
|
|
|
valid_loss.append(valid_epoch_loss)
|
|
|
|
train_acc.append(train_epoch_acc)
|
|
|
|
valid_acc.append(valid_epoch_acc)
|
|
|
|
print(f"Training loss: {train_epoch_loss:.3f}, training acc: {train_epoch_acc:.3f}")
|
|
|
|
print(f"Validation loss: {valid_epoch_loss:.3f}, validation acc: {valid_epoch_acc:.3f}")
|
|
|
|
print('-'*50)
|
|
|
|
time.sleep(5)
|
|
|
|
|
|
|
|
# save the trained model weights
|
|
|
|
save_model(epochs, model, optimizer, criterion)
|
|
|
|
# save the loss and accuracy plots
|
|
|
|
save_plots(train_acc, valid_acc, train_loss, valid_loss)
|
|
|
|
print('TRAINING COMPLETE')
|