polish roberta
This commit is contained in:
parent
bbf7c5f350
commit
bee7eaa312
10570
dev-0/out.tsv
Normal file
10570
dev-0/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
4
polish_roberta_finetuned/2_bpe_encode.sh
Normal file
4
polish_roberta_finetuned/2_bpe_encode.sh
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
MODEL=sentencepiece.bpe.model
|
||||||
|
SPM=~/vcpkg/buildtrees/sentencepiece/x64-linux-dbg/src/spm_encode
|
||||||
|
${SPM} --model=${MODEL} <data/train.input0 >data/train.input0.spm
|
||||||
|
${SPM} --model=${MODEL} <data/dev.input0 >data/dev.input0.spm
|
15
polish_roberta_finetuned/4_preprocess_data.sh
Normal file
15
polish_roberta_finetuned/4_preprocess_data.sh
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
|
||||||
|
fairseq-preprocess \
|
||||||
|
--only-source \
|
||||||
|
--trainpref "data/train.input0.spm" \
|
||||||
|
--validpref "data/dev.input0.spm" \
|
||||||
|
--destdir "data-bin/input0" \
|
||||||
|
--workers 8 \
|
||||||
|
--srcdict dict.txt
|
||||||
|
|
||||||
|
fairseq-preprocess \
|
||||||
|
--only-source \
|
||||||
|
--trainpref "data/train.label" \
|
||||||
|
--validpref "data/dev.label" \
|
||||||
|
--destdir "data-bin/label" \
|
||||||
|
--workers 8
|
31
polish_roberta_finetuned/5_run_training.sh
Normal file
31
polish_roberta_finetuned/5_run_training.sh
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
TOTAL_NUM_UPDATES=500000 # 10 epochs through IMDB for bsz 32
|
||||||
|
WARMUP_UPDATES=1000 # 6 percent of the number of updates
|
||||||
|
LR=1e-05 # Peak LR for polynomial LR scheduler.
|
||||||
|
HEAD_NAME=imdb_head # Custom name for the classification head.
|
||||||
|
NUM_CLASSES=2 # Number of classes for the classification task.
|
||||||
|
MAX_SENTENCES=128 # Batch size.
|
||||||
|
ROBERTA_PATH=model.pt
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/ \
|
||||||
|
--restore-file $ROBERTA_PATH \
|
||||||
|
--max-positions 512 \
|
||||||
|
--batch-size $MAX_SENTENCES \
|
||||||
|
--max-tokens 4400 \
|
||||||
|
--task sentence_prediction \
|
||||||
|
--reset-optimizer --reset-dataloader --reset-meters \
|
||||||
|
--required-batch-size-multiple 1 \
|
||||||
|
--init-token 0 --separator-token 2 \
|
||||||
|
--arch roberta_base \
|
||||||
|
--criterion sentence_prediction \
|
||||||
|
--classification-head-name $HEAD_NAME \
|
||||||
|
--num-classes $NUM_CLASSES \
|
||||||
|
--dropout 0.1 --attention-dropout 0.1 \
|
||||||
|
--weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \
|
||||||
|
--clip-norm 0.0 \
|
||||||
|
--lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
|
||||||
|
--fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
|
||||||
|
--max-epoch 10 \
|
||||||
|
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \
|
||||||
|
--shorten-method "truncate" \
|
||||||
|
--find-unused-parameters \
|
||||||
|
--update-freq 4
|
28
polish_roberta_finetuned/6_predict.py
Normal file
28
polish_roberta_finetuned/6_predict.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from fairseq.models.roberta import RobertaModel
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
roberta = RobertaModel.from_pretrained(
|
||||||
|
model_name_or_path='checkpoints',
|
||||||
|
data_name_or_path='checkpoints',
|
||||||
|
sentencepiece_vocab="sentencepiece.bpe.model",
|
||||||
|
sentencepiece_model="sentencepiece.bpe.model",
|
||||||
|
checkpoint_file='checkpoint_best.pt',
|
||||||
|
#load_checkpoint_heads=True,
|
||||||
|
#data_name_or_path='data-bin',
|
||||||
|
bpe = 'sentencepiece',
|
||||||
|
)
|
||||||
|
|
||||||
|
roberta.cuda()
|
||||||
|
roberta.eval()
|
||||||
|
|
||||||
|
def predict(fpath_in, fpath_out):
|
||||||
|
f_in = open(fpath_in,newline='\n')
|
||||||
|
f_out = open(fpath_out,'w')
|
||||||
|
for i in tqdm(f_in,total=137314):
|
||||||
|
tokens = roberta.encode(i.rstrip('\n'))[:512]
|
||||||
|
pred = str(np.exp(roberta.predict('imdb_head', tokens)[0][1].item()))
|
||||||
|
f_out.write(pred + '\n')
|
||||||
|
f_out.close()
|
||||||
|
|
||||||
|
predict('../dev-0/in.tsv', '../dev-0/out.tsv')
|
||||||
|
predict('../test-A/in.tsv', '../test-A/out.tsv')
|
11834
test-A/out.tsv
Normal file
11834
test-A/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user