Compare commits
2 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
c98058d49a | ||
|
1220888362 |
1000
dev-0/out.tsv
Normal file
1000
dev-0/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
56
linreg.py
Normal file
56
linreg.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import sys
|
||||||
|
from sklearn .linear_model import LinearRegression
|
||||||
|
|
||||||
|
TRAIN_FILE_PATH = 'train/train.tsv'
|
||||||
|
|
||||||
|
DEV0_IN = 'dev-0/in.tsv'
|
||||||
|
DEV0_OUT = 'dev-0/out.tsv'
|
||||||
|
|
||||||
|
TEST_A_IN = 'test-A/in.tsv'
|
||||||
|
TEST_A_OUT = 'test-A/out.tsv'
|
||||||
|
|
||||||
|
|
||||||
|
def read_data_file(filepath, x_index, y_index):
|
||||||
|
df = pd.read_csv(filepath, sep='\t', header=None, index_col=None)
|
||||||
|
x = df[x_index].tolist() if x_index is not None else None
|
||||||
|
y = df[y_index].tolist() if y_index is not None else None
|
||||||
|
|
||||||
|
return {'x': x, 'y': y}
|
||||||
|
|
||||||
|
|
||||||
|
def to_numpy_2d(lst):
|
||||||
|
return np.array(lst).reshape(-1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def get_trained_linreg_model(train_data):
|
||||||
|
x = to_numpy_2d(train_data.get('x'))
|
||||||
|
y = to_numpy_2d(train_data.get('y'))
|
||||||
|
|
||||||
|
model = LinearRegression()
|
||||||
|
model.fit(x, y)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def make_predictions(model, in_file, out_file):
|
||||||
|
input = read_data_file(in_file, 1, None)
|
||||||
|
input_x = to_numpy_2d(input.get('x'))
|
||||||
|
pred_y = model.predict(input_x)
|
||||||
|
|
||||||
|
with open(out_file, 'w') as f:
|
||||||
|
for pred in pred_y:
|
||||||
|
f.write(str(pred[0]) + '\n')
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
train_data = read_data_file(TRAIN_FILE_PATH, 2, 0)
|
||||||
|
model = get_trained_linreg_model(train_data)
|
||||||
|
|
||||||
|
make_predictions(model, DEV0_IN, DEV0_OUT)
|
||||||
|
make_predictions(model, TEST_A_IN, TEST_A_OUT)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
1000
test-A/out.tsv
Normal file
1000
test-A/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user