ClassValueManager.java

package com.kapil.verbametrics.ml.config;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;

/**
 * Manages class values for ML models to ensure consistent mapping between training and prediction.
 *
 * @author Kapil Garg
 */
@Component
public class ClassValueManager {

    private static final Logger LOGGER = LoggerFactory.getLogger(ClassValueManager.class);

    private final Map<String, List<String>> modelClassValues = new ConcurrentHashMap<>();

    /**
     * Stores class values for a specific model.
     *
     * @param modelId     the model ID
     * @param classValues the class values in the order they were used during training
     */
    public void storeClassValues(String modelId, List<String> classValues) {
        Objects.requireNonNull(modelId, "Model ID cannot be null");
        Objects.requireNonNull(classValues, "Class values cannot be null");
        List<String> storedValues = new ArrayList<>(classValues);
        modelClassValues.put(modelId, storedValues);
        LOGGER.debug("Stored class values for model {} ({} classes)", modelId, storedValues.size());
    }

    /**
     * Retrieves class values for a specific model.
     *
     * @param modelId the model ID
     * @return the class values for the model, or empty list if not found
     */
    public List<String> getClassValues(String modelId) {
        Objects.requireNonNull(modelId, "Model ID cannot be null");
        List<String> classValues = modelClassValues.get(modelId);
        if (classValues == null) {
            LOGGER.warn("No class values found for model: {}", modelId);
            return new ArrayList<>();
        }
        return new ArrayList<>(classValues);
    }

}