In [6]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [7]:
root_path = '/content/gdrive/MyDrive/challenging-america-word-gap-prediction'

In [8]:
import torch
torch.cuda.is_available()

True

In [9]:
import torch
import csv
torch.cuda.empty_cache()
from torch.utils.data import DataLoader
import pandas as pd
from os.path import exists
from torchtext.vocab import build_vocab_from_iterator
import itertools
import regex as re
from csv import QUOTE_NONE
from torch import nn

ENCODING = "utf-8"

REP = re.compile(r"[{}\[\]\&%^$*#\(\)@\t\n0123456789]+")
REM = re.compile(r"'s|[\-]\\n|\-\\n|\p{P}")

def read_csv(fname):
    return pd.read_csv(fname, sep="\t", on_bad_lines='skip', header=None, quoting=QUOTE_NONE, encoding=ENCODING)

def clean_text(text):
    res = str(text).lower().strip()
    res = res.replace("’", "'")
    res = REM.sub("", res)
    res = REP.sub(" ", res)
    res = res.replace("'t", " not")
    res = res.replace("'s", " is")
    res = res.replace("'ll", " will")
    res = res.replace("'ve'", "have")
    return res.replace("'m", " am")

def get_words_from_line(line, specials = True):
    line = line.rstrip()
    if specials:
      yield '<s>'
    for m in re.finditer(r'[\p{L}0-9\*]+|\p{P}+', line):
        yield m.group(0).lower()
    if specials:
      yield '</s>'


def get_word_lines_from_data(d):
    for line in d:
        yield get_words_from_line(line)




class Bigrams(torch.utils.data.IterableDataset):
    def __init__(self, data, vocabulary_size):
        self.vocab = build_vocab_from_iterator(
           get_word_lines_from_data(data),
           max_tokens = vocabulary_size,
           specials = ['<unk>'])
        self.vocab.set_default_index(self.vocab['<unk>'])
        self.vocabulary_size = vocabulary_size
        self.data = data

    @staticmethod
    def look_ahead_iterator(gen):
        w1 = None
        for item in gen:
            if w1 is not None:
                yield (w1, item)
            w1 = item

    def __iter__(self):
       return self.look_ahead_iterator(
           (self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_data(self.data))))

class SimpleBigramNeuralLanguageModel(torch.nn.Module):
    def __init__(self, vocabulary_size, embedding_size):
        super(SimpleBigramNeuralLanguageModel, self).__init__()
        self.model = nn.Sequential(
            nn.Embedding(vocabulary_size, embedding_size),
            nn.Linear(embedding_size, vocabulary_size),
            nn.Softmax(),
        )

    def forward(self, x):
        return self.model(x)




In [10]:


data = read_csv("/content/gdrive/MyDrive/challenging-america-word-gap-prediction/train/in.tsv.xz")
train_words = read_csv("/content/gdrive/MyDrive/challenging-america-word-gap-prediction/train/expected.tsv")

train_data = data[[6, 7]]
train_data = pd.concat([train_data, train_words], axis=1)
train_data = train_data[6] + train_data[0] + train_data[7]
train_data = train_data.apply(clean_text)

vocab_size = 30000
embed_size = 150

train_dataset = Bigrams(train_data, vocab_size)



device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SimpleBigramNeuralLanguageModel(vocab_size, embed_size).to(device)
print(device)
if(not exists('model1.bin')):
    data = DataLoader(train_dataset, batch_size=8000)
    optimizer = torch.optim.Adam(model.parameters())
    criterion = torch.nn.NLLLoss()

    model.train()
    step = 0
    for i in range(2):
      print(f"EPOCH {i}=========================")
      for x, y in data:
          x = x.to(device)
          y = y.to(device)
          optimizer.zero_grad()
          ypredicted = model(x)
          loss = criterion(torch.log(ypredicted), y)
          if step % 100 == 0:
              print(step, loss)
          step += 1
          loss.backward()
          optimizer.step()

    torch.save(model.state_dict(), 'model1.bin')
else:
    print("Loading model1")
    model.load_state_dict(torch.load('model1.bin'))



vocab = train_dataset.vocab

def predict(tokens):
    ixs = torch.tensor(vocab.forward(tokens)).to(device)
    out = model(ixs)
    top = torch.topk(out[0], 8)
    top_indices = top.indices.tolist()
    top_probs = top.values.tolist()
    top_words = vocab.lookup_tokens(top_indices)
    result = ""
    for word, prob in list(zip(top_words, top_probs)):
                result += f"{word}:{prob} "
    # result  += f':0.01'
    return result

DEFAULT_PREDICTION = "from:0.2 the:0.2 to:0.2 a:0.1 and:0.1 of:0.1 :0.1"

def predict_file(result_path, data):
    with open(result_path, "w+", encoding="UTF-8") as f:
        for row in data:
            result = {}
            before = None
            for before in get_words_from_line(clean_text(str(row)), False):
              pass
            before = [before]
            print(before)
            if(len(before) < 1):
                result = DEFAULT_PREDICTION
            else:
                result = predict(before)
            result = result.strip()
            f.write(result + "\n")
            print(result)



cuda
Loading model1


In [11]:
dev_data = pd.read_csv("/content/gdrive/MyDrive/challenging-america-word-gap-prediction/dev-0/in.tsv.xz", sep='\t', header=None, quoting=csv.QUOTE_NONE)[6]
dev_data = dev_data.apply(clean_text)
predict_file("/content/gdrive/MyDrive/challenging-america-word-gap-prediction/dev-0/out.tsv", dev_data)

test_data = pd.read_csv("/content/gdrive/MyDrive/challenging-america-word-gap-prediction/test-A/in.tsv.xz", sep='\t', header=None, quoting=csv.QUOTE_NONE)[6]
test_data = test_data.apply(clean_text)
predict_file("/content/gdrive/MyDrive/challenging-america-word-gap-prediction/test-A/out.tsv", test_data)

['fromn']
to:0.6022941470146179 <unk>:0.08413857221603394 the:0.02710115723311901 a:0.024180108681321144 cents:0.01099067460745573 and:0.010230590589344501 oclock:0.01012511644512415 at:0.007927066646516323
['its']
<unk>:0.2702571153640747 own:0.019152523949742317 use:0.006631065625697374 way:0.005629365798085928 provisions:0.004911798983812332 power:0.004610280506312847 origin:0.0041810800321400166 present:0.004065223038196564
['ot']
the:0.22001349925994873 <unk>:0.19508947432041168 a:0.030033614486455917 this:0.01654713787138462 tho:0.016085060313344002 his:0.013413750566542149 tbe:0.011244924739003181 said:0.010236472822725773
['singlenspiing']
<unk>:0.13795071840286255 the:0.050966303795576096 of:0.046845439821481705 and:0.042942408472299576 to:0.03077627159655094 in:0.023394089192152023 a:0.01842011697590351 that:0.01086820662021637
['com']
<unk>:0.3745647668838501 nmittee:0.0985867902636528 npany:0.07008005678653717 nplete:0.0227628406137228 nmenced:0.022177640348672867 nmunity:0

  input = module(input)


[1;30;43mStrumieniowane dane wyjściowe obcięte do 5000 ostatnich wierszy.[0m
['wellnspread']
<unk>:0.13795071840286255 the:0.050966303795576096 of:0.046845439821481705 and:0.042942408472299576 to:0.03077627159655094 in:0.023394089192152023 a:0.01842011697590351 that:0.01086820662021637
['from']
the:0.25287488102912903 <unk>:0.15713191032409668 a:0.034414686262607574 his:0.014321111142635345 tho:0.014128337614238262 this:0.012060044333338737 which:0.011743685230612755 to:0.011461544781923294
['of']
the:0.24628451466560364 <unk>:0.16549083590507507 a:0.030842546373605728 this:0.018929913640022278 his:0.013870411552488804 tho:0.011849009431898594 said:0.010169874876737595 their:0.00833516288548708
['that']
the:0.1410231739282608 <unk>:0.11953254044055939 he:0.044784024357795715 it:0.03494413569569588 they:0.022156892344355583 is:0.019587429240345955 a:0.019473088905215263 there:0.015096952207386494
['say']
that:0.21495364606380463 <unk>:0.11226599663496017 the:0.04550514370203018 to:0.0

  input = module(input)


[1;30;43mStrumieniowane dane wyjściowe obcięte do 5000 ostatnich wierszy.[0m
['employed']
in:0.22053653001785278 by:0.08955898135900497 <unk>:0.08067802339792252 to:0.05950067192316055 as:0.03592720627784729 and:0.035128992050886154 at:0.02778000943362713 the:0.02518022619187832
['the']
<unk>:0.22546915709972382 same:0.00775930006057024 state:0.007242937106639147 first:0.006000572349876165 city:0.005373408552259207 most:0.005041860044002533 people:0.005007922183722258 united:0.0049338326789438725
['man']
<unk>:0.12974177300930023 who:0.11130981147289276 of:0.04574340209364891 and:0.042749278247356415 in:0.04084893316030502 is:0.02529185637831688 to:0.02484465390443802 was:0.01994374394416809
['acre']
<unk>:0.12164300680160522 of:0.08069809526205063 and:0.07071764022111893 in:0.05700080841779709 the:0.03265472128987312 on:0.02822159044444561 for:0.025697262957692146 is:0.021180791780352592
['muchnengaged']
<unk>:0.13795071840286255 the:0.050966303795576096 of:0.046845439821481705 and: