This commit is contained in:
XsedoX 2022-05-18 15:44:07 +02:00
parent 431113a04c
commit 4e3e68d4c3
3 changed files with 10 additions and 6 deletions

View File

@ -14,7 +14,9 @@ CNN = NeuralNetwork().to(device)
def train(model): def train(model):
model.train() model.train()
trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=setup_photos) trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=setup_photos)
testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=setup_photos)
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True) train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=True)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=learning_rate) optimizer = Adam(model.parameters(), lr=learning_rate)
@ -33,11 +35,13 @@ def train(model):
optimizer.step() optimizer.step()
if epoch % 2 == 0: if epoch % 2 == 0:
print("epoch: %3d loss: %.4f" % (epoch, loss.item())) print("epoch: %d loss: %.4f" % (epoch, loss.item()))
print("FINISHED!") print("FINISHED TRAINING!")
print("Checking accuracy.") print("Checking accuracy for the train set.")
check_accuracy(train_loader) check_accuracy(train_loader)
print("Checking accuracy for the test set.")
check_accuracy(test_loader)
torch.save(model.state_dict(), "./learnedNetwork.pt") torch.save(model.state_dict(), "./learnedNetwork.pt")
@ -62,7 +66,7 @@ def check_accuracy(loader):
num_correct += (predictions == y).sum() num_correct += (predictions == y).sum()
num_samples += predictions.size(0) num_samples += predictions.size(0)
print(f"Got {num_correct}/{num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}") print(f"Got {num_correct}/{num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}%")
def what_is_it(img_path, show_img=False): def what_is_it(img_path, show_img=False):
@ -83,4 +87,4 @@ def what_is_it(img_path, show_img=False):
return id_to_class[idx] return id_to_class[idx]
print(what_is_it('./data/test/sand/sand.png', True)) train(CNN)

View File

@ -74,7 +74,7 @@ BAR_HEIGHT_MULTIPLIER = 0.1
#NEURAL_NETWORK #NEURAL_NETWORK
learning_rate = 0.001 learning_rate = 0.001
batch_size = 7 batch_size = 7
num_epochs = 10 num_epochs = 100
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
classes = ['grass', 'sand', 'tree', 'water'] classes = ['grass', 'sand', 'tree', 'water']