47 lines
1.5 KiB
Python
47 lines
1.5 KiB
Python
import pathlib
|
|
import random
|
|
|
|
import torch
|
|
from PIL.Image import Image
|
|
from torch import nn
|
|
from torch.utils.data import DataLoader
|
|
from torchvision import datasets, transforms
|
|
from torchvision.transforms import Lambda
|
|
|
|
device = torch.device('cuda')
|
|
|
|
def train(model, dataset, n_iter=100, batch_size=2560000):
|
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
|
criterion = nn.NLLLoss()
|
|
dl = DataLoader(dataset, batch_size=batch_size)
|
|
model.train()
|
|
for epoch in range(n_iter):
|
|
for images, targets in dl:
|
|
optimizer.zero_grad()
|
|
out = model(images.to(device))
|
|
loss = criterion(out, targets.to(device))
|
|
loss.backward()
|
|
optimizer.step()
|
|
if epoch % 10 == 0:
|
|
print('epoch: %3d loss: %.4f' % (epoch, loss))
|
|
|
|
|
|
image_path_list = list(pathlib.Path('./').glob("*/*/*.png"))
|
|
|
|
random_image_path = random.choice(image_path_list)
|
|
data_transform = transforms.Compose([
|
|
transforms.Resize(size=(100, 100)),
|
|
transforms.RandomHorizontalFlip(p=0.5),
|
|
transforms.ToTensor(),
|
|
Lambda(lambda x: x.flatten())
|
|
])
|
|
|
|
train_data = datasets.ImageFolder(root="./datasets",
|
|
transform=data_transform,
|
|
target_transform=None)
|
|
|
|
model1=nn.Sequential(nn.Linear(30000, 10000),nn.ReLU(),nn.Linear(10000,10000),nn.ReLU(),nn.Linear(10000,10000),nn.Linear(10000,4),nn.LogSoftmax(dim=-1)).to(device)
|
|
model1.load_state_dict(torch.load("./trained"))
|
|
train(model1,train_data)
|
|
|
|
torch.save(model1.state_dict(), "./trained") |