"""
Person and object tracker using MediaPipe Pose and OpenCV.
"""

import cv2
import mediapipe as mp
import numpy as np
from enum import Enum, auto


class TrackingState(Enum):
    SEARCHING = auto()
    TRACKING = auto()
    LOST = auto()


class TargetType(Enum):
    PERSON = auto()
    OBJECT = auto()


class TrackerResult:
    def __init__(self):
        self.state = TrackingState.SEARCHING
        self.target_type = TargetType.PERSON
        self.cx = 0
        self.cy = 0
        self.bbox = None
        self.area = 0
        self.frame_w = 0
        self.frame_h = 0

    @property
    def offset_x(self):
        if self.frame_w == 0:
            return 0.0
        return (self.cx - self.frame_w / 2) / (self.frame_w / 2)

    @property
    def offset_y(self):
        if self.frame_h == 0:
            return 0.0
        return (self.cy - self.frame_h / 2) / (self.frame_h / 2)

    @property
    def relative_size(self):
        if self.frame_w == 0 or self.frame_h == 0:
            return 0.0
        return self.area / (self.frame_w * self.frame_h)


class PersonObjectTracker:
    TARGET_SIZE_MIN = 0.03
    TARGET_SIZE_MAX = 0.25

    def __init__(self):
        self.mp_pose = mp.solutions.pose
        self.pose = self.mp_pose.Pose(
            static_image_mode=False,
            model_complexity=1,
            min_detection_confidence=0.5,
            min_tracking_confidence=0.4,
        )
        self.state = TrackingState.SEARCHING
        self.target_type = TargetType.PERSON
        self._cv_tracker = None
        self._lost_frames = 0
        self._max_lost_frames = 30

    def track_person(self, frame_rgb) -> TrackerResult:
        result = TrackerResult()
        h, w, _ = frame_rgb.shape
        result.frame_w = w
        result.frame_h = h

        pose_results = self.pose.process(frame_rgb)

        if pose_results.pose_landmarks:
            landmarks = pose_results.pose_landmarks.landmark
            xs = [lm.x * w for lm in landmarks if lm.visibility > 0.5]
            ys = [lm.y * h for lm in landmarks if lm.visibility > 0.5]

            if len(xs) > 4:
                x_min, x_max = int(min(xs)), int(max(xs))
                y_min, y_max = int(min(ys)), int(max(ys))

                pad_x = int((x_max - x_min) * 0.1)
                pad_y = int((y_max - y_min) * 0.1)
                x_min = max(0, x_min - pad_x)
                y_min = max(0, y_min - pad_y)
                x_max = min(w, x_max + pad_x)
                y_max = min(h, y_max + pad_y)

                result.cx = (x_min + x_max) // 2
                result.cy = (y_min + y_max) // 2
                result.bbox = (x_min, y_min, x_max - x_min, y_max - y_min)
                result.area = (x_max - x_min) * (y_max - y_min)
                result.state = TrackingState.TRACKING
                result.target_type = TargetType.PERSON
                self.state = TrackingState.TRACKING
                self._lost_frames = 0
                return result

        self._lost_frames += 1
        if self._lost_frames > self._max_lost_frames:
            result.state = TrackingState.LOST
            self.state = TrackingState.LOST
        else:
            result.state = TrackingState.SEARCHING
            self.state = TrackingState.SEARCHING

        return result

    def init_object_tracking(self, frame_bgr, bbox):
        self._cv_tracker = cv2.TrackerCSRT_create()
        self._cv_tracker.init(frame_bgr, bbox)
        self.target_type = TargetType.OBJECT
        self.state = TrackingState.TRACKING
        self._lost_frames = 0

    def track_object(self, frame_bgr) -> TrackerResult:
        result = TrackerResult()
        h, w = frame_bgr.shape[:2]
        result.frame_w = w
        result.frame_h = h

        if self._cv_tracker is None:
            result.state = TrackingState.LOST
            return result

        success, bbox = self._cv_tracker.update(frame_bgr)

        if success:
            x, y, bw, bh = [int(v) for v in bbox]
            result.cx = x + bw // 2
            result.cy = y + bh // 2
            result.bbox = (x, y, bw, bh)
            result.area = bw * bh
            result.state = TrackingState.TRACKING
            result.target_type = TargetType.OBJECT
            self._lost_frames = 0
        else:
            self._lost_frames += 1
            if self._lost_frames > self._max_lost_frames:
                result.state = TrackingState.LOST
                self.state = TrackingState.LOST
            else:
                result.state = TrackingState.SEARCHING

        return result

    def track(self, frame_rgb, frame_bgr) -> TrackerResult:
        if self.target_type == TargetType.PERSON:
            return self.track_person(frame_rgb)
        else:
            return self.track_object(frame_bgr)

    def switch_to_person(self):
        self.target_type = TargetType.PERSON
        self._cv_tracker = None
        self._lost_frames = 0
        self.state = TrackingState.SEARCHING

    def find_object_in_direction(self, frame_bgr, direction: str):
        h, w = frame_bgr.shape[:2]

        if direction == "left":
            roi = frame_bgr[:, : w // 2]
            x_offset = 0
        else:
            roi = frame_bgr[:, w // 2 :]
            x_offset = w // 2

        gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)
        blurred = cv2.GaussianBlur(gray, (11, 11), 0)
        edges = cv2.Canny(blurred, 30, 100)

        kernel = np.ones((5, 5), np.uint8)
        dilated = cv2.dilate(edges, kernel, iterations=2)

        contours, _ = cv2.findContours(
            dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
        )

        if not contours:
            return None

        min_area = (h * w // 4) * 0.02
        valid = [c for c in contours if cv2.contourArea(c) > min_area]

        if not valid:
            return None

        largest = max(valid, key=cv2.contourArea)
        x, y, bw, bh = cv2.boundingRect(largest)

        return (x + x_offset, y, bw, bh)

    def close(self):
        self.pose.close()
