import torch from common.constants import device, batch_size, num_epochs, learning_rate, setup_photos, id_to_class from watersandtreegrass import WaterSandTreeGrass from torch.utils.data import DataLoader from neural_network import NeuralNetwork from torchvision.io import read_image, ImageReadMode import torch.nn as nn from torch.optim import Adam import matplotlib.pyplot as plt CNN = NeuralNetwork().to(device) def train(model): model.train() trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=setup_photos) train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True) criterion = nn.CrossEntropyLoss() optimizer = Adam(model.parameters(), lr=learning_rate) for epoch in range(num_epochs): for batch_idx, (data, targets) in enumerate(train_loader): data = data.to(device=device) targets = targets.to(device=device) scores = model(data) loss = criterion(scores, targets) optimizer.zero_grad() loss.backward() optimizer.step() if epoch % 2 == 0: print("epoch: %3d loss: %.4f" % (epoch, loss.item())) print("FINISHED!") print("Checking accuracy.") check_accuracy(train_loader) torch.save(model.state_dict(), "./learnedNetwork.pt") def check_accuracy(loader): num_correct = 0 num_samples = 0 model = NeuralNetwork() model.load_state_dict(torch.load("./learnedNetwork.pt", map_location=device)) model = model.to(device) with torch.no_grad(): model.eval() for x, y in loader: x = x.to(device=device) y = y.to(device=device) scores = model(x) _, predictions = scores.max(1) num_correct += (predictions == y).sum() num_samples += predictions.size(0) print(f"Got {num_correct}/{num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}") def what_is_it(img_path, show_img=False): image = read_image(img_path, mode=ImageReadMode.RGB) if show_img: plt.imshow(plt.imread(img_path)) plt.show() image = setup_photos(image).unsqueeze(0) model = NeuralNetwork() model.load_state_dict(torch.load("./learnedNetwork.pt", map_location=device)) model = model.to(device) image = image.to(device) with torch.no_grad(): model.eval() idx = int(model(image).argmax(dim=1)) return id_to_class[idx] print(what_is_it('./data/test/sand/sand.png', True))