84 lines
2.5 KiB
Python
84 lines
2.5 KiB
Python
import csv
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from torch import nn
|
|
from torchvision import transforms, datasets
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv1 = nn.Conv2d(3, 32, 3, 1)
|
|
self.batchnorm1 = nn.BatchNorm2d(32)
|
|
self.conv2 = nn.Conv2d(32, 64, 3, 1)
|
|
self.batchnorm2 = nn.BatchNorm2d(64)
|
|
self.conv3 = nn.Conv2d(64, 128, 3, 1)
|
|
self.fc1 = nn.Linear(128*26*26, 128)
|
|
self.fc2 = nn.Linear(128, 2)
|
|
|
|
def forward(self, x):
|
|
x = F.relu(self.batchnorm1(self.conv1(x)))
|
|
x = F.max_pool2d(x, 2, 2)
|
|
x = F.relu(self.batchnorm2(self.conv2(x)))
|
|
x = F.max_pool2d(x, 2, 2)
|
|
x = F.relu(self.conv3(x))
|
|
x = F.max_pool2d(x, 2, 2)
|
|
x = x.view(-1, 128*26*26)
|
|
x = F.relu(self.fc1(x))
|
|
x = self.fc2(x)
|
|
|
|
return F.log_softmax(x, dim=1)
|
|
|
|
def get_data(IMG_SIZE: int, BATCH_SIZE: int):
|
|
testTransformer = transforms.Compose([
|
|
transforms.Resize(size = (IMG_SIZE, IMG_SIZE), antialias = True),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]),
|
|
])
|
|
|
|
testSet = datasets.ImageFolder(root = "./test", transform = testTransformer)
|
|
testLoader = DataLoader(testSet, batch_size = BATCH_SIZE, shuffle = False)
|
|
|
|
return testLoader
|
|
|
|
|
|
if __name__ == '__main__':
|
|
IMG_SIZE = 224
|
|
BATCH_SIZE = 32
|
|
MODEL_PATH = 'model.pth'
|
|
|
|
names = {0: 'Benign', 1: 'Malignant'}
|
|
predictions = []
|
|
|
|
test_loader = get_data(IMG_SIZE, BATCH_SIZE)
|
|
labels = [names[i] for i in test_loader.dataset.targets]
|
|
|
|
model = Model()
|
|
model.load_state_dict(torch.load(MODEL_PATH))
|
|
|
|
test_correct, test_total = 0, 0
|
|
|
|
with torch.no_grad():
|
|
for i, data in enumerate(test_loader):
|
|
input, label = data
|
|
output = model(input)
|
|
_, predicted = torch.max(output.data, 1)
|
|
test_total += label.size(0)
|
|
test_correct += (predicted == label).sum().item()
|
|
|
|
predictions.extend(predicted.tolist())
|
|
|
|
test_accuracy = test_correct / test_total
|
|
|
|
predictions = [names[pred] for pred in predictions]
|
|
|
|
print(f'Accuracy test {test_accuracy:.2%}')
|
|
|
|
with open('predictions.csv', 'w', newline='') as csvfile:
|
|
writer = csv.writer(csvfile)
|
|
writer.writerow(["Predictions", "Labels"])
|
|
|
|
for pred, label in zip(predictions, labels):
|
|
writer.writerow([pred, label]) |