diff --git a/neural_network/datasets.py b/neural_network/datasets.py index 33c7f162..166268c5 100644 --- a/neural_network/datasets.py +++ b/neural_network/datasets.py @@ -29,9 +29,9 @@ valid_transform = transforms.Compose([ ) ]) -train_dataset = torchvision.datasets.ImageFolder(root='./Vegetable Images/train', transform=train_transform) +train_dataset = torchvision.datasets.ImageFolder(root='./images/train', transform=train_transform) -validation_dataset = torchvision.datasets.ImageFolder(root='./Vegetable Images/validation', transform=valid_transform) +validation_dataset = torchvision.datasets.ImageFolder(root='./images/validation', transform=valid_transform) train_loader = DataLoader( train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True