challenging-america-word-ga.../main.ipynb

349 lines
41 KiB
Plaintext
Raw Normal View History

2024-04-24 02:48:35 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import csv\n",
"import os\n",
"import re\n",
"import random\n",
"from collections import Counter, defaultdict\n",
"import nltk\n",
"import math\n",
"from tqdm import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"directory = \"train/in.tsv.xz\"\n",
"directory_dev_0 = \"dev-0/in.tsv.xz\"\n",
"directory_test_A = \"test-A/in.tsv.xz\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### MODEL N-GRAM"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"class Model():\n",
" \n",
" def __init__(self, vocab_size=30_000, UNK_token= '<UNK>', n=3):\n",
" if (n <= 1 or n % 2 == 0):\n",
" raise \"change N value !!!\"\n",
" self.n = n\n",
" self.vocab_size = vocab_size\n",
" self.UNK_token = UNK_token\n",
" \n",
" def train(self, corpus:list) -> None:\n",
" if(self.n > 1):\n",
" self.n_grams = list(nltk.ngrams(corpus, n=self.n))\n",
" else:\n",
" self.n_grams = corpus\n",
" self.counter = Counter(self.n_grams)\n",
" self.words_counter = Counter(corpus)\n",
" self.all_quantities = Counter([gram[:math.floor(self.n/2)]+gram[math.ceil(self.n/2):] for gram in self.n_grams])\n",
"\n",
" self.all_grams = defaultdict(set)\n",
"\n",
" for gram in tqdm(self.n_grams):\n",
" previous_words = tuple(gram[:math.floor(self.n/2)])\n",
" next_words = tuple(gram[math.ceil(self.n/2):])\n",
" word = gram[math.floor(self.n/2)]\n",
" self.all_grams[(previous_words, next_words)].add(word)\n",
"\n",
" def get_conditional_prob_for_word(self, left_text: list, right_text: list, word: str) -> float:\n",
" previous_words = tuple(left_text[-math.floor(self.n/2):])\n",
" next_words = tuple(right_text[:math.floor(self.n/2)])\n",
" quantity = self.counter[previous_words + tuple([word]) + next_words]\n",
" all_quantity = self.all_quantities[previous_words + next_words]\n",
" if (all_quantity <= 0):\n",
" return 0\n",
" return quantity/all_quantity\n",
" \n",
" def get_prob_for_text(self, text: list) -> float:\n",
" prob = 1\n",
" for gram in list(nltk.ngrams(text, self.n)):\n",
" prob *= self.get_conditional_prob_for_word(gram[:math.floor(self.n/2)], gram[math.ceil(self.n/2):], gram[math.floor(self.n/2)])\n",
" return prob\n",
" \n",
" def most_probable_words(self, left_text: list, right_text: list) -> str:\n",
" previous_words = tuple(left_text[-math.floor(self.n/2):])\n",
" next_words = tuple(right_text[:math.floor(self.n/2)])\n",
" all_words = self.all_grams[(previous_words, next_words)]\n",
" best_words = []\n",
" for word in all_words:\n",
" probability = self.get_conditional_prob_for_word(list(previous_words), list(next_words), word)\n",
" best_words.append((word, probability))\n",
" return sorted(best_words, key=(lambda l: l[1]), reverse=True)[:20]\n",
" \n",
" def generate_text(self, text_beggining:list, text_ending:list, greedy: bool) -> list:\n",
" words = self.most_probable_words(text_beggining, text_ending)\n",
" return words\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### DATASET"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['came', 'fiom', 'the', 'last', 'place', 'to', 'thisnplace,', 'and', 'this', 'place', 'is', 'Where', 'WenWere,', 'this', 'is', 'the', 'first', 'road', 'I', 'evernwas', 'on', 'where', 'you', 'can', 'ride', 'elsewherenfrom', 'anywhere', 'and', 'be', 'nowhere.nHe', 'says,', 'while', 'this', 'train', 'stops', 'every-nwhere,', 'it', 'never', 'stops', 'anywhere', 'un-nless', 'its', 'somewhere.', 'Well,', 'I', \"says,nI'm\", 'glad', 'to', 'hear', 'that,', 'but,', 'accord-ning', 'to', 'your', 'figures,', 'I', 'left', 'myselfnwhere', '1', 'was,', 'which', 'is', 'five', 'miles', 'near-ner', 'to', 'myself', 'than', 'I', 'was', 'when', 'wenwere', 'where', 'we', 'are', 'now.nWe', 'have', 'now', 'reached', \"Slidell.nThat's\", 'a', 'fine', 'place.', 'The', 'peoplendown', 'there', 'remind', 'me', 'of', 'bananas-nthey', 'come', 'and', 'go', 'in', 'bunches.', '811-ndell', 'used', 'to', 'be', 'noted']\n"
]
}
],
"source": [
"dataframeList = pd.read_csv(directory, sep='\\t', header=None, names=['FileId', 'Year', 'LeftContext', 'RightContext'], escapechar='\\\\', quoting=csv.QUOTE_NONE, chunksize=10000)\n",
"\n",
"expectedList = pd.read_csv(directory, sep='\\t', header=None, names=['Word'], escapechar='\\\\', quoting=csv.QUOTE_NONE, chunksize=10000)\n",
"\n",
"DATASET = \"\"\n",
"\n",
"for number, (dataframe, expected) in enumerate(zip(dataframeList, expectedList)):\n",
" left_text = dataframe['LeftContext'].to_list()\n",
" right_text = dataframe['RightContext'].to_list()\n",
" word = expected['Word'].to_list()\n",
"\n",
" lines = zip(left_text, word, right_text)\n",
" lines = list(map(lambda l: \" \".join(l), lines))\n",
" DATASET = DATASET + \" \".join(lines)\n",
"\n",
"FINAL_DATASET = re.split(r\"\\s+\", DATASET)\n",
"print(FINAL_DATASET[:100])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### TRAIN"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 180304236/180304236 [13:57<00:00, 215160.70it/s] \n"
]
}
],
"source": [
"model_3gram = Model(n = 3)\n",
"model_3gram.train(FINAL_DATASET)"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
"model = model_3gram"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### PREDICTION"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
"def convert_predictions(line):\n",
" sum_predictions = np.sum([pred[1] for pred in line])\n",
" result = \"\"\n",
" all_pred = 0\n",
" for word, pred in line:\n",
" new_pred = math.floor(pred / sum_predictions * 100) / 100\n",
" if(new_pred == 1.0):\n",
" new_pred = 0.99\n",
" all_pred = all_pred + new_pred\n",
" result = result + word + \":\" + str(new_pred) + \" \"\n",
" if(round(all_pred, 2) < 1):\n",
" result = result + \":\" + str(round(1 - all_pred, 2))\n",
" else:\n",
" result = result + \":\" + str(0.01)\n",
" return result"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"10519it [00:31, 330.85it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[], [('passage', 0.005712530712530713), ('growth', 0.0049754299754299755), ('use,', 0.004545454545454545), ('functions,', 0.003931203931203931), ('successors', 0.0036855036855036856), ('place,', 0.0035626535626535625), ('own,', 0.0031941031941031942), ('own', 0.0031941031941031942), ('head', 0.00300982800982801), ('power', 0.0029484029484029483), ('action,', 0.002764127764127764), ('work', 0.0025798525798525797), ('members', 0.0025184275184275185), ('value,', 0.0025184275184275185), ('value', 0.002334152334152334), ('vicinity,', 0.002334152334152334), ('name', 0.002334152334152334), ('place', 0.0022727272727272726), ('beauty', 0.0022113022113022115), ('strength', 0.0022113022113022115)], [], [], [('undertook', 1.0)], [('a', 0.2926829268292683), ('two', 0.08536585365853659), ('goodnand', 0.07317073170731707), ('him', 0.054878048780487805), ('means', 0.036585365853658534), ('money', 0.03048780487804878), ('all', 0.024390243902439025), ('force', 0.024390243902439025), ('just', 0.018292682926829267), ('capacity', 0.012195121951219513), ('scarcely', 0.012195121951219513), ('stabling', 0.012195121951219513), ('guns', 0.012195121951219513), ('barely', 0.012195121951219513), ('boats', 0.012195121951219513), ('h', 0.012195121951219513), ('amply', 0.012195121951219513), ('decline', 0.012195121951219513), ('capital', 0.012195121951219513), ('u', 0.012195121951219513)], [], [], [], [('as', 0.7727272727272727), ('a3', 0.09090909090909091), ('that', 0.09090909090909091), ('its', 0.045454545454545456)], [], [('the', 0.4133906633906634), ('show', 0.21375921375921375), ('shew', 0.03194103194103194), ('this', 0.03194103194103194), ('tho', 0.0214987714987715), ('our', 0.016584766584766583), ('a', 0.012285012285012284), ('their', 0.009828009828009828), ('that', 0.009213759213759214), ('tbe', 0.009213759213759214), ('ascertainnthe', 0.005528255528255528), ('benthe', 0.004914004914004914), ('learnnthe', 0.004914004914004914), ('any', 0.0042997542997543), ('tlie', 0.0042997542997543), ('thenreal', 0.0036855036855036856), ('his', 0.0036855036855036856), ('what', 0.003071253071253071), ('said', 0.003071253071253071), ('immediately', 0.003071253071253071)], [], [('and', 0.0744047619047619), ('put', 0.03273809523809524), ('be', 0.026785714285714284), ('placed', 0.023809523809523808), ('again', 0.017857142857142856), ('held', 0.01488095238095238), ('engaged', 0.01488095238095238), ('pending', 0.01488095238095238), ('wrapped', 0.011904761904761904), ('started', 0.011904761904761904), ('went', 0.008928571428571428), ('got', 0.008928571428571428), ('?', 0.008928571428571428), ('roll', 0.005952380952380952), ('playing', 0.005952380952380952), ('13SJ4', 0.005952380952380952), ('specialised', 0.005952380952380952), ('anninfant', 0.005952380952380952), ('streamed.', 0.005952380952380952), ('flew', 0.005952380952380952)], [('to', 0.6538461538461539), ('a', 0.09615384615384616), ('exceptnto', 0.07692307692307693), ('¦', 0.038461538461538464), ('efceptnta', 0.038461538461538464), ('world,', 0.019230769230769232), ('.nto', 0.019230769230769232), ('the', 0.019230769230769232), ('and', 0.019230769230769232), ('anyway.n“Then', 0.019230769230769232)], [], [('to', 0.7), ('Almighty', 0.2), ('that', 0.1)], [('Knocked', 1.0)], [], [], [], [('will', 1.0)], [('went', 0.024908077333649626), ('carried', 0.016486774997034753), ('set', 0.014351796939864785), ('find', 0.014233187047799786), ('go', 0.01399596726366979), ('came', 0.013402917803344799), ('carry', 0.01197959909856482), ('pointed', 0.011505159530304827), ('come', 0.010793500177914838), ('put', 0.010556280393784841), ('get', 0.00937018147313486), ('paid', 0.009132961689004864), ('sent', 0.009014351796939865), ('started', 0.00865852212074487), ('brought', 0.007828252876289882), ('took', 0.007709642984224884), ('got', 0.007235203415964892), ('take', 0.006879373739769897), ('laid', 0.006523544063574902), ('worn', 0.006404934171509904)], [('City', 0.04151624548736462), ('city', 0.03790613718411552), ('Bay', 0.02707581227436823), ('and', 0.023465703971119134), ('avenue', 0.010830324909747292), ('andndesiri
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# PREDICTION FOR DEV-0\n",
"\n",
"dataframe = pd.read_csv(directory_dev_0, sep='\\t', header=None, names=['FileId', 'Year', 'LeftContext', 'RightContext'], escapechar='\\\\', quoting=csv.QUOTE_NONE)\n",
"\n",
"left_text = dataframe['LeftContext'].apply(lambda l: re.split(r\"\\s+\", l)).to_list()\n",
"right_text = dataframe['RightContext'].apply(lambda l: re.split(r\"\\s+\", l)).to_list()\n",
"\n",
"lines = zip(left_text, right_text)\n",
"lines = list(map(lambda l: model.generate_text(l[0], l[1], False), tqdm(lines)))\n",
"print(lines[:100])"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 10519/10519 [00:00<00:00, 111905.55it/s]\n"
]
}
],
"source": [
"with open(\"dev-0/out.tsv\", \"w\", encoding=\"UTF-8\") as file:\n",
" result = \"\\n\".join(list(map(lambda l: convert_predictions(l), tqdm(lines))))\n",
" file.write(result)\n",
" file.close()"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"7414it [00:17, 422.07it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[], [('home', 0.08333333333333333), ('decline', 0.08333333333333333), ('or', 0.08333333333333333), ('spread', 0.08333333333333333), ('is', 0.08333333333333333), ('numerous', 0.08333333333333333), ('road', 0.08333333333333333), ('owned', 0.08333333333333333), ('resides', 0.08333333333333333), ('taxes', 0.08333333333333333), ('whitely', 0.08333333333333333), ('water', 0.08333333333333333)], [], [('man', 0.01770717393503997), ('plan', 0.009106546595163412), ('trial', 0.006779318020843873), ('living', 0.00647576646767176), ('statement', 0.005868663361327532), ('law', 0.005868663361327532), ('vote', 0.005261560254983305), ('class', 0.005059192552868562), ('year', 0.00485682485075382), ('sensation', 0.0043509055954669635), ('question', 0.004148537893352221), ('single', 0.004148537893352221), ('bill', 0.0040473540422948494), ('day', 0.0040473540422948494), ('government', 0.0034402509359506223), ('time', 0.0034402509359506223), ('paper', 0.003339067084893251), ('means', 0.002934331680663766), ('speech', 0.002934331680663766), ('way', 0.002934331680663766)], [], [], [], [], [], [], [('on', 0.23076923076923078), ('be', 0.15384615384615385), ('means', 0.15384615384615385), ('about', 0.15384615384615385), ('in', 0.15384615384615385), ('for', 0.07692307692307693), ('Influence', 0.07692307692307693)], [], [('days', 0.06546854942233633), ('minutes', 0.051347881899871634), ('years', 0.03979460847240052), ('moments', 0.03465982028241335), ('weeks', 0.029525032092426188), ('months', 0.026957637997432605), ('cents', 0.01797175866495507), ('at', 0.01668806161745828), ('years,', 0.01540436456996149), ('yearsnago', 0.014120667522464698), ('days,', 0.012836970474967908), ('day', 0.011553273427471117), ('dollars', 0.011553273427471117), ('furnish', 0.010269576379974325), ('months,', 0.010269576379974325), ('in', 0.008985879332477536), ('daysnago', 0.008985879332477536), ('hours', 0.008985879332477536), ('weeksnago', 0.007702182284980745), ('with', 0.006418485237483954)], [], [], [('from', 1.0)], [('which', 0.4), (\"Radway's,\", 0.13333333333333333), ('mules,', 0.13333333333333333), ('interest;', 0.13333333333333333), ('each,', 0.13333333333333333), ('Wednesday', 0.06666666666666667)], [], [('little', 0.16510903426791276), ('whole', 0.04361370716510903), ('beautiful', 0.04361370716510903), ('neighboring', 0.028037383177570093), ('townnor', 0.024922118380062305), ('nearest', 0.024922118380062305), ('agricultural', 0.018691588785046728), ('said', 0.012461059190031152), ('next', 0.012461059190031152), ('Maine', 0.012461059190031152), ('British', 0.012461059190031152), ('present', 0.009345794392523364), ('ancient', 0.009345794392523364), ('incorporated', 0.009345794392523364), ('thriving', 0.009345794392523364), ('small', 0.009345794392523364), ('obscure', 0.009345794392523364), ('city,', 0.009345794392523364), ('Japanese', 0.009345794392523364), ('States,nthe', 0.006230529595015576)], [], [('weakling', 1.0)], [], [('inches', 0.21926910299003322), ('feet', 0.07973421926910298), ('poles', 0.046511627906976744), ('chains', 0.04318936877076412), ('Inches', 0.029900332225913623), ('in.)', 0.019933554817275746), ('years', 0.019933554817275746), ('links', 0.016611295681063124), ('.00', 0.016611295681063124), ('miles', 0.013289036544850499), ('perches', 0.013289036544850499), ('.25', 0.013289036544850499), ('inch-nes', 0.013289036544850499), ('.50', 0.013289036544850499), ('chs,', 0.013289036544850499), ('.', 0.013289036544850499), ('00', 0.009966777408637873), ('feet,', 0.009966777408637873), ('inchee', 0.009966777408637873), ('10nper.', 0.006644518272425249)], [], [], [], [('puzzling', 0.1111111111111111), ('sufficient', 0.1111111111111111), ('strong', 0.1111111111111111), ('brought', 0.1111111111111111), ('-signed', 0.1111111111111111), ('preparatory', 0.1111111111111111), ('taken', 0.1111111111111111), ('obliged', 0.1111111111111111), ('enough', 0.1111111111111111)], [], [('it', 0.2631578947368421), ('than', 0.15789473684210525), ('that', 0.13157894736842105), ('this', 0.10526315789473684), ('1nthan', 0.05263157894736842), ('allnthat', 0.05263
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# PREDICTION FOR TEST-A\n",
"\n",
"dataframe = pd.read_csv(directory_test_A, sep='\\t', header=None, names=['FileId', 'Year', 'LeftContext', 'RightContext'], escapechar='\\\\', quoting=csv.QUOTE_NONE)\n",
"\n",
"left_text = dataframe['LeftContext'].apply(lambda l: re.split(r\"\\s+\", l)).to_list()\n",
"right_text = dataframe['RightContext'].apply(lambda l: re.split(r\"\\s+\", l)).to_list()\n",
"\n",
"lines = zip(left_text, right_text)\n",
"lines = list(map(lambda l: model.generate_text(l[0], l[1], False), tqdm(lines)))\n",
"print(lines[:100])"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 7414/7414 [00:00<00:00, 128642.81it/s]\n"
]
}
],
"source": [
"with open(\"test-A/out.tsv\", \"w\", encoding=\"UTF-8\") as file:\n",
" result = \"\\n\".join(list(map(lambda l: convert_predictions(l), tqdm(lines))))\n",
" file.write(result)\n",
" file.close()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "python11",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}