98% , ale dalej myli pola

This commit is contained in:
s464923 2024-05-24 05:46:37 +02:00
parent 6ce5ce0468
commit 2756bab940
2 changed files with 12 additions and 5 deletions

BIN
model.pth

Binary file not shown.

View File

@ -7,7 +7,7 @@ from PIL import Image
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(64 * 64, 128)
self.fc1 = nn.Linear(64 * 64 * 3, 128) # *3 aby wchodziły kolory
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)
self.relu = nn.ReLU()
@ -21,8 +21,8 @@ class SimpleNN(nn.Module):
x = self.log_softmax(self.fc3(x))
return x
def train(model, train_loader, n_iter=200):
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
def train(model, train_loader, n_iter=100):
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) #adam niby lepszy
criterion = nn.NLLLoss()
model.train()
for epoch in range(n_iter):
@ -42,7 +42,13 @@ def predict_image(image_path, model):
image_path = "warzywa/" + str(image_path) + ".png"
image = Image.open(image_path)
if image.mode == 'RGBA':
image = image.convert('RGB')
elif image.mode != 'RGB':
image = image.convert('RGB')
image = transform(image).unsqueeze(0).to(device)
class_names = ["fasola","brokul","kapusta","marchewka", "kalafior",
"ogorek", "ziemniak", "dynia", "rzodkiewka", "pomidor"]
with torch.no_grad():
@ -72,8 +78,9 @@ def load_model(model_path):
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.Grayscale(num_output_channels=1),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2), #random jasnosci i wywaliłem ze wszystko jest szare
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_data_path = 'train'
train_dataset = datasets.ImageFolder(root=train_data_path, transform=transform)
@ -86,7 +93,7 @@ model = SimpleNN().to(device)
if __name__ == "__main__":
train(model, train_loader, n_iter=200)
train(model, train_loader, n_iter=100)
torch.save(model.state_dict(), 'model.pth')
model.load_state_dict(torch.load('model.pth', map_location=device))