forked from s464965/WMICraft
FINISH
This commit is contained in:
parent
431113a04c
commit
4e3e68d4c3
Binary file not shown.
@ -14,7 +14,9 @@ CNN = NeuralNetwork().to(device)
|
|||||||
def train(model):
|
def train(model):
|
||||||
model.train()
|
model.train()
|
||||||
trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=setup_photos)
|
trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=setup_photos)
|
||||||
|
testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=setup_photos)
|
||||||
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
|
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
|
||||||
|
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=True)
|
||||||
|
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
optimizer = Adam(model.parameters(), lr=learning_rate)
|
optimizer = Adam(model.parameters(), lr=learning_rate)
|
||||||
@ -33,11 +35,13 @@ def train(model):
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
if epoch % 2 == 0:
|
if epoch % 2 == 0:
|
||||||
print("epoch: %3d loss: %.4f" % (epoch, loss.item()))
|
print("epoch: %d loss: %.4f" % (epoch, loss.item()))
|
||||||
|
|
||||||
print("FINISHED!")
|
print("FINISHED TRAINING!")
|
||||||
print("Checking accuracy.")
|
print("Checking accuracy for the train set.")
|
||||||
check_accuracy(train_loader)
|
check_accuracy(train_loader)
|
||||||
|
print("Checking accuracy for the test set.")
|
||||||
|
check_accuracy(test_loader)
|
||||||
|
|
||||||
torch.save(model.state_dict(), "./learnedNetwork.pt")
|
torch.save(model.state_dict(), "./learnedNetwork.pt")
|
||||||
|
|
||||||
@ -62,7 +66,7 @@ def check_accuracy(loader):
|
|||||||
num_correct += (predictions == y).sum()
|
num_correct += (predictions == y).sum()
|
||||||
num_samples += predictions.size(0)
|
num_samples += predictions.size(0)
|
||||||
|
|
||||||
print(f"Got {num_correct}/{num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}")
|
print(f"Got {num_correct}/{num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}%")
|
||||||
|
|
||||||
|
|
||||||
def what_is_it(img_path, show_img=False):
|
def what_is_it(img_path, show_img=False):
|
||||||
@ -83,4 +87,4 @@ def what_is_it(img_path, show_img=False):
|
|||||||
return id_to_class[idx]
|
return id_to_class[idx]
|
||||||
|
|
||||||
|
|
||||||
print(what_is_it('./data/test/sand/sand.png', True))
|
train(CNN)
|
||||||
|
@ -74,7 +74,7 @@ BAR_HEIGHT_MULTIPLIER = 0.1
|
|||||||
#NEURAL_NETWORK
|
#NEURAL_NETWORK
|
||||||
learning_rate = 0.001
|
learning_rate = 0.001
|
||||||
batch_size = 7
|
batch_size = 7
|
||||||
num_epochs = 10
|
num_epochs = 100
|
||||||
|
|
||||||
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
classes = ['grass', 'sand', 'tree', 'water']
|
classes = ['grass', 'sand', 'tree', 'water']
|
||||||
|
Loading…
Reference in New Issue
Block a user