roberta large with regression layer on top

This commit is contained in:
Jakub Pokrywka 2021-07-10 21:58:30 +00:00
parent 46e06b748e
commit 400b65c4f8
6 changed files with 73824 additions and 72863 deletions

72530
dev-0/out.tsv Normal file → Executable file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,179 @@
import os
import torch
import random
import copy
from fairseq.models.roberta import RobertaModel, RobertaHubInterface
from fairseq import hub_utils
from fairseq.data.data_utils import collate_tokens
from tqdm import tqdm
import numpy as np
from sklearn.preprocessing import MinMaxScaler
EVAL_OFTEN = True
EVAL_EVERY = 10000
BATCH_SIZE = 1
model_type = 'large' # base or large
roberta = torch.hub.load('pytorch/fairseq', f'roberta.{model_type}')
roberta.cuda()
device='cuda'
# LOAD DATA
train_in = [l.rstrip('\n') for l in open('../train/in.tsv',newline='\n').readlines()] # shuffled
dev_in = [l.rstrip('\n') for l in open('../dev-0/in.tsv',newline='\n').readlines()] # shuffled
train_year = [float(l.rstrip('\n')) for l in open('../train/expected.tsv',newline='\n').readlines()]
dev_year = [float(l.rstrip('\n')) for l in open('../dev-0/expected.tsv',newline='\n').readlines()]
dev_in_not_shuffled = copy.deepcopy(dev_in) # not shuffled
test_in = [l.rstrip('\n') for l in open('../test-A/in.tsv',newline='\n').readlines()] # not shuffled
# SHUFFLE DATA
c = list(zip(train_in,train_year))
random.shuffle(c)
train_in, train_year = zip(*c)
c = list(zip(dev_in,dev_year))
random.shuffle(c)
dev_in, dev_year = zip(*c)
# SCALE DATA
scaler = MinMaxScaler()
train_year_scaled = scaler.fit_transform(np.array(train_year).reshape(-1,1))
dev_year_scaled = scaler.transform(np.array(dev_year).reshape(-1,1))
class RegressorHead(torch.nn.Module):
def __init__(self):
super(RegressorHead, self).__init__()
in_dim = 768 if model_type == 'base' else 1024
self.linear = torch.nn.Linear(in_dim, 1)
self.m = torch.nn.LeakyReLU(0.1)
def forward(self, x):
x = self.linear(x)
x = self.m(x)
x = - self.m(-x + 1 ) +1
return x
def get_features_and_year(dataset_in,dataset_y):
for i in tqdm(range(0,len(dataset_in), BATCH_SIZE)):
batch_of_text = dataset_in[i:i+BATCH_SIZE]
batch = collate_tokens([roberta.encode(p)[:512] for p in batch_of_text], pad_idx=1)
features = roberta.extract_features(batch).mean(1)
years = torch.FloatTensor(dataset_y[i:i+BATCH_SIZE]).to(device)
yield features, years
def eval_dev(short=False):
criterion_eval = torch.nn.MSELoss(reduction='sum')
roberta.eval()
regressor_head.eval()
loss = 0.0
loss_clipped = 0.0
loss_scaled = 0.0
if short:
dataset_in = dev_in[:1000]
dataset_years = dev_year_scaled[:1000]
else:
dataset_in = dev_in
dataset_years = dev_year_scaled
predictions_sum = 0
for batch, year in tqdm(get_features_and_year(dataset_in, dataset_years)):
predictions_sum += year.shape[0]
x = regressor_head(batch.to(device))
x_clipped = torch.clamp(x,0.0,1.0)
original_x = torch.FloatTensor(scaler.inverse_transform(x.detach().cpu().numpy().reshape(1,-1)))
original_x_clipped = torch.FloatTensor(scaler.inverse_transform(x_clipped.detach().cpu().numpy().reshape(1,-1)))
original_year = torch.FloatTensor(scaler.inverse_transform(year.detach().cpu().numpy().reshape(1,-1)))
loss_scaled += criterion_eval(x, year).item()
loss += criterion_eval(original_x, original_year).item()
loss_clipped += criterion_eval(original_x_clipped, original_year).item()
print('valid loss scaled: ' + str(np.sqrt(loss_scaled/predictions_sum)))
print('valid loss: ' + str(np.sqrt(loss/predictions_sum)))
print('valid loss clipped: ' + str(np.sqrt(loss_clipped/predictions_sum)))
def train_one_epoch():
roberta.train()
regressor_head.train()
loss_value=0.0
iteration = 0
for batch, year in get_features_and_year(train_in,train_year_scaled):
iteration +=1
roberta.zero_grad()
regressor_head.zero_grad()
predictions = regressor_head(batch.to(device))
loss = criterion(predictions, year)
loss_value += loss.item()
loss.backward()
optimizer.step()
roberta.zero_grad()
regressor_head.zero_grad()
if EVAL_OFTEN and (iteration > 1) and (iteration % EVAL_EVERY == 1):
print('train loss: ' + str(np.sqrt(loss_value / (EVAL_EVERY*BATCH_SIZE))))
eval_dev(True)
roberta.train()
regressor_head.train()
loss_value = 0.0
def predict(dataset='dev'):
if dataset=='dev':
f_out_path = '../dev-0/out.tsv'
dataset_in_not_shuffled = dev_in_not_shuffled
elif dataset=='test':
f_out_path = '../test-A/out.tsv'
dataset_in_not_shuffled = test_in
roberta.eval()
regressor_head.eval()
f_out = open(f_out_path,'w')
for batch, year in tqdm(get_features_and_year(dataset_in_not_shuffled, dev_year_scaled)):
x = regressor_head(batch)
x_clipped = torch.clamp(x,0.0,1.0)
original_x_clipped = scaler.inverse_transform(x_clipped.detach().cpu().numpy().reshape(1,-1))
for y in original_x_clipped[0]:
f_out.write(str(y) + '\n')
f_out.close()
regressor_head = RegressorHead().to(device)
optimizer = torch.optim.Adam(list(roberta.parameters()) + list(regressor_head.parameters()), lr=1e-6)
criterion = torch.nn.MSELoss(reduction='sum').to(device)
#for i in range(100):
# print('epoch ' + str(i))
# train_one_epoch()
#
# print(f'epoch {i} done, EVALUATION ON FULL DEV:')
# eval_dev()
# print('evaluation done')
# predict('dev')
# predict('test')
#
# torch.save(roberta.state_dict(),'checkpoints/roberta_to_regressor' + str(i) + '.pt')
# torch.save(regressor_head.state_dict(),'checkpoints/regressor_head' + str(i) + '.pt')
roberta.load_state_dict(torch.load('checkpoints/roberta_to_regressor2.pt'))
regressor_head.load_state_dict(torch.load('checkpoints/regressor_head2.pt'))
predict('dev')
predict('test')

View File

@ -0,0 +1,236 @@
epoch 0
train loss: 0.10178063126471196
valid loss scaled: 0.08694887030887369
valid loss: 14.393254678046581
valid loss clipped: 14.393254678046581
train loss: 0.08234276634246003
valid loss scaled: 0.0941885423619173
valid loss: 15.591687502080195
valid loss clipped: 15.591687502080195
train loss: 0.07840566897624576
valid loss scaled: 0.10380323876951889
valid loss: 17.18327815406788
valid loss clipped: 17.18327815406788
train loss: 0.07422883822886046
valid loss scaled: 0.10199038650759494
valid loss: 16.883180520434045
valid loss clipped: 16.883180520434045
train loss: 0.0715406045230223
valid loss scaled: 0.08943446505286708
valid loss: 14.8047131664572
valid loss clipped: 14.8047131664572
train loss: 0.07021876769740436
valid loss scaled: 0.0936252853835668
valid loss: 15.4984476046049
valid loss clipped: 15.498044701389192
train loss: 0.0698000478439966
valid loss scaled: 0.0732450713218563
valid loss: 12.124766595640176
valid loss clipped: 12.1245106015308
train loss: 0.06787076175194112
valid loss scaled: 0.08167600706479376
valid loss: 13.520399589246514
valid loss clipped: 13.520399589246514
train loss: 0.06687028820999129
valid loss scaled: 0.07860967955406604
valid loss: 13.012812666808495
valid loss clipped: 13.012812666808495
train loss: 0.06485297715840126
valid loss scaled: 0.06548883445499123
valid loss: 10.840826069687282
valid loss clipped: 10.840826069687282
train loss: 0.06536078446808377
valid loss scaled: 0.06285778998197525
valid loss: 10.405289848890389
valid loss clipped: 10.402398559859241
train loss: 0.06282014927278873
valid loss scaled: 0.0596510515677964
valid loss: 9.874458507430498
valid loss clipped: 9.87346211871712
train loss: 0.06211401904304488
valid loss scaled: 0.08676563866997164
valid loss: 14.362925503357548
valid loss clipped: 14.362736109627358
train loss: 0.06245454523502525
valid loss scaled: 0.061674062543468834
valid loss: 10.209341911143197
valid loss clipped: 10.209341911143197
train loss: 0.06125578036664975
valid loss scaled: 0.07651956140507003
valid loss: 12.666819321757812
valid loss clipped: 12.665205579153804
train loss: 0.059952159220873956
valid loss scaled: 0.060298738548431624
valid loss: 9.981674467979698
valid loss clipped: 9.981674467979698
train loss: 0.058864656505384534
valid loss scaled: 0.06224854018969625
valid loss: 10.304438001377102
valid loss clipped: 10.304438001377102
train loss: 0.06033644216377098
valid loss scaled: 0.06860725018131864
valid loss: 11.357033358430243
valid loss clipped: 11.355784159072005
train loss: 0.05914900476107049
valid loss scaled: 0.06558627462957999
valid loss: 10.856954814387644
valid loss clipped: 10.855340320357298
train loss: 0.05801354450611692
valid loss scaled: 0.05499361709851459
valid loss: 9.103479951662178
valid loss clipped: 9.1033016162492
train loss: 0.05722296967684818
valid loss scaled: 0.057276662331433935
valid loss: 9.481407818900754
valid loss clipped: 9.481123608784491
train loss: 0.05615422017549816
valid loss scaled: 0.058946735056267775
valid loss: 9.757867643250142
valid loss clipped: 9.757867643250142
train loss: 0.0568488666286758
valid loss scaled: 0.05493375345527444
valid loss: 9.093570558704585
valid loss clipped: 9.093570558704585
train loss: 0.056790975754785165
valid loss scaled: 0.05718280344549578
valid loss: 9.465868729560718
valid loss clipped: 9.465666290138364
train loss: 0.056136715463394736
valid loss scaled: 0.053061117528704765
valid loss: 8.783576806818123
valid loss clipped: 8.783576806818123
train loss: 0.0553076835671036
valid loss scaled: 0.05756376005144751
valid loss: 9.528932262971411
valid loss clipped: 9.528929186851473
train loss: 0.05608131737157289
valid loss scaled: 0.05928617719392673
valid loss: 9.814053006356623
valid loss clipped: 9.813086909431942
train loss: 0.054209274725265544
valid loss scaled: 0.059807781220226314
valid loss: 9.900402019838845
valid loss clipped: 9.899803739726012
train loss: 0.05412653998906017
valid loss scaled: 0.05289933204229587
valid loss: 8.75679344111035
valid loss clipped: 8.75679344111035
train loss: 0.05518407843528995
valid loss scaled: 0.05688392864725153
valid loss: 9.416393057892193
valid loss clipped: 9.416393057892193
train loss: 0.054482056935617085
valid loss scaled: 0.05216977017299728
valid loss: 8.636023503699716
valid loss clipped: 8.636023503699716
train loss: 0.05399237402441963
valid loss scaled: 0.05409103409701342
valid loss: 8.954068173354056
valid loss clipped: 8.954068173354056
train loss: 0.05278416502551686
valid loss scaled: 0.05505897657177228
valid loss: 9.114298554400836
valid loss clipped: 9.114128444401343
train loss: 0.05283459274043561
valid loss scaled: 0.05330468837007633
valid loss: 8.823897722342918
valid loss clipped: 8.823469673021748
train loss: 0.05349978882139668
valid loss scaled: 0.05340672519610164
valid loss: 8.840791683139289
valid loss clipped: 8.840791683139289
train loss: 0.05244616329396176
valid loss scaled: 0.05313163808650701
valid loss: 8.795251288181513
valid loss clipped: 8.795251288181513
train loss: 0.05458822252512801
valid loss scaled: 0.05058266187593807
valid loss: 8.373303749237403
valid loss clipped: 8.372144978003309
train loss: 0.052178348999718266
valid loss scaled: 0.05477028984676988
valid loss: 9.066509788893315
valid loss clipped: 9.066509788893315
train loss: 0.05264708004419216
valid loss scaled: 0.05061255628492586
valid loss: 8.378252333491037
valid loss clipped: 8.378252333491037
epoch 0 done, EVALUATION ON FULL DEV:
valid loss scaled: 0.05304482890568026
valid loss: 8.780882018378279
valid loss clipped: 8.780282676016194
evaluation done
epoch 1
train loss: 0.05193906537483793
valid loss scaled: 0.05525216534943099
valid loss: 9.146276710629154
valid loss clipped: 9.146276710629154
train loss: 0.05052301682721696
valid loss scaled: 0.05013342859224921
valid loss: 8.298936329028248
valid loss clipped: 8.298936329028248
train loss: 0.05098356105392436
valid loss scaled: 0.05003979649887808
valid loss: 8.283437402089449
valid loss clipped: 8.283437402089449
train loss: 0.050088782917609716
valid loss scaled: 0.049286247248487895
valid loss: 8.15869774398909
valid loss clipped: 8.15869774398909
train loss: 0.04976562104718302
valid loss scaled: 0.049223300781573644
valid loss: 8.148276889380307
valid loss clipped: 8.148276889380307
train loss: 0.04991738679144195
valid loss scaled: 0.057702729077406346
valid loss: 9.551934437355772
valid loss clipped: 9.550634254813586
train loss: 0.049231458019302426
valid loss scaled: 0.04882268479584859
valid loss: 8.081957979949655
valid loss clipped: 8.081957979949655
train loss: 0.04899937393720371
valid loss scaled: 0.05770907825550645
valid loss: 9.552989047406287
valid loss clipped: 9.552741958522342
train loss: 0.04908777283465607
valid loss scaled: 0.05134136848123586
valid loss: 8.498893804182858
valid loss clipped: 8.49871165029239
train loss: 0.04764838316443966
valid loss scaled: 0.05261156122562749
valid loss: 8.709157880799538
valid loss clipped: 8.709138480005972
train loss: 0.048582658014159445
valid loss scaled: 0.058413308832729716
valid loss: 9.669566298033944
valid loss clipped: 9.668254240676646
train loss: 0.04743617173694683
valid loss scaled: 0.051687093210397744
valid loss: 8.556127664837945
valid loss clipped: 8.556127664837945
train loss: 0.04656181514429912
valid loss scaled: 0.0507906026516647
valid loss: 8.407722978198628
valid loss clipped: 8.407722978198628
train loss: 0.04762343910575039
valid loss scaled: 0.04810272429882445
valid loss: 7.962782613217737
valid loss clipped: 7.962782613217737
train loss: 0.048021365618576804
valid loss scaled: 0.051694526258934385
valid loss: 8.55735450556255
valid loss clipped: 8.55735450556255
train loss: 0.04710475399633604
valid loss scaled: 0.04621696889011772
valid loss: 7.650616536714813
valid loss clipped: 7.650616536714813
train loss: 0.04589946658188546
valid loss scaled: 0.04597697057310632
valid loss: 7.610890791241026
valid loss clipped: 7.610890791241026
train loss: 0.047086027235787566
valid loss scaled: 0.04816656697145074
valid loss: 7.973350046367094
valid loss clipped: 7.973350046367094
train loss: 0.04701215419012016

View File

@ -0,0 +1,365 @@
epoch 1
train loss: 0.04859849993016349
valid loss scaled: 0.049774259863821814
valid loss: 8.239480139682152
valid loss clipped: 8.239480139682152
train loss: 0.048566044032315264
valid loss scaled: 0.05388942821210208
valid loss: 8.920693268651688
valid loss clipped: 8.920693268651688
train loss: 0.048645840325942186
valid loss scaled: 0.05395363235295337
valid loss: 8.931319168623325
valid loss clipped: 8.931319168623325
train loss: 0.04814379534040526
valid loss scaled: 0.0505393534476585
valid loss: 8.366129820820378
valid loss clipped: 8.366112817712773
train loss: 0.04762666338418628
valid loss scaled: 0.0559599088881008
valid loss: 9.263433639650358
valid loss clipped: 9.263433639650358
train loss: 0.047656307189366594
valid loss scaled: 0.05023614907590337
valid loss: 8.315944918367261
valid loss clipped: 8.315944918367261
train loss: 0.04783966918808628
valid loss scaled: 0.05002744041130589
valid loss: 8.281391949540064
valid loss clipped: 8.281391949540064
train loss: 0.0470857543381118
valid loss scaled: 0.05007423239511511
valid loss: 8.289137880244128
valid loss clipped: 8.289137880244128
train loss: 0.04748013890490943
valid loss scaled: 0.048425306807186
valid loss: 8.01617955284889
valid loss clipped: 8.01617955284889
train loss: 0.04756126793491012
valid loss scaled: 0.05838172127632288
valid loss: 9.664334319586205
valid loss clipped: 9.664334319586205
train loss: 0.04653691036105432
valid loss scaled: 0.04918750869152843
valid loss: 8.14234734821747
valid loss clipped: 8.14225860100907
train loss: 0.047854342213726093
valid loss scaled: 0.04825916678037553
valid loss: 7.988678617267059
valid loss clipped: 7.988678617267059
train loss: 0.04773736108071371
valid loss scaled: 0.04693122838927694
valid loss: 7.768854512974581
valid loss clipped: 7.768764906393828
train loss: 0.04762865603151333
valid loss scaled: 0.05211583033031525
valid loss: 8.627096447506986
valid loss clipped: 8.627096447506986
train loss: 0.04704306549118133
valid loss scaled: 0.048404824273360625
valid loss: 8.012792317972876
valid loss clipped: 8.012792317972876
train loss: 0.047052753271308353
valid loss scaled: 0.047940905819710194
valid loss: 7.935991222098737
valid loss clipped: 7.935991222098737
train loss: 0.047227875012458086
valid loss scaled: 0.04819175795767938
valid loss: 7.977517309701107
valid loss clipped: 7.977427346370881
train loss: 0.04710633027268819
valid loss scaled: 0.06383742791518349
valid loss: 10.567450933423165
valid loss clipped: 10.567450933423165
train loss: 0.04722847116132579
valid loss scaled: 0.050155074146505475
valid loss: 8.302519880598078
valid loss clipped: 8.302519880598078
train loss: 0.04704662696396088
valid loss scaled: 0.049731758847849555
valid loss: 8.232442667326458
valid loss clipped: 8.231381949146131
train loss: 0.04640111281869779
valid loss scaled: 0.04439513521075078
valid loss: 7.349034015457206
valid loss clipped: 7.349034015457206
train loss: 0.04608356887790085
valid loss scaled: 0.04921432936453068
valid loss: 8.146788631177284
valid loss clipped: 8.146788631177284
train loss: 0.04667254772723141
valid loss scaled: 0.04490475325651902
valid loss: 7.433397045419523
valid loss clipped: 7.433397045419523
train loss: 0.04732138091412676
valid loss scaled: 0.050358595447874285
valid loss: 8.336210587795785
valid loss clipped: 8.336210587795785
train loss: 0.046673220858461255
valid loss scaled: 0.05013948923902952
valid loss: 8.29994005436984
valid loss clipped: 8.29994005436984
train loss: 0.04553470963703993
valid loss scaled: 0.04682243262330682
valid loss: 7.750845583038412
valid loss clipped: 7.750845583038412
train loss: 0.046360710414935315
valid loss scaled: 0.045567268129456086
valid loss: 7.543068199989672
valid loss clipped: 7.543068199989672
train loss: 0.04552500833500083
valid loss scaled: 0.04640085875333069
valid loss: 7.681058866001286
valid loss clipped: 7.681058866001286
train loss: 0.0452476728574542
valid loss scaled: 0.046615049866360565
valid loss: 7.716513190971443
valid loss clipped: 7.716513190971443
train loss: 0.045669575255858894
valid loss scaled: 0.04750326795304539
valid loss: 7.8635460404556
valid loss clipped: 7.8635460404556
train loss: 0.04647987461932483
valid loss scaled: 0.04860825732568691
valid loss: 8.046460836562115
valid loss clipped: 8.046460836562115
train loss: 0.046793957499830896
valid loss scaled: 0.04575770256354883
valid loss: 7.574592460104707
valid loss clipped: 7.574592460104707
train loss: 0.04585571441358016
valid loss scaled: 0.04501184464469181
valid loss: 7.451126169626989
valid loss clipped: 7.451126169626989
train loss: 0.04610028065682122
valid loss scaled: 0.046051780212043245
valid loss: 7.623273458385178
valid loss clipped: 7.623083864668157
train loss: 0.046177845466100334
valid loss scaled: 0.06025220972281953
valid loss: 9.973968222380467
valid loss clipped: 9.973968222380467
train loss: 0.044866837962671394
valid loss scaled: 0.04543754842407298
valid loss: 7.5215926794190535
valid loss clipped: 7.5215926794190535
train loss: 0.044414652746250796
valid loss scaled: 0.044867695763598454
valid loss: 7.4272631776866636
valid loss clipped: 7.427128339884669
train loss: 0.04534780122814873
valid loss scaled: 0.045269020539499126
valid loss: 7.493695395829885
valid loss clipped: 7.493620279412681
train loss: 0.04439006437582844
valid loss scaled: 0.04507342677914177
valid loss: 7.461320119658475
valid loss clipped: 7.461320119658475
epoch 1 done, EVALUATION ON FULL DEV:
valid loss scaled: 0.04480323708981578
valid loss: 7.416593515309324
valid loss clipped: 7.416387303935181
evaluation done
epoch 2
train loss: 0.04156288899572168
valid loss scaled: 0.04604067304159161
valid loss: 7.621432637256123
valid loss clipped: 7.621432637256123
train loss: 0.04100707919178113
valid loss scaled: 0.052927161282336356
valid loss: 8.761404249024789
valid loss clipped: 8.761404249024789
train loss: 0.04102466219659839
valid loss scaled: 0.0449000373685904
valid loss: 7.432613525687441
valid loss clipped: 7.432613525687441
train loss: 0.04056007982480108
valid loss scaled: 0.04401167129402493
valid loss: 7.285561051000317
valid loss clipped: 7.285561051000317
train loss: 0.04043435892664155
valid loss scaled: 0.0482961730545656
valid loss: 7.9948042408749025
valid loss clipped: 7.9948042408749025
train loss: 0.040083140595004084
valid loss scaled: 0.05474324524206743
valid loss: 9.062036144331996
valid loss clipped: 9.061846768141493
train loss: 0.040820699932255504
valid loss scaled: 0.0489683078886907
valid loss: 8.106066491790516
valid loss clipped: 8.106061770281157
train loss: 0.04011373971574097
valid loss scaled: 0.04683979676309513
valid loss: 7.75372067780672
valid loss clipped: 7.75372067780672
train loss: 0.040472179534386435
valid loss scaled: 0.04310859235241129
valid loss: 7.13606575431109
valid loss clipped: 7.13606575431109
train loss: 0.040441015502005524
valid loss scaled: 0.05050364392711698
valid loss: 8.360220506489952
valid loss clipped: 8.360220506489952
train loss: 0.03942231891194775
valid loss scaled: 0.0488088141837757
valid loss: 8.079661712550005
valid loss clipped: 8.079636855330948
train loss: 0.04015028625372404
valid loss scaled: 0.04516411367366714
valid loss: 7.476330067311157
valid loss clipped: 7.476182879344117
train loss: 0.04084930318361188
valid loss scaled: 0.04687766826105318
valid loss: 7.7599889055078215
valid loss clipped: 7.759971606690415
train loss: 0.04066393936852439
valid loss scaled: 0.04189752550714794
valid loss: 6.935592127831488
valid loss clipped: 6.935592127831488
train loss: 0.03997029890541135
valid loss scaled: 0.044517524088250444
valid loss: 7.369297241211859
valid loss clipped: 7.369000766181205
train loss: 0.03993799675444034
valid loss scaled: 0.04544862752523738
valid loss: 7.523431609784254
valid loss clipped: 7.523394388992367
train loss: 0.040077832674748126
valid loss scaled: 0.04967513279215436
valid loss: 8.223067645108816
valid loss clipped: 8.223067645108816
train loss: 0.03990507605068726
valid loss scaled: 0.05784423713065072
valid loss: 9.575362301216071
valid loss clipped: 9.575185345956458
train loss: 0.040589384846144605
valid loss scaled: 0.0496796067080015
valid loss: 8.223812362251003
valid loss clipped: 8.223812362251003
train loss: 0.04017873370107866
valid loss scaled: 0.046199073425154856
valid loss: 7.647655631709408
valid loss clipped: 7.647218458815767
train loss: 0.04044891785946012
valid loss scaled: 0.04150905843239381
valid loss: 6.871281648789468
valid loss clipped: 6.871281648789468
train loss: 0.03947431928291259
valid loss scaled: 0.05190902248718164
valid loss: 8.592867224099864
valid loss clipped: 8.592867224099864
train loss: 0.03990533801732083
valid loss scaled: 0.04276042828750439
valid loss: 7.078432226253377
valid loss clipped: 7.078432226253377
train loss: 0.040009430050001356
valid loss scaled: 0.04785424872916424
valid loss: 7.92164959531741
valid loss clipped: 7.92164959531741
train loss: 0.03973275025215787
valid loss scaled: 0.0463738630632847
valid loss: 7.676589189399943
valid loss clipped: 7.67643024694617
train loss: 0.038768686683078965
valid loss scaled: 0.04392041103001954
valid loss: 7.270451050508478
valid loss clipped: 7.270451050508478
train loss: 0.039348882183647814
valid loss scaled: 0.04542625701933914
valid loss: 7.519721720555645
valid loss clipped: 7.519588671593333
train loss: 0.03904904227063566
valid loss scaled: 0.04735228840997471
valid loss: 7.838551573747202
valid loss clipped: 7.8383747004264395
train loss: 0.03885343865845939
valid loss scaled: 0.04507136710901525
valid loss: 7.460976738161493
valid loss clipped: 7.460948397077222
train loss: 0.039375428725538786
valid loss scaled: 0.049669896283577444
valid loss: 8.222201176726562
valid loss clipped: 8.222057108632997
train loss: 0.03995204566972607
valid loss scaled: 0.048435733190966335
valid loss: 8.017900004946629
valid loss clipped: 8.017900004946629
train loss: 0.0398491290780488
valid loss scaled: 0.04720521208137839
valid loss: 7.814206857210747
valid loss clipped: 7.814180725425085
train loss: 0.039705720615770324
valid loss scaled: 0.04476360371587965
valid loss: 7.410029695071655
valid loss clipped: 7.410029695071655
train loss: 0.03922532949685088
valid loss scaled: 0.043110160856005976
valid loss: 7.136322356669672
valid loss clipped: 7.136322356669672
train loss: 0.03944114218638488
valid loss scaled: 0.05162181950303874
valid loss: 8.54532049172949
valid loss clipped: 8.545318087656034
train loss: 0.038248045063862056
valid loss scaled: 0.045729077826710006
valid loss: 7.5698496787836405
valid loss clipped: 7.5698496787836405
train loss: 0.038367600972177264
valid loss scaled: 0.045730969331967096
valid loss: 7.570170291346379
valid loss clipped: 7.570170291346379
train loss: 0.03882607528483637
valid loss scaled: 0.04907676118577929
valid loss: 8.124020217147429
valid loss clipped: 8.124020217147429
train loss: 0.03804710767993723
valid loss scaled: 0.043570846150311024
valid loss: 7.2125900868664665
valid loss clipped: 7.2125900868664665
epoch 2 done, EVALUATION ON FULL DEV:
valid loss scaled: 0.04305710652707711
valid loss: 7.127543370821652
valid loss clipped: 7.127414565514862
evaluation done
epoch 3
train loss: 0.03567875976729889
valid loss scaled: 0.0435919204035305
valid loss: 7.216075509096017
valid loss clipped: 7.216075509096017
train loss: 0.03612129883597582
valid loss scaled: 0.05422117357473846
valid loss: 8.975609535535733
valid loss clipped: 8.975609535535733
train loss: 0.03610084946473125
valid loss scaled: 0.04305326329624392
valid loss: 7.126904973342827
valid loss clipped: 7.126904973342827
train loss: 0.0354239810019526
valid loss scaled: 0.041925094025886615
valid loss: 6.940151276296263
valid loss clipped: 6.940151276296263
train loss: 0.0352127675835109
valid loss scaled: 0.0480450296753359
valid loss: 7.9532327819071025
valid loss clipped: 7.9532327819071025
train loss: 0.03504781431121313
valid loss scaled: 0.05108940815500588
valid loss: 8.457186233906421
valid loss clipped: 8.45717667639358
train loss: 0.034990980345574955
valid loss scaled: 0.05037606369447504
valid loss: 8.339104845822392
valid loss clipped: 8.339104845822392
train loss: 0.034820230333430666
valid loss scaled: 0.04548966136491771
valid loss: 7.530221839492527
valid loss clipped: 7.530221839492527
train loss: 0.034814962549952325
valid loss scaled: 0.044887413795706835
valid loss: 7.430525821074159
valid loss clipped: 7.430525821074159
train loss: 0.035407486527384235
valid loss scaled: 0.04845567947436388
valid loss: 8.021209744958515
valid loss clipped: 8.021177381420596

View File

@ -0,0 +1,181 @@
import os
import torch
import random
import copy
from fairseq.models.roberta import RobertaModel, RobertaHubInterface
from fairseq import hub_utils
from fairseq.data.data_utils import collate_tokens
from tqdm import tqdm
import numpy as np
from sklearn.preprocessing import MinMaxScaler
EVAL_OFTEN = True
EVAL_EVERY = 10000
BATCH_SIZE = 1
model_type = 'large' # base or large
roberta = torch.hub.load('pytorch/fairseq', f'roberta.{model_type}')
roberta.cuda()
device='cuda'
# LOAD DATA
train_in = [l.rstrip('\n') for l in open('../train/in.tsv',newline='\n').readlines()] # shuffled
dev_in = [l.rstrip('\n') for l in open('../dev-0/in.tsv',newline='\n').readlines()] # shuffled
train_year = [float(l.rstrip('\n')) for l in open('../train/expected.tsv',newline='\n').readlines()]
dev_year = [float(l.rstrip('\n')) for l in open('../dev-0/expected.tsv',newline='\n').readlines()]
dev_in_not_shuffled = copy.deepcopy(dev_in) # not shuffled
test_in = [l.rstrip('\n') for l in open('../test-A/in.tsv',newline='\n').readlines()] # not shuffled
# SHUFFLE DATA
c = list(zip(train_in,train_year))
random.shuffle(c)
train_in, train_year = zip(*c)
c = list(zip(dev_in,dev_year))
random.shuffle(c)
dev_in, dev_year = zip(*c)
# SCALE DATA
scaler = MinMaxScaler()
train_year_scaled = scaler.fit_transform(np.array(train_year).reshape(-1,1))
dev_year_scaled = scaler.transform(np.array(dev_year).reshape(-1,1))
class RegressorHead(torch.nn.Module):
def __init__(self):
super(RegressorHead, self).__init__()
in_dim = 768 if model_type == 'base' else 1024
self.linear = torch.nn.Linear(in_dim, 1)
self.m = torch.nn.LeakyReLU(0.1)
def forward(self, x):
x = self.linear(x)
x = self.m(x)
x = - self.m(-x + 1 ) +1
return x
def get_features_and_year(dataset_in,dataset_y):
for i in tqdm(range(0,len(dataset_in), BATCH_SIZE)):
batch_of_text = dataset_in[i:i+BATCH_SIZE]
batch = collate_tokens([roberta.encode(p)[:512] for p in batch_of_text], pad_idx=1)
features = roberta.extract_features(batch).mean(1)
years = torch.FloatTensor(dataset_y[i:i+BATCH_SIZE]).to(device)
yield features, years
def eval_dev(short=False):
criterion_eval = torch.nn.MSELoss(reduction='sum')
roberta.eval()
regressor_head.eval()
loss = 0.0
loss_clipped = 0.0
loss_scaled = 0.0
if short:
dataset_in = dev_in[:1000]
dataset_years = dev_year_scaled[:1000]
else:
dataset_in = dev_in
dataset_years = dev_year_scaled
predictions_sum = 0
for batch, year in tqdm(get_features_and_year(dataset_in, dataset_years)):
predictions_sum += year.shape[0]
x = regressor_head(batch.to(device))
x_clipped = torch.clamp(x,0.0,1.0)
original_x = torch.FloatTensor(scaler.inverse_transform(x.detach().cpu().numpy().reshape(1,-1)))
original_x_clipped = torch.FloatTensor(scaler.inverse_transform(x_clipped.detach().cpu().numpy().reshape(1,-1)))
original_year = torch.FloatTensor(scaler.inverse_transform(year.detach().cpu().numpy().reshape(1,-1)))
loss_scaled += criterion_eval(x, year).item()
loss += criterion_eval(original_x, original_year).item()
loss_clipped += criterion_eval(original_x_clipped, original_year).item()
print('valid loss scaled: ' + str(np.sqrt(loss_scaled/predictions_sum)))
print('valid loss: ' + str(np.sqrt(loss/predictions_sum)))
print('valid loss clipped: ' + str(np.sqrt(loss_clipped/predictions_sum)))
def train_one_epoch():
roberta.train()
regressor_head.train()
loss_value=0.0
iteration = 0
for batch, year in get_features_and_year(train_in,train_year_scaled):
iteration +=1
roberta.zero_grad()
regressor_head.zero_grad()
predictions = regressor_head(batch.to(device))
loss = criterion(predictions, year)
loss_value += loss.item()
loss.backward()
optimizer.step()
roberta.zero_grad()
regressor_head.zero_grad()
if EVAL_OFTEN and (iteration > 1) and (iteration % EVAL_EVERY == 1):
print('train loss: ' + str(np.sqrt(loss_value / (EVAL_EVERY*BATCH_SIZE))))
eval_dev(True)
roberta.train()
regressor_head.train()
loss_value = 0.0
def predict(dataset='dev'):
if dataset=='dev':
f_out_path = '../dev-0/out.tsv'
dataset_in_not_shuffled = dev_in_not_shuffled
elif dataset=='test':
f_out_path = '../test-A/out.tsv'
dataset_in_not_shuffled = test_in
roberta.eval()
regressor_head.eval()
f_out = open(f_out_path,'w')
for batch, year in tqdm(get_features_and_year(dataset_in_not_shuffled, dev_year_scaled)):
x = regressor_head(batch)
x_clipped = torch.clamp(x,0.0,1.0)
original_x_clipped = scaler.inverse_transform(x_clipped.detach().cpu().numpy().reshape(1,-1))
for y in original_x_clipped[0]:
f_out.write(str(y) + '\n')
f_out.close()
regressor_head = RegressorHead().to(device)
optimizer = torch.optim.Adam(list(roberta.parameters()) + list(regressor_head.parameters()), lr=1e-6)
criterion = torch.nn.MSELoss(reduction='sum').to(device)
roberta.load_state_dict(torch.load('checkpoints/roberta_to_regressor0.pt'))
regressor_head.load_state_dict(torch.load('checkpoints/regressor_head0.pt'))
for i in range(1,100):
print('epoch ' + str(i))
train_one_epoch()
print(f'epoch {i} done, EVALUATION ON FULL DEV:')
eval_dev()
print('evaluation done')
predict('dev')
predict('test')
torch.save(roberta.state_dict(),'checkpoints/roberta_to_regressor' + str(i) + '.pt')
torch.save(regressor_head.state_dict(),'checkpoints/regressor_head' + str(i) + '.pt')
roberta.load_state_dict(torch.load('checkpoints/roberta_to_regressor1.pt'))
regressor_head.load_state_dict(torch.load('checkpoints/regressor_head1.pt'))
predict('dev')
predict('test')

73196
test-A/out.tsv Normal file → Executable file

File diff suppressed because it is too large Load Diff