forked from s464965/WMICraft
FINISH
This commit is contained in:
parent
431113a04c
commit
4e3e68d4c3
Binary file not shown.
@ -14,7 +14,9 @@ CNN = NeuralNetwork().to(device)
|
||||
def train(model):
|
||||
model.train()
|
||||
trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=setup_photos)
|
||||
testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=setup_photos)
|
||||
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
|
||||
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=True)
|
||||
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = Adam(model.parameters(), lr=learning_rate)
|
||||
@ -33,11 +35,13 @@ def train(model):
|
||||
optimizer.step()
|
||||
|
||||
if epoch % 2 == 0:
|
||||
print("epoch: %3d loss: %.4f" % (epoch, loss.item()))
|
||||
print("epoch: %d loss: %.4f" % (epoch, loss.item()))
|
||||
|
||||
print("FINISHED!")
|
||||
print("Checking accuracy.")
|
||||
print("FINISHED TRAINING!")
|
||||
print("Checking accuracy for the train set.")
|
||||
check_accuracy(train_loader)
|
||||
print("Checking accuracy for the test set.")
|
||||
check_accuracy(test_loader)
|
||||
|
||||
torch.save(model.state_dict(), "./learnedNetwork.pt")
|
||||
|
||||
@ -62,7 +66,7 @@ def check_accuracy(loader):
|
||||
num_correct += (predictions == y).sum()
|
||||
num_samples += predictions.size(0)
|
||||
|
||||
print(f"Got {num_correct}/{num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}")
|
||||
print(f"Got {num_correct}/{num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}%")
|
||||
|
||||
|
||||
def what_is_it(img_path, show_img=False):
|
||||
@ -83,4 +87,4 @@ def what_is_it(img_path, show_img=False):
|
||||
return id_to_class[idx]
|
||||
|
||||
|
||||
print(what_is_it('./data/test/sand/sand.png', True))
|
||||
train(CNN)
|
||||
|
@ -74,7 +74,7 @@ BAR_HEIGHT_MULTIPLIER = 0.1
|
||||
#NEURAL_NETWORK
|
||||
learning_rate = 0.001
|
||||
batch_size = 7
|
||||
num_epochs = 10
|
||||
num_epochs = 100
|
||||
|
||||
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
classes = ['grass', 'sand', 'tree', 'water']
|
||||
|
Loading…
Reference in New Issue
Block a user