mieszkania5/train.py

59 lines
1.8 KiB
Python
Raw Normal View History

2020-12-08 20:14:00 +01:00
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import pickle
2020-12-08 20:50:35 +01:00
2020-12-08 20:14:00 +01:00
TRAIN_FILE_PATH = 'train/train.tsv'
#prepare data methods
def read_data_file(filepath):
df = pd.read_csv(filepath, sep='\t', header=None, index_col=None)
dataframe = df.iloc[:, [0,8,11]]
dataframe.columns = ['price','biggy','type']
2020-12-08 20:50:35 +01:00
2020-12-08 20:14:00 +01:00
for x in range(len(dataframe)):
dataframe['biggy'].loc[x] = dataframe['biggy'].loc[x].replace(" ","")
#such dumb solution, well, but at least it works
dataframe['bias'] = 1
dataframe['biggy'] = dataframe['biggy'].astype(float)
return dataframe
def dataframe_to_arrays(dataframe):
dataframe1 = dataframe.copy(deep=True)
2020-12-08 23:13:06 +01:00
#dataframe1["type"] = dataframe1["type"].astype('category').cat.codes
dataframe1 = pd.get_dummies(dataframe1, columns =['type'])
print(dataframe1.columns)
input_cols = dataframe1.columns.values[1:]
output_cols = dataframe1.columns.values[:1]
2020-12-08 20:14:00 +01:00
inputs_array = dataframe1[input_cols].to_numpy()
targets_array = dataframe1[output_cols].to_numpy()
return inputs_array, targets_array
data = read_data_file(TRAIN_FILE_PATH)
2020-12-08 23:13:06 +01:00
2020-12-08 20:14:00 +01:00
inputs_array_training, targets_array_training = dataframe_to_arrays(data)
inputs_training = torch.from_numpy(inputs_array_training).type(torch.float32)
targets_training = torch.from_numpy(targets_array_training).type(torch.float32)
2020-12-08 23:13:06 +01:00
weights = torch.tensor([1.0, 300000.0,1.0, 1.0, 1.0, 1.0,1.0,1.0,1.0], requires_grad=True)
learning_rate = torch.tensor(0.0000000000005)
2020-12-08 20:14:00 +01:00
print("training started")
2020-12-08 23:13:06 +01:00
for i in range(300):
2020-12-08 20:50:35 +01:00
y_predicted = inputs_training @ weights
2020-12-08 20:14:00 +01:00
cost = torch.sum((y_predicted - targets_training) ** 2)
cost.backward()
with torch.no_grad():
2020-12-08 20:50:35 +01:00
weights -= learning_rate * weights.grad
weights.requires_grad = True
2020-12-08 20:14:00 +01:00
2020-12-08 20:50:35 +01:00
print(weights)
pickle.dump(weights, open('model.pkl', 'wb'))