28 KiB
28 KiB
from collections import Counter
import torch
import pandas as pd
from torchtext.vocab import vocab
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score
if torch.cuda.is_available():
print("CUDA jest dostępna!")
print(f"Nazwa urządzenia: {torch.cuda.get_device_name(0)}")
device = torch.device("cuda")
else:
print("CUDA nie jest dostępna. Model będzie uruchomiony na CPU.")
device = torch.device("cpu")
CUDA nie jest dostępna. Model będzie uruchomiony na CPU.
train_data = pd.read_csv("train.tsv", sep='\t', header=None, names=['labels', 'documents'])
train_data["tokenized_documents"] = train_data["documents"].apply(lambda x: x.split())
train_data["tokenized_labels"] = train_data["labels"].apply(lambda x: x.split())
X_train, X_val, y_train, y_val = train_test_split(
train_data["tokenized_documents"], train_data["tokenized_labels"], test_size=0.2, random_state=42
)
def build_vocab(dataset):
counter = Counter()
for document in dataset:
counter.update(document)
return vocab(counter, specials=["<unk>", "<pad>", "<bos>", "<eos>"])
train_vocab = build_vocab(X_train)
itos = train_vocab.get_itos()
train_vocab.set_default_index(train_vocab["<unk>"])
def data_process(dt):
return [
torch.tensor(
[train_vocab["<bos>"]] + [train_vocab[token] for token in document] + [train_vocab["<eos>"]],
dtype=torch.long,
)
for document in dt
]
train_tokens_ids = data_process(X_train)
val_tokens_ids = data_process(X_val)
labels = ["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "B-MISC", "I-MISC"]
label_to_index = {label: idx for idx, label in enumerate(labels)}
def labels_process(dt, label_to_index):
return [
torch.tensor(
[0] + [label_to_index[label] for label in document] + [0],
dtype=torch.long,
device=device,
)
for document in dt
]
train_labels = labels_process(y_train, label_to_index)
val_labels = labels_process(y_val, label_to_index)
all_label_indices = [
label_to_index[label]
for document in y_train
for label in document
]
num_tags = max(all_label_indices) + 1
class LSTM(torch.nn.Module):
def __init__(self):
super(LSTM, self).__init__()
self.emb = torch.nn.Embedding(len(train_vocab.get_itos()), 100)
self.rec = torch.nn.LSTM(100, 256, 1, batch_first=True)
self.fc1 = torch.nn.Linear(256, num_tags)
def forward(self, x):
emb = torch.relu(self.emb(x))
lstm_output, (h_n, c_n) = self.rec(emb)
out_weights = self.fc1(lstm_output)
return out_weights
lstm = LSTM()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(lstm.parameters())
def get_scores(y_true, y_pred):
acc_score = 0
tp = 0
fp = 0
selected_items = 0
relevant_items = 0
for p, t in zip(y_pred, y_true):
if p == t:
acc_score += 1
if p > 0 and p == t:
tp += 1
if p > 0:
selected_items += 1
if t > 0:
relevant_items += 1
if selected_items == 0:
precision = 1.0
else:
precision = tp / selected_items
if relevant_items == 0:
recall = 1.0
else:
recall = tp / relevant_items
if precision + recall == 0.0:
f1 = 0.0
else:
f1 = 2 * precision * recall / (precision + recall)
return precision, recall, f1
def eval_model(dataset_tokens, dataset_labels, model):
Y_true = []
Y_pred = []
for i in tqdm(range(len(dataset_labels))):
batch_tokens = dataset_tokens[i].unsqueeze(0)
tags = list(dataset_labels[i].numpy())
Y_true += tags
Y_batch_pred_weights = model(batch_tokens).squeeze(0)
Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)
Y_pred += list(Y_batch_pred.numpy())
return get_scores(Y_true, Y_pred)
NUM_EPOCHS = 20
for i in range(NUM_EPOCHS):
lstm.train()
for i in tqdm(range(len(train_labels))):
batch_tokens = train_tokens_ids[i].unsqueeze(0)
tags = train_labels[i].unsqueeze(1)
predicted_tags = lstm(batch_tokens)
optimizer.zero_grad()
loss = criterion(predicted_tags.squeeze(0), tags.squeeze(1))
loss.backward()
optimizer.step()
lstm.eval()
print(eval_model(val_tokens_ids, val_labels, lstm))
0%| | 0/756 [00:00<?, ?it/s]
0%| | 0/189 [00:00<?, ?it/s]
(0.5042095416276894, 0.07628078120577413, 0.1325138291333743)
0%| | 0/756 [00:00<?, ?it/s]
0%| | 0/189 [00:00<?, ?it/s]
(0.6831020812685827, 0.3901783187093122, 0.4966672671590705)
0%| | 0/756 [00:00<?, ?it/s]
0%| | 0/189 [00:00<?, ?it/s]
(0.7253631723596388, 0.5229266911972827, 0.6077302631578947)
0%| | 0/756 [00:00<?, ?it/s]
0%| | 0/189 [00:00<?, ?it/s]
(0.7647388059701492, 0.5801018964053213, 0.6597456945115081)
0%| | 0/756 [00:00<?, ?it/s]
0%| | 0/189 [00:00<?, ?it/s]
(0.7872690689592098, 0.6091140673648457, 0.6868267773079072)
0%| | 0/756 [00:00<?, ?it/s]
0%| | 0/189 [00:00<?, ?it/s]
(0.8001462790272444, 0.6193037078969714, 0.6982050259274033)
0%| | 0/756 [00:00<?, ?it/s]
0%| | 0/189 [00:00<?, ?it/s]
(0.8068851251840943, 0.6202943673931502, 0.7013922227556408)
0%| | 0/756 [00:00<?, ?it/s]
0%| | 0/189 [00:00<?, ?it/s]
(0.8154405086285196, 0.635295782621002, 0.7141834380717526)
0%| | 0/756 [00:00<?, ?it/s]
0%| | 0/189 [00:00<?, ?it/s]
(0.8557729190640583, 0.6314746674214549, 0.7267100977198697)
0%| | 0/756 [00:00<?, ?it/s]
0%| | 0/189 [00:00<?, ?it/s]
(0.837495475931958, 0.6549674497594112, 0.7350698856416772)
0%| | 0/756 [00:00<?, ?it/s]
0%| | 0/189 [00:00<?, ?it/s]
(0.8592408926187297, 0.6375601471836966, 0.7319847266227963)
0%| | 0/756 [00:00<?, ?it/s]
0%| | 0/189 [00:00<?, ?it/s]
(0.8497704315886134, 0.6548259269742428, 0.7396690911997442)
0%| | 0/756 [00:00<?, ?it/s]
0%| | 0/189 [00:00<?, ?it/s]
(0.8601201652271874, 0.6483158788564959, 0.7393479664299548)
0%| | 0/756 [00:00<?, ?it/s]
0%| | 0/189 [00:00<?, ?it/s]
(0.8630500758725341, 0.6439286725162752, 0.737558761549684)
0%| | 0/756 [00:00<?, ?it/s]
0%| | 0/189 [00:00<?, ?it/s]
(0.8547055586130985, 0.6593546560996321, 0.7444275784932493)
0%| | 0/756 [00:00<?, ?it/s]
0%| | 0/189 [00:00<?, ?it/s]
(0.8580931263858093, 0.6572318143221059, 0.7443500560987337)
0%| | 0/756 [00:00<?, ?it/s]
0%| | 0/189 [00:00<?, ?it/s]
(0.857722827089869, 0.6577979054627795, 0.744573488185823)
0%| | 0/756 [00:00<?, ?it/s]
0%| | 0/189 [00:00<?, ?it/s]
(0.8601476014760148, 0.6597792244551373, 0.7467563671311869)
0%| | 0/756 [00:00<?, ?it/s]
0%| | 0/189 [00:00<?, ?it/s]
(0.8474636395885066, 0.6761958675346731, 0.7522040302267002)
0%| | 0/756 [00:00<?, ?it/s]
0%| | 0/189 [00:00<?, ?it/s]
(0.855834829443447, 0.6746391168978205, 0.7545109211775879)
def pred_labels(dataset_tokens, model, label_to_index):
Y_pred = []
inv_label_to_index = {
v: k for k, v in label_to_index.items()
}
for i in tqdm(range(len(dataset_tokens))):
batch_tokens = dataset_tokens[i].unsqueeze(0)
Y_batch_pred_weights = model(batch_tokens).squeeze(0)
Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)
predicted_labels = [inv_label_to_index[label.item()] for label in Y_batch_pred]
predicted_labels = predicted_labels[1:-1]
Y_pred.append(" ".join(predicted_labels))
return Y_pred
dev_data = pd.read_csv("dev-0/in.tsv", sep="\t", names=["Text"])
dev_labels = pd.read_csv("dev-0/expected.tsv", sep="\t", names=["Label"])
test_A_data = pd.read_csv("test-A/in.tsv", sep="\t", names=["Text"])
dev_data["tokenized_text"] = dev_data["Text"].apply(lambda x: x.split())
dev_labels["tokenized_labels"] = dev_labels["Label"].apply(lambda x: x.split())
test_A_data["tokenized_text"] = test_A_data["Text"].apply(lambda x: x.split())
dev_0_tokens_ids = data_process(dev_data["tokenized_text"])
test_A_tokens_ids = data_process(test_A_data["tokenized_text"])
dev_0_labels = labels_process(dev_labels["tokenized_labels"], label_to_index)
dev_0_predictons = pred_labels(dev_0_tokens_ids, lstm, label_to_index)
dev_0_predictons = pd.DataFrame(dev_0_predictons, columns=["Label"])
dev_0_predictons.to_csv("dev-0/out.tsv", index=False, header=False)
test_A_predictions = pred_labels(test_A_tokens_ids, lstm, label_to_index)
test_A_predictions = pd.DataFrame(test_A_predictions, columns=["Label"])
test_A_predictions.to_csv("test-A/out.tsv", index=False, header=False)
0%| | 0/215 [00:00<?, ?it/s]
0%| | 0/230 [00:00<?, ?it/s]
with open('dev-0/out.tsv', 'r') as file:
predicted_labels = [line.strip().split()[1:] for line in file]
with open('dev-0/expected.tsv', 'r') as file:
true_labels = [line.strip().split()[1:] for line in file]
predicted_labels = [label for sublist in predicted_labels for label in sublist]
true_labels = [label for sublist in true_labels for label in sublist]
accuracy = accuracy_score(true_labels, predicted_labels)
print("Accuracy:", accuracy)
Accuracy: 0.9441390252001624