nie potrzeba filderu all

This commit is contained in:
XsedoX 2022-05-18 10:29:05 +02:00
parent dc411fae42
commit 736dbc9616
5 changed files with 37 additions and 40 deletions

View File

@ -12,7 +12,7 @@ CNN = NeuralNetwork().to(device)
def train(model): def train(model):
model.train() model.train()
trainset = WaterSandTreeGrass('./data/train_csv_file.csv', './data/train/all', transform=setup_photos) trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=setup_photos)
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True) train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
@ -46,7 +46,7 @@ def check_accuracy(loader):
num_samples = 0 num_samples = 0
model = NeuralNetwork() model = NeuralNetwork()
model.load_state_dict(torch.load("./learnedNetwork.pt")) model.load_state_dict(torch.load("./learnedNetwork.pt", map_location=device))
model = model.to(device) model = model.to(device)
with torch.no_grad(): with torch.no_grad():
@ -64,18 +64,12 @@ def check_accuracy(loader):
print(f"Got {num_correct}/{num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}") print(f"Got {num_correct}/{num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}")
testset_loader = DataLoader(
WaterSandTreeGrass('./data/test_csv_file.csv', './data/test/all', transform=setup_photos),
batch_size=batch_size
)
def what_is_it(img_path): def what_is_it(img_path):
image = read_image(img_path, mode=ImageReadMode.RGB) image = read_image(img_path, mode=ImageReadMode.RGB)
image = setup_photos(image).unsqueeze(0) image = setup_photos(image).unsqueeze(0)
model = NeuralNetwork() model = NeuralNetwork()
model.load_state_dict(torch.load("./learnedNetwork.pt")) model.load_state_dict(torch.load("./learnedNetwork.pt", map_location=device))
model = model.to(device) model = model.to(device)
image = image.to(device) image = image.to(device)
@ -85,7 +79,4 @@ def what_is_it(img_path):
return id_to_class[idx] return id_to_class[idx]
check_accuracy(testset_loader)
print(what_is_it('./data/test/water/water.png')) print(what_is_it('./data/test/water/water.png'))

View File

@ -3,22 +3,19 @@ from torch.utils.data import Dataset
import pandas as pd import pandas as pd
from torchvision.io import read_image, ImageReadMode from torchvision.io import read_image, ImageReadMode
from common.helpers import createCSV from common.helpers import createCSV
import os
class WaterSandTreeGrass(Dataset): class WaterSandTreeGrass(Dataset):
def __init__(self, annotations_file, img_dir, transform=None): def __init__(self, annotations_file, transform=None):
createCSV() createCSV()
self.img_labels = pd.read_csv(annotations_file) self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform self.transform = transform
def __len__(self): def __len__(self):
return len(self.img_labels) return len(self.img_labels)
def __getitem__(self, idx): def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) image = read_image(self.img_labels.iloc[idx, 0], mode=ImageReadMode.RGB)
image = read_image(img_path, mode=ImageReadMode.RGB)
label = torch.tensor(int(self.img_labels.iloc[idx, 1])) label = torch.tensor(int(self.img_labels.iloc[idx, 1]))
if self.transform: if self.transform:

View File

@ -70,12 +70,13 @@ BAR_ANIMATION_SPEED = 1
BAR_WIDTH_MULTIPLIER = 0.9 # (0;1> BAR_WIDTH_MULTIPLIER = 0.9 # (0;1>
BAR_HEIGHT_MULTIPLIER = 0.1 BAR_HEIGHT_MULTIPLIER = 0.1
#NEURAL_NETWORK #NEURAL_NETWORK
learning_rate = 0.001 learning_rate = 0.001
batch_size = 7 batch_size = 7
num_epochs = 10 num_epochs = 10
device = torch.device('cuda') device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
classes = ['grass', 'sand', 'tree', 'water'] classes = ['grass', 'sand', 'tree', 'water']
setup_photos = transforms.Compose([ setup_photos = transforms.Compose([

View File

@ -14,34 +14,43 @@ def draw_text(text, color, surface, x, y, text_size=30, is_bold=False):
textrect.topleft = (x, y) textrect.topleft = (x, y)
surface.blit(textobj, textrect) surface.blit(textobj, textrect)
def createCSV():
train_csvfile = open('./data/train_csv_file.csv', 'w', newline="")
writer = csv.writer(train_csvfile)
writer.writerow(["filename", "type"])
def createCSV():
train_data_path = './data/train' train_data_path = './data/train'
test_data_path = './data/test' test_data_path = './data/test'
if os.path.exists(train_data_path):
train_csvfile = open('./data/train_csv_file.csv', 'w', newline="")
writer = csv.writer(train_csvfile)
writer.writerow(["filepath", "type"])
for class_name in classes: for class_name in classes:
class_dir = train_data_path + "/" + class_name class_dir = train_data_path + "/" + class_name
for filename in os.listdir(class_dir): for filename in os.listdir(class_dir):
f = os.path.join(class_dir, filename) f = os.path.join(class_dir, filename)
if os.path.isfile(f): if os.path.isfile(f):
writer.writerow([filename, class_to_id[class_name]]) writer.writerow([f, class_to_id[class_name]])
train_csvfile.close()
else:
print("Brak plików do uczenia")
if os.path.exists(train_data_path):
test_csvfile = open('./data/test_csv_file.csv', 'w', newline="") test_csvfile = open('./data/test_csv_file.csv', 'w', newline="")
writer = csv.writer(test_csvfile) writer = csv.writer(test_csvfile)
writer.writerow(["filename", "type"]) writer.writerow(["filepath", "type"])
for class_name in classes: for class_name in classes:
class_dir = test_data_path + "/" + class_name class_dir = test_data_path + "/" + class_name
for filename in os.listdir(class_dir): for filename in os.listdir(class_dir):
f = os.path.join(class_dir, filename) f = os.path.join(class_dir, filename)
if os.path.isfile(f): if os.path.isfile(f):
writer.writerow([filename, class_to_id[class_name]]) writer.writerow([f, class_to_id[class_name]])
test_csvfile.close() test_csvfile.close()
train_csvfile.close() else:
print("Brak plików do testowania")
def print_numbers(): def print_numbers():

View File

@ -155,4 +155,3 @@ class Level:
# update and draw the game # update and draw the game
self.sprites.draw(self.screen) self.sprites.draw(self.screen)
self.sprites.update() self.sprites.update()