From bdb188e60935c600503163dc2d93728263243afe Mon Sep 17 00:00:00 2001 From: kubapok Date: Thu, 8 Aug 2024 17:05:41 +0200 Subject: [PATCH] names in/not in train splits --- 1.py | 72 ++++++++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 51 insertions(+), 21 deletions(-) diff --git a/1.py b/1.py index 721c8be..38c8777 100644 --- a/1.py +++ b/1.py @@ -11,22 +11,27 @@ import pandas as pd # PREDICTED_COLUMN_NAME = 'query_annot' # LABEL_COLUMN_NAME = 'target_annot' # COLUMN_SEPARATOR = '\t' +# NAMES_IN_TRAIN_COLUMN_NAME = 'annot_present_in_target' # or leave empty NAMES_IN_TRAIN_COLUMN_NAME = '' if no such column + + PREDICTED_COLUMN_NAME = 'plm_names' LABEL_COLUMN_NAME = 'targets' COLUMN_SEPARATOR = ',' +NAMES_IN_TRAIN_COLUMN_NAME = '' # or leave empty NAMES_IN_TRAIN_COLUMN_NAME = '' if no such column -predicted_all_splits = list() -label_all_splits = list() bleu = BLEU() bleu_one_sentence = BLEU(effective_order=True) chrf = CHRF() + + def get_statistics(r): metrics = dict() r['score_bleu'] = r.apply( - lambda row: round(bleu_one_sentence.sentence_score(row[PREDICTED_COLUMN_NAME], [row[LABEL_COLUMN_NAME]]).score, 2), axis=1) + lambda row: round(bleu_one_sentence.sentence_score(row[PREDICTED_COLUMN_NAME], [row[LABEL_COLUMN_NAME]]).score, + 2), axis=1) r['score_chrf'] = r.apply( lambda row: round(chrf.sentence_score(row[PREDICTED_COLUMN_NAME], [row[LABEL_COLUMN_NAME]]).score, 2), axis=1) r['score_exact_match'] = r.apply(lambda row: 1 if row[PREDICTED_COLUMN_NAME] == row[LABEL_COLUMN_NAME] else 0, @@ -42,26 +47,51 @@ def get_statistics(r): return r, metrics -for FILE_PATH in sys.argv[1:]: - r = pd.read_csv(FILE_PATH,sep = COLUMN_SEPARATOR) +def main(names_in_train = None): + assert names_in_train in (True, False, None) + predicted_all_splits = list() + label_all_splits = list() - print(FILE_PATH + ':') - report_with_metrics, metrics = get_statistics(r) - predicted_all_splits.extend(r[PREDICTED_COLUMN_NAME].to_list()) - label_all_splits.extend(r[LABEL_COLUMN_NAME].to_list()) + for FILE_PATH in sys.argv[1:]: + r = pd.read_csv(FILE_PATH,sep = COLUMN_SEPARATOR) + if names_in_train == True: + r= r[r[NAMES_IN_TRAIN_COLUMN_NAME] == True] + elif names_in_train == False: + r = r[r[NAMES_IN_TRAIN_COLUMN_NAME] == False] - print(metrics) + print(FILE_PATH + ':') + report_with_metrics, metrics = get_statistics(r) - report_with_metrics = report_with_metrics.sort_values(by='score_chrf', ascending=False)[ - [LABEL_COLUMN_NAME, PREDICTED_COLUMN_NAME, 'score_bleu', 'score_chrf', 'score_exact_match']].drop_duplicates() - report_with_metrics.to_csv(FILE_PATH.replace('.', '_metrics.'), sep=COLUMN_SEPARATOR, index=False) + predicted_all_splits.extend(r[PREDICTED_COLUMN_NAME].to_list()) + label_all_splits.extend(r[LABEL_COLUMN_NAME].to_list()) + + print('samples:', len(r)) + print(metrics) + + report_with_metrics = report_with_metrics.sort_values(by='score_chrf', ascending=False)[ + [LABEL_COLUMN_NAME, PREDICTED_COLUMN_NAME, 'score_bleu', 'score_chrf', 'score_exact_match']].drop_duplicates() + report_with_metrics.to_csv(FILE_PATH.replace('.', '_metrics.'), sep=COLUMN_SEPARATOR, index=False) + + if len(sys.argv) > 2: + print('ALL SPLITS:') + label_all_splits = [label_all_splits, ] + metrics = dict() + print('samples:', len(label_all_splits)) + metrics['bleu'] = round(bleu.corpus_score(predicted_all_splits, label_all_splits).score, 2) + metrics['chrf'] = round(chrf.corpus_score(predicted_all_splits, label_all_splits).score, 2) + metrics['exact'] = round(float(100 * r['score_exact_match'].mean()), 2) + print(metrics) + + +print('WHOLE DATASET:') +main() +print() +if len(NAMES_IN_TRAIN_COLUMN_NAME) > 0: + print('NAMES IN TRAIN:') + main(names_in_train=True) + print() + print('NAMES NOT IN TRAIN:') + main(names_in_train=False) + print() -if len(sys.argv) > 2: - print('ALL SPLITS:') - label_all_splits = [label_all_splits, ] - metrics = dict() - metrics['bleu'] = round(bleu.corpus_score(predicted_all_splits, label_all_splits).score, 2) - metrics['chrf'] = round(chrf.corpus_score(predicted_all_splits, label_all_splits).score, 2) - metrics['exact'] = round(float(100 * r['score_exact_match'].mean()), 2) - print(metrics) \ No newline at end of file