neural_networks2
This commit is contained in:
parent
59945fb8de
commit
9ca08c7afb
@ -15,20 +15,10 @@ dig_train_images, dig_train_labels = extract_training_samples('digits')
|
||||
dig_test_images, dig_test_labels = extract_test_samples('digits')
|
||||
let_train_images, let_train_labels = extract_training_samples('letters')
|
||||
let_test_images, let_test_labels = extract_test_samples('letters')
|
||||
#print(dig_train_images.shape)
|
||||
|
||||
#def plotdigit(image):
|
||||
# img = np.reshape(image, (-1, 28))
|
||||
# imshow(img, cmap='Greys', vmin=0, vmax=255)
|
||||
print(dig_train_images.shape)
|
||||
"""
|
||||
dig_train_images = dig_train_images / 255
|
||||
dig_test_images = dig_test_images / 255
|
||||
let_train_images = let_train_images / 255
|
||||
let_test_images = let_test_images / 255
|
||||
|
||||
dig_train_images = [torch.tensor(image, dtype=torch.float32) for image in dig_train_images]
|
||||
"""
|
||||
|
||||
|
||||
#print(dig_train_images[0])
|
||||
dig_train_images = dig_train_images.reshape(len(dig_train_images),28*28)
|
||||
d_train = dig_train_images[:1000]
|
||||
@ -45,13 +35,16 @@ print(d_labelstest)
|
||||
|
||||
|
||||
class NeuralNetwork:
|
||||
def __init__(self, inputNodes, hiddenNodes, outputNodes, learningGrade):
|
||||
def __init__(self, inputNodes, hiddenNodes, outputNodes, learningGrade, fileWeight, fileHidden):
|
||||
self.inodes = inputNodes
|
||||
self.hnodes = hiddenNodes
|
||||
self.onodes = outputNodes
|
||||
|
||||
self.weights = (np.random.rand(self.hnodes, self.inodes) - 0.5)
|
||||
self.hidden = (np.random.rand(self.onodes, self.hnodes) - 0.5)
|
||||
"""te pierwsze dwa użyj przy nauce, potem zostaw cały czas te 2"""
|
||||
#self.weights = (np.random.rand(self.hnodes, self.inodes) - 0.5)
|
||||
#self.hidden = (np.random.rand(self.onodes, self.hnodes) - 0.5)
|
||||
self.weights = np.load(fileWeight)
|
||||
self.hidden = np.load(fileHidden)
|
||||
|
||||
#print( 'Matrix1 \n', self.weights)
|
||||
#print( 'Matrix2 \n', self.hidden)
|
||||
@ -89,6 +82,10 @@ class NeuralNetwork:
|
||||
|
||||
pass
|
||||
|
||||
def saveTraining(self, fileWeight, fileHidden):
|
||||
np.save(fileWeight, self.weights)
|
||||
np.save(fileHidden, self.hidden)
|
||||
|
||||
def query(self, inputsList):
|
||||
|
||||
inputs = np.array(inputsList, ndmin=2).T
|
||||
@ -102,29 +99,20 @@ class NeuralNetwork:
|
||||
|
||||
return finalOutputs
|
||||
|
||||
|
||||
|
||||
"""
|
||||
def getAccurancy(predictons,Y):
|
||||
print(predictons,Y)
|
||||
return np.sum(predictons=Y)/Y.size
|
||||
|
||||
def getPredictions(A2):
|
||||
return np.argmax(A2,0)
|
||||
"""
|
||||
|
||||
|
||||
#n = NeuralNetwork(inputNodes=3, hiddenNodes=5, outputNodes=2, learningGrade=0.2)
|
||||
n = NeuralNetwork(inputNodes=784, hiddenNodes=200, outputNodes=10, learningGrade=0.1)
|
||||
|
||||
def trainNetwork(n):
|
||||
""" dodaj tablicę literek"""
|
||||
#n = NeuralNetwork(inputNodes=3, hiddenNodes=5, outputNodes=2, learningGrade=0.2)
|
||||
digitNetwork = NeuralNetwork(inputNodes=784, hiddenNodes=200, outputNodes=10, learningGrade=0.1, fileWeight="Dweights.npy", fileHidden="Dhidden.npy")
|
||||
|
||||
def trainNetwork(n, fWeight, fHidden, trainingSamples):
|
||||
epochs = 10
|
||||
outputNodes = 10
|
||||
for e in range(epochs):
|
||||
m=0
|
||||
print('Epoch', e+1)
|
||||
|
||||
for record in d_train:
|
||||
for record in trainingSamples:
|
||||
inputs = (np.asfarray(record[0:])/255 * 0.99) + 0.01
|
||||
#print(inputs.shape)
|
||||
|
||||
@ -136,15 +124,17 @@ def trainNetwork(n):
|
||||
m+=1
|
||||
pass
|
||||
pass
|
||||
n.saveTraining(fileWeight=fWeight, fileHidden=fHidden)
|
||||
|
||||
|
||||
trainNetwork(n)
|
||||
|
||||
record = d_test[0]
|
||||
##################################### ODPALANIE TRAINING
|
||||
#trainNetwork(digitNetwork, "Dweights.npy", "Dhidden.npy", d_train)
|
||||
|
||||
#record = d_test[0]
|
||||
#print('Label', d_labelstest[0])
|
||||
inputs = np.asfarray(record[0:])/ 255 * 0.99 + 0.01
|
||||
#inputs = np.asfarray(record[0:])/ 255 * 0.99 + 0.01
|
||||
#print(n.query(inputs))
|
||||
|
||||
|
||||
|
||||
|
||||
#testing
|
Loading…
Reference in New Issue
Block a user