BlurMe/ml/element_detection.py

70 lines
2.1 KiB
Python
Raw Normal View History

import os
from typing import List, Tuple
import torch
from PIL import Image, ImageDraw, ImageFont
from ultralytics import YOLO
DIR_PATH = os.path.dirname(os.path.realpath(__file__))
FACES_PATH = os.path.join(DIR_PATH, "assets/weights/yolov8n-face.pt")
PLATES_PATH = os.path.join(DIR_PATH, "assets/weights/yolov8-plate.pt")
FACES_MODEL = YOLO(FACES_PATH)
PLATES_MODEL = YOLO(PLATES_PATH)
CONF_THRESH = 0.3
IOU_THRESH = 0.5
class BoundBox:
def __init__(self, x1, y1, x2, y2):
self.x1, self.y1, self.x2, self.y2 = x1, y1, x2, y2
self.selected = True
def select(self):
self.selected = True
def unselect(self):
self.selected = False
def get_params(self) -> Tuple[int, int, int, int]:
return self.x1, self.y1, self.x2, self.y2
def detect(image_path: str) -> List[BoundBox]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
faces = FACES_MODEL.predict(
source=image_path, conf=CONF_THRESH, iou=IOU_THRESH, device=device
)
faces = faces[0].cpu().numpy().boxes
plates = PLATES_MODEL.predict(
source=image_path, conf=CONF_THRESH, iou=IOU_THRESH, device=device
)
plates = plates[0].cpu().numpy().boxes
bounding_boxes = []
for boxes in [faces, plates]:
for box in boxes:
xyxyn = box.xyxy[0]
x1 = int(xyxyn[0])
y1 = int(xyxyn[1])
x2 = int(xyxyn[2])
y2 = int(xyxyn[3])
bounding_boxes.append(BoundBox(x1, y1, x2, y2))
return bounding_boxes
def show_image_with_boxes(
in_image_path: str, bounding_boxes: List[BoundBox], out_image_path: str = None
):
img = Image.open(in_image_path)
draw = ImageDraw.Draw(img)
font_path = DIR_PATH + "/assets/fonts/arial.ttf"
font = ImageFont.truetype(font_path, 25)
for i, box in enumerate(bounding_boxes):
draw.rectangle(box.get_params(), outline="red", width=2, fill=None)
draw.text((box.x1 + 5, box.y1 + 5), str(i+1), fill="red", font=font)
if not out_image_path:
out_image_path = (
in_image_path.split(".")[0] + "_out." + in_image_path.split(".")[1]
)
img.save(out_image_path)