modified the CNN to get better results

This commit is contained in:
MarRac 2024-05-27 04:27:46 +02:00
parent 35d7a6ab6e
commit 8ddba0c828
5 changed files with 9 additions and 13 deletions

BIN
source/CNN_model.pth Normal file

Binary file not shown.

View File

@ -10,19 +10,15 @@ class Conv_Neural_Network_Model(nn.Module):
self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.fc1 = nn.Linear(128*12*12,hidden_layer1)
self.fc1 = nn.Linear(64*25*25,hidden_layer1)
self.fc2 = nn.Linear(hidden_layer1,hidden_layer2)
self.out = nn.Linear(hidden_layer2,num_classes)
def forward(self, x):
x = self.pool1(F.relu(self.conv1(x)))
x = self.pool2(F.relu(self.conv2(x)))
x = self.pool3(F.relu(self.conv3(x)))
x = x.view(-1, 128*12*12) #<----flattening the image
x = self.pool1(F.relu(self.conv2(x)))
x = x.view(-1, 64*25*25) #<----flattening the image
x = self.fc1(x)
x = torch.relu(x)
x = self.fc2(x)

View File

@ -29,7 +29,7 @@ test_set = datasets.ImageFolder(root='resources/test', transform=data_transforme
#function for training model
def train(model, dataset, iter=100, batch_size=64):
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.NLLLoss()
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
model.train()
@ -60,13 +60,13 @@ model = Conv_Neural_Network_Model()
model.to(device)
#loading the already saved model:
# model.load_state_dict(torch.load('model.pth'))
# model.eval()
model.load_state_dict(torch.load('CNN_model.pth'))
model.eval()
#training the model:
# #training the model:
# train(model, train_set)
# print(f"Accuracy of the network is: {100*accuracy(model, test_set)}%")
# torch.save(model.state_dict(), 'model.pth')
# torch.save(model.state_dict(), 'CNN_model.pth')
@ -98,7 +98,7 @@ def guess_image(model, image_tensor):
# print(f"The predicted image is: {prediction}")
#TEST - loading the image and getting results:
testImage_path = 'resources/images/plant_photos/pexels-justus-menke-3490295-5213970.jpg'
testImage_path = 'resources/images/plant_photos/1c76aa4d-11f4-47d1-8bdd-2cb78deeeccf.jpg'
testImage = Image.open(testImage_path)
testImage = data_transformer(testImage)
testImage = testImage.unsqueeze(0)

Binary file not shown.