Polish RoBERTa (base), epoch 5, seq_len 512, active dropout
This commit is contained in:
parent
ddce23e0d4
commit
ea4b155ee6
8
0-get-models.sh
Executable file
8
0-get-models.sh
Executable file
@ -0,0 +1,8 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
set -x
|
||||
|
||||
wget https://github.com/sdadas/polish-roberta/releases/download/models/roberta_base_fairseq.zip
|
||||
|
||||
unzip roberta_base_fairseq.zip -d roberta_base_fairseq
|
10
1-create-data.sh
Executable file
10
1-create-data.sh
Executable file
@ -0,0 +1,10 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
set -x
|
||||
|
||||
spm_encode --model=roberta_base_fairseq/sentencepiece.bpe.model < data/train/in.tsv > data/train.input0.spm
|
||||
spm_encode --model=roberta_base_fairseq/sentencepiece.bpe.model < data/dev-0/in.tsv > data/dev.input.spm
|
||||
|
||||
cp data/dev-0/expected.tsv data/dev.label
|
||||
cp data/train/expected.tsv data/train.label
|
18
2-preproc.sh
Executable file
18
2-preproc.sh
Executable file
@ -0,0 +1,18 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
set -x
|
||||
|
||||
fairseq-preprocess \
|
||||
--only-source \
|
||||
--trainpref "data/train.input0.spm" \
|
||||
--validpref "data/dev.input0.spm" \
|
||||
--destdir "data-bin/input0" \
|
||||
--workers 4 --srcdict roberta_base_fairseq/dict.txt
|
||||
|
||||
fairseq-preprocess \
|
||||
--only-source \
|
||||
--trainpref "data/train.label" \
|
||||
--validpref "data/dev.label" \
|
||||
--destdir "data-bin/label" \
|
||||
--workers 4
|
31
3-train.py
Executable file
31
3-train.py
Executable file
@ -0,0 +1,31 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
TOTAL_NUM_UPDATES=1000000000000000 # 10 epochs through IMDB for bsz 32
|
||||
WARMUP_UPDATES=216085 # 6 percent of the number of updates
|
||||
LR=1e-05 # Peak LR for polynomial LR scheduler.
|
||||
HEAD_NAME=hesaid # Custom name for the classification head.
|
||||
NUM_CLASSES=2 # Number of classes for the classification task.
|
||||
MAX_SENTENCES=35 # Batch size.
|
||||
ROBERTA_PATH="roberta_base_fairseq/model.pt"
|
||||
|
||||
fairseq-train data-bin/ \
|
||||
--restore-file $ROBERTA_PATH \
|
||||
--max-positions 512 \
|
||||
--max-sentences $MAX_SENTENCES \
|
||||
--max-tokens 8192 \
|
||||
--task sentence_prediction \
|
||||
--reset-optimizer --reset-dataloader --reset-meters \
|
||||
--required-batch-size-multiple 2 \
|
||||
--init-token 0 --separator-token 2 \
|
||||
--arch roberta_base \
|
||||
--criterion sentence_prediction \
|
||||
--classification-head-name $HEAD_NAME \
|
||||
--num-classes $NUM_CLASSES \
|
||||
--dropout 0.1 --attention-dropout 0.1 \
|
||||
--weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \
|
||||
--clip-norm 0.0 \
|
||||
--lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
|
||||
--max-epoch 5 --log-format tqdm --log-interval 1 --save-interval-updates 15000 --keep-interval-updates 5 --skip-invalid-size-inputs-valid-test \
|
||||
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \
|
||||
--find-unused-parameters \
|
||||
--update-freq 1
|
73
6_predict.py
Executable file
73
6_predict.py
Executable file
@ -0,0 +1,73 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from typing import List
|
||||
|
||||
from fairseq.models.roberta import RobertaModel
|
||||
from collections import OrderedDict
|
||||
import torch
|
||||
|
||||
|
||||
def get_batches(data_path: str, max_seq: int,
|
||||
batch_size: int, pad_index: int) -> List[torch.Tensor]:
|
||||
lines = []
|
||||
with open(data_path, 'rt') as f:
|
||||
for line in tqdm(f, desc=f'Reading {data_path}'):
|
||||
line = roberta.encode(line.rstrip('\n'))[:max_seq]
|
||||
lines.append(line)
|
||||
|
||||
tensor_list = []
|
||||
for i in tqdm(range(0, len(lines), batch_size), desc='Batching'):
|
||||
batch_text = lines[i: i + batch_size]
|
||||
# Get max length of batch
|
||||
max_len = max([tokens.size(0) for tokens in batch_text])
|
||||
|
||||
# Create empty tensor with padding index
|
||||
input_tensor = torch.LongTensor(len(batch_text), max_len).fill_(pad_index)
|
||||
# Fill tensor with tokens
|
||||
for i, tokens in enumerate(batch_text):
|
||||
input_tensor[i][:tokens.size(0)] = tokens
|
||||
tensor_list.append(input_tensor)
|
||||
|
||||
return tensor_list
|
||||
|
||||
|
||||
def predict(roberta: RobertaModel, batches: List[torch.Tensor], save_file: str):
|
||||
with open(save_file, 'wt') as fout:
|
||||
for batch in tqdm(batches, desc='Processing'):
|
||||
raw_prediction = roberta.predict('hesaid', batch)
|
||||
# Get probability for second class (M class)
|
||||
out_tensor = torch.exp(raw_prediction[:, 1])
|
||||
for line_prediction in out_tensor:
|
||||
# Get probability for first class
|
||||
fout.write(f'{line_prediction.item()}\n')
|
||||
|
||||
|
||||
def load_model():
|
||||
roberta = RobertaModel.from_pretrained(
|
||||
model_name_or_path='checkpoints',
|
||||
data_name_or_path='data-bin',
|
||||
sentencepiece_vocab='roberta_base_fairseq/sentencepiece.bpe.model',
|
||||
checkpoint_file='checkpoint_best.pt',
|
||||
bpe='sentencepiece',
|
||||
)
|
||||
return roberta
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
roberta = load_model()
|
||||
roberta.cuda()
|
||||
roberta.train()
|
||||
|
||||
max_seq = 512
|
||||
batch_size = 5
|
||||
pad_index = roberta.task.source_dictionary.pad()
|
||||
|
||||
for dir_name in ['dev-0', 'dev-1', 'test-A']:
|
||||
batches = get_batches(f'data/{dir_name}/in.tsv', max_seq, batch_size, pad_index)
|
||||
for i in range(12):
|
||||
print(f'Processing iteration: {i}')
|
||||
j = str(i)
|
||||
predict(roberta, batches, f'data/{dir_name}/out.tsv' + j)
|
43
7_average.py
Executable file
43
7_average.py
Executable file
@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
REGEX_FILE_NAME = re.compile(r'^out.tsv[0-9]+$')
|
||||
|
||||
|
||||
def avarage(dir_name: str):
|
||||
print(f'Processing {dir_name}')
|
||||
file_names = [f for f in os.listdir(dir_name)
|
||||
if REGEX_FILE_NAME.match(f)]
|
||||
|
||||
if not file_names:
|
||||
print('ERROR! Not found files!')
|
||||
return
|
||||
|
||||
print(f'Reading from files: {file_names}')
|
||||
files = [open(dir_name + '/' + f) for f in file_names]
|
||||
f_out = open(dir_name + '/out-model=best_sum.tsv', 'w')
|
||||
progress = tqdm(desc=dir_name)
|
||||
|
||||
while True:
|
||||
try:
|
||||
hyps = [float(next(x).rstrip()) for x in files]
|
||||
except StopIteration:
|
||||
break
|
||||
avg_v = np.mean(hyps)
|
||||
f_out.write(str(avg_v) + '\n')
|
||||
progress.update(1)
|
||||
|
||||
f_out.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
avarage('data/dev-0')
|
||||
avarage('data/dev-1')
|
||||
avarage('data/test-A')
|
137314
dev-0/out.tsv
Normal file
137314
dev-0/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
BIN
dev-0/out.tsv0.xz
Normal file
BIN
dev-0/out.tsv0.xz
Normal file
Binary file not shown.
BIN
dev-0/out.tsv1.xz
Normal file
BIN
dev-0/out.tsv1.xz
Normal file
Binary file not shown.
BIN
dev-0/out.tsv10.xz
Normal file
BIN
dev-0/out.tsv10.xz
Normal file
Binary file not shown.
BIN
dev-0/out.tsv11.xz
Normal file
BIN
dev-0/out.tsv11.xz
Normal file
Binary file not shown.
BIN
dev-0/out.tsv2.xz
Normal file
BIN
dev-0/out.tsv2.xz
Normal file
Binary file not shown.
BIN
dev-0/out.tsv3.xz
Normal file
BIN
dev-0/out.tsv3.xz
Normal file
Binary file not shown.
BIN
dev-0/out.tsv4.xz
Normal file
BIN
dev-0/out.tsv4.xz
Normal file
Binary file not shown.
BIN
dev-0/out.tsv5.xz
Normal file
BIN
dev-0/out.tsv5.xz
Normal file
Binary file not shown.
BIN
dev-0/out.tsv6.xz
Normal file
BIN
dev-0/out.tsv6.xz
Normal file
Binary file not shown.
BIN
dev-0/out.tsv7.xz
Normal file
BIN
dev-0/out.tsv7.xz
Normal file
Binary file not shown.
BIN
dev-0/out.tsv8.xz
Normal file
BIN
dev-0/out.tsv8.xz
Normal file
Binary file not shown.
BIN
dev-0/out.tsv9.xz
Normal file
BIN
dev-0/out.tsv9.xz
Normal file
Binary file not shown.
156606
dev-1/out.tsv
Normal file
156606
dev-1/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
BIN
dev-1/out.tsv0.xz
Normal file
BIN
dev-1/out.tsv0.xz
Normal file
Binary file not shown.
BIN
dev-1/out.tsv1.xz
Normal file
BIN
dev-1/out.tsv1.xz
Normal file
Binary file not shown.
BIN
dev-1/out.tsv10.xz
Normal file
BIN
dev-1/out.tsv10.xz
Normal file
Binary file not shown.
BIN
dev-1/out.tsv11.xz
Normal file
BIN
dev-1/out.tsv11.xz
Normal file
Binary file not shown.
BIN
dev-1/out.tsv2.xz
Normal file
BIN
dev-1/out.tsv2.xz
Normal file
Binary file not shown.
BIN
dev-1/out.tsv3.xz
Normal file
BIN
dev-1/out.tsv3.xz
Normal file
Binary file not shown.
BIN
dev-1/out.tsv4.xz
Normal file
BIN
dev-1/out.tsv4.xz
Normal file
Binary file not shown.
BIN
dev-1/out.tsv5.xz
Normal file
BIN
dev-1/out.tsv5.xz
Normal file
Binary file not shown.
BIN
dev-1/out.tsv6.xz
Normal file
BIN
dev-1/out.tsv6.xz
Normal file
Binary file not shown.
BIN
dev-1/out.tsv7.xz
Normal file
BIN
dev-1/out.tsv7.xz
Normal file
Binary file not shown.
BIN
dev-1/out.tsv8.xz
Normal file
BIN
dev-1/out.tsv8.xz
Normal file
Binary file not shown.
BIN
dev-1/out.tsv9.xz
Normal file
BIN
dev-1/out.tsv9.xz
Normal file
Binary file not shown.
134618
test-A/out.tsv
Normal file
134618
test-A/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
BIN
test-A/out.tsv0.xz
Normal file
BIN
test-A/out.tsv0.xz
Normal file
Binary file not shown.
BIN
test-A/out.tsv1.xz
Normal file
BIN
test-A/out.tsv1.xz
Normal file
Binary file not shown.
BIN
test-A/out.tsv10.xz
Normal file
BIN
test-A/out.tsv10.xz
Normal file
Binary file not shown.
BIN
test-A/out.tsv11.xz
Normal file
BIN
test-A/out.tsv11.xz
Normal file
Binary file not shown.
BIN
test-A/out.tsv2.xz
Normal file
BIN
test-A/out.tsv2.xz
Normal file
Binary file not shown.
BIN
test-A/out.tsv3.xz
Normal file
BIN
test-A/out.tsv3.xz
Normal file
Binary file not shown.
BIN
test-A/out.tsv4.xz
Normal file
BIN
test-A/out.tsv4.xz
Normal file
Binary file not shown.
BIN
test-A/out.tsv5.xz
Normal file
BIN
test-A/out.tsv5.xz
Normal file
Binary file not shown.
BIN
test-A/out.tsv6.xz
Normal file
BIN
test-A/out.tsv6.xz
Normal file
Binary file not shown.
BIN
test-A/out.tsv7.xz
Normal file
BIN
test-A/out.tsv7.xz
Normal file
Binary file not shown.
BIN
test-A/out.tsv8.xz
Normal file
BIN
test-A/out.tsv8.xz
Normal file
Binary file not shown.
BIN
test-A/out.tsv9.xz
Normal file
BIN
test-A/out.tsv9.xz
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user