learning algoritm

This commit is contained in:
Woj 2023-06-04 12:38:13 +02:00
parent 28b120769d
commit 9172aa397f
2 changed files with 47 additions and 5 deletions

View File

@ -1,10 +1,10 @@
from src.torchvision_resize_dataset import combined_dataset
import glob
from src.torchvision_resize_dataset import combined_dataset, letters_path, package_path
import src.data_model
from torch.optim import Adam
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import Compose, Lambda, ToTensor
import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@ -15,3 +15,46 @@ train_loader = DataLoader(
)
classes = ["package", "list"]
model = src.data_model.DataModel(num_objects=2).to(device)
#optimizer
optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
#loss function
criterion = nn.CrossEntropyLoss()
num_epochs = 10
train_size = len(glob.glob(letters_path, '*.jpg')) + len(glob.glob(package_path, '*.png'))
go_to_accuracy = 0.0
for epoch in range(num_epochs):
#training on dataset
model.train()
train_accuracy = 0.0
train_loss = 0.0
for i, (images, labels) in enumerate(train_loader):
if torch.cuda.is_available():
images = torch.Variable(images.cuda())
labels = torch.Variable(labels.cuda())
# clearing the optimizer gradients
optimizer.zero_grad()
outputs = model(images) # predoction
loss = criterion(outputs, labels) #loss calculation
loss.backward()
optimizer.step()
train_loss += loss.cpu().data*images.size(0)
_, prediction = torch.max(outputs.data, 1)
train_accuracy += int(torch.sum(prediction == labels.data))
train_accuracy = train_accuracy/train_size
train_loss = train_loss/train_size
print('Epoch: '+ str(epoch+1) +' Train Loss: '+ str(int(train_loss)) +' Train Accuracy: '+ str(train_accuracy))
if train_accuracy > go_to_accuracy:
go_to_accuracy= train_accuracy
torch.save(model.state_dict(), "best_model.pth")

View File

@ -1,7 +1,6 @@
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import ConcatDataset
import os
# images have to be the same size for the algorithm to work
transform = transforms.Compose([