trenowanie modelu, wytrenowany model

This commit is contained in:
Anna Nowak 2021-04-25 17:39:38 +02:00
parent 0a38aa7d86
commit 1d7aa10b57
8 changed files with 3635 additions and 10 deletions

5
.gitignore vendored
View File

@ -57,6 +57,9 @@ docs/source/changelog.md
#fifa dataset #fifa dataset
fifa19* fifa19*
*.csv data.csv
test.csv
train.csv
dev.csv
stat.txt stat.txt
.venv/ .venv/

View File

@ -8,3 +8,4 @@ RUN pip3 install -r requirements.txt
COPY ["Zadanie 1.py", "."] COPY ["Zadanie 1.py", "."]
COPY ["stats.py", "."] COPY ["stats.py", "."]
COPY ["train.py", "."]

View File

@ -3,6 +3,236 @@ import os
import pandas as pd import pandas as pd
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
positions=['RF', 'ST', 'LW', 'GK', 'RCM', 'LF', 'RS', 'RCB', 'LCM', 'CB',
'LDM', 'CAM', 'CDM', 'LS', 'LCB', 'RM', 'LM', 'LB', 'RDM', 'RW',
'CM', 'RB', 'RAM', 'CF', 'LAM', 'RWB', 'LWB']
nationalities=['Argentina', 'Portugal', 'Brazil', 'Spain', 'Belgium', 'Croatia',
'Uruguay', 'Slovenia', 'Poland', 'Germany', 'France', 'England',
'Italy', 'Egypt', 'Denmark', 'Gabon', 'Wales', 'Senegal',
'Costa Rica', 'Slovakia', 'Netherlands', 'Bosnia Herzegovina',
'Morocco', 'Serbia', 'Algeria', 'Austria', 'Greece', 'Chile',
'Sweden', 'Colombia', 'Korea Republic', 'Finland', 'Guinea',
'Montenegro', 'Armenia', 'Switzerland', 'Norway', 'Czech Republic',
'Scotland', 'Ghana', 'Central African Rep.', 'DR Congo',
'Ivory Coast', 'Russia', 'Ukraine', 'Iceland', 'Mexico', 'Jamaica',
'Albania', 'Venezuela', 'Japan', 'Turkey', 'Ecuador', 'Paraguay',
'Mali', 'Nigeria', 'Cameroon', 'Dominican Republic', 'Israel',
'Kenya', 'Hungary', 'Republic of Ireland', 'Romania',
'United States', 'Cape Verde', 'Australia', 'Peru', 'Togo',
'Syria', 'Zimbabwe', 'Angola', 'Burkina Faso', 'Iran', 'Estonia',
'Tunisia', 'Equatorial Guinea', 'New Zealand', 'FYR Macedonia',
'United Arab Emirates', 'China PR', 'Guinea Bissau', 'Bulgaria',
'Kosovo', 'South Africa', 'Madagascar', 'Georgia', 'Tanzania',
'Gambia', 'Cuba', 'Belarus', 'Uzbekistan', 'Benin', 'Congo',
'Mozambique', 'Honduras', 'Canada', 'Northern Ireland', 'Cyprus',
'Saudi Arabia', 'Curacao', 'Moldova', 'Bolivia',
'Trinidad & Tobago', 'Sierra Leone', 'Zambia', 'Chad',
'Philippines', 'Haiti', 'Comoros', 'Libya', 'Panama',
'São Tomé & Príncipe', 'Eritrea', 'Oman', 'Iraq', 'Burundi',
'Fiji', 'New Caledonia', 'Lithuania', 'Luxembourg', 'Korea DPR',
'Liechtenstein', 'St Kitts Nevis', 'Latvia', 'Suriname', 'Uganda',
'El Salvador', 'Kuwait', 'Antigua & Barbuda', 'Thailand',
'Mauritius', 'Guatemala', 'Liberia', 'Kazakhstan', 'Niger',
'Mauritania', 'Montserrat', 'Namibia', 'Azerbaijan', 'Guam',
'Faroe Islands', 'Nicaragua', 'Barbados', 'Lebanon', 'Palestine',
'Guyana', 'Sudan', 'Ethiopia', 'Puerto Rico', 'Grenada', 'Jordan',
'Rwanda', 'Bermuda', 'Qatar', 'Afghanistan', 'Hong Kong',
'Andorra', 'Belize', 'South Sudan', 'Indonesia', 'Botswana']
clubs = ['FC Barcelona', 'Juventus', 'Paris Saint-Germain',
'Manchester United', 'Manchester City', 'Chelsea', 'Real Madrid',
'Atlético Madrid', 'FC Bayern München', 'Tottenham Hotspur',
'Liverpool', 'Napoli', 'Arsenal', 'Inter', 'Lazio',
'Borussia Dortmund', 'Vissel Kobe', 'Olympique Lyonnais', 'Roma',
'Valencia CF', 'FC Porto', 'FC Schalke 04', 'Beşiktaş JK',
'LA Galaxy', 'Sporting CP', 'Real Betis', 'Olympique de Marseille',
'RC Celta', 'Bayer 04 Leverkusen', 'Real Sociedad',
'Villarreal CF', 'Sevilla FC', 'SL Benfica', 'AS Saint-Étienne',
'AS Monaco', 'Leicester City', 'Atalanta', 'Grêmio',
'Atlético Mineiro', 'RB Leipzig', 'Ajax', 'Dalian YiFang FC',
'Everton', 'Milan', 'West Ham United', '1. FC Köln',
'TSG 1899 Hoffenheim', 'Shanghai SIPG FC', 'OGC Nice', 'Al Nassr',
'Wolverhampton Wanderers', 'Borussia Mönchengladbach',
'Hertha BSC', 'SV Werder Bremen', 'Cruzeiro',
'Athletic Club de Bilbao', 'Torino', 'Medipol Başakşehir FK',
'Beijing Sinobo Guoan FC', 'Crystal Palace', 'PFC CSKA Moscow',
'VfL Wolfsburg', 'Shakhtar Donetsk', 'Toronto FC',
'Lokomotiv Moscow', 'Sassuolo', 'New York City FC', 'Fluminense',
'PSV', 'Levante UD', 'Fulham', 'Watford', 'Atlanta United',
'Montpellier HSC', 'Galatasaray SK', 'Fenerbahçe SK', 'SD Eibar',
'Los Angeles FC', 'Sampdoria', 'Al Hilal', 'VfB Stuttgart',
'SC Braga', 'River Plate', 'Deportivo Alavés',
'Eintracht Frankfurt', 'Girona FC', 'Guangzhou R&F; FC', 'Burnley',
'Stoke City', 'Southampton', 'Tianjin Quanjian FC', 'Getafe CF',
'Beijing Renhe FC', 'Montreal Impact', 'Chievo Verona', 'Genoa',
'Portland Timbers', 'Tigres U.A.N.L.', 'RCD Espanyol',
'Hebei China Fortune FC', 'Cagliari', 'Chicago Fire', 'DC United',
'Sagan Tosu', 'Dynamo Kyiv', 'Santos', 'Internacional',
'América FC (Minas Gerais)', 'Independiente', 'Boca Juniors',
'Cruz Azul', '1. FSV Mainz 05', 'Bournemouth', 'Spartak Moscow',
'Racing Club', 'FC Augsburg', 'Fiorentina', 'FC Nantes',
'Feyenoord', 'Club Brugge KV', 'Brighton & Hove Albion',
'Guangzhou Evergrande Taobao FC', 'Al Ahli', 'Jiangsu Suning FC',
'SC Freiburg', 'PAOK', 'Stade Rennais FC', 'Trabzonspor', 'SPAL',
'Portimonense SC', 'Olympiacos CFP', 'Club Atlético Huracán',
'Kasimpaşa SK', 'Newcastle United', 'Querétaro', 'KRC Genk',
'Hannover 96', 'Stade Malherbe Caen', 'Godoy Cruz',
'Toulouse Football Club', 'RSC Anderlecht', 'Huddersfield Town',
'CD Tondela', 'Seattle Sounders FC', 'Hamburger SV',
'FC Red Bull Salzburg', 'Rio Ave FC', 'FC Girondins de Bordeaux',
'Melbourne Victory', 'Parma', 'FC Basel 1893', 'Al Wehda',
'BSC Young Boys', 'KAA Gent', 'Al Ittihad', 'Standard de Liège',
'Shanghai Greenland Shenhua FC', 'Colo-Colo', 'Junior FC',
'West Bromwich Albion', 'RC Strasbourg Alsace', 'Göztepe SK',
'Deportivo Cali', 'Deportivo Toluca', 'Bologna', 'Nagoya Grampus',
'Amiens SC', 'Changchun Yatai FC', 'Club Atlético Lanús',
'Botafogo', 'Club América', 'Udinese', 'Real Valladolid CF',
'CD Leganés', 'Club Atlético Banfield', 'Celtic',
'Vitória Guimarães', 'FC København', 'UD Las Palmas',
'Deportivo de La Coruña', 'Universidad Católica',
'San Lorenzo de Almagro', 'Rayo Vallecano', 'Monterrey',
'Columbus Crew SC', 'MKE Ankaragücü', 'Guizhou Hengfeng FC',
'Swansea City', 'Tianjin TEDA FC',
'Chongqing Dangdai Lifan FC SWM Team', 'AEK Athens', 'Al Taawoun',
'Melbourne City FC', 'En Avant de Guingamp',
'Akhisar Belediyespor', 'Foggia', 'LOSC Lille',
'Clube Sport Marítimo', 'Real Sporting de Gijón', 'BB Erzurumspor',
'Shandong Luneng TaiShan FC', 'Bahia', 'Once Caldas',
'FC Groningen', 'Angers SCO', 'Paraná', 'Antalyaspor',
'Minnesota United FC', 'Club León', 'Empoli', 'Leeds United',
'Viktoria Plzeň', 'Alanyaspor', 'Frosinone', 'Atlético Paranaense',
'Derby County', 'Kawasaki Frontale', 'Aston Villa', 'Guadalajara',
'Dijon FCO', 'Santos Laguna', 'Vitória', 'Çaykur Rizespor',
'U.N.A.M.', 'Nottingham Forest', 'Royal Antwerp FC',
'Club Tijuana', 'Sport Club do Recife', 'Real Salt Lake',
'AZ Alkmaar', 'SK Slavia Praha', 'Willem II', 'Middlesbrough',
'Dinamo Zagreb', 'Club Atlas', 'Granada CF', 'Sydney FC',
'Sporting Kansas City', 'SV Zulte-Waregem', 'Málaga CF',
'Real Oviedo', 'Pachuca', 'Boavista FC', 'Atiker Konyaspor',
'Kaizer Chiefs', 'GD Chaves', 'Palermo', 'Atlético Nacional',
'Puebla FC', 'Perth Glory', 'Panathinaikos FC', 'FC Sion',
'New York Red Bulls', 'Al Shabab', 'Club Atlético Colón',
'Monarcas Morelia', 'Albacete BP', 'Rangers FC', 'Sparta Praha',
'Philadelphia Union', 'Legia Warszawa', 'Urawa Red Diamonds',
'Rosario Central', 'Stade de Reims', 'ADO Den Haag', 'Chapecoense',
'FC Midtjylland', 'San Jose Earthquakes', 'Cardiff City',
'Belgrano de Córdoba', '1. FC Nürnberg', 'Brescia',
'Kashima Antlers', 'Vitória de Setúbal',
'CD Everton de Viña del Mar', 'Fortuna Düsseldorf', 'SD Huesca',
'Preston North End', 'Club Atlético Talleres', 'Benevento',
'Gimnasia y Esgrima La Plata', 'Houston Dynamo', 'Club Necaxa',
'Norwich City', 'Holstein Kiel', 'Ettifaq FC', 'Kayserispor',
'1. FC Heidenheim 1846', 'Vitesse', 'Brentford',
'Yeni Malatyaspor', 'Ceará Sporting Club', 'FC Ingolstadt 04',
'Estudiantes de La Plata', 'AIK', 'Queens Park Rangers',
'Suwon Samsung Bluewings', 'Heart of Midlothian', 'Reading',
'FC Dallas', 'Heracles Almelo', 'Bursaspor', 'Venezia FC',
'CD Lugo', 'Henan Jianye FC', 'Orlando City SC', 'CA Osasuna',
'Livorno', 'Universidad de Chile', 'Brøndby IF', 'Aberdeen',
'Defensa y Justicia', 'Atlético Tucumán', 'Blackburn Rovers',
'SV Darmstadt 98', 'Moreirense FC', 'Sanfrecce Hiroshima',
'CD Numancia', 'KV Oostende', 'Vancouver Whitecaps FC',
'Odense Boldklub', 'SC Heerenveen', 'Racing Club de Lens',
'Independiente Santa Fe', 'Sporting de Charleroi',
'Millonarios FC', 'Sheffield Wednesday', 'Perugia', 'Daegu FC',
'Vélez Sarsfield', 'Grasshopper Club Zürich', 'Sivasspor',
'Rosenborg BK', 'SK Sturm Graz', 'FC Metz',
'CD Universidad de Concepción', 'Brisbane Roar', 'CD Feirense',
'Hull City', 'Neuchâtel Xamax', 'Real Zaragoza', 'CD Aves',
'Millwall', 'Unión de Santa Fe', 'KAS Eupen', 'Cádiz CF',
'CD Tenerife', '1. FC Union Berlin', 'Al Fayha', 'AJ Auxerre',
'Nîmes Olympique', 'Patriotas Boyacá FC', 'Molde FK',
'Bristol City', 'CD Nacional', 'Sporting Lokeren', 'FC St. Pauli',
'Deportes Iquique', 'Al Qadisiyah', 'Sheffield United',
'Lobos BUAP', 'FC Utrecht', 'Club Atlético Tigre',
'FK Austria Wien', 'Patronato', 'Malmö FF', 'Kashiwa Reysol',
'US Cremonese', 'VfL Bochum 1848', 'SK Rapid Wien',
'Hellas Verona', 'Rionegro Águilas', 'Lecce', 'Santa Clara',
'BK Häcken', 'New England Revolution', 'Orlando Pirates',
'Atlético Huila', 'Western Sydney Wanderers', 'Kalmar FF',
'Independiente Medellín', 'Lech Poznań', 'Djurgårdens IF',
'CF Reus Deportiu', 'SK Brann', 'Ulsan Hyundai FC',
'Sint-Truidense VV', 'Al Fateh', 'Royal Excel Mouscron',
'AC Ajaccio', 'PEC Zwolle', 'Sunderland', 'Club Atlético Aldosivi',
'US Salernitana 1919', 'FC Lorient', 'Argentinos Juniors',
'AD Alcorcón', 'Crotone', 'Excelsior', 'Gimnàstic de Tarragona',
'FC Tokyo', 'KV Kortrijk', 'IFK Norrköping', 'Adelaide United',
'FC St. Gallen', 'Tiburones Rojos de Veracruz', 'CD Palestino',
'Jeju United FC', 'Deportes Tolima', 'Jeonbuk Hyundai Motors',
'Birmingham City', 'América de Cali', 'La Equidad', 'Spezia',
'Aalborg BK', 'Le Havre AC', 'KSV Cercle Brugge', 'Górnik Zabrze',
'Wigan Athletic', 'Jagiellonia Białystok', 'Cittadella',
'Hibernian', 'FC Lugano', 'San Martín de San Juan',
'Strømsgodset IF', "Newell's Old Boys", 'Al Faisaly',
'Colorado Rapids', 'IF Elfsborg', 'SV Sandhausen', 'Al Batin',
'VVV-Venlo', 'Stade Brestois 29', 'UD Almería', 'Gyeongnam FC',
'Yokohama F. Marinos', 'Kilmarnock', 'Pescara', 'Newcastle Jets',
'Central Coast Mariners', 'Córdoba CF', 'RCD Mallorca',
'Hammarby IF', 'Cerezo Osaka', 'KFC Uerdingen 05',
'Shimizu S-Pulse', 'MSV Duisburg', 'Os Belenenses',
'DSC Arminia Bielefeld', 'Ipswich Town', 'FC Seoul',
'Lechia Gdańsk', 'Gamba Osaka', 'CF Rayo Majadahonda', 'Carpi',
'LASK Linz', 'Bolton Wanderers', 'Al Raed', 'Extremadura UD',
'SC Paderborn 07', 'Wellington Phoenix', 'Unión Española',
'Alianza Petrolera', 'Cracovia', 'Gangwon FC', 'Júbilo Iwata',
'Elche CF', 'AS Béziers', 'La Berrichonne de Châteauroux',
'Clermont Foot 63', 'ESTAC Troyes', 'Pohang Steelers', 'Örebro SK',
'Arka Gdynia', 'SG Dynamo Dresden', 'SpVgg Greuther Fürth',
'Wisła Kraków', 'Stabæk Fotball', 'Eintracht Braunschweig',
'Valenciennes FC', 'FC Thun', 'San Luis de Quillota',
'Fortuna Sittard', ' SSV Jahn Regensburg', 'FC Nordsjælland',
'FC Erzgebirge Aue', 'Jeonnam Dragons', 'Wolfsberger AC',
'Chamois Niortais Football Club', 'Club Deportes Temuco',
'AS Nancy Lorraine', 'Red Star FC', 'Al Hazem', 'Pogoń Szczecin',
'Charlton Athletic', 'Grenoble Foot 38', 'FC Hansa Rostock',
'San Martin de Tucumán', 'Incheon United FC', 'Śląsk Wrocław',
'GFC Ajaccio', '1. FC Kaiserslautern', 'Waasland-Beveren',
'Deportivo Pasto', 'Lincoln City', 'Motherwell',
'Rotherham United', 'Burton Albion', 'Wisła Płock',
'CD Huachipato', 'FC Wacker Innsbruck', 'Atlético Bucaramanga',
'Peterborough United', 'Ascoli', 'FC Zürich', 'Fleetwood Town',
'Padova', 'SV Wehen Wiesbaden', 'FC Sochaux-Montbéliard',
'Unión La Calera', 'Scunthorpe United', 'NAC Breda',
'1. FC Magdeburg', "CD O'Higgins", 'CD Antofagasta',
'Plymouth Argyle', 'Aarhus GF', 'Lillestrøm SK', 'Karlsruher SC',
'GIF Sundsvall', 'FC Emmen', 'Barnsley', 'Audax Italiano',
'V-Varen Nagasaki', 'Paris FC', 'SpVgg Unterhaching', 'Hobro IK',
'De Graafschap', 'Hokkaido Consadole Sapporo', 'Tromsø IL',
'FC Luzern', 'FK Haugesund', 'Zagłębie Lubin', 'VfR Aalen',
'Dundalk', 'Piast Gliwice', 'Ohod Club', 'Östersunds FK',
'Crawley Town', 'FC Admira Wacker Mödling', 'Vålerenga Fotball',
'Oxford United', 'Dundee FC', 'Portsmouth', 'Envigado FC',
'Miedź Legnica', 'Odds BK', 'SC Fortuna Köln', 'Cosenza',
'US Orléans Loiret Football', 'Sarpsborg 08 FF',
'Jaguares de Córdoba', 'Bradford City', 'St. Johnstone FC',
'Boyacá Chicó FC', 'SV Mattersburg', 'Luton Town',
'Kristiansund BK', 'Sangju Sangmu FC', 'Walsall', 'Korona Kielce',
'Shonan Bellmare', 'FC Würzburger Kickers', 'FSV Zwickau',
'St. Mirren', 'AC Horsens', 'HJK Helsinki', 'Accrington Stanley',
'Southend United', 'Bristol Rovers', 'Hamilton Academical FC',
'TSV 1860 München', 'Curicó Unido', 'SCR Altach',
'Ranheim Fotball', 'Stevenage', 'SG Sonnenhof Großaspach',
'Oldham Athletic', 'Milton Keynes Dons', 'FK Bodø/Glimt',
'SC Preußen Münster', 'Vejle Boldklub', 'Vegalta Sendai', 'Bury',
'Randers FC', 'VfL Osnabrück', 'SønderjyskE', 'IFK Göteborg',
'Mansfield Town', 'Coventry City', 'Esbjerg fB', 'Waterford FC',
'Shrewsbury', 'IK Start', 'Rochdale', 'Gillingham',
'FC Energie Cottbus', 'FC Carl Zeiss Jena', 'Hallescher FC',
'Wycombe Wanderers', 'AFC Wimbledon', 'Blackpool',
'Doncaster Rovers', 'Sandefjord Fotball', 'VfL Sportfreunde Lotte',
'Cheltenham Town', 'IK Sirius', 'Vendsyssel FF', 'Swindon Town',
'SV Meppen', 'Notts County', 'SKN St. Pölten', 'Exeter City',
'Northampton Town', 'Shamrock Rovers', 'Colchester United',
'Livingston FC', 'TSV Hartberg', 'Tranmere Rovers',
'Cambridge United', 'Grimsby Town', 'Port Vale',
'Itagüí Leones FC', 'Forest Green Rovers', 'Dalkurd FF',
'Zagłębie Sosnowiec', 'Carlisle United', 'Trelleborgs FF',
"St. Patrick's Athletic", 'Morecambe', 'Cork City',
'IF Brommapojkarna', 'Crewe Alexandra', 'Yeovil Town',
'Bohemian FC', 'Macclesfield Town', 'Newport County',
'Sligo Rovers', 'Derry City', 'Limerick FC', 'Bray Wanderers']
if os.getenv("KAGGLE_KEY") is None or os.getenv("KAGGLE_USERNAME") is None: if os.getenv("KAGGLE_KEY") is None or os.getenv("KAGGLE_USERNAME") is None:
print("Brak zmiennych środowiskowych KAGGLE_KEY lub KAAGLE_USERNAME") print("Brak zmiennych środowiskowych KAGGLE_KEY lub KAAGLE_USERNAME")
exit() exit()
@ -17,9 +247,10 @@ df=pd.read_csv('data.csv')
df = df[df["Release Clause"].notna()] df = df[df["Release Clause"].notna()]
df = df[df["Release Clause"].notnull()] df = df[df["Release Clause"].notnull()]
if df["Overall"].mean() > 1: df["Age"]= df["Age"]/50
df["Overall"]= df["Overall"]/100 df["Nationality"] = df["Nationality"].apply(nationalities.index)/(len(nationalities)-1)
df["Position"] = df["Position"].apply(positions.index)/(len(positions)-1)
df["Club"] = df["Club"].apply(clubs.index)/(len(clubs)-1)
df["Release Clause"] = df["Release Clause"].str.replace("", "") df["Release Clause"] = df["Release Clause"].str.replace("", "")
df["Release Clause"] = (df["Release Clause"].replace(r'[KM]+$', '', regex=True).astype(float) * df["Release Clause"] = (df["Release Clause"].replace(r'[KM]+$', '', regex=True).astype(float) *
@ -28,7 +259,7 @@ df["Release Clause"] = (df["Release Clause"].replace(r'[KM]+$', '', regex=True).
df.to_csv('data.csv') df.to_csv('data.csv')
train, dev = train_test_split(df, train_size=0.6, test_size=0.4, shuffle=True) train, dev = train_test_split(df, train_size=0.6, test_size=0.4, shuffle=True)
dev, test = train_test_split(dev, train_size=0.5, test_size=0.5, shuffle=False) dev, test = train_test_split(dev, train_size=0.5, test_size=0.5, shuffle=True)
test.to_csv('test.csv') test.to_csv('test.csv')
dev.to_csv('dev.csv') dev.to_csv('dev.csv')

2
evaluation_result.txt Normal file
View File

@ -0,0 +1,2 @@
Train: 25.844615936279297
Test: 25.38555335998535

BIN
model.h5 Normal file

Binary file not shown.

View File

@ -1,4 +1,6 @@
kaggle kaggle==1.5.12
pandas pandas==1.2.4
numpy numpy==1.19.2
sklearn sklearn
tensorflow==2.4.1
jinja2==2.11.3

3330
results.csv Normal file

File diff suppressed because it is too large Load Diff

56
train.py Normal file
View File

@ -0,0 +1,56 @@
import pandas as pd
from os import path
from tensorflow import keras
from tensorflow.keras import layers
model_name = "model.h5"
train_data=pd.read_csv('train.csv')
input_columns=["Age","Nationality","Position","Club"]
X=train_data[input_columns].to_numpy()
Y=train_data[["Overall"]].to_numpy()
model = None
if path.exists(model_name):
model = keras.models.load_model(model_name)
else:
model = keras.Sequential(name="fifa_overall")
model.add(keras.Input(shape=(len(input_columns),), name="player_info"))
model.add(layers.Dense(4, activation="relu", name="layer1"))
model.add(layers.Dense(8, activation="relu", name="layer2"))
model.add(layers.Dense(8, activation="relu", name="layer3"))
model.add(layers.Dense(5, activation="relu", name="layer4"))
model.add(layers.Dense(1, activation="relu", name="output"))
model.compile(
optimizer=keras.optimizers.RMSprop(),
loss=keras.losses.MeanSquaredError(),
)
history = model.fit(
X,
Y,
batch_size=16,
epochs=15,
)
model.save(model_name)
test_data=pd.read_csv('test.csv')
X_test=test_data[input_columns].to_numpy()
Y_test=test_data[["Overall"]].to_numpy()
results_train = model.evaluate(X, Y, batch_size=128)
results_test = model.evaluate(X_test, Y_test, batch_size=128)
y_pred = model(X_test)
lines = ["Name;Overall;Predicted overall\n"]
for i in range(len(X_test)):
name = test_data["Name"][i]
lines.append(f"{name};{int(Y_test[i])};{int(y_pred[i])}\n")
with open('results.csv', 'w+', encoding="UTF-8") as f:
f.writelines(lines)
with open('evaluation_result.txt', 'w+', encoding="UTF-8") as f:
f.write(f"Train: {str(results_train)}\nTest: {str(results_test)}")