si23traktor/neural_network/nn.py

112 lines
3.9 KiB
Python
Raw Permalink Normal View History

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
import multiprocessing
def main():
# Set the device to use (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define data transformations
data_transforms = {
"train": transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
"validation": transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
}
# Set the path to your vegetable images folder
data_dir = "neural_network/dataset/vegetables"
# Load the dataset from the folder
image_datasets = {x: datasets.ImageFolder(f"{data_dir}/{x}", data_transforms[x])
for x in ["train", "validation"]}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=multiprocessing.cpu_count())
for x in ["train", "validation"]}
dataset_sizes = {x: len(image_datasets[x]) for x in ["train", "validation"]}
class_names = image_datasets["train"].classes
2023-06-05 16:18:20 +02:00
print(class_names)
num_classes = len(class_names)
2023-06-05 16:18:20 +02:00
print(num_classes)
# Load a pre-trained ResNet model
model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)
model = model.to(device)
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
2023-06-05 16:18:20 +02:00
# Load the previously trained model state
#checkpoint = torch.load("neural_network/save/trained_model.pth")
#model.load_state_dict(checkpoint)
# Train the model
def train_model(model, criterion, optimizer, num_epochs=2):
2023-06-05 01:30:39 +02:00
best_model_wts = None # Initialize the variable
best_acc = 0.0
2023-06-05 16:18:20 +02:00
for epoch in range(num_epochs):
print(f"Epoch {epoch+1}/{num_epochs}")
print("-" * 10)
for phase in ["train", "validation"]:
if phase == "train":
model.train()
else:
model.eval()
running_loss = 0.0
running_corrects = 0
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
with torch.set_grad_enabled(phase == "train"):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
if phase == "train":
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase]
print(f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")
2023-06-05 01:30:39 +02:00
if phase == "validation" and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = model.state_dict()
torch.save(best_model_wts, "neural_network/save/trained_model.pth")
# Start training
train_model(model, criterion, optimizer, num_epochs=2)
if __name__ == '__main__':
multiprocessing.set_start_method('spawn') # Set start method for multiprocessing
main()