WMICraft/algorithms/neural_network/neural_network_interface.py

91 lines
2.7 KiB
Python
Raw Normal View History

import torch
from common.constants import device, batch_size, num_epochs, learning_rate, setup_photos, id_to_class
from watersandtreegrass import WaterSandTreeGrass
from torch.utils.data import DataLoader
from neural_network import NeuralNetwork
from torchvision.io import read_image, ImageReadMode
import torch.nn as nn
from torch.optim import Adam
2022-05-18 12:18:59 +02:00
import matplotlib.pyplot as plt
CNN = NeuralNetwork().to(device)
def train(model):
model.train()
2022-05-18 10:29:05 +02:00
trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=setup_photos)
2022-05-18 15:44:07 +02:00
testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=setup_photos)
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
2022-05-18 15:44:07 +02:00
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=True)
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=learning_rate)
for epoch in range(num_epochs):
for batch_idx, (data, targets) in enumerate(train_loader):
data = data.to(device=device)
targets = targets.to(device=device)
scores = model(data)
loss = criterion(scores, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 2 == 0:
2022-05-18 15:44:07 +02:00
print("epoch: %d loss: %.4f" % (epoch, loss.item()))
2022-05-18 15:44:07 +02:00
print("FINISHED TRAINING!")
print("Checking accuracy for the train set.")
check_accuracy(train_loader)
2022-05-18 15:44:07 +02:00
print("Checking accuracy for the test set.")
check_accuracy(test_loader)
torch.save(model.state_dict(), "./learnedNetwork.pt")
def check_accuracy(loader):
num_correct = 0
num_samples = 0
model = NeuralNetwork()
2022-05-18 10:29:05 +02:00
model.load_state_dict(torch.load("./learnedNetwork.pt", map_location=device))
model = model.to(device)
with torch.no_grad():
model.eval()
for x, y in loader:
x = x.to(device=device)
y = y.to(device=device)
scores = model(x)
_, predictions = scores.max(1)
num_correct += (predictions == y).sum()
num_samples += predictions.size(0)
2022-05-18 15:44:07 +02:00
print(f"Got {num_correct}/{num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}%")
2022-05-18 12:18:59 +02:00
def what_is_it(img_path, show_img=False):
image = read_image(img_path, mode=ImageReadMode.RGB)
2022-05-18 12:18:59 +02:00
if show_img:
plt.imshow(plt.imread(img_path))
plt.show()
image = setup_photos(image).unsqueeze(0)
model = NeuralNetwork()
2022-05-18 10:29:05 +02:00
model.load_state_dict(torch.load("./learnedNetwork.pt", map_location=device))
model = model.to(device)
image = image.to(device)
with torch.no_grad():
model.eval()
idx = int(model(image).argmax(dim=1))
return id_to_class[idx]
2022-05-18 15:44:07 +02:00
train(CNN)