upd
This commit is contained in:
parent
14b6b95631
commit
4ba5294f0b
20
script7.py
20
script7.py
@ -98,4 +98,24 @@ def train_model():
|
||||
rating_prediction = model(new_writer_encoded)
|
||||
print("Predicted rating for the writer 'Jim Cash':", rating_prediction.item())
|
||||
|
||||
# Create dataloaders for evaluation
|
||||
test_dataset = CustomDataset(X_test, y_test)
|
||||
test_dataloader = DataLoader(test_dataset, batch_size=64)
|
||||
|
||||
# Evaluate the model
|
||||
model.eval()
|
||||
predictions = []
|
||||
targets = []
|
||||
for inputs, targets_batch in test_dataloader:
|
||||
outputs = model(inputs)
|
||||
predictions.extend(outputs.tolist())
|
||||
targets.extend(targets_batch.tolist())
|
||||
|
||||
# Calculate evaluation metrics
|
||||
predictions = torch.FloatTensor(predictions).squeeze()
|
||||
targets = torch.FloatTensor(targets).squeeze()
|
||||
|
||||
rmse = mean_squared_error(targets, predictions, squared=False)
|
||||
|
||||
|
||||
ex.run()
|
||||
|
Loading…
Reference in New Issue
Block a user