diff --git a/model/best_model.pth b/model/best_model.pth new file mode 100644 index 0000000..c308654 Binary files /dev/null and b/model/best_model.pth differ diff --git a/model/data/train/bat/0CUH9CM2IF4Z.jpg b/model/data/train/bat/0CUH9CM2IF4Z.jpg new file mode 100644 index 0000000..966ac67 Binary files /dev/null and b/model/data/train/bat/0CUH9CM2IF4Z.jpg differ diff --git a/model/data/train/bear/0BVFXHLSFYEM.jpg b/model/data/train/bear/0BVFXHLSFYEM.jpg new file mode 100644 index 0000000..dc3d9be Binary files /dev/null and b/model/data/train/bear/0BVFXHLSFYEM.jpg differ diff --git a/model/data/train/elephant/Elephant_0.jpg b/model/data/train/elephant/Elephant_0.jpg new file mode 100644 index 0000000..cc40bd2 Binary files /dev/null and b/model/data/train/elephant/Elephant_0.jpg differ diff --git a/model/data/train/giraffe/Giraffe_0.jpg b/model/data/train/giraffe/Giraffe_0.jpg new file mode 100644 index 0000000..c2a0024 Binary files /dev/null and b/model/data/train/giraffe/Giraffe_0.jpg differ diff --git a/model/data/train/owl/0AL42OEB782C.jpg b/model/data/train/owl/0AL42OEB782C.jpg new file mode 100644 index 0000000..42b4a28 Binary files /dev/null and b/model/data/train/owl/0AL42OEB782C.jpg differ diff --git a/model/data/train/parrot/Parrot_Download_train_0.jpg b/model/data/train/parrot/Parrot_Download_train_0.jpg new file mode 100644 index 0000000..7f0dd8d Binary files /dev/null and b/model/data/train/parrot/Parrot_Download_train_0.jpg differ diff --git a/model/data/train/penguin/Penguin_Download_1_train_0.jpg b/model/data/train/penguin/Penguin_Download_1_train_0.jpg new file mode 100644 index 0000000..857e00d Binary files /dev/null and b/model/data/train/penguin/Penguin_Download_1_train_0.jpg differ diff --git a/model/data/val/bat/0BN7W0OQC1M1.jpg b/model/data/val/bat/0BN7W0OQC1M1.jpg new file mode 100644 index 0000000..7b3c883 Binary files /dev/null and b/model/data/val/bat/0BN7W0OQC1M1.jpg differ diff --git a/model/data/val/bear/0I31CLWNAVJV.jpg b/model/data/val/bear/0I31CLWNAVJV.jpg new file mode 100644 index 0000000..a1fb299 Binary files /dev/null and b/model/data/val/bear/0I31CLWNAVJV.jpg differ diff --git a/model/data/val/elephant/Elephant_10.jpg b/model/data/val/elephant/Elephant_10.jpg new file mode 100644 index 0000000..b02232b Binary files /dev/null and b/model/data/val/elephant/Elephant_10.jpg differ diff --git a/model/data/val/giraffe/Giraffe_2.jpg b/model/data/val/giraffe/Giraffe_2.jpg new file mode 100644 index 0000000..0a05324 Binary files /dev/null and b/model/data/val/giraffe/Giraffe_2.jpg differ diff --git a/model/data/val/owl/0AD8FZAIF9BZ.jpg b/model/data/val/owl/0AD8FZAIF9BZ.jpg new file mode 100644 index 0000000..4b10f95 Binary files /dev/null and b/model/data/val/owl/0AD8FZAIF9BZ.jpg differ diff --git a/model/data/val/parrot/Parrot_Download_train_9.jpg b/model/data/val/parrot/Parrot_Download_train_9.jpg new file mode 100644 index 0000000..42fdbb3 Binary files /dev/null and b/model/data/val/parrot/Parrot_Download_train_9.jpg differ diff --git a/model/data/val/penguin/Penguin_Download_1_train_4.jpg b/model/data/val/penguin/Penguin_Download_1_train_4.jpg new file mode 100644 index 0000000..2001ddd Binary files /dev/null and b/model/data/val/penguin/Penguin_Download_1_train_4.jpg differ diff --git a/model/model.py b/model/model.py new file mode 100644 index 0000000..b36904a --- /dev/null +++ b/model/model.py @@ -0,0 +1,129 @@ +import torch +import torch.nn as nn +import torch.optim as optim +import torchvision.datasets +from torchvision import datasets, transforms, models +from torch.utils.data import DataLoader + + +def set_device(): + if torch.cuda.is_available(): + device = 'cuda' + else: + device = 'cpu' + return torch.device(device) + + +train_dataset_path = './data/train' +test_dataset_path = './data/val' +number_of_classes = 7 + +SIZE = 224 +mean = [0.5164, 0.5147, 0.4746] +std = [0.2180, 0.2126, 0.2172] + +train_transforms = transforms.Compose([ + transforms.Resize((SIZE, SIZE)), + transforms.RandomHorizontalFlip(), + transforms.RandomRotation(10), + transforms.ToTensor(), + transforms.Normalize(torch.Tensor(mean), torch.Tensor(std)) +]) + +test_transforms = transforms.Compose([ + transforms.Resize((SIZE, SIZE)), + transforms.ToTensor(), + transforms.Normalize(torch.Tensor(mean), torch.Tensor(std)) +]) + +train_dataset = torchvision.datasets.ImageFolder(root=train_dataset_path, transform=train_transforms) +test_dataset = torchvision.datasets.ImageFolder(root=test_dataset_path, transform=test_transforms) + +train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True) +test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False) + +resnet18_model = models.resnet18(weights=None) +num_ftrs = resnet18_model.fc.in_features +resnet18_model.fc = nn.Linear(num_ftrs, number_of_classes) +device = set_device() +resnet18_model = resnet18_model.to(device) +loss_fn = nn.CrossEntropyLoss() + +optimizer = optim.SGD(resnet18_model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.003) + + +def save_checkpoint(model, epoch, optimizer, best_acc): + state = { + 'epoch': epoch + 1, + 'model': model.state_dict(), + 'best accuracy': best_acc, + 'optimizer': optimizer.state_dict() + } + torch.save(state, 'model_best_checkpoint.pth.tar') +def train_nn(model, train_loader, test_loader, criterion, optimizer, n_epochs): + device = set_device() + best_acc = 0 + + for epoch in range(n_epochs): + print("Epoch number %d " % (epoch + 1)) + model.train() + running_loss = 0.0 + running_correct = 0.0 + total = 0 + + for data in train_loader: + images, labels = data + images = images.to(device) + labels = labels.to(device) + total += labels.size(0) + + optimizer.zero_grad() + outputs = model(images) + _, predicted = torch.max(outputs.data, 1) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + running_loss += loss.item() + running_correct += (labels == predicted).sum().item() + + epoch_loss = running_loss/len(train_loader) + epoch_acc = 100 * running_correct / total + print(f"Training dataset. Got {running_correct} out of {total} images correctly ({epoch_acc}). Epoch loss: {epoch_loss}") + + test_data_acc = evaluate_model_on_test_set(model, test_loader) + + if test_data_acc > best_acc: + best_acc = test_data_acc + save_checkpoint(model, epoch, optimizer, best_acc) + + print("Finished") + return model +def evaluate_model_on_test_set(model, test_loader): + model.eval() + predicted_correctly_on_epoch = 0 + total = 0 + device = set_device() + + with torch.no_grad(): + for data in test_loader: + images, labels = data + images = images.to(device) + labels = labels.to(device) + total += labels.size(0) + + outputs = model(images) + _, predicted = torch.max(outputs.data, 1) + + predicted_correctly_on_epoch += (predicted == labels).sum().item() + + epoch_acc = 100 * predicted_correctly_on_epoch / total + print(f"Testing dataset. Got {predicted_correctly_on_epoch} out of {total} images correctly ({epoch_acc})") + return epoch_acc + + +train_nn(resnet18_model, train_loader, test_loader, loss_fn, optimizer, n_epochs=30) + +checkpoint = torch.load('model_best_checkpoint.pth.tar') +resnet18_model.load_state_dict(checkpoint['model']) +torch.save(resnet18_model, 'best_model.pth') \ No newline at end of file diff --git a/model/model_best_checkpoint.pth.tar b/model/model_best_checkpoint.pth.tar new file mode 100644 index 0000000..e00de45 Binary files /dev/null and b/model/model_best_checkpoint.pth.tar differ diff --git a/tree.png b/tree.png index ba0161b..d608b6e 100644 Binary files a/tree.png and b/tree.png differ