Move script into main fn
This commit is contained in:
parent
46d7831b98
commit
fe63ef269c
@ -12,6 +12,9 @@ default_batch_size = 64
|
|||||||
default_epochs = 5
|
default_epochs = 5
|
||||||
|
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
def hour_to_int(text: str):
|
def hour_to_int(text: str):
|
||||||
return float(text.replace(':', ''))
|
return float(text.replace(':', ''))
|
||||||
|
|
||||||
@ -86,7 +89,7 @@ def test(dataloader, model, loss_fn):
|
|||||||
pred = model(X)
|
pred = model(X)
|
||||||
test_loss += loss_fn(pred, y).item()
|
test_loss += loss_fn(pred, y).item()
|
||||||
test_loss /= num_batches
|
test_loss /= num_batches
|
||||||
print(f"Avg loss: {test_loss:>8f} \n")
|
print(f"Avg loss (using {loss_fn}): {test_loss:>8f} \n")
|
||||||
|
|
||||||
|
|
||||||
def setup_args():
|
def setup_args():
|
||||||
@ -97,7 +100,7 @@ def setup_args():
|
|||||||
return args_parser.parse_args()
|
return args_parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
def main():
|
||||||
print(f"Using {device} device")
|
print(f"Using {device} device")
|
||||||
|
|
||||||
args = setup_args()
|
args = setup_args()
|
||||||
@ -129,3 +132,6 @@ print("Done!")
|
|||||||
|
|
||||||
torch.save(model.state_dict(), './model_out')
|
torch.save(model.state_dict(), './model_out')
|
||||||
print("Model saved in ./model_out file.")
|
print("Model saved in ./model_out file.")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Reference in New Issue
Block a user