06_kenlm/06_kenlm.ipynb

1384 lines
50 KiB
Plaintext
Raw Normal View History

2024-05-15 04:33:04 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import lzma\n",
"\n",
"folders = [\"./challenging-america-word-gap-prediction/dev-0\",\n",
" \"./challenging-america-word-gap-prediction/test-A\",\n",
" \"./challenging-america-word-gap-prediction/train\"]\n",
"\n",
"for folder in folders:\n",
" for file in os.listdir(folder):\n",
" if file.endswith(\".tsv.xz\"):\n",
" file_path = os.path.join(folder, file)\n",
" output_path = os.path.splitext(file_path)[0] # Remove the .xz extension\n",
" with lzma.open(file_path, \"rb\") as compressed_file:\n",
" with open(output_path, \"wb\") as output_file:\n",
" output_file.write(compressed_file.read())"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[nltk_data] Downloading package punkt to\n",
"[nltk_data] /teamspace/studios/this_studio/nltk_data...\n",
"[nltk_data] Package punkt is already up-to-date!\n"
]
},
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import nltk\n",
"nltk.download('punkt')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Skipping line 30538: expected 8 fields, saw 9\n",
"Skipping line 37185: expected 8 fields, saw 9\n",
"Skipping line 40930: expected 8 fields, saw 9\n",
"Skipping line 44499: expected 8 fields, saw 9\n",
"Skipping line 46409: expected 8 fields, saw 9\n",
"Skipping line 52642: expected 8 fields, saw 9\n",
"Skipping line 53046: expected 8 fields, saw 9\n",
"\n",
"Skipping line 69658: expected 8 fields, saw 9\n",
"Skipping line 71325: expected 8 fields, saw 9\n",
"Skipping line 72955: expected 8 fields, saw 9\n",
"Skipping line 80528: expected 8 fields, saw 9\n",
"Skipping line 96979: expected 8 fields, saw 9\n",
"Skipping line 121731: expected 8 fields, saw 9\n",
"Skipping line 126630: expected 8 fields, saw 9\n",
"\n",
"Skipping line 132289: expected 8 fields, saw 9\n",
"Skipping line 140251: expected 8 fields, saw 9\n",
"Skipping line 142374: expected 8 fields, saw 9\n",
"Skipping line 149592: expected 8 fields, saw 9\n",
"Skipping line 150041: expected 8 fields, saw 9\n",
"Skipping line 151624: expected 8 fields, saw 9\n",
"Skipping line 158163: expected 8 fields, saw 9\n",
"Skipping line 159665: expected 8 fields, saw 9\n",
"Skipping line 171749: expected 8 fields, saw 9\n",
"Skipping line 174845: expected 8 fields, saw 9\n",
"Skipping line 177638: expected 8 fields, saw 9\n",
"Skipping line 178778: expected 8 fields, saw 9\n",
"Skipping line 188823: expected 8 fields, saw 9\n",
"Skipping line 191398: expected 8 fields, saw 9\n",
"\n",
"Skipping line 196865: expected 8 fields, saw 9\n",
"Skipping line 203572: expected 8 fields, saw 9\n",
"Skipping line 207802: expected 8 fields, saw 9\n",
"Skipping line 214509: expected 8 fields, saw 9\n",
"Skipping line 214633: expected 8 fields, saw 9\n",
"Skipping line 217906: expected 8 fields, saw 9\n",
"Skipping line 220906: expected 8 fields, saw 9\n",
"Skipping line 238000: expected 8 fields, saw 9\n",
"Skipping line 257754: expected 8 fields, saw 9\n",
"Skipping line 259366: expected 8 fields, saw 9\n",
"Skipping line 261826: expected 8 fields, saw 9\n",
"\n",
"Skipping line 272727: expected 8 fields, saw 9\n",
"Skipping line 280527: expected 8 fields, saw 9\n",
"Skipping line 282454: expected 8 fields, saw 9\n",
"Skipping line 285910: expected 8 fields, saw 9\n",
"Skipping line 289865: expected 8 fields, saw 9\n",
"Skipping line 292892: expected 8 fields, saw 9\n",
"Skipping line 292984: expected 8 fields, saw 9\n",
"Skipping line 293058: expected 8 fields, saw 9\n",
"Skipping line 302716: expected 8 fields, saw 9\n",
"Skipping line 303370: expected 8 fields, saw 9\n",
"Skipping line 314194: expected 8 fields, saw 9\n",
"Skipping line 321975: expected 8 fields, saw 9\n",
"Skipping line 324999: expected 8 fields, saw 9\n",
"\n",
"Skipping line 331978: expected 8 fields, saw 9\n",
"Skipping line 345426: expected 8 fields, saw 9\n",
"Skipping line 345951: expected 8 fields, saw 9\n",
"Skipping line 355430: expected 8 fields, saw 9\n",
"Skipping line 358744: expected 8 fields, saw 9\n",
"Skipping line 361491: expected 8 fields, saw 9\n",
"Skipping line 370443: expected 8 fields, saw 9\n",
"Skipping line 388057: expected 8 fields, saw 9\n",
"Skipping line 391061: expected 8 fields, saw 9\n",
"\n",
"Skipping line 395391: expected 8 fields, saw 9\n",
"Skipping line 404270: expected 8 fields, saw 9\n",
"Skipping line 407896: expected 8 fields, saw 9\n",
"Skipping line 409881: expected 8 fields, saw 9\n",
"Skipping line 421230: expected 8 fields, saw 9\n",
"Skipping line 425850: expected 8 fields, saw 9\n",
"Skipping line 427269: expected 8 fields, saw 9\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"in_df:\n",
" 0 1 2 \\\n",
"0 4e04702da929c78c52baf09c1851d3ff ST ChronAm \n",
"1 b374dadd940510271d9675d3e8caf9d8 DAILY ARIZONA SILVER BELT ChronAm \n",
"2 adb666c426bdc10fd949cb824da6c0d0 THE SAVANNAH MORNING NEWS ChronAm \n",
"3 bc2c9aa0b77d724311e3c2e12fc61c92 CHARLES CITY INTELLIGENCER ChronAm \n",
"4 0f612b991a39c712f0d745835b8b2f0d EVENING STAR ChronAm \n",
"\n",
" 3 4 5 \\\n",
"0 1919.604110 30.475470 -90.100911 \n",
"1 1909.097260 33.399478 -110.870950 \n",
"2 1900.913699 32.080926 -81.091177 \n",
"3 1864.974044 43.066361 -92.672411 \n",
"4 1878.478082 38.894955 -77.036646 \n",
"\n",
" 6 \\\n",
"0 came fiom the last place to this\\nplace, and t... \n",
"1 MB. BOOT'S POLITICAL OBEED\\nAttempt to imagine... \n",
"2 Thera were in 1771 only aeventy-nine\\n*ub*erlb... \n",
"3 whenever any prize property shall!*' condemn- ... \n",
"4 SA LKOFVALUABLE UNIMPBOV&D RE\\\\L\\nJSIATF. ON T... \n",
"\n",
" 7 \n",
"0 said\\nit's all squash. The best I could get\\ni... \n",
"1 \\ninto a proper perspective with those\\nminor ... \n",
"2 NaN \n",
"3 the ceitihcate of'\\noperate to prevent tfie ma... \n",
"4 \\nTerms of sale: One-tblrd, togethor with the ... \n",
"\n",
"expected_df:\n",
" 0\n",
"0 lie\n",
"1 himself\n",
"2 of\n",
"3 ably\n",
"4 j\n",
"\n",
"hate_speech_info_df:\n",
"\n",
"in_df info:\n",
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 428517 entries, 0 to 428516\n",
"Data columns (total 8 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 0 428517 non-null object \n",
" 1 1 428517 non-null object \n",
" 2 2 428517 non-null object \n",
" 3 3 428517 non-null float64\n",
" 4 4 428517 non-null float64\n",
" 5 5 428517 non-null float64\n",
" 6 6 428517 non-null object \n",
" 7 7 425735 non-null object \n",
"dtypes: float64(3), object(5)\n",
"memory usage: 26.2+ MB\n",
"None\n",
"\n",
"expected_df info:\n",
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 279623 entries, 0 to 279622\n",
"Data columns (total 1 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 0 279619 non-null object\n",
"dtypes: object(1)\n",
"memory usage: 2.1+ MB\n",
"None\n"
]
}
],
"source": [
"import pandas as pd\n",
"\n",
"# Load the data\n",
"in_df = pd.read_csv('./challenging-america-word-gap-prediction/train/in.tsv', sep='\\t', header=None, on_bad_lines='warn')\n",
"expected_df = pd.read_csv('./challenging-america-word-gap-prediction/train/expected.tsv', sep='\\t', header=None, on_bad_lines='warn')\n",
"\n",
"# Print out the first few rows of each DataFrame\n",
"print(\"in_df:\")\n",
"print(in_df.head())\n",
"print(\"\\nexpected_df:\")\n",
"print(expected_df.head())\n",
"print(\"\\nhate_speech_info_df:\")\n",
"\n",
"# Print out more information about each DataFrame\n",
"print(\"\\nin_df info:\")\n",
"print(in_df.info())\n",
"print(\"\\nexpected_df info:\")\n",
"print(expected_df.info())"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# Drop unnecessary columns\n",
"columns_to_drop = [0, 1, 2, 3, 4, 5] # Column indices to drop\n",
"in_df.drop(columns_to_drop, axis=1, inplace=True)\n",
"\n",
"# Rename remaining columns for clarity\n",
"in_df.columns = ['text_1', 'text_2']"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"in_df['text_2'].fillna('', inplace=True)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Replacing '\\n' in 'text_1': 100%|██████████| 428517/428517 [00:02<00:00, 166646.43it/s]\n",
"Replacing '\\t' in 'text_1': 100%|██████████| 428517/428517 [00:01<00:00, 422489.36it/s]\n",
"Replacing '\\n' in 'text_2': 100%|██████████| 428517/428517 [00:02<00:00, 149443.26it/s]\n",
"Replacing '\\t' in 'text_2': 100%|██████████| 428517/428517 [00:01<00:00, 417969.18it/s]\n"
]
}
],
"source": [
"from tqdm import tqdm\n",
"\n",
"# Define a function to replace '\\n' with ' '\n",
"def replace_newline(text):\n",
" if isinstance(text, str):\n",
" return text.replace('\\\\n', ' ')\n",
" return text\n",
"\n",
"def replace_tabulation(text):\n",
" if isinstance(text, str):\n",
" return text.replace('\\\\t', ' ')\n",
" return text\n",
"\n",
"# Apply the function to 'text_1' and 'text_2' columns and show progress\n",
"tqdm.pandas(desc=\"Replacing '\\\\n' in 'text_1'\")\n",
"in_df['text_1'] = in_df['text_1'].progress_apply(replace_newline)\n",
"\n",
"tqdm.pandas(desc=\"Replacing '\\\\t' in 'text_1'\")\n",
"in_df['text_1'] = in_df['text_1'].progress_apply(replace_tabulation)\n",
"\n",
"tqdm.pandas(desc=\"Replacing '\\\\n' in 'text_2'\")\n",
"in_df['text_2'] = in_df['text_2'].progress_apply(replace_newline)\n",
"\n",
"\n",
"tqdm.pandas(desc=\"Replacing '\\\\t' in 'text_2'\")\n",
"in_df['text_2'] = in_df['text_2'].progress_apply(replace_tabulation)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" text_1 \\\n",
"0 came fiom the last place to this place, and th... \n",
"1 MB. BOOT'S POLITICAL OBEED Attempt to imagine ... \n",
"2 Thera were in 1771 only aeventy-nine *ub*erlbe... \n",
"3 whenever any prize property shall!*' condemn- ... \n",
"4 SA LKOFVALUABLE UNIMPBOV&D RE\\\\L JSIATF. ON TH... \n",
"\n",
" text_2 \n",
"0 said it's all squash. The best I could get in ... \n",
"1 into a proper perspective with those minor se... \n",
"2 \n",
"3 the ceitihcate of' operate to prevent tfie mak... \n",
"4 Terms of sale: One-tblrd, togethor with the e... \n"
]
}
],
"source": [
"print(in_df.head())"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[nltk_data] Downloading package punkt to\n",
"[nltk_data] /teamspace/studios/this_studio/nltk_data...\n",
"[nltk_data] Package punkt is already up-to-date!\n",
"[nltk_data] Downloading package stopwords to\n",
"[nltk_data] /teamspace/studios/this_studio/nltk_data...\n",
"[nltk_data] Package stopwords is already up-to-date!\n",
"Processing 'text_1': 100%|██████████| 428517/428517 [16:38<00:00, 429.15it/s]\n",
"Processing 'text_2': 100%|██████████| 428517/428517 [14:55<00:00, 478.46it/s]\n"
]
}
],
"source": [
"import nltk\n",
"from nltk.corpus import stopwords\n",
"from nltk.stem import PorterStemmer\n",
"from nltk.tokenize import word_tokenize\n",
"import string\n",
"from tqdm import tqdm\n",
"\n",
"\n",
"# If not already done, download the NLTK English stopwords and the Punkt Tokenizer Models\n",
"nltk.download('punkt')\n",
"nltk.download('stopwords')\n",
"\n",
"stop_words = set(stopwords.words('english'))\n",
"stemmer = PorterStemmer()\n",
"\n",
"def preprocess_text(text):\n",
" # Lowercase the text\n",
" text = text.lower()\n",
" \n",
" # Remove punctuation\n",
" text = text.translate(str.maketrans('', '', string.punctuation))\n",
" \n",
" # Tokenize the text\n",
" words = word_tokenize(text)\n",
" \n",
" # Remove stopwords and stem the words\n",
" words = [stemmer.stem(word) for word in words if word not in stop_words]\n",
" \n",
" # Join the words back into a single string\n",
" text = ' '.join(words)\n",
" \n",
" return text\n",
"\n",
"# Apply the preprocessing to the 'text_1' and 'text_2' columns\n",
"tqdm.pandas(desc=\"Processing 'text_1'\")\n",
"in_df['text_1'] = in_df['text_1'].progress_apply(preprocess_text)\n",
"\n",
"\n",
"tqdm.pandas(desc=\"Processing 'text_2'\")\n",
"in_df['text_2'] = in_df['text_2'].progress_apply(preprocess_text)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" text_1 \\\n",
"0 came fiom last place place place first road ev... \n",
"1 mb boot polit obe attempt imagin piatt make ad... \n",
"2 thera 1771 aeventynin uberlb lo lloyd nearli 1... \n",
"3 whenev prize properti shall condemn appeal dis... \n",
"4 sa lkofvalu unimpbovd rel jsiatf north bideof ... \n",
"\n",
" text_2 \n",
"0 said squash best could get hotel soup sandwich... \n",
"1 proper perspect minor senatori duti tho fill i... \n",
"2 \n",
"3 ceitihc oper prevent tfie make execut district... \n",
"4 term sale onetblrd togethor ex¬ pens sale cash... \n"
]
}
],
"source": [
"print(in_df.head())"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"# Save 'in_df' DataFrame to a .tsv file\n",
"in_df.to_csv('preprocessed_text1_text2.tsv', sep='\\t', index=False)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processing 'text_1': 100%|██████████| 428517/428517 [00:21<00:00, 19801.75it/s]\n",
"Processing 'text_2': 100%|██████████| 428517/428517 [00:25<00:00, 16938.99it/s]\n"
]
}
],
"source": [
"import re\n",
"\n",
"def handle_numbers_and_special_chars(text):\n",
" # Remove numbers\n",
" text = re.sub(r'\\d+', '', text)\n",
" \n",
" # Remove special characters\n",
" text = re.sub(r'\\W+', ' ', text)\n",
" \n",
" return text\n",
"\n",
"# Apply the function to the 'text_1' and 'text_2' columns\n",
"tqdm.pandas(desc=\"Processing 'text_1'\")\n",
"in_df['text_1'] = in_df['text_1'].progress_apply(handle_numbers_and_special_chars)\n",
"\n",
"tqdm.pandas(desc=\"Processing 'text_2'\")\n",
"in_df['text_2'] = in_df['text_2'].progress_apply(handle_numbers_and_special_chars)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from spellchecker import SpellChecker\n",
"\n",
"spell = SpellChecker()\n",
"\n",
"def correct_spelling(text):\n",
" # Tokenize the text\n",
" words = text.split()\n",
" \n",
" # Correct spelling\n",
" corrected_words = [spell.correction(word) if spell.correction(word) is not None else '' for word in words]\n",
" \n",
" # Join the words back into a single string\n",
" text = ' '.join(corrected_words)\n",
" \n",
" return text\n",
"\n",
"# Apply the spelling correction to the 'text_1' and 'text_2' columns\n",
"tqdm.pandas(desc=\"Spelling Correction 'text_1'\")\n",
"in_df['text_1'] = in_df['text_1'].progress_apply(correct_spelling)\n",
"\n",
"tqdm.pandas(desc=\"Spelling Correction 'text_2'\")\n",
"in_df['text_2'] = in_df['text_2'].progress_apply(correct_spelling)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Define a function to concatenate 'text_1' and 'text_2'\n",
"def concatenate_texts(row):\n",
" return str(row['text_1']) + ' <MASK> ' + str(row['text_2'])\n",
"\n",
"# Apply the function to each row and show progress\n",
"tqdm.pandas(desc=\"Concatenating 'text_1' and 'text_2'\")\n",
"in_df['text'] = in_df.progress_apply(concatenate_texts, axis=1)\n",
"\n",
"# Now you can drop 'text_1' and 'text_2' columns if you want\n",
"in_df.drop(['text_1', 'text_2'], axis=1, inplace=True)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 428517/428517 [00:05<00:00, 83867.84it/s]\n"
]
}
],
"source": [
"from tqdm import tqdm\n",
"tqdm.pandas()\n",
"\n",
"# Load the preprocessed data\n",
"in_df = pd.read_csv('preprocessed_text1_text2.tsv', sep='\\t')\n",
"\n",
"# Load the expected words\n",
"expected_df = pd.read_csv('./challenging-america-word-gap-prediction/train/expected.tsv', sep='\\t', header=None, on_bad_lines='warn')\n",
"expected_df.columns = ['expected_word']\n",
"\n",
"# Add the expected words to in_df\n",
"in_df = pd.concat([in_df, expected_df], axis=1)\n",
"\n",
"# Define a function to concatenate 'text_1' and 'expected_word' and 'text_2'\n",
"def concatenate_texts(row):\n",
" return str(row['text_1']) + ' ' + str(row['expected_word']) + ' ' + str(row['text_2'])\n",
"\n",
"# Apply the function to each row and show progress\n",
"in_df['unmasked_text'] = in_df.progress_apply(concatenate_texts, axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"in_df['unmasked_text'].to_csv('preprocessed_text_join_unmask.tsv', sep='\\t', index=False)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"# Save the 'text' column to a text file\n",
"in_df['unmasked_text'].to_csv('training_data.txt', index=False, header=False)"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"=== 1/5 Counting and sorting n-grams ===\n",
"Reading /teamspace/studios/this_studio/training_data.txt\n",
"----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"****************************************************************************************************\n",
"Unigram tokens 76946114 types 3214378\n",
"=== 2/5 Calculating and sorting adjusted counts ===\n",
"Chain sizes: 1:38572536 2:1266848384 3:2375340800 4:3800545280 5:5542461952\n",
"Statistics:\n",
"1 3214378 D1=0.828129 D2=1.02195 D3+=1.19945\n",
"2 28265496 D1=0.815336 D2=1.00509 D3+=1.20745\n",
"3 64425776 D1=0.933716 D2=1.31704 D3+=1.47374\n",
"4 71021141 D1=0.980912 D2=1.49465 D3+=1.62427\n",
"5 72056555 D1=0.968608 D2=1.34446 D3+=1.42151\n",
"Memory estimate for binary LM:\n",
"type MB\n",
"probing 5063 assuming -p 1.5\n",
"probing 6012 assuming -r models -p 1.5\n",
"trie 2711 without quantization\n",
"trie 1596 assuming -q 8 -b 8 quantization \n",
"trie 2329 assuming -a 22 array pointer compression\n",
"trie 1214 assuming -a 22 -q 8 -b 8 array pointer compression and quantization\n",
"=== 3/5 Calculating and sorting initial probabilities ===\n",
"Chain sizes: 1:38572536 2:452247936 3:1288515520 4:1704507384 5:2017583540\n",
"----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n",
"####################################################################################################\n",
"=== 4/5 Calculating and writing order-interpolated probabilities ===\n",
"Chain sizes: 1:38572536 2:452247936 3:1288515520 4:1704507384 5:2017583540\n",
"----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n",
"####################################################################################################\n",
"=== 5/5 Writing ARPA model ===\n",
"----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n",
"****************************************************************************************************\n",
"Name:lmplz\tVmPeak:12890788 kB\tVmRSS:4496 kB\tRSSMax:5436904 kB\tuser:274.028\tsys:89.3375\tCPU:363.366\treal:343.888\n"
]
}
],
"source": [
"!lmplz -o 5 --discount_fallback < training_data.txt > language_model.arpa"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Reading language_model.arpa\n",
"----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"****************************************************************************************************\n",
"SUCCESS\n"
]
}
],
"source": [
"!build_binary language_model.arpa language_model.binary"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"=== 1/5 Counting and sorting n-grams ===\n",
"Reading /teamspace/studios/this_studio/training_data.txt\n",
"----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"****************************************************************************************************\n",
"Unigram tokens 76946114 types 3214378\n",
"=== 2/5 Calculating and sorting adjusted counts ===\n",
"Chain sizes: 1:38572536 2:799089024 3:1498291968 4:2397267200 5:3496014592 6:4794534400\n",
"Statistics:\n",
"1 3214378 D1=0.828129 D2=1.02195 D3+=1.19945\n",
"2 28265496 D1=0.815336 D2=1.00509 D3+=1.20745\n",
"3 4497710/64425776 D1=0.933716 D2=1.31704 D3+=1.47374\n",
"4 2178876/71021141 D1=0.980912 D2=1.49465 D3+=1.62427\n",
"5 1699108/72056555 D1=0.988779 D2=1.56326 D3+=1.72651\n",
"6 1460655/72365687 D1=0.972772 D2=1.34844 D3+=1.44176\n",
"Memory estimate for binary LM:\n",
"type MB\n",
"probing 943 assuming -p 1.5\n",
"probing 1165 assuming -r models -p 1.5\n",
"trie 553 without quantization\n",
"trie 343 assuming -q 8 -b 8 quantization \n",
"trie 474 assuming -a 22 array pointer compression\n",
"trie 265 assuming -a 22 -q 8 -b 8 array pointer compression and quantization\n",
"=== 3/5 Calculating and sorting initial probabilities ===\n",
"Chain sizes: 1:38572536 2:452247936 3:89954200 4:52293024 5:47575024 6:46740960\n",
"----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n",
"**##################################################################################################\n",
"=== 4/5 Calculating and writing order-interpolated probabilities ===\n",
"Chain sizes: 1:38572536 2:452247936 3:89954200 4:52293024 5:47575024 6:46740960\n",
"----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n",
"####################################################################################################\n",
"=== 5/5 Writing ARPA model ===\n",
"----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n",
"****************************************************************************************************\n",
"Name:lmplz\tVmPeak:12907180 kB\tVmRSS:5872 kB\tRSSMax:7679904 kB\tuser:168.818\tsys:56.6427\tCPU:225.461\treal:147.307\n"
]
}
],
"source": [
"!lmplz -o 6 --discount_fallback -S 80% --prune 0 0 1 1 1 1 < training_data.txt > language_model.arpa"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10519 ./challenging-america-word-gap-prediction/dev-0/in.tsv\n"
]
}
],
"source": [
"!wc -l ./challenging-america-word-gap-prediction/dev-0/in.tsv"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import csv\n",
"\n",
"# Load the data\n",
"try:\n",
" in_df = pd.read_csv('./challenging-america-word-gap-prediction/dev-0/in.tsv', sep='\\t', header=None, on_bad_lines='error')\n",
"except Exception as e:\n",
" print(e)\n",
"expected_df = pd.read_csv('./challenging-america-word-gap-prediction/dev-0/expected.tsv', sep='\\t', header=None, on_bad_lines='warn', quoting=csv.QUOTE_NONE)\n",
"\n",
"print(in_df.shape[0])\n",
"\n",
"# Drop unnecessary columns\n",
"columns_to_drop = [0, 1, 2, 3, 4, 5] # Column indices to drop\n",
"in_df.drop(columns_to_drop, axis=1, inplace=True)\n",
"\n",
"# Rename remaining columns for clarity\n",
"in_df.columns = ['text_1', 'text_2']\n",
"\n",
"in_df['text_1'].fillna('placeholder', inplace=True)\n",
"in_df['text_2'].fillna('placeholder', inplace=True)"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10519\n"
]
}
],
"source": [
"import pandas as pd\n",
"import csv\n",
"\n",
"# Placeholder line\n",
"placeholder_line = ['placeholder'] * 8 # Adjust the number of fields as needed\n",
"\n",
"# Read the file line by line\n",
"with open('./challenging-america-word-gap-prediction/dev-0/in.tsv', 'r') as f:\n",
" lines = f.readlines()\n",
"\n",
"# Split each line into fields and replace problematic lines with the placeholder line\n",
"lines = [line.strip().split('\\t') if len(line.strip().split('\\t')) == 8 else placeholder_line for line in lines]\n",
"\n",
"# Convert the list of lines into a DataFrame\n",
"in_df = pd.DataFrame(lines)\n",
"\n",
"# Print the number of rows in the DataFrame\n",
"print(in_df.shape[0])\n",
"\n",
"# Drop unnecessary columns\n",
"columns_to_drop = [0, 1, 2, 3, 4, 5] # Column indices to drop\n",
"in_df.drop(columns_to_drop, axis=1, inplace=True)\n",
"\n",
"# Rename remaining columns for clarity\n",
"in_df.columns = ['text_1', 'text_2']\n",
"\n",
"in_df['text_1'].fillna('placeholder', inplace=True)\n",
"in_df['text_2'].fillna('placeholder', inplace=True)"
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10519\n"
]
}
],
"source": [
"print(in_df.shape[0])"
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Replacing '\\n' in 'text_1': 100%|██████████| 10519/10519 [00:00<00:00, 223066.53it/s]\n",
"Replacing '\\t' in 'text_1': 100%|██████████| 10519/10519 [00:00<00:00, 535825.65it/s]\n",
"Replacing '\\n' in 'text_2': 100%|██████████| 10519/10519 [00:00<00:00, 216324.84it/s]\n",
"Replacing '\\t' in 'text_2': 100%|██████████| 10519/10519 [00:00<00:00, 534630.94it/s]\n"
]
}
],
"source": [
"from tqdm import tqdm\n",
"\n",
"# Define a function to replace '\\n' with ' '\n",
"def replace_newline(text):\n",
" if isinstance(text, str):\n",
" return text.replace('\\\\n', ' ')\n",
" return text\n",
"\n",
"def replace_tabulation(text):\n",
" if isinstance(text, str):\n",
" return text.replace('\\\\t', ' ')\n",
" return text\n",
"\n",
"# Apply the function to 'text_1' and 'text_2' columns and show progress\n",
"tqdm.pandas(desc=\"Replacing '\\\\n' in 'text_1'\")\n",
"in_df['text_1'] = in_df['text_1'].progress_apply(replace_newline)\n",
"\n",
"tqdm.pandas(desc=\"Replacing '\\\\t' in 'text_1'\")\n",
"in_df['text_1'] = in_df['text_1'].progress_apply(replace_tabulation)\n",
"\n",
"tqdm.pandas(desc=\"Replacing '\\\\n' in 'text_2'\")\n",
"in_df['text_2'] = in_df['text_2'].progress_apply(replace_newline)\n",
"\n",
"\n",
"tqdm.pandas(desc=\"Replacing '\\\\t' in 'text_2'\")\n",
"in_df['text_2'] = in_df['text_2'].progress_apply(replace_tabulation)"
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[nltk_data] Downloading package punkt to\n",
"[nltk_data] /teamspace/studios/this_studio/nltk_data...\n",
"[nltk_data] Package punkt is already up-to-date!\n",
"[nltk_data] Downloading package stopwords to\n",
"[nltk_data] /teamspace/studios/this_studio/nltk_data...\n",
"[nltk_data] Package stopwords is already up-to-date!\n",
"Processing 'text_1': 100%|██████████| 10519/10519 [00:23<00:00, 440.48it/s]\n",
"Processing 'text_2': 100%|██████████| 10519/10519 [00:29<00:00, 358.63it/s]\n"
]
}
],
"source": [
"import nltk\n",
"from nltk.corpus import stopwords\n",
"from nltk.stem import PorterStemmer\n",
"from nltk.tokenize import word_tokenize\n",
"import string\n",
"from tqdm import tqdm\n",
"\n",
"\n",
"# If not already done, download the NLTK English stopwords and the Punkt Tokenizer Models\n",
"nltk.download('punkt')\n",
"nltk.download('stopwords')\n",
"\n",
"stop_words = set(stopwords.words('english'))\n",
"stemmer = PorterStemmer()\n",
"\n",
"def preprocess_text(text):\n",
" # Lowercase the text\n",
" text = text.lower()\n",
" \n",
" # Remove punctuation\n",
" text = text.translate(str.maketrans('', '', string.punctuation))\n",
" \n",
" # Tokenize the text\n",
" words = word_tokenize(text)\n",
" \n",
" # Remove stopwords and stem the words\n",
" words = [stemmer.stem(word) for word in words if word not in stop_words]\n",
" \n",
" # Join the words back into a single string\n",
" text = ' '.join(words)\n",
" \n",
" return text\n",
"\n",
"# Apply the preprocessing to the 'text_1' and 'text_2' columns\n",
"tqdm.pandas(desc=\"Processing 'text_1'\")\n",
"in_df['text_1'] = in_df['text_1'].progress_apply(preprocess_text)\n",
"\n",
"\n",
"tqdm.pandas(desc=\"Processing 'text_2'\")\n",
"in_df['text_2'] = in_df['text_2'].progress_apply(preprocess_text)"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processing 'text_1': 100%|██████████| 10519/10519 [00:00<00:00, 21823.01it/s]\n",
"Processing 'text_2': 100%|██████████| 10519/10519 [00:00<00:00, 21693.77it/s]\n"
]
}
],
"source": [
"import re\n",
"\n",
"def handle_numbers_and_special_chars(text):\n",
" # Remove numbers\n",
" text = re.sub(r'\\d+', '', text)\n",
" \n",
" # Remove special characters\n",
" text = re.sub(r'\\W+', ' ', text)\n",
" \n",
" return text\n",
"\n",
"# Apply the function to the 'text_1' and 'text_2' columns\n",
"tqdm.pandas(desc=\"Processing 'text_1'\")\n",
"in_df['text_1'] = in_df['text_1'].progress_apply(handle_numbers_and_special_chars)\n",
"\n",
"tqdm.pandas(desc=\"Processing 'text_2'\")\n",
"in_df['text_2'] = in_df['text_2'].progress_apply(handle_numbers_and_special_chars)"
]
},
{
"cell_type": "code",
"execution_count": 84,
"metadata": {},
"outputs": [],
"source": [
"in_df.to_csv('preprocessed_dev_text1_text2.tsv', sep='\\t', index=False)"
]
},
{
"cell_type": "code",
"execution_count": 85,
"metadata": {},
"outputs": [],
"source": [
"import kenlm\n",
"\n",
"# Load the language model\n",
"model = kenlm.Model('language_model.binary')"
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"def softmax(x):\n",
" e_x = np.exp(x - np.max(x))\n",
" return e_x / e_x.sum()\n",
"\n",
"def predict_missing_word(model, context_1, context_2):\n",
" # Define the vocabulary\n",
" vocabulary = set(' '.join([context_1, context_2]).split())\n",
"\n",
" # Initialize a dictionary to store the words and their scores\n",
" word_scores = {}\n",
"\n",
" # Iterate over the vocabulary\n",
" for word in vocabulary:\n",
" try:\n",
" # Generate the sentence\n",
" sentence = f\"{context_1} {word} {context_2}\"\n",
" \n",
" # Score the sentence and store it in the dictionary\n",
" word_scores[word] = model.score(sentence)\n",
"\n",
" except Exception as e:\n",
" print(f\"Error processing word '{word}': {e}\")\n",
" continue\n",
"\n",
" # If no word was found, return None for all values\n",
" if not word_scores:\n",
" return None, None, None\n",
"\n",
" # Convert the scores to probabilities using the softmax function\n",
" word_probs = {word: max(0.001, prob) for word, prob in zip(word_scores.keys(), softmax(list(word_scores.values())))}\n",
"\n",
" # Find the word with the highest probability\n",
" best_word, best_prob = max(word_probs.items(), key=lambda x: x[1])\n",
"\n",
" # Calculate the sum of probabilities for the other words\n",
" other_probs_sum = sum(prob for word, prob in word_probs.items() if word != best_word)\n",
"\n",
" return best_word, best_prob, other_probs_sum"
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processed 0 rows. Current accuracy: 0.0\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processed 1000 rows. Current accuracy: 0.016983016983016984\n",
"Processed 2000 rows. Current accuracy: 0.026486756621689155\n",
"Processed 3000 rows. Current accuracy: 0.02599133622125958\n",
"Processed 4000 rows. Current accuracy: 0.024493876530867282\n",
"Processed 5000 rows. Current accuracy: 0.02559488102379524\n",
"Processed 6000 rows. Current accuracy: 0.024329278453591067\n",
"Processed 7000 rows. Current accuracy: 0.023853735180688472\n",
"Processed 8000 rows. Current accuracy: 0.023122109736282963\n",
"Processed 9000 rows. Current accuracy: 0.022886345961559827\n",
"Processed 10000 rows. Current accuracy: 0.0225977402259774\n",
"The accuracy of the model is 0.02243559273695218\n"
]
}
],
"source": [
"# Initialize a counter for the correct predictions\n",
"correct_predictions = 0\n",
"\n",
"# Open the output file\n",
"with open('out.tsv', 'w') as f:\n",
" # Iterate over the rows of the input DataFrame and the expected DataFrame\n",
" for i, ((_, input_row), expected_word) in enumerate(zip(in_df.iterrows(), expected_df[0])):\n",
" try:\n",
" # Get the context\n",
" context_1 = input_row['text_1']\n",
" context_2 = input_row['text_2']\n",
"\n",
" # Predict the missing word and get the probabilities\n",
" predicted_word, prob, other_probs_sum = predict_missing_word(model, context_1, context_2)\n",
"\n",
" # If any of the values are None, use placeholder values\n",
" if predicted_word is None:\n",
" predicted_word = 'placeholder'\n",
" if prob is None:\n",
" prob = 0.001\n",
" if other_probs_sum is None:\n",
" other_probs_sum = 0.001\n",
"\n",
" # Write the output to the file\n",
" f.write(f\"{predicted_word}:{prob:.4f} :{other_probs_sum:.4f}\\n\")\n",
"\n",
" # Check if the prediction is correct\n",
" if predicted_word == expected_word:\n",
" correct_predictions += 1\n",
"\n",
" # Log progress every 1000 iterations\n",
" if i % 1000 == 0:\n",
" print(f\"Processed {i} rows. Current accuracy: {correct_predictions / (i+1)}\")\n",
"\n",
" except Exception as e:\n",
" print(f\"Error processing row {i}: {e}\")\n",
"\n",
"# Calculate the accuracy\n",
"accuracy = correct_predictions / len(in_df)\n",
"\n",
"print(f\"The accuracy of the model is {accuracy}\")"
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"7414\n"
]
}
],
"source": [
"import pandas as pd\n",
"import csv\n",
"\n",
"# Placeholder line\n",
"placeholder_line = ['placeholder'] * 8 # Adjust the number of fields as needed\n",
"\n",
"# Read the file line by line\n",
"with open('./challenging-america-word-gap-prediction/test-A/in.tsv', 'r') as f:\n",
" lines = f.readlines()\n",
"\n",
"# Split each line into fields and replace problematic lines with the placeholder line\n",
"lines = [line.strip().split('\\t') if len(line.strip().split('\\t')) == 8 else placeholder_line for line in lines]\n",
"\n",
"# Convert the list of lines into a DataFrame\n",
"in_df = pd.DataFrame(lines)\n",
"\n",
"# Print the number of rows in the DataFrame\n",
"print(in_df.shape[0])\n",
"\n",
"# Drop unnecessary columns\n",
"columns_to_drop = [0, 1, 2, 3, 4, 5] # Column indices to drop\n",
"in_df.drop(columns_to_drop, axis=1, inplace=True)\n",
"\n",
"# Rename remaining columns for clarity\n",
"in_df.columns = ['text_1', 'text_2']\n",
"\n",
"in_df['text_1'].fillna('placeholder', inplace=True)\n",
"in_df['text_2'].fillna('placeholder', inplace=True)"
]
},
{
"cell_type": "code",
"execution_count": 89,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Replacing '\\n' in 'text_1': 100%|██████████| 7414/7414 [00:00<00:00, 216746.15it/s]\n",
"Replacing '\\t' in 'text_1': 100%|██████████| 7414/7414 [00:00<00:00, 545247.75it/s]\n",
"Replacing '\\n' in 'text_2': 100%|██████████| 7414/7414 [00:00<00:00, 223832.27it/s]\n",
"Replacing '\\t' in 'text_2': 100%|██████████| 7414/7414 [00:00<00:00, 569784.70it/s]\n"
]
}
],
"source": [
"from tqdm import tqdm\n",
"\n",
"# Define a function to replace '\\n' with ' '\n",
"def replace_newline(text):\n",
" if isinstance(text, str):\n",
" return text.replace('\\\\n', ' ')\n",
" return text\n",
"\n",
"def replace_tabulation(text):\n",
" if isinstance(text, str):\n",
" return text.replace('\\\\t', ' ')\n",
" return text\n",
"\n",
"# Apply the function to 'text_1' and 'text_2' columns and show progress\n",
"tqdm.pandas(desc=\"Replacing '\\\\n' in 'text_1'\")\n",
"in_df['text_1'] = in_df['text_1'].progress_apply(replace_newline)\n",
"\n",
"tqdm.pandas(desc=\"Replacing '\\\\t' in 'text_1'\")\n",
"in_df['text_1'] = in_df['text_1'].progress_apply(replace_tabulation)\n",
"\n",
"tqdm.pandas(desc=\"Replacing '\\\\n' in 'text_2'\")\n",
"in_df['text_2'] = in_df['text_2'].progress_apply(replace_newline)\n",
"\n",
"\n",
"tqdm.pandas(desc=\"Replacing '\\\\t' in 'text_2'\")\n",
"in_df['text_2'] = in_df['text_2'].progress_apply(replace_tabulation)"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[nltk_data] Downloading package punkt to\n",
"[nltk_data] /teamspace/studios/this_studio/nltk_data...\n",
"[nltk_data] Package punkt is already up-to-date!\n",
"[nltk_data] Downloading package stopwords to\n",
"[nltk_data] /teamspace/studios/this_studio/nltk_data...\n",
"[nltk_data] Package stopwords is already up-to-date!\n",
"Processing 'text_1': 100%|██████████| 7414/7414 [00:20<00:00, 365.75it/s]\n",
"Processing 'text_2': 100%|██████████| 7414/7414 [00:15<00:00, 478.59it/s]\n"
]
}
],
"source": [
"import nltk\n",
"from nltk.corpus import stopwords\n",
"from nltk.stem import PorterStemmer\n",
"from nltk.tokenize import word_tokenize\n",
"import string\n",
"from tqdm import tqdm\n",
"\n",
"\n",
"# If not already done, download the NLTK English stopwords and the Punkt Tokenizer Models\n",
"nltk.download('punkt')\n",
"nltk.download('stopwords')\n",
"\n",
"stop_words = set(stopwords.words('english'))\n",
"stemmer = PorterStemmer()\n",
"\n",
"def preprocess_text(text):\n",
" # Lowercase the text\n",
" text = text.lower()\n",
" \n",
" # Remove punctuation\n",
" text = text.translate(str.maketrans('', '', string.punctuation))\n",
" \n",
" # Tokenize the text\n",
" words = word_tokenize(text)\n",
" \n",
" # Remove stopwords and stem the words\n",
" words = [stemmer.stem(word) for word in words if word not in stop_words]\n",
" \n",
" # Join the words back into a single string\n",
" text = ' '.join(words)\n",
" \n",
" return text\n",
"\n",
"# Apply the preprocessing to the 'text_1' and 'text_2' columns\n",
"tqdm.pandas(desc=\"Processing 'text_1'\")\n",
"in_df['text_1'] = in_df['text_1'].progress_apply(preprocess_text)\n",
"\n",
"\n",
"tqdm.pandas(desc=\"Processing 'text_2'\")\n",
"in_df['text_2'] = in_df['text_2'].progress_apply(preprocess_text)"
]
},
{
"cell_type": "code",
"execution_count": 91,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processing 'text_1': 100%|██████████| 7414/7414 [00:00<00:00, 21928.33it/s]\n",
"Processing 'text_2': 100%|██████████| 7414/7414 [00:00<00:00, 20930.30it/s]\n"
]
}
],
"source": [
"import re\n",
"\n",
"def handle_numbers_and_special_chars(text):\n",
" # Remove numbers\n",
" text = re.sub(r'\\d+', '', text)\n",
" \n",
" # Remove special characters\n",
" text = re.sub(r'\\W+', ' ', text)\n",
" \n",
" return text\n",
"\n",
"# Apply the function to the 'text_1' and 'text_2' columns\n",
"tqdm.pandas(desc=\"Processing 'text_1'\")\n",
"in_df['text_1'] = in_df['text_1'].progress_apply(handle_numbers_and_special_chars)\n",
"\n",
"tqdm.pandas(desc=\"Processing 'text_2'\")\n",
"in_df['text_2'] = in_df['text_2'].progress_apply(handle_numbers_and_special_chars)"
]
},
{
"cell_type": "code",
"execution_count": 92,
"metadata": {},
"outputs": [],
"source": [
"in_df.to_csv('preprocessed_test_text1_text2.tsv', sep='\\t', index=False)"
]
},
{
"cell_type": "code",
"execution_count": 93,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processed 0 rows. Current accuracy: 0.0\n",
"Processed 1000 rows. Current accuracy: 0.000999000999000999\n",
"Processed 2000 rows. Current accuracy: 0.0004997501249375312\n",
"Processed 3000 rows. Current accuracy: 0.0003332222592469177\n",
"Processed 4000 rows. Current accuracy: 0.0004998750312421895\n",
"Processed 5000 rows. Current accuracy: 0.0005998800239952009\n",
"Processed 6000 rows. Current accuracy: 0.0006665555740709882\n",
"Processed 7000 rows. Current accuracy: 0.0005713469504356521\n",
"The accuracy of the model is 0.0005395198273536552\n"
]
}
],
"source": [
"# Initialize a counter for the correct predictions\n",
"correct_predictions = 0\n",
"\n",
"# Open the output file\n",
"with open('out.tsv', 'w') as f:\n",
" # Iterate over the rows of the input DataFrame and the expected DataFrame\n",
" for i, ((_, input_row), expected_word) in enumerate(zip(in_df.iterrows(), expected_df[0])):\n",
" try:\n",
" # Get the context\n",
" context_1 = input_row['text_1']\n",
" context_2 = input_row['text_2']\n",
"\n",
" # Predict the missing word and get the probabilities\n",
" predicted_word, prob, other_probs_sum = predict_missing_word(model, context_1, context_2)\n",
"\n",
" # If any of the values are None, use placeholder values\n",
" if predicted_word is None:\n",
" predicted_word = 'placeholder'\n",
" if prob is None:\n",
" prob = 0.001\n",
" if other_probs_sum is None:\n",
" other_probs_sum = 0.001\n",
"\n",
" # Write the output to the file\n",
" f.write(f\"{predicted_word}:{prob:.4f} :{other_probs_sum:.4f}\\n\")\n",
"\n",
" # Check if the prediction is correct\n",
" if predicted_word == expected_word:\n",
" correct_predictions += 1\n",
"\n",
" # Log progress every 1000 iterations\n",
" if i % 1000 == 0:\n",
" print(f\"Processed {i} rows. Current accuracy: {correct_predictions / (i+1)}\")\n",
"\n",
" except Exception as e:\n",
" print(f\"Error processing row {i}: {e}\")\n",
"\n",
"# Calculate the accuracy\n",
"accuracy = correct_predictions / len(in_df)\n",
"\n",
"print(f\"The accuracy of the model is {accuracy}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "cloudspace",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}