Zaktualizuj 'ml_pytroch_sacred.py'
This commit is contained in:
parent
6ec74ae14b
commit
8d11e06d9f
@ -100,21 +100,22 @@ def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):
|
||||
history.append(result)
|
||||
return history
|
||||
|
||||
@ex.automain
|
||||
def my_main(epochs):
|
||||
input_size = len(input_cols)
|
||||
output_size = len(output_cols)
|
||||
model=Model_xPosition()
|
||||
epochs = 2000
|
||||
lr = 1e-5
|
||||
learning_proccess = fit(epochs, lr, model, train_loader, val_loader)
|
||||
def predict_single(input, target, model):
|
||||
inputs = input.unsqueeze(0)
|
||||
predictions = model(inputs)
|
||||
prediction = predictions[0].detach()
|
||||
|
||||
return "Target: "+str(target)+" Predicted: "+str(prediction)+"\n"
|
||||
input_size = len(input_cols)
|
||||
output_size = len(output_cols)
|
||||
model=Model_xPosition()
|
||||
epochs = 1000
|
||||
lr = 1e-5
|
||||
learning_proccess = fit(epochs, lr, model, train_loader, val_loader)
|
||||
def predict_single(input, target, model):
|
||||
inputs = input.unsqueeze(0)
|
||||
predictions = model(inputs)
|
||||
prediction = predictions[0].detach()
|
||||
|
||||
return "Target: "+str(target)+" Predicted: "+str(prediction)+"\n"
|
||||
|
||||
@ex.automain
|
||||
def my_main(epochs):
|
||||
for i in random.sample(range(0, len(val_ds)), 10):
|
||||
input_, target = val_ds[i]
|
||||
print(predict_single(input_, target, model),end="")
|
||||
|
Loading…
Reference in New Issue
Block a user