poprawki
This commit is contained in:
parent
b6ba817d55
commit
8d73a85707
@ -13,44 +13,6 @@ from pytorch_lightning.callbacks import EarlyStopping
|
|||||||
import torchvision.transforms.functional as F
|
import torchvision.transforms.functional as F
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
def train(model):
|
|
||||||
model = model.to(DEVICE)
|
|
||||||
model.train()
|
|
||||||
trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=SETUP_PHOTOS)
|
|
||||||
testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=SETUP_PHOTOS)
|
|
||||||
train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
|
|
||||||
test_loader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=True)
|
|
||||||
|
|
||||||
criterion = nn.CrossEntropyLoss()
|
|
||||||
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
|
|
||||||
|
|
||||||
for epoch in range(NUM_EPOCHS):
|
|
||||||
for batch_idx, (data, targets) in enumerate(train_loader):
|
|
||||||
data = data.to(device=DEVICE)
|
|
||||||
targets = targets.to(device=DEVICE)
|
|
||||||
|
|
||||||
scores = model(data)
|
|
||||||
loss = criterion(scores, targets)
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
if batch_idx % 4 == 0:
|
|
||||||
print("epoch: %d loss: %.4f" % (epoch, loss.item()))
|
|
||||||
|
|
||||||
print("FINISHED TRAINING!")
|
|
||||||
torch.save(model.state_dict(), "./learnednetwork.pth")
|
|
||||||
|
|
||||||
print("Checking accuracy for the train set.")
|
|
||||||
check_accuracy(train_loader)
|
|
||||||
print("Checking accuracy for the test set.")
|
|
||||||
check_accuracy(test_loader)
|
|
||||||
print("Checking accuracy for the tiles.")
|
|
||||||
check_accuracy_tiles()
|
|
||||||
|
|
||||||
|
|
||||||
def check_accuracy_tiles():
|
def check_accuracy_tiles():
|
||||||
answer = 0
|
answer = 0
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
|
Loading…
Reference in New Issue
Block a user