created dataset, module trained 95/90 accuracy
BIN
model/best_model.pth
Normal file
BIN
model/data/train/bat/0CUH9CM2IF4Z.jpg
Normal file
After Width: | Height: | Size: 28 KiB |
BIN
model/data/train/bear/0BVFXHLSFYEM.jpg
Normal file
After Width: | Height: | Size: 34 KiB |
BIN
model/data/train/elephant/Elephant_0.jpg
Normal file
After Width: | Height: | Size: 173 KiB |
BIN
model/data/train/giraffe/Giraffe_0.jpg
Normal file
After Width: | Height: | Size: 107 KiB |
BIN
model/data/train/owl/0AL42OEB782C.jpg
Normal file
After Width: | Height: | Size: 14 KiB |
BIN
model/data/train/parrot/Parrot_Download_train_0.jpg
Normal file
After Width: | Height: | Size: 19 KiB |
BIN
model/data/train/penguin/Penguin_Download_1_train_0.jpg
Normal file
After Width: | Height: | Size: 38 KiB |
BIN
model/data/val/bat/0BN7W0OQC1M1.jpg
Normal file
After Width: | Height: | Size: 53 KiB |
BIN
model/data/val/bear/0I31CLWNAVJV.jpg
Normal file
After Width: | Height: | Size: 77 KiB |
BIN
model/data/val/elephant/Elephant_10.jpg
Normal file
After Width: | Height: | Size: 173 KiB |
BIN
model/data/val/giraffe/Giraffe_2.jpg
Normal file
After Width: | Height: | Size: 158 KiB |
BIN
model/data/val/owl/0AD8FZAIF9BZ.jpg
Normal file
After Width: | Height: | Size: 126 KiB |
BIN
model/data/val/parrot/Parrot_Download_train_9.jpg
Normal file
After Width: | Height: | Size: 44 KiB |
BIN
model/data/val/penguin/Penguin_Download_1_train_4.jpg
Normal file
After Width: | Height: | Size: 30 KiB |
129
model/model.py
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
import torchvision.datasets
|
||||||
|
from torchvision import datasets, transforms, models
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
|
||||||
|
def set_device():
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = 'cuda'
|
||||||
|
else:
|
||||||
|
device = 'cpu'
|
||||||
|
return torch.device(device)
|
||||||
|
|
||||||
|
|
||||||
|
train_dataset_path = './data/train'
|
||||||
|
test_dataset_path = './data/val'
|
||||||
|
number_of_classes = 7
|
||||||
|
|
||||||
|
SIZE = 224
|
||||||
|
mean = [0.5164, 0.5147, 0.4746]
|
||||||
|
std = [0.2180, 0.2126, 0.2172]
|
||||||
|
|
||||||
|
train_transforms = transforms.Compose([
|
||||||
|
transforms.Resize((SIZE, SIZE)),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.RandomRotation(10),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(torch.Tensor(mean), torch.Tensor(std))
|
||||||
|
])
|
||||||
|
|
||||||
|
test_transforms = transforms.Compose([
|
||||||
|
transforms.Resize((SIZE, SIZE)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(torch.Tensor(mean), torch.Tensor(std))
|
||||||
|
])
|
||||||
|
|
||||||
|
train_dataset = torchvision.datasets.ImageFolder(root=train_dataset_path, transform=train_transforms)
|
||||||
|
test_dataset = torchvision.datasets.ImageFolder(root=test_dataset_path, transform=test_transforms)
|
||||||
|
|
||||||
|
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
|
||||||
|
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
|
||||||
|
|
||||||
|
resnet18_model = models.resnet18(weights=None)
|
||||||
|
num_ftrs = resnet18_model.fc.in_features
|
||||||
|
resnet18_model.fc = nn.Linear(num_ftrs, number_of_classes)
|
||||||
|
device = set_device()
|
||||||
|
resnet18_model = resnet18_model.to(device)
|
||||||
|
loss_fn = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
optimizer = optim.SGD(resnet18_model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.003)
|
||||||
|
|
||||||
|
|
||||||
|
def save_checkpoint(model, epoch, optimizer, best_acc):
|
||||||
|
state = {
|
||||||
|
'epoch': epoch + 1,
|
||||||
|
'model': model.state_dict(),
|
||||||
|
'best accuracy': best_acc,
|
||||||
|
'optimizer': optimizer.state_dict()
|
||||||
|
}
|
||||||
|
torch.save(state, 'model_best_checkpoint.pth.tar')
|
||||||
|
def train_nn(model, train_loader, test_loader, criterion, optimizer, n_epochs):
|
||||||
|
device = set_device()
|
||||||
|
best_acc = 0
|
||||||
|
|
||||||
|
for epoch in range(n_epochs):
|
||||||
|
print("Epoch number %d " % (epoch + 1))
|
||||||
|
model.train()
|
||||||
|
running_loss = 0.0
|
||||||
|
running_correct = 0.0
|
||||||
|
total = 0
|
||||||
|
|
||||||
|
for data in train_loader:
|
||||||
|
images, labels = data
|
||||||
|
images = images.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
total += labels.size(0)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
outputs = model(images)
|
||||||
|
_, predicted = torch.max(outputs.data, 1)
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
running_loss += loss.item()
|
||||||
|
running_correct += (labels == predicted).sum().item()
|
||||||
|
|
||||||
|
epoch_loss = running_loss/len(train_loader)
|
||||||
|
epoch_acc = 100 * running_correct / total
|
||||||
|
print(f"Training dataset. Got {running_correct} out of {total} images correctly ({epoch_acc}). Epoch loss: {epoch_loss}")
|
||||||
|
|
||||||
|
test_data_acc = evaluate_model_on_test_set(model, test_loader)
|
||||||
|
|
||||||
|
if test_data_acc > best_acc:
|
||||||
|
best_acc = test_data_acc
|
||||||
|
save_checkpoint(model, epoch, optimizer, best_acc)
|
||||||
|
|
||||||
|
print("Finished")
|
||||||
|
return model
|
||||||
|
def evaluate_model_on_test_set(model, test_loader):
|
||||||
|
model.eval()
|
||||||
|
predicted_correctly_on_epoch = 0
|
||||||
|
total = 0
|
||||||
|
device = set_device()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for data in test_loader:
|
||||||
|
images, labels = data
|
||||||
|
images = images.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
total += labels.size(0)
|
||||||
|
|
||||||
|
outputs = model(images)
|
||||||
|
_, predicted = torch.max(outputs.data, 1)
|
||||||
|
|
||||||
|
predicted_correctly_on_epoch += (predicted == labels).sum().item()
|
||||||
|
|
||||||
|
epoch_acc = 100 * predicted_correctly_on_epoch / total
|
||||||
|
print(f"Testing dataset. Got {predicted_correctly_on_epoch} out of {total} images correctly ({epoch_acc})")
|
||||||
|
return epoch_acc
|
||||||
|
|
||||||
|
|
||||||
|
train_nn(resnet18_model, train_loader, test_loader, loss_fn, optimizer, n_epochs=30)
|
||||||
|
|
||||||
|
checkpoint = torch.load('model_best_checkpoint.pth.tar')
|
||||||
|
resnet18_model.load_state_dict(checkpoint['model'])
|
||||||
|
torch.save(resnet18_model, 'best_model.pth')
|