import cv2
import time
import torch
import serial
import torchvision.models as models
from PIL import Image
from torchvision import transforms

CAM_INDEX = 1
CONF_THRESHOLD = 0.75
CONTOUR_AREA_THRESHOLD = 5000
OBJECT_CLEAR_THRESHOLD = 2000

device = torch.device("cpu")

esp32 = None

try:
    esp32 = serial.Serial(
        "/dev/ttyUSB0",
        115200,
        timeout=1
    )

    time.sleep(2)
    print("ESP32 connected")

except Exception as e:
    print("SERIAL ERROR:", e)


def send_serial(cmd, retries=3):

    if esp32 is None:
        print("SERIAL OFF:", cmd)
        return False

    for attempt in range(retries):
        try:
            esp32.reset_input_buffer()
            esp32.write((cmd + "\n").encode())
            print(f"SENT: {cmd} (attempt {attempt + 1})")

            start = time.time()
            while time.time() - start < 3.0:
                if esp32.in_waiting:
                    response = esp32.readline().decode("utf-8", errors="replace").strip()
                    print(f"ESP32 >> {response}")
                    if "Command Valid" in response:
                        return True
                time.sleep(0.05)

            print("No ACK, retrying...")

        except Exception as e:
            print(f"SERIAL FAIL: {e}")

    print("Failed after all retries")
    return False


print("Loading MobileNetV2...")

ckpt = torch.load("original.pth", map_location=device)

class_names = ckpt["class_names"]
NUM_CLASSES = len(class_names)

model = models.mobilenet_v2(weights=None)

in_features = model.classifier[1].in_features

model.classifier[1] = torch.nn.Linear(
    in_features,
    NUM_CLASSES
)

model.load_state_dict(ckpt["model_state_dict"])

model.to(device)
model.eval()

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.485, 0.456, 0.406],
        [0.229, 0.224, 0.225]
    )
])


def map_waste(label):
    if label in ["Vegetation", "Food Organics"]:
        return "ORGANIC"

    elif label in ["Paper", "Cardboard"]:
        return "PAPER"

    elif label in ["Glass", "Metal"]:
        return "HAZARD"

    elif label in ["Plastic", "Textile Trash", "Miscellaneous Trash"]:
        return "ANORGANIC"

    return "UNKNOWN"


cap = cv2.VideoCapture(CAM_INDEX)

if not cap.isOpened():
    raise RuntimeError("Camera not found")

ret, prev_frame = cap.read()

if not ret:
    raise RuntimeError("Cannot read camera")

prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY)

COOLDOWN_AFTER_SEND = 8.0

object_detected = False
cooldown_until  = 0.0

print("Monitoring contour...")

try:
    while True:

        ret, frame = cap.read()
        if not ret:
            continue

        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        gray = cv2.GaussianBlur(gray, (5, 5), 0)

        diff = cv2.absdiff(prev_gray, gray)

        _, thresh = cv2.threshold(
            diff, 25, 255, cv2.THRESH_BINARY
        )

        contours, _ = cv2.findContours(
            thresh,
            cv2.RETR_EXTERNAL,
            cv2.CHAIN_APPROX_SIMPLE
        )

        max_area = 0
        for cnt in contours:
            area = cv2.contourArea(cnt)
            if area > max_area:
                max_area = area

        now = time.time()
        in_cooldown = now < cooldown_until

        if in_cooldown:
            remaining = cooldown_until - now
            cv2.putText(
                frame,
                f"COOLDOWN {remaining:.1f}s",
                (20, 120),
                cv2.FONT_HERSHEY_SIMPLEX,
                0.8,
                (0, 100, 255),
                2
            )
            prev_gray = gray
            cv2.imshow("ZEUS Contour Trigger", frame)
            cv2.imshow("Threshold", thresh)
            if cv2.waitKey(1) == ord('q'):
                break
            continue

        if max_area > CONTOUR_AREA_THRESHOLD and not object_detected:

            object_detected = True
            time.sleep(3.0)
            ret, frame = cap.read()

            print("\nObject detected")
            print("Running inference...")

            rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            img = Image.fromarray(rgb)

            x = transform(img).unsqueeze(0).to(device)

            with torch.no_grad():
                out = model(x)
                prob = torch.softmax(out, dim=1)
                conf, pred = torch.max(prob, 1)

            idx = pred.item()
            label = class_names[idx]
            confidence = conf.item()

            waste = map_waste(label)

            print(f"Prediction : {label}")
            print(f"Confidence : {confidence * 100:.1f}%")
            print(f"Category   : {waste}")

            if confidence >= CONF_THRESHOLD:
                success = send_serial(waste)
                if success:
                    cooldown_until = time.time() + COOLDOWN_AFTER_SEND
                    print(f"Cooldown active for {COOLDOWN_AFTER_SEND:.0f} seconds")
            else:
                print("Confidence too low, not sending command")

        elif max_area < OBJECT_CLEAR_THRESHOLD and object_detected:
            object_detected = False
            print("Ready for next object")

        display = frame.copy()

        cv2.putText(
            display,
            f"Area: {int(max_area)}",
            (20, 40),
            cv2.FONT_HERSHEY_SIMPLEX,
            0.8,
            (0, 255, 0),
            2
        )

        status = "OBJECT DETECTED" if object_detected else "WAITING"

        cv2.putText(
            display,
            status,
            (20, 80),
            cv2.FONT_HERSHEY_SIMPLEX,
            0.8,
            (0, 255, 255),
            2
        )

        cv2.imshow("ZEUS Contour Trigger", display)
        cv2.imshow("Threshold", thresh)

        prev_gray = gray

        key = cv2.waitKey(1)
        if key == ord('q'):
            break

except KeyboardInterrupt:
    print("\nStopping")

finally:
    cap.release()
    cv2.destroyAllWindows()

    if esp32:
        esp32.close()

    print("Cleanup complete")