import torch from torchcrf import CRF class Model(torch.nn.Module): def __init__(self, num_tags, seq_length, vocab): super(Model, self).__init__() self.emb = torch.nn.Embedding(len(vocab.get_itos()), 100) self.gru = torch.nn.GRU(100, 256, 1, batch_first=True) self.hidden2tag = torch.nn.Linear(256, 9) self.crf = CRF(num_tags, batch_first=True) self.relu = torch.nn.ReLU() self.fc1 = torch.nn.Linear(1, seq_length) self.softmax = torch.nn.Softmax(dim=0) self.sigm = torch.nn.Sigmoid() def forward(self, data, tags): emb = self.relu(self.emb(data)) out, h_n = self.gru(emb) out = self.hidden2tag(out) out = self.crf(out, tags.T) return -out def decode(self, data): emb = self.relu(self.emb(data)) out, h_n = self.gru(emb) out = self.hidden2tag(out) out = self.crf.decode(out) return out def train_mode(self): self.crf.train() def eval_mode(self): self.crf.eval()