Zaktualizuj 'ml_pytroch_sacred.py'
This commit is contained in:
parent
6ec74ae14b
commit
8d11e06d9f
@ -100,21 +100,22 @@ def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):
|
|||||||
history.append(result)
|
history.append(result)
|
||||||
return history
|
return history
|
||||||
|
|
||||||
@ex.automain
|
|
||||||
def my_main(epochs):
|
input_size = len(input_cols)
|
||||||
input_size = len(input_cols)
|
output_size = len(output_cols)
|
||||||
output_size = len(output_cols)
|
model=Model_xPosition()
|
||||||
model=Model_xPosition()
|
epochs = 1000
|
||||||
epochs = 2000
|
lr = 1e-5
|
||||||
lr = 1e-5
|
learning_proccess = fit(epochs, lr, model, train_loader, val_loader)
|
||||||
learning_proccess = fit(epochs, lr, model, train_loader, val_loader)
|
def predict_single(input, target, model):
|
||||||
def predict_single(input, target, model):
|
|
||||||
inputs = input.unsqueeze(0)
|
inputs = input.unsqueeze(0)
|
||||||
predictions = model(inputs)
|
predictions = model(inputs)
|
||||||
prediction = predictions[0].detach()
|
prediction = predictions[0].detach()
|
||||||
|
|
||||||
return "Target: "+str(target)+" Predicted: "+str(prediction)+"\n"
|
return "Target: "+str(target)+" Predicted: "+str(prediction)+"\n"
|
||||||
|
|
||||||
|
@ex.automain
|
||||||
|
def my_main(epochs):
|
||||||
for i in random.sample(range(0, len(val_ds)), 10):
|
for i in random.sample(range(0, len(val_ds)), 10):
|
||||||
input_, target = val_ds[i]
|
input_, target = val_ds[i]
|
||||||
print(predict_single(input_, target, model),end="")
|
print(predict_single(input_, target, model),end="")
|
||||||
|
Loading…
Reference in New Issue
Block a user