neural_network #4
@ -1,68 +1,22 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
import glob
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
from torchvision.transforms import transforms
|
||||
from torch.utils.data import DataLoader
|
||||
import torch.nn as nn
|
||||
from torch.optim import Adam
|
||||
from torch.autograd import Variable
|
||||
import torchvision
|
||||
import pathlib
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from net import Net
|
||||
from machine_learning.neural_network.net import Net
|
||||
from machine_learning.neural_network.helpers import main_path, train_path, test_path, prediction_path, transformer
|
||||
|
||||
|
||||
temp_path = os.path.abspath('../../..')
|
||||
DIR = ''
|
||||
train_dir = r'images\learning\training'
|
||||
test_dir = r'images\learning\test'
|
||||
|
||||
train_dir = os.path.join(temp_path, train_dir)
|
||||
test_dir = os.path.join(temp_path, test_dir)
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
transformer = transforms.Compose([
|
||||
transforms.Resize((150,150)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5,0.5,0.5],
|
||||
[0.5,0.5,0.5])
|
||||
])
|
||||
|
||||
train_path = r'C:\Users\User\PycharmProjects\Super-Saper222\images\learning\training\training'
|
||||
test_path = r'C:\Users\User\PycharmProjects\Super-Saper222\images\learning\test\test'
|
||||
pred_path = r'C:\Users\User\PycharmProjects\Super-Saper222\images\learning\prediction\prediction'
|
||||
classes = ['mine', 'rock']
|
||||
|
||||
|
||||
train_loader = DataLoader(
|
||||
torchvision.datasets.ImageFolder(train_path, transform=transformer),
|
||||
batch_size=64, shuffle=True
|
||||
)
|
||||
test_loader = DataLoader(
|
||||
torchvision.datasets.ImageFolder(test_path, transform=transformer),
|
||||
batch_size=32, shuffle=True
|
||||
)
|
||||
|
||||
root=pathlib.Path(train_path)
|
||||
classes = sorted([j.name.split('/')[-1] for j in root.iterdir()])
|
||||
|
||||
model = Net(num_classes=6).to(device)
|
||||
|
||||
optimizer = Adam(model.parameters(),lr=1e-3,weight_decay=0.0001)
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
|
||||
num_epochs = 10
|
||||
|
||||
train_count = len(glob.glob(train_path+'/**/*.*'))
|
||||
test_count = len(glob.glob(test_path+'/**/*.*'))
|
||||
|
||||
print(train_count,test_count)
|
||||
|
||||
best_accuracy = 0.0
|
||||
|
||||
def train(dataloader, model, loss_fn, optimizer):
|
||||
def train(dataloader, model: Net, optimizer: Adam, loss_fn: nn.CrossEntropyLoss):
|
||||
size = len(dataloader.dataset)
|
||||
for batch, (X, y) in enumerate(dataloader):
|
||||
X, y = X.to(device), y.to(device)
|
||||
@ -74,12 +28,12 @@ def train(dataloader, model, loss_fn, optimizer):
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if batch % 100 == 0:
|
||||
if batch % 5 == 0:
|
||||
loss, current = loss.item(), batch * len(X)
|
||||
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
|
||||
|
||||
|
||||
def test(dataloader, model):
|
||||
def test(dataloader, model: Net, loss_fn: nn.CrossEntropyLoss):
|
||||
size = len(dataloader.dataset)
|
||||
model.eval()
|
||||
test_loss, correct = 0, 0
|
||||
@ -96,7 +50,7 @@ def test(dataloader, model):
|
||||
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
|
||||
|
||||
|
||||
def prediction1(classes, img_path, model, transformer):
|
||||
def prediction(img_path, model: Net):
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
image_tensor = transformer(image).float()
|
||||
image_tensor = image_tensor.unsqueeze_(0)
|
||||
@ -110,41 +64,45 @@ def prediction1(classes, img_path, model, transformer):
|
||||
pred = classes[index]
|
||||
return pred
|
||||
|
||||
transformer1 = transforms.Compose([transforms.Resize((150, 150)),
|
||||
transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
||||
|
||||
#creating new model
|
||||
def test_prediction_set():
|
||||
checkpoint = torch.load(f'{main_path}/mine_recognizer.model')
|
||||
model = Net(num_classes=2)
|
||||
model.load_state_dict(checkpoint)
|
||||
model.eval()
|
||||
|
||||
# for t in range(9):
|
||||
# print(f"Epoch {t+1}\n-------------------------------")
|
||||
# train(train_loader, model, loss_fn, optimizer)
|
||||
# test(test_loader, model)
|
||||
# print("Done!")
|
||||
# torch.save(model.state_dict(), 'mine_recognizer.model')
|
||||
pred_dict = {}
|
||||
|
||||
for file in os.listdir(prediction_path):
|
||||
pred_dict[file[file.rfind('/') + 1:]] = prediction(f'{prediction_path}/{file}', model)
|
||||
|
||||
print(pred_dict)
|
||||
|
||||
|
||||
#checking work of new model
|
||||
def main():
|
||||
num_epochs = 50
|
||||
|
||||
checkpoint = torch.load(os.path.join('.', 'mine_recognizer.model'))
|
||||
model = Net(num_classes=6)
|
||||
model.load_state_dict(checkpoint)
|
||||
model.eval()
|
||||
train_loader = DataLoader(
|
||||
torchvision.datasets.ImageFolder(train_path, transform=transformer),
|
||||
batch_size=64, shuffle=True
|
||||
)
|
||||
test_loader = DataLoader(
|
||||
torchvision.datasets.ImageFolder(test_path, transform=transformer),
|
||||
batch_size=32, shuffle=True
|
||||
)
|
||||
|
||||
transformer1 = transforms.Compose([transforms.Resize((150, 150)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
||||
images_path = glob.glob(pred_path+'/*.*')
|
||||
pred_dict = {}
|
||||
model = Net(2).to(device)
|
||||
optimizer = Adam(model.parameters(), lr=1e-3, weight_decay=0.0001)
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
|
||||
for i in images_path:
|
||||
pred_dict[i[i.rfind('/') + 1:]] = prediction1(classes, i, model, transformer1)
|
||||
print(pred_dict)
|
||||
for t in range(num_epochs):
|
||||
print(f"Epoch {t + 1}\n-------------------------------")
|
||||
train(train_loader, model, optimizer, loss_fn)
|
||||
test(test_loader, model, loss_fn)
|
||||
print("Done!")
|
||||
torch.save(model.state_dict(), f'{main_path}/mine_recognizer.model')
|
||||
test_prediction_set()
|
||||
|
||||
model = Net(num_classes=6)
|
||||
model.load_state_dict(checkpoint)
|
||||
model.eval()
|
||||
|
||||
images_path = glob.glob(pred_path + '/*.*')
|
||||
pred_dict = {}
|
||||
for i in images_path:
|
||||
pred_dict[i[i.rfind('/') + 1:]] = prediction1(classes, i, model, transformer1)
|
||||
print(pred_dict)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -4,39 +4,21 @@ import torch.nn as nn
|
||||
class Net(nn.Module):
|
||||
def __init__(self, num_classes=6):
|
||||
super(Net, self).__init__()
|
||||
|
||||
# Output size after convolution filter
|
||||
# ((w-f+2P)/s) +1
|
||||
|
||||
# Input shape= (256,3,150,150)
|
||||
|
||||
self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=3, stride=1, padding=1)
|
||||
# Shape= (256,12,150,150)
|
||||
self.bn1 = nn.BatchNorm2d(num_features=12)
|
||||
# Shape= (256,12,150,150)
|
||||
self.relu1 = nn.ReLU()
|
||||
# Shape= (256,12,150,150)
|
||||
|
||||
self.pool = nn.MaxPool2d(kernel_size=2)
|
||||
# Reduce the image size be factor 2
|
||||
# Shape= (256,12,75,75)
|
||||
|
||||
self.conv2 = nn.Conv2d(in_channels=12, out_channels=20, kernel_size=3, stride=1, padding=1)
|
||||
# Shape= (256,20,75,75)
|
||||
self.relu2 = nn.ReLU()
|
||||
# Shape= (256,20,75,75)
|
||||
|
||||
self.conv3 = nn.Conv2d(in_channels=20, out_channels=32, kernel_size=3, stride=1, padding=1)
|
||||
# Shape= (256,32,75,75)
|
||||
self.bn3 = nn.BatchNorm2d(num_features=32)
|
||||
# Shape= (256,32,75,75)
|
||||
self.relu3 = nn.ReLU()
|
||||
# Shape= (256,32,75,75)
|
||||
|
||||
self.fc = nn.Linear(in_features=75 * 75 * 32, out_features=num_classes)
|
||||
|
||||
# Feed forwad function
|
||||
|
||||
def forward(self, input):
|
||||
output = self.conv1(input)
|
||||
output = self.bn1(output)
|
||||
@ -51,10 +33,7 @@ class Net(nn.Module):
|
||||
output = self.bn3(output)
|
||||
output = self.relu3(output)
|
||||
|
||||
# Above output will be in matrix form, with shape (256,32,75,75)
|
||||
|
||||
output = output.view(-1, 32 * 75 * 75)
|
||||
|
||||
output = self.fc(output)
|
||||
|
||||
return output
|
Loading…
Reference in New Issue
Block a user