1
0
forked from s425077/PotatoPlan
JoelForkTest/Game1/Sources/ML_Joel/Model.cs

191 lines
7.5 KiB
C#
Raw Normal View History

2020-05-23 20:39:05 +02:00
using System;
using System.Collections.Generic;
using System.IO;
2020-06-15 01:45:16 +02:00
using System.Reflection;
2020-05-23 20:39:05 +02:00
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers.LightGbm;
namespace Game1.Sources.ML_Joel
{
class Model
{
private static MLContext mlContext = new MLContext(seed: 1);
2020-06-15 01:45:16 +02:00
private static string path_base = FindPath();
private static string path = System.IO.Path.Combine(path_base, "Content/ML/Rainfall.csv");
private static string modelpath = System.IO.Path.Combine(path_base, "Content/ML/MLmodel_Joel");
private static string report = System.IO.Path.Combine(path_base, "Content/ML/report_Joel");
private static string path_area = System.IO.Path.Combine(path_base, "Content/ML/Rainfall_area.csv");
private static string modelpath_area = System.IO.Path.Combine(path_base, "Content/ML/MLmodel_Joel_area");
private static string report_area = System.IO.Path.Combine(path_base, "Content/ML/report_Joel_area");
2020-05-23 20:39:05 +02:00
// Loading data, creatin and saving ML model for smaller dataset (100)
public static void CreateModel()
{
2020-05-24 01:57:27 +02:00
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<Input>(
2020-05-23 20:39:05 +02:00
path: path,
hasHeader: true,
separatorChar: ',',
allowQuoting: true,
allowSparse: false);
2020-05-24 01:57:27 +02:00
var splitData = mlContext.Data.TrainTestSplit(trainingDataView, testFraction: 0.2);
trainingDataView = splitData.TrainSet;
IDataView testDataView = splitData.TestSet;
Input sample = mlContext.Data.CreateEnumerable<Input>(trainingDataView, false).ElementAt(0);
ITransformer MLModel = BuildAndTrain(mlContext, trainingDataView, testDataView, sample, report);
2020-05-23 20:39:05 +02:00
SaveModel(mlContext, MLModel, modelpath, trainingDataView.Schema);
}
2020-05-24 01:57:27 +02:00
// Building and training ML model
public static ITransformer BuildAndTrain(MLContext mLContext, IDataView trainingDataView, IDataView testDataView, Input sample, string reportPath)
2020-05-23 20:39:05 +02:00
{
2020-05-24 01:57:27 +02:00
var options = new LightGbmRegressionTrainer.Options
2020-05-23 20:39:05 +02:00
{
2020-05-24 01:57:27 +02:00
MaximumBinCountPerFeature = 40,
LearningRate = 0.00020,
2020-05-24 13:25:46 +02:00
NumberOfIterations = 20000,
2020-05-24 01:57:27 +02:00
NumberOfLeaves = 55,
LabelColumnName = "Production",
2020-05-23 20:39:05 +02:00
FeatureColumnName = "Features",
Booster = new DartBooster.Options()
{
2020-05-24 13:25:46 +02:00
MaximumTreeDepth = 10
2020-05-23 20:39:05 +02:00
}
};
var pipeline = mlContext.Transforms
2020-05-24 01:57:27 +02:00
.Text.FeaturizeText("SeasonF", "Season")
.Append(mlContext.Transforms.Text.FeaturizeText("CropF", "Crop"))
.Append(mlContext.Transforms.Concatenate("Features", "SeasonF", "CropF", "Rainfall"))
2020-05-24 18:50:01 +02:00
.AppendCacheCheckpoint(mLContext)
2020-05-24 01:57:27 +02:00
.Append(mLContext.Regression.Trainers.LightGbm(options));
2020-05-23 20:39:05 +02:00
ITransformer MLModel = pipeline.Fit(trainingDataView);
2020-05-24 01:57:27 +02:00
var testEval = MLModel.Transform(testDataView);
2020-05-24 18:50:01 +02:00
Evaluate(mlContext, testEval, pipeline, 10, reportPath, "Production");
2020-05-23 20:39:05 +02:00
return MLModel;
}
2020-05-24 21:00:54 +02:00
public static void CreateModelArea()
{
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<InputArea>(
path: path_area,
hasHeader: true,
separatorChar: ',',
allowQuoting: true,
allowSparse: false);
var splitData = mlContext.Data.TrainTestSplit(trainingDataView, testFraction: 0.2);
trainingDataView = splitData.TrainSet;
IDataView testDataView = splitData.TestSet;
ITransformer MLModel = BuildAndTrainArea(mlContext, trainingDataView, testDataView, report_area);
SaveModel(mlContext, MLModel, modelpath_area, trainingDataView.Schema);
}
// Building and training ML model
public static ITransformer BuildAndTrainArea(MLContext mLContext, IDataView trainingDataView, IDataView testDataView, string reportPath)
{
var options = new LightGbmRegressionTrainer.Options
{
MaximumBinCountPerFeature = 40,
LearningRate = 0.00020,
NumberOfIterations = 20000,
NumberOfLeaves = 55,
LabelColumnName = "Production",
FeatureColumnName = "Features",
Booster = new DartBooster.Options()
{
MaximumTreeDepth = 10
}
};
var pipeline = mlContext.Transforms
.Text.FeaturizeText("SeasonF", "Season")
.Append(mlContext.Transforms.Text.FeaturizeText("CropF", "Crop"))
.Append(mlContext.Transforms.Concatenate("Features", "SeasonF", "CropF", "Area", "Rainfall"))
.AppendCacheCheckpoint(mLContext)
.Append(mLContext.Regression.Trainers.LightGbm(options));
ITransformer MLModel = pipeline.Fit(trainingDataView);
var testEval = MLModel.Transform(testDataView);
Evaluate(mlContext, testEval, pipeline, 10, reportPath, "Production");
return MLModel;
}
2020-05-23 20:39:05 +02:00
public static ITransformer TrainModel(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline)
{
ITransformer model = trainingPipeline.Fit(trainingDataView);
return model;
}
// Evaluate and save results to a text file
public static void Evaluate(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline, int folds, string reportPath, string labelColumnName)
{
2020-05-24 18:50:01 +02:00
var eval = mlContext.Regression.Evaluate(trainingDataView, labelColumnName: labelColumnName);
2020-05-23 20:39:05 +02:00
var report = File.CreateText(reportPath);
2020-05-25 00:03:32 +02:00
report.Write("Mean Absolute Error: " + eval.MeanAbsoluteError + '\n' + "R Squared: " + eval.RSquared, 0, 0);
2020-05-23 20:39:05 +02:00
report.Flush();
report.Close();
}
public static void SaveModel(MLContext mlContext, ITransformer Model, string modelPath, DataViewSchema modelInputSchema)
{
mlContext.Model.Save(Model, modelInputSchema, modelPath);
}
2020-05-24 21:00:54 +02:00
public static ITransformer LoadModel()
2020-05-23 20:39:05 +02:00
{
return mlContext.Model.Load(modelpath, out DataViewSchema inputSchema);
}
2020-06-15 01:45:16 +02:00
public static PredictionEngine<Input, Output> CreateEngine()
2020-05-23 20:39:05 +02:00
{
2020-05-24 21:00:54 +02:00
ITransformer mlModel = LoadModel();
2020-05-24 01:57:27 +02:00
return mlContext.Model.CreatePredictionEngine<Input, Output>(mlModel);
2020-05-23 20:39:05 +02:00
}
2020-05-24 21:00:54 +02:00
public static ITransformer LoadModelArea()
{
return mlContext.Model.Load(modelpath_area, out DataViewSchema inputSchema);
}
2020-06-15 01:45:16 +02:00
public static PredictionEngine<InputArea, OutputArea> CreateEngineArea()
2020-05-24 21:00:54 +02:00
{
ITransformer mlModel = LoadModelArea();
return mlContext.Model.CreatePredictionEngine<InputArea, OutputArea>(mlModel);
}
2020-06-15 01:45:16 +02:00
private static string FindPath()
{
string path = System.IO.Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location).ToString();
while(true)
{
path = Directory.GetParent(path).ToString();
if (path.EndsWith("\\Game1"))
{
return path;
}
}
}
2020-05-23 20:39:05 +02:00
}
}