diff --git a/algorithms/neural_network/neural_network_interface.py b/algorithms/neural_network/neural_network_interface.py index 4033778..6ec7cb4 100644 --- a/algorithms/neural_network/neural_network_interface.py +++ b/algorithms/neural_network/neural_network_interface.py @@ -13,44 +13,6 @@ from pytorch_lightning.callbacks import EarlyStopping import torchvision.transforms.functional as F from PIL import Image -def train(model): - model = model.to(DEVICE) - model.train() - trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=SETUP_PHOTOS) - testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=SETUP_PHOTOS) - train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True) - test_loader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=True) - - criterion = nn.CrossEntropyLoss() - optimizer = Adam(model.parameters(), lr=LEARNING_RATE) - - for epoch in range(NUM_EPOCHS): - for batch_idx, (data, targets) in enumerate(train_loader): - data = data.to(device=DEVICE) - targets = targets.to(device=DEVICE) - - scores = model(data) - loss = criterion(scores, targets) - - optimizer.zero_grad() - loss.backward() - - optimizer.step() - - if batch_idx % 4 == 0: - print("epoch: %d loss: %.4f" % (epoch, loss.item())) - - print("FINISHED TRAINING!") - torch.save(model.state_dict(), "./learnednetwork.pth") - - print("Checking accuracy for the train set.") - check_accuracy(train_loader) - print("Checking accuracy for the test set.") - check_accuracy(test_loader) - print("Checking accuracy for the tiles.") - check_accuracy_tiles() - - def check_accuracy_tiles(): answer = 0 for i in range(100):