ium_464915/test.py

59 lines
1.7 KiB
Python
Raw Normal View History

2024-04-28 18:57:42 +02:00
import csv
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
2024-04-28 19:10:12 +02:00
from train import Model
2024-04-28 18:57:42 +02:00
def get_data(IMG_SIZE: int, BATCH_SIZE: int):
testTransformer = transforms.Compose([
transforms.Resize(size = (IMG_SIZE, IMG_SIZE), antialias = True),
transforms.ToTensor(),
transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]),
])
testSet = datasets.ImageFolder(root = "./test", transform = testTransformer)
testLoader = DataLoader(testSet, batch_size = BATCH_SIZE, shuffle = False)
return testLoader
if __name__ == '__main__':
IMG_SIZE = 224
BATCH_SIZE = 32
MODEL_PATH = 'model.pth'
names = {0: 'Benign', 1: 'Malignant'}
predictions = []
test_loader = get_data(IMG_SIZE, BATCH_SIZE)
labels = [names[i] for i in test_loader.dataset.targets]
model = Model()
model.load_state_dict(torch.load(MODEL_PATH))
test_correct, test_total = 0, 0
with torch.no_grad():
for i, data in enumerate(test_loader):
input, label = data
output = model(input)
_, predicted = torch.max(output.data, 1)
test_total += label.size(0)
test_correct += (predicted == label).sum().item()
predictions.extend(predicted.tolist())
test_accuracy = test_correct / test_total
predictions = [names[pred] for pred in predictions]
print(f'Accuracy test {test_accuracy:.2%}')
with open('predictions.csv', 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(["Predictions", "Labels"])
for pred, label in zip(predictions, labels):
writer.writerow([pred, label])