cnn #28
@ -6,6 +6,7 @@ from neural_network import NeuralNetwork
|
|||||||
from torchvision.io import read_image, ImageReadMode
|
from torchvision.io import read_image, ImageReadMode
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
CNN = NeuralNetwork().to(device)
|
CNN = NeuralNetwork().to(device)
|
||||||
|
|
||||||
@ -64,8 +65,11 @@ def check_accuracy(loader):
|
|||||||
print(f"Got {num_correct}/{num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}")
|
print(f"Got {num_correct}/{num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}")
|
||||||
|
|
||||||
|
|
||||||
def what_is_it(img_path):
|
def what_is_it(img_path, show_img=False):
|
||||||
image = read_image(img_path, mode=ImageReadMode.RGB)
|
image = read_image(img_path, mode=ImageReadMode.RGB)
|
||||||
|
if show_img:
|
||||||
|
plt.imshow(plt.imread(img_path))
|
||||||
|
plt.show()
|
||||||
image = setup_photos(image).unsqueeze(0)
|
image = setup_photos(image).unsqueeze(0)
|
||||||
model = NeuralNetwork()
|
model = NeuralNetwork()
|
||||||
|
|
||||||
@ -79,4 +83,4 @@ def what_is_it(img_path):
|
|||||||
return id_to_class[idx]
|
return id_to_class[idx]
|
||||||
|
|
||||||
|
|
||||||
print(what_is_it('./data/test/water/water.png'))
|
print(what_is_it('./data/test/sand/sand.png', True))
|
||||||
|
Loading…
Reference in New Issue
Block a user