Merge pull request 'Flask API' (#2) from flask-ML into dev
Reviewed-on: #2
This commit is contained in:
commit
caf98c54cd
48
cat_detection.py
Normal file
48
cat_detection.py
Normal file
@ -0,0 +1,48 @@
|
||||
from io import BytesIO
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from torchvision.models.resnet import resnet50, ResNet50_Weights
|
||||
from torchvision.transforms import transforms
|
||||
|
||||
model = resnet50(weights=ResNet50_Weights.DEFAULT)
|
||||
|
||||
model.eval()
|
||||
|
||||
# Define the image transformations
|
||||
preprocess = transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
|
||||
|
||||
def is_cat(image):
|
||||
try:
|
||||
img = Image.open(BytesIO(image.read()))
|
||||
|
||||
# Preprocess the image
|
||||
img_t = preprocess(img)
|
||||
batch_t = torch.unsqueeze(img_t, 0)
|
||||
|
||||
# Make the prediction
|
||||
out = model(batch_t)
|
||||
|
||||
# Apply softmax to get probabilities
|
||||
probabilities = F.softmax(out, dim=1)
|
||||
|
||||
# Get the maximum predicted class and its probability
|
||||
max_prob, max_class = torch.max(probabilities, dim=1)
|
||||
max_prob = max_prob.item()
|
||||
max_class = max_class.item()
|
||||
|
||||
# Check if the maximum predicted class is within the range 281-285
|
||||
if 281 <= max_class <= 285:
|
||||
return max_class, max_prob
|
||||
else:
|
||||
return max_class, None
|
||||
except Exception as e:
|
||||
print("Error while processing the image:", e)
|
||||
return None
|
9
docs.md
Normal file
9
docs.md
Normal file
@ -0,0 +1,9 @@
|
||||
# Api
|
||||
|
||||
Port -> 5000
|
||||
|
||||
endpoint -> /detect-cat
|
||||
|
||||
Key -> 'Image'
|
||||
|
||||
Value -> {UPLOADED_FILE}
|
79
main.py
79
main.py
@ -1,59 +1,38 @@
|
||||
from PIL import Image
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torchvision.models.resnet import resnet50, ResNet50_Weights
|
||||
from torchvision.transforms import transforms
|
||||
from flask import Flask, request, jsonify, session
|
||||
|
||||
# Load the pre-trained model
|
||||
model = resnet50(weights=ResNet50_Weights.DEFAULT)
|
||||
from cat_detection import is_cat
|
||||
|
||||
model.eval()
|
||||
|
||||
# Define the image transformations
|
||||
preprocess = transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
# Define flask app
|
||||
app = Flask(__name__)
|
||||
app.secret_key = 'secret_key'
|
||||
|
||||
|
||||
def is_cat(image_path):
|
||||
# Open the image
|
||||
img = Image.open(image_path)
|
||||
@app.route('/detect-cat', methods=['POST'])
|
||||
def upload_file():
|
||||
# 'Key' in body should be named as 'image'. Type should be 'File' and in 'Value' we should upload image from disc.
|
||||
file = request.files['image']
|
||||
if file.filename == '':
|
||||
return jsonify({'error': "File name is empty. Please name a file."}), 400
|
||||
max_class, max_prob = is_cat(file)
|
||||
|
||||
# Preprocess the image
|
||||
img_t = preprocess(img)
|
||||
batch_t = torch.unsqueeze(img_t, 0)
|
||||
# Save result in session
|
||||
session['result'] = max_class, max_prob
|
||||
|
||||
# Make the prediction
|
||||
out = model(batch_t)
|
||||
|
||||
# Apply softmax to get probabilities
|
||||
probabilities = F.softmax(out, dim=1)
|
||||
|
||||
# Get the maximum predicted class and its probability
|
||||
max_prob, max_class = torch.max(probabilities, dim=1)
|
||||
max_prob = max_prob.item()
|
||||
max_class = max_class.item()
|
||||
|
||||
# Check if the maximum predicted class is within the range 281-285
|
||||
if 281 <= max_class <= 285:
|
||||
return max_class, max_prob
|
||||
# Tworzenie komunikatu na podstawie wyniku analizy zdjęcia
|
||||
translator = {
|
||||
281: "tabby cat",
|
||||
282: "tiger cat",
|
||||
283: "persian cat",
|
||||
284: "siamese cat",
|
||||
285: "egyptian cat"
|
||||
}
|
||||
if max_prob is not None:
|
||||
result = f"The image is recognized as '{translator[max_class]}' with a probability of {round(max_prob * 100, 2)}%"
|
||||
else:
|
||||
return max_class, None
|
||||
result = f"The image is not recognized as a class within the range 281-285 ({max_class})"
|
||||
|
||||
return jsonify({'result': result}), 200
|
||||
|
||||
|
||||
image_path = 'wolf.jpg'
|
||||
max_class, max_prob = is_cat(image_path)
|
||||
translator = {
|
||||
281: "tabby cat",
|
||||
282: "tiger cat",
|
||||
283: "persian cat",
|
||||
284: "siamese cat",
|
||||
285: "egyptian cat"
|
||||
}
|
||||
if max_prob is not None:
|
||||
print(f"The image is recognized as '{translator[max_class]}' with a probability of {round(max_prob * 100, 2)}%")
|
||||
else:
|
||||
print(f"The image is not recognized as a class within the range 281-285 ({max_class})")
|
||||
if __name__ == '__main__':
|
||||
app.run(debug=True)
|
||||
|
Loading…
Reference in New Issue
Block a user