1384 lines
50 KiB
Plaintext
1384 lines
50 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"import lzma\n",
|
|
"\n",
|
|
"folders = [\"./challenging-america-word-gap-prediction/dev-0\",\n",
|
|
" \"./challenging-america-word-gap-prediction/test-A\",\n",
|
|
" \"./challenging-america-word-gap-prediction/train\"]\n",
|
|
"\n",
|
|
"for folder in folders:\n",
|
|
" for file in os.listdir(folder):\n",
|
|
" if file.endswith(\".tsv.xz\"):\n",
|
|
" file_path = os.path.join(folder, file)\n",
|
|
" output_path = os.path.splitext(file_path)[0] # Remove the .xz extension\n",
|
|
" with lzma.open(file_path, \"rb\") as compressed_file:\n",
|
|
" with open(output_path, \"wb\") as output_file:\n",
|
|
" output_file.write(compressed_file.read())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[nltk_data] Downloading package punkt to\n",
|
|
"[nltk_data] /teamspace/studios/this_studio/nltk_data...\n",
|
|
"[nltk_data] Package punkt is already up-to-date!\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"True"
|
|
]
|
|
},
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"import nltk\n",
|
|
"nltk.download('punkt')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Skipping line 30538: expected 8 fields, saw 9\n",
|
|
"Skipping line 37185: expected 8 fields, saw 9\n",
|
|
"Skipping line 40930: expected 8 fields, saw 9\n",
|
|
"Skipping line 44499: expected 8 fields, saw 9\n",
|
|
"Skipping line 46409: expected 8 fields, saw 9\n",
|
|
"Skipping line 52642: expected 8 fields, saw 9\n",
|
|
"Skipping line 53046: expected 8 fields, saw 9\n",
|
|
"\n",
|
|
"Skipping line 69658: expected 8 fields, saw 9\n",
|
|
"Skipping line 71325: expected 8 fields, saw 9\n",
|
|
"Skipping line 72955: expected 8 fields, saw 9\n",
|
|
"Skipping line 80528: expected 8 fields, saw 9\n",
|
|
"Skipping line 96979: expected 8 fields, saw 9\n",
|
|
"Skipping line 121731: expected 8 fields, saw 9\n",
|
|
"Skipping line 126630: expected 8 fields, saw 9\n",
|
|
"\n",
|
|
"Skipping line 132289: expected 8 fields, saw 9\n",
|
|
"Skipping line 140251: expected 8 fields, saw 9\n",
|
|
"Skipping line 142374: expected 8 fields, saw 9\n",
|
|
"Skipping line 149592: expected 8 fields, saw 9\n",
|
|
"Skipping line 150041: expected 8 fields, saw 9\n",
|
|
"Skipping line 151624: expected 8 fields, saw 9\n",
|
|
"Skipping line 158163: expected 8 fields, saw 9\n",
|
|
"Skipping line 159665: expected 8 fields, saw 9\n",
|
|
"Skipping line 171749: expected 8 fields, saw 9\n",
|
|
"Skipping line 174845: expected 8 fields, saw 9\n",
|
|
"Skipping line 177638: expected 8 fields, saw 9\n",
|
|
"Skipping line 178778: expected 8 fields, saw 9\n",
|
|
"Skipping line 188823: expected 8 fields, saw 9\n",
|
|
"Skipping line 191398: expected 8 fields, saw 9\n",
|
|
"\n",
|
|
"Skipping line 196865: expected 8 fields, saw 9\n",
|
|
"Skipping line 203572: expected 8 fields, saw 9\n",
|
|
"Skipping line 207802: expected 8 fields, saw 9\n",
|
|
"Skipping line 214509: expected 8 fields, saw 9\n",
|
|
"Skipping line 214633: expected 8 fields, saw 9\n",
|
|
"Skipping line 217906: expected 8 fields, saw 9\n",
|
|
"Skipping line 220906: expected 8 fields, saw 9\n",
|
|
"Skipping line 238000: expected 8 fields, saw 9\n",
|
|
"Skipping line 257754: expected 8 fields, saw 9\n",
|
|
"Skipping line 259366: expected 8 fields, saw 9\n",
|
|
"Skipping line 261826: expected 8 fields, saw 9\n",
|
|
"\n",
|
|
"Skipping line 272727: expected 8 fields, saw 9\n",
|
|
"Skipping line 280527: expected 8 fields, saw 9\n",
|
|
"Skipping line 282454: expected 8 fields, saw 9\n",
|
|
"Skipping line 285910: expected 8 fields, saw 9\n",
|
|
"Skipping line 289865: expected 8 fields, saw 9\n",
|
|
"Skipping line 292892: expected 8 fields, saw 9\n",
|
|
"Skipping line 292984: expected 8 fields, saw 9\n",
|
|
"Skipping line 293058: expected 8 fields, saw 9\n",
|
|
"Skipping line 302716: expected 8 fields, saw 9\n",
|
|
"Skipping line 303370: expected 8 fields, saw 9\n",
|
|
"Skipping line 314194: expected 8 fields, saw 9\n",
|
|
"Skipping line 321975: expected 8 fields, saw 9\n",
|
|
"Skipping line 324999: expected 8 fields, saw 9\n",
|
|
"\n",
|
|
"Skipping line 331978: expected 8 fields, saw 9\n",
|
|
"Skipping line 345426: expected 8 fields, saw 9\n",
|
|
"Skipping line 345951: expected 8 fields, saw 9\n",
|
|
"Skipping line 355430: expected 8 fields, saw 9\n",
|
|
"Skipping line 358744: expected 8 fields, saw 9\n",
|
|
"Skipping line 361491: expected 8 fields, saw 9\n",
|
|
"Skipping line 370443: expected 8 fields, saw 9\n",
|
|
"Skipping line 388057: expected 8 fields, saw 9\n",
|
|
"Skipping line 391061: expected 8 fields, saw 9\n",
|
|
"\n",
|
|
"Skipping line 395391: expected 8 fields, saw 9\n",
|
|
"Skipping line 404270: expected 8 fields, saw 9\n",
|
|
"Skipping line 407896: expected 8 fields, saw 9\n",
|
|
"Skipping line 409881: expected 8 fields, saw 9\n",
|
|
"Skipping line 421230: expected 8 fields, saw 9\n",
|
|
"Skipping line 425850: expected 8 fields, saw 9\n",
|
|
"Skipping line 427269: expected 8 fields, saw 9\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"in_df:\n",
|
|
" 0 1 2 \\\n",
|
|
"0 4e04702da929c78c52baf09c1851d3ff ST ChronAm \n",
|
|
"1 b374dadd940510271d9675d3e8caf9d8 DAILY ARIZONA SILVER BELT ChronAm \n",
|
|
"2 adb666c426bdc10fd949cb824da6c0d0 THE SAVANNAH MORNING NEWS ChronAm \n",
|
|
"3 bc2c9aa0b77d724311e3c2e12fc61c92 CHARLES CITY INTELLIGENCER ChronAm \n",
|
|
"4 0f612b991a39c712f0d745835b8b2f0d EVENING STAR ChronAm \n",
|
|
"\n",
|
|
" 3 4 5 \\\n",
|
|
"0 1919.604110 30.475470 -90.100911 \n",
|
|
"1 1909.097260 33.399478 -110.870950 \n",
|
|
"2 1900.913699 32.080926 -81.091177 \n",
|
|
"3 1864.974044 43.066361 -92.672411 \n",
|
|
"4 1878.478082 38.894955 -77.036646 \n",
|
|
"\n",
|
|
" 6 \\\n",
|
|
"0 came fiom the last place to this\\nplace, and t... \n",
|
|
"1 MB. BOOT'S POLITICAL OBEED\\nAttempt to imagine... \n",
|
|
"2 Thera were in 1771 only aeventy-nine\\n*ub*erlb... \n",
|
|
"3 whenever any prize property shall!*' condemn- ... \n",
|
|
"4 SA LKOFVALUABLE UNIMPBOV&D RE\\\\L\\nJSIATF. ON T... \n",
|
|
"\n",
|
|
" 7 \n",
|
|
"0 said\\nit's all squash. The best I could get\\ni... \n",
|
|
"1 \\ninto a proper perspective with those\\nminor ... \n",
|
|
"2 NaN \n",
|
|
"3 the ceitihcate of'\\noperate to prevent tfie ma... \n",
|
|
"4 \\nTerms of sale: One-tblrd, togethor with the ... \n",
|
|
"\n",
|
|
"expected_df:\n",
|
|
" 0\n",
|
|
"0 lie\n",
|
|
"1 himself\n",
|
|
"2 of\n",
|
|
"3 ably\n",
|
|
"4 j\n",
|
|
"\n",
|
|
"hate_speech_info_df:\n",
|
|
"\n",
|
|
"in_df info:\n",
|
|
"<class 'pandas.core.frame.DataFrame'>\n",
|
|
"RangeIndex: 428517 entries, 0 to 428516\n",
|
|
"Data columns (total 8 columns):\n",
|
|
" # Column Non-Null Count Dtype \n",
|
|
"--- ------ -------------- ----- \n",
|
|
" 0 0 428517 non-null object \n",
|
|
" 1 1 428517 non-null object \n",
|
|
" 2 2 428517 non-null object \n",
|
|
" 3 3 428517 non-null float64\n",
|
|
" 4 4 428517 non-null float64\n",
|
|
" 5 5 428517 non-null float64\n",
|
|
" 6 6 428517 non-null object \n",
|
|
" 7 7 425735 non-null object \n",
|
|
"dtypes: float64(3), object(5)\n",
|
|
"memory usage: 26.2+ MB\n",
|
|
"None\n",
|
|
"\n",
|
|
"expected_df info:\n",
|
|
"<class 'pandas.core.frame.DataFrame'>\n",
|
|
"RangeIndex: 279623 entries, 0 to 279622\n",
|
|
"Data columns (total 1 columns):\n",
|
|
" # Column Non-Null Count Dtype \n",
|
|
"--- ------ -------------- ----- \n",
|
|
" 0 0 279619 non-null object\n",
|
|
"dtypes: object(1)\n",
|
|
"memory usage: 2.1+ MB\n",
|
|
"None\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import pandas as pd\n",
|
|
"\n",
|
|
"# Load the data\n",
|
|
"in_df = pd.read_csv('./challenging-america-word-gap-prediction/train/in.tsv', sep='\\t', header=None, on_bad_lines='warn')\n",
|
|
"expected_df = pd.read_csv('./challenging-america-word-gap-prediction/train/expected.tsv', sep='\\t', header=None, on_bad_lines='warn')\n",
|
|
"\n",
|
|
"# Print out the first few rows of each DataFrame\n",
|
|
"print(\"in_df:\")\n",
|
|
"print(in_df.head())\n",
|
|
"print(\"\\nexpected_df:\")\n",
|
|
"print(expected_df.head())\n",
|
|
"print(\"\\nhate_speech_info_df:\")\n",
|
|
"\n",
|
|
"# Print out more information about each DataFrame\n",
|
|
"print(\"\\nin_df info:\")\n",
|
|
"print(in_df.info())\n",
|
|
"print(\"\\nexpected_df info:\")\n",
|
|
"print(expected_df.info())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Drop unnecessary columns\n",
|
|
"columns_to_drop = [0, 1, 2, 3, 4, 5] # Column indices to drop\n",
|
|
"in_df.drop(columns_to_drop, axis=1, inplace=True)\n",
|
|
"\n",
|
|
"# Rename remaining columns for clarity\n",
|
|
"in_df.columns = ['text_1', 'text_2']"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"in_df['text_2'].fillna('', inplace=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Replacing '\\n' in 'text_1': 100%|██████████| 428517/428517 [00:02<00:00, 166646.43it/s]\n",
|
|
"Replacing '\\t' in 'text_1': 100%|██████████| 428517/428517 [00:01<00:00, 422489.36it/s]\n",
|
|
"Replacing '\\n' in 'text_2': 100%|██████████| 428517/428517 [00:02<00:00, 149443.26it/s]\n",
|
|
"Replacing '\\t' in 'text_2': 100%|██████████| 428517/428517 [00:01<00:00, 417969.18it/s]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from tqdm import tqdm\n",
|
|
"\n",
|
|
"# Define a function to replace '\\n' with ' '\n",
|
|
"def replace_newline(text):\n",
|
|
" if isinstance(text, str):\n",
|
|
" return text.replace('\\\\n', ' ')\n",
|
|
" return text\n",
|
|
"\n",
|
|
"def replace_tabulation(text):\n",
|
|
" if isinstance(text, str):\n",
|
|
" return text.replace('\\\\t', ' ')\n",
|
|
" return text\n",
|
|
"\n",
|
|
"# Apply the function to 'text_1' and 'text_2' columns and show progress\n",
|
|
"tqdm.pandas(desc=\"Replacing '\\\\n' in 'text_1'\")\n",
|
|
"in_df['text_1'] = in_df['text_1'].progress_apply(replace_newline)\n",
|
|
"\n",
|
|
"tqdm.pandas(desc=\"Replacing '\\\\t' in 'text_1'\")\n",
|
|
"in_df['text_1'] = in_df['text_1'].progress_apply(replace_tabulation)\n",
|
|
"\n",
|
|
"tqdm.pandas(desc=\"Replacing '\\\\n' in 'text_2'\")\n",
|
|
"in_df['text_2'] = in_df['text_2'].progress_apply(replace_newline)\n",
|
|
"\n",
|
|
"\n",
|
|
"tqdm.pandas(desc=\"Replacing '\\\\t' in 'text_2'\")\n",
|
|
"in_df['text_2'] = in_df['text_2'].progress_apply(replace_tabulation)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" text_1 \\\n",
|
|
"0 came fiom the last place to this place, and th... \n",
|
|
"1 MB. BOOT'S POLITICAL OBEED Attempt to imagine ... \n",
|
|
"2 Thera were in 1771 only aeventy-nine *ub*erlbe... \n",
|
|
"3 whenever any prize property shall!*' condemn- ... \n",
|
|
"4 SA LKOFVALUABLE UNIMPBOV&D RE\\\\L JSIATF. ON TH... \n",
|
|
"\n",
|
|
" text_2 \n",
|
|
"0 said it's all squash. The best I could get in ... \n",
|
|
"1 into a proper perspective with those minor se... \n",
|
|
"2 \n",
|
|
"3 the ceitihcate of' operate to prevent tfie mak... \n",
|
|
"4 Terms of sale: One-tblrd, togethor with the e... \n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(in_df.head())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[nltk_data] Downloading package punkt to\n",
|
|
"[nltk_data] /teamspace/studios/this_studio/nltk_data...\n",
|
|
"[nltk_data] Package punkt is already up-to-date!\n",
|
|
"[nltk_data] Downloading package stopwords to\n",
|
|
"[nltk_data] /teamspace/studios/this_studio/nltk_data...\n",
|
|
"[nltk_data] Package stopwords is already up-to-date!\n",
|
|
"Processing 'text_1': 100%|██████████| 428517/428517 [16:38<00:00, 429.15it/s]\n",
|
|
"Processing 'text_2': 100%|██████████| 428517/428517 [14:55<00:00, 478.46it/s]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import nltk\n",
|
|
"from nltk.corpus import stopwords\n",
|
|
"from nltk.stem import PorterStemmer\n",
|
|
"from nltk.tokenize import word_tokenize\n",
|
|
"import string\n",
|
|
"from tqdm import tqdm\n",
|
|
"\n",
|
|
"\n",
|
|
"# If not already done, download the NLTK English stopwords and the Punkt Tokenizer Models\n",
|
|
"nltk.download('punkt')\n",
|
|
"nltk.download('stopwords')\n",
|
|
"\n",
|
|
"stop_words = set(stopwords.words('english'))\n",
|
|
"stemmer = PorterStemmer()\n",
|
|
"\n",
|
|
"def preprocess_text(text):\n",
|
|
" # Lowercase the text\n",
|
|
" text = text.lower()\n",
|
|
" \n",
|
|
" # Remove punctuation\n",
|
|
" text = text.translate(str.maketrans('', '', string.punctuation))\n",
|
|
" \n",
|
|
" # Tokenize the text\n",
|
|
" words = word_tokenize(text)\n",
|
|
" \n",
|
|
" # Remove stopwords and stem the words\n",
|
|
" words = [stemmer.stem(word) for word in words if word not in stop_words]\n",
|
|
" \n",
|
|
" # Join the words back into a single string\n",
|
|
" text = ' '.join(words)\n",
|
|
" \n",
|
|
" return text\n",
|
|
"\n",
|
|
"# Apply the preprocessing to the 'text_1' and 'text_2' columns\n",
|
|
"tqdm.pandas(desc=\"Processing 'text_1'\")\n",
|
|
"in_df['text_1'] = in_df['text_1'].progress_apply(preprocess_text)\n",
|
|
"\n",
|
|
"\n",
|
|
"tqdm.pandas(desc=\"Processing 'text_2'\")\n",
|
|
"in_df['text_2'] = in_df['text_2'].progress_apply(preprocess_text)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" text_1 \\\n",
|
|
"0 came fiom last place place place first road ev... \n",
|
|
"1 mb boot polit obe attempt imagin piatt make ad... \n",
|
|
"2 thera 1771 aeventynin uberlb lo lloyd nearli 1... \n",
|
|
"3 whenev prize properti shall condemn appeal dis... \n",
|
|
"4 sa lkofvalu unimpbovd rel jsiatf north bideof ... \n",
|
|
"\n",
|
|
" text_2 \n",
|
|
"0 said squash best could get hotel soup sandwich... \n",
|
|
"1 proper perspect minor senatori duti tho fill i... \n",
|
|
"2 \n",
|
|
"3 ceitihc oper prevent tfie make execut district... \n",
|
|
"4 term sale onetblrd togethor ex¬ pens sale cash... \n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(in_df.head())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Save 'in_df' DataFrame to a .tsv file\n",
|
|
"in_df.to_csv('preprocessed_text1_text2.tsv', sep='\\t', index=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Processing 'text_1': 100%|██████████| 428517/428517 [00:21<00:00, 19801.75it/s]\n",
|
|
"Processing 'text_2': 100%|██████████| 428517/428517 [00:25<00:00, 16938.99it/s]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import re\n",
|
|
"\n",
|
|
"def handle_numbers_and_special_chars(text):\n",
|
|
" # Remove numbers\n",
|
|
" text = re.sub(r'\\d+', '', text)\n",
|
|
" \n",
|
|
" # Remove special characters\n",
|
|
" text = re.sub(r'\\W+', ' ', text)\n",
|
|
" \n",
|
|
" return text\n",
|
|
"\n",
|
|
"# Apply the function to the 'text_1' and 'text_2' columns\n",
|
|
"tqdm.pandas(desc=\"Processing 'text_1'\")\n",
|
|
"in_df['text_1'] = in_df['text_1'].progress_apply(handle_numbers_and_special_chars)\n",
|
|
"\n",
|
|
"tqdm.pandas(desc=\"Processing 'text_2'\")\n",
|
|
"in_df['text_2'] = in_df['text_2'].progress_apply(handle_numbers_and_special_chars)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from spellchecker import SpellChecker\n",
|
|
"\n",
|
|
"spell = SpellChecker()\n",
|
|
"\n",
|
|
"def correct_spelling(text):\n",
|
|
" # Tokenize the text\n",
|
|
" words = text.split()\n",
|
|
" \n",
|
|
" # Correct spelling\n",
|
|
" corrected_words = [spell.correction(word) if spell.correction(word) is not None else '' for word in words]\n",
|
|
" \n",
|
|
" # Join the words back into a single string\n",
|
|
" text = ' '.join(corrected_words)\n",
|
|
" \n",
|
|
" return text\n",
|
|
"\n",
|
|
"# Apply the spelling correction to the 'text_1' and 'text_2' columns\n",
|
|
"tqdm.pandas(desc=\"Spelling Correction 'text_1'\")\n",
|
|
"in_df['text_1'] = in_df['text_1'].progress_apply(correct_spelling)\n",
|
|
"\n",
|
|
"tqdm.pandas(desc=\"Spelling Correction 'text_2'\")\n",
|
|
"in_df['text_2'] = in_df['text_2'].progress_apply(correct_spelling)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Define a function to concatenate 'text_1' and 'text_2'\n",
|
|
"def concatenate_texts(row):\n",
|
|
" return str(row['text_1']) + ' <MASK> ' + str(row['text_2'])\n",
|
|
"\n",
|
|
"# Apply the function to each row and show progress\n",
|
|
"tqdm.pandas(desc=\"Concatenating 'text_1' and 'text_2'\")\n",
|
|
"in_df['text'] = in_df.progress_apply(concatenate_texts, axis=1)\n",
|
|
"\n",
|
|
"# Now you can drop 'text_1' and 'text_2' columns if you want\n",
|
|
"in_df.drop(['text_1', 'text_2'], axis=1, inplace=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|██████████| 428517/428517 [00:05<00:00, 83867.84it/s]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from tqdm import tqdm\n",
|
|
"tqdm.pandas()\n",
|
|
"\n",
|
|
"# Load the preprocessed data\n",
|
|
"in_df = pd.read_csv('preprocessed_text1_text2.tsv', sep='\\t')\n",
|
|
"\n",
|
|
"# Load the expected words\n",
|
|
"expected_df = pd.read_csv('./challenging-america-word-gap-prediction/train/expected.tsv', sep='\\t', header=None, on_bad_lines='warn')\n",
|
|
"expected_df.columns = ['expected_word']\n",
|
|
"\n",
|
|
"# Add the expected words to in_df\n",
|
|
"in_df = pd.concat([in_df, expected_df], axis=1)\n",
|
|
"\n",
|
|
"# Define a function to concatenate 'text_1' and 'expected_word' and 'text_2'\n",
|
|
"def concatenate_texts(row):\n",
|
|
" return str(row['text_1']) + ' ' + str(row['expected_word']) + ' ' + str(row['text_2'])\n",
|
|
"\n",
|
|
"# Apply the function to each row and show progress\n",
|
|
"in_df['unmasked_text'] = in_df.progress_apply(concatenate_texts, axis=1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"in_df['unmasked_text'].to_csv('preprocessed_text_join_unmask.tsv', sep='\\t', index=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Save the 'text' column to a text file\n",
|
|
"in_df['unmasked_text'].to_csv('training_data.txt', index=False, header=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 42,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"=== 1/5 Counting and sorting n-grams ===\n",
|
|
"Reading /teamspace/studios/this_studio/training_data.txt\n",
|
|
"----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"****************************************************************************************************\n",
|
|
"Unigram tokens 76946114 types 3214378\n",
|
|
"=== 2/5 Calculating and sorting adjusted counts ===\n",
|
|
"Chain sizes: 1:38572536 2:1266848384 3:2375340800 4:3800545280 5:5542461952\n",
|
|
"Statistics:\n",
|
|
"1 3214378 D1=0.828129 D2=1.02195 D3+=1.19945\n",
|
|
"2 28265496 D1=0.815336 D2=1.00509 D3+=1.20745\n",
|
|
"3 64425776 D1=0.933716 D2=1.31704 D3+=1.47374\n",
|
|
"4 71021141 D1=0.980912 D2=1.49465 D3+=1.62427\n",
|
|
"5 72056555 D1=0.968608 D2=1.34446 D3+=1.42151\n",
|
|
"Memory estimate for binary LM:\n",
|
|
"type MB\n",
|
|
"probing 5063 assuming -p 1.5\n",
|
|
"probing 6012 assuming -r models -p 1.5\n",
|
|
"trie 2711 without quantization\n",
|
|
"trie 1596 assuming -q 8 -b 8 quantization \n",
|
|
"trie 2329 assuming -a 22 array pointer compression\n",
|
|
"trie 1214 assuming -a 22 -q 8 -b 8 array pointer compression and quantization\n",
|
|
"=== 3/5 Calculating and sorting initial probabilities ===\n",
|
|
"Chain sizes: 1:38572536 2:452247936 3:1288515520 4:1704507384 5:2017583540\n",
|
|
"----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n",
|
|
"####################################################################################################\n",
|
|
"=== 4/5 Calculating and writing order-interpolated probabilities ===\n",
|
|
"Chain sizes: 1:38572536 2:452247936 3:1288515520 4:1704507384 5:2017583540\n",
|
|
"----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n",
|
|
"####################################################################################################\n",
|
|
"=== 5/5 Writing ARPA model ===\n",
|
|
"----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n",
|
|
"****************************************************************************************************\n",
|
|
"Name:lmplz\tVmPeak:12890788 kB\tVmRSS:4496 kB\tRSSMax:5436904 kB\tuser:274.028\tsys:89.3375\tCPU:363.366\treal:343.888\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"!lmplz -o 5 --discount_fallback < training_data.txt > language_model.arpa"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 52,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Reading language_model.arpa\n",
|
|
"----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"****************************************************************************************************\n",
|
|
"SUCCESS\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"!build_binary language_model.arpa language_model.binary"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 51,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"=== 1/5 Counting and sorting n-grams ===\n",
|
|
"Reading /teamspace/studios/this_studio/training_data.txt\n",
|
|
"----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"****************************************************************************************************\n",
|
|
"Unigram tokens 76946114 types 3214378\n",
|
|
"=== 2/5 Calculating and sorting adjusted counts ===\n",
|
|
"Chain sizes: 1:38572536 2:799089024 3:1498291968 4:2397267200 5:3496014592 6:4794534400\n",
|
|
"Statistics:\n",
|
|
"1 3214378 D1=0.828129 D2=1.02195 D3+=1.19945\n",
|
|
"2 28265496 D1=0.815336 D2=1.00509 D3+=1.20745\n",
|
|
"3 4497710/64425776 D1=0.933716 D2=1.31704 D3+=1.47374\n",
|
|
"4 2178876/71021141 D1=0.980912 D2=1.49465 D3+=1.62427\n",
|
|
"5 1699108/72056555 D1=0.988779 D2=1.56326 D3+=1.72651\n",
|
|
"6 1460655/72365687 D1=0.972772 D2=1.34844 D3+=1.44176\n",
|
|
"Memory estimate for binary LM:\n",
|
|
"type MB\n",
|
|
"probing 943 assuming -p 1.5\n",
|
|
"probing 1165 assuming -r models -p 1.5\n",
|
|
"trie 553 without quantization\n",
|
|
"trie 343 assuming -q 8 -b 8 quantization \n",
|
|
"trie 474 assuming -a 22 array pointer compression\n",
|
|
"trie 265 assuming -a 22 -q 8 -b 8 array pointer compression and quantization\n",
|
|
"=== 3/5 Calculating and sorting initial probabilities ===\n",
|
|
"Chain sizes: 1:38572536 2:452247936 3:89954200 4:52293024 5:47575024 6:46740960\n",
|
|
"----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n",
|
|
"**##################################################################################################\n",
|
|
"=== 4/5 Calculating and writing order-interpolated probabilities ===\n",
|
|
"Chain sizes: 1:38572536 2:452247936 3:89954200 4:52293024 5:47575024 6:46740960\n",
|
|
"----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n",
|
|
"####################################################################################################\n",
|
|
"=== 5/5 Writing ARPA model ===\n",
|
|
"----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n",
|
|
"****************************************************************************************************\n",
|
|
"Name:lmplz\tVmPeak:12907180 kB\tVmRSS:5872 kB\tRSSMax:7679904 kB\tuser:168.818\tsys:56.6427\tCPU:225.461\treal:147.307\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"!lmplz -o 6 --discount_fallback -S 80% --prune 0 0 1 1 1 1 < training_data.txt > language_model.arpa"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 77,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"10519 ./challenging-america-word-gap-prediction/dev-0/in.tsv\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"!wc -l ./challenging-america-word-gap-prediction/dev-0/in.tsv"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import pandas as pd\n",
|
|
"import csv\n",
|
|
"\n",
|
|
"# Load the data\n",
|
|
"try:\n",
|
|
" in_df = pd.read_csv('./challenging-america-word-gap-prediction/dev-0/in.tsv', sep='\\t', header=None, on_bad_lines='error')\n",
|
|
"except Exception as e:\n",
|
|
" print(e)\n",
|
|
"expected_df = pd.read_csv('./challenging-america-word-gap-prediction/dev-0/expected.tsv', sep='\\t', header=None, on_bad_lines='warn', quoting=csv.QUOTE_NONE)\n",
|
|
"\n",
|
|
"print(in_df.shape[0])\n",
|
|
"\n",
|
|
"# Drop unnecessary columns\n",
|
|
"columns_to_drop = [0, 1, 2, 3, 4, 5] # Column indices to drop\n",
|
|
"in_df.drop(columns_to_drop, axis=1, inplace=True)\n",
|
|
"\n",
|
|
"# Rename remaining columns for clarity\n",
|
|
"in_df.columns = ['text_1', 'text_2']\n",
|
|
"\n",
|
|
"in_df['text_1'].fillna('placeholder', inplace=True)\n",
|
|
"in_df['text_2'].fillna('placeholder', inplace=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 79,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"10519\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import pandas as pd\n",
|
|
"import csv\n",
|
|
"\n",
|
|
"# Placeholder line\n",
|
|
"placeholder_line = ['placeholder'] * 8 # Adjust the number of fields as needed\n",
|
|
"\n",
|
|
"# Read the file line by line\n",
|
|
"with open('./challenging-america-word-gap-prediction/dev-0/in.tsv', 'r') as f:\n",
|
|
" lines = f.readlines()\n",
|
|
"\n",
|
|
"# Split each line into fields and replace problematic lines with the placeholder line\n",
|
|
"lines = [line.strip().split('\\t') if len(line.strip().split('\\t')) == 8 else placeholder_line for line in lines]\n",
|
|
"\n",
|
|
"# Convert the list of lines into a DataFrame\n",
|
|
"in_df = pd.DataFrame(lines)\n",
|
|
"\n",
|
|
"# Print the number of rows in the DataFrame\n",
|
|
"print(in_df.shape[0])\n",
|
|
"\n",
|
|
"# Drop unnecessary columns\n",
|
|
"columns_to_drop = [0, 1, 2, 3, 4, 5] # Column indices to drop\n",
|
|
"in_df.drop(columns_to_drop, axis=1, inplace=True)\n",
|
|
"\n",
|
|
"# Rename remaining columns for clarity\n",
|
|
"in_df.columns = ['text_1', 'text_2']\n",
|
|
"\n",
|
|
"in_df['text_1'].fillna('placeholder', inplace=True)\n",
|
|
"in_df['text_2'].fillna('placeholder', inplace=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 80,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"10519\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(in_df.shape[0])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 81,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Replacing '\\n' in 'text_1': 100%|██████████| 10519/10519 [00:00<00:00, 223066.53it/s]\n",
|
|
"Replacing '\\t' in 'text_1': 100%|██████████| 10519/10519 [00:00<00:00, 535825.65it/s]\n",
|
|
"Replacing '\\n' in 'text_2': 100%|██████████| 10519/10519 [00:00<00:00, 216324.84it/s]\n",
|
|
"Replacing '\\t' in 'text_2': 100%|██████████| 10519/10519 [00:00<00:00, 534630.94it/s]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from tqdm import tqdm\n",
|
|
"\n",
|
|
"# Define a function to replace '\\n' with ' '\n",
|
|
"def replace_newline(text):\n",
|
|
" if isinstance(text, str):\n",
|
|
" return text.replace('\\\\n', ' ')\n",
|
|
" return text\n",
|
|
"\n",
|
|
"def replace_tabulation(text):\n",
|
|
" if isinstance(text, str):\n",
|
|
" return text.replace('\\\\t', ' ')\n",
|
|
" return text\n",
|
|
"\n",
|
|
"# Apply the function to 'text_1' and 'text_2' columns and show progress\n",
|
|
"tqdm.pandas(desc=\"Replacing '\\\\n' in 'text_1'\")\n",
|
|
"in_df['text_1'] = in_df['text_1'].progress_apply(replace_newline)\n",
|
|
"\n",
|
|
"tqdm.pandas(desc=\"Replacing '\\\\t' in 'text_1'\")\n",
|
|
"in_df['text_1'] = in_df['text_1'].progress_apply(replace_tabulation)\n",
|
|
"\n",
|
|
"tqdm.pandas(desc=\"Replacing '\\\\n' in 'text_2'\")\n",
|
|
"in_df['text_2'] = in_df['text_2'].progress_apply(replace_newline)\n",
|
|
"\n",
|
|
"\n",
|
|
"tqdm.pandas(desc=\"Replacing '\\\\t' in 'text_2'\")\n",
|
|
"in_df['text_2'] = in_df['text_2'].progress_apply(replace_tabulation)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 82,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[nltk_data] Downloading package punkt to\n",
|
|
"[nltk_data] /teamspace/studios/this_studio/nltk_data...\n",
|
|
"[nltk_data] Package punkt is already up-to-date!\n",
|
|
"[nltk_data] Downloading package stopwords to\n",
|
|
"[nltk_data] /teamspace/studios/this_studio/nltk_data...\n",
|
|
"[nltk_data] Package stopwords is already up-to-date!\n",
|
|
"Processing 'text_1': 100%|██████████| 10519/10519 [00:23<00:00, 440.48it/s]\n",
|
|
"Processing 'text_2': 100%|██████████| 10519/10519 [00:29<00:00, 358.63it/s]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import nltk\n",
|
|
"from nltk.corpus import stopwords\n",
|
|
"from nltk.stem import PorterStemmer\n",
|
|
"from nltk.tokenize import word_tokenize\n",
|
|
"import string\n",
|
|
"from tqdm import tqdm\n",
|
|
"\n",
|
|
"\n",
|
|
"# If not already done, download the NLTK English stopwords and the Punkt Tokenizer Models\n",
|
|
"nltk.download('punkt')\n",
|
|
"nltk.download('stopwords')\n",
|
|
"\n",
|
|
"stop_words = set(stopwords.words('english'))\n",
|
|
"stemmer = PorterStemmer()\n",
|
|
"\n",
|
|
"def preprocess_text(text):\n",
|
|
" # Lowercase the text\n",
|
|
" text = text.lower()\n",
|
|
" \n",
|
|
" # Remove punctuation\n",
|
|
" text = text.translate(str.maketrans('', '', string.punctuation))\n",
|
|
" \n",
|
|
" # Tokenize the text\n",
|
|
" words = word_tokenize(text)\n",
|
|
" \n",
|
|
" # Remove stopwords and stem the words\n",
|
|
" words = [stemmer.stem(word) for word in words if word not in stop_words]\n",
|
|
" \n",
|
|
" # Join the words back into a single string\n",
|
|
" text = ' '.join(words)\n",
|
|
" \n",
|
|
" return text\n",
|
|
"\n",
|
|
"# Apply the preprocessing to the 'text_1' and 'text_2' columns\n",
|
|
"tqdm.pandas(desc=\"Processing 'text_1'\")\n",
|
|
"in_df['text_1'] = in_df['text_1'].progress_apply(preprocess_text)\n",
|
|
"\n",
|
|
"\n",
|
|
"tqdm.pandas(desc=\"Processing 'text_2'\")\n",
|
|
"in_df['text_2'] = in_df['text_2'].progress_apply(preprocess_text)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 83,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Processing 'text_1': 100%|██████████| 10519/10519 [00:00<00:00, 21823.01it/s]\n",
|
|
"Processing 'text_2': 100%|██████████| 10519/10519 [00:00<00:00, 21693.77it/s]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import re\n",
|
|
"\n",
|
|
"def handle_numbers_and_special_chars(text):\n",
|
|
" # Remove numbers\n",
|
|
" text = re.sub(r'\\d+', '', text)\n",
|
|
" \n",
|
|
" # Remove special characters\n",
|
|
" text = re.sub(r'\\W+', ' ', text)\n",
|
|
" \n",
|
|
" return text\n",
|
|
"\n",
|
|
"# Apply the function to the 'text_1' and 'text_2' columns\n",
|
|
"tqdm.pandas(desc=\"Processing 'text_1'\")\n",
|
|
"in_df['text_1'] = in_df['text_1'].progress_apply(handle_numbers_and_special_chars)\n",
|
|
"\n",
|
|
"tqdm.pandas(desc=\"Processing 'text_2'\")\n",
|
|
"in_df['text_2'] = in_df['text_2'].progress_apply(handle_numbers_and_special_chars)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 84,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"in_df.to_csv('preprocessed_dev_text1_text2.tsv', sep='\\t', index=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 85,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import kenlm\n",
|
|
"\n",
|
|
"# Load the language model\n",
|
|
"model = kenlm.Model('language_model.binary')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 86,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"def softmax(x):\n",
|
|
" e_x = np.exp(x - np.max(x))\n",
|
|
" return e_x / e_x.sum()\n",
|
|
"\n",
|
|
"def predict_missing_word(model, context_1, context_2):\n",
|
|
" # Define the vocabulary\n",
|
|
" vocabulary = set(' '.join([context_1, context_2]).split())\n",
|
|
"\n",
|
|
" # Initialize a dictionary to store the words and their scores\n",
|
|
" word_scores = {}\n",
|
|
"\n",
|
|
" # Iterate over the vocabulary\n",
|
|
" for word in vocabulary:\n",
|
|
" try:\n",
|
|
" # Generate the sentence\n",
|
|
" sentence = f\"{context_1} {word} {context_2}\"\n",
|
|
" \n",
|
|
" # Score the sentence and store it in the dictionary\n",
|
|
" word_scores[word] = model.score(sentence)\n",
|
|
"\n",
|
|
" except Exception as e:\n",
|
|
" print(f\"Error processing word '{word}': {e}\")\n",
|
|
" continue\n",
|
|
"\n",
|
|
" # If no word was found, return None for all values\n",
|
|
" if not word_scores:\n",
|
|
" return None, None, None\n",
|
|
"\n",
|
|
" # Convert the scores to probabilities using the softmax function\n",
|
|
" word_probs = {word: max(0.001, prob) for word, prob in zip(word_scores.keys(), softmax(list(word_scores.values())))}\n",
|
|
"\n",
|
|
" # Find the word with the highest probability\n",
|
|
" best_word, best_prob = max(word_probs.items(), key=lambda x: x[1])\n",
|
|
"\n",
|
|
" # Calculate the sum of probabilities for the other words\n",
|
|
" other_probs_sum = sum(prob for word, prob in word_probs.items() if word != best_word)\n",
|
|
"\n",
|
|
" return best_word, best_prob, other_probs_sum"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 87,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Processed 0 rows. Current accuracy: 0.0\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Processed 1000 rows. Current accuracy: 0.016983016983016984\n",
|
|
"Processed 2000 rows. Current accuracy: 0.026486756621689155\n",
|
|
"Processed 3000 rows. Current accuracy: 0.02599133622125958\n",
|
|
"Processed 4000 rows. Current accuracy: 0.024493876530867282\n",
|
|
"Processed 5000 rows. Current accuracy: 0.02559488102379524\n",
|
|
"Processed 6000 rows. Current accuracy: 0.024329278453591067\n",
|
|
"Processed 7000 rows. Current accuracy: 0.023853735180688472\n",
|
|
"Processed 8000 rows. Current accuracy: 0.023122109736282963\n",
|
|
"Processed 9000 rows. Current accuracy: 0.022886345961559827\n",
|
|
"Processed 10000 rows. Current accuracy: 0.0225977402259774\n",
|
|
"The accuracy of the model is 0.02243559273695218\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Initialize a counter for the correct predictions\n",
|
|
"correct_predictions = 0\n",
|
|
"\n",
|
|
"# Open the output file\n",
|
|
"with open('out.tsv', 'w') as f:\n",
|
|
" # Iterate over the rows of the input DataFrame and the expected DataFrame\n",
|
|
" for i, ((_, input_row), expected_word) in enumerate(zip(in_df.iterrows(), expected_df[0])):\n",
|
|
" try:\n",
|
|
" # Get the context\n",
|
|
" context_1 = input_row['text_1']\n",
|
|
" context_2 = input_row['text_2']\n",
|
|
"\n",
|
|
" # Predict the missing word and get the probabilities\n",
|
|
" predicted_word, prob, other_probs_sum = predict_missing_word(model, context_1, context_2)\n",
|
|
"\n",
|
|
" # If any of the values are None, use placeholder values\n",
|
|
" if predicted_word is None:\n",
|
|
" predicted_word = 'placeholder'\n",
|
|
" if prob is None:\n",
|
|
" prob = 0.001\n",
|
|
" if other_probs_sum is None:\n",
|
|
" other_probs_sum = 0.001\n",
|
|
"\n",
|
|
" # Write the output to the file\n",
|
|
" f.write(f\"{predicted_word}:{prob:.4f} :{other_probs_sum:.4f}\\n\")\n",
|
|
"\n",
|
|
" # Check if the prediction is correct\n",
|
|
" if predicted_word == expected_word:\n",
|
|
" correct_predictions += 1\n",
|
|
"\n",
|
|
" # Log progress every 1000 iterations\n",
|
|
" if i % 1000 == 0:\n",
|
|
" print(f\"Processed {i} rows. Current accuracy: {correct_predictions / (i+1)}\")\n",
|
|
"\n",
|
|
" except Exception as e:\n",
|
|
" print(f\"Error processing row {i}: {e}\")\n",
|
|
"\n",
|
|
"# Calculate the accuracy\n",
|
|
"accuracy = correct_predictions / len(in_df)\n",
|
|
"\n",
|
|
"print(f\"The accuracy of the model is {accuracy}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 88,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"7414\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import pandas as pd\n",
|
|
"import csv\n",
|
|
"\n",
|
|
"# Placeholder line\n",
|
|
"placeholder_line = ['placeholder'] * 8 # Adjust the number of fields as needed\n",
|
|
"\n",
|
|
"# Read the file line by line\n",
|
|
"with open('./challenging-america-word-gap-prediction/test-A/in.tsv', 'r') as f:\n",
|
|
" lines = f.readlines()\n",
|
|
"\n",
|
|
"# Split each line into fields and replace problematic lines with the placeholder line\n",
|
|
"lines = [line.strip().split('\\t') if len(line.strip().split('\\t')) == 8 else placeholder_line for line in lines]\n",
|
|
"\n",
|
|
"# Convert the list of lines into a DataFrame\n",
|
|
"in_df = pd.DataFrame(lines)\n",
|
|
"\n",
|
|
"# Print the number of rows in the DataFrame\n",
|
|
"print(in_df.shape[0])\n",
|
|
"\n",
|
|
"# Drop unnecessary columns\n",
|
|
"columns_to_drop = [0, 1, 2, 3, 4, 5] # Column indices to drop\n",
|
|
"in_df.drop(columns_to_drop, axis=1, inplace=True)\n",
|
|
"\n",
|
|
"# Rename remaining columns for clarity\n",
|
|
"in_df.columns = ['text_1', 'text_2']\n",
|
|
"\n",
|
|
"in_df['text_1'].fillna('placeholder', inplace=True)\n",
|
|
"in_df['text_2'].fillna('placeholder', inplace=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 89,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Replacing '\\n' in 'text_1': 100%|██████████| 7414/7414 [00:00<00:00, 216746.15it/s]\n",
|
|
"Replacing '\\t' in 'text_1': 100%|██████████| 7414/7414 [00:00<00:00, 545247.75it/s]\n",
|
|
"Replacing '\\n' in 'text_2': 100%|██████████| 7414/7414 [00:00<00:00, 223832.27it/s]\n",
|
|
"Replacing '\\t' in 'text_2': 100%|██████████| 7414/7414 [00:00<00:00, 569784.70it/s]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from tqdm import tqdm\n",
|
|
"\n",
|
|
"# Define a function to replace '\\n' with ' '\n",
|
|
"def replace_newline(text):\n",
|
|
" if isinstance(text, str):\n",
|
|
" return text.replace('\\\\n', ' ')\n",
|
|
" return text\n",
|
|
"\n",
|
|
"def replace_tabulation(text):\n",
|
|
" if isinstance(text, str):\n",
|
|
" return text.replace('\\\\t', ' ')\n",
|
|
" return text\n",
|
|
"\n",
|
|
"# Apply the function to 'text_1' and 'text_2' columns and show progress\n",
|
|
"tqdm.pandas(desc=\"Replacing '\\\\n' in 'text_1'\")\n",
|
|
"in_df['text_1'] = in_df['text_1'].progress_apply(replace_newline)\n",
|
|
"\n",
|
|
"tqdm.pandas(desc=\"Replacing '\\\\t' in 'text_1'\")\n",
|
|
"in_df['text_1'] = in_df['text_1'].progress_apply(replace_tabulation)\n",
|
|
"\n",
|
|
"tqdm.pandas(desc=\"Replacing '\\\\n' in 'text_2'\")\n",
|
|
"in_df['text_2'] = in_df['text_2'].progress_apply(replace_newline)\n",
|
|
"\n",
|
|
"\n",
|
|
"tqdm.pandas(desc=\"Replacing '\\\\t' in 'text_2'\")\n",
|
|
"in_df['text_2'] = in_df['text_2'].progress_apply(replace_tabulation)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 90,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[nltk_data] Downloading package punkt to\n",
|
|
"[nltk_data] /teamspace/studios/this_studio/nltk_data...\n",
|
|
"[nltk_data] Package punkt is already up-to-date!\n",
|
|
"[nltk_data] Downloading package stopwords to\n",
|
|
"[nltk_data] /teamspace/studios/this_studio/nltk_data...\n",
|
|
"[nltk_data] Package stopwords is already up-to-date!\n",
|
|
"Processing 'text_1': 100%|██████████| 7414/7414 [00:20<00:00, 365.75it/s]\n",
|
|
"Processing 'text_2': 100%|██████████| 7414/7414 [00:15<00:00, 478.59it/s]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import nltk\n",
|
|
"from nltk.corpus import stopwords\n",
|
|
"from nltk.stem import PorterStemmer\n",
|
|
"from nltk.tokenize import word_tokenize\n",
|
|
"import string\n",
|
|
"from tqdm import tqdm\n",
|
|
"\n",
|
|
"\n",
|
|
"# If not already done, download the NLTK English stopwords and the Punkt Tokenizer Models\n",
|
|
"nltk.download('punkt')\n",
|
|
"nltk.download('stopwords')\n",
|
|
"\n",
|
|
"stop_words = set(stopwords.words('english'))\n",
|
|
"stemmer = PorterStemmer()\n",
|
|
"\n",
|
|
"def preprocess_text(text):\n",
|
|
" # Lowercase the text\n",
|
|
" text = text.lower()\n",
|
|
" \n",
|
|
" # Remove punctuation\n",
|
|
" text = text.translate(str.maketrans('', '', string.punctuation))\n",
|
|
" \n",
|
|
" # Tokenize the text\n",
|
|
" words = word_tokenize(text)\n",
|
|
" \n",
|
|
" # Remove stopwords and stem the words\n",
|
|
" words = [stemmer.stem(word) for word in words if word not in stop_words]\n",
|
|
" \n",
|
|
" # Join the words back into a single string\n",
|
|
" text = ' '.join(words)\n",
|
|
" \n",
|
|
" return text\n",
|
|
"\n",
|
|
"# Apply the preprocessing to the 'text_1' and 'text_2' columns\n",
|
|
"tqdm.pandas(desc=\"Processing 'text_1'\")\n",
|
|
"in_df['text_1'] = in_df['text_1'].progress_apply(preprocess_text)\n",
|
|
"\n",
|
|
"\n",
|
|
"tqdm.pandas(desc=\"Processing 'text_2'\")\n",
|
|
"in_df['text_2'] = in_df['text_2'].progress_apply(preprocess_text)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 91,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Processing 'text_1': 100%|██████████| 7414/7414 [00:00<00:00, 21928.33it/s]\n",
|
|
"Processing 'text_2': 100%|██████████| 7414/7414 [00:00<00:00, 20930.30it/s]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import re\n",
|
|
"\n",
|
|
"def handle_numbers_and_special_chars(text):\n",
|
|
" # Remove numbers\n",
|
|
" text = re.sub(r'\\d+', '', text)\n",
|
|
" \n",
|
|
" # Remove special characters\n",
|
|
" text = re.sub(r'\\W+', ' ', text)\n",
|
|
" \n",
|
|
" return text\n",
|
|
"\n",
|
|
"# Apply the function to the 'text_1' and 'text_2' columns\n",
|
|
"tqdm.pandas(desc=\"Processing 'text_1'\")\n",
|
|
"in_df['text_1'] = in_df['text_1'].progress_apply(handle_numbers_and_special_chars)\n",
|
|
"\n",
|
|
"tqdm.pandas(desc=\"Processing 'text_2'\")\n",
|
|
"in_df['text_2'] = in_df['text_2'].progress_apply(handle_numbers_and_special_chars)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 92,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"in_df.to_csv('preprocessed_test_text1_text2.tsv', sep='\\t', index=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 93,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Processed 0 rows. Current accuracy: 0.0\n",
|
|
"Processed 1000 rows. Current accuracy: 0.000999000999000999\n",
|
|
"Processed 2000 rows. Current accuracy: 0.0004997501249375312\n",
|
|
"Processed 3000 rows. Current accuracy: 0.0003332222592469177\n",
|
|
"Processed 4000 rows. Current accuracy: 0.0004998750312421895\n",
|
|
"Processed 5000 rows. Current accuracy: 0.0005998800239952009\n",
|
|
"Processed 6000 rows. Current accuracy: 0.0006665555740709882\n",
|
|
"Processed 7000 rows. Current accuracy: 0.0005713469504356521\n",
|
|
"The accuracy of the model is 0.0005395198273536552\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Initialize a counter for the correct predictions\n",
|
|
"correct_predictions = 0\n",
|
|
"\n",
|
|
"# Open the output file\n",
|
|
"with open('out.tsv', 'w') as f:\n",
|
|
" # Iterate over the rows of the input DataFrame and the expected DataFrame\n",
|
|
" for i, ((_, input_row), expected_word) in enumerate(zip(in_df.iterrows(), expected_df[0])):\n",
|
|
" try:\n",
|
|
" # Get the context\n",
|
|
" context_1 = input_row['text_1']\n",
|
|
" context_2 = input_row['text_2']\n",
|
|
"\n",
|
|
" # Predict the missing word and get the probabilities\n",
|
|
" predicted_word, prob, other_probs_sum = predict_missing_word(model, context_1, context_2)\n",
|
|
"\n",
|
|
" # If any of the values are None, use placeholder values\n",
|
|
" if predicted_word is None:\n",
|
|
" predicted_word = 'placeholder'\n",
|
|
" if prob is None:\n",
|
|
" prob = 0.001\n",
|
|
" if other_probs_sum is None:\n",
|
|
" other_probs_sum = 0.001\n",
|
|
"\n",
|
|
" # Write the output to the file\n",
|
|
" f.write(f\"{predicted_word}:{prob:.4f} :{other_probs_sum:.4f}\\n\")\n",
|
|
"\n",
|
|
" # Check if the prediction is correct\n",
|
|
" if predicted_word == expected_word:\n",
|
|
" correct_predictions += 1\n",
|
|
"\n",
|
|
" # Log progress every 1000 iterations\n",
|
|
" if i % 1000 == 0:\n",
|
|
" print(f\"Processed {i} rows. Current accuracy: {correct_predictions / (i+1)}\")\n",
|
|
"\n",
|
|
" except Exception as e:\n",
|
|
" print(f\"Error processing row {i}: {e}\")\n",
|
|
"\n",
|
|
"# Calculate the accuracy\n",
|
|
"accuracy = correct_predictions / len(in_df)\n",
|
|
"\n",
|
|
"print(f\"The accuracy of the model is {accuracy}\")"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "cloudspace",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.10"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|