PotatoPlan/Game1/Sources/ML_Joel/Model.cs

114 lines
4.7 KiB
C#
Raw Normal View History

2020-05-23 20:39:05 +02:00
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers.LightGbm;
namespace Game1.Sources.ML_Joel
{
class Model
{
private static MLContext mlContext = new MLContext(seed: 1);
2020-05-24 01:57:27 +02:00
/*
2020-05-23 20:39:05 +02:00
private static string path = "C:/Users/Joel/source/repos/Oskars Repo/Game1/Content/ML/Rainfall.csv";
private static string modelpath = "C:/Users/Joel/source/repos/Oskars Repo/Game1/Content/ML/MLmodel_Joel";
private static string report = "C:/Users/Joel/source/repos/Oskars Repo/Game1/Content/ML/report_Joel";
2020-05-24 01:57:27 +02:00
*/
private static string path = "C:/Users/Oskar/source/repos/PotatoPlanFinal/Game1/Content/ML/Rainfall.csv";
private static string modelpath = "C:/Users/Oskar/source/repos/PotatoPlanFinal/Game1/Content/ML/MLmodel_Joel";
private static string report = "C:/Users/Oskar/source/repos/PotatoPlanFinal/Game1/Content/ML/report_Joel";
2020-05-23 20:39:05 +02:00
// Loading data, creatin and saving ML model for smaller dataset (100)
public static void CreateModel()
{
2020-05-24 01:57:27 +02:00
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<Input>(
2020-05-23 20:39:05 +02:00
path: path,
hasHeader: true,
separatorChar: ',',
allowQuoting: true,
allowSparse: false);
2020-05-24 01:57:27 +02:00
var splitData = mlContext.Data.TrainTestSplit(trainingDataView, testFraction: 0.2);
trainingDataView = splitData.TrainSet;
IDataView testDataView = splitData.TestSet;
Input sample = mlContext.Data.CreateEnumerable<Input>(trainingDataView, false).ElementAt(0);
ITransformer MLModel = BuildAndTrain(mlContext, trainingDataView, testDataView, sample, report);
2020-05-23 20:39:05 +02:00
SaveModel(mlContext, MLModel, modelpath, trainingDataView.Schema);
}
2020-05-24 01:57:27 +02:00
// Building and training ML model
public static ITransformer BuildAndTrain(MLContext mLContext, IDataView trainingDataView, IDataView testDataView, Input sample, string reportPath)
2020-05-23 20:39:05 +02:00
{
2020-05-24 01:57:27 +02:00
var options = new LightGbmRegressionTrainer.Options
2020-05-23 20:39:05 +02:00
{
2020-05-24 01:57:27 +02:00
MaximumBinCountPerFeature = 40,
LearningRate = 0.00020,
2020-05-24 13:25:46 +02:00
NumberOfIterations = 20000,
2020-05-24 01:57:27 +02:00
NumberOfLeaves = 55,
LabelColumnName = "Production",
2020-05-23 20:39:05 +02:00
FeatureColumnName = "Features",
Booster = new DartBooster.Options()
{
2020-05-24 13:25:46 +02:00
MaximumTreeDepth = 10
2020-05-23 20:39:05 +02:00
}
};
var pipeline = mlContext.Transforms
2020-05-24 01:57:27 +02:00
.Text.FeaturizeText("SeasonF", "Season")
.Append(mlContext.Transforms.Text.FeaturizeText("CropF", "Crop"))
.Append(mlContext.Transforms.Concatenate("Features", "SeasonF", "CropF", "Rainfall"))
2020-05-24 18:50:01 +02:00
.AppendCacheCheckpoint(mLContext)
2020-05-24 01:57:27 +02:00
.Append(mLContext.Regression.Trainers.LightGbm(options));
2020-05-23 20:39:05 +02:00
ITransformer MLModel = pipeline.Fit(trainingDataView);
2020-05-24 01:57:27 +02:00
var testEval = MLModel.Transform(testDataView);
2020-05-24 18:50:01 +02:00
Evaluate(mlContext, testEval, pipeline, 10, reportPath, "Production");
2020-05-23 20:39:05 +02:00
return MLModel;
}
public static ITransformer TrainModel(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline)
{
ITransformer model = trainingPipeline.Fit(trainingDataView);
return model;
}
// Evaluate and save results to a text file
public static void Evaluate(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline, int folds, string reportPath, string labelColumnName)
{
2020-05-24 18:50:01 +02:00
var eval = mlContext.Regression.Evaluate(trainingDataView, labelColumnName: labelColumnName);
2020-05-23 20:39:05 +02:00
var report = File.CreateText(reportPath);
2020-05-24 18:50:01 +02:00
report.Write("Mean Absolute Error: " + eval.MeanAbsoluteError + '\n' + "Mean Squared Error: " + eval.MeanSquaredError + '\n' + "R Squared: " + eval.RSquared, 0, 0);
2020-05-23 20:39:05 +02:00
report.Flush();
report.Close();
}
public static void SaveModel(MLContext mlContext, ITransformer Model, string modelPath, DataViewSchema modelInputSchema)
{
mlContext.Model.Save(Model, modelInputSchema, modelPath);
}
public static ITransformer LoadModel(bool isBig)
{
return mlContext.Model.Load(modelpath, out DataViewSchema inputSchema);
}
2020-05-24 01:57:27 +02:00
public static Microsoft.ML.PredictionEngine<Input, Output> CreateEngine()
2020-05-23 20:39:05 +02:00
{
ITransformer mlModel = LoadModel(false);
2020-05-24 01:57:27 +02:00
return mlContext.Model.CreatePredictionEngine<Input, Output>(mlModel);
2020-05-23 20:39:05 +02:00
}
}
}