diff --git a/evaluate.py b/evaluate.py index 247b1b5..112e821 100644 --- a/evaluate.py +++ b/evaluate.py @@ -32,7 +32,8 @@ if os.path.exists(metrics_file): metrics_df = pd.read_csv(metrics_file) else: metrics_df = pd.DataFrame(columns=['top_1_accuracy', 'top_5_accuracy']) -new_row = pd.DataFrame({'top_1_accuracy': np.mean(top_1_accuracy), 'top_5_accuracy': np.mean(top_5_accuracy)}, index=[0]) + +new_row = pd.DataFrame([{'top_1_accuracy': np.mean(top_1_accuracy.numpy()), 'top_5_accuracy': np.mean(top_5_accuracy.numpy())}]) metrics_df = pd.concat([metrics_df, new_row], ignore_index=True) metrics_df.to_csv(metrics_file, index=False)