biological-lms/1.py

98 lines
3.4 KiB
Python
Raw Permalink Normal View History

2024-08-04 14:02:57 +02:00
import copy
import sys
from sacrebleu.metrics import BLEU, CHRF, TER
import pandas as pd
# pip install sacrebleu pandas
# example usage one arg python 1.py model_cv_1_0_preds.csv
# example usage mulitple args python 1.py model_cv_1_0_preds.csv model_cv_1_1_preds.csv model_cv_1_2_preds.csv
2024-08-04 14:23:12 +02:00
# PREDICTED_COLUMN_NAME = 'query_annot'
# LABEL_COLUMN_NAME = 'target_annot'
# COLUMN_SEPARATOR = '\t'
2024-08-08 17:05:41 +02:00
# NAMES_IN_TRAIN_COLUMN_NAME = 'annot_present_in_target' # or leave empty NAMES_IN_TRAIN_COLUMN_NAME = '' if no such column
2024-08-04 14:02:57 +02:00
PREDICTED_COLUMN_NAME = 'plm_names'
LABEL_COLUMN_NAME = 'targets'
COLUMN_SEPARATOR = ','
2024-08-08 17:05:41 +02:00
NAMES_IN_TRAIN_COLUMN_NAME = '' # or leave empty NAMES_IN_TRAIN_COLUMN_NAME = '' if no such column
2024-08-04 14:02:57 +02:00
2024-08-04 14:23:12 +02:00
2024-08-04 14:02:57 +02:00
2024-08-04 14:23:12 +02:00
bleu = BLEU()
bleu_one_sentence = BLEU(effective_order=True)
2024-08-04 14:02:57 +02:00
chrf = CHRF()
2024-08-08 17:05:41 +02:00
2024-08-04 14:02:57 +02:00
def get_statistics(r):
metrics = dict()
r['score_bleu'] = r.apply(
2024-08-08 17:05:41 +02:00
lambda row: round(bleu_one_sentence.sentence_score(row[PREDICTED_COLUMN_NAME], [row[LABEL_COLUMN_NAME]]).score,
2), axis=1)
2024-08-04 14:02:57 +02:00
r['score_chrf'] = r.apply(
lambda row: round(chrf.sentence_score(row[PREDICTED_COLUMN_NAME], [row[LABEL_COLUMN_NAME]]).score, 2), axis=1)
r['score_exact_match'] = r.apply(lambda row: 1 if row[PREDICTED_COLUMN_NAME] == row[LABEL_COLUMN_NAME] else 0,
axis=1)
hyps = r[PREDICTED_COLUMN_NAME].tolist()
references = [r[LABEL_COLUMN_NAME].tolist(), ]
metrics['bleu'] = round(bleu.corpus_score(hyps, references).score, 2)
metrics['chrf'] = round(chrf.corpus_score(hyps, references).score, 2)
metrics['exact'] = round(float(100 * r['score_exact_match'].mean()), 2)
return r, metrics
2024-08-08 17:05:41 +02:00
def main(names_in_train = None):
assert names_in_train in (True, False, None)
predicted_all_splits = list()
label_all_splits = list()
2024-08-04 14:02:57 +02:00
2024-08-08 17:05:41 +02:00
for FILE_PATH in sys.argv[1:]:
r = pd.read_csv(FILE_PATH,sep = COLUMN_SEPARATOR)
if names_in_train == True:
r= r[r[NAMES_IN_TRAIN_COLUMN_NAME] == True]
elif names_in_train == False:
r = r[r[NAMES_IN_TRAIN_COLUMN_NAME] == False]
2024-08-04 14:02:57 +02:00
2024-08-08 17:05:41 +02:00
print(FILE_PATH + ':')
report_with_metrics, metrics = get_statistics(r)
2024-08-04 14:02:57 +02:00
2024-08-08 17:05:41 +02:00
predicted_all_splits.extend(r[PREDICTED_COLUMN_NAME].to_list())
label_all_splits.extend(r[LABEL_COLUMN_NAME].to_list())
print('samples:', len(r))
print(metrics)
report_with_metrics = report_with_metrics.sort_values(by='score_chrf', ascending=False)[
[LABEL_COLUMN_NAME, PREDICTED_COLUMN_NAME, 'score_bleu', 'score_chrf', 'score_exact_match']].drop_duplicates()
report_with_metrics.to_csv(FILE_PATH.replace('.', '_metrics.'), sep=COLUMN_SEPARATOR, index=False)
if len(sys.argv) > 2:
print('ALL SPLITS:')
label_all_splits = [label_all_splits, ]
metrics = dict()
print('samples:', len(label_all_splits))
metrics['bleu'] = round(bleu.corpus_score(predicted_all_splits, label_all_splits).score, 2)
metrics['chrf'] = round(chrf.corpus_score(predicted_all_splits, label_all_splits).score, 2)
metrics['exact'] = round(float(100 * r['score_exact_match'].mean()), 2)
print(metrics)
print('WHOLE DATASET:')
main()
print()
if len(NAMES_IN_TRAIN_COLUMN_NAME) > 0:
print('NAMES IN TRAIN:')
main(names_in_train=True)
print()
print('NAMES NOT IN TRAIN:')
main(names_in_train=False)
print()
2024-08-04 14:02:57 +02:00