diff --git a/.idea/ai-project.iml b/.idea/ai-project.iml
index a23e703..0c80114 100644
--- a/.idea/ai-project.iml
+++ b/.idea/ai-project.iml
@@ -5,7 +5,7 @@
-
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
index dc9ea49..3fc3916 100644
--- a/.idea/misc.xml
+++ b/.idea/misc.xml
@@ -1,4 +1,4 @@
-
+
\ No newline at end of file
diff --git a/assets/data/test.csv b/assets/data/test.csv
new file mode 100644
index 0000000..8fd5029
--- /dev/null
+++ b/assets/data/test.csv
@@ -0,0 +1,55 @@
+stan_nawodnienia,rodzaj_gleby,stan_nawiezienia,stopien_rozwoju,rodzaj_rosliny,rodzaj_nawozu,to_water
+0.32,piaszczyste,0.06,0.40,ziemniak,mineralny,1
+0.60,piaszczyste,0.72,0.79,kaktus,sztuczny,0
+0.76,czarnoziemy,0.49,0.57,pszenica,organiczny,0
+0.66,czarnoziemy,0.49,0.57,brak,organiczny,0
+0.25,brunatne,0.18,0.23,ziemniak,sztuczny,1
+0.39,czarnoziemy,0.88,0.85,kaktus,organiczny,0
+0.02,brunatne,0.58,0.85,kaktus,sztuczny,1
+0.96,piaszczyste,0.70,0.51,ziemniak,organiczny,0
+0.32,brunatne,0.69,0.23,pszenica,sztuczny,1
+0.50,piaszczyste,0.52,0.39,kaktus,sztuczny,0
+0.45,brunatne,0.91,0.93,ziemniak,organiczny,0
+0.31,czarnoziemy,0.82,0.74,ziemniak,mineralny,1
+0.21,piaszczyste,0.64,0.24,pszenica,sztuczny,1
+0.69,brunatne,0.82,0.65,ziemniak,mineralny,0
+0.44,piaszczyste,0.38,0.84,pszenica,organiczny,1
+0.11,brunatne,0.22,0.79,kaktus,organiczny,1
+0.83,czarnoziemy,0.64,0.74,pszenica,mineralny,0
+0.73,brunatne,0.54,0.94,ziemniak,mineralny,0
+0.90,piaszczyste,0.35,0.23,pszenica,mineralny,0
+0.78,piaszczyste,0.5,0.66,ziemniak,organiczny,0
+0.31,czarnoziemy,0.42,0.34,ziemniak,mineralny,1
+0.76,czarnoziemy,0.49,0.57,brak,sztuczny,0
+0.71,czarnoziemy,0.67,0.50,pszenica,sztuczny,0
+0.75,piaszczyste,0.53,0.35,kaktus,mineralny,0
+0.41,czarnoziemy,0.12,0.05,ziemniak,sztuczny,1
+0.85,czarnoziemy,0.64,0.36,pszenica,organiczny,0
+0.77,czarnoziemy,0.34,0.65,ziemniak,sztuczny,0
+0.85,brunatne,0.25,0.0,brak,sztuczny,0
+0.68,czarnoziemy,0.66,0.77,pszenica,organiczny,0
+0.28,czarnoziemy,0.55,0,brak,sztuczny,1
+0.63,brunatne,0.88,0.85,ziemniak,mineralny,0
+0.56,brunatne,0.88,0.75,pszenica,organiczny,0
+0.45,piaszczyste,0.77,0.65,kaktus,organiczny,0
+0.83,brunatne,0.66,0.41,ziemniak,mineralny,0
+0.34,piaszczyste,0.72,0.53,pszenica,mineralny,1
+0.64,piaszczyste,0.79,0.33,pszenica,sztuczny,0
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/assets/data/test.csv.bak b/assets/data/test.csv.bak
new file mode 100644
index 0000000..3bad4de
--- /dev/null
+++ b/assets/data/test.csv.bak
@@ -0,0 +1,55 @@
+stan_nawodnienia,rodzaj_gleby,stan_nawiezienia,stopien_rozwoju,rodzaj_rosliny,rodzaj_nawozu,to_water
+0.32,piaszczyste,0.06,0.40,ziemniak,mineralny,1
+0.60,piaszczyste,0.72,0.79,kaktus,sztuczny,0
+0.76,czarnoziemy,0.49,0.57,pszenica,organiczny,0
+0.66,czarnoziemy,0.49,0.57,brak,organiczny,0
+0.25,brunatne,0.18,0.23,ziemniak,sztuczny,1
+0.39,czarnoziemy,0.88,0.85,kaktus,organiczny,0
+0.02,brunatne,0.38,0.85,kaktus,sztuczny,1
+0.96,piaszczyste,0.70,0.51,ziemniak,organiczny,0
+0.32,brunatne,0.69,0.23,pszenica,sztuczny,1
+0.50,piaszczyste,0.52,0.39,kaktus,sztuczny,0
+0.45,brunatne,0.91,0.93,ziemniak,organiczny,0
+0.31,czarnoziemy,0.82,0.74,ziemniak,mineralny,1
+0.21,piaszczyste,0.64,0.24,pszenica,sztuczny,1
+0.69,brunatne,0.82,0.65,ziemniak,mineralny,0
+0.44,piaszczyste,0.38,0.84,pszenica,organiczny,1
+0.11,brunatne,0.22,0.79,kaktus,organiczny,1
+0.83,czarnoziemy,0.64,0.74,pszenica,mineralny,0
+0.73,brunatne,0.54,0.94,ziemniak,mineralny,0
+0.90,piaszczyste,0.35,0.23,pszenica,mineralny,0
+0.78,piaszczyste,0.5,0.66,ziemniak,organiczny,0
+0.31,czarnoziemy,0.42,0.34,ziemniak,mineralny,1
+0.76,czarnoziemy,0.49,0.57,brak,sztuczny,0
+0.71,czarnoziemy,0.67,0.50,pszenica,sztuczny,0
+0.75,piaszczyste,0.53,0.35,kaktus,mineralny,0
+0.41,czarnoziemy,0.12,0.05,ziemniak,sztuczny,1
+0.85,czarnoziemy,0.64,0.36,pszenica,organiczny,0
+0.77,czarnoziemy,0.34,0.65,ziemniak,sztuczny,0
+0.85,brunatne,0.25,0.0,brak,sztuczny,0
+0.68,czarnoziemy,0.66,0.77,pszenica,organiczny,0
+0.28,czarnoziemy,0.55,0,brak,sztuczny,1
+0.63,brunatne,0.88,0.85,ziemniak,mineralny,0
+0.56,brunatne,0.88,0.75,pszenica,organiczny,0
+0.45,piaszczyste,0.77,0.65,kaktus,organiczny,0
+0.83,brunatne,0.66,0.41,ziemniak,mineralny,0
+0.34,piaszczyste,0.72,0.53,pszenica,mineralny,1
+0.64,piaszczyste,0.79,0.33,pszenica,sztuczny,0
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/assets/data/train.csv b/assets/data/train.csv
new file mode 100644
index 0000000..e487d60
--- /dev/null
+++ b/assets/data/train.csv
@@ -0,0 +1,178 @@
+stan_nawodnienia,rodzaj_gleby,stan_nawiezienia,stopien_rozwoju,rodzaj_rosliny,rodzaj_nawozu,to_water
+0.32,brunatne,0.01,0.75,ziemniak,mineralny,1
+0.38,brunatne,0.15,0.85,pszenica,organiczny,1
+0.17,piaszczyste,0.22,0.13,brak,organiczny,1
+0.67,czarnoziemy,0.64,0.55,pszenica,sztuczny,0
+0.45,brunatne,0.16,0.16,brak,organiczny,1
+0.25,czarnoziemy,0.16,0.23,pszenica,organiczny,1
+0.10,brunatne,0.02,0.17,brak,mineralny,1
+0.25,brunatne,0.39,0.17,kaktus,organiczny,1
+0.92,brunatne,0.68,0.77,ziemniak,sztuczny,0
+0.70,piaszczyste,0.72,0.79,kaktus,sztuczny,0
+0.54,czarnoziemy,0.95,0.59,ziemniak,organiczny,0
+0.21,piaszczyste,0.05,0.03,brak,mineralny,1
+0.23,brunatne,0.11,0.86,ziemniak,organiczny,1
+0.61,czarnoziemy,0.76,0.76,ziemniak,sztuczny,0
+0.87,brunatne,0.53,0.85,pszenica,organiczny,0
+0.21,brunatne,0.01,0.26,kaktus,mineralny,1
+0.32,piaszczyste,0.06,0.40,ziemniak,mineralny,1
+0.76,czarnoziemy,0.79,0.67,pszenica,mineralny,0
+0.30,brunatne,0.21,0.10,brak,mineralny,1
+0.66,piaszczyste,0.93,0.95,kaktus,sztuczny,0
+0.71,brunatne,0.52,0.62,ziemniak,mineralny,0
+0.43,piaszczyste,0.38,0.35,pszenica,sztuczny,1
+0.66,brunatne,0.72,0.63,pszenica,organiczny,0
+0.21,piaszczyste,0.47,0.96,brak,sztuczny,1
+0.78,brunatne,0.59,0.68,pszenica,organiczny,0
+0.85,piaszczyste,0.77,0.64,pszenica,organiczny,0
+0.84,piaszczyste,0.75,0.86,ziemniak,sztuczny,0
+0.71,czarnoziemy,0.72,0.74,ziemniak,mineralny,1
+0.75,piaszczyste,0.77,0.74,ziemniak,sztuczny,0
+0.76,piaszczyste,0.75,0.81,pszenica,mineralny,0
+0.72,brunatne,0.74,0.84,kaktus,sztuczny,0
+0.59,brunatne,0.82,0.88,pszenica,organiczny,0
+0.48,czarnoziemy,0.15,0.27,ziemniak,organiczny,0
+0.04,piaszczyste,0.90,0.8,pszenica,organiczny,0
+0.23,brunatne,0.12,0.56,brak,organiczny,1
+0.02,piaszczyste,0.20,0.18,brak,organiczny,1
+0.02,piaszczyste,0.35,0.24,ziemniak,mineralny,1
+0.76,piaszczyste,0.78,0.85,ziemniak,mineralny,0
+0.72,czarnoziemy,0.67,0.76,brak,mineralny,0
+0.06,brunatne,0.02,0.71,ziemniak,mineralny,1
+0.96,czarnoziemy,0.66,0.80,ziemniak,sztuczny,0
+0.34,brunatne,0.22,0.15,kaktus,mineralny,0
+0.78,piaszczyste,0.62,0.68,pszenica,mineralny,0
+0.01,czarnoziemy,0.22,0.05,pszenica,sztuczny,1
+0.02,czarnoziemy,0.31,0.80,pszenica,sztuczny,1
+0.11,piaszczyste,0.14,0.14,pszenica,organiczny,1
+0.95,piaszczyste,0.61,0.94,brak,sztuczny,0
+0.78,brunatne,0.86,0.63,pszenica,organiczny,0
+0.03,brunatne,0.27,0.25,pszenica,sztuczny,1
+0.62,czarnoziemy,0.75,0.75,ziemniak,sztuczny,0
+0.24,piaszczyste,0.02,0.23,pszenica,organiczny,1
+0.46,czarnoziemy,0.79,0.27,ziemniak,sztuczny,0
+0.24,piaszczyste,0.74,0.53,pszenica,mineralny,1
+0.82,czarnoziemy,0.77,0.86,ziemniak,mineralny,0
+0.25,brunatne,0.18,0.23,pszenica,sztuczny,1
+0.73,czarnoziemy,0.94,0.54,brak,organiczny,0
+0.63,piaszczyste,0.84,0.92,kaktus,sztuczny,0
+0.78,czarnoziemy,0.75,0.94,brak,sztuczny,0
+0.84,piaszczyste,0.79,0.80,pszenica,sztuczny,0
+0.78,czarnoziemy,0.95,0.63,pszenica,mineralny,0
+0.43,brunatne,0.38,0.52,pszenica,organiczny,1
+0.88,piaszczyste,0.84,0.78,ziemniak,sztuczny,0
+0.69,brunatne,0.86,0.55,ziemniak,organiczny,0
+0.77,piaszczyste,0.78,0.55,ziemniak,organiczny,0
+0.26,brunatne,0.24,0.12,ziemniak,mineralny,1
+0.75,czarnoziemy,0.64,0.73,brak,mineralny,0
+0.14,piaszczyste,0.38,0.44,pszenica,sztuczny,1
+0.84,czarnoziemy,0.85,0.67,ziemniak,organiczny,0
+0.33,brunatne,0.11,0.12,brak,organiczny,1
+0.89,brunatne,0.67,0.88,brak,organiczny,0
+0.15,brunatne,0.21,0.23,pszenica,mineralny,1
+0.07,piaszczyste,0.24,0.67,brak,sztuczny,1
+0.29,czarnoziemy,0.10,0.14,ziemniak,sztuczny,1
+0.99,piaszczyste,0.99,0.75,kaktus,mineralny,0
+0.62,piaszczyste,0.65,0.75,brak,organiczny,0
+0.33,czarnoziemy,0.23,0.23,brak,sztuczny,1
+0.16,brunatne,0.28,0.38,brak,organiczny,1
+0.35,piaszczyste,0.18,0.13,ziemniak,mineralny,0
+0.85,czarnoziemy,0.64,0.74,pszenica,sztuczny,0
+0.78,czarnoziemy,0.59,0.72,ziemniak,organiczny,0
+0.24,czarnoziemy,0.14,0.18,brak,organiczny,1
+0.25,czarnoziemy,0.13,0.16,pszenica,sztuczny,1
+0.87,czarnoziemy,0.87,0.76,pszenica,mineralny,0
+0.28,brunatne,0.44,0.27,brak,mineralny,1
+0.47,brunatne,0.09,0.05,ziemniak,mineralny,1
+0.68,piaszczyste,0.72,0.67,brak,sztuczny,0
+0.23,czarnoziemy,0.27,0.25,pszenica,sztuczny,1
+0.05,piaszczyste,0.11,0.16,brak,organiczny,1
+0.12,piaszczyste,0.21,0.22,brak,mineralny,1
+0.86,brunatne,0.55,0.67,ziemniak,organiczny,0
+0.68,brunatne,0.76,0.87,pszenica,organiczny,0
+0.04,brunatne,0.32,0.26,kaktus,sztuczny,1
+0.73,piaszczyste,0.73,0.73,pszenica,sztuczny,0
+0.01,czarnoziemy,0.42,0.34,ziemniak,mineralny,1
+0.96,piaszczyste,0.70,0.71,ziemniak,organiczny,0
+0.74,piaszczyste,0.66,0.79,kaktus,mineralny,0
+0.75,brunatne,0.66,0.82,pszenica,mineralny,0
+0.70,brunatne,0.69,0.63,pszenica,organiczny,0
+0.66,brunatne,0.85,0.72,brak,sztuczny,0
+0.73,piaszczyste,0.58,0.95,ziemniak,mineralny,0
+0.36,brunatne,0.14,0.02,brak,mineralny,0
+0.61,piaszczyste,0.70,0.66,pszenica,mineralny,0
+0.83,piaszczyste,0.68,0.75,ziemniak,sztuczny,0
+0.55,brunatne,0.68,0.79,ziemniak,sztuczny,0
+0.75,czarnoziemy,0.64,0.58,kaktus,organiczny,0
+0.99,brunatne,0.57,0.82,pszenica,organiczny,0
+0.88,czarnoziemy,0.68,0.65,brak,sztuczny,0
+0.58,czarnoziemy,0.68,0.71,ziemniak,organiczny,1
+0.14,piaszczyste,0.46,0.17,brak,sztuczny,1
+0.06,czarnoziemy,0.23,0.25,kaktus,sztuczny,1
+0.56,brunatne,0.88,0.65,ziemniak,organiczny,0
+0.32,brunatne,0.39,0.23,pszenica,organiczny,1
+0.64,piaszczyste,0.85,0.99,ziemniak,mineralny,0
+0.25,czarnoziemy,0.26,0.45,kaktus,organiczny,0
+0.78,piaszczyste,0.85,0.92,brak,mineralny,0
+0.14,brunatne,0.09,0.13,pszenica,organiczny,1
+0.82,brunatne,0.81,0.80,brak,mineralny,0
+0.65,piaszczyste,0.58,0.78,pszenica,mineralny,0
+0.32,piaszczyste,0.25,0.04,ziemniak,organiczny,1
+0.33,czarnoziemy,0.12,0.46,kaktus,sztuczny,0
+0.58,czarnoziemy,0.76,0.66,pszenica,mineralny,0
+0.01,czarnoziemy,0.32,0.33,ziemniak,mineralny,1
+0.65,brunatne,0.88,0.66,ziemniak,organiczny,0
+0.12,czarnoziemy,0.16,0.13,brak,mineralny,1
+0.86,brunatne,0.98,0.88,brak,sztuczny,0
+0.23,brunatne,0.15,0.15,ziemniak,mineralny,1
+0.24,piaszczyste,0.37,0.34,kaktus,mineralny,0
+0.05,brunatne,0.18,0.39,ziemniak,mineralny,1
+0.23,czarnoziemy,0.26,0.35,ziemniak,organiczny,1
+0.55,czarnoziemy,0.78,0.89,brak,organiczny,0
+0.25,czarnoziemy,0.13,0.47,brak,mineralny,1
+0.62,czarnoziemy,0.82,0.68,brak,sztuczny,0
+0.85,czarnoziemy,0.64,0.56,brak,mineralny,0
+0.58,piaszczyste,0.94,0.74,brak,sztuczny,0
+0.17,piaszczyste,0.34,0.15,ziemniak,mineralny,1
+0.26,czarnoziemy,0.21,0.15,pszenica,mineralny,1
+0.22,czarnoziemy,0.35,0.45,pszenica,mineralny,1
+0.21,piaszczyste,0.26,0.26,pszenica,organiczny,1
+0.38,piaszczyste,0.36,0.14,pszenica,mineralny,1
+0.17,piaszczyste,0.38,0.27,pszenica,organiczny,1
+0.11,piaszczyste,0.21,0.23,pszenica,sztuczny,1
+0.48,czarnoziemy,0.29,0.28,kaktus,organiczny,0
+0.16,czarnoziemy,0.21,0.56,ziemniak,organiczny,1
+0.22,piaszczyste,0.22,0.19,ziemniak,sztuczny,1
+0.63,czarnoziemy,0.87,0.77,pszenica,mineralny,0
+0.69,piaszczyste,0.85,0.78,brak,mineralny,0
+0.55,brunatne,0.50,0.79,brak,organiczny,0
+0.14,piaszczyste,0.14,0.35,ziemniak,mineralny,1
+0.95,czarnoziemy,0.82,0.92,pszenica,organiczny,0
+0.68,piaszczyste,0.57,0.62,pszenica,mineralny,0
+0.60,brunatne,0.66,0.85,ziemniak,mineralny,0
+0.50,brunatne,0.91,0.83,ziemniak,organiczny,0
+0.36,piaszczyste,0.08,0.75,kaktus,mineralny,0
+0.13,brunatne,0.26,0.38,kaktus,sztuczny,1
+0.32,piaszczyste,0.25,0.24,ziemniak,sztuczny,1
+0.84,brunatne,0.54,0.50,brak,mineralny,0
+0.31,czarnoziemy,0.49,0.14,pszenica,organiczny,1
+0.62,piaszczyste,0.74,0.85,kaktus,sztuczny,0
+0.85,czarnoziemy,0.67,0.74,brak,organiczny,0
+0.74,piaszczyste,0.69,0.98,ziemniak,organiczny,0
+0.16,piaszczyste,0.16,0.36,pszenica,mineralny,1
+0.01,czarnoziemy,0.12,0.29,ziemniak,organiczny,1
+0.55,brunatne,0.67,0.74,brak,sztuczny,0
+0.75,czarnoziemy,0.86,0.90,ziemniak,sztuczny,0
+0.42,czarnoziemy,0.25,0.39,kaktus,sztuczny,0
+0.25,brunatne,0.30,0.24,ziemniak,organiczny,1
+0.12,czarnoziemy,0.20,0.49,ziemniak,mineralny,1
+0.75,brunatne,0.75,0.54,pszenica,organiczny,0
+0.65,czarnoziemy,0.82,0.76,pszenica,mineralny,0
+0.79,czarnoziemy,0.66,0.77,pszenica,sztuczny,0
+0.62,brunatne,0.78,0.84,brak,mineralny,0
+0.05,czarnoziemy,0.19,0.17,brak,organiczny,1
+0.22,czarnoziemy,0.24,0.25,kaktus,organiczny,1
+0.17,brunatne,0.31,0.35,ziemniak,mineralny,1
+0.91,brunatne,0.58,0.89,ziemniak,sztuczny,0
+0.07,brunatne,0.27,0.27,ziemniak,organiczny,1
+0.32,piaszczyste,0.12,0.68,brak,organiczny,1
\ No newline at end of file
diff --git a/assets/data/train.csv.bak b/assets/data/train.csv.bak
new file mode 100644
index 0000000..8b68f7b
--- /dev/null
+++ b/assets/data/train.csv.bak
@@ -0,0 +1,178 @@
+stan_nawodnienia,rodzaj_gleby,stan_nawiezienia,stopien_rozwoju,rodzaj_rosliny,rodzaj_nawozu,to_water
+0.32,brunatne,0.01,0.75,ziemniak,mineralny,1
+0.38,brunatne,0.15,0.85,pszenica,organiczny,1
+0.17,piaszczyste,0.22,0.13,brak,organiczny,1
+0.67,czarnoziemy,0.64,0.55,pszenica,sztuczny,0
+0.45,brunatne,0.16,0.16,brak,organiczny,1
+0.25,czarnoziemy,0.16,0.23,pszenica,organiczny,1
+0.10,brunatne,0.02,0.17,brak,mineralny,1
+0.25,brunatne,0.39,0.17,kaktus,organiczny,1
+0.92,brunatne,0.68,0.77,ziemniak,sztuczny,0
+0.70,piaszczyste,0.72,0.79,kaktus,sztuczny,0
+0.54,czarnoziemy,0.95,0.59,ziemniak,organiczny,0
+0.21,piaszczyste,0.05,0.03,brak,mineralny,1
+0.23,brunatne,0.11,0.86,ziemniak,organiczny,1
+0.61,czarnoziemy,0.76,0.76,ziemniak,sztuczny,0
+0.87,brunatne,0.53,0.85,pszenica,organiczny,0
+0.21,brunatne,0.01,0.26,kaktus,mineralny,1
+0.32,piaszczyste,0.06,0.40,ziemniak,mineralny,1
+0.76,czarnoziemy,0.79,0.67,pszenica,mineralny,0
+0.30,brunatne,0.21,0.10,brak,mineralny,1
+0.66,piaszczyste,0.93,0.95,kaktus,sztuczny,0
+0.71,brunatne,0.52,0.62,ziemniak,mineralny,0
+0.43,piaszczyste,0.38,0.35,pszenica,sztuczny,1
+0.66,brunatne,0.72,0.63,pszenica,organiczny,0
+0.21,piaszczyste,0.47,0.96,brak,sztuczny,1
+0.78,brunatne,0.59,0.68,pszenica,organiczny,0
+0.85,piaszczyste,0.77,0.64,pszenica,organiczny,0
+0.84,piaszczyste,0.75,0.86,ziemniak,sztuczny,0
+0.71,czarnoziemy,0.72,0.74,ziemniak,mineralny,1
+0.75,piaszczyste,0.77,0.74,ziemniak,sztuczny,0
+0.76,piaszczyste,0.75,0.81,pszenica,mineralny,0
+0.72,brunatne,0.74,0.84,kaktus,sztuczny,0
+0.59,brunatne,0.82,0.88,pszenica,organiczny,0
+0.48,czarnoziemy,0.15,0.27,ziemniak,organiczny,0
+0.04,piaszczyste,0.90,0.8,pszenica,organiczny,0
+0.23,brunatne,0.12,0.56,brak,organiczny,1
+0.02,piaszczyste,0.20,0.18,brak,organiczny,1
+0.02,piaszczyste,0.35,0.24,ziemniak,mineralny,1
+0.76,piaszczyste,0.78,0.85,ziemniak,mineralny,0
+0.72,czarnoziemy,0.67,0.76,brak,mineralny,0
+0.06,brunatne,0.02,0.71,ziemniak,mineralny,1
+0.96,czarnoziemy,0.66,0.80,ziemniak,sztuczny,0
+0.34,brunatne,0.22,0.15,kaktus,mineralny,0
+0.78,piaszczyste,0.62,0.68,pszenica,mineralny,0
+0.01,czarnoziemy,0.22,0.05,pszenica,sztuczny,1
+0.02,czarnoziemy,0.31,0.80,pszenica,sztuczny,1
+0.11,piaszczyste,0.14,0.14,pszenica,organiczny,1
+0.95,piaszczyste,0.61,0.94,brak,sztuczny,0
+0.78,brunatne,0.86,0.63,pszenica,organiczny,0
+0.03,brunatne,0.27,0.25,pszenica,sztuczny,1
+0.62,czarnoziemy,0.75,0.75,ziemniak,sztuczny,0
+0.24,piaszczyste,0.02,0.23,pszenica,organiczny,1
+0.46,czarnoziemy,0.79,0.27,ziemniak,sztuczny,0
+0.24,piaszczyste,0.74,0.53,pszenica,mineralny,1
+0.82,czarnoziemy,0.77,0.86,ziemniak,mineralny,0
+0.25,brunatne,0.18,0.23,pszenica,sztuczny,1
+0.73,czarnoziemy,0.94,0.54,brak,organiczny,0
+0.63,piaszczyste,0.84,0.92,kaktus,sztuczny,0
+0.78,czarnoziemy,0.75,0.94,brak,sztuczny,0
+0.84,piaszczyste,0.79,0.80,pszenica,sztuczny,0
+0.78,czarnoziemy,0.95,0.63,pszenica,mineralny,0
+0.43,brunatne,0.38,0.52,pszenica,organiczny,1
+0.88,piaszczyste,0.84,0.78,ziemniak,sztuczny,0
+0.69,brunatne,0.86,0.55,ziemniak,organiczny,0
+0.77,piaszczyste,0.78,0.55,ziemniak,organiczny,0
+0.26,brunatne,0.24,0.12,ziemniak,mineralny,1
+0.75,czarnoziemy,0.64,0.73,brak,mineralny,0
+0.14,piaszczyste,0.38,0.44,pszenica,sztuczny,1
+0.84,czarnoziemy,0.85,0.67,ziemniak,organiczny,0
+0.33,brunatne,0.11,0.12,brak,organiczny,1
+0.89,brunatne,0.67,0.88,brak,organiczny,0
+0.15,brunatne,0.21,0.23,pszenica,mineralny,1
+0.07,piaszczyste,0.24,0.67,brak,sztuczny,1
+0.29,czarnoziemy,0.10,0.14,ziemniak,sztuczny,1
+0.99,piaszczyste,0.99,0.75,kaktus,mineralny,0
+0.62,piaszczyste,0.65,0.75,brak,organiczny,0
+0.33,czarnoziemy,0.23,0.23,brak,sztuczny,1
+0.16,brunatne,0.28,0.38,brak,organiczny,1
+0.35,piaszczyste,0.18,0.13,ziemniak,mineralny,0
+0.85,czarnoziemy,0.64,0.74,pszenica,sztuczny,0
+0.78,czarnoziemy,0.59,0.72,ziemniak,organiczny,0
+0.24,czarnoziemy,0.14,0.18,brak,organiczny,1
+0.25,czarnoziemy,0.13,0.16,pszenica,sztuczny,1
+0.87,czarnoziemy,0.87,0.76,pszenica,mineralny,0
+0.28,brunatne,0.44,0.27,brak,mineralny,1
+0.47,brunatne,0.09,0.05,ziemniak,mineralny,1
+0.68,piaszczyste,0.72,0.67,brak,sztuczny,0
+0.23,czarnoziemy,0.27,0.25,pszenica,sztuczny,1
+0.05,piaszczyste,0.11,0.16,brak,organiczny,1
+0.12,piaszczyste,0.21,0.22,brak,mineralny,1
+0.86,brunatne,0.55,0.67,ziemniak,organiczny,0
+0.68,brunatne,0.76,0.87,pszenica,organiczny,0
+0.04,brunatne,0.32,0.26,kaktus,sztuczny,1
+0.73,piaszczyste,0.73,0.73,pszenica,sztuczny,0
+0.01,czarnoziemy,0.42,0.34,ziemniak,mineralny,1
+0.96,piaszczyste,0.70,0.71,ziemniak,organiczny,0
+0.74,piaszczyste,0.66,0.79,kaktus,mineralny,0
+0.75,brunatne,0.66,0.82,pszenica,mineralny,0
+0.70,brunatne,0.69,0.63,pszenica,organiczny,0
+0.66,brunatne,0.85,0.72,brak,sztuczny,0
+0.73,piaszczyste,0.58,0.95,ziemniak,mineralny,0
+0.36,brunatne,0.14,0.02,brak,mineralny,0
+0.61,piaszczyste,0.70,0.66,pszenica,mineralny,0
+0.83,piaszczyste,0.68,0.75,ziemniak,sztuczny,0
+0.55,brunatne,0.68,0.79,ziemniak,sztuczny,0
+0.75,czarnoziemy,0.64,0.58,kaktus,organiczny,0
+0.99,brunatne,0.57,0.82,pszenica,organiczny,0
+0.88,czarnoziemy,0.68,0.65,brak,sztuczny,0
+0.58,czarnoziemy,0.68,0.71,ziemniak,organiczny,1
+0.14,piaszczyste,0.46,0.17,brak,sztuczny,1
+0.06,czarnoziemy,0.23,0.25,kaktus,sztuczny,1
+0.56,brunatne,0.88,0.65,ziemniak,organiczny,0
+0.32,brunatne,0.39,0.23,pszenica,organiczny,1
+0.64,piaszczyste,0.85,0.99,ziemniak,mineralny,0
+0.25,czarnoziemy,0.26,0.45,kaktus,organiczny,0
+0.78,piaszczyste,0.85,0.92,brak,mineralny,0
+0.14,brunatne,0.09,0.13,pszenica,organiczny,1
+0.82,brunatne,0.81,0.80,brak,mineralny,0
+0.65,piaszczyste,0.58,0.78,pszenica,mineralny,0
+0.32,piaszczyste,0.25,0.04,ziemniak,organiczny,1
+0.33,czarnoziemy,0.12,0.46,kaktus,sztuczny,0
+0.58,czarnoziemy,0.76,0.66,pszenica,mineralny,0
+0.01,czarnoziemy,0.32,0.33,ziemniak,mineralny,1
+0.65,brunatne,0.88,0.66,ziemniak,organiczny,0
+0.12,czarnoziemy,0.16,0.13,brak,mineralny,1
+0.86,brunatne,0.98,0.88,brak,sztuczny,0
+0.23,brunatne,0.15,0.15,ziemniak,mineralny,1
+0.24,piaszczyste,0.37,0.34,kaktus,mineralny,0
+0.05,brunatne,0.18,0.39,ziemniak,mineralny,1
+0.23,czarnoziemy,0.26,0.35,ziemniak,organiczny,1
+0.55,czarnoziemy,0.78,0.89,brak,organiczny,0
+0.25,czarnoziemy,0.13,0.47,brak,mineralny,1
+0.62,czarnoziemy,0.82,0.68,brak,sztuczny,0
+0.85,czarnoziemy,0.64,0.56,brak,mineralny,0
+0.58,piaszczyste,0.94,0.74,brak,sztuczny,0
+0.17,piaszczyste,0.34,0.15,ziemniak,mineralny,1
+0.26,czarnoziemy,0.21,0.15,pszenica,mineralny,1
+0.22,czarnoziemy,0.35,0.45,pszenica,mineralny,1
+0.21,piaszczyste,0.26,0.26,pszenica,organiczny,1
+0.38,piaszczyste,0.36,0.14,pszenica,mineralny,1
+0.17,piaszczyste,0.38,0.27,pszenica,organiczny,1
+0.11,piaszczyste,0.21,0.23,pszenica,sztuczny,1
+0.48,czarnoziemy,0.29,0.28,kaktus,organiczny,0
+0.16,czarnoziemy,0.21,0.56,ziemniak,organiczny,1
+0.22,piaszczyste,0.22,0.19,ziemniak,sztuczny,1
+0.63,czarnoziemy,0.87,0.77,pszenica,mineralny,0
+0.69,piaszczyste,0.85,0.78,brak,mineralny,0
+0.55,brunatne,0.50,0.79,brak,organiczny,0
+0.14,piaszczyste,0.14,0.35,ziemniak,mineralny,1
+0.95,czarnoziemy,0.82,0.92,pszenica,organiczny,0
+0.68,piaszczyste,0.57,0.62,pszenica,mineralny,0
+0.60,brunatne,0.66,0.85,ziemniak,mineralny,0
+0.50,brunatne,0.91,0.83,ziemniak,organiczny,0
+0.36,piaszczyste,0.08,0.75,kaktus,mineralny,0
+0.13,brunatne,0.26,0.38,kaktus,sztuczny,1
+0.32,piaszczyste,0.25,0.24,ziemniak,sztuczny,1
+0.84,brunatne,0.54,0.50,brak,mineralny,0
+0.31,czarnoziemy,0.49,0.14,pszenica,organiczny,1
+0.62,piaszczyste,0.74,0.85,kaktus,sztuczny,0
+0.85,czarnoziemy,0.67,0.74,brak,organiczny,0
+0.74,piaszczyste,0.69,0.98,ziemniak,organiczny,0
+0.16,piaszczyste,0.16,0.36,pszenica,mineralny,1
+0.01,czarnoziemy,0.12,0.29,ziemniak,organiczny,1
+0.55,brunatne,0.67,0.74,brak,sztuczny,0
+0.75,czarnoziemy,0.86,0.90,ziemniak,sztuczny,0
+0.42,czarnoziemy,0.25,0.39,kaktus,sztuczny,0
+0.25,brunatne,0.30,0.24,ziemniak,organiczny,1
+0.12,czarnoziemy,0.20,0.49,ziemniak,mineralny,1
+0.75,brunatne,0.75,0.54,pszenica,organiczny,0
+0.65,czarnoziemy,0.82,0.76,pszenica,mineralny,0
+0.79,czarnoziemy,0.66,0.77,pszenica,sztuczny,0
+0.62,brunatne,0.78,0.84,brak,mineralny,0
+0.05,czarnoziemy,0.19,0.17,brak,organiczny,1
+0.22,czarnoziemy,0.24,0.25,kaktus,organiczny,1
+0.17,brunatne,0.31,0.35,ziemniak,mineralny,1
+0.91,brunatne,0.58,0.89,ziemniak,sztuczny,0
+0.07,brunatne,0.27,0.27,ziemniak,organiczny,1
+0.32,piaszczyste,0.12,0.18,brak,organiczny,1
\ No newline at end of file
diff --git a/assets/images/cactus.png b/assets/images/cactus.png
new file mode 100644
index 0000000..7bba73c
Binary files /dev/null and b/assets/images/cactus.png differ
diff --git a/assets/images/dirt_cactus.jpg b/assets/images/dirt_cactus.jpg
new file mode 100644
index 0000000..9199996
Binary files /dev/null and b/assets/images/dirt_cactus.jpg differ
diff --git a/assets/images/dirt.jpeg b/assets/images/dirt_empty.jpeg
similarity index 100%
rename from assets/images/dirt.jpeg
rename to assets/images/dirt_empty.jpeg
diff --git a/assets/images/dirt_potato.jpg b/assets/images/dirt_potato.jpg
new file mode 100644
index 0000000..a16501f
Binary files /dev/null and b/assets/images/dirt_potato.jpg differ
diff --git a/assets/images/dirt_wheat.jpg b/assets/images/dirt_wheat.jpg
new file mode 100644
index 0000000..b7fb009
Binary files /dev/null and b/assets/images/dirt_wheat.jpg differ
diff --git a/assets/images/farmland_cactus.jpg b/assets/images/farmland_cactus.jpg
new file mode 100644
index 0000000..c2a9ca2
Binary files /dev/null and b/assets/images/farmland_cactus.jpg differ
diff --git a/assets/images/farmland_empty.png b/assets/images/farmland_empty.png
new file mode 100644
index 0000000..8c16b9c
Binary files /dev/null and b/assets/images/farmland_empty.png differ
diff --git a/assets/images/farmland_potato.jpg b/assets/images/farmland_potato.jpg
new file mode 100644
index 0000000..d7b7d87
Binary files /dev/null and b/assets/images/farmland_potato.jpg differ
diff --git a/assets/images/farmland_wheat.jpg b/assets/images/farmland_wheat.jpg
new file mode 100644
index 0000000..236989d
Binary files /dev/null and b/assets/images/farmland_wheat.jpg differ
diff --git a/assets/images/potato.png b/assets/images/potato.png
new file mode 100644
index 0000000..d26f1b0
Binary files /dev/null and b/assets/images/potato.png differ
diff --git a/assets/images/wheat.png b/assets/images/wheat.png
new file mode 100644
index 0000000..b108d58
Binary files /dev/null and b/assets/images/wheat.png differ
diff --git a/assets/model/xgboost_model.pkl b/assets/model/xgboost_model.pkl
new file mode 100644
index 0000000..0a371df
Binary files /dev/null and b/assets/model/xgboost_model.pkl differ
diff --git a/main.py b/main.py
index f69b071..f9bcc00 100644
--- a/main.py
+++ b/main.py
@@ -1,30 +1,31 @@
import pygame
+from src.utils.xgb_model import Model
from src.world import World
from src.tractor import Tractor
from src.settings import Settings
-from utils.astar import a_star_search
-
+from src.utils.bfs import BFSSearcher
+from src.constants import Constants as C
def main():
pygame.init()
- settings = Settings() # ustawienia pygame
- world = World(settings) # stworzenie mapy na bazie ustawień pygame
- tractor = Tractor("Spalinowy", "Nawóz 1", settings, 8 * settings.tile_size, 8 * settings.tile_size) # stworzenie traktora z podanymi argumentami
- obstacles = [tile for tile in world.tiles if tile.type == 'rock'] # stworzenie listy z przeszkodami, kamień = przeszkoda
+ settings = Settings()
+ model = Model()
+ world = World(settings, model)
+ tractor = Tractor("Spalinowy", "Nawóz 1", settings, 0 * settings.tile_size, 0 * settings.tile_size, C.RIGHT)
+ plants_to_water = [tile for tile in world.tiles if tile.to_water == 1]
clock = pygame.time.Clock() # FPS purpose
- screen = pygame.display.set_mode((settings.screen_width, settings.screen_height)) # tworzenie ekranu
- pygame.display.set_caption('TRAKTOHOLIK') # nazwa okna
-
- start_cords = (8, 1)
- goals = [(8, 7), (7, 7), (0, 0)]
- end_cords = goals[0]
- start_dir = tractor.curr_direction # przypisanie początkowego ustawienia traktora do zmiennej
- # path = BFSSearcher().search(start_cords, end_cords, start_dir) # wygenerowanie listy ruchów na bazie BFS
- path = a_star_search(start_cords, end_cords, start_dir, world) # generowanie ścieżki za pomocą A*
+ screen = pygame.display.set_mode((settings.screen_width, settings.screen_height))
+ pygame.display.set_caption('TRAKTOHOLIK')
+ start_cords = tractor.curr_position
+ goals = [plant.position for plant in plants_to_water]
+ cords_idx = tractor.find_nearest_cords(tractor.curr_position, goals)
+ end_cords = goals[cords_idx]
+ start_dir = tractor.curr_direction
+ path = BFSSearcher().search(start_cords, end_cords, start_dir)
run = True
while run:
@@ -37,22 +38,46 @@ def main():
if event.type == pygame.QUIT:
run = False
- # iteracja przez listę ruchów
if path:
- action = path.pop(0) # pobranie pierwszego ruchu z listy
- tractor.update(action) # wykonanie ruchu przez traktor
+ action = path.pop(0)
+ tractor.update(action)
else:
- if len(goals) > 1: # sprawdzenie czy są inne cele
- new_start = goals.pop(0) # pobierz współrzędne pierwszego celu i ustaw jako początkowe
- end_cords = goals[0] # ustaw kolejny cel
- start_dir = tractor.curr_direction # aktualizacja kierunku traktora
- # path = BFSSearcher().search(start_cord, end_cords, start_dir) # generacja nowej ścieżki
- path = a_star_search(new_start, end_cords, start_dir, world)
+ tractor.water_plant(world, end_cords)
+ if len(goals) > 1:
+ start_cord = goals.pop(cords_idx)
+ cords_idx = tractor.find_nearest_cords(tractor.curr_position, goals)
+ end_cords = goals[cords_idx]
+ start_dir = tractor.curr_direction
+ path = BFSSearcher().search(start_cord, end_cords, start_dir)
pygame.time.wait(settings.freeze_time)
pygame.display.update()
pygame.quit()
-if __name__ == '__main__':
- main()
+# if __name__ == '__main__':
+#
+# # inicjalizacja array z zerami
+# rows = 10
+# cols = 10
+# field = np.zeros((rows, cols), dtype=int)
+#
+# # tworzenie ścian w array
+# for i in range(0, 9):
+# field[1, i] = 1
+#
+# field[2, 8] = 1
+# field[3, 8] = 1
+# field[3, 7] = 1
+# field[3, 6] = 1
+#
+# print(field)
+#
+# start = (0, 0)
+# goals = [(2, 7)]
+# while goals:
+# goal = goals.pop(0)
+# path = a_star(field, start, goal)
+# print(path)
+
+main()
\ No newline at end of file
diff --git a/src/settings.py b/src/settings.py
index 9dabdcb..34c1140 100644
--- a/src/settings.py
+++ b/src/settings.py
@@ -10,6 +10,9 @@ class Settings:
self.screen_width = 700
self.screen_height = 700
+ # World settings
+ self.world_size = 10
+
# Tile settings
self.tile_size = 70
diff --git a/src/tile.py b/src/tile.py
index 7c7323f..83f806b 100644
--- a/src/tile.py
+++ b/src/tile.py
@@ -4,7 +4,8 @@ from pygame.sprite import Sprite
class Tile(Sprite):
""" Class to represent single board tile """
- def __init__(self, type, row_id, col_id, image, rect, cost):
+ def __init__(self, type, row_id, col_id, image, rect, cost, stan_nawodnienia, rodzaj_gleby,
+ stan_nawiezienia, stopien_rozwoju, rodzaj_rosliny, rodzaj_nawozu, to_water):
super().__init__()
self.type = type
self.row_id = row_id
@@ -13,18 +14,15 @@ class Tile(Sprite):
self.image = image
self.rect = rect
self.cost = cost
- self.plant = None
- self.hydration = 0
- self.fertilizer = None
- self.is_fertilized = False
-
- def add_plant(self, plant):
- self.plant = plant
-
- def remove_plant(self):
- self.plant = None
-
+ # nwm, ten kod po polsku moze kiedys do zmiany, póki co mi sie nie chce
+ self.stan_nawodnienia = stan_nawodnienia
+ self.rodzaj_gleby = rodzaj_gleby
+ self.stan_nawiezienia = stan_nawiezienia
+ self.stopien_rozwoju = stopien_rozwoju
+ self.rodzaj_rosliny = rodzaj_rosliny
+ self.rodzaj_nawozu = rodzaj_nawozu
+ self.to_water = to_water
diff --git a/src/tractor.py b/src/tractor.py
index 4046428..cd1129c 100644
--- a/src/tractor.py
+++ b/src/tractor.py
@@ -1,19 +1,24 @@
import pygame
+import numpy as np
+from scipy import spatial
from pygame.sprite import Sprite
from constants import Constants as C
+
class Tractor(Sprite):
""" Class to represent our agent """
- def __init__(self, engine, fertilizer, settings, x, y):
+ def __init__(self, engine, fertilizer, settings, initial_x, initial_y, initial_direction):
super().__init__()
self.settings = settings
- self.image = pygame.transform.scale(pygame.image.load('assets/images/tractor/tractor-transparent-up.png'),
- (self.settings.tile_size, self.settings.tile_size))
+ self.image = pygame.transform.scale(pygame.image.load('assets/images/tractor/tractor-transparent-right.png'),
+ (self.settings.tile_size - 1, self.settings.tile_size - 1))
self.rect = self.image.get_rect()
- self.rect.x = x
- self.rect.y = y
- self.curr_direction = C.UP # wektor w ukladzie wspolrzednych wskazujacy kierunek traktora
+ self.rect.x = initial_x
+ self.rect.y = initial_y
+ self.curr_position = (round(self.rect.x / self.settings.tile_size),
+ self.settings.world_size - round(self.rect.y / self.settings.tile_size) - 1)
+ self.curr_direction = initial_direction # wektor w ukladzie wspolrzednych wskazujacy kierunek traktora
self.engine = engine
self.fertilizer = fertilizer
@@ -80,15 +85,56 @@ class Tractor(Sprite):
pygame.time.wait(self.settings.freeze_time) # bez tego sie kreci jak hot-wheels
- def check_collision(self, obstacles):
- if pygame.sprite.spritecollideany(self, obstacles):
- print('yes')
- self.rect.x -= self.curr_direction[0] * self.settings.tile_size # no to troche prymitywne jest, ale
- self.rect.y += self.curr_direction[1] * self.settings.tile_size # jak wejdzie na kolizje to cofamy ruch
+ def water_plant(self, world, position):
+ plant = world.get_tile(position[0], position[1])
+
+ if plant.rodzaj_rosliny == 'brak':
+ plant.image = pygame.transform.scale(world.farmland_empty,
+ (self.settings.tile_size, self.settings.tile_size))
+ elif plant.rodzaj_rosliny == 'kaktus':
+ plant.image = pygame.transform.scale(world.farmland_cactus,
+ (self.settings.tile_size, self.settings.tile_size))
+ elif plant.rodzaj_rosliny == 'pszenica':
+ plant.image = pygame.transform.scale(world.farmland_wheat,
+ (self.settings.tile_size, self.settings.tile_size))
+ elif plant.rodzaj_rosliny == 'ziemniak':
+ plant.image = pygame.transform.scale(world.farmland_potato,
+ (self.settings.tile_size, self.settings.tile_size))
+
+ def find_nearest_cords(self, curr_cords, cords_lst):
+ # source https://stackoverflow.com/questions/39107896/efficiently-finding-the-closest-coordinate-pair-from-a-set-in-python
+ tree = spatial.KDTree(cords_lst)
+ return tree.query(curr_cords)[1]
+
+ # moze sie jeszcze kiedys przyda
+ # def water_plants(self, plants, plants_lst, goals, world):
+ # hit_list = pygame.sprite.spritecollide(sprite=self, group=plants, dokill=False)
+ #
+ # print(goals)
+ #
+ # for plant in hit_list:
+ # if plant.rodzaj_rosliny == 'brak':
+ # plant.image = pygame.transform.scale(world.farmland_empty, (self.settings.tile_size, self.settings.tile_size))
+ # elif plant.rodzaj_rosliny == 'kaktus':
+ # plant.image = pygame.transform.scale(world.farmland_cactus, (self.settings.tile_size, self.settings.tile_size))
+ # elif plant.rodzaj_rosliny == 'pszenica':
+ # plant.image = pygame.transform.scale(world.farmland_wheat, (self.settings.tile_size, self.settings.tile_size))
+ # elif plant.rodzaj_rosliny == 'ziemniak':
+ # plant.image = pygame.transform.scale(world.farmland_potato, (self.settings.tile_size, self.settings.tile_size))
+ #
+ # if plant.position in goals:
+ # goals.remove(plant.position)
+
+
+ # if pygame.sprite.spritecollideany(self, obstacles):
+ # print(len(obstacles))
+ # self.rect.x -= self.curr_direction[0] * self.settings.tile_size # no to troche prymitywne jest, ale
+ # self.rect.y += self.curr_direction[1] * self.settings.tile_size # jak wejdzie na kolizje to cofamy ruch
# w przyszlosci mozna zmienic
def update(self, action):
-
+ self.curr_position = (round(self.rect.x / self.settings.tile_size),
+ self.settings.world_size - round(self.rect.y / self.settings.tile_size) - 1)
if action == C.ROTATE_RIGHT and self.rect.x:
self.turn_right()
elif action == C.ROTATE_LEFT:
diff --git a/src/utils/.ipynb_checkpoints/create_model-checkpoint.ipynb b/src/utils/.ipynb_checkpoints/create_model-checkpoint.ipynb
new file mode 100644
index 0000000..38ee33f
--- /dev/null
+++ b/src/utils/.ipynb_checkpoints/create_model-checkpoint.ipynb
@@ -0,0 +1,7123 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "e044c818",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "import seaborn as sns\n",
+ "import numpy as np\n",
+ "import warnings\n",
+ "import pickle\n",
+ "from matplotlib import pyplot as plt\n",
+ "from xgboost import XGBClassifier\n",
+ "from sklearn.model_selection import cross_val_score, cross_validate, KFold\n",
+ "from sklearn.metrics import accuracy_score, recall_score, precision_score, roc_auc_score, f1_score\n",
+ "from hyperopt import hp, fmin, tpe, Trials, STATUS_OK\n",
+ "warnings.filterwarnings('ignore')\n",
+ "%config InlineBackend.figure_format = 'svg'"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "605e1308",
+ "metadata": {},
+ "source": [
+ "# EDA"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "37a1a88d",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " stan_nawodnienia | \n",
+ " rodzaj_gleby | \n",
+ " stan_nawiezienia | \n",
+ " stopien_rozwoju | \n",
+ " rodzaj_rosliny | \n",
+ " rodzaj_nawozu | \n",
+ " to_water | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0.32 | \n",
+ " brunatne | \n",
+ " 0.01 | \n",
+ " 0.75 | \n",
+ " ziemniak | \n",
+ " mineralny | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 0.38 | \n",
+ " brunatne | \n",
+ " 0.15 | \n",
+ " 0.85 | \n",
+ " pszenica | \n",
+ " organiczny | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.17 | \n",
+ " piaszczyste | \n",
+ " 0.22 | \n",
+ " 0.13 | \n",
+ " brak | \n",
+ " organiczny | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.67 | \n",
+ " czarnoziemy | \n",
+ " 0.64 | \n",
+ " 0.55 | \n",
+ " pszenica | \n",
+ " sztuczny | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.45 | \n",
+ " brunatne | \n",
+ " 0.16 | \n",
+ " 0.16 | \n",
+ " brak | \n",
+ " organiczny | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 172 | \n",
+ " 0.22 | \n",
+ " czarnoziemy | \n",
+ " 0.24 | \n",
+ " 0.25 | \n",
+ " kaktus | \n",
+ " organiczny | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 173 | \n",
+ " 0.17 | \n",
+ " brunatne | \n",
+ " 0.31 | \n",
+ " 0.35 | \n",
+ " ziemniak | \n",
+ " mineralny | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 174 | \n",
+ " 0.91 | \n",
+ " brunatne | \n",
+ " 0.58 | \n",
+ " 0.89 | \n",
+ " ziemniak | \n",
+ " sztuczny | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 175 | \n",
+ " 0.07 | \n",
+ " brunatne | \n",
+ " 0.27 | \n",
+ " 0.27 | \n",
+ " ziemniak | \n",
+ " organiczny | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 176 | \n",
+ " 0.32 | \n",
+ " piaszczyste | \n",
+ " 0.12 | \n",
+ " 0.68 | \n",
+ " brak | \n",
+ " organiczny | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
177 rows × 7 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " stan_nawodnienia rodzaj_gleby stan_nawiezienia stopien_rozwoju \\\n",
+ "0 0.32 brunatne 0.01 0.75 \n",
+ "1 0.38 brunatne 0.15 0.85 \n",
+ "2 0.17 piaszczyste 0.22 0.13 \n",
+ "3 0.67 czarnoziemy 0.64 0.55 \n",
+ "4 0.45 brunatne 0.16 0.16 \n",
+ ".. ... ... ... ... \n",
+ "172 0.22 czarnoziemy 0.24 0.25 \n",
+ "173 0.17 brunatne 0.31 0.35 \n",
+ "174 0.91 brunatne 0.58 0.89 \n",
+ "175 0.07 brunatne 0.27 0.27 \n",
+ "176 0.32 piaszczyste 0.12 0.68 \n",
+ "\n",
+ " rodzaj_rosliny rodzaj_nawozu to_water \n",
+ "0 ziemniak mineralny 1 \n",
+ "1 pszenica organiczny 1 \n",
+ "2 brak organiczny 1 \n",
+ "3 pszenica sztuczny 0 \n",
+ "4 brak organiczny 1 \n",
+ ".. ... ... ... \n",
+ "172 kaktus organiczny 1 \n",
+ "173 ziemniak mineralny 1 \n",
+ "174 ziemniak sztuczny 0 \n",
+ "175 ziemniak organiczny 1 \n",
+ "176 brak organiczny 1 \n",
+ "\n",
+ "[177 rows x 7 columns]"
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "train = pd.read_csv('../../assets/data/train.csv')\n",
+ "train"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "ddda0ee1",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\r\n",
+ "\r\n",
+ "\r\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "sns.countplot('to_water', data=train);"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "823b28a3",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\r\n",
+ "\r\n",
+ "\r\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "sns.countplot('rodzaj_gleby', data=train);"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "ff2991fb",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\r\n",
+ "\r\n",
+ "\r\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "sns.countplot('rodzaj_rosliny', data=train);"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "c2815299",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\r\n",
+ "\r\n",
+ "\r\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "sns.countplot('rodzaj_nawozu', data=train);"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "b66c9226",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\r\n",
+ "\r\n",
+ "\r\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "sns.boxplot(train.to_water, train.stan_nawodnienia);"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "8335c0ba",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\r\n",
+ "\r\n",
+ "\r\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "sns.boxplot(train.to_water, train.stan_nawiezienia);"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1b7c9621",
+ "metadata": {},
+ "source": [
+ "# Features Engineering"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "8da84497",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " stan_nawodnienia | \n",
+ " stan_nawiezienia | \n",
+ " stopien_rozwoju | \n",
+ " to_water | \n",
+ " rodzaj_gleby_brunatne | \n",
+ " rodzaj_gleby_czarnoziemy | \n",
+ " rodzaj_gleby_piaszczyste | \n",
+ " rodzaj_rosliny_brak | \n",
+ " rodzaj_rosliny_kaktus | \n",
+ " rodzaj_rosliny_pszenica | \n",
+ " rodzaj_rosliny_ziemniak | \n",
+ " rodzaj_nawozu_mineralny | \n",
+ " rodzaj_nawozu_organiczny | \n",
+ " rodzaj_nawozu_sztuczny | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0.32 | \n",
+ " 0.01 | \n",
+ " 0.75 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 0.38 | \n",
+ " 0.15 | \n",
+ " 0.85 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.17 | \n",
+ " 0.22 | \n",
+ " 0.13 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.67 | \n",
+ " 0.64 | \n",
+ " 0.55 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.45 | \n",
+ " 0.16 | \n",
+ " 0.16 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 172 | \n",
+ " 0.22 | \n",
+ " 0.24 | \n",
+ " 0.25 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 173 | \n",
+ " 0.17 | \n",
+ " 0.31 | \n",
+ " 0.35 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 174 | \n",
+ " 0.91 | \n",
+ " 0.58 | \n",
+ " 0.89 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 175 | \n",
+ " 0.07 | \n",
+ " 0.27 | \n",
+ " 0.27 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 176 | \n",
+ " 0.32 | \n",
+ " 0.12 | \n",
+ " 0.68 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
177 rows × 14 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " stan_nawodnienia stan_nawiezienia stopien_rozwoju to_water \\\n",
+ "0 0.32 0.01 0.75 1 \n",
+ "1 0.38 0.15 0.85 1 \n",
+ "2 0.17 0.22 0.13 1 \n",
+ "3 0.67 0.64 0.55 0 \n",
+ "4 0.45 0.16 0.16 1 \n",
+ ".. ... ... ... ... \n",
+ "172 0.22 0.24 0.25 1 \n",
+ "173 0.17 0.31 0.35 1 \n",
+ "174 0.91 0.58 0.89 0 \n",
+ "175 0.07 0.27 0.27 1 \n",
+ "176 0.32 0.12 0.68 1 \n",
+ "\n",
+ " rodzaj_gleby_brunatne rodzaj_gleby_czarnoziemy \\\n",
+ "0 1 0 \n",
+ "1 1 0 \n",
+ "2 0 0 \n",
+ "3 0 1 \n",
+ "4 1 0 \n",
+ ".. ... ... \n",
+ "172 0 1 \n",
+ "173 1 0 \n",
+ "174 1 0 \n",
+ "175 1 0 \n",
+ "176 0 0 \n",
+ "\n",
+ " rodzaj_gleby_piaszczyste rodzaj_rosliny_brak rodzaj_rosliny_kaktus \\\n",
+ "0 0 0 0 \n",
+ "1 0 0 0 \n",
+ "2 1 1 0 \n",
+ "3 0 0 0 \n",
+ "4 0 1 0 \n",
+ ".. ... ... ... \n",
+ "172 0 0 1 \n",
+ "173 0 0 0 \n",
+ "174 0 0 0 \n",
+ "175 0 0 0 \n",
+ "176 1 1 0 \n",
+ "\n",
+ " rodzaj_rosliny_pszenica rodzaj_rosliny_ziemniak \\\n",
+ "0 0 1 \n",
+ "1 1 0 \n",
+ "2 0 0 \n",
+ "3 1 0 \n",
+ "4 0 0 \n",
+ ".. ... ... \n",
+ "172 0 0 \n",
+ "173 0 1 \n",
+ "174 0 1 \n",
+ "175 0 1 \n",
+ "176 0 0 \n",
+ "\n",
+ " rodzaj_nawozu_mineralny rodzaj_nawozu_organiczny rodzaj_nawozu_sztuczny \n",
+ "0 1 0 0 \n",
+ "1 0 1 0 \n",
+ "2 0 1 0 \n",
+ "3 0 0 1 \n",
+ "4 0 1 0 \n",
+ ".. ... ... ... \n",
+ "172 0 1 0 \n",
+ "173 1 0 0 \n",
+ "174 0 0 1 \n",
+ "175 0 1 0 \n",
+ "176 0 1 0 \n",
+ "\n",
+ "[177 rows x 14 columns]"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "train = pd.get_dummies(train)\n",
+ "train"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "b22a5623",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "X_train = train.drop('to_water', axis=1)\n",
+ "y_train = train['to_water']"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1f7e3d8e",
+ "metadata": {},
+ "source": [
+ "# Training XGBoost model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "58187f92",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,\n",
+ " colsample_bynode=1, colsample_bytree=1, enable_categorical=False,\n",
+ " eval_metric='mlogloss', gamma=0, gpu_id=-1, importance_type=None,\n",
+ " interaction_constraints='', learning_rate=0.300000012,\n",
+ " max_delta_step=0, max_depth=6, min_child_weight=1, missing=nan,\n",
+ " monotone_constraints='()', n_estimators=100, n_jobs=12,\n",
+ " num_parallel_tree=1, predictor='auto', random_state=0,\n",
+ " reg_alpha=0, reg_lambda=1, scale_pos_weight=1, subsample=1,\n",
+ " tree_method='exact', validate_parameters=1, verbosity=None)"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "xgb = XGBClassifier(eval_metric='mlogloss')\n",
+ "xgb.fit(X_train, y_train)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "ebaba036",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def cv_model(X_train, y_train, model):\n",
+ " kfold = KFold(n_splits=5, shuffle=True, random_state=42)\n",
+ " scoring = {'precision': 'precision_macro',\n",
+ " 'recall': 'recall_macro',\n",
+ " 'accuracy': 'accuracy',\n",
+ " 'auc': 'roc_auc'}\n",
+ " cv_results = []\n",
+ " \n",
+ " cv_model = cross_validate(model, X_train, y_train, scoring=scoring, cv=kfold)\n",
+ " cv_results.append([cv_model['test_precision'].mean(), cv_model['test_recall'].mean(),\n",
+ " cv_model['test_accuracy'].mean(), cv_model['test_auc'].mean()])\n",
+ " \n",
+ " return cv_results\n",
+ "\n",
+ "def show_cv_results(cv_results):\n",
+ " cvr_df = pd.DataFrame(index=['XGBoost'],\n",
+ " data=cv_results, columns=['precision', 'recall', 'accuracy', 'AUC'])\n",
+ " return cvr_df"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "9aca9ade",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " precision | \n",
+ " recall | \n",
+ " accuracy | \n",
+ " AUC | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " XGBoost | \n",
+ " 0.924244 | \n",
+ " 0.920625 | \n",
+ " 0.920952 | \n",
+ " 0.973928 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " precision recall accuracy AUC\n",
+ "XGBoost 0.924244 0.920625 0.920952 0.973928"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "cv_results = cv_model(X_train, y_train, xgb)\n",
+ "show_cv_results(cv_results)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "2faf6ef1",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\r\n",
+ "\r\n",
+ "\r\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "feat_importances = pd.Series(xgb.feature_importances_, index=X_train.columns).sort_values().nlargest(10)\n",
+ "feat_importances.sort_values().plot(kind='barh', figsize=[9, 7]);"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6bc9eb1e",
+ "metadata": {},
+ "source": [
+ "# XGBoost parameter tuning using HyperOpt"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "9cb0e365",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "params_bounds = {'learning_rate': hp.uniform('learning_rate', 0.01, 1.0),\n",
+ " 'n_estimators': hp.uniform('n_estimators', 100.0, 1000.0),\n",
+ " 'max_depth': hp.uniform('max_depth', 4.0, 10.0), \n",
+ " 'subsample': hp.uniform('subsample', 0.5, 1.0),\n",
+ " 'gamma': hp.uniform('gamma', 0.0, 5.0)}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "0a6867f5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def fun(params_bounds):\n",
+ " model = XGBClassifier(eval_metric='mlogloss',\n",
+ " learning_rate = params_bounds['learning_rate'],\n",
+ " n_estimators = round(params_bounds['n_estimators']),\n",
+ " max_depth = round(params_bounds['max_depth']),\n",
+ " subsample = params_bounds['subsample'],\n",
+ " gamma = params_bounds['gamma'])\n",
+ " score = cross_val_score(model, X_train, y_train, cv=5, scoring='roc_auc', error_score='raise').mean()\n",
+ " \n",
+ " return {'loss': -score, 'status': STATUS_OK}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "532b8dcd",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "100%|███████████████████████████████████████████████| 30/30 [00:20<00:00, 1.47trial/s, best loss: -0.9791677631578948]\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "{'gamma': 2.877864670923052,\n",
+ " 'learning_rate': 0.01623819939058474,\n",
+ " 'max_depth': 6,\n",
+ " 'n_estimators': 378,\n",
+ " 'subsample': 0.8352886879422001}"
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "trials = Trials()\n",
+ "best = fmin(fn=fun,\n",
+ " space=params_bounds,\n",
+ " algo=tpe.suggest,\n",
+ " max_evals=30,\n",
+ " trials=trials)\n",
+ "best['max_depth'] = round(best['max_depth'])\n",
+ "best['n_estimators'] = round(best['n_estimators'])\n",
+ "best"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "df612fb1",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " precision | \n",
+ " recall | \n",
+ " accuracy | \n",
+ " AUC | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " XGBoost | \n",
+ " 0.936778 | \n",
+ " 0.930734 | \n",
+ " 0.932063 | \n",
+ " 0.975984 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " precision recall accuracy AUC\n",
+ "XGBoost 0.936778 0.930734 0.932063 0.975984"
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "xgb = XGBClassifier(eval_metric='mlogloss', **best)\n",
+ "xgb.fit(X_train, y_train)\n",
+ "cv_results = cv_model(X_train, y_train, xgb)\n",
+ "show_cv_results(cv_results)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1e8dcbad",
+ "metadata": {},
+ "source": [
+ "**As expected, no improvement at all**"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "39a2330b",
+ "metadata": {},
+ "source": [
+ "# Testing Model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "1c28788a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "test = pd.read_csv('../../assets/data/test.csv')\n",
+ "test = pd.get_dummies(test)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "41f3bab6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "X_test = test.drop('to_water', axis=1)\n",
+ "y_test = test['to_water']"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "id": "21f0fe78",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "prediction = xgb.predict(X_test)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "2fb11624",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " precision | \n",
+ " recall | \n",
+ " accuracy | \n",
+ " AUC | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " XGBoost | \n",
+ " 1.0 | \n",
+ " 0.916667 | \n",
+ " 0.972222 | \n",
+ " 0.958333 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " precision recall accuracy AUC\n",
+ "XGBoost 1.0 0.916667 0.972222 0.958333"
+ ]
+ },
+ "execution_count": 23,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "data=[[precision_score(y_test, prediction),\n",
+ " recall_score(y_test, prediction),\n",
+ " accuracy_score(y_test, prediction),\n",
+ " roc_auc_score(y_test, prediction)]]\n",
+ "result_df = pd.DataFrame(index=['XGBoost'],\n",
+ " data=data, columns=['precision', 'recall', 'accuracy', 'AUC'])\n",
+ "result_df"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "69deba28",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "34 1\n",
+ "Name: to_water, dtype: int64"
+ ]
+ },
+ "execution_count": 24,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "y_test[y_test != prediction]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a53b99d2",
+ "metadata": {},
+ "source": [
+ "**Only one wrong prediction:**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "id": "6a4a8707",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "stan_nawodnienia 0.34\n",
+ "stan_nawiezienia 0.72\n",
+ "stopien_rozwoju 0.53\n",
+ "rodzaj_gleby_brunatne 0.00\n",
+ "rodzaj_gleby_czarnoziemy 0.00\n",
+ "rodzaj_gleby_piaszczyste 1.00\n",
+ "rodzaj_rosliny_brak 0.00\n",
+ "rodzaj_rosliny_kaktus 0.00\n",
+ "rodzaj_rosliny_pszenica 1.00\n",
+ "rodzaj_rosliny_ziemniak 0.00\n",
+ "rodzaj_nawozu_mineralny 1.00\n",
+ "rodzaj_nawozu_organiczny 0.00\n",
+ "rodzaj_nawozu_sztuczny 0.00\n",
+ "Name: 34, dtype: float64"
+ ]
+ },
+ "execution_count": 25,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "X_test.iloc[34]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "fec680e9",
+ "metadata": {},
+ "source": [
+ "# Save XGBoost model using pickle"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "id": "e043273b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "filename = '../../assets/model/xgboost_model.pkl'\n",
+ "pickle.dump(xgb, open(filename, 'wb'))"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.7"
+ },
+ "toc": {
+ "base_numbering": 1,
+ "nav_menu": {},
+ "number_sections": false,
+ "sideBar": true,
+ "skip_h1_title": false,
+ "title_cell": "Table of Contents",
+ "title_sidebar": "Contents",
+ "toc_cell": false,
+ "toc_position": {
+ "height": "calc(100% - 180px)",
+ "left": "10px",
+ "top": "150px",
+ "width": "303.825px"
+ },
+ "toc_section_display": true,
+ "toc_window_display": true
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/src/utils/create_model.html b/src/utils/create_model.html
new file mode 100644
index 0000000..50e99f1
--- /dev/null
+++ b/src/utils/create_model.html
@@ -0,0 +1,22011 @@
+
+
+
+
+
+create_model
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Out[2]:
+
+
+
+
+
+
+
+
+
+ |
+ stan_nawodnienia |
+ rodzaj_gleby |
+ stan_nawiezienia |
+ stopien_rozwoju |
+ rodzaj_rosliny |
+ rodzaj_nawozu |
+ to_water |
+
+
+
+
+ 0 |
+ 0.32 |
+ brunatne |
+ 0.01 |
+ 0.75 |
+ ziemniak |
+ mineralny |
+ 1 |
+
+
+ 1 |
+ 0.38 |
+ brunatne |
+ 0.15 |
+ 0.85 |
+ pszenica |
+ organiczny |
+ 1 |
+
+
+ 2 |
+ 0.17 |
+ piaszczyste |
+ 0.22 |
+ 0.13 |
+ brak |
+ organiczny |
+ 1 |
+
+
+ 3 |
+ 0.67 |
+ czarnoziemy |
+ 0.64 |
+ 0.55 |
+ pszenica |
+ sztuczny |
+ 0 |
+
+
+ 4 |
+ 0.45 |
+ brunatne |
+ 0.16 |
+ 0.16 |
+ brak |
+ organiczny |
+ 1 |
+
+
+ ... |
+ ... |
+ ... |
+ ... |
+ ... |
+ ... |
+ ... |
+ ... |
+
+
+ 172 |
+ 0.22 |
+ czarnoziemy |
+ 0.24 |
+ 0.25 |
+ kaktus |
+ organiczny |
+ 1 |
+
+
+ 173 |
+ 0.17 |
+ brunatne |
+ 0.31 |
+ 0.35 |
+ ziemniak |
+ mineralny |
+ 1 |
+
+
+ 174 |
+ 0.91 |
+ brunatne |
+ 0.58 |
+ 0.89 |
+ ziemniak |
+ sztuczny |
+ 0 |
+
+
+ 175 |
+ 0.07 |
+ brunatne |
+ 0.27 |
+ 0.27 |
+ ziemniak |
+ organiczny |
+ 1 |
+
+
+ 176 |
+ 0.32 |
+ piaszczyste |
+ 0.12 |
+ 0.68 |
+ brak |
+ organiczny |
+ 1 |
+
+
+
+
177 rows × 7 columns
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Out[10]:
+
+
+
+
+
+
+
+
+
+ |
+ stan_nawodnienia |
+ stan_nawiezienia |
+ stopien_rozwoju |
+ to_water |
+ rodzaj_gleby_brunatne |
+ rodzaj_gleby_czarnoziemy |
+ rodzaj_gleby_piaszczyste |
+ rodzaj_rosliny_brak |
+ rodzaj_rosliny_kaktus |
+ rodzaj_rosliny_pszenica |
+ rodzaj_rosliny_ziemniak |
+ rodzaj_nawozu_mineralny |
+ rodzaj_nawozu_organiczny |
+ rodzaj_nawozu_sztuczny |
+
+
+
+
+ 0 |
+ 0.32 |
+ 0.01 |
+ 0.75 |
+ 1 |
+ 1 |
+ 0 |
+ 0 |
+ 0 |
+ 0 |
+ 0 |
+ 1 |
+ 1 |
+ 0 |
+ 0 |
+
+
+ 1 |
+ 0.38 |
+ 0.15 |
+ 0.85 |
+ 1 |
+ 1 |
+ 0 |
+ 0 |
+ 0 |
+ 0 |
+ 1 |
+ 0 |
+ 0 |
+ 1 |
+ 0 |
+
+
+ 2 |
+ 0.17 |
+ 0.22 |
+ 0.13 |
+ 1 |
+ 0 |
+ 0 |
+ 1 |
+ 1 |
+ 0 |
+ 0 |
+ 0 |
+ 0 |
+ 1 |
+ 0 |
+
+
+ 3 |
+ 0.67 |
+ 0.64 |
+ 0.55 |
+ 0 |
+ 0 |
+ 1 |
+ 0 |
+ 0 |
+ 0 |
+ 1 |
+ 0 |
+ 0 |
+ 0 |
+ 1 |
+
+
+ 4 |
+ 0.45 |
+ 0.16 |
+ 0.16 |
+ 1 |
+ 1 |
+ 0 |
+ 0 |
+ 1 |
+ 0 |
+ 0 |
+ 0 |
+ 0 |
+ 1 |
+ 0 |
+
+
+ ... |
+ ... |
+ ... |
+ ... |
+ ... |
+ ... |
+ ... |
+ ... |
+ ... |
+ ... |
+ ... |
+ ... |
+ ... |
+ ... |
+ ... |
+
+
+ 172 |
+ 0.22 |
+ 0.24 |
+ 0.25 |
+ 1 |
+ 0 |
+ 1 |
+ 0 |
+ 0 |
+ 1 |
+ 0 |
+ 0 |
+ 0 |
+ 1 |
+ 0 |
+
+
+ 173 |
+ 0.17 |
+ 0.31 |
+ 0.35 |
+ 1 |
+ 1 |
+ 0 |
+ 0 |
+ 0 |
+ 0 |
+ 0 |
+ 1 |
+ 1 |
+ 0 |
+ 0 |
+
+
+ 174 |
+ 0.91 |
+ 0.58 |
+ 0.89 |
+ 0 |
+ 1 |
+ 0 |
+ 0 |
+ 0 |
+ 0 |
+ 0 |
+ 1 |
+ 0 |
+ 0 |
+ 1 |
+
+
+ 175 |
+ 0.07 |
+ 0.27 |
+ 0.27 |
+ 1 |
+ 1 |
+ 0 |
+ 0 |
+ 0 |
+ 0 |
+ 0 |
+ 1 |
+ 0 |
+ 1 |
+ 0 |
+
+
+ 176 |
+ 0.32 |
+ 0.12 |
+ 0.68 |
+ 1 |
+ 0 |
+ 0 |
+ 1 |
+ 1 |
+ 0 |
+ 0 |
+ 0 |
+ 0 |
+ 1 |
+ 0 |
+
+
+
+
177 rows × 14 columns
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Out[12]:
+
+
+
+
+
+
XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
+ colsample_bynode=1, colsample_bytree=1, enable_categorical=False,
+ eval_metric='mlogloss', gamma=0, gpu_id=-1, importance_type=None,
+ interaction_constraints='', learning_rate=0.300000012,
+ max_delta_step=0, max_depth=6, min_child_weight=1, missing=nan,
+ monotone_constraints='()', n_estimators=100, n_jobs=12,
+ num_parallel_tree=1, predictor='auto', random_state=0,
+ reg_alpha=0, reg_lambda=1, scale_pos_weight=1, subsample=1,
+ tree_method='exact', validate_parameters=1, verbosity=None)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Out[14]:
+
+
+
+
+
+
+
+
+
+ |
+ precision |
+ recall |
+ accuracy |
+ AUC |
+
+
+
+
+ XGBoost |
+ 0.924244 |
+ 0.920625 |
+ 0.920952 |
+ 0.973928 |
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
100%|███████████████████████████████████████████████| 30/30 [00:18<00:00, 1.63trial/s, best loss: -0.9784769736842105]
+
+
+
+
+
+
+
+
Out[18]:
+
+
+
+
+
+
{'gamma': 3.5918955252412506,
+ 'learning_rate': 0.01498664732425289,
+ 'max_depth': 6,
+ 'n_estimators': 520,
+ 'subsample': 0.6447636481124372}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Out[19]:
+
+
+
+
+
+
+
+
+
+ |
+ precision |
+ recall |
+ accuracy |
+ AUC |
+
+
+
+
+ XGBoost |
+ 0.931876 |
+ 0.925972 |
+ 0.926508 |
+ 0.978721 |
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Out[23]:
+
+
+
+
+
+
+
+
+
+ |
+ precision |
+ recall |
+ accuracy |
+ AUC |
+
+
+
+
+ XGBoost |
+ 1.0 |
+ 0.916667 |
+ 0.972222 |
+ 0.958333 |
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Out[24]:
+
+
+
+
+
+
34 1
+Name: to_water, dtype: int64
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Out[25]:
+
+
+
+
+
+
stan_nawodnienia 0.34
+stan_nawiezienia 0.72
+stopien_rozwoju 0.53
+rodzaj_gleby_brunatne 0.00
+rodzaj_gleby_czarnoziemy 0.00
+rodzaj_gleby_piaszczyste 1.00
+rodzaj_rosliny_brak 0.00
+rodzaj_rosliny_kaktus 0.00
+rodzaj_rosliny_pszenica 1.00
+rodzaj_rosliny_ziemniak 0.00
+rodzaj_nawozu_mineralny 1.00
+rodzaj_nawozu_organiczny 0.00
+rodzaj_nawozu_sztuczny 0.00
+Name: 34, dtype: float64
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/utils/create_model.ipynb b/src/utils/create_model.ipynb
new file mode 100644
index 0000000..b1956ba
--- /dev/null
+++ b/src/utils/create_model.ipynb
@@ -0,0 +1,7123 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "e044c818",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "import seaborn as sns\n",
+ "import numpy as np\n",
+ "import warnings\n",
+ "import pickle\n",
+ "from matplotlib import pyplot as plt\n",
+ "from xgboost import XGBClassifier\n",
+ "from sklearn.model_selection import cross_val_score, cross_validate, KFold\n",
+ "from sklearn.metrics import accuracy_score, recall_score, precision_score, roc_auc_score, f1_score\n",
+ "from hyperopt import hp, fmin, tpe, Trials, STATUS_OK\n",
+ "warnings.filterwarnings('ignore')\n",
+ "%config InlineBackend.figure_format = 'svg'"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "605e1308",
+ "metadata": {},
+ "source": [
+ "# EDA"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "37a1a88d",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " stan_nawodnienia | \n",
+ " rodzaj_gleby | \n",
+ " stan_nawiezienia | \n",
+ " stopien_rozwoju | \n",
+ " rodzaj_rosliny | \n",
+ " rodzaj_nawozu | \n",
+ " to_water | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0.32 | \n",
+ " brunatne | \n",
+ " 0.01 | \n",
+ " 0.75 | \n",
+ " ziemniak | \n",
+ " mineralny | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 0.38 | \n",
+ " brunatne | \n",
+ " 0.15 | \n",
+ " 0.85 | \n",
+ " pszenica | \n",
+ " organiczny | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.17 | \n",
+ " piaszczyste | \n",
+ " 0.22 | \n",
+ " 0.13 | \n",
+ " brak | \n",
+ " organiczny | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.67 | \n",
+ " czarnoziemy | \n",
+ " 0.64 | \n",
+ " 0.55 | \n",
+ " pszenica | \n",
+ " sztuczny | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.45 | \n",
+ " brunatne | \n",
+ " 0.16 | \n",
+ " 0.16 | \n",
+ " brak | \n",
+ " organiczny | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 172 | \n",
+ " 0.22 | \n",
+ " czarnoziemy | \n",
+ " 0.24 | \n",
+ " 0.25 | \n",
+ " kaktus | \n",
+ " organiczny | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 173 | \n",
+ " 0.17 | \n",
+ " brunatne | \n",
+ " 0.31 | \n",
+ " 0.35 | \n",
+ " ziemniak | \n",
+ " mineralny | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 174 | \n",
+ " 0.91 | \n",
+ " brunatne | \n",
+ " 0.58 | \n",
+ " 0.89 | \n",
+ " ziemniak | \n",
+ " sztuczny | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 175 | \n",
+ " 0.07 | \n",
+ " brunatne | \n",
+ " 0.27 | \n",
+ " 0.27 | \n",
+ " ziemniak | \n",
+ " organiczny | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 176 | \n",
+ " 0.32 | \n",
+ " piaszczyste | \n",
+ " 0.12 | \n",
+ " 0.68 | \n",
+ " brak | \n",
+ " organiczny | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
177 rows × 7 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " stan_nawodnienia rodzaj_gleby stan_nawiezienia stopien_rozwoju \\\n",
+ "0 0.32 brunatne 0.01 0.75 \n",
+ "1 0.38 brunatne 0.15 0.85 \n",
+ "2 0.17 piaszczyste 0.22 0.13 \n",
+ "3 0.67 czarnoziemy 0.64 0.55 \n",
+ "4 0.45 brunatne 0.16 0.16 \n",
+ ".. ... ... ... ... \n",
+ "172 0.22 czarnoziemy 0.24 0.25 \n",
+ "173 0.17 brunatne 0.31 0.35 \n",
+ "174 0.91 brunatne 0.58 0.89 \n",
+ "175 0.07 brunatne 0.27 0.27 \n",
+ "176 0.32 piaszczyste 0.12 0.68 \n",
+ "\n",
+ " rodzaj_rosliny rodzaj_nawozu to_water \n",
+ "0 ziemniak mineralny 1 \n",
+ "1 pszenica organiczny 1 \n",
+ "2 brak organiczny 1 \n",
+ "3 pszenica sztuczny 0 \n",
+ "4 brak organiczny 1 \n",
+ ".. ... ... ... \n",
+ "172 kaktus organiczny 1 \n",
+ "173 ziemniak mineralny 1 \n",
+ "174 ziemniak sztuczny 0 \n",
+ "175 ziemniak organiczny 1 \n",
+ "176 brak organiczny 1 \n",
+ "\n",
+ "[177 rows x 7 columns]"
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "train = pd.read_csv('../../assets/data/train.csv')\n",
+ "train"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "ddda0ee1",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\r\n",
+ "\r\n",
+ "\r\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "sns.countplot('to_water', data=train);"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "823b28a3",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\r\n",
+ "\r\n",
+ "\r\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "sns.countplot('rodzaj_gleby', data=train);"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "ff2991fb",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\r\n",
+ "\r\n",
+ "\r\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "sns.countplot('rodzaj_rosliny', data=train);"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "c2815299",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\r\n",
+ "\r\n",
+ "\r\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "sns.countplot('rodzaj_nawozu', data=train);"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "b66c9226",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\r\n",
+ "\r\n",
+ "\r\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "sns.boxplot(train.to_water, train.stan_nawodnienia);"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "8335c0ba",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\r\n",
+ "\r\n",
+ "\r\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "sns.boxplot(train.to_water, train.stan_nawiezienia);"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1b7c9621",
+ "metadata": {},
+ "source": [
+ "# Features Engineering"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "8da84497",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " stan_nawodnienia | \n",
+ " stan_nawiezienia | \n",
+ " stopien_rozwoju | \n",
+ " to_water | \n",
+ " rodzaj_gleby_brunatne | \n",
+ " rodzaj_gleby_czarnoziemy | \n",
+ " rodzaj_gleby_piaszczyste | \n",
+ " rodzaj_rosliny_brak | \n",
+ " rodzaj_rosliny_kaktus | \n",
+ " rodzaj_rosliny_pszenica | \n",
+ " rodzaj_rosliny_ziemniak | \n",
+ " rodzaj_nawozu_mineralny | \n",
+ " rodzaj_nawozu_organiczny | \n",
+ " rodzaj_nawozu_sztuczny | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0.32 | \n",
+ " 0.01 | \n",
+ " 0.75 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 0.38 | \n",
+ " 0.15 | \n",
+ " 0.85 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.17 | \n",
+ " 0.22 | \n",
+ " 0.13 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.67 | \n",
+ " 0.64 | \n",
+ " 0.55 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.45 | \n",
+ " 0.16 | \n",
+ " 0.16 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 172 | \n",
+ " 0.22 | \n",
+ " 0.24 | \n",
+ " 0.25 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 173 | \n",
+ " 0.17 | \n",
+ " 0.31 | \n",
+ " 0.35 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 174 | \n",
+ " 0.91 | \n",
+ " 0.58 | \n",
+ " 0.89 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 175 | \n",
+ " 0.07 | \n",
+ " 0.27 | \n",
+ " 0.27 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 176 | \n",
+ " 0.32 | \n",
+ " 0.12 | \n",
+ " 0.68 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
177 rows × 14 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " stan_nawodnienia stan_nawiezienia stopien_rozwoju to_water \\\n",
+ "0 0.32 0.01 0.75 1 \n",
+ "1 0.38 0.15 0.85 1 \n",
+ "2 0.17 0.22 0.13 1 \n",
+ "3 0.67 0.64 0.55 0 \n",
+ "4 0.45 0.16 0.16 1 \n",
+ ".. ... ... ... ... \n",
+ "172 0.22 0.24 0.25 1 \n",
+ "173 0.17 0.31 0.35 1 \n",
+ "174 0.91 0.58 0.89 0 \n",
+ "175 0.07 0.27 0.27 1 \n",
+ "176 0.32 0.12 0.68 1 \n",
+ "\n",
+ " rodzaj_gleby_brunatne rodzaj_gleby_czarnoziemy \\\n",
+ "0 1 0 \n",
+ "1 1 0 \n",
+ "2 0 0 \n",
+ "3 0 1 \n",
+ "4 1 0 \n",
+ ".. ... ... \n",
+ "172 0 1 \n",
+ "173 1 0 \n",
+ "174 1 0 \n",
+ "175 1 0 \n",
+ "176 0 0 \n",
+ "\n",
+ " rodzaj_gleby_piaszczyste rodzaj_rosliny_brak rodzaj_rosliny_kaktus \\\n",
+ "0 0 0 0 \n",
+ "1 0 0 0 \n",
+ "2 1 1 0 \n",
+ "3 0 0 0 \n",
+ "4 0 1 0 \n",
+ ".. ... ... ... \n",
+ "172 0 0 1 \n",
+ "173 0 0 0 \n",
+ "174 0 0 0 \n",
+ "175 0 0 0 \n",
+ "176 1 1 0 \n",
+ "\n",
+ " rodzaj_rosliny_pszenica rodzaj_rosliny_ziemniak \\\n",
+ "0 0 1 \n",
+ "1 1 0 \n",
+ "2 0 0 \n",
+ "3 1 0 \n",
+ "4 0 0 \n",
+ ".. ... ... \n",
+ "172 0 0 \n",
+ "173 0 1 \n",
+ "174 0 1 \n",
+ "175 0 1 \n",
+ "176 0 0 \n",
+ "\n",
+ " rodzaj_nawozu_mineralny rodzaj_nawozu_organiczny rodzaj_nawozu_sztuczny \n",
+ "0 1 0 0 \n",
+ "1 0 1 0 \n",
+ "2 0 1 0 \n",
+ "3 0 0 1 \n",
+ "4 0 1 0 \n",
+ ".. ... ... ... \n",
+ "172 0 1 0 \n",
+ "173 1 0 0 \n",
+ "174 0 0 1 \n",
+ "175 0 1 0 \n",
+ "176 0 1 0 \n",
+ "\n",
+ "[177 rows x 14 columns]"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "train = pd.get_dummies(train)\n",
+ "train"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "b22a5623",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "X_train = train.drop('to_water', axis=1)\n",
+ "y_train = train['to_water']"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1f7e3d8e",
+ "metadata": {},
+ "source": [
+ "# Training XGBoost model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "58187f92",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,\n",
+ " colsample_bynode=1, colsample_bytree=1, enable_categorical=False,\n",
+ " eval_metric='mlogloss', gamma=0, gpu_id=-1, importance_type=None,\n",
+ " interaction_constraints='', learning_rate=0.300000012,\n",
+ " max_delta_step=0, max_depth=6, min_child_weight=1, missing=nan,\n",
+ " monotone_constraints='()', n_estimators=100, n_jobs=12,\n",
+ " num_parallel_tree=1, predictor='auto', random_state=0,\n",
+ " reg_alpha=0, reg_lambda=1, scale_pos_weight=1, subsample=1,\n",
+ " tree_method='exact', validate_parameters=1, verbosity=None)"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "xgb = XGBClassifier(eval_metric='mlogloss')\n",
+ "xgb.fit(X_train, y_train)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "ebaba036",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def cv_model(X_train, y_train, model):\n",
+ " kfold = KFold(n_splits=5, shuffle=True, random_state=42)\n",
+ " scoring = {'precision': 'precision_macro',\n",
+ " 'recall': 'recall_macro',\n",
+ " 'accuracy': 'accuracy',\n",
+ " 'auc': 'roc_auc'}\n",
+ " cv_results = []\n",
+ " \n",
+ " cv_model = cross_validate(model, X_train, y_train, scoring=scoring, cv=kfold)\n",
+ " cv_results.append([cv_model['test_precision'].mean(), cv_model['test_recall'].mean(),\n",
+ " cv_model['test_accuracy'].mean(), cv_model['test_auc'].mean()])\n",
+ " \n",
+ " return cv_results\n",
+ "\n",
+ "def show_cv_results(cv_results):\n",
+ " cvr_df = pd.DataFrame(index=['XGBoost'],\n",
+ " data=cv_results, columns=['precision', 'recall', 'accuracy', 'AUC'])\n",
+ " return cvr_df"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "9aca9ade",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " precision | \n",
+ " recall | \n",
+ " accuracy | \n",
+ " AUC | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " XGBoost | \n",
+ " 0.924244 | \n",
+ " 0.920625 | \n",
+ " 0.920952 | \n",
+ " 0.973928 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " precision recall accuracy AUC\n",
+ "XGBoost 0.924244 0.920625 0.920952 0.973928"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "cv_results = cv_model(X_train, y_train, xgb)\n",
+ "show_cv_results(cv_results)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "2faf6ef1",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\r\n",
+ "\r\n",
+ "\r\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "feat_importances = pd.Series(xgb.feature_importances_, index=X_train.columns).sort_values().nlargest(10)\n",
+ "feat_importances.sort_values().plot(kind='barh', figsize=[9, 7]);"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6bc9eb1e",
+ "metadata": {},
+ "source": [
+ "# XGBoost parameter tuning using HyperOpt"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "9cb0e365",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "params_bounds = {'learning_rate': hp.uniform('learning_rate', 0.01, 1.0),\n",
+ " 'n_estimators': hp.uniform('n_estimators', 100.0, 1000.0),\n",
+ " 'max_depth': hp.uniform('max_depth', 4.0, 10.0), \n",
+ " 'subsample': hp.uniform('subsample', 0.5, 1.0),\n",
+ " 'gamma': hp.uniform('gamma', 0.0, 5.0)}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "0a6867f5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def fun(params_bounds):\n",
+ " model = XGBClassifier(eval_metric='mlogloss',\n",
+ " learning_rate = params_bounds['learning_rate'],\n",
+ " n_estimators = round(params_bounds['n_estimators']),\n",
+ " max_depth = round(params_bounds['max_depth']),\n",
+ " subsample = params_bounds['subsample'],\n",
+ " gamma = params_bounds['gamma'])\n",
+ " score = cross_val_score(model, X_train, y_train, cv=5, scoring='roc_auc', error_score='raise').mean()\n",
+ " \n",
+ " return {'loss': -score, 'status': STATUS_OK}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "532b8dcd",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "100%|███████████████████████████████████████████████| 30/30 [00:20<00:00, 1.47trial/s, best loss: -0.9791677631578948]\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "{'gamma': 2.877864670923052,\n",
+ " 'learning_rate': 0.01623819939058474,\n",
+ " 'max_depth': 6,\n",
+ " 'n_estimators': 378,\n",
+ " 'subsample': 0.8352886879422001}"
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "trials = Trials()\n",
+ "best = fmin(fn=fun,\n",
+ " space=params_bounds,\n",
+ " algo=tpe.suggest,\n",
+ " max_evals=30,\n",
+ " trials=trials)\n",
+ "best['max_depth'] = round(best['max_depth'])\n",
+ "best['n_estimators'] = round(best['n_estimators'])\n",
+ "best"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "df612fb1",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " precision | \n",
+ " recall | \n",
+ " accuracy | \n",
+ " AUC | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " XGBoost | \n",
+ " 0.936778 | \n",
+ " 0.930734 | \n",
+ " 0.932063 | \n",
+ " 0.975984 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " precision recall accuracy AUC\n",
+ "XGBoost 0.936778 0.930734 0.932063 0.975984"
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "xgb = XGBClassifier(eval_metric='mlogloss', **best)\n",
+ "xgb.fit(X_train, y_train)\n",
+ "cv_results = cv_model(X_train, y_train, xgb)\n",
+ "show_cv_results(cv_results)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1e8dcbad",
+ "metadata": {},
+ "source": [
+ "**As expected, no improvement at all**"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "39a2330b",
+ "metadata": {},
+ "source": [
+ "# Testing Model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "1c28788a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "test = pd.read_csv('../../assets/data/test.csv')\n",
+ "test = pd.get_dummies(test)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "41f3bab6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "X_test = test.drop('to_water', axis=1)\n",
+ "y_test = test['to_water']"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "id": "21f0fe78",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "prediction = xgb.predict(X_test)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "2fb11624",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " precision | \n",
+ " recall | \n",
+ " accuracy | \n",
+ " AUC | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " XGBoost | \n",
+ " 1.0 | \n",
+ " 0.916667 | \n",
+ " 0.972222 | \n",
+ " 0.958333 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " precision recall accuracy AUC\n",
+ "XGBoost 1.0 0.916667 0.972222 0.958333"
+ ]
+ },
+ "execution_count": 23,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "data=[[precision_score(y_test, prediction),\n",
+ " recall_score(y_test, prediction),\n",
+ " accuracy_score(y_test, prediction),\n",
+ " roc_auc_score(y_test, prediction)]]\n",
+ "result_df = pd.DataFrame(index=['XGBoost'],\n",
+ " data=data, columns=['precision', 'recall', 'accuracy', 'AUC'])\n",
+ "result_df"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "69deba28",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "34 1\n",
+ "Name: to_water, dtype: int64"
+ ]
+ },
+ "execution_count": 24,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "y_test[y_test != prediction]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a53b99d2",
+ "metadata": {},
+ "source": [
+ "**Only one wrong prediction:**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "id": "6a4a8707",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "stan_nawodnienia 0.34\n",
+ "stan_nawiezienia 0.72\n",
+ "stopien_rozwoju 0.53\n",
+ "rodzaj_gleby_brunatne 0.00\n",
+ "rodzaj_gleby_czarnoziemy 0.00\n",
+ "rodzaj_gleby_piaszczyste 1.00\n",
+ "rodzaj_rosliny_brak 0.00\n",
+ "rodzaj_rosliny_kaktus 0.00\n",
+ "rodzaj_rosliny_pszenica 1.00\n",
+ "rodzaj_rosliny_ziemniak 0.00\n",
+ "rodzaj_nawozu_mineralny 1.00\n",
+ "rodzaj_nawozu_organiczny 0.00\n",
+ "rodzaj_nawozu_sztuczny 0.00\n",
+ "Name: 34, dtype: float64"
+ ]
+ },
+ "execution_count": 25,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "X_test.iloc[34]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "fec680e9",
+ "metadata": {},
+ "source": [
+ "# Save XGBoost model using pickle"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "id": "e043273b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "filename = '../../assets/model/xgboost_model.pkl'\n",
+ "pickle.dump(xgb, open(filename, 'wb'))"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.7"
+ },
+ "toc": {
+ "base_numbering": 1,
+ "nav_menu": {},
+ "number_sections": false,
+ "sideBar": true,
+ "skip_h1_title": false,
+ "title_cell": "Table of Contents",
+ "title_sidebar": "Contents",
+ "toc_cell": false,
+ "toc_position": {
+ "height": "calc(100% - 180px)",
+ "left": "10px",
+ "top": "150px",
+ "width": "303.825px"
+ },
+ "toc_section_display": true,
+ "toc_window_display": true
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
\ No newline at end of file
diff --git a/src/utils/xgb_model.py b/src/utils/xgb_model.py
new file mode 100644
index 0000000..5763c87
--- /dev/null
+++ b/src/utils/xgb_model.py
@@ -0,0 +1,27 @@
+import pandas as pd
+import pickle
+import xgboost
+
+
+class Model:
+ """ Class to represent trained XGBoost model that predicts data on our board """
+
+ def __init__(self):
+ self.input_path = 'assets/data/test.csv'
+ self.df = pd.read_csv(self.input_path)
+ self.model = pickle.load(open('assets/model/xgboost_model.pkl', 'rb'))
+ self.X_test = None
+ self.y_test = None
+ self.parse_input()
+ self.predict_data()
+
+ def parse_input(self):
+ self.y_test = self.df['to_water']
+ self.df = self.df.drop('to_water', axis=1)
+ self.X_test = pd.get_dummies(self.df)
+
+ def predict_data(self):
+ prediction = self.model.predict(self.X_test)
+ # print(self.y_test[self.y_test != prediction])
+ self.df['to_water'] = prediction
+
diff --git a/src/world.py b/src/world.py
index 30f622e..a7831ac 100644
--- a/src/world.py
+++ b/src/world.py
@@ -5,40 +5,85 @@ from src.tile import Tile
class World:
""" Class to represent complete game board, storing Tile classes inside Sprite Group """
- def __init__(self, settings):
- self.world_data = [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
- self.dirt_image = pygame.image.load('assets/images/dirt.jpeg')
- self.rock_image = pygame.image.load('assets/images/cobblestone.jpg')
+ def __init__(self, settings, model):
self.settings = settings
+ self.model = model
+ self.world_data = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 1, 1, 1, 0, 0, 1, 1, 1, 0],
+ [0, 1, 1, 1, 0, 0, 1, 1, 1, 0],
+ [0, 1, 1, 1, 0, 0, 1, 1, 1, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 1, 1, 1, 0, 0, 1, 1, 1, 0],
+ [0, 1, 1, 1, 0, 0, 1, 1, 1, 0],
+ [0, 1, 1, 1, 0, 0, 1, 1, 1, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
+ self.rock = pygame.image.load('assets/images/cobblestone.jpg')
+ self.dirt_empty = pygame.image.load('assets/images/dirt_empty.jpeg')
+ self.dirt_cactus = pygame.image.load('assets/images/dirt_cactus.jpg')
+ self.dirt_wheat = pygame.image.load('assets/images/dirt_wheat.jpg')
+ self.dirt_potato = pygame.image.load('assets/images/dirt_potato.jpg')
+ self.farmland_empty = pygame.image.load('assets/images/farmland_empty.png')
+ self.farmland_cactus = pygame.image.load('assets/images/farmland_cactus.jpg')
+ self.farmland_wheat = pygame.image.load('assets/images/farmland_wheat.jpg')
+ self.farmland_potato = pygame.image.load('assets/images/farmland_potato.jpg')
self.tiles = pygame.sprite.Group() # mamy tiles jako Sprite Group, to sie przyda potem do kolizji itp.
self.create_tiles()
def create_tiles(self):
row_count = 0
+ df_idx = 0
for row in self.world_data:
col_count = 0
for tile in row:
if tile == 1:
- img = pygame.transform.scale(self.dirt_image, (self.settings.tile_size, self.settings.tile_size))
- type = 'dirt'
- cost = 1
+ type = 'dirt' # type dirt mówimy nam ogólnie, ze jest to pole uprawne, szczegóły rośliny potem
+ cost = 0
+ stan_nawodnienia = self.model.df.iloc[df_idx]['stan_nawodnienia']
+ rodzaj_gleby = self.model.df.iloc[df_idx]['rodzaj_gleby']
+ stan_nawiezienia = self.model.df.iloc[df_idx]['stan_nawiezienia']
+ stopien_rozwoju = self.model.df.iloc[df_idx]['stopien_rozwoju']
+ rodzaj_rosliny = self.model.df.iloc[df_idx]['rodzaj_rosliny']
+ rodzaj_nawozu = self.model.df.iloc[df_idx]['rodzaj_nawozu']
+ to_water = self.model.df.iloc[df_idx]['to_water']
+
+ if to_water == 0 and rodzaj_rosliny == 'brak':
+ img = pygame.transform.scale(self.farmland_empty, (self.settings.tile_size, self.settings.tile_size))
+ elif to_water == 0 and rodzaj_rosliny == 'kaktus':
+ img = pygame.transform.scale(self.farmland_cactus, (self.settings.tile_size, self.settings.tile_size))
+ elif to_water == 0 and rodzaj_rosliny == 'pszenica':
+ img = pygame.transform.scale(self.farmland_wheat, (self.settings.tile_size, self.settings.tile_size))
+ elif to_water == 0 and rodzaj_rosliny == 'ziemniak':
+ img = pygame.transform.scale(self.farmland_potato, (self.settings.tile_size, self.settings.tile_size))
+ if to_water == 1 and rodzaj_rosliny == 'brak':
+ img = pygame.transform.scale(self.dirt_empty, (self.settings.tile_size, self.settings.tile_size))
+ elif to_water == 1 and rodzaj_rosliny == 'kaktus':
+ img = pygame.transform.scale(self.dirt_cactus, (self.settings.tile_size, self.settings.tile_size))
+ elif to_water == 1 and rodzaj_rosliny == 'pszenica':
+ img = pygame.transform.scale(self.dirt_wheat, (self.settings.tile_size, self.settings.tile_size))
+ elif to_water == 1 and rodzaj_rosliny == 'ziemniak':
+ img = pygame.transform.scale(self.dirt_potato, (self.settings.tile_size, self.settings.tile_size))
+ df_idx += 1
+
elif tile == 0:
- img = pygame.transform.scale(self.rock_image, (self.settings.tile_size, self.settings.tile_size))
- type = 'rock'
- cost = 1000
+ img = pygame.transform.scale(self.rock, (self.settings.tile_size, self.settings.tile_size))
+ type = 'rock' # podobnie j.w., na polu rock nie mamy upraw
+ cost = 100
+
+ stan_nawodnienia = None
+ rodzaj_gleby = None
+ stan_nawiezienia = None
+ stopien_rozwoju = None
+ rodzaj_rosliny = None
+ rodzaj_nawozu = None
+ to_water = None
+
img_rect = img.get_rect()
img_rect.x = col_count * self.settings.tile_size
img_rect.y = row_count * self.settings.tile_size
- tile = Tile(type, col_count, len(self.world_data) - row_count - 1, img, img_rect, cost)
+ tile = Tile(type, col_count, len(self.world_data) - row_count - 1, img, img_rect, cost,
+ stan_nawodnienia, rodzaj_gleby, stan_nawiezienia, stopien_rozwoju, rodzaj_rosliny,
+ rodzaj_nawozu, to_water)
self.tiles.add(tile)
col_count += 1
row_count += 1
@@ -53,9 +98,12 @@ class World:
pygame.draw.line(screen, (255, 255, 255), (line * self.settings.tile_size, 0),
(line * self.settings.tile_size, self.settings.screen_height))
+ def get_tile(self, x, y):
+ for tile in self.tiles:
+ if tile.position == (x, y):
+ return tile
+
def get_tile_cost(self, x, y):
for tile in self.tiles:
if tile.position == (x, y):
- return tile.cost
-
-
+ return tile.cost
\ No newline at end of file