cnn #28
@ -6,6 +6,7 @@ from neural_network import NeuralNetwork
|
||||
from torchvision.io import read_image, ImageReadMode
|
||||
import torch.nn as nn
|
||||
from torch.optim import Adam
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
CNN = NeuralNetwork().to(device)
|
||||
|
||||
@ -64,8 +65,11 @@ def check_accuracy(loader):
|
||||
print(f"Got {num_correct}/{num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}")
|
||||
|
||||
|
||||
def what_is_it(img_path):
|
||||
def what_is_it(img_path, show_img=False):
|
||||
image = read_image(img_path, mode=ImageReadMode.RGB)
|
||||
if show_img:
|
||||
plt.imshow(plt.imread(img_path))
|
||||
plt.show()
|
||||
image = setup_photos(image).unsqueeze(0)
|
||||
model = NeuralNetwork()
|
||||
|
||||
@ -79,4 +83,4 @@ def what_is_it(img_path):
|
||||
return id_to_class[idx]
|
||||
|
||||
|
||||
print(what_is_it('./data/test/water/water.png'))
|
||||
print(what_is_it('./data/test/sand/sand.png', True))
|
||||
|
Loading…
Reference in New Issue
Block a user