From 81fa0ec07f07877f12b89cffa2d779898095a282 Mon Sep 17 00:00:00 2001 From: Maciej Sobkowiak Date: Tue, 22 Jun 2021 02:24:08 +0200 Subject: [PATCH] gru added --- seq.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/seq.py b/seq.py index 847ea74..f47633e 100644 --- a/seq.py +++ b/seq.py @@ -80,18 +80,17 @@ def eval_model(dataset_tokens, dataset_labels, model): return get_scores(Y_true, Y_pred) -class LSTM(torch.nn.Module): +class GRU(torch.nn.Module): def __init__(self, vocab_len): - super(LSTM, self).__init__() + super(GRU, self).__init__() self.emb = torch.nn.Embedding(vocab_len, 100) - self.rec = torch.nn.LSTM(100, 256, 1, batch_first=True) + self.rec = torch.nn.GRU(100, 256, 2, batch_first=True, dropout=0.2) self.fc1 = torch.nn.Linear(256, 9) def forward(self, x): emb = torch.relu(self.emb(x)) - lstm_output, (h_n, c_n) = self.rec(emb) - out_weights = self.fc1(lstm_output) - + gru_output, h_n = self.rec(emb) + out_weights = self.fc1(gru_output) return out_weights # Load data @@ -126,7 +125,7 @@ def train(model, crf, train_tokens, labels_tokens): predicted_tags = model(batch_tokens).squeeze(0).unsqueeze(1) optimizer.zero_grad() - loss = criterion(predicted_tags.squeeze(0), tags.squeeze(1)) + loss = -crf(predicted_tags, tags) loss.backward() optimizer.step() @@ -138,12 +137,12 @@ if __name__ == "__main__": vocab_y = build_vocab(Y_train) train_tokens = data_process(X_train, vocab_x) labels_tokens = data_process(Y_train, vocab_y) - print(train_tokens[0]) # model - model = LSTM(len(vocab_x)) + model = GRU(len(vocab_x)) + print(model) crf = CRF(9) p = list(model.parameters()) + list(crf.parameters()) optimizer = torch.optim.Adam(p) - criterion = torch.nn.CrossEntropyLoss() + # mask = torch.ByteTensor([1, 1]) # (batch_size. sequence_size) train(model, crf, train_tokens, labels_tokens)