This commit is contained in:
nlitkowski 2021-05-26 04:12:04 +02:00
parent f3e75f5e1f
commit 616f7be1c8
4 changed files with 4641 additions and 4638 deletions

File diff suppressed because it is too large Load Diff

View File

@ -19,7 +19,7 @@ OUT_HEADER_FILE_NAME = "out-header.tsv"
# Model training config # Model training config
BATCH_SIZE = 5 BATCH_SIZE = 5
EPOCHS = 30 EPOCHS = 15
THRESHOLD = 0.5 THRESHOLD = 0.5
# Model dimensions # Model dimensions

View File

@ -16,10 +16,11 @@ class Model(nn.Module):
self.output_dim = output_dim self.output_dim = output_dim
self.fc1 = nn.Linear(self.input_dim, self.hidden_dim) self.fc1 = nn.Linear(self.input_dim, self.hidden_dim)
self.fc2 = nn.Linear(self.hidden_dim, self.output_dim) self.fc2 = nn.Linear(self.hidden_dim, self.hidden_dim)
self.fc3 = nn.Linear(self.hidden_dim, self.output_dim)
self.criterion = nn.BCELoss() self.criterion = nn.BCELoss()
self.optimizer = torch.optim.SGD(self.parameters(), lr=0.02) self.optimizer = torch.optim.SGD(self.parameters(), lr=0.01)
def forward(self, x): def forward(self, x):
"""Step forward learning fn""" """Step forward learning fn"""
@ -27,6 +28,8 @@ class Model(nn.Module):
x = self.fc1(x) x = self.fc1(x)
x = torch.relu(x) x = torch.relu(x)
x = self.fc2(x) x = self.fc2(x)
x = torch.relu(x)
x = self.fc3(x)
x = torch.sigmoid(x) x = torch.sigmoid(x)
return x return x

File diff suppressed because it is too large Load Diff