import pandas as pd
import numpy as np
from scipy.ndimage import gaussian_filter
from scipy.interpolate import griddata
import pyvista as pv

# Load data
df = pd.read_csv('24000stars.csv')

# Convert parallax to distance (light years)
df['distance_ly'] = 3262 / df['parallax']

# Remove rows with NaN values
df = df.dropna(subset=['ra', 'dec', 'distance_ly', 'parallax'])
df = df[np.isfinite(df['distance_ly'])]
df = df.reset_index(drop=True)

# Convert RA/Dec to Cartesian coordinates
ra_rad = np.radians(df['ra'])
dec_rad = np.radians(df['dec'])
x = df['distance_ly'] * np.cos(dec_rad) * np.cos(ra_rad)
y = df['distance_ly'] * np.cos(dec_rad) * np.sin(ra_rad)
z = df['distance_ly'] * np.sin(dec_rad)

# Create 3D grid for density calculation
grid_size = 80  # Higher resolution for better shells
x_range = np.linspace(-130, 130, grid_size)
y_range = np.linspace(-130, 130, grid_size)
z_range = np.linspace(-130, 130, grid_size)

X, Y, Z = np.meshgrid(x_range, y_range, z_range, indexing='ij')

# Calculate density using histogram in 3D
H, edges = np.histogramdd(
    np.column_stack([x, y, z]),
    bins=(x_range, y_range, z_range)
)

# Smooth the density field
H_smooth = gaussian_filter(H, sigma=3)

# Print diagnostics
print(f"Density stats: min={H_smooth.min():.4f}, max={H_smooth.max():.4f}, mean={H_smooth.mean():.4f}")
print(f"Non-zero cells: {np.sum(H_smooth > 0)}/{H_smooth.size}")

# Create PyVista uniform grid
grid = pv.ImageData()
grid.dimensions = H_smooth.shape  # Use actual data dimensions, not +1
grid.spacing = (
    (x_range.max() - x_range.min()) / (grid_size - 1),
    (y_range.max() - y_range.min()) / (grid_size - 1),
    (z_range.max() - z_range.min()) / (grid_size - 1)
)
grid.origin = (x_range.min(), y_range.min(), z_range.min())

# Add density as POINT data for contouring
grid.point_data['density'] = H_smooth.flatten(order='F')

# Create plotter
plotter = pv.Plotter()
plotter.set_background('black')

# Define density thresholds (use very small values based on actual data)
thresholds = [
    (0.01, 'blue', 0.4, 'Very sparse'),
    (0.02, 'cyan', 0.5, 'Sparse'),
    (0.05, 'green', 0.6, 'Medium'),
    (0.10, 'yellow', 0.7, 'Dense'),
    (0.20, 'orange', 0.8, 'Very dense'),
    (0.40, 'red', 0.9, 'Maximum')
]

print(f"Thresholds: {[t[0] for t in thresholds]}")

# Add isosurfaces for each density level
surfaces_added = 0
for threshold, color, opacity, label in thresholds:
    try:
        contour = grid.contour([threshold], scalars='density')
        if contour.n_points > 0:
            plotter.add_mesh(
                contour,
                color=color,
                opacity=opacity,
                show_edges=False,
                smooth_shading=True,
                label=label
            )
            surfaces_added += 1
            print(f"Added surface at threshold {threshold:.3f}: {contour.n_points} points")
        else:
            print(f"No surface at threshold {threshold:.3f}")
    except Exception as e:
        print(f"Failed at threshold {threshold:.3f}: {e}")
        continue

print(f"Total surfaces added: {surfaces_added}")

# Add title and labels
plotter.add_text(
    '3D Stellar Density Isosurfaces\n24,071 Stars within 131 Light Years',
    position='upper_edge',
    font_size=12,
    color='white'
)

plotter.add_text(
    'Thin shells near Earth = Stellar Desert\nThick shells at distance = Dense regions',
    position='lower_left',
    font_size=9,
    color='white'
)

# Legend is added via label parameter in add_mesh calls above

# Add coordinate grid axes (not the origin marker)
plotter.show_axes()

# Add a small sphere at origin to mark Earth's position
origin_sphere = pv.Sphere(radius=5, center=(0, 0, 0))
plotter.add_mesh(origin_sphere, color='white', opacity=0.8, label='Earth (origin)')

# Set camera
plotter.camera_position = 'iso'
plotter.camera.zoom(1.2)

# Still image - uncomment and run separately if needed
# plotter.show()  # Opens window, take manual screenshot, close window
# import sys; sys.exit()  # Stop before video generation

# Save still image using off_screen rendering
# print("Saving still image...")
# plotter.screenshot('isosurface_still.png', window_size=[1920, 1080])
# print("Still image saved as isosurface_still.png")

# Create a rotating video with zoom-out effect
print("Creating rotation + zoom-out video...")
plotter.open_movie('isosurface_rotation.mp4', framerate=30)

# Render once to initialize
plotter.render()

# Generate orbital path
n_frames_zoom = 360  # First rotation with zoom
n_frames_full = 180 * 3  # Three rotations at 180 frames each = 2x faster
path_zoom = plotter.generate_orbital_path(n_points=n_frames_zoom, shift=0, viewup=[0, 0, 1], factor=2.5)
path_full = plotter.generate_orbital_path(n_points=n_frames_full, shift=0, viewup=[0, 0, 1], factor=2.5)

# Start zoomed in closer
plotter.camera.zoom(2.5)

# Phase 1: Orbit with gradual zoom out
zoom_per_frame = 1.0 - (1.5 / n_frames_zoom)

for i in range(n_frames_zoom):
    plotter.camera.position = path_zoom.points[i]
    plotter.camera.zoom(zoom_per_frame)
    plotter.write_frame()

print("Zoom-out complete, continuing with 3 full rotations...")

# Phase 2: Three more rotations at max zoom
for i in range(n_frames_full):
    plotter.camera.position = path_full.points[i]
    # No zoom change - stay at current zoom level
    plotter.write_frame()
    
plotter.close()
print("Video saved as isosurface_rotation.mp4 (4 rotations total: 1 with zoom-out + 3 at full view)")

# Uncomment below to show interactive window instead of creating video
# plotter.show()

print(f"Total stars: {len(df)}")
print(f"Grid resolution: {grid_size}³ = {grid_size**3} cells")
print(f"Density range: {H_smooth.min():.2f} to {H_smooth.max():.2f}")
print(f"Surfaces generated: Check the visualization window")
