import torch import common.helpers from common.constants import DEVICE, BATCH_SIZE, NUM_EPOCHS, LEARNING_RATE, SETUP_PHOTOS, ID_TO_CLASS from watersandtreegrass import WaterSandTreeGrass from torch.utils.data import DataLoader from neural_network import NeuralNetwork from torchvision.io import read_image, ImageReadMode import torch.nn as nn from torch.optim import Adam import matplotlib.pyplot as plt import pytorch_lightning as pl from pytorch_lightning.callbacks import EarlyStopping def train(model): model = model.to(DEVICE) model.train() trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=SETUP_PHOTOS) testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=SETUP_PHOTOS) train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True) test_loader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=True) criterion = nn.CrossEntropyLoss() optimizer = Adam(model.parameters(), lr=LEARNING_RATE) for epoch in range(NUM_EPOCHS): for batch_idx, (data, targets) in enumerate(train_loader): data = data.to(device=DEVICE) targets = targets.to(device=DEVICE) scores = model(data) loss = criterion(scores, targets) optimizer.zero_grad() loss.backward() optimizer.step() if batch_idx % 4 == 0: print("epoch: %d loss: %.4f" % (epoch, loss.item())) print("FINISHED TRAINING!") torch.save(model.state_dict(), "./learnednetwork.pth") print("Checking accuracy for the train set.") check_accuracy(train_loader) print("Checking accuracy for the test set.") check_accuracy(test_loader) print("Checking accuracy for the tiles.") check_accuracy_tiles() def check_accuracy_tiles(): answer = 0 for i in range(100): if what_is_it('../../resources/textures/grass_with_tree.jpg') == 'tree': answer = answer + 1 print("Accuracy(%) grass_with_tree.jpg", answer) answer = 0 for i in range(100): if what_is_it('../../resources/textures/grass2.png') == 'grass': answer = answer + 1 print("Accuracy(%) grass2.png", answer) answer = 0 for i in range(100): if what_is_it('../../resources/textures/grass3.png') == 'grass': answer = answer + 1 print("Accuracy(%) grass3.png", answer) answer = 0 for i in range(100): if what_is_it('../../resources/textures/grass4.png') == 'grass': answer = answer + 1 print("Accuracy(%) grass4.png", answer) answer = 0 for i in range(100): if what_is_it('../../resources/textures/grass1.png') == 'grass': answer = answer + 1 print("Accuracy(%) grass1.png", answer) answer = 0 for i in range(100): if what_is_it('../../resources/textures/water.png') == 'water': answer = answer + 1 print("Accuracy(%) water.png", answer) answer = 0 for i in range(100): if what_is_it('../../resources/textures/sand.png') == 'sand': answer = answer + 1 print("Accuracy(%) sand.png", answer) def what_is_it(img_path, show_img=False): image = read_image(img_path, mode=ImageReadMode.RGB) if show_img: plt.imshow(plt.imread(img_path)) plt.show() image = SETUP_PHOTOS(image).unsqueeze(0) model = NeuralNetwork.load_from_checkpoint('./lightning_logs/version_13/checkpoints/epoch=4-step=405.ckpt') with torch.no_grad(): model.eval() idx = int(model(image).argmax(dim=1)) return ID_TO_CLASS[idx] CNN = NeuralNetwork() common.helpers.createCSV() #trainer = pl.Trainer(accelerator='gpu', devices=1, callbacks=[EarlyStopping('val_loss')], max_epochs=NUM_EPOCHS) trainer = pl.Trainer(accelerator='gpu', devices=1, auto_lr_find=True, max_epochs=NUM_EPOCHS) trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=SETUP_PHOTOS) testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=SETUP_PHOTOS) train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True) test_loader = DataLoader(testset, batch_size=BATCH_SIZE) trainer.fit(CNN, train_loader, test_loader) #trainer.tune(CNN, train_loader, test_loader) #check_accuracy_tiles() #print(what_is_it('../../resources/textures/sand.png', True))