4.4 KiB
4.4 KiB
import lzma
import csv
import re
def readInput(dir):
X = []
if 'xz' in dir:
with lzma.open(dir) as f:
for line in f:
text = line.decode('utf-8')
text = text.split('\t')
X.append(text)
else:
with open(dir, encoding='utf8', errors='ignore') as f:
for line in f:
X. append(line.replace('\n',''))
return X
def writeOutput(output, dir):
with open(dir, 'w', newline='') as f:
writer = csv.writer(f)
writer.writerows(output)
import pandas as pd
train = pd.DataFrame(readInput('train/train.tsv.xz'), columns=['Beginning', 'End', 'Title', 'Source', 'X'])
train['Y'] = train.apply(lambda x: (float(x.Beginning) + float(x.End))/2, axis=1)
train = train.drop(columns=['Beginning', 'End', 'Title', 'Source'])
from sklearn import linear_model
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.pipeline import Pipeline
estimators = [('tfidf', TfidfVectorizer()), ('linearRegression', linear_model.LinearRegression())]
model = Pipeline(estimators)
model.fit(train.X, train.Y)
dev0X = readInput('dev-0/in.tsv')
dev0Expected = readInput('dev-0/expected.tsv')
dev0Predicted = model.predict(dev0X)
print('RMSE = ', np.sqrt(sklearn.metrics.mean_squared_error(dev0Expected, dev0Predicted)))
print('Model score = ', model.score(dev0X, dev0Expected))
import sklearn.metrics
import numpy as np
print('RMSE = ', np.sqrt(sklearn.metrics.mean_squared_error(dev0Expected, dev0Predicted)))
print('Model score = ', model.score(dev0X, dev0Expected))
c:\software\python3\lib\site-packages\sklearn\utils\validation.py:63: FutureWarning: Arrays of bytes/strings is being converted to decimal numbers if dtype='numeric'. This behavior is deprecated in 0.24 and will be removed in 1.1 (renaming of 0.26). Please convert your data to numeric values explicitly instead. return f(*args, **kwargs)
RMSE = 21.716380888138996 Model score = 0.8585103501633741
c:\software\python3\lib\site-packages\sklearn\utils\validation.py:63: FutureWarning: Arrays of bytes/strings is being converted to decimal numbers if dtype='numeric'. This behavior is deprecated in 0.24 and will be removed in 1.1 (renaming of 0.26). Please convert your data to numeric values explicitly instead. return f(*args, **kwargs)