diff --git a/dev-0/out.tsv b/dev-0/out.tsv new file mode 100644 index 0000000..bc38ef9 --- /dev/null +++ b/dev-0/out.tsv @@ -0,0 +1,462 @@ +753233.5625 +717278.0 +761310.75 +737726.625 +581629.3125 +633398.5625 +734318.5 +742838.8125 +635378.125 +736022.625 +626772.6875 +726735.5 +755602.1875 +660751.5625 +752824.625 +737743.6875 +728132.8125 +1178018.125 +735153.5 +534563.3125 +741015.5 +903060.0 +739771.5 +954096.4375 +753063.125 +688970.6875 +977016.0 +737845.9375 +751359.125 +768007.75 +727502.3125 +730910.4375 +761072.25 +647391.75 +761072.25 +909365.0 +789121.0 +577675.875 +721061.0 +793636.75 +727894.25 +942082.875 +891472.375 +739243.25 +784162.1875 +577062.4375 +624591.5 +644341.5 +725832.375 +624216.5625 +620229.0625 +899992.625 +716937.125 +747951.0 +776238.3125 +729206.375 +636997.0 +704648.0 +687266.625 +526043.0 +707684.125 +582992.5625 +613992.25 +1065286.125 +802480.75 +657002.625 +845082.1875 +780805.1875 +752364.5 +682154.5 +705690.375 +715948.8125 +739430.75 +919589.3125 +725951.625 +985195.5 +764139.5 +970915.5 +730160.625 +943786.875 +895732.5 +538721.1875 +724094.1875 +725798.25 +899992.625 +738646.875 +532859.25 +754767.25 +712165.8125 +534563.3125 +794128.0 +777976.5 +594205.25 +942082.875 +886360.25 +911921.0625 +880225.625 +754750.1875 +737897.0625 +547684.5 +951114.375 +732614.5 +738749.125 +724043.125 +526434.9375 +525617.0 +952818.375 +724366.875 +724486.125 +733466.5 +732273.6875 +774173.5 +895851.8125 +752040.75 +724298.6875 +732273.6875 +735681.75 +738067.5 +720345.25 +570348.5 +709609.75 +939185.9375 +1076822.5 +824463.125 +744542.875 +701444.375 +717448.375 +663409.875 +728712.1875 +758175.3125 +642961.1875 +1082383.625 +945965.1875 +740282.75 +734829.75 +886258.0 +529451.125 +894999.75 +717278.0 +741475.5625 +636465.875 +787144.25 +580402.375 +609218.0 +764991.5625 +536437.75 +658297.6875 +682648.6875 +648260.8125 +971563.0 +897538.8125 +737266.5625 +992863.75 +736022.625 +967575.5 +717278.0 +907916.5 +763117.125 +727195.625 +663409.875 +735187.625 +675338.25 +735937.375 +730910.4375 +737845.9375 +708042.0 +978924.5625 +888916.3125 +763287.5 +613992.25 +724043.125 +552166.1875 +737743.6875 +753063.125 +705789.75 +554228.125 +951404.0 +710666.25 +739345.5 +730041.375 +698104.4375 +548195.75 +514114.625 +713733.5 +700953.125 +901696.75 +738646.875 +613259.5 +762946.6875 +742838.8125 +803523.125 +770069.625 +727502.3125 +595738.875 +728320.25 +771807.75 +732614.5 +745531.25 +759368.125 +776919.9375 +727502.3125 +880140.375 +708076.0625 +753063.125 +904116.5 +724094.1875 +742838.8125 +632736.875 +662285.1875 +723634.125 +754767.25 +778624.0 +677601.75 +534563.3125 +763287.5 +555898.0625 +743520.4375 +783054.5 +739788.5625 +732614.5 +740197.5 +546679.125 +503890.28125 +556443.375 +548076.4375 +734318.5 +589076.0625 +717278.0 +730041.375 +707735.25 +721435.875 +708416.875 +724094.1875 +629328.75 +762810.375 +583742.375 +734318.5 +1117450.125 +723412.5625 +634440.9375 +538244.0625 +526179.3125 +633827.4375 +705349.625 +554500.75 +733125.6875 +733636.875 +639893.875 +737556.25 +665284.3125 +743861.25 +723276.25 +708076.0625 +802855.6875 +729632.375 +733602.875 +739430.75 +642719.75 +697491.0 +663409.875 +619513.375 +528939.875 +594546.0625 +703798.875 +552967.125 +805888.875 +523060.90625 +612816.4375 +1039725.25 +710938.875 +707053.625 +919418.875 +1036317.125 +696249.9375 +747030.8125 +663665.5 +677059.375 +670226.125 +736022.625 +992522.9375 +912773.125 +947467.6875 +522634.90625 +747899.875 +592501.1875 +717278.0 +640626.625 +521459.09375 +761583.4375 +718641.25 +1005473.75 +749229.0 +734318.5 +920373.1875 +908683.375 +700595.25 +728081.6875 +740282.75 +735698.8125 +636145.0 +923849.4375 +792648.375 +961338.6875 +804627.875 +717278.0 +645176.5 +553188.625 +520930.84375 +573092.0 +544872.8125 +733926.625 +522345.21875 +706181.6875 +720305.375 +570348.5 +717278.0 +743690.875 +772370.125 +616088.25 +768399.625 +904440.25 +976675.1875 +580572.8125 +532859.25 +544787.625 +713921.0 +1184289.0 +726411.75 +910217.0 +886019.375 +732614.5 +704548.6875 +956888.1875 +760901.8125 +599147.0 +1111295.625 +753063.125 +917033.25 +735085.375 +892051.75 +773136.9375 +949188.75 +557482.875 +729206.375 +606324.0 +783906.5625 +606324.0 +570246.25 +745585.25 +732819.0 +728422.5 +738578.6875 +533336.375 +742458.125 +611197.5625 +538244.0625 +717295.0 +616497.1875 +712063.5625 +651481.5 +760561.0 +599147.0 +669183.75 +906808.875 +753233.5625 +560124.125 +746076.5 +1146211.75 +570246.25 +734301.5 +739430.75 +708842.875 +724946.25 +845593.375 +580061.5625 +722816.125 +627624.6875 +906808.875 +726343.5625 +660001.75 +541652.1875 +915138.8125 +615642.25 +733994.75 +636806.625 +895831.875 +558249.6875 +703884.125 +740569.5 +704071.5625 +540646.75 +539573.1875 +771807.75 +574438.1875 +556716.0 +636806.625 +675338.25 +574659.75 +895050.875 +749859.5 +743486.375 +733276.1875 +1011949.125 +969518.125 +936459.5 +1255859.375 +714173.6875 +566940.375 +725713.0625 +654889.5625 +807490.75 +613157.25 +580794.3125 +634492.0 +797453.8125 +598499.5 +992673.375 +744031.625 +580760.25 +591700.3125 +895542.1875 +724043.125 +595298.75 +764395.125 +731114.875 +546662.0625 +738698.0 +738698.0 +733279.0625 +738425.3125 +796669.9375 +795920.1875 +644324.4375 +689822.75 +663069.0625 +787076.125 +992352.5 +617380.4375 +835519.5 +736056.6875 +559749.25 +630893.5625 +580095.6875 +638019.4375 +654395.4375 +675048.5625 +868899.125 +763696.5 +714074.375 +730893.375 +713852.8125 +585667.9375 +713852.8125 +733347.25 +574012.1875 +776545.0625 +597408.875 +776545.0625 +557107.9375 +744815.5 +1171488.75 +775045.5 diff --git a/geval b/geval new file mode 100755 index 0000000..b68b316 Binary files /dev/null and b/geval differ diff --git a/model.pkl b/model.pkl new file mode 100644 index 0000000..e4ca2a9 Binary files /dev/null and b/model.pkl differ diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..4cf3cb0 --- /dev/null +++ b/predict.py @@ -0,0 +1,37 @@ +import pickle +import sys +import torch +import pandas as pd + +def read_data_file(filepath): + df = pd.read_csv(filepath, sep='\t', header=None, index_col=None) + dataframe = df.iloc[:, [7,10]] + dataframe.columns = ['biggy','type'] + #print(dataframe.size[0]) + # for x in range(len(dataframe)): + # dataframe['biggy'].loc[x] = dataframe['biggy'].loc[x].replace(" ","") + #such dumb solution, well, but at least it works + dataframe['bias'] = 1 + dataframe['biggy'] = dataframe['biggy'].astype(float) + return dataframe + +def dataframe_to_arrays(dataframe): + dataframe1 = dataframe.copy(deep=True) + dataframe1["type"] = dataframe1["type"].astype('category').cat.codes + return dataframe1 + +PREDICT_FILE_PATH = 'test-A/in.tsv' + +def main(): + w = pickle.load(open('model.pkl', 'rb')) + + data = read_data_file(PREDICT_FILE_PATH) + data = dataframe_to_arrays(data) + + for index, row in data.iterrows(): + #print(row[0], row[1]) + x = torch.tensor([float(row[0]), float(row[1]), 1]) + y = x @ w + print(y.item()) + +main() \ No newline at end of file diff --git a/test-A/out.tsv b/test-A/out.tsv new file mode 100644 index 0000000..36c42d6 --- /dev/null +++ b/test-A/out.tsv @@ -0,0 +1,418 @@ +758158.25 +761583.4375 +739686.3125 +639161.125 +728115.75 +810830.625 +545639.625 +880205.6875 +1147683.0 +735784.0 +756454.25 +758175.3125 +735000.125 +717278.0 +734318.5 +736022.625 +806911.3125 +834196.125 +703645.5 +939049.625 +778624.0 +705349.625 +646369.3125 +718982.0 +618062.0625 +735698.8125 +760987.0 +1111420.75 +607704.25 +735681.75 +1005644.125 +944298.125 +1056765.75 +620621.0 +636145.0 +776238.3125 +922145.375 +747951.0 +725798.25 +681779.625 +761072.25 +563344.8125 +532007.1875 +754767.25 +715573.9375 +738646.875 +756641.6875 +713409.75 +704838.375 +628286.375 +753114.25 +610243.3125 +627454.3125 +757834.5 +734318.5 +946036.25 +665812.5625 +622410.25 +643881.375 +723412.5625 +620958.9375 +923849.4375 +736022.625 +976675.1875 +718078.875 +653185.5 +701069.5 +560771.6875 +917033.25 +719322.875 +893176.4375 +575460.625 +941742.0 +933835.25 +730399.25 +759027.375 +724435.0 +910217.0 +735698.8125 +746246.9375 +816215.5 +797147.125 +674264.6875 +760731.375 +707053.625 +745105.25 +724230.5 +757323.3125 +735341.0 +717278.0 +746570.6875 +894829.375 +757152.875 +917033.25 +549899.8125 +1070398.25 +828893.625 +942082.875 +534733.6875 +530643.9375 +742838.8125 +737726.625 +741134.75 +753063.125 +724605.375 +731080.8125 +759879.375 +622969.75 +622066.5625 +543083.5625 +1022684.6875 +746332.125 +703901.125 +629328.75 +625730.3125 +703815.9375 +543236.9375 +743350.0 +820563.75 +539505.0625 +822608.625 +726548.0 +730586.625 +736908.6875 +681029.8125 +886019.375 +628647.125 +737845.9375 +752824.625 +754767.25 +936118.625 +566122.4375 +717278.0 +718982.0 +760390.625 +983340.9375 +730910.4375 +737726.625 +558420.0625 +716235.625 +561811.125 +710666.25 +761242.625 +762265.0625 +800776.75 +544787.625 +728183.9375 +1022684.6875 +759879.375 +738220.8125 +705042.875 +725798.25 +712268.0625 +766695.625 +532859.25 +732614.5 +1072085.25 +1051210.625 +736942.75 +825826.375 +766695.625 +537971.375 +734318.5 +657871.6875 +767240.875 +595298.75 +753693.625 +747951.0 +742838.8125 +754750.1875 +518715.5625 +774023.0625 +738527.5625 +718982.0 +533285.25 +727604.5625 +549899.8125 +639893.875 +677042.3125 +626448.875 +717278.0 +946002.1875 +778624.0 +743179.625 +718198.1875 +721026.875 +683057.625 +982449.0625 +742838.8125 +765673.1875 +841012.375 +558300.8125 +551603.875 +1031205.0 +706917.3125 +570518.875 +725798.25 +920441.375 +727672.75 +741134.75 +676340.75 +570348.5 +722185.625 +733125.6875 +597613.375 +744491.75 +696249.9375 +906519.1875 +721776.6875 +1299764.25 +735903.3125 +644665.25 +541209.125 +950091.9375 +720686.125 +997123.875 +961338.6875 +696249.9375 +708076.0625 +897907.9375 +736022.625 +919027.0 +906808.875 +703901.125 +651481.5 +720856.5 +748479.25 +1029500.9375 +727502.3125 +727502.3125 +736022.625 +772540.5 +716425.9375 +725798.25 +1139844.375 +654634.0 +968683.1875 +1203314.625 +656593.625 +741305.125 +573756.5625 +583980.9375 +736022.625 +729325.625 +578868.75 +663921.0625 +613992.25 +717278.0 +726070.875 +794471.75 +849683.125 +705349.625 +648260.8125 +722731.0 +961338.6875 +697831.8125 +669203.6875 +781827.625 +654634.0 +512580.96875 +721231.375 +722407.1875 +784639.3125 +757152.875 +934585.0 +727246.6875 +732273.6875 +766695.625 +690674.75 +768399.625 +766695.625 +727195.625 +802480.75 +1147512.625 +975823.125 +751359.125 +713529.0625 +745735.75 +578868.75 +642961.1875 +555011.9375 +688970.6875 +764991.5625 +738016.375 +696249.9375 +549899.8125 +549899.8125 +580402.375 +551603.875 +583980.9375 +715642.125 +708451.0 +744525.875 +776511.0 +816624.4375 +643642.8125 +604429.625 +534563.3125 +734318.5 +538823.4375 +760561.0 +712847.4375 +629158.375 +724366.875 +722901.375 +644665.25 +624878.25 +710291.375 +659831.375 +736022.625 +543697.0 +561828.1875 +580572.8125 +535347.125 +549559.0 +745122.25 +623174.1875 +798050.25 +646164.8125 +734318.5 +522242.96875 +639945.0 +553307.875 +556716.0 +540033.3125 +712336.25 +651481.5 +583980.9375 +756982.5 +624196.625 +672782.1875 +715573.9375 +1073826.25 +585565.6875 +595057.25 +620808.4375 +724435.0 +741305.125 +674486.25 +522242.96875 +754920.5625 +910029.5625 +750847.875 +559544.75 +825826.375 +786803.5 +732273.6875 +591700.3125 +562492.75 +922176.5625 +724094.1875 +757834.5 +1077595.25 +788848.375 +686605.0 +733602.875 +1009802.0625 +720686.125 +727502.3125 +1090847.0 +732392.9375 +530030.5 +740776.875 +751427.25 +675338.25 +735153.5 +769081.25 +1038021.25 +551603.875 +560635.375 +734318.5 +714426.4375 +530473.5625 +879714.375 +546474.625 +695786.9375 +737726.625 +540646.75 +649862.625 +538295.1875 +845474.125 +714892.3125 +1046030.25 +568184.3125 +644648.1875 +985826.0 +734744.5625 +725661.9375 +1041278.875 +1041278.875 +1099898.375 +771194.3125 +1002767.1875 +554807.5 +745991.3125 +731609.125 +736022.625 +892903.75 +733449.5 +663409.875 +736806.5 +782424.0625 +743588.625 +743588.625 +1152573.625 +721538.125 +709453.4375 +1077725.75 +769047.1875 +751785.125 +736976.875 +520981.96875 +713835.75 +712114.6875 +736124.875 +783242.0 +593574.75 +764718.875 diff --git a/train.py b/train.py new file mode 100644 index 0000000..72fb5c2 --- /dev/null +++ b/train.py @@ -0,0 +1,58 @@ +import torch +import torch.nn as nn +import numpy as np +import pandas as pd +import sys +import torch.nn.functional as F +import pickle +from torch.utils.data import TensorDataset, DataLoader + +TRAIN_FILE_PATH = 'train/train.tsv' + + +#prepare data methods +def read_data_file(filepath): + df = pd.read_csv(filepath, sep='\t', header=None, index_col=None) + dataframe = df.iloc[:, [0,8,11]] + dataframe.columns = ['price','biggy','type'] + #print(dataframe.size[0]) + for x in range(len(dataframe)): + dataframe['biggy'].loc[x] = dataframe['biggy'].loc[x].replace(" ","") + #such dumb solution, well, but at least it works + dataframe['bias'] = 1 + dataframe['biggy'] = dataframe['biggy'].astype(float) + return dataframe + + +def dataframe_to_arrays(dataframe): + dataframe1 = dataframe.copy(deep=True) + dataframe1["type"] = dataframe1["type"].astype('category').cat.codes + inputs_array = dataframe1[input_cols].to_numpy() + targets_array = dataframe1[output_cols].to_numpy() + return inputs_array, targets_array + +data = read_data_file(TRAIN_FILE_PATH) +input_cols = data.columns.values[1:] +output_cols = data.columns.values[:1] + + + +inputs_array_training, targets_array_training = dataframe_to_arrays(data) + +inputs_training = torch.from_numpy(inputs_array_training).type(torch.float32) +targets_training = torch.from_numpy(targets_array_training).type(torch.float32) + +print(inputs_training) +w = torch.tensor([7201.61492633873, 1,7201.500], requires_grad=True) +learning_rate = torch.tensor(0.000000000005) +print("training started") +for i in range(10000): + y_predicted = inputs_training @ w + cost = torch.sum((y_predicted - targets_training) ** 2) + cost.backward() + with torch.no_grad(): + w -= learning_rate * w.grad + w.requires_grad = True + +print(w) +pickle.dump(w, open('model.pkl', 'wb')) \ No newline at end of file