From 736dbc9616c1e7d1f28ace5149540eb6eb41828e Mon Sep 17 00:00:00 2001 From: XsedoX Date: Wed, 18 May 2022 10:29:05 +0200 Subject: [PATCH] nie potrzeba filderu all --- .../neural_network_interface.py | 15 ++---- .../neural_network/watersandtreegrass.py | 7 +-- common/constants.py | 3 +- common/helpers.py | 51 +++++++++++-------- logic/level.py | 1 - 5 files changed, 37 insertions(+), 40 deletions(-) diff --git a/algorithms/neural_network/neural_network_interface.py b/algorithms/neural_network/neural_network_interface.py index 3a3b1c6..d0e9e4d 100644 --- a/algorithms/neural_network/neural_network_interface.py +++ b/algorithms/neural_network/neural_network_interface.py @@ -12,7 +12,7 @@ CNN = NeuralNetwork().to(device) def train(model): model.train() - trainset = WaterSandTreeGrass('./data/train_csv_file.csv', './data/train/all', transform=setup_photos) + trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=setup_photos) train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True) criterion = nn.CrossEntropyLoss() @@ -46,7 +46,7 @@ def check_accuracy(loader): num_samples = 0 model = NeuralNetwork() - model.load_state_dict(torch.load("./learnedNetwork.pt")) + model.load_state_dict(torch.load("./learnedNetwork.pt", map_location=device)) model = model.to(device) with torch.no_grad(): @@ -64,18 +64,12 @@ def check_accuracy(loader): print(f"Got {num_correct}/{num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}") -testset_loader = DataLoader( - WaterSandTreeGrass('./data/test_csv_file.csv', './data/test/all', transform=setup_photos), - batch_size=batch_size -) - - def what_is_it(img_path): image = read_image(img_path, mode=ImageReadMode.RGB) image = setup_photos(image).unsqueeze(0) model = NeuralNetwork() - model.load_state_dict(torch.load("./learnedNetwork.pt")) + model.load_state_dict(torch.load("./learnedNetwork.pt", map_location=device)) model = model.to(device) image = image.to(device) @@ -85,7 +79,4 @@ def what_is_it(img_path): return id_to_class[idx] - -check_accuracy(testset_loader) - print(what_is_it('./data/test/water/water.png')) diff --git a/algorithms/neural_network/watersandtreegrass.py b/algorithms/neural_network/watersandtreegrass.py index 835d540..93525d0 100644 --- a/algorithms/neural_network/watersandtreegrass.py +++ b/algorithms/neural_network/watersandtreegrass.py @@ -3,22 +3,19 @@ from torch.utils.data import Dataset import pandas as pd from torchvision.io import read_image, ImageReadMode from common.helpers import createCSV -import os class WaterSandTreeGrass(Dataset): - def __init__(self, annotations_file, img_dir, transform=None): + def __init__(self, annotations_file, transform=None): createCSV() self.img_labels = pd.read_csv(annotations_file) - self.img_dir = img_dir self.transform = transform def __len__(self): return len(self.img_labels) def __getitem__(self, idx): - img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) - image = read_image(img_path, mode=ImageReadMode.RGB) + image = read_image(self.img_labels.iloc[idx, 0], mode=ImageReadMode.RGB) label = torch.tensor(int(self.img_labels.iloc[idx, 1])) if self.transform: diff --git a/common/constants.py b/common/constants.py index e346f82..2e73dee 100644 --- a/common/constants.py +++ b/common/constants.py @@ -70,12 +70,13 @@ BAR_ANIMATION_SPEED = 1 BAR_WIDTH_MULTIPLIER = 0.9 # (0;1> BAR_HEIGHT_MULTIPLIER = 0.1 + #NEURAL_NETWORK learning_rate = 0.001 batch_size = 7 num_epochs = 10 -device = torch.device('cuda') +device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') classes = ['grass', 'sand', 'tree', 'water'] setup_photos = transforms.Compose([ diff --git a/common/helpers.py b/common/helpers.py index 22e44b7..e5c7a32 100644 --- a/common/helpers.py +++ b/common/helpers.py @@ -14,34 +14,43 @@ def draw_text(text, color, surface, x, y, text_size=30, is_bold=False): textrect.topleft = (x, y) surface.blit(textobj, textrect) -def createCSV(): - train_csvfile = open('./data/train_csv_file.csv', 'w', newline="") - writer = csv.writer(train_csvfile) - writer.writerow(["filename", "type"]) +def createCSV(): train_data_path = './data/train' test_data_path = './data/test' - for class_name in classes: - class_dir = train_data_path + "/" + class_name - for filename in os.listdir(class_dir): - f = os.path.join(class_dir, filename) - if os.path.isfile(f): - writer.writerow([filename, class_to_id[class_name]]) + if os.path.exists(train_data_path): + train_csvfile = open('./data/train_csv_file.csv', 'w', newline="") + writer = csv.writer(train_csvfile) + writer.writerow(["filepath", "type"]) - test_csvfile = open('./data/test_csv_file.csv', 'w', newline="") - writer = csv.writer(test_csvfile) - writer.writerow(["filename", "type"]) + for class_name in classes: + class_dir = train_data_path + "/" + class_name + for filename in os.listdir(class_dir): + f = os.path.join(class_dir, filename) + if os.path.isfile(f): + writer.writerow([f, class_to_id[class_name]]) - for class_name in classes: - class_dir = test_data_path + "/" + class_name - for filename in os.listdir(class_dir): - f = os.path.join(class_dir, filename) - if os.path.isfile(f): - writer.writerow([filename, class_to_id[class_name]]) + train_csvfile.close() - test_csvfile.close() - train_csvfile.close() + else: + print("Brak plików do uczenia") + + if os.path.exists(train_data_path): + test_csvfile = open('./data/test_csv_file.csv', 'w', newline="") + writer = csv.writer(test_csvfile) + writer.writerow(["filepath", "type"]) + + for class_name in classes: + class_dir = test_data_path + "/" + class_name + for filename in os.listdir(class_dir): + f = os.path.join(class_dir, filename) + if os.path.isfile(f): + writer.writerow([f, class_to_id[class_name]]) + + test_csvfile.close() + else: + print("Brak plików do testowania") def print_numbers(): diff --git a/logic/level.py b/logic/level.py index f8e5ee5..8062cbf 100644 --- a/logic/level.py +++ b/logic/level.py @@ -155,4 +155,3 @@ class Level: # update and draw the game self.sprites.draw(self.screen) self.sprites.update() - \ No newline at end of file