From ba715dfce1418c6e0c64cd70760132ddaaa55b5e Mon Sep 17 00:00:00 2001 From: filnow Date: Sun, 28 Apr 2024 19:10:12 +0200 Subject: [PATCH] import model --- .gitignore | 4 ++-- test.py | 27 +-------------------------- 2 files changed, 3 insertions(+), 28 deletions(-) diff --git a/.gitignore b/.gitignore index 5878d1d..7cf8443 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,6 @@ ipython_config.py /train/ /test/ -# Remove previous ipynb_checkpoints -# git rm -r .ipynb_checkpoints/ + +__pycache__/ diff --git a/test.py b/test.py index 2fa7455..e7ae738 100644 --- a/test.py +++ b/test.py @@ -1,35 +1,10 @@ import csv import torch -import torch.nn.functional as F -from torch import nn from torchvision import transforms, datasets from torch.utils.data import DataLoader +from train import Model - -class Model(nn.Module): - def __init__(self): - super().__init__() - self.conv1 = nn.Conv2d(3, 32, 3, 1) - self.batchnorm1 = nn.BatchNorm2d(32) - self.conv2 = nn.Conv2d(32, 64, 3, 1) - self.batchnorm2 = nn.BatchNorm2d(64) - self.conv3 = nn.Conv2d(64, 128, 3, 1) - self.fc1 = nn.Linear(128*26*26, 128) - self.fc2 = nn.Linear(128, 2) - - def forward(self, x): - x = F.relu(self.batchnorm1(self.conv1(x))) - x = F.max_pool2d(x, 2, 2) - x = F.relu(self.batchnorm2(self.conv2(x))) - x = F.max_pool2d(x, 2, 2) - x = F.relu(self.conv3(x)) - x = F.max_pool2d(x, 2, 2) - x = x.view(-1, 128*26*26) - x = F.relu(self.fc1(x)) - x = self.fc2(x) - - return F.log_softmax(x, dim=1) def get_data(IMG_SIZE: int, BATCH_SIZE: int): testTransformer = transforms.Compose([