upd
This commit is contained in:
parent
14b6b95631
commit
4ba5294f0b
20
script7.py
20
script7.py
@ -98,4 +98,24 @@ def train_model():
|
|||||||
rating_prediction = model(new_writer_encoded)
|
rating_prediction = model(new_writer_encoded)
|
||||||
print("Predicted rating for the writer 'Jim Cash':", rating_prediction.item())
|
print("Predicted rating for the writer 'Jim Cash':", rating_prediction.item())
|
||||||
|
|
||||||
|
# Create dataloaders for evaluation
|
||||||
|
test_dataset = CustomDataset(X_test, y_test)
|
||||||
|
test_dataloader = DataLoader(test_dataset, batch_size=64)
|
||||||
|
|
||||||
|
# Evaluate the model
|
||||||
|
model.eval()
|
||||||
|
predictions = []
|
||||||
|
targets = []
|
||||||
|
for inputs, targets_batch in test_dataloader:
|
||||||
|
outputs = model(inputs)
|
||||||
|
predictions.extend(outputs.tolist())
|
||||||
|
targets.extend(targets_batch.tolist())
|
||||||
|
|
||||||
|
# Calculate evaluation metrics
|
||||||
|
predictions = torch.FloatTensor(predictions).squeeze()
|
||||||
|
targets = torch.FloatTensor(targets).squeeze()
|
||||||
|
|
||||||
|
rmse = mean_squared_error(targets, predictions, squared=False)
|
||||||
|
|
||||||
|
|
||||||
ex.run()
|
ex.run()
|
||||||
|
Loading…
Reference in New Issue
Block a user