ium_151636/script5_4.py
2023-06-06 22:38:32 +02:00

81 lines
2.1 KiB
Python

import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import mean_squared_error
import pickle
# Define the neural network model
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = nn.Linear(1, 64)
self.fc2 = nn.Linear(64, 1)
self.relu = nn.ReLU()
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# Define a custom dataset
class CustomDataset(Dataset):
def __init__(self, X, y):
self.X = torch.FloatTensor(X.values.reshape(-1, 1))
self.y = torch.FloatTensor(y.values.reshape(-1, 1))
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
return self.X[idx], self.y[idx]
# Load the dataset
df = pd.read_csv('data.csv')
# Select the relevant columns (e.g., 'Rating' and 'Writer')
data = df[['Rating', 'Writer']]
# Drop rows with missing values
data = data.dropna()
# Convert the 'Writer' column to numeric using label encoding
encoder = LabelEncoder()
data['Writer'] = encoder.fit_transform(data['Writer'])
# Split the data into training and testing sets
X = data['Writer']
y = data['Rating']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Create the model instance
model = Model()
# Load the saved model
model.load_state_dict(torch.load('model.pth'))
# Create dataloaders for evaluation
test_dataset = CustomDataset(X_test, y_test)
test_dataloader = DataLoader(test_dataset, batch_size=64)
# Evaluate the model
model.eval()
predictions = []
targets = []
for inputs, targets_batch in test_dataloader:
outputs = model(inputs)
predictions.extend(outputs.tolist())
targets.extend(targets_batch.tolist())
# Calculate evaluation metrics
predictions = torch.FloatTensor(predictions).squeeze()
targets = torch.FloatTensor(targets).squeeze()
rmse = mean_squared_error(targets, predictions, squared=False)
print("RMSE:", rmse)