mieszkania5/predict.py

38 lines
1.1 KiB
Python
Raw Normal View History

2020-12-08 20:14:00 +01:00
import pickle
import sys
import torch
import pandas as pd
def read_data_file(filepath):
df = pd.read_csv(filepath, sep='\t', header=None, index_col=None)
dataframe = df.iloc[:, [7,10]]
dataframe.columns = ['biggy','type']
#print(dataframe.size[0])
# for x in range(len(dataframe)):
# dataframe['biggy'].loc[x] = dataframe['biggy'].loc[x].replace(" ","")
#such dumb solution, well, but at least it works
dataframe['bias'] = 1
dataframe['biggy'] = dataframe['biggy'].astype(float)
return dataframe
def dataframe_to_arrays(dataframe):
dataframe1 = dataframe.copy(deep=True)
dataframe1["type"] = dataframe1["type"].astype('category').cat.codes
return dataframe1
2020-12-08 20:50:35 +01:00
#PREDICT_FILE_PATH = 'dev-0/in.tsv'
2020-12-08 20:14:00 +01:00
PREDICT_FILE_PATH = 'test-A/in.tsv'
def main():
w = pickle.load(open('model.pkl', 'rb'))
data = read_data_file(PREDICT_FILE_PATH)
data = dataframe_to_arrays(data)
for index, row in data.iterrows():
#print(row[0], row[1])
x = torch.tensor([float(row[0]), float(row[1]), 1])
y = x @ w
print(y.item())
main()