from ultralytics import YOLO
import torch
import argparse
from pathlib import Path


SCRIPT_DIR = Path(__file__).resolve().parent


AUGMENTATION_PRESETS = {
    "none": {
        "augment": False,
        "hsv_h": 0.0,
        "hsv_s": 0.0,
        "hsv_v": 0.0,
        "degrees": 0.0,
        "translate": 0.0,
        "scale": 0.0,
        "shear": 0.0,
        "perspective": 0.0,
        "fliplr": 0.0,
        "flipud": 0.0,
        "mosaic": 0.0,
        "close_mosaic": 0,
        "mixup": 0.0,
        "cutmix": 0.0,
        "copy_paste": 0.0,
        "erasing": 0.0,
    },
    "coin-medium": {
        "augment": True,
        "hsv_h": 0.005,         # Hue shift.
        "hsv_s": 0.25,          # Saturation shift.
        "hsv_v": 0.18,          # Brightness shift.
        "degrees": 25.0,        # Rotation.
        "translate": 0.10,      # Translation.
        "scale": 0.35,          # Scale.
        "shear": 3.0,           # Shear.
        "perspective": 0.001,   # Perspective warp.
        "fliplr": 0.0,          # Horizontal flip.
        "flipud": 0.0,          # Vertical flip.
        "mosaic": 0.15,         # Mosaic.
        "close_mosaic": 20,     # Mosaic closing epochs.
        "mixup": 0.0,           # MixUp.
        "cutmix": 0.0,          # CutMix.
        "copy_paste": 0.0,      # Copy-paste.
        "erasing": 0.0,         # Random erasing.
    },
}


def parse_args():
    p = argparse.ArgumentParser(description="Train an Ultralytics YOLO model")
    p.add_argument("--model", default="yolo26s.pt")
    p.add_argument("--data", default=str(SCRIPT_DIR / "Data" / "data_roboflow" / "data.yaml"))
    p.add_argument("--epochs", type=int, default=100)
    p.add_argument("--imgsz", type=int, default=640)
    p.add_argument("--batch", type=int, default=16)
    p.add_argument("--name", default="yolo_train")
    p.add_argument("--device", default="auto")
    p.add_argument("--aug-preset", choices=sorted(AUGMENTATION_PRESETS), default="coin-medium")
    p.add_argument("--augment", action="store_true", help="Shortcut for --aug-preset coin-medium")
    p.add_argument("--latest", action="store_true")
    return p.parse_args()

def device_tofind(arg_device):
    try:
        device_key = str(arg_device).lower()
        if device_key in ("auto", "default"):
            return "cuda" if torch.cuda.is_available() else "cpu"
        if device_key in ("gpu", "cuda"):
            return "cuda" if torch.cuda.is_available() else "cpu"
        if device_key.startswith("cuda"):
            return arg_device if torch.cuda.is_available() else "cpu"
    except Exception:
        return "cpu"

    try:
        if isinstance(arg_device, str) and arg_device.isdigit():
            return int(arg_device)
    except Exception:
        pass

    return arg_device

def main():
    args = parse_args()

    model_name = "yolo26s.pt"
    model_path = args.model
    if args.latest or str(args.model).lower() == "latest":
        model_path = model_name

    device_arg = args.device
    device = device_tofind(device_arg)
    print(f"Using device for training: {device}")

    data_path = Path(args.data)
    if not data_path.is_absolute():
        data_path = SCRIPT_DIR / data_path

    models_dir = SCRIPT_DIR / "Models"
    models_dir.mkdir(parents=True, exist_ok=True)

    model = YOLO(model_path)
    aug_preset = "coin-medium" if args.augment else args.aug_preset
    augmentation_args = AUGMENTATION_PRESETS[aug_preset].copy()
    print(f"Using augmentation preset: {aug_preset}")

    model.train(
        data=str(data_path),
        epochs=args.epochs,
        imgsz=args.imgsz,
        batch=args.batch,
        name=args.name,
        project=str(models_dir),
        device=device,
        **augmentation_args,
    )


if __name__ == "__main__":
    main()
