commit dd80f72b312cdc9b3c2aabb766f6e1073fb3dac2 Author: kubapok Date: Sun Aug 4 14:02:57 2024 +0200 add 1 diff --git a/1.py b/1.py new file mode 100644 index 0000000..65aa3d8 --- /dev/null +++ b/1.py @@ -0,0 +1,64 @@ +import copy +import sys +from sacrebleu.metrics import BLEU, CHRF, TER +import pandas as pd + + +# pip install sacrebleu pandas +# example usage one arg python 1.py model_cv_1_0_preds.csv +# example usage mulitple args python 1.py model_cv_1_0_preds.csv model_cv_1_1_preds.csv model_cv_1_2_preds.csv + +PREDICTED_COLUMN_NAME = 'query_annot' +LABEL_COLUMN_NAME = 'target_annot' +PREDICTED_COLUMN_NAME = 'plm_names' +LABEL_COLUMN_NAME = 'targets' +COLUMN_SEPARATOR = ',' + + +predicted_all_splits = list() +label_all_splits = list() + +bleu = BLEU(effective_order=True) +chrf = CHRF() +def get_statistics(r): + metrics = dict() + r['score_bleu'] = r.apply( + lambda row: round(bleu.sentence_score(row[PREDICTED_COLUMN_NAME], [row[LABEL_COLUMN_NAME]]).score, 2), axis=1) + r['score_chrf'] = r.apply( + lambda row: round(chrf.sentence_score(row[PREDICTED_COLUMN_NAME], [row[LABEL_COLUMN_NAME]]).score, 2), axis=1) + r['score_exact_match'] = r.apply(lambda row: 1 if row[PREDICTED_COLUMN_NAME] == row[LABEL_COLUMN_NAME] else 0, + axis=1) + + hyps = r[PREDICTED_COLUMN_NAME].tolist() + references = [r[LABEL_COLUMN_NAME].tolist(), ] + + metrics['bleu'] = round(bleu.corpus_score(hyps, references).score, 2) + metrics['chrf'] = round(chrf.corpus_score(hyps, references).score, 2) + metrics['exact'] = round(float(100 * r['score_exact_match'].mean()), 2) + + return r, metrics + + +for FILE_PATH in sys.argv[1:]: + r = pd.read_csv(FILE_PATH,sep = COLUMN_SEPARATOR) + + print(FILE_PATH + ':') + report_with_metrics, metrics = get_statistics(r) + + predicted_all_splits.extend(r[PREDICTED_COLUMN_NAME].to_list()) + label_all_splits.extend(r[LABEL_COLUMN_NAME].to_list()) + + print(metrics) + + report_with_metrics = report_with_metrics.sort_values(by='score_chrf', ascending=False)[ + [LABEL_COLUMN_NAME, PREDICTED_COLUMN_NAME, 'score_bleu', 'score_chrf', 'score_exact_match']].drop_duplicates() + report_with_metrics.to_csv(FILE_PATH.replace('.', '_metrics.'), sep=COLUMN_SEPARATOR, index=False) + +if len(sys.argv) > 2: + print('ALL SPLITS:') + label_all_splits = [label_all_splits, ] + metrics = dict() + metrics['bleu'] = round(bleu.corpus_score(predicted_all_splits, label_all_splits).score, 2) + metrics['chrf'] = round(chrf.corpus_score(predicted_all_splits, label_all_splits).score, 2) + metrics['exact'] = round(float(100 * r['score_exact_match'].mean()), 2) + print(metrics) \ No newline at end of file