25 KiB
25 KiB
AITech — Uczenie maszynowe — laboratoria
4. Sieci neuronowe (PyTorch)
Przykład implementacji sieci neuronowej do rozpoznawania cyfr ze zbioru MNIST, według https://github.com/pytorch/examples/tree/master/mnist
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
class Net(nn.Module):
"""W PyTorchu tworzenie sieci neuronowej
polega na zdefiniowaniu klasy, która dziedziczy z nn.Module.
"""
def __init__(self):
super().__init__()
# Warstwy splotowe
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
# Warstwy dropout
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
# Warstwy liniowe
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
"""Definiujemy przechodzenie "do przodu" jako kolejne przekształcenia wejścia x"""
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
def train(model, device, train_loader, optimizer, epoch, log_interval, dry_run):
"""Uczenie modelu"""
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device) # wrzucenie danych na kartę graficzną (jeśli dotyczy)
optimizer.zero_grad() # wyzerowanie gradientu
output = model(data) # przejście "do przodu"
loss = F.nll_loss(output, target) # obliczenie funkcji kosztu
loss.backward() # propagacja wsteczna
optimizer.step() # krok optymalizatora
if batch_idx % log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
if dry_run:
break
def test(model, device, test_loader):
"""Testowanie modelu"""
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device) # wrzucenie danych na kartę graficzną (jeśli dotyczy)
output = model(data) # przejście "do przodu"
test_loss += F.nll_loss(output, target, reduction='sum').item() # suma kosztów z każdego batcha
pred = output.argmax(dim=1, keepdim=True) # predykcja na podstawie maks. logarytmu prawdopodobieństwa
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset) # obliczenie kosztu na zbiorze testowym
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
def run(
batch_size=64,
test_batch_size=1000,
epochs=14,
lr=1.0,
gamma=0.7,
no_cuda=False,
dry_run=False,
seed=1,
log_interval=10,
save_model=False,
):
"""Main training function.
Arguments:
batch_size - wielkość batcha podczas uczenia (default: 64),
test_batch_size - wielkość batcha podczas testowania (default: 1000)
epochs - liczba epok uczenia (default: 14)
lr - współczynnik uczenia (learning rate) (default: 1.0)
gamma - współczynnik gamma (dla optymalizatora) (default: 0.7)
no_cuda - wyłącza uczenie na karcie graficznej (default: False)
dry_run - szybko ("na sucho") sprawdza pojedyncze przejście (default: False)
seed - ziarno generatora liczb pseudolosowych (default: 1)
log_interval - interwał logowania stanu uczenia (default: 10)
save_model - zapisuje bieżący model (default: False)
"""
use_cuda = no_cuda and torch.cuda.is_available()
torch.manual_seed(seed)
device = torch.device("cuda" if use_cuda else "cpu")
train_kwargs = {'batch_size': batch_size}
test_kwargs = {'batch_size': test_batch_size}
if use_cuda:
cuda_kwargs = {'num_workers': 1,
'pin_memory': True,
'shuffle': True}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
dataset1 = datasets.MNIST('../data', train=True, download=True,
transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
for epoch in range(1, epochs + 1):
train(model, device, train_loader, optimizer, epoch, log_interval, dry_run)
test(model, device, test_loader)
scheduler.step()
if save_model:
torch.save(model.state_dict(), "mnist_cnn.pt")
Uwaga: uruchomienie tego przykładu długo trwa. Żeby trwało krócej, można zmniejszyć liczbę epok.
run(epochs=5)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data\MNIST\raw\train-images-idx3-ubyte.gz
HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
[1;31m---------------------------------------------------------------------------[0m [1;31mHTTPError[0m Traceback (most recent call last) [1;32m<ipython-input-3-ac97a8e9c0c6>[0m in [0;36m<module>[1;34m[0m [1;32m----> 1[1;33m [0mrun[0m[1;33m([0m[0mepochs[0m[1;33m=[0m[1;36m5[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0m [1;32m<ipython-input-1-173aee26f027>[0m in [0;36mrun[1;34m(batch_size, test_batch_size, epochs, lr, gamma, no_cuda, dry_run, seed, log_interval, save_model)[0m [0;32m 127[0m [0mtransforms[0m[1;33m.[0m[0mNormalize[0m[1;33m([0m[1;33m([0m[1;36m0.1307[0m[1;33m,[0m[1;33m)[0m[1;33m,[0m [1;33m([0m[1;36m0.3081[0m[1;33m,[0m[1;33m)[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0;32m 128[0m ]) [1;32m--> 129[1;33m dataset1 = datasets.MNIST('../data', train=True, download=True, [0m[0;32m 130[0m transform=transform) [0;32m 131[0m dataset2 = datasets.MNIST('../data', train=False, [1;32m~\anaconda3\lib\site-packages\torchvision\datasets\mnist.py[0m in [0;36m__init__[1;34m(self, root, train, transform, target_transform, download)[0m [0;32m 77[0m [1;33m[0m[0m [0;32m 78[0m [1;32mif[0m [0mdownload[0m[1;33m:[0m[1;33m[0m[1;33m[0m[0m [1;32m---> 79[1;33m [0mself[0m[1;33m.[0m[0mdownload[0m[1;33m([0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0m[0;32m 80[0m [1;33m[0m[0m [0;32m 81[0m [1;32mif[0m [1;32mnot[0m [0mself[0m[1;33m.[0m[0m_check_exists[0m[1;33m([0m[1;33m)[0m[1;33m:[0m[1;33m[0m[1;33m[0m[0m [1;32m~\anaconda3\lib\site-packages\torchvision\datasets\mnist.py[0m in [0;36mdownload[1;34m(self)[0m [0;32m 144[0m [1;32mfor[0m [0murl[0m[1;33m,[0m [0mmd5[0m [1;32min[0m [0mself[0m[1;33m.[0m[0mresources[0m[1;33m:[0m[1;33m[0m[1;33m[0m[0m [0;32m 145[0m [0mfilename[0m [1;33m=[0m [0murl[0m[1;33m.[0m[0mrpartition[0m[1;33m([0m[1;34m'/'[0m[1;33m)[0m[1;33m[[0m[1;36m2[0m[1;33m][0m[1;33m[0m[1;33m[0m[0m [1;32m--> 146[1;33m [0mdownload_and_extract_archive[0m[1;33m([0m[0murl[0m[1;33m,[0m [0mdownload_root[0m[1;33m=[0m[0mself[0m[1;33m.[0m[0mraw_folder[0m[1;33m,[0m [0mfilename[0m[1;33m=[0m[0mfilename[0m[1;33m,[0m [0mmd5[0m[1;33m=[0m[0mmd5[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0m[0;32m 147[0m [1;33m[0m[0m [0;32m 148[0m [1;31m# process and save as torch files[0m[1;33m[0m[1;33m[0m[1;33m[0m[0m [1;32m~\anaconda3\lib\site-packages\torchvision\datasets\utils.py[0m in [0;36mdownload_and_extract_archive[1;34m(url, download_root, extract_root, filename, md5, remove_finished)[0m [0;32m 254[0m [0mfilename[0m [1;33m=[0m [0mos[0m[1;33m.[0m[0mpath[0m[1;33m.[0m[0mbasename[0m[1;33m([0m[0murl[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0;32m 255[0m [1;33m[0m[0m [1;32m--> 256[1;33m [0mdownload_url[0m[1;33m([0m[0murl[0m[1;33m,[0m [0mdownload_root[0m[1;33m,[0m [0mfilename[0m[1;33m,[0m [0mmd5[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0m[0;32m 257[0m [1;33m[0m[0m [0;32m 258[0m [0marchive[0m [1;33m=[0m [0mos[0m[1;33m.[0m[0mpath[0m[1;33m.[0m[0mjoin[0m[1;33m([0m[0mdownload_root[0m[1;33m,[0m [0mfilename[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [1;32m~\anaconda3\lib\site-packages\torchvision\datasets\utils.py[0m in [0;36mdownload_url[1;34m(url, root, filename, md5)[0m [0;32m 82[0m ) [0;32m 83[0m [1;32melse[0m[1;33m:[0m[1;33m[0m[1;33m[0m[0m [1;32m---> 84[1;33m [1;32mraise[0m [0me[0m[1;33m[0m[1;33m[0m[0m [0m[0;32m 85[0m [1;31m# check integrity of downloaded file[0m[1;33m[0m[1;33m[0m[1;33m[0m[0m [0;32m 86[0m [1;32mif[0m [1;32mnot[0m [0mcheck_integrity[0m[1;33m([0m[0mfpath[0m[1;33m,[0m [0mmd5[0m[1;33m)[0m[1;33m:[0m[1;33m[0m[1;33m[0m[0m [1;32m~\anaconda3\lib\site-packages\torchvision\datasets\utils.py[0m in [0;36mdownload_url[1;34m(url, root, filename, md5)[0m [0;32m 68[0m [1;32mtry[0m[1;33m:[0m[1;33m[0m[1;33m[0m[0m [0;32m 69[0m [0mprint[0m[1;33m([0m[1;34m'Downloading '[0m [1;33m+[0m [0murl[0m [1;33m+[0m [1;34m' to '[0m [1;33m+[0m [0mfpath[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [1;32m---> 70[1;33m urllib.request.urlretrieve( [0m[0;32m 71[0m [0murl[0m[1;33m,[0m [0mfpath[0m[1;33m,[0m[1;33m[0m[1;33m[0m[0m [0;32m 72[0m [0mreporthook[0m[1;33m=[0m[0mgen_bar_updater[0m[1;33m([0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [1;32m~\anaconda3\lib\urllib\request.py[0m in [0;36murlretrieve[1;34m(url, filename, reporthook, data)[0m [0;32m 245[0m [0murl_type[0m[1;33m,[0m [0mpath[0m [1;33m=[0m [0m_splittype[0m[1;33m([0m[0murl[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0;32m 246[0m [1;33m[0m[0m [1;32m--> 247[1;33m [1;32mwith[0m [0mcontextlib[0m[1;33m.[0m[0mclosing[0m[1;33m([0m[0murlopen[0m[1;33m([0m[0murl[0m[1;33m,[0m [0mdata[0m[1;33m)[0m[1;33m)[0m [1;32mas[0m [0mfp[0m[1;33m:[0m[1;33m[0m[1;33m[0m[0m [0m[0;32m 248[0m [0mheaders[0m [1;33m=[0m [0mfp[0m[1;33m.[0m[0minfo[0m[1;33m([0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0;32m 249[0m [1;33m[0m[0m [1;32m~\anaconda3\lib\urllib\request.py[0m in [0;36murlopen[1;34m(url, data, timeout, cafile, capath, cadefault, context)[0m [0;32m 220[0m [1;32melse[0m[1;33m:[0m[1;33m[0m[1;33m[0m[0m [0;32m 221[0m [0mopener[0m [1;33m=[0m [0m_opener[0m[1;33m[0m[1;33m[0m[0m [1;32m--> 222[1;33m [1;32mreturn[0m [0mopener[0m[1;33m.[0m[0mopen[0m[1;33m([0m[0murl[0m[1;33m,[0m [0mdata[0m[1;33m,[0m [0mtimeout[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0m[0;32m 223[0m [1;33m[0m[0m [0;32m 224[0m [1;32mdef[0m [0minstall_opener[0m[1;33m([0m[0mopener[0m[1;33m)[0m[1;33m:[0m[1;33m[0m[1;33m[0m[0m [1;32m~\anaconda3\lib\urllib\request.py[0m in [0;36mopen[1;34m(self, fullurl, data, timeout)[0m [0;32m 529[0m [1;32mfor[0m [0mprocessor[0m [1;32min[0m [0mself[0m[1;33m.[0m[0mprocess_response[0m[1;33m.[0m[0mget[0m[1;33m([0m[0mprotocol[0m[1;33m,[0m [1;33m[[0m[1;33m][0m[1;33m)[0m[1;33m:[0m[1;33m[0m[1;33m[0m[0m [0;32m 530[0m [0mmeth[0m [1;33m=[0m [0mgetattr[0m[1;33m([0m[0mprocessor[0m[1;33m,[0m [0mmeth_name[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [1;32m--> 531[1;33m [0mresponse[0m [1;33m=[0m [0mmeth[0m[1;33m([0m[0mreq[0m[1;33m,[0m [0mresponse[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0m[0;32m 532[0m [1;33m[0m[0m [0;32m 533[0m [1;32mreturn[0m [0mresponse[0m[1;33m[0m[1;33m[0m[0m [1;32m~\anaconda3\lib\urllib\request.py[0m in [0;36mhttp_response[1;34m(self, request, response)[0m [0;32m 638[0m [1;31m# request was successfully received, understood, and accepted.[0m[1;33m[0m[1;33m[0m[1;33m[0m[0m [0;32m 639[0m [1;32mif[0m [1;32mnot[0m [1;33m([0m[1;36m200[0m [1;33m<=[0m [0mcode[0m [1;33m<[0m [1;36m300[0m[1;33m)[0m[1;33m:[0m[1;33m[0m[1;33m[0m[0m [1;32m--> 640[1;33m response = self.parent.error( [0m[0;32m 641[0m 'http', request, response, code, msg, hdrs) [0;32m 642[0m [1;33m[0m[0m [1;32m~\anaconda3\lib\urllib\request.py[0m in [0;36merror[1;34m(self, proto, *args)[0m [0;32m 567[0m [1;32mif[0m [0mhttp_err[0m[1;33m:[0m[1;33m[0m[1;33m[0m[0m [0;32m 568[0m [0margs[0m [1;33m=[0m [1;33m([0m[0mdict[0m[1;33m,[0m [1;34m'default'[0m[1;33m,[0m [1;34m'http_error_default'[0m[1;33m)[0m [1;33m+[0m [0morig_args[0m[1;33m[0m[1;33m[0m[0m [1;32m--> 569[1;33m [1;32mreturn[0m [0mself[0m[1;33m.[0m[0m_call_chain[0m[1;33m([0m[1;33m*[0m[0margs[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0m[0;32m 570[0m [1;33m[0m[0m [0;32m 571[0m [1;31m# XXX probably also want an abstract factory that knows when it makes[0m[1;33m[0m[1;33m[0m[1;33m[0m[0m [1;32m~\anaconda3\lib\urllib\request.py[0m in [0;36m_call_chain[1;34m(self, chain, kind, meth_name, *args)[0m [0;32m 500[0m [1;32mfor[0m [0mhandler[0m [1;32min[0m [0mhandlers[0m[1;33m:[0m[1;33m[0m[1;33m[0m[0m [0;32m 501[0m [0mfunc[0m [1;33m=[0m [0mgetattr[0m[1;33m([0m[0mhandler[0m[1;33m,[0m [0mmeth_name[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [1;32m--> 502[1;33m [0mresult[0m [1;33m=[0m [0mfunc[0m[1;33m([0m[1;33m*[0m[0margs[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0m[0;32m 503[0m [1;32mif[0m [0mresult[0m [1;32mis[0m [1;32mnot[0m [1;32mNone[0m[1;33m:[0m[1;33m[0m[1;33m[0m[0m [0;32m 504[0m [1;32mreturn[0m [0mresult[0m[1;33m[0m[1;33m[0m[0m [1;32m~\anaconda3\lib\urllib\request.py[0m in [0;36mhttp_error_default[1;34m(self, req, fp, code, msg, hdrs)[0m [0;32m 647[0m [1;32mclass[0m [0mHTTPDefaultErrorHandler[0m[1;33m([0m[0mBaseHandler[0m[1;33m)[0m[1;33m:[0m[1;33m[0m[1;33m[0m[0m [0;32m 648[0m [1;32mdef[0m [0mhttp_error_default[0m[1;33m([0m[0mself[0m[1;33m,[0m [0mreq[0m[1;33m,[0m [0mfp[0m[1;33m,[0m [0mcode[0m[1;33m,[0m [0mmsg[0m[1;33m,[0m [0mhdrs[0m[1;33m)[0m[1;33m:[0m[1;33m[0m[1;33m[0m[0m [1;32m--> 649[1;33m [1;32mraise[0m [0mHTTPError[0m[1;33m([0m[0mreq[0m[1;33m.[0m[0mfull_url[0m[1;33m,[0m [0mcode[0m[1;33m,[0m [0mmsg[0m[1;33m,[0m [0mhdrs[0m[1;33m,[0m [0mfp[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0m[0;32m 650[0m [1;33m[0m[0m [0;32m 651[0m [1;32mclass[0m [0mHTTPRedirectHandler[0m[1;33m([0m[0mBaseHandler[0m[1;33m)[0m[1;33m:[0m[1;33m[0m[1;33m[0m[0m [1;31mHTTPError[0m: HTTP Error 503: Service Unavailable