digit data

This commit is contained in:
s452693 2021-06-20 22:27:23 +02:00
parent 76e3aa2cb2
commit b97c4ba4b1
3 changed files with 6 additions and 11 deletions

BIN
src/Dhidden.npy Normal file

Binary file not shown.

BIN
src/Dweights.npy Normal file

Binary file not shown.

View File

@ -21,14 +21,9 @@ let_test_images, let_test_labels = extract_test_samples('letters')
#print(dig_train_images[0])
dig_train_images = dig_train_images.reshape(len(dig_train_images),28*28)
d_train = dig_train_images[:1000]
d_labels = dig_train_labels[:1000]
dig_test_images = dig_test_images.reshape(len(dig_test_images),28*28)
d_test = dig_test_images[:600]
d_labelstest = dig_test_labels[:600]
print(d_test.shape)
#print(d_test.shape)
print(d_labelstest)
#print(dig_train_images[0])
#print(dig_train_images.shape)
@ -61,10 +56,10 @@ class NeuralNetwork:
targets = np.array(targetsList,ndmin=2).T
#forward pass
hiddenInputs = np.dot(self.weights, inputs) + 2
hiddenInputs = np.dot(self.weights, inputs)
hiddenOutputs = self.activationFunction(hiddenInputs)
finalInputs = np.dot(self.hidden, hiddenOutputs) + 1
finalInputs = np.dot(self.hidden, hiddenOutputs)
finalOutputs = self.activationFunction(finalInputs)
outputErrors = targets - finalOutputs
@ -105,7 +100,7 @@ class NeuralNetwork:
#n = NeuralNetwork(inputNodes=3, hiddenNodes=5, outputNodes=2, learningGrade=0.2)
digitNetwork = NeuralNetwork(inputNodes=784, hiddenNodes=200, outputNodes=10, learningGrade=0.1, fileWeight="Dweights.npy", fileHidden="Dhidden.npy")
def trainNetwork(n, fWeight, fHidden, trainingSamples):
def trainNetwork(n, fWeight, fHidden, trainingSamples, trainingLabels):
epochs = 10
outputNodes = 10
for e in range(epochs):
@ -117,7 +112,7 @@ def trainNetwork(n, fWeight, fHidden, trainingSamples):
#print(inputs.shape)
targets = np.zeros(outputNodes) + 0.01
targets[d_labels[m]] = 0.99
targets[trainingLabels[m]] = 0.99
#print(targets)
n.train(inputs,targets)
@ -129,7 +124,7 @@ def trainNetwork(n, fWeight, fHidden, trainingSamples):
##################################### ODPALANIE TRAINING
#trainNetwork(digitNetwork, "Dweights.npy", "Dhidden.npy", d_train)
#trainNetwork(digitNetwork, "Dweights.npy", "Dhidden.npy", dig_train_images, dig_train_labels)
#record = d_test[0]
#print('Label', d_labelstest[0])