20 KiB
20 KiB
import torch
import pandas as pd
from collections import Counter
from torchtext.vocab import vocab
from sklearn.metrics import accuracy_score
from tqdm import tqdm
#Wczytanie zbioru danych
train_set = pd.read_csv('./train/train.tsv', sep='\t', header=None, names=['labels', 'text'])
val_set = pd.read_csv('./dev-0/expected.tsv', sep='\t', header=None, names=['labels'])
val_set['text'] = pd.read_csv('./dev-0/in.tsv', sep='\t', header=None, names=['text'])
test_set = pd.read_csv('./test-A/in.tsv', sep='\t', header=None, names=['text'])
#Tokenizacja danych
train_set['text'] = train_set["text"].apply(lambda x : x.split())
train_set['labels'] = train_set["labels"].apply(lambda x : x.split())
val_set['text'] = val_set["text"].apply(lambda x : x.split())
val_set['labels'] = val_set["labels"].apply(lambda x : x.split())
test_set['text'] = test_set["text"].apply(lambda x : x.split())
train_set.head(10)
labels | text | |
---|---|---|
0 | [B-ORG, O, B-MISC, O, O, O, B-MISC, O, O, O, B... | [EU, rejects, German, call, to, boycott, Briti... |
1 | [O, B-PER, O, O, O, O, O, O, O, O, O, B-LOC, O... | [Rare, Hendrix, song, draft, sells, for, almos... |
2 | [B-LOC, O, B-LOC, O, O, O, O, O, O, B-LOC, O, ... | [China, says, Taiwan, spoils, atmosphere, for,... |
3 | [B-LOC, O, O, O, O, B-LOC, O, O, O, B-LOC, O, ... | [China, says, time, right, for, Taiwan, talks,... |
4 | [B-MISC, O, O, O, O, O, O, O, O, O, O, O, B-LO... | [German, July, car, registrations, up, 14.2, p... |
5 | [B-MISC, O, O, O, O, O, O, O, O, O, O, B-LOC, ... | [GREEK, SOCIALISTS, GIVE, GREEN, LIGHT, TO, PM... |
6 | [B-ORG, O, B-MISC, O, O, O, O, O, O, B-LOC, O,... | [BayerVB, sets, C$, 100, million, six-year, bo... |
7 | [B-ORG, O, O, O, O, O, O, O, O, O, B-LOC, O, O... | [Venantius, sets, $, 300, million, January, 19... |
8 | [O, O, O, O, B-LOC, O, B-ORG, I-ORG, O, O, O, ... | [Port, conditions, update, -, Syria, -, Lloyds... |
9 | [B-LOC, O, O, O, O, O, O, B-LOC, O, O, B-PER, ... | [Israel, plays, down, fears, of, war, with, Sy... |
#Budowanie słownika
def build_vocab(dataset):
counter = Counter()
for document in dataset:
counter.update(document)
return vocab(counter, specials=["<unk>", "<pad>", "<bos>", "<eos>"])
v = build_vocab(train_set['text'])
v.set_default_index(v["<unk>"])
itos = v.get_itos()
itos[:10]
['<unk>', '<pad>', '<bos>', '<eos>', 'EU', 'rejects', 'German', 'call', 'to', 'boycott']
def data_process(dt):
# Wektoryzacja dokumentów tekstowych.
return [
torch.tensor(
[v["<bos>"]] + [v[token] for token in document] + [v["<eos>"]],
dtype=torch.long,
)
for document in dt
]
def labels_process(dt):
# Wektoryzacja etykiet (NER)
return [torch.tensor([0] + document + [0], dtype=torch.long) for document in dt]
#Różne tagi NER
num_tags = {
"O" : 0,
"B-PER" : 1,
"I-PER" : 2,
"B-ORG" : 3,
"I-ORG" : 4,
"B-LOC" : 5,
"I-LOC" : 6,
"B-MISC" : 7,
"I-MISC" : 8,
}
def covert_to_int(dt, tags):
labels = []
for label in dt:
labels.append([tags[i] for i in label])
return labels
train_tokens_ids = data_process(train_set['text'])
train_labels_ids = labels_process(covert_to_int(train_set['labels'], tags=num_tags))
val_tokens_ids = data_process(val_set['text'])
val_labels_ids = labels_process(covert_to_int(val_set['labels'], tags=num_tags))
test_tokens_ids = data_process(train_set['text'])
class LSTM(torch.nn.Module):
def __init__(self, num_tags):
super(LSTM, self).__init__()
self.emb = torch.nn.Embedding(len(v.get_itos()), 100)
self.rec = torch.nn.LSTM(100, 256, 1, batch_first=True)
self.fc1 = torch.nn.Linear(256, num_tags)
self.hidden2tag = torch.nn.Linear(20, num_tags)
def forward(self, x):
emb = torch.relu(self.emb(x))
lstm_output, (h_n, c_n) = self.rec(emb)
out_weights = self.fc1(lstm_output)
return out_weights
EPOCHS = 10
LR = 0.001
NUM_TAGS = len(num_tags)
model = LSTM(num_tags=NUM_TAGS)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = torch.nn.CrossEntropyLoss()
def get_scores(y_true, y_pred):
# Funkcja zwraca precyzję, pokrycie i F1
acc_score = 0
tp = 0
fp = 0
selected_items = 0
relevant_items = 0
for p, t in zip(y_pred, y_true):
if p == t:
acc_score += 1
if p > 0 and p == t:
tp += 1
if p > 0:
selected_items += 1
if t > 0:
relevant_items += 1
if selected_items == 0:
precision = 1.0
else:
precision = tp / selected_items
if relevant_items == 0:
recall = 1.0
else:
recall = tp / relevant_items
if precision + recall == 0.0:
f1 = 0.0
else:
f1 = 2 * precision * recall / (precision + recall)
acc = accuracy_score(y_true, y_pred)
return precision, recall, f1, acc
def eval_model(dataset_tokens, dataset_labels, model):
Y_true = []
Y_pred = []
for i in tqdm(range(len(dataset_labels))):
batch_tokens = dataset_tokens[i].unsqueeze(0)
tags = dataset_labels[i].unsqueeze(1)
Y_true += tags
Y_batch_pred_weights = model(batch_tokens).squeeze(0)
Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)
Y_pred += list(Y_batch_pred.numpy())
precision, recall, f1, acc = get_scores(Y_true, Y_pred)
print(f'precision: {precision}, recall: {recall}, f1: {f1}, val accuracy: {acc}')
NUM_EPOCHS = 10
for i in range(NUM_EPOCHS):
model.train()
train_true = []
train_pred = []
for i in tqdm(range(len(train_set['labels']))):
batch_tokens = train_tokens_ids[i].unsqueeze(0)
tags = train_labels_ids[i].unsqueeze(1)
train_true += tags
Y_batch_pred_weights = model(batch_tokens).squeeze(0)
Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)
train_pred += list(Y_batch_pred.numpy())
predicted_tags = model(batch_tokens)
optimizer.zero_grad()
loss = criterion(predicted_tags.squeeze(0), tags.squeeze(1))
loss.backward()
optimizer.step()
model.eval()
eval_model(val_tokens_ids, val_labels_ids, model)
print(f'Train accuracy: {accuracy_score(train_true, train_pred)}')
100%|██████████| 945/945 [02:21<00:00, 6.68it/s] 100%|██████████| 215/215 [00:02<00:00, 93.42it/s]
precision: 0.8434014196726061, recall: 0.6783966441388953, f1: 0.7519535033903778, val accuracy: 0.9457621556580554 Train accuracy: 0.9983919167623316
100%|██████████| 945/945 [02:14<00:00, 7.04it/s] 100%|██████████| 215/215 [00:02<00:00, 89.87it/s]
precision: 0.8440340076223981, recall: 0.6709391750174785, f1: 0.7475980264866269, val accuracy: 0.9454522251189587 Train accuracy: 0.9989522403833889
100%|██████████| 945/945 [02:08<00:00, 7.34it/s] 100%|██████████| 215/215 [00:02<00:00, 90.39it/s]
precision: 0.852653120888759, recall: 0.6796783966441389, f1: 0.7564027750761848, val accuracy: 0.9472206523126288 Train accuracy: 0.9993303449406877
100%|██████████| 945/945 [02:18<00:00, 6.85it/s] 100%|██████████| 215/215 [00:03<00:00, 66.03it/s]
precision: 0.8375809935205184, recall: 0.6778140293637847, f1: 0.7492754556578862, val accuracy: 0.9455980747844159 Train accuracy: 0.9991891251662749
100%|██████████| 945/945 [02:07<00:00, 7.39it/s] 100%|██████████| 215/215 [00:03<00:00, 62.19it/s]
precision: 0.8413109098749461, recall: 0.6820088557445817, f1: 0.7533303301370746, val accuracy: 0.9462908606953383 Train accuracy: 0.9991435704003353
100%|██████████| 945/945 [02:13<00:00, 7.08it/s] 100%|██████████| 215/215 [00:03<00:00, 54.46it/s]
precision: 0.8479315263908702, recall: 0.6926124446515963, f1: 0.7624422780913289, val accuracy: 0.9478769758071868 Train accuracy: 0.998583246779278
100%|██████████| 945/945 [02:11<00:00, 7.20it/s] 100%|██████████| 215/215 [00:03<00:00, 57.36it/s]
precision: 0.8470877294406706, recall: 0.6829410393847588, f1: 0.7562092768208503, val accuracy: 0.9471294962717179 Train accuracy: 0.999180014213087
100%|██████████| 945/945 [02:15<00:00, 6.99it/s] 100%|██████████| 215/215 [00:04<00:00, 45.79it/s]
precision: 0.8728230645397337, recall: 0.6949429037520392, f1: 0.7737917612714889, val accuracy: 0.9498824087072251 Train accuracy: 0.9993212339874997
100%|██████████| 945/945 [02:38<00:00, 5.98it/s] 100%|██████████| 215/215 [00:07<00:00, 28.22it/s]
precision: 0.8691318792431028, recall: 0.7011186203682125, f1: 0.7761367300870687, val accuracy: 0.9505934258263296 Train accuracy: 0.9996310063958891
100%|██████████| 945/945 [02:03<00:00, 7.62it/s] 100%|██████████| 215/215 [00:02<00:00, 77.13it/s]
precision: 0.8701146047605054, recall: 0.6900489396411092, f1: 0.7696906680530282, val accuracy: 0.949116697963574 Train accuracy: 0.9997540042639261
def save_prediction(test_tokens, test_pred, file_name):
with open(file_name, 'w') as f:
for i in range(len(test_tokens)):
for j in range(len(test_tokens[i])):
print(i, j)
print(test_pred[i][j])
f.write(f'{test_tokens[i][j]}\t{list(num_tags.keys())[test_pred[i][j]]}\n')
f.write('\n')
test_pred = []
with torch.no_grad():
for i in range(len(test_tokens_ids)):
batch_tokens = test_tokens_ids[i].unsqueeze(0)
Y_batch_pred_weights = model(batch_tokens).squeeze(0)
Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)
test_pred += list(Y_batch_pred.numpy())
0
with open('test-A/out.tsv', 'w') as f:
for i in range(len(test_pred)):
tag = list(num_tags.keys())[test_pred[i]]
f.write(tag)
f.write('\n')