diff --git a/config.yaml b/config.yaml index cb7e730..4ff337a 100644 --- a/config.yaml +++ b/config.yaml @@ -1,7 +1,8 @@ -pretrained_processor_path: "Zombely/plwiki-proto-fine-tuned-v2" -pretrained_model_path: "Zombely/plwiki-proto-fine-tuned-v2" +pretrained_processor_path: "naver-clova-ix/donut-proto" +pretrained_model_path: "naver-clova-ix/donut-proto" validation_dataset_path: "Zombely/diachronia-ocr" validation_dataset_split: "train" has_metadata: False print_output: True -output_file_dir: "../../gonito-outs" \ No newline at end of file +output_file_dir: "../../gonito-outs" +test_name: "proto-test" diff --git a/donut-eval.py b/donut-eval.py index 23efe11..5754483 100644 --- a/donut-eval.py +++ b/donut-eval.py @@ -63,7 +63,7 @@ def main(config): output_list.append(seq) if config.output_file_dir: df = pd.DataFrame(map(lambda x: x.get('text_sequence', ''), output_list)) - df.to_csv(f'{config.output_file_dir}/{config.pretrained_processor_path}-out.tsv', sep='\t', header=False, index=False) + df.to_csv(f'{config.output_file_dir}/{config.test_name}-out.tsv', sep='\t', header=False, index=False) if config.has_metadata: scores = {"accuracies": accs, "mean_accuracy": np.mean(accs)} @@ -77,4 +77,4 @@ if __name__ == "__main__": config = Config(args.config) config.argv_update(left_argv) - main(config) \ No newline at end of file + main(config)