roberta_base_finetune_existing
This commit is contained in:
parent
0a1b9e7815
commit
cde1e52f78
71190
dev-0/out.tsv
71190
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
16
join.py
Normal file
16
join.py
Normal file
@ -0,0 +1,16 @@
|
||||
|
||||
|
||||
def process(f_in_path, f_exp_path, f_whole_path):
|
||||
with open(f_in_path) as f_in, open(f_exp_path) as f_exp, open(f_whole_path,'w') as f_whole:
|
||||
for line_in, line_exp in zip(f_in, f_exp):
|
||||
_, _, left, right = line_in.rstrip('\n').split('\t')
|
||||
middle = line_exp.rstrip('\n')
|
||||
text = left + ' ' + middle + ' ' + right
|
||||
text = text.replace('\\n', '\n') + '\n\n'
|
||||
f_whole.write(text)
|
||||
|
||||
|
||||
process('train/in.tsv', 'train/expected.tsv', 'train/wiki.train.raw')
|
||||
process('dev-0/in.tsv', 'dev-0/expected.tsv', 'dev-0/wiki.valid.raw')
|
||||
process('dev-0/in.tsv', 'dev-0/expected.tsv', 'dev-0/wiki.test.raw')
|
||||
|
1
roberta_base_finetune_existing/0_get_datasets.sh
Normal file
1
roberta_base_finetune_existing/0_get_datasets.sh
Normal file
@ -0,0 +1 @@
|
||||
(cd .. ; python join.py ; cp train/wiki.train.raw roberta_base_finetune/wikitext-103-raw/ ; cp dev-0/wiki.{valid,test}.raw roberta_base_finetune/wikitext-103-raw/ )
|
12
roberta_base_finetune_existing/1_encode.sh
Normal file
12
roberta_base_finetune_existing/1_encode.sh
Normal file
@ -0,0 +1,12 @@
|
||||
mkdir -p gpt2_bpe
|
||||
wget -O gpt2_bpe/encoder.json https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json
|
||||
wget -O gpt2_bpe/vocab.bpe https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe
|
||||
for SPLIT in train valid test; do \
|
||||
python -m multiprocessing_bpe_encoder \
|
||||
--encoder-json gpt2_bpe/encoder.json \
|
||||
--vocab-bpe gpt2_bpe/vocab.bpe \
|
||||
--inputs wikitext-103-raw/wiki.${SPLIT}.raw \
|
||||
--outputs wikitext-103-raw/wiki.${SPLIT}.bpe \
|
||||
--keep-empty \
|
||||
--workers 10; \
|
||||
done
|
9
roberta_base_finetune_existing/2_binarize.sh
Normal file
9
roberta_base_finetune_existing/2_binarize.sh
Normal file
@ -0,0 +1,9 @@
|
||||
wget -O gpt2_bpe/dict.txt https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt
|
||||
fairseq-preprocess \
|
||||
--only-source \
|
||||
--srcdict gpt2_bpe/dict.txt \
|
||||
--trainpref wikitext-103-raw/wiki.train.bpe \
|
||||
--validpref wikitext-103-raw/wiki.valid.bpe \
|
||||
--testpref wikitext-103-raw/wiki.test.bpe \
|
||||
--destdir data-bin/wikitext-103 \
|
||||
--workers 10
|
20
roberta_base_finetune_existing/3_train.sh
Normal file
20
roberta_base_finetune_existing/3_train.sh
Normal file
@ -0,0 +1,20 @@
|
||||
TOTAL_UPDATES=12500000 # Total number of training steps
|
||||
WARMUP_UPDATES=1000 # Warmup the learning rate over this many updates
|
||||
PEAK_LR=0.0005 # Peak learning rate, adjust as needed
|
||||
TOKENS_PER_SAMPLE=512 # Max sequence length
|
||||
MAX_POSITIONS=512 # Num. positional embeddings (usually same as above)
|
||||
MAX_SENTENCES=4 # Number of sequences per batch (batch size)
|
||||
UPDATE_FREQ=64 # Increase the batch size 16x
|
||||
|
||||
DATA_DIR=data-bin/wikitext-103
|
||||
ulimit -n 4096
|
||||
|
||||
fairseq-train --fp16 $DATA_DIR \
|
||||
--task masked_lm --criterion masked_lm \
|
||||
--arch roberta_base --sample-break-mode complete --tokens-per-sample $TOKENS_PER_SAMPLE \
|
||||
--optimizer adam --adam-betas '(0.9,0.98)' --adam-eps 1e-6 --clip-norm 0.0 \
|
||||
--lr-scheduler polynomial_decay --lr $PEAK_LR --warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_UPDATES \
|
||||
--dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \
|
||||
--batch-size $MAX_SENTENCES --update-freq $UPDATE_FREQ \
|
||||
--max-update $TOTAL_UPDATES --log-format simple --log-interval 1 \
|
||||
--restore-file roberta.base/model.pt
|
62
roberta_base_finetune_existing/predict.py
Normal file
62
roberta_base_finetune_existing/predict.py
Normal file
@ -0,0 +1,62 @@
|
||||
import torch
|
||||
from fairseq.models.roberta import RobertaModel
|
||||
from fairseq import hub_utils
|
||||
from fairseq.models.roberta import RobertaModel, RobertaHubInterface
|
||||
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
roberta = RobertaModel.from_pretrained('checkpoint_final')
|
||||
|
||||
roberta.eval()
|
||||
roberta.cuda()
|
||||
|
||||
|
||||
preds = roberta.fill_mask('I like <mask> and apples', topk=3)
|
||||
#import pdb; pdb.set_trace()
|
||||
|
||||
# raise CUDA RuntimeError from which
|
||||
# the process does not recover
|
||||
BLACKLIST = ['aeeadb08042bbd49dcbefcefa1f13806',
|
||||
'01ba303704bb62bcb59f8cb7cb5663d7',
|
||||
'98bdfa711364f45f1bcffb1359793614',
|
||||
'a9da7950abcbd531a5207c04c3bdc840',
|
||||
'4cd7f730ee72451406afa89c5c6431d6',
|
||||
]
|
||||
|
||||
def predict(f_in_path,f_out_path):
|
||||
f_in = open(f_in_path,'r', newline='\n')
|
||||
f_out = open(f_out_path,'w', newline='\n')
|
||||
|
||||
for line in tqdm(f_in,total = 88000):
|
||||
id,_, before, after = line.split('\t')
|
||||
before = before.replace('\\n', '\n')
|
||||
after = after.replace('\\n', '\n')
|
||||
before = ' '.join(before.split(' ')[-40:]) # tu można poprawić, żeby śmigał na tokenal spm a nie zakładał że jest jak ze spacjami
|
||||
after = ' '.join(after.split(' ')[:40])
|
||||
input = before + ' <mask> ' + after
|
||||
try:
|
||||
if id in BLACKLIST:
|
||||
f_out.write(':1\n')
|
||||
continue
|
||||
preds = roberta.fill_mask(input, topk=10)
|
||||
hyps = []
|
||||
probs_sum = 0.0
|
||||
for pred in preds:
|
||||
if pred[2] == '<unk>':
|
||||
continue
|
||||
hyps.append(pred[2].rstrip().lstrip() + ':' + str(pred[1]))
|
||||
probs_sum += pred[1]
|
||||
hyps.append(':' + str(1 - probs_sum))
|
||||
preds_line = ' '.join(hyps)
|
||||
f_out.write(preds_line + '\n')
|
||||
except RuntimeError:
|
||||
import pdb ; pdb.set_trace()
|
||||
print('RUNTIMEERROR')
|
||||
f_out.write(':1\n')
|
||||
|
||||
f_out.close()
|
||||
|
||||
predict('../dev-0/in.tsv', '../dev-0/out.tsv')
|
||||
predict('../test-A/in.tsv', '../test-A/out.tsv')
|
71070
test-A/out.tsv
71070
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user