2023-06-04 12:38:13 +02:00
|
|
|
import glob
|
2023-06-04 13:13:03 +02:00
|
|
|
from src.torchvision_resize_dataset import combined_dataset, images_path, classes
|
2023-06-04 12:38:13 +02:00
|
|
|
import src.data_model
|
|
|
|
from torch.optim import Adam
|
2023-06-04 00:09:57 +02:00
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
train_loader = DataLoader(
|
|
|
|
combined_dataset, #dataset of images
|
|
|
|
batch_size=256, # accuracy
|
|
|
|
shuffle=True # rand order
|
|
|
|
)
|
|
|
|
|
2023-06-04 12:38:13 +02:00
|
|
|
model = src.data_model.DataModel(num_objects=2).to(device)
|
|
|
|
|
|
|
|
#optimizer
|
|
|
|
optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
|
|
|
|
#loss function
|
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
|
|
|
2023-06-04 16:41:25 +02:00
|
|
|
num_epochs = 20
|
2023-06-04 14:41:33 +02:00
|
|
|
# train_size = len(glob.glob(images_path+'*.jpg'))
|
|
|
|
train_size = 2002
|
2023-06-04 12:38:13 +02:00
|
|
|
|
|
|
|
go_to_accuracy = 0.0
|
|
|
|
for epoch in range(num_epochs):
|
|
|
|
#training on dataset
|
|
|
|
model.train()
|
|
|
|
train_accuracy = 0.0
|
|
|
|
train_loss = 0.0
|
|
|
|
for i, (images, labels) in enumerate(train_loader):
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
images = torch.Variable(images.cuda())
|
|
|
|
labels = torch.Variable(labels.cuda())
|
|
|
|
# clearing the optimizer gradients
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
outputs = model(images) # predoction
|
|
|
|
loss = criterion(outputs, labels) #loss calculation
|
|
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
train_loss += loss.cpu().data*images.size(0)
|
|
|
|
_, prediction = torch.max(outputs.data, 1)
|
|
|
|
|
|
|
|
train_accuracy += int(torch.sum(prediction == labels.data))
|
|
|
|
|
|
|
|
train_accuracy = train_accuracy/train_size
|
|
|
|
train_loss = train_loss/train_size
|
|
|
|
|
2023-06-04 14:41:33 +02:00
|
|
|
model.eval()
|
|
|
|
|
2023-06-04 12:38:13 +02:00
|
|
|
print('Epoch: '+ str(epoch+1) +' Train Loss: '+ str(int(train_loss)) +' Train Accuracy: '+ str(train_accuracy))
|
|
|
|
|
|
|
|
if train_accuracy > go_to_accuracy:
|
|
|
|
go_to_accuracy= train_accuracy
|
2023-06-04 14:41:33 +02:00
|
|
|
torch.save(model.state_dict(), "best_model.pth")
|