neural_network #4
0
src/machine_learning/decision_tree/__init__.py
Normal file
0
src/machine_learning/decision_tree/__init__.py
Normal file
@ -5,7 +5,7 @@ from collections import Counter
|
||||
|
||||
from anytree import Node, RenderTree
|
||||
|
||||
from machine_learning.data_set import training_set, test_set, attributes as attribs
|
||||
from machine_learning.decision_tree.data_set import training_set, test_set, attributes as attribs
|
||||
|
||||
|
||||
def calculate_entropy(p: int, n: int) -> float:
|
25
src/machine_learning/neural_network/Dataset.py
Normal file
25
src/machine_learning/neural_network/Dataset.py
Normal file
@ -0,0 +1,25 @@
|
||||
import os
|
||||
import pandas as pd
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.io import read_image
|
||||
|
||||
|
||||
class ImageDataset(Dataset):
|
||||
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
|
||||
self.img_labels = pd.read_csv(annotations_file)
|
||||
self.img_dir = img_dir
|
||||
self.transform = transform
|
||||
self.target_transform = target_transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_labels)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
|
||||
image = read_image(img_path)
|
||||
label = self.img_labels.iloc[idx, 1]
|
||||
if self.transform:
|
||||
image = self.transform(image)
|
||||
if self.target_transform:
|
||||
label = self.target_transform(label)
|
||||
return image, label
|
0
src/machine_learning/neural_network/__init__.py
Normal file
0
src/machine_learning/neural_network/__init__.py
Normal file
6
src/machine_learning/neural_network/annotations.csv
Normal file
6
src/machine_learning/neural_network/annotations.csv
Normal file
@ -0,0 +1,6 @@
|
||||
mine1.jpg, 0
|
||||
mine2.jpg, 0
|
||||
mine3.jpg. 0
|
||||
other1.jpg, 1
|
||||
other2.jpg, 1
|
||||
other3.jpg, 1
|
|
74
src/machine_learning/neural_network/learning.py
Normal file
74
src/machine_learning/neural_network/learning.py
Normal file
@ -0,0 +1,74 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from net import NeuralNetwork
|
||||
from database import create_training_data
|
||||
from Dataset import ImageDataset
|
||||
|
||||
|
||||
training_data = create_training_data()
|
||||
# dataset = ImageDataset(annonation_file, img_dir)
|
||||
# trainset, valset = random_split(dataset, [15593, 3898])
|
||||
|
||||
batch_size = 64
|
||||
|
||||
# Create data loaders.
|
||||
train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
|
||||
# test_dataloader = DataLoader(test_data, batch_size=batch_size)
|
||||
|
||||
# for X, y in test_dataloader:
|
||||
# print("Shape of X [N, C, H, W]: ", X.shape)
|
||||
# print("Shape of y: ", y.shape, y.dtype)
|
||||
# break
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print("Using {} device".format(device))
|
||||
|
||||
model = NeuralNetwork().to(device)
|
||||
print(model)
|
||||
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
|
||||
|
||||
|
||||
def train(dataloader, model, loss_fn, optimizer):
|
||||
size = len(dataloader.dataset)
|
||||
for batch, (X, y) in enumerate(dataloader):
|
||||
X, y = X.to(device), y.to(device)
|
||||
|
||||
# Compute prediction error
|
||||
pred = model(X)
|
||||
loss = loss_fn(pred, y)
|
||||
|
||||
# Backpropagation
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if batch % 100 == 0:
|
||||
loss, current = loss.item(), batch * len(X)
|
||||
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
|
||||
|
||||
|
||||
def test(dataloader, model):
|
||||
size = len(dataloader.dataset)
|
||||
model.eval()
|
||||
test_loss, correct = 0, 0
|
||||
with torch.no_grad():
|
||||
for X, y in dataloader:
|
||||
X, y = X.to(device), y.to(device)
|
||||
pred = model(X)
|
||||
test_loss += loss_fn(pred, y).item()
|
||||
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
|
||||
test_loss /= size
|
||||
correct /= size
|
||||
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
|
||||
|
||||
|
||||
epochs = 5
|
||||
for t in range(epochs):
|
||||
print(f"Epoch {t+1}\n-------------------------------")
|
||||
train(train_dataloader, model, loss_fn, optimizer)
|
||||
test(test_dataloader, model)
|
||||
print("Done!")
|
@ -30,6 +30,7 @@ if __name__ == '__main__':
|
||||
|
||||
for epoch in range(n_iter):
|
||||
for image, label in zip(train_images, train_labels):
|
||||
print(image.shape)
|
||||
optimizer.zero_grad()
|
||||
|
||||
output = model(image)
|
19
src/machine_learning/neural_network/net.py
Normal file
19
src/machine_learning/neural_network/net.py
Normal file
@ -0,0 +1,19 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class NeuralNetwork(nn.Module):
|
||||
def __init__(self):
|
||||
super(NeuralNetwork, self).__init__()
|
||||
self.flatten = nn.Flatten()
|
||||
self.linear_relu_stack = nn.Sequential(
|
||||
nn.Linear(28*28, 512),
|
||||
nn.ReLU(),
|
||||
nn.Linear(512, 512),
|
||||
nn.ReLU(),
|
||||
nn.Linear(512, 10),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.flatten(x)
|
||||
return self.linear_relu_stack(x)
|
@ -7,8 +7,8 @@ from tilesFactory import TilesFactory
|
||||
from const import ICON, IMAGES, TREE_ROOT
|
||||
from src.search_algoritms.a_star import a_star
|
||||
from search_algoritms.BFS import breadth_first_search
|
||||
from machine_learning.decision_tree import get_decision
|
||||
from machine_learning.helpers import get_dataset_from_tile
|
||||
from machine_learning.decision_tree.decision_tree import get_decision
|
||||
from machine_learning.decision_tree.helpers import get_dataset_from_tile
|
||||
|
||||
|
||||
def handle_keys(env: Environment, agent: Agent, game_ui: GameUi, factory: TilesFactory):
|
||||
|
@ -4,7 +4,7 @@ from importlib import import_module
|
||||
|
||||
from tile import Tile
|
||||
from const import IMAGES
|
||||
from machine_learning.data_set import visibility, stability, armed
|
||||
from machine_learning.decision_tree.data_set import visibility, stability, armed
|
||||
|
||||
|
||||
class TilesFactory:
|
||||
|
Loading…
Reference in New Issue
Block a user