rm pyplot

This commit is contained in:
Alagris 2021-04-26 08:07:56 +02:00
parent b1aa99772c
commit 5a1550c3fc

View File

@ -10,7 +10,6 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import re
import random
import os
@ -92,7 +91,6 @@ def dist(a: [str], b: [str]):
def train_model(model):
plt.ion()
optimizer = optim.Adam(filter(lambda x: x.requires_grad, model.parameters()),
lr=LEARNING_RATE)
loss_snapshots = []
@ -132,10 +130,6 @@ def train_model(model):
total_loss += loss_scalar
inner_bar.set_description("loss %.2f" % loss_scalar)
loss_snapshots.append(total_loss / len(DATA) * 3)
plt.clf()
plt.plot(loss_snapshots, label="Avg loss ")
plt.legend(loc="upper left")
plt.pause(interval=0.01)
# print()
# print("Total epoch loss:", total_loss)
# print("Total epoch avg loss:", total_loss / TOTAL_TRAINING_OUT_LEN)
@ -145,7 +139,6 @@ def train_model(model):
# print("Evaluation snapshots(%):", eval_snapshots_percentage)
outer_bar.set_description("Epochs")
outer_bar.update(1)
plt.ioff()
def evaluate_monte_carlo(model, repeats):