forked from s464965/WMICraft
83 lines
2.4 KiB
Python
83 lines
2.4 KiB
Python
import torch
|
|
from common.constants import device, batch_size, num_epochs, learning_rate, setup_photos, id_to_class
|
|
from watersandtreegrass import WaterSandTreeGrass
|
|
from torch.utils.data import DataLoader
|
|
from neural_network import NeuralNetwork
|
|
from torchvision.io import read_image, ImageReadMode
|
|
import torch.nn as nn
|
|
from torch.optim import Adam
|
|
|
|
CNN = NeuralNetwork().to(device)
|
|
|
|
|
|
def train(model):
|
|
model.train()
|
|
trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=setup_photos)
|
|
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
|
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = Adam(model.parameters(), lr=learning_rate)
|
|
|
|
for epoch in range(num_epochs):
|
|
for batch_idx, (data, targets) in enumerate(train_loader):
|
|
data = data.to(device=device)
|
|
targets = targets.to(device=device)
|
|
|
|
scores = model(data)
|
|
loss = criterion(scores, targets)
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
|
|
if epoch % 2 == 0:
|
|
print("epoch: %3d loss: %.4f" % (epoch, loss.item()))
|
|
|
|
print("FINISHED!")
|
|
print("Checking accuracy.")
|
|
check_accuracy(train_loader)
|
|
|
|
torch.save(model.state_dict(), "./learnedNetwork.pt")
|
|
|
|
|
|
def check_accuracy(loader):
|
|
num_correct = 0
|
|
num_samples = 0
|
|
model = NeuralNetwork()
|
|
|
|
model.load_state_dict(torch.load("./learnedNetwork.pt", map_location=device))
|
|
model = model.to(device)
|
|
|
|
with torch.no_grad():
|
|
model.eval()
|
|
for x, y in loader:
|
|
x = x.to(device=device)
|
|
y = y.to(device=device)
|
|
|
|
scores = model(x)
|
|
|
|
_, predictions = scores.max(1)
|
|
num_correct += (predictions == y).sum()
|
|
num_samples += predictions.size(0)
|
|
|
|
print(f"Got {num_correct}/{num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}")
|
|
|
|
|
|
def what_is_it(img_path):
|
|
image = read_image(img_path, mode=ImageReadMode.RGB)
|
|
image = setup_photos(image).unsqueeze(0)
|
|
model = NeuralNetwork()
|
|
|
|
model.load_state_dict(torch.load("./learnedNetwork.pt", map_location=device))
|
|
model = model.to(device)
|
|
image = image.to(device)
|
|
|
|
with torch.no_grad():
|
|
model.eval()
|
|
idx = int(model(image).argmax(dim=1))
|
|
return id_to_class[idx]
|
|
|
|
|
|
print(what_is_it('./data/test/water/water.png'))
|