forked from kubapok/auta-public
25 lines
888 B
Python
25 lines
888 B
Python
|
import numpy
|
||
|
import pandas
|
||
|
import sys
|
||
|
from sklearn.linear_model import LinearRegression
|
||
|
|
||
|
def trainModel(filePath):
|
||
|
trainFrame = pandas.read_csv(filePath, sep='\t', names=['price','mileage','year','brand','engineType','engineCapacity'])
|
||
|
dataY = trainFrame[['price']]
|
||
|
dataX = trainFrame[['year','mileage','engineCapacity']]
|
||
|
model = LinearRegression()
|
||
|
model.fit(dataX, dataY)
|
||
|
return model
|
||
|
|
||
|
def predictModel(model, filePathIn, filePathOut):
|
||
|
inFrame = pandas.read_csv(filePathIn, sep='\t', names=['mileage','year','brand','engineType','engineCapacity'])
|
||
|
data = inFrame[['year','mileage','engineCapacity']]
|
||
|
prediction = model.predict(data)
|
||
|
numpy.savetxt(filePathOut, prediction, fmt='%d', delimiter='\n')
|
||
|
|
||
|
model = trainModel("train/train.tsv")
|
||
|
predictModel(model, "dev-0/in.tsv", "dev-0/out.tsv")
|
||
|
predictModel(model, "test-A/in.tsv", "test-A/out.tsv")
|
||
|
|
||
|
|