98% , ale dalej myli pola
This commit is contained in:
parent
6ce5ce0468
commit
2756bab940
@ -7,7 +7,7 @@ from PIL import Image
|
|||||||
class SimpleNN(nn.Module):
|
class SimpleNN(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(SimpleNN, self).__init__()
|
super(SimpleNN, self).__init__()
|
||||||
self.fc1 = nn.Linear(64 * 64, 128)
|
self.fc1 = nn.Linear(64 * 64 * 3, 128) # *3 aby wchodziły kolory
|
||||||
self.fc2 = nn.Linear(128, 64)
|
self.fc2 = nn.Linear(128, 64)
|
||||||
self.fc3 = nn.Linear(64, 10)
|
self.fc3 = nn.Linear(64, 10)
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
@ -21,8 +21,8 @@ class SimpleNN(nn.Module):
|
|||||||
x = self.log_softmax(self.fc3(x))
|
x = self.log_softmax(self.fc3(x))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def train(model, train_loader, n_iter=200):
|
def train(model, train_loader, n_iter=100):
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) #adam niby lepszy
|
||||||
criterion = nn.NLLLoss()
|
criterion = nn.NLLLoss()
|
||||||
model.train()
|
model.train()
|
||||||
for epoch in range(n_iter):
|
for epoch in range(n_iter):
|
||||||
@ -42,7 +42,13 @@ def predict_image(image_path, model):
|
|||||||
image_path = "warzywa/" + str(image_path) + ".png"
|
image_path = "warzywa/" + str(image_path) + ".png"
|
||||||
|
|
||||||
image = Image.open(image_path)
|
image = Image.open(image_path)
|
||||||
|
if image.mode == 'RGBA':
|
||||||
|
image = image.convert('RGB')
|
||||||
|
elif image.mode != 'RGB':
|
||||||
|
image = image.convert('RGB')
|
||||||
|
|
||||||
image = transform(image).unsqueeze(0).to(device)
|
image = transform(image).unsqueeze(0).to(device)
|
||||||
|
|
||||||
class_names = ["fasola","brokul","kapusta","marchewka", "kalafior",
|
class_names = ["fasola","brokul","kapusta","marchewka", "kalafior",
|
||||||
"ogorek", "ziemniak", "dynia", "rzodkiewka", "pomidor"]
|
"ogorek", "ziemniak", "dynia", "rzodkiewka", "pomidor"]
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -72,8 +78,9 @@ def load_model(model_path):
|
|||||||
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
transform = transforms.Compose([
|
transform = transforms.Compose([
|
||||||
transforms.Resize((64, 64)),
|
transforms.Resize((64, 64)),
|
||||||
transforms.Grayscale(num_output_channels=1),
|
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2), #random jasnosci i wywaliłem ze wszystko jest szare
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||||
])
|
])
|
||||||
train_data_path = 'train'
|
train_data_path = 'train'
|
||||||
train_dataset = datasets.ImageFolder(root=train_data_path, transform=transform)
|
train_dataset = datasets.ImageFolder(root=train_data_path, transform=transform)
|
||||||
@ -86,7 +93,7 @@ model = SimpleNN().to(device)
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
train(model, train_loader, n_iter=200)
|
train(model, train_loader, n_iter=100)
|
||||||
torch.save(model.state_dict(), 'model.pth')
|
torch.save(model.state_dict(), 'model.pth')
|
||||||
|
|
||||||
model.load_state_dict(torch.load('model.pth', map_location=device))
|
model.load_state_dict(torch.load('model.pth', map_location=device))
|
||||||
|
Loading…
Reference in New Issue
Block a user