diff --git a/survival/learning_utils.py b/survival/learning_utils.py index a7cbebf..21edf0b 100644 --- a/survival/learning_utils.py +++ b/survival/learning_utils.py @@ -14,6 +14,7 @@ class LearningUtils: self.plot_mean_scores = [] self.total_score = 0 self.last_actions: [Action, [int, int]] = [] + self.plots = 0 def add_scores(self, learning: LearningComponent, games_count: int): self.plot_scores.append(learning.score) @@ -25,14 +26,16 @@ class LearningUtils: display.clear_output(wait=True) display.display(plt.gcf()) plt.clf() - plt.title('Training...') + plt.title('Results') plt.xlabel('Number of Games') plt.ylabel('Score') plt.plot(self.plot_scores) - # plt.plot(self.plot_mean_scores) + plt.plot(self.plot_mean_scores) plt.ylim(ymin=0) plt.text(len(self.plot_scores) - 1, self.plot_scores[-1], str(self.plot_scores[-1])) - # plt.text(len(self.plot_mean_scores) - 1, self.plot_mean_scores[-1], str(self.plot_mean_scores[-1])) + plt.text(len(self.plot_mean_scores) - 1, self.plot_mean_scores[-1], str(self.plot_mean_scores[-1])) + self.plots += 1 + plt.savefig(f'model/plots/{self.plots}.png') plt.show(block=False) plt.pause(.1)