from __future__ import annotations

import os
import threading
import time

import cv2
import numpy as np

from ai.roi_lane_engine import map_to_lanes, decide_lane, compute_green_time

try:
    from database_logger import insert_traffic_event
except ImportError:
    print("[WARN] database_logger.py not found in RPi folder. Database inserts will be simulated.")
    insert_traffic_event = None


VEHICLE_CLASSES = {"car", "bus", "motorcycle", "motorbike"}
PEDESTRIAN_CLASSES = {"pedestrian", "person"}


class HeadlessLoop:
    def __init__(self, ctx, traffic, sensors) -> None:
        self._ctx = ctx
        self._traffic = traffic
        self._sensors = sensors

        if os.environ.get("AUTORUN") == "1":
            self._system_active = True
            print("[SYSTEM] 💻 Laptop/Development mode detected via environment flag. Auto-activating tracking...")
        else:
            self._system_active = False
            print("[SYSTEM] 🔋 Standalone mode. Waiting for physical hardware button event.")

        self._running = True
        self._stop_event = threading.Event()
        self._thread: threading.Thread | None = None
        self._lock = threading.Lock()

        self._latest_frame: np.ndarray | None = None
        self._latest_annotated: np.ndarray | None = None
        self._latest_counts: dict = {}
        self._latest_detection_count: int = 0
        self._latest_vehicle_count: int = 0
        self._latest_active_lane: str = "north"
        self._latest_green_time = 10
        self._latest_distances: dict = {}
        self._latest_status: str = "System Idle. Press button to start."

        self._last_green_axis = "NS"
        self._current_phase_axis = "none"
        self._next_target_axis = "none"
        self._phase_start_time = 0.0
        self._phase_duration = 0.0
        self._in_yellow_transition = False
        self._short_phase_active = False

        self._yellow_duration = 3.0
        self._idle_switch_duration = 8.0  # Automatic cycle timing when lanes are tied

        try:
            import RPi.GPIO as GPIO
            GPIO.setmode(GPIO.BCM)
            GPIO.setup(26, GPIO.IN, pull_up_down=GPIO.PUD_UP)

            GPIO.add_event_detect(
                26,
                GPIO.FALLING,
                callback=lambda channel: self.on_button_press(),
                bouncetime=300
            )
            print("[HARDWARE] ✅ Physical button edge interrupt registered on GPIO pin 26.")
        except Exception as gpio_err:
            print(f"[HARDWARE WARN] GPIO hardware listener failed (Expected if testing on a laptop): {gpio_err}")

        if not self._system_active:
            self._ctx.lcd.show_lines("SYSTEM CLOSED", "PRESS TO START")
        else:
            self._ctx.lcd.show_lines("STARTING...", "Auto Laptop Mode")

        print("[AUTO-START] Spawning core processing worker thread...")
        self._thread = threading.Thread(
            target=self._loop, daemon=True, name="headless-loop"
        )
        self._thread.start()

    def on_button_press(self) -> None:
        with self._lock:
            self._system_active = not self._system_active

            if self._system_active:
                print("[BTN] Toggle clicked -> STARTING Traffic Management System...")
                self._ctx.lcd.show_lines("STARTING...", "Initializing Cam")
                self._current_phase_axis = "none"
                self._in_yellow_transition = False
                self._short_phase_active = False
                self._phase_start_time = time.time()
            else:
                print("[BTN] Toggle clicked -> CLOSING Traffic Management System...")
                self._ctx.lcd.show_lines("SYSTEM CLOSED", "PRESS TO START")
                self._traffic.all_red()

    def on_hold_press(self) -> None:
        self.on_button_press()

    def stop(self) -> None:
        self._running = False
        self._stop_event.set()

    def is_running(self) -> bool:
        return self._running

    def get_latest(self) -> dict:
        with self._lock:
            return {
                "annotated": self._latest_annotated,
                "counts": dict(self._latest_counts),
                "detection_count": self._latest_detection_count,
                "vehicle_count": self._latest_vehicle_count,
                "active_lane": self._latest_active_lane,
                "green_time": self._latest_green_time,
                "distances": dict(self._latest_distances),
                "status": self._latest_status,
            }

    def _get_frame(self) -> np.ndarray | None:
        from inputs.registry import INPUT_RPI_CAMERA
        provider = self._ctx.providers.get(INPUT_RPI_CAMERA)
        if provider is None or provider._cap is None:
            return None

        if hasattr(provider, "_lock"):
            with provider._lock:
                ret, frame = provider._cap.read()
        else:
            ret, frame = provider._cap.read()

        if not ret or frame is None:
            return None

        try:
            frame = cv2.flip(frame, -1)
            return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        except Exception:
            return None

    def _loop(self) -> None:
        self._running = True
        cfg = self._ctx.cfg
        print("[LOOP] Thread active and monitoring toggle status.")

        while self._running and not self._stop_event.is_set():
            try:
                if not self._system_active:
                    with self._lock:
                        self._latest_status = "System Idle. Waiting for activation click."
                    time.sleep(0.2)
                    continue

                frame = self._get_frame()
                if frame is None:
                    if self._running and self._system_active:
                        self._ctx.lcd.show_lines("CAM ERROR", "No frame")
                    with self._lock:
                        self._latest_status = "Camera error — no frame"
                    time.sleep(1)
                    continue

                # 📐 Calculate geometric bounds
                h, w = frame.shape[:2]
                
                # 🛡️ 1. Corner ROI Masks (Removes background structures/sidewalks)
                cv2.rectangle(frame, (0, 0), (int(w * 0.35), int(h * 0.35)), (0, 0, 0), -1)
                cv2.rectangle(frame, (0, int(h * 0.64)), (int(w * 0.34), h), (0, 0, 0), -1)
                cv2.rectangle(frame, (int(w * 0.56), 0), (w, int(h * 0.33)), (0, 0, 0), -1)
                cv2.rectangle(frame, (int(w * 0.58), int(h * 0.60)), (w, h), (0, 0, 0), -1)

                # 🎯 2. Middle Region Mask (Blocks center zone to ignore non-road objects/crosswalk interference)
                start_x = int(w * 0.35)
                end_x = int(w * 0.55)
                start_y = int(h * 0.35)
                end_y = int(h * 0.60)
                cv2.rectangle(frame, (start_x, start_y), (end_x, end_y), (0, 0, 0), -1)

                # Execute YOLO model tracking on the fully structured masked frame
                result = self._ctx.models["yolo"].run(frame, conf=0.25, iou=0.45)
                counts = result.payload.get("counts", {})
                detection_count = result.payload.get("detection_count", 0)
                detections = result.payload.get("detections", [])

                distances = self._sensors.get_distances()

                valid_vehicles = []
                pedestrian_count = 0

                for d in detections:
                    if d.get("bbox") is None:
                        continue

                    raw_name = d.get("class_name") or d.get("class") or d.get("label") or ""
                    class_name_clean = str(raw_name).strip().lower()
                    class_id = d.get("class_id")

                    is_pedestrian = class_name_clean in PEDESTRIAN_CLASSES or class_id == 0
                    is_vehicle = class_name_clean in VEHICLE_CLASSES or class_id in {1, 2, 3, 5, 7}

                    if is_pedestrian:
                        pedestrian_count += 1
                        continue

                    if is_vehicle:
                        valid_vehicles.append(d)

                if valid_vehicles:
                    raw_density = map_to_lanes(valid_vehicles, frame.shape)
                    camera_density = {
                        "north": raw_density.get("east", 0),
                        "south": raw_density.get("west", 0),
                        "east": raw_density.get("south", 0),
                        "west": raw_density.get("north", 0),
                    }
                else:
                    camera_density = {"north": 0, "south": 0, "east": 0, "west": 0}

                lane_density = {"north": 0, "south": 0, "east": 0, "west": 0}
                for lane_key in lane_density.keys():
                    camera_count = camera_density.get(lane_key, 0)
                    sensor_reading = distances.get(lane_key, -1)

                    if camera_count > 0:
                        if sensor_reading == -1 or sensor_reading < 150:
                            lane_density[lane_key] = camera_count
                    else:
                        lane_density[lane_key] = 0

                north_south_total = lane_density["north"] + lane_density["south"]
                east_west_total = lane_density["east"] + lane_density["west"]

                now = time.time()
                elapsed_time = now - self._phase_start_time

                if self._in_yellow_transition:
                    remaining_yellow = max(0, int(self._phase_duration - elapsed_time))
                    display_phase = "YELLOW CAUTION"
                    status_msg = f"Transitioning... Yellow active: {remaining_yellow}s"

                    if elapsed_time >= self._phase_duration:
                        self._in_yellow_transition = False
                        self._current_phase_axis = self._next_target_axis
                        self._phase_start_time = time.time()
                        elapsed_time = 0.0

                        if self._current_phase_axis == "none":
                            self._current_phase_axis = "NS"

                        active_lane = "north" if self._current_phase_axis == "NS" else "east"
                        print(f"[SCHEDULER] Yellow finished. Activating {self._current_phase_axis} GREEN.")
                        self._traffic.apply_traffic(active_lane)
                    else:
                        if self._running and self._system_active:
                            self._ctx.lcd.show_lines(f"{display_phase} {remaining_yellow}s", "CAUTION CLEARING")
                        with self._lock:
                            self._latest_status = status_msg
                            self._latest_green_time = remaining_yellow
                        time.sleep(1.0 / cfg.default_target_fps)
                        continue

                # 💡 MAIN STATE SCHEDULER LOGIC
                if not self._in_yellow_transition:
                    if self._current_phase_axis == "none":
                        # Tie break logic check at initialization
                        if north_south_total == east_west_total:
                            print(f"[SCHEDULER] Both axes have equal load ({north_south_total}). Automatic timed cycle triggered.")
                            self._current_phase_axis = "NS"
                            self._traffic.apply_traffic("north")
                        elif north_south_total > east_west_total:
                            self._current_phase_axis = "NS"
                            self._traffic.apply_traffic("north")
                        else:
                            self._current_phase_axis = "EW"
                            self._traffic.apply_traffic("east")
                        self._phase_start_time = time.time()
                        elapsed_time = 0.0
                        self._short_phase_active = False

                    current_axis = self._current_phase_axis

                    if current_axis == "NS":
                        active_vehicles = north_south_total
                        opposing_vehicles = east_west_total
                        next_axis_candidate = "EW"
                        apply_yellow_lane = "north"
                    else:
                        active_vehicles = east_west_total
                        opposing_vehicles = north_south_total
                        next_axis_candidate = "NS"
                        apply_yellow_lane = "east"

                    should_switch = False
                    lcd_time_str = "INF"

                    # RULE: Both lanes have same number of vehicles (or both completely empty) -> Automatic switching loop
                    if active_vehicles == opposing_vehicles:
                        lcd_time_str = f"{max(0, int(self._idle_switch_duration - elapsed_time))}s"
                        if elapsed_time >= self._idle_switch_duration:
                            print(f"[SCHEDULER] Lanes equal ({active_vehicles} each) for {self._idle_switch_duration}s. Auto-cycling axis.")
                            should_switch = True
                            self._short_phase_active = False

                    # RULE: Active lane has cars, opposing side is clean -> Run infinitely
                    elif active_vehicles > 0 and opposing_vehicles == 0:
                        should_switch = False
                        self._short_phase_active = False
                        lcd_time_str = "INF"

                    # RULE: Active lane cleared but opposing lane has traffic waiting -> Switch instantly
                    elif active_vehicles == 0 and opposing_vehicles > 0:
                        print("[SCHEDULER] Active axis cleared while vehicles wait on opposing axis. Switching instantly.")
                        should_switch = True
                        self._short_phase_active = False

                    # RULE: Active traffic vs Opposing traffic processing rule trees
                    elif active_vehicles > 0 and opposing_vehicles > 0:
                        if self._short_phase_active:
                            # Running the 10-second priority extension phase cap
                            if elapsed_time >= 10.0:
                                print("[SCHEDULER] 10s shorter rotation window complete. Rotating phase axis back.")
                                should_switch = True
                                self._short_phase_active = False
                            else:
                                lcd_time_str = f"{max(0, int(10.0 - elapsed_time))}s"
                        else:
                            # ⏱️ The 10-Second Mid Phase Verification Audit
                            if elapsed_time >= 10.0 and elapsed_time < 30.0:
                                if opposing_vehicles > active_vehicles:
                                    print(f"[SCHEDULER] Opposing side has a larger load ({opposing_vehicles} vs {active_vehicles}) at 10s check. Switching early.")
                                    should_switch = True
                                    self._short_phase_active = False
                                else:
                                    # Opposing load is less or equal -> Proceed up to 30 second ceiling limit safely
                                    lcd_time_str = f"{max(0, int(30.0 - elapsed_time))}s"

                            elif elapsed_time >= 30.0:
                                print(f"[SCHEDULER] 30s ceiling cap achieved. Opposing traffic detected ({opposing_vehicles}).")
                                should_switch = True

                                # If opposing load was smaller/equal after 30s max run, grant a 10s rotation window
                                if opposing_vehicles <= active_vehicles:
                                    print("[SCHEDULER] Opposing side has less or equal vehicles. Granting 10s short phase window.")
                                    self._short_phase_active = True
                                else:
                                    self._short_phase_active = False
                            else:
                                # Prior to passing the initial 10s checklist point
                                lcd_time_str = f"{max(0, int(30.0 - elapsed_time))}s"

                    if should_switch:
                        print(f"[SCHEDULER] Initiating Phase Switch Sequence: {current_axis} ➔ {next_axis_candidate}")
                        self._traffic.apply_yellow(apply_yellow_lane)

                        self._in_yellow_transition = True
                        self._next_target_axis = next_axis_candidate
                        self._phase_duration = self._yellow_duration
                        self._phase_start_time = time.time()
                        elapsed_time = 0.0
                        display_phase = "YELLOW CAUTION"
                        status_msg = f"Yellow active transition to {next_axis_candidate}"
                    else:
                        display_phase = "N/S GREEN" if current_axis == "NS" else "E/W GREEN"
                        status_msg = f"Phase: {current_axis} | Active Vehicles: {active_vehicles} | Opposing: {opposing_vehicles}"
                        self._last_green_axis = current_axis

                if self._running and self._system_active:
                    self._ctx.lcd.show_lines(
                        f"{display_phase} {lcd_time_str}",
                        f"N:{lane_density['north']} S:{lane_density['south']} E:{lane_density['east']} W:{lane_density['west']}"
                    )

                active_lane_out = "none" if self._current_phase_axis == "none" else ("north" if self._current_phase_axis == "NS" else "east")
                sensor_str = " ".join(f"{k[0].upper()}:{v}cm" for k, v in distances.items())
                print(f"[DETECT] {status_msg} | N:{lane_density['north']} S:{lane_density['south']} E:{lane_density['east']} W:{lane_density['west']} | Sensors: {sensor_str}")

                with self._lock:
                    self._latest_annotated = result.annotated_rgb
                    self._latest_counts = counts
                    self._latest_detection_count = detection_count
                    self._latest_vehicle_count = north_south_total + east_west_total
                    self._latest_active_lane = active_lane_out
                    self._latest_green_time = 999 if lcd_time_str == "INF" else (int(lcd_time_str.replace("s", "")) if "s" in lcd_time_str else 0)
                    self._latest_distances = distances
                    self._latest_status = status_msg

                if insert_traffic_event is not None and active_lane_out != "none" and not self._in_yellow_transition:
                    current_counts = {
                        "car": counts.get("car", 0) or counts.get("Car", 0),
                        "pedestrian": counts.get("pedestrian", 0) or counts.get("Pedestrian", 0),
                        "bus": counts.get("bus", 0) or counts.get("Bus", 0),
                        "motorcycle": counts.get("motorcycle", 0) or counts.get("motorbike", 0) or counts.get("Motorcycle", 0),
                    }
                    try:
                        insert_traffic_event(
                            lane_name=active_lane_out,
                            density_score=float(lane_density[active_lane_out]),
                            green_duration=10,
                            signal_state_name="green",
                            class_counts=current_counts
                        )
                    except Exception as db_err:
                        print(f"[DB ENGINE ERROR] Database log execution failed: {db_err}")

                time.sleep(1.0 / cfg.default_target_fps)

            except Exception as e:
                print(f"[ERROR] Loop crash event: {e}")
                if self._running and self._system_active:
                    self._ctx.lcd.show_lines("ERROR", str(e)[:20])
                with self._lock:
                    self._latest_status = f"Error: {e}"
                time.sleep(1)

        self._traffic.all_red()
        self._running = False
        print("[LOOP] Stopped")