Added CreateEngine() in ML

This commit is contained in:
BOTLester 2020-05-06 17:48:11 +02:00
parent ccf064ca06
commit 7809c46bdc

View File

@ -1,17 +1,11 @@
using System; using System;
using System.IO; using System.IO;
using System.Collections.Generic;
using System.Diagnostics;
using System.ComponentModel;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.ML; using Microsoft.ML;
using Microsoft.ML.Data; using Microsoft.ML.Data;
using Microsoft.ML.Trainers.LightGbm; using Microsoft.ML.Trainers.LightGbm;
using Microsoft.Xna.Framework;
using Microsoft.Xna.Framework.Content;
using Game1;
namespace Game1.Sources.ML namespace Game1.Sources.ML
{ {
@ -25,6 +19,7 @@ namespace Game1.Sources.ML
private static string modelpathBig = "C:/Users/Oskar/source/repos/PotatoPlanFinal/Game1/Content/ML/MLmodelBig"; private static string modelpathBig = "C:/Users/Oskar/source/repos/PotatoPlanFinal/Game1/Content/ML/MLmodelBig";
private static string reportBig = "C:/Users/Oskar/source/repos/PotatoPlanFinal/Game1/Content/ML/report_BigModel"; private static string reportBig = "C:/Users/Oskar/source/repos/PotatoPlanFinal/Game1/Content/ML/report_BigModel";
// Loading data, creatin and saving ML model for smaller dataset (100)
public static void CreateModel() public static void CreateModel()
{ {
@ -40,6 +35,7 @@ namespace Game1.Sources.ML
SaveModel(mlContext, MLModel, modelpath, trainingDataView.Schema); SaveModel(mlContext, MLModel, modelpath, trainingDataView.Schema);
} }
// ... for bigger dataset (1600)
public static void CreateBigModel() public static void CreateBigModel()
{ {
@ -55,6 +51,7 @@ namespace Game1.Sources.ML
SaveModel(mlContext, MLModel, modelpathBig, trainingDataView.Schema); SaveModel(mlContext, MLModel, modelpathBig, trainingDataView.Schema);
} }
// Building and training ML model, very small dataset (100 entries)
public static ITransformer BuildAndTrain(MLContext mLContext, IDataView trainingDataView, ModelInput sample, string reportPath) public static ITransformer BuildAndTrain(MLContext mLContext, IDataView trainingDataView, ModelInput sample, string reportPath)
{ {
@ -84,25 +81,12 @@ namespace Game1.Sources.ML
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel", "PredictedLabel")); .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel", "PredictedLabel"));
Evaluate(mlContext, trainingDataView, pipeline, 10, reportPath, "Fertilizer_NameF"); Evaluate(mlContext, trainingDataView, pipeline, 10, reportPath, "Fertilizer_NameF");
ITransformer MLModel = pipeline.Fit(trainingDataView); ITransformer MLModel = pipeline.Fit(trainingDataView);
var predEng = mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(MLModel);
ModelOutput predResult = predEng.Predict(sample);
ModelOutput predResult2 = predEng.Predict(new ModelInput()
{
Temperature = 24,
Humidity = 59,
Moisture = 57,
Soil_Type = "Loamy",
Crop_Type = "Sugarcane",
Nitrogen = 34,
Potassium = 0,
Phosporous = 3
});
return MLModel; return MLModel;
} }
//Building and training ML model, moderate size dataset (1600 entries)
public static ITransformer BuildAndTrain(MLContext mLContext, IDataView trainingDataView, BigModelInput sample, string reportPath) public static ITransformer BuildAndTrain(MLContext mLContext, IDataView trainingDataView, BigModelInput sample, string reportPath)
{ {
@ -130,10 +114,7 @@ namespace Game1.Sources.ML
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel", "PredictedLabel")); .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel", "PredictedLabel"));
Evaluate(mlContext, trainingDataView, pipeline, 8, reportPath, "ClassF"); Evaluate(mlContext, trainingDataView, pipeline, 8, reportPath, "ClassF");
ITransformer MLModel = pipeline.Fit(trainingDataView); ITransformer MLModel = pipeline.Fit(trainingDataView);
var predEng = mlContext.Model.CreatePredictionEngine<BigModelInput, BigModelOutput>(MLModel);
BigModelOutput predResult = predEng.Predict(sample);
return MLModel; return MLModel;
} }
@ -144,6 +125,7 @@ namespace Game1.Sources.ML
return model; return model;
} }
// Evaluate and save results to a text file
public static void Evaluate(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline, int folds, string reportPath, string labelColumnName) public static void Evaluate(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline, int folds, string reportPath, string labelColumnName)
{ {
var crossVal = mlContext.MulticlassClassification.CrossValidate(trainingDataView, trainingPipeline, numberOfFolds: folds, labelColumnName: labelColumnName); var crossVal = mlContext.MulticlassClassification.CrossValidate(trainingDataView, trainingPipeline, numberOfFolds: folds, labelColumnName: labelColumnName);
@ -163,6 +145,7 @@ namespace Game1.Sources.ML
report.Close(); report.Close();
} }
public static void SaveModel(MLContext mlContext, ITransformer Model, string modelPath, DataViewSchema modelInputSchema) public static void SaveModel(MLContext mlContext, ITransformer Model, string modelPath, DataViewSchema modelInputSchema)
{ {
mlContext.Model.Save(Model, modelInputSchema, modelPath); mlContext.Model.Save(Model, modelInputSchema, modelPath);
@ -176,5 +159,11 @@ namespace Game1.Sources.ML
return mlContext.Model.Load(modelpath, out DataViewSchema inputSchema); return mlContext.Model.Load(modelpath, out DataViewSchema inputSchema);
} }
public static Microsoft.ML.PredictionEngine<ModelInput, ModelOutput> CreateEngine()
{
ITransformer mlModel = LoadModel(false);
return mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(mlModel);
}
} }
} }