MLModelServiceImpl.java

package com.kapil.verbametrics.ml.services.impl;

import com.kapil.verbametrics.ml.domain.MLModel;
import com.kapil.verbametrics.ml.domain.ModelEvaluationResult;
import com.kapil.verbametrics.ml.domain.ModelTrainingResult;
import com.kapil.verbametrics.ml.engines.ModelEvaluationEngine;
import com.kapil.verbametrics.ml.engines.ModelPredictionEngine;
import com.kapil.verbametrics.ml.entities.MLModelEntity;
import com.kapil.verbametrics.ml.mapper.MLModelMapper;
import com.kapil.verbametrics.ml.repository.MLModelRepository;
import com.kapil.verbametrics.ml.services.MLModelService;
import com.kapil.verbametrics.ml.services.ModelTrainingService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;

import java.time.LocalDateTime;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

/**
 * Service implementation for machine learning model operations using Spring Boot.
 * Handles model training, evaluation, and prediction for text analysis.
 *
 * @author Kapil Garg
 */
@Service
@Transactional
public class MLModelServiceImpl implements MLModelService {

    private static final Logger LOGGER = LoggerFactory.getLogger(MLModelServiceImpl.class);

    private final MLModelMapper modelMapper;
    private final MLModelRepository modelRepository;
    private final ModelTrainingService trainingService;
    private final ModelEvaluationEngine evaluationEngine;
    private final ModelPredictionEngine predictionEngine;

    @Autowired
    public MLModelServiceImpl(MLModelRepository modelRepository,
                              MLModelMapper modelMapper,
                              ModelTrainingService trainingService,
                              ModelEvaluationEngine evaluationEngine,
                              ModelPredictionEngine predictionEngine) {
        this.modelRepository = modelRepository;
        this.modelMapper = modelMapper;
        this.trainingService = trainingService;
        this.evaluationEngine = evaluationEngine;
        this.predictionEngine = predictionEngine;
    }

    @Override
    public ModelTrainingResult trainModel(String modelType, List<Map<String, Object>> trainingData,
                                          Map<String, Object> parameters) {
        Objects.requireNonNull(modelType, "Model type cannot be null");
        Objects.requireNonNull(trainingData, "Training data cannot be null");
        Objects.requireNonNull(parameters, "Parameters cannot be null");
        try {
            ModelTrainingResult result = trainingService.trainModel(modelType, trainingData, parameters);
            MLModel model = createMLModelFromResult(result, modelType, parameters);
            MLModelEntity entity = modelMapper.toEntity(model);
            modelRepository.save(entity);
            return result;
        } catch (Exception e) {
            LOGGER.error("Failed to train model", e);
            throw new RuntimeException("Model training failed: " + e.getMessage(), e);
        }
    }

    @Override
    public ModelEvaluationResult evaluateModel(String modelId, List<Map<String, Object>> testData) {
        Objects.requireNonNull(modelId, "Model ID cannot be null");
        Objects.requireNonNull(testData, "Test data cannot be null");
        try {
            if (!modelRepository.existsById(modelId)) {
                throw new IllegalArgumentException("Model not found: " + modelId);
            }
            if (testData.isEmpty()) {
                throw new IllegalArgumentException("Test data cannot be empty");
            }
            return evaluationEngine.evaluateModel(modelId, testData);
        } catch (Exception e) {
            LOGGER.error("Failed to evaluate model", e);
            throw new RuntimeException("Model evaluation failed: " + e.getMessage(), e);
        }
    }

    @Override
    public Map<String, Object> predict(String modelId, Map<String, Object> input) {
        Objects.requireNonNull(modelId, "Model ID cannot be null");
        Objects.requireNonNull(input, "Input cannot be null");
        try {
            MLModel model = getModel(modelId);
            if (!model.isReadyForUse()) {
                throw new IllegalStateException("Model is not ready for use: " + modelId);
            }
            return predictionEngine.predict(modelId, input);
        } catch (Exception e) {
            LOGGER.error("Failed to make prediction", e);
            throw new RuntimeException("Prediction failed: " + e.getMessage(), e);
        }
    }

    @Override
    public MLModel getModel(String modelId) {
        Objects.requireNonNull(modelId, "Model ID cannot be null");
        MLModelEntity entity = modelRepository.findById(modelId)
                .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
        return modelMapper.toDomain(entity);
    }

    @Override
    public List<MLModel> listModels() {
        return modelRepository.findAll().stream()
                .map(modelMapper::toDomain)
                .collect(Collectors.toList());
    }

    @Override
    public boolean deleteModel(String modelId) {
        Objects.requireNonNull(modelId, "Model ID cannot be null");
        if (modelRepository.existsById(modelId)) {
            modelRepository.deleteById(modelId);
            return true;
        } else {
            LOGGER.warn("Model not found for deletion: {}", modelId);
            return false;
        }
    }

    /**
     * Creates an ML model from training result.
     *
     * @param result     The training result containing performance metrics
     * @param modelType  The type of the model
     * @param parameters The parameters used for training
     * @return The created MLModel instance
     */
    private MLModel createMLModelFromResult(ModelTrainingResult result, String modelType, Map<String, Object> parameters) {
        return new MLModel(
                result.modelId(),
                modelType,
                (String) parameters.getOrDefault("name", "Model_" + result.modelId()),
                (String) parameters.getOrDefault("description", "Trained " + modelType + " model"),
                "1.0",
                LocalDateTime.now(),
                LocalDateTime.now(),
                parameters,
                Map.of("accuracy", result.accuracy(), "f1Score", result.f1Score()),
                "/models/" + result.modelId(),
                true,
                "system",
                result.trainingDataSize(),
                result.accuracy(),
                "TRAINED"
        );
    }

}