ium_s451499/ium_05/learning.ipynb
2024-04-24 02:37:19 +02:00

4.8 KiB

PyTorch train model

Wczytanie niezbędnych bibliotek

import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd
from sklearn.preprocessing import LabelEncoder

Wczytanie danych z pliku

data = pd.read_csv('../data/btc_train.csv')
data = pd.DataFrame(data)

Przygotowanie danych

Powinienembył zrobić to w zadaniu 1

le = LabelEncoder()
data['date'] = le.fit_transform(data['date'])
data['hour'] = le.fit_transform(data['hour'])
data['Volume BTC'] = data['Volume BTC']/10

# Przekształć łańcuchy znaków na liczby aby zapobiec 'TypeError: can't convert np.ndarray of type numpy.object_.'
for col in data.columns:
    data[col] = pd.to_numeric(data[col], errors='coerce')

# # Zamień brakujące wartości na 0 aby zapobiec 'IndexError: Target -9223372036854775808 is out of bounds.'
data = data.fillna(0)

Przygotowanie inputs oraz targets

# Przekształć dane na tensory PyTorch
inputs = torch.tensor(data[['date', 'hour', 'Volume BTC']].values, dtype=torch.float32)
targets = torch.tensor(data['Volume USD'].values, dtype=torch.float32).view(-1, 1) # zmieniono z torch.float32 na torch.long aby zapobiec RuntimeError: expected scalar type Long but found Float

Utwórz DataLoader

data_set = TensorDataset(inputs, targets)
data_loader = DataLoader(data_set, batch_size=64)

Model

model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(inputs.shape[1], 64),
    nn.ReLU(),
    nn.Linear(64, 1),
)

Funkcja straty i optymalizator

loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

Trenowanie modelu

for epoch in range(10):
    for X, y in data_loader:
        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

print("Model został wytrenowany.")
Model został wytrenowany.

Zapis modelu do pliku

torch.save(model.state_dict(), "model.pth")
print("Model został zapisany do pliku 'model.pth'.")
Model został zapisany do pliku 'model.pth'.