Added CreateEngine() in ML

This commit is contained in:
BOTLester 2020-05-06 17:48:11 +02:00
parent ccf064ca06
commit 7809c46bdc

View File

@ -1,17 +1,11 @@
using System;
using System.IO;
using System.Collections.Generic;
using System.Diagnostics;
using System.ComponentModel;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers.LightGbm;
using Microsoft.Xna.Framework;
using Microsoft.Xna.Framework.Content;
using Game1;
namespace Game1.Sources.ML
{
@ -25,6 +19,7 @@ namespace Game1.Sources.ML
private static string modelpathBig = "C:/Users/Oskar/source/repos/PotatoPlanFinal/Game1/Content/ML/MLmodelBig";
private static string reportBig = "C:/Users/Oskar/source/repos/PotatoPlanFinal/Game1/Content/ML/report_BigModel";
// Loading data, creatin and saving ML model for smaller dataset (100)
public static void CreateModel()
{
@ -40,6 +35,7 @@ namespace Game1.Sources.ML
SaveModel(mlContext, MLModel, modelpath, trainingDataView.Schema);
}
// ... for bigger dataset (1600)
public static void CreateBigModel()
{
@ -55,6 +51,7 @@ namespace Game1.Sources.ML
SaveModel(mlContext, MLModel, modelpathBig, trainingDataView.Schema);
}
// Building and training ML model, very small dataset (100 entries)
public static ITransformer BuildAndTrain(MLContext mLContext, IDataView trainingDataView, ModelInput sample, string reportPath)
{
@ -84,25 +81,12 @@ namespace Game1.Sources.ML
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel", "PredictedLabel"));
Evaluate(mlContext, trainingDataView, pipeline, 10, reportPath, "Fertilizer_NameF");
ITransformer MLModel = pipeline.Fit(trainingDataView);
var predEng = mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(MLModel);
ModelOutput predResult = predEng.Predict(sample);
ModelOutput predResult2 = predEng.Predict(new ModelInput()
{
Temperature = 24,
Humidity = 59,
Moisture = 57,
Soil_Type = "Loamy",
Crop_Type = "Sugarcane",
Nitrogen = 34,
Potassium = 0,
Phosporous = 3
});
return MLModel;
}
//Building and training ML model, moderate size dataset (1600 entries)
public static ITransformer BuildAndTrain(MLContext mLContext, IDataView trainingDataView, BigModelInput sample, string reportPath)
{
@ -130,10 +114,7 @@ namespace Game1.Sources.ML
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel", "PredictedLabel"));
Evaluate(mlContext, trainingDataView, pipeline, 8, reportPath, "ClassF");
ITransformer MLModel = pipeline.Fit(trainingDataView);
var predEng = mlContext.Model.CreatePredictionEngine<BigModelInput, BigModelOutput>(MLModel);
BigModelOutput predResult = predEng.Predict(sample);
return MLModel;
}
@ -144,6 +125,7 @@ namespace Game1.Sources.ML
return model;
}
// Evaluate and save results to a text file
public static void Evaluate(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline, int folds, string reportPath, string labelColumnName)
{
var crossVal = mlContext.MulticlassClassification.CrossValidate(trainingDataView, trainingPipeline, numberOfFolds: folds, labelColumnName: labelColumnName);
@ -163,6 +145,7 @@ namespace Game1.Sources.ML
report.Close();
}
public static void SaveModel(MLContext mlContext, ITransformer Model, string modelPath, DataViewSchema modelInputSchema)
{
mlContext.Model.Save(Model, modelInputSchema, modelPath);
@ -176,5 +159,11 @@ namespace Game1.Sources.ML
return mlContext.Model.Load(modelpath, out DataViewSchema inputSchema);
}
public static Microsoft.ML.PredictionEngine<ModelInput, ModelOutput> CreateEngine()
{
ITransformer mlModel = LoadModel(false);
return mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(mlModel);
}
}
}