rm pyplot
This commit is contained in:
parent
b1aa99772c
commit
5a1550c3fc
@ -10,7 +10,6 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import matplotlib.pyplot as plt
|
||||
import re
|
||||
import random
|
||||
import os
|
||||
@ -92,7 +91,6 @@ def dist(a: [str], b: [str]):
|
||||
|
||||
|
||||
def train_model(model):
|
||||
plt.ion()
|
||||
optimizer = optim.Adam(filter(lambda x: x.requires_grad, model.parameters()),
|
||||
lr=LEARNING_RATE)
|
||||
loss_snapshots = []
|
||||
@ -132,10 +130,6 @@ def train_model(model):
|
||||
total_loss += loss_scalar
|
||||
inner_bar.set_description("loss %.2f" % loss_scalar)
|
||||
loss_snapshots.append(total_loss / len(DATA) * 3)
|
||||
plt.clf()
|
||||
plt.plot(loss_snapshots, label="Avg loss ")
|
||||
plt.legend(loc="upper left")
|
||||
plt.pause(interval=0.01)
|
||||
# print()
|
||||
# print("Total epoch loss:", total_loss)
|
||||
# print("Total epoch avg loss:", total_loss / TOTAL_TRAINING_OUT_LEN)
|
||||
@ -145,7 +139,6 @@ def train_model(model):
|
||||
# print("Evaluation snapshots(%):", eval_snapshots_percentage)
|
||||
outer_bar.set_description("Epochs")
|
||||
outer_bar.update(1)
|
||||
plt.ioff()
|
||||
|
||||
|
||||
def evaluate_monte_carlo(model, repeats):
|
||||
|
Loading…
Reference in New Issue
Block a user