import pydafruit_gfx as pyfx
import math, time, random

## OLED 
## --------------------------------------------------------------
W, H  = 128, 64
BLACK = 0x0000
WHITE = 0xFFFF

## SDL2 
## --------------------------------------------------------------
SCALE = 6
FPS   = 30
SPF   = 1.0 / FPS

disp = pyfx.GFXDisplay(W, H, scale=SCALE)
if not disp.begin("pydafruit_gfx — Retro OLED"):
    raise RuntimeError("Failed to open SDL2 window")

def transition_wipe():
    ## https://hackaday.io/project/203611-ssid-silly-space-invaders-dashboard
    for i in range(0, 80, 2):
        if not disp.is_open():
            return False
        disp.handle_events()
        cx = W // 2
        disp.draw_circle(cx,     H,     i,     BLACK)
        disp.draw_circle(cx - 1, H,     i,     BLACK)
        disp.draw_circle(cx,     H - 1, i * 2, BLACK)
        disp.draw_circle(cx,     H - 2, i * 3, BLACK)
        disp.flush()
        time.sleep(0.01)
    disp.fill_screen(BLACK)
    disp.flush()
    return True

def run_scene(draw_fn, duration):
    t0 = time.monotonic()
    frame = 0
    while disp.is_open():
        t = time.monotonic() - t0
        if t >= duration:
            return True
        ft = time.monotonic()
        if not disp.handle_events():
            return False
        draw_fn(t, frame)
        disp.flush()
        sleep = SPF - (time.monotonic() - ft)
        if sleep > 0:
            time.sleep(sleep)
        frame += 1
    return False


## Coin Spin 
## --------------------------------------------------------------
def scene_insert_coin():
    TOTAL = 256
    frames_done = [0]

    def draw(t, frame):
        w_val = frame % 256
        disp.fill_screen(BLACK)

        disp.set_font("FreeSansBold9")
        disp.set_text_color(WHITE)

        disp.set_text_size(1)
        title1 = "INSERT"
        _, _, tw, th = disp.get_text_bounds(title1, 0, 0)
        disp.set_cursor((W - tw) // 2, 15)
        disp.print(title1)

        title2 = "COIN"
        _, _, tw, th = disp.get_text_bounds(title2, 0, 0)
        disp.set_cursor((W - tw) // 2, 32)
        disp.print(title2)


        disp.set_font("TomThumb")
        disp.set_text_color(WHITE)

        title3 = "press Q to skip"
        _, _, tw, th = disp.get_text_bounds(title3, 0, 0)
        disp.set_cursor((W - tw) // 2, 53)
        disp.print(title3)

        slot_w = w_val % 8
        slot_x = W//2 - slot_w // 2
        if slot_w > 0:
            disp.fill_round_rect(slot_x, 36, slot_w, 10, 2, WHITE)

        if (frame // 30) % 2 == 0 and frame > 60:
            disp.set_font("TomThumb")
            disp.set_text_color(WHITE)
            title4 = "-- INSERT COIN --"
            _, _, tw, th = disp.get_text_bounds(title4, 0, 0)
            disp.set_cursor((W - tw) // 2, 62)
            disp.print(title4)

        frames_done[0] = frame

    ok = run_scene(draw, 4.0)
    if ok:
        transition_wipe()
    return ok


## Intro
## --------------------------------------------------------------
def scene_boot():
    def draw(t, frame):
        disp.fill_screen(BLACK)

        if t > 0.4:
            disp.draw_rect(0, 0, W, H, WHITE)
            disp.draw_rect(2, 2, W-4, H-4, WHITE)

        if t > 0.8:
            disp.set_font("FreeSansBold9")
            disp.set_text_color(WHITE)
            disp.set_cursor(6, 19)
            disp.print("PYDAFRUIT")
            disp.set_cursor(30, 35)
            disp.print("GFX")

        if t > 1.4:
            disp.set_font("TomThumb")
            # disp.set_font("FreeMono9")
            disp.set_text_color(WHITE)
            disp.set_cursor(8, 48)
            disp.print("SDL2+pybind11")

        if t > 2.0 and int(t * 3) % 2 == 0:
            disp.set_font("TomThumb")
            disp.set_text_color(WHITE)
            disp.set_cursor(8, 58)
            disp.print("< DEMO >")

    return run_scene(draw, 3.5)

## star field
## --------------------------------------------------------------
def scene_starfield():
    random.seed(7)
    N = 70
    stars = [(random.uniform(-1,1), random.uniform(-1,1),
              random.uniform(0.3, 1.0)) for _ in range(N)]
    cx, cy = W//2, H//2

    def draw(t, frame):
        disp.fill_screen(BLACK)
        speed = min(t / 1.5, 1.0) * 4.0 + 0.4

        for ox, oy, spd in stars:
            z  = ((t * spd * speed) % 2.0)
            sc = z / 2.0
            x  = int(cx + ox * sc * W)
            y  = int(cy + oy * sc * H)
            if not (0 <= x < W and 0 <= y < H):
                continue
            if sc > 0.5:
                px = int(cx + ox * max(0, sc - 0.18) * W)
                py = int(cy + oy * max(0, sc - 0.18) * H)
                disp.draw_line(px, py, x, y, WHITE)
            else:
                disp.draw_pixel(x, y, WHITE)

        disp.set_font("TomThumb")
        disp.set_text_color(WHITE)
        disp.set_cursor(2, H-1)
        disp.print("STARFIELD")

    return run_scene(draw, 5.0)

## ball bounce
## --------------------------------------------------------------
def scene_bounce():
    R     = 5
    bx    = float(W//2)
    by    = float(H//2)
    vx    = 38.0
    vy    = 27.0
    score = 0
    trail = []
    TRAIL = 14

    def draw(t, frame):
        nonlocal bx, by, vx, vy, score

        bx += vx * SPF
        by += vy * SPF

        bounced = False
        if bx - R < 0:     bx = float(R);     vx =  abs(vx); bounced=True
        if bx + R >= W:    bx = float(W-R-1); vx = -abs(vx); bounced=True
        if by - R < 0:     by = float(R);     vy =  abs(vy); bounced=True
        if by + R >= H:    by = float(H-R-1); vy = -abs(vy); bounced=True
        if bounced:
            score += 1

        trail.append((int(bx), int(by)))
        if len(trail) > TRAIL:
            trail.pop(0)

        disp.fill_screen(BLACK)
        disp.draw_rect(0, 0, W, H, WHITE)

        for i, (tx, ty) in enumerate(trail[:-1]):
            if i % 2 == 0:
                disp.draw_pixel(tx, ty, WHITE)

        disp.fill_circle(int(bx), int(by), R, WHITE)
        disp.draw_circle(int(bx), int(by), R+2, WHITE)

        disp.set_font("TomThumb")
        disp.set_text_color(WHITE)
        disp.set_cursor(3, 7)
        disp.print(f"SCORE {score:03d}")

    return run_scene(draw, 6.0)

## sprite race
## --------------------------------------------------------------
BMP_W = BMP_H = 32

bitmap_python = bytes([
  0x00,0x1F,0xF0,0x00,0x00,0x7F,0xFC,0x00,0x00,0x7F,0xFE,0x00,0x00,0xC7,0xFF,0x00,
  0x00,0xC7,0xFF,0x00,0x00,0xFF,0xFF,0x00,0x00,0xFF,0xFF,0x00,0x00,0x00,0xFF,0x00,
  0x0F,0xFF,0xFF,0x78,0x3F,0xFF,0xFF,0x7C,0x7F,0xFF,0xFF,0x7E,0x7F,0xFF,0xFF,0x7E,
  0xFF,0xFF,0xFF,0x7F,0xFF,0xFF,0xFE,0x7F,0xFF,0xFF,0xFC,0xFF,0xFF,0xF0,0x01,0xFF,
  0xFF,0x80,0x0F,0xFF,0xFF,0x3F,0xFF,0xFF,0xFE,0x7F,0xFF,0xFF,0xFE,0xFF,0xFF,0xFF,
  0x7E,0xFF,0xFF,0xFE,0x7E,0xFF,0xFF,0xFE,0x3E,0xFF,0xFF,0xFC,0x1E,0xFF,0xFF,0xF0,
  0x00,0xFF,0x00,0x00,0x00,0xFF,0xFF,0x00,0x00,0xFF,0xFF,0x00,0x00,0xFF,0xE3,0x00,
  0x00,0xFF,0xE3,0x00,0x00,0x7F,0xFE,0x00,0x00,0x3F,0xFE,0x00,0x00,0x0F,0xF8,0x00,
])
bitmap_adafruit = bytes([
  0x00,0x00,0x60,0x00,0x00,0x00,0xE0,0x00,0x00,0x01,0xE0,0x00,0x00,0x01,0xF0,0x00,
  0x00,0x03,0xF0,0x00,0x00,0x07,0xF0,0x00,0x00,0x07,0xF8,0x00,0x00,0x0F,0xF8,0x00,
  0x7F,0x0F,0xF8,0x00,0xFF,0xEF,0xF8,0x00,0xFF,0xFF,0xF8,0x00,0x7F,0xFE,0x7F,0xC0,
  0x3F,0xFE,0x7F,0xF8,0x1F,0xFE,0x7F,0xFF,0x1F,0xC6,0xFF,0xFF,0x0F,0xE3,0xC7,0xFE,
  0x07,0xFF,0x87,0xFC,0x01,0xFF,0xFF,0xF0,0x01,0xF3,0x7F,0xE0,0x03,0xE3,0x3F,0x80,
  0x07,0xE7,0x3C,0x00,0x07,0xFF,0xBE,0x00,0x07,0xFF,0xFE,0x00,0x0F,0xFF,0xFE,0x00,
  0x0F,0xFF,0xFF,0x00,0x0F,0xF9,0xFF,0x00,0x1F,0xF1,0xFF,0x00,0x1F,0x80,0xFF,0x00,
  0x1C,0x00,0x7F,0x00,0x00,0x00,0x1F,0x00,0x00,0x00,0x0F,0x00,0x00,0x00,0x06,0x00,
])
bitmap_invader = bytes([
  0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x03,0xE0,0x07,0xC0,
  0x03,0xE0,0x07,0xC0,0x03,0xE0,0x07,0xC0,0x03,0xFC,0x3F,0xC0,0x03,0xFC,0x3F,0xC0,
  0x00,0x7C,0x3E,0x00,0x00,0x7C,0x3E,0x00,0x03,0xFF,0xFF,0xC0,0x03,0xFF,0xFF,0xC0,
  0x03,0xFF,0xFF,0xC0,0x1F,0xFF,0xFF,0xF8,0x1F,0xFF,0xFF,0xF8,0x1F,0x83,0xC1,0xF8,
  0xFF,0x83,0xC1,0xFF,0xFF,0x83,0xC1,0xFF,0xFF,0x83,0xC1,0xFF,0xFF,0x83,0xC1,0xFF,
  0xFF,0xFF,0xFF,0xFF,0xFB,0xFF,0xFF,0xDF,0xFB,0xFF,0xFF,0xDF,0xFB,0xFF,0xFF,0xDF,
  0xFB,0xE0,0x07,0xDF,0xFB,0xFE,0x7F,0xDF,0xFB,0xFE,0x7F,0xDF,0x00,0x7E,0x7E,0x00,
  0x00,0x7E,0x7E,0x00,0x00,0x7E,0x7E,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
])

def scene_sprites():
    LANE_Y = [0, H//2 - BMP_H//2, H - BMP_H]
    SPEEDS = [20.0, 34.0, 19.0]
    BMPS   = [bitmap_python, bitmap_adafruit, bitmap_invader]

    xs    = [float(-BMP_W - 20), float(-BMP_W - 60), float(-BMP_W - 40)]
    laps  = [0, 0, 0]

    def draw(t, frame):
        disp.fill_screen(BLACK)

        for x in range(0, W, 6):
            disp.draw_pixel(x, LANE_Y[1] - 1, WHITE)
            disp.draw_pixel(x, LANE_Y[2] - 1, WHITE)
            disp.draw_pixel(x, H - BMP_H//2, WHITE)

        disp.draw_fast_vline(W-2, 0, H, WHITE)

        for i in range(3):
            xs[i] += SPEEDS[i] * SPF
            if xs[i] > W + 4:
                xs[i] = float(-BMP_W)
                laps[i] += 1

            ix = int(xs[i])
            iy = LANE_Y[i]
            if -BMP_W < ix < W:
                disp.draw_bitmap(ix, iy, BMPS[i], BMP_W, BMP_H, WHITE)

            disp.set_font("TomThumb")
            disp.set_text_color(WHITE)
            disp.set_cursor(W - 18, iy + BMP_H - 2)
            disp.print(f"x{laps[i]}")

        disp.set_font("TomThumb")
        disp.set_text_color(WHITE)
        disp.set_cursor(2, 5)
        disp.print("SPRITE RACE")

    return run_scene(draw, 7.0)

## star wars credits
## --------------------------------------------------------------
def scene_credits():
    lines = [
        ("PYDAFRUIT",    "FreeSansBold9"),
        ("",                  "Org01"),
        ("GFX",          "FreeSansBold9"),
        ("",                  "Org01"),
        ("SDL2 backend",      "Org01"),
        ("pybind11 bindings", "Org01"),
        ("Adafruit_GFX port", "Org01"),
        ("",                  "Org01"),
        ("draw_pixel",        "Org01"),
        ("draw_line",         "Org01"),
        ("fill_circle",       "Org01"),
        ("draw_bitmap",       "Org01"),
        ("set_font",          "Org01"),
        ("get_text_bounds",   "Org01"),
        ("",                  "TomThumb"),
        ("github:pydafruitGFX",      "TomThumb"),
        ("MIT License",       "TomThumb"),
        ("",                  "TomThumb"),
        ("* FIN *",           "FreeSansBold9"),
        ("thanks!",      "TomThumb"),
    ]

    SCROLL_SPEED = 18.0
    LINE_H       = 12

    def draw(t, frame):
        disp.fill_screen(BLACK)
        scroll = t * SCROLL_SPEED

        for i, (text, font) in enumerate(lines):
            y = H - int(scroll) + i * LINE_H + LINE_H
            if y < -LINE_H or y > H + LINE_H:
                continue
            disp.set_font(font)
            _, _, tw, _ = disp.get_text_bounds(text, 0, y)
            disp.set_text_color(WHITE)
            disp.set_cursor((W - tw) // 2, y)
            disp.print(text)

        disp.fill_rect(0, 0,   W, 4, BLACK)
        disp.fill_rect(0, H-4, W, 4, BLACK)

    duration = (len(lines) * LINE_H + H) / SCROLL_SPEED + 1.0
    return run_scene(draw, duration)

## main loop
## --------------------------------------------------------------
scenes = [
    scene_insert_coin,
    scene_boot, 
    scene_starfield, 
    scene_bounce, 
    scene_sprites, 
    scene_credits]

running = True
while running:
    for fn in scenes:
        if not disp.is_open():
            running = False
            break
        if not fn():
            running = False
            break