en-ner-conll-2003/model.py

34 lines
1.0 KiB
Python

import torch
from torchcrf import CRF
class Model(torch.nn.Module):
def __init__(self, num_tags, seq_length, vocab):
super(Model, self).__init__()
self.emb = torch.nn.Embedding(len(vocab.get_itos()), 100)
self.gru = torch.nn.GRU(100, 256, 1, batch_first=True)
self.hidden2tag = torch.nn.Linear(256, 9)
self.crf = CRF(num_tags, batch_first=True)
self.relu = torch.nn.ReLU()
self.fc1 = torch.nn.Linear(1, seq_length)
self.softmax = torch.nn.Softmax(dim=0)
self.sigm = torch.nn.Sigmoid()
def forward(self, data, tags):
emb = self.relu(self.emb(data))
out, h_n = self.gru(emb)
out = self.hidden2tag(out)
out = self.crf(out, tags.T)
return -out
def decode(self, data):
emb = self.relu(self.emb(data))
out, h_n = self.gru(emb)
out = self.hidden2tag(out)
out = self.crf.decode(out)
return out
def train_mode(self):
self.crf.train()
def eval_mode(self):
self.crf.eval()