Compare commits
30 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
6d075df43a | ||
|
2d9ff2e1f3 | ||
|
0afbf6c8c6 | ||
|
ab758851de | ||
|
af823aa3c9 | ||
|
92da4991fe | ||
|
b137f9b6e5 | ||
|
e5af613eb7 | ||
|
7c351de460 | ||
|
4d4ef0e385 | ||
|
2ed6fcbddb | ||
|
0872ed69df | ||
|
971ddd924a | ||
|
92a6d4700c | ||
|
9b31cf7b6e | ||
|
5b68a15361 | ||
|
00fd68f8f1 | ||
|
12faadda76 | ||
|
d0d1459d8d | ||
|
cf2a6a1363 | ||
|
764fb1201f | ||
|
9461f6f674 | ||
|
1b8e25c9b3 | ||
|
6ac706db11 | ||
|
18b4fcc89d | ||
|
5a8e513cd8 | ||
|
d2d4d75e95 | ||
|
aa676c091a | ||
|
93f784284f | ||
|
efd5c0ee22 |
21038
dev-0/out.tsv
21038
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
7
gonito.yaml
Normal file
7
gonito.yaml
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
description: s478873, left context
|
||||||
|
tags:
|
||||||
|
- neural-network
|
||||||
|
- gpt2
|
||||||
|
params:
|
||||||
|
top_k: 30
|
||||||
|
|
7414
test-A/out.tsv
Normal file
7414
test-A/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
351
tetragrams_final.ipynb
Normal file
351
tetragrams_final.ipynb
Normal file
@ -0,0 +1,351 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import lzma\n",
|
||||||
|
"import pickle\n",
|
||||||
|
"from tqdm import tqdm\n",
|
||||||
|
"from collections import Counter, defaultdict"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def clean_text(line: str):\n",
|
||||||
|
" separated = line.split('\\t')\n",
|
||||||
|
" prefix = separated[6].replace(r'\\n', ' ').replace(' ', ' ').replace('.', '').replace(',', '').replace('?', '').replace('!', '')\n",
|
||||||
|
" suffix = separated[7].replace(r'\\n', ' ').replace(' ', ' ').replace('.', '').replace(',', '').replace('?', '').replace('!', '')\n",
|
||||||
|
" return prefix + ' ' + suffix"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def unigrams(filename):\n",
|
||||||
|
" with lzma.open(filename, mode='rt', encoding='utf-8') as fid:\n",
|
||||||
|
" with tqdm(total=432022) as pbar:\n",
|
||||||
|
" for line in fid:\n",
|
||||||
|
" text = clean_text(line)\n",
|
||||||
|
" for word in text.split():\n",
|
||||||
|
" yield word\n",
|
||||||
|
" pbar.update(1)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def bigrams(filename, V: dict):\n",
|
||||||
|
" with lzma.open(filename, mode='rt', encoding='utf-8') as fid:\n",
|
||||||
|
" pbar = tqdm(total=432022)\n",
|
||||||
|
" first_word = ''\n",
|
||||||
|
" for line in fid:\n",
|
||||||
|
" text = clean_text(line)\n",
|
||||||
|
" for second_word in text.split():\n",
|
||||||
|
" if V.get(second_word) is None:\n",
|
||||||
|
" second_word = 'UNK'\n",
|
||||||
|
" if second_word:\n",
|
||||||
|
" yield first_word, second_word\n",
|
||||||
|
" first_word = second_word\n",
|
||||||
|
" pbar.update(1)\n",
|
||||||
|
" pbar.close()\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def trigrams(filename, V: dict):\n",
|
||||||
|
" with lzma.open(filename, mode='rt', encoding='utf-8') as fid:\n",
|
||||||
|
" print('Trigrams')\n",
|
||||||
|
" for line in tqdm(fid, total=432022):\n",
|
||||||
|
" text = clean_text(line)\n",
|
||||||
|
" words = text.split()\n",
|
||||||
|
" for i in range(len(words)-2):\n",
|
||||||
|
" trigram = tuple(V.get(word, 'UNK') for word in words[i:i+3])\n",
|
||||||
|
" yield trigram\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def tetragrams(filename, V: dict):\n",
|
||||||
|
" with lzma.open(filename, mode='rt', encoding='utf-8') as fid:\n",
|
||||||
|
" print('Tetragrams')\n",
|
||||||
|
" for i, line in enumerate(tqdm(fid, total=432022)):\n",
|
||||||
|
" text = clean_text(line)\n",
|
||||||
|
" words = [V.get(word, 'UNK') for word in text.split()]\n",
|
||||||
|
" for first_word, second_word, third_word, fourth_word in zip(words, words[1:], words[2:], words[3:]):\n",
|
||||||
|
" yield first_word, second_word, third_word, fourth_word\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def P(first_word, second_word=None, third_word=None, fourth_word=None):\n",
|
||||||
|
" if second_word is None:\n",
|
||||||
|
" return V_common_dict.get(first_word, 0) / total\n",
|
||||||
|
" elif third_word is None:\n",
|
||||||
|
" return V2_bigrams_dict.get((first_word, second_word), 0) / V_common_dict.get(first_word, 0)\n",
|
||||||
|
" elif fourth_word is None:\n",
|
||||||
|
" return V3_trigrams_dict.get((first_word, second_word, third_word), 0) / V2_bigrams_dict.get((first_word, second_word), 0)\n",
|
||||||
|
" else:\n",
|
||||||
|
" return V4_tatragrams_dict.get((first_word, second_word, third_word, fourth_word), 0) / V3__trigrams_dict.get((first_word, second_word, third_word), 0)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def compute_tetragram_probability(tetragram):\n",
|
||||||
|
" return 0.5 * P(*tetragram) + 0.35 * P(tetragram[1], *tetragram[2:]) + \\\n",
|
||||||
|
" 0.1 * P(tetragram[2], *tetragram[3:]) + 0.05 * P(tetragram[3])\n",
|
||||||
|
"\n",
|
||||||
|
"def get_context(position, sentence):\n",
|
||||||
|
" context = []\n",
|
||||||
|
" for i in range(position-3, position):\n",
|
||||||
|
" if i < 0:\n",
|
||||||
|
" context.append('')\n",
|
||||||
|
" else:\n",
|
||||||
|
" context.append(sentence[i])\n",
|
||||||
|
" for i in range(position, position+4):\n",
|
||||||
|
" if i >= len(sentence):\n",
|
||||||
|
" context.append('')\n",
|
||||||
|
" else:\n",
|
||||||
|
" context.append(sentence[i])\n",
|
||||||
|
" return context\n",
|
||||||
|
"\n",
|
||||||
|
"def compute_candidates(left_context, right_context):\n",
|
||||||
|
" candidate_probabilities = {}\n",
|
||||||
|
" for word in V_common_dict:\n",
|
||||||
|
" tetragram = left_context[-3:] + [word] + right_context[:3]\n",
|
||||||
|
" probability = compute_tetragram_probability(tetragram)\n",
|
||||||
|
" candidate_probabilities[word] = probability\n",
|
||||||
|
" sorted_candidates = sorted(candidate_probabilities.items(), key=lambda x: x[1], reverse=True)[:5]\n",
|
||||||
|
" total_probability = sum([c[1] for c in sorted_candidates])\n",
|
||||||
|
" normalized_candidates = [(c[0], c[1] / total_probability) for c in sorted_candidates]\n",
|
||||||
|
" for index, elem in enumerate(normalized_candidates):\n",
|
||||||
|
" if 'UNK' in elem:\n",
|
||||||
|
" normalized_candidates.pop(index)\n",
|
||||||
|
" normalized_candidates.append(('', elem[1]))\n",
|
||||||
|
" break\n",
|
||||||
|
" else:\n",
|
||||||
|
" normalized_candidates[-1] = ('', normalized_candidates[-1][1])\n",
|
||||||
|
" return ' '.join([f'{x[0]}:{x[1]}' for x in normalized_candidates])\n",
|
||||||
|
"\n",
|
||||||
|
"def candidates(left_context, right_context):\n",
|
||||||
|
" left_context = [w if V_common_dict.get(w) else 'UNK' for w in left_context]\n",
|
||||||
|
" right_context = [w if V_common_dict.get(w) else 'UNK' for w in right_context]\n",
|
||||||
|
" return compute_candidates(left_context, right_context)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def create_vocab(filename, word_limit):\n",
|
||||||
|
" V = Counter(unigrams(filename))\n",
|
||||||
|
" V_common = V.most_common(word_limit)\n",
|
||||||
|
" UNK = sum(v for k, v in V.items() if k not in dict(V_common))\n",
|
||||||
|
" V_common_dict = dict(V_common)\n",
|
||||||
|
" V_common_dict['UNK'] = UNK\n",
|
||||||
|
" V_common_tuple = tuple((k, v) for k, v in V_common_dict.items())\n",
|
||||||
|
" with open('V.pickle', 'wb') as handle:\n",
|
||||||
|
" pickle.dump(V_common_tuple, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
|
||||||
|
" return V_common_dict\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def load_pickle(filename):\n",
|
||||||
|
" with open(filename, 'rb') as handle:\n",
|
||||||
|
" return pickle.load(handle)\n",
|
||||||
|
"\n",
|
||||||
|
"def save_pickle(obj, filename):\n",
|
||||||
|
" with open(filename, 'wb') as handle:\n",
|
||||||
|
" pickle.dump(obj, handle, protocol=pickle.HIGHEST_PROTOCOL)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"create_vocab('train/in.tsv.xz', 1000)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"with open('V.pickle', 'rb') as handle:\n",
|
||||||
|
" V_common_dict = pickle.load(handle)\n",
|
||||||
|
"\n",
|
||||||
|
"total = sum(V_common_dict.values())"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"with open('V.pickle', 'rb') as handle:\n",
|
||||||
|
" V_common_tuple = pickle.load(handle)\n",
|
||||||
|
"\n",
|
||||||
|
"V_common_dict = dict(V_common_tuple)\n",
|
||||||
|
"\n",
|
||||||
|
"total = sum(V_common_dict.values())\n",
|
||||||
|
"\n",
|
||||||
|
"V2_bigrams = Counter(bigrams('train/in.tsv.xz', V_common_tuple))\n",
|
||||||
|
"V2_bigrams_dict = dict(V2_bigrams)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"save_pickle(V2_bigrams_dict,'V2_bigrams.pickle')\n",
|
||||||
|
"\n",
|
||||||
|
"V2_bigrams_dict = load_pickle('V2_bigrams.pickle')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"V2_bigrams = Counter(bigrams('train/in.tsv.xz', V_common_dict))\n",
|
||||||
|
"V2_bigrams_dict = dict(V2_bigrams)\n",
|
||||||
|
"save_pickle('V2_bigrams.pickle', V2_bigrams_dict)\n",
|
||||||
|
"\n",
|
||||||
|
"V2_bigrams_dict = load_pickle('V2_bigrams.pickle')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"V3_trigrams = Counter(trigrams('train/in.tsv.xz', V_common_dict))\n",
|
||||||
|
"V3_trigrams_dict = dict(V3_trigrams)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"save_pickle(V3_trigrams_dict, 'V3_trigrams.pickle')\n",
|
||||||
|
"V3_trigrams_dict = load_pickle('V3_trigrams.pickle')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"V4_tetragrams = Counter(tetragrams('train/in.tsv.xz', V_common_dict))\n",
|
||||||
|
"V4_tetragrams_dict = dict(V4_tetragrams)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"save_pickle(V4_tetragrams_dict, 'V4_tetragrams.pickle')\n",
|
||||||
|
"V4_tetragrams_dict = load_pickle('V4_tetragrams.pickle')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def save_outs(folder_name):\n",
|
||||||
|
" with lzma.open(f'{folder_name}/in.tsv.xz', mode='rt', encoding='utf-8') as fid:\n",
|
||||||
|
" with open(f'{folder_name}/out.tsv', 'w', encoding='utf-8') as f:\n",
|
||||||
|
" for line in tqdm(fid):\n",
|
||||||
|
" separated = line.split('\\t')\n",
|
||||||
|
" prefix = separated[6].replace(r'\\n', ' ').replace(' ', ' ').replace('.', '').replace(',', '').replace('?', '').replace('!', '').split()\n",
|
||||||
|
" suffix = separated[7].replace(r'\\n', ' ').replace(' ', ' ').replace('.', '').replace(',', '').replace('?', '').replace('!', '').split()\n",
|
||||||
|
" left_context = [x if V_common_dict.get(x) else 'UNK' for x in prefix[-3:]]\n",
|
||||||
|
" right_context = [x if V_common_dict.get(x) else 'UNK' for x in suffix[:3]]\n",
|
||||||
|
" w = candidates(left_context, right_context)\n",
|
||||||
|
" f.write(w + '\\n')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"save_outs('dev-0')\n",
|
||||||
|
"save_outs('test-A')"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.2"
|
||||||
|
},
|
||||||
|
"orig_nbformat": 4
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
1211
zad7neural_networks.ipynb
Normal file
1211
zad7neural_networks.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
328
zad8_trigrams_nn.ipynb
Normal file
328
zad8_trigrams_nn.ipynb
Normal file
@ -0,0 +1,328 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import torch\n",
|
||||||
|
"import lzma\n",
|
||||||
|
"from itertools import islice\n",
|
||||||
|
"import re\n",
|
||||||
|
"import sys\n",
|
||||||
|
"from torchtext.vocab import build_vocab_from_iterator\n",
|
||||||
|
"from torch import nn\n",
|
||||||
|
"from torch.utils.data import IterableDataset, DataLoader\n",
|
||||||
|
"import itertools\n",
|
||||||
|
"import matplotlib.pyplot as plt"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Parameters"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"VOCAB_SIZE = 2_000\n",
|
||||||
|
"EMBED_SIZE = 500"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Functions"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def get_words_from_line(line):\n",
|
||||||
|
" line = line.rstrip()\n",
|
||||||
|
" line = line.split(\"\\t\")\n",
|
||||||
|
" text = line[-2] + \" \" + line[-1]\n",
|
||||||
|
" text = re.sub(r\"\\\\+n\", \" \", text)\n",
|
||||||
|
" text = re.sub('[^A-Za-z ]+', '', text)\n",
|
||||||
|
" for t in text.split():\n",
|
||||||
|
" yield t"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def get_word_lines_from_file(file_name):\n",
|
||||||
|
" with lzma.open(file_name, encoding='utf8', mode=\"rt\") as fh:\n",
|
||||||
|
" for line in fh:\n",
|
||||||
|
" yield get_words_from_line(line)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def look_ahead_iterator(gen):\n",
|
||||||
|
" first = None\n",
|
||||||
|
" second = None\n",
|
||||||
|
" for item in gen:\n",
|
||||||
|
" if first is not None and second is not None:\n",
|
||||||
|
" yield ((first, item), second)\n",
|
||||||
|
" first = second\n",
|
||||||
|
" second = item"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Create Vocab"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"vocab = build_vocab_from_iterator(\n",
|
||||||
|
" get_word_lines_from_file(\"train/in.tsv.xz\"),\n",
|
||||||
|
" max_tokens = VOCAB_SIZE,\n",
|
||||||
|
" specials = ['<unk>'])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Trigram class"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class Trigrams(IterableDataset):\n",
|
||||||
|
" def __init__(self, text_file, vocabulary_size):\n",
|
||||||
|
" self.vocab = vocab\n",
|
||||||
|
" self.vocab.set_default_index(self.vocab['<unk>'])\n",
|
||||||
|
" self.vocabulary_size = VOCAB_SIZE\n",
|
||||||
|
" self.text_file = text_file\n",
|
||||||
|
"\n",
|
||||||
|
" def __iter__(self):\n",
|
||||||
|
" return look_ahead_iterator(\n",
|
||||||
|
" (self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_file(self.text_file))))\n",
|
||||||
|
"\n",
|
||||||
|
"train_dataset = Trigrams(\"train/in.tsv.xz\", VOCAB_SIZE)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class TrigramNNModel(nn.Module):\n",
|
||||||
|
" def __init__(self, VOCAB_SIZE, EMBED_SIZE):\n",
|
||||||
|
" super(TrigramNNModel, self).__init__()\n",
|
||||||
|
" self.embeddings = nn.Embedding(VOCAB_SIZE, EMBED_SIZE)\n",
|
||||||
|
" self.hidden_layer = nn.Linear(EMBED_SIZE*2, 1200)\n",
|
||||||
|
" self.output_layer = nn.Linear(1200, VOCAB_SIZE)\n",
|
||||||
|
" self.softmax = nn.Softmax()\n",
|
||||||
|
"\n",
|
||||||
|
" def forward(self, x):\n",
|
||||||
|
" emb_2 = self.embeddings(x[0])\n",
|
||||||
|
" emb_1 = self.embeddings(x[1])\n",
|
||||||
|
" x = torch.cat([emb_2, emb_1], dim=1)\n",
|
||||||
|
" x = self.hidden_layer(x)\n",
|
||||||
|
" x = self.output_layer(x)\n",
|
||||||
|
" x = self.softmax(x)\n",
|
||||||
|
" return x\n",
|
||||||
|
"\n",
|
||||||
|
"model = TrigramNNModel(VOCAB_SIZE, EMBED_SIZE)\n",
|
||||||
|
"\n",
|
||||||
|
"vocab.set_default_index(vocab['<unk>'])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Training"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"device = 'cpu'\n",
|
||||||
|
"model = TrigramNNModel(VOCAB_SIZE, EMBED_SIZE).to(device)\n",
|
||||||
|
"data = DataLoader(train_dataset, batch_size=1_000)\n",
|
||||||
|
"optimizer = torch.optim.Adam(model.parameters())\n",
|
||||||
|
"criterion = torch.nn.NLLLoss()\n",
|
||||||
|
"\n",
|
||||||
|
"loss_track = []\n",
|
||||||
|
"last_loss = 1_000\n",
|
||||||
|
"trigger_count = 0\n",
|
||||||
|
"\n",
|
||||||
|
"model.train()\n",
|
||||||
|
"step = 0\n",
|
||||||
|
"for x, y in data:\n",
|
||||||
|
" x[0] = x[0].to(device)\n",
|
||||||
|
" x[1] = x[1].to(device)\n",
|
||||||
|
" y = y.to(device)\n",
|
||||||
|
" optimizer.zero_grad()\n",
|
||||||
|
" ypredicted = model(x)\n",
|
||||||
|
" loss = criterion(torch.log(ypredicted), y)\n",
|
||||||
|
" if step % 100 == 0:\n",
|
||||||
|
" print(step, loss)\n",
|
||||||
|
" step += 1\n",
|
||||||
|
" loss.backward()\n",
|
||||||
|
" optimizer.step()\n",
|
||||||
|
"\n",
|
||||||
|
" if loss > last_loss:\n",
|
||||||
|
" trigger_count += 1 \n",
|
||||||
|
" print(trigger_count, 'LOSS DIFF:', loss, last_loss)\n",
|
||||||
|
"\n",
|
||||||
|
" if trigger_count >= 500:\n",
|
||||||
|
" break\n",
|
||||||
|
"\n",
|
||||||
|
" loss_track.append(loss)\n",
|
||||||
|
" last_loss = loss"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"torch.save(model.state_dict(), f'model_trigram-EMBED_SIZE={EMBED_SIZE}.bin')\n",
|
||||||
|
"vocab_unique = set(vocab.get_stoi().keys())"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"output = []\n",
|
||||||
|
"pattern = re.compile('[^A-Za-z]+')\n",
|
||||||
|
"\n",
|
||||||
|
"with lzma.open(\"dev-0/in.tsv.xz\", encoding='utf8', mode=\"rt\") as file:\n",
|
||||||
|
" for line in file:\n",
|
||||||
|
" line = line.split(\"\\t\")\n",
|
||||||
|
" first_word = pattern.sub(' ', line[-2]).split()[-1]\n",
|
||||||
|
" second_word = pattern.sub(' ', line[-1]).split(maxsplit=1)[0]\n",
|
||||||
|
"\n",
|
||||||
|
" first_word = re.sub('[^A-Za-z]+', '', first_word)\n",
|
||||||
|
" second_word = re.sub('[^A-Za-z]+', '', second_word)\n",
|
||||||
|
"\n",
|
||||||
|
" first_word = \"<unk>\" if first_word not in vocab_unique else first_word\n",
|
||||||
|
" second_word = \"<unk>\" if second_word not in vocab_unique else second_word\n",
|
||||||
|
"\n",
|
||||||
|
" input_tokens = torch.tensor([vocab.forward([first_word]), vocab.forward([second_word])]).to(device)\n",
|
||||||
|
" out = model(input_tokens)\n",
|
||||||
|
"\n",
|
||||||
|
" top = torch.topk(out[0], 10)\n",
|
||||||
|
" top_indices = top.indices.tolist()\n",
|
||||||
|
" top_probs = top.values.tolist()\n",
|
||||||
|
" unk_bonus = 1 - sum(top_probs)\n",
|
||||||
|
" top_words = vocab.lookup_tokens(top_indices)\n",
|
||||||
|
" top_zipped = list(zip(top_words, top_probs))\n",
|
||||||
|
"\n",
|
||||||
|
" res = \" \".join([f\"{w}:{p:.4f}\" if w != \"<unk>\" else f\":{(p + unk_bonus):.4f}\" for w, p in top_zipped])\n",
|
||||||
|
" res += \"\\n\"\n",
|
||||||
|
" output.append(res)\n",
|
||||||
|
"\n",
|
||||||
|
"with open(f\"dev-0/out-EMBED_SIZE={EMBED_SIZE}.tsv\", mode=\"w\") as file:\n",
|
||||||
|
" file.writelines(output)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"output = []\n",
|
||||||
|
"pattern = re.compile('[^A-Za-z]+')\n",
|
||||||
|
"\n",
|
||||||
|
"with lzma.open(\"test-A/in.tsv.xz\", encoding='utf8', mode=\"rt\") as file:\n",
|
||||||
|
" for line in file:\n",
|
||||||
|
" line = line.split(\"\\t\")\n",
|
||||||
|
" first_word = pattern.sub(' ', line[-2]).split()[-1]\n",
|
||||||
|
" second_word = pattern.sub(' ', line[-1]).split(maxsplit=1)[0]\n",
|
||||||
|
"\n",
|
||||||
|
" first_word = re.sub('[^A-Za-z]+', '', first_word)\n",
|
||||||
|
" second_word = re.sub('[^A-Za-z]+', '', second_word)\n",
|
||||||
|
"\n",
|
||||||
|
" first_word = \"<unk>\" if first_word not in vocab_unique else first_word\n",
|
||||||
|
" second_word = \"<unk>\" if second_word not in vocab_unique else second_word\n",
|
||||||
|
"\n",
|
||||||
|
" input_tokens = torch.tensor([vocab.forward([first_word]), vocab.forward([second_word])]).to(device)\n",
|
||||||
|
" out = model(input_tokens)\n",
|
||||||
|
"\n",
|
||||||
|
" top = torch.topk(out[0], 10)\n",
|
||||||
|
" top_indices = top.indices.tolist()\n",
|
||||||
|
" top_probs = top.values.tolist()\n",
|
||||||
|
" unk_bonus = 1 - sum(top_probs)\n",
|
||||||
|
" top_words = vocab.lookup_tokens(top_indices)\n",
|
||||||
|
" top_zipped = list(zip(top_words, top_probs))\n",
|
||||||
|
"\n",
|
||||||
|
" res = \" \".join([f\"{w}:{p:.4f}\" if w != \"<unk>\" else f\":{(p + unk_bonus):.4f}\" for w, p in top_zipped])\n",
|
||||||
|
" res += \"\\n\"\n",
|
||||||
|
" output.append(res)\n",
|
||||||
|
"\n",
|
||||||
|
"with open(f\"test-A/out-EMBED_SIZE={EMBED_SIZE}.tsv\", mode=\"w\") as file:\n",
|
||||||
|
" file.writelines(output)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.2"
|
||||||
|
},
|
||||||
|
"orig_nbformat": 4
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user