n-gram model
This commit is contained in:
parent
4607559b8a
commit
4470830adf
10519
dev-0/in.tsv
Normal file
10519
dev-0/in.tsv
Normal file
File diff suppressed because it is too large
Load Diff
21038
dev-0/out.tsv
21038
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
769
src/04_statystyczny_model_językowy.ipynb
Normal file
769
src/04_statystyczny_model_językowy.ipynb
Normal file
@ -0,0 +1,769 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"<h1> Modelowanie języka </h1>\n",
|
||||||
|
"<h2> 4. <i>Statystyczny model językowy</i> [ćwiczenia]</h2> "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"NR_INDEKSU = 452629"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"https://web.stanford.edu/~jurafsky/slp3/3.pdf"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from collections import Counter, defaultdict\n",
|
||||||
|
"from tqdm import tqdm\n",
|
||||||
|
"import re\n",
|
||||||
|
"import nltk\n",
|
||||||
|
"import math\n",
|
||||||
|
"import random"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class Model():\n",
|
||||||
|
" \n",
|
||||||
|
" def __init__(self, vocab_size = 30_000, UNK_token = '<UNK>', n = 2):\n",
|
||||||
|
" self.n = n\n",
|
||||||
|
" self.vocab_size = vocab_size\n",
|
||||||
|
" self.UNK_token = UNK_token\n",
|
||||||
|
" self.ngrams = defaultdict(lambda: defaultdict(int))\n",
|
||||||
|
" self.contexts = defaultdict(int)\n",
|
||||||
|
" self.vocab = set()\n",
|
||||||
|
" \n",
|
||||||
|
" def train(self, corpus: list) -> None:\n",
|
||||||
|
" self.vocab = set()\n",
|
||||||
|
" self.vocab.add(self.UNK_token)\n",
|
||||||
|
"\n",
|
||||||
|
" counts = Counter(corpus)\n",
|
||||||
|
" most_common = counts.most_common(self.vocab_size - 1)\n",
|
||||||
|
" for word, _ in most_common:\n",
|
||||||
|
" self.vocab.add(word)\n",
|
||||||
|
"\n",
|
||||||
|
" corpus = [word if word in self.vocab else self.UNK_token for word in corpus]\n",
|
||||||
|
"\n",
|
||||||
|
" n_grams = list(nltk.ngrams(corpus, self.n))\n",
|
||||||
|
" for gram in tqdm(n_grams):\n",
|
||||||
|
" context = gram[:-1]\n",
|
||||||
|
" word = gram[-1]\n",
|
||||||
|
"\n",
|
||||||
|
" if word == self.UNK_token:\n",
|
||||||
|
" continue\n",
|
||||||
|
"\n",
|
||||||
|
" self.ngrams[context][word] += 1\n",
|
||||||
|
" self.contexts[context] += 1\n",
|
||||||
|
" \n",
|
||||||
|
" def get_conditional_prob_for_word(self, text: list, word: str) -> float:\n",
|
||||||
|
" if len(text) < self.n - 1:\n",
|
||||||
|
" raise ValueError(\"Text is too short for the given n-gram order.\")\n",
|
||||||
|
" \n",
|
||||||
|
" context = tuple(text[-self.n + 1:])\n",
|
||||||
|
" if context not in self.ngrams:\n",
|
||||||
|
" return 0.0\n",
|
||||||
|
" \n",
|
||||||
|
" total_count = sum(self.ngrams[context].values())\n",
|
||||||
|
" word_count = self.ngrams[context][word]\n",
|
||||||
|
" \n",
|
||||||
|
" if total_count == 0:\n",
|
||||||
|
" return 0.0\n",
|
||||||
|
" else:\n",
|
||||||
|
" return word_count / total_count\n",
|
||||||
|
" \n",
|
||||||
|
" def get_prob_for_text(self, text: list) -> float:\n",
|
||||||
|
" if len(text) < self.n - 1:\n",
|
||||||
|
" raise ValueError(\"Text is too short for the given n-gram order.\")\n",
|
||||||
|
" \n",
|
||||||
|
" prob = 1.0\n",
|
||||||
|
" n_grams = list(nltk.ngrams(text, self.n))\n",
|
||||||
|
" for gram in n_grams:\n",
|
||||||
|
" context = gram[:-1]\n",
|
||||||
|
" word = gram[-1]\n",
|
||||||
|
" prob *= self.get_conditional_prob_for_word(context, word)\n",
|
||||||
|
" \n",
|
||||||
|
" return prob\n",
|
||||||
|
" \n",
|
||||||
|
" def most_probable_next_word(self, text: list) -> str:\n",
|
||||||
|
" '''nie powinien zwracań nigdy <UNK>'''\n",
|
||||||
|
" if len(text) < self.n - 1:\n",
|
||||||
|
" raise ValueError(\"Text is too short for the given n-gram order.\")\n",
|
||||||
|
" \n",
|
||||||
|
" context = tuple(text[-self.n+1:])\n",
|
||||||
|
" if context not in self.ngrams:\n",
|
||||||
|
" return \"\"\n",
|
||||||
|
" \n",
|
||||||
|
" most_probable_word = max(self.ngrams[context], key=self.ngrams[context].get)\n",
|
||||||
|
" return most_probable_word\n",
|
||||||
|
" \n",
|
||||||
|
" def generate_text(self, text_beginning: list, length: int, greedy: bool) -> list:\n",
|
||||||
|
" '''nie powinien zwracań nigdy <UNK>'''\n",
|
||||||
|
" if len(text_beginning) < self.n - 1:\n",
|
||||||
|
" raise ValueError(\"Text beginning is too short for the given n-gram order.\")\n",
|
||||||
|
" \n",
|
||||||
|
" text_beginning = [word if word in self.vocab else self.UNK_token for word in text_beginning]\n",
|
||||||
|
" \n",
|
||||||
|
" generated_text = text_beginning[:]\n",
|
||||||
|
" while len(generated_text) < length:\n",
|
||||||
|
" if self.n == 1:\n",
|
||||||
|
" context = ()\n",
|
||||||
|
" else:\n",
|
||||||
|
" context = tuple(generated_text[-self.n+1:])\n",
|
||||||
|
" if greedy:\n",
|
||||||
|
" next_word = self.most_probable_next_word(context)\n",
|
||||||
|
" else:\n",
|
||||||
|
" candidate_words = list(self.ngrams[context].keys())\n",
|
||||||
|
" probabilities = [self.get_prob_for_text(generated_text + [word]) for word in candidate_words]\n",
|
||||||
|
" next_word = random.choices(candidate_words, weights=probabilities)[0]\n",
|
||||||
|
" \n",
|
||||||
|
" if next_word == self.UNK_token:\n",
|
||||||
|
" break\n",
|
||||||
|
" generated_text.append(next_word)\n",
|
||||||
|
" \n",
|
||||||
|
" return generated_text\n",
|
||||||
|
"\n",
|
||||||
|
" def get_perplexity(self, text: list) -> float:\n",
|
||||||
|
" if len(text) < self.n - 1:\n",
|
||||||
|
" raise ValueError(\"Text is too short for the given n-gram order.\")\n",
|
||||||
|
" \n",
|
||||||
|
" log_prob = 0.0\n",
|
||||||
|
" N = 0\n",
|
||||||
|
" for i in range(len(text) - self.n + 1):\n",
|
||||||
|
" context = text[i:i + self.n - 1]\n",
|
||||||
|
" word = text[i + self.n - 1]\n",
|
||||||
|
" prob = self.get_prob_for_text(context + [word])\n",
|
||||||
|
" if prob == 0.0:\n",
|
||||||
|
" return float('inf')\n",
|
||||||
|
" else:\n",
|
||||||
|
" log_prob += math.log2(self.get_prob_for_text(context + [word]))\n",
|
||||||
|
" N += 1\n",
|
||||||
|
" \n",
|
||||||
|
" if N == 0:\n",
|
||||||
|
" return float('inf')\n",
|
||||||
|
"\n",
|
||||||
|
" avg_log_prob = log_prob / N\n",
|
||||||
|
" perplexity = 2 ** (-avg_log_prob)\n",
|
||||||
|
" return perplexity"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Zadanie (60 punktów)\n",
|
||||||
|
"\n",
|
||||||
|
"- Wybierz tekst w dowolnym języku (10 000 000 słów).\n",
|
||||||
|
"- Podziel zbiór na train/test w proporcji 9:1.\n",
|
||||||
|
"- Stwórz unigramowy model językowy.\n",
|
||||||
|
"- Stwórz bigramowy model językowy.\n",
|
||||||
|
"- Stwórz trigramowy model językowy.\n",
|
||||||
|
"- Wymyśl 5 krótkich zdań. Dla każdego oblicz jego prawdopodobieństwo.\n",
|
||||||
|
"- Napisz włąsnoręcznie funkcję, która liczy perplexity na korpusie i policz perplexity na każdym z modeli dla podzbiorów train i test.\n",
|
||||||
|
"- Wygeneruj tekst, zaczynając od wymyślonych 5 początków. Postaraj się, żeby dla obu funkcji, a przynajmniej dla `high_probable_next_word`, teksty były orginalne.\n",
|
||||||
|
"- Stwórz model dla korpusu z ZADANIE 1 i policz perplexity dla każdego z tekstów (zrób split 9:1) dla train i test.\n",
|
||||||
|
"\n",
|
||||||
|
"Dodatkowo:\n",
|
||||||
|
"- Dokonaj klasyfikacji za pomocą modelu językowego.\n",
|
||||||
|
" - Znajdź duży zbiór danych dla klasyfikacji binarnej, wytrenuj osobne modele dla każdej z klas i użyj dla klasyfikacji.\n",
|
||||||
|
"- Zastosuj wygładzanie metodą Laplace'a."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"#### START ZADANIA"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Podział korpusu na train/test"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Train size: 90.004%\n",
|
||||||
|
"Test size: 9.996%\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"corpus = re.split(r'\\s+', open(\"04_materialy/pan-tadeusz.txt\", encoding=\"UTF-8\").read())\n",
|
||||||
|
"\n",
|
||||||
|
"split_index = int(len(corpus) * 0.9)\n",
|
||||||
|
"\n",
|
||||||
|
"while corpus[split_index].endswith(('.', '?', '!')) == False:\n",
|
||||||
|
" split_index += 1\n",
|
||||||
|
"split_index += 1\n",
|
||||||
|
"\n",
|
||||||
|
"train = corpus[:split_index]\n",
|
||||||
|
"test = corpus[split_index:]\n",
|
||||||
|
"\n",
|
||||||
|
"print(f\"Train size: {len(train)/len(corpus)*100:.3f}%\")\n",
|
||||||
|
"print(f\"Test size: {len(test)/len(corpus)*100:.3f}%\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Unigramowy model języka"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Training unigram model...\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 62189/62189 [00:00<00:00, 1066841.60it/s]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(\"Training unigram model...\")\n",
|
||||||
|
"unigram_model = Model(vocab_size = 300_000, n = 1)\n",
|
||||||
|
"unigram_model.train(train)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Generating text with unigram model... (greedy)\n",
|
||||||
|
"Śród takich pól przed laty, w w w w w w w w w w w w w w w\n",
|
||||||
|
"Generating text with unigram model... (non-greedy)\n",
|
||||||
|
"Śród takich pól przed laty, dworskiej od\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(\"Generating text with unigram model... (greedy)\")\n",
|
||||||
|
"text = unigram_model.generate_text(re.split(r'\\s+', 'Śród takich pól przed laty,'), 20, greedy = True)\n",
|
||||||
|
"print(' '.join(text))\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"Generating text with unigram model... (non-greedy)\")\n",
|
||||||
|
"text = unigram_model.generate_text(re.split(r'\\s+', 'Śród takich pól przed laty,'), 7, greedy = False)\n",
|
||||||
|
"print(' '.join(text))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Bigramowy model języka"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Training bigram model...\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" 0%| | 0/62188 [00:00<?, ?it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 62188/62188 [00:00<00:00, 714486.32it/s]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(\"Training bigram model...\")\n",
|
||||||
|
"bigram_model = Model(vocab_size = 300_000, n = 2)\n",
|
||||||
|
"bigram_model.train(train)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Generating text with bigram model... (greedy)\n",
|
||||||
|
"Śród takich pól przed laty, nad nim się w tym łacniej w tym łacniej w tym łacniej w tym łacniej\n",
|
||||||
|
"Generating text with bigram model... (non-greedy)\n",
|
||||||
|
"Śród takich pól przed laty, nad Woźnego lepiej niedźwiedź kości; Pójdź, księże, w sądy podkomorskie. Dotąd mej Birbante-rokka: Oby ten\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(\"Generating text with bigram model... (greedy)\")\n",
|
||||||
|
"text = bigram_model.generate_text(re.split(r'\\s+', 'Śród takich pól przed laty,'), 20, greedy = True)\n",
|
||||||
|
"print(' '.join(text))\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"Generating text with bigram model... (non-greedy)\")\n",
|
||||||
|
"text = bigram_model.generate_text(re.split(r'\\s+', 'Śród takich pól przed laty,'), 20, greedy = False)\n",
|
||||||
|
"print(' '.join(text))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Trigramowy model języka"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Training trigram model...\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 62187/62187 [00:00<00:00, 295370.31it/s]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(\"Training trigram model...\")\n",
|
||||||
|
"trigram_model = Model(vocab_size = 300_000, n = 3)\n",
|
||||||
|
"trigram_model.train(train)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Generating text with trigram model... (greedy)\n",
|
||||||
|
"Śród takich pól przed laty, nad brzegiem ruczaju, Na pagórku niewielkim, we brzozowym gaju, Stał dwór szlachecki, z drzewa, lecz podmurowany; Świeciły się z nim na miejscu pustym oczy swe osadzał. Dziwna rzecz! miejsca wkoło są siedzeniem dziewic, Na które\n",
|
||||||
|
"Generating text with trigram model... (non-greedy)\n",
|
||||||
|
"Śród takich pól przed laty, nad brzegiem ruczaju, Na pagórku niewielkim, we brzozowym gaju, Stał dwór szlachecki, z drzewa, gotyckiej naśladowstwo sztuki. Z wierzchu ozdoby sztuczne, nie rylcem, nie dłutem, Ale zręcznie ciesielskim wyrzezane sklutem, Krzywe jak szabasowych ramiona świeczników;\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(\"Generating text with trigram model... (greedy)\")\n",
|
||||||
|
"text = trigram_model.generate_text(re.split(r'\\s+', 'Śród takich pól przed laty,'), 40, greedy = True)\n",
|
||||||
|
"print(' '.join(text))\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"Generating text with trigram model... (non-greedy)\")\n",
|
||||||
|
"text = trigram_model.generate_text(re.split(r'\\s+', 'Śród takich pól przed laty,'), 40, greedy = False)\n",
|
||||||
|
"print(' '.join(text))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Wymyśl 5 krótkich zdań. Dla każdego oblicz jego prawdopodobieństwo."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 11,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Sentence: Nikt go na polowanie\n",
|
||||||
|
"Unigram model: 0.0000000000\n",
|
||||||
|
"Bigram model: 0.0000027142\n",
|
||||||
|
"Trigram model: 0.2000000000\n",
|
||||||
|
"\n",
|
||||||
|
"Sentence: Podróżny długo w oknie stał\n",
|
||||||
|
"Unigram model: 0.0000000000\n",
|
||||||
|
"Bigram model: 0.0000124784\n",
|
||||||
|
"Trigram model: 0.3333333333\n",
|
||||||
|
"\n",
|
||||||
|
"Sentence: Rzekł z uśmiechem,\n",
|
||||||
|
"Unigram model: 0.0000000001\n",
|
||||||
|
"Bigram model: 0.0000521023\n",
|
||||||
|
"Trigram model: 1.0000000000\n",
|
||||||
|
"\n",
|
||||||
|
"Sentence: Pod płotem wąskie, długie, wypukłe pagórki,\n",
|
||||||
|
"Unigram model: 0.0000000000\n",
|
||||||
|
"Bigram model: 0.0192307692\n",
|
||||||
|
"Trigram model: 1.0000000000\n",
|
||||||
|
"\n",
|
||||||
|
"Sentence: Hrabia oczy roztworzył.\n",
|
||||||
|
"Unigram model: 0.0000000000\n",
|
||||||
|
"Bigram model: 0.0004479283\n",
|
||||||
|
"Trigram model: 0.5000000000\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"sentences = [\n",
|
||||||
|
" \"Nikt go na polowanie\",\n",
|
||||||
|
" \"Podróżny długo w oknie stał\",\n",
|
||||||
|
" \"Rzekł z uśmiechem,\",\n",
|
||||||
|
" \"Pod płotem wąskie, długie, wypukłe pagórki,\",\n",
|
||||||
|
" \"Hrabia oczy roztworzył.\"\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"for sentence in sentences:\n",
|
||||||
|
" sentence = re.split(r'\\s+', sentence)\n",
|
||||||
|
" print(f\"Sentence: {' '.join(sentence)}\")\n",
|
||||||
|
" print(f\"Unigram model: {unigram_model.get_prob_for_text(sentence):.10f}\")\n",
|
||||||
|
" print(f\"Bigram model: {bigram_model.get_prob_for_text(sentence):.10f}\")\n",
|
||||||
|
" print(f\"Trigram model: {trigram_model.get_prob_for_text(sentence):.10f}\")\n",
|
||||||
|
" print()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Napisz włąsnoręcznie funkcję, która liczy perplexity na korpusie i policz perplexity na każdym z modeli dla podzbiorów train i test."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Calculating perplexity for unigram model...\n",
|
||||||
|
"Train perplexity: 5666.4901484896\n",
|
||||||
|
"Test perplexity: inf\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"Calculating perplexity for bigram model...\n",
|
||||||
|
"Train perplexity: 9.1369500910\n",
|
||||||
|
"Test perplexity: inf\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"Calculating perplexity for trigram model...\n",
|
||||||
|
"Train perplexity: 1.1857475475\n",
|
||||||
|
"Test perplexity: inf\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(\"Calculating perplexity for unigram model...\")\n",
|
||||||
|
"train_perplexity = unigram_model.get_perplexity(train)\n",
|
||||||
|
"test_perplexity = unigram_model.get_perplexity(test)\n",
|
||||||
|
"print(f\"Train perplexity: {train_perplexity:.10f}\")\n",
|
||||||
|
"print(f\"Test perplexity: {test_perplexity:.10f}\")\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"\\n\")\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"Calculating perplexity for bigram model...\")\n",
|
||||||
|
"train_perplexity = bigram_model.get_perplexity(train)\n",
|
||||||
|
"test_perplexity = bigram_model.get_perplexity(test)\n",
|
||||||
|
"print(f\"Train perplexity: {train_perplexity:.10f}\")\n",
|
||||||
|
"print(f\"Test perplexity: {test_perplexity:.10f}\")\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"\\n\")\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"Calculating perplexity for trigram model...\")\n",
|
||||||
|
"train_perplexity = trigram_model.get_perplexity(train)\n",
|
||||||
|
"test_perplexity = trigram_model.get_perplexity(test)\n",
|
||||||
|
"print(f\"Train perplexity: {train_perplexity:.10f}\")\n",
|
||||||
|
"print(f\"Test perplexity: {test_perplexity:.10f}\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Wygeneruj tekst, zaczynając od wymyślonych 5 początków. Postaraj się, żeby dla obu funkcji, a przynajmniej dla `high_probable_next_word`, teksty były orginalne."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 13,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Sentence: Nikt go na polowanie\n",
|
||||||
|
"Unigram model: Nikt go na polowanie w w w w w w w w w w w w w w w w\n",
|
||||||
|
"Bigram model: Nikt go na polowanie uprosić nie jest w tym łacniej w tym łacniej w tym łacniej w tym łacniej w\n",
|
||||||
|
"Trigram model: Nikt go na polowanie uprosić nie może, Białopiotrowiczowi samemu odmówił! Bo cóż by on na waszych polowaniach łowił? Piękna byłaby\n",
|
||||||
|
"\n",
|
||||||
|
"Sentence: Podróżny długo w oknie stał\n",
|
||||||
|
"Unigram model: Podróżny długo w oknie stał w w w w w w w w w w w w w w w\n",
|
||||||
|
"Bigram model: Podróżny długo w oknie stał w tym łacniej w tym łacniej w tym łacniej w tym łacniej w tym łacniej\n",
|
||||||
|
"Trigram model: Podróżny długo w oknie stał patrząc, dumając, Wonnymi powiewami kwiatów oddychając. Oblicze aż na krzaki fijołkowe skłonił, Oczyma ciekawymi po\n",
|
||||||
|
"\n",
|
||||||
|
"Sentence: Rzekł z uśmiechem,\n",
|
||||||
|
"Unigram model: Rzekł z uśmiechem, w w w w w w w w w w w w w w w w w\n",
|
||||||
|
"Bigram model: Rzekł z uśmiechem, a na kształt ogromnego gmachu, Słońce ostatnich kresów nieba dochodziło, Mniej silnie, ale nie jest w tym\n",
|
||||||
|
"Trigram model: Rzekł z uśmiechem, a był to pan kapitan Ryków, Stary żołnierz, stał w bliskiej wiosce na kwaterze, Pan Sędzia nagłym\n",
|
||||||
|
"\n",
|
||||||
|
"Sentence: Pod płotem wąskie, długie, wypukłe pagórki,\n",
|
||||||
|
"Unigram model: Pod płotem wąskie, długie, wypukłe pagórki, w w w w w w w w w w w w w w\n",
|
||||||
|
"Bigram model: Pod płotem wąskie, długie, wypukłe pagórki, Bez Suwarowa to nie jest w tym łacniej w tym łacniej w tym łacniej\n",
|
||||||
|
"Trigram model: Pod płotem wąskie, długie, wypukłe pagórki, Bez drzew, krzewów i kwiatów: ogród na ogórki. Pięknie wyrosły; liściem wielkim, rozłożystym, Okryły\n",
|
||||||
|
"\n",
|
||||||
|
"Sentence: Hrabia oczy roztworzył.\n",
|
||||||
|
"Unigram model: Hrabia oczy roztworzył. w w w w w w w w w w w w w w w w w\n",
|
||||||
|
"Bigram model: Hrabia oczy roztworzył. Zmieszany, zdziwiony, Milczał; bo w tym łacniej w tym łacniej w tym łacniej w tym łacniej w\n",
|
||||||
|
"Trigram model: Hrabia oczy roztworzył. Zmieszany, zdziwiony, Milczał; wreszcie, zniżając swej rozmowy tony: «Przepraszam — rzekł — mój Rejencie, prawda bez wątpienia,\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"sentences = [\n",
|
||||||
|
" \"Nikt go na polowanie\",\n",
|
||||||
|
" \"Podróżny długo w oknie stał\",\n",
|
||||||
|
" \"Rzekł z uśmiechem,\",\n",
|
||||||
|
" \"Pod płotem wąskie, długie, wypukłe pagórki,\",\n",
|
||||||
|
" \"Hrabia oczy roztworzył.\"\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"for sentence in sentences:\n",
|
||||||
|
" sentence = re.split(r'\\s+', sentence)\n",
|
||||||
|
" print(f\"Sentence: {' '.join(sentence)}\")\n",
|
||||||
|
" print(f\"Unigram model: {' '.join(unigram_model.generate_text(sentence, 20, greedy = True))}\")\n",
|
||||||
|
" print(f\"Bigram model: {' '.join(bigram_model.generate_text(sentence, 20, greedy = True))}\")\n",
|
||||||
|
" print(f\"Trigram model: {' '.join(trigram_model.generate_text(sentence, 20, greedy = True))}\")\n",
|
||||||
|
" print()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Dokonaj klasyfikacji za pomocą modelu językowego.\n",
|
||||||
|
"- Znajdź duży zbiór danych dla klasyfikacji binarnej, wytrenuj osobne modele dla każdej z klas i użyj dla klasyfikacji."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 24,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 136592/136592 [00:00<00:00, 357332.74it/s]\n",
|
||||||
|
"100%|██████████| 126878/126878 [00:00<00:00, 299366.46it/s]\n",
|
||||||
|
"100%|██████████| 498/498 [00:00<00:00, 71213.51it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Accuracy: 0.645\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from datasets import load_dataset\n",
|
||||||
|
"\n",
|
||||||
|
"# Load dataset as sentiment140\n",
|
||||||
|
"dataset = load_dataset(\"sentiment140\")\n",
|
||||||
|
"train = zip(dataset[\"train\"][\"text\"], dataset[\"train\"][\"sentiment\"])\n",
|
||||||
|
"test = zip(dataset[\"test\"][\"text\"], dataset[\"test\"][\"sentiment\"])\n",
|
||||||
|
"\n",
|
||||||
|
"train = list(train)\n",
|
||||||
|
"random.shuffle(train)\n",
|
||||||
|
"train = list(train)[:20_000]\n",
|
||||||
|
"\n",
|
||||||
|
"test = list(test)\n",
|
||||||
|
"random.shuffle(test)\n",
|
||||||
|
"test = list(test)[:1_000]\n",
|
||||||
|
"\n",
|
||||||
|
"pos = [text.split() for text, label in train if label == 0]\n",
|
||||||
|
"neg = [text.split() for text, label in train if label > 0]\n",
|
||||||
|
"\n",
|
||||||
|
"pos_model = Model(vocab_size = 6_000_000, n = 3)\n",
|
||||||
|
"neg_model = Model(vocab_size = 6_000_000, n = 3)\n",
|
||||||
|
"\n",
|
||||||
|
"pos_model.train(sum(pos, []))\n",
|
||||||
|
"neg_model.train(sum(neg, []))\n",
|
||||||
|
"\n",
|
||||||
|
"correct = 0\n",
|
||||||
|
"for text, label in tqdm(test):\n",
|
||||||
|
" text = text.split()\n",
|
||||||
|
" pos_perplexity = pos_model.get_perplexity(text)\n",
|
||||||
|
" neg_perplexity = neg_model.get_perplexity(text)\n",
|
||||||
|
" result = \"pos\" if pos_perplexity < neg_perplexity else \"neg\"\n",
|
||||||
|
" if result == \"pos\" and label == 0:\n",
|
||||||
|
" correct += 1\n",
|
||||||
|
" elif result == \"neg\" and label > 0:\n",
|
||||||
|
" correct += 1\n",
|
||||||
|
"\n",
|
||||||
|
"accuracy = correct / len(test)\n",
|
||||||
|
"print(f\"Accuracy: {accuracy:.3f}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Zastosuj wygładzanie metodą Laplace'a."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 15,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"on , gdy tańczyłem , krzyknął : » precz za drzwi złodzieja ! « że wtenczas za pułkowej okradzenie kasy\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from nltk.lm import Laplace\n",
|
||||||
|
"from nltk.lm.preprocessing import padded_everygram_pipeline\n",
|
||||||
|
"from nltk.tokenize import sent_tokenize, word_tokenize\n",
|
||||||
|
"\n",
|
||||||
|
"n = 5\n",
|
||||||
|
"tokenized_text = [list(map(str.lower, word_tokenize(sent))) for sent in sent_tokenize(open(\"04_materialy/pan-tadeusz.txt\", encoding=\"UTF-8\").read())]\n",
|
||||||
|
"train, vocab = padded_everygram_pipeline(n, tokenized_text)\n",
|
||||||
|
"\n",
|
||||||
|
"model = Laplace(n)\n",
|
||||||
|
"model.fit(train, vocab)\n",
|
||||||
|
"\n",
|
||||||
|
"print(' '.join(model.generate(20, random_seed=42)))\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"#### KONIEC ZADANIA"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## WYKONANIE ZADAŃ\n",
|
||||||
|
"Zgodnie z instrukcją 01_Kodowanie_tekstu.ipynb"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Teoria informacji"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Wygładzanie modeli językowych"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"author": "Jakub Pokrywka",
|
||||||
|
"email": "kubapok@wmi.amu.edu.pl",
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"lang": "pl",
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.1"
|
||||||
|
},
|
||||||
|
"subtitle": "0.Informacje na temat przedmiotu[ćwiczenia]",
|
||||||
|
"title": "Ekstrakcja informacji",
|
||||||
|
"year": "2021"
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 4
|
||||||
|
}
|
48
src/evaluate.py
Normal file
48
src/evaluate.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
import csv
|
||||||
|
from model import Model
|
||||||
|
from tqdm import tqdm
|
||||||
|
import re
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
|
||||||
|
print("Loading model")
|
||||||
|
dataset_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'train', 'in.tsv.xz'))
|
||||||
|
model = Model.load(os.path.abspath(os.path.join(os.path.dirname(dataset_dir), 'model.pkl')))
|
||||||
|
|
||||||
|
print("Evaluating")
|
||||||
|
dataset_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', sys.argv[1], 'in.tsv.xz'))
|
||||||
|
output_dir = os.path.abspath(os.path.join(os.path.dirname(dataset_dir), 'out.tsv'))
|
||||||
|
|
||||||
|
df = pd.read_csv(dataset_dir, sep='\t', header=None, names=['FileId', 'Year', 'LeftContext', 'RightContext'], quoting=csv.QUOTE_NONE)
|
||||||
|
df = df.replace(r'\\r+|\\n+|\\t+','', regex=True)
|
||||||
|
|
||||||
|
final = ""
|
||||||
|
|
||||||
|
for i, (_, row) in tqdm(enumerate(df.iterrows()), total=len(df)):
|
||||||
|
text = ""
|
||||||
|
prob_sum = 0.0
|
||||||
|
|
||||||
|
probs = model.fill_gap(re.split(r"\s+", row['LeftContext']), re.split(r"\s+", row['RightContext']))
|
||||||
|
|
||||||
|
if len(probs) == 0:
|
||||||
|
text = ":1"
|
||||||
|
else:
|
||||||
|
prob_sum = sum([prob for _, prob in probs])
|
||||||
|
|
||||||
|
for word, prob in probs:
|
||||||
|
new_prob = math.floor(prob / prob_sum * 1000) / 1000
|
||||||
|
|
||||||
|
if new_prob == 1.0:
|
||||||
|
new_prob = 0.999
|
||||||
|
|
||||||
|
text += f"{word}:{new_prob} "
|
||||||
|
|
||||||
|
text += ":0.001"
|
||||||
|
|
||||||
|
final += text + "\n"
|
||||||
|
|
||||||
|
with open(output_dir, 'w', encoding="UTF-8") as f:
|
||||||
|
f.write(final)
|
108
src/model.py
Normal file
108
src/model.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
from collections import defaultdict, Counter
|
||||||
|
from tqdm import tqdm
|
||||||
|
import nltk
|
||||||
|
import random
|
||||||
|
import pickle
|
||||||
|
import math
|
||||||
|
|
||||||
|
class Model():
|
||||||
|
|
||||||
|
def __init__(self, UNK_token = '<UNK>', n = 3):
|
||||||
|
self.n = n
|
||||||
|
self.UNK_token = UNK_token
|
||||||
|
self.ngrams = defaultdict(defaultdict(int).copy)
|
||||||
|
self.contexts = defaultdict(int)
|
||||||
|
self.tokenizer = { UNK_token: 0 }
|
||||||
|
self.reverse_tokenizer = { 0: UNK_token }
|
||||||
|
self._tokenizer_index = 1
|
||||||
|
self.vocab = set()
|
||||||
|
|
||||||
|
self.n_split = self.n // 2
|
||||||
|
|
||||||
|
def train_tokenizer(self, corpus: list) -> list[int]:
|
||||||
|
for word in tqdm(corpus):
|
||||||
|
if word not in self.vocab:
|
||||||
|
self.vocab.add(word)
|
||||||
|
self.tokenizer[word] = self._tokenizer_index
|
||||||
|
self.reverse_tokenizer[self._tokenizer_index] = word
|
||||||
|
|
||||||
|
self._tokenizer_index += 1
|
||||||
|
|
||||||
|
def tokenize(self, corpus: list, verbose = False) -> list[int]:
|
||||||
|
result = []
|
||||||
|
|
||||||
|
for word in tqdm(corpus) if verbose else corpus:
|
||||||
|
if word not in self.vocab:
|
||||||
|
result.append(self.tokenizer[self.UNK_token])
|
||||||
|
else:
|
||||||
|
result.append(self.tokenizer[word])
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def train(self, corpus: list) -> None:
|
||||||
|
|
||||||
|
print("Training tokenizer")
|
||||||
|
self.train_tokenizer(corpus)
|
||||||
|
|
||||||
|
print("Tokenizing corpus")
|
||||||
|
corpus = self.tokenize(corpus, verbose = True)
|
||||||
|
|
||||||
|
print("Saving n-grams")
|
||||||
|
n_grams = list(nltk.ngrams(corpus, self.n))
|
||||||
|
for gram in tqdm(n_grams):
|
||||||
|
left_context = gram[:self.n_split]
|
||||||
|
right_context = gram[self.n_split + 1:]
|
||||||
|
word = gram[self.n_split]
|
||||||
|
|
||||||
|
if word == self.UNK_token:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.ngrams[(left_context, right_context)][word] += 1
|
||||||
|
self.contexts[(left_context, right_context)] += 1
|
||||||
|
|
||||||
|
def get_conditional_probability_for_word(self, left_context: list, right_context: list, word: str) -> float:
|
||||||
|
left_context = tuple(left_context[-self.n_split:])
|
||||||
|
right_context = tuple(right_context[:self.n_split])
|
||||||
|
|
||||||
|
total_count = self.contexts[(left_context, right_context)]
|
||||||
|
|
||||||
|
if total_count == 0:
|
||||||
|
return 0.0
|
||||||
|
else:
|
||||||
|
word_count = self.ngrams[(left_context, right_context)][word]
|
||||||
|
|
||||||
|
return word_count / total_count
|
||||||
|
|
||||||
|
def get_probabilities(self, left_context: list, right_context: list) -> float:
|
||||||
|
left_context = tuple(left_context[-self.n_split:])
|
||||||
|
right_context = tuple(right_context[:self.n_split])
|
||||||
|
|
||||||
|
words = list(self.ngrams[(left_context, right_context)].keys())
|
||||||
|
probs = []
|
||||||
|
|
||||||
|
for word in words:
|
||||||
|
prob = self.get_conditional_probability_for_word(left_context, right_context, word)
|
||||||
|
probs.append((word, prob))
|
||||||
|
|
||||||
|
return sorted(probs, reverse = True, key = lambda x: x[0])[:10]
|
||||||
|
|
||||||
|
def fill_gap(self, left_context: list, right_context: list) -> list:
|
||||||
|
left_context = self.tokenize(left_context)
|
||||||
|
right_context = self.tokenize(right_context)
|
||||||
|
|
||||||
|
result = []
|
||||||
|
probabilities = self.get_probabilities(left_context, right_context)
|
||||||
|
for probability in probabilities:
|
||||||
|
word = self.reverse_tokenizer[probability[0]]
|
||||||
|
result.append((word, probability[1]))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def save(self, output_dir: str) -> None:
|
||||||
|
with open(output_dir, 'wb') as f:
|
||||||
|
pickle.dump(self, f)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(model_path: str) -> 'Model':
|
||||||
|
with open(model_path, 'rb') as f:
|
||||||
|
return pickle.load(f)
|
41
src/train.py
Normal file
41
src/train.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
from collections import Counter, defaultdict
|
||||||
|
from tqdm import tqdm
|
||||||
|
import re
|
||||||
|
import nltk
|
||||||
|
import random
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import pickle
|
||||||
|
import csv
|
||||||
|
import pandas as pd
|
||||||
|
from model import Model
|
||||||
|
|
||||||
|
dataset_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'train', 'in.tsv.xz'))
|
||||||
|
expected_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'train', 'expected.tsv'))
|
||||||
|
|
||||||
|
model = Model(n = 3)
|
||||||
|
|
||||||
|
df = pd.read_csv(dataset_dir, sep='\t', header=None, names=['FileId', 'Year', 'LeftContext', 'RightContext'], quoting=csv.QUOTE_NONE, chunksize=10 ** 2)
|
||||||
|
expected_df = pd.read_csv(expected_dir, sep='\t', header=None, names=['Word'], quoting=csv.QUOTE_NONE, chunksize=10 ** 2)
|
||||||
|
|
||||||
|
print('Loading training corpus...')
|
||||||
|
corpus = []
|
||||||
|
for j, chunk in tqdm(enumerate(zip(df, expected_df)), total=4321):
|
||||||
|
df, expected_df = chunk
|
||||||
|
|
||||||
|
df = df.replace(r'\\r+|\\n+|\\t+','', regex=True)
|
||||||
|
|
||||||
|
for (_, row1), (_, row2) in zip(df.iterrows(), expected_df.iterrows()):
|
||||||
|
word = row2['Word']
|
||||||
|
left_context = row1['LeftContext']
|
||||||
|
right_context = row1['RightContext']
|
||||||
|
|
||||||
|
corpus.extend(left_context.split() + [word] + right_context.split())
|
||||||
|
|
||||||
|
# if j > 50:
|
||||||
|
# break
|
||||||
|
|
||||||
|
print('Training model...')
|
||||||
|
model.train(corpus)
|
||||||
|
print('Saving model...')
|
||||||
|
model.save(os.path.abspath(os.path.join(os.path.dirname(dataset_dir), 'model.pkl')))
|
7414
test-A/in.tsv
Normal file
7414
test-A/in.tsv
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user