Flask API

This commit is contained in:
Michael 2024-01-04 21:07:52 +01:00
parent 68ce008dea
commit bf70d8eb9c
No known key found for this signature in database
GPG Key ID: 066ED7D431960C9B
3 changed files with 86 additions and 50 deletions

48
cat_detection.py Normal file
View File

@ -0,0 +1,48 @@
from io import BytesIO
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision.models.resnet import resnet50, ResNet50_Weights
from torchvision.transforms import transforms
model = resnet50(weights=ResNet50_Weights.DEFAULT)
model.eval()
# Define the image transformations
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def is_cat(image):
try:
img = Image.open(BytesIO(image.read()))
# Preprocess the image
img_t = preprocess(img)
batch_t = torch.unsqueeze(img_t, 0)
# Make the prediction
out = model(batch_t)
# Apply softmax to get probabilities
probabilities = F.softmax(out, dim=1)
# Get the maximum predicted class and its probability
max_prob, max_class = torch.max(probabilities, dim=1)
max_prob = max_prob.item()
max_class = max_class.item()
# Check if the maximum predicted class is within the range 281-285
if 281 <= max_class <= 285:
return max_class, max_prob
else:
return max_class, None
except Exception as e:
print("Error while processing the image:", e)
return None

9
docs.md Normal file
View File

@ -0,0 +1,9 @@
# Api
Port -> 5000
endpoint -> /detect-cat
Key -> 'Image'
Value -> {UPLOADED_FILE}

79
main.py
View File

@ -1,59 +1,38 @@
from PIL import Image
import torch
import torch.nn.functional as F
from torchvision.models.resnet import resnet50, ResNet50_Weights
from torchvision.transforms import transforms
from flask import Flask, request, jsonify, session
# Load the pre-trained model
model = resnet50(weights=ResNet50_Weights.DEFAULT)
from cat_detection import is_cat
model.eval()
# Define the image transformations
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Define flask app
app = Flask(__name__)
app.secret_key = 'secret_key'
def is_cat(image_path):
# Open the image
img = Image.open(image_path)
@app.route('/detect-cat', methods=['POST'])
def upload_file():
# 'Key' in body should be named as 'image'. Type should be 'File' and in 'Value' we should upload image from disc.
file = request.files['image']
if file.filename == '':
return jsonify({'error': "File name is empty. Please name a file."}), 400
max_class, max_prob = is_cat(file)
# Preprocess the image
img_t = preprocess(img)
batch_t = torch.unsqueeze(img_t, 0)
# Save result in session
session['result'] = max_class, max_prob
# Make the prediction
out = model(batch_t)
# Apply softmax to get probabilities
probabilities = F.softmax(out, dim=1)
# Get the maximum predicted class and its probability
max_prob, max_class = torch.max(probabilities, dim=1)
max_prob = max_prob.item()
max_class = max_class.item()
# Check if the maximum predicted class is within the range 281-285
if 281 <= max_class <= 285:
return max_class, max_prob
# Tworzenie komunikatu na podstawie wyniku analizy zdjęcia
translator = {
281: "tabby cat",
282: "tiger cat",
283: "persian cat",
284: "siamese cat",
285: "egyptian cat"
}
if max_prob is not None:
result = f"The image is recognized as '{translator[max_class]}' with a probability of {round(max_prob * 100, 2)}%"
else:
return max_class, None
result = f"The image is not recognized as a class within the range 281-285 ({max_class})"
return jsonify({'result': result}), 200
image_path = 'wolf.jpg'
max_class, max_prob = is_cat(image_path)
translator = {
281: "tabby cat",
282: "tiger cat",
283: "persian cat",
284: "siamese cat",
285: "egyptian cat"
}
if max_prob is not None:
print(f"The image is recognized as '{translator[max_class]}' with a probability of {round(max_prob * 100, 2)}%")
else:
print(f"The image is not recognized as a class within the range 281-285 ({max_class})")
if __name__ == '__main__':
app.run(debug=True)