learning algoritm
This commit is contained in:
parent
28b120769d
commit
9172aa397f
@ -1,10 +1,10 @@
|
||||
from src.torchvision_resize_dataset import combined_dataset
|
||||
import glob
|
||||
from src.torchvision_resize_dataset import combined_dataset, letters_path, package_path
|
||||
import src.data_model
|
||||
from torch.optim import Adam
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import datasets
|
||||
from torchvision.transforms import Compose, Lambda, ToTensor
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
@ -15,3 +15,46 @@ train_loader = DataLoader(
|
||||
)
|
||||
classes = ["package", "list"]
|
||||
|
||||
model = src.data_model.DataModel(num_objects=2).to(device)
|
||||
|
||||
#optimizer
|
||||
optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
|
||||
#loss function
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
num_epochs = 10
|
||||
train_size = len(glob.glob(letters_path, '*.jpg')) + len(glob.glob(package_path, '*.png'))
|
||||
|
||||
go_to_accuracy = 0.0
|
||||
for epoch in range(num_epochs):
|
||||
#training on dataset
|
||||
model.train()
|
||||
train_accuracy = 0.0
|
||||
train_loss = 0.0
|
||||
for i, (images, labels) in enumerate(train_loader):
|
||||
if torch.cuda.is_available():
|
||||
images = torch.Variable(images.cuda())
|
||||
labels = torch.Variable(labels.cuda())
|
||||
# clearing the optimizer gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
outputs = model(images) # predoction
|
||||
loss = criterion(outputs, labels) #loss calculation
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
train_loss += loss.cpu().data*images.size(0)
|
||||
_, prediction = torch.max(outputs.data, 1)
|
||||
|
||||
train_accuracy += int(torch.sum(prediction == labels.data))
|
||||
|
||||
train_accuracy = train_accuracy/train_size
|
||||
train_loss = train_loss/train_size
|
||||
|
||||
print('Epoch: '+ str(epoch+1) +' Train Loss: '+ str(int(train_loss)) +' Train Accuracy: '+ str(train_accuracy))
|
||||
|
||||
if train_accuracy > go_to_accuracy:
|
||||
go_to_accuracy= train_accuracy
|
||||
torch.save(model.state_dict(), "best_model.pth")
|
||||
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
import torchvision.transforms as transforms
|
||||
from torchvision.datasets import ImageFolder
|
||||
from torch.utils.data import ConcatDataset
|
||||
import os
|
||||
|
||||
# images have to be the same size for the algorithm to work
|
||||
transform = transforms.Compose([
|
||||
|
Loading…
Reference in New Issue
Block a user