import cv2
import numpy as np
from ultralytics import YOLO
import time
from collections import deque
import asyncio
from bleak import BleakClient, BleakScanner
from bleak.backends.characteristic import BleakGATTCharacteristic
from bleak.backends.device import BLEDevice
from bleak.backends.scanner import AdvertisementData
import sys
from typing import Iterator
from itertools import count, takewhile

# BLE Configuration
UART_SERVICE_UUID = "6E400001-B5A3-F393-E0A9-E50E24DCCA9E"
UART_RX_CHAR_UUID = "6E400002-B5A3-F393-E0A9-E50E24DCCA9E"
UART_TX_CHAR_UUID = "6E400003-B5A3-F393-E0A9-E50E24DCCA9E"
BLE_DEVICE_MAC = "2C:CF:67:B0:D5:97"  # Replace with your Pi's BT MAC

# Game parameters
SEGMENTATION_FRAMES = 5
DART_DETECTION_TIMEOUT = 3.0
PRINT_INTERVAL = 0.5
MAX_DART_LENGTH = 300
MIN_DART_LENGTH = 20
DARTS_PER_PLAYER = 3
BASE_BOARD_DIAMETER_PX = 600

# Perspective correction: measured distances from center to triple ring edge
TRIPLE_RING_OFFSETS = {'top': 140, 'right': 180, 'bottom': 200, 'left': 180}

def calculate_ring_sizes(board_diameter_px):
    ratios = {
        'bullseye': 0.028,
        'outer_bull': 0.071,
        'triple_inner': 0.444,
        'triple_outer': 0.488,
        'double_inner': 0.754,
        'double_outer': 0.798
    }
    board_radius = board_diameter_px / 2
    return {k: int(v * board_radius) for k, v in ratios.items()}

def get_perspective_transform(center, triple_ring_offsets):
    cx, cy = center
    r = np.mean(list(triple_ring_offsets.values()))
    src_pts = np.array([
        [cx, cy - r],      # top
        [cx + r, cy],      # right
        [cx, cy + r],      # bottom
        [cx - r, cy],      # left
    ], dtype=np.float32)
    dst_pts = np.array([
        [cx, cy - triple_ring_offsets['top']],
        [cx + triple_ring_offsets['right'], cy],
        [cx, cy + triple_ring_offsets['bottom']],
        [cx - triple_ring_offsets['left'], cy],
    ], dtype=np.float32)
    M = cv2.getPerspectiveTransform(src_pts, dst_pts)
    return M

def apply_perspective_to_circle(center, radius, M, num_points=360):
    cx, cy = center
    angles = np.linspace(0, 2 * np.pi, num_points)
    circle_pts = np.stack([
        cx + radius * np.cos(angles),
        cy + radius * np.sin(angles)
    ], axis=-1).astype(np.float32)
    circle_pts = circle_pts.reshape(-1, 1, 2)
    transformed = cv2.perspectiveTransform(circle_pts, M)
    return transformed.astype(int).reshape(-1, 2)

# Load models
seg_model = YOLO(r"C:\Users\trist\Documents\School\2024-2025\Semester_2\Project one\2024-2025-projectone-ctai-DebrabandereTristan\runs\yolov8l_seg_darts2\weights\best.pt")
pose_model = YOLO(r"C:\Users\trist\Documents\School\2024-2025\Semester_2\Project one\2024-2025-projectone-ctai-DebrabandereTristan\runs\detect\yolov8m_dart_pose4\weights\best.pt")

# Initialize trackers
bullseye_history = deque(maxlen=SEGMENTATION_FRAMES)
dart_history = deque(maxlen=5)

# Global variables
bullseye_position = None
board_diameter_px = BASE_BOARD_DIAMETER_PX
last_dart_time = 0
last_score_time = 0
score_display_duration = 3
last_score_info = None
processing_stage = "calibrating"
last_print_time = 0
current_dart_score = None
ble_client = None
dart_tracking_ended_shown = False
debug_mode = False
current_player = 1
darts_thrown = 0
player_scores = {1: 501, 2: 501}
p_presses = 0
ring_sizes = calculate_ring_sizes(BASE_BOARD_DIAMETER_PX)
pending_score = 0
end_message = ""

def get_sector_score(angle, multiplier):
    sectors = [20, 1, 18, 4, 13, 6, 10, 15, 2, 17, 3, 19, 7, 16, 8, 11, 14, 9, 12, 5]
    rotation_offset = 90
    adjusted_angle = (angle + rotation_offset) % 360
    sector_width = 360 / len(sectors)
    sector_index = int((adjusted_angle + sector_width/2) // sector_width) % 20
    return sectors[sector_index] * multiplier

def find_bullseye_and_board(frame):
    results = seg_model(frame, verbose=False)[0]
    if results.masks is not None and len(results.masks.xy) > 0:
        contours = [cv2.approxPolyDP(mask, 0.02*cv2.arcLength(mask, True), True) 
                   for mask in results.masks.xy]
        best_circle = max(contours, key=lambda c: (4*np.pi*cv2.contourArea(c))/(cv2.arcLength(c,True)**2))
        M = cv2.moments(best_circle)
        if M["m00"] != 0:
            center = np.array([M["m10"]/M["m00"], M["m01"]/M["m00"]])
            return center, BASE_BOARD_DIAMETER_PX
    return None, None

def detect_dart(frame):
    results = pose_model(frame, verbose=False)[0]
    if results.keypoints is not None and len(results.keypoints.xy) > 0:
        best_dart = None
        best_length = 0
        for i in range(len(results.keypoints.xy)):
            keypoints = results.keypoints.xy[i].cpu().numpy()
            if len(keypoints) >= 2:
                tip = keypoints[0]
                tail = keypoints[1]
                dart_length = np.linalg.norm(tip - tail)
                if MIN_DART_LENGTH < dart_length < MAX_DART_LENGTH:
                    if dart_length > best_length:
                        best_length = dart_length
                        best_dart = tip
        return best_dart
    return None

async def connect_ble():
    global ble_client
    max_attempts = 3
    attempt = 0
    while attempt < max_attempts:
        attempt += 1
        try:
            def match_nus_uuid(device: BLEDevice, adv: AdvertisementData):
                if device.address == BLE_DEVICE_MAC and UART_SERVICE_UUID.lower() in adv.service_uuids:
                    return True
                return False
            device = await BleakScanner.find_device_by_filter(match_nus_uuid)
            if device is None:
                print(f"Attempt {attempt}: No matching device found")
                await asyncio.sleep(1)
                continue
            def handle_disconnect(_: BleakClient):
                print("Device disconnected")
                global ble_client
                ble_client = None
            ble_client = BleakClient(device, disconnected_callback=handle_disconnect)
            await ble_client.connect()
            print("Connected to BLE device")
            return True
        except Exception as e:
            print(f"Attempt {attempt} failed: {e}")
            await asyncio.sleep(1)
    return False

async def send_ble_command(command):
    global ble_client
    if ble_client is None:
        return
    try:
        service = ble_client.services.get_service(UART_SERVICE_UUID)
        rx_char = service.get_characteristic(UART_RX_CHAR_UUID)
        for s in sliced(command.encode(), rx_char.max_write_without_response_size):
            await ble_client.write_gatt_char(rx_char, s, response=False)
        print(f"Sent via BLE: {command}")
    except Exception as e:
        print(f"Error sending command via BLE: {e}")

async def send_score_via_ble(score_info):
    if score_info is None:
        return
    score, zone, dist, angle = score_info
    sector = get_sector_score(angle, 1)
    if zone == "Bullseye":
        prediction = "PRED 50"
    elif zone == "Outer bull":
        prediction = "PRED 25"
    elif zone == "Triple":
        prediction = f"PRED T{sector//3}"
    elif zone == "Double":
        prediction = f"PRED D{sector//2}"
    else:
        prediction = f"PRED {sector}" if score > 0 else "PRED 0"
    await send_ble_command(prediction)

def sliced(data: bytes, n: int) -> Iterator[bytes]:
    return takewhile(len, (data[i: i + n] for i in count(0, n)))

async def switch_player():
    global current_player, darts_thrown, p_presses, pending_score
    await send_ble_command("SWITCH")
    current_player = 2 if current_player == 1 else 1
    darts_thrown = 0
    p_presses = 0
    pending_score = 0
    print(f"\nSwitched to Player {current_player}'s turn")

def print_dart_score(score_info):
    if score_info is not None:
        score, zone, dist, angle = score_info
        print(f"\rPlayer {current_player} - Dart {darts_thrown+1}: {score} ({zone}) | Distance: {dist:.1f}px | Angle: {angle:.1f}° | Sector: {get_sector_score(angle, 1)}", end="", flush=True)

def draw_perspective_rings(frame, center, ring_sizes, triple_ring_offsets):
    if center is None or not ring_sizes:
        return
    M = get_perspective_transform(center, triple_ring_offsets)
    overlay = frame.copy()
    for key, color in [
        ('double_outer', (255, 0, 255)),
        ('double_inner', (255, 0, 255)),
        ('triple_outer', (0, 255, 255)),
        ('triple_inner', (0, 255, 255)),
        ('outer_bull', (0, 165, 255)),
        ('bullseye', (0, 0, 255))
    ]:
        # Shrink double rings by 40px, triple rings by 20px, others unchanged
        if 'double' in key:
            shrink = 40
        elif 'triple' in key:
            shrink = 20
        else:
            shrink = 0
        pts = apply_perspective_to_circle(center, ring_sizes[key] - shrink, M)
        cv2.polylines(overlay, [pts], isClosed=True, color=color, thickness=2)
    alpha = 0.3
    cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0, frame)

def calculate_score(dart_tip, bullseye_pos):
    if dart_tip is None or bullseye_pos is None or not ring_sizes:
        return None, None, None, None
    dx = dart_tip[0] - bullseye_pos[0]
    dy = dart_tip[1] - bullseye_pos[1]
    distance_px = np.sqrt(dx**2 + dy**2)
    angle = np.degrees(np.arctan2(dy, dx)) % 360
    if distance_px <= ring_sizes['bullseye']:
        return 50, "Bullseye", distance_px, angle
    elif distance_px <= ring_sizes['outer_bull']:
        return 25, "Outer bull", distance_px, angle
    elif distance_px <= ring_sizes['triple_inner'] - 20:
        return get_sector_score(angle, 1), "Single", distance_px, angle
    elif distance_px <= ring_sizes['triple_outer'] - 20:
        return get_sector_score(angle, 3), "Triple", distance_px, angle
    elif distance_px <= ring_sizes['double_inner'] - 40:
        return get_sector_score(angle, 1), "Single", distance_px, angle
    elif distance_px <= ring_sizes['double_outer'] - 40:
        return get_sector_score(angle, 2), "Double", distance_px, angle
    else:
        return 0, "Miss", distance_px, angle

async def main_async():
    global processing_stage, bullseye_position, board_diameter_px, ring_sizes
    global current_dart_score, last_dart_time, last_print_time
    global last_score_time, last_score_info, dart_tracking_ended_shown
    global debug_mode, current_player, darts_thrown, player_scores, p_presses, pending_score, end_message

    if not await connect_ble():
        return
    cap = cv2.VideoCapture(0)
    try:
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            current_time = time.time()
            key = cv2.waitKey(1)
            if key == ord('c'):
                processing_stage = "calibrating"
                bullseye_position = None
                print("\nRecalibrating bullseye position...")
            elif key == ord('p'):
                if current_dart_score is not None:
                    score, zone, dist, angle = current_dart_score
                    new_score = player_scores[current_player] - score
                    valid_finish = (new_score == 0 and zone == "Double")
                    bust = (new_score < 0) or (new_score == 0 and zone != "Double")
                    one_left = (new_score == 1)
                    if valid_finish:
                        player_scores[current_player] = 0
                        print(f"\nPlayer {current_player} wins with a double out!")
                        end_message = f"Player {current_player} wins!"
                        await send_ble_command("WIN")
                        cv2.imshow('Dart Tracker', frame)
                        cv2.waitKey(2000)
                        player_scores = {1: 501, 2: 501}
                        current_player = 1
                        darts_thrown = 0
                        p_presses = 0
                        pending_score = 0
                        end_message = ""
                        await send_ble_command("RESET")
                        continue
                    elif bust or one_left:
                        if one_left:
                            print(f"\nScore of 1 is not finishable. Turn does not count.")
                            end_message = f"Score 1 not allowed! Score: {player_scores[current_player]}"
                        else:
                            print(f"\nNeed to end with double. Score needed: {player_scores[current_player]}")
                            end_message = f"Need double! Score: {player_scores[current_player]}"
                        darts_thrown += 1
                        p_presses += 1
                        if darts_thrown >= DARTS_PER_PLAYER or p_presses >= 3:
                            await switch_player()
                        continue
                    else:
                        player_scores[current_player] = new_score
                        print(f"\nPlayer {current_player} scored {score} ({zone}), new score: {player_scores[current_player]}")
                        end_message = ""
                        darts_thrown += 1
                        p_presses += 1
                        await send_score_via_ble(current_dart_score)
                        if darts_thrown >= DARTS_PER_PLAYER or p_presses >= 3:
                            await switch_player()
                else:
                    print("\nNo dart currently being tracked")
            elif key == ord('d'):
                debug_mode = not debug_mode
                print(f"\nDebug mode {'enabled' if debug_mode else 'disabled'}")
            elif key == ord('q'):
                break
            elif key == ord('r'):
                player_scores = {1: 501, 2: 501}
                current_player = 1
                darts_thrown = 0
                p_presses = 0
                pending_score = 0
                end_message = ""
                print("\nGame reset! Scores and turns cleared.")
                await send_ble_command("RESET")

            if processing_stage == "calibrating":
                bullseye_pos, _ = find_bullseye_and_board(frame)
                if bullseye_pos is not None:
                    bullseye_history.append(bullseye_pos)
                    if len(bullseye_history) == SEGMENTATION_FRAMES:
                        bullseye_position = np.mean(bullseye_history, axis=0)
                        processing_stage = "tracking"
                        print(f"\nBullseye locked at {bullseye_position}")
                        print(f"Using fixed board diameter: {BASE_BOARD_DIAMETER_PX}px")
                        print("Ring sizes:", ring_sizes)
            elif processing_stage == "tracking":
                dart_tip = detect_dart(frame)
                if dart_tip is not None:
                    last_dart_time = current_time
                    dart_history.append(dart_tip)
                    dart_tracking_ended_shown = False
                    if len(dart_history) >= 1:
                        smoothed_tip = np.mean(dart_history, axis=0)
                        score_info = calculate_score(smoothed_tip, bullseye_position)
                        current_dart_score = score_info
                        if current_time - last_print_time > PRINT_INTERVAL:
                            print_dart_score(score_info)
                            last_print_time = current_time
                        if (last_score_info is None or 
                            np.linalg.norm(smoothed_tip - dart_history[-1]) > 10):
                            last_score_info = score_info
                            last_score_time = current_time
                elif current_time - last_dart_time > DART_DETECTION_TIMEOUT:
                    dart_history.clear()
                    current_dart_score = None
                    if not dart_tracking_ended_shown:
                        print("\nDart tracking ended")
                        dart_tracking_ended_shown = True
            display_frame = frame.copy()
            if bullseye_position is not None:
                center = tuple(map(int, bullseye_position))
                draw_perspective_rings(display_frame, bullseye_position, ring_sizes, TRIPLE_RING_OFFSETS)
                cv2.circle(display_frame, center, 5, (0, 255, 0), -1)
                if len(dart_history) > 0 and current_dart_score is not None:
                    tip_pos = tuple(map(int, dart_history[-1]))
                    cv2.circle(display_frame, tip_pos, 10, (0, 0, 255), -1)
                    cv2.line(display_frame, center, tip_pos, (0, 255, 255), 2)
                    score_text = f"P{current_player}: {current_dart_score[0]} ({current_dart_score[1]})"
                    cv2.putText(display_frame, score_text, 
                               (tip_pos[0] + 15, tip_pos[1] - 15), 
                               cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2)
                    info_text = f"{current_dart_score[2]:.1f}px, {current_dart_score[3]:.1f}°"
                    cv2.putText(display_frame, info_text, 
                               (tip_pos[0] + 15, tip_pos[1] + 15), 
                               cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 1)
            if debug_mode:
                results = pose_model(frame, verbose=False)[0]
                if results.keypoints is not None:
                    for kp in results.keypoints.xy[0].cpu().numpy():
                        if len(kp) == 2:
                            cv2.circle(display_frame, (int(kp[0]), int(kp[1])), 5, (255, 0, 0), -1)
            player_info = f"Player {current_player} - Dart {darts_thrown+1}/{DARTS_PER_PLAYER}"
            cv2.putText(display_frame, player_info, 
                       (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
            score_info = f"P1: {player_scores[1]} | P2: {player_scores[2]}"
            cv2.putText(display_frame, score_info, 
                       (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
            press_info = f"Presses: {p_presses}/3"
            cv2.putText(display_frame, press_info,
                       (10, 110), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 165, 255), 2)
            # Draw button help text at the bottom
            button_text = "C: Calibrate  P: Score  D: Debug  Q: Quit  R: Reset"
            cv2.putText(display_frame, button_text, (10, display_frame.shape[0] - 20),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
            # Show end message if needed
            if end_message:
                cv2.putText(display_frame, end_message, (10, display_frame.shape[0] - 60),
                            cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 3, cv2.LINE_AA)
            cv2.imshow('Dart Tracker', display_frame)
    finally:
        cap.release()
        cv2.destroyAllWindows()
        if ble_client and ble_client.is_connected:
            await ble_client.disconnect()
        print("\nFinal Scores:")
        print(f"Player 1: {player_scores[1]}")
        print(f"Player 2: {player_scores[2]}")
        print("Program terminated")

def main():
    asyncio.run(main_async())

if __name__ == '__main__':
    main()
