diff --git a/eval_model.py b/eval_model.py index cac9a81..a14eea4 100644 --- a/eval_model.py +++ b/eval_model.py @@ -2,7 +2,6 @@ import torch import sys from train_model import MLP, PlantsDataset, test from torch.utils.data import DataLoader -from contextlib import redirect_stdout def load_model(): @@ -22,9 +21,9 @@ def main(): loss_fn = torch.nn.MSELoss() - with open('evaluation_results.txt', 'w') as f: - with redirect_stdout(f): - test(dataloader, model, loss_fn) + loss = test(dataloader, model, loss_fn) + with open('evaluation_results.txt', 'a+') as f: + f.write(f'{str(loss)}\n') if __name__ == "__main__": diff --git a/train_model.py b/train_model.py index ecc27ef..cdaedeb 100644 --- a/train_model.py +++ b/train_model.py @@ -90,6 +90,7 @@ def test(dataloader, model, loss_fn): test_loss += loss_fn(pred, y).item() test_loss /= num_batches print(f"Avg loss (using {loss_fn}): {test_loss:>8f} \n") + return test_loss def setup_args():