s478841
This commit is contained in:
parent
82c2482af6
commit
eb69e726e3
254
run.py
Normal file
254
run.py
Normal file
@ -0,0 +1,254 @@
|
|||||||
|
import lzma
|
||||||
|
from collections import Counter
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torchtext.vocab
|
||||||
|
from bidict import bidict
|
||||||
|
from string import punctuation
|
||||||
|
|
||||||
|
LABEL_TO_ID = bidict({
|
||||||
|
'O': 0,
|
||||||
|
'B-PER': 1,
|
||||||
|
'B-LOC': 2,
|
||||||
|
'I-PER': 3,
|
||||||
|
'B-MISC': 4,
|
||||||
|
'I-MISC': 5,
|
||||||
|
'I-LOC': 6,
|
||||||
|
'B-ORG': 7,
|
||||||
|
'I-ORG': 8
|
||||||
|
})
|
||||||
|
ID_TO_LABEL = LABEL_TO_ID.inverse
|
||||||
|
|
||||||
|
|
||||||
|
def read_data(path):
|
||||||
|
print(f"I am reading the data from {path}...")
|
||||||
|
if path[-2:] == 'xz':
|
||||||
|
data = {'text': [], 'tokens': []}
|
||||||
|
with lzma.open(path, 'rt', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip().rsplit('\t')
|
||||||
|
tokens, text = line[0].split(), line[1].split()
|
||||||
|
if len(tokens) == len(text):
|
||||||
|
data['tokens'].append(tokens)
|
||||||
|
data['text'].append(text)
|
||||||
|
else:
|
||||||
|
with open(path, 'r', encoding='utf-8') as f:
|
||||||
|
data = [line.strip().split() for line in f]
|
||||||
|
print("Data loaded")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def make_vocabulary(dataset):
|
||||||
|
counter = Counter()
|
||||||
|
for document in dataset:
|
||||||
|
counter.update(document)
|
||||||
|
vocab = torchtext.vocab.vocab(
|
||||||
|
counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])
|
||||||
|
vocab.set_default_index(0)
|
||||||
|
return vocab
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize_data(data, vocab):
|
||||||
|
return [
|
||||||
|
torch.tensor([vocab['<bos>']] + [vocab[token] for token in document] +
|
||||||
|
[vocab['<eos>']],
|
||||||
|
dtype=torch.long) for document in data
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def encode_labels(data):
|
||||||
|
data_num = [[LABEL_TO_ID[label] for label in labels] for labels in data]
|
||||||
|
return [
|
||||||
|
torch.tensor([0] + document + [0], dtype=torch.long)
|
||||||
|
for document in data_num
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def add_features(x_base, x_str):
|
||||||
|
word_features = [0, 0, 0, 0, 0, 0, 0, 0, 0]
|
||||||
|
if len(x_str) > 1 and len(x_str[1]) > 1:
|
||||||
|
word = x_str[1]
|
||||||
|
if word.isupper():
|
||||||
|
word_features[0] = 1
|
||||||
|
if word[0].isupper():
|
||||||
|
word_features[1] = 1
|
||||||
|
if word.isalnum():
|
||||||
|
word_features[2] = 1
|
||||||
|
if word.isnumeric():
|
||||||
|
word_features[3] = 1
|
||||||
|
if '-' in word:
|
||||||
|
word_features[4] = 1
|
||||||
|
if '/' in word:
|
||||||
|
word_features[5] = 1
|
||||||
|
for char in word:
|
||||||
|
if char in punctuation:
|
||||||
|
word_features[6] = 1
|
||||||
|
break
|
||||||
|
if len(word) > 6:
|
||||||
|
word_features[7] = 1
|
||||||
|
if len(word) < 3:
|
||||||
|
word_features[8] = 1
|
||||||
|
extra_features = torch.tensor(word_features)
|
||||||
|
x_features = torch.cat((x_base, extra_features), 0)
|
||||||
|
return x_features
|
||||||
|
|
||||||
|
|
||||||
|
class NERModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(NERModel, self).__init__()
|
||||||
|
self.embedding = nn.Embedding(23627, 200)
|
||||||
|
self.linear = nn.Linear(2400, 9)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.embedding(x)
|
||||||
|
x = x.reshape(2400)
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def train_model(model,
|
||||||
|
data,
|
||||||
|
train_labels,
|
||||||
|
train_tokens_ids,
|
||||||
|
epochs,
|
||||||
|
save=False):
|
||||||
|
model.train()
|
||||||
|
for epoch in range(epochs):
|
||||||
|
loss_score = 0
|
||||||
|
acc_score = 0
|
||||||
|
prec_score = 0
|
||||||
|
selected_items = 0
|
||||||
|
recall_score = 0
|
||||||
|
relevant_items = 0
|
||||||
|
items_total = 0
|
||||||
|
for i in range(len(train_labels) - 1):
|
||||||
|
for j in range(1, len(train_labels[i]) - 1):
|
||||||
|
X_base = train_tokens_ids[i][j - 1:j + 2]
|
||||||
|
X_string = data['text'][i][j - 1:j + 2]
|
||||||
|
X_extra = add_features(X_base, X_string)
|
||||||
|
Y = train_labels[i][j:j + 1]
|
||||||
|
|
||||||
|
X = X_extra.to(device)
|
||||||
|
Y = Y.to(device)
|
||||||
|
|
||||||
|
Y_predictions = model(X)
|
||||||
|
|
||||||
|
pred_class = torch.argmax(Y_predictions)
|
||||||
|
y_item = Y.item()
|
||||||
|
acc_score += pred_class == Y
|
||||||
|
if pred_class != 0:
|
||||||
|
selected_items += 1
|
||||||
|
if pred_class == y_item:
|
||||||
|
prec_score += 1
|
||||||
|
if y_item != 0:
|
||||||
|
relevant_items += 1
|
||||||
|
if pred_class == y_item:
|
||||||
|
recall_score += 1
|
||||||
|
items_total += 1
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss = criterion(Y_predictions.unsqueeze(0), Y)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
loss_score += loss.item()
|
||||||
|
|
||||||
|
precision = prec_score / selected_items
|
||||||
|
recall = recall_score / relevant_items
|
||||||
|
f1_score = 2 * precision * recall / (
|
||||||
|
precision + recall) if precision and recall else 0
|
||||||
|
|
||||||
|
if i + 1 % 10 == 0:
|
||||||
|
print('Epoch: ', epoch)
|
||||||
|
print('Loss: ', loss_score / items_total)
|
||||||
|
print('Accuracy: ', acc_score / items_total)
|
||||||
|
print('F1-score: ', f1_score)
|
||||||
|
print('Finished epoch: ', epoch)
|
||||||
|
if save:
|
||||||
|
torch.save(model, 'model.pt')
|
||||||
|
|
||||||
|
|
||||||
|
def write_results(data, path):
|
||||||
|
with open(path, 'w') as f:
|
||||||
|
for line in data:
|
||||||
|
f.write(f'{line}\n')
|
||||||
|
print(f"Data written to the file {path}")
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def predict(model, x_data, vocab, device):
|
||||||
|
tokens_ids = tokenize_data(x_data, vocab)
|
||||||
|
preds = []
|
||||||
|
# print('Getting into predicting loop')
|
||||||
|
for i in range(len(tokens_ids)):
|
||||||
|
labels = ''
|
||||||
|
# print('I will go with the sentence:\t', i)
|
||||||
|
for j in range(1, len(tokens_ids[i]) - 1):
|
||||||
|
x_base = tokens_ids[i][j - 1:j + 2]
|
||||||
|
x_strings = x_data[i][j - 1:j + 2]
|
||||||
|
x_features = add_features(x_base, x_strings) # .to(device)
|
||||||
|
# print('I will predict on data:\t', x_base, x_strings)
|
||||||
|
try:
|
||||||
|
pred = model(x_features)
|
||||||
|
label = ID_TO_LABEL[int(torch.argmax(pred))]
|
||||||
|
labels += f'{label} '
|
||||||
|
except Exception as ex:
|
||||||
|
print(f'Exception\t→\t{ex}\t{x_strings}→{x_features}')
|
||||||
|
preds.append(labels[:-1])
|
||||||
|
print('Done with the inference, now writing it into the file!\n')
|
||||||
|
lines = []
|
||||||
|
for line in preds:
|
||||||
|
prev_label = None
|
||||||
|
new_line = []
|
||||||
|
for label in line.split():
|
||||||
|
if label[0] == 'I':
|
||||||
|
if prev_label is None or prev_label == 'O':
|
||||||
|
label = label.replace('I', 'B')
|
||||||
|
else:
|
||||||
|
label = 'I' + prev_label[1:]
|
||||||
|
prev_label = label
|
||||||
|
new_line.append(label)
|
||||||
|
lines.append(' '.join(new_line))
|
||||||
|
return lines
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
# * Data loading
|
||||||
|
data = read_data('./train/train.tsv.xz')
|
||||||
|
vocab = make_vocabulary(data['text'])
|
||||||
|
train_tokens_ids = tokenize_data(data['text'], vocab)
|
||||||
|
train_labels = encode_labels(data['tokens'])
|
||||||
|
|
||||||
|
# * Model set-up
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
print('My device is ', device)
|
||||||
|
ner_model = NERModel().to(device)
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
optimizer = torch.optim.Adam(ner_model.parameters())
|
||||||
|
epochs = 3
|
||||||
|
|
||||||
|
# * Training
|
||||||
|
train_model(ner_model,
|
||||||
|
data,
|
||||||
|
train_labels,
|
||||||
|
train_tokens_ids,
|
||||||
|
epochs,
|
||||||
|
save=True)
|
||||||
|
|
||||||
|
# * Inference time!!!
|
||||||
|
print("Now, let's predict something!")
|
||||||
|
# new_model = torch.load(PATH)
|
||||||
|
ner_model.cpu()
|
||||||
|
ner_model.eval()
|
||||||
|
|
||||||
|
# * Inference on dev-0 data
|
||||||
|
dev_data = read_data('./dev-0/in.tsv')
|
||||||
|
write_results(predict(ner_model, dev_data, vocab, device),
|
||||||
|
'./dev-0/out.tsv')
|
||||||
|
|
||||||
|
# * Inference on test-A data
|
||||||
|
test_data = read_data('./test-A/in.tsv')
|
||||||
|
write_results(predict(ner_model, test_data, vocab, device),
|
||||||
|
'./test-A/out.tsv')
|
Loading…
Reference in New Issue
Block a user