Fix train
This commit is contained in:
parent
8e066818ea
commit
91decc353d
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
gender_classification_v7.csv
|
||||
gender_classification_val.csv
|
||||
gender_classification_test.csv
|
||||
gender_classification_train.csv
|
6
train.py
6
train.py
@ -11,11 +11,11 @@ import argparse
|
||||
class MyNeuralNetwork(nn.Module):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super(MyNeuralNetwork, self).__init__(*args, **kwargs)
|
||||
self.fc1 = nn.Linear(32, 7)
|
||||
self.fc1 = nn.Linear(7, 12)
|
||||
self.relu = nn.ReLU()
|
||||
self.fc1 = nn.Linear(12, 64)
|
||||
self.fc1 = nn.Linear(7, 12)
|
||||
self.relu = nn.ReLU()
|
||||
self.fc2 = nn.Linear(64, 1)
|
||||
self.fc2 = nn.Linear(12, 1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
|
Loading…
Reference in New Issue
Block a user