cnn #28

Merged
s464869 merged 10 commits from cnn into master 2022-05-25 19:57:09 +02:00
Showing only changes of commit 431113a04c - Show all commits

View File

@ -6,6 +6,7 @@ from neural_network import NeuralNetwork
from torchvision.io import read_image, ImageReadMode from torchvision.io import read_image, ImageReadMode
import torch.nn as nn import torch.nn as nn
from torch.optim import Adam from torch.optim import Adam
import matplotlib.pyplot as plt
CNN = NeuralNetwork().to(device) CNN = NeuralNetwork().to(device)
@ -64,8 +65,11 @@ def check_accuracy(loader):
print(f"Got {num_correct}/{num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}") print(f"Got {num_correct}/{num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}")
def what_is_it(img_path): def what_is_it(img_path, show_img=False):
image = read_image(img_path, mode=ImageReadMode.RGB) image = read_image(img_path, mode=ImageReadMode.RGB)
if show_img:
plt.imshow(plt.imread(img_path))
plt.show()
image = setup_photos(image).unsqueeze(0) image = setup_photos(image).unsqueeze(0)
model = NeuralNetwork() model = NeuralNetwork()
@ -79,4 +83,4 @@ def what_is_it(img_path):
return id_to_class[idx] return id_to_class[idx]
print(what_is_it('./data/test/water/water.png')) print(what_is_it('./data/test/sand/sand.png', True))