Fix train

This commit is contained in:
wojciechbatruszewicz 2023-06-27 15:19:55 +02:00
parent 8e066818ea
commit 91decc353d
3 changed files with 7 additions and 3 deletions

4
.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
gender_classification_v7.csv
gender_classification_val.csv
gender_classification_test.csv
gender_classification_train.csv

BIN
model.pt Normal file

Binary file not shown.

View File

@ -11,11 +11,11 @@ import argparse
class MyNeuralNetwork(nn.Module): class MyNeuralNetwork(nn.Module):
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
super(MyNeuralNetwork, self).__init__(*args, **kwargs) super(MyNeuralNetwork, self).__init__(*args, **kwargs)
self.fc1 = nn.Linear(32, 7) self.fc1 = nn.Linear(7, 12)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.fc1 = nn.Linear(12, 64) self.fc1 = nn.Linear(7, 12)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.fc2 = nn.Linear(64, 1) self.fc2 = nn.Linear(12, 1)
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
def forward(self, x): def forward(self, x):