Compare commits
2 Commits
b5d25d710d
...
9e978d6032
Author | SHA1 | Date | |
---|---|---|---|
|
9e978d6032 | ||
|
c363b09f85 |
BIN
source/NN/__pycache__/model.cpython-311.pyc
Normal file
@ -1,4 +1,6 @@
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Neural_Network_Model(nn.Module):
|
||||
@ -16,5 +18,4 @@ class Neural_Network_Model(nn.Module):
|
||||
x = self.fc2(x)
|
||||
x = torch.relu(x)
|
||||
x = self.out(x)
|
||||
F.log_softmax(x, dim=-1)
|
||||
return x
|
||||
return F.log_softmax(x, dim=-1)
|
||||
|
@ -1,15 +1,17 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import datasets, transforms
|
||||
from torchvision import datasets, transforms, utils
|
||||
from torchvision.transforms import Compose, Lambda, ToTensor
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from model import *
|
||||
from PIL import Image
|
||||
|
||||
device = torch.device('cuda')
|
||||
|
||||
#data transform to tensors:
|
||||
data_transformer = transforms.Compose
|
||||
([
|
||||
data_transformer = transforms.Compose([
|
||||
transforms.Resize((150, 150)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
@ -31,10 +33,9 @@ test_set = datasets.ImageFolder(root='resources/test', transform=data_transforme
|
||||
#print(train_set.targets[3002])
|
||||
|
||||
|
||||
#loading your own image: <-- zrobię to na koniec - wrzucanie konkretnego obrazka aby uzyskac wynik
|
||||
#function for training model
|
||||
def train(model, dataset, iter=100, batch_size=64):
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
criterion = nn.NLLLoss()
|
||||
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
||||
model.train()
|
||||
@ -46,9 +47,13 @@ def train(model, dataset, iter=100, batch_size=64):
|
||||
loss = criterion(output, labels.to(device))
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
if epoch % 10 == 0:
|
||||
print('epoch: %3d loss: %.4f' % (epoch, loss))
|
||||
|
||||
#function for getting accuracy
|
||||
def accuracy(model, dataset):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
correct = sum([
|
||||
(model(inputs.to(device)).argmax(dim=1) == labels.to(device)).sum()
|
||||
for inputs, labels in DataLoader(dataset, batch_size=64, shuffle=True)
|
||||
@ -57,6 +62,34 @@ def accuracy(model, dataset):
|
||||
return correct.float() / len(dataset)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
model = Neural_Network_Model()
|
||||
train(model, train_set)
|
||||
print(accuracy(model, test_set))
|
||||
model.to(device)
|
||||
|
||||
model.load_state_dict(torch.load('model.pth'))
|
||||
model.eval()
|
||||
|
||||
#training the model:
|
||||
# train(model, train_set)
|
||||
# print(f"Accuracy of the network is: {100*accuracy(model, test_set)}%")
|
||||
# torch.save(model.state_dict(), 'model.pth')
|
||||
|
||||
|
||||
#TEST - loading the image and getting results:
|
||||
testImage_path = 'resources/images/plant_photos/pexels-polina-tankilevitch-4110456.jpg'
|
||||
testImage = Image.open(testImage_path)
|
||||
testImage = data_transformer(testImage)
|
||||
testImage = testImage.unsqueeze(0)
|
||||
testImage = testImage.to(device)
|
||||
|
||||
model.load_state_dict(torch.load('model.pth'))
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
testOutput = model(testImage)
|
||||
_, predicted = torch.max(testOutput, 1)
|
||||
predicted_class = train_set.classes[predicted.item()]
|
||||
print(f'The predicted class is: {predicted_class}')
|
||||
|
||||
|
BIN
source/model.pth
Normal file
BIN
source/resources/images/plant_photos/00187550-Wheat-field.jpg
Normal file
After Width: | Height: | Size: 190 KiB |
After Width: | Height: | Size: 1.8 MiB |
BIN
source/resources/images/plant_photos/apple01-lg.jpg
Normal file
After Width: | Height: | Size: 99 KiB |
BIN
source/resources/images/plant_photos/apple1.jpg
Normal file
After Width: | Height: | Size: 5.5 KiB |
After Width: | Height: | Size: 1.3 MiB |
After Width: | Height: | Size: 1.2 MiB |
After Width: | Height: | Size: 888 KiB |