Projekt_IO/cat_detection.py

59 lines
1.7 KiB
Python
Raw Normal View History

2024-01-04 21:07:52 +01:00
from io import BytesIO
2024-01-14 14:53:54 +01:00
import numpy as np
2024-01-04 21:07:52 +01:00
from PIL import Image
2024-01-14 14:53:54 +01:00
from keras.src.applications.resnet import preprocess_input, decode_predictions
from keras.applications.resnet import ResNet50
2024-01-04 21:07:52 +01:00
2024-01-14 14:53:54 +01:00
"""
Recognition file.
Model is ResNet50. Pretrained model to image recognition.
If model recognize cat then returns response with first ten CAT predictions.
If first prediction is not a cat then returns False.
If prediction is not a cat (is not within list_of_labels) then skips this prediction.
Format of response:
{
'label': {label}
'score': {score}
}
"""
2024-01-04 21:07:52 +01:00
2024-01-14 14:53:54 +01:00
model = ResNet50(weights='imagenet')
2024-01-04 21:07:52 +01:00
2024-01-14 14:53:54 +01:00
# PRIVATE Preprocess image method
def _preprocess_image(image):
2024-01-04 21:07:52 +01:00
try:
img = Image.open(BytesIO(image.read()))
2024-01-14 14:53:54 +01:00
img = img.resize((224, 224))
img_array = np.array(img)
img_array = np.expand_dims(img_array, axis=0)
img_array = preprocess_input(img_array)
return img_array
2024-01-04 21:07:52 +01:00
except Exception as e:
2024-01-14 14:53:54 +01:00
print(f"Error preprocessing image: {e}")
2024-01-04 21:07:52 +01:00
return None
2024-01-14 14:53:54 +01:00
# Generate response
def _generate_response(decoded_predictions, list_of_labels):
results = {}
for i, (imagenet_id, label, score) in enumerate(decoded_predictions):
if i == 0 and label not in list_of_labels:
return None
if score < 0.01:
break
if label in list_of_labels:
results[len(results) + 1] = {"label": label, "score": round(float(score), 2)}
return results
# Cat detection
def detect_cat(image_file, list_of_labels):
img_array = _preprocess_image(image_file)
prediction = model.predict(img_array)
decoded_predictions = decode_predictions(prediction, top=10)[0]
return _generate_response(decoded_predictions, list_of_labels)