rm pyplot
This commit is contained in:
parent
b1aa99772c
commit
5a1550c3fc
@ -10,7 +10,6 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import re
|
import re
|
||||||
import random
|
import random
|
||||||
import os
|
import os
|
||||||
@ -92,7 +91,6 @@ def dist(a: [str], b: [str]):
|
|||||||
|
|
||||||
|
|
||||||
def train_model(model):
|
def train_model(model):
|
||||||
plt.ion()
|
|
||||||
optimizer = optim.Adam(filter(lambda x: x.requires_grad, model.parameters()),
|
optimizer = optim.Adam(filter(lambda x: x.requires_grad, model.parameters()),
|
||||||
lr=LEARNING_RATE)
|
lr=LEARNING_RATE)
|
||||||
loss_snapshots = []
|
loss_snapshots = []
|
||||||
@ -132,10 +130,6 @@ def train_model(model):
|
|||||||
total_loss += loss_scalar
|
total_loss += loss_scalar
|
||||||
inner_bar.set_description("loss %.2f" % loss_scalar)
|
inner_bar.set_description("loss %.2f" % loss_scalar)
|
||||||
loss_snapshots.append(total_loss / len(DATA) * 3)
|
loss_snapshots.append(total_loss / len(DATA) * 3)
|
||||||
plt.clf()
|
|
||||||
plt.plot(loss_snapshots, label="Avg loss ")
|
|
||||||
plt.legend(loc="upper left")
|
|
||||||
plt.pause(interval=0.01)
|
|
||||||
# print()
|
# print()
|
||||||
# print("Total epoch loss:", total_loss)
|
# print("Total epoch loss:", total_loss)
|
||||||
# print("Total epoch avg loss:", total_loss / TOTAL_TRAINING_OUT_LEN)
|
# print("Total epoch avg loss:", total_loss / TOTAL_TRAINING_OUT_LEN)
|
||||||
@ -145,7 +139,6 @@ def train_model(model):
|
|||||||
# print("Evaluation snapshots(%):", eval_snapshots_percentage)
|
# print("Evaluation snapshots(%):", eval_snapshots_percentage)
|
||||||
outer_bar.set_description("Epochs")
|
outer_bar.set_description("Epochs")
|
||||||
outer_bar.update(1)
|
outer_bar.update(1)
|
||||||
plt.ioff()
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate_monte_carlo(model, repeats):
|
def evaluate_monte_carlo(model, repeats):
|
||||||
|
Loading…
Reference in New Issue
Block a user