36 lines
1002 B
Python
36 lines
1002 B
Python
from torchvision import datasets
|
|
from torch.utils.data import DataLoader
|
|
import torch.nn as nn
|
|
import torch
|
|
|
|
from settings import MODEL_FILENAME, get_balancing_sampler, RAW_DATASET_PATH, OurCNN
|
|
from preprocess import TRANSFORM
|
|
|
|
if __name__ == "__main__":
|
|
dataset = datasets.ImageFolder(root=RAW_DATASET_PATH + "train", transform=TRANSFORM)
|
|
|
|
sampler = get_balancing_sampler(dataset)
|
|
|
|
dataloader = DataLoader(dataset, batch_size=32, num_workers=16, sampler=sampler)
|
|
|
|
model = OurCNN()
|
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
|
|
|
num_epochs = 10
|
|
|
|
for epoch in range(num_epochs):
|
|
for images, labels in dataloader:
|
|
optimizer.zero_grad()
|
|
|
|
outputs = model(images)
|
|
loss = criterion(outputs, labels)
|
|
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")
|
|
|
|
torch.save(model.state_dict(), MODEL_FILENAME)
|