Neural bigrams v2
This commit is contained in:
parent
37775b359b
commit
762dc9dff7
21010
dev-0/out.tsv
21010
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
17
run.py
17
run.py
@ -22,9 +22,11 @@ import pandas as pd
|
||||
def read_train_data(file):
|
||||
data = pd.read_csv(file, sep="\t", error_bad_lines=False, index_col=0, header=None)
|
||||
with open('input_train.txt', 'w') as f:
|
||||
for index, row in data[:500000].iterrows():
|
||||
first_part = str(row[6]).replace('\\n', '')
|
||||
sec_part = str(row[7]).replace('\\n', '')
|
||||
for index, row in data.iterrows():
|
||||
first_part = str(row[6]).replace('\\n', " ").replace("'t", " not").replace("'s", " is").replace("'ll", " will").replace("'m", " am").replace("'ve", " have")
|
||||
first_part = first_part.replace("'", "")
|
||||
sec_part = str(row[7]).replace('\\n', " ").replace("'t", " not").replace("'s", " is").replace("'ll", " will").replace("'m", " am").replace("'ve", " have")
|
||||
sec_part = sec_part.replace("'", "")
|
||||
if first_part != 'nan':
|
||||
f.write(first_part + '\n')
|
||||
if sec_part != 'nan':
|
||||
@ -53,7 +55,7 @@ from itertools import islice
|
||||
for line in fh:
|
||||
yield get_words_from_line(line)
|
||||
|
||||
vocab_size = 30000
|
||||
vocab_size = 50000
|
||||
|
||||
vocab = build_vocab_from_iterator(
|
||||
get_word_lines_from_file('input_train.txt'),
|
||||
@ -113,6 +115,10 @@ class Bigrams(IterableDataset):
|
||||
|
||||
train_dataset = Bigrams('input_train.shuf.txt', vocab_size)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
!nvidia-smi
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
@ -179,12 +185,13 @@ def predict_word(word):
|
||||
to_return += f':{total}'
|
||||
return to_return
|
||||
|
||||
!pip install nltk
|
||||
!pip install nltk pandas
|
||||
|
||||
from nltk.tokenize import RegexpTokenizer
|
||||
tokenizer = RegexpTokenizer(r"\w+")
|
||||
|
||||
import csv
|
||||
import pandas as pd
|
||||
|
||||
def generate_outputs(input_file, output_file):
|
||||
data = pd.read_csv(input_file, sep='\t', error_bad_lines=False, index_col=0, header=None, quoting=csv.QUOTE_NONE)
|
||||
|
14594
test-A/out.tsv
14594
test-A/out.tsv
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user