fix: small change to main_network

This commit is contained in:
Jager72 2024-06-10 20:30:49 +02:00
parent da938193c3
commit 5932ba1f93
1 changed files with 20 additions and 6 deletions

View File

@ -41,6 +41,25 @@ hidden_size = 135*64
#żeby to skończyć kiedykolwiek i jak uda mi się to CUDA zrobić to zmieniam rozdziałke z 1024 na 128
#Użycie konwolucyjnej warstwy dobiło mi z 85 do 90
# model = torch.nn.Sequential(
# torch.nn.Conv2d(3, 6, 5),
# torch.nn.MaxPool2d(2, 2),
# torch.nn.Conv2d(6, 16, 5),
# torch.nn.Flatten(),
# torch.nn.Linear(53824, hidden_size),
# torch.nn.ReLU(),
# torch.nn.Linear(hidden_size, 32*32),
# torch.nn.ReLU(),
# torch.nn.Linear(32*32, 10),
# torch.nn.LogSoftmax(dim=-1)
# ).to(device)
# train(model, trainset)
# print(accuracy(model, testset))
# torch.save(model.state_dict(), 'model.pt')
model = torch.nn.Sequential(
torch.nn.Conv2d(3, 6, 5),
torch.nn.MaxPool2d(2, 2),
@ -53,9 +72,4 @@ model = torch.nn.Sequential(
torch.nn.Linear(32*32, 10),
torch.nn.LogSoftmax(dim=-1)
).to(device)
train(model, trainset)
print(accuracy(model, testset))
# torch.save(model.state_dict(), 'model.pt')
model.load_state_dict(torch.load('model.pt'))