import copy import sys from sacrebleu.metrics import BLEU, CHRF, TER import pandas as pd # pip install sacrebleu pandas # example usage one arg python 1.py model_cv_1_0_preds.csv # example usage mulitple args python 1.py model_cv_1_0_preds.csv model_cv_1_1_preds.csv model_cv_1_2_preds.csv # PREDICTED_COLUMN_NAME = 'query_annot' # LABEL_COLUMN_NAME = 'target_annot' # COLUMN_SEPARATOR = '\t' # NAMES_IN_TRAIN_COLUMN_NAME = 'annot_present_in_target' # or leave empty NAMES_IN_TRAIN_COLUMN_NAME = '' if no such column PREDICTED_COLUMN_NAME = 'plm_names' LABEL_COLUMN_NAME = 'targets' COLUMN_SEPARATOR = ',' NAMES_IN_TRAIN_COLUMN_NAME = '' # or leave empty NAMES_IN_TRAIN_COLUMN_NAME = '' if no such column bleu = BLEU() bleu_one_sentence = BLEU(effective_order=True) chrf = CHRF() def get_statistics(r): metrics = dict() r['score_bleu'] = r.apply( lambda row: round(bleu_one_sentence.sentence_score(row[PREDICTED_COLUMN_NAME], [row[LABEL_COLUMN_NAME]]).score, 2), axis=1) r['score_chrf'] = r.apply( lambda row: round(chrf.sentence_score(row[PREDICTED_COLUMN_NAME], [row[LABEL_COLUMN_NAME]]).score, 2), axis=1) r['score_exact_match'] = r.apply(lambda row: 1 if row[PREDICTED_COLUMN_NAME] == row[LABEL_COLUMN_NAME] else 0, axis=1) hyps = r[PREDICTED_COLUMN_NAME].tolist() references = [r[LABEL_COLUMN_NAME].tolist(), ] metrics['bleu'] = round(bleu.corpus_score(hyps, references).score, 2) metrics['chrf'] = round(chrf.corpus_score(hyps, references).score, 2) metrics['exact'] = round(float(100 * r['score_exact_match'].mean()), 2) return r, metrics def main(names_in_train = None): assert names_in_train in (True, False, None) predicted_all_splits = list() label_all_splits = list() for FILE_PATH in sys.argv[1:]: r = pd.read_csv(FILE_PATH,sep = COLUMN_SEPARATOR) if names_in_train == True: r= r[r[NAMES_IN_TRAIN_COLUMN_NAME] == True] elif names_in_train == False: r = r[r[NAMES_IN_TRAIN_COLUMN_NAME] == False] print(FILE_PATH + ':') report_with_metrics, metrics = get_statistics(r) predicted_all_splits.extend(r[PREDICTED_COLUMN_NAME].to_list()) label_all_splits.extend(r[LABEL_COLUMN_NAME].to_list()) print('samples:', len(r)) print(metrics) report_with_metrics = report_with_metrics.sort_values(by='score_chrf', ascending=False)[ [LABEL_COLUMN_NAME, PREDICTED_COLUMN_NAME, 'score_bleu', 'score_chrf', 'score_exact_match']].drop_duplicates() report_with_metrics.to_csv(FILE_PATH.replace('.', '_metrics.'), sep=COLUMN_SEPARATOR, index=False) if len(sys.argv) > 2: print('ALL SPLITS:') label_all_splits = [label_all_splits, ] metrics = dict() print('samples:', len(label_all_splits)) metrics['bleu'] = round(bleu.corpus_score(predicted_all_splits, label_all_splits).score, 2) metrics['chrf'] = round(chrf.corpus_score(predicted_all_splits, label_all_splits).score, 2) metrics['exact'] = round(float(100 * r['score_exact_match'].mean()), 2) print(metrics) print('WHOLE DATASET:') main() print() if len(NAMES_IN_TRAIN_COLUMN_NAME) > 0: print('NAMES IN TRAIN:') main(names_in_train=True) print() print('NAMES NOT IN TRAIN:') main(names_in_train=False) print()