gru added
This commit is contained in:
parent
adc78b066c
commit
81fa0ec07f
19
seq.py
19
seq.py
@ -80,18 +80,17 @@ def eval_model(dataset_tokens, dataset_labels, model):
|
|||||||
return get_scores(Y_true, Y_pred)
|
return get_scores(Y_true, Y_pred)
|
||||||
|
|
||||||
|
|
||||||
class LSTM(torch.nn.Module):
|
class GRU(torch.nn.Module):
|
||||||
def __init__(self, vocab_len):
|
def __init__(self, vocab_len):
|
||||||
super(LSTM, self).__init__()
|
super(GRU, self).__init__()
|
||||||
self.emb = torch.nn.Embedding(vocab_len, 100)
|
self.emb = torch.nn.Embedding(vocab_len, 100)
|
||||||
self.rec = torch.nn.LSTM(100, 256, 1, batch_first=True)
|
self.rec = torch.nn.GRU(100, 256, 2, batch_first=True, dropout=0.2)
|
||||||
self.fc1 = torch.nn.Linear(256, 9)
|
self.fc1 = torch.nn.Linear(256, 9)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
emb = torch.relu(self.emb(x))
|
emb = torch.relu(self.emb(x))
|
||||||
lstm_output, (h_n, c_n) = self.rec(emb)
|
gru_output, h_n = self.rec(emb)
|
||||||
out_weights = self.fc1(lstm_output)
|
out_weights = self.fc1(gru_output)
|
||||||
|
|
||||||
return out_weights
|
return out_weights
|
||||||
|
|
||||||
# Load data
|
# Load data
|
||||||
@ -126,7 +125,7 @@ def train(model, crf, train_tokens, labels_tokens):
|
|||||||
predicted_tags = model(batch_tokens).squeeze(0).unsqueeze(1)
|
predicted_tags = model(batch_tokens).squeeze(0).unsqueeze(1)
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss = criterion(predicted_tags.squeeze(0), tags.squeeze(1))
|
loss = -crf(predicted_tags, tags)
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
@ -138,12 +137,12 @@ if __name__ == "__main__":
|
|||||||
vocab_y = build_vocab(Y_train)
|
vocab_y = build_vocab(Y_train)
|
||||||
train_tokens = data_process(X_train, vocab_x)
|
train_tokens = data_process(X_train, vocab_x)
|
||||||
labels_tokens = data_process(Y_train, vocab_y)
|
labels_tokens = data_process(Y_train, vocab_y)
|
||||||
print(train_tokens[0])
|
|
||||||
|
|
||||||
# model
|
# model
|
||||||
model = LSTM(len(vocab_x))
|
model = GRU(len(vocab_x))
|
||||||
|
print(model)
|
||||||
crf = CRF(9)
|
crf = CRF(9)
|
||||||
p = list(model.parameters()) + list(crf.parameters())
|
p = list(model.parameters()) + list(crf.parameters())
|
||||||
optimizer = torch.optim.Adam(p)
|
optimizer = torch.optim.Adam(p)
|
||||||
criterion = torch.nn.CrossEntropyLoss()
|
# mask = torch.ByteTensor([1, 1]) # (batch_size. sequence_size)
|
||||||
train(model, crf, train_tokens, labels_tokens)
|
train(model, crf, train_tokens, labels_tokens)
|
||||||
|
Loading…
Reference in New Issue
Block a user