forked from kubapok/en-ner-conll-2003
34 lines
1.0 KiB
Python
34 lines
1.0 KiB
Python
|
import torch
|
||
|
from torchcrf import CRF
|
||
|
|
||
|
class Model(torch.nn.Module):
|
||
|
def __init__(self, num_tags, seq_length, vocab):
|
||
|
super(Model, self).__init__()
|
||
|
self.emb = torch.nn.Embedding(len(vocab.get_itos()), 100)
|
||
|
self.gru = torch.nn.GRU(100, 256, 1, batch_first=True)
|
||
|
self.hidden2tag = torch.nn.Linear(256, 9)
|
||
|
self.crf = CRF(num_tags, batch_first=True)
|
||
|
self.relu = torch.nn.ReLU()
|
||
|
self.fc1 = torch.nn.Linear(1, seq_length)
|
||
|
self.softmax = torch.nn.Softmax(dim=0)
|
||
|
self.sigm = torch.nn.Sigmoid()
|
||
|
|
||
|
def forward(self, data, tags):
|
||
|
emb = self.relu(self.emb(data))
|
||
|
out, h_n = self.gru(emb)
|
||
|
out = self.hidden2tag(out)
|
||
|
out = self.crf(out, tags.T)
|
||
|
return -out
|
||
|
|
||
|
def decode(self, data):
|
||
|
emb = self.relu(self.emb(data))
|
||
|
out, h_n = self.gru(emb)
|
||
|
out = self.hidden2tag(out)
|
||
|
out = self.crf.decode(out)
|
||
|
return out
|
||
|
|
||
|
def train_mode(self):
|
||
|
self.crf.train()
|
||
|
|
||
|
def eval_mode(self):
|
||
|
self.crf.eval()
|