Compare commits
2 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
c98058d49a | ||
|
1220888362 |
@ -1 +1 @@
|
||||
--precision 1
|
||||
--metric RMSE --precision 1
|
||||
|
1000
dev-0/out.tsv
Normal file
1000
dev-0/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
56
linreg.py
Normal file
56
linreg.py
Normal file
@ -0,0 +1,56 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import sys
|
||||
from sklearn .linear_model import LinearRegression
|
||||
|
||||
TRAIN_FILE_PATH = 'train/train.tsv'
|
||||
|
||||
DEV0_IN = 'dev-0/in.tsv'
|
||||
DEV0_OUT = 'dev-0/out.tsv'
|
||||
|
||||
TEST_A_IN = 'test-A/in.tsv'
|
||||
TEST_A_OUT = 'test-A/out.tsv'
|
||||
|
||||
|
||||
def read_data_file(filepath, x_index, y_index):
|
||||
df = pd.read_csv(filepath, sep='\t', header=None, index_col=None)
|
||||
x = df[x_index].tolist() if x_index is not None else None
|
||||
y = df[y_index].tolist() if y_index is not None else None
|
||||
|
||||
return {'x': x, 'y': y}
|
||||
|
||||
|
||||
def to_numpy_2d(lst):
|
||||
return np.array(lst).reshape(-1, 1)
|
||||
|
||||
|
||||
def get_trained_linreg_model(train_data):
|
||||
x = to_numpy_2d(train_data.get('x'))
|
||||
y = to_numpy_2d(train_data.get('y'))
|
||||
|
||||
model = LinearRegression()
|
||||
model.fit(x, y)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def make_predictions(model, in_file, out_file):
|
||||
input = read_data_file(in_file, 1, None)
|
||||
input_x = to_numpy_2d(input.get('x'))
|
||||
pred_y = model.predict(input_x)
|
||||
|
||||
with open(out_file, 'w') as f:
|
||||
for pred in pred_y:
|
||||
f.write(str(pred[0]) + '\n')
|
||||
|
||||
|
||||
def main():
|
||||
train_data = read_data_file(TRAIN_FILE_PATH, 2, 0)
|
||||
model = get_trained_linreg_model(train_data)
|
||||
|
||||
make_predictions(model, DEV0_IN, DEV0_OUT)
|
||||
make_predictions(model, TEST_A_IN, TEST_A_OUT)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
1000
test-A/out.tsv
Normal file
1000
test-A/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user