Added whats needed
This commit is contained in:
parent
1858d47a3f
commit
649ec662a1
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
*~
|
||||
model*
|
||||
vocab*
|
||||
venv/*
|
38
README.md
Normal file
38
README.md
Normal file
@ -0,0 +1,38 @@
|
||||
|
||||
"He Said She Said" classification challenge (2nd edition)
|
||||
=========================================================
|
||||
|
||||
Give the probability that a text in Polish was written by a man.
|
||||
|
||||
This challenge is based on the "He Said She Said" corpus for Polish.
|
||||
The corpus was created by grepping gender-specific first person
|
||||
expressions (e.g. "zrobiłem/zrobiłam", "jestem zadowolony/zadowolona",
|
||||
"będę robił/robiła") in the Common Crawl corpus. Such expressions were
|
||||
normalised here into masculine forms.
|
||||
|
||||
Classes
|
||||
-------
|
||||
|
||||
* `0` — text written by a woman
|
||||
* `1` — text written by a man
|
||||
|
||||
Directory structure
|
||||
-------------------
|
||||
|
||||
* `README.md` — this file
|
||||
* `config.txt` — configuration file
|
||||
* `train/` — directory with training data
|
||||
* `train/train.tsv.gz` — train set (gzipped), the class is given in the first column,
|
||||
a text fragment in the second one
|
||||
* `train/meta.tsv.gz` — metadata (do not use during training)
|
||||
* `dev-0/` — directory with dev (test) data
|
||||
* `dev-0/in.tsv` — input data for the dev set (text fragments)
|
||||
* `dev-0/expected.tsv` — expected (reference) data for the dev set
|
||||
* `dev-0/meta.tsv` — metadata (not used during testing)
|
||||
* `dev-1/` — directory with extra dev (test) data
|
||||
* `dev-1/in.tsv` — input data for the extra dev set (text fragments)
|
||||
* `dev-1/expected.tsv` — expected (reference) data for the extra dev set
|
||||
* `dev-1/meta.tsv` — metadata (not used during testing)
|
||||
* `test-A` — directory with test data
|
||||
* `test-A/in.tsv` — input data for the test set (text fragments)
|
||||
* `test-A/expected.tsv` — expected (reference) data for the test set (hidden)
|
1
config.txt
Normal file
1
config.txt
Normal file
@ -0,0 +1 @@
|
||||
--metric Likelihood --metric Accuracy --metric {Likelihood:N<Likelihood>,Accuracy:N<Accuracy>}P<2>{f<in[2]:for-humans>N<+H>,f<in[3]:contaminated>N<+C>,f<in[3]:not-contaminated>N<-C>} --precision 5
|
137314
dev-0/expected.tsv
Normal file
137314
dev-0/expected.tsv
Normal file
File diff suppressed because it is too large
Load Diff
137314
dev-0/in.tsv
Normal file
137314
dev-0/in.tsv
Normal file
File diff suppressed because it is too large
Load Diff
137314
dev-0/meta.tsv
Normal file
137314
dev-0/meta.tsv
Normal file
File diff suppressed because it is too large
Load Diff
543
dev-0/out.tsv
Normal file
543
dev-0/out.tsv
Normal file
@ -0,0 +1,543 @@
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
|
156606
dev-1/expected.tsv
Normal file
156606
dev-1/expected.tsv
Normal file
File diff suppressed because it is too large
Load Diff
156606
dev-1/in.tsv
Normal file
156606
dev-1/in.tsv
Normal file
File diff suppressed because it is too large
Load Diff
156606
dev-1/meta.tsv
Normal file
156606
dev-1/meta.tsv
Normal file
File diff suppressed because it is too large
Load Diff
@ -44,10 +44,7 @@ def main():
|
||||
one_hot = create_one_hot(vocab, line)
|
||||
one_hot_tensor = torch.tensor(one_hot, dtype=torch.float64).T
|
||||
y_predicted = model(one_hot_tensor.float())
|
||||
if y_predicted > 0.5:
|
||||
out.write('1\n')
|
||||
elif y_predicted <= 0.5:
|
||||
out.write('0\n')
|
||||
out.write(f'{y_predicted}\n')
|
||||
if counter % 100 == 0:
|
||||
print(f"{counter}")
|
||||
counter +=1
|
||||
|
32
src/train.py
32
src/train.py
@ -5,6 +5,7 @@ import sys
|
||||
import pickle
|
||||
import torch
|
||||
from Network import Network
|
||||
import random
|
||||
|
||||
def clear_tokens(text):
|
||||
text = text.rstrip('\n').lower()
|
||||
@ -15,33 +16,40 @@ def clear_tokens(text):
|
||||
|
||||
def create_vocab(texts_and_scores):
|
||||
vocab = {}
|
||||
celling = 100000
|
||||
for text in texts_and_scores.keys():
|
||||
tokens = clear_tokens(text)
|
||||
for token in tokens:
|
||||
vocab[token] = 0
|
||||
if len(vocab.keys()) > celling:
|
||||
print(f"Short vocab length : {len(vocab.keys())}")
|
||||
return vocab
|
||||
return vocab
|
||||
|
||||
def create_one_hot(vocab, text):
|
||||
one_hot = dict(vocab)
|
||||
tokens = clear_tokens(text)
|
||||
for token in tokens:
|
||||
one_hot[token] += 1
|
||||
try:
|
||||
one_hot[token] += 1
|
||||
except KeyError:
|
||||
pass
|
||||
return [[i] for i in one_hot.values()]
|
||||
|
||||
def main():
|
||||
if len(sys.argv) == 4 or len(sys.argv) == 3:
|
||||
if len(sys.argv) == 4 or len(sys.argv) == 3 or len(sys.argv) == 5:
|
||||
pass
|
||||
else:
|
||||
print('Not sufficient number of args')
|
||||
return
|
||||
|
||||
print("Reading data")
|
||||
texts_and_scores = {}
|
||||
with open(sys.argv[1], 'r') as file_in, open(sys.argv[2], 'r') as file_exp:
|
||||
for line_in, line_exp in zip(file_in, file_exp):
|
||||
texts_and_scores[line_in] = int(line_exp.rstrip('\n'))
|
||||
print(f"Data read")
|
||||
|
||||
if len(sys.argv) == 4:
|
||||
if len(sys.argv) == 5:
|
||||
print(f"Loading vocab from {sys.argv[3]}")
|
||||
with open(sys.argv[3], 'rb') as f:
|
||||
vocab = pickle.load(f)
|
||||
@ -57,6 +65,11 @@ def main():
|
||||
input_size = len(vocab)
|
||||
|
||||
model = Network(input_size)
|
||||
rand_num = random.uniform(1,500)
|
||||
if len(sys.argv) == 5:
|
||||
model.load_state_dict(torch.load(sys.argv[4]))
|
||||
print(f"Starting from checkpoint {sys.argv[4]}")
|
||||
|
||||
lr = 0.1
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
|
||||
criterion = torch.nn.BCELoss()
|
||||
@ -64,8 +77,9 @@ def main():
|
||||
print("Starting training")
|
||||
model.train()
|
||||
counter = 0
|
||||
random_set = sorted(texts_and_scores.items(), key=lambda x: random.random())
|
||||
try:
|
||||
for text, score in texts_and_scores.items():
|
||||
for text, score in random_set:
|
||||
one_hot = create_one_hot(vocab, text)
|
||||
one_hot_tensor = torch.tensor(one_hot, dtype=torch.float64).T
|
||||
y = torch.tensor([[score]])
|
||||
@ -78,10 +92,10 @@ def main():
|
||||
if counter % 50 == 0:
|
||||
print(f"{counter} : {loss}")
|
||||
if counter % 100 == 0:
|
||||
print(f"Saving checkpoint model-{counter}-{lr}.ckpt")
|
||||
torch.save(model.state_dict(), f"model-{counter}-{lr}.ckpt")
|
||||
print(f"Saving checkpoint model-{counter}-{lr}-{rand_num}.ckpt")
|
||||
torch.save(model.state_dict(), f"model-{counter}-{lr}-{rand_num}.ckpt")
|
||||
counter += 1
|
||||
except KeyboardInterrupt:
|
||||
torch.save(model.state_dict(), f"model-interrupted.ckpt")
|
||||
torch.save(model.state_dict(), f"model-final.ckpt")
|
||||
torch.save(model.state_dict(), f"model-interrupted-{lr}-{rand_num}.ckpt")
|
||||
torch.save(model.state_dict(), f"model-final-{lr}-{rand_num}.ckpt")
|
||||
main()
|
||||
|
134618
test-A/in.tsv
Normal file
134618
test-A/in.tsv
Normal file
File diff suppressed because it is too large
Load Diff
10
test-A/out.tsv
Normal file
10
test-A/out.tsv
Normal file
@ -0,0 +1,10 @@
|
||||
0.7156
|
||||
0.0171
|
||||
0.9889
|
||||
0.0059
|
||||
0.3557
|
||||
0.0706
|
||||
0.6441
|
||||
0.1192
|
||||
0.9071
|
||||
0.0488
|
|
3601424
train/expected.tsv
Normal file
3601424
train/expected.tsv
Normal file
File diff suppressed because it is too large
Load Diff
BIN
train/in.tsv.xz
Normal file
BIN
train/in.tsv.xz
Normal file
Binary file not shown.
BIN
train/meta.tsv.gz
Normal file
BIN
train/meta.tsv.gz
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user