binary classifier evaluator

This commit is contained in:
marcin-szczepanski 2023-12-07 12:48:17 +01:00
parent 5d4da97dda
commit 5e52235dc4
2 changed files with 164 additions and 8 deletions

View File

@ -1,12 +1,19 @@
package net.sourceforge.jFuzzyLogic;
import org.antlr.runtime.RecognitionException;
import net.sourceforge.jFuzzyLogic.demo.tipper.TipperAnimation;
import net.sourceforge.jFuzzyLogic.plot.JFuzzyChart;
import net.sourceforge.jFuzzyLogic.rule.Rule;
import net.sourceforge.jFuzzyLogic.rule.RuleBlock;
import net.sourceforge.jFuzzyLogic.rule.Variable;
import org.antlr.runtime.RecognitionException;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;
/**
* Main jFuzzyLogic class
@ -25,6 +32,7 @@ public class JFuzzyLogic {
public static final String VERSION_SHORT = VERSION_MAJOR + REVISION;
public static final String VERSION_NO_NAME = VERSION_SHORT + " (build " + BUILD + "), by " + Pcingola.BY;
public static final String VERSION = SOFTWARE_NAME + " " + VERSION_NO_NAME;
private static final String COMMA_DELIMITER = ";";
public static boolean debug = false;
@ -83,6 +91,150 @@ public class JFuzzyLogic {
}
}
double classifyBinary(int argNum) {
double result = -1.0;
List<List<String>> records = new ArrayList<>();
String testFileName = args[argNum++];
try (Scanner scanner = new Scanner(new File(testFileName))) {
while (scanner.hasNextLine()) {
records.add(getRecordFromLine(scanner.nextLine()));
}
} catch (FileNotFoundException e) {
throw new RuntimeException(e);
}
List<Integer> expected = new ArrayList<>();
List<Integer> predicted = new ArrayList<>();
int tp = 0, tn = 0, fp = 0, fn = 0;
String fileName = args[argNum++];
load(fileName);
double decisionPoint = Double.parseDouble(args[argNum++]);
String decisionType = args[argNum++];
for (List<String> record : records) {
String expectedValue = record.get(record.size() - 1);
expected.add(Integer.valueOf(expectedValue));
//---
// Assign values (parse command line)
//---
for (FunctionBlock fb : fis) {
int k = 0;
for (Variable var : fb.variablesSorted()) {
if (var.isInput()) {
var.setValue(Gpr.parseDoubleSafe(record.get(k)));
k++;
}
}
}
//---
// Evaluate and show results
//---
fis.evaluate(); // Evaluate
// Show values
for (FunctionBlock fb : fis) {
// Show variables
for (Variable var : fb.variablesSorted()) {
if (var.isOutput()) {
double res = var.getValue();
switch (decisionType) {
case "gt":
if (res > decisionPoint) {
predicted.add(1);
} else {
predicted.add(0);
}
break;
case "lt":
if (res < decisionPoint) {
predicted.add(1);
} else {
predicted.add(0);
}
break;
case "geq":
if (res >= decisionPoint) {
predicted.add(1);
} else {
predicted.add(0);
}
break;
case "leq":
if (res <= decisionPoint) {
predicted.add(1);
} else {
predicted.add(0);
}
break;
}
}
}
}
if (expected.get(expected.size() - 1).equals(predicted.get(predicted.size() - 1))) {
if (predicted.get(predicted.size() - 1).equals(1)) {
tp = tp + 1;
} else {
tn = tn + 1;
}
} else {
if (predicted.get(predicted.size() - 1).equals(1)) {
fp = fp + 1;
} else {
fn = fn + 1;
}
}
}
double prec = (tp * 1.0) / (tp + fp);
double rec = (tp * 1.0) / (tp + fn);
result = 2.0 * ((prec * rec * 1.0) / (prec + rec));
try {
saveCSV(testFileName.replace(".csv", "_out.csv"), records, predicted);
} catch (IOException e) {
throw new RuntimeException(e);
}
System.out.println("F1-score is: " + result);
return result;
}
public void saveCSV(String fileName, List<List<String>> records, List<Integer> colToAdd) throws IOException {
File csvOutputFile = new File(fileName);
csvOutputFile.createNewFile();
try (PrintWriter pw = new PrintWriter(csvOutputFile)) {
for (int k = 0; k < records.size(); k++) {
List<String> record = records.get(k);
StringBuilder stringBuilder = new StringBuilder();
for (String recordValue : record) {
stringBuilder.append(recordValue).append(COMMA_DELIMITER);
}
stringBuilder.append(colToAdd.get(k)).append("\n");
pw.printf(stringBuilder.toString());
}
}
}
private List<String> getRecordFromLine(String line) {
List<String> values = new ArrayList<>();
try (Scanner rowScanner = new Scanner(line)) {
rowScanner.useDelimiter(COMMA_DELIMITER);
while (rowScanner.hasNext()) {
values.add(rowScanner.next());
}
}
return values;
}
/**
* Evaluate an FCL file
*/
@ -192,6 +344,9 @@ public class JFuzzyLogic {
} else if (arg.equals("-e")) {
evaluate(i + 1);
return;
} else if (arg.equals("-b")) {
classifyBinary(i + 1);
return;
} else if (arg.equalsIgnoreCase("-noCharts")) {
// Do not use chart classes
JFuzzyChart.UseMockClass = true;
@ -225,6 +380,7 @@ public class JFuzzyLogic {
System.err.println("\t-c file.fcl : Compile. Generate C++ code from FCL file (to STDOUT)");
System.err.println("\t-j file.fcl : Compile. Generate JavaScript code from FCL file (to STDOUT)");
System.err.println("\t-e file.fcl in_1 in_2 ... in_N : Evaluate. Load FCL file, assign inputs i_1, i_2, ..., i_n and evaluate (variables sorted alphabetically).");
System.err.println("\t-b test.csv file.fcl decisionValue decisionType : Evaluate for CSV file - last column must be for \"expected values\"; decisionType: gt; lt; geq; leq; decisionValue: IF calculatedValue decisionType decisionValue THEN return 1; ELSE return 0;");
System.err.println("\t-noCharts : Use a mock class for charts. This is used when not compiled using JFreeCharts.");
System.err.println("\tdemo : Run a demo example (tipper.fcl)");
System.exit(1);