98% , ale dalej myli pola

This commit is contained in:
s464923 2024-05-24 05:46:37 +02:00
parent 6ce5ce0468
commit 2756bab940
2 changed files with 12 additions and 5 deletions

BIN
model.pth

Binary file not shown.

View File

@ -7,7 +7,7 @@ from PIL import Image
class SimpleNN(nn.Module): class SimpleNN(nn.Module):
def __init__(self): def __init__(self):
super(SimpleNN, self).__init__() super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(64 * 64, 128) self.fc1 = nn.Linear(64 * 64 * 3, 128) # *3 aby wchodziły kolory
self.fc2 = nn.Linear(128, 64) self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10) self.fc3 = nn.Linear(64, 10)
self.relu = nn.ReLU() self.relu = nn.ReLU()
@ -21,8 +21,8 @@ class SimpleNN(nn.Module):
x = self.log_softmax(self.fc3(x)) x = self.log_softmax(self.fc3(x))
return x return x
def train(model, train_loader, n_iter=200): def train(model, train_loader, n_iter=100):
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) #adam niby lepszy
criterion = nn.NLLLoss() criterion = nn.NLLLoss()
model.train() model.train()
for epoch in range(n_iter): for epoch in range(n_iter):
@ -42,7 +42,13 @@ def predict_image(image_path, model):
image_path = "warzywa/" + str(image_path) + ".png" image_path = "warzywa/" + str(image_path) + ".png"
image = Image.open(image_path) image = Image.open(image_path)
if image.mode == 'RGBA':
image = image.convert('RGB')
elif image.mode != 'RGB':
image = image.convert('RGB')
image = transform(image).unsqueeze(0).to(device) image = transform(image).unsqueeze(0).to(device)
class_names = ["fasola","brokul","kapusta","marchewka", "kalafior", class_names = ["fasola","brokul","kapusta","marchewka", "kalafior",
"ogorek", "ziemniak", "dynia", "rzodkiewka", "pomidor"] "ogorek", "ziemniak", "dynia", "rzodkiewka", "pomidor"]
with torch.no_grad(): with torch.no_grad():
@ -72,8 +78,9 @@ def load_model(model_path):
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
transform = transforms.Compose([ transform = transforms.Compose([
transforms.Resize((64, 64)), transforms.Resize((64, 64)),
transforms.Grayscale(num_output_channels=1), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2), #random jasnosci i wywaliłem ze wszystko jest szare
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]) ])
train_data_path = 'train' train_data_path = 'train'
train_dataset = datasets.ImageFolder(root=train_data_path, transform=transform) train_dataset = datasets.ImageFolder(root=train_data_path, transform=transform)
@ -86,7 +93,7 @@ model = SimpleNN().to(device)
if __name__ == "__main__": if __name__ == "__main__":
train(model, train_loader, n_iter=200) train(model, train_loader, n_iter=100)
torch.save(model.state_dict(), 'model.pth') torch.save(model.state_dict(), 'model.pth')
model.load_state_dict(torch.load('model.pth', map_location=device)) model.load_state_dict(torch.load('model.pth', map_location=device))