669 lines
20 KiB
Plaintext
669 lines
20 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 29,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"import pandas as pd\n",
|
|
"\n",
|
|
"from collections import Counter\n",
|
|
"from torchtext.vocab import vocab\n",
|
|
"from sklearn.metrics import accuracy_score\n",
|
|
"from tqdm import tqdm"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"#Wczytanie zbioru danych\n",
|
|
"\n",
|
|
"train_set = pd.read_csv('./train/train.tsv', sep='\\t', header=None, names=['labels', 'text'])\n",
|
|
"val_set = pd.read_csv('./dev-0/expected.tsv', sep='\\t', header=None, names=['labels'])\n",
|
|
"val_set['text'] = pd.read_csv('./dev-0/in.tsv', sep='\\t', header=None, names=['text'])\n",
|
|
"test_set = pd.read_csv('./test-A/in.tsv', sep='\\t', header=None, names=['text'])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"#Tokenizacja danych\n",
|
|
"train_set['text'] = train_set[\"text\"].apply(lambda x : x.split())\n",
|
|
"train_set['labels'] = train_set[\"labels\"].apply(lambda x : x.split())\n",
|
|
"\n",
|
|
"val_set['text'] = val_set[\"text\"].apply(lambda x : x.split())\n",
|
|
"val_set['labels'] = val_set[\"labels\"].apply(lambda x : x.split())\n",
|
|
"\n",
|
|
"test_set['text'] = test_set[\"text\"].apply(lambda x : x.split())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>labels</th>\n",
|
|
" <th>text</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>0</th>\n",
|
|
" <td>[B-ORG, O, B-MISC, O, O, O, B-MISC, O, O, O, B...</td>\n",
|
|
" <td>[EU, rejects, German, call, to, boycott, Briti...</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>1</th>\n",
|
|
" <td>[O, B-PER, O, O, O, O, O, O, O, O, O, B-LOC, O...</td>\n",
|
|
" <td>[Rare, Hendrix, song, draft, sells, for, almos...</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>2</th>\n",
|
|
" <td>[B-LOC, O, B-LOC, O, O, O, O, O, O, B-LOC, O, ...</td>\n",
|
|
" <td>[China, says, Taiwan, spoils, atmosphere, for,...</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>3</th>\n",
|
|
" <td>[B-LOC, O, O, O, O, B-LOC, O, O, O, B-LOC, O, ...</td>\n",
|
|
" <td>[China, says, time, right, for, Taiwan, talks,...</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>4</th>\n",
|
|
" <td>[B-MISC, O, O, O, O, O, O, O, O, O, O, O, B-LO...</td>\n",
|
|
" <td>[German, July, car, registrations, up, 14.2, p...</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>5</th>\n",
|
|
" <td>[B-MISC, O, O, O, O, O, O, O, O, O, O, B-LOC, ...</td>\n",
|
|
" <td>[GREEK, SOCIALISTS, GIVE, GREEN, LIGHT, TO, PM...</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>6</th>\n",
|
|
" <td>[B-ORG, O, B-MISC, O, O, O, O, O, O, B-LOC, O,...</td>\n",
|
|
" <td>[BayerVB, sets, C$, 100, million, six-year, bo...</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>7</th>\n",
|
|
" <td>[B-ORG, O, O, O, O, O, O, O, O, O, B-LOC, O, O...</td>\n",
|
|
" <td>[Venantius, sets, $, 300, million, January, 19...</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>8</th>\n",
|
|
" <td>[O, O, O, O, B-LOC, O, B-ORG, I-ORG, O, O, O, ...</td>\n",
|
|
" <td>[Port, conditions, update, -, Syria, -, Lloyds...</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>9</th>\n",
|
|
" <td>[B-LOC, O, O, O, O, O, O, B-LOC, O, O, B-PER, ...</td>\n",
|
|
" <td>[Israel, plays, down, fears, of, war, with, Sy...</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" labels \\\n",
|
|
"0 [B-ORG, O, B-MISC, O, O, O, B-MISC, O, O, O, B... \n",
|
|
"1 [O, B-PER, O, O, O, O, O, O, O, O, O, B-LOC, O... \n",
|
|
"2 [B-LOC, O, B-LOC, O, O, O, O, O, O, B-LOC, O, ... \n",
|
|
"3 [B-LOC, O, O, O, O, B-LOC, O, O, O, B-LOC, O, ... \n",
|
|
"4 [B-MISC, O, O, O, O, O, O, O, O, O, O, O, B-LO... \n",
|
|
"5 [B-MISC, O, O, O, O, O, O, O, O, O, O, B-LOC, ... \n",
|
|
"6 [B-ORG, O, B-MISC, O, O, O, O, O, O, B-LOC, O,... \n",
|
|
"7 [B-ORG, O, O, O, O, O, O, O, O, O, B-LOC, O, O... \n",
|
|
"8 [O, O, O, O, B-LOC, O, B-ORG, I-ORG, O, O, O, ... \n",
|
|
"9 [B-LOC, O, O, O, O, O, O, B-LOC, O, O, B-PER, ... \n",
|
|
"\n",
|
|
" text \n",
|
|
"0 [EU, rejects, German, call, to, boycott, Briti... \n",
|
|
"1 [Rare, Hendrix, song, draft, sells, for, almos... \n",
|
|
"2 [China, says, Taiwan, spoils, atmosphere, for,... \n",
|
|
"3 [China, says, time, right, for, Taiwan, talks,... \n",
|
|
"4 [German, July, car, registrations, up, 14.2, p... \n",
|
|
"5 [GREEK, SOCIALISTS, GIVE, GREEN, LIGHT, TO, PM... \n",
|
|
"6 [BayerVB, sets, C$, 100, million, six-year, bo... \n",
|
|
"7 [Venantius, sets, $, 300, million, January, 19... \n",
|
|
"8 [Port, conditions, update, -, Syria, -, Lloyds... \n",
|
|
"9 [Israel, plays, down, fears, of, war, with, Sy... "
|
|
]
|
|
},
|
|
"execution_count": 22,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"train_set.head(10)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"#Budowanie słownika\n",
|
|
"def build_vocab(dataset):\n",
|
|
" counter = Counter()\n",
|
|
" for document in dataset:\n",
|
|
" counter.update(document)\n",
|
|
" return vocab(counter, specials=[\"<unk>\", \"<pad>\", \"<bos>\", \"<eos>\"])\n",
|
|
" "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 24,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"['<unk>',\n",
|
|
" '<pad>',\n",
|
|
" '<bos>',\n",
|
|
" '<eos>',\n",
|
|
" 'EU',\n",
|
|
" 'rejects',\n",
|
|
" 'German',\n",
|
|
" 'call',\n",
|
|
" 'to',\n",
|
|
" 'boycott']"
|
|
]
|
|
},
|
|
"execution_count": 24,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"v = build_vocab(train_set['text'])\n",
|
|
"v.set_default_index(v[\"<unk>\"])\n",
|
|
"\n",
|
|
"itos = v.get_itos()\n",
|
|
"\n",
|
|
"itos[:10]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 25,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def data_process(dt):\n",
|
|
" # Wektoryzacja dokumentów tekstowych.\n",
|
|
" return [\n",
|
|
" torch.tensor(\n",
|
|
" [v[\"<bos>\"]] + [v[token] for token in document] + [v[\"<eos>\"]],\n",
|
|
" dtype=torch.long,\n",
|
|
" )\n",
|
|
" for document in dt\n",
|
|
" ]\n",
|
|
"\n",
|
|
"def labels_process(dt):\n",
|
|
" # Wektoryzacja etykiet (NER)\n",
|
|
" return [torch.tensor([0] + document + [0], dtype=torch.long) for document in dt]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 28,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"#Różne tagi NER\n",
|
|
"num_tags = {\n",
|
|
" \"O\" : 0,\n",
|
|
" \"B-PER\" : 1,\n",
|
|
" \"I-PER\" : 2,\n",
|
|
" \"B-ORG\" : 3,\n",
|
|
" \"I-ORG\" : 4,\n",
|
|
" \"B-LOC\" : 5,\n",
|
|
" \"I-LOC\" : 6,\n",
|
|
" \"B-MISC\" : 7,\n",
|
|
" \"I-MISC\" : 8,\n",
|
|
"}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def covert_to_int(dt, tags):\n",
|
|
" labels = []\n",
|
|
" for label in dt:\n",
|
|
" labels.append([tags[i] for i in label])\n",
|
|
" return labels"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 26,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_tokens_ids = data_process(train_set['text'])\n",
|
|
"train_labels_ids = labels_process(covert_to_int(train_set['labels'], tags=num_tags))\n",
|
|
"\n",
|
|
"val_tokens_ids = data_process(val_set['text'])\n",
|
|
"val_labels_ids = labels_process(covert_to_int(val_set['labels'], tags=num_tags))\n",
|
|
"\n",
|
|
"test_tokens_ids = data_process(train_set['text'])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 43,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class LSTM(torch.nn.Module):\n",
|
|
"\n",
|
|
" def __init__(self, num_tags):\n",
|
|
" super(LSTM, self).__init__()\n",
|
|
" self.emb = torch.nn.Embedding(len(v.get_itos()), 100)\n",
|
|
" self.rec = torch.nn.LSTM(100, 256, 1, batch_first=True)\n",
|
|
" self.fc1 = torch.nn.Linear(256, num_tags)\n",
|
|
" self.hidden2tag = torch.nn.Linear(20, num_tags)\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" emb = torch.relu(self.emb(x))\n",
|
|
" lstm_output, (h_n, c_n) = self.rec(emb)\n",
|
|
" out_weights = self.fc1(lstm_output)\n",
|
|
" return out_weights"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"EPOCHS = 10\n",
|
|
"LR = 0.001\n",
|
|
"NUM_TAGS = len(num_tags)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model = LSTM(num_tags=NUM_TAGS)\n",
|
|
"optimizer = torch.optim.Adam(model.parameters(), lr=LR)\n",
|
|
"criterion = torch.nn.CrossEntropyLoss()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 39,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_scores(y_true, y_pred):\n",
|
|
" # Funkcja zwraca precyzję, pokrycie i F1\n",
|
|
" acc_score = 0\n",
|
|
" tp = 0\n",
|
|
" fp = 0\n",
|
|
" selected_items = 0\n",
|
|
" relevant_items = 0\n",
|
|
"\n",
|
|
" for p, t in zip(y_pred, y_true):\n",
|
|
" if p == t:\n",
|
|
" acc_score += 1\n",
|
|
"\n",
|
|
" if p > 0 and p == t:\n",
|
|
" tp += 1\n",
|
|
"\n",
|
|
" if p > 0:\n",
|
|
" selected_items += 1\n",
|
|
"\n",
|
|
" if t > 0:\n",
|
|
" relevant_items += 1\n",
|
|
"\n",
|
|
" if selected_items == 0:\n",
|
|
" precision = 1.0\n",
|
|
" else:\n",
|
|
" precision = tp / selected_items\n",
|
|
"\n",
|
|
" if relevant_items == 0:\n",
|
|
" recall = 1.0\n",
|
|
" else:\n",
|
|
" recall = tp / relevant_items\n",
|
|
"\n",
|
|
" if precision + recall == 0.0:\n",
|
|
" f1 = 0.0\n",
|
|
" else:\n",
|
|
" f1 = 2 * precision * recall / (precision + recall)\n",
|
|
"\n",
|
|
" acc = accuracy_score(y_true, y_pred)\n",
|
|
" return precision, recall, f1, acc"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 40,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def eval_model(dataset_tokens, dataset_labels, model):\n",
|
|
" Y_true = []\n",
|
|
" Y_pred = []\n",
|
|
" for i in tqdm(range(len(dataset_labels))):\n",
|
|
" batch_tokens = dataset_tokens[i].unsqueeze(0)\n",
|
|
" tags = dataset_labels[i].unsqueeze(1)\n",
|
|
" Y_true += tags\n",
|
|
"\n",
|
|
" Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n",
|
|
" Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)\n",
|
|
" Y_pred += list(Y_batch_pred.numpy())\n",
|
|
"\n",
|
|
" precision, recall, f1, acc = get_scores(Y_true, Y_pred)\n",
|
|
" print(f'precision: {precision}, recall: {recall}, f1: {f1}, val accuracy: {acc}')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 46,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|██████████| 945/945 [02:21<00:00, 6.68it/s]\n",
|
|
"100%|██████████| 215/215 [00:02<00:00, 93.42it/s] \n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"precision: 0.8434014196726061, recall: 0.6783966441388953, f1: 0.7519535033903778, val accuracy: 0.9457621556580554\n",
|
|
"Train accuracy: 0.9983919167623316\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|██████████| 945/945 [02:14<00:00, 7.04it/s]\n",
|
|
"100%|██████████| 215/215 [00:02<00:00, 89.87it/s] \n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"precision: 0.8440340076223981, recall: 0.6709391750174785, f1: 0.7475980264866269, val accuracy: 0.9454522251189587\n",
|
|
"Train accuracy: 0.9989522403833889\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|██████████| 945/945 [02:08<00:00, 7.34it/s]\n",
|
|
"100%|██████████| 215/215 [00:02<00:00, 90.39it/s] \n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"precision: 0.852653120888759, recall: 0.6796783966441389, f1: 0.7564027750761848, val accuracy: 0.9472206523126288\n",
|
|
"Train accuracy: 0.9993303449406877\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|██████████| 945/945 [02:18<00:00, 6.85it/s]\n",
|
|
"100%|██████████| 215/215 [00:03<00:00, 66.03it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"precision: 0.8375809935205184, recall: 0.6778140293637847, f1: 0.7492754556578862, val accuracy: 0.9455980747844159\n",
|
|
"Train accuracy: 0.9991891251662749\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|██████████| 945/945 [02:07<00:00, 7.39it/s]\n",
|
|
"100%|██████████| 215/215 [00:03<00:00, 62.19it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"precision: 0.8413109098749461, recall: 0.6820088557445817, f1: 0.7533303301370746, val accuracy: 0.9462908606953383\n",
|
|
"Train accuracy: 0.9991435704003353\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|██████████| 945/945 [02:13<00:00, 7.08it/s]\n",
|
|
"100%|██████████| 215/215 [00:03<00:00, 54.46it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"precision: 0.8479315263908702, recall: 0.6926124446515963, f1: 0.7624422780913289, val accuracy: 0.9478769758071868\n",
|
|
"Train accuracy: 0.998583246779278\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|██████████| 945/945 [02:11<00:00, 7.20it/s]\n",
|
|
"100%|██████████| 215/215 [00:03<00:00, 57.36it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"precision: 0.8470877294406706, recall: 0.6829410393847588, f1: 0.7562092768208503, val accuracy: 0.9471294962717179\n",
|
|
"Train accuracy: 0.999180014213087\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|██████████| 945/945 [02:15<00:00, 6.99it/s]\n",
|
|
"100%|██████████| 215/215 [00:04<00:00, 45.79it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"precision: 0.8728230645397337, recall: 0.6949429037520392, f1: 0.7737917612714889, val accuracy: 0.9498824087072251\n",
|
|
"Train accuracy: 0.9993212339874997\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|██████████| 945/945 [02:38<00:00, 5.98it/s]\n",
|
|
"100%|██████████| 215/215 [00:07<00:00, 28.22it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"precision: 0.8691318792431028, recall: 0.7011186203682125, f1: 0.7761367300870687, val accuracy: 0.9505934258263296\n",
|
|
"Train accuracy: 0.9996310063958891\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|██████████| 945/945 [02:03<00:00, 7.62it/s]\n",
|
|
"100%|██████████| 215/215 [00:02<00:00, 77.13it/s] \n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"precision: 0.8701146047605054, recall: 0.6900489396411092, f1: 0.7696906680530282, val accuracy: 0.949116697963574\n",
|
|
"Train accuracy: 0.9997540042639261\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"NUM_EPOCHS = 10\n",
|
|
"for i in range(NUM_EPOCHS):\n",
|
|
" model.train()\n",
|
|
" train_true = []\n",
|
|
" train_pred = []\n",
|
|
" for i in tqdm(range(len(train_set['labels']))):\n",
|
|
" batch_tokens = train_tokens_ids[i].unsqueeze(0)\n",
|
|
" tags = train_labels_ids[i].unsqueeze(1)\n",
|
|
" train_true += tags\n",
|
|
"\n",
|
|
" Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n",
|
|
" Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)\n",
|
|
" train_pred += list(Y_batch_pred.numpy())\n",
|
|
"\n",
|
|
" predicted_tags = model(batch_tokens)\n",
|
|
"\n",
|
|
" optimizer.zero_grad()\n",
|
|
" loss = criterion(predicted_tags.squeeze(0), tags.squeeze(1))\n",
|
|
"\n",
|
|
" loss.backward()\n",
|
|
" optimizer.step()\n",
|
|
"\n",
|
|
" model.eval()\n",
|
|
" eval_model(val_tokens_ids, val_labels_ids, model)\n",
|
|
" print(f'Train accuracy: {accuracy_score(train_true, train_pred)}')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 67,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def save_prediction(test_tokens, test_pred, file_name):\n",
|
|
" with open(file_name, 'w') as f:\n",
|
|
" for i in range(len(test_tokens)):\n",
|
|
" for j in range(len(test_tokens[i])):\n",
|
|
" print(i, j)\n",
|
|
" print(test_pred[i][j])\n",
|
|
" f.write(f'{test_tokens[i][j]}\\t{list(num_tags.keys())[test_pred[i][j]]}\\n')\n",
|
|
" f.write('\\n')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 70,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"0\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"test_pred = []\n",
|
|
"\n",
|
|
"with torch.no_grad():\n",
|
|
" for i in range(len(test_tokens_ids)):\n",
|
|
" batch_tokens = test_tokens_ids[i].unsqueeze(0)\n",
|
|
"\n",
|
|
" Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n",
|
|
" Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)\n",
|
|
" test_pred += list(Y_batch_pred.numpy())\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 86,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"with open('test-A/out.tsv', 'w') as f:\n",
|
|
" for i in range(len(test_pred)):\n",
|
|
" tag = list(num_tags.keys())[test_pred[i]]\n",
|
|
" f.write(tag)\n",
|
|
" f.write('\\n')\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "dl",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.6"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|