challenging-america-word-ga.../run.ipynb
2022-04-10 21:13:49 +02:00

216 lines
6.6 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 136,
"id": "e1ae390b",
"metadata": {},
"outputs": [],
"source": [
"import lzma\n",
"import nltk\n",
"\n",
"from nltk.tokenize import word_tokenize\n",
"from nltk import trigrams\n",
"from nltk.stem import PorterStemmer\n",
"from nltk.tokenize import word_tokenize\n",
"from statistics import mean\n",
"from wordcloud import WordCloud,STOPWORDS\n",
"from collections import defaultdict, Counter\n",
"import plotly.express as px\n",
"import pandas as pd\n",
"from tqdm import tqdm\n",
"from nltk import ngrams\n",
"import pandas as pd\n",
"import csv\n",
"import cleantext\n",
"import re\n",
"import string"
]
},
{
"cell_type": "code",
"execution_count": 172,
"id": "32ece3fd",
"metadata": {},
"outputs": [],
"source": [
"model = defaultdict(lambda: defaultdict(lambda: 0))"
]
},
{
"cell_type": "code",
"execution_count": 173,
"id": "9174b2bb",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_755/1610768154.py:1: FutureWarning: The error_bad_lines argument has been deprecated and will be removed in a future version. Use on_bad_lines in the future.\n",
"\n",
"\n",
" train_file_in = pd.read_csv(\"train/in.tsv.xz\", sep=\"\\t\", error_bad_lines=False, header=None, quoting=csv.QUOTE_NONE, nrows=200000)\n",
"/tmp/ipykernel_755/1610768154.py:2: FutureWarning: The error_bad_lines argument has been deprecated and will be removed in a future version. Use on_bad_lines in the future.\n",
"\n",
"\n",
" train_file_out = pd.read_csv(\"train/expected.tsv\", sep=\"\\t\", error_bad_lines=False, header=None, quoting=csv.QUOTE_NONE, nrows=200000)\n"
]
}
],
"source": [
"train_file_in = pd.read_csv(\"train/in.tsv.xz\", sep=\"\\t\", error_bad_lines=False, header=None, quoting=csv.QUOTE_NONE, nrows=200000)\n",
"train_file_out = pd.read_csv(\"train/expected.tsv\", sep=\"\\t\", error_bad_lines=False, header=None, quoting=csv.QUOTE_NONE, nrows=200000)"
]
},
{
"cell_type": "code",
"execution_count": 174,
"id": "52848cf3",
"metadata": {},
"outputs": [],
"source": [
"stop_words= nltk.corpus.stopwords.words('english')\n",
"\n",
"def get_20common_2grams(text, n):\n",
" outputTrigrams = []\n",
" n_grams = ngrams(nltk.tokenize.word_tokenize(text), n)\n",
" for grams in n_grams:\n",
" outputTrigrams.append(grams)\n",
" return outputTrigrams\n",
"\n",
"def get_20common_2grams_no_stop(text, n):\n",
" tokenized_world = nltk.tokenize.word_tokenize(text)\n",
" stop_words= nltk.corpus.stopwords.words('english') \n",
" tokenized_no_stop = [i for i in tokenized_world if i not in stop_words]\n",
" n_grams = ngrams(tokenized_no_stop, n)\n",
" return n_grams\n",
"\n",
"def predict(word_before, word_after):\n",
" print(\"tu jestem\")\n",
" prob_list = dict(Counter(model[(word_before, word_after)]).most_common(6)).items()\n",
" predictions = []\n",
" prob_sum = 0.0\n",
" for key, value in prob_list:\n",
" print(\"tu jestem .................................\")\n",
" prob_sum += value\n",
" predictions.append(f'{key}:{value}')\n",
" if prob_sum == 0.0:\n",
" print(\"a teraz tu\")\n",
" return 'the:0:2 be:0.2 to:0.2 of:0.15 and:0.15 :0.1'\n",
" remaining_prob = 1 - prob_sum\n",
" if remaining_prob < 0.01:\n",
" predictions.append(f':{0.01}')\n",
" return ' '.join(predictions)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 175,
"id": "b7757d06",
"metadata": {},
"outputs": [],
"source": [
"train = train_file_in[[6, 7]]\n",
"train = pd.concat([train, train_file_out], axis=1)\n",
"\n",
"train[\"result\"] = train[6] + train[0] + train[7]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0b9ce1a1",
"metadata": {},
"outputs": [],
"source": [
"for index, row in train.iterrows():\n",
" lower= str(row[\"result\"]).lower()\n",
" new_doc = re.sub(\"s+\",\" \", lower)\n",
" text_clean = \"\".join([i for i in new_doc if i not in string.punctuation])\n",
" words = word_tokenize(text_clean)\n",
" for w1, w2, w3 in trigrams(words, pad_right=True, pad_left=True):\n",
" if w1 and w2 and w3:\n",
" model[(w2, w3)][w1] += 1"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "57c08749",
"metadata": {},
"outputs": [],
"source": [
"for key in model:\n",
" total_count = float(sum(model[key].values()))\n",
" for value in model[key]:\n",
" model[key][value] /= total_count"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f0ad2b1a",
"metadata": {},
"outputs": [],
"source": [
"dev_data = pd.read_csv('dev-0/in.tsv.xz', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n",
"test_a_data = pd.read_csv('test-A/in.tsv.xz', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e91e51e5",
"metadata": {},
"outputs": [],
"source": [
"with open('dev-0/out.tsv', 'w') as file:\n",
" for index, row in dev_data.iterrows():\n",
" lower= str(row[7]).lower()\n",
" new_doc = re.sub(\"s+\",\" \", lower)\n",
" text_clean = \"\".join([i for i in new_doc if i not in string.punctuation])\n",
" words = word_tokenize(text_clean)\n",
" if len(words) < 4:\n",
" print(words)\n",
" prediction = 'the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1'\n",
" else:\n",
" prediction = predict(words[0], words[1])\n",
" file.write(prediction + '\\n')\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "01e3f7e8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}