using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers.LightGbm;

namespace Game1.Sources.ML_Joel
{
    class Model
    {
        private static MLContext mlContext = new MLContext(seed: 1);
        private static string path = "C:/Users/Joel/source/repos/Oskars Repo/Game1/Content/ML/Rainfall.csv";
        private static string modelpath = "C:/Users/Joel/source/repos/Oskars Repo/Game1/Content/ML/MLmodel_Joel";
        private static string report = "C:/Users/Joel/source/repos/Oskars Repo/Game1/Content/ML/report_Joel";

        // Loading data, creatin and saving ML model for smaller dataset (100)
        public static void CreateModel()
        {

            IDataView trainingDataView = mlContext.Data.LoadFromTextFile<ModelInput>(
                path: path,
                hasHeader: true,
                separatorChar: ',',
                allowQuoting: true,
                allowSparse: false);

            ModelInput sample = mlContext.Data.CreateEnumerable<ModelInput>(trainingDataView, false).ElementAt(0);
            ITransformer MLModel = BuildAndTrain(mlContext, trainingDataView, sample, report);
            SaveModel(mlContext, MLModel, modelpath, trainingDataView.Schema);
        }

        // Building and training ML model, very small dataset (100 entries)
        public static ITransformer BuildAndTrain(MLContext mLContext, IDataView trainingDataView, ModelInput sample, string reportPath)
        {

            var options = new LightGbmMulticlassTrainer.Options
            {
                MaximumBinCountPerFeature = 8,
                LearningRate = 0.00025,
                NumberOfIterations = 40000,
                NumberOfLeaves = 10,
                LabelColumnName = "Fertilizer_NameF",
                FeatureColumnName = "Features",

                Booster = new DartBooster.Options()
                {
                    MaximumTreeDepth = 10
                }
            };

            var pipeline = mlContext.Transforms
                    .Text.FeaturizeText("Soil_TypeF", "Soil_Type")
                    .Append(mlContext.Transforms.Text.FeaturizeText("Crop_TypeF", "Crop_Type"))
                    .Append(mlContext.Transforms.Concatenate("Features", "Temperature", "Humidity", "Moisture", "Soil_TypeF", "Crop_TypeF", "Nitrogen", "Potassium", "Phosphorous"))
                    .Append(mlContext.Transforms.Conversion.MapValueToKey("Fertilizer_NameF", "Fertilizer_Name"), TransformerScope.TrainTest)
                    .AppendCacheCheckpoint(mLContext)
                    .Append(mLContext.MulticlassClassification.Trainers.LightGbm(options))
                    .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel", "PredictedLabel"));

            Evaluate(mlContext, trainingDataView, pipeline, 10, reportPath, "Fertilizer_NameF");
            ITransformer MLModel = pipeline.Fit(trainingDataView);

            return MLModel;
        }

        public static ITransformer TrainModel(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline)
        {
            ITransformer model = trainingPipeline.Fit(trainingDataView);
            return model;
        }

        // Evaluate and save results to a text file
        public static void Evaluate(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline, int folds, string reportPath, string labelColumnName)
        {
            var crossVal = mlContext.MulticlassClassification.CrossValidate(trainingDataView, trainingPipeline, numberOfFolds: folds, labelColumnName: labelColumnName);

            var metricsInMultipleFolds = crossVal.Select(r => r.Metrics);

            var MicroAccuracyValues = metricsInMultipleFolds.Select(m => m.MicroAccuracy);
            var LogLossValues = metricsInMultipleFolds.Select(m => m.LogLoss);
            var LogLossReductionValues = metricsInMultipleFolds.Select(m => m.LogLossReduction);
            string MicroAccuracyAverage = MicroAccuracyValues.Average().ToString("0.######");
            string LogLossAvg = LogLossValues.Average().ToString("0.######");
            string LogLossReductionAvg = LogLossReductionValues.Average().ToString("0.######");

            var report = File.CreateText(reportPath);
            report.Write("Micro Accuracy: " + MicroAccuracyAverage + '\n' + "LogLoss Average: " + LogLossAvg + '\n' + "LogLoss Reduction: " + LogLossReductionAvg, 0, 0);
            report.Flush();
            report.Close();
        }


        public static void SaveModel(MLContext mlContext, ITransformer Model, string modelPath, DataViewSchema modelInputSchema)
        {
            mlContext.Model.Save(Model, modelInputSchema, modelPath);
        }

        public static ITransformer LoadModel(bool isBig)
        {
             return mlContext.Model.Load(modelpath, out DataViewSchema inputSchema);
        }

        public static Microsoft.ML.PredictionEngine<ModelInput, ModelOutput> CreateEngine()
        {
            ITransformer mlModel = LoadModel(false);
            return mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(mlModel);
        }

    }
}