# Face Recognition with Servo Tracking for Raspberry Pi 5
# Based on your existing facial_recognition.py with added servo tracking capabilities
# UPDATED: Added camera flip for upside-down mounting + Audio alerts for unknown faces

import face_recognition
import cv2
import numpy as np
from picamera2 import Picamera2
import time
import pickle
from gpiozero import AngularServo
from gpiozero.pins.pigpio import PiGPIOFactory
from time import sleep
import pygame  # Added for audio playback

# Load pre-trained face encodings
print("[INFO] loading encodings...")
with open("encodings.pickle", "rb") as f:
    data = pickle.loads(f.read())
    known_face_encodings = data["encodings"]
    known_face_names = data["names"]

# ===== AUDIO CONFIGURATION =====
AUDIO_FILE = "ALERT ALERT ALERT Un.wav"  # Your audio file
AUDIO_COOLDOWN = 5  # Seconds between audio alerts (prevents spam)

# Initialize pygame mixer for audio
try:
    pygame.mixer.init(frequency=22050, size=-16, channels=2, buffer=512)
    unknown_alert_sound = pygame.mixer.Sound(AUDIO_FILE)
    print(f"[INFO] Audio system initialized - Alert file: {AUDIO_FILE}")
    audio_enabled = True
except Exception as e:
    print(f"[WARNING] Audio initialization failed: {e}")
    print("[WARNING] Continuing without audio alerts")
    audio_enabled = False

# Audio timing variables
last_unknown_alert_time = 0
unknown_face_detected = False

# ===== UPDATED SERVO CONFIGURATION =====
# Servo setup - Full 0-180 degree range as requested
PAN_PIN = 18   # Base servo pin (left/right)
TILT_PIN = 19  # Top servo pin (up/down)

# Updated servo range and center positions
SERVO_MIN = 0         # Minimum servo angle
SERVO_MAX = 180       # Maximum servo angle
PAN_CENTER = 130      # Bottom servo center at 130 degrees
TILT_CENTER = 80      # Top servo center at 80 degrees

# ===== CAMERA FLIP CONFIGURATION =====
CAMERA_UPSIDE_DOWN = True  # Set to True if camera is mounted upside down

# Initialize servos with pigpio for better control and full 0-180 range
try:
    factory = PiGPIOFactory()
    # Create servos with full 0-180 degree range
    pan_servo = AngularServo(PAN_PIN, min_angle=0, max_angle=180, pin_factory=factory)
    tilt_servo = AngularServo(TILT_PIN, min_angle=0, max_angle=180, pin_factory=factory)
    
    print(f"[INFO] Servos initialized with full range (0-180°)")
    print(f"[INFO] Center positions - Pan: {PAN_CENTER}°, Tilt: {TILT_CENTER}°")
    
except Exception as e:
    print(f"[ERROR] Failed to initialize servos: {e}")
    pan_servo = None
    tilt_servo = None

# Initialize the camera - keeping your original settings
picam2 = Picamera2()
picam2.configure(picam2.create_preview_configuration(main={"format": 'XRGB8888', "size": (1920, 1080)}))
picam2.start()

# Face tracking parameters
FRAME_WIDTH = 1920
FRAME_HEIGHT = 1080
FRAME_CENTER_X = FRAME_WIDTH // 2
FRAME_CENTER_Y = FRAME_HEIGHT // 2

# Tracking sensitivity (adjust these values to fine-tune tracking)
TRACKING_THRESHOLD = 50  # Minimum pixel difference to trigger servo movement
SERVO_STEP = 2          # Degrees to move servo per adjustment
TARGET_NAME = "Ankur"   # CASE SENSITIVE! Make sure this matches your training data

# Initialize our variables - keeping your original settings
cv_scaler = 4
face_locations = []
face_encodings = []
face_names = []
frame_count = 0
start_time = time.time()
fps = 0

# Current servo positions
current_pan_angle = PAN_CENTER
current_tilt_angle = TILT_CENTER

def play_unknown_alert():
    """Play audio alert for unknown person detection with cooldown"""
    global last_unknown_alert_time
    
    current_time = time.time()
    
    # Check if enough time has passed since last alert
    if not audio_enabled or (current_time - last_unknown_alert_time) < AUDIO_COOLDOWN:
        return
    
    try:
        # Play the alert sound
        unknown_alert_sound.play()
        last_unknown_alert_time = current_time
        print("[AUDIO] Playing unknown person alert")
        
    except Exception as e:
        print(f"[ERROR] Failed to play audio alert: {e}")

def flip_frame(frame):
    """Flip frame 180 degrees if camera is mounted upside down"""
    if CAMERA_UPSIDE_DOWN:
        # Rotate frame 180 degrees (upside down to right side up)
        frame = cv2.rotate(frame, cv2.ROTATE_180)
    return frame

def move_servos_to_center():
    """Move both servos to center position"""
    global current_pan_angle, current_tilt_angle
    
    if pan_servo is None or tilt_servo is None:
        print("[ERROR] Servos not initialized")
        return
    
    print(f"Centering servos - Pan: {PAN_CENTER}°, Tilt: {TILT_CENTER}°")
    try:
        pan_servo.angle = PAN_CENTER
        tilt_servo.angle = TILT_CENTER
        current_pan_angle = PAN_CENTER
        current_tilt_angle = TILT_CENTER
        sleep(0.5)
    except Exception as e:
        print(f"Error centering servos: {e}")

def calculate_servo_angles(face_center_x, face_center_y):
    """Calculate required servo angles to center face in frame"""
    global current_pan_angle, current_tilt_angle
    
    # Calculate pixel differences from frame center
    diff_x = face_center_x - FRAME_CENTER_X
    diff_y = face_center_y - FRAME_CENTER_Y
    
    # Only move if difference is significant enough
    new_pan_angle = current_pan_angle
    new_tilt_angle = current_tilt_angle
    
    if abs(diff_x) > TRACKING_THRESHOLD:
        # Pan servo: negative diff_x means face is on left, so pan left (decrease angle)
        # Note: If camera is flipped, the servo movement direction might need adjustment
        pan_adjustment = -diff_x / 30  # Scale factor for sensitivity
        
        if CAMERA_UPSIDE_DOWN:
            # Reverse pan direction for upside-down camera
            pan_adjustment = -pan_adjustment
        
        new_pan_angle = current_pan_angle + pan_adjustment * SERVO_STEP
        
        # Clamp to servo limits (0-180 degrees)
        new_pan_angle = max(SERVO_MIN, min(SERVO_MAX, new_pan_angle))
    
    if abs(diff_y) > TRACKING_THRESHOLD:
        # Tilt servo: negative diff_y means face is above center, so tilt up (increase angle)
        tilt_adjustment = diff_y / 30  # Scale factor for sensitivity
        
        if CAMERA_UPSIDE_DOWN:
            # Reverse tilt direction for upside-down camera
            tilt_adjustment = -tilt_adjustment
        
        new_tilt_angle = current_tilt_angle + tilt_adjustment * SERVO_STEP
        
        # Clamp to servo limits (0-180 degrees)
        new_tilt_angle = max(SERVO_MIN, min(SERVO_MAX, new_tilt_angle))
    
    return new_pan_angle, new_tilt_angle

def move_servos(pan_angle, tilt_angle):
    """Smoothly move servos to target angles"""
    global current_pan_angle, current_tilt_angle
    
    if pan_servo is None or tilt_servo is None:
        return
    
    try:
        # Only move if angles have changed significantly
        if abs(pan_angle - current_pan_angle) > 0.5:
            pan_servo.angle = pan_angle
            current_pan_angle = pan_angle
        
        if abs(tilt_angle - current_tilt_angle) > 0.5:
            tilt_servo.angle = tilt_angle
            current_tilt_angle = tilt_angle
        
        print(f"Servo tracking - Pan: {current_pan_angle:.1f}°, Tilt: {current_tilt_angle:.1f}°")
        
    except Exception as e:
        print(f"Servo movement error: {e}")

def process_frame(frame):
    global face_locations, face_encodings, face_names, unknown_face_detected
    
    # Flip frame if camera is mounted upside down
    frame = flip_frame(frame)
    
    # Resize the frame using cv_scaler to increase performance
    resized_frame = cv2.resize(frame, (0, 0), fx=(1/cv_scaler), fy=(1/cv_scaler))
    
    # Convert the image from BGR to RGB colour space
    rgb_resized_frame = cv2.cvtColor(resized_frame, cv2.COLOR_BGR2RGB)
    
    # Find all the faces and face encodings in the current frame of video
    face_locations = face_recognition.face_locations(rgb_resized_frame)
    face_encodings = face_recognition.face_encodings(rgb_resized_frame, face_locations, model='large')
    
    face_names = []
    target_face_center = None
    unknown_face_detected = False
    
    for i, face_encoding in enumerate(face_encodings):
        # See if the face is a match for the known face(s)
        matches = face_recognition.compare_faces(known_face_encodings, face_encoding)
        name = "Unknown"
        
        # Use the known face with the smallest distance to the new face
        face_distances = face_recognition.face_distance(known_face_encodings, face_encoding)
        best_match_index = np.argmin(face_distances)
        
        if matches[best_match_index]:
            name = known_face_names[best_match_index]
        else:
            # Unknown face detected - trigger audio alert
            unknown_face_detected = True
            
        face_names.append(name)
        
        # If this is our target person, calculate face center for servo tracking
        if name.lower() == TARGET_NAME.lower() and i < len(face_locations):
            top, right, bottom, left = face_locations[i]
            
            # Scale back up to original frame size
            top *= cv_scaler
            right *= cv_scaler
            bottom *= cv_scaler
            left *= cv_scaler
            
            # Calculate face center
            face_center_x = left + (right - left) // 2
            face_center_y = top + (bottom - top) // 2
            target_face_center = (face_center_x, face_center_y)
            
            print(f"Tracking {name} at center: ({face_center_x}, {face_center_y})")
    
    # Play audio alert if unknown face detected
    if unknown_face_detected:
        play_unknown_alert()
    
    return frame, target_face_center

def draw_results(frame, target_face_center=None):
    # Draw crosshairs at frame center
    cv2.line(frame, (FRAME_CENTER_X - 30, FRAME_CENTER_Y), 
             (FRAME_CENTER_X + 30, FRAME_CENTER_Y), (255, 255, 255), 2)
    cv2.line(frame, (FRAME_CENTER_X, FRAME_CENTER_Y - 30), 
             (FRAME_CENTER_X, FRAME_CENTER_Y + 30), (255, 255, 255), 2)
    
    # Display the results
    for (top, right, bottom, left), name in zip(face_locations, face_names):
        # Scale back up face locations since the frame we detected in was scaled
        top *= cv_scaler
        right *= cv_scaler
        bottom *= cv_scaler
        left *= cv_scaler
        
        # Choose color based on if this is our target or unknown
        if name.lower() == TARGET_NAME.lower():
            color = (0, 255, 0)  # Green for target
            label_text = "TRACKING TARGET"
        elif name == "Unknown":
            color = (0, 0, 255)  # Red for unknown
            label_text = "UNKNOWN - ALERT TRIGGERED!"
        else:
            color = (255, 255, 0)  # Yellow for known but not target
            label_text = "KNOWN PERSON"
        
        # Draw a box around the face
        cv2.rectangle(frame, (left, top), (right, bottom), color, 3)
        
        # Draw a label with a name below the face
        cv2.rectangle(frame, (left - 3, top - 35), (right + 3, top), color, cv2.FILLED)
        font = cv2.FONT_HERSHEY_DUPLEX
        cv2.putText(frame, name, (left + 6, top - 6), font, 1.0, (255, 255, 255), 1)
        
        # Add status label
        cv2.putText(frame, label_text, (left + 6, bottom + 25), 
                   font, 0.6, color, 2)
    
    # Draw target face center if available
    if target_face_center:
        cv2.circle(frame, target_face_center, 8, (0, 255, 0), -1)
        # Draw line from center to face
        cv2.line(frame, (FRAME_CENTER_X, FRAME_CENTER_Y), target_face_center, (0, 255, 0), 2)
        
        # Show distance from center
        distance_x = abs(target_face_center[0] - FRAME_CENTER_X)
        distance_y = abs(target_face_center[1] - FRAME_CENTER_Y)
        cv2.putText(frame, f"Distance: X={distance_x} Y={distance_y}", 
                   (10, 120), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1)
    
    # Add camera flip indicator
    if CAMERA_UPSIDE_DOWN:
        cv2.putText(frame, "CAMERA FLIPPED", (10, 50),
                   cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 0), 2)
    
    # Add audio status indicator
    audio_status = "AUDIO: ENABLED" if audio_enabled else "AUDIO: DISABLED"
    audio_color = (0, 255, 0) if audio_enabled else (0, 0, 255)
    cv2.putText(frame, audio_status, (10, 150),
               cv2.FONT_HERSHEY_SIMPLEX, 0.6, audio_color, 2)
    
    # Show unknown face detection status
    if unknown_face_detected:
        cv2.putText(frame, "UNKNOWN FACE DETECTED - ALERT PLAYING!", 
                   (10, 180), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
    
    return frame

def calculate_fps():
    global frame_count, start_time, fps
    frame_count += 1
    elapsed_time = time.time() - start_time
    if elapsed_time > 1:
        fps = frame_count / elapsed_time
        frame_count = 0
        start_time = time.time()
    return fps

# Initialize servos to center position
print(f"[INFO] Initializing face tracking turret for target: {TARGET_NAME}")
print(f"[INFO] Servo range: {SERVO_MIN}° to {SERVO_MAX}°")
print(f"[INFO] Servo pins - Pan: GPIO {PAN_PIN}, Tilt: GPIO {TILT_PIN}")
print(f"[INFO] Audio alerts: {'ENABLED' if audio_enabled else 'DISABLED'}")
if CAMERA_UPSIDE_DOWN:
    print("[INFO] Camera is mounted upside down - auto-flipping feed")
    print("[INFO] Servo directions adjusted for flipped camera")

move_servos_to_center()
print("[INFO] Servo tracking initialized. Starting face recognition...")
print("[INFO] Controls: 'q' = quit, 'c' = center servos, 'f' = toggle camera flip")

try:
    while True:
        # Capture a frame from camera
        frame = picam2.capture_array()
        
        # Process the frame with the function and get target face center
        processed_frame, target_face_center = process_frame(frame)
        
        # If we found our target person, move servos to track them
        if target_face_center:
            new_pan, new_tilt = calculate_servo_angles(target_face_center[0], target_face_center[1])
            move_servos(new_pan, new_tilt)
        
        # Get the text and boxes to be drawn based on the processed frame
        display_frame = draw_results(processed_frame, target_face_center)
        
        # Calculate and update FPS
        current_fps = calculate_fps()
        
        # Attach FPS counter to the text and boxes
        cv2.putText(display_frame, f"FPS: {current_fps:.1f}", 
                   (display_frame.shape[1] - 150, 30),
                   cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
        
        # Add servo angle info
        cv2.putText(display_frame, f"Pan: {current_pan_angle:.1f}° Tilt: {current_tilt_angle:.1f}°",
                   (10, display_frame.shape[0] - 20), 
                   cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
        
        # Add tracking target info
        cv2.putText(display_frame, f"Tracking Target: {TARGET_NAME}", (10, 80),
                   cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
        
        # Display everything over the video feed
        cv2.imshow('Face Tracking Turret', display_frame)
        
        # Handle keyboard input
        key = cv2.waitKey(1) & 0xFF
        if key == ord("q"):
            break
        elif key == ord("c"):
            # Return to center on 'c' key press
            print("Returning servos to center position...")
            move_servos_to_center()
        elif key == ord("f"):
            # Toggle camera flip on 'f' key press
            CAMERA_UPSIDE_DOWN = not CAMERA_UPSIDE_DOWN
            flip_status = "ENABLED" if CAMERA_UPSIDE_DOWN else "DISABLED"
            print(f"Camera flip {flip_status}")

except KeyboardInterrupt:
    print("\nKeyboard interrupt - stopping face tracking...")

except Exception as e:
    print(f"Error occurred: {e}")

finally:
    # By breaking the loop we run this code here which closes everything
    print("Shutting down face tracking turret...")
    move_servos_to_center()  # Return servos to center before exit
    cv2.destroyAllWindows()
    picam2.stop()
    
    # Close servo connections
    if pan_servo is not None:
        pan_servo.close()
    if tilt_servo is not None:
        tilt_servo.close()
    
    # Close audio mixer
    if audio_enabled:
        pygame.mixer.quit()
    
    print("Face tracking turret stopped safely.")
