2024-05-23 01:57:24 +02:00
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
from torchvision import datasets, transforms
|
|
|
|
from PIL import Image
|
|
|
|
|
|
|
|
class SimpleNN(nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super(SimpleNN, self).__init__()
|
2024-05-24 05:46:37 +02:00
|
|
|
self.fc1 = nn.Linear(64 * 64 * 3, 128) # *3 aby wchodziły kolory
|
2024-05-23 01:57:24 +02:00
|
|
|
self.fc2 = nn.Linear(128, 64)
|
2024-05-23 23:40:01 +02:00
|
|
|
self.fc3 = nn.Linear(64, 10)
|
2024-05-23 01:57:24 +02:00
|
|
|
self.relu = nn.ReLU()
|
2024-05-23 23:40:01 +02:00
|
|
|
|
2024-05-23 01:57:24 +02:00
|
|
|
self.log_softmax = nn.LogSoftmax(dim=1)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = x.view(x.size(0), -1) # Spłaszczenie obrazów
|
|
|
|
x = self.relu(self.fc1(x))
|
|
|
|
x = self.relu(self.fc2(x))
|
|
|
|
x = self.log_softmax(self.fc3(x))
|
|
|
|
return x
|
|
|
|
|
2024-05-24 05:46:37 +02:00
|
|
|
def train(model, train_loader, n_iter=100):
|
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) #adam niby lepszy
|
2024-05-23 01:57:24 +02:00
|
|
|
criterion = nn.NLLLoss()
|
|
|
|
model.train()
|
|
|
|
for epoch in range(n_iter):
|
|
|
|
running_loss = 0.0
|
|
|
|
for images, targets in train_loader:
|
|
|
|
images, targets = images.to(device), targets.to(device)
|
|
|
|
optimizer.zero_grad()
|
|
|
|
out = model(images)
|
|
|
|
loss = criterion(out, targets)
|
|
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
|
|
running_loss += loss.item()
|
|
|
|
if epoch % 10 == 0:
|
|
|
|
print(f'Epoch: {epoch:3d}, Loss: {running_loss/len(train_loader):.4f}')
|
|
|
|
|
|
|
|
def predict_image(image_path, model):
|
|
|
|
image = Image.open(image_path)
|
2024-05-24 05:46:37 +02:00
|
|
|
if image.mode == 'RGBA':
|
|
|
|
image = image.convert('RGB')
|
|
|
|
elif image.mode != 'RGB':
|
|
|
|
image = image.convert('RGB')
|
|
|
|
|
2024-05-23 01:57:24 +02:00
|
|
|
image = transform(image).unsqueeze(0).to(device)
|
2024-05-24 05:46:37 +02:00
|
|
|
|
2024-05-23 23:40:01 +02:00
|
|
|
class_names = ["fasola","brokul","kapusta","marchewka", "kalafior",
|
|
|
|
"ogorek", "ziemniak", "dynia", "rzodkiewka", "pomidor"]
|
2024-05-23 01:57:24 +02:00
|
|
|
with torch.no_grad():
|
|
|
|
output = model(image)
|
|
|
|
_, predicted = torch.max(output, 1)
|
|
|
|
return class_names[predicted.item()]
|
|
|
|
|
|
|
|
def accuracy(model, dataset):
|
|
|
|
model.eval()
|
|
|
|
correct = 0
|
|
|
|
total = 0
|
|
|
|
with torch.no_grad():
|
|
|
|
for images, targets in DataLoader(dataset, batch_size=256):
|
|
|
|
images, targets = images.to(device), targets.to(device)
|
|
|
|
outputs = model(images)
|
|
|
|
_, predicted = torch.max(outputs, 1)
|
|
|
|
correct += (predicted == targets).sum().item()
|
|
|
|
total += targets.size(0)
|
|
|
|
return correct / total
|
|
|
|
|
|
|
|
def load_model(model_path):
|
|
|
|
model = SimpleNN()
|
|
|
|
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
|
|
|
model.eval()
|
|
|
|
return model
|
|
|
|
|
|
|
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
|
|
|
transform = transforms.Compose([
|
|
|
|
transforms.Resize((64, 64)),
|
2024-05-24 05:46:37 +02:00
|
|
|
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2), #random jasnosci i wywaliłem ze wszystko jest szare
|
2024-05-23 01:57:24 +02:00
|
|
|
transforms.ToTensor(),
|
2024-05-24 05:46:37 +02:00
|
|
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
2024-05-23 01:57:24 +02:00
|
|
|
])
|
|
|
|
train_data_path = 'train'
|
|
|
|
train_dataset = datasets.ImageFolder(root=train_data_path, transform=transform)
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
|
|
|
|
|
|
|
|
test_data_path = 'test'
|
|
|
|
test_dataset = datasets.ImageFolder(root=train_data_path, transform=transform)
|
|
|
|
|
|
|
|
model = SimpleNN().to(device)
|
|
|
|
|
2024-05-26 05:12:46 +02:00
|
|
|
'''
|
2024-05-23 01:57:24 +02:00
|
|
|
if __name__ == "__main__":
|
2024-05-24 05:46:37 +02:00
|
|
|
train(model, train_loader, n_iter=100)
|
2024-05-23 01:57:24 +02:00
|
|
|
torch.save(model.state_dict(), 'model.pth')
|
2024-05-23 23:40:01 +02:00
|
|
|
|
|
|
|
model.load_state_dict(torch.load('model.pth', map_location=device))
|
|
|
|
|
|
|
|
model_accuracy = accuracy(model, test_dataset)
|
|
|
|
print(f'Dokładność modelu: {model_accuracy * 100:.2f}%')
|
2024-05-26 05:12:46 +02:00
|
|
|
'''
|