ModelTrainingEngine.java

package com.kapil.verbametrics.ml.engines;

import com.kapil.verbametrics.ml.classifiers.ModelTypeClassifier;
import com.kapil.verbametrics.ml.config.ClassValueManager;
import com.kapil.verbametrics.ml.config.MLModelProperties;
import com.kapil.verbametrics.ml.domain.ModelTrainingResult;
import com.kapil.verbametrics.ml.managers.ModelFileManager;
import com.kapil.verbametrics.ml.utils.MetricsCalculationUtils;
import com.kapil.verbametrics.ml.utils.WekaDatasetUtils;
import com.kapil.verbametrics.util.VerbaMetricsConstants;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import weka.classifiers.Classifier;
import weka.classifiers.trees.RandomTree;
import weka.core.Instances;

import java.time.LocalDateTime;
import java.util.*;

@Component
public class ModelTrainingEngine {

    private static final Logger LOGGER = LoggerFactory.getLogger(ModelTrainingEngine.class);

    private static final int CROSS_VALIDATION_SEED = 1;
    private static final int CROSS_VALIDATION_FOLDS = 5;

    private final MLModelProperties properties;
    private final ModelFileManager fileManager;
    private final ModelTypeClassifier modelTypeClassifier;
    private final ClassValueManager classValueManager;

    @Autowired
    public ModelTrainingEngine(MLModelProperties properties, ModelFileManager fileManager,
                               ModelTypeClassifier modelTypeClassifier, ClassValueManager classValueManager) {
        this.properties = properties;
        this.fileManager = fileManager;
        this.modelTypeClassifier = modelTypeClassifier;
        this.classValueManager = classValueManager;
    }

    /**
     * Trains a machine learning model based on the provided parameters.
     *
     * @param modelId      The unique identifier for the model
     * @param modelType    The type of model to train (e.g., "regression", "classification")
     * @param trainingData The training data as a list of maps
     * @param parameters   The training parameters as a map
     * @return The result of the model training
     */
    public ModelTrainingResult trainModel(String modelId, String modelType,
                                          List<Map<String, Object>> trainingData,
                                          Map<String, Object> parameters) {
        Objects.requireNonNull(modelId, "Model ID cannot be null");
        Objects.requireNonNull(modelType, "Model type cannot be null");
        Objects.requireNonNull(trainingData, "Training data cannot be null");
        Objects.requireNonNull(parameters, "Parameters cannot be null");
        try {
            long startTime = System.currentTimeMillis();
            Object trainedModel = performModelTraining(modelType, trainingData, parameters);
            long trainingTime = System.currentTimeMillis() - startTime;
            fileManager.saveModelToFile(modelId, trainedModel);
            storeClassValuesForModel(modelId, trainingData);
            Map<String, Object> performanceMetrics = calculatePerformanceMetrics(trainedModel, trainingData, modelType);
            LOGGER.info("Model training completed successfully in {}ms for model: {}", trainingTime, modelId);
            return new ModelTrainingResult(
                    modelId,
                    modelType,
                    true,
                    (Double) performanceMetrics.get("accuracy"),
                    (Double) performanceMetrics.get("precision"),
                    (Double) performanceMetrics.get("recall"),
                    (Double) performanceMetrics.get("f1Score"),
                    trainingTime,
                    trainingData.size(),
                    0,
                    performanceMetrics,
                    null,
                    LocalDateTime.now()
            );
        } catch (Exception e) {
            LOGGER.error("Failed to train model: {}", modelId, e);
            return new ModelTrainingResult(
                    modelId,
                    modelType,
                    false,
                    0.0, 0.0, 0.0, 0.0,
                    System.currentTimeMillis(),
                    trainingData.size(),
                    0,
                    Map.of(),
                    e.getMessage(),
                    LocalDateTime.now()
            );
        }
    }

    /**
     * Performs the actual model training using appropriate ML library.
     *
     * @param modelType    The type of model to train
     * @param trainingData The training dataset
     * @param parameters   The training parameters
     * @return The trained model object
     * @throws Exception if training fails
     */
    private Object performModelTraining(String modelType, List<Map<String, Object>> trainingData, Map<String, Object> parameters) throws Exception {
        return switch (modelType.toUpperCase()) {
            case VerbaMetricsConstants.K_SENTIMENT -> trainSentimentModel(trainingData, parameters);
            case VerbaMetricsConstants.K_CLASSIFICATION -> trainClassificationModel(trainingData, parameters);
            case VerbaMetricsConstants.K_TOPIC_MODELING ->
                    throw new UnsupportedOperationException("Topic modeling not implemented yet");
            default -> throw new IllegalArgumentException("Unsupported model type: " + modelType);
        };
    }

    /**
     * Validates the training data for the specified model type to ensure it meets requirements.
     *
     * @param trainingData The training data as a list of maps
     * @param modelType    The type of model to train
     * @return An optional error message if validation fails (empty if valid)
     */
    public Optional<String> validateTrainingDataError(List<Map<String, Object>> trainingData, String modelType) {
        // Normalize model type once and orchestrate granular checks
        String normalizedType = modelType == null ? "" : modelType.toUpperCase();
        Optional<String> err = checkPreflight(trainingData, normalizedType);
        if (err.isPresent()) {
            return err;
        }
        List<String> required = modelTypeClassifier.getRequiredFields(normalizedType);
        // Determine if this is a classification-like model that needs label/feature validation
        boolean isClassificationModel = VerbaMetricsConstants.K_CLASSIFICATION.equals(normalizedType)
                || VerbaMetricsConstants.K_SENTIMENT.equals(normalizedType);
        // Validate data records: for classification models, check required fields, labels, and features in one pass
        if (isClassificationModel) {
            return validateClassificationData(trainingData, required);
        }
        // For non-classification models, only check required fields if any
        if (!required.isEmpty()) {
            return checkRequiredFields(trainingData, required);
        }
        return Optional.empty();
    }

    /**
     * Basic preflight checks that validate dataset presence/size and model type validity.
     *
     * @param trainingData   The training dataset
     * @param normalizedType The normalized model type
     * @return An optional error message if preflight checks fail
     */
    private Optional<String> checkPreflight(List<Map<String, Object>> trainingData, String normalizedType) {
        if (trainingData == null || trainingData.isEmpty()) {
            return Optional.of("Training data is empty");
        }
        int minDataSize = properties.getTrainingLimits().getOrDefault("min-data-size", 10);
        if (trainingData.size() < minDataSize) {
            return Optional.of("Need at least " + minDataSize + " records, got " + trainingData.size());
        }
        if (!modelTypeClassifier.isValidModelType(normalizedType)) {
            return Optional.of("Unsupported model type: " + normalizedType + ". Supported: "
                    + String.join(", ", getSupportedModelTypes()));
        }
        return Optional.empty();
    }

    /**
     * Checks for required fields in the training data.
     *
     * @param trainingData The training dataset
     * @param required     The list of required fields (non-empty)
     * @return An optional error message if required fields are missing
     */
    private Optional<String> checkRequiredFields(List<Map<String, Object>> trainingData, List<String> required) {
        for (int i = 0; i < trainingData.size(); i++) {
            Map<String, Object> record = trainingData.get(i);
            for (String field : required) {
                Object value = record.get(field);
                if (value == null || (value instanceof String s && s.isBlank())) {
                    return Optional.of("Record #" + (i + 1) + " missing required field '" + field + "'");
                }
            }
        }
        return Optional.empty();
    }

    /**
     * Validates classification model data (required fields, labels, and features) in a single pass.
     *
     * @param trainingData The training dataset
     * @param required     The list of required fields (maybe empty)
     * @return An optional error message if validation fails
     */
    private Optional<String> validateClassificationData(List<Map<String, Object>> trainingData, List<String> required) {
        Set<String> classes = new LinkedHashSet<>();
        int expectedFeatureLength = -1;
        for (int i = 0; i < trainingData.size(); i++) {
            Map<String, Object> record = trainingData.get(i);
            int recordNum = i + 1;
            if (!required.isEmpty()) {
                for (String field : required) {
                    Object value = record.get(field);
                    if (value == null || (value instanceof String s && s.isBlank())) {
                        return Optional.of("Record #" + recordNum + " missing required field '" + field + "'");
                    }
                }
            }
            Object label = record.get("label");
            if (label == null) {
                return Optional.of("Record #" + recordNum + " missing label value");
            }
            if (label instanceof String labelStr && !labelStr.isBlank()) {
                classes.add(labelStr);
            } else {
                return Optional.of("Record #" + recordNum + " has invalid label value");
            }
            Object features = record.get("features");
            if (features == null) {
                return Optional.of("Record #" + recordNum + " missing features");
            }
            int featureLength = switch (features) {
                case double[] arr -> arr.length;
                case List<?> list -> {
                    boolean allNumbers = list.stream().allMatch(o -> o instanceof Number);
                    yield allNumbers ? list.size() : -1;
                }
                default -> -1;
            };
            if (featureLength == -1) {
                return Optional.of("Record #" + recordNum + " features must be an array of numbers");
            }
            // Check feature length consistency
            if (expectedFeatureLength == -1) {
                expectedFeatureLength = featureLength;
            } else if (featureLength != expectedFeatureLength) {
                return Optional.of("All feature vectors must be of the same length: expected " + expectedFeatureLength
                        + "elements, but record #" + recordNum + " has " + featureLength + " elements");
            }
        }
        // Validate class distribution
        if (classes.isEmpty()) {
            return Optional.of("No valid label values found in training data");
        }
        if (classes.size() == 1) {
            return Optional.of("At least 2 distinct label classes required; found " + classes);
        }
        return Optional.empty();
    }

    /**
     * Gets the list of supported model types.
     *
     * @return The list of supported model types
     */
    public List<String> getSupportedModelTypes() {
        return properties.getSupportedModelTypes();
    }

    /**
     * Gets the default parameters for a given model type.
     *
     * @param modelType The type of model
     * @return The default parameters as a map
     */
    public Map<String, Object> getDefaultParameters(String modelType) {
        Object defaultParams = properties.getDefaultParameters().get(modelType);
        if (defaultParams instanceof Map<?, ?> map) {
            Map<String, Object> params = new HashMap<>();
            for (Map.Entry<?, ?> entry : map.entrySet()) {
                if (entry.getKey() instanceof String) {
                    params.put((String) entry.getKey(), entry.getValue());
                } else {
                    return Map.of();
                }
            }
            return params;
        }
        return Map.of();
    }

    /**
     * Trains a sentiment analysis model using Weka RandomTree.
     * Uses only numeric features for training, ignoring text attributes.
     *
     * @param trainingData The training dataset
     * @param parameters   The training parameters
     * @return The trained sentiment model
     * @throws Exception if training fails
     */
    private Object trainSentimentModel(List<Map<String, Object>> trainingData, Map<String, Object> parameters) throws Exception {
        Instances dataset = createWekaDataset(trainingData);
        Instances numericDataset = new Instances(dataset);
        // Remove text attribute (index 0) as RandomTree works only with numeric features
        numericDataset.deleteAttributeAt(0);
        RandomTree model = new RandomTree();
        // Add dataset size to parameters for adaptive configuration
        Map<String, Object> adaptiveParams = new HashMap<>(parameters);
        adaptiveParams.put("datasetSize", trainingData.size());
        configureRandomTreeModel(model, adaptiveParams);
        model.buildClassifier(numericDataset);
        return model;
    }

    /**
     * Trains a general classification model using Weka library.
     * Uses only numeric features for training, ignoring text attributes.
     *
     * @param trainingData The training dataset
     * @param parameters   The training parameters
     * @return The trained classification model
     * @throws Exception if training fails
     */
    private Object trainClassificationModel(List<Map<String, Object>> trainingData, Map<String, Object> parameters) throws Exception {
        Instances dataset = createWekaDataset(trainingData);
        Instances numericDataset = new Instances(dataset);
        // Remove text attribute (index 0) as RandomTree works only with numeric features
        numericDataset.deleteAttributeAt(0);
        RandomTree classifier = new RandomTree();
        // Add dataset size to parameters for adaptive configuration
        Map<String, Object> adaptiveParams = new HashMap<>(parameters);
        adaptiveParams.put("datasetSize", trainingData.size());
        configureRandomTreeModel(classifier, adaptiveParams);
        classifier.buildClassifier(numericDataset);
        return classifier;
    }

    /**
     * Creates a Weka dataset from training data.
     *
     * @param trainingData The training dataset
     * @return The Weka Instances object
     */
    private Instances createWekaDataset(List<Map<String, Object>> trainingData) {
        return WekaDatasetUtils.createDataset(trainingData, "ClassificationDataset");
    }

    /**
     * Stores class values for a model based on training data.
     *
     * @param modelId      The model ID
     * @param trainingData The training data
     */
    private void storeClassValuesForModel(String modelId, List<Map<String, Object>> trainingData) {
        try {
            Set<String> uniqueClasses = new LinkedHashSet<>();
            for (Map<String, Object> dataPoint : trainingData) {
                Object label = dataPoint.get("label");
                if (label instanceof String) {
                    uniqueClasses.add((String) label);
                }
            }
            List<String> classValues = new ArrayList<>(uniqueClasses);
            classValueManager.storeClassValues(modelId, classValues);
            LOGGER.debug("Stored class values for model {} ({} classes)", modelId, classValues.size());
        } catch (Exception e) {
            LOGGER.warn("Failed to store class values for model {}: {}", modelId, e.getMessage());
        }
    }

    /**
     * Calculates performance metrics for the trained model.
     *
     * @param model        The trained model
     * @param trainingData The training dataset
     * @param modelType    The type of the model
     * @return A map of performance metrics
     */
    private Map<String, Object> calculatePerformanceMetrics(Object model, List<Map<String, Object>> trainingData, String modelType) {
        double accuracy = calculateModelAccuracy(model, trainingData);
        double precision = calculatePrecision(model, trainingData);
        double recall = calculateRecall(model, trainingData);
        double f1Score = MetricsCalculationUtils.calculateF1Score(precision, recall);
        return Map.of(
                "accuracy", accuracy,
                "precision", precision,
                "recall", recall,
                "f1Score", f1Score,
                "modelType", modelType,
                "trainingSamples", trainingData.size()
        );
    }

    /**
     * Configures a RandomTree model with parameters from configuration.
     * Uses adaptive settings based on dataset size to prevent overfitting.
     *
     * @param model      The RandomTree model to configure
     * @param parameters The training parameters
     */
    private void configureRandomTreeModel(RandomTree model, Map<String, Object> parameters) {
        try {
            // Adaptive defaults based on dataset size
            int datasetSize = parameters.containsKey("datasetSize") ? (Integer) parameters.get("datasetSize") : 100;
            // Adjust depth based on dataset size
            int maxDepth = datasetSize < 20 ? 3 : (datasetSize < 50 ? 5 : 8);
            int minNum = datasetSize < 20 ? 2 : 1;
            model.setMaxDepth(maxDepth);
            model.setMinNum(minNum);
            model.setSeed(42);
            if (parameters.containsKey(VerbaMetricsConstants.PARAM_MAX_DEPTH)) {
                int paramMaxDepth = (Integer) parameters.get(VerbaMetricsConstants.PARAM_MAX_DEPTH);
                if (paramMaxDepth > 0) {
                    model.setMaxDepth(Math.min(paramMaxDepth, maxDepth));
                }
            }
            if (parameters.containsKey(VerbaMetricsConstants.PARAM_MIN_SAMPLES_SPLIT)) {
                int minSamplesSplit = (Integer) parameters.get(VerbaMetricsConstants.PARAM_MIN_SAMPLES_SPLIT);
                if (minSamplesSplit > 0) {
                    model.setMinNum(Math.max(minSamplesSplit, minNum));
                }
            }
            if (parameters.containsKey(VerbaMetricsConstants.PARAM_MIN_SAMPLES_LEAF)) {
                int minSamplesLeaf = (Integer) parameters.get(VerbaMetricsConstants.PARAM_MIN_SAMPLES_LEAF);
                if (minSamplesLeaf > 0) {
                    model.setMinVarianceProp(minSamplesLeaf / 100.0);
                }
            }
            if (parameters.containsKey(VerbaMetricsConstants.PARAM_RANDOM_STATE)) {
                int randomState = (Integer) parameters.get(VerbaMetricsConstants.PARAM_RANDOM_STATE);
                model.setSeed(randomState);
            }
        } catch (Exception e) {
            LOGGER.warn("Failed to configure RandomTree model, using defaults", e);
        }
    }

    /**
     * Prepares a numeric dataset from training data by removing the text attribute.
     * This is used for cross-validation as RandomTree works only with numeric features.
     *
     * @param trainingData The training dataset
     * @return Numeric Instances object ready for cross-validation
     */
    private Instances prepareNumericDataset(List<Map<String, Object>> trainingData) {
        Instances dataset = createWekaDataset(trainingData);
        Instances numericDataset = new Instances(dataset);
        // Remove text attribute (index 0) as RandomTree works only with numeric features
        numericDataset.deleteAttributeAt(0);
        return numericDataset;
    }

    /**
     * Calculates model accuracy using Weka's cross-validation.
     *
     * @param model        The trained model
     * @param trainingData The training dataset
     * @return Real accuracy using cross-validation
     */
    private double calculateModelAccuracy(Object model, List<Map<String, Object>> trainingData) {
        if (model instanceof Classifier) {
            try {
                Instances numericDataset = prepareNumericDataset(trainingData);
                weka.classifiers.Evaluation evaluation = new weka.classifiers.Evaluation(numericDataset);
                evaluation.crossValidateModel((Classifier) model, numericDataset, CROSS_VALIDATION_FOLDS, new Random(CROSS_VALIDATION_SEED));
                return evaluation.pctCorrect() / 100.0;
            } catch (Exception e) {
                LOGGER.warn("Failed to calculate model accuracy with cross-validation", e);
                Double fallbackAccuracy = properties.getPerformanceThresholds().get("min-accuracy");
                return fallbackAccuracy != null ? fallbackAccuracy : 0.6;
            }
        }
        Double fallbackAccuracy = properties.getPerformanceThresholds().get("min-accuracy");
        return fallbackAccuracy != null ? fallbackAccuracy : 0.6;
    }

    /**
     * Calculates precision using Weka's evaluation.
     *
     * @param model        The trained model
     * @param trainingData The training dataset
     * @return Real precision value
     */
    private double calculatePrecision(Object model, List<Map<String, Object>> trainingData) {
        if (model instanceof Classifier) {
            try {
                Instances numericDataset = prepareNumericDataset(trainingData);
                weka.classifiers.Evaluation evaluation = new weka.classifiers.Evaluation(numericDataset);
                evaluation.crossValidateModel((Classifier) model, numericDataset, CROSS_VALIDATION_FOLDS, new Random(CROSS_VALIDATION_SEED));
                // Use weighted precision (across classes) rather than precision for class index 0
                return evaluation.weightedPrecision();
            } catch (Exception e) {
                LOGGER.warn("Failed to calculate precision", e);
                Double fallbackPrecision = properties.getPerformanceThresholds().get("min-precision");
                return fallbackPrecision != null ? fallbackPrecision : 0.6;
            }
        }
        Double fallbackPrecision = properties.getPerformanceThresholds().get("min-precision");
        return fallbackPrecision != null ? fallbackPrecision : 0.6;
    }

    /**
     * Calculates recall using Weka's evaluation.
     *
     * @param model        The trained model
     * @param trainingData The training dataset
     * @return Real recall value
     */
    private double calculateRecall(Object model, List<Map<String, Object>> trainingData) {
        if (model instanceof Classifier) {
            try {
                Instances numericDataset = prepareNumericDataset(trainingData);
                weka.classifiers.Evaluation evaluation = new weka.classifiers.Evaluation(numericDataset);
                evaluation.crossValidateModel((Classifier) model, numericDataset, CROSS_VALIDATION_FOLDS, new Random(CROSS_VALIDATION_SEED));
                // Use weighted recall (across classes) rather than recall for class index 0
                return evaluation.weightedRecall();
            } catch (Exception e) {
                LOGGER.warn("Failed to calculate recall", e);
                Double fallbackRecall = properties.getPerformanceThresholds().get("min-recall");
                return fallbackRecall != null ? fallbackRecall : 0.6;
            }
        }
        Double fallbackRecall = properties.getPerformanceThresholds().get("min-recall");
        return fallbackRecall != null ? fallbackRecall : 0.6;
    }

}