import os
import shutil
import numpy as np
from tensorflow.keras.models import load_model  # type: ignore
from tensorflow.keras.preprocessing.image import load_img, img_to_array  # type: ignore
from tensorflow.keras.optimizers import Adam  # type: ignore

DATA_PATH = '/home/mark/Projects/glyph-generator/data2'
MODEL_PATH = f'{DATA_PATH}/models/classifier-8x8-V2.keras'  # Path to the trained model
INPUT_PATH = f'{DATA_PATH}/input/glyphs-unclassified/'  # Directory with images to classify
GOOD_PATH = f'{DATA_PATH}/output/glyphs-good'  # Directory for classified 'good' images
BAD_PATH = f'{DATA_PATH}/output/glyphs-bad'  # Directory for classified 'bad' images

INPUT_PATH = '/home/mark/Pictures/dupes'

IMAGE_SIZE = (16, 16)  # Input image size used for training (height and width)

# Load the trained model
model = load_model(MODEL_PATH)

# Compile the loaded model (optional but recommended)
model.compile(optimizer=Adam(), loss='categorical_crossentropy', metrics=['accuracy'])

# Ensure output directories exist
os.makedirs(GOOD_PATH, exist_ok=True)
os.makedirs(BAD_PATH, exist_ok=True)

def classify_image(image_path: str) -> int:
    """Classify a single image and return classification result."""
    # Load and preprocess the image
    img = load_img(image_path, target_size=IMAGE_SIZE)
    img_array = img_to_array(img) / 255.0  # Normalize to [0, 1]
    img_array = np.expand_dims(img_array, axis=0)

    # Predict using the model
    prediction = model.predict(img_array)[0][0]  # Access the prediction probability
    print(f'Prediction for {image_path}: {prediction:.2f}')  # Debug: log the prediction probabilities
    return 1 if prediction >= 0.99 else 0

# Classify all images in the input directory
image_files = [f for f in os.listdir(INPUT_PATH) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
for filename in image_files:
    file_path = os.path.join(INPUT_PATH, filename)
    result = classify_image(file_path)

    if result == 1:
        target_path = os.path.join(GOOD_PATH if result == 1 else BAD_PATH, filename)
        shutil.move(file_path, target_path)

print("Classification complete. Images moved to 'good' and 'bad' directories.")
