478841
This commit is contained in:
parent
647c099815
commit
805283fe2d
20000
dev-0/out.tsv
Normal file
20000
dev-0/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
11563
dev-1/out.tsv
Normal file
11563
dev-1/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
49
run.py
Normal file
49
run.py
Normal file
@ -0,0 +1,49 @@
|
||||
import lzma
|
||||
import pandas as pd
|
||||
from sklearn.pipeline import make_pipeline
|
||||
from sklearn.linear_model import LinearRegression
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.metrics import mean_squared_error
|
||||
|
||||
PATHS = ['train/train.tsv.xz', 'dev-0/in.tsv', 'dev-1/in.tsv', 'test-A/in.tsv']
|
||||
|
||||
|
||||
def read_data(path):
|
||||
print(f"I am reading the data from {path}...")
|
||||
if path[-2:] == 'xz':
|
||||
with lzma.open(path, 'rt', encoding='utf-8') as f:
|
||||
data = pd.read_csv(f, delimiter='\t', header=None).drop(
|
||||
[1, 2, 3], axis=1).rename({0: 'start', 4: 'text'}, axis=1)
|
||||
else:
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
data = [line.strip() for line in f.readlines()]
|
||||
print("Data loaded")
|
||||
return data
|
||||
|
||||
|
||||
def save_predictions(path, preds):
|
||||
new_path = f"{path.split('/')[0]}/out.tsv"
|
||||
print(f"Saving predictions to {new_path}")
|
||||
with open(new_path, 'w') as f:
|
||||
for line in preds:
|
||||
f.write(f'{line}\n')
|
||||
|
||||
|
||||
# * Load training data
|
||||
data = read_data(PATHS[0])
|
||||
x_train, y_train = data['text'], data['start']
|
||||
|
||||
# * Loading pipeline & model training
|
||||
pipeline = make_pipeline(TfidfVectorizer(), LinearRegression())
|
||||
print("Now I will train the model...")
|
||||
pipeline.fit(x_train, y_train)
|
||||
print("Training completed!\n\n")
|
||||
|
||||
# * Making predictions
|
||||
for path in PATHS[1:]:
|
||||
X = read_data(path)
|
||||
print(f"I will make predictions for {path}")
|
||||
predictions = pipeline.predict(X)
|
||||
if not 'test' in path:
|
||||
print(f"RMSE for {path.split('/')[0]}: {mean_squared_error(read_data(path.split('/')[0] + '/expected.tsv'), predictions)}\n")
|
||||
save_predictions(path, predictions)
|
14220
test-A/out.tsv
Normal file
14220
test-A/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user