challenging-america-word-ga.../main.ipynb

362 lines
47 KiB
Plaintext
Raw Normal View History

2024-04-24 02:48:35 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
2024-04-24 02:48:35 +02:00
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import csv\n",
"import re\n",
"from collections import Counter, defaultdict\n",
"import nltk\n",
"import math\n",
"from tqdm import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 2,
2024-04-24 02:48:35 +02:00
"metadata": {},
"outputs": [],
"source": [
"directory = \"train/in.tsv.xz\"\n",
"directory_expected = \"train/expected.tsv\"\n",
2024-04-24 02:48:35 +02:00
"directory_dev_0 = \"dev-0/in.tsv.xz\"\n",
"directory_test_A = \"test-A/in.tsv.xz\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### MODEL N-GRAM"
]
},
{
"cell_type": "code",
"execution_count": 3,
2024-04-24 02:48:35 +02:00
"metadata": {},
"outputs": [],
"source": [
"class Model():\n",
" \n",
" def __init__(self, vocab_size=30_000, UNK_token= '<UNK>', n=3):\n",
" if (n <= 1 or n % 2 == 0):\n",
" raise \"change N value !!!\"\n",
" self.n = n\n",
" self.vocab_size = vocab_size\n",
" self.UNK_token = UNK_token\n",
" \n",
" def train(self, corpus:list) -> None:\n",
" if(self.n > 1):\n",
" self.n_grams = list(nltk.ngrams(corpus, n=self.n))\n",
" else:\n",
" self.n_grams = corpus\n",
" self.counter = Counter(self.n_grams)\n",
" self.words_counter = Counter(corpus)\n",
" self.all_quantities = Counter([gram[:math.floor(self.n/2)]+gram[math.ceil(self.n/2):] for gram in self.n_grams])\n",
"\n",
" self.all_grams = defaultdict(set)\n",
"\n",
" for gram in tqdm(self.n_grams):\n",
" previous_words = tuple(gram[:math.floor(self.n/2)])\n",
" next_words = tuple(gram[math.ceil(self.n/2):])\n",
" word = gram[math.floor(self.n/2)]\n",
" self.all_grams[(previous_words, next_words)].add(word)\n",
"\n",
" def get_conditional_prob_for_word(self, left_text: list, right_text: list, word: str) -> float:\n",
" previous_words = tuple(left_text[-math.floor(self.n/2):])\n",
" next_words = tuple(right_text[:math.floor(self.n/2)])\n",
" quantity = self.counter[previous_words + tuple([word]) + next_words]\n",
" all_quantity = self.all_quantities[previous_words + next_words]\n",
" if (all_quantity <= 0):\n",
" return 0\n",
" return quantity/all_quantity\n",
" \n",
" def get_prob_for_text(self, text: list) -> float:\n",
" prob = 1\n",
" for gram in list(nltk.ngrams(text, self.n)):\n",
" prob *= self.get_conditional_prob_for_word(gram[:math.floor(self.n/2)], gram[math.ceil(self.n/2):], gram[math.floor(self.n/2)])\n",
" return prob\n",
" \n",
" def most_probable_words(self, left_text: list, right_text: list) -> str:\n",
" previous_words = tuple(left_text[-math.floor(self.n/2):])\n",
" next_words = tuple(right_text[:math.floor(self.n/2)])\n",
" all_words = self.all_grams[(previous_words, next_words)]\n",
" best_words = []\n",
" for word in all_words:\n",
" probability = self.get_conditional_prob_for_word(list(previous_words), list(next_words), word)\n",
" best_words.append((word, probability))\n",
" return sorted(best_words, key=(lambda l: l[1]), reverse=True)[:20]\n",
" \n",
" def generate_text(self, text_beggining:list, text_ending:list, greedy: bool) -> list:\n",
" words = self.most_probable_words(text_beggining, text_ending)\n",
" return words\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### DATASET"
]
},
{
"cell_type": "code",
"execution_count": 4,
2024-04-24 02:48:35 +02:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-04-28 20:32:46 +02:00
"['came', 'fiom', 'the', 'last', 'place', 'to', 'this', 'place,', 'and', 'this', 'place', 'is', 'Where', 'We', 'Were,', 'this', 'is', 'the', 'first', 'road', 'I', 'ever', 'was', 'on', 'where', 'you', 'can', 'ride', 'elsewhere', 'from', 'anywhere', 'and', 'be', 'nowhere.', 'He', 'says,', 'while', 'this', 'train', 'stops', 'every-', 'where,', 'it', 'never', 'stops', 'anywhere', 'un-', 'less', 'its', 'somewhere.', 'Well,', 'I', 'says,', \"I'm\", 'glad', 'to', 'hear', 'that,', 'but,', 'accord-', 'ing', 'to', 'your', 'figures,', 'I', 'left', 'myself', 'where', '1', 'was,', 'which', 'is', 'five', 'miles', 'near-', 'er', 'to', 'myself', 'than', 'I', 'was', 'when', 'we', 'were', 'where', 'we', 'are', 'now.', 'We', 'have', 'now', 'reached', 'Slidell.', \"That's\", 'a', 'fine', 'place.', 'The', 'people', 'down']\n"
2024-04-24 02:48:35 +02:00
]
}
],
"source": [
2024-04-28 20:32:46 +02:00
"dataframeList = pd.read_csv(directory, sep='\\t', header=None, names=['FileId', 'Year', 'LeftContext', 'RightContext'], quoting=csv.QUOTE_NONE, chunksize=10000)\n",
2024-04-24 02:48:35 +02:00
"\n",
"expectedList = pd.read_csv(directory_expected, sep='\\t', header=None, names=['Word'], quoting=csv.QUOTE_NONE, chunksize=10000)\n",
2024-04-24 02:48:35 +02:00
"\n",
"DATASET = \"\"\n",
"\n",
"for number, (dataframe, expected) in enumerate(zip(dataframeList, expectedList)):\n",
" dataframe = dataframe.reset_index()\n",
2024-04-28 20:32:46 +02:00
" dataframe = dataframe.replace(r'\\\\r|\\\\n|\\n|\\\\t', ' ', regex=True)\n",
"\n",
" expected['Word'] = expected['Word'].apply(lambda x: str(x).strip())\n",
" word = expected['Word']\n",
"\n",
2024-04-24 02:48:35 +02:00
" left_text = dataframe['LeftContext'].to_list()\n",
" right_text = dataframe['RightContext'].to_list()\n",
" word = expected['Word'].to_list()\n",
"\n",
" lines = zip(left_text, word, right_text)\n",
" lines = list(map(lambda l: \" \".join(l), lines))\n",
" DATASET = DATASET + \" \".join(lines)\n",
"\n",
"FINAL_DATASET = re.split(r\"\\s+\", DATASET)\n",
"print(FINAL_DATASET[:100])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### TRAIN"
]
},
{
"cell_type": "code",
"execution_count": 5,
2024-04-24 02:48:35 +02:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 139475976/139475976 [04:39<00:00, 498903.89it/s]\n"
2024-04-24 02:48:35 +02:00
]
}
],
"source": [
"model_3gram = Model(n = 3)\n",
"model_3gram.train(FINAL_DATASET)"
]
},
{
"cell_type": "code",
"execution_count": 6,
2024-04-24 02:48:35 +02:00
"metadata": {},
"outputs": [],
"source": [
"model = model_3gram"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### PREDICTION"
]
},
{
"cell_type": "code",
"execution_count": 7,
2024-04-24 02:48:35 +02:00
"metadata": {},
"outputs": [],
"source": [
"def convert_predictions(line):\n",
" sum_predictions = np.sum([pred[1] for pred in line])\n",
" result = \"\"\n",
" all_pred = 0\n",
" for word, pred in line:\n",
" new_pred = math.floor(pred / sum_predictions * 100) / 100\n",
" if(new_pred == 1.0):\n",
" new_pred = 0.99\n",
" all_pred = all_pred + new_pred\n",
" result = result + word + \":\" + str(new_pred) + \" \"\n",
" if(round(all_pred, 2) < 1):\n",
" result = result + \":\" + str(round(1 - all_pred, 2))\n",
" else:\n",
" result = result + \":\" + str(0.01)\n",
" return result"
]
},
{
"cell_type": "code",
"execution_count": 8,
2024-04-24 02:48:35 +02:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"10519it [00:27, 385.35it/s]"
2024-04-24 02:48:35 +02:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[], [('successors', 0.006017228274638966), ('passage', 0.005193818089688371), ('place,', 0.005067139599695972), ('growth', 0.004813782619711173), ('use,', 0.004117050924752977), ('head', 0.003737015454775779), ('functions,', 0.0034836584747909806), ('power', 0.0034836584747909806), ('place', 0.003356979984798581), ('own,', 0.0032936407398023817), ('own', 0.0032936407398023817), ('members', 0.0032936407398023817), ('work', 0.003230301494806182), ('principles', 0.0031669622498099823), ('strength', 0.003040283759817583), ('value', 0.003040283759817583), ('beauty', 0.0026602482898403852), ('business', 0.0025969090448441853), ('size', 0.0025969090448441853), ('history', 0.0025969090448441853)], [('a', 0.5), ('lha', 0.25), ('the', 0.25)], [], [], [('a', 0.32934131736526945), ('him', 0.0718562874251497), ('two', 0.0718562874251497), ('only', 0.029940119760479042), ('just', 0.029940119760479042), ('means', 0.023952095808383235), ('money', 0.017964071856287425), ('force', 0.017964071856287425), ('the', 0.017964071856287425), ('barely', 0.011976047904191617), ('earth', 0.011976047904191617), ('all', 0.011976047904191617), ('no', 0.011976047904191617), ('applicants', 0.011976047904191617), ('capital', 0.011976047904191617), ('in\\xad', 0.011976047904191617), ('capacity', 0.005988023952095809), ('corncobs', 0.005988023952095809), ('water', 0.005988023952095809), ('stabling', 0.005988023952095809)], [], [], [('arc', 1.0)], [('as', 0.7678571428571429), ('that', 0.11904761904761904), ('ns', 0.020833333333333332), ('confident', 0.008928571428571428), ('sure,', 0.005952380952380952), ('that,', 0.005952380952380952), ('aa', 0.005952380952380952), ('sure', 0.005952380952380952), ('alike;', 0.002976190476190476), ('a3', 0.002976190476190476), ('.as', 0.002976190476190476), ('thst', 0.002976190476190476), ('\"as', 0.002976190476190476), ('a', 0.002976190476190476), ('sbuah', 0.002976190476190476), ('bad.', 0.002976190476190476), ('its', 0.002976190476190476), ('tbat', 0.002976190476190476), ('aggravated', 0.002976190476190476), ('defrauded', 0.002976190476190476)], [], [('the', 0.46712158808933), ('show', 0.25930521091811415), ('shew', 0.04404466501240695), ('this', 0.03163771712158809), ('tho', 0.020471464019851116), ('our', 0.018610421836228287), ('a', 0.013027295285359801), ('tbe', 0.009305210918114143), ('their', 0.008064516129032258), ('that', 0.008064516129032258), ('any', 0.00620347394540943), ('said', 0.004342431761786601), ('immediately', 0.004342431761786601), ('find', 0.003101736972704715), ('tlie', 0.003101736972704715), ('some', 0.003101736972704715), ('what', 0.0024813895781637717), ('give', 0.0024813895781637717), ('chow', 0.0018610421836228288), ('snow', 0.0018610421836228288)], [], [], [('to', 0.7446808510638298), ('a', 0.10638297872340426), ('and', 0.0425531914893617), ('for', 0.02127659574468085), ('world,', 0.02127659574468085), ('the', 0.02127659574468085), ('uud', 0.02127659574468085), ('¦', 0.02127659574468085)], [('There', 0.5), ('To', 0.5)], [('to', 0.5135135135135135), ('that', 0.1891891891891892), ('Almighty', 0.1891891891891892), ('for', 0.05405405405405406), ('thai', 0.02702702702702703), ('the', 0.02702702702702703)], [('as', 0.3048780487804878), ('posted', 0.059233449477351915), ('up', 0.050522648083623695), ('informed', 0.050522648083623695), ('started', 0.03832752613240418), ('known', 0.03832752613240418), ('down', 0.01916376306620209), ('fed', 0.017421602787456445), ('Informed', 0.017421602787456445), ('represented', 0.0156794425087108), ('out', 0.013937282229965157), ('back', 0.010452961672473868), ('along', 0.010452961672473868), ('and', 0.008710801393728223), ('established', 0.006968641114982578), ('that', 0.006968641114982578), ('put', 0.005226480836236934), ('over', 0.005226480836236934), ('placed', 0.005226480836236934), ('is', 0.005226480836236934)], [], [], [], [('will', 0.8), ('to', 0.2)], [('went', 0.032606199770378874), ('carried', 0.019402985074626865), ('with-', 0.01928817451205511), ('came', 0.01791044776119403), ('find', 0.01584385763490241), ('set', 0.015499425947187142), ('pointed'
2024-04-24 02:48:35 +02:00
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# PREDICTION FOR DEV-0\n",
"\n",
2024-04-28 20:32:46 +02:00
"dataframe = pd.read_csv(directory_dev_0, sep='\\t', header=None, names=['FileId', 'Year', 'LeftContext', 'RightContext'], quoting=csv.QUOTE_NONE)\n",
"dataframe = dataframe.replace(r'\\\\r|\\\\n|\\n|\\\\t', ' ', regex=True)\n",
2024-04-24 02:48:35 +02:00
"\n",
"left_text = dataframe['LeftContext'].apply(lambda l: re.split(r\"\\s+\", l)).to_list()\n",
"right_text = dataframe['RightContext'].apply(lambda l: re.split(r\"\\s+\", l)).to_list()\n",
"\n",
"lines = zip(left_text, right_text)\n",
"lines = list(map(lambda l: model.generate_text(l[0], l[1], False), tqdm(lines)))\n",
"print(lines[:100])"
]
},
{
"cell_type": "code",
"execution_count": 9,
2024-04-24 02:48:35 +02:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 10519/10519 [00:00<00:00, 106254.34it/s]\n"
2024-04-24 02:48:35 +02:00
]
}
],
"source": [
"with open(\"dev-0/out.tsv\", \"w\", encoding=\"UTF-8\") as file:\n",
" result = \"\\n\".join(list(map(lambda l: convert_predictions(l), tqdm(lines))))\n",
" file.write(result)\n",
" file.close()"
]
},
{
"cell_type": "code",
"execution_count": 10,
2024-04-24 02:48:35 +02:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"7414it [00:16, 457.43it/s]"
2024-04-24 02:48:35 +02:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[], [], [('the', 0.9), ('tho', 0.1)], [('man', 0.02725228204788993), ('plan', 0.012567799973541474), ('trial', 0.010715703135335361), ('living', 0.009921947347532743), ('statement', 0.009525069453631433), ('law', 0.008334435771927504), ('class', 0.008202143140627068), ('time', 0.007937557878026195), ('government', 0.005953168408519645), ('bill', 0.0054239978833179), ('year', 0.0054239978833179), ('question', 0.005291705252017462), ('sensation', 0.005291705252017462), ('day', 0.005159412620717026), ('corporation,', 0.005159412620717026), ('little', 0.0050271199894165895), ('vote', 0.004894827358116153), ('single', 0.004762534726815717), ('means', 0.00423336420161397), ('speech', 0.004101071570313534)], [], [('to', 0.16666666666666666), ('here', 0.16666666666666666), ('youngsters,', 0.08333333333333333), ('vines', 0.08333333333333333), ('material', 0.08333333333333333), ('plaster,', 0.08333333333333333), ('fabrics', 0.08333333333333333), ('mist', 0.08333333333333333), ('arms,', 0.08333333333333333), ('close,', 0.08333333333333333)], [('tho', 1.0)], [], [('employed', 0.045454545454545456), ('alarmed', 0.045454545454545456), ('burned', 0.045454545454545456), ('put', 0.045454545454545456), ('above', 0.045454545454545456), ('early', 0.045454545454545456), ('visible', 0.045454545454545456), ('hanging', 0.045454545454545456), ('allowed', 0.045454545454545456), ('received', 0.045454545454545456), ('seated', 0.045454545454545456), ('outlined', 0.045454545454545456), ('drunk', 0.045454545454545456), ('typewritten', 0.045454545454545456), ('conferred', 0.045454545454545456), ('landed', 0.045454545454545456), ('gathered', 0.045454545454545456), ('poured', 0.045454545454545456), ('living', 0.045454545454545456), ('apartments', 0.045454545454545456)], [], [('since', 0.15228966986155484), ('in', 0.042243521476748314), ('had', 0.03123890663826766), ('on', 0.029463968761093362), ('been', 0.028399006034788784), ('saw', 0.012069577564785232), ('seen', 0.011714589989350373), ('to', 0.011714589989350373), ('be', 0.011004614838480652), ('that', 0.008874689385871494), ('see', 0.008874689385871494), ('heard', 0.008519701810436636), ('of', 0.007809726659566915), ('In', 0.007454739084132056), ('forget', 0.0067447639332623354), ('before', 0.006389776357827476), ('reached', 0.006389776357827476), ('visited', 0.006389776357827476), ('at', 0.006034788782392616), ('for', 0.006034788782392616)], [], [('days', 0.122), ('minutes', 0.072), ('years', 0.07), ('months', 0.05), ('weeks', 0.05), ('moments', 0.042), ('cents', 0.036), ('at', 0.03), ('dollars', 0.028), ('days,', 0.026), ('years,', 0.02), ('hours', 0.016), ('day', 0.012), ('that', 0.01), ('weeks,', 0.01), ('months,', 0.01), ('in', 0.01), ('furnish', 0.008), ('miles', 0.008), ('of', 0.008)], [], [('there', 1.0)], [('from', 1.0)], [('which', 0.45422535211267606), ('what', 0.05985915492957746), ('whom', 0.03873239436619718), ('until', 0.02992957746478873), ('as', 0.017605633802816902), ('them', 0.011443661971830986), ('if', 0.008802816901408451), ('till', 0.008802816901408451), ('and', 0.007922535211267605), ('when', 0.007922535211267605), ('this', 0.007042253521126761), ('them,', 0.006161971830985915), ('whether', 0.006161971830985915), ('that', 0.006161971830985915), ('board,', 0.00528169014084507), ('Monday', 0.00528169014084507), ('hand,', 0.00528169014084507), ('Saturday', 0.00528169014084507), ('hand', 0.00528169014084507), ('everything', 0.0044014084507042256)], [('this', 0.4745762711864407), ('any', 0.11016949152542373), ('every', 0.0903954802259887), ('the', 0.08757062146892655), ('that', 0.08192090395480225), ('some', 0.022598870056497175), ('a', 0.01977401129943503), ('one', 0.01977401129943503), ('what', 0.011299435028248588), ('no', 0.00847457627118644), ('each', 0.00847457627118644), (\"th's\", 0.005649717514124294), ('nny', 0.002824858757062147), ('thie', 0.002824858757062147), ('aay', 0.002824858757062147), ('all', 0.002824858757062147), ('ono', 0.002824858757062147), ('ever?', 0.002824858757062147), ('driving', 0.002824858757062147), ('tha', 0.002824858757062147)], [('little'
2024-04-24 02:48:35 +02:00
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# PREDICTION FOR TEST-A\n",
"\n",
2024-04-28 20:32:46 +02:00
"dataframe = pd.read_csv(directory_test_A, sep='\\t', header=None, names=['FileId', 'Year', 'LeftContext', 'RightContext'], quoting=csv.QUOTE_NONE)\n",
"dataframe = dataframe.replace(r'\\\\r|\\\\n|\\n|\\\\t', ' ', regex=True)\n",
2024-04-24 02:48:35 +02:00
"\n",
"left_text = dataframe['LeftContext'].apply(lambda l: re.split(r\"\\s+\", l)).to_list()\n",
"right_text = dataframe['RightContext'].apply(lambda l: re.split(r\"\\s+\", l)).to_list()\n",
"\n",
"lines = zip(left_text, right_text)\n",
"lines = list(map(lambda l: model.generate_text(l[0], l[1], False), tqdm(lines)))\n",
"print(lines[:100])"
]
},
{
"cell_type": "code",
"execution_count": 11,
2024-04-24 02:48:35 +02:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 7414/7414 [00:00<00:00, 112933.06it/s]\n"
2024-04-24 02:48:35 +02:00
]
}
],
"source": [
"with open(\"test-A/out.tsv\", \"w\", encoding=\"UTF-8\") as file:\n",
" result = \"\\n\".join(list(map(lambda l: convert_predictions(l), tqdm(lines))))\n",
" file.write(result)\n",
" file.close()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
2024-04-24 02:48:35 +02:00
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
2024-04-24 02:48:35 +02:00
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
2024-04-24 02:48:35 +02:00
},
"nbformat": 4,
"nbformat_minor": 4
2024-04-24 02:48:35 +02:00
}