import socket
import struct
import time
import threading
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib.widgets import TextBox, Button, RadioButtons
from collections import deque

# UDP server settings
UDP_IP = "0.0.0.0"  # Listen on all available interfaces
UDP_PORT = 8888
BUFFER_SIZE = 65535  # Max UDP packet size

# Data settings - updated for new 250Hz sampling rate
MEASUREMENT_FREQUENCY = 250  # Hz (as per updated Arduino code)
WINDOW_SIZE = 5000  # Store 20 seconds of data at 250Hz
DISPLAY_WINDOW = 8.0  # Show 8 seconds of data (doubled from 4)

# Default y-axis limits
DEFAULT_Y_MIN = -200
DEFAULT_Y_MAX = 200

# Normalization methods
NORM_METHODS = {
    'None': 'No normalization',
    'MinMax': 'Min-Max scaling (0-1)',
    'ZScore': 'Z-score normalization',
    'Baseline': 'Subtract baseline',
    'Percent': 'Percent of max value'
}

# Data storage for processed measurements (after ESP32 processing)
# These are the actual values we want to display
a_long_data = deque(maxlen=WINDOW_SIZE)
b_long_data = deque(maxlen=WINDOW_SIZE)
a_short_data = deque(maxlen=WINDOW_SIZE)
b_short_data = deque(maxlen=WINDOW_SIZE)
# Add normalized data storage
a_short_norm_data = deque(maxlen=WINDOW_SIZE)
b_short_norm_data = deque(maxlen=WINDOW_SIZE)
measurement_times = deque(maxlen=WINDOW_SIZE)  # Store actual measurement timestamps

class PlotControl:
    """Class to manage plot controls and state"""
    def __init__(self):
        # Y-axis limits
        self.manual_y_scale = False
        self.y_min = DEFAULT_Y_MIN
        self.y_max = DEFAULT_Y_MAX
        self.y_scale_step = 50
        
        # Time window size in seconds
        self.time_window = DISPLAY_WINDOW
        
        # Normalization settings
        self.norm_method = 'None'  # Default: no normalization
        self.norm_window = 100     # Window size for rolling calculations
        self.show_raw = True       # Flag to toggle between raw and normalized view
        self.norm_y_min = 0        # Y-axis min for normalized plots
        self.norm_y_max = 1        # Y-axis max for normalized plots
        self.baseline_samples = 50 # Number of samples to use for baseline
        
        # References to plot elements
        self.fig = None
        self.axes = []
        self.text_boxes = {}
        self.buttons = {}
        self.radio_buttons = {}
        
# Create a global instance
plot_control = PlotControl()

# Lock for thread-safe data access
data_lock = threading.Lock()

# Flag for controlling application
running = True

def udp_server():
    """UDP server that receives and processes data from ESP32"""
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    sock.bind((UDP_IP, UDP_PORT))
    sock.settimeout(0.05)  # 50ms timeout for responsive UI
    
    print(f"UDP server listening on {UDP_IP}:{UDP_PORT}")
    
    # Start time reference
    start_time = time.time()
    last_packet_time = 0
    
    try:
        while running:
            try:
                data, addr = sock.recvfrom(BUFFER_SIZE)
                receive_time = time.time() - start_time  # Time since program started
                
                # Process the received data
                process_packet(data, receive_time)
                
                # Report packet rates occasionally
                if receive_time - last_packet_time > 5.0:
                    with data_lock:
                        if measurement_times:
                            samples = len(measurement_times)
                            duration = max(measurement_times) - min(measurement_times) if len(measurement_times) > 1 else 0
                            rate = samples / duration if duration > 0 else 0
                            print(f"Receiving at approximately {rate:.1f} Hz")
                    last_packet_time = receive_time
                    
            except socket.timeout:
                # This is normal - just continue
                continue
                
    except KeyboardInterrupt:
        print("UDP server stopped by user")
    finally:
        sock.close()

def process_packet(data, receive_time):
    """
    Process a UDP packet based on the new format from ESP32.
    
    The new packet format contains:
    - Header (16 bytes): timestamp, period_count, packet_num, total_packets, reserved
    - Measurements array (12 bytes per measurement):
      - channel_A_long (2 bytes)
      - channel_B_long (2 bytes)
      - channel_A_short (2 bytes)
      - channel_B_short (2 bytes)
      - timestamp (4 bytes)
    """
    try:
        # Check if we have at least the header
        if len(data) < 16:
            print(f"Received packet too small: {len(data)} bytes")
            return
            
        # Parse the header
        # The header structure is defined in ESP32 code as:
        # timestamp (4 bytes), period_count (4 bytes), packet_num (2 bytes), total_packets (2 bytes), reserved (4 bytes)
        header = struct.unpack("<IIHH4s", data[:16])
        esp_timestamp, period_count, packet_num, total_packets, reserved = header
        
        # Skip invalid packets
        if total_packets == 0 or total_packets > 1000:
            print(f"Invalid packet: {packet_num}/{total_packets}")
            return
        
        # Number of measurements in this packet
        # Each measurement takes 12 bytes (4 values x 2 bytes + 4 byte timestamp)
        measurements_data = data[16:]
        measurement_size = 12
        num_measurements = len(measurements_data) // measurement_size
        
        if packet_num == 0:  # Only log first packet of a group to reduce console spam
            print(f"Received packet {packet_num+1}/{total_packets} with {num_measurements} measurements")
        
        # Track filtered measurements
        filtered_count = 0
        
        with data_lock:
            for i in range(num_measurements):
                offset = i * measurement_size
                if offset + 12 <= len(measurements_data):
                    # Extract the 4 measurement values and timestamp
                    # The format is: channel_A_long, channel_B_long, channel_A_short, channel_B_short, timestamp
                    measurement = struct.unpack("<hhhhI", measurements_data[offset:offset+12])
                    
                    # Unpack the values
                    channel_a_long = measurement[0]
                    channel_b_long = measurement[1]
                    channel_a_short = measurement[2]
                    channel_b_short = measurement[3]
                    esp_sample_timestamp = measurement[4]  # Microsecond timestamp from ESP32
                    
                    # Convert ESP32 timestamp from microseconds to seconds and create relative time
                    # Use packet receive time as reference if ESP32 timestamp looks suspicious
                    sample_time = receive_time - ((num_measurements - 1 - i) * (1.0 / MEASUREMENT_FREQUENCY))
                    
                    # Only store measurements where both A-short and B-short are > 0
                    if channel_a_short > 0 and channel_b_short > 0:
                        # Store values in our deques
                        a_long_data.append(channel_a_long)
                        b_long_data.append(channel_b_long)
                        a_short_data.append(channel_a_short)
                        b_short_data.append(channel_b_short)
                        measurement_times.append(sample_time)
                        
                        # Apply normalization based on current method
                        normalize_data()
                    else:
                        filtered_count += 1
        
        # Log filtered measurements if any were removed
        if filtered_count > 0 and packet_num == 0:
            print(f"Filtered {filtered_count} measurements with A-short or B-short values ≤ 0")
    
    except Exception as e:
        print(f"Error processing packet: {e}")
        import traceback
        traceback.print_exc()

def normalize_data():
    """Apply the selected normalization method to the short data streams"""
    global a_short_data, b_short_data, a_short_norm_data, b_short_norm_data
    
    if len(a_short_data) == 0 or len(b_short_data) == 0:
        return  # No data to normalize
    
    a_values = list(a_short_data)
    b_values = list(b_short_data)
    
    # We can assume all values are > 0 due to filtering in process_packet
    
    # Apply the selected normalization method
    if plot_control.norm_method == 'None':
        # No normalization - just copy the values 
        a_norm = a_values[-1]
        b_norm = b_values[-1]
        
    elif plot_control.norm_method == 'MinMax':
        # Use rolling window for min-max scaling
        window_size = min(plot_control.norm_window, len(a_values))
        a_window = a_values[-window_size:]
        b_window = b_values[-window_size:]
        
        a_min, a_max = min(a_window), max(a_window)
        b_min, b_max = min(b_window), max(b_window)
        
        # Avoid division by zero
        a_range = max(1, a_max - a_min)
        b_range = max(1, b_max - b_min)
        
        a_norm = (a_values[-1] - a_min) / a_range
        b_norm = (b_values[-1] - b_min) / b_range
        
    elif plot_control.norm_method == 'ZScore':
        # Z-score normalization using rolling window
        window_size = min(plot_control.norm_window, len(a_values))
        a_window = a_values[-window_size:]
        b_window = b_values[-window_size:]
        
        a_mean = sum(a_window) / len(a_window)
        b_mean = sum(b_window) / len(b_window)
        
        a_std = max(1, np.std(a_window))
        b_std = max(1, np.std(b_window))
        
        a_norm = (a_values[-1] - a_mean) / a_std
        b_norm = (b_values[-1] - b_mean) / b_std
        
    elif plot_control.norm_method == 'Baseline':
        # Subtract baseline (average of first N samples or first part of window)
        baseline_count = min(plot_control.baseline_samples, len(a_values))
        
        if len(a_short_norm_data) == 0:  # First time, calculate baseline
            a_baseline = sum(a_values[:baseline_count]) / baseline_count
            b_baseline = sum(b_values[:baseline_count]) / baseline_count
            # Store baselines as the first normalized values
            a_norm = 0  # Baseline point is 0
            b_norm = 0
            # Save the baselines for future use
            plot_control.a_baseline = a_baseline
            plot_control.b_baseline = b_baseline
        else:
            # Use stored baselines
            a_norm = a_values[-1] - plot_control.a_baseline
            b_norm = b_values[-1] - plot_control.b_baseline
            
    elif plot_control.norm_method == 'Percent':
        # Normalize as percentage of maximum absolute value in window
        window_size = min(plot_control.norm_window, len(a_values))
        a_window = a_values[-window_size:]
        b_window = b_values[-window_size:]
        
        a_max_abs = max(1, max(abs(val) for val in a_window))
        b_max_abs = max(1, max(abs(val) for val in b_window))
        
        a_norm = a_values[-1] / a_max_abs
        b_norm = b_values[-1] / b_max_abs
    
    else:
        # Default fallback
        a_norm = a_values[-1]
        b_norm = b_values[-1]
    
    # Append the normalized values
    a_short_norm_data.append(a_norm)
    b_short_norm_data.append(b_norm)

# Text box and button callbacks
def on_y_min_change(text):
    """Handle Y-min text box change"""
    try:
        value = float(text)
        plot_control.y_min = value
        plot_control.manual_y_scale = True
        print(f"Y-min changed to {value}")
        
        # Update the axes
        for ax in plot_control.axes[:3]:  # Only update raw data axes
            ax.set_ylim(bottom=value)
        
        # Update the text box display in case it was formatted
        plot_control.text_boxes['y_min'].set_val(str(value))
    except ValueError:
        print(f"Invalid Y-min value: {text}")
        # Reset to previous value
        plot_control.text_boxes['y_min'].set_val(str(plot_control.y_min))

def on_y_max_change(text):
    """Handle Y-max text box change"""
    try:
        value = float(text)
        plot_control.y_max = value
        plot_control.manual_y_scale = True
        print(f"Y-max changed to {value}")
        
        # Update the axes
        for ax in plot_control.axes[:3]:  # Only update raw data axes
            ax.set_ylim(top=value)
        
        # Update the text box display in case it was formatted
        plot_control.text_boxes['y_max'].set_val(str(value))
    except ValueError:
        print(f"Invalid Y-max value: {text}")
        # Reset to previous value
        plot_control.text_boxes['y_max'].set_val(str(plot_control.y_max))

def on_time_window_change(text):
    """Handle time window text box change"""
    try:
        value = float(text)
        if value <= 0:
            raise ValueError("Time window must be positive")
            
        plot_control.time_window = value
        print(f"Time window changed to {value} seconds")
        
        # Update the axes
        for ax in plot_control.axes:
            current_min, current_max = ax.get_xlim()
            if current_max > value:
                # If we're shrinking the window, adjust to keep the right edge
                ax.set_xlim(current_max - value, current_max)
            else:
                # Otherwise just extend the right edge
                ax.set_xlim(current_min, current_min + value)
        
        # Update the text box display in case it was formatted
        plot_control.text_boxes['time_window'].set_val(str(value))
    except ValueError as e:
        print(f"Invalid time window value: {text} - {e}")
        # Reset to previous value
        plot_control.text_boxes['time_window'].set_val(str(plot_control.time_window))

def on_norm_window_change(text):
    """Handle normalization window size change"""
    try:
        value = int(text)
        if value <= 0:
            raise ValueError("Normalization window must be positive")
            
        plot_control.norm_window = value
        print(f"Normalization window changed to {value} samples")
        
        # Update the text box display in case it was formatted
        plot_control.text_boxes['norm_window'].set_val(str(value))
    except ValueError as e:
        print(f"Invalid normalization window value: {text} - {e}")
        # Reset to previous value
        plot_control.text_boxes['norm_window'].set_val(str(plot_control.norm_window))

def on_baseline_samples_change(text):
    """Handle baseline samples change"""
    try:
        value = int(text)
        if value <= 0:
            raise ValueError("Baseline samples must be positive")
            
        plot_control.baseline_samples = value
        print(f"Baseline samples changed to {value}")
        
        # Update the text box display in case it was formatted
        plot_control.text_boxes['baseline_samples'].set_val(str(value))
    except ValueError as e:
        print(f"Invalid baseline samples value: {text} - {e}")
        # Reset to previous value
        plot_control.text_boxes['baseline_samples'].set_val(str(plot_control.baseline_samples))

def toggle_auto_scale(event):
    """Toggle between auto and manual scaling"""
    plot_control.manual_y_scale = not plot_control.manual_y_scale
    state = "Manual" if plot_control.manual_y_scale else "Auto"
    print(f"Switching to {state} scaling mode")
    
    # Force update of all plot Y limits
    if plot_control.manual_y_scale:
        for ax in plot_control.axes[:3]:  # Only update raw data axes
            ax.set_ylim(plot_control.y_min, plot_control.y_max)

def toggle_raw_norm_view(event):
    """Toggle between raw and normalized view for short channels"""
    plot_control.show_raw = not plot_control.show_raw
    state = "Raw" if plot_control.show_raw else "Normalized"
    print(f"Switching to {state} view for short channels")
    
    # Update button text
    plot_control.buttons['toggle_view'].label.set_text(f'View: {state}')
    
    # Force redraw
    if plot_control.fig:
        plot_control.fig.canvas.draw_idle()

def on_norm_method_change(label):
    """Handle change in normalization method"""
    plot_control.norm_method = label
    print(f"Normalization method changed to: {label}")
    
    # Clear normalized data to restart with new method
    with data_lock:
        a_short_norm_data.clear()
        b_short_norm_data.clear()
        
        # Recalculate normalized data for existing measurements
        if a_short_data and b_short_data:
            normalize_data()

def init_plot():
    """Initialize the plot for visualization with interactive controls"""
    # Create figure with extra space at bottom for controls
    fig = plt.figure(figsize=(12, 12))  # Make figure taller for normalization plots
    plot_control.fig = fig
    
    # Define the grid for plots (leaving space at bottom for controls)
    grid = plt.GridSpec(6, 1, height_ratios=[3, 3, 3, 3, 3, 1.5])
    
    # Plot 1: All measurements
    ax1 = fig.add_subplot(grid[0])
    line_a_long, = ax1.plot([], [], 'r-', label='A_long (850nm on long)')
    line_b_long, = ax1.plot([], [], 'g-', label='B_long (700nm on long)')
    line_a_short, = ax1.plot([], [], 'b-', label='A_short (850nm on short)')
    line_b_short, = ax1.plot([], [], 'm-', label='B_short (700nm on short)')
    
    ax1.set_xlim(0, plot_control.time_window)  # Initial window size - 8 seconds
    ax1.set_ylim(plot_control.y_min, plot_control.y_max)  # Initial range - few hundred
    ax1.set_title('All Measurements')
    ax1.set_xlabel('Time (seconds)')
    ax1.set_ylabel('Measurement Value')
    ax1.legend()
    ax1.grid(True)
    
    # Plot 2: A_short only
    ax2 = fig.add_subplot(grid[1])
    line_a_short_only, = ax2.plot([], [], 'b-', label='A_short (850nm on short)')
    
    ax2.set_xlim(0, plot_control.time_window)
    ax2.set_ylim(plot_control.y_min, plot_control.y_max)
    ax2.set_title('A_short Measurement (850nm LED effect on short detector)')
    ax2.set_xlabel('Time (seconds)')
    ax2.set_ylabel('Measurement Value')
    ax2.legend()
    ax2.grid(True)
    
    # Plot 3: B_short only
    ax3 = fig.add_subplot(grid[2])
    line_b_short_only, = ax3.plot([], [], 'm-', label='B_short (700nm on short)')
    
    ax3.set_xlim(0, plot_control.time_window)
    ax3.set_ylim(plot_control.y_min, plot_control.y_max)
    ax3.set_title('B_short Measurement (700nm LED effect on short detector)')
    ax3.set_xlabel('Time (seconds)')
    ax3.set_ylabel('Measurement Value')
    ax3.legend()
    ax3.grid(True)
    
    # Add new normalized plots
    # Plot 4: Normalized A_short
    ax4 = fig.add_subplot(grid[3])
    line_a_short_norm, = ax4.plot([], [], 'b-', label='A_short Normalized')
    
    ax4.set_xlim(0, plot_control.time_window)
    ax4.set_ylim(-1, 1)  # Default range for normalized data
    ax4.set_title('Normalized A_short')
    ax4.set_xlabel('Time (seconds)')
    ax4.set_ylabel('Normalized Value')
    ax4.legend()
    ax4.grid(True)
    
    # Plot 5: Normalized B_short
    ax5 = fig.add_subplot(grid[4])
    line_b_short_norm, = ax5.plot([], [], 'm-', label='B_short Normalized')
    
    ax5.set_xlim(0, plot_control.time_window)
    ax5.set_ylim(-1, 1)  # Default range for normalized data
    ax5.set_title('Normalized B_short')
    ax5.set_xlabel('Time (seconds)')
    ax5.set_ylabel('Normalized Value')
    ax5.legend()
    ax5.grid(True)
    
    # Add stats display and scale mode indicator textboxes to all plots
    for i, ax in enumerate([ax1, ax2, ax3, ax4, ax5]):
        props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
        ax.text(0.02, 0.95, 'Auto-scaling active', transform=ax.transAxes, 
                verticalalignment='top', bbox=props)
        
        # Add data rate indicator for first plot
        if i == 0:
            props2 = dict(boxstyle='round', facecolor='lightblue', alpha=0.5)
            ax.text(0.75, 0.95, 'Data rate: -- Hz', transform=ax.transAxes,
                   verticalalignment='top', bbox=props2)
    
    # Store axes in control object
    plot_control.axes = [ax1, ax2, ax3, ax4, ax5]
    
    # Add controls in the bottom area
    control_ax = fig.add_subplot(grid[5])
    control_ax.set_frame_on(False)  # Hide the frame
    control_ax.set_xticks([])  # Hide x-axis ticks
    control_ax.set_yticks([])  # Hide y-axis ticks
    
    # Create text boxes for Y-min, Y-max, Time Window
    # Y-min control
    y_min_ax = plt.axes([0.15, 0.05, 0.1, 0.025])
    y_min_textbox = TextBox(y_min_ax, 'Y Min: ', initial=str(plot_control.y_min))
    y_min_textbox.on_submit(on_y_min_change)
    plot_control.text_boxes['y_min'] = y_min_textbox
    
    # Y-max control
    y_max_ax = plt.axes([0.15, 0.09, 0.1, 0.025])
    y_max_textbox = TextBox(y_max_ax, 'Y Max: ', initial=str(plot_control.y_max))
    y_max_textbox.on_submit(on_y_max_change)
    plot_control.text_boxes['y_max'] = y_max_textbox
    
    # Time window control
    time_window_ax = plt.axes([0.35, 0.05, 0.1, 0.025])
    time_window_textbox = TextBox(time_window_ax, 'Time Window (s): ', initial=str(plot_control.time_window))
    time_window_textbox.on_submit(on_time_window_change)
    plot_control.text_boxes['time_window'] = time_window_textbox
    
    # Normalization window control
    norm_window_ax = plt.axes([0.35, 0.09, 0.1, 0.025])
    norm_window_textbox = TextBox(norm_window_ax, 'Norm Window: ', initial=str(plot_control.norm_window))
    norm_window_textbox.on_submit(on_norm_window_change)
    plot_control.text_boxes['norm_window'] = norm_window_textbox
    
    # Baseline samples control
    baseline_samples_ax = plt.axes([0.55, 0.09, 0.1, 0.025])
    baseline_samples_textbox = TextBox(baseline_samples_ax, 'Baseline Samples: ', initial=str(plot_control.baseline_samples))
    baseline_samples_textbox.on_submit(on_baseline_samples_change)
    plot_control.text_boxes['baseline_samples'] = baseline_samples_textbox
    
    # Toggle auto/manual scaling button
    auto_scale_ax = plt.axes([0.55, 0.05, 0.1, 0.025])
    auto_scale_button = Button(auto_scale_ax, 'Toggle Scale')
    auto_scale_button.on_clicked(toggle_auto_scale)
    plot_control.buttons['auto_scale'] = auto_scale_button
    
    # Toggle raw/normalized view button
    toggle_view_ax = plt.axes([0.75, 0.05, 0.15, 0.025])
    toggle_view_button = Button(toggle_view_ax, 'View: Raw')
    toggle_view_button.on_clicked(toggle_raw_norm_view)
    plot_control.buttons['toggle_view'] = toggle_view_button
    
    # Normalization method radio buttons
    norm_method_ax = plt.axes([0.75, 0.08, 0.15, 0.14])
    norm_radio = RadioButtons(norm_method_ax, list(NORM_METHODS.keys()), active=0)
    norm_radio.on_clicked(on_norm_method_change)
    plot_control.radio_buttons['norm_method'] = norm_radio
    
    # Add instructions at the bottom
    fig.text(0.5, 0.01, 
            "Press 'q' to quit | 'n' to cycle normalization methods | 'r' to toggle raw/normalized view",
            ha='center', fontsize=9)
    
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.18)  # Make room for controls
    
    # Lines collection for easy access
    lines = {
        'all_lines': (line_a_long, line_b_long, line_a_short, line_b_short, ax1),
        'a_short': (line_a_short_only, ax2),
        'b_short': (line_b_short_only, ax3),
        'a_short_norm': (line_a_short_norm, ax4),
        'b_short_norm': (line_b_short_norm, ax5)
    }
    
    return fig, lines

def update_plot(frame, lines_dict):
    """Update the plot with new data"""
    global plot_control
    
    all_lines = lines_dict['all_lines']
    a_short_line = lines_dict['a_short']
    b_short_line = lines_dict['b_short']
    a_short_norm_line = lines_dict['a_short_norm']
    b_short_norm_line = lines_dict['b_short_norm']
    
    line_a_long, line_b_long, line_a_short, line_b_short, ax1 = all_lines
    line_a_short_only, ax2 = a_short_line
    line_b_short_only, ax3 = b_short_line
    line_a_short_norm, ax4 = a_short_norm_line
    line_b_short_norm, ax5 = b_short_norm_line
    
    with data_lock:
        # Get the latest data
        times = list(measurement_times)
        a_long = list(a_long_data)
        b_long = list(b_long_data)
        a_short = list(a_short_data)
        b_short = list(b_short_data)
        a_short_norm = list(a_short_norm_data)
        b_short_norm = list(b_short_norm_data)
    
    # Update all plots if we have data
    if times:
        # Normalize time data - subtract the minimum time to start from 0
        min_time = min(times)
        times_normalized = [t - min_time for t in times]
        
        # Calculate data rate
        if len(times) > 10:
            time_span = max(times) - min(times)
            data_rate = len(times) / time_span if time_span > 0 else 0
        else:
            data_rate = 0
            
        # Update the first plot (all measurements)
        line_a_long.set_data(times_normalized, a_long)
        line_b_long.set_data(times_normalized, b_long)
        line_a_short.set_data(times_normalized, a_short)
        line_b_short.set_data(times_normalized, b_short)
        
        # Update individual plots - raw data
        line_a_short_only.set_data(times_normalized, a_short)
        line_b_short_only.set_data(times_normalized, b_short)
        
        # Update normalized plots
        if a_short_norm and b_short_norm:
            line_a_short_norm.set_data(times_normalized[-len(a_short_norm):], a_short_norm)
            line_b_short_norm.set_data(times_normalized[-len(b_short_norm):], b_short_norm)
            
            # Update titles to show normalization method
            ax4.set_title(f'Normalized A_short ({plot_control.norm_method})')
            ax5.set_title(f'Normalized B_short ({plot_control.norm_method})')
        
        # Calculate min/max for auto-scaling
        all_data = a_long + b_long + a_short + b_short
        if all_data:
            data_min = min(all_data)
            data_max = max(all_data)
            
            # Add padding (10%)
            data_range = data_max - data_min
            padding = max(data_range * 0.1, 10)  # At least 10 units of padding
            data_min -= padding
            data_max += padding
            
            # Ensure minimum range to prevent excessive zooming
            min_range = 50
            if data_max - data_min < min_range:
                center = (data_min + data_max) / 2
                data_min = center - min_range / 2
                data_max = center + min_range / 2
        else:
            data_min, data_max = DEFAULT_Y_MIN, DEFAULT_Y_MAX  # Default if no data
        
        # Calculate min/max for normalized data auto-scaling
        norm_data = a_short_norm + b_short_norm
        if norm_data:
            norm_min = min(norm_data)
            norm_max = max(norm_data)
            
            # Add padding (10%)
            norm_range = norm_max - norm_min
            norm_padding = max(norm_range * 0.1, 0.1)  # At least 0.1 units padding
            norm_min -= norm_padding
            norm_max += norm_padding
            
            # Ensure minimum range
            min_norm_range = 0.2
            if norm_max - norm_min < min_norm_range:
                center = (norm_min + norm_max) / 2
                norm_min = center - min_norm_range / 2
                norm_max = center + min_norm_range / 2
        else:
            # Default for normalized data
            if plot_control.norm_method == 'ZScore':
                norm_min, norm_max = -3, 3
            elif plot_control.norm_method == 'Baseline':
                norm_min, norm_max = -50, 50
            else:
                norm_min, norm_max = -1, 1
        
        # Adjust x-axis for all plots - time domain display
        max_time = max(times_normalized)
        window_size = plot_control.time_window  # Use the current time window setting
        
        if max_time <= window_size:
            # If we have less data than the window, show all data
            for ax in plot_control.axes:
                ax.set_xlim(0, max(window_size, max_time))
        else:
            # Sliding window of the most recent data
            for ax in plot_control.axes:
                ax.set_xlim(max_time - window_size, max_time)
        
        # Update Y scales and text for raw data plots
        for i, ax in enumerate(plot_control.axes[:3]):  # First 3 are raw data
            # Check if we're in manual or auto scaling mode
            if plot_control.manual_y_scale:
                # Use manual Y-scale settings
                ax.set_ylim(plot_control.y_min, plot_control.y_max)
                
                # Clear previous textboxes (they're always at index 0)
                if ax.texts:
                    ax.texts[0].set_text(f'Y-scale: {plot_control.y_min}-{plot_control.y_max} (MANUAL)')
                    ax.texts[0].set_bbox(dict(boxstyle='round', facecolor='lightcoral', alpha=0.6))
                    
                    # Update data rate on first plot
                    if i == 0 and len(ax.texts) > 1:
                        ax.texts[1].set_text(f'Data rate: {data_rate:.1f} Hz')
            else:
                # Use auto scaling
                ax.set_ylim(data_min, data_max)
                
                # Update text boxes with auto values
                if i == 0 and 'y_min' in plot_control.text_boxes and 'y_max' in plot_control.text_boxes:
                    plot_control.text_boxes['y_min'].set_val(f"{data_min:.1f}")
                    plot_control.text_boxes['y_max'].set_val(f"{data_max:.1f}")
                
                # Update the textbox with current scale
                if ax.texts:
                    ax.texts[0].set_text(f'Y-scale: {data_min:.1f}-{data_max:.1f} (AUTO)')
                    ax.texts[0].set_bbox(dict(boxstyle='round', facecolor='wheat', alpha=0.5))
                    
                    # Update data rate on first plot
                    if i == 0 and len(ax.texts) > 1:
                        ax.texts[1].set_text(f'Data rate: {data_rate:.1f} Hz')
        
        # Update Y scales for normalized plots
        for i, ax in enumerate(plot_control.axes[3:], 3):  # Last 2 are normalized
            # Always auto-scale normalized plots
            ax.set_ylim(norm_min, norm_max)
            
            # Update the textbox with current scale
            if ax.texts:
                ax.texts[0].set_text(f'Norm method: {plot_control.norm_method}')
                ax.texts[0].set_bbox(dict(boxstyle='round', facecolor='lightgreen', alpha=0.5))
    
    # Return all the lines that were updated
    return line_a_long, line_b_long, line_a_short, line_b_short, line_a_short_only, line_b_short_only, line_a_short_norm, line_b_short_norm

def key_press_handler(event):
    """Handle keyboard events"""
    global running, plot_control
    
    # Quit the application
    if event.key == 'q':
        plt.close('all')
        running = False
    
    # Toggle manual Y-scale mode
    elif event.key == 'm':
        plot_control.manual_y_scale = not plot_control.manual_y_scale
        print(f"Manual Y-scale mode: {'ON' if plot_control.manual_y_scale else 'OFF'}")
    
    # Toggle raw/normalized view
    elif event.key == 'r':
        toggle_raw_norm_view(None)
    
    # Cycle through normalization methods
    elif event.key == 'n':
        methods = list(NORM_METHODS.keys())
        current_idx = methods.index(plot_control.norm_method)
        next_idx = (current_idx + 1) % len(methods)
        next_method = methods[next_idx]
        on_norm_method_change(next_method)
        plot_control.radio_buttons['norm_method'].set_active(next_idx)
    
    # Change time window
    elif event.key == 'right':
        plot_control.time_window = min(30, plot_control.time_window * 1.5)  # Increase window size (max 30s)
        plot_control.text_boxes['time_window'].set_val(str(plot_control.time_window))
        print(f"Time window increased to: {plot_control.time_window:.1f}s")
    
    elif event.key == 'left':
        plot_control.time_window = max(1, plot_control.time_window / 1.5)  # Decrease window size (min 1s)
        plot_control.text_boxes['time_window'].set_val(str(plot_control.time_window))
        print(f"Time window decreased to: {plot_control.time_window:.1f}s")
    
    # Reset to auto-scaling
    elif event.key == 'a':
        plot_control.manual_y_scale = False
        print("Auto Y-scale mode enabled")
    
    # Increase/decrease Y-scale
    elif event.key == 'up':
        plot_control.y_max += plot_control.y_scale_step
        plot_control.text_boxes['y_max'].set_val(str(plot_control.y_max))
        if not plot_control.manual_y_scale:
            plot_control.manual_y_scale = True
        print(f"Y-max increased to: {plot_control.y_max}")
    
    elif event.key == 'down':
        plot_control.y_max = max(plot_control.y_scale_step, plot_control.y_max - plot_control.y_scale_step)
        plot_control.text_boxes['y_max'].set_val(str(plot_control.y_max))
        if not plot_control.manual_y_scale:
            plot_control.manual_y_scale = True
        print(f"Y-max decreased to: {plot_control.y_max}")
        
    # Y-min adjustments
    elif event.key == 'pageup':
        plot_control.y_min += plot_control.y_scale_step
        plot_control.text_boxes['y_min'].set_val(str(plot_control.y_min))
        if not plot_control.manual_y_scale:
            plot_control.manual_y_scale = True
        print(f"Y-min increased to: {plot_control.y_min}")
    
    elif event.key == 'pagedown':
        plot_control.y_min -= plot_control.y_scale_step
        plot_control.text_boxes['y_min'].set_val(str(plot_control.y_min))
        if not plot_control.manual_y_scale:
            plot_control.manual_y_scale = True
        print(f"Y-min decreased to: {plot_control.y_min}")
    
    # Change scale step size
    elif event.key == '+':
        plot_control.y_scale_step = min(1000, plot_control.y_scale_step * 2)
        print(f"Y-scale step size increased to: {plot_control.y_scale_step}")
    
    elif event.key == '-':
        plot_control.y_scale_step = max(1, plot_control.y_scale_step // 2)
        print(f"Y-scale step size decreased to: {plot_control.y_scale_step}")

def main():
    """Main function to run the UDP receiver and visualizer"""
    # Start UDP server in a separate thread
    server_thread = threading.Thread(target=udp_server)
    server_thread.daemon = True
    server_thread.start()
    
    # Initialize and run the visualization
    fig, lines = init_plot()
    fig.canvas.mpl_connect('key_press_event', key_press_handler)
    
    # Use animation to update the plot
    ani = FuncAnimation(fig, update_plot, fargs=(lines,), interval=100, 
                       blit=True, cache_frame_data=False)
    
    plt.show()
    
    # Clean up
    global running
    running = False
    server_thread.join(timeout=1.0)
    print("Program terminated")

if __name__ == "__main__":
    main()