names in/not in train splits
This commit is contained in:
parent
77a1e13c8f
commit
bdb188e609
40
1.py
40
1.py
@ -11,22 +11,27 @@ import pandas as pd
|
|||||||
# PREDICTED_COLUMN_NAME = 'query_annot'
|
# PREDICTED_COLUMN_NAME = 'query_annot'
|
||||||
# LABEL_COLUMN_NAME = 'target_annot'
|
# LABEL_COLUMN_NAME = 'target_annot'
|
||||||
# COLUMN_SEPARATOR = '\t'
|
# COLUMN_SEPARATOR = '\t'
|
||||||
|
# NAMES_IN_TRAIN_COLUMN_NAME = 'annot_present_in_target' # or leave empty NAMES_IN_TRAIN_COLUMN_NAME = '' if no such column
|
||||||
|
|
||||||
|
|
||||||
PREDICTED_COLUMN_NAME = 'plm_names'
|
PREDICTED_COLUMN_NAME = 'plm_names'
|
||||||
LABEL_COLUMN_NAME = 'targets'
|
LABEL_COLUMN_NAME = 'targets'
|
||||||
COLUMN_SEPARATOR = ','
|
COLUMN_SEPARATOR = ','
|
||||||
|
NAMES_IN_TRAIN_COLUMN_NAME = '' # or leave empty NAMES_IN_TRAIN_COLUMN_NAME = '' if no such column
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
predicted_all_splits = list()
|
|
||||||
label_all_splits = list()
|
|
||||||
|
|
||||||
bleu = BLEU()
|
bleu = BLEU()
|
||||||
bleu_one_sentence = BLEU(effective_order=True)
|
bleu_one_sentence = BLEU(effective_order=True)
|
||||||
chrf = CHRF()
|
chrf = CHRF()
|
||||||
|
|
||||||
|
|
||||||
def get_statistics(r):
|
def get_statistics(r):
|
||||||
metrics = dict()
|
metrics = dict()
|
||||||
r['score_bleu'] = r.apply(
|
r['score_bleu'] = r.apply(
|
||||||
lambda row: round(bleu_one_sentence.sentence_score(row[PREDICTED_COLUMN_NAME], [row[LABEL_COLUMN_NAME]]).score, 2), axis=1)
|
lambda row: round(bleu_one_sentence.sentence_score(row[PREDICTED_COLUMN_NAME], [row[LABEL_COLUMN_NAME]]).score,
|
||||||
|
2), axis=1)
|
||||||
r['score_chrf'] = r.apply(
|
r['score_chrf'] = r.apply(
|
||||||
lambda row: round(chrf.sentence_score(row[PREDICTED_COLUMN_NAME], [row[LABEL_COLUMN_NAME]]).score, 2), axis=1)
|
lambda row: round(chrf.sentence_score(row[PREDICTED_COLUMN_NAME], [row[LABEL_COLUMN_NAME]]).score, 2), axis=1)
|
||||||
r['score_exact_match'] = r.apply(lambda row: 1 if row[PREDICTED_COLUMN_NAME] == row[LABEL_COLUMN_NAME] else 0,
|
r['score_exact_match'] = r.apply(lambda row: 1 if row[PREDICTED_COLUMN_NAME] == row[LABEL_COLUMN_NAME] else 0,
|
||||||
@ -42,8 +47,18 @@ def get_statistics(r):
|
|||||||
return r, metrics
|
return r, metrics
|
||||||
|
|
||||||
|
|
||||||
for FILE_PATH in sys.argv[1:]:
|
def main(names_in_train = None):
|
||||||
|
assert names_in_train in (True, False, None)
|
||||||
|
predicted_all_splits = list()
|
||||||
|
label_all_splits = list()
|
||||||
|
|
||||||
|
|
||||||
|
for FILE_PATH in sys.argv[1:]:
|
||||||
r = pd.read_csv(FILE_PATH,sep = COLUMN_SEPARATOR)
|
r = pd.read_csv(FILE_PATH,sep = COLUMN_SEPARATOR)
|
||||||
|
if names_in_train == True:
|
||||||
|
r= r[r[NAMES_IN_TRAIN_COLUMN_NAME] == True]
|
||||||
|
elif names_in_train == False:
|
||||||
|
r = r[r[NAMES_IN_TRAIN_COLUMN_NAME] == False]
|
||||||
|
|
||||||
print(FILE_PATH + ':')
|
print(FILE_PATH + ':')
|
||||||
report_with_metrics, metrics = get_statistics(r)
|
report_with_metrics, metrics = get_statistics(r)
|
||||||
@ -51,17 +66,32 @@ for FILE_PATH in sys.argv[1:]:
|
|||||||
predicted_all_splits.extend(r[PREDICTED_COLUMN_NAME].to_list())
|
predicted_all_splits.extend(r[PREDICTED_COLUMN_NAME].to_list())
|
||||||
label_all_splits.extend(r[LABEL_COLUMN_NAME].to_list())
|
label_all_splits.extend(r[LABEL_COLUMN_NAME].to_list())
|
||||||
|
|
||||||
|
print('samples:', len(r))
|
||||||
print(metrics)
|
print(metrics)
|
||||||
|
|
||||||
report_with_metrics = report_with_metrics.sort_values(by='score_chrf', ascending=False)[
|
report_with_metrics = report_with_metrics.sort_values(by='score_chrf', ascending=False)[
|
||||||
[LABEL_COLUMN_NAME, PREDICTED_COLUMN_NAME, 'score_bleu', 'score_chrf', 'score_exact_match']].drop_duplicates()
|
[LABEL_COLUMN_NAME, PREDICTED_COLUMN_NAME, 'score_bleu', 'score_chrf', 'score_exact_match']].drop_duplicates()
|
||||||
report_with_metrics.to_csv(FILE_PATH.replace('.', '_metrics.'), sep=COLUMN_SEPARATOR, index=False)
|
report_with_metrics.to_csv(FILE_PATH.replace('.', '_metrics.'), sep=COLUMN_SEPARATOR, index=False)
|
||||||
|
|
||||||
if len(sys.argv) > 2:
|
if len(sys.argv) > 2:
|
||||||
print('ALL SPLITS:')
|
print('ALL SPLITS:')
|
||||||
label_all_splits = [label_all_splits, ]
|
label_all_splits = [label_all_splits, ]
|
||||||
metrics = dict()
|
metrics = dict()
|
||||||
|
print('samples:', len(label_all_splits))
|
||||||
metrics['bleu'] = round(bleu.corpus_score(predicted_all_splits, label_all_splits).score, 2)
|
metrics['bleu'] = round(bleu.corpus_score(predicted_all_splits, label_all_splits).score, 2)
|
||||||
metrics['chrf'] = round(chrf.corpus_score(predicted_all_splits, label_all_splits).score, 2)
|
metrics['chrf'] = round(chrf.corpus_score(predicted_all_splits, label_all_splits).score, 2)
|
||||||
metrics['exact'] = round(float(100 * r['score_exact_match'].mean()), 2)
|
metrics['exact'] = round(float(100 * r['score_exact_match'].mean()), 2)
|
||||||
print(metrics)
|
print(metrics)
|
||||||
|
|
||||||
|
|
||||||
|
print('WHOLE DATASET:')
|
||||||
|
main()
|
||||||
|
print()
|
||||||
|
if len(NAMES_IN_TRAIN_COLUMN_NAME) > 0:
|
||||||
|
print('NAMES IN TRAIN:')
|
||||||
|
main(names_in_train=True)
|
||||||
|
print()
|
||||||
|
print('NAMES NOT IN TRAIN:')
|
||||||
|
main(names_in_train=False)
|
||||||
|
print()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user