ium_464915/test.py
2024-04-28 18:57:42 +02:00

84 lines
2.5 KiB
Python

import csv
import torch
import torch.nn.functional as F
from torch import nn
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, 3, 1)
self.batchnorm1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.batchnorm2 = nn.BatchNorm2d(64)
self.conv3 = nn.Conv2d(64, 128, 3, 1)
self.fc1 = nn.Linear(128*26*26, 128)
self.fc2 = nn.Linear(128, 2)
def forward(self, x):
x = F.relu(self.batchnorm1(self.conv1(x)))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.batchnorm2(self.conv2(x)))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv3(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 128*26*26)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def get_data(IMG_SIZE: int, BATCH_SIZE: int):
testTransformer = transforms.Compose([
transforms.Resize(size = (IMG_SIZE, IMG_SIZE), antialias = True),
transforms.ToTensor(),
transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]),
])
testSet = datasets.ImageFolder(root = "./test", transform = testTransformer)
testLoader = DataLoader(testSet, batch_size = BATCH_SIZE, shuffle = False)
return testLoader
if __name__ == '__main__':
IMG_SIZE = 224
BATCH_SIZE = 32
MODEL_PATH = 'model.pth'
names = {0: 'Benign', 1: 'Malignant'}
predictions = []
test_loader = get_data(IMG_SIZE, BATCH_SIZE)
labels = [names[i] for i in test_loader.dataset.targets]
model = Model()
model.load_state_dict(torch.load(MODEL_PATH))
test_correct, test_total = 0, 0
with torch.no_grad():
for i, data in enumerate(test_loader):
input, label = data
output = model(input)
_, predicted = torch.max(output.data, 1)
test_total += label.size(0)
test_correct += (predicted == label).sum().item()
predictions.extend(predicted.tolist())
test_accuracy = test_correct / test_total
predictions = [names[pred] for pred in predictions]
print(f'Accuracy test {test_accuracy:.2%}')
with open('predictions.csv', 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(["Predictions", "Labels"])
for pred, label in zip(predictions, labels):
writer.writerow([pred, label])