psi/src/neural_network.py

244 lines
7.5 KiB
Python
Raw Normal View History

2021-06-20 16:56:42 +02:00
from emnist import extract_test_samples
from emnist import extract_training_samples
import numpy as np
import scipy.special
import glob
import imageio
2021-06-21 01:43:32 +02:00
""" pobranie obrazów cyfr i liter z biblioteki """
2021-06-20 16:56:42 +02:00
dig_train_images, dig_train_labels = extract_training_samples('digits')
dig_test_images, dig_test_labels = extract_test_samples('digits')
let_train_images, let_train_labels = extract_training_samples('letters')
let_test_images, let_test_labels = extract_test_samples('letters')
2021-06-21 01:43:32 +02:00
""" przekształcenie tablic """
2021-06-20 16:56:42 +02:00
dig_train_images = dig_train_images.reshape(len(dig_train_images),28*28)
dig_test_images = dig_test_images.reshape(len(dig_test_images),28*28)
2021-06-21 01:43:32 +02:00
let_train_images = let_train_images.reshape(len(let_train_images),28*28)
let_test_images = let_test_images.reshape(len(let_test_images),28*28)
2021-06-20 16:56:42 +02:00
class NeuralNetwork:
2021-06-21 01:43:32 +02:00
""" inicjalizacja sieci neuronowej """
2021-06-20 17:59:01 +02:00
def __init__(self, inputNodes, hiddenNodes, outputNodes, learningGrade, fileWeight, fileHidden):
2021-06-20 16:56:42 +02:00
self.inodes = inputNodes
self.hnodes = hiddenNodes
self.onodes = outputNodes
2021-06-21 01:43:32 +02:00
""" używane przy uczeniu sieci """
self.weights = (np.random.rand(self.hnodes, self.inodes) - 0.5)
self.hidden = (np.random.rand(self.onodes, self.hnodes) - 0.5)
""" używane przy pobieraniu danych o nauczonej sieci, z pliku """
# self.weights = np.load(fileWeight)
# self.hidden = np.load(fileHidden)
2021-06-20 16:56:42 +02:00
self.lr = learningGrade
2021-06-21 01:43:32 +02:00
""" funkcja aktywacji """
2021-06-20 16:56:42 +02:00
self.activationFunction = lambda x: scipy.special.expit(x)
pass
2021-06-21 01:43:32 +02:00
"""trening sieci neuronowej"""
2021-06-20 16:56:42 +02:00
def train(self, inputsList, targetsList):
2021-06-21 01:43:32 +02:00
""" konwersja list na tablice 2d """
2021-06-20 16:56:42 +02:00
inputs = np.array(inputsList,ndmin=2).T
targets = np.array(targetsList,ndmin=2).T
2021-06-21 01:43:32 +02:00
""" forward pass """
hiddenInputs = np.dot(self.weights, inputs) # input -> hidden layer
2021-06-20 16:56:42 +02:00
hiddenOutputs = self.activationFunction(hiddenInputs)
2021-06-20 22:27:23 +02:00
finalInputs = np.dot(self.hidden, hiddenOutputs)
2021-06-20 16:56:42 +02:00
finalOutputs = self.activationFunction(finalInputs)
2021-06-21 01:43:32 +02:00
""" backward pass """
2021-06-20 16:56:42 +02:00
outputErrors = targets - finalOutputs
x =self.weights.T
hiddenErrors = np.dot(self.hidden.T, outputErrors)
self.hidden += self.lr * np.dot((outputErrors * finalOutputs * (1.0 - finalOutputs)) , np.transpose(hiddenOutputs))
self.weights += self.lr * np.dot((hiddenErrors * hiddenOutputs * (1.0 - hiddenOutputs)) , np.transpose(inputs))
pass
2021-06-21 01:43:32 +02:00
""" zapisywanie wytrenowanej sieci do pliku """
2021-06-20 17:59:01 +02:00
def saveTraining(self, fileWeight, fileHidden):
np.save(fileWeight, self.weights)
np.save(fileHidden, self.hidden)
2021-06-21 01:43:32 +02:00
""" wykorzystanie sieci """
2021-06-20 16:56:42 +02:00
def query(self, inputsList):
2021-06-21 01:43:32 +02:00
""" konwersja listy na tablicę 2d """
2021-06-20 16:56:42 +02:00
inputs = np.array(inputsList, ndmin=2).T
hiddenInputs = np.dot(self.weights, inputs)
hiddenOutputs = self.activationFunction(hiddenInputs)
finalInputs = np.dot(self.hidden, hiddenOutputs)
finalOutputs = self.activationFunction(finalInputs)
return finalOutputs
2021-06-20 17:59:01 +02:00
2021-06-21 01:43:32 +02:00
""" tablice sieci neuronowych """
2021-06-20 17:59:01 +02:00
digitNetwork = NeuralNetwork(inputNodes=784, hiddenNodes=200, outputNodes=10, learningGrade=0.1, fileWeight="Dweights.npy", fileHidden="Dhidden.npy")
2021-06-21 01:43:32 +02:00
letterNetwork = NeuralNetwork(inputNodes=784, hiddenNodes=200, outputNodes=27, learningGrade=0.1, fileWeight="Lweights.npy", fileHidden="Lhidden.npy")
2021-06-20 16:56:42 +02:00
2021-06-21 01:43:32 +02:00
# trainNetwork(digitNetwork, "Dweights_test.npy", "Dhidden_test.npy", let_train_images, let_train_labels)
2021-06-20 22:27:23 +02:00
def trainNetwork(n, fWeight, fHidden, trainingSamples, trainingLabels):
2021-06-20 16:56:42 +02:00
epochs = 10
2021-06-21 01:43:32 +02:00
outputNodes = 27
2021-06-20 16:56:42 +02:00
for e in range(epochs):
m=0
print('Epoch', e+1)
2021-06-20 17:59:01 +02:00
for record in trainingSamples:
2021-06-21 01:43:32 +02:00
""" zmiana wartości przedziału z [0,255] na [0,1] """
2021-06-20 16:56:42 +02:00
inputs = (np.asfarray(record[0:])/255 * 0.99) + 0.01
targets = np.zeros(outputNodes) + 0.01
2021-06-20 22:27:23 +02:00
targets[trainingLabels[m]] = 0.99
2021-06-20 16:56:42 +02:00
n.train(inputs,targets)
m+=1
pass
pass
2021-06-20 17:59:01 +02:00
n.saveTraining(fileWeight=fWeight, fileHidden=fHidden)
2021-06-20 16:56:42 +02:00
2021-06-21 01:43:32 +02:00
def testing(n, testingSamples, testingLabels):
scorecard = []
k = 0
for record in testingSamples:
inputs = (np.asfarray(record[0:])/255 * 0.99) + 0.01
correctLabels = testingLabels[k]
outputs = n.query(inputs)
label = np.argmax(outputs)
if(label == correctLabels):
scorecard.append(1)
else:
scorecard.append(0)
k+=1
scorecardArray = np.asfarray(scorecard)
print('Performance', scorecardArray.sum() / scorecardArray.size)
testing(digitNetwork,dig_test_images,dig_test_labels)
2021-06-21 11:16:59 +02:00
testing(letterNetwork,let_test_images,let_test_labels)
2021-06-21 10:57:23 +02:00
li = []
ourOwnDataset = []
2021-06-21 02:49:08 +02:00
record_cache = None
def testCase(inputWord):
len = len(inputWord)
2021-06-21 10:57:23 +02:00
2021-06-21 02:49:08 +02:00
word = ""
2021-06-21 10:57:23 +02:00
for i in range(0,len-2):
imgArray = imageio.imread(imageFileName, as_gray=True)
2021-06-21 02:49:08 +02:00
imgData = 255 - imgArray.reshape(784)
imgData = (imgData/255 * 0.99) + 0.01
2021-06-21 10:57:23 +02:00
#inputWord[i]
word = word + recognizeLet(letterNetwork ,imgData)
2021-06-21 11:16:59 +02:00
i=len-2
for i in range(i,len):
imgArray = imageio.imread(imageFileName, as_gray=True)
imgData = 255 - imgArray.reshape(784)
imgData = (imgData/255 * 0.99) + 0.01
#inputWord[i]
word = word + recognizeNum(digitNetwork, imgData)
2021-06-21 02:49:08 +02:00
2021-06-21 10:57:23 +02:00
#assert record_cache.shape == ourOwnDataset[0].shape
#labelInput = np.asfarray(li)
2021-06-21 02:49:08 +02:00
#print(labelInput)
print('slowo: ', word)
pass
2021-06-21 10:57:23 +02:00
def recognizeLet(n,imgData):
2021-06-21 02:49:08 +02:00
letters=['','a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z']
2021-06-21 10:57:23 +02:00
#record = np.append(label,imgData)
outputs = n.query(imgData)
2021-06-21 02:49:08 +02:00
label = np.argmax(outputs)
return letters[int(label)]
2021-06-21 01:43:32 +02:00
2021-06-21 10:57:23 +02:00
def recognizeNum(n, imgData):
2021-06-21 02:49:08 +02:00
pass
2021-06-21 01:43:32 +02:00
2021-06-21 10:57:23 +02:00
#record = np.append(label,imgData)
outputs = n.query(imgData)
2021-06-21 02:49:08 +02:00
#print('Record: ',record)
2021-06-21 10:57:23 +02:00
#ourOwnDataset.append(record)
#if record_cache is None:
# record_cache = record
2021-06-21 02:49:08 +02:00
#print(ood[0])
2021-06-21 10:57:23 +02:00
#li.append(label)
label = np.argmax(outputs)
return str(label)
2021-06-21 02:49:08 +02:00
pass
2021-06-21 01:43:32 +02:00
2021-06-21 02:49:08 +02:00
"""
2021-06-21 01:43:32 +02:00
li = []
2021-06-21 02:49:08 +02:00
#ourOwnDataset = np.asfarray(ood)
2021-06-21 01:43:32 +02:00
ourOwnDataset = []
record_cache = None
2021-06-21 02:49:08 +02:00
for imageFileName in glob.glob('cyfry/?.png'):
2021-06-21 01:43:32 +02:00
label = int(imageFileName[-5:-4])
print('loading...', imageFileName)
imgArray = imageio.imread(imageFileName, as_gray=True)
#print(' imgArray: ', imgArray)
imgData = 255 - imgArray.reshape(784)
#print('imgData1: ',imgData)
imgData = (imgData/255 * 0.99) + 0.01
#print('imgData2: ',imgData)
#print(np.min(imgData))
#print(np.max(imgData))
record = np.append(label,imgData)
#print('Record: ',record)
ourOwnDataset.append(record)
if record_cache is None:
record_cache = record
#print(ood[0])
li.append(label)
pass
assert record_cache.shape == ourOwnDataset[0].shape
labelInput = np.asfarray(li)
#print(labelInput)
word = ""
2021-06-21 02:49:08 +02:00
for item in range(0,9):
2021-06-21 01:43:32 +02:00
correctLabels = labelInput[item]
2021-06-21 02:49:08 +02:00
outputs = n.query(ourOwnDataset[item][1:])
2021-06-21 01:43:32 +02:00
print(outputs)
label = np.argmax(outputs)
#print('Network says: ', label)
#labelString = np.array_str(label)
word = word + str(label)
2021-06-20 16:56:42 +02:00
2021-06-21 01:43:32 +02:00
print('slowo: ', word)
2021-06-21 02:49:08 +02:00
"""
2021-06-20 17:59:01 +02:00
2021-06-20 16:56:42 +02:00
2021-06-21 01:43:32 +02:00
##################################### URUCHOMIENIE TRENINGU
2021-06-21 11:16:59 +02:00
#trainNetwork(letterNetwork, "Lweights_test.npy", "Lhidden_test.npy", let_train_images, let_train_labels)
2021-06-21 01:43:32 +02:00
# trainNetwork(digitNetwork, "Dweights_test.npy", "Dhidden_test.npy", let_train_images, let_train_labels)