import json
import numpy as np
import paho.mqtt.client as mqtt
import joblib

# =====================================================
# MQTT CONFIG
# =====================================================

MQTT_BROKER = "broker.hivemq.com"

TOPIC_INPUT = "malik/imprimante/ligne1/capteurs"

TOPIC_OUTPUT = "malik/imprimante/ligne1/predictions"

# =====================================================
# LOAD MODEL
# =====================================================

print("Chargement IA...")

model = joblib.load(
    "random_forest_model.pkl"
)

scaler = joblib.load(
    "scaler.pkl"
)

encoder = joblib.load(
    "label_encoder.pkl"
)

print("IA chargée")

# =====================================================
# EXTRACT FEATURES
# =====================================================

def extract_features(payload):

    try:

        features = [

            # ELECTRICAL
            payload.get(
                "electrical",
                {}
            ).get(
                "voltage",
                0
            ),

            payload.get(
                "electrical",
                {}
            ).get(
                "current",
                0
            ),

            # THERMAL
            payload.get(
                "thermal",
                {}
            ).get(
                "motor_bottom",
                0
            ),

            payload.get(
                "thermal",
                {}
            ).get(
                "motor_top",
                0
            ),

            # GAS
            payload.get(
                "gas",
                {}
            ).get(
                "mq2",
                0
            ),

            # VIBRATION
            payload.get(
                "vibration",
                {}
            ).get(
                "left_rms",
                0
            ),

            payload.get(
                "vibration",
                {}
            ).get(
                "right_rms",
                0
            ),

            # ACC
            payload.get(
                "motion",
                {}
            ).get(
                "acceleration",
                {}
            ).get(
                "x",
                0
            ),

            payload.get(
                "motion",
                {}
            ).get(
                "acceleration",
                {}
            ).get(
                "y",
                0
            ),

            payload.get(
                "motion",
                {}
            ).get(
                "acceleration",
                {}
            ).get(
                "z",
                0
            ),

            # GYRO
            payload.get(
                "motion",
                {}
            ).get(
                "gyroscope",
                {}
            ).get(
                "x",
                0
            ),

            payload.get(
                "motion",
                {}
            ).get(
                "gyroscope",
                {}
            ).get(
                "y",
                0
            ),

            payload.get(
                "motion",
                {}
            ).get(
                "gyroscope",
                {}
            ).get(
                "z",
                0
            ),

            # SYSTEM
            payload.get(
                "system",
                {}
            ).get(
                "wifi_rssi",
                0
            ),

            payload.get(
                "system",
                {}
            ).get(
                "heap",
                0
            ),

            # MARLIN VALUES
            payload.get(
                "marlin",
                {}
            ).get(
                "nozzle_temp",
                0
            ),

            payload.get(
                "marlin",
                {}
            ).get(
                "nozzle_target",
                0
            ),

            payload.get(
                "marlin",
                {}
            ).get(
                "bed_temp",
                0
            ),

            payload.get(
                "marlin",
                {}
            ).get(
                "bed_target",
                0
            ),

            payload.get(
                "marlin",
                {}
            ).get(
                "posX",
                0
            ),

            payload.get(
                "marlin",
                {}
            ).get(
                "posY",
                0
            ),

            payload.get(
                "marlin",
                {}
            ).get(
                "posZ",
                0
            ),

            payload.get(
                "marlin",
                {}
            ).get(
                "posE",
                0
            )
        ]

        return features

    except Exception as e:

        print("Erreur extraction :", e)

        return None

# =====================================================
# MQTT CALLBACK
# =====================================================

def on_message(client, userdata, msg):

    try:

        payload = json.loads(

            msg.payload.decode()
        )

        features = extract_features(
            payload
        )

        if features is None:
            return

        # =========================================
        # PREPARE DATA
        # =========================================

        X = np.array(
            features
        ).reshape(1, -1)

        X_scaled = scaler.transform(X)

        # =========================================
        # PREDICTION
        # =========================================

        prediction = model.predict(
            X_scaled
        )[0]

        prediction_label = encoder.inverse_transform(

            [prediction]
        )[0]

        # =========================================
        # CONFIDENCE
        # =========================================

        probabilities = model.predict_proba(

            X_scaled
        )[0]

        confidence = float(

            np.max(probabilities) * 100
        )

        # =========================================
        # RESULT
        # =========================================

        result = {

            "prediction":
                prediction_label,

            "confidence":
                round(confidence, 2)
        }

        print("\n========================")

        print("PREDICTION IA")

        print(result)

        print("========================")

        # =========================================
        # MQTT OUTPUT
        # =========================================

        client.publish(

            TOPIC_OUTPUT,

            json.dumps(result)
        )

    except Exception as e:

        print("Erreur :", e)

# =====================================================
# MQTT CLIENT
# =====================================================

client = mqtt.Client()

client.on_message = on_message

print("Connexion MQTT...")

client.connect(

    MQTT_BROKER,
    1883,
    60
)

client.subscribe(TOPIC_INPUT)

print("En écoute MQTT...")

client.loop_forever()