gru added

This commit is contained in:
Maciej Sobkowiak 2021-06-22 02:24:08 +02:00
parent adc78b066c
commit 81fa0ec07f

19
seq.py
View File

@ -80,18 +80,17 @@ def eval_model(dataset_tokens, dataset_labels, model):
return get_scores(Y_true, Y_pred) return get_scores(Y_true, Y_pred)
class LSTM(torch.nn.Module): class GRU(torch.nn.Module):
def __init__(self, vocab_len): def __init__(self, vocab_len):
super(LSTM, self).__init__() super(GRU, self).__init__()
self.emb = torch.nn.Embedding(vocab_len, 100) self.emb = torch.nn.Embedding(vocab_len, 100)
self.rec = torch.nn.LSTM(100, 256, 1, batch_first=True) self.rec = torch.nn.GRU(100, 256, 2, batch_first=True, dropout=0.2)
self.fc1 = torch.nn.Linear(256, 9) self.fc1 = torch.nn.Linear(256, 9)
def forward(self, x): def forward(self, x):
emb = torch.relu(self.emb(x)) emb = torch.relu(self.emb(x))
lstm_output, (h_n, c_n) = self.rec(emb) gru_output, h_n = self.rec(emb)
out_weights = self.fc1(lstm_output) out_weights = self.fc1(gru_output)
return out_weights return out_weights
# Load data # Load data
@ -126,7 +125,7 @@ def train(model, crf, train_tokens, labels_tokens):
predicted_tags = model(batch_tokens).squeeze(0).unsqueeze(1) predicted_tags = model(batch_tokens).squeeze(0).unsqueeze(1)
optimizer.zero_grad() optimizer.zero_grad()
loss = criterion(predicted_tags.squeeze(0), tags.squeeze(1)) loss = -crf(predicted_tags, tags)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
@ -138,12 +137,12 @@ if __name__ == "__main__":
vocab_y = build_vocab(Y_train) vocab_y = build_vocab(Y_train)
train_tokens = data_process(X_train, vocab_x) train_tokens = data_process(X_train, vocab_x)
labels_tokens = data_process(Y_train, vocab_y) labels_tokens = data_process(Y_train, vocab_y)
print(train_tokens[0])
# model # model
model = LSTM(len(vocab_x)) model = GRU(len(vocab_x))
print(model)
crf = CRF(9) crf = CRF(9)
p = list(model.parameters()) + list(crf.parameters()) p = list(model.parameters()) + list(crf.parameters())
optimizer = torch.optim.Adam(p) optimizer = torch.optim.Adam(p)
criterion = torch.nn.CrossEntropyLoss() # mask = torch.ByteTensor([1, 1]) # (batch_size. sequence_size)
train(model, crf, train_tokens, labels_tokens) train(model, crf, train_tokens, labels_tokens)