file namechange

- changed name of files to represent inside
- added main to seperate processes
- changed raw code to functions
This commit is contained in:
Lewy 2021-06-15 01:36:18 +02:00
parent f993a577f4
commit 328b171e5a
4 changed files with 70 additions and 59 deletions

64
AI/NN_accuracy.py Normal file
View File

@ -0,0 +1,64 @@
import PIL
import torchvision.transforms as transforms
from AI import neural_network
# wcześniej grinder.py
# Get accuracy for neural_network model 'network_model.pth'
def NN_accuracy():
# Create the model
model = neural_network.Net()
# Load state_dict
neural_network.load_network_from_structure(model)
# Create the preprocessing transformation here
transform = transforms.Compose([neural_network.Negative(), transforms.ToTensor()])
# load your image(s)
img = PIL.Image.open('../src/test/0_100.jpg')
img2 = PIL.Image.open('../src/test/1_100.jpg')
img3 = PIL.Image.open('../src/test/4_100.jpg')
img4 = PIL.Image.open('../src/test/5_100.jpg')
# Transform
input = transform(img)
input2 = transform(img2)
input3 = transform(img3)
input4 = transform(img4)
# unsqueeze batch dimension, in case you are dealing with a single image
input = input.unsqueeze(0)
input2 = input2.unsqueeze(0)
input3 = input3.unsqueeze(0)
input4 = input4.unsqueeze(0)
# Set model to eval
model.eval()
# Get prediction
output = model(input)
output2 = model(input2)
output3 = model(input3)
output4 = model(input4)
print(output)
index = output.cpu().data.numpy().argmax()
print(index)
print(output2)
index = output2.cpu().data.numpy().argmax()
print(index)
print(output3)
index = output3.cpu().data.numpy().argmax()
print(index)
print(output4)
index = output4.cpu().data.numpy().argmax()
print(index)
if __name__ == "__main__":
NN_accuracy()

View File

@ -1,56 +0,0 @@
import PIL
import torchvision.transforms as transforms
from AI import neural_network
# Create the model
model = neural_network.Net()
# Load state_dict
neural_network.load_network_from_structure(model)
# Create the preprocessing transformation here
transform = transforms.Compose([neural_network.Negative(), transforms.ToTensor()])
# load your image(s)
img = PIL.Image.open('../src/test/0_100.jpg')
img2 = PIL.Image.open('../src/test/1_100.jpg')
img3 = PIL.Image.open('../src/test/4_100.jpg')
img4 = PIL.Image.open('../src/test/5_100.jpg')
# Transform
input = transform(img)
input2 = transform(img2)
input3 = transform(img3)
input4 = transform(img4)
# unsqueeze batch dimension, in case you are dealing with a single image
input = input.unsqueeze(0)
input2 = input2.unsqueeze(0)
input3 = input3.unsqueeze(0)
input4 = input4.unsqueeze(0)
# Set model to eval
model.eval()
# Get prediction
output = model(input)
output2 = model(input2)
output3 = model(input3)
output4 = model(input4)
print(output)
index = output.cpu().data.numpy().argmax()
print(index)
print(output2)
index = output2.cpu().data.numpy().argmax()
print(index)
print(output3)
index = output3.cpu().data.numpy().argmax()
print(index)
print(output4)
index = output4.cpu().data.numpy().argmax()
print(index)

View File

@ -120,5 +120,7 @@ def pretty_print(root, n):
pretty_print(child[0], n + 1) pretty_print(child[0], n + 1)
# Get view of decision_tree.py
if __name__ == "__main__":
tree = treelearn(cases, attributes, 0) tree = treelearn(cases, attributes, 0)
pretty_print(tree, 0) pretty_print(tree, 0)

View File

@ -27,7 +27,7 @@ def plotdigit(image):
transform = transforms.Compose([Negative(), transforms.ToTensor()]) transform = transforms.Compose([Negative(), transforms.ToTensor()])
train_set = torchvision.datasets.ImageFolder(root='train', transform=transform) train_set = torchvision.datasets.ImageFolder(root='../src/train', transform=transform)
classes = ("apple", "potato") classes = ("apple", "potato")
BATCH_SIZE = 2 BATCH_SIZE = 2
@ -99,6 +99,7 @@ def load_network_from_structure(network):
network.load_state_dict(torch.load('network_model.pth')) network.load_state_dict(torch.load('network_model.pth'))
# Create network_model.pth
if __name__ == "__main__": if __name__ == "__main__":
print(torch.cuda.is_available()) print(torch.cuda.is_available())
training_network() training_network()