forked from s464965/WMICraft
FINISH 95% test 99% train
This commit is contained in:
parent
8d73a85707
commit
162e2df890
Binary file not shown.
Binary file not shown.
@ -0,0 +1 @@
|
|||||||
|
{}
|
@ -13,6 +13,7 @@ from pytorch_lightning.callbacks import EarlyStopping
|
|||||||
import torchvision.transforms.functional as F
|
import torchvision.transforms.functional as F
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
def check_accuracy_tiles():
|
def check_accuracy_tiles():
|
||||||
answer = 0
|
answer = 0
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
@ -72,6 +73,37 @@ def what_is_it(img_path, show_img=False):
|
|||||||
return ID_TO_CLASS[idx]
|
return ID_TO_CLASS[idx]
|
||||||
|
|
||||||
|
|
||||||
|
def check_accuracy(tset):
|
||||||
|
model = NeuralNetwork.load_from_checkpoint('./lightning_logs/version_23/checkpoints/epoch=3-step=324.ckpt')
|
||||||
|
num_correct = 0
|
||||||
|
num_samples = 0
|
||||||
|
model = model.to(DEVICE)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for photo, label in tset:
|
||||||
|
photo = photo.to(DEVICE)
|
||||||
|
label = label.to(DEVICE)
|
||||||
|
|
||||||
|
scores = model(photo)
|
||||||
|
predictions = scores.argmax(dim=1)
|
||||||
|
num_correct += (predictions == label).sum()
|
||||||
|
num_samples += predictions.size(0)
|
||||||
|
|
||||||
|
print(f'Got {num_correct} / {num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}%')
|
||||||
|
|
||||||
|
|
||||||
|
def check_accuracy_data():
|
||||||
|
trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=SETUP_PHOTOS)
|
||||||
|
testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=SETUP_PHOTOS)
|
||||||
|
train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
|
test_loader = DataLoader(testset, batch_size=BATCH_SIZE)
|
||||||
|
|
||||||
|
print("Accuracy of train_set:")
|
||||||
|
check_accuracy(train_loader)
|
||||||
|
print("Accuracy of test_set:")
|
||||||
|
check_accuracy(test_loader)
|
||||||
|
|
||||||
#CNN = NeuralNetwork()
|
#CNN = NeuralNetwork()
|
||||||
#common.helpers.createCSV()
|
#common.helpers.createCSV()
|
||||||
|
|
||||||
@ -82,9 +114,12 @@ def what_is_it(img_path, show_img=False):
|
|||||||
#testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=SETUP_PHOTOS)
|
#testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=SETUP_PHOTOS)
|
||||||
#train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
|
#train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
#test_loader = DataLoader(testset, batch_size=BATCH_SIZE)
|
#test_loader = DataLoader(testset, batch_size=BATCH_SIZE)
|
||||||
|
|
||||||
#trainer.fit(CNN, train_loader, test_loader)
|
#trainer.fit(CNN, train_loader, test_loader)
|
||||||
#trainer.tune(CNN, train_loader, test_loader)
|
#trainer.tune(CNN, train_loader, test_loader)
|
||||||
#check_accuracy_tiles()
|
|
||||||
|
|
||||||
#print(what_is_it('../../resources/textures/grass2.png', True))
|
#print(what_is_it('../../resources/textures/grass2.png', True))
|
||||||
|
|
||||||
|
#check_accuracy_data()
|
||||||
|
|
||||||
|
#check_accuracy_tiles()
|
||||||
|
Loading…
Reference in New Issue
Block a user