ium_s451499/ium_05/learning.ipynb
2024-04-24 02:37:19 +02:00

229 lines
4.8 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## PyTorch train model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Wczytanie niezbędnych bibliotek"
]
},
{
"cell_type": "code",
"execution_count": 233,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from torch.utils.data import DataLoader, TensorDataset\n",
"import pandas as pd\n",
"from sklearn.preprocessing import LabelEncoder"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Wczytanie danych z pliku"
]
},
{
"cell_type": "code",
"execution_count": 234,
"metadata": {},
"outputs": [],
"source": [
"data = pd.read_csv('../data/btc_train.csv')\n",
"data = pd.DataFrame(data)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Przygotowanie danych\n",
"Powinienembył zrobić to w zadaniu 1"
]
},
{
"cell_type": "code",
"execution_count": 235,
"metadata": {},
"outputs": [],
"source": [
"le = LabelEncoder()\n",
"data['date'] = le.fit_transform(data['date'])\n",
"data['hour'] = le.fit_transform(data['hour'])\n",
"data['Volume BTC'] = data['Volume BTC']/10\n",
"\n",
"# Przekształć łańcuchy znaków na liczby aby zapobiec 'TypeError: can't convert np.ndarray of type numpy.object_.'\n",
"for col in data.columns:\n",
" data[col] = pd.to_numeric(data[col], errors='coerce')\n",
"\n",
"# # Zamień brakujące wartości na 0 aby zapobiec 'IndexError: Target -9223372036854775808 is out of bounds.'\n",
"data = data.fillna(0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Przygotowanie inputs oraz targets"
]
},
{
"cell_type": "code",
"execution_count": 236,
"metadata": {},
"outputs": [],
"source": [
"# Przekształć dane na tensory PyTorch\n",
"inputs = torch.tensor(data[['date', 'hour', 'Volume BTC']].values, dtype=torch.float32)\n",
"targets = torch.tensor(data['Volume USD'].values, dtype=torch.float32).view(-1, 1) # zmieniono z torch.float32 na torch.long aby zapobiec RuntimeError: expected scalar type Long but found Float\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Utwórz DataLoader"
]
},
{
"cell_type": "code",
"execution_count": 237,
"metadata": {},
"outputs": [],
"source": [
"data_set = TensorDataset(inputs, targets)\n",
"data_loader = DataLoader(data_set, batch_size=64)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Model"
]
},
{
"cell_type": "code",
"execution_count": 238,
"metadata": {},
"outputs": [],
"source": [
"model = nn.Sequential(\n",
" nn.Flatten(),\n",
" nn.Linear(inputs.shape[1], 64),\n",
" nn.ReLU(),\n",
" nn.Linear(64, 1),\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Funkcja straty i optymalizator"
]
},
{
"cell_type": "code",
"execution_count": 239,
"metadata": {},
"outputs": [],
"source": [
"loss_fn = nn.MSELoss()\n",
"optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Trenowanie modelu"
]
},
{
"cell_type": "code",
"execution_count": 240,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model został wytrenowany.\n"
]
}
],
"source": [
"for epoch in range(10):\n",
" for X, y in data_loader:\n",
" pred = model(X)\n",
" loss = loss_fn(pred, y)\n",
"\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
"print(\"Model został wytrenowany.\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Zapis modelu do pliku"
]
},
{
"cell_type": "code",
"execution_count": 241,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model został zapisany do pliku 'model.pth'.\n"
]
}
],
"source": [
"torch.save(model.state_dict(), \"model.pth\")\n",
"print(\"Model został zapisany do pliku 'model.pth'.\")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}