small fixes
This commit is contained in:
parent
b5d25d710d
commit
c363b09f85
BIN
source/NN/__pycache__/model.cpython-311.pyc
Normal file
BIN
source/NN/__pycache__/model.cpython-311.pyc
Normal file
Binary file not shown.
@ -1,4 +1,6 @@
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
class Neural_Network_Model(nn.Module):
|
class Neural_Network_Model(nn.Module):
|
||||||
@ -16,5 +18,4 @@ class Neural_Network_Model(nn.Module):
|
|||||||
x = self.fc2(x)
|
x = self.fc2(x)
|
||||||
x = torch.relu(x)
|
x = torch.relu(x)
|
||||||
x = self.out(x)
|
x = self.out(x)
|
||||||
F.log_softmax(x, dim=-1)
|
return F.log_softmax(x, dim=-1)
|
||||||
return x
|
|
||||||
|
@ -1,15 +1,16 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torchvision import datasets, transforms
|
from torchvision import datasets, transforms, utils
|
||||||
from torchvision.transforms import Compose, Lambda, ToTensor
|
from torchvision.transforms import Compose, Lambda, ToTensor
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from model import *
|
from model import *
|
||||||
|
|
||||||
|
device = torch.device('cuda')
|
||||||
|
|
||||||
#data transform to tensors:
|
#data transform to tensors:
|
||||||
data_transformer = transforms.Compose
|
data_transformer = transforms.Compose([
|
||||||
([
|
|
||||||
transforms.Resize((150, 150)),
|
transforms.Resize((150, 150)),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||||
@ -46,6 +47,9 @@ def train(model, dataset, iter=100, batch_size=64):
|
|||||||
loss = criterion(output, labels.to(device))
|
loss = criterion(output, labels.to(device))
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
if epoch % 10 == 0:
|
||||||
|
print('epoch: %3d loss: %.4f' % (epoch, loss))
|
||||||
|
|
||||||
#function for getting accuracy
|
#function for getting accuracy
|
||||||
def accuracy(model, dataset):
|
def accuracy(model, dataset):
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -57,6 +61,9 @@ def accuracy(model, dataset):
|
|||||||
return correct.float() / len(dataset)
|
return correct.float() / len(dataset)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
model = Neural_Network_Model()
|
model = Neural_Network_Model()
|
||||||
|
model.to(device)
|
||||||
train(model, train_set)
|
train(model, train_set)
|
||||||
print(accuracy(model, test_set))
|
print(accuracy(model, test_set))
|
||||||
|
Loading…
Reference in New Issue
Block a user