Append to a file instead of overwriting it
This commit is contained in:
parent
28585de7c3
commit
9e1e5ae3b9
@ -2,7 +2,6 @@ import torch
|
|||||||
import sys
|
import sys
|
||||||
from train_model import MLP, PlantsDataset, test
|
from train_model import MLP, PlantsDataset, test
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from contextlib import redirect_stdout
|
|
||||||
|
|
||||||
|
|
||||||
def load_model():
|
def load_model():
|
||||||
@ -22,9 +21,9 @@ def main():
|
|||||||
|
|
||||||
loss_fn = torch.nn.MSELoss()
|
loss_fn = torch.nn.MSELoss()
|
||||||
|
|
||||||
with open('evaluation_results.txt', 'w') as f:
|
loss = test(dataloader, model, loss_fn)
|
||||||
with redirect_stdout(f):
|
with open('evaluation_results.txt', 'a+') as f:
|
||||||
test(dataloader, model, loss_fn)
|
f.write(f'{str(loss)}\n')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -90,6 +90,7 @@ def test(dataloader, model, loss_fn):
|
|||||||
test_loss += loss_fn(pred, y).item()
|
test_loss += loss_fn(pred, y).item()
|
||||||
test_loss /= num_batches
|
test_loss /= num_batches
|
||||||
print(f"Avg loss (using {loss_fn}): {test_loss:>8f} \n")
|
print(f"Avg loss (using {loss_fn}): {test_loss:>8f} \n")
|
||||||
|
return test_loss
|
||||||
|
|
||||||
|
|
||||||
def setup_args():
|
def setup_args():
|
||||||
|
Loading…
Reference in New Issue
Block a user