import pandas as pd
import csv
import regex as re
from nltk import trigrams, word_tokenize
from collections import Counter, defaultdict
import nltk
nltk.download('punkt')
\n", " \n", " " ] }, "metadata": {}, "execution_count": 4 } ], "source": [ "train_data = pd.read_csv('train/in.tsv.xz', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n", "train_labels = pd.read_csv('train/expected.tsv', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n", "\n", "train_data = train_data[[6, 7]]\n", "train_data = pd.concat([train_data, train_labels], axis=1)\n", "train_data\n" ] }, { "cell_type": "code", "source": [ "train_data = train_data[:120000]" ], "metadata": { "id": "ifrGODxOTuK7" }, "execution_count": 5, "outputs": [] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "JJTvit-qPh1L", "outputId": "58f187c9-6561-4418-d0e0-fbaca2260b70" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ ":1: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", " train_data['final'] = train_data[6] + train_data[0] + train_data[7]\n" ] } ], "source": [ "train_data['final'] = train_data[6] + train_data[0] + train_data[7]\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "0GzBUzFkPh1M" }, "outputs": [], "source": [ "model = defaultdict(lambda: defaultdict(lambda: 0))" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "IViVFNNzPh1O" }, "outputs": [], "source": [ "def clean_text(text):\n", " text = text.lower().replace('-\\\\n', '').replace('\\\\n', ' ')\n", " text = re.sub(r'\\p{P}', '', text)\n", " return text" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "ZXkV4cLFPh1P" }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "3Y4_y97tPh1R" }, "outputs": [], "source": [ "for index, row in train_data.iterrows():\n", " text = clean_text(str(row['final']))\n", " words = word_tokenize(text)\n", " for w1, w2, w3 in trigrams(words, pad_right=True, pad_left=True):\n", " if w1 and w2 and w3:\n", " model[(w2, w3)][w1] += 1" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "V87WPI1PPh1S" }, "outputs": [], "source": [ "for w2_w3 in model:\n", " total_count = float(sum(model[w2_w3].values()))\n", " for w1 in model[w2_w3]:\n", " model[w2_w3][w1] /= total_count\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "TP-eEc4OPh1T" }, "outputs": [], "source": [ "def predict_probs(word1, word2):\n", " raw_prediction = dict(model[word1, word2])\n", " prediction = dict(Counter(raw_prediction).most_common(6))\n", " \n", " total_prob = 0.0\n", " str_prediction = ''\n", "\n", " for word, prob in prediction.items():\n", " total_prob += prob\n", " str_prediction += f'{word}:{prob} '\n", "\n", " if total_prob == 0.0:\n", " return 'from:0.2 the:0.2 to:0.2 a:0.1 and:0.1 of:0.1 :0.1'\n", "\n", " remaining_prob = 1 - total_prob\n", "\n", " if remaining_prob < 0.01:\n", " remaining_prob = 0.01\n", " \n", " str_prediction += f':{remaining_prob}'\n", " \n", " return str_prediction" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "aehup5qzPh1W", "outputId": "c4443682-95fb-43f0-c04d-726a65b4f6b9" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ ":1: FutureWarning: The error_bad_lines argument has been deprecated and will be removed in a future version. Use on_bad_lines in the future.\n", "\n", "\n", " dev_data = pd.read_csv('dev-0/in.tsv.xz', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n", ":1: FutureWarning: The warn_bad_lines argument has been deprecated and will be removed in a future version. Use on_bad_lines in the future.\n", "\n", "\n", " dev_data = pd.read_csv('dev-0/in.tsv.xz', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n", ":2: FutureWarning: The error_bad_lines argument has been deprecated and will be removed in a future version. Use on_bad_lines in the future.\n", "\n", "\n", " test_data = pd.read_csv('test-A/in.tsv.xz', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n", ":2: FutureWarning: The warn_bad_lines argument has been deprecated and will be removed in a future version. Use on_bad_lines in the future.\n", "\n", "\n", " test_data = pd.read_csv('test-A/in.tsv.xz', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n" ] } ], "source": [ "dev_data = pd.read_csv('dev-0/in.tsv.xz', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n", "test_data = pd.read_csv('test-A/in.tsv.xz', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "bTCUDesePh1X" }, "outputs": [], "source": [ "with open('dev-0/out.tsv', 'w') as file:\n", " for index, row in dev_data.iterrows():\n", " text = clean_text(str(row[7]))\n", " words = word_tokenize(text)\n", " if len(words) < 4:\n", " prediction = 'from:0.2 the:0.2 to:0.2 a:0.1 and:0.1 of:0.1 :0.1'\n", " else:\n", " prediction = predict_probs(words[0], words[1])\n", " file.write(prediction + '\\n')\n", "\n" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "kzg8J0hAPh1Y" }, "outputs": [], "source": [ "\n", "with open('test-A/out.tsv', 'w') as file:\n", " for index, row in test_data.iterrows():\n", " text = clean_text(str(row[7]))\n", " words = word_tokenize(text)\n", " if len(words) < 4:\n", " prediction = 'from:0.2 the:0.2 to:0.2 a:0.1 and:0.1 of:0.1 :0.1'\n", " else:\n", " prediction = predict_probs(words[0], words[1])\n", " file.write(prediction + '\\n')" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "s01_AhIbPh1b" }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.2" }, "colab": { "provenance": [] } }, "nbformat": 4, "nbformat_minor": 0 }