challenging-america-word-ga.../run_3.ipynb
2023-05-10 20:47:58 +02:00

24 KiB
Raw Blame History

import pandas as pd
import csv
import regex as re
from nltk import trigrams, word_tokenize
from collections import Counter, defaultdict
import nltk
nltk.download('punkt')
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
True
from google.colab import drive
drive.mount('/content/gdrive')
Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
cd '/content/gdrive/MyDrive/challenging-america-word-gap-prediction/'
/content/gdrive/MyDrive/challenging-america-word-gap-prediction
train_data = pd.read_csv('train/in.tsv.xz', sep='\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)
train_labels = pd.read_csv('train/expected.tsv', sep='\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)

train_data = train_data[[6, 7]]
train_data = pd.concat([train_data, train_labels], axis=1)
train_data
<ipython-input-4-06713320f790>:1: FutureWarning: The error_bad_lines argument has been deprecated and will be removed in a future version. Use on_bad_lines in the future.


  train_data = pd.read_csv('train/in.tsv.xz', sep='\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)
<ipython-input-4-06713320f790>:1: FutureWarning: The warn_bad_lines argument has been deprecated and will be removed in a future version. Use on_bad_lines in the future.


  train_data = pd.read_csv('train/in.tsv.xz', sep='\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)
<ipython-input-4-06713320f790>:2: FutureWarning: The error_bad_lines argument has been deprecated and will be removed in a future version. Use on_bad_lines in the future.


  train_labels = pd.read_csv('train/expected.tsv', sep='\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)
<ipython-input-4-06713320f790>:2: FutureWarning: The warn_bad_lines argument has been deprecated and will be removed in a future version. Use on_bad_lines in the future.


  train_labels = pd.read_csv('train/expected.tsv', sep='\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)
6 7 0
0 came fiom the last place to this\nplace, and t... said\nit's all squash. The best I could get\ni... lie
1 MB. BOOT'S POLITICAL OBEED\nAttempt to imagine... \ninto a proper perspective with those\nminor ... himself
2 "Thera were in 1771 only aeventy-nine\n*ub*erl... all notU\nashore and afloat arc subjects for I... of
3 A gixnl man y nitereRtiiiv dii-clos-\nur«s reg... ceucju l< d no; <o waste it nud so\nsunk it in... ably
4 Tin: 188UB TV THF BBABBT QABJE\nMr. Schiffs *t... ascertained w? OCt the COOltS of ibis\nletale ... j
... ... ... ...
432017 Sam Clendenin bad a fancy for Ui«\nscience of ... \nSam was arrested.\nThe case excited a great ... and
432018 Wita.htt halting the party ware dilven to the ... through the alnp the »Uitors laapeeeed tia.»\n... paasliic
432019 It was the last thing that either of\nthem exp... Agua Negra across the line.\nIt was a grim pla... for
432020 settlement with the department.\nIt is also sh... \na note of Wood, Dialogue fc Co., for\nc27,im... for
432021 Flour quotations—low extras at 1 R0®2 50;\ncit... 3214c;do White at 3614c: Mixed Western at\n331... at

432022 rows × 3 columns

train_data = train_data[:120000]
train_data['final'] = train_data[6] + train_data[0] + train_data[7]
<ipython-input-6-b31274590998>:1: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  train_data['final'] = train_data[6] + train_data[0] + train_data[7]
model = defaultdict(lambda: defaultdict(lambda: 0))
def clean_text(text):
    text = text.lower().replace('-\\\\n', '').replace('\\\\n', ' ')
    text = re.sub(r'\p{P}', '', text)
    return text
for index, row in train_data.iterrows():
    text = clean_text(str(row['final']))
    words = word_tokenize(text)
    for w1, w2, w3 in trigrams(words, pad_right=True, pad_left=True):
        if w1 and w2 and w3:
            model[(w2, w3)][w1] += 1
for w2_w3 in model:
    total_count = float(sum(model[w2_w3].values()))
    for w1 in model[w2_w3]:
        model[w2_w3][w1] /= total_count
def predict_probs(word1, word2):
    raw_prediction = dict(model[word1, word2])
    prediction = dict(Counter(raw_prediction).most_common(6))
    
    total_prob = 0.0
    str_prediction = ''

    for word, prob in prediction.items():
        total_prob += prob
        str_prediction += f'{word}:{prob} '

    if total_prob == 0.0:
        return 'from:0.2 the:0.2 to:0.2 a:0.1 and:0.1 of:0.1 :0.1'

    remaining_prob = 1 - total_prob

    if remaining_prob < 0.01:
        remaining_prob = 0.01
        
    str_prediction += f':{remaining_prob}'
    
    return str_prediction
dev_data = pd.read_csv('dev-0/in.tsv.xz', sep='\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)
test_data = pd.read_csv('test-A/in.tsv.xz', sep='\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)
<ipython-input-12-94466712d0ba>:1: FutureWarning: The error_bad_lines argument has been deprecated and will be removed in a future version. Use on_bad_lines in the future.


  dev_data = pd.read_csv('dev-0/in.tsv.xz', sep='\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)
<ipython-input-12-94466712d0ba>:1: FutureWarning: The warn_bad_lines argument has been deprecated and will be removed in a future version. Use on_bad_lines in the future.


  dev_data = pd.read_csv('dev-0/in.tsv.xz', sep='\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)
<ipython-input-12-94466712d0ba>:2: FutureWarning: The error_bad_lines argument has been deprecated and will be removed in a future version. Use on_bad_lines in the future.


  test_data = pd.read_csv('test-A/in.tsv.xz', sep='\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)
<ipython-input-12-94466712d0ba>:2: FutureWarning: The warn_bad_lines argument has been deprecated and will be removed in a future version. Use on_bad_lines in the future.


  test_data = pd.read_csv('test-A/in.tsv.xz', sep='\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)
with open('dev-0/out.tsv', 'w') as file:
    for index, row in dev_data.iterrows():
        text = clean_text(str(row[7]))
        words = word_tokenize(text)
        if len(words) < 4:
            prediction = 'from:0.2 the:0.2 to:0.2 a:0.1 and:0.1 of:0.1 :0.1'
        else:
            prediction = predict_probs(words[0], words[1])
        file.write(prediction + '\n')


with open('test-A/out.tsv', 'w') as file:
    for index, row in test_data.iterrows():
        text = clean_text(str(row[7]))
        words = word_tokenize(text)
        if len(words) < 4:
            prediction = 'from:0.2 the:0.2 to:0.2 a:0.1 and:0.1 of:0.1 :0.1'
        else:
            prediction = predict_probs(words[0], words[1])
        file.write(prediction + '\n')