diff --git a/cat_detection.py b/cat_detection.py new file mode 100644 index 0000000..db6ce4d --- /dev/null +++ b/cat_detection.py @@ -0,0 +1,48 @@ +from io import BytesIO + +import torch +import torch.nn.functional as F +from PIL import Image +from torchvision.models.resnet import resnet50, ResNet50_Weights +from torchvision.transforms import transforms + +model = resnet50(weights=ResNet50_Weights.DEFAULT) + +model.eval() + +# Define the image transformations +preprocess = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), +]) + + +def is_cat(image): + try: + img = Image.open(BytesIO(image.read())) + + # Preprocess the image + img_t = preprocess(img) + batch_t = torch.unsqueeze(img_t, 0) + + # Make the prediction + out = model(batch_t) + + # Apply softmax to get probabilities + probabilities = F.softmax(out, dim=1) + + # Get the maximum predicted class and its probability + max_prob, max_class = torch.max(probabilities, dim=1) + max_prob = max_prob.item() + max_class = max_class.item() + + # Check if the maximum predicted class is within the range 281-285 + if 281 <= max_class <= 285: + return max_class, max_prob + else: + return max_class, None + except Exception as e: + print("Error while processing the image:", e) + return None diff --git a/docs.md b/docs.md new file mode 100644 index 0000000..c768b85 --- /dev/null +++ b/docs.md @@ -0,0 +1,9 @@ +# Api + +Port -> 5000 + +endpoint -> /detect-cat + +Key -> 'Image' + +Value -> {UPLOADED_FILE} \ No newline at end of file diff --git a/main.py b/main.py index b3f2199..69aa61f 100644 --- a/main.py +++ b/main.py @@ -1,59 +1,38 @@ -from PIL import Image -import torch -import torch.nn.functional as F -from torchvision.models.resnet import resnet50, ResNet50_Weights -from torchvision.transforms import transforms +from flask import Flask, request, jsonify, session -# Load the pre-trained model -model = resnet50(weights=ResNet50_Weights.DEFAULT) +from cat_detection import is_cat -model.eval() - -# Define the image transformations -preprocess = transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), -]) +# Define flask app +app = Flask(__name__) +app.secret_key = 'secret_key' -def is_cat(image_path): - # Open the image - img = Image.open(image_path) +@app.route('/detect-cat', methods=['POST']) +def upload_file(): + # 'Key' in body should be named as 'image'. Type should be 'File' and in 'Value' we should upload image from disc. + file = request.files['image'] + if file.filename == '': + return jsonify({'error': "File name is empty. Please name a file."}), 400 + max_class, max_prob = is_cat(file) - # Preprocess the image - img_t = preprocess(img) - batch_t = torch.unsqueeze(img_t, 0) + # Save result in session + session['result'] = max_class, max_prob - # Make the prediction - out = model(batch_t) - - # Apply softmax to get probabilities - probabilities = F.softmax(out, dim=1) - - # Get the maximum predicted class and its probability - max_prob, max_class = torch.max(probabilities, dim=1) - max_prob = max_prob.item() - max_class = max_class.item() - - # Check if the maximum predicted class is within the range 281-285 - if 281 <= max_class <= 285: - return max_class, max_prob + # Tworzenie komunikatu na podstawie wyniku analizy zdjęcia + translator = { + 281: "tabby cat", + 282: "tiger cat", + 283: "persian cat", + 284: "siamese cat", + 285: "egyptian cat" + } + if max_prob is not None: + result = f"The image is recognized as '{translator[max_class]}' with a probability of {round(max_prob * 100, 2)}%" else: - return max_class, None + result = f"The image is not recognized as a class within the range 281-285 ({max_class})" + + return jsonify({'result': result}), 200 -image_path = 'wolf.jpg' -max_class, max_prob = is_cat(image_path) -translator = { - 281: "tabby cat", - 282: "tiger cat", - 283: "persian cat", - 284: "siamese cat", - 285: "egyptian cat" -} -if max_prob is not None: - print(f"The image is recognized as '{translator[max_class]}' with a probability of {round(max_prob * 100, 2)}%") -else: - print(f"The image is not recognized as a class within the range 281-285 ({max_class})") +if __name__ == '__main__': + app.run(debug=True)