import pandas as pd
import numpy as np
import joblib

from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import StandardScaler

from sklearn.metrics import (
    accuracy_score,
    classification_report,
    confusion_matrix
)

# =====================================================
# CONFIG
# =====================================================

DATASET = "dataset_labeled.csv"

MODEL_FILE = "random_forest_model.pkl"
SCALER_FILE = "scaler.pkl"
ENCODER_FILE = "label_encoder.pkl"

# =====================================================
# LOAD DATASET
# =====================================================

print("\nChargement du dataset...")

df = pd.read_csv(DATASET)

print(f"Dataset chargé : {len(df)} lignes")

# =====================================================
# CLEAN
# =====================================================

df = df.dropna()

# =====================================================
# FEATURES
# =====================================================

FEATURES = [

    # ELECTRICAL
    "voltage",
    "current",

    # THERMAL
    "temp_motor_bottom",
    "temp_motor_top",

    # GAS
    "gas",

    # VIBRATION
    "vibration_left",
    "vibration_right",

    # ACC
    "accX",
    "accY",
    "accZ",

    # GYRO
    "gyroX",
    "gyroY",
    "gyroZ",

    # SYSTEM
    "wifi_rssi",
    "heap",

    # MARLIN
    "nozzle_temp",
    "nozzle_target",

    "bed_temp",
    "bed_target",

    "posX",
    "posY",
    "posZ",
    "posE"
]

TARGET = "fault_label"

# =====================================================
# CHECK FEATURES
# =====================================================

missing = [

    col for col in FEATURES

    if col not in df.columns
]

if missing:

    print("\nColonnes manquantes :")

    print(missing)

    exit()

# =====================================================
# X / Y
# =====================================================

X = df[FEATURES]

y = df[TARGET]

# =====================================================
# LABEL ENCODER
# =====================================================

encoder = LabelEncoder()

y_encoded = encoder.fit_transform(y)

# =====================================================
# NORMALIZATION
# =====================================================

scaler = StandardScaler()

X_scaled = scaler.fit_transform(X)

# =====================================================
# SPLIT
# =====================================================

X_train, X_test, y_train, y_test = train_test_split(

    X_scaled,
    y_encoded,

    test_size=0.2,

    random_state=42,

    stratify=y_encoded
)

# =====================================================
# RANDOM FOREST
# =====================================================

model = RandomForestClassifier(

    n_estimators=300,

    max_depth=20,

    min_samples_split=5,

    min_samples_leaf=2,

    class_weight="balanced",

    random_state=42,

    n_jobs=-1
)

# =====================================================
# TRAIN
# =====================================================

print("\nEntraînement IA...")

model.fit(X_train, y_train)

print("Training terminé")

# =====================================================
# EVALUATION
# =====================================================

y_pred = model.predict(X_test)

accuracy = accuracy_score(

    y_test,
    y_pred
)

print(f"\nAccuracy : {accuracy * 100:.2f}%")

print("\nClassification Report :\n")

print(

    classification_report(

        y_test,
        y_pred,

        target_names=encoder.classes_
    )
)

# =====================================================
# FEATURE IMPORTANCE
# =====================================================

importance = pd.DataFrame({

    "Feature": FEATURES,

    "Importance": model.feature_importances_
})

importance = importance.sort_values(

    by="Importance",
    ascending=False
)

print("\nTop Features :\n")

print(importance.head(15))

# =====================================================
# SAVE MODEL
# =====================================================

joblib.dump(model, MODEL_FILE)

joblib.dump(scaler, SCALER_FILE)

joblib.dump(encoder, ENCODER_FILE)

print("\nModèle sauvegardé")

print("Fichiers générés :")

print("-", MODEL_FILE)
print("-", SCALER_FILE)
print("-", ENCODER_FILE)