gru added
This commit is contained in:
parent
adc78b066c
commit
81fa0ec07f
19
seq.py
19
seq.py
@ -80,18 +80,17 @@ def eval_model(dataset_tokens, dataset_labels, model):
|
||||
return get_scores(Y_true, Y_pred)
|
||||
|
||||
|
||||
class LSTM(torch.nn.Module):
|
||||
class GRU(torch.nn.Module):
|
||||
def __init__(self, vocab_len):
|
||||
super(LSTM, self).__init__()
|
||||
super(GRU, self).__init__()
|
||||
self.emb = torch.nn.Embedding(vocab_len, 100)
|
||||
self.rec = torch.nn.LSTM(100, 256, 1, batch_first=True)
|
||||
self.rec = torch.nn.GRU(100, 256, 2, batch_first=True, dropout=0.2)
|
||||
self.fc1 = torch.nn.Linear(256, 9)
|
||||
|
||||
def forward(self, x):
|
||||
emb = torch.relu(self.emb(x))
|
||||
lstm_output, (h_n, c_n) = self.rec(emb)
|
||||
out_weights = self.fc1(lstm_output)
|
||||
|
||||
gru_output, h_n = self.rec(emb)
|
||||
out_weights = self.fc1(gru_output)
|
||||
return out_weights
|
||||
|
||||
# Load data
|
||||
@ -126,7 +125,7 @@ def train(model, crf, train_tokens, labels_tokens):
|
||||
predicted_tags = model(batch_tokens).squeeze(0).unsqueeze(1)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss = criterion(predicted_tags.squeeze(0), tags.squeeze(1))
|
||||
loss = -crf(predicted_tags, tags)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
@ -138,12 +137,12 @@ if __name__ == "__main__":
|
||||
vocab_y = build_vocab(Y_train)
|
||||
train_tokens = data_process(X_train, vocab_x)
|
||||
labels_tokens = data_process(Y_train, vocab_y)
|
||||
print(train_tokens[0])
|
||||
|
||||
# model
|
||||
model = LSTM(len(vocab_x))
|
||||
model = GRU(len(vocab_x))
|
||||
print(model)
|
||||
crf = CRF(9)
|
||||
p = list(model.parameters()) + list(crf.parameters())
|
||||
optimizer = torch.optim.Adam(p)
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
# mask = torch.ByteTensor([1, 1]) # (batch_size. sequence_size)
|
||||
train(model, crf, train_tokens, labels_tokens)
|
||||
|
Loading…
Reference in New Issue
Block a user