Flask API
This commit is contained in:
parent
68ce008dea
commit
bf70d8eb9c
48
cat_detection.py
Normal file
48
cat_detection.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision.models.resnet import resnet50, ResNet50_Weights
|
||||||
|
from torchvision.transforms import transforms
|
||||||
|
|
||||||
|
model = resnet50(weights=ResNet50_Weights.DEFAULT)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# Define the image transformations
|
||||||
|
preprocess = transforms.Compose([
|
||||||
|
transforms.Resize(256),
|
||||||
|
transforms.CenterCrop(224),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def is_cat(image):
|
||||||
|
try:
|
||||||
|
img = Image.open(BytesIO(image.read()))
|
||||||
|
|
||||||
|
# Preprocess the image
|
||||||
|
img_t = preprocess(img)
|
||||||
|
batch_t = torch.unsqueeze(img_t, 0)
|
||||||
|
|
||||||
|
# Make the prediction
|
||||||
|
out = model(batch_t)
|
||||||
|
|
||||||
|
# Apply softmax to get probabilities
|
||||||
|
probabilities = F.softmax(out, dim=1)
|
||||||
|
|
||||||
|
# Get the maximum predicted class and its probability
|
||||||
|
max_prob, max_class = torch.max(probabilities, dim=1)
|
||||||
|
max_prob = max_prob.item()
|
||||||
|
max_class = max_class.item()
|
||||||
|
|
||||||
|
# Check if the maximum predicted class is within the range 281-285
|
||||||
|
if 281 <= max_class <= 285:
|
||||||
|
return max_class, max_prob
|
||||||
|
else:
|
||||||
|
return max_class, None
|
||||||
|
except Exception as e:
|
||||||
|
print("Error while processing the image:", e)
|
||||||
|
return None
|
9
docs.md
Normal file
9
docs.md
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
# Api
|
||||||
|
|
||||||
|
Port -> 5000
|
||||||
|
|
||||||
|
endpoint -> /detect-cat
|
||||||
|
|
||||||
|
Key -> 'Image'
|
||||||
|
|
||||||
|
Value -> {UPLOADED_FILE}
|
67
main.py
67
main.py
@ -1,51 +1,24 @@
|
|||||||
from PIL import Image
|
from flask import Flask, request, jsonify, session
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torchvision.models.resnet import resnet50, ResNet50_Weights
|
|
||||||
from torchvision.transforms import transforms
|
|
||||||
|
|
||||||
# Load the pre-trained model
|
from cat_detection import is_cat
|
||||||
model = resnet50(weights=ResNet50_Weights.DEFAULT)
|
|
||||||
|
|
||||||
model.eval()
|
# Define flask app
|
||||||
|
app = Flask(__name__)
|
||||||
# Define the image transformations
|
app.secret_key = 'secret_key'
|
||||||
preprocess = transforms.Compose([
|
|
||||||
transforms.Resize(256),
|
|
||||||
transforms.CenterCrop(224),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
||||||
])
|
|
||||||
|
|
||||||
|
|
||||||
def is_cat(image_path):
|
@app.route('/detect-cat', methods=['POST'])
|
||||||
# Open the image
|
def upload_file():
|
||||||
img = Image.open(image_path)
|
# 'Key' in body should be named as 'image'. Type should be 'File' and in 'Value' we should upload image from disc.
|
||||||
|
file = request.files['image']
|
||||||
|
if file.filename == '':
|
||||||
|
return jsonify({'error': "File name is empty. Please name a file."}), 400
|
||||||
|
max_class, max_prob = is_cat(file)
|
||||||
|
|
||||||
# Preprocess the image
|
# Save result in session
|
||||||
img_t = preprocess(img)
|
session['result'] = max_class, max_prob
|
||||||
batch_t = torch.unsqueeze(img_t, 0)
|
|
||||||
|
|
||||||
# Make the prediction
|
# Tworzenie komunikatu na podstawie wyniku analizy zdjęcia
|
||||||
out = model(batch_t)
|
|
||||||
|
|
||||||
# Apply softmax to get probabilities
|
|
||||||
probabilities = F.softmax(out, dim=1)
|
|
||||||
|
|
||||||
# Get the maximum predicted class and its probability
|
|
||||||
max_prob, max_class = torch.max(probabilities, dim=1)
|
|
||||||
max_prob = max_prob.item()
|
|
||||||
max_class = max_class.item()
|
|
||||||
|
|
||||||
# Check if the maximum predicted class is within the range 281-285
|
|
||||||
if 281 <= max_class <= 285:
|
|
||||||
return max_class, max_prob
|
|
||||||
else:
|
|
||||||
return max_class, None
|
|
||||||
|
|
||||||
|
|
||||||
image_path = 'wolf.jpg'
|
|
||||||
max_class, max_prob = is_cat(image_path)
|
|
||||||
translator = {
|
translator = {
|
||||||
281: "tabby cat",
|
281: "tabby cat",
|
||||||
282: "tiger cat",
|
282: "tiger cat",
|
||||||
@ -54,6 +27,12 @@ translator = {
|
|||||||
285: "egyptian cat"
|
285: "egyptian cat"
|
||||||
}
|
}
|
||||||
if max_prob is not None:
|
if max_prob is not None:
|
||||||
print(f"The image is recognized as '{translator[max_class]}' with a probability of {round(max_prob * 100, 2)}%")
|
result = f"The image is recognized as '{translator[max_class]}' with a probability of {round(max_prob * 100, 2)}%"
|
||||||
else:
|
else:
|
||||||
print(f"The image is not recognized as a class within the range 281-285 ({max_class})")
|
result = f"The image is not recognized as a class within the range 281-285 ({max_class})"
|
||||||
|
|
||||||
|
return jsonify({'result': result}), 200
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
app.run(debug=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user