{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## PyTorch train model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Wczytanie niezbędnych bibliotek" ] }, { "cell_type": "code", "execution_count": 233, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from torch.utils.data import DataLoader, TensorDataset\n", "import pandas as pd\n", "from sklearn.preprocessing import LabelEncoder" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Wczytanie danych z pliku" ] }, { "cell_type": "code", "execution_count": 234, "metadata": {}, "outputs": [], "source": [ "data = pd.read_csv('../data/btc_train.csv')\n", "data = pd.DataFrame(data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Przygotowanie danych\n", "Powinienembył zrobić to w zadaniu 1" ] }, { "cell_type": "code", "execution_count": 235, "metadata": {}, "outputs": [], "source": [ "le = LabelEncoder()\n", "data['date'] = le.fit_transform(data['date'])\n", "data['hour'] = le.fit_transform(data['hour'])\n", "data['Volume BTC'] = data['Volume BTC']/10\n", "\n", "# Przekształć łańcuchy znaków na liczby aby zapobiec 'TypeError: can't convert np.ndarray of type numpy.object_.'\n", "for col in data.columns:\n", " data[col] = pd.to_numeric(data[col], errors='coerce')\n", "\n", "# # Zamień brakujące wartości na 0 aby zapobiec 'IndexError: Target -9223372036854775808 is out of bounds.'\n", "data = data.fillna(0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Przygotowanie inputs oraz targets" ] }, { "cell_type": "code", "execution_count": 236, "metadata": {}, "outputs": [], "source": [ "# Przekształć dane na tensory PyTorch\n", "inputs = torch.tensor(data[['date', 'hour', 'Volume BTC']].values, dtype=torch.float32)\n", "targets = torch.tensor(data['Volume USD'].values, dtype=torch.float32).view(-1, 1) # zmieniono z torch.float32 na torch.long aby zapobiec RuntimeError: expected scalar type Long but found Float\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Utwórz DataLoader" ] }, { "cell_type": "code", "execution_count": 237, "metadata": {}, "outputs": [], "source": [ "data_set = TensorDataset(inputs, targets)\n", "data_loader = DataLoader(data_set, batch_size=64)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model" ] }, { "cell_type": "code", "execution_count": 238, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(\n", " nn.Flatten(),\n", " nn.Linear(inputs.shape[1], 64),\n", " nn.ReLU(),\n", " nn.Linear(64, 1),\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Funkcja straty i optymalizator" ] }, { "cell_type": "code", "execution_count": 239, "metadata": {}, "outputs": [], "source": [ "loss_fn = nn.MSELoss()\n", "optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Trenowanie modelu" ] }, { "cell_type": "code", "execution_count": 240, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model został wytrenowany.\n" ] } ], "source": [ "for epoch in range(10):\n", " for X, y in data_loader:\n", " pred = model(X)\n", " loss = loss_fn(pred, y)\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", "print(\"Model został wytrenowany.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Zapis modelu do pliku" ] }, { "cell_type": "code", "execution_count": 241, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model został zapisany do pliku 'model.pth'.\n" ] } ], "source": [ "torch.save(model.state_dict(), \"model.pth\")\n", "print(\"Model został zapisany do pliku 'model.pth'.\")\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 2 }