diff --git a/source/NN/neural_network.py b/source/NN/neural_network.py index 824a578..fa61040 100644 --- a/source/NN/neural_network.py +++ b/source/NN/neural_network.py @@ -6,6 +6,7 @@ from torchvision.transforms import Compose, Lambda, ToTensor import matplotlib.pyplot as plt import numpy as np from model import * +from PIL import Image device = torch.device('cuda') @@ -32,10 +33,9 @@ test_set = datasets.ImageFolder(root='resources/test', transform=data_transforme #print(train_set.targets[3002]) -#loading your own image: <-- zrobię to na koniec - wrzucanie konkretnego obrazka aby uzyskac wynik #function for training model def train(model, dataset, iter=100, batch_size=64): - optimizer = torch.optim.SGD(model.parameters(), lr=0.001) + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) criterion = nn.NLLLoss() train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) model.train() @@ -53,17 +53,43 @@ def train(model, dataset, iter=100, batch_size=64): #function for getting accuracy def accuracy(model, dataset): model.eval() - correct = sum([ - (model(inputs.to(device)).argmax(dim=1) == labels.to(device)).sum() - for inputs, labels in DataLoader(dataset, batch_size=64, shuffle=True) - ]) + with torch.no_grad(): + correct = sum([ + (model(inputs.to(device)).argmax(dim=1) == labels.to(device)).sum() + for inputs, labels in DataLoader(dataset, batch_size=64, shuffle=True) + ]) return correct.float() / len(dataset) + model = Neural_Network_Model() model.to(device) -train(model, train_set) -print(accuracy(model, test_set)) + +model.load_state_dict(torch.load('model.pth')) +model.eval() + +#training the model: +# train(model, train_set) +# print(f"Accuracy of the network is: {100*accuracy(model, test_set)}%") +# torch.save(model.state_dict(), 'model.pth') + + +#TEST - loading the image and getting results: +testImage_path = 'resources/images/plant_photos/pexels-polina-tankilevitch-4110456.jpg' +testImage = Image.open(testImage_path) +testImage = data_transformer(testImage) +testImage = testImage.unsqueeze(0) +testImage = testImage.to(device) + +model.load_state_dict(torch.load('model.pth')) +model.to(device) +model.eval() + +testOutput = model(testImage) +_, predicted = torch.max(testOutput, 1) +predicted_class = train_set.classes[predicted.item()] +print(f'The predicted class is: {predicted_class}') + diff --git a/source/model.pth b/source/model.pth new file mode 100644 index 0000000..b752136 Binary files /dev/null and b/source/model.pth differ diff --git a/source/resources/images/plant_photos/00187550-Wheat-field.jpg b/source/resources/images/plant_photos/00187550-Wheat-field.jpg new file mode 100644 index 0000000..7d3f6a1 Binary files /dev/null and b/source/resources/images/plant_photos/00187550-Wheat-field.jpg differ diff --git a/source/resources/images/plant_photos/Triticum_monococcum_IMG_3029.jpg b/source/resources/images/plant_photos/Triticum_monococcum_IMG_3029.jpg new file mode 100644 index 0000000..1fe28d1 Binary files /dev/null and b/source/resources/images/plant_photos/Triticum_monococcum_IMG_3029.jpg differ diff --git a/source/resources/images/plant_photos/apple01-lg.jpg b/source/resources/images/plant_photos/apple01-lg.jpg new file mode 100644 index 0000000..e45a004 Binary files /dev/null and b/source/resources/images/plant_photos/apple01-lg.jpg differ diff --git a/source/resources/images/plant_photos/apple1.jpg b/source/resources/images/plant_photos/apple1.jpg new file mode 100644 index 0000000..01448cb Binary files /dev/null and b/source/resources/images/plant_photos/apple1.jpg differ diff --git a/source/resources/images/plant_photos/pexels-justus-menke-3490295-5213093.jpg b/source/resources/images/plant_photos/pexels-justus-menke-3490295-5213093.jpg new file mode 100644 index 0000000..2c64187 Binary files /dev/null and b/source/resources/images/plant_photos/pexels-justus-menke-3490295-5213093.jpg differ diff --git a/source/resources/images/plant_photos/pexels-justus-menke-3490295-5213970.jpg b/source/resources/images/plant_photos/pexels-justus-menke-3490295-5213970.jpg new file mode 100644 index 0000000..5b667e7 Binary files /dev/null and b/source/resources/images/plant_photos/pexels-justus-menke-3490295-5213970.jpg differ diff --git a/source/resources/images/plant_photos/pexels-karolina-grabowska-4038801.jpg b/source/resources/images/plant_photos/pexels-karolina-grabowska-4038801.jpg new file mode 100644 index 0000000..f57029d Binary files /dev/null and b/source/resources/images/plant_photos/pexels-karolina-grabowska-4038801.jpg differ