115 lines
3.9 KiB
Python
115 lines
3.9 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data import DataLoader
|
|
from torchvision import datasets
|
|
from torchvision.transforms import Compose, Lambda, ToTensor
|
|
import torchvision.transforms as transforms
|
|
import matplotlib.pyplot as plt
|
|
from PIL import Image
|
|
import random
|
|
|
|
imageSize = (128, 128)
|
|
labels = ['carrot','corn', 'potato', 'tomato'] # musi być w kolejności alfabetycznej
|
|
fertilizer = {labels[0]: 'kompost', labels[1]: 'saletra amonowa', labels[2]: 'superfosfat', labels[3]:'obornik kurzy'}
|
|
#labels = ['corn','tomato'] #uncomment this two lines for 2 crops only
|
|
#fertilizer = {labels[0]: 'kompost', labels[1]: 'saletra amonowa'}
|
|
torch.manual_seed(42)
|
|
|
|
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
device = torch.device("cpu")
|
|
# device = torch.device("mps") if torch.backends.mps.is_available() else torch.device('cpu')
|
|
# print(device)
|
|
|
|
def getTransformation():
|
|
transform=transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
|
transforms.Resize(imageSize),
|
|
Lambda(lambda x: x.flatten())])
|
|
return transform
|
|
|
|
def getDataset(train=True):
|
|
transform = getTransformation()
|
|
if (train):
|
|
trainset = datasets.ImageFolder(root='dataset/train', transform=transform)
|
|
return trainset
|
|
else:
|
|
testset = datasets.ImageFolder(root='dataset/test', transform=transform)
|
|
return testset
|
|
|
|
|
|
def train(model, dataset, n_iter=100, batch_size=256):
|
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
|
criterion = nn.NLLLoss()
|
|
dl = DataLoader(dataset, batch_size=batch_size)
|
|
model.train()
|
|
for epoch in range(n_iter):
|
|
for images, targets in dl:
|
|
optimizer.zero_grad()
|
|
out = model(images.to(device))
|
|
loss = criterion(out, targets.to(device))
|
|
loss.backward()
|
|
optimizer.step()
|
|
if epoch % 10 == 0:
|
|
print('epoch: %3d loss: %.4f' % (epoch, loss))
|
|
return model
|
|
|
|
def accuracy(model, dataset):
|
|
model.eval()
|
|
correct = sum([(model(images.to(device)).argmax(dim=1) == targets.to(device)).sum()
|
|
for images, targets in DataLoader(dataset, batch_size=256)])
|
|
return correct.float() / len(dataset)
|
|
|
|
def getModel():
|
|
hidden_size = 500
|
|
model = nn.Sequential(
|
|
nn.Linear(imageSize[0] * imageSize[1] * 3, hidden_size),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_size, len(labels)),
|
|
nn.LogSoftmax(dim=-1)
|
|
).to(device)
|
|
return model
|
|
|
|
def saveModel(model, path):
|
|
print("Saving model")
|
|
torch.save(model.state_dict(), path)
|
|
|
|
def loadModel(path):
|
|
print("Loading model")
|
|
model = getModel()
|
|
model.load_state_dict(torch.load(path, map_location=torch.device('cpu'))) # musiałem tutaj dodać to ładowanie z mapowaniem na cpu bo u mnie CUDA nie działa wy pewnie możecie to usunąć
|
|
return model
|
|
|
|
def trainNewModel(n_iter=100, batch_size=256):
|
|
trainset = getDataset(True)
|
|
model = getModel()
|
|
model = train(model, trainset)
|
|
return model
|
|
|
|
def predictLabel(imagePath, model):
|
|
image = Image.open(imagePath).convert("RGB")
|
|
image = preprocess_image(image)
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model.to(device)
|
|
with torch.no_grad():
|
|
model.eval() # Ustawienie modelu w tryb ewaluacji
|
|
output = model(image)
|
|
|
|
# Znalezienie indeksu klasy o największej wartości prawdopodobieństwa
|
|
predicted_class = torch.argmax(output).item()
|
|
return labels[predicted_class]
|
|
|
|
# Znalezienie indeksu klasy o największej wartości prawdopodobieństwa
|
|
predicted_class = torch.argmax(output).item()
|
|
return labels[predicted_class]
|
|
|
|
|
|
def preprocess_image(image):
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
transform = getTransformation()
|
|
image = transform(image).unsqueeze(0) # Add batch dimension
|
|
image = image.to(device) # Move the image tensor to the same device as the model
|
|
return image
|
|
|
|
|