2022-04-23 21:36:38 +02:00
{
" cells " : [
{
" cell_type " : " code " ,
2022-06-06 11:25:46 +02:00
" execution_count " : 1 ,
" id " : " 46fe9a72-d787-46ab-a9df-c6732c173a26 " ,
2022-04-23 21:36:38 +02:00
" metadata " : { } ,
" outputs " : [ ] ,
2022-06-06 11:25:46 +02:00
" source " : [
" import pandas as pd \n " ,
" import numpy as np \n " ,
" import regex as re \n " ,
" import csv \n " ,
" import torch \n " ,
" from torch import nn \n " ,
" from gensim.models import Word2Vec \n " ,
" from nltk.tokenize import word_tokenize "
]
2022-04-23 21:36:38 +02:00
} ,
{
" cell_type " : " code " ,
" execution_count " : 2 ,
2022-06-06 11:25:46 +02:00
" id " : " 6fd72918-57e4-44f2-a0e1-a71f322df5f7 " ,
2022-04-23 21:36:38 +02:00
" metadata " : { } ,
2022-06-06 11:25:46 +02:00
" outputs " : [ ] ,
2022-04-23 21:36:38 +02:00
" source " : [
2022-06-06 11:25:46 +02:00
" def clean_text(text): \n " ,
" text = text.lower().replace( ' - \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ n ' , ' ' ).replace( ' \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ n ' , ' ' ) \n " ,
" text = re.sub(r ' \\ p {P} ' , ' ' , text) \n " ,
" text = text.replace( \" ' t \" , \" not \" ).replace( \" ' s \" , \" is \" ).replace( \" ' ll \" , \" will \" ).replace( \" ' m \" , \" am \" ).replace( \" ' ve \" , \" have \" ) \n " ,
" \n " ,
" return text "
]
} ,
{
" cell_type " : " code " ,
" execution_count " : 3 ,
" id " : " 87dc8a73-f089-4551-b0f0-dedefd0b5a05 " ,
" metadata " : { } ,
" outputs " : [ ] ,
" source " : [
" train_data = pd.read_csv( ' train/in.tsv.xz ' , sep= ' \\ t ' , on_bad_lines= ' skip ' , header=None, quoting=csv.QUOTE_NONE) \n " ,
" train_labels = pd.read_csv( ' train/expected.tsv ' , sep= ' \\ t ' , on_bad_lines= ' skip ' , header=None, quoting=csv.QUOTE_NONE) \n " ,
" \n " ,
" train_data = train_data[[6, 7]] \n " ,
" train_data = pd.concat([train_data, train_labels], axis=1) "
]
} ,
{
" cell_type " : " code " ,
" execution_count " : 4 ,
" id " : " 1d9244d4-1dd7-4c02-8e22-6ea5feee9b26 " ,
" metadata " : { } ,
" outputs " : [ ] ,
" source " : [
" class TrainCorpus: \n " ,
" def __init__(self, data): \n " ,
" self.data = data \n " ,
" \n " ,
" def __iter__(self): \n " ,
" for _, row in self.data.iterrows(): \n " ,
" text = str(row[6]) + str(row[0]) + str(row[7]) \n " ,
" text = clean_text(text) \n " ,
" yield word_tokenize(text) "
2022-04-23 21:36:38 +02:00
]
} ,
{
" cell_type " : " code " ,
" execution_count " : 5 ,
2022-06-06 11:25:46 +02:00
" id " : " 5b6f017b-990f-42e2-8ef5-ce40824ace61 " ,
2022-04-23 21:36:38 +02:00
" metadata " : { } ,
" outputs " : [
{
" name " : " stdout " ,
" output_type " : " stream " ,
" text " : [
2022-06-06 11:25:46 +02:00
" 81477 \n "
2022-04-23 21:36:38 +02:00
]
}
] ,
" source " : [
2022-06-06 11:25:46 +02:00
" train_sentences = TrainCorpus(train_data.head(80000)) \n " ,
" w2v_model = Word2Vec(vector_size=100, min_count=10) \n " ,
" w2v_model.build_vocab(corpus_iterable=train_sentences) \n " ,
" \n " ,
" key_to_index = w2v_model.wv.key_to_index \n " ,
" index_to_key = w2v_model.wv.index_to_key \n " ,
" \n " ,
" index_to_key.append( ' <unk> ' ) \n " ,
" key_to_index[ ' <unk> ' ] = len(index_to_key) - 1 \n " ,
" \n " ,
" vocab_size = len(index_to_key) \n " ,
" print(vocab_size) "
2022-04-23 21:36:38 +02:00
]
} ,
{
" cell_type " : " code " ,
2022-06-06 11:25:46 +02:00
" execution_count " : 6 ,
" id " : " 516767c7-ce51-4483-b99f-675cfd4fe99d " ,
2022-04-23 21:36:38 +02:00
" metadata " : { } ,
" outputs " : [ ] ,
" source " : [
2022-06-06 11:25:46 +02:00
" class TrainDataset(torch.utils.data.IterableDataset): \n " ,
" def __init__(self, data, index_to_key, key_to_index, reversed=False): \n " ,
" self.reversed = reversed \n " ,
" self.data = data \n " ,
" self.index_to_key = index_to_key \n " ,
" self.key_to_index = key_to_index \n " ,
" self.vocab_size = len(key_to_index) \n " ,
" \n " ,
" def __iter__(self): \n " ,
" for _, row in self.data.iterrows(): \n " ,
" text = str(row[6]) + str(row[0]) + str(row[7]) \n " ,
" text = clean_text(text) \n " ,
" tokens = word_tokenize(text) \n " ,
" if self.reversed: \n " ,
" tokens = list(reversed(tokens)) \n " ,
" for i in range(5, len(tokens), 1): \n " ,
" input_context = tokens[i-5:i] \n " ,
" target_context = tokens[i-4:i+1] \n " ,
" \n " ,
" input_embed = [self.key_to_index[word] if word in self.key_to_index else self.key_to_index[ ' <unk> ' ] for word in input_context] \n " ,
" target_embed = [self.key_to_index[word] if word in self.key_to_index else self.key_to_index[ ' <unk> ' ] for word in target_context] \n " ,
" \n " ,
" yield np.asarray(input_embed, dtype=np.int64), np.asarray(target_embed, dtype=np.int64) "
2022-04-23 21:36:38 +02:00
]
} ,
{
" cell_type " : " code " ,
2022-06-06 11:25:46 +02:00
" execution_count " : 7 ,
" id " : " 678d0388-a2b2-44ca-b686-dc133a0d16e5 " ,
2022-04-23 21:36:38 +02:00
" metadata " : { } ,
" outputs " : [ ] ,
" source " : [
2022-06-06 11:25:46 +02:00
" class Model(nn.Module): \n " ,
" def __init__(self, embed_size, vocab_size): \n " ,
" super(Model, self).__init__() \n " ,
" self.embed_size = embed_size \n " ,
" self.vocab_size = vocab_size \n " ,
" self.lstm_size = 128 \n " ,
" self.num_layers = 2 \n " ,
" \n " ,
" self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=self.embed_size) \n " ,
" self.lstm = nn.LSTM(input_size=self.embed_size, hidden_size=self.lstm_size, num_layers=self.num_layers, dropout=0.2) \n " ,
" self.fc = nn.Linear(self.lstm_size, vocab_size) \n " ,
" \n " ,
" def forward(self, x, prev_state = None): \n " ,
" embed = self.embed(x) \n " ,
" output, state = self.lstm(embed, prev_state) \n " ,
" logits = self.fc(output) \n " ,
" probs = torch.softmax(logits, dim=1) \n " ,
" return logits, state \n " ,
" \n " ,
" def init_state(self, sequence_length): \n " ,
" zeros = torch.zeros(self.num_layers, sequence_length, self.gru_size).to(device) \n " ,
" return (zeros, zeros) "
2022-04-23 21:36:38 +02:00
]
} ,
{
" cell_type " : " code " ,
2022-06-06 11:25:46 +02:00
" execution_count " : 17 ,
" id " : " 487df613-15cd-450e-a8b2-7e7cd879f8f7 " ,
2022-04-23 21:36:38 +02:00
" metadata " : { } ,
" outputs " : [ ] ,
" source " : [
2022-06-06 11:25:46 +02:00
" from torch.utils.data import DataLoader \n " ,
" from torch.optim import Adam \n " ,
2022-04-23 21:36:38 +02:00
" \n " ,
2022-06-06 11:25:46 +02:00
" def train(dataset, model, max_epochs, batch_size): \n " ,
" model.train() \n " ,
" \n " ,
" dataloader = DataLoader(dataset, batch_size=batch_size) \n " ,
" criterion = nn.CrossEntropyLoss() \n " ,
" optimizer = Adam(model.parameters(), lr=0.001) \n " ,
" \n " ,
" for epoch in range(max_epochs): \n " ,
" for batch, (x, y) in enumerate(dataloader): \n " ,
" optimizer.zero_grad() \n " ,
" \n " ,
" x = x.to(device) \n " ,
" y = y.to(device) \n " ,
" \n " ,
" y_pred, (state_h, state_c) = model(x) \n " ,
" loss = criterion(y_pred.transpose(1, 2), y) \n " ,
" \n " ,
" loss.backward() \n " ,
" optimizer.step() \n " ,
" \n " ,
" if batch % 100 == 0: \n " ,
" print(f ' epoch: {epoch} , update in batch {batch} /???, loss: { loss.item()} ' ) "
2022-04-23 21:36:38 +02:00
]
} ,
{
" cell_type " : " code " ,
2022-06-06 11:25:46 +02:00
" execution_count " : 18 ,
" id " : " 79df35ad-cb19-4867-869c-b651455580ae " ,
2022-04-23 21:36:38 +02:00
" metadata " : { } ,
" outputs " : [ ] ,
" source " : [
2022-06-06 11:25:46 +02:00
" torch.cuda.empty_cache() \n " ,
" device = ' cuda ' if torch.cuda.is_available() else ' cpu ' "
2022-04-23 21:36:38 +02:00
]
} ,
{
" cell_type " : " code " ,
2022-06-06 11:25:46 +02:00
" execution_count " : 26 ,
" id " : " e4654fd0-6c0a-47e3-b623-73c7c04ea194 " ,
2022-04-23 21:36:38 +02:00
" metadata " : { } ,
" outputs " : [ ] ,
" source " : [
2022-06-06 11:25:46 +02:00
" train_dataset_front = TrainDataset(train_data.head(8000), index_to_key, key_to_index, False) "
2022-04-23 21:36:38 +02:00
]
} ,
{
" cell_type " : " code " ,
2022-06-06 11:25:46 +02:00
" execution_count " : 27 ,
" id " : " 7b5f09f5-88c4-44d4-ab8a-aa62ec8f70d5 " ,
" metadata " : { } ,
" outputs " : [ ] ,
" source " : [
" model_front = Model(100, vocab_size).to(device) "
]
} ,
{
" cell_type " : " code " ,
" execution_count " : 21 ,
" id " : " 7e886972-4994-4bbd-aca8-2ab685c7b8db " ,
2022-04-23 21:36:38 +02:00
" metadata " : { } ,
" outputs " : [
{
" name " : " stdout " ,
" output_type " : " stream " ,
" text " : [
2022-06-06 11:25:46 +02:00
" epoch: 0, update in batch 0/???, loss: 11.315739631652832 \n " ,
" epoch: 0, update in batch 100/???, loss: 8.016324996948242 \n " ,
" epoch: 0, update in batch 200/???, loss: 7.45602560043335 \n " ,
" epoch: 0, update in batch 300/???, loss: 6.306332588195801 \n " ,
" epoch: 0, update in batch 400/???, loss: 8.629552841186523 \n " ,
" epoch: 0, update in batch 500/???, loss: 7.637443542480469 \n " ,
" epoch: 0, update in batch 600/???, loss: 7.67318868637085 \n " ,
" epoch: 0, update in batch 700/???, loss: 7.2209930419921875 \n " ,
" epoch: 0, update in batch 800/???, loss: 7.739532470703125 \n " ,
" epoch: 0, update in batch 900/???, loss: 7.219891548156738 \n " ,
" epoch: 0, update in batch 1000/???, loss: 6.8804473876953125 \n " ,
" epoch: 0, update in batch 1100/???, loss: 7.228173732757568 \n " ,
" epoch: 0, update in batch 1200/???, loss: 6.513087272644043 \n " ,
" epoch: 0, update in batch 1300/???, loss: 7.142991542816162 \n " ,
" epoch: 0, update in batch 1400/???, loss: 7.711663246154785 \n " ,
" epoch: 0, update in batch 1500/???, loss: 6.894327640533447 \n " ,
" epoch: 0, update in batch 1600/???, loss: 7.723884582519531 \n " ,
" epoch: 0, update in batch 1700/???, loss: 8.409640312194824 \n " ,
" epoch: 0, update in batch 1800/???, loss: 6.570927619934082 \n " ,
" epoch: 0, update in batch 1900/???, loss: 6.906421661376953 \n " ,
" epoch: 0, update in batch 2000/???, loss: 7.197023868560791 \n " ,
" epoch: 0, update in batch 2100/???, loss: 6.892503261566162 \n " ,
" epoch: 0, update in batch 2200/???, loss: 7.109471321105957 \n " ,
" epoch: 0, update in batch 2300/???, loss: 8.84702205657959 \n " ,
" epoch: 0, update in batch 2400/???, loss: 7.394454002380371 \n " ,
" epoch: 0, update in batch 2500/???, loss: 7.380859375 \n " ,
" epoch: 0, update in batch 2600/???, loss: 6.635237693786621 \n " ,
" epoch: 0, update in batch 2700/???, loss: 6.869620323181152 \n " ,
" epoch: 0, update in batch 2800/???, loss: 6.656294822692871 \n " ,
" epoch: 0, update in batch 2900/???, loss: 8.090291976928711 \n " ,
" epoch: 0, update in batch 3000/???, loss: 7.012345314025879 \n " ,
" epoch: 0, update in batch 3100/???, loss: 6.7099809646606445 \n " ,
" epoch: 0, update in batch 3200/???, loss: 6.798626899719238 \n " ,
" epoch: 0, update in batch 3300/???, loss: 6.510752201080322 \n " ,
" epoch: 0, update in batch 3400/???, loss: 7.742552757263184 \n " ,
" epoch: 0, update in batch 3500/???, loss: 7.3319292068481445 \n " ,
" epoch: 0, update in batch 3600/???, loss: 8.022462844848633 \n " ,
" epoch: 0, update in batch 3700/???, loss: 5.883602619171143 \n " ,
" epoch: 0, update in batch 3800/???, loss: 6.235389232635498 \n " ,
" epoch: 0, update in batch 3900/???, loss: 7.012289524078369 \n " ,
" epoch: 0, update in batch 4000/???, loss: 7.005420684814453 \n " ,
" epoch: 0, update in batch 4100/???, loss: 6.595402717590332 \n " ,
" epoch: 0, update in batch 4200/???, loss: 6.7428154945373535 \n " ,
" epoch: 0, update in batch 4300/???, loss: 6.358878135681152 \n " ,
" epoch: 0, update in batch 4400/???, loss: 6.6188201904296875 \n " ,
" epoch: 0, update in batch 4500/???, loss: 7.08281946182251 \n " ,
" epoch: 0, update in batch 4600/???, loss: 5.705609321594238 \n " ,
" epoch: 0, update in batch 4700/???, loss: 7.1878180503845215 \n " ,
" epoch: 0, update in batch 4800/???, loss: 7.071160793304443 \n " ,
" epoch: 0, update in batch 4900/???, loss: 6.768280029296875 \n " ,
" epoch: 0, update in batch 5000/???, loss: 6.507267951965332 \n " ,
" epoch: 0, update in batch 5100/???, loss: 6.6431379318237305 \n " ,
" epoch: 0, update in batch 5200/???, loss: 6.719052314758301 \n " ,
" epoch: 0, update in batch 5300/???, loss: 7.172060489654541 \n " ,
" epoch: 0, update in batch 5400/???, loss: 5.98638916015625 \n " ,
" epoch: 0, update in batch 5500/???, loss: 5.674165725708008 \n " ,
" epoch: 0, update in batch 5600/???, loss: 5.612569808959961 \n " ,
" epoch: 0, update in batch 5700/???, loss: 6.307109832763672 \n " ,
" epoch: 0, update in batch 5800/???, loss: 5.382391452789307 \n " ,
" epoch: 0, update in batch 5900/???, loss: 5.712988376617432 \n " ,
" epoch: 0, update in batch 6000/???, loss: 6.371735572814941 \n " ,
" epoch: 0, update in batch 6100/???, loss: 6.417542457580566 \n " ,
" epoch: 0, update in batch 6200/???, loss: 7.14879846572876 \n " ,
" epoch: 0, update in batch 6300/???, loss: 7.0701189041137695 \n " ,
" epoch: 0, update in batch 6400/???, loss: 7.048495292663574 \n " ,
" epoch: 0, update in batch 6500/???, loss: 7.3384833335876465 \n " ,
" epoch: 0, update in batch 6600/???, loss: 6.561330318450928 \n " ,
" epoch: 0, update in batch 6700/???, loss: 6.839573860168457 \n " ,
" epoch: 0, update in batch 6800/???, loss: 6.5179548263549805 \n " ,
" epoch: 0, update in batch 6900/???, loss: 7.246607303619385 \n " ,
" epoch: 0, update in batch 7000/???, loss: 6.5699052810668945 \n " ,
" epoch: 0, update in batch 7100/???, loss: 7.202715873718262 \n " ,
" epoch: 0, update in batch 7200/???, loss: 6.1833648681640625 \n " ,
" epoch: 0, update in batch 7300/???, loss: 5.977782249450684 \n " ,
" epoch: 0, update in batch 7400/???, loss: 6.717446327209473 \n " ,
" epoch: 0, update in batch 7500/???, loss: 6.574376583099365 \n " ,
" epoch: 0, update in batch 7600/???, loss: 5.8418450355529785 \n " ,
" epoch: 0, update in batch 7700/???, loss: 6.282655715942383 \n " ,
" epoch: 0, update in batch 7800/???, loss: 6.065321922302246 \n " ,
" epoch: 0, update in batch 7900/???, loss: 6.415077209472656 \n " ,
" epoch: 0, update in batch 8000/???, loss: 6.482673645019531 \n " ,
" epoch: 0, update in batch 8100/???, loss: 6.670407772064209 \n " ,
" epoch: 0, update in batch 8200/???, loss: 6.799211025238037 \n " ,
" epoch: 0, update in batch 8300/???, loss: 7.299313545227051 \n " ,
" epoch: 0, update in batch 8400/???, loss: 7.42974328994751 \n " ,
" epoch: 0, update in batch 8500/???, loss: 8.549559593200684 \n " ,
" epoch: 0, update in batch 8600/???, loss: 6.794680118560791 \n " ,
" epoch: 0, update in batch 8700/???, loss: 7.390380859375 \n " ,
" epoch: 0, update in batch 8800/???, loss: 7.552660942077637 \n " ,
" epoch: 0, update in batch 8900/???, loss: 6.663547515869141 \n " ,
" epoch: 0, update in batch 9000/???, loss: 6.5236711502075195 \n " ,
" epoch: 0, update in batch 9100/???, loss: 7.666424751281738 \n " ,
" epoch: 0, update in batch 9200/???, loss: 6.479496955871582 \n " ,
" epoch: 0, update in batch 9300/???, loss: 5.5056304931640625 \n " ,
" epoch: 0, update in batch 9400/???, loss: 6.6904096603393555 \n " ,
" epoch: 0, update in batch 9500/???, loss: 6.9318037033081055 \n " ,
" epoch: 0, update in batch 9600/???, loss: 6.521365165710449 \n " ,
" epoch: 0, update in batch 9700/???, loss: 6.376631736755371 \n " ,
" epoch: 0, update in batch 9800/???, loss: 6.4104766845703125 \n " ,
" epoch: 0, update in batch 9900/???, loss: 7.3995232582092285 \n " ,
" epoch: 0, update in batch 10000/???, loss: 6.510337829589844 \n " ,
" epoch: 0, update in batch 10100/???, loss: 6.2512407302856445 \n " ,
" epoch: 0, update in batch 10200/???, loss: 6.048404216766357 \n " ,
" epoch: 0, update in batch 10300/???, loss: 6.832150936126709 \n " ,
" epoch: 0, update in batch 10400/???, loss: 6.7485456466674805 \n " ,
" epoch: 0, update in batch 10500/???, loss: 5.385656833648682 \n " ,
" epoch: 0, update in batch 10600/???, loss: 6.769070625305176 \n " ,
" epoch: 0, update in batch 10700/???, loss: 6.857029914855957 \n " ,
" epoch: 0, update in batch 10800/???, loss: 5.991332530975342 \n " ,
" epoch: 0, update in batch 10900/???, loss: 6.5500006675720215 \n " ,
" epoch: 0, update in batch 11000/???, loss: 6.951509952545166 \n " ,
" epoch: 0, update in batch 11100/???, loss: 6.396986961364746 \n " ,
" epoch: 0, update in batch 11200/???, loss: 6.639346122741699 \n " ,
" epoch: 0, update in batch 11300/???, loss: 5.87351655960083 \n " ,
" epoch: 0, update in batch 11400/???, loss: 5.996974945068359 \n " ,
" epoch: 0, update in batch 11500/???, loss: 7.103158473968506 \n " ,
" epoch: 0, update in batch 11600/???, loss: 6.429941654205322 \n " ,
" epoch: 0, update in batch 11700/???, loss: 5.597273826599121 \n " ,
" epoch: 0, update in batch 11800/???, loss: 7.112508296966553 \n " ,
" epoch: 0, update in batch 11900/???, loss: 6.745194911956787 \n " ,
" epoch: 0, update in batch 12000/???, loss: 7.47100305557251 \n " ,
" epoch: 0, update in batch 12100/???, loss: 6.847914695739746 \n " ,
" epoch: 0, update in batch 12200/???, loss: 6.876992702484131 \n " ,
" epoch: 0, update in batch 12300/???, loss: 6.499053955078125 \n " ,
" epoch: 0, update in batch 12400/???, loss: 7.196413993835449 \n " ,
" epoch: 0, update in batch 12500/???, loss: 6.593430995941162 \n " ,
" epoch: 0, update in batch 12600/???, loss: 6.368945121765137 \n " ,
" epoch: 0, update in batch 12700/???, loss: 6.362246513366699 \n " ,
" epoch: 0, update in batch 12800/???, loss: 7.209506034851074 \n " ,
" epoch: 0, update in batch 12900/???, loss: 6.8092780113220215 \n " ,
" epoch: 0, update in batch 13000/???, loss: 8.273663520812988 \n " ,
" epoch: 0, update in batch 13100/???, loss: 7.061187744140625 \n " ,
" epoch: 0, update in batch 13200/???, loss: 5.778809547424316 \n " ,
" epoch: 0, update in batch 13300/???, loss: 5.650263786315918 \n " ,
" epoch: 0, update in batch 13400/???, loss: 5.9032440185546875 \n " ,
" epoch: 0, update in batch 13500/???, loss: 6.629636287689209 \n " ,
" epoch: 0, update in batch 13600/???, loss: 6.577019691467285 \n " ,
" epoch: 0, update in batch 13700/???, loss: 5.953114032745361 \n " ,
" epoch: 0, update in batch 13800/???, loss: 6.630902290344238 \n " ,
" epoch: 0, update in batch 13900/???, loss: 7.593966484069824 \n " ,
" epoch: 0, update in batch 14000/???, loss: 6.636081695556641 \n " ,
" epoch: 0, update in batch 14100/???, loss: 5.772985458374023 \n " ,
" epoch: 0, update in batch 14200/???, loss: 5.907249450683594 \n " ,
" epoch: 0, update in batch 14300/???, loss: 7.863391876220703 \n " ,
" epoch: 0, update in batch 14400/???, loss: 7.275572776794434 \n " ,
" epoch: 0, update in batch 14500/???, loss: 6.818984031677246 \n " ,
" epoch: 0, update in batch 14600/???, loss: 6.0456342697143555 \n " ,
" epoch: 0, update in batch 14700/???, loss: 6.281990051269531 \n " ,
" epoch: 0, update in batch 14800/???, loss: 6.197850227355957 \n " ,
" epoch: 0, update in batch 14900/???, loss: 5.851240634918213 \n " ,
" epoch: 0, update in batch 15000/???, loss: 6.826748847961426 \n " ,
" epoch: 0, update in batch 15100/???, loss: 7.2189483642578125 \n " ,
" epoch: 0, update in batch 15200/???, loss: 6.609204292297363 \n " ,
" epoch: 0, update in batch 15300/???, loss: 6.947709560394287 \n " ,
" epoch: 0, update in batch 15400/???, loss: 6.604478359222412 \n " ,
" epoch: 0, update in batch 15500/???, loss: 6.222006797790527 \n " ,
" epoch: 0, update in batch 15600/???, loss: 6.515635013580322 \n " ,
" epoch: 0, update in batch 15700/???, loss: 6.40108585357666 \n " ,
" epoch: 0, update in batch 15800/???, loss: 6.36106014251709 \n " ,
" epoch: 0, update in batch 15900/???, loss: 6.533608436584473 \n " ,
" epoch: 0, update in batch 16000/???, loss: 6.662516117095947 \n " ,
" epoch: 0, update in batch 16100/???, loss: 7.284195899963379 \n " ,
" epoch: 0, update in batch 16200/???, loss: 6.6524176597595215 \n " ,
" epoch: 0, update in batch 16300/???, loss: 6.430756568908691 \n " ,
" epoch: 0, update in batch 16400/???, loss: 7.515387058258057 \n " ,
" epoch: 0, update in batch 16500/???, loss: 6.938241481781006 \n " ,
" epoch: 0, update in batch 16600/???, loss: 5.860864162445068 \n " ,
" epoch: 0, update in batch 16700/???, loss: 6.451329231262207 \n " ,
" epoch: 0, update in batch 16800/???, loss: 6.5510663986206055 \n " ,
" epoch: 0, update in batch 16900/???, loss: 7.3591437339782715 \n " ,
" epoch: 0, update in batch 17000/???, loss: 6.158746719360352 \n " ,
" epoch: 0, update in batch 17100/???, loss: 7.202520847320557 \n " ,
" epoch: 0, update in batch 17200/???, loss: 6.80673885345459 \n " ,
" epoch: 0, update in batch 17300/???, loss: 6.698304653167725 \n " ,
" epoch: 0, update in batch 17400/???, loss: 5.743161201477051 \n " ,
" epoch: 0, update in batch 17500/???, loss: 6.518529415130615 \n " ,
" epoch: 0, update in batch 17600/???, loss: 6.021708011627197 \n " ,
" epoch: 0, update in batch 17700/???, loss: 6.354712963104248 \n " ,
" epoch: 0, update in batch 17800/???, loss: 6.323357582092285 \n " ,
" epoch: 0, update in batch 17900/???, loss: 6.61548376083374 \n " ,
" epoch: 0, update in batch 18000/???, loss: 6.600308895111084 \n " ,
" epoch: 0, update in batch 18100/???, loss: 6.794068336486816 \n " ,
" epoch: 0, update in batch 18200/???, loss: 7.487390041351318 \n " ,
" epoch: 0, update in batch 18300/???, loss: 5.973461627960205 \n " ,
" epoch: 0, update in batch 18400/???, loss: 6.891515254974365 \n " ,
" epoch: 0, update in batch 18500/???, loss: 5.897144317626953 \n " ,
" epoch: 0, update in batch 18600/???, loss: 6.6016364097595215 \n " ,
" epoch: 0, update in batch 18700/???, loss: 6.948650360107422 \n " ,
" epoch: 0, update in batch 18800/???, loss: 7.221627235412598 \n " ,
" epoch: 0, update in batch 18900/???, loss: 6.817994117736816 \n " ,
" epoch: 0, update in batch 19000/???, loss: 5.730655193328857 \n " ,
" epoch: 0, update in batch 19100/???, loss: 6.236818790435791 \n " ,
" epoch: 0, update in batch 19200/???, loss: 7.178666114807129 \n " ,
" epoch: 0, update in batch 19300/???, loss: 6.77465295791626 \n " ,
" epoch: 0, update in batch 19400/???, loss: 6.996792793273926 \n " ,
" epoch: 0, update in batch 19500/???, loss: 6.80951452255249 \n " ,
" epoch: 0, update in batch 19600/???, loss: 7.1757965087890625 \n " ,
" epoch: 0, update in batch 19700/???, loss: 8.400952339172363 \n " ,
" epoch: 0, update in batch 19800/???, loss: 7.1904473304748535 \n " ,
" epoch: 0, update in batch 19900/???, loss: 6.339241981506348 \n " ,
" epoch: 0, update in batch 20000/???, loss: 7.078637599945068 \n " ,
" epoch: 0, update in batch 20100/???, loss: 5.015235900878906 \n " ,
" epoch: 0, update in batch 20200/???, loss: 6.763777732849121 \n " ,
" epoch: 0, update in batch 20300/???, loss: 6.543915748596191 \n " ,
" epoch: 0, update in batch 20400/???, loss: 6.027902603149414 \n " ,
" epoch: 0, update in batch 20500/???, loss: 6.710694789886475 \n " ,
" epoch: 0, update in batch 20600/???, loss: 6.800978660583496 \n " ,
" epoch: 0, update in batch 20700/???, loss: 6.371827125549316 \n " ,
" epoch: 0, update in batch 20800/???, loss: 5.952463626861572 \n " ,
" epoch: 0, update in batch 20900/???, loss: 6.317960739135742 \n " ,
" epoch: 0, update in batch 21000/???, loss: 7.178386688232422 \n " ,
" epoch: 0, update in batch 21100/???, loss: 6.887454986572266 \n " ,
" epoch: 0, update in batch 21200/???, loss: 6.468400478363037 \n " ,
" epoch: 0, update in batch 21300/???, loss: 7.8383684158325195 \n " ,
" epoch: 0, update in batch 21400/???, loss: 5.850740909576416 \n " ,
" epoch: 0, update in batch 21500/???, loss: 6.065464973449707 \n " ,
" epoch: 0, update in batch 21600/???, loss: 7.537625312805176 \n " ,
" epoch: 0, update in batch 21700/???, loss: 6.095994472503662 \n " ,
" epoch: 0, update in batch 21800/???, loss: 6.342766761779785 \n " ,
" epoch: 0, update in batch 21900/???, loss: 5.810301780700684 \n " ,
" epoch: 0, update in batch 22000/???, loss: 6.447206974029541 \n " ,
" epoch: 0, update in batch 22100/???, loss: 7.0662946701049805 \n " ,
" epoch: 0, update in batch 22200/???, loss: 6.535088539123535 \n " ,
" epoch: 0, update in batch 22300/???, loss: 7.017588138580322 \n " ,
" epoch: 0, update in batch 22400/???, loss: 5.067782402038574 \n " ,
" epoch: 0, update in batch 22500/???, loss: 6.493170738220215 \n " ,
" epoch: 0, update in batch 22600/???, loss: 5.642627716064453 \n " ,
" epoch: 0, update in batch 22700/???, loss: 7.200662136077881 \n " ,
" epoch: 0, update in batch 22800/???, loss: 6.137134075164795 \n " ,
" epoch: 0, update in batch 22900/???, loss: 6.367280006408691 \n " ,
" epoch: 0, update in batch 23000/???, loss: 7.458652496337891 \n " ,
" epoch: 0, update in batch 23100/???, loss: 6.515708923339844 \n " ,
" epoch: 0, update in batch 23200/???, loss: 7.526422023773193 \n " ,
" epoch: 0, update in batch 23300/???, loss: 6.653852939605713 \n " ,
" epoch: 0, update in batch 23400/???, loss: 6.737251281738281 \n " ,
" epoch: 0, update in batch 23500/???, loss: 6.493605136871338 \n " ,
" epoch: 0, update in batch 23600/???, loss: 6.132809638977051 \n " ,
" epoch: 0, update in batch 23700/???, loss: 6.406940460205078 \n " ,
" epoch: 0, update in batch 23800/???, loss: 6.84005880355835 \n " ,
" epoch: 0, update in batch 23900/???, loss: 6.830739498138428 \n " ,
" epoch: 0, update in batch 24000/???, loss: 5.862464427947998 \n " ,
" epoch: 0, update in batch 24100/???, loss: 6.382696628570557 \n " ,
" epoch: 0, update in batch 24200/???, loss: 5.722895622253418 \n " ,
" epoch: 0, update in batch 24300/???, loss: 6.697083473205566 \n " ,
" epoch: 0, update in batch 24400/???, loss: 6.56771183013916 \n " ,
" epoch: 0, update in batch 24500/???, loss: 7.566462516784668 \n " ,
" epoch: 0, update in batch 24600/???, loss: 6.217026710510254 \n " ,
" epoch: 0, update in batch 24700/???, loss: 7.164259433746338 \n " ,
" epoch: 0, update in batch 24800/???, loss: 6.460946083068848 \n " ,
" epoch: 0, update in batch 24900/???, loss: 6.333778381347656 \n " ,
" epoch: 0, update in batch 25000/???, loss: 6.522342681884766 \n " ,
" epoch: 0, update in batch 25100/???, loss: 6.270648002624512 \n " ,
" epoch: 0, update in batch 25200/???, loss: 7.118265628814697 \n " ,
" epoch: 0, update in batch 25300/???, loss: 5.8695197105407715 \n " ,
" epoch: 0, update in batch 25400/???, loss: 5.92995023727417 \n " ,
" epoch: 0, update in batch 25500/???, loss: 6.202570915222168 \n " ,
" epoch: 0, update in batch 25600/???, loss: 6.4268975257873535 \n " ,
" epoch: 0, update in batch 25700/???, loss: 6.710567474365234 \n " ,
" epoch: 0, update in batch 25800/???, loss: 6.130914688110352 \n " ,
" epoch: 0, update in batch 25900/???, loss: 6.082686424255371 \n " ,
" epoch: 0, update in batch 26000/???, loss: 6.111697196960449 \n " ,
" epoch: 0, update in batch 26100/???, loss: 7.320557594299316 \n " ,
" epoch: 0, update in batch 26200/???, loss: 6.227985858917236 \n " ,
" epoch: 0, update in batch 26300/???, loss: 6.204974174499512 \n " ,
" epoch: 0, update in batch 26400/???, loss: 6.658400058746338 \n " ,
" epoch: 0, update in batch 26500/???, loss: 5.911742687225342 \n " ,
" epoch: 0, update in batch 26600/???, loss: 6.891500949859619 \n " ,
" epoch: 0, update in batch 26700/???, loss: 5.763737201690674 \n " ,
" epoch: 0, update in batch 26800/???, loss: 5.757307529449463 \n " ,
" epoch: 0, update in batch 26900/???, loss: 6.076601982116699 \n " ,
" epoch: 0, update in batch 27000/???, loss: 6.193032264709473 \n " ,
" epoch: 0, update in batch 27100/???, loss: 6.120661735534668 \n " ,
" epoch: 0, update in batch 27200/???, loss: 6.5425519943237305 \n " ,
" epoch: 0, update in batch 27300/???, loss: 6.511394500732422 \n " ,
" epoch: 0, update in batch 27400/???, loss: 7.127263069152832 \n " ,
" epoch: 0, update in batch 27500/???, loss: 6.134243488311768 \n " ,
" epoch: 0, update in batch 27600/???, loss: 6.5747809410095215 \n " ,
" epoch: 0, update in batch 27700/???, loss: 6.351634979248047 \n " ,
" epoch: 0, update in batch 27800/???, loss: 5.589611530303955 \n " ,
" epoch: 0, update in batch 27900/???, loss: 6.916817665100098 \n " ,
" epoch: 0, update in batch 28000/???, loss: 5.711864948272705 \n " ,
" epoch: 0, update in batch 28100/???, loss: 6.921398162841797 \n " ,
" epoch: 0, update in batch 28200/???, loss: 6.785823822021484 \n " ,
" epoch: 0, update in batch 28300/???, loss: 6.007838249206543 \n " ,
" epoch: 0, update in batch 28400/???, loss: 6.338862419128418 \n " ,
" epoch: 0, update in batch 28500/???, loss: 6.9078168869018555 \n " ,
" epoch: 0, update in batch 28600/???, loss: 6.710842132568359 \n " ,
" epoch: 0, update in batch 28700/???, loss: 6.592329502105713 \n " ,
" epoch: 0, update in batch 28800/???, loss: 6.184128761291504 \n " ,
" epoch: 0, update in batch 28900/???, loss: 6.209361553192139 \n " ,
" epoch: 0, update in batch 29000/???, loss: 7.067984104156494 \n " ,
" epoch: 0, update in batch 29100/???, loss: 6.479236602783203 \n " ,
" epoch: 0, update in batch 29200/???, loss: 6.413198947906494 \n " ,
" epoch: 0, update in batch 29300/???, loss: 6.638579368591309 \n " ,
" epoch: 0, update in batch 29400/???, loss: 5.938233375549316 \n " ,
" epoch: 0, update in batch 29500/???, loss: 6.8490891456604 \n " ,
" epoch: 0, update in batch 29600/???, loss: 6.111110210418701 \n " ,
" epoch: 0, update in batch 29700/???, loss: 6.959462642669678 \n " ,
" epoch: 0, update in batch 29800/???, loss: 6.964720726013184 \n " ,
" epoch: 0, update in batch 29900/???, loss: 6.2007527351379395 \n " ,
" epoch: 0, update in batch 30000/???, loss: 6.803907871246338 \n " ,
" epoch: 0, update in batch 30100/???, loss: 5.665301322937012 \n " ,
" epoch: 0, update in batch 30200/???, loss: 6.913702487945557 \n " ,
" epoch: 0, update in batch 30300/???, loss: 6.824265956878662 \n " ,
" epoch: 0, update in batch 30400/???, loss: 6.131905555725098 \n " ,
" epoch: 0, update in batch 30500/???, loss: 5.799595832824707 \n " ,
" epoch: 0, update in batch 30600/???, loss: 6.846949100494385 \n " ,
" epoch: 0, update in batch 30700/???, loss: 6.481771945953369 \n " ,
" epoch: 0, update in batch 30800/???, loss: 6.5581254959106445 \n " ,
" epoch: 0, update in batch 30900/???, loss: 6.111696720123291 \n " ,
" epoch: 0, update in batch 31000/???, loss: 4.8547563552856445 \n " ,
" epoch: 0, update in batch 31100/???, loss: 6.5503740310668945 \n " ,
" epoch: 0, update in batch 31200/???, loss: 6.212404251098633 \n " ,
" epoch: 0, update in batch 31300/???, loss: 5.761624336242676 \n " ,
" epoch: 0, update in batch 31400/???, loss: 7.043508052825928 \n " ,
" epoch: 0, update in batch 31500/???, loss: 8.301980018615723 \n " ,
" epoch: 0, update in batch 31600/???, loss: 5.655745506286621 \n " ,
" epoch: 0, update in batch 31700/???, loss: 7.116888999938965 \n " ,
" epoch: 0, update in batch 31800/???, loss: 6.237078666687012 \n " ,
" epoch: 0, update in batch 31900/???, loss: 6.990937232971191 \n " ,
" epoch: 0, update in batch 32000/???, loss: 6.327075958251953 \n " ,
" epoch: 0, update in batch 32100/???, loss: 6.831456184387207 \n " ,
" epoch: 0, update in batch 32200/???, loss: 6.511493682861328 \n " ,
" epoch: 0, update in batch 32300/???, loss: 6.719797611236572 \n " ,
" epoch: 0, update in batch 32400/???, loss: 6.46258544921875 \n " ,
" epoch: 0, update in batch 32500/???, loss: 7.349535942077637 \n " ,
" epoch: 0, update in batch 32600/???, loss: 5.773186683654785 \n " ,
" epoch: 0, update in batch 32700/???, loss: 6.072037696838379 \n " ,
" epoch: 0, update in batch 32800/???, loss: 7.044579982757568 \n " ,
" epoch: 0, update in batch 32900/???, loss: 6.290024757385254 \n " ,
" epoch: 0, update in batch 33000/???, loss: 7.101686000823975 \n " ,
" epoch: 0, update in batch 33100/???, loss: 6.590539455413818 \n " ,
" epoch: 0, update in batch 33200/???, loss: 6.944089412689209 \n " ,
" epoch: 0, update in batch 33300/???, loss: 6.6709442138671875 \n " ,
" epoch: 0, update in batch 33400/???, loss: 7.119935035705566 \n " ,
" epoch: 0, update in batch 33500/???, loss: 6.845646858215332 \n " ,
" epoch: 0, update in batch 33600/???, loss: 6.941410064697266 \n " ,
" epoch: 0, update in batch 33700/???, loss: 6.341822624206543 \n " ,
" epoch: 0, update in batch 33800/???, loss: 6.98660945892334 \n " ,
" epoch: 0, update in batch 33900/???, loss: 7.544371128082275 \n " ,
" epoch: 0, update in batch 34000/???, loss: 6.844598293304443 \n " ,
" epoch: 0, update in batch 34100/???, loss: 6.958268642425537 \n " ,
" epoch: 0, update in batch 34200/???, loss: 6.6372880935668945 \n "
2022-04-23 21:36:38 +02:00
]
}
] ,
" source " : [
2022-06-06 11:25:46 +02:00
" train(train_dataset_front, model_front, 1, 64) "
]
} ,
{
" cell_type " : " code " ,
" execution_count " : 22 ,
" id " : " f0512b6e-098d-4264-b52e-1d9ce88311f0 " ,
" metadata " : { } ,
" outputs " : [ ] ,
" source " : [
" def predict_probs(left_tokens, right_tokens): \n " ,
" model_front.eval() \n " ,
2022-04-23 21:36:38 +02:00
" \n " ,
2022-06-06 11:25:46 +02:00
" x_left = torch.tensor([[train_dataset_front.key_to_index[w] if w in key_to_index else train_dataset_front.key_to_index[ ' <unk> ' ] for w in left_tokens]]).to(device) \n " ,
" x_right = torch.tensor([[train_dataset_front.key_to_index[w] if w in key_to_index else train_dataset_front.key_to_index[ ' <unk> ' ] for w in right_tokens]]).to(device) \n " ,
" y_pred_left, (state_h_left, state_c_left) = model_front(x_left) \n " ,
" y_pred_right, (state_h_right, state_c_right) = model_back(x_right) \n " ,
2022-04-23 21:36:38 +02:00
" \n " ,
2022-06-06 11:25:46 +02:00
" last_word_logits_left = y_pred_left[0][-1] \n " ,
" last_word_logits_right = y_pred_right[0][-1] \n " ,
" probs_left = torch.nn.functional.softmax(last_word_logits_left, dim=0).detach().cpu().numpy() \n " ,
" probs_right = torch.nn.functional.softmax(last_word_logits_right, dim=0).detach().cpu().numpy() \n " ,
" \n " ,
" probs = [np.mean(k) for k in zip(probs_left, probs_right)] \n " ,
" \n " ,
" top_words = [] \n " ,
" for index in range(len(probs)): \n " ,
" if len(top_words) < 30: \n " ,
" top_words.append((probs[index], [index])) \n " ,
" else: \n " ,
" worst_word = None \n " ,
" for word in top_words: \n " ,
" if not worst_word: \n " ,
" worst_word = word \n " ,
" else: \n " ,
" if word[0] < worst_word[0]: \n " ,
" worst_word = word \n " ,
" if worst_word[0] < probs[index] and index != len(probs) - 1: \n " ,
" top_words.remove(worst_word) \n " ,
" top_words.append((probs[index], [index])) \n " ,
" \n " ,
" prediction = ' ' \n " ,
" sum_prob = 0.0 \n " ,
" for word in top_words: \n " ,
" sum_prob += word[0] \n " ,
" word_index = word[0] \n " ,
" word_text = index_to_key[word[1][0]] \n " ,
" prediction += f ' {word_text} : {word_index} ' \n " ,
" prediction += f ' : { 1 - sum_prob} ' \n " ,
" \n " ,
" return prediction "
2022-04-23 21:36:38 +02:00
]
} ,
{
" cell_type " : " code " ,
2022-06-06 11:25:46 +02:00
" execution_count " : 23 ,
" id " : " 90938496-416a-4618-bf88-eddb3cc6e8da " ,
2022-04-23 21:36:38 +02:00
" metadata " : { } ,
" outputs " : [ ] ,
" source " : [
2022-06-06 11:25:46 +02:00
" dev_data = pd.read_csv( ' dev-0/in.tsv.xz ' , sep= ' \\ t ' , on_bad_lines= ' skip ' , header=None, quoting=csv.QUOTE_NONE) \n " ,
" test_data = pd.read_csv( ' test-A/in.tsv.xz ' , sep= ' \\ t ' , on_bad_lines= ' skip ' , header=None, quoting=csv.QUOTE_NONE) "
2022-04-23 21:36:38 +02:00
]
} ,
{
" cell_type " : " code " ,
2022-06-06 11:25:46 +02:00
" execution_count " : 24 ,
" id " : " a0d4bd42-0499-46af-ab0e-5dd23dacf47a " ,
2022-04-23 21:36:38 +02:00
" metadata " : { } ,
" outputs " : [ ] ,
" source " : [
2022-06-06 11:25:46 +02:00
" with open( ' dev-0/out.tsv ' , ' w ' ) as file: \n " ,
" for index, row in dev_data.iterrows(): \n " ,
" left_text = clean_text(str(row[6])) \n " ,
" right_text = clean_text(str(row[7])) \n " ,
" left_words = word_tokenize(left_text) \n " ,
" right_words = word_tokenize(right_text) \n " ,
" right_words.reverse() \n " ,
" if len(left_words) < 6 or len(right_words) < 6: \n " ,
" prediction = ' :1.0 ' \n " ,
" else: \n " ,
" prediction = predict_probs(left_words[-5:], right_words[-5:]) \n " ,
" file.write(prediction + ' \\ n ' ) "
]
} ,
{
" cell_type " : " code " ,
" execution_count " : 25 ,
" id " : " 819be88f-6336-4853-8015-33c9f3392aa4 " ,
" metadata " : { } ,
" outputs " : [ ] ,
" source " : [
" with open( ' test-A/out.tsv ' , ' w ' ) as file: \n " ,
" for index, row in test_data.iterrows(): \n " ,
" left_text = clean_text(str(row[6])) \n " ,
" right_text = clean_text(str(row[7])) \n " ,
" left_words = word_tokenize(left_text) \n " ,
" right_words = word_tokenize(right_text) \n " ,
" right_words.reverse() \n " ,
" if len(left_words) < 6 or len(right_words) < 6: \n " ,
" prediction = ' :1.0 ' \n " ,
" else: \n " ,
" prediction = predict_probs(left_words[-5:], right_words[-5:]) \n " ,
" file.write(prediction + ' \\ n ' ) "
2022-04-23 21:36:38 +02:00
]
} ,
{
" cell_type " : " code " ,
" execution_count " : null ,
2022-06-06 11:25:46 +02:00
" id " : " 8de69da7-999e-4653-8308-b9d7e6fe24c1 " ,
2022-04-23 21:36:38 +02:00
" metadata " : { } ,
" outputs " : [ ] ,
" source " : [ ]
}
] ,
" metadata " : {
" kernelspec " : {
" display_name " : " Python 3 (ipykernel) " ,
" language " : " python " ,
" name " : " python3 "
} ,
" language_info " : {
" codemirror_mode " : {
" name " : " ipython " ,
" version " : 3
} ,
" file_extension " : " .py " ,
" mimetype " : " text/x-python " ,
" name " : " python " ,
" nbconvert_exporter " : " python " ,
" pygments_lexer " : " ipython3 " ,
" version " : " 3.9.7 "
}
} ,
" nbformat " : 4 ,
" nbformat_minor " : 5
}