import cv2
import numpy as np
import socket
import time
import mediapipe as mp
import face_recognition
from math import radians, cos, sin, tan

WEMOS_IP = "192.168.233.158"  # Update this with the actual IP
PORT = 80

def sc(channel, angle):
    """Sends a command to control the servo at specified channel and angle."""
    if not (0 <= channel <= 15 and 0 <= angle <= 180):
        print("Invalid channel or angle. Channel must be 0-15 and angle must be 0-180.")
        return
    
    command = f"C{channel}:{angle}\n"
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as client_socket:
        try:
            client_socket.connect((WEMOS_IP, PORT))
            client_socket.sendall(command.encode())
            response = client_socket.recv(1024).decode()
            
        except Exception as e:
            print("Connection error:", e)
sc(13, 170)
sc(15, 110)

# Arm configuration
n_links = 3
gripper_angles = [0, -45, -90, 45, 90]
link_lengths = [0, 160, 145, 1]
curr_gripper = 0
origin = (0, 0)

# Arm joint positions (initialize as floats)
w = [0.0] * n_links
z = [0.0] * n_links
a = [0.0] * n_links

# Target coordinates
tw, tz = 180, 180
tw0, tz0 = tw, tz

# Socket connection setup

# Functions for inverse kinematics calculations
def calc_p2():
    """Calculates the position of the second joint and the distance to the target."""
    global w, z, l12
    w[2] = tw - cos(radians(gripper_angles[curr_gripper])) * link_lengths[3]
    z[2] = tz - sin(radians(gripper_angles[curr_gripper])) * link_lengths[3]
    l12 = np.sqrt(w[2]**2 + z[2]**2)

def calc_p1():
    """Calculates the angles of the arm joints based on the position of the second joint."""
    global w, z, a
    a12 = np.arctan2(z[2], w[2])
    a[1] = np.arccos((link_lengths[1]**2 + l12**2 - link_lengths[2]**2) / (2 * link_lengths[1] * l12)) + a12
    w[1] = np.cos(a[1]) * link_lengths[1]
    z[1] = np.sin(a[1]) * link_lengths[1]
    a[2] = np.arctan2(z[2] - z[1], w[2] - w[1]) - a[1]
    
    print(f"Link 1 angle: {np.degrees(a[1]):.2f}°")
    print(f"Link 2 angle: {180 - abs(np.degrees(a[2])):.2f}°")
    sc(0, int(np.degrees(a[1])) )
    sc(12, int(180-(90 - abs(np.degrees(a[2]))))-10)

def process_arm(cord):
    """Processes arm movement based on the target coordinates."""
    global tw, tz, tw0, tz0
    tw = int(cord)
    print(f"Received target position: {tw}")
    tz = 140
    print(tz)
    
    calc_p2()
    if l12 > link_lengths[1] + link_lengths[2]:  # Out of reach check
        print("No Solution. Move target closer to origin.")
        tw, tz = tw0, tz0
        calc_p2()
        
    calc_p1()
    tw0, tz0 = tw, tz

# Camera and grid configuration
num_rows = 20
num_cols = 20
height = 40  # cm
fov = 80  # degrees
theta = 50  # degrees
gap_angle = fov / num_rows  # degrees
gap = 2 * height * tan(radians(gap_angle / 2))

# Open video capture
cap = cv2.VideoCapture(0)
if not cap.isOpened():
    print("Error: Could not open the video source.")
    exit()

# Initialize variables for persistence check
last_grid_index = None
persist_count = 0
required_persist_frames = 5  # Adjust this for sensitivity (e.g., 10 frames)

while True:
    ret, frame = cap.read()
    if not ret:
        print("Error: Could not read frame.")
        break

    # Get frame dimensions
    height, width, _ = frame.shape

    # Calculate spacing for the grid lines
    row_height = height // num_rows
    col_width = width // num_cols

    # Generate grid intersection points
    intersection_points = [
        ((col * col_width, row * row_height), (row + 1, col + 1))  # Store both pixel and grid indices
        for row in range(num_rows + 1)
        for col in range(num_cols + 1)
    ]

    # Convert frame to HSV for color detection
    hsv_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)

    # Define HSV range for detecting blue color
    lower_yellow = np.array([140, 50, 50])  # Adjust the values as needed
    upper_yellow = np.array([170, 255, 255])  # Adjust the values as needed

    # Create a mask for yellow color
    mask = cv2.inRange(hsv_frame, lower_yellow, upper_yellow)

    # Find contours of the detected object
    contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    object_center = None

    if contours:
        # Get the largest contour by area
        largest_contour = max(contours, key=cv2.contourArea)

        # Calculate the center of the largest contour
        M = cv2.moments(largest_contour)
        if M["m00"] != 0:
            cx = int(M["m10"] / M["m00"])  # x-coordinate of the center
            cy = int(M["m01"] / M["m00"])  # y-coordinate of the center
            object_center = (cx, cy)

            # Draw the contour and center on the frame
            cv2.drawContours(frame, [largest_contour], -1, (0, 255, 0), 2)
            cv2.circle(frame, object_center, 5, (0, 0, 255), -1)

    # Find the closest grid intersection
    if object_center:
        closest_point, closest_index = min(
            intersection_points,
            key=lambda point: (point[0][0] - object_center[0]) ** 2 + (point[0][1] - object_center[1]) ** 2,
        )

        # Highlight the closest intersection point
        cv2.circle(frame, closest_point, 10, (255, 0, 0), -1)

        # Check if the object persists in the same grid
        if closest_index == last_grid_index:
            persist_count += 1
            if persist_count >= required_persist_frames:
                sc(14, 50)
                time.sleep(5)
                sc(14, 180)  # Trigger servo command
                persist_count = 0  # Reset the counter
        else:
            persist_count = 0  # Reset persistence count for a new grid index

        last_grid_index = closest_index

        # Continue processing arm movement
        cam_x_cord = (num_rows) - (closest_index[0]) + 1
        x_cord = 18 + ((((height * cam_x_cord * gap * cos(radians(theta))) +
                         (gap * cam_x_cord * sin(radians(theta)) * height * tan(radians(theta - (fov / 2)))))) /
                        ((height) - (gap * cam_x_cord * sin(radians(theta)))))
        print(x_cord)
        process_arm(x_cord * 5)
        time.sleep(0.5)

    # Display the frame with the grid and detection
    cv2.imshow("Video with Grid and Object Detection", frame)

    # Break on 'q' key press
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

# Release the capture and close windows
cap.release()
cv2.destroyAllWindows()
