Traktor/neuralnetwork.py

102 lines
3.4 KiB
Python
Raw Normal View History

2024-05-23 01:57:24 +02:00
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from PIL import Image
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
2024-05-24 05:46:37 +02:00
self.fc1 = nn.Linear(64 * 64 * 3, 128) # *3 aby wchodziły kolory
2024-05-23 01:57:24 +02:00
self.fc2 = nn.Linear(128, 64)
2024-05-23 23:40:01 +02:00
self.fc3 = nn.Linear(64, 10)
2024-05-23 01:57:24 +02:00
self.relu = nn.ReLU()
2024-05-23 23:40:01 +02:00
2024-05-23 01:57:24 +02:00
self.log_softmax = nn.LogSoftmax(dim=1)
def forward(self, x):
x = x.view(x.size(0), -1) # Spłaszczenie obrazów
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.log_softmax(self.fc3(x))
return x
2024-05-24 05:46:37 +02:00
def train(model, train_loader, n_iter=100):
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) #adam niby lepszy
2024-05-23 01:57:24 +02:00
criterion = nn.NLLLoss()
model.train()
for epoch in range(n_iter):
running_loss = 0.0
for images, targets in train_loader:
images, targets = images.to(device), targets.to(device)
optimizer.zero_grad()
out = model(images)
loss = criterion(out, targets)
loss.backward()
optimizer.step()
running_loss += loss.item()
if epoch % 10 == 0:
print(f'Epoch: {epoch:3d}, Loss: {running_loss/len(train_loader):.4f}')
def predict_image(image_path, model):
image = Image.open(image_path)
2024-05-24 05:46:37 +02:00
if image.mode == 'RGBA':
image = image.convert('RGB')
elif image.mode != 'RGB':
image = image.convert('RGB')
2024-05-23 01:57:24 +02:00
image = transform(image).unsqueeze(0).to(device)
2024-05-24 05:46:37 +02:00
2024-05-23 23:40:01 +02:00
class_names = ["fasola","brokul","kapusta","marchewka", "kalafior",
"ogorek", "ziemniak", "dynia", "rzodkiewka", "pomidor"]
2024-05-23 01:57:24 +02:00
with torch.no_grad():
output = model(image)
_, predicted = torch.max(output, 1)
return class_names[predicted.item()]
def accuracy(model, dataset):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, targets in DataLoader(dataset, batch_size=256):
images, targets = images.to(device), targets.to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)
correct += (predicted == targets).sum().item()
total += targets.size(0)
return correct / total
def load_model(model_path):
model = SimpleNN()
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
return model
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
transform = transforms.Compose([
transforms.Resize((64, 64)),
2024-05-24 05:46:37 +02:00
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2), #random jasnosci i wywaliłem ze wszystko jest szare
2024-05-23 01:57:24 +02:00
transforms.ToTensor(),
2024-05-24 05:46:37 +02:00
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
2024-05-23 01:57:24 +02:00
])
train_data_path = 'train'
train_dataset = datasets.ImageFolder(root=train_data_path, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_data_path = 'test'
test_dataset = datasets.ImageFolder(root=train_data_path, transform=transform)
model = SimpleNN().to(device)
2024-05-26 05:12:46 +02:00
'''
2024-05-23 01:57:24 +02:00
if __name__ == "__main__":
2024-05-24 05:46:37 +02:00
train(model, train_loader, n_iter=100)
2024-05-23 01:57:24 +02:00
torch.save(model.state_dict(), 'model.pth')
2024-05-23 23:40:01 +02:00
model.load_state_dict(torch.load('model.pth', map_location=device))
model_accuracy = accuracy(model, test_dataset)
print(f'Dokładność modelu: {model_accuracy * 100:.2f}%')
2024-05-26 05:12:46 +02:00
'''