forked from s464965/WMICraft
poprawki
This commit is contained in:
parent
b6ba817d55
commit
8d73a85707
@ -13,44 +13,6 @@ from pytorch_lightning.callbacks import EarlyStopping
|
||||
import torchvision.transforms.functional as F
|
||||
from PIL import Image
|
||||
|
||||
def train(model):
|
||||
model = model.to(DEVICE)
|
||||
model.train()
|
||||
trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=SETUP_PHOTOS)
|
||||
testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=SETUP_PHOTOS)
|
||||
train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
test_loader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
|
||||
|
||||
for epoch in range(NUM_EPOCHS):
|
||||
for batch_idx, (data, targets) in enumerate(train_loader):
|
||||
data = data.to(device=DEVICE)
|
||||
targets = targets.to(device=DEVICE)
|
||||
|
||||
scores = model(data)
|
||||
loss = criterion(scores, targets)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
optimizer.step()
|
||||
|
||||
if batch_idx % 4 == 0:
|
||||
print("epoch: %d loss: %.4f" % (epoch, loss.item()))
|
||||
|
||||
print("FINISHED TRAINING!")
|
||||
torch.save(model.state_dict(), "./learnednetwork.pth")
|
||||
|
||||
print("Checking accuracy for the train set.")
|
||||
check_accuracy(train_loader)
|
||||
print("Checking accuracy for the test set.")
|
||||
check_accuracy(test_loader)
|
||||
print("Checking accuracy for the tiles.")
|
||||
check_accuracy_tiles()
|
||||
|
||||
|
||||
def check_accuracy_tiles():
|
||||
answer = 0
|
||||
for i in range(100):
|
||||
|
Loading…
Reference in New Issue
Block a user