Dodanie 'neural_network.py'
This commit is contained in:
parent
4a1a3293fc
commit
042e993137
103
neural_network.py
Normal file
103
neural_network.py
Normal file
@ -0,0 +1,103 @@
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as f
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
from matplotlib.pyplot import imshow
|
||||
import os
|
||||
import PIL
|
||||
import numpy as np
|
||||
from matplotlib.pyplot import imshow
|
||||
|
||||
def to_negative(img):
|
||||
img = PIL.ImageOps.invert(img)
|
||||
return img
|
||||
|
||||
class Negative(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, img):
|
||||
return to_negative(img)
|
||||
|
||||
def plotdigit(image):
|
||||
img = np.reshape(image, (-1, 100))
|
||||
imshow(img, cmap='Greys')
|
||||
|
||||
transform = transforms.Compose([Negative(), transforms.ToTensor()])
|
||||
train_set = torchvision.datasets.ImageFolder(root='train', transform=transform)
|
||||
classes = ("apple", "potato")
|
||||
|
||||
BATCH_SIZE = 2
|
||||
train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.flatten = nn.Flatten()
|
||||
self.linear_relu_stack = nn.Sequential(
|
||||
nn.Linear(3*100*100, 512),
|
||||
nn.ReLU(),
|
||||
nn.Linear(512, 512),
|
||||
nn.ReLU(),
|
||||
nn.Linear(512, 2),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.linear_relu_stack = self.linear_relu_stack.to(device)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.flatten(x).to(device)
|
||||
logits = self.linear_relu_stack(x).to(device)
|
||||
return logits
|
||||
|
||||
def training_network():
|
||||
net = Net()
|
||||
net = net.to(device)
|
||||
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
|
||||
|
||||
for epoch in range(4):
|
||||
running_loss = 0.0
|
||||
for i, data in enumerate(train_loader, 0):
|
||||
inputs, labels = data[0].to(device), data[1].to(device)
|
||||
optimizer.zero_grad()
|
||||
outputs = net(inputs.to(device))
|
||||
loss = criterion(outputs, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_loss += loss.item()
|
||||
if i % 2000 == 1999:
|
||||
print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss))
|
||||
running_loss = 0.0
|
||||
|
||||
print("Finished training")
|
||||
save_network_to_file(net)
|
||||
|
||||
|
||||
def result_from_network(net, loaded_image):
|
||||
image = PIL.Image.open(loaded_image)
|
||||
pil_to_tensor = transforms.ToTensor()(image.convert("RGB")).unsqueeze_(0)
|
||||
outputs = net(pil_to_tensor.to(device))
|
||||
|
||||
return classes[torch.max(outputs, 1)[1]]
|
||||
|
||||
|
||||
def save_network_to_file(network):
|
||||
torch.save(network.state_dict(), 'network_model.pth')
|
||||
print("Network saved to file")
|
||||
|
||||
|
||||
def load_network_from_structure(network):
|
||||
network.load_state_dict(torch.load('network_model.pth'))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(torch.cuda.is_available())
|
||||
training_network()
|
||||
|
Loading…
Reference in New Issue
Block a user