diff --git a/main.py b/main.py index 64c35dd..ea3436e 100644 --- a/main.py +++ b/main.py @@ -1,12 +1,59 @@ import pandas as pd +import matplotlib.pyplot as plt -labels = ["mileage", "year", "brand", "engine_type", "engine_capacity"] +# Read column names +col_names = [] +with open('names') as f: + col_names = f.read().strip().split('\t') # Read data -dev = pd.read_table('dev-0/in.tsv', error_bad_lines=False, header=None) -test = pd.read_table('test-A/in.tsv', error_bad_lines=False, header=None) -test_expected = pd.read_table( - 'dev-0/expected.tsv', error_bad_lines=False, header=None) -train = pd.read_table('train/train.tsv', error_bad_lines=False, header=None) +dev = pd.read_table('dev-0/in.tsv', error_bad_lines=False, + header=None, names=col_names[1:]) +test = pd.read_table('test-A/in.tsv', error_bad_lines=False, + header=None, names=col_names[1:]) +train = pd.read_table('train/train.tsv', error_bad_lines=False, + header=None, names=col_names) +test_expected = pd.read_table('dev-0/expected.tsv', error_bad_lines=False, + header=None) -print(dev) +# Create dummies for brand +train = pd.get_dummies(train, columns=['engineType']) + +# Sprawdzanie ile jest odstających wartości dla price +fig, ax = plt.subplots(1, 2) +fig.set_figheight(15) +fig.set_figwidth(20) +ax[0].boxplot(train['price']) +ax[0].set_title('price') +ax[1].boxplot(train['mileage']) +ax[1].set_title('mileage') +plt.show() + +# Usunięcie odstających wartości +priceMin = 0 +for price in train['price']: + if price < 1000: + priceMin += 1 +print("Price min cut: " + str(priceMin)) + +priceMax = 0 +for price in train['price']: + if price > 1000000: + priceMin += 1 +print("Price max cut: " + str(priceMax)) + +mileageMin = 0 +for m in train['mileage']: + if m < 100: + mileageMin += 1 +print("Mileage min cut: " + str(mileageMin)) + +train = train.loc[(train['price'] > 1000)] +train = train.loc[(train['mileage'] > 100)] + +# Split train set to X and Y +X = train.loc[:, train.columns != 'price'] +Y = train['price'] + +# print(train) +# print(col_names)