forked from s464965/WMICraft
nie potrzeba filderu all
This commit is contained in:
parent
dc411fae42
commit
736dbc9616
@ -12,7 +12,7 @@ CNN = NeuralNetwork().to(device)
|
|||||||
|
|
||||||
def train(model):
|
def train(model):
|
||||||
model.train()
|
model.train()
|
||||||
trainset = WaterSandTreeGrass('./data/train_csv_file.csv', './data/train/all', transform=setup_photos)
|
trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=setup_photos)
|
||||||
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
|
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
|
||||||
|
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
@ -46,7 +46,7 @@ def check_accuracy(loader):
|
|||||||
num_samples = 0
|
num_samples = 0
|
||||||
model = NeuralNetwork()
|
model = NeuralNetwork()
|
||||||
|
|
||||||
model.load_state_dict(torch.load("./learnedNetwork.pt"))
|
model.load_state_dict(torch.load("./learnedNetwork.pt", map_location=device))
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -64,18 +64,12 @@ def check_accuracy(loader):
|
|||||||
print(f"Got {num_correct}/{num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}")
|
print(f"Got {num_correct}/{num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}")
|
||||||
|
|
||||||
|
|
||||||
testset_loader = DataLoader(
|
|
||||||
WaterSandTreeGrass('./data/test_csv_file.csv', './data/test/all', transform=setup_photos),
|
|
||||||
batch_size=batch_size
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def what_is_it(img_path):
|
def what_is_it(img_path):
|
||||||
image = read_image(img_path, mode=ImageReadMode.RGB)
|
image = read_image(img_path, mode=ImageReadMode.RGB)
|
||||||
image = setup_photos(image).unsqueeze(0)
|
image = setup_photos(image).unsqueeze(0)
|
||||||
model = NeuralNetwork()
|
model = NeuralNetwork()
|
||||||
|
|
||||||
model.load_state_dict(torch.load("./learnedNetwork.pt"))
|
model.load_state_dict(torch.load("./learnedNetwork.pt", map_location=device))
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
image = image.to(device)
|
image = image.to(device)
|
||||||
|
|
||||||
@ -85,7 +79,4 @@ def what_is_it(img_path):
|
|||||||
return id_to_class[idx]
|
return id_to_class[idx]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
check_accuracy(testset_loader)
|
|
||||||
|
|
||||||
print(what_is_it('./data/test/water/water.png'))
|
print(what_is_it('./data/test/water/water.png'))
|
||||||
|
@ -3,22 +3,19 @@ from torch.utils.data import Dataset
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
from torchvision.io import read_image, ImageReadMode
|
from torchvision.io import read_image, ImageReadMode
|
||||||
from common.helpers import createCSV
|
from common.helpers import createCSV
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
class WaterSandTreeGrass(Dataset):
|
class WaterSandTreeGrass(Dataset):
|
||||||
def __init__(self, annotations_file, img_dir, transform=None):
|
def __init__(self, annotations_file, transform=None):
|
||||||
createCSV()
|
createCSV()
|
||||||
self.img_labels = pd.read_csv(annotations_file)
|
self.img_labels = pd.read_csv(annotations_file)
|
||||||
self.img_dir = img_dir
|
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.img_labels)
|
return len(self.img_labels)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
|
image = read_image(self.img_labels.iloc[idx, 0], mode=ImageReadMode.RGB)
|
||||||
image = read_image(img_path, mode=ImageReadMode.RGB)
|
|
||||||
label = torch.tensor(int(self.img_labels.iloc[idx, 1]))
|
label = torch.tensor(int(self.img_labels.iloc[idx, 1]))
|
||||||
|
|
||||||
if self.transform:
|
if self.transform:
|
||||||
|
@ -70,12 +70,13 @@ BAR_ANIMATION_SPEED = 1
|
|||||||
BAR_WIDTH_MULTIPLIER = 0.9 # (0;1>
|
BAR_WIDTH_MULTIPLIER = 0.9 # (0;1>
|
||||||
BAR_HEIGHT_MULTIPLIER = 0.1
|
BAR_HEIGHT_MULTIPLIER = 0.1
|
||||||
|
|
||||||
|
|
||||||
#NEURAL_NETWORK
|
#NEURAL_NETWORK
|
||||||
learning_rate = 0.001
|
learning_rate = 0.001
|
||||||
batch_size = 7
|
batch_size = 7
|
||||||
num_epochs = 10
|
num_epochs = 10
|
||||||
|
|
||||||
device = torch.device('cuda')
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
classes = ['grass', 'sand', 'tree', 'water']
|
classes = ['grass', 'sand', 'tree', 'water']
|
||||||
|
|
||||||
setup_photos = transforms.Compose([
|
setup_photos = transforms.Compose([
|
||||||
|
@ -14,34 +14,43 @@ def draw_text(text, color, surface, x, y, text_size=30, is_bold=False):
|
|||||||
textrect.topleft = (x, y)
|
textrect.topleft = (x, y)
|
||||||
surface.blit(textobj, textrect)
|
surface.blit(textobj, textrect)
|
||||||
|
|
||||||
def createCSV():
|
|
||||||
train_csvfile = open('./data/train_csv_file.csv', 'w', newline="")
|
|
||||||
writer = csv.writer(train_csvfile)
|
|
||||||
writer.writerow(["filename", "type"])
|
|
||||||
|
|
||||||
|
def createCSV():
|
||||||
train_data_path = './data/train'
|
train_data_path = './data/train'
|
||||||
test_data_path = './data/test'
|
test_data_path = './data/test'
|
||||||
|
|
||||||
|
if os.path.exists(train_data_path):
|
||||||
|
train_csvfile = open('./data/train_csv_file.csv', 'w', newline="")
|
||||||
|
writer = csv.writer(train_csvfile)
|
||||||
|
writer.writerow(["filepath", "type"])
|
||||||
|
|
||||||
for class_name in classes:
|
for class_name in classes:
|
||||||
class_dir = train_data_path + "/" + class_name
|
class_dir = train_data_path + "/" + class_name
|
||||||
for filename in os.listdir(class_dir):
|
for filename in os.listdir(class_dir):
|
||||||
f = os.path.join(class_dir, filename)
|
f = os.path.join(class_dir, filename)
|
||||||
if os.path.isfile(f):
|
if os.path.isfile(f):
|
||||||
writer.writerow([filename, class_to_id[class_name]])
|
writer.writerow([f, class_to_id[class_name]])
|
||||||
|
|
||||||
|
train_csvfile.close()
|
||||||
|
|
||||||
|
else:
|
||||||
|
print("Brak plików do uczenia")
|
||||||
|
|
||||||
|
if os.path.exists(train_data_path):
|
||||||
test_csvfile = open('./data/test_csv_file.csv', 'w', newline="")
|
test_csvfile = open('./data/test_csv_file.csv', 'w', newline="")
|
||||||
writer = csv.writer(test_csvfile)
|
writer = csv.writer(test_csvfile)
|
||||||
writer.writerow(["filename", "type"])
|
writer.writerow(["filepath", "type"])
|
||||||
|
|
||||||
for class_name in classes:
|
for class_name in classes:
|
||||||
class_dir = test_data_path + "/" + class_name
|
class_dir = test_data_path + "/" + class_name
|
||||||
for filename in os.listdir(class_dir):
|
for filename in os.listdir(class_dir):
|
||||||
f = os.path.join(class_dir, filename)
|
f = os.path.join(class_dir, filename)
|
||||||
if os.path.isfile(f):
|
if os.path.isfile(f):
|
||||||
writer.writerow([filename, class_to_id[class_name]])
|
writer.writerow([f, class_to_id[class_name]])
|
||||||
|
|
||||||
test_csvfile.close()
|
test_csvfile.close()
|
||||||
train_csvfile.close()
|
else:
|
||||||
|
print("Brak plików do testowania")
|
||||||
|
|
||||||
|
|
||||||
def print_numbers():
|
def print_numbers():
|
||||||
|
@ -155,4 +155,3 @@ class Level:
|
|||||||
# update and draw the game
|
# update and draw the game
|
||||||
self.sprites.draw(self.screen)
|
self.sprites.draw(self.screen)
|
||||||
self.sprites.update()
|
self.sprites.update()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user