Computer_Vision/Chapter17/fmnist.py
2024-02-13 03:34:51 +01:00

36 lines
1.1 KiB
Python

from torch_snippets import *
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class FMNIST(nn.Module):
classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress',
'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(28 * 28, 1000),
nn.ReLU(),
nn.Linear(1000, 10)
).to(device)
self.model.load_state_dict(torch.load('fmnist.weights.pth', map_location=device))
logger.info('Loaded FMNIST Model')
def forward(self, x):
x = x.view(1, -1).to(device)
pred = self.model(x)
pred = F.softmax(pred, -1)[0]
conf, clss = pred.max(-1)
clss = self.classes[clss.cpu().item()]
return conf.item(), clss
def predict_from_path(self, path):
x = cv2.imread(path,0)
return self.predict_from_image(x)
def predict_from_image(self, image):
x = np.array(image)
x = cv2.resize(x, (28,28))
x = torch.Tensor(255 - x)/255.
conf, clss = self(x)
return {'class': clss, 'confidence': f'{conf:.4f}'}