98% , ale dalej myli pola
This commit is contained in:
parent
6ce5ce0468
commit
2756bab940
@ -7,7 +7,7 @@ from PIL import Image
|
||||
class SimpleNN(nn.Module):
|
||||
def __init__(self):
|
||||
super(SimpleNN, self).__init__()
|
||||
self.fc1 = nn.Linear(64 * 64, 128)
|
||||
self.fc1 = nn.Linear(64 * 64 * 3, 128) # *3 aby wchodziły kolory
|
||||
self.fc2 = nn.Linear(128, 64)
|
||||
self.fc3 = nn.Linear(64, 10)
|
||||
self.relu = nn.ReLU()
|
||||
@ -21,8 +21,8 @@ class SimpleNN(nn.Module):
|
||||
x = self.log_softmax(self.fc3(x))
|
||||
return x
|
||||
|
||||
def train(model, train_loader, n_iter=200):
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
|
||||
def train(model, train_loader, n_iter=100):
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) #adam niby lepszy
|
||||
criterion = nn.NLLLoss()
|
||||
model.train()
|
||||
for epoch in range(n_iter):
|
||||
@ -42,7 +42,13 @@ def predict_image(image_path, model):
|
||||
image_path = "warzywa/" + str(image_path) + ".png"
|
||||
|
||||
image = Image.open(image_path)
|
||||
if image.mode == 'RGBA':
|
||||
image = image.convert('RGB')
|
||||
elif image.mode != 'RGB':
|
||||
image = image.convert('RGB')
|
||||
|
||||
image = transform(image).unsqueeze(0).to(device)
|
||||
|
||||
class_names = ["fasola","brokul","kapusta","marchewka", "kalafior",
|
||||
"ogorek", "ziemniak", "dynia", "rzodkiewka", "pomidor"]
|
||||
with torch.no_grad():
|
||||
@ -72,8 +78,9 @@ def load_model(model_path):
|
||||
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((64, 64)),
|
||||
transforms.Grayscale(num_output_channels=1),
|
||||
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2), #random jasnosci i wywaliłem ze wszystko jest szare
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
train_data_path = 'train'
|
||||
train_dataset = datasets.ImageFolder(root=train_data_path, transform=transform)
|
||||
@ -86,7 +93,7 @@ model = SimpleNN().to(device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train(model, train_loader, n_iter=200)
|
||||
train(model, train_loader, n_iter=100)
|
||||
torch.save(model.state_dict(), 'model.pth')
|
||||
|
||||
model.load_state_dict(torch.load('model.pth', map_location=device))
|
||||
|
Loading…
Reference in New Issue
Block a user