2024-01-04 21:07:52 +01:00
|
|
|
from io import BytesIO
|
|
|
|
|
2024-01-14 14:53:54 +01:00
|
|
|
import numpy as np
|
2024-01-04 21:07:52 +01:00
|
|
|
from PIL import Image
|
2024-01-14 14:53:54 +01:00
|
|
|
from keras.src.applications.resnet import preprocess_input, decode_predictions
|
|
|
|
from keras.applications.resnet import ResNet50
|
2024-01-04 21:07:52 +01:00
|
|
|
|
|
|
|
|
2024-01-14 14:53:54 +01:00
|
|
|
"""
|
|
|
|
Recognition file.
|
|
|
|
Model is ResNet50. Pretrained model to image recognition.
|
|
|
|
If model recognize cat then returns response with first ten CAT predictions.
|
|
|
|
If first prediction is not a cat then returns False.
|
|
|
|
If prediction is not a cat (is not within list_of_labels) then skips this prediction.
|
|
|
|
Format of response:
|
|
|
|
{
|
|
|
|
'label': {label}
|
|
|
|
'score': {score}
|
|
|
|
}
|
|
|
|
"""
|
2024-01-04 21:07:52 +01:00
|
|
|
|
|
|
|
|
2024-01-14 14:53:54 +01:00
|
|
|
model = ResNet50(weights='imagenet')
|
2024-01-04 21:07:52 +01:00
|
|
|
|
2024-01-14 14:53:54 +01:00
|
|
|
|
|
|
|
# PRIVATE Preprocess image method
|
|
|
|
def _preprocess_image(image):
|
2024-01-04 21:07:52 +01:00
|
|
|
try:
|
|
|
|
img = Image.open(BytesIO(image.read()))
|
2024-01-14 14:53:54 +01:00
|
|
|
img = img.resize((224, 224))
|
|
|
|
img_array = np.array(img)
|
|
|
|
img_array = np.expand_dims(img_array, axis=0)
|
|
|
|
img_array = preprocess_input(img_array)
|
|
|
|
return img_array
|
2024-01-04 21:07:52 +01:00
|
|
|
except Exception as e:
|
2024-01-14 14:53:54 +01:00
|
|
|
print(f"Error preprocessing image: {e}")
|
2024-01-04 21:07:52 +01:00
|
|
|
return None
|
2024-01-14 14:53:54 +01:00
|
|
|
|
|
|
|
|
|
|
|
# Generate response
|
|
|
|
def _generate_response(decoded_predictions, list_of_labels):
|
|
|
|
results = {}
|
|
|
|
for i, (imagenet_id, label, score) in enumerate(decoded_predictions):
|
|
|
|
if i == 0 and label not in list_of_labels:
|
|
|
|
return None
|
|
|
|
if score < 0.01:
|
|
|
|
break
|
|
|
|
if label in list_of_labels:
|
|
|
|
results[len(results) + 1] = {"label": label, "score": round(float(score), 2)}
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
# Cat detection
|
|
|
|
def detect_cat(image_file, list_of_labels):
|
|
|
|
img_array = _preprocess_image(image_file)
|
|
|
|
prediction = model.predict(img_array)
|
|
|
|
decoded_predictions = decode_predictions(prediction, top=10)[0]
|
|
|
|
return _generate_response(decoded_predictions, list_of_labels)
|