From 4ba5294f0ba301eecc6c1a1c74d60204f1073796 Mon Sep 17 00:00:00 2001 From: eugene Date: Wed, 7 Jun 2023 00:44:36 +0200 Subject: [PATCH] upd --- script7.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/script7.py b/script7.py index 5cde5c4..da31889 100644 --- a/script7.py +++ b/script7.py @@ -98,4 +98,24 @@ def train_model(): rating_prediction = model(new_writer_encoded) print("Predicted rating for the writer 'Jim Cash':", rating_prediction.item()) + # Create dataloaders for evaluation + test_dataset = CustomDataset(X_test, y_test) + test_dataloader = DataLoader(test_dataset, batch_size=64) + + # Evaluate the model + model.eval() + predictions = [] + targets = [] + for inputs, targets_batch in test_dataloader: + outputs = model(inputs) + predictions.extend(outputs.tolist()) + targets.extend(targets_batch.tolist()) + + # Calculate evaluation metrics + predictions = torch.FloatTensor(predictions).squeeze() + targets = torch.FloatTensor(targets).squeeze() + + rmse = mean_squared_error(targets, predictions, squared=False) + + ex.run()