refactor #26
4
.gitignore
vendored
4
.gitignore
vendored
@ -1,3 +1,5 @@
|
||||
__pycache__/
|
||||
.idea/
|
||||
tree.png
|
||||
tree.png
|
||||
dataset/
|
||||
dataset.zip
|
3
Image.py
3
Image.py
@ -66,6 +66,9 @@ def getRandomImageFromDataBase():
|
||||
imgPath = os.path.join(folderPath, random_image)
|
||||
|
||||
while imgPath in imagePathList:
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
quit()
|
||||
label = random.choice(neuralnetwork.labels)
|
||||
folderPath = f"dataset/test/{label}"
|
||||
files = os.listdir(folderPath)
|
||||
|
@ -197,6 +197,9 @@ class Tractor:
|
||||
count = 0
|
||||
for i in range(initPos[1], dCon.NUM_Y):
|
||||
for j in range(initPos[0], dCon.NUM_X):
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
quit()
|
||||
if self.slot.imagePath != None:
|
||||
predictedLabel = nn.predictLabel(self.slot.imagePath, model)
|
||||
print(str("Coords: ({:02d}, {:02d})").format(self.slot.x_axis, self.slot.y_axis), "real:", self.slot.label, "predicted:", predictedLabel, "correct" if (self.slot.label == predictedLabel) else "incorrect", 'nawożę za pomocą:', nn.fertilizer[predictedLabel])
|
||||
|
@ -9,12 +9,12 @@ from PIL import Image
|
||||
import random
|
||||
|
||||
imageSize = (128, 128)
|
||||
labels = ['beetroot', 'carrot', 'potato'] # musi być w kolejności alfabetycznej
|
||||
labels = ['carrot', 'potato', 'tomato'] # musi być w kolejności alfabetycznej
|
||||
fertilizer = {labels[0]: 'kompost', labels[1]: 'saletra amonowa', labels[2]: 'superfosfat'}
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
device = torch.device('cuda') if torch.cuda.is_available () else torch.device('cpu')
|
||||
device = torch.device('cuda')
|
||||
|
||||
def getTransformation():
|
||||
transform=transforms.Compose([
|
||||
@ -67,9 +67,11 @@ def getModel():
|
||||
return model
|
||||
|
||||
def saveModel(model, path):
|
||||
print("Saving model")
|
||||
torch.save(model.state_dict(), path)
|
||||
|
||||
def loadModel(path):
|
||||
print("Loading model")
|
||||
model = getModel()
|
||||
model.load_state_dict(torch.load(path))
|
||||
return model
|
||||
@ -83,6 +85,8 @@ def trainNewModel(n_iter=100, batch_size=256):
|
||||
def predictLabel(imagePath, model):
|
||||
image = Image.open(imagePath).convert("RGB")
|
||||
image = preprocess_image(image)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model.to(device)
|
||||
with torch.no_grad():
|
||||
model.eval() # Ustawienie modelu w tryb ewaluacji
|
||||
output = model(image)
|
||||
@ -97,9 +101,10 @@ def predictLabel(imagePath, model):
|
||||
|
||||
|
||||
def preprocess_image(image):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
transform = getTransformation()
|
||||
image = transform(image).unsqueeze(0) # Dodanie wymiaru batch_size
|
||||
|
||||
image = transform(image).unsqueeze(0) # Add batch dimension
|
||||
image = image.to(device) # Move the image tensor to the same device as the model
|
||||
return image
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user