rm pyplot

This commit is contained in:
Alagris 2021-04-26 08:07:56 +02:00
parent b1aa99772c
commit 5a1550c3fc

View File

@ -10,7 +10,6 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import re import re
import random import random
import os import os
@ -92,7 +91,6 @@ def dist(a: [str], b: [str]):
def train_model(model): def train_model(model):
plt.ion()
optimizer = optim.Adam(filter(lambda x: x.requires_grad, model.parameters()), optimizer = optim.Adam(filter(lambda x: x.requires_grad, model.parameters()),
lr=LEARNING_RATE) lr=LEARNING_RATE)
loss_snapshots = [] loss_snapshots = []
@ -132,10 +130,6 @@ def train_model(model):
total_loss += loss_scalar total_loss += loss_scalar
inner_bar.set_description("loss %.2f" % loss_scalar) inner_bar.set_description("loss %.2f" % loss_scalar)
loss_snapshots.append(total_loss / len(DATA) * 3) loss_snapshots.append(total_loss / len(DATA) * 3)
plt.clf()
plt.plot(loss_snapshots, label="Avg loss ")
plt.legend(loc="upper left")
plt.pause(interval=0.01)
# print() # print()
# print("Total epoch loss:", total_loss) # print("Total epoch loss:", total_loss)
# print("Total epoch avg loss:", total_loss / TOTAL_TRAINING_OUT_LEN) # print("Total epoch avg loss:", total_loss / TOTAL_TRAINING_OUT_LEN)
@ -145,7 +139,6 @@ def train_model(model):
# print("Evaluation snapshots(%):", eval_snapshots_percentage) # print("Evaluation snapshots(%):", eval_snapshots_percentage)
outer_bar.set_description("Epochs") outer_bar.set_description("Epochs")
outer_bar.update(1) outer_bar.update(1)
plt.ioff()
def evaluate_monte_carlo(model, repeats): def evaluate_monte_carlo(model, repeats):