import cv2
import numpy as np
import time
from collections import deque
import threading
import queue
import socket

# ==============================
# CONFIGURABLE CONSTANTS
# ==============================
WIDTH = 480
HEIGHT = 640
PORT = 'COM3'
BAUD = 115200
TIMEOUT = 0.1
VIDEO_STREAM_URL = "http://192.168.1.11:8080/video"

esp_ip = '192.168.1.10'  # ESP32's IP
esp_port = 80

LEFT_THRESHOLD = 7000
RIGHT_THRESHOLD = 7000
FORWARD_THRESHOLD = 7000
BOTTOM_THRESHOLD = 7000
TOP_RIGHT_THRESHOLD = 6000

FORWARD_INTERVAL = 1.0
ADJUSTMENT_INTERVAL = 2.0
TURN_DELAY = 30 #20
UTURN_DELAY = 40 #50

KP = 0.8
CENTER_TOLERANCE = 15
MIN_INPUT = -50
MAX_INPUT = 50
MIN_PWM = 1
MAX_PWM = 50

DIRECTION_STABILITY_FRAMES = 3

ATTEMPTFORLR = 50
ATTEMPTFORLRCOPY = ATTEMPTFORLR
ATTEMPTFORU = 100

cmd_queue = queue.Queue()
is_busy = False
recovering = False
last_turn = "left"
recovery_sequence = 0

robot_stopped = False
traversed_path = []
short_path = []
short_path_index = 0
last_path_append_time = 0
FORWARD_APPEND_COOLDOWN = 5.0

is_backtracking = False

def serial_writer():
    global is_busy, sock
    while True:
        cmd, delay = cmd_queue.get()
        is_busy = True
        print(f"[Thread] Sending command: {cmd}")
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.connect((esp_ip, esp_port))
            s.sendall(cmd.encode())
        time.sleep(delay)
        is_busy = False

threading.Thread(target=serial_writer, daemon=True).start()

cap = cv2.VideoCapture(VIDEO_STREAM_URL)

recent_moves = deque(maxlen=DIRECTION_STABILITY_FRAMES)
last_move_type = None
last_move_time = 0

def send_command(cmd, delay=0):
    cmd_queue.put((cmd, delay))

def scale_pwm(value):
    value = max(MIN_INPUT, min(value, MAX_INPUT))
    abs_value = abs(value)
    return int((abs_value / MAX_INPUT) * (MAX_PWM - MIN_PWM) + MIN_PWM)

def get_direction_and_centroid(frame):
    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    blur = cv2.GaussianBlur(gray, (5, 5), 0)
    _, binary = cv2.threshold(blur, 60, 255, cv2.THRESH_BINARY_INV)

    height, width = binary.shape

    left_box = (int(width * 0.05), int(height * 0.55), int(width * 0.15), int(height * 0.25))
    right_box = (int(width * 0.8), int(height * 0.55), int(width * 0.15), int(height * 0.25))
    bottom_box = (int(width * 0.3), int(height * 0.85), int(width * 0.4), int(height * 0.15))
    top_box = (int(width * 0.3), int(height * 0.05), int(width * 0.4), int(height * 0.15))
    top_right_box = (int(width * 0.8), int(height * 0.05), int(width * 0.15), int(height * 0.15))

    left_roi = binary[left_box[1]:left_box[1]+left_box[3], left_box[0]:left_box[0]+left_box[2]]
    right_roi = binary[right_box[1]:right_box[1]+right_box[3], right_box[0]:right_box[0]+right_box[2]]
    bottom_roi = binary[bottom_box[1]:bottom_box[1]+bottom_box[3], bottom_box[0]:bottom_box[0]+bottom_box[2]]
    top_roi = binary[top_box[1]:top_box[1]+top_box[3], top_box[0]:top_box[0]+top_box[2]]
    top_right_roi = binary[top_right_box[1]:top_right_box[1]+top_right_box[3], top_right_box[0]:top_right_box[0]+top_right_box[2]]

    left_black = cv2.countNonZero(left_roi)
    right_black = cv2.countNonZero(right_roi)
    bottom_black = cv2.countNonZero(bottom_roi)
    top_black = cv2.countNonZero(top_roi)
    top_right_black = cv2.countNonZero(top_right_roi)

    moments = cv2.moments(bottom_roi)
    cx = int(moments["m10"] / moments["m00"]) + bottom_box[0] if moments["m00"] != 0 else WIDTH // 2

    cv2.rectangle(frame, (left_box[0], left_box[1]), (left_box[0]+left_box[2], left_box[1]+left_box[3]), (0, 255, 0), 2)
    cv2.rectangle(frame, (right_box[0], right_box[1]), (right_box[0]+right_box[2], right_box[1]+right_box[3]), (0, 255, 0), 2)
    cv2.rectangle(frame, (bottom_box[0], bottom_box[1]), (bottom_box[0]+bottom_box[2], bottom_box[1]+bottom_box[3]), (0, 255, 255), 2)
    cv2.rectangle(frame, (top_box[0], top_box[1]), (top_box[0]+top_box[2], top_box[1]+top_box[3]), (0, 0, 255), 2)
    cv2.rectangle(frame, (top_right_box[0], top_right_box[1]), (top_right_box[0]+top_right_box[2], top_right_box[1]+top_right_box[3]), (255, 0, 255), 2)

    cv2.line(frame, (WIDTH // 2, bottom_box[1]), (WIDTH // 2, bottom_box[1]+bottom_box[3]), (0, 0, 255), 2)
    cv2.circle(frame, (cx, bottom_box[1]+bottom_box[3]//2), 5, (0, 0, 255), -1)

    cv2.putText(frame, f"L:{left_black}", (left_box[0], left_box[1]-10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
    cv2.putText(frame, f"B:{bottom_black}", (bottom_box[0], bottom_box[1]-10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
    cv2.putText(frame, f"R:{right_black}", (right_box[0], right_box[1]-10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
    cv2.putText(frame, f"T:{top_black}", (top_box[0], top_box[1]-10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)

    direction_flags = {
        "left": ((left_black > LEFT_THRESHOLD) or (right_black>left_black>2000)), 
        "right": (right_black > RIGHT_THRESHOLD),
        "forward": (top_black > FORWARD_THRESHOLD),
        "bottom": (bottom_black > BOTTOM_THRESHOLD),
        "stop": (top_right_black > TOP_RIGHT_THRESHOLD)
    }
    #print("TOP BLAC:",top_right_black)
    return direction_flags, cx, frame

def get_shortest_path(traversed_path):
    stack = []
    simplification_rules = {
        ('left', 'u-turn', 'right'): 'u-turn',
        ('left', 'u-turn', 'forward'): 'right',
        ('right', 'u-turn', 'left'): 'u-turn',
        ('forward', 'u-turn', 'left'): 'right',
        ('forward', 'u-turn', 'forward'): 'u-turn',
        ('left', 'u-turn', 'left'): 'forward',
    }
    i = 0
    while i < len(traversed_path):
        stack.append(traversed_path[i])
        # Try to reduce the last 3 items if they match any rule
        while len(stack) >= 3:
            last_three = tuple(stack[-3:])
            if last_three in simplification_rules:
                # Replace last 3 with the reduced move
                stack[-3:] = [simplification_rules[last_three]]
            else:
                break
        i += 1
    return stack

# ==============================
# MAIN LOOP
# ==============================

while True:
    ret, frame = cap.read()
    if not ret or frame is None:
        print("Failed to grab frame. Retrying...")
        break

    frame = cv2.resize(frame, (WIDTH, HEIGHT))
    current_time = time.time()

    k = cv2.waitKey(1)
    if k & 0xFF == ord('q'):
        #close_ser()
        cap.release()
        cv2.destroyAllWindows()
        break
    if k & 0xFF == ord('e'):
        print("Shortest Path: ", short_path, len(short_path))
        short_path_index = 0
        is_backtracking = True
        robot_stopped = False

    directions, cx, frame = get_direction_and_centroid(frame)
    recent_moves.append(tuple(directions.items()))

    stable = all(m == recent_moves[0] for m in recent_moves)

    move = None
    move_type = None

    left = directions['left']
    right = directions['right']
    forward = directions['forward']
    bottom = directions['bottom']
    stop = directions['stop']

    if stable and not is_busy and not recovering and not is_backtracking and not robot_stopped:
        print("Traversed Path: ", traversed_path)
        move = None
        move_type = None
        if stop and left and bottom:
            move = "stop"
            move_type = "stop"
            print("Decision: Stop")
            send_command("stop", 10)
            last_turn = "stop"
            print("Traversed Path: ",traversed_path)
            short_path = get_shortest_path(traversed_path)
            print("Shortest Path: ", short_path, len(short_path))
            robot_stopped = True

        elif left or (left and bottom) or (left and bottom and forward) or (left and right and bottom and forward):
            move = "left"
            move_type = "turn"
            traversed_path.append(move)
            print("Decision: Left turn")
            send_command("l", TURN_DELAY)
            last_turn = "left"
            recovery_sequence = 0
            recovering = True

        elif (bottom and forward) or (bottom and forward and right):
            deviation = cx - WIDTH // 2
            if abs(deviation) < CENTER_TOLERANCE:
                move = "forward"
                move_type = "forward"
                if bottom and forward and right:
                    if current_time - last_path_append_time > FORWARD_APPEND_COOLDOWN:
                        traversed_path.append("forward")
                        last_path_append_time = current_time
                if move_type != last_move_type or (current_time - last_move_time) >= FORWARD_INTERVAL:
                    print("Decision: forward")
                    send_command("f", FORWARD_INTERVAL)
                    last_move_time = current_time
                    last_move_type = move_type
            else:
                direction = "right" if deviation > 0 else "left"
                pwm_value = scale_pwm(KP * deviation)
                move = f"adjust {direction} ({pwm_value})"
                move_type = "adjust"
                if move_type != last_move_type or (current_time - last_move_time) >= ADJUSTMENT_INTERVAL:
                    print("Decision:", move)
                    command = f"ar{pwm_value}" if direction == "right" else f"al{pwm_value}"
                    send_command(command, ADJUSTMENT_INTERVAL)
                    last_move_time = current_time
                    last_move_type = move_type
        
        elif right or (right and bottom):
            move = "right"
            move_type = "turn"
            traversed_path.append(move)
            print("Decision: Right turn")
            send_command("r", TURN_DELAY)
            last_turn = "right"
            recovery_sequence = 0
            recovering = True
        
        elif bottom:
            deviation = cx - WIDTH // 2
            if abs(deviation) < CENTER_TOLERANCE:
                move = "forward"
                move_type = "forward"
                if move_type != last_move_type or (current_time - last_move_time) >= FORWARD_INTERVAL:
                    print("Decision: forward")
                    send_command("f", FORWARD_INTERVAL)
                    last_move_time = current_time
                    last_move_type = move_type
            else:
                direction = "right" if deviation > 0 else "left"
                pwm_value = scale_pwm(KP * deviation)
                move = f"adjust {direction} ({pwm_value})"
                move_type = "adjust"
                if move_type != last_move_type or (current_time - last_move_time) >= ADJUSTMENT_INTERVAL:
                    print("Decision:", move)
                    command = f"ar{pwm_value}" if direction == "right" else f"al{pwm_value}"
                    send_command(command, ADJUSTMENT_INTERVAL)
                    last_move_time = current_time
                    last_move_type = move_type

        elif not forward and not left and not right and not bottom:
            move = "u-turn"
            move_type = "uturn"
            traversed_path.append(move)
            print("Decision: U-turn")
            send_command("u", UTURN_DELAY)
            last_turn = "u"
            recovery_sequence = 0
            recovering = True
            
        recovery_loop_last_time = time.time()

    elif recovering and not is_busy:
        recovery_every_sec = 1
        recovery_current_time = time.time()
        if (last_turn == "left" or last_turn == "u"):
            if(recovery_current_time-recovery_loop_last_time>recovery_every_sec) and not is_busy and ATTEMPTFORLR >= 0:
                send_command("al30", ADJUSTMENT_INTERVAL)
                if directions['forward'] == True and directions['bottom'] == True:
                    recovering = False
                    ATTEMPTFORLR = ATTEMPTFORLRCOPY
                ATTEMPTFORLR = ATTEMPTFORLR - 1    
                recovery_loop_last_time = recovery_current_time
            if(recovery_current_time-recovery_loop_last_time>recovery_every_sec) and not is_busy and ATTEMPTFORLR >= ATTEMPTFORLRCOPY*-2:
                send_command("ar30", ADJUSTMENT_INTERVAL)
                if directions['forward'] == True and directions['bottom'] == True:
                    recovering = False
                    ATTEMPTFORLR = ATTEMPTFORLRCOPY
                ATTEMPTFORLR = ATTEMPTFORLR - 1    
                recovery_loop_last_time = recovery_current_time
        elif last_turn == "right":
            if(recovery_current_time-recovery_loop_last_time>recovery_every_sec) and not is_busy and ATTEMPTFORLR >= 0:
                send_command("ar30", ADJUSTMENT_INTERVAL)
                if directions['forward'] == True and directions['bottom'] == True:
                    recovering = False
                    ATTEMPTFORLR = ATTEMPTFORLRCOPY
                ATTEMPTFORLR = ATTEMPTFORLR - 1    
                recovery_loop_last_time = recovery_current_time
            if(recovery_current_time-recovery_loop_last_time>recovery_every_sec) and not is_busy and ATTEMPTFORLR >= ATTEMPTFORLRCOPY*-2:
                send_command("al30", ADJUSTMENT_INTERVAL)
                if directions['forward'] == True and directions['bottom'] == True:
                    recovering = False
                    ATTEMPTFORLR = ATTEMPTFORLRCOPY
                ATTEMPTFORLR = ATTEMPTFORLR - 1    
                recovery_loop_last_time = recovery_current_time
    
    elif recovering and is_busy:
        if directions['bottom'] == False and directions['forward'] == False and directions['left'] == False and directions['right'] == False and recovery_sequence == 0:
            print("Starting Recovery Switch Detection")
            recovery_sequence += 1
        elif recovery_sequence == 1 and (last_turn == "left" or last_turn == "u") and directions['right'] == True:
            last_turn = "right"
            print("Recovery Turn Switched")
            recovery_sequence += 1
        elif recovery_sequence == 1 and (last_turn == "right") and directions['left'] == True:
            last_turn = "left"
            print("Recovery Turn Switched")
            recovery_sequence += 1

    elif is_backtracking and not robot_stopped and not recovering and not is_busy:
        if short_path_index < len(short_path):
            looking_for = short_path[short_path_index]
        else:
            looking_for = "stop"
        move = None
        move_type = None

        if stop and left and bottom:
            print("Decision: Stop")
            send_command("stop", 10)
            robot_stopped = True

        elif looking_for == "forward" and ((left and bottom and forward) or (right and bottom and forward) or (left and right and bottom and forward)):
            deviation = cx - WIDTH // 2
            if abs(deviation) < CENTER_TOLERANCE:
                move = "forward"
                move_type = "forward"
                if move_type != last_move_type or (current_time - last_move_time) >= FORWARD_INTERVAL:
                    print("Looking For: ",looking_for," Decision: forward", short_path_index)
                    for i in range (0, 10):
                        send_command("f", FORWARD_INTERVAL)
                    last_move_time = current_time
                    last_move_type = move_type
                    short_path_index += 1        
            else:
                direction = "right" if deviation > 0 else "left"
                pwm_value = scale_pwm(KP * deviation)
                move = f"adjust {direction} ({pwm_value})"
                move_type = "adjust"
                if move_type != last_move_type or (current_time - last_move_time) >= ADJUSTMENT_INTERVAL:
                    #print("Decision:", move)
                    command = f"ar{pwm_value}" if direction == "right" else f"al{pwm_value}"
                    send_command(command, ADJUSTMENT_INTERVAL)
                    last_move_time = current_time
                    last_move_type = move_type
        
        elif looking_for == "left" and ((left and bottom) or (left and bottom and forward) or (left and right and bottom and forward)):
            print("Looking For:",looking_for," Decision: Left turn", short_path_index)
            send_command("l", TURN_DELAY)
            last_turn = "left"
            short_path_index += 1
            recovery_sequence = 0
            recovering = True

        elif looking_for == "right" and ((right and bottom) or (right and bottom and forward) or (left and right and bottom and forward)):
            print("Looking For:",looking_for," Decision: Right turn", short_path_index)
            send_command("r", TURN_DELAY)
            last_turn = "right"
            short_path_index += 1
            recovery_sequence = 0
            recovering = True

        elif (bottom and forward) or bottom:
            deviation = cx - WIDTH // 2
            if abs(deviation) < CENTER_TOLERANCE:
                move = "forward"
                move_type = "forward"
                if move_type != last_move_type or (current_time - last_move_time) >= FORWARD_INTERVAL:
                    send_command("f", FORWARD_INTERVAL)
                    #print("Forward")
                    last_move_time = current_time
                    last_move_type = move_type
            else:
                direction = "right" if deviation > 0 else "left"
                pwm_value = scale_pwm(KP * deviation)
                move = f"adjust {direction} ({pwm_value})"
                move_type = "adjust"
                if move_type != last_move_type or (current_time - last_move_time) >= ADJUSTMENT_INTERVAL:
                    #print("Decision:", move)
                    command = f"ar{pwm_value}" if direction == "right" else f"al{pwm_value}"
                    send_command(command, ADJUSTMENT_INTERVAL)
                    last_move_time = current_time
                    last_move_type = move_type
        
        elif not (forward and bottom and left and right):
            print("Robot Stopped As Desired Path Not Found")
            send_command("stop", 10)
            robot_stopped = True
        recovery_loop_last_time = time.time()

    cv2.imshow("Maze View", frame)