Added whats needed

This commit is contained in:
SzamanFL 2020-12-14 22:53:02 +01:00
parent 1858d47a3f
commit 649ec662a1
17 changed files with 4618422 additions and 13 deletions

4
.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
*~
model*
vocab*
venv/*

38
README.md Normal file
View File

@ -0,0 +1,38 @@
"He Said She Said" classification challenge (2nd edition)
=========================================================
Give the probability that a text in Polish was written by a man.
This challenge is based on the "He Said She Said" corpus for Polish.
The corpus was created by grepping gender-specific first person
expressions (e.g. "zrobiłem/zrobiłam", "jestem zadowolony/zadowolona",
"będę robił/robiła") in the Common Crawl corpus. Such expressions were
normalised here into masculine forms.
Classes
-------
* `0` — text written by a woman
* `1` — text written by a man
Directory structure
-------------------
* `README.md` — this file
* `config.txt` — configuration file
* `train/` — directory with training data
* `train/train.tsv.gz` — train set (gzipped), the class is given in the first column,
a text fragment in the second one
* `train/meta.tsv.gz` — metadata (do not use during training)
* `dev-0/` — directory with dev (test) data
* `dev-0/in.tsv` — input data for the dev set (text fragments)
* `dev-0/expected.tsv` — expected (reference) data for the dev set
* `dev-0/meta.tsv` — metadata (not used during testing)
* `dev-1/` — directory with extra dev (test) data
* `dev-1/in.tsv` — input data for the extra dev set (text fragments)
* `dev-1/expected.tsv` — expected (reference) data for the extra dev set
* `dev-1/meta.tsv` — metadata (not used during testing)
* `test-A` — directory with test data
* `test-A/in.tsv` — input data for the test set (text fragments)
* `test-A/expected.tsv` — expected (reference) data for the test set (hidden)

1
config.txt Normal file
View File

@ -0,0 +1 @@
--metric Likelihood --metric Accuracy --metric {Likelihood:N<Likelihood>,Accuracy:N<Accuracy>}P<2>{f<in[2]:for-humans>N<+H>,f<in[3]:contaminated>N<+C>,f<in[3]:not-contaminated>N<-C>} --precision 5

137314
dev-0/expected.tsv Normal file

File diff suppressed because it is too large Load Diff

137314
dev-0/in.tsv Normal file

File diff suppressed because it is too large Load Diff

137314
dev-0/meta.tsv Normal file

File diff suppressed because it is too large Load Diff

543
dev-0/out.tsv Normal file
View File

@ -0,0 +1,543 @@
1
0
1
1
1
1
1
0
1
1
0
1
0
0
1
0
1
0
1
1
0
1
0
1
1
0
1
1
1
0
1
1
1
1
0
0
1
1
0
1
0
1
0
0
0
1
1
1
0
0
0
0
0
1
1
1
0
0
1
1
1
0
0
1
1
0
1
1
1
1
0
1
0
1
1
1
1
0
1
1
1
1
1
1
1
0
0
1
1
1
1
0
1
0
1
1
1
1
1
0
1
1
1
0
1
1
1
1
0
0
1
0
1
1
0
1
1
1
1
1
0
0
1
1
0
1
1
1
1
0
1
1
1
1
1
1
1
1
0
0
1
1
0
0
1
1
1
1
0
1
1
0
0
1
0
1
1
0
1
1
1
1
1
1
1
0
0
1
1
1
0
1
1
0
1
0
0
0
1
0
1
0
1
1
0
0
1
1
0
1
0
0
0
1
1
1
0
1
0
1
0
0
1
1
1
0
1
0
1
1
1
1
1
1
1
1
1
0
1
1
1
1
0
0
0
0
0
0
1
1
1
1
1
1
0
0
1
1
1
1
0
0
0
1
1
1
1
1
1
1
1
1
1
0
0
1
1
1
0
0
0
1
1
0
1
1
1
1
1
1
1
1
1
1
1
1
0
0
1
1
0
1
0
0
1
1
1
0
1
0
1
1
1
1
1
1
0
0
1
1
1
1
1
1
1
0
1
1
1
1
1
0
0
1
0
1
1
1
1
0
1
1
1
0
1
0
0
1
0
1
0
1
1
1
0
0
1
1
1
1
1
1
1
0
1
1
0
1
1
1
0
1
0
1
1
1
1
0
0
1
1
1
0
0
0
0
0
0
1
1
1
1
1
1
0
1
1
1
0
1
1
1
1
0
1
0
1
0
1
1
1
0
0
1
1
1
1
0
1
0
1
1
1
1
0
1
0
1
1
0
1
1
1
1
0
1
1
1
0
0
1
0
1
1
1
1
1
1
1
1
0
0
1
1
1
1
1
0
1
0
1
1
1
0
1
1
1
1
0
1
0
1
1
1
1
1
1
1
1
0
1
1
0
1
1
0
0
0
1
1
0
1
0
1
1
1
0
0
1
0
1
1
1
1
1
1
1
1
1
1
1
0
0
1
1
1
1
1
0
0
1
1
0
1
1
0
1
0
0
0
1
1
1
1
0
1
1
0
0
1
0
1
0
1
1
1
1
1
0
1
1
0
1
1
0
1
1
1
1
1
1
0
0
1 1
2 0
3 1
4 1
5 1
6 1
7 1
8 0
9 1
10 1
11 0
12 1
13 0
14 0
15 1
16 0
17 1
18 0
19 1
20 1
21 0
22 1
23 0
24 1
25 1
26 0
27 1
28 1
29 1
30 0
31 1
32 1
33 1
34 1
35 0
36 0
37 1
38 1
39 0
40 1
41 0
42 1
43 0
44 0
45 0
46 1
47 1
48 1
49 0
50 0
51 0
52 0
53 0
54 1
55 1
56 1
57 0
58 0
59 1
60 1
61 1
62 0
63 0
64 1
65 1
66 0
67 1
68 1
69 1
70 1
71 0
72 1
73 0
74 1
75 1
76 1
77 1
78 0
79 1
80 1
81 1
82 1
83 1
84 1
85 1
86 0
87 0
88 1
89 1
90 1
91 1
92 0
93 1
94 0
95 1
96 1
97 1
98 1
99 1
100 0
101 1
102 1
103 1
104 0
105 1
106 1
107 1
108 1
109 0
110 0
111 1
112 0
113 1
114 1
115 0
116 1
117 1
118 1
119 1
120 1
121 0
122 0
123 1
124 1
125 0
126 1
127 1
128 1
129 1
130 0
131 1
132 1
133 1
134 1
135 1
136 1
137 1
138 1
139 0
140 0
141 1
142 1
143 0
144 0
145 1
146 1
147 1
148 1
149 0
150 1
151 1
152 0
153 0
154 1
155 0
156 1
157 1
158 0
159 1
160 1
161 1
162 1
163 1
164 1
165 1
166 0
167 0
168 1
169 1
170 1
171 0
172 1
173 1
174 0
175 1
176 0
177 0
178 0
179 1
180 0
181 1
182 0
183 1
184 1
185 0
186 0
187 1
188 1
189 0
190 1
191 0
192 0
193 0
194 1
195 1
196 1
197 0
198 1
199 0
200 1
201 0
202 0
203 1
204 1
205 1
206 0
207 1
208 0
209 1
210 1
211 1
212 1
213 1
214 1
215 1
216 1
217 1
218 0
219 1
220 1
221 1
222 1
223 0
224 0
225 0
226 0
227 0
228 0
229 1
230 1
231 1
232 1
233 1
234 1
235 0
236 0
237 1
238 1
239 1
240 1
241 0
242 0
243 0
244 1
245 1
246 1
247 1
248 1
249 1
250 1
251 1
252 1
253 1
254 0
255 0
256 1
257 1
258 1
259 0
260 0
261 0
262 1
263 1
264 0
265 1
266 1
267 1
268 1
269 1
270 1
271 1
272 1
273 1
274 1
275 1
276 1
277 0
278 0
279 1
280 1
281 0
282 1
283 0
284 0
285 1
286 1
287 1
288 0
289 1
290 0
291 1
292 1
293 1
294 1
295 1
296 1
297 0
298 0
299 1
300 1
301 1
302 1
303 1
304 1
305 1
306 0
307 1
308 1
309 1
310 1
311 1
312 0
313 0
314 1
315 0
316 1
317 1
318 1
319 1
320 0
321 1
322 1
323 1
324 0
325 1
326 0
327 0
328 1
329 0
330 1
331 0
332 1
333 1
334 1
335 0
336 0
337 1
338 1
339 1
340 1
341 1
342 1
343 1
344 0
345 1
346 1
347 0
348 1
349 1
350 1
351 0
352 1
353 0
354 1
355 1
356 1
357 1
358 0
359 0
360 1
361 1
362 1
363 0
364 0
365 0
366 0
367 0
368 0
369 1
370 1
371 1
372 1
373 1
374 1
375 0
376 1
377 1
378 1
379 0
380 1
381 1
382 1
383 1
384 0
385 1
386 0
387 1
388 0
389 1
390 1
391 1
392 0
393 0
394 1
395 1
396 1
397 1
398 0
399 1
400 0
401 1
402 1
403 1
404 1
405 0
406 1
407 0
408 1
409 1
410 0
411 1
412 1
413 1
414 1
415 0
416 1
417 1
418 1
419 0
420 0
421 1
422 0
423 1
424 1
425 1
426 1
427 1
428 1
429 1
430 1
431 0
432 0
433 1
434 1
435 1
436 1
437 1
438 0
439 1
440 0
441 1
442 1
443 1
444 0
445 1
446 1
447 1
448 1
449 0
450 1
451 0
452 1
453 1
454 1
455 1
456 1
457 1
458 1
459 1
460 0
461 1
462 1
463 0
464 1
465 1
466 0
467 0
468 0
469 1
470 1
471 0
472 1
473 0
474 1
475 1
476 1
477 0
478 0
479 1
480 0
481 1
482 1
483 1
484 1
485 1
486 1
487 1
488 1
489 1
490 1
491 1
492 0
493 0
494 1
495 1
496 1
497 1
498 1
499 0
500 0
501 1
502 1
503 0
504 1
505 1
506 0
507 1
508 0
509 0
510 0
511 1
512 1
513 1
514 1
515 0
516 1
517 1
518 0
519 0
520 1
521 0
522 1
523 0
524 1
525 1
526 1
527 1
528 1
529 0
530 1
531 1
532 0
533 1
534 1
535 0
536 1
537 1
538 1
539 1
540 1
541 1
542 0
543 0

156606
dev-1/expected.tsv Normal file

File diff suppressed because it is too large Load Diff

156606
dev-1/in.tsv Normal file

File diff suppressed because it is too large Load Diff

156606
dev-1/meta.tsv Normal file

File diff suppressed because it is too large Load Diff

View File

@ -44,10 +44,7 @@ def main():
one_hot = create_one_hot(vocab, line) one_hot = create_one_hot(vocab, line)
one_hot_tensor = torch.tensor(one_hot, dtype=torch.float64).T one_hot_tensor = torch.tensor(one_hot, dtype=torch.float64).T
y_predicted = model(one_hot_tensor.float()) y_predicted = model(one_hot_tensor.float())
if y_predicted > 0.5: out.write(f'{y_predicted}\n')
out.write('1\n')
elif y_predicted <= 0.5:
out.write('0\n')
if counter % 100 == 0: if counter % 100 == 0:
print(f"{counter}") print(f"{counter}")
counter +=1 counter +=1

View File

@ -5,6 +5,7 @@ import sys
import pickle import pickle
import torch import torch
from Network import Network from Network import Network
import random
def clear_tokens(text): def clear_tokens(text):
text = text.rstrip('\n').lower() text = text.rstrip('\n').lower()
@ -15,33 +16,40 @@ def clear_tokens(text):
def create_vocab(texts_and_scores): def create_vocab(texts_and_scores):
vocab = {} vocab = {}
celling = 100000
for text in texts_and_scores.keys(): for text in texts_and_scores.keys():
tokens = clear_tokens(text) tokens = clear_tokens(text)
for token in tokens: for token in tokens:
vocab[token] = 0 vocab[token] = 0
if len(vocab.keys()) > celling:
print(f"Short vocab length : {len(vocab.keys())}")
return vocab
return vocab return vocab
def create_one_hot(vocab, text): def create_one_hot(vocab, text):
one_hot = dict(vocab) one_hot = dict(vocab)
tokens = clear_tokens(text) tokens = clear_tokens(text)
for token in tokens: for token in tokens:
one_hot[token] += 1 try:
one_hot[token] += 1
except KeyError:
pass
return [[i] for i in one_hot.values()] return [[i] for i in one_hot.values()]
def main(): def main():
if len(sys.argv) == 4 or len(sys.argv) == 3: if len(sys.argv) == 4 or len(sys.argv) == 3 or len(sys.argv) == 5:
pass pass
else: else:
print('Not sufficient number of args') print('Not sufficient number of args')
return return
print("Reading data")
texts_and_scores = {} texts_and_scores = {}
with open(sys.argv[1], 'r') as file_in, open(sys.argv[2], 'r') as file_exp: with open(sys.argv[1], 'r') as file_in, open(sys.argv[2], 'r') as file_exp:
for line_in, line_exp in zip(file_in, file_exp): for line_in, line_exp in zip(file_in, file_exp):
texts_and_scores[line_in] = int(line_exp.rstrip('\n')) texts_and_scores[line_in] = int(line_exp.rstrip('\n'))
print(f"Data read") print(f"Data read")
if len(sys.argv) == 4: if len(sys.argv) == 5:
print(f"Loading vocab from {sys.argv[3]}") print(f"Loading vocab from {sys.argv[3]}")
with open(sys.argv[3], 'rb') as f: with open(sys.argv[3], 'rb') as f:
vocab = pickle.load(f) vocab = pickle.load(f)
@ -57,6 +65,11 @@ def main():
input_size = len(vocab) input_size = len(vocab)
model = Network(input_size) model = Network(input_size)
rand_num = random.uniform(1,500)
if len(sys.argv) == 5:
model.load_state_dict(torch.load(sys.argv[4]))
print(f"Starting from checkpoint {sys.argv[4]}")
lr = 0.1 lr = 0.1
optimizer = torch.optim.SGD(model.parameters(), lr=lr) optimizer = torch.optim.SGD(model.parameters(), lr=lr)
criterion = torch.nn.BCELoss() criterion = torch.nn.BCELoss()
@ -64,8 +77,9 @@ def main():
print("Starting training") print("Starting training")
model.train() model.train()
counter = 0 counter = 0
random_set = sorted(texts_and_scores.items(), key=lambda x: random.random())
try: try:
for text, score in texts_and_scores.items(): for text, score in random_set:
one_hot = create_one_hot(vocab, text) one_hot = create_one_hot(vocab, text)
one_hot_tensor = torch.tensor(one_hot, dtype=torch.float64).T one_hot_tensor = torch.tensor(one_hot, dtype=torch.float64).T
y = torch.tensor([[score]]) y = torch.tensor([[score]])
@ -78,10 +92,10 @@ def main():
if counter % 50 == 0: if counter % 50 == 0:
print(f"{counter} : {loss}") print(f"{counter} : {loss}")
if counter % 100 == 0: if counter % 100 == 0:
print(f"Saving checkpoint model-{counter}-{lr}.ckpt") print(f"Saving checkpoint model-{counter}-{lr}-{rand_num}.ckpt")
torch.save(model.state_dict(), f"model-{counter}-{lr}.ckpt") torch.save(model.state_dict(), f"model-{counter}-{lr}-{rand_num}.ckpt")
counter += 1 counter += 1
except KeyboardInterrupt: except KeyboardInterrupt:
torch.save(model.state_dict(), f"model-interrupted.ckpt") torch.save(model.state_dict(), f"model-interrupted-{lr}-{rand_num}.ckpt")
torch.save(model.state_dict(), f"model-final.ckpt") torch.save(model.state_dict(), f"model-final-{lr}-{rand_num}.ckpt")
main() main()

134618
test-A/in.tsv Normal file

File diff suppressed because it is too large Load Diff

10
test-A/out.tsv Normal file
View File

@ -0,0 +1,10 @@
0.7156
0.0171
0.9889
0.0059
0.3557
0.0706
0.6441
0.1192
0.9071
0.0488
1 0.7156
2 0.0171
3 0.9889
4 0.0059
5 0.3557
6 0.0706
7 0.6441
8 0.1192
9 0.9071
10 0.0488

3601424
train/expected.tsv Normal file

File diff suppressed because it is too large Load Diff

BIN
train/in.tsv.xz Normal file

Binary file not shown.

BIN
train/meta.tsv.gz Normal file

Binary file not shown.