updat
This commit is contained in:
parent
f3e75f5e1f
commit
616f7be1c8
4134
dev-0/out.tsv
4134
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
2
main.py
2
main.py
@ -19,7 +19,7 @@ OUT_HEADER_FILE_NAME = "out-header.tsv"
|
||||
|
||||
# Model training config
|
||||
BATCH_SIZE = 5
|
||||
EPOCHS = 30
|
||||
EPOCHS = 15
|
||||
THRESHOLD = 0.5
|
||||
|
||||
# Model dimensions
|
||||
|
7
model.py
7
model.py
@ -16,10 +16,11 @@ class Model(nn.Module):
|
||||
self.output_dim = output_dim
|
||||
|
||||
self.fc1 = nn.Linear(self.input_dim, self.hidden_dim)
|
||||
self.fc2 = nn.Linear(self.hidden_dim, self.output_dim)
|
||||
self.fc2 = nn.Linear(self.hidden_dim, self.hidden_dim)
|
||||
self.fc3 = nn.Linear(self.hidden_dim, self.output_dim)
|
||||
|
||||
self.criterion = nn.BCELoss()
|
||||
self.optimizer = torch.optim.SGD(self.parameters(), lr=0.02)
|
||||
self.optimizer = torch.optim.SGD(self.parameters(), lr=0.01)
|
||||
|
||||
def forward(self, x):
|
||||
"""Step forward learning fn"""
|
||||
@ -27,6 +28,8 @@ class Model(nn.Module):
|
||||
x = self.fc1(x)
|
||||
x = torch.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = torch.relu(x)
|
||||
x = self.fc3(x)
|
||||
x = torch.sigmoid(x)
|
||||
return x
|
||||
|
||||
|
5136
test-A/out.tsv
5136
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user