ModelEvaluationEngine.java

package com.kapil.verbametrics.ml.engines;

import com.kapil.verbametrics.ml.config.ClassValueManager;
import com.kapil.verbametrics.ml.domain.ModelEvaluationResult;
import com.kapil.verbametrics.ml.managers.ModelFileManager;
import com.kapil.verbametrics.util.TypeSafeCastUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import weka.classifiers.Classifier;
import weka.core.Instances;

import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;

/**
 * Engine for model evaluation operations.
 * Handles the core logic for evaluating machine learning models.
 *
 * @author Kapil Garg
 */
@Component
public class ModelEvaluationEngine {

    private static final Logger LOGGER = LoggerFactory.getLogger(ModelEvaluationEngine.class);

    private final ModelFileManager fileManager;
    private final ClassValueManager classValueManager;

    @Autowired
    public ModelEvaluationEngine(ModelFileManager fileManager, ClassValueManager classValueManager) {
        this.fileManager = fileManager;
        this.classValueManager = classValueManager;
    }

    /**
     * Evaluates a trained model using test data.
     *
     * @param modelId  the ID of the trained model
     * @param testData the test dataset
     * @return evaluation result with performance metrics
     */
    public ModelEvaluationResult evaluateModel(String modelId, List<Map<String, Object>> testData) {
        Objects.requireNonNull(modelId, "Model ID cannot be null");
        Objects.requireNonNull(testData, "Test data cannot be null");
        try {
            return doEvaluateModel(modelId, testData);
        } catch (Exception e) {
            LOGGER.error("Failed to evaluate model: {}", modelId, e);
            return buildFailedEvaluationResult(modelId, testData, e);
        }
    }

    /**
     * Performs the actual model evaluation.
     *
     * @param modelId  the ID of the trained model
     * @param testData the test dataset
     * @return evaluation result with performance metrics
     * @throws Exception if evaluation fails
     */
    private ModelEvaluationResult doEvaluateModel(String modelId, List<Map<String, Object>> testData) throws Exception {
        long startTime = System.currentTimeMillis();
        if (testData.isEmpty()) {
            throw new IllegalArgumentException("Test data cannot be empty");
        }
        Object model = fileManager.loadModelFromFile(modelId)
                .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
        if (!(model instanceof Classifier)) {
            throw new IllegalArgumentException("Loaded model is not a Weka Classifier: " + model.getClass().getSimpleName());
        }
        Map<String, Object> evaluationMetrics = performModelEvaluation(model, testData, modelId);
        long evaluationTime = System.currentTimeMillis() - startTime;
        LOGGER.info("Model evaluation completed in {}ms for model: {}", evaluationTime, modelId);
        return new ModelEvaluationResult(
                modelId,
                "EVALUATION",
                true,
                (Double) evaluationMetrics.get("accuracy"),
                (Double) evaluationMetrics.get("precision"),
                (Double) evaluationMetrics.get("recall"),
                (Double) evaluationMetrics.get("f1Score"),
                (Double) evaluationMetrics.get("auc"),
                evaluationTime,
                testData.size(),
                TypeSafeCastUtil.safeCastToMap(evaluationMetrics.get("confusionMatrix")),
                TypeSafeCastUtil.safeCastToMap(evaluationMetrics.get("additionalMetrics")),
                null,
                LocalDateTime.now()
        );
    }

    /**
     * Performs the actual model evaluation using the loaded model.
     *
     * @param model    The trained model
     * @param testData The test dataset
     * @return Evaluation metrics
     * @throws Exception if evaluation fails
     */
    private Map<String, Object> performModelEvaluation(Object model, List<Map<String, Object>> testData, String modelId) throws Exception {
        if (model instanceof Classifier) {
            return evaluateWekaModel((Classifier) model, testData, modelId);
        } else {
            throw new IllegalArgumentException("Unsupported model type: " + model.getClass().getSimpleName());
        }
    }

    /**
     * Evaluates a Weka-based model.
     *
     * @param model    The Weka classifier
     * @param testData The test dataset
     * @return Evaluation metrics
     * @throws Exception if evaluation fails
     */
    private Map<String, Object> evaluateWekaModel(Classifier model, List<Map<String, Object>> testData, String modelId) throws Exception {
        Instances testDataset = createAlignedEvaluationDataset(testData, modelId);
        weka.classifiers.Evaluation evaluation = new weka.classifiers.Evaluation(testDataset);
        evaluation.evaluateModel(model, testDataset);
        double accuracy = evaluation.pctCorrect() / 100.0;
        double precision = evaluation.weightedPrecision();
        double recall = evaluation.weightedRecall();
        double f1Score = evaluation.weightedFMeasure();
        double auc = evaluation.weightedAreaUnderROC();
        double[][] confusionMatrix = evaluation.confusionMatrix();
        Map<String, Object> confusionMap = extractConfusionMatrix(confusionMatrix);
        return Map.of(
                "accuracy", accuracy,
                "precision", precision,
                "recall", recall,
                "f1Score", f1Score,
                "auc", auc,
                "confusionMatrix", confusionMap,
                "additionalMetrics", Map.of(
                        "correctPredictions", (int) evaluation.correct(),
                        "totalPredictions", testData.size(),
                        "incorrectPredictions", (int) evaluation.incorrect()
                )
        );
    }

    /**
     * Safely extracts confusion matrix values, handling cases with single class or variable dimensions.
     *
     * @param confusionMatrix The Weka confusion matrix
     * @return Map containing TP, TN, FP, FN values
     */
    private Map<String, Object> extractConfusionMatrix(double[][] confusionMatrix) {
        int rows = confusionMatrix.length;
        int cols = rows > 0 ? confusionMatrix[0].length : 0;
        if (rows == 0 || cols == 0) {
            return Map.of("TP", 0, "TN", 0, "FP", 0, "FN", 0);
        }
        if (rows == 1 && cols == 1) {
            int tp = (int) confusionMatrix[0][0];
            return Map.of("TP", tp, "TN", 0, "FP", 0, "FN", 0);
        }
        int tp = (int) confusionMatrix[0][0];
        int tn = rows > 1 && cols > 1 ? (int) confusionMatrix[1][1] : 0;
        int fp = rows > 1 ? (int) confusionMatrix[1][0] : 0;
        int fn = cols > 1 ? (int) confusionMatrix[0][1] : 0;
        return Map.of("TP", tp, "TN", tn, "FP", fp, "FN", fn);
    }

    /**
     * Creates an aligned evaluation dataset compatible with the trained model.
     *
     * @param testData The test dataset
     * @param modelId  The ID of the trained model
     * @return Aligned Weka Instances dataset
     */
    private Instances createAlignedEvaluationDataset(List<Map<String, Object>> testData, String modelId) {
        ArrayList<weka.core.Attribute> attributes = new ArrayList<>();
        int featureCount = 0;
        Map<String, Object> sample = testData.getFirst();
        Object feats = sample.get("features");
        if (feats instanceof double[] arr) {
            featureCount = arr.length;
            for (int i = 0; i < featureCount; i++) attributes.add(new weka.core.Attribute("feature_" + i));
        } else if (feats instanceof List<?> list) {
            featureCount = list.size();
            for (int i = 0; i < featureCount; i++) attributes.add(new weka.core.Attribute("feature_" + i));
        }
        List<String> classValues = classValueManager.getClassValues(modelId);
        if (classValues.isEmpty()) {
            // Fallback to common sentiment ordering if none stored
            classValues = List.of("negative", "neutral", "positive");
        }
        attributes.add(new weka.core.Attribute("label", new ArrayList<>(classValues)));
        weka.core.Instances dataset = new weka.core.Instances("EvaluationDataset", attributes, testData.size());
        dataset.setClassIndex(attributes.size() - 1);
        for (Map<String, Object> dp : testData) {
            weka.core.DenseInstance instance = new weka.core.DenseInstance(attributes.size());
            instance.setDataset(dataset);
            Object f = dp.get("features");
            if (f instanceof double[] arr2) {
                for (int i = 0; i < featureCount; i++) instance.setValue(i, i < arr2.length ? arr2[i] : 0.0);
            } else if (f instanceof List<?> l2) {
                for (int i = 0; i < featureCount; i++) {
                    Object v = i < l2.size() ? l2.get(i) : 0.0;
                    instance.setValue(i, v instanceof Number ? ((Number) v).doubleValue() : 0.0);
                }
            }
            Object lbl = dp.get("label");
            if (lbl != null) {
                String s = lbl.toString();
                if (dataset.classAttribute().indexOfValue(s) >= 0) {
                    instance.setValue(featureCount, s);
                } else {
                    instance.setMissing(featureCount);
                }
            } else {
                instance.setMissing(featureCount);
            }
            dataset.add(instance);
        }
        return dataset;
    }

    /**
     * Builds a failed evaluation result in case of exceptions.
     *
     * @param modelId  the ID of the trained model
     * @param testData the test dataset
     * @param e        the exception encountered
     * @return evaluation result indicating failure
     */
    private ModelEvaluationResult buildFailedEvaluationResult(String modelId, List<Map<String, Object>> testData, Exception e) {
        return new ModelEvaluationResult(
                modelId,
                "EVALUATION",
                false,
                0.0, 0.0, 0.0, 0.0, 0.0,
                System.currentTimeMillis(),
                testData.size(),
                Map.of(),
                Map.of(),
                e.getMessage(),
                LocalDateTime.now()
        );
    }

}