Move script into main fn
This commit is contained in:
parent
46d7831b98
commit
fe63ef269c
@ -12,6 +12,9 @@ default_batch_size = 64
|
||||
default_epochs = 5
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def hour_to_int(text: str):
|
||||
return float(text.replace(':', ''))
|
||||
|
||||
@ -86,7 +89,7 @@ def test(dataloader, model, loss_fn):
|
||||
pred = model(X)
|
||||
test_loss += loss_fn(pred, y).item()
|
||||
test_loss /= num_batches
|
||||
print(f"Avg loss: {test_loss:>8f} \n")
|
||||
print(f"Avg loss (using {loss_fn}): {test_loss:>8f} \n")
|
||||
|
||||
|
||||
def setup_args():
|
||||
@ -97,7 +100,7 @@ def setup_args():
|
||||
return args_parser.parse_args()
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
def main():
|
||||
print(f"Using {device} device")
|
||||
|
||||
args = setup_args()
|
||||
@ -129,3 +132,6 @@ print("Done!")
|
||||
|
||||
torch.save(model.state_dict(), './model_out')
|
||||
print("Model saved in ./model_out file.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue
Block a user