PotatoPlan/Game1/Sources/ML_Joel/Model.cs

134 lines
6.3 KiB
C#
Raw Normal View History

2020-05-23 20:39:05 +02:00
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers.LightGbm;
namespace Game1.Sources.ML_Joel
{
class Model
{
private static MLContext mlContext = new MLContext(seed: 1);
2020-05-24 01:57:27 +02:00
/*
2020-05-23 20:39:05 +02:00
private static string path = "C:/Users/Joel/source/repos/Oskars Repo/Game1/Content/ML/Rainfall.csv";
private static string modelpath = "C:/Users/Joel/source/repos/Oskars Repo/Game1/Content/ML/MLmodel_Joel";
private static string report = "C:/Users/Joel/source/repos/Oskars Repo/Game1/Content/ML/report_Joel";
2020-05-24 01:57:27 +02:00
*/
private static string path = "C:/Users/Oskar/source/repos/PotatoPlanFinal/Game1/Content/ML/Rainfall.csv";
private static string modelpath = "C:/Users/Oskar/source/repos/PotatoPlanFinal/Game1/Content/ML/MLmodel_Joel";
private static string report = "C:/Users/Oskar/source/repos/PotatoPlanFinal/Game1/Content/ML/report_Joel";
2020-05-23 20:39:05 +02:00
// Loading data, creatin and saving ML model for smaller dataset (100)
public static void CreateModel()
{
2020-05-24 01:57:27 +02:00
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<Input>(
2020-05-23 20:39:05 +02:00
path: path,
hasHeader: true,
separatorChar: ',',
allowQuoting: true,
allowSparse: false);
2020-05-24 01:57:27 +02:00
var splitData = mlContext.Data.TrainTestSplit(trainingDataView, testFraction: 0.2);
trainingDataView = splitData.TrainSet;
IDataView testDataView = splitData.TestSet;
Input sample = mlContext.Data.CreateEnumerable<Input>(trainingDataView, false).ElementAt(0);
ITransformer MLModel = BuildAndTrain(mlContext, trainingDataView, testDataView, sample, report);
2020-05-23 20:39:05 +02:00
SaveModel(mlContext, MLModel, modelpath, trainingDataView.Schema);
}
2020-05-24 01:57:27 +02:00
// Building and training ML model
public static ITransformer BuildAndTrain(MLContext mLContext, IDataView trainingDataView, IDataView testDataView, Input sample, string reportPath)
2020-05-23 20:39:05 +02:00
{
2020-05-24 01:57:27 +02:00
var options = new LightGbmRegressionTrainer.Options
2020-05-23 20:39:05 +02:00
{
2020-05-24 01:57:27 +02:00
MaximumBinCountPerFeature = 40,
LearningRate = 0.00020,
2020-05-24 13:25:46 +02:00
NumberOfIterations = 20000,
2020-05-24 01:57:27 +02:00
NumberOfLeaves = 55,
LabelColumnName = "Production",
2020-05-23 20:39:05 +02:00
FeatureColumnName = "Features",
2020-05-24 01:57:27 +02:00
//EarlyStoppingRound = 20,
//UseCategoricalSplit = true,
//L2CategoricalRegularization = 1,
//CategoricalSmoothing = 1,
2020-05-23 20:39:05 +02:00
Booster = new DartBooster.Options()
{
2020-05-24 13:25:46 +02:00
MaximumTreeDepth = 10
2020-05-23 20:39:05 +02:00
}
};
var pipeline = mlContext.Transforms
2020-05-24 01:57:27 +02:00
.Text.FeaturizeText("SeasonF", "Season")
.Append(mlContext.Transforms.Text.FeaturizeText("CropF", "Crop"))
.Append(mlContext.Transforms.Concatenate("Features", "SeasonF", "CropF", "Rainfall"))
//.Append(mlContext.Transforms.Conversion.MapValueToKey("ProductionF", "Production"), TransformerScope.TrainTest)
//.AppendCacheCheckpoint(mLContext)
.Append(mLContext.Regression.Trainers.LightGbm(options));
//.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel", "PredictedLabel"));
//Evaluate(mlContext, trainingDataView, pipeline, 10, reportPath, "ProductionF");
//var Evaluate = mlContext.Regression.CrossValidate(trainingDataView, pipeline, numberOfFolds: 100, labelColumnName: "Production");
//var Evaluate = mlContext.Regression.Evaluate(testDataView, labelColumnName: "Production", scoreColumnName: "Score");
//var metricsInMultipleFolds = Evaluate.Select(r => r.Metrics);
2020-05-23 20:39:05 +02:00
ITransformer MLModel = pipeline.Fit(trainingDataView);
2020-05-24 01:57:27 +02:00
var testEval = MLModel.Transform(testDataView);
var Evaluate = mlContext.Regression.Evaluate(testEval, labelColumnName: "Production");
2020-05-23 20:39:05 +02:00
return MLModel;
}
public static ITransformer TrainModel(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline)
{
ITransformer model = trainingPipeline.Fit(trainingDataView);
return model;
}
// Evaluate and save results to a text file
public static void Evaluate(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline, int folds, string reportPath, string labelColumnName)
{
var crossVal = mlContext.MulticlassClassification.CrossValidate(trainingDataView, trainingPipeline, numberOfFolds: folds, labelColumnName: labelColumnName);
var metricsInMultipleFolds = crossVal.Select(r => r.Metrics);
var MicroAccuracyValues = metricsInMultipleFolds.Select(m => m.MicroAccuracy);
var LogLossValues = metricsInMultipleFolds.Select(m => m.LogLoss);
var LogLossReductionValues = metricsInMultipleFolds.Select(m => m.LogLossReduction);
string MicroAccuracyAverage = MicroAccuracyValues.Average().ToString("0.######");
string LogLossAvg = LogLossValues.Average().ToString("0.######");
string LogLossReductionAvg = LogLossReductionValues.Average().ToString("0.######");
var report = File.CreateText(reportPath);
report.Write("Micro Accuracy: " + MicroAccuracyAverage + '\n' + "LogLoss Average: " + LogLossAvg + '\n' + "LogLoss Reduction: " + LogLossReductionAvg, 0, 0);
report.Flush();
report.Close();
}
public static void SaveModel(MLContext mlContext, ITransformer Model, string modelPath, DataViewSchema modelInputSchema)
{
mlContext.Model.Save(Model, modelInputSchema, modelPath);
}
public static ITransformer LoadModel(bool isBig)
{
return mlContext.Model.Load(modelpath, out DataViewSchema inputSchema);
}
2020-05-24 01:57:27 +02:00
public static Microsoft.ML.PredictionEngine<Input, Output> CreateEngine()
2020-05-23 20:39:05 +02:00
{
ITransformer mlModel = LoadModel(false);
2020-05-24 01:57:27 +02:00
return mlContext.Model.CreatePredictionEngine<Input, Output>(mlModel);
2020-05-23 20:39:05 +02:00
}
}
}