2020-12-08 20:14:00 +01:00
|
|
|
import pickle
|
|
|
|
import sys
|
|
|
|
import torch
|
|
|
|
import pandas as pd
|
|
|
|
|
|
|
|
def read_data_file(filepath):
|
|
|
|
df = pd.read_csv(filepath, sep='\t', header=None, index_col=None)
|
|
|
|
dataframe = df.iloc[:, [7,10]]
|
|
|
|
dataframe.columns = ['biggy','type']
|
|
|
|
#print(dataframe.size[0])
|
|
|
|
# for x in range(len(dataframe)):
|
|
|
|
# dataframe['biggy'].loc[x] = dataframe['biggy'].loc[x].replace(" ","")
|
|
|
|
#such dumb solution, well, but at least it works
|
|
|
|
dataframe['bias'] = 1
|
|
|
|
dataframe['biggy'] = dataframe['biggy'].astype(float)
|
|
|
|
return dataframe
|
|
|
|
|
|
|
|
def dataframe_to_arrays(dataframe):
|
|
|
|
dataframe1 = dataframe.copy(deep=True)
|
|
|
|
dataframe1["type"] = dataframe1["type"].astype('category').cat.codes
|
|
|
|
return dataframe1
|
|
|
|
|
2020-12-08 20:50:35 +01:00
|
|
|
#PREDICT_FILE_PATH = 'dev-0/in.tsv'
|
2020-12-08 20:14:00 +01:00
|
|
|
PREDICT_FILE_PATH = 'test-A/in.tsv'
|
|
|
|
|
|
|
|
def main():
|
|
|
|
w = pickle.load(open('model.pkl', 'rb'))
|
|
|
|
|
|
|
|
data = read_data_file(PREDICT_FILE_PATH)
|
|
|
|
data = dataframe_to_arrays(data)
|
|
|
|
|
|
|
|
for index, row in data.iterrows():
|
|
|
|
#print(row[0], row[1])
|
|
|
|
x = torch.tensor([float(row[0]), float(row[1]), 1])
|
|
|
|
y = x @ w
|
|
|
|
print(y.item())
|
|
|
|
|
|
|
|
main()
|