Compare commits

...

1 Commits

Author SHA1 Message Date
piotr6789
ac42ca9fd0 add script to linear regression 2021-05-18 18:11:26 +02:00
3 changed files with 2058 additions and 0 deletions

1000
dev-0/out.tsv Normal file

File diff suppressed because it is too large Load Diff

58
linear-regression.py Normal file
View File

@ -0,0 +1,58 @@
import pandas as pd
from pathlib import Path
from sklearn.linear_model import LinearRegression
def get_names():
DATA_DIR = Path('./')
with open(DATA_DIR / 'names') as f_names:
return f_names.read().rstrip('\n').split('\t')
def get_data(names):
df = pd.read_csv("train/train.tsv", header=None, sep="\t", error_bad_lines=False, names=names)
dev_data = pd.read_csv("dev-0/in.tsv", header=None, sep="\t", error_bad_lines=False, names=['mileage', 'year', 'brand', 'engineType','engineCapacity'])
test_data = pd.read_csv("test-A/in.tsv", header=None, sep="\t", error_bad_lines=False, names=['mileage', 'year','brand', 'engineType', 'engineCapacity'])
return df, dev_data, test_data
def get_train_data(df):
df = df.drop(['brand'], axis=1)
train = pd.get_dummies(df, columns=['engineType'])
train = train.loc[(train['price'] > 1000)]
return train.loc[(train['mileage'] > 100)]
def get_x(train):
return train.loc[:, train.columns != 'price']
def get_y(train):
return train['price']
def get_linear_regression(x,y):
return LinearRegression().fit(x, y)
def process_data(df):
data = df.drop(['brand'], axis=1)
return pd.get_dummies(data, columns=['engineType'])
def get_prediction(clf, data, type):
prediction = clf.predict(data)
if type == 'dev':
prediction.tofile("./dev-0/out.tsv", sep='\n')
elif type == 'test':
prediction.tofile("./test-A/out.tsv", sep='\n')
def main():
#prepare
df, dev_data, test_data = get_data(get_names())
train = get_train_data(df)
x = get_x(train)
y = get_y(train)
#linear regression
clf = get_linear_regression(x, y)
#predictions
dev = process_data(dev_data)
test = process_data(test_data)
get_prediction(clf, dev, 'dev')
get_prediction(clf, test, 'test')
main()

1000
test-A/out.tsv Normal file

File diff suppressed because it is too large Load Diff