FINISH 95% test 99% train

This commit is contained in:
XsedoX 2022-05-31 09:25:36 +02:00
parent 8d73a85707
commit 162e2df890
4 changed files with 38 additions and 2 deletions

View File

@ -0,0 +1 @@
{}

View File

@ -13,6 +13,7 @@ from pytorch_lightning.callbacks import EarlyStopping
import torchvision.transforms.functional as F
from PIL import Image
def check_accuracy_tiles():
answer = 0
for i in range(100):
@ -72,6 +73,37 @@ def what_is_it(img_path, show_img=False):
return ID_TO_CLASS[idx]
def check_accuracy(tset):
model = NeuralNetwork.load_from_checkpoint('./lightning_logs/version_23/checkpoints/epoch=3-step=324.ckpt')
num_correct = 0
num_samples = 0
model = model.to(DEVICE)
model.eval()
with torch.no_grad():
for photo, label in tset:
photo = photo.to(DEVICE)
label = label.to(DEVICE)
scores = model(photo)
predictions = scores.argmax(dim=1)
num_correct += (predictions == label).sum()
num_samples += predictions.size(0)
print(f'Got {num_correct} / {num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}%')
def check_accuracy_data():
trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=SETUP_PHOTOS)
testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=SETUP_PHOTOS)
train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(testset, batch_size=BATCH_SIZE)
print("Accuracy of train_set:")
check_accuracy(train_loader)
print("Accuracy of test_set:")
check_accuracy(test_loader)
#CNN = NeuralNetwork()
#common.helpers.createCSV()
@ -82,9 +114,12 @@ def what_is_it(img_path, show_img=False):
#testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=SETUP_PHOTOS)
#train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
#test_loader = DataLoader(testset, batch_size=BATCH_SIZE)
#trainer.fit(CNN, train_loader, test_loader)
#trainer.tune(CNN, train_loader, test_loader)
#check_accuracy_tiles()
#print(what_is_it('../../resources/textures/grass2.png', True))
#check_accuracy_data()
#check_accuracy_tiles()