import pandas as pd
import numpy as np
from scipy.ndimage import gaussian_filter
import pyvista as pv

# Load data
df = pd.read_csv('1-million-stars.csv')

# Convert parallax to distance (light years)
df['distance_ly'] = 3262 / df['parallax']

# Remove bad rows
df = df.dropna(subset=['ra', 'dec', 'distance_ly', 'parallax'])
df = df[np.isfinite(df['distance_ly'])]

# ── DISTANCE FILTER ──────────────────────────────────────────
MAX_DISTANCE_LY = 388  # Actual max from data: 388.33 ly
df = df[df['distance_ly'] <= MAX_DISTANCE_LY]
df = df.reset_index(drop=True)
print(f"Stars within {MAX_DISTANCE_LY} ly: {len(df):,}")

# Convert RA/Dec to Cartesian coordinates
ra_rad  = np.radians(df['ra'])
dec_rad = np.radians(df['dec'])
x = df['distance_ly'] * np.cos(dec_rad) * np.cos(ra_rad)
y = df['distance_ly'] * np.cos(dec_rad) * np.sin(ra_rad)
z = df['distance_ly'] * np.sin(dec_rad)

# ── 3D GRID ──────────────────────────────────────────────────
grid_size = 150
lim = MAX_DISTANCE_LY + 5

x_range = np.linspace(-lim, lim, grid_size)
y_range = np.linspace(-lim, lim, grid_size)
z_range = np.linspace(-lim, lim, grid_size)

# Calculate density via 3D histogram
H, edges = np.histogramdd(
    np.column_stack([x, y, z]),
    bins=(x_range, y_range, z_range)
)

# Smooth the density field
H_smooth = gaussian_filter(H, sigma=2.5)

print(f"Density stats: min={H_smooth.min():.4f}, max={H_smooth.max():.4f}, mean={H_smooth.mean():.4f}")
print(f"Non-zero cells: {np.sum(H_smooth > 0):,}/{H_smooth.size:,}")

# ── DIAGNOSTIC: RA/Dec of densest regions ────────────────────
print("\n── Top 20 Densest Grid Cells ──")
x_centers = (x_range[:-1] + x_range[1:]) / 2
y_centers = (y_range[:-1] + y_range[1:]) / 2
z_centers = (z_range[:-1] + z_range[1:]) / 2

flat_indices = np.argsort(H_smooth.ravel())[-20:][::-1]
ix, iy, iz = np.unravel_index(flat_indices, H_smooth.shape)

print(f"{'Rank':<5} {'Density':<10} {'X (ly)':<10} {'Y (ly)':<10} {'Z (ly)':<10} {'RA (deg)':<10} {'Dec (deg)':<10} {'Dist (ly)':<10}")
for rank, (i, j, k) in enumerate(zip(ix, iy, iz), 1):
    cx = x_centers[i]
    cy = y_centers[j]
    cz = z_centers[k]
    dist = np.sqrt(cx**2 + cy**2 + cz**2)
    ra_deg  = np.degrees(np.arctan2(cy, cx)) % 360
    dec_deg = np.degrees(np.arcsin(np.clip(cz / dist, -1, 1))) if dist > 0 else 0
    density = H_smooth[i, j, k]
    print(f"{rank:<5} {density:<10.2f} {cx:<10.1f} {cy:<10.1f} {cz:<10.1f} {ra_deg:<10.1f} {dec_deg:<10.1f} {dist:<10.1f}")

print("\nLook up these RA/Dec coordinates to identify known stellar structures.")
print("Hint: Scorpius-Centaurus ~ RA 240, Dec -40 | Orion OB1 ~ RA 84, Dec 0")
print("      Local Bubble boundary ~ 200-300 ly  | Pleiades  ~ RA 56,  Dec +24\n")

# ── PYVISTA GRID ─────────────────────────────────────────────
grid = pv.ImageData()
grid.dimensions = H_smooth.shape
grid.spacing = (
    (x_range.max() - x_range.min()) / (grid_size - 1),
    (y_range.max() - y_range.min()) / (grid_size - 1),
    (z_range.max() - z_range.min()) / (grid_size - 1)
)
grid.origin = (x_range.min(), y_range.min(), z_range.min())
grid.point_data['density'] = H_smooth.flatten(order='F')

# ── DENSITY THRESHOLDS ───────────────────────────────────────
d_max = H_smooth.max()
thresholds = [
    (d_max * 0.005, 'darkblue', 0.15, 'Ultra sparse (0.5%)'),
    (d_max * 0.01,  'blue',     0.20, 'Very sparse (1%)'),
    (d_max * 0.02,  'cyan',     0.30, 'Sparse (2%)'),
    (d_max * 0.05,  'green',    0.45, 'Medium (5%)'),
    (d_max * 0.10,  'yellow',   0.55, 'Dense (10%)'),
    (d_max * 0.20,  'orange',   0.70, 'Very dense (20%)'),
    (d_max * 0.40,  'red',      0.85, 'Maximum (40%)'),
]

print(f"Auto-scaled thresholds (max density = {d_max:.2f}):")
for t in thresholds:
    print(f"  {t[3]}: {t[0]:.4f}")

# ── PLOTTER ──────────────────────────────────────────────────
plotter = pv.Plotter()
plotter.set_background('black')

surfaces_added = 0
for threshold, color, opacity, label in thresholds:
    try:
        contour = grid.contour([threshold], scalars='density')
        if contour.n_points > 0:
            plotter.add_mesh(
                contour,
                color=color,
                opacity=opacity,
                show_edges=False,
                smooth_shading=True,
                label=label
            )
            surfaces_added += 1
            print(f"Added surface at {threshold:.4f}: {contour.n_points:,} points")
        else:
            print(f"No surface at threshold {threshold:.4f}")
    except Exception as e:
        print(f"Failed at threshold {threshold:.4f}: {e}")

print(f"\nTotal surfaces added: {surfaces_added}")

# ── ANNOTATIONS ──────────────────────────────────────────────
plotter.add_text(
    f'3D Stellar Density Isosurfaces\n1,032,591 Stars within {MAX_DISTANCE_LY} Light Years (Actual Data Extent)',
    position='upper_edge',
    font_size=12,
    color='white'
)
plotter.add_text(
    'Thin shells near Earth = Stellar Desert\nThick shells at distance = Dense regions',
    position='lower_left',
    font_size=9,
    color='white'
)

plotter.show_axes()

# Earth marker
origin_sphere = pv.Sphere(radius=8, center=(0, 0, 0))
plotter.add_mesh(origin_sphere, color='white', opacity=0.9)
plotter.add_point_labels(
    [[0, 0, 0]],
    ['Earth'],
    font_size=12,
    text_color='white',
    point_color='white',
    point_size=0,
    always_visible=True,
    shape_opacity=0
)


# ── DISTANCE SCALE RINGS ──────────────────────────────────────
# Flat equatorial circles at 100, 200, 300, 400, 500 ly
ring_distances = [100, 200, 300, 388]  # 388 = actual max distance in dataset

for dist in ring_distances:
    # Build a circle in the XY plane (galactic equator)
    theta = np.linspace(0, 2 * np.pi, 180)
    ring_x = dist * np.cos(theta)
    ring_y = dist * np.sin(theta)
    ring_z = np.zeros_like(theta)
    points = np.column_stack([ring_x, ring_y, ring_z])
    # Create polyline
    n = len(points)
    lines = np.zeros((n, 3), dtype=int)
    lines[:, 0] = 2
    lines[:, 1] = np.arange(n)
    lines[:, 2] = np.roll(np.arange(n), -1)
    ring = pv.PolyData(points)
    ring.lines = lines.flatten()
    plotter.add_mesh(ring, color='gray', opacity=0.35, line_width=1)
    # Label at positive X edge of each ring
    plotter.add_point_labels(
        [[dist, 0, 0]],
        [f'{dist} ly'],
        font_size=9,
        text_color='lightgray',
        point_color='lightgray',
        point_size=0,
        always_visible=True,
        shape_opacity=0,
    )

# ── SCO-CEN LABEL ─────────────────────────────────────────────
# Sco-Cen centroid from diagnostic: RA~269, Dec~-35, dist~350 ly
# Convert to Cartesian for label placement
sco_cen_ra  = np.radians(269.0)
sco_cen_dec = np.radians(-35.0)
sco_cen_dist = 410.0  # Push label just outside the 388 ly boundary
sco_cen_x = sco_cen_dist * np.cos(sco_cen_dec) * np.cos(sco_cen_ra)
sco_cen_y = sco_cen_dist * np.cos(sco_cen_dec) * np.sin(sco_cen_ra)
sco_cen_z = sco_cen_dist * np.sin(sco_cen_dec)

plotter.add_point_labels(
    [[sco_cen_x, sco_cen_y, sco_cen_z]],
    ['Scorpius-Centaurus\nOB Association\n(~350-420 ly)'],
    font_size=11,
    text_color='yellow',
    point_color='yellow',
    point_size=12,
    always_visible=True,
    shape_opacity=0
)

# ── COLOR LEGEND ─────────────────────────────────────────────
legend_entries = [
    ["Ultra sparse (0.5%)",  "darkblue"],
    ["Very sparse (1%)",     "blue"],
    ["Sparse (2%)",          "cyan"],
    ["Medium density (5%)",  "green"],
    ["Dense (10%)",          "yellow"],
    ["Very dense (20%)",     "orange"],
    ["Maximum (40%)",        "red"],
]
plotter.add_legend(
    legend_entries,
    bcolor=(0.1, 0.1, 0.1),
    border=True,
    loc="center left",
    size=(0.18, 0.22),
)

# ── INTERACTIVE WINDOW ───────────────────────────────────────
# Mouse controls:
#   Left-click + drag   → rotate
#   Right-click + drag  → zoom
#   Middle-click + drag → pan
#   Scroll wheel        → zoom in/out
#   R                   → reset camera
#   Q or close window   → quit
plotter.camera_position = 'iso'
plotter.camera.zoom(1.2)
print("\nOpening interactive window — rotate and zoom freely...")
plotter.show()

print(f"\nSummary:")
print(f"  Total stars loaded:  {len(df):,}")
print(f"  Distance cap:        {MAX_DISTANCE_LY} ly")
print(f"  Grid resolution:     {grid_size}^3 = {grid_size**3:,} cells")
print(f"  Density range:       {H_smooth.min():.2f} to {H_smooth.max():.2f}")
print(f"  Surfaces generated:  {surfaces_added}")
