5.6 KiB
5.6 KiB
PyTorch train model
Wczytanie niezbędnych bibliotek
import pandas as pd
import torch
from torch import nn
from sklearn.preprocessing import LabelEncoder
Wczytanie danych z pliku
data = pd.read_csv('../data/btc_test.csv')
data = pd.DataFrame(data)
Przygotowanie danych
Powinienembył zrobić to w zadaniu 1
le = LabelEncoder()
data['date'] = le.fit_transform(data['date'])
data['hour'] = le.fit_transform(data['hour'])
data['Volume BTC'] = data['Volume BTC']/10
# Przekształć łańcuchy znaków na liczby aby zapobiec 'TypeError: can't convert np.ndarray of type numpy.object_.'
for col in data.columns:
data[col] = pd.to_numeric(data[col], errors='coerce')
# Zamień brakujące wartości na 0 aby zapobiec 'IndexError: Target -9223372036854775808 is out of bounds.'
data = data.fillna(0)
Przygotowanie inputs oraz targets
# Przekształć dane na tensory PyTorch
inputs = torch.tensor(data[['date', 'hour', 'Volume BTC']].values, dtype=torch.float32)
Model
model = nn.Sequential(
nn.Flatten(),
nn.Linear(inputs.shape[1], 64),
nn.ReLU(),
nn.Linear(64, 1),
)
Wczytanie wytrenowanego modelu
model.load_state_dict(torch.load("model.pth"))
model.eval()
Sequential( (0): Flatten(start_dim=1, end_dim=-1) (1): Linear(in_features=3, out_features=64, bias=True) (2): ReLU() (3): Linear(in_features=64, out_features=1, bias=True) )
Predykcja modelu
predictions = model(inputs)
predicted_data = (predictions.float() * 10)
print(predicted_data)
tensor([[772837.5000], [772837.5000], [772837.5000], ..., [772837.5000], [772837.5000], [772837.5000]], grad_fn=<MulBackward0>)
Zapis danych do pliku csv
predicted_data_df = pd.DataFrame(torch.detach(predicted_data).numpy())
predicted_data_df.to_csv("predict_result.csv", index=False)
[1;31m---------------------------------------------------------------------------[0m [1;31mTypeError[0m Traceback (most recent call last) Cell [1;32mIn[300], line 1[0m [1;32m----> 1[0m predicted_data_df [38;5;241m=[39m pd[38;5;241m.[39mDataFrame([43mtorch[49m[38;5;241;43m.[39;49m[43mdetach[49m[43m([49m[43m)[49m[38;5;241m.[39mnumpy(predicted_data)) [0;32m 2[0m predicted_data_df[38;5;241m.[39mto_csv([38;5;124m"[39m[38;5;124mpredict_result.csv[39m[38;5;124m"[39m, index[38;5;241m=[39m[38;5;28;01mFalse[39;00m) [1;31mTypeError[0m: detach() missing 1 required positional arguments: "input"