Compare commits
11 Commits
Author | SHA1 | Date | |
---|---|---|---|
edf1e43d33 | |||
0b33ff1803 | |||
86be72ba33 | |||
da73e223e3 | |||
e27acbacaf | |||
28bf53c037 | |||
931e40d88f | |||
d4e382a7f0 | |||
c0c5d3adcd | |||
9715493351 | |||
0099d64e54 |
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,3 +0,0 @@
|
|||||||
.idea
|
|
||||||
.DS_Store
|
|
||||||
__pycache__
|
|
BIN
__pycache__/blocks.cpython-37.pyc
Normal file
BIN
__pycache__/blocks.cpython-37.pyc
Normal file
Binary file not shown.
BIN
__pycache__/land.cpython-37.pyc
Normal file
BIN
__pycache__/land.cpython-37.pyc
Normal file
Binary file not shown.
BIN
__pycache__/soil.cpython-37.pyc
Normal file
BIN
__pycache__/soil.cpython-37.pyc
Normal file
Binary file not shown.
BIN
__pycache__/tractor.cpython-37.pyc
Normal file
BIN
__pycache__/tractor.cpython-37.pyc
Normal file
Binary file not shown.
101
main.py
101
main.py
@ -5,8 +5,90 @@ import land
|
|||||||
import tractor
|
import tractor
|
||||||
import blocks
|
import blocks
|
||||||
import astar_search
|
import astar_search
|
||||||
|
import neural_network.inference
|
||||||
from pygame.locals import *
|
from pygame.locals import *
|
||||||
|
|
||||||
|
examples = [
|
||||||
|
['piasek', 'sucha', 'jalowa', 'żółty'],
|
||||||
|
['czarnoziem', 'wilgotna', 'bogata', 'brazowa'],
|
||||||
|
['rzedzina', 'wilgotna', 'bogata', 'zielona'],
|
||||||
|
['gleby murszowe', 'wilgotna', 'bogata', 'szara'],
|
||||||
|
['pustynne gleby', 'sucha', 'jalowa', 'pomarańczowa'],
|
||||||
|
['torfowiska', 'sucha', 'jalowa', 'czerwona']
|
||||||
|
]
|
||||||
|
|
||||||
|
attributes = ['typ_gleby', 'wilgotność', 'zawartość_składników', 'kolor']
|
||||||
|
|
||||||
|
# Tworzenie obiektu TreeLearn i nauka drzewa decyzyjnego
|
||||||
|
# tree_learner = TreeLearn()
|
||||||
|
# default_class = 'nieznane'
|
||||||
|
|
||||||
|
# tree_learner.train(examples, attributes, default_class)
|
||||||
|
|
||||||
|
class TreeLearn:
|
||||||
|
def __init__(self):
|
||||||
|
self.tree = None
|
||||||
|
|
||||||
|
def train(self, examples, attributes, default_class):
|
||||||
|
self.tree = self.build_tree(examples, attributes, default_class)
|
||||||
|
|
||||||
|
def build_tree(self, examples, attributes, default_class):
|
||||||
|
if not examples:
|
||||||
|
return Node(default_class)
|
||||||
|
|
||||||
|
if self.all_same_class(examples):
|
||||||
|
return Node(examples[0][-1])
|
||||||
|
|
||||||
|
if not attributes:
|
||||||
|
class_counts = self.get_class_counts(examples)
|
||||||
|
default_class = max(class_counts, key=class_counts.get)
|
||||||
|
return Node(default_class)
|
||||||
|
|
||||||
|
best_attribute = self.choose_attribute(examples, attributes)
|
||||||
|
root = Node(best_attribute)
|
||||||
|
|
||||||
|
attribute_values = self.get_attribute_values(examples, best_attribute)
|
||||||
|
|
||||||
|
for value in attribute_values:
|
||||||
|
new_examples = self.filter_examples(examples, best_attribute, value)
|
||||||
|
new_attributes = attributes[:]
|
||||||
|
new_attributes.remove(best_attribute)
|
||||||
|
new_default_class = max(self.get_class_counts(new_examples), key=lambda k: class_counts.get(k, 0))
|
||||||
|
|
||||||
|
subtree = self.build_tree(new_examples, new_attributes, new_default_class)
|
||||||
|
root.add_child(value, subtree)
|
||||||
|
|
||||||
|
return root
|
||||||
|
|
||||||
|
def all_same_class(self, examples):
|
||||||
|
return len(set([example[-1] for example in examples])) == 1
|
||||||
|
|
||||||
|
def get_class_counts(self, examples):
|
||||||
|
class_counts = {}
|
||||||
|
for example in examples:
|
||||||
|
class_label = example[-1]
|
||||||
|
class_counts[class_label] = class_counts.get(class_label, 0) + 1
|
||||||
|
return class_counts
|
||||||
|
|
||||||
|
def choose_attribute(self, examples, attributes):
|
||||||
|
# Placeholder for attribute selection logic
|
||||||
|
return attributes[0]
|
||||||
|
|
||||||
|
def get_attribute_values(self, examples, attribute):
|
||||||
|
return list(set([example[attribute] for example in examples]))
|
||||||
|
|
||||||
|
def filter_examples(self, examples, attribute, value):
|
||||||
|
return [example for example in examples if example[attribute] == value]
|
||||||
|
|
||||||
|
|
||||||
|
class Node:
|
||||||
|
def __init__(self, label):
|
||||||
|
self.label = label
|
||||||
|
self.children = {}
|
||||||
|
|
||||||
|
def add_child(self, value, child):
|
||||||
|
self.children[value] = child
|
||||||
|
|
||||||
|
|
||||||
class Game:
|
class Game:
|
||||||
cell_size = 50
|
cell_size = 50
|
||||||
@ -70,10 +152,13 @@ class Game:
|
|||||||
clock = pygame.time.Clock()
|
clock = pygame.time.Clock()
|
||||||
|
|
||||||
move_tractor_event = pygame.USEREVENT + 1
|
move_tractor_event = pygame.USEREVENT + 1
|
||||||
pygame.time.set_timer(move_tractor_event, 100) # tractor moves every 1000 ms
|
pygame.time.set_timer(move_tractor_event, 500) # tractor moves every 1000 ms
|
||||||
tractor_next_moves = []
|
tractor_next_moves = []
|
||||||
astar_search_object = astar_search.Search(self.cell_size, self.cell_number)
|
astar_search_object = astar_search.Search(self.cell_size, self.cell_number)
|
||||||
|
|
||||||
|
veggies = dict()
|
||||||
|
veggies_debug = dict()
|
||||||
|
|
||||||
while running:
|
while running:
|
||||||
clock.tick(60) # manual fps control not to overwork the computer
|
clock.tick(60) # manual fps control not to overwork the computer
|
||||||
for event in pygame.event.get():
|
for event in pygame.event.get():
|
||||||
@ -109,6 +194,20 @@ class Game:
|
|||||||
#bandaid to know about stones
|
#bandaid to know about stones
|
||||||
tractor_next_moves = astar_search_object.astarsearch(
|
tractor_next_moves = astar_search_object.astarsearch(
|
||||||
[self.tractor.x, self.tractor.y, angles[self.tractor.angle]], [random_x, random_y], self.stone_body, self.flower_body)
|
[self.tractor.x, self.tractor.y, angles[self.tractor.angle]], [random_x, random_y], self.stone_body, self.flower_body)
|
||||||
|
current_veggie = next(os.walk('./neural_network/images/test'))[1][random.randint(0, len(next(os.walk('./neural_network/images/test'))[1])-1)]
|
||||||
|
if(current_veggie in veggies_debug):
|
||||||
|
veggies_debug[current_veggie]+=1
|
||||||
|
else:
|
||||||
|
veggies_debug[current_veggie] = 1
|
||||||
|
|
||||||
|
current_veggie_example = next(os.walk(f'./neural_network/images/test/{current_veggie}'))[2][random.randint(0, len(next(os.walk(f'./neural_network/images/test/{current_veggie}'))[2])-1)]
|
||||||
|
predicted_veggie = neural_network.inference.main(f"./neural_network/images/test/{current_veggie}/{current_veggie_example}")
|
||||||
|
if predicted_veggie in veggies:
|
||||||
|
veggies[predicted_veggie]+=1
|
||||||
|
else:
|
||||||
|
veggies[predicted_veggie] = 1
|
||||||
|
print("Debug veggies: ", veggies_debug, "Predicted veggies: ", veggies)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.tractor.move(tractor_next_moves.pop(0)[0], self.cell_size, self.cell_number)
|
self.tractor.move(tractor_next_moves.pop(0)[0], self.cell_size, self.cell_number)
|
||||||
elif event.type == QUIT:
|
elif event.type == QUIT:
|
||||||
|
42
neural_network/datasets.py
Normal file
42
neural_network/datasets.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import torchvision
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
BATCH_SIZE = 64
|
||||||
|
|
||||||
|
|
||||||
|
train_transform = transforms.Compose([
|
||||||
|
transforms.Resize((224, 224)), #validate that all images are 224x244
|
||||||
|
transforms.RandomHorizontalFlip(p=0.5),
|
||||||
|
transforms.RandomVerticalFlip(p=0.5),
|
||||||
|
transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
|
||||||
|
transforms.RandomRotation(degrees=(30, 70)), #random effects are applied to prevent overfitting
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(
|
||||||
|
mean=[0.5, 0.5, 0.5],
|
||||||
|
std=[0.5, 0.5, 0.5]
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
|
valid_transform = transforms.Compose([
|
||||||
|
transforms.Resize((224, 224)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(
|
||||||
|
mean=[0.5, 0.5, 0.5],
|
||||||
|
std=[0.5, 0.5, 0.5]
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
|
train_dataset = torchvision.datasets.ImageFolder(root='./images/train', transform=train_transform)
|
||||||
|
|
||||||
|
validation_dataset = torchvision.datasets.ImageFolder(root='./images/validation', transform=valid_transform)
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True
|
||||||
|
)
|
||||||
|
|
||||||
|
valid_loader = DataLoader(
|
||||||
|
validation_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True
|
||||||
|
)
|
59
neural_network/inference.py
Normal file
59
neural_network/inference.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
import torch
|
||||||
|
import cv2
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
import argparse
|
||||||
|
from neural_network.model import CNNModel
|
||||||
|
# construct the argument parser
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('-i', '--input',
|
||||||
|
default='',
|
||||||
|
help='path to the input image')
|
||||||
|
args = vars(parser.parse_args())
|
||||||
|
|
||||||
|
def main(path):
|
||||||
|
# the computation device
|
||||||
|
device = ('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
# list containing all the class labels
|
||||||
|
labels = [
|
||||||
|
'bean', 'bitter gourd', 'bottle gourd', 'brinjal', 'broccoli',
|
||||||
|
'cabbage', 'capsicum', 'carrot', 'cauliflower', 'cucumber',
|
||||||
|
'papaya', 'potato', 'pumpkin', 'radish', 'tomato'
|
||||||
|
]
|
||||||
|
|
||||||
|
# initialize the model and load the trained weights
|
||||||
|
model = CNNModel().to(device)
|
||||||
|
checkpoint = torch.load('./neural_network/outputs/model.pth', map_location=device)
|
||||||
|
model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# define preprocess transforms
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.ToPILImage(),
|
||||||
|
transforms.Resize(224),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(
|
||||||
|
mean=[0.5, 0.5, 0.5],
|
||||||
|
std=[0.5, 0.5, 0.5]
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
# read and preprocess the image
|
||||||
|
image = cv2.imread(path)
|
||||||
|
# get the ground truth class
|
||||||
|
gt_class = path.split('/')[-2]
|
||||||
|
orig_image = image.copy()
|
||||||
|
# convert to RGB format
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
|
image = transform(image)
|
||||||
|
# add batch dimension
|
||||||
|
image = torch.unsqueeze(image, 0)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(image.to(device))
|
||||||
|
output_label = torch.topk(outputs, 1)
|
||||||
|
pred_class = labels[int(output_label.indices)]
|
||||||
|
|
||||||
|
return pred_class
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main(args['input'])
|
24
neural_network/model.py
Normal file
24
neural_network/model.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class CNNModel(nn.Module): #model of the CNN type
|
||||||
|
def __init__(self):
|
||||||
|
super(CNNModel, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(3, 32, 5)
|
||||||
|
self.conv2 = nn.Conv2d(32, 64, 5)
|
||||||
|
self.conv3 = nn.Conv2d(64, 128, 3)
|
||||||
|
self.conv4 = nn.Conv2d(128, 256, 5)
|
||||||
|
|
||||||
|
self.fc1 = nn.Linear(256, 50)
|
||||||
|
|
||||||
|
self.pool = nn.MaxPool2d(2, 2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.pool(F.relu(self.conv1(x)))
|
||||||
|
x = self.pool(F.relu(self.conv2(x)))
|
||||||
|
x = self.pool(F.relu(self.conv3(x)))
|
||||||
|
x = self.pool(F.relu(self.conv4(x)))
|
||||||
|
bs, _, _, _ = x.shape
|
||||||
|
x = F.adaptive_avg_pool2d(x, 1).reshape(bs, -1)
|
||||||
|
x = self.fc1(x)
|
||||||
|
return x
|
@ -1,103 +0,0 @@
|
|||||||
#imports
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.optim as optim
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
import torchvision.datasets as datasets
|
|
||||||
import torchvision.transforms as transforms
|
|
||||||
|
|
||||||
#create fully connected network
|
|
||||||
class NN(nn.Module):
|
|
||||||
def __init__(self, input_size, num_classes): #1 layer (28x28 = 784 nodes)
|
|
||||||
super(NN,self)._init_()
|
|
||||||
self.fc1 = nn.Linear(input_size, 50)
|
|
||||||
self.fc2 = nn.Linear(50,num_classes)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = F.relu(self.fc1(x))
|
|
||||||
x = self.fc2(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
# model = NN(784, 10)
|
|
||||||
# x = torch.rand(64, 784)
|
|
||||||
# print(model(x).shape)
|
|
||||||
|
|
||||||
#set device
|
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
||||||
|
|
||||||
#hyperparameters
|
|
||||||
input_size = 784
|
|
||||||
num_classes = 10
|
|
||||||
learning_rate = 0.001
|
|
||||||
batch_size = 64
|
|
||||||
num_epochs = 1
|
|
||||||
|
|
||||||
#load data
|
|
||||||
train_dataset = datasets.MNIST(root='dataset/', train = True, transform = transforms.toTensor(), download = True)
|
|
||||||
train_loader = DataLoader(dataset= train_dataset, batch_size = batch_size, shuffle = True)
|
|
||||||
test_dataset = datasets.MNIST(root='dataset/', train = False, transform = transforms.toTensor(), download = True)
|
|
||||||
test_loader = DataLoader(dataset= test_dataset, batch_size = batch_size, shuffle = True)
|
|
||||||
|
|
||||||
|
|
||||||
#initialize network
|
|
||||||
model = NN(input_size=input_size, num_classes=num_classes).to(device)
|
|
||||||
|
|
||||||
#loss and optimizer
|
|
||||||
criterion = nn.CrossEntropyLoss()
|
|
||||||
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
|
||||||
|
|
||||||
#train network
|
|
||||||
for epoch in range(num_epochs):
|
|
||||||
for batch_idx, (data, targets) in enumerate(train_loader):
|
|
||||||
#get data to cuda if possible
|
|
||||||
data = data.to(device = device)
|
|
||||||
targets = targets.to(device = device)
|
|
||||||
|
|
||||||
#get to correct shape
|
|
||||||
data = data.reshape(data.shape[0], -1)
|
|
||||||
|
|
||||||
# forward
|
|
||||||
scores = model(data)
|
|
||||||
loss = criterion(scores, targets)
|
|
||||||
|
|
||||||
#backward
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
#gradient descent or adam step
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
|
|
||||||
#check accuracy on training and test to see how good our model
|
|
||||||
|
|
||||||
def check_accuracy(loader, model):
|
|
||||||
if loader.dataset.train:
|
|
||||||
print("Checking accuracy on training data")
|
|
||||||
else:
|
|
||||||
print("Checking accuracy on test data")
|
|
||||||
|
|
||||||
num_correct = 0
|
|
||||||
num_samples = 0
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for x,y in loader:
|
|
||||||
x = x.to(device = device)
|
|
||||||
y = y.to(device = device)
|
|
||||||
x = x.reshape(x.shape[0], -1)
|
|
||||||
|
|
||||||
scores = model(x)
|
|
||||||
_, predictions = scores.max(1)
|
|
||||||
num_correct += (predictions == y).sum()
|
|
||||||
num_samples += predictions.size(0)
|
|
||||||
|
|
||||||
print(f'Got {num_correct} / {num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}')
|
|
||||||
|
|
||||||
model.train()
|
|
||||||
return acc
|
|
||||||
|
|
||||||
check_accuracy(train_loader, model)
|
|
||||||
check_accuracy(test_loader, model)
|
|
||||||
|
|
||||||
|
|
BIN
neural_network/outputs/accuracy.png
Normal file
BIN
neural_network/outputs/accuracy.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 40 KiB |
BIN
neural_network/outputs/loss.png
Normal file
BIN
neural_network/outputs/loss.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 41 KiB |
BIN
neural_network/outputs/model.pth
Normal file
BIN
neural_network/outputs/model.pth
Normal file
Binary file not shown.
119
neural_network/train.py
Normal file
119
neural_network/train.py
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
import torch
|
||||||
|
import argparse
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
import time
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from neural_network.model import CNNModel
|
||||||
|
from neural_network.datasets import train_loader, valid_loader
|
||||||
|
from neural_network.utils import save_model, save_plots
|
||||||
|
|
||||||
|
# construct the argument parser
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('-e', '--epochs', type=int, default=20,
|
||||||
|
help='number of epochs to train our network for')
|
||||||
|
args = vars(parser.parse_args())
|
||||||
|
|
||||||
|
|
||||||
|
lr = 1e-3
|
||||||
|
epochs = args['epochs']
|
||||||
|
device = ('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
print(f"Computation device: {device}\n")
|
||||||
|
|
||||||
|
model = CNNModel().to(device)
|
||||||
|
print(model)
|
||||||
|
|
||||||
|
total_params = sum(p.numel() for p in model.parameters())
|
||||||
|
print(f"{total_params:,} total parameters.")
|
||||||
|
total_trainable_params = sum(
|
||||||
|
p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
print(f"{total_trainable_params:,} training parameters.")
|
||||||
|
# optimizer
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=lr)
|
||||||
|
# loss function
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
|
||||||
|
# training
|
||||||
|
def train(model, trainloader, optimizer, criterion):
|
||||||
|
model.train()
|
||||||
|
print('Training')
|
||||||
|
train_running_loss = 0.0
|
||||||
|
train_running_correct = 0
|
||||||
|
counter = 0
|
||||||
|
for i, data in tqdm(enumerate(trainloader), total=len(trainloader)):
|
||||||
|
counter += 1
|
||||||
|
image, labels = data
|
||||||
|
image = image.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
# forward pass
|
||||||
|
outputs = model(image)
|
||||||
|
# calculate the loss
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
train_running_loss += loss.item()
|
||||||
|
# calculate the accuracy
|
||||||
|
_, preds = torch.max(outputs.data, 1)
|
||||||
|
train_running_correct += (preds == labels).sum().item()
|
||||||
|
# backpropagation
|
||||||
|
loss.backward()
|
||||||
|
# update the optimizer parameters
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# loss and accuracy for the complete epoch
|
||||||
|
epoch_loss = train_running_loss / counter
|
||||||
|
epoch_acc = 100. * (train_running_correct / len(trainloader.dataset))
|
||||||
|
return epoch_loss, epoch_acc
|
||||||
|
|
||||||
|
# validation
|
||||||
|
def validate(model, testloader, criterion):
|
||||||
|
model.eval()
|
||||||
|
print('Validation')
|
||||||
|
valid_running_loss = 0.0
|
||||||
|
valid_running_correct = 0
|
||||||
|
counter = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for i, data in tqdm(enumerate(testloader), total=len(testloader)):
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
image, labels = data
|
||||||
|
image = image.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
# forward pass
|
||||||
|
outputs = model(image)
|
||||||
|
# calculate the loss
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
valid_running_loss += loss.item()
|
||||||
|
# calculate the accuracy
|
||||||
|
_, preds = torch.max(outputs.data, 1)
|
||||||
|
valid_running_correct += (preds == labels).sum().item()
|
||||||
|
|
||||||
|
# loss and accuracy for the complete epoch
|
||||||
|
epoch_loss = valid_running_loss / counter
|
||||||
|
epoch_acc = 100. * (valid_running_correct / len(testloader.dataset))
|
||||||
|
return epoch_loss, epoch_acc
|
||||||
|
|
||||||
|
# lists to keep track of losses and accuracies
|
||||||
|
train_loss, valid_loss = [], []
|
||||||
|
train_acc, valid_acc = [], []
|
||||||
|
# start the training
|
||||||
|
for epoch in range(epochs):
|
||||||
|
print(f"[INFO]: Epoch {epoch+1} of {epochs}")
|
||||||
|
train_epoch_loss, train_epoch_acc = train(model, train_loader,
|
||||||
|
optimizer, criterion)
|
||||||
|
valid_epoch_loss, valid_epoch_acc = validate(model, valid_loader,
|
||||||
|
criterion)
|
||||||
|
train_loss.append(train_epoch_loss)
|
||||||
|
valid_loss.append(valid_epoch_loss)
|
||||||
|
train_acc.append(train_epoch_acc)
|
||||||
|
valid_acc.append(valid_epoch_acc)
|
||||||
|
print(f"Training loss: {train_epoch_loss:.3f}, training acc: {train_epoch_acc:.3f}")
|
||||||
|
print(f"Validation loss: {valid_epoch_loss:.3f}, validation acc: {valid_epoch_acc:.3f}")
|
||||||
|
print('-'*50)
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
# save the trained model weights
|
||||||
|
save_model(epochs, model, optimizer, criterion)
|
||||||
|
# save the loss and accuracy plots
|
||||||
|
save_plots(train_acc, valid_acc, train_loss, valid_loss)
|
||||||
|
print('TRAINING COMPLETE')
|
49
neural_network/utils.py
Normal file
49
neural_network/utils.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import torch
|
||||||
|
import matplotlib
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
matplotlib.style.use('ggplot')
|
||||||
|
|
||||||
|
def save_model(epochs, model, optimizer, criterion):
|
||||||
|
"""
|
||||||
|
Function to save the trained model to disk.
|
||||||
|
"""
|
||||||
|
torch.save({
|
||||||
|
'epoch': epochs,
|
||||||
|
'model_state_dict': model.state_dict(),
|
||||||
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
'loss': criterion,
|
||||||
|
}, 'outputs/model.pth')
|
||||||
|
|
||||||
|
def save_plots(train_acc, valid_acc, train_loss, valid_loss):
|
||||||
|
"""
|
||||||
|
Function to save the loss and accuracy plots to disk.
|
||||||
|
"""
|
||||||
|
# accuracy plots
|
||||||
|
plt.figure(figsize=(10, 7))
|
||||||
|
plt.plot(
|
||||||
|
train_acc, color='green', linestyle='-',
|
||||||
|
label='train accuracy'
|
||||||
|
)
|
||||||
|
plt.plot(
|
||||||
|
valid_acc, color='blue', linestyle='-',
|
||||||
|
label='validataion accuracy'
|
||||||
|
)
|
||||||
|
plt.xlabel('Epochs')
|
||||||
|
plt.ylabel('Accuracy')
|
||||||
|
plt.legend()
|
||||||
|
plt.savefig('outputs/accuracy.png')
|
||||||
|
|
||||||
|
# loss plots
|
||||||
|
plt.figure(figsize=(10, 7))
|
||||||
|
plt.plot(
|
||||||
|
train_loss, color='orange', linestyle='-',
|
||||||
|
label='train loss'
|
||||||
|
)
|
||||||
|
plt.plot(
|
||||||
|
valid_loss, color='red', linestyle='-',
|
||||||
|
label='validataion loss'
|
||||||
|
)
|
||||||
|
plt.xlabel('Epochs')
|
||||||
|
plt.ylabel('Loss')
|
||||||
|
plt.legend()
|
||||||
|
plt.savefig('outputs/loss.png')
|
Loading…
Reference in New Issue
Block a user