Laboratoria 4. Sieci neuronowe
This commit is contained in:
parent
e9444c4cc8
commit
c3106e188f
219
lab/4_Sieci_neuronowe_PyTorch.ipynb
Normal file
219
lab/4_Sieci_neuronowe_PyTorch.ipynb
Normal file
@ -0,0 +1,219 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"slideshow": {
|
||||
"slide_type": "-"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"### AITech — Uczenie maszynowe — laboratoria\n",
|
||||
"# 4. Sieci neuronowe (PyTorch)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Przykład implementacji sieci neuronowej do rozpoznawania cyfr ze zbioru MNIST, według https://github.com/pytorch/examples/tree/master/mnist"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"import torch.optim as optim\n",
|
||||
"from torchvision import datasets, transforms\n",
|
||||
"from torch.optim.lr_scheduler import StepLR\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class Net(nn.Module):\n",
|
||||
" \"\"\"W PyTorchu tworzenie sieci neuronowej\n",
|
||||
" polega na zdefiniowaniu klasy, która dziedziczy z nn.Module.\n",
|
||||
" \"\"\"\n",
|
||||
" \n",
|
||||
" def __init__(self):\n",
|
||||
" super().__init__()\n",
|
||||
" \n",
|
||||
" # Warstwy splotowe\n",
|
||||
" self.conv1 = nn.Conv2d(1, 32, 3, 1)\n",
|
||||
" self.conv2 = nn.Conv2d(32, 64, 3, 1)\n",
|
||||
" \n",
|
||||
" # Warstwy dropout\n",
|
||||
" self.dropout1 = nn.Dropout(0.25)\n",
|
||||
" self.dropout2 = nn.Dropout(0.5)\n",
|
||||
" \n",
|
||||
" # Warstwy liniowe\n",
|
||||
" self.fc1 = nn.Linear(9216, 128)\n",
|
||||
" self.fc2 = nn.Linear(128, 10)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" \"\"\"Definiujemy przechodzenie \"do przodu\" jako kolejne przekształcenia wejścia x\"\"\"\n",
|
||||
" x = self.conv1(x)\n",
|
||||
" x = F.relu(x)\n",
|
||||
" x = self.conv2(x)\n",
|
||||
" x = F.relu(x)\n",
|
||||
" x = F.max_pool2d(x, 2)\n",
|
||||
" x = self.dropout1(x)\n",
|
||||
" x = torch.flatten(x, 1)\n",
|
||||
" x = self.fc1(x)\n",
|
||||
" x = F.relu(x)\n",
|
||||
" x = self.dropout2(x)\n",
|
||||
" x = self.fc2(x)\n",
|
||||
" output = F.log_softmax(x, dim=1)\n",
|
||||
" return output\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def train(model, device, train_loader, optimizer, epoch, log_interval, dry_run):\n",
|
||||
" \"\"\"Uczenie modelu\"\"\"\n",
|
||||
" model.train()\n",
|
||||
" for batch_idx, (data, target) in enumerate(train_loader):\n",
|
||||
" data, target = data.to(device), target.to(device) # wrzucenie danych na kartę graficzną (jeśli dotyczy)\n",
|
||||
" optimizer.zero_grad() # wyzerowanie gradientu\n",
|
||||
" output = model(data) # przejście \"do przodu\"\n",
|
||||
" loss = F.nll_loss(output, target) # obliczenie funkcji kosztu\n",
|
||||
" loss.backward() # propagacja wsteczna\n",
|
||||
" optimizer.step() # krok optymalizatora\n",
|
||||
" if batch_idx % log_interval == 0:\n",
|
||||
" print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
|
||||
" epoch, batch_idx * len(data), len(train_loader.dataset),\n",
|
||||
" 100. * batch_idx / len(train_loader), loss.item()))\n",
|
||||
" if dry_run:\n",
|
||||
" break\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def test(model, device, test_loader):\n",
|
||||
" \"\"\"Testowanie modelu\"\"\"\n",
|
||||
" model.eval()\n",
|
||||
" test_loss = 0\n",
|
||||
" correct = 0\n",
|
||||
" with torch.no_grad():\n",
|
||||
" for data, target in test_loader:\n",
|
||||
" data, target = data.to(device), target.to(device) # wrzucenie danych na kartę graficzną (jeśli dotyczy)\n",
|
||||
" output = model(data) # przejście \"do przodu\"\n",
|
||||
" test_loss += F.nll_loss(output, target, reduction='sum').item() # suma kosztów z każdego batcha\n",
|
||||
" pred = output.argmax(dim=1, keepdim=True) # predykcja na podstawie maks. logarytmu prawdopodobieństwa\n",
|
||||
" correct += pred.eq(target.view_as(pred)).sum().item()\n",
|
||||
"\n",
|
||||
" test_loss /= len(test_loader.dataset) # obliczenie kosztu na zbiorze testowym\n",
|
||||
"\n",
|
||||
" print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
|
||||
" test_loss, correct, len(test_loader.dataset),\n",
|
||||
" 100. * correct / len(test_loader.dataset)))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def run(\n",
|
||||
" batch_size=64,\n",
|
||||
" test_batch_size=1000,\n",
|
||||
" epochs=14,\n",
|
||||
" lr=1.0,\n",
|
||||
" gamma=0.7,\n",
|
||||
" no_cuda=False,\n",
|
||||
" dry_run=False,\n",
|
||||
" seed=1,\n",
|
||||
" log_interval=10,\n",
|
||||
" save_model=False,\n",
|
||||
" ):\n",
|
||||
" \"\"\"Main training function.\n",
|
||||
" \n",
|
||||
" Arguments:\n",
|
||||
" batch_size - wielkość batcha podczas uczenia (default: 64),\n",
|
||||
" test_batch_size - wielkość batcha podczas testowania (default: 1000)\n",
|
||||
" epochs - liczba epok uczenia (default: 14)\n",
|
||||
" lr - współczynnik uczenia (learning rate) (default: 1.0)\n",
|
||||
" gamma - współczynnik gamma (dla optymalizatora) (default: 0.7)\n",
|
||||
" no_cuda - wyłącza uczenie na karcie graficznej (default: False)\n",
|
||||
" dry_run - szybko (\"na sucho\") sprawdza pojedyncze przejście (default: False)\n",
|
||||
" seed - ziarno generatora liczb pseudolosowych (default: 1)\n",
|
||||
" log_interval - interwał logowania stanu uczenia (default: 10)\n",
|
||||
" save_model - zapisuje bieżący model (default: False)\n",
|
||||
" \"\"\"\n",
|
||||
" use_cuda = no_cuda and torch.cuda.is_available()\n",
|
||||
"\n",
|
||||
" torch.manual_seed(seed)\n",
|
||||
"\n",
|
||||
" device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
|
||||
"\n",
|
||||
" train_kwargs = {'batch_size': batch_size}\n",
|
||||
" test_kwargs = {'batch_size': test_batch_size}\n",
|
||||
" if use_cuda:\n",
|
||||
" cuda_kwargs = {'num_workers': 1,\n",
|
||||
" 'pin_memory': True,\n",
|
||||
" 'shuffle': True}\n",
|
||||
" train_kwargs.update(cuda_kwargs)\n",
|
||||
" test_kwargs.update(cuda_kwargs)\n",
|
||||
"\n",
|
||||
" transform=transforms.Compose([\n",
|
||||
" transforms.ToTensor(),\n",
|
||||
" transforms.Normalize((0.1307,), (0.3081,))\n",
|
||||
" ])\n",
|
||||
" dataset1 = datasets.MNIST('../data', train=True, download=True,\n",
|
||||
" transform=transform)\n",
|
||||
" dataset2 = datasets.MNIST('../data', train=False,\n",
|
||||
" transform=transform)\n",
|
||||
" train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)\n",
|
||||
" test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)\n",
|
||||
"\n",
|
||||
" model = Net().to(device)\n",
|
||||
" optimizer = optim.Adadelta(model.parameters(), lr=lr)\n",
|
||||
"\n",
|
||||
" scheduler = StepLR(optimizer, step_size=1, gamma=gamma)\n",
|
||||
" for epoch in range(1, epochs + 1):\n",
|
||||
" train(model, device, train_loader, optimizer, epoch, log_interval, dry_run)\n",
|
||||
" test(model, device, test_loader)\n",
|
||||
" scheduler.step()\n",
|
||||
"\n",
|
||||
" if save_model:\n",
|
||||
" torch.save(model.state_dict(), \"mnist_cnn.pt\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**Uwaga**: uruchomienie tego przykładu długo trwa. Żeby trwało krócej, można zmniejszyć liczbę epok."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"run(epochs=5)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"celltoolbar": "Slideshow",
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.3"
|
||||
},
|
||||
"livereveal": {
|
||||
"start_slideshow_at": "selected",
|
||||
"theme": "amu"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
Loading…
Reference in New Issue
Block a user