import pathlib import random import torch from torch import nn from torch.utils.data import DataLoader from torchvision import datasets, transforms from torchvision.transforms import Lambda device = torch.device('cpu') def train(model, dataset, n_iter=100, batch_size=2560000): optimizer = torch.optim.SGD(model.parameters(), lr=0.01) criterion = nn.NLLLoss() dl = DataLoader(dataset, batch_size=batch_size) model.train() for epoch in range(n_iter): for images, targets in dl: optimizer.zero_grad() out = model(images.to(device)) loss = criterion(out, targets.to(device)) loss.backward() optimizer.step() if epoch % 10 == 0: print('epoch: %3d loss: %.4f' % (epoch, loss)) image_path_list = list(pathlib.Path('./').glob("*/*/*.png")) random_image_path = random.choice(image_path_list) data_transform = transforms.Compose([ transforms.Resize(size=(100, 100)), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), Lambda(lambda x: x.flatten()) ]) train_data = datasets.ImageFolder(root="./datasets", transform=data_transform, target_transform=None) model1 = nn.Sequential(nn.Linear(30000, 10000), nn.ReLU(), nn.Linear(10000, 10000), nn.ReLU(), nn.Linear(10000, 0000), nn.Linear(10000, 4), nn.LogSoftmax(dim=-1)).to(device) model1.load_state_dict(torch.load("./trained")) train(model1, train_data) torch.save(model1.state_dict(), "./trained")