modified the CNN to get better results
This commit is contained in:
parent
35d7a6ab6e
commit
8ddba0c828
BIN
source/CNN_model.pth
Normal file
BIN
source/CNN_model.pth
Normal file
Binary file not shown.
Binary file not shown.
@ -10,19 +10,15 @@ class Conv_Neural_Network_Model(nn.Module):
|
|||||||
self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1)
|
self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1)
|
||||||
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
|
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
|
||||||
self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
|
self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
|
||||||
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
|
|
||||||
self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
|
|
||||||
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
|
|
||||||
|
|
||||||
self.fc1 = nn.Linear(128*12*12,hidden_layer1)
|
self.fc1 = nn.Linear(64*25*25,hidden_layer1)
|
||||||
self.fc2 = nn.Linear(hidden_layer1,hidden_layer2)
|
self.fc2 = nn.Linear(hidden_layer1,hidden_layer2)
|
||||||
self.out = nn.Linear(hidden_layer2,num_classes)
|
self.out = nn.Linear(hidden_layer2,num_classes)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.pool1(F.relu(self.conv1(x)))
|
x = self.pool1(F.relu(self.conv1(x)))
|
||||||
x = self.pool2(F.relu(self.conv2(x)))
|
x = self.pool1(F.relu(self.conv2(x)))
|
||||||
x = self.pool3(F.relu(self.conv3(x)))
|
x = x.view(-1, 64*25*25) #<----flattening the image
|
||||||
x = x.view(-1, 128*12*12) #<----flattening the image
|
|
||||||
x = self.fc1(x)
|
x = self.fc1(x)
|
||||||
x = torch.relu(x)
|
x = torch.relu(x)
|
||||||
x = self.fc2(x)
|
x = self.fc2(x)
|
||||||
|
@ -29,7 +29,7 @@ test_set = datasets.ImageFolder(root='resources/test', transform=data_transforme
|
|||||||
|
|
||||||
#function for training model
|
#function for training model
|
||||||
def train(model, dataset, iter=100, batch_size=64):
|
def train(model, dataset, iter=100, batch_size=64):
|
||||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||||
criterion = nn.NLLLoss()
|
criterion = nn.NLLLoss()
|
||||||
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
||||||
model.train()
|
model.train()
|
||||||
@ -60,13 +60,13 @@ model = Conv_Neural_Network_Model()
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
#loading the already saved model:
|
#loading the already saved model:
|
||||||
# model.load_state_dict(torch.load('model.pth'))
|
model.load_state_dict(torch.load('CNN_model.pth'))
|
||||||
# model.eval()
|
model.eval()
|
||||||
|
|
||||||
#training the model:
|
# #training the model:
|
||||||
# train(model, train_set)
|
# train(model, train_set)
|
||||||
# print(f"Accuracy of the network is: {100*accuracy(model, test_set)}%")
|
# print(f"Accuracy of the network is: {100*accuracy(model, test_set)}%")
|
||||||
# torch.save(model.state_dict(), 'model.pth')
|
# torch.save(model.state_dict(), 'CNN_model.pth')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -98,7 +98,7 @@ def guess_image(model, image_tensor):
|
|||||||
# print(f"The predicted image is: {prediction}")
|
# print(f"The predicted image is: {prediction}")
|
||||||
|
|
||||||
#TEST - loading the image and getting results:
|
#TEST - loading the image and getting results:
|
||||||
testImage_path = 'resources/images/plant_photos/pexels-justus-menke-3490295-5213970.jpg'
|
testImage_path = 'resources/images/plant_photos/1c76aa4d-11f4-47d1-8bdd-2cb78deeeccf.jpg'
|
||||||
testImage = Image.open(testImage_path)
|
testImage = Image.open(testImage_path)
|
||||||
testImage = data_transformer(testImage)
|
testImage = data_transformer(testImage)
|
||||||
testImage = testImage.unsqueeze(0)
|
testImage = testImage.unsqueeze(0)
|
||||||
|
BIN
source/model.pth
BIN
source/model.pth
Binary file not shown.
Loading…
Reference in New Issue
Block a user