inteligentny-traktor/NN/trainer.py

49 lines
1.5 KiB
Python
Raw Normal View History

2023-06-01 11:10:14 +02:00
import pathlib
import random
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import Lambda
2023-06-15 14:03:24 +02:00
device = torch.device('cpu')
2023-06-01 11:10:14 +02:00
def train(model, dataset, n_iter=100, batch_size=2560000):
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.NLLLoss()
dl = DataLoader(dataset, batch_size=batch_size)
model.train()
for epoch in range(n_iter):
for images, targets in dl:
optimizer.zero_grad()
out = model(images.to(device))
loss = criterion(out, targets.to(device))
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print('epoch: %3d loss: %.4f' % (epoch, loss))
image_path_list = list(pathlib.Path('./').glob("*/*/*.png"))
random_image_path = random.choice(image_path_list)
data_transform = transforms.Compose([
transforms.Resize(size=(100, 100)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
Lambda(lambda x: x.flatten())
])
train_data = datasets.ImageFolder(root="./datasets",
transform=data_transform,
target_transform=None)
2023-06-15 14:03:24 +02:00
model1 = nn.Sequential(nn.Linear(30000, 10000), nn.ReLU(), nn.Linear(10000, 10000), nn.ReLU(), nn.Linear(10000, 0000), nn.Linear(10000, 4), nn.LogSoftmax(dim=-1)).to(device)
2023-06-01 11:10:14 +02:00
model1.load_state_dict(torch.load("./trained"))
2023-06-15 14:03:24 +02:00
train(model1, train_data)
2023-06-01 11:10:14 +02:00
2023-06-15 14:03:24 +02:00
torch.save(model1.state_dict(), "./trained")