Vowpal wabbit logistic b=20
This commit is contained in:
parent
d043e30286
commit
16f61dceee
35
Makefile
Normal file
35
Makefile
Normal file
@ -0,0 +1,35 @@
|
||||
SHELL=/bin/bash
|
||||
|
||||
.SECONDARY:
|
||||
|
||||
# $< - piersza zaleznosc
|
||||
# $@ - cel
|
||||
|
||||
predict: test-A/out.tsv dev-0/out.tsv
|
||||
|
||||
test-A/out.tsv: test-A/out.num.tsv num2label.py
|
||||
./num2label.py < $< > $@
|
||||
|
||||
test-A/out.num.tsv: test-A/in.vw.txt train/train.vw.model
|
||||
vw -t $< -i train/train.vw.model --loss_function logistic --probabilities -p /dev/stdout > $@
|
||||
|
||||
test-A/in.vw.txt: test-A/in.tsv
|
||||
./tsv2vw.py < $< > $@
|
||||
|
||||
dev-0/out.tsv: dev-0/out.num.tsv num2label.py
|
||||
./num2label.py < $< > $@
|
||||
|
||||
dev-0/out.num.tsv: dev-0/in.vw.txt train/train.vw.model
|
||||
vw -t $< -i train/train.vw.model --loss_function logistic --probabilities -p /dev/stdout > $@
|
||||
|
||||
dev-0/in.vw.txt: dev-0/in.tsv
|
||||
./tsv2vw.py < $< > $@
|
||||
|
||||
train/train.vw.model: train/train.vw.txt
|
||||
vw $< -f $@ --passes 50 -b 20 --random_seed 2020 --oaa 8 --loss_function logistic --probabilities -k --cache_file vw-meta-cache
|
||||
|
||||
train/train.vw.txt : train/train.tsv.gz tsv2vw.py
|
||||
zcat $< | ./tsv2vw.py > $@
|
||||
|
||||
clean:
|
||||
rm -rf dev-0/out.tsv dev-0/in.vw.txt train/train.vw.model train/train.vw.txt dev-0/out.num.tsv test-A/out.tsv test-A/out.num.tsv test-A/in.vw.txt
|
5453
dev-0/out.tsv
Normal file
5453
dev-0/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
13
num2label.py
Executable file
13
num2label.py
Executable file
@ -0,0 +1,13 @@
|
||||
#!/usr/bin/python3
|
||||
import sys
|
||||
|
||||
out_dict = {'pilka-nozna':1, 'siatkowka':2, 'sporty-walki':3, 'pilka-reczna':4, 'koszykowka':5, 'tenis':6, 'moto':7, 'zimowe':8}
|
||||
out_inv = {v: k for k, v in out_dict.items()}
|
||||
|
||||
|
||||
for line in sys.stdin:
|
||||
probs = line.split(' ')
|
||||
for prob in probs:
|
||||
fields = prob.split(':')
|
||||
first = out_inv[int(fields[0])]
|
||||
print(first + ":" + fields[1], end=" ")
|
5448
test-A/out.tsv
Normal file
5448
test-A/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
33
tsv2vw.py
Executable file
33
tsv2vw.py
Executable file
@ -0,0 +1,33 @@
|
||||
#!/usr/bin/python3
|
||||
|
||||
import sys
|
||||
|
||||
out_dict = {'pilka-nozna':1, 'siatkowka':2, 'sporty-walki':3, 'pilka-reczna':4, 'koszykowka':5, 'tenis':6, 'moto':7, 'zimowe':8}
|
||||
counter = 1
|
||||
|
||||
def process_item(out,inp):
|
||||
if out is None:
|
||||
out = ''
|
||||
else:
|
||||
global counter
|
||||
if out not in out_dict:
|
||||
out = 8
|
||||
else:
|
||||
out = out_dict[out]
|
||||
if out is None:
|
||||
out = ''
|
||||
if out == '0':
|
||||
out = "-1"
|
||||
if out == 9:
|
||||
out = 8
|
||||
|
||||
inp = inp.replace(':',' ')
|
||||
return str(out) + ' | ' + inp
|
||||
|
||||
for line in sys.stdin:
|
||||
line = line.rstrip()
|
||||
fields = line.split('\t')
|
||||
if len(fields) == 2:
|
||||
print(process_item(fields[0], fields[1]))
|
||||
else:
|
||||
print(process_item(None, fields[0]))
|
Loading…
Reference in New Issue
Block a user