diff --git a/NeuralNetwork/prediction.py b/NeuralNetwork/prediction.py index d6e27ee..6ee634a 100644 --- a/NeuralNetwork/prediction.py +++ b/NeuralNetwork/prediction.py @@ -5,9 +5,9 @@ from PIL import Image from NeuralNetwork import NeuralNetwork def getPrediction(img_path): - + # Inicjacja sieci neuronowej - neural_net = NeuralNetwork() + neural_net = NeuralNetwork.NeuralNetwork() PATH = './trained_nn.pth' img = Image.open(img_path) transform_tensor = transforms.ToTensor()(img).unsqueeze_(0) diff --git a/NeuralNetwork/train_nn.py b/NeuralNetwork/train_nn.py index 22cd7ba..fc52ae5 100644 --- a/NeuralNetwork/train_nn.py +++ b/NeuralNetwork/train_nn.py @@ -20,7 +20,7 @@ def trainNeuralNetwork(): criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(neural_net.parameters(), lr=0.001, momentum=0.9) - epoch_num = 4 # najlepiej 10, dla lepszej wiarygodności + epoch_num = 10 # najlepiej 10, dla lepszej wiarygodności for epoch in range(epoch_num): measure_loss = 0.0 for i, data in enumerate(trainloader, 0): diff --git a/test_nn.py b/test_nn.py new file mode 100644 index 0000000..89ec886 --- /dev/null +++ b/test_nn.py @@ -0,0 +1,7 @@ +from NeuralNetwork import prediction + +def main(): + print(prediction.getPrediction("./resources/trash_dataset/test/paper/google-image(0639).jpeg")) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/resources/trained_nn.pth b/trained_nn.pth similarity index 76% rename from resources/trained_nn.pth rename to trained_nn.pth index 14c1442..dcdfa8b 100644 Binary files a/resources/trained_nn.pth and b/trained_nn.pth differ