using System;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers.LightGbm;


class MLModel
{
    private static MLContext mlContext = new MLContext(seed: 1);
    private static string path = "C:/Users/Oskar/source/repos/PotatoPlanFinal/Game1/Content/ML/Fertilizer_Prediction.csv";
    private static string modelpath = "C:/Users/Oskar/source/repos/PotatoPlanFinal/Game1/Content/ML/MLmodel";
    private static string report = "C:/Users/Oskar/source/repos/PotatoPlanFinal/Game1/Content/ML/report";
    private static string pathBig = "C:/Users/Oskar/source/repos/PotatoPlanFinal/Game1/Content/ML/BigFertPredict.csv";
    private static string modelpathBig = "C:/Users/Oskar/source/repos/PotatoPlanFinal/Game1/Content/ML/MLmodelBig";
    private static string reportBig = "C:/Users/Oskar/source/repos/PotatoPlanFinal/Game1/Content/ML/report_BigModel";

    // Loading data, creatin and saving ML model for smaller dataset (100)
    public static void CreateModel()
    {
        
        IDataView trainingDataView = mlContext.Data.LoadFromTextFile<ModelInput>(
            path: path,
            hasHeader: true,
            separatorChar: ',',
            allowQuoting: true,
            allowSparse: false);

        ModelInput sample = mlContext.Data.CreateEnumerable<ModelInput>(trainingDataView, false).ElementAt(0);
        ITransformer MLModel = BuildAndTrain(mlContext, trainingDataView, sample, report);
        SaveModel(mlContext, MLModel, modelpath, trainingDataView.Schema);
    }

    // ... for bigger dataset (1600)
    public static void CreateBigModel()
    {
        
        IDataView trainingDataView = mlContext.Data.LoadFromTextFile<BigModelInput>(
            path: pathBig,
            hasHeader: true,
            separatorChar: ',',
            allowQuoting: true,
            allowSparse: false);

        BigModelInput sample = mlContext.Data.CreateEnumerable<BigModelInput>(trainingDataView, false).ElementAt(0);
        ITransformer MLModel = BuildAndTrain(mlContext, trainingDataView, sample, reportBig);
        SaveModel(mlContext, MLModel, modelpathBig, trainingDataView.Schema);
    }

    // Building and training ML model, very small dataset (100 entries)
    public static ITransformer BuildAndTrain(MLContext mLContext, IDataView trainingDataView, ModelInput sample, string reportPath)
    {

        var options = new LightGbmMulticlassTrainer.Options
        {
            MaximumBinCountPerFeature = 8,
            LearningRate = 0.00025,
            NumberOfIterations = 40000,
            NumberOfLeaves = 10,
            LabelColumnName = "Fertilizer_NameF",
            FeatureColumnName = "Features",

            Booster = new DartBooster.Options()
            {
                MaximumTreeDepth = 10
            }
        };

        var pipeline = mlContext.Transforms
                .Text.FeaturizeText("Soil_TypeF", "Soil_Type")
                .Append(mlContext.Transforms.Text.FeaturizeText("Crop_TypeF", "Crop_Type"))
                .Append(mlContext.Transforms.Concatenate("Features", "Temperature", "Humidity", "Moisture", "Soil_TypeF", "Crop_TypeF", "Nitrogen", "Potassium", "Phosphorous"))
                .Append(mlContext.Transforms.Conversion.MapValueToKey("Fertilizer_NameF", "Fertilizer_Name"), TransformerScope.TrainTest)
                .AppendCacheCheckpoint(mLContext)
                .Append(mLContext.MulticlassClassification.Trainers.LightGbm(options))
                .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel", "PredictedLabel"));

        Evaluate(mlContext, trainingDataView, pipeline, 10, reportPath, "Fertilizer_NameF");
        ITransformer MLModel = pipeline.Fit(trainingDataView);

        return MLModel;
    }

    //Building and training ML model, moderate size dataset (1600 entries)
    public static ITransformer BuildAndTrain(MLContext mLContext, IDataView trainingDataView, BigModelInput sample, string reportPath)
    {

        var options = new LightGbmMulticlassTrainer.Options
        {
            MaximumBinCountPerFeature = 10,
            LearningRate = 0.001,
            NumberOfIterations = 10000,
            NumberOfLeaves = 12,
            LabelColumnName = "ClassF",
            FeatureColumnName = "Features",

            Booster = new DartBooster.Options()
            {
                MaximumTreeDepth = 12
            }
        };

        var pipeline = mlContext.Transforms
                .Concatenate("Features", "Ca", "Mg", "K", "S", "N", "Lime", "C", "P", "Moisture")
                .Append(mLContext.Transforms.NormalizeMinMax("Features"))
                .Append(mlContext.Transforms.Conversion.MapValueToKey("ClassF", "Class"), TransformerScope.TrainTest)
                .AppendCacheCheckpoint(mLContext)
                .Append(mLContext.MulticlassClassification.Trainers.LightGbm(options))
                .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel", "PredictedLabel"));

        Evaluate(mlContext, trainingDataView, pipeline, 8, reportPath, "ClassF");
        ITransformer MLModel = pipeline.Fit(trainingDataView);

        return MLModel;
    }

    public static ITransformer TrainModel(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline)
    {
        ITransformer model = trainingPipeline.Fit(trainingDataView);
        return model;
    }

    // Evaluate and save results to a text file
    public static void Evaluate(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline, int folds, string reportPath, string labelColumnName)
    {
        var crossVal = mlContext.MulticlassClassification.CrossValidate(trainingDataView, trainingPipeline, numberOfFolds: folds, labelColumnName: labelColumnName);

        var metricsInMultipleFolds = crossVal.Select(r => r.Metrics);

        var MicroAccuracyValues = metricsInMultipleFolds.Select(m => m.MicroAccuracy);
        var LogLossValues = metricsInMultipleFolds.Select(m => m.LogLoss);
        var LogLossReductionValues = metricsInMultipleFolds.Select(m => m.LogLossReduction);
        string MicroAccuracyAverage = MicroAccuracyValues.Average().ToString("0.######");
        string LogLossAvg = LogLossValues.Average().ToString("0.######");
        string LogLossReductionAvg = LogLossReductionValues.Average().ToString("0.######");

        var report = File.CreateText(reportPath);
        report.Write("Micro Accuracy: " + MicroAccuracyAverage +'\n'+ "LogLoss Average: " + LogLossAvg +'\n'+ "LogLoss Reduction: " + LogLossReductionAvg, 0, 0);
        report.Flush();
        report.Close();
    }


    public static void SaveModel(MLContext mlContext, ITransformer Model, string modelPath, DataViewSchema modelInputSchema)
    {
        mlContext.Model.Save(Model, modelInputSchema, modelPath);
    }

    public static ITransformer LoadModel(bool isBig)
    {
        if (isBig)
            return mlContext.Model.Load(modelpathBig, out DataViewSchema inputSchema);
        else
            return mlContext.Model.Load(modelpath, out DataViewSchema inputSchema);
    }

    public static Microsoft.ML.PredictionEngine<ModelInput, ModelOutput> CreateEngine()
    {
        ITransformer mlModel = LoadModel(false);
        return mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(mlModel);
    }

}