retroc2/retro.ipynb
2022-05-08 23:30:43 +02:00

4.4 KiB

import lzma
import csv
import re

def readInput(dir):
    X = []
    if 'xz' in dir:
        with lzma.open(dir) as f:
            for line in f:
                text = line.decode('utf-8')
                text = text.split('\t')
                X.append(text)
    else:
        with open(dir, encoding='utf8', errors='ignore') as f:
            for line in f:
                X. append(line.replace('\n',''))
    return X

def writeOutput(output, dir):
    with open(dir, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerows(output)
import pandas as pd

train = pd.DataFrame(readInput('train/train.tsv.xz'), columns=['Beginning', 'End', 'Title', 'Source', 'X'])
train['Y'] = train.apply(lambda x: (float(x.Beginning) + float(x.End))/2, axis=1)
train = train.drop(columns=['Beginning', 'End', 'Title', 'Source'])
from sklearn import linear_model
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.pipeline import Pipeline

estimators = [('tfidf', TfidfVectorizer()), ('linearRegression', linear_model.LinearRegression())]
model = Pipeline(estimators)
model.fit(train.X, train.Y)
dev0X = readInput('dev-0/in.tsv')
dev0Expected = readInput('dev-0/expected.tsv')
dev0Predicted = model.predict(dev0X)
print('RMSE = ', np.sqrt(sklearn.metrics.mean_squared_error(dev0Expected, dev0Predicted)))
print('Model score = ', model.score(dev0X, dev0Expected))
import sklearn.metrics
import numpy as np

print('RMSE = ', np.sqrt(sklearn.metrics.mean_squared_error(dev0Expected, dev0Predicted)))
print('Model score = ', model.score(dev0X, dev0Expected))
c:\software\python3\lib\site-packages\sklearn\utils\validation.py:63: FutureWarning: Arrays of bytes/strings is being converted to decimal numbers if dtype='numeric'. This behavior is deprecated in 0.24 and will be removed in 1.1 (renaming of 0.26). Please convert your data to numeric values explicitly instead.
  return f(*args, **kwargs)
RMSE =  21.716380888138996
Model score =  0.8585103501633741
c:\software\python3\lib\site-packages\sklearn\utils\validation.py:63: FutureWarning: Arrays of bytes/strings is being converted to decimal numbers if dtype='numeric'. This behavior is deprecated in 0.24 and will be removed in 1.1 (renaming of 0.26). Please convert your data to numeric values explicitly instead.
  return f(*args, **kwargs)