diff --git a/ml/assets/fonts/arial.ttf b/ml/assets/fonts/arial.ttf new file mode 100644 index 0000000..ff0815c Binary files /dev/null and b/ml/assets/fonts/arial.ttf differ diff --git a/ml/assets/weights/yolov8-plate.pt b/ml/assets/weights/yolov8-plate.pt new file mode 100644 index 0000000..9991e51 Binary files /dev/null and b/ml/assets/weights/yolov8-plate.pt differ diff --git a/ml/assets/weights/yolov8n-face.pt b/ml/assets/weights/yolov8n-face.pt new file mode 100644 index 0000000..0798a1f Binary files /dev/null and b/ml/assets/weights/yolov8n-face.pt differ diff --git a/ml/element_detection.py b/ml/element_detection.py new file mode 100644 index 0000000..071cb77 --- /dev/null +++ b/ml/element_detection.py @@ -0,0 +1,69 @@ +import os +from typing import List, Tuple + +import torch +from PIL import Image, ImageDraw, ImageFont + +from ultralytics import YOLO + +DIR_PATH = os.path.dirname(os.path.realpath(__file__)) +FACES_PATH = os.path.join(DIR_PATH, "assets/weights/yolov8n-face.pt") +PLATES_PATH = os.path.join(DIR_PATH, "assets/weights/yolov8-plate.pt") +FACES_MODEL = YOLO(FACES_PATH) +PLATES_MODEL = YOLO(PLATES_PATH) +CONF_THRESH = 0.3 +IOU_THRESH = 0.5 + + +class BoundBox: + def __init__(self, x1, y1, x2, y2): + self.x1, self.y1, self.x2, self.y2 = x1, y1, x2, y2 + self.selected = True + + def select(self): + self.selected = True + + def unselect(self): + self.selected = False + + def get_params(self) -> Tuple[int, int, int, int]: + return self.x1, self.y1, self.x2, self.y2 + + +def detect(image_path: str) -> List[BoundBox]: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + faces = FACES_MODEL.predict( + source=image_path, conf=CONF_THRESH, iou=IOU_THRESH, device=device + ) + faces = faces[0].cpu().numpy().boxes + plates = PLATES_MODEL.predict( + source=image_path, conf=CONF_THRESH, iou=IOU_THRESH, device=device + ) + plates = plates[0].cpu().numpy().boxes + bounding_boxes = [] + for boxes in [faces, plates]: + for box in boxes: + xyxyn = box.xyxy[0] + x1 = int(xyxyn[0]) + y1 = int(xyxyn[1]) + x2 = int(xyxyn[2]) + y2 = int(xyxyn[3]) + bounding_boxes.append(BoundBox(x1, y1, x2, y2)) + return bounding_boxes + + +def show_image_with_boxes( + in_image_path: str, bounding_boxes: List[BoundBox], out_image_path: str = None +): + img = Image.open(in_image_path) + draw = ImageDraw.Draw(img) + font_path = DIR_PATH + "/assets/fonts/arial.ttf" + font = ImageFont.truetype(font_path, 25) + for i, box in enumerate(bounding_boxes): + draw.rectangle(box.get_params(), outline="red", width=2, fill=None) + draw.text((box.x1 + 5, box.y1 + 5), str(i+1), fill="red", font=font) + if not out_image_path: + out_image_path = ( + in_image_path.split(".")[0] + "_out." + in_image_path.split(".")[1] + ) + img.save(out_image_path) diff --git a/requirements.txt b/requirements.txt index 2a58fd0..b0a1756 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,93 @@ - +alembic==1.12.1 +anyio==4.0.0 +async-timeout==4.0.3 +bidict==0.22.1 +certifi==2023.11.17 +charset-normalizer==3.3.2 +click==8.1.7 +cloudpickle==2.2.1 +contourpy==1.2.0 +cycler==0.12.1 +distro==1.8.0 +exceptiongroup==1.1.3 +fastapi==0.96.1 +filelock==3.13.1 +fonttools==4.44.3 +fsspec==2023.10.0 +greenlet==3.0.1 +gunicorn==20.1.0 +h11==0.14.0 +httpcore==0.17.3 +httpx==0.24.1 +idna==3.4 +Jinja2==3.1.2 +kiwisolver==1.4.5 +Mako==1.3.0 +markdown-it-py==3.0.0 +MarkupSafe==2.1.3 +matplotlib==3.8.2 +mdurl==0.1.2 +mpmath==1.3.0 +networkx==3.2.1 +numpy==1.26.2 +nvidia-cublas-cu12==12.1.3.1 +nvidia-cuda-cupti-cu12==12.1.105 +nvidia-cuda-nvrtc-cu12==12.1.105 +nvidia-cuda-runtime-cu12==12.1.105 +nvidia-cudnn-cu12==8.9.2.26 +nvidia-cufft-cu12==11.0.2.54 +nvidia-curand-cu12==10.3.2.106 +nvidia-cusolver-cu12==11.4.5.107 +nvidia-cusparse-cu12==12.1.0.106 +nvidia-nccl-cu12==2.18.1 +nvidia-nvjitlink-cu12==12.3.101 +nvidia-nvtx-cu12==12.1.105 +opencv-python==4.8.1.78 +packaging==23.2 +pandas==2.1.3 +Pillow==10.1.0 +pipdeptree==2.13.1 +platformdirs==3.11.0 +psutil==5.9.6 +py-cpuinfo==9.0.0 +pydantic==1.10.13 +Pygments==2.16.1 +pyparsing==3.1.1 +python-dateutil==2.8.2 +python-engineio==4.8.0 +python-multipart==0.0.5 +python-socketio==5.10.0 +pytz==2023.3.post1 +PyYAML==6.0.1 +redis==4.6.0 reflex==0.3.2 +requests==2.31.0 +rich==13.7.0 +scipy==1.11.3 +seaborn==0.13.0 +simple-websocket==1.0.0 +six==1.16.0 +sniffio==1.3.0 +SQLAlchemy==1.4.41 +sqlalchemy2-stubs==0.0.2a37 +sqlmodel==0.0.8 +starlette==0.27.0 +starlette-admin==0.9.0 +sympy==1.12 +tabulate==0.9.0 +thop==0.1.1.post2209072238 +torch==2.1.1 +torchvision==0.16.1 +tqdm==4.66.1 +triton==2.1.0 +typer==0.9.0 +typing_extensions==4.8.0 +tzdata==2023.3 +ultralytics==8.0.212 +urllib3==2.1.0 +uvicorn==0.20.0 +watchdog==2.3.1 +watchfiles==0.19.0 +websockets==10.4 +wrapt==1.16.0 +wsproto==1.2.0