This commit is contained in:
nlitkowski 2021-05-26 02:33:40 +02:00
parent d0d7934292
commit 53fd98388c
2 changed files with 298 additions and 290 deletions

File diff suppressed because it is too large Load Diff

View File

@ -16,7 +16,10 @@ class Model(nn.Module):
self.output_dim = output_dim self.output_dim = output_dim
self.fc1 = nn.Linear(self.input_dim, self.hidden_dim) self.fc1 = nn.Linear(self.input_dim, self.hidden_dim)
self.fc2 = nn.Linear(self.hidden_dim, self.output_dim) self.fc2 = nn.Linear(self.hidden_dim, self.hidden_dim)
self.fc3 = nn.Linear(self.hidden_dim, self.output_dim)
self.relu = nn.ReLU()
self.criterion = nn.BCELoss() self.criterion = nn.BCELoss()
self.optimizer = torch.optim.SGD(self.parameters(), lr=0.01) self.optimizer = torch.optim.SGD(self.parameters(), lr=0.01)
@ -25,14 +28,19 @@ class Model(nn.Module):
"""Step forward learning fn""" """Step forward learning fn"""
x = self.fc1(x) x = self.fc1(x)
x = torch.relu(x) x = self.relu(x)
x = self.fc2(x) x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
x = torch.sigmoid(x) x = torch.sigmoid(x)
return x return x
def run_training(self, X_train, Y_train, batch_size, epochs_count): def run_training(self, X_train, Y_train, batch_size, epochs_count):
for _ in range(epochs_count): for _ in range(epochs_count):
self.train() self.train()
print(f"{Y_train.shape[0]}, {Y_train.shape[0] == self.input_dim}")
print(f"{Y_train.shape[0]}, {Y_train.shape[0] == self.hidden_dim}")
print(f"{Y_train.shape[0]}, {Y_train.shape[0] == self.output_dim}")
for i in range(0, Y_train.shape[0], batch_size): for i in range(0, Y_train.shape[0], batch_size):
X = X_train[i: i + batch_size] X = X_train[i: i + batch_size]
X = torch.tensor(X) X = torch.tensor(X)