This commit is contained in:
eugene 2023-06-07 00:44:36 +02:00
parent 14b6b95631
commit 4ba5294f0b

View File

@ -98,4 +98,24 @@ def train_model():
rating_prediction = model(new_writer_encoded)
print("Predicted rating for the writer 'Jim Cash':", rating_prediction.item())
# Create dataloaders for evaluation
test_dataset = CustomDataset(X_test, y_test)
test_dataloader = DataLoader(test_dataset, batch_size=64)
# Evaluate the model
model.eval()
predictions = []
targets = []
for inputs, targets_batch in test_dataloader:
outputs = model(inputs)
predictions.extend(outputs.tolist())
targets.extend(targets_batch.tolist())
# Calculate evaluation metrics
predictions = torch.FloatTensor(predictions).squeeze()
targets = torch.FloatTensor(targets).squeeze()
rmse = mean_squared_error(targets, predictions, squared=False)
ex.run()