API improvements
This commit is contained in:
parent
bf70d8eb9c
commit
663ef6d80d
@ -1,48 +1,58 @@
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
import torch
|
import numpy as np
|
||||||
import torch.nn.functional as F
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchvision.models.resnet import resnet50, ResNet50_Weights
|
from keras.src.applications.resnet import preprocess_input, decode_predictions
|
||||||
from torchvision.transforms import transforms
|
from keras.applications.resnet import ResNet50
|
||||||
|
|
||||||
model = resnet50(weights=ResNet50_Weights.DEFAULT)
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
# Define the image transformations
|
|
||||||
preprocess = transforms.Compose([
|
|
||||||
transforms.Resize(256),
|
|
||||||
transforms.CenterCrop(224),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
||||||
])
|
|
||||||
|
|
||||||
|
|
||||||
def is_cat(image):
|
"""
|
||||||
|
Recognition file.
|
||||||
|
Model is ResNet50. Pretrained model to image recognition.
|
||||||
|
If model recognize cat then returns response with first ten CAT predictions.
|
||||||
|
If first prediction is not a cat then returns False.
|
||||||
|
If prediction is not a cat (is not within list_of_labels) then skips this prediction.
|
||||||
|
Format of response:
|
||||||
|
{
|
||||||
|
'label': {label}
|
||||||
|
'score': {score}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
model = ResNet50(weights='imagenet')
|
||||||
|
|
||||||
|
|
||||||
|
# PRIVATE Preprocess image method
|
||||||
|
def _preprocess_image(image):
|
||||||
try:
|
try:
|
||||||
img = Image.open(BytesIO(image.read()))
|
img = Image.open(BytesIO(image.read()))
|
||||||
|
img = img.resize((224, 224))
|
||||||
# Preprocess the image
|
img_array = np.array(img)
|
||||||
img_t = preprocess(img)
|
img_array = np.expand_dims(img_array, axis=0)
|
||||||
batch_t = torch.unsqueeze(img_t, 0)
|
img_array = preprocess_input(img_array)
|
||||||
|
return img_array
|
||||||
# Make the prediction
|
|
||||||
out = model(batch_t)
|
|
||||||
|
|
||||||
# Apply softmax to get probabilities
|
|
||||||
probabilities = F.softmax(out, dim=1)
|
|
||||||
|
|
||||||
# Get the maximum predicted class and its probability
|
|
||||||
max_prob, max_class = torch.max(probabilities, dim=1)
|
|
||||||
max_prob = max_prob.item()
|
|
||||||
max_class = max_class.item()
|
|
||||||
|
|
||||||
# Check if the maximum predicted class is within the range 281-285
|
|
||||||
if 281 <= max_class <= 285:
|
|
||||||
return max_class, max_prob
|
|
||||||
else:
|
|
||||||
return max_class, None
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error while processing the image:", e)
|
print(f"Error preprocessing image: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Generate response
|
||||||
|
def _generate_response(decoded_predictions, list_of_labels):
|
||||||
|
results = {}
|
||||||
|
for i, (imagenet_id, label, score) in enumerate(decoded_predictions):
|
||||||
|
if i == 0 and label not in list_of_labels:
|
||||||
|
return None
|
||||||
|
if score < 0.01:
|
||||||
|
break
|
||||||
|
if label in list_of_labels:
|
||||||
|
results[len(results) + 1] = {"label": label, "score": round(float(score), 2)}
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
# Cat detection
|
||||||
|
def detect_cat(image_file, list_of_labels):
|
||||||
|
img_array = _preprocess_image(image_file)
|
||||||
|
prediction = model.predict(img_array)
|
||||||
|
decoded_predictions = decode_predictions(prediction, top=10)[0]
|
||||||
|
return _generate_response(decoded_predictions, list_of_labels)
|
||||||
|
9
docs.md
9
docs.md
@ -1,9 +0,0 @@
|
|||||||
# Api
|
|
||||||
|
|
||||||
Port -> 5000
|
|
||||||
|
|
||||||
endpoint -> /detect-cat
|
|
||||||
|
|
||||||
Key -> 'Image'
|
|
||||||
|
|
||||||
Value -> {UPLOADED_FILE}
|
|
53
docs/docs.md
Normal file
53
docs/docs.md
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
# Api
|
||||||
|
|
||||||
|
Port -> 5000
|
||||||
|
|
||||||
|
endpoint -> api/v1/detect-cat
|
||||||
|
|
||||||
|
Key -> 'Image'
|
||||||
|
|
||||||
|
Value -> {UPLOADED_FILE}
|
||||||
|
|
||||||
|
Flask Rest API application to cat recognition.
|
||||||
|
If request is valid then send response with results of recognition.
|
||||||
|
If key named 'Image' in body does not occur then returns 400 (BAD REQUEST).
|
||||||
|
Otherwise, returns 200 with results of recognition.
|
||||||
|
Format of response:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"lang": "{users_lang}",
|
||||||
|
"results": {
|
||||||
|
"{filename}": {
|
||||||
|
"isCat": "{is_cat}",
|
||||||
|
"results": {
|
||||||
|
"1": "{result}",
|
||||||
|
"2": "{result}",
|
||||||
|
"3": "{result}",
|
||||||
|
"4": "{result}",
|
||||||
|
"5": "{result}",
|
||||||
|
"6": "{result}",
|
||||||
|
"7": "{result}",
|
||||||
|
"8": "{result}",
|
||||||
|
"9": "{result}",
|
||||||
|
"10": "{result}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"errors": [
|
||||||
|
"{error_message}",
|
||||||
|
"{error_message}"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
Format of result:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"label": "{label}",
|
||||||
|
"score": "{score}"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Example response:
|
||||||
|
```json
|
||||||
|
|
||||||
|
```
|
Before Width: | Height: | Size: 204 KiB After Width: | Height: | Size: 204 KiB |
Before Width: | Height: | Size: 105 KiB After Width: | Height: | Size: 105 KiB |
Before Width: | Height: | Size: 7.1 KiB After Width: | Height: | Size: 7.1 KiB |
26
language_label_mapper.py
Normal file
26
language_label_mapper.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
from jproperties import Properties
|
||||||
|
|
||||||
|
"""
|
||||||
|
Translator method.
|
||||||
|
If everything fine then returns translated labels.
|
||||||
|
Else throws an Exception and returns untranslated labels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def translate(to_translate, lang):
|
||||||
|
try:
|
||||||
|
config = Properties()
|
||||||
|
|
||||||
|
# Load properties file for given lang
|
||||||
|
with open(f"resources/{lang}.properties", 'rb') as config_file:
|
||||||
|
config.load(config_file, encoding='UTF-8')
|
||||||
|
|
||||||
|
# Translate labels for given to_translate dictionary
|
||||||
|
for index, label_info in to_translate.items():
|
||||||
|
label = label_info.get("label")
|
||||||
|
to_translate[index]["label"] = config.get(label).data
|
||||||
|
return to_translate, None
|
||||||
|
except Exception as e:
|
||||||
|
error_message = f"Error translating labels: {e}"
|
||||||
|
print(error_message)
|
||||||
|
return to_translate, error_message
|
114
main.py
114
main.py
@ -1,37 +1,105 @@
|
|||||||
from flask import Flask, request, jsonify, session
|
from flask import Flask, request, Response, json
|
||||||
|
from cat_detection import detect_cat
|
||||||
|
from language_label_mapper import translate
|
||||||
|
from validator import validate
|
||||||
|
|
||||||
|
"""
|
||||||
|
Flask Rest API application to cat recognition.
|
||||||
|
If request is valid then send response with results of recognition.
|
||||||
|
If key named 'Image' in body does not occurred then returns 400 (BAD REQUEST).
|
||||||
|
Otherwise returns 200 with results of recognition.
|
||||||
|
Format of response:
|
||||||
|
{
|
||||||
|
"lang": {users_lang},
|
||||||
|
"results": {
|
||||||
|
{filename}: {
|
||||||
|
"isCat": {is_cat},
|
||||||
|
"results": {
|
||||||
|
"1": {result}
|
||||||
|
"2": {result}
|
||||||
|
"3": {result}
|
||||||
|
...
|
||||||
|
"10" {result}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
...
|
||||||
|
},
|
||||||
|
errors[
|
||||||
|
{error_message},
|
||||||
|
{error_message},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
}
|
||||||
|
To see result format -> cat_detection.py
|
||||||
|
"""
|
||||||
|
|
||||||
from cat_detection import is_cat
|
|
||||||
|
|
||||||
# Define flask app
|
# Define flask app
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
app.secret_key = 'secret_key'
|
app.secret_key = 'secret_key'
|
||||||
|
|
||||||
|
# Available cats
|
||||||
|
list_of_labels = [
|
||||||
|
'lynx',
|
||||||
|
'lion',
|
||||||
|
'tiger',
|
||||||
|
'cheetah',
|
||||||
|
'leopard',
|
||||||
|
'jaguar',
|
||||||
|
'tabby',
|
||||||
|
'Egyptian_cat',
|
||||||
|
'cougar',
|
||||||
|
'Persian_cat',
|
||||||
|
'Siamese_cat',
|
||||||
|
'snow_leopard',
|
||||||
|
'tiger_cat'
|
||||||
|
]
|
||||||
|
|
||||||
@app.route('/detect-cat', methods=['POST'])
|
# Available languages
|
||||||
|
languages = {'pl', 'en'}
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/api/v1/detect-cat', methods=['POST'])
|
||||||
def upload_file():
|
def upload_file():
|
||||||
# 'Key' in body should be named as 'image'. Type should be 'File' and in 'Value' we should upload image from disc.
|
# Validate request
|
||||||
file = request.files['image']
|
error_messages = validate(request)
|
||||||
if file.filename == '':
|
|
||||||
return jsonify({'error': "File name is empty. Please name a file."}), 400
|
|
||||||
max_class, max_prob = is_cat(file)
|
|
||||||
|
|
||||||
# Save result in session
|
# If any errors occurred, return 400 (BAD REQUEST)
|
||||||
session['result'] = max_class, max_prob
|
if len(error_messages) > 0:
|
||||||
|
errors = json.dumps(
|
||||||
# Tworzenie komunikatu na podstawie wyniku analizy zdjęcia
|
{
|
||||||
translator = {
|
'errors': error_messages
|
||||||
281: "tabby cat",
|
|
||||||
282: "tiger cat",
|
|
||||||
283: "persian cat",
|
|
||||||
284: "siamese cat",
|
|
||||||
285: "egyptian cat"
|
|
||||||
}
|
}
|
||||||
if max_prob is not None:
|
)
|
||||||
result = f"The image is recognized as '{translator[max_class]}' with a probability of {round(max_prob * 100, 2)}%"
|
return Response(errors, status=400, mimetype='application/json')
|
||||||
else:
|
|
||||||
result = f"The image is not recognized as a class within the range 281-285 ({max_class})"
|
|
||||||
|
|
||||||
return jsonify({'result': result}), 200
|
# Get files from request
|
||||||
|
files = request.files.getlist('image')
|
||||||
|
|
||||||
|
# Get user's language (Value in header 'Accept-Language'). Default value is English
|
||||||
|
lang = request.accept_languages.best_match(languages, default='en')
|
||||||
|
|
||||||
|
# Define JSON structure for results
|
||||||
|
results = {
|
||||||
|
'lang': lang,
|
||||||
|
'results': {},
|
||||||
|
'errors': []
|
||||||
|
}
|
||||||
|
|
||||||
|
# Generate results
|
||||||
|
for file in files:
|
||||||
|
predictions = detect_cat(file, list_of_labels)
|
||||||
|
if predictions is not None:
|
||||||
|
predictions, error_messages = translate(predictions, lang)
|
||||||
|
results['results'][file.filename] = {
|
||||||
|
'isCat': False if not predictions else True,
|
||||||
|
**({'predictions': predictions} if predictions is not None else {})
|
||||||
|
}
|
||||||
|
if error_messages is not None and predictions is None:
|
||||||
|
results['errors'].append(error_messages)
|
||||||
|
|
||||||
|
# Send response with 200 (Success)
|
||||||
|
return Response(json.dumps(results), status=200, mimetype='application/json')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
14
resources/en.properties
Normal file
14
resources/en.properties
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
# EN
|
||||||
|
lynx=lynx
|
||||||
|
lion=lion
|
||||||
|
tiger=tiger
|
||||||
|
cheetah=cheetah
|
||||||
|
leopard=leopard
|
||||||
|
jaguar=jaguar
|
||||||
|
tabby=tabby
|
||||||
|
Egyptian_cat=Egyptian cat
|
||||||
|
cougar=cougar
|
||||||
|
Persian_cat=Persian cat
|
||||||
|
Siamese_cat=Siamese cat
|
||||||
|
snow_leopard=snow leopard
|
||||||
|
tiger_cat=tiger cat
|
14
resources/pl.properties
Normal file
14
resources/pl.properties
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
# PL
|
||||||
|
lynx=ryś
|
||||||
|
lion=lew
|
||||||
|
tiger=tygrys
|
||||||
|
cheetah=gepard
|
||||||
|
leopard=lampart
|
||||||
|
jaguar=jaguar
|
||||||
|
tabby=kot pręgowany
|
||||||
|
Egyptian_cat=kot egipski
|
||||||
|
cougar=puma
|
||||||
|
Persian_cat=kot perski
|
||||||
|
Siamese_cat=kot syjamski
|
||||||
|
snow_leopard=lampart śnieżny
|
||||||
|
tiger_cat=kot tygrysi
|
33
validator.py
Normal file
33
validator.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
import imghdr
|
||||||
|
|
||||||
|
"""
|
||||||
|
Validation method.
|
||||||
|
If everything fine then returns empty list.
|
||||||
|
Else returns list of error messages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Allowed extensions
|
||||||
|
allowed_extensions = {'jpg', 'jpeg', 'png'}
|
||||||
|
|
||||||
|
|
||||||
|
def validate(request):
|
||||||
|
errors = []
|
||||||
|
try:
|
||||||
|
images = request.files.getlist('image')
|
||||||
|
|
||||||
|
# Case 1 - > request has no 'Image' Key in body
|
||||||
|
if images is None:
|
||||||
|
raise KeyError("'Image' key not found in request.")
|
||||||
|
|
||||||
|
# Case 2 - > if some of the images has no filename
|
||||||
|
if not images or all(img.filename == '' for img in images):
|
||||||
|
raise ValueError("Value of 'Image' key is empty.")
|
||||||
|
|
||||||
|
# Case 3 -> if some of the images has wrong extension
|
||||||
|
for img in images:
|
||||||
|
if imghdr.what(img) not in allowed_extensions:
|
||||||
|
raise ValueError(f"Given file '{img.filename}' has no allowed extension. "
|
||||||
|
f"Allowed extensions: {allowed_extensions}.")
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(e.args[0])
|
||||||
|
return errors
|
Loading…
Reference in New Issue
Block a user