From f688e6e9667d4570eaf311c3822a7925d392fc1c Mon Sep 17 00:00:00 2001 From: SzamanFL Date: Wed, 16 Dec 2020 00:12:24 +0100 Subject: [PATCH] szu --- src/predict.py | 41 ++- src/train.py | 28 +- test-A/out.tsv | 834 +++++++++++++++++++++++++------------------------ 3 files changed, 452 insertions(+), 451 deletions(-) diff --git a/src/predict.py b/src/predict.py index 5cf87e9..ed5ef32 100644 --- a/src/predict.py +++ b/src/predict.py @@ -21,12 +21,10 @@ dict_column_23 = {"":9, "beton":1, "beton komórkowy":2, "cegła":3, "inne":4, " def read_data(in_file): print("Reading in") - all_data = pd.read_csv(in_file, sep='\t', keep_default_na=False, header = 1) - expected = all_data.iloc[:,0] - data = all_data.iloc[:,1:] + all_data = pd.read_csv(in_file, sep='\t', keep_default_na=False, header = 0) print("Data read") - return expected, data + return all_data def clean_df(data_): print("Cleaning data") @@ -43,6 +41,7 @@ def clean_df(data_): if col == 'parter': data_.iloc[i,14] = 22 + # clear money for i, col in enumerate(data_.iloc[:,1]): try: @@ -50,9 +49,10 @@ def clean_df(data_): data_.iloc[i,1] = 1.0 else: data_.iloc[i,1] = float(col.replace("zł", "").replace(" ", "")) - except ValueError: + except AttributeError: import ipdb; ipdb.set_trace() + # deleting columns deleted_columns = [4,13,18,20,22,24] data_.drop(data_.columns[deleted_columns], axis = 1, inplace=True) @@ -71,20 +71,11 @@ def clean_df(data_): data_.iloc[i,4] = 1 for i, col in enumerate(data_.iloc[:,6]): - data_.iloc[i,6] = col.replace(' ', '') - print("Data cleaned") - return data_ - -def clear(data_): - for i, col in enumerate(data_.iloc[:,2]): try: - if col == "": - data_.iloc[i,2] = 1.0 - else: - data_.iloc[i,2] = float(col.replace("zł", "").replace(" ", "")) + data_.iloc[i,6] = col.replace(' ', '') except AttributeError: - data_.iloc[i,2] = float(data_.iloc[i,2]) - #import ipdb; ipdb.set_trace() + pass + print("Data cleaned") return data_ def main(): @@ -94,9 +85,8 @@ def main(): parser.add_argument("--out") args = parser.parse_args() - expected, data = read_data(args.in_file) - clean_data = clear(data) - clean_data = clean_data.iloc[:,2] + data = read_data(args.in_file) + clean_data = clean_df(data) import ipdb; ipdb.set_trace() #model = Network(len(clean_data.columns)) @@ -105,8 +95,15 @@ def main(): print(f"Loading model : {args.checkpoint}") model.load_state_dict(torch.load(args.checkpoint)) with open(args.out, 'w+') as f: - for i in clean_data: - tensor = torch.tensor([i]) + for i in range(len(clean_data.index)): + data_arr = clean_data.loc[i].to_numpy() + data_arr[data_arr == ""] = 0 + try: + data_arr = pd.to_numeric(data_arr) + except ValueError: + ipdb.set_trace() + data_arr = np.sum(data_arr) + tensor = torch.tensor([data_arr]) y = model(tensor.float()) try: f.write(str(y.item()) + '\n') diff --git a/src/train.py b/src/train.py index 05283b6..2e3aa46 100644 --- a/src/train.py +++ b/src/train.py @@ -21,7 +21,7 @@ dict_column_23 = {"":9, "beton":1, "beton komórkowy":2, "cegła":3, "inne":4, " def read_data(in_file): print("Reading in") - all_data = pd.read_csv(in_file, sep='\t', keep_default_na=False, header = 1) + all_data = pd.read_csv(in_file, sep='\t', keep_default_na=False, header = 0) expected = all_data.iloc[:,0] data = all_data.iloc[:,1:] @@ -83,7 +83,7 @@ def main(): expected, data = read_data(args.in_file) clean_data = clean_df(data) - clean_data = clean_data.iloc[:,1] + #clean_data = clean_data.iloc[:,6] import ipdb; ipdb.set_trace() #model = Network(len(clean_data.columns)) @@ -91,7 +91,7 @@ def main(): if args.checkpoint: print(f"Loading model : {args.checkpoint}") model.load_state_dict(torch.load(args.checkpoint)) - lr = 10 + lr = 0.1 optimizer = torch.optim.Adam(model.parameters(), lr=lr) criterion = torch.nn.MSELoss() @@ -102,21 +102,23 @@ def main(): counter = 0 l = [i for i in range(len(clean_data.index))] #import ipdb; ipdb.set_trace() - for j in range(100): + for j in range(500): random.shuffle(l) for i in l: - data_arr = [clean_data[i]] - #data_arr = clean_data.loc[i].to_numpy() - #data_arr[data_arr == ""] = 0 - #try: - # data_arr = pd.to_numeric(data_arr) - #except ValueError: - # import ipdb; ipdb.set_trace() + #data_arr = [float(clean_data[i])] + data_arr = clean_data.loc[i].to_numpy() + data_arr[data_arr == ""] = 0 + try: + data_arr = pd.to_numeric(data_arr) + except ValueError: + import ipdb; ipdb.set_trace() #import ipdb; ipdb.set_trace() + data_arr = np.sum(data_arr) + #$import ipdb; ipdb.set_trace() expected_arr = float(expected.loc[i]) #tensor = torch.from_numpy(data_arr) - tensor = torch.tensor(data_arr) + tensor = torch.tensor([data_arr]) y = torch.tensor([expected_arr]) optimizer.zero_grad() @@ -134,7 +136,7 @@ def main(): counter += 1 print(f"Saving last model model-final-{lr}-{random_number}") - torch.save(model.state_dict(), f"model-{counter}-{lr}-{random_number}.ckpt") + torch.save(model.state_dict(), f"model-{counter}-{lr}-{random_number}-final.ckpt") main() diff --git a/test-A/out.tsv b/test-A/out.tsv index 717d487..cafaf1a 100644 --- a/test-A/out.tsv +++ b/test-A/out.tsv @@ -1,416 +1,418 @@ --495323.9375 --950439.375 --643593.0 --499419.0 --515375.5625 --348890.5625 --377979.5625 --594170.0 --580049.125 --424013.5625 --586827.125 --507609.0625 --588945.25 --690333.125 --865996.625 --130440.78125 --850746.125 --943237.75 --456915.1875 --761078.625 --813043.375 --839873.0 --755006.625 --615351.3125 --904405.375 --652771.5625 --1003675.125 --523848.0625 --224203.3125 --415682.25 --99657.3125 --1266040.75 --682707.875 --614927.6875 --457480.0 --509162.375 --740179.75 --604619.4375 --569034.875 --503514.0625 --861054.375 --427120.1875 --844391.75 --474142.625 --320225.1875 --553219.5 --755147.875 --430085.5625 --615351.3125 --1034034.875 --1147849.125 --424296.0 --550677.75 --658419.9375 --615351.3125 --1064536.0 --510856.875 --704454.0 --184100.0625 --656019.375 --690333.125 --804429.625 --581320.0 --882094.375 --572847.5 --594170.0 --671269.9375 --754441.75 --561409.5625 --480920.6875 --530626.125 --474142.625 --909206.5 --709113.875 --697393.5 --650653.4375 --615351.3125 --640062.8125 --644581.5 --631307.875 --603489.75 --491934.9375 --730153.875 --923892.25 --629472.125 --175627.5625 --629330.9375 --589368.875 --615351.3125 --601230.4375 --506620.625 --388852.625 --498148.125 --239877.5 --591204.625 --286758.75 --683131.5 --558867.8125 --556749.6875 --615351.3125 --465952.5625 --392806.4375 --524836.5625 --653054.0 --366541.625 --756560.0 --715750.625 --510292.0625 --332934.0 --530626.125 --708831.375 --1035447.0 --583720.5625 --982635.0 --502384.375 --604478.25 --502384.375 --600524.375 --1014548.125 --631307.875 --792003.25 --700076.5 --550395.3125 --756560.0 --1338480.75 --580049.125 --523565.6875 --427120.1875 --1105768.875 --476119.5625 --611679.875 --637097.4375 --474566.25 --897768.625 --486427.8125 --474142.625 --876022.5 --953687.25 --517352.5 --752606.125 --465670.125 --460163.0 --649382.5625 --843262.0 --418788.875 --397183.9375 --743145.125 --522577.1875 --784660.5 --700782.5 --710384.75 --425708.0625 --676353.375 --468070.6875 --770680.75 --742439.125 --545029.375 --608573.25 --383486.6875 --657713.875 --950721.875 --728318.25 --421895.4375 --683837.5 --580049.125 --597417.8125 --479508.5625 --537827.75 --596146.9375 --498289.3125 --580190.3125 --709537.5 --687932.5 --488969.5625 --615351.3125 --991672.375 --976704.25 --693016.0 --607867.25 --493629.4375 --523706.875 --500125.0625 --675929.75 --645569.9375 --493488.25 --1139800.25 --530061.25 --488122.3125 --526107.4375 --608290.875 --597700.1875 --547853.5625 --828858.75 --389135.0625 --546865.125 --756560.0 --707136.875 --427967.4375 --728318.25 --701488.625 --584426.625 --489675.5625 --870797.75 --389417.4375 --799063.75 --298620.3125 --537262.9375 --544746.9375 --474142.625 --750770.375 --1625416.75 --454514.625 --647123.25 --219967.0625 --237476.9375 --798922.5 --728318.25 --210364.875 --549124.4375 --756560.0 --766585.75 --707136.875 --1355990.625 --527801.9375 --768845.125 --356515.8125 --718010.0 --798216.5 --635544.125 --613939.1875 --692592.375 --937165.75 --493488.25 --645993.5625 --460021.75 --182546.8125 --837190.125 --654042.4375 --527378.3125 --383486.6875 --488545.9375 --582449.6875 --365976.8125 --692451.25 --1745020.5 --589368.875 --615351.3125 --523706.875 --328697.75 --856818.125 --502666.8125 --651359.5 --483744.8125 --1116077.125 --824340.125 --369365.8125 --724223.125 --791862.125 --769974.75 --841285.125 --654183.6875 --498571.75 --495182.75 --841285.125 --498854.125 --427826.1875 --927563.625 --580049.125 --701771.0 --629472.125 --427543.8125 --500689.875 --448442.6875 --746816.5 --758960.5 --646558.375 --507185.4375 --950721.875 --558867.8125 --820527.5 --394924.5625 --516505.25 --798922.5 --768562.625 --603348.5625 --364423.5 --656584.1875 --678612.75 --609702.9375 --572847.5 --463693.1875 --401279.0 --475837.125 --645146.3125 --404385.5625 --663644.625 --682990.25 --714197.375 --473295.375 --500407.4375 --1003675.125 --815867.625 --785931.375 --923751.0 --781412.625 --868256.0 --756560.0 --646982.0 --554631.5625 --931799.875 --441523.4375 --615351.3125 --580049.125 --458327.25 --1180185.875 --286617.5625 --580049.125 --6459.59375 --755147.875 --601230.4375 --672964.4375 --540651.9375 --728318.25 --678895.125 --1358673.625 --364564.75 --421613.0 --1029657.5 --672540.8125 --577224.9375 --378685.625 --684402.375 --577931.0 --833236.25 --697675.875 --601230.4375 --656725.4375 --911889.5 --685955.625 --492358.5625 --717021.5 --562680.4375 --544746.9375 --289865.375 --498854.125 --493488.25 --661385.3125 --1250790.25 --628766.125 --424296.0 --784801.625 --403538.3125 --1102521.125 --467082.1875 --302291.6875 --722669.875 --567199.125 --473154.1875 --551807.375 --615351.3125 --558867.8125 --516505.25 --678612.75 --510856.875 --613656.8125 --243972.5625 --332934.0 --587109.5625 --800758.25 --251597.8125 --735519.875 --963289.375 --700076.5 --788755.5 --428955.875 --734672.625 --589933.75 --502384.375 --930529.0 --792144.5 --481909.125 --1391998.875 --565928.25 --401843.8125 --445900.9375 --561692.0 --594170.0 --594170.0 --608290.875 --551525.0 --608290.875 --580049.125 --629472.125 --544746.9375 --563527.6875 --332086.75 --629472.125 --940131.125 --629472.125 --601230.4375 --587109.5625 --587109.5625 --161365.5 --636532.5625 --728318.25 --724929.25 --388005.375 --615351.3125 --629472.125 --516646.4375 --664774.3125 --636532.5625 --544605.75 --961453.75 --943237.75 --544746.9375 +328260.59375 +290434.46875 +292780.09375 +452705.0 +339881.5625 +294927.78125 +318433.34375 +247435.4375 +256481.875 +323397.09375 +319579.84375 +290749.3125 +342362.5625 +313328.3125 +321604.28125 +377242.15625 +447839.75 +175605.453125 +416553.0 +434855.46875 +311214.34375 +376477.53125 +393614.1875 +414439.03125 +376927.3125 +330591.34375 +423059.09375 +342557.3125 +452864.25 +301894.875 +207854.734375 +279639.75 +169623.359375 +545860.0 +351379.78125 +331166.625 +288680.34375 +325787.25 +369730.8125 +327621.4375 +316553.25 +296276.6875 +429754.0625 +294572.46875 +402384.90625 +286140.875 +267815.0 +309493.03125 +374080.21875 +273927.53125 +330916.09375 +473756.125 +518154.0 +304233.75 +329205.59375 +344319.125 +331068.125 +472727.9375 +296890.625 +358333.375 +194842.5625 +357496.78125 +354438.28125 +392084.9375 +343979.53125 +415878.34375 +348865.5 +324410.0 +347691.5625 +378960.34375 +321109.53125 +288545.40625 +334535.46875 +286560.0625 +437049.53125 +385360.71875 +356336.375 +341169.75 +330366.46875 +338291.15625 +341802.15625 +336891.0 +354904.71875 +292480.96875 +365907.6875 +448952.0625 +334696.5 +205673.296875 +352486.25 +337121.71875 +330743.375 +324751.40625 +315910.0625 +257960.3125 +293178.125 +237045.5 +324865.1875 +225967.40625 +352553.71875 +312955.0 +336716.90625 +330464.96875 +283372.90625 +281213.96875 +323102.03125 +343063.34375 +280764.1875 +376876.9375 +364843.96875 +297226.15625 +257735.421875 +304042.59375 +359707.46875 +482760.71875 +322166.5 +460096.25 +295476.0625 +345905.9375 +294572.46875 +340940.34375 +457986.8125 +352838.84375 +386454.09375 +356980.4375 +310644.90625 +384384.6875 +560150.0 +319088.65625 +301582.75 +271183.875 +509396.75 +286544.78125 +328620.875 +336402.0625 +286130.0625 +419490.0625 +288995.1875 +305996.90625 +413809.34375 +478307.90625 +300194.28125 +373154.5625 +312734.59375 +303855.9375 +343868.4375 +403194.53125 +267333.75 +263762.5 +371305.0625 +318693.75 +383351.09375 +374903.3125 +359748.40625 +301049.3125 +349895.5 +306356.71875 +383538.65625 +374082.03125 +307865.28125 +330361.53125 +257915.328125 +344678.03125 +454784.375 +366526.125 +270568.5625 +350796.40625 +320226.625 +342384.15625 +287331.0 +306716.09375 +335129.1875 +324631.3125 +319871.28125 +377107.21875 +353235.125 +319987.78125 +331454.46875 +450475.46875 +477543.28125 +369924.6875 +341529.5625 +306176.8125 +331454.46875 +308119.875 +348750.34375 +340099.28125 +292077.53125 +501291.6875 +308650.59375 +315550.21875 +303163.28125 +328932.5625 +325427.40625 +331589.40625 +408048.53125 +259449.078125 +323763.21875 +386327.75 +359210.46875 +270958.96875 +370279.5625 +385158.3125 +336441.65625 +311407.75 +414214.15625 +259488.21875 +388089.09375 +242570.15625 +320508.625 +317673.21875 +285978.5 +391455.25 +652008.75 +304260.71875 +339820.40625 +206505.390625 +211452.96875 +387616.8125 +365439.9375 +206685.296875 +310269.8125 +375180.8125 +390420.75 +359662.46875 +588864.0 +301818.4375 +380269.625 +250763.8125 +378231.6875 +403734.25 +352998.96875 +330484.3125 +377129.71875 +450556.4375 +292212.4375 +341207.96875 +282172.46875 +213297.078125 +400990.59375 +365642.3125 +302893.40625 +258994.796875 +311934.0 +320033.1875 +262008.328125 +374871.8125 +709585.1875 +322099.03125 +331011.90625 +303671.53125 +259116.25 +417767.40625 +294442.46875 +361059.0625 +289220.0625 +491450.5 +398354.90625 +263937.0 +365631.09375 +385387.6875 +379055.6875 +402700.1875 +343847.28125 +312941.5 +16417158.0 +401973.34375 +293798.84375 +283552.84375 +432250.34375 +319760.1875 +358756.1875 +334504.875 +304647.5625 +294888.65625 +279347.375 +372384.53125 +375613.9375 +340661.5 +319805.15625 +436748.15625 +313418.28125 +422580.0625 +262098.28125 +298942.96875 +387841.71875 +379760.9375 +327046.625 +251523.9375 +344947.90625 +374048.71875 +328937.5 +315614.09375 +281888.21875 +280888.3125 +288109.09375 +339343.625 +265021.875 +362174.5 +373733.875 +361927.125 +307769.03125 +307004.40625 +453070.71875 +413141.875 +391198.84375 +442505.34375 +385788.0 +411106.1875 +375663.4375 +358126.46875 +312175.0625 +431710.625 +276041.5 +330350.71875 +320290.90625 +281139.3125 +513075.96875 +227483.15625 +318900.1875 +137194.171875 +374488.15625 +325707.625 +349265.8125 +322998.59375 +366007.09375 +349900.0 +566824.75 +252248.09375 +270432.71875 +465381.1875 +370607.90625 +317736.15625 +256179.171875 +354525.09375 +320007.5625 +399551.28125 +371179.125 +325746.78125 +365570.34375 +424458.78125 +353677.6875 +303248.28125 +363015.59375 +314392.5 +309851.53125 +229255.28125 +293753.84375 +292392.375 +345505.1875 +535908.1875 +334153.15625 +304233.75 +401305.4375 +264437.15625 +484245.0 +299754.84375 +250828.578125 +364333.46875 +327136.59375 +289534.90625 +309634.28125 +329670.65625 +312779.125 +299207.0 +373643.90625 +297068.3125 +357919.59375 +215006.234375 +240958.59375 +322157.5 +407107.59375 +219079.0 +368633.375 +450425.96875 +357406.40625 +405623.34375 +271993.46875 +368542.0625 +341306.9375 +294221.1875 +434104.8125 +386309.71875 +290960.71875 +578506.9375 +314542.28125 +264093.96875 +276502.53125 +312560.09375 +324631.3125 +324586.3125 +330631.40625 +311647.9375 +327977.6875 +319125.09375 +335225.90625 +308579.09375 +314317.84375 +240164.265625 +334759.9375 +433824.59375 +335073.40625 +327641.6875 +321804.0 +321961.40625 +187614.59375 +336919.3125 +367617.75 +369609.375 +258921.9375 +330701.09375 +334808.0625 +298486.90625 +345846.5625 +336715.59375 +307528.84375 +441907.59375 +436236.75 +308643.40625