challenging-america-word-ga.../main.ipynb

353 lines
49 KiB
Plaintext
Raw Normal View History

2024-04-24 02:48:35 +02:00
{
"cells": [
{
"cell_type": "code",
2024-04-28 20:32:46 +02:00
"execution_count": 2,
2024-04-24 02:48:35 +02:00
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import csv\n",
"import os\n",
"import re\n",
"import random\n",
"from collections import Counter, defaultdict\n",
"import nltk\n",
"import math\n",
"from tqdm import tqdm"
]
},
{
"cell_type": "code",
2024-04-28 20:32:46 +02:00
"execution_count": 3,
2024-04-24 02:48:35 +02:00
"metadata": {},
"outputs": [],
"source": [
"directory = \"train/in.tsv.xz\"\n",
"directory_dev_0 = \"dev-0/in.tsv.xz\"\n",
"directory_test_A = \"test-A/in.tsv.xz\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### MODEL N-GRAM"
]
},
{
"cell_type": "code",
2024-04-28 20:32:46 +02:00
"execution_count": 4,
2024-04-24 02:48:35 +02:00
"metadata": {},
"outputs": [],
"source": [
"class Model():\n",
" \n",
" def __init__(self, vocab_size=30_000, UNK_token= '<UNK>', n=3):\n",
" if (n <= 1 or n % 2 == 0):\n",
" raise \"change N value !!!\"\n",
" self.n = n\n",
" self.vocab_size = vocab_size\n",
" self.UNK_token = UNK_token\n",
" \n",
" def train(self, corpus:list) -> None:\n",
" if(self.n > 1):\n",
" self.n_grams = list(nltk.ngrams(corpus, n=self.n))\n",
" else:\n",
" self.n_grams = corpus\n",
" self.counter = Counter(self.n_grams)\n",
" self.words_counter = Counter(corpus)\n",
" self.all_quantities = Counter([gram[:math.floor(self.n/2)]+gram[math.ceil(self.n/2):] for gram in self.n_grams])\n",
"\n",
" self.all_grams = defaultdict(set)\n",
"\n",
" for gram in tqdm(self.n_grams):\n",
" previous_words = tuple(gram[:math.floor(self.n/2)])\n",
" next_words = tuple(gram[math.ceil(self.n/2):])\n",
" word = gram[math.floor(self.n/2)]\n",
" self.all_grams[(previous_words, next_words)].add(word)\n",
"\n",
" def get_conditional_prob_for_word(self, left_text: list, right_text: list, word: str) -> float:\n",
" previous_words = tuple(left_text[-math.floor(self.n/2):])\n",
" next_words = tuple(right_text[:math.floor(self.n/2)])\n",
" quantity = self.counter[previous_words + tuple([word]) + next_words]\n",
" all_quantity = self.all_quantities[previous_words + next_words]\n",
" if (all_quantity <= 0):\n",
" return 0\n",
" return quantity/all_quantity\n",
" \n",
" def get_prob_for_text(self, text: list) -> float:\n",
" prob = 1\n",
" for gram in list(nltk.ngrams(text, self.n)):\n",
" prob *= self.get_conditional_prob_for_word(gram[:math.floor(self.n/2)], gram[math.ceil(self.n/2):], gram[math.floor(self.n/2)])\n",
" return prob\n",
" \n",
" def most_probable_words(self, left_text: list, right_text: list) -> str:\n",
" previous_words = tuple(left_text[-math.floor(self.n/2):])\n",
" next_words = tuple(right_text[:math.floor(self.n/2)])\n",
" all_words = self.all_grams[(previous_words, next_words)]\n",
" best_words = []\n",
" for word in all_words:\n",
" probability = self.get_conditional_prob_for_word(list(previous_words), list(next_words), word)\n",
" best_words.append((word, probability))\n",
" return sorted(best_words, key=(lambda l: l[1]), reverse=True)[:20]\n",
" \n",
" def generate_text(self, text_beggining:list, text_ending:list, greedy: bool) -> list:\n",
" words = self.most_probable_words(text_beggining, text_ending)\n",
" return words\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### DATASET"
]
},
{
"cell_type": "code",
2024-04-28 20:32:46 +02:00
"execution_count": 5,
2024-04-24 02:48:35 +02:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-04-28 20:32:46 +02:00
"['came', 'fiom', 'the', 'last', 'place', 'to', 'this', 'place,', 'and', 'this', 'place', 'is', 'Where', 'We', 'Were,', 'this', 'is', 'the', 'first', 'road', 'I', 'ever', 'was', 'on', 'where', 'you', 'can', 'ride', 'elsewhere', 'from', 'anywhere', 'and', 'be', 'nowhere.', 'He', 'says,', 'while', 'this', 'train', 'stops', 'every-', 'where,', 'it', 'never', 'stops', 'anywhere', 'un-', 'less', 'its', 'somewhere.', 'Well,', 'I', 'says,', \"I'm\", 'glad', 'to', 'hear', 'that,', 'but,', 'accord-', 'ing', 'to', 'your', 'figures,', 'I', 'left', 'myself', 'where', '1', 'was,', 'which', 'is', 'five', 'miles', 'near-', 'er', 'to', 'myself', 'than', 'I', 'was', 'when', 'we', 'were', 'where', 'we', 'are', 'now.', 'We', 'have', 'now', 'reached', 'Slidell.', \"That's\", 'a', 'fine', 'place.', 'The', 'people', 'down']\n"
2024-04-24 02:48:35 +02:00
]
}
],
"source": [
2024-04-28 20:32:46 +02:00
"dataframeList = pd.read_csv(directory, sep='\\t', header=None, names=['FileId', 'Year', 'LeftContext', 'RightContext'], quoting=csv.QUOTE_NONE, chunksize=10000)\n",
2024-04-24 02:48:35 +02:00
"\n",
2024-04-28 20:32:46 +02:00
"expectedList = pd.read_csv(directory, sep='\\t', header=None, names=['Word'], quoting=csv.QUOTE_NONE, chunksize=10000)\n",
2024-04-24 02:48:35 +02:00
"\n",
"DATASET = \"\"\n",
"\n",
"for number, (dataframe, expected) in enumerate(zip(dataframeList, expectedList)):\n",
2024-04-28 20:32:46 +02:00
" dataframe = dataframe.replace(r'\\\\r|\\\\n|\\n|\\\\t', ' ', regex=True)\n",
"\n",
2024-04-24 02:48:35 +02:00
" left_text = dataframe['LeftContext'].to_list()\n",
" right_text = dataframe['RightContext'].to_list()\n",
" word = expected['Word'].to_list()\n",
"\n",
" lines = zip(left_text, word, right_text)\n",
" lines = list(map(lambda l: \" \".join(l), lines))\n",
" DATASET = DATASET + \" \".join(lines)\n",
"\n",
"FINAL_DATASET = re.split(r\"\\s+\", DATASET)\n",
"print(FINAL_DATASET[:100])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### TRAIN"
]
},
{
"cell_type": "code",
2024-04-28 20:32:46 +02:00
"execution_count": 6,
2024-04-24 02:48:35 +02:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-04-28 20:32:46 +02:00
"100%|██████████| 199099663/199099663 [11:00<00:00, 301572.30it/s] \n"
2024-04-24 02:48:35 +02:00
]
}
],
"source": [
"model_3gram = Model(n = 3)\n",
"model_3gram.train(FINAL_DATASET)"
]
},
{
"cell_type": "code",
2024-04-28 20:32:46 +02:00
"execution_count": 7,
2024-04-24 02:48:35 +02:00
"metadata": {},
"outputs": [],
"source": [
"model = model_3gram"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### PREDICTION"
]
},
{
"cell_type": "code",
2024-04-28 20:32:46 +02:00
"execution_count": 8,
2024-04-24 02:48:35 +02:00
"metadata": {},
"outputs": [],
"source": [
"def convert_predictions(line):\n",
" sum_predictions = np.sum([pred[1] for pred in line])\n",
" result = \"\"\n",
" all_pred = 0\n",
" for word, pred in line:\n",
" new_pred = math.floor(pred / sum_predictions * 100) / 100\n",
" if(new_pred == 1.0):\n",
" new_pred = 0.99\n",
" all_pred = all_pred + new_pred\n",
" result = result + word + \":\" + str(new_pred) + \" \"\n",
" if(round(all_pred, 2) < 1):\n",
" result = result + \":\" + str(round(1 - all_pred, 2))\n",
" else:\n",
" result = result + \":\" + str(0.01)\n",
" return result"
]
},
{
"cell_type": "code",
2024-04-28 20:32:46 +02:00
"execution_count": 9,
2024-04-24 02:48:35 +02:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-04-28 20:32:46 +02:00
"10519it [00:51, 206.24it/s]"
2024-04-24 02:48:35 +02:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-04-28 20:32:46 +02:00
"[[], [('passage', 0.005959701068962256), ('growth', 0.005202913631633715), ('successors', 0.005108315201967647), ('place,', 0.004682622268470343), ('use,', 0.004115031690473938), ('head', 0.003452842682811465), ('own', 0.003452842682811465), ('own,', 0.003310945038312364), ('power', 0.00326364582347933), ('place', 0.0032163466086462963), ('functions,', 0.0031690473938132627), ('members', 0.0031690473938132627), ('work', 0.0030271497493141613), ('value', 0.00288525210481506), ('principles', 0.002743354460315959), ('strength', 0.002696055245482925), ('beauty', 0.002459559171317756), ('action,', 0.0023649607416516886), ('history', 0.0023176615268186546), ('value,', 0.002270362311985621)], [('a', 0.5714285714285714), ('the', 0.2857142857142857), ('lha', 0.14285714285714285)], [], [], [('a', 0.31221719457013575), ('him', 0.07239819004524888), ('two', 0.06334841628959276), ('means', 0.03167420814479638), ('just', 0.027149321266968326), ('only', 0.02262443438914027), ('money', 0.02262443438914027), ('good\\\\nand', 0.01809954751131222), ('all', 0.01809954751131222), ('force', 0.01809954751131222), ('the', 0.013574660633484163), ('capital', 0.013574660633484163), ('no', 0.013574660633484163), ('barely', 0.013574660633484163), ('capacity', 0.00904977375565611), ('bills', 0.00904977375565611), ('scarcely', 0.00904977375565611), ('boats', 0.00904977375565611), ('stabling', 0.00904977375565611), ('applicants', 0.00904977375565611)], [], [], [('arc', 1.0)], [('as', 0.7895791583166333), ('that', 0.11022044088176353), ('ns', 0.018036072144288578), ('sure,', 0.008016032064128256), ('sure', 0.006012024048096192), ('confident', 0.006012024048096192), ('defrauded', 0.004008016032064128), ('that,', 0.004008016032064128), ('r.s', 0.004008016032064128), ('us', 0.004008016032064128), ('but', 0.004008016032064128), ('tbat', 0.004008016032064128), ('thst', 0.004008016032064128), ('a>', 0.004008016032064128), ('its', 0.002004008016032064), ('ts', 0.002004008016032064), ('a3', 0.002004008016032064), ('alike;', 0.002004008016032064), ('\"as', 0.002004008016032064), ('bad.', 0.002004008016032064)], [], [('the', 0.4470046082949309), ('show', 0.25161290322580643), ('shew', 0.04470046082949309), ('this', 0.027188940092165898), ('tho', 0.02165898617511521), ('our', 0.01889400921658986), ('a', 0.013364055299539171), ('tbe', 0.00967741935483871), ('that', 0.009216589861751152), ('their', 0.009216589861751152), ('any', 0.005529953917050691), ('immediately', 0.004147465437788019), ('said', 0.004147465437788019), ('tlie', 0.003686635944700461), ('some', 0.0027649769585253456), ('his', 0.0027649769585253456), ('what', 0.0027649769585253456), ('find', 0.002304147465437788), ('thow', 0.002304147465437788), ('snow', 0.002304147465437788)], [], [], [('to', 0.71875), ('a', 0.109375), ('except\\\\nto', 0.03125), ('and', 0.03125), ('¦', 0.03125), ('world,', 0.015625), ('the', 0.015625), ('uud', 0.015625), ('for', 0.015625), ('efcept\\\\nta', 0.015625)], [('There', 0.5), ('To', 0.5)], [('to', 0.5416666666666666), ('Almighty', 0.20833333333333334), ('that', 0.16666666666666666), ('for', 0.041666666666666664), ('thai', 0.020833333333333332), ('the', 0.020833333333333332)], [('as', 0.29831387808041504), ('posted', 0.05188067444876784), ('informed', 0.04798962386511025), ('up', 0.04669260700389105), ('started', 0.03501945525291829), ('known', 0.03501945525291829), ('fed', 0.016861219195849545), ('down', 0.01556420233463035), ('Informed', 0.014267185473411154), ('represented', 0.01297016861219196), ('along', 0.011673151750972763), ('out', 0.011673151750972763), ('back', 0.010376134889753566), ('and', 0.010376134889753566), ('established', 0.009079118028534372), ('that', 0.007782101167315175), ('aa', 0.007782101167315175), ('satisfied', 0.00648508430609598), ('is', 0.005188067444876783), ('advanced', 0.005188067444876783)], [], [], [], [('will', 0.7142857142857143), ('to', 0.2857142857142857)], [('went', 0.031111497349439472), ('carried', 0.018510471886677673), ('came', 0.016424784913530895), ('find', 0.015642652298600852), ('set', 0.014773616059789694), ('with-',
2024-04-24 02:48:35 +02:00
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# PREDICTION FOR DEV-0\n",
"\n",
2024-04-28 20:32:46 +02:00
"dataframe = pd.read_csv(directory_dev_0, sep='\\t', header=None, names=['FileId', 'Year', 'LeftContext', 'RightContext'], quoting=csv.QUOTE_NONE)\n",
"dataframe = dataframe.replace(r'\\\\r|\\\\n|\\n|\\\\t', ' ', regex=True)\n",
2024-04-24 02:48:35 +02:00
"\n",
"left_text = dataframe['LeftContext'].apply(lambda l: re.split(r\"\\s+\", l)).to_list()\n",
"right_text = dataframe['RightContext'].apply(lambda l: re.split(r\"\\s+\", l)).to_list()\n",
"\n",
"lines = zip(left_text, right_text)\n",
"lines = list(map(lambda l: model.generate_text(l[0], l[1], False), tqdm(lines)))\n",
"print(lines[:100])"
]
},
{
"cell_type": "code",
2024-04-28 20:32:46 +02:00
"execution_count": 10,
2024-04-24 02:48:35 +02:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-04-28 20:32:46 +02:00
"100%|██████████| 10519/10519 [00:00<00:00, 106542.75it/s]\n"
2024-04-24 02:48:35 +02:00
]
}
],
"source": [
"with open(\"dev-0/out.tsv\", \"w\", encoding=\"UTF-8\") as file:\n",
" result = \"\\n\".join(list(map(lambda l: convert_predictions(l), tqdm(lines))))\n",
" file.write(result)\n",
" file.close()"
]
},
{
"cell_type": "code",
2024-04-28 20:32:46 +02:00
"execution_count": 11,
2024-04-24 02:48:35 +02:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-04-28 20:32:46 +02:00
"7414it [00:25, 290.39it/s]"
2024-04-24 02:48:35 +02:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-04-28 20:32:46 +02:00
"[[], [], [('the', 0.9090909090909091), ('tho', 0.09090909090909091)], [('man', 0.022790697674418603), ('plan', 0.011348837209302326), ('trial', 0.009674418604651163), ('living', 0.008744186046511627), ('statement', 0.008), ('law', 0.007720930232558139), ('class', 0.007162790697674419), ('time', 0.006232558139534884), ('year', 0.005767441860465117), ('vote', 0.005488372093023256), ('government', 0.005209302325581395), ('single', 0.005209302325581395), ('day', 0.005023255813953489), ('question', 0.004930232558139535), ('sensation', 0.0048372093023255815), ('bill', 0.004651162790697674), ('little', 0.004372093023255814), ('corporation,', 0.003813953488372093), ('way', 0.003813953488372093), ('means', 0.003813953488372093)], [], [('here', 0.17647058823529413), ('to', 0.17647058823529413), ('plaster,', 0.11764705882352941), ('arms,', 0.11764705882352941), ('youngsters,', 0.058823529411764705), ('mist', 0.058823529411764705), ('baby\\\\nlingers', 0.058823529411764705), ('close,', 0.058823529411764705), ('vines', 0.058823529411764705), ('material', 0.058823529411764705), ('fabrics', 0.058823529411764705)], [('tho', 1.0)], [], [('hanging', 0.06666666666666667), ('used', 0.06666666666666667), ('alarmed', 0.06666666666666667), ('seated', 0.06666666666666667), ('landed', 0.06666666666666667), ('received', 0.06666666666666667), ('outlined', 0.06666666666666667), ('drunk', 0.03333333333333333), ('employed', 0.03333333333333333), ('above', 0.03333333333333333), ('gathered', 0.03333333333333333), ('put', 0.03333333333333333), ('burned', 0.03333333333333333), ('early', 0.03333333333333333), ('conferred', 0.03333333333333333), ('apartments', 0.03333333333333333), ('living', 0.03333333333333333), ('poured', 0.03333333333333333), ('practical-\\\\nly', 0.03333333333333333), ('visible', 0.03333333333333333)], [], [('since', 0.1415119720204466), ('in', 0.03847188592951305), ('on', 0.02824858757062147), ('been', 0.026096314231907454), ('had', 0.024751143395211193), ('seen', 0.011299435028248588), ('to', 0.011030400860909336), ('saw', 0.010761366693570083), ('be', 0.009954264191552327), ('see', 0.008340059187516815), ('that', 0.008071025020177562), ('heard', 0.007532956685499058), ('reached', 0.007263922518159807), ('In', 0.007263922518159807), ('before', 0.006994888350820554), ('of', 0.006725854183481302), ('forget', 0.006187785848802798), ('at', 0.005918751681463546), ('occupied', 0.005918751681463546), ('have', 0.005649717514124294)], [], [('days', 0.1013157894736842), ('minutes', 0.06710526315789474), ('years', 0.05789473684210526), ('weeks', 0.04473684210526316), ('moments', 0.042105263157894736), ('months', 0.04078947368421053), ('cents', 0.02894736842105263), ('at', 0.02631578947368421), ('dollars', 0.019736842105263157), ('days,', 0.018421052631578946), ('years,', 0.018421052631578946), ('hours', 0.013157894736842105), ('day', 0.011842105263157895), ('months,', 0.010526315789473684), ('furnish', 0.010526315789473684), ('in', 0.010526315789473684), ('weeks,', 0.007894736842105263), ('with', 0.006578947368421052), ('that', 0.006578947368421052), ('hundreds', 0.005263157894736842)], [], [('there', 1.0)], [('from', 1.0)], [('which', 0.40717029449423814), ('what', 0.05505761843790013), ('whom', 0.03393085787451985), ('until', 0.029449423815621), ('as', 0.01792573623559539), ('them', 0.01088348271446863), ('till', 0.009603072983354673), ('when', 0.0076824583866837385), ('this', 0.007042253521126761), ('if', 0.007042253521126761), ('and', 0.007042253521126761), ('them,', 0.006402048655569782), ('hand', 0.006402048655569782), ('all', 0.005761843790012804), ('hand,', 0.005121638924455826), ('whether', 0.005121638924455826), ('Saturday', 0.005121638924455826), ('that', 0.005121638924455826), ('board,', 0.005121638924455826), ('where', 0.005121638924455826)], [('this', 0.4672489082969432), ('any', 0.09606986899563319), ('every', 0.09388646288209607), ('the', 0.07860262008733625), ('that', 0.07860262008733625), ('a', 0.024017467248908297), ('one', 0.021834061135371178), ('some', 0.019650655021834062), ('each', 0.013100436681222707), ('an\\\
2024-04-24 02:48:35 +02:00
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# PREDICTION FOR TEST-A\n",
"\n",
2024-04-28 20:32:46 +02:00
"dataframe = pd.read_csv(directory_test_A, sep='\\t', header=None, names=['FileId', 'Year', 'LeftContext', 'RightContext'], quoting=csv.QUOTE_NONE)\n",
"dataframe = dataframe.replace(r'\\\\r|\\\\n|\\n|\\\\t', ' ', regex=True)\n",
2024-04-24 02:48:35 +02:00
"\n",
"left_text = dataframe['LeftContext'].apply(lambda l: re.split(r\"\\s+\", l)).to_list()\n",
"right_text = dataframe['RightContext'].apply(lambda l: re.split(r\"\\s+\", l)).to_list()\n",
"\n",
"lines = zip(left_text, right_text)\n",
"lines = list(map(lambda l: model.generate_text(l[0], l[1], False), tqdm(lines)))\n",
"print(lines[:100])"
]
},
{
"cell_type": "code",
2024-04-28 20:32:46 +02:00
"execution_count": 12,
2024-04-24 02:48:35 +02:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-04-28 20:32:46 +02:00
"100%|██████████| 7414/7414 [00:00<00:00, 114060.60it/s]\n"
2024-04-24 02:48:35 +02:00
]
}
],
"source": [
"with open(\"test-A/out.tsv\", \"w\", encoding=\"UTF-8\") as file:\n",
" result = \"\\n\".join(list(map(lambda l: convert_predictions(l), tqdm(lines))))\n",
" file.write(result)\n",
" file.close()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "python11",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}