PotatoPlan/Game1/Sources/ML/MLModel.cs

187 lines
7.6 KiB
C#
Raw Permalink Normal View History

2020-05-06 16:27:14 +02:00
using System;
using System.IO;
using System.Linq;
2020-06-15 01:45:16 +02:00
using System.Reflection;
2020-05-06 16:27:14 +02:00
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers.LightGbm;
2020-05-10 01:38:08 +02:00
class MLModel
2020-05-06 16:27:14 +02:00
{
2020-05-10 01:38:08 +02:00
private static MLContext mlContext = new MLContext(seed: 1);
2020-06-15 01:45:16 +02:00
private static string path_base = FindPath();
2020-06-15 00:59:39 +02:00
2020-06-15 01:45:16 +02:00
private static string path = System.IO.Path.Combine(path_base, "Content/ML/Fertilizer_Prediction.csv");
private static string modelpath = System.IO.Path.Combine(path_base, "Content/ML/MLmodel");
private static string report = System.IO.Path.Combine(path_base, "Content/ML/report");
private static string pathBig = System.IO.Path.Combine(path_base, "Content/ML/BigFertPredict.csv");
private static string modelpathBig = System.IO.Path.Combine(path_base, "Content/ML/MLmodelBig");
private static string reportBig = System.IO.Path.Combine(path_base, "Content/ML/report_BigModel");
2020-05-10 01:38:08 +02:00
// Loading data, creatin and saving ML model for smaller dataset (100)
public static void CreateModel()
2020-05-06 16:27:14 +02:00
{
2020-05-10 01:38:08 +02:00
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<ModelInput>(
path: path,
hasHeader: true,
separatorChar: ',',
allowQuoting: true,
allowSparse: false);
ModelInput sample = mlContext.Data.CreateEnumerable<ModelInput>(trainingDataView, false).ElementAt(0);
ITransformer MLModel = BuildAndTrain(mlContext, trainingDataView, sample, report);
SaveModel(mlContext, MLModel, modelpath, trainingDataView.Schema);
}
2020-05-06 16:27:14 +02:00
2020-05-10 01:38:08 +02:00
// ... for bigger dataset (1600)
public static void CreateBigModel()
{
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<BigModelInput>(
path: pathBig,
hasHeader: true,
separatorChar: ',',
allowQuoting: true,
allowSparse: false);
BigModelInput sample = mlContext.Data.CreateEnumerable<BigModelInput>(trainingDataView, false).ElementAt(0);
ITransformer MLModel = BuildAndTrain(mlContext, trainingDataView, sample, reportBig);
SaveModel(mlContext, MLModel, modelpathBig, trainingDataView.Schema);
}
2020-05-06 16:27:14 +02:00
2020-05-10 01:38:08 +02:00
// Building and training ML model, very small dataset (100 entries)
public static ITransformer BuildAndTrain(MLContext mLContext, IDataView trainingDataView, ModelInput sample, string reportPath)
{
var options = new LightGbmMulticlassTrainer.Options
2020-05-06 16:27:14 +02:00
{
2020-05-10 01:38:08 +02:00
MaximumBinCountPerFeature = 8,
LearningRate = 0.00025,
NumberOfIterations = 40000,
NumberOfLeaves = 10,
LabelColumnName = "Fertilizer_NameF",
FeatureColumnName = "Features",
Booster = new DartBooster.Options()
{
MaximumTreeDepth = 10
}
};
var pipeline = mlContext.Transforms
.Text.FeaturizeText("Soil_TypeF", "Soil_Type")
.Append(mlContext.Transforms.Text.FeaturizeText("Crop_TypeF", "Crop_Type"))
.Append(mlContext.Transforms.Concatenate("Features", "Temperature", "Humidity", "Moisture", "Soil_TypeF", "Crop_TypeF", "Nitrogen", "Potassium", "Phosphorous"))
.Append(mlContext.Transforms.Conversion.MapValueToKey("Fertilizer_NameF", "Fertilizer_Name"), TransformerScope.TrainTest)
.AppendCacheCheckpoint(mLContext)
.Append(mLContext.MulticlassClassification.Trainers.LightGbm(options))
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel", "PredictedLabel"));
Evaluate(mlContext, trainingDataView, pipeline, 10, reportPath, "Fertilizer_NameF");
ITransformer MLModel = pipeline.Fit(trainingDataView);
return MLModel;
}
//Building and training ML model, moderate size dataset (1600 entries)
public static ITransformer BuildAndTrain(MLContext mLContext, IDataView trainingDataView, BigModelInput sample, string reportPath)
{
2020-05-06 16:27:14 +02:00
2020-05-10 01:38:08 +02:00
var options = new LightGbmMulticlassTrainer.Options
2020-05-06 16:27:14 +02:00
{
2020-05-10 01:38:08 +02:00
MaximumBinCountPerFeature = 10,
LearningRate = 0.001,
NumberOfIterations = 10000,
NumberOfLeaves = 12,
LabelColumnName = "ClassF",
FeatureColumnName = "Features",
Booster = new DartBooster.Options()
{
MaximumTreeDepth = 12
}
};
var pipeline = mlContext.Transforms
.Concatenate("Features", "Ca", "Mg", "K", "S", "N", "Lime", "C", "P", "Moisture")
.Append(mLContext.Transforms.NormalizeMinMax("Features"))
.Append(mlContext.Transforms.Conversion.MapValueToKey("ClassF", "Class"), TransformerScope.TrainTest)
.AppendCacheCheckpoint(mLContext)
.Append(mLContext.MulticlassClassification.Trainers.LightGbm(options))
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel", "PredictedLabel"));
Evaluate(mlContext, trainingDataView, pipeline, 8, reportPath, "ClassF");
ITransformer MLModel = pipeline.Fit(trainingDataView);
return MLModel;
}
2020-05-06 16:27:14 +02:00
2020-05-10 01:38:08 +02:00
public static ITransformer TrainModel(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline)
{
ITransformer model = trainingPipeline.Fit(trainingDataView);
return model;
}
2020-05-06 16:27:14 +02:00
2020-05-10 01:38:08 +02:00
// Evaluate and save results to a text file
public static void Evaluate(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline, int folds, string reportPath, string labelColumnName)
{
var crossVal = mlContext.MulticlassClassification.CrossValidate(trainingDataView, trainingPipeline, numberOfFolds: folds, labelColumnName: labelColumnName);
2020-05-06 16:27:14 +02:00
2020-05-10 01:38:08 +02:00
var metricsInMultipleFolds = crossVal.Select(r => r.Metrics);
2020-05-06 16:27:14 +02:00
2020-05-10 01:38:08 +02:00
var MicroAccuracyValues = metricsInMultipleFolds.Select(m => m.MicroAccuracy);
var LogLossValues = metricsInMultipleFolds.Select(m => m.LogLoss);
var LogLossReductionValues = metricsInMultipleFolds.Select(m => m.LogLossReduction);
string MicroAccuracyAverage = MicroAccuracyValues.Average().ToString("0.######");
string LogLossAvg = LogLossValues.Average().ToString("0.######");
string LogLossReductionAvg = LogLossReductionValues.Average().ToString("0.######");
2020-05-06 17:48:11 +02:00
2020-05-10 01:38:08 +02:00
var report = File.CreateText(reportPath);
report.Write("Micro Accuracy: " + MicroAccuracyAverage +'\n'+ "LogLoss Average: " + LogLossAvg +'\n'+ "LogLoss Reduction: " + LogLossReductionAvg, 0, 0);
report.Flush();
report.Close();
}
2020-05-06 16:27:14 +02:00
2020-05-10 01:38:08 +02:00
public static void SaveModel(MLContext mlContext, ITransformer Model, string modelPath, DataViewSchema modelInputSchema)
{
mlContext.Model.Save(Model, modelInputSchema, modelPath);
}
public static ITransformer LoadModel(bool isBig)
{
if (isBig)
return mlContext.Model.Load(modelpathBig, out DataViewSchema inputSchema);
else
return mlContext.Model.Load(modelpath, out DataViewSchema inputSchema);
}
2020-05-06 17:48:11 +02:00
2020-06-15 01:45:16 +02:00
public static PredictionEngine<ModelInput, ModelOutput> CreateEngine()
2020-05-10 01:38:08 +02:00
{
ITransformer mlModel = LoadModel(false);
return mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(mlModel);
2020-05-06 16:27:14 +02:00
}
2020-05-10 01:38:08 +02:00
2020-06-15 01:45:16 +02:00
private static string FindPath()
{
string path = System.IO.Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location).ToString();
while (true)
{
path = Directory.GetParent(path).ToString();
if (path.EndsWith("\\Game1"))
{
return path;
}
}
}
2020-05-06 16:27:14 +02:00
}
2020-05-10 01:38:08 +02:00