roberta large with regression layer on top
This commit is contained in:
parent
46e06b748e
commit
400b65c4f8
72530
dev-0/out.tsv
Normal file → Executable file
72530
dev-0/out.tsv
Normal file → Executable file
File diff suppressed because it is too large
Load Diff
179
roberta_large_regressor_head/predict.py
Normal file
179
roberta_large_regressor_head/predict.py
Normal file
@ -0,0 +1,179 @@
|
||||
import os
|
||||
import torch
|
||||
import random
|
||||
import copy
|
||||
from fairseq.models.roberta import RobertaModel, RobertaHubInterface
|
||||
from fairseq import hub_utils
|
||||
from fairseq.data.data_utils import collate_tokens
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
|
||||
|
||||
EVAL_OFTEN = True
|
||||
EVAL_EVERY = 10000
|
||||
BATCH_SIZE = 1
|
||||
model_type = 'large' # base or large
|
||||
|
||||
|
||||
|
||||
roberta = torch.hub.load('pytorch/fairseq', f'roberta.{model_type}')
|
||||
roberta.cuda()
|
||||
device='cuda'
|
||||
|
||||
|
||||
# LOAD DATA
|
||||
train_in = [l.rstrip('\n') for l in open('../train/in.tsv',newline='\n').readlines()] # shuffled
|
||||
dev_in = [l.rstrip('\n') for l in open('../dev-0/in.tsv',newline='\n').readlines()] # shuffled
|
||||
|
||||
train_year = [float(l.rstrip('\n')) for l in open('../train/expected.tsv',newline='\n').readlines()]
|
||||
dev_year = [float(l.rstrip('\n')) for l in open('../dev-0/expected.tsv',newline='\n').readlines()]
|
||||
|
||||
dev_in_not_shuffled = copy.deepcopy(dev_in) # not shuffled
|
||||
test_in = [l.rstrip('\n') for l in open('../test-A/in.tsv',newline='\n').readlines()] # not shuffled
|
||||
|
||||
# SHUFFLE DATA
|
||||
c = list(zip(train_in,train_year))
|
||||
random.shuffle(c)
|
||||
train_in, train_year = zip(*c)
|
||||
c = list(zip(dev_in,dev_year))
|
||||
random.shuffle(c)
|
||||
dev_in, dev_year = zip(*c)
|
||||
|
||||
# SCALE DATA
|
||||
scaler = MinMaxScaler()
|
||||
train_year_scaled = scaler.fit_transform(np.array(train_year).reshape(-1,1))
|
||||
dev_year_scaled = scaler.transform(np.array(dev_year).reshape(-1,1))
|
||||
|
||||
|
||||
class RegressorHead(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(RegressorHead, self).__init__()
|
||||
in_dim = 768 if model_type == 'base' else 1024
|
||||
self.linear = torch.nn.Linear(in_dim, 1)
|
||||
self.m = torch.nn.LeakyReLU(0.1)
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
x = self.m(x)
|
||||
x = - self.m(-x + 1 ) +1
|
||||
return x
|
||||
|
||||
def get_features_and_year(dataset_in,dataset_y):
|
||||
for i in tqdm(range(0,len(dataset_in), BATCH_SIZE)):
|
||||
batch_of_text = dataset_in[i:i+BATCH_SIZE]
|
||||
|
||||
batch = collate_tokens([roberta.encode(p)[:512] for p in batch_of_text], pad_idx=1)
|
||||
features = roberta.extract_features(batch).mean(1)
|
||||
years = torch.FloatTensor(dataset_y[i:i+BATCH_SIZE]).to(device)
|
||||
|
||||
yield features, years
|
||||
|
||||
def eval_dev(short=False):
|
||||
criterion_eval = torch.nn.MSELoss(reduction='sum')
|
||||
roberta.eval()
|
||||
regressor_head.eval()
|
||||
|
||||
loss = 0.0
|
||||
loss_clipped = 0.0
|
||||
loss_scaled = 0.0
|
||||
|
||||
if short:
|
||||
dataset_in = dev_in[:1000]
|
||||
dataset_years = dev_year_scaled[:1000]
|
||||
else:
|
||||
dataset_in = dev_in
|
||||
dataset_years = dev_year_scaled
|
||||
|
||||
predictions_sum = 0
|
||||
for batch, year in tqdm(get_features_and_year(dataset_in, dataset_years)):
|
||||
|
||||
predictions_sum += year.shape[0]
|
||||
x = regressor_head(batch.to(device))
|
||||
x_clipped = torch.clamp(x,0.0,1.0)
|
||||
|
||||
original_x = torch.FloatTensor(scaler.inverse_transform(x.detach().cpu().numpy().reshape(1,-1)))
|
||||
original_x_clipped = torch.FloatTensor(scaler.inverse_transform(x_clipped.detach().cpu().numpy().reshape(1,-1)))
|
||||
original_year = torch.FloatTensor(scaler.inverse_transform(year.detach().cpu().numpy().reshape(1,-1)))
|
||||
|
||||
loss_scaled += criterion_eval(x, year).item()
|
||||
loss += criterion_eval(original_x, original_year).item()
|
||||
loss_clipped += criterion_eval(original_x_clipped, original_year).item()
|
||||
|
||||
print('valid loss scaled: ' + str(np.sqrt(loss_scaled/predictions_sum)))
|
||||
print('valid loss: ' + str(np.sqrt(loss/predictions_sum)))
|
||||
print('valid loss clipped: ' + str(np.sqrt(loss_clipped/predictions_sum)))
|
||||
|
||||
|
||||
|
||||
def train_one_epoch():
|
||||
roberta.train()
|
||||
regressor_head.train()
|
||||
loss_value=0.0
|
||||
iteration = 0
|
||||
for batch, year in get_features_and_year(train_in,train_year_scaled):
|
||||
iteration +=1
|
||||
roberta.zero_grad()
|
||||
regressor_head.zero_grad()
|
||||
|
||||
predictions = regressor_head(batch.to(device))
|
||||
|
||||
loss = criterion(predictions, year)
|
||||
loss_value += loss.item()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
roberta.zero_grad()
|
||||
regressor_head.zero_grad()
|
||||
|
||||
|
||||
if EVAL_OFTEN and (iteration > 1) and (iteration % EVAL_EVERY == 1):
|
||||
print('train loss: ' + str(np.sqrt(loss_value / (EVAL_EVERY*BATCH_SIZE))))
|
||||
eval_dev(True)
|
||||
roberta.train()
|
||||
regressor_head.train()
|
||||
loss_value = 0.0
|
||||
|
||||
|
||||
def predict(dataset='dev'):
|
||||
if dataset=='dev':
|
||||
f_out_path = '../dev-0/out.tsv'
|
||||
dataset_in_not_shuffled = dev_in_not_shuffled
|
||||
elif dataset=='test':
|
||||
f_out_path = '../test-A/out.tsv'
|
||||
dataset_in_not_shuffled = test_in
|
||||
roberta.eval()
|
||||
regressor_head.eval()
|
||||
f_out = open(f_out_path,'w')
|
||||
for batch, year in tqdm(get_features_and_year(dataset_in_not_shuffled, dev_year_scaled)):
|
||||
x = regressor_head(batch)
|
||||
x_clipped = torch.clamp(x,0.0,1.0)
|
||||
original_x_clipped = scaler.inverse_transform(x_clipped.detach().cpu().numpy().reshape(1,-1))
|
||||
for y in original_x_clipped[0]:
|
||||
f_out.write(str(y) + '\n')
|
||||
f_out.close()
|
||||
|
||||
|
||||
regressor_head = RegressorHead().to(device)
|
||||
|
||||
optimizer = torch.optim.Adam(list(roberta.parameters()) + list(regressor_head.parameters()), lr=1e-6)
|
||||
criterion = torch.nn.MSELoss(reduction='sum').to(device)
|
||||
|
||||
|
||||
#for i in range(100):
|
||||
# print('epoch ' + str(i))
|
||||
# train_one_epoch()
|
||||
#
|
||||
# print(f'epoch {i} done, EVALUATION ON FULL DEV:')
|
||||
# eval_dev()
|
||||
# print('evaluation done')
|
||||
# predict('dev')
|
||||
# predict('test')
|
||||
#
|
||||
# torch.save(roberta.state_dict(),'checkpoints/roberta_to_regressor' + str(i) + '.pt')
|
||||
# torch.save(regressor_head.state_dict(),'checkpoints/regressor_head' + str(i) + '.pt')
|
||||
|
||||
|
||||
roberta.load_state_dict(torch.load('checkpoints/roberta_to_regressor2.pt'))
|
||||
regressor_head.load_state_dict(torch.load('checkpoints/regressor_head2.pt'))
|
||||
predict('dev')
|
||||
predict('test')
|
236
roberta_large_regressor_head/train.logs
Normal file
236
roberta_large_regressor_head/train.logs
Normal file
@ -0,0 +1,236 @@
|
||||
epoch 0
|
||||
train loss: 0.10178063126471196
|
||||
valid loss scaled: 0.08694887030887369
|
||||
valid loss: 14.393254678046581
|
||||
valid loss clipped: 14.393254678046581
|
||||
train loss: 0.08234276634246003
|
||||
valid loss scaled: 0.0941885423619173
|
||||
valid loss: 15.591687502080195
|
||||
valid loss clipped: 15.591687502080195
|
||||
train loss: 0.07840566897624576
|
||||
valid loss scaled: 0.10380323876951889
|
||||
valid loss: 17.18327815406788
|
||||
valid loss clipped: 17.18327815406788
|
||||
train loss: 0.07422883822886046
|
||||
valid loss scaled: 0.10199038650759494
|
||||
valid loss: 16.883180520434045
|
||||
valid loss clipped: 16.883180520434045
|
||||
train loss: 0.0715406045230223
|
||||
valid loss scaled: 0.08943446505286708
|
||||
valid loss: 14.8047131664572
|
||||
valid loss clipped: 14.8047131664572
|
||||
train loss: 0.07021876769740436
|
||||
valid loss scaled: 0.0936252853835668
|
||||
valid loss: 15.4984476046049
|
||||
valid loss clipped: 15.498044701389192
|
||||
train loss: 0.0698000478439966
|
||||
valid loss scaled: 0.0732450713218563
|
||||
valid loss: 12.124766595640176
|
||||
valid loss clipped: 12.1245106015308
|
||||
train loss: 0.06787076175194112
|
||||
valid loss scaled: 0.08167600706479376
|
||||
valid loss: 13.520399589246514
|
||||
valid loss clipped: 13.520399589246514
|
||||
train loss: 0.06687028820999129
|
||||
valid loss scaled: 0.07860967955406604
|
||||
valid loss: 13.012812666808495
|
||||
valid loss clipped: 13.012812666808495
|
||||
train loss: 0.06485297715840126
|
||||
valid loss scaled: 0.06548883445499123
|
||||
valid loss: 10.840826069687282
|
||||
valid loss clipped: 10.840826069687282
|
||||
train loss: 0.06536078446808377
|
||||
valid loss scaled: 0.06285778998197525
|
||||
valid loss: 10.405289848890389
|
||||
valid loss clipped: 10.402398559859241
|
||||
train loss: 0.06282014927278873
|
||||
valid loss scaled: 0.0596510515677964
|
||||
valid loss: 9.874458507430498
|
||||
valid loss clipped: 9.87346211871712
|
||||
train loss: 0.06211401904304488
|
||||
valid loss scaled: 0.08676563866997164
|
||||
valid loss: 14.362925503357548
|
||||
valid loss clipped: 14.362736109627358
|
||||
train loss: 0.06245454523502525
|
||||
valid loss scaled: 0.061674062543468834
|
||||
valid loss: 10.209341911143197
|
||||
valid loss clipped: 10.209341911143197
|
||||
train loss: 0.06125578036664975
|
||||
valid loss scaled: 0.07651956140507003
|
||||
valid loss: 12.666819321757812
|
||||
valid loss clipped: 12.665205579153804
|
||||
train loss: 0.059952159220873956
|
||||
valid loss scaled: 0.060298738548431624
|
||||
valid loss: 9.981674467979698
|
||||
valid loss clipped: 9.981674467979698
|
||||
train loss: 0.058864656505384534
|
||||
valid loss scaled: 0.06224854018969625
|
||||
valid loss: 10.304438001377102
|
||||
valid loss clipped: 10.304438001377102
|
||||
train loss: 0.06033644216377098
|
||||
valid loss scaled: 0.06860725018131864
|
||||
valid loss: 11.357033358430243
|
||||
valid loss clipped: 11.355784159072005
|
||||
train loss: 0.05914900476107049
|
||||
valid loss scaled: 0.06558627462957999
|
||||
valid loss: 10.856954814387644
|
||||
valid loss clipped: 10.855340320357298
|
||||
train loss: 0.05801354450611692
|
||||
valid loss scaled: 0.05499361709851459
|
||||
valid loss: 9.103479951662178
|
||||
valid loss clipped: 9.1033016162492
|
||||
train loss: 0.05722296967684818
|
||||
valid loss scaled: 0.057276662331433935
|
||||
valid loss: 9.481407818900754
|
||||
valid loss clipped: 9.481123608784491
|
||||
train loss: 0.05615422017549816
|
||||
valid loss scaled: 0.058946735056267775
|
||||
valid loss: 9.757867643250142
|
||||
valid loss clipped: 9.757867643250142
|
||||
train loss: 0.0568488666286758
|
||||
valid loss scaled: 0.05493375345527444
|
||||
valid loss: 9.093570558704585
|
||||
valid loss clipped: 9.093570558704585
|
||||
train loss: 0.056790975754785165
|
||||
valid loss scaled: 0.05718280344549578
|
||||
valid loss: 9.465868729560718
|
||||
valid loss clipped: 9.465666290138364
|
||||
train loss: 0.056136715463394736
|
||||
valid loss scaled: 0.053061117528704765
|
||||
valid loss: 8.783576806818123
|
||||
valid loss clipped: 8.783576806818123
|
||||
train loss: 0.0553076835671036
|
||||
valid loss scaled: 0.05756376005144751
|
||||
valid loss: 9.528932262971411
|
||||
valid loss clipped: 9.528929186851473
|
||||
train loss: 0.05608131737157289
|
||||
valid loss scaled: 0.05928617719392673
|
||||
valid loss: 9.814053006356623
|
||||
valid loss clipped: 9.813086909431942
|
||||
train loss: 0.054209274725265544
|
||||
valid loss scaled: 0.059807781220226314
|
||||
valid loss: 9.900402019838845
|
||||
valid loss clipped: 9.899803739726012
|
||||
train loss: 0.05412653998906017
|
||||
valid loss scaled: 0.05289933204229587
|
||||
valid loss: 8.75679344111035
|
||||
valid loss clipped: 8.75679344111035
|
||||
train loss: 0.05518407843528995
|
||||
valid loss scaled: 0.05688392864725153
|
||||
valid loss: 9.416393057892193
|
||||
valid loss clipped: 9.416393057892193
|
||||
train loss: 0.054482056935617085
|
||||
valid loss scaled: 0.05216977017299728
|
||||
valid loss: 8.636023503699716
|
||||
valid loss clipped: 8.636023503699716
|
||||
train loss: 0.05399237402441963
|
||||
valid loss scaled: 0.05409103409701342
|
||||
valid loss: 8.954068173354056
|
||||
valid loss clipped: 8.954068173354056
|
||||
train loss: 0.05278416502551686
|
||||
valid loss scaled: 0.05505897657177228
|
||||
valid loss: 9.114298554400836
|
||||
valid loss clipped: 9.114128444401343
|
||||
train loss: 0.05283459274043561
|
||||
valid loss scaled: 0.05330468837007633
|
||||
valid loss: 8.823897722342918
|
||||
valid loss clipped: 8.823469673021748
|
||||
train loss: 0.05349978882139668
|
||||
valid loss scaled: 0.05340672519610164
|
||||
valid loss: 8.840791683139289
|
||||
valid loss clipped: 8.840791683139289
|
||||
train loss: 0.05244616329396176
|
||||
valid loss scaled: 0.05313163808650701
|
||||
valid loss: 8.795251288181513
|
||||
valid loss clipped: 8.795251288181513
|
||||
train loss: 0.05458822252512801
|
||||
valid loss scaled: 0.05058266187593807
|
||||
valid loss: 8.373303749237403
|
||||
valid loss clipped: 8.372144978003309
|
||||
train loss: 0.052178348999718266
|
||||
valid loss scaled: 0.05477028984676988
|
||||
valid loss: 9.066509788893315
|
||||
valid loss clipped: 9.066509788893315
|
||||
train loss: 0.05264708004419216
|
||||
valid loss scaled: 0.05061255628492586
|
||||
valid loss: 8.378252333491037
|
||||
valid loss clipped: 8.378252333491037
|
||||
epoch 0 done, EVALUATION ON FULL DEV:
|
||||
valid loss scaled: 0.05304482890568026
|
||||
valid loss: 8.780882018378279
|
||||
valid loss clipped: 8.780282676016194
|
||||
evaluation done
|
||||
epoch 1
|
||||
train loss: 0.05193906537483793
|
||||
valid loss scaled: 0.05525216534943099
|
||||
valid loss: 9.146276710629154
|
||||
valid loss clipped: 9.146276710629154
|
||||
train loss: 0.05052301682721696
|
||||
valid loss scaled: 0.05013342859224921
|
||||
valid loss: 8.298936329028248
|
||||
valid loss clipped: 8.298936329028248
|
||||
train loss: 0.05098356105392436
|
||||
valid loss scaled: 0.05003979649887808
|
||||
valid loss: 8.283437402089449
|
||||
valid loss clipped: 8.283437402089449
|
||||
train loss: 0.050088782917609716
|
||||
valid loss scaled: 0.049286247248487895
|
||||
valid loss: 8.15869774398909
|
||||
valid loss clipped: 8.15869774398909
|
||||
train loss: 0.04976562104718302
|
||||
valid loss scaled: 0.049223300781573644
|
||||
valid loss: 8.148276889380307
|
||||
valid loss clipped: 8.148276889380307
|
||||
train loss: 0.04991738679144195
|
||||
valid loss scaled: 0.057702729077406346
|
||||
valid loss: 9.551934437355772
|
||||
valid loss clipped: 9.550634254813586
|
||||
train loss: 0.049231458019302426
|
||||
valid loss scaled: 0.04882268479584859
|
||||
valid loss: 8.081957979949655
|
||||
valid loss clipped: 8.081957979949655
|
||||
train loss: 0.04899937393720371
|
||||
valid loss scaled: 0.05770907825550645
|
||||
valid loss: 9.552989047406287
|
||||
valid loss clipped: 9.552741958522342
|
||||
train loss: 0.04908777283465607
|
||||
valid loss scaled: 0.05134136848123586
|
||||
valid loss: 8.498893804182858
|
||||
valid loss clipped: 8.49871165029239
|
||||
train loss: 0.04764838316443966
|
||||
valid loss scaled: 0.05261156122562749
|
||||
valid loss: 8.709157880799538
|
||||
valid loss clipped: 8.709138480005972
|
||||
train loss: 0.048582658014159445
|
||||
valid loss scaled: 0.058413308832729716
|
||||
valid loss: 9.669566298033944
|
||||
valid loss clipped: 9.668254240676646
|
||||
train loss: 0.04743617173694683
|
||||
valid loss scaled: 0.051687093210397744
|
||||
valid loss: 8.556127664837945
|
||||
valid loss clipped: 8.556127664837945
|
||||
train loss: 0.04656181514429912
|
||||
valid loss scaled: 0.0507906026516647
|
||||
valid loss: 8.407722978198628
|
||||
valid loss clipped: 8.407722978198628
|
||||
train loss: 0.04762343910575039
|
||||
valid loss scaled: 0.04810272429882445
|
||||
valid loss: 7.962782613217737
|
||||
valid loss clipped: 7.962782613217737
|
||||
train loss: 0.048021365618576804
|
||||
valid loss scaled: 0.051694526258934385
|
||||
valid loss: 8.55735450556255
|
||||
valid loss clipped: 8.55735450556255
|
||||
train loss: 0.04710475399633604
|
||||
valid loss scaled: 0.04621696889011772
|
||||
valid loss: 7.650616536714813
|
||||
valid loss clipped: 7.650616536714813
|
||||
train loss: 0.04589946658188546
|
||||
valid loss scaled: 0.04597697057310632
|
||||
valid loss: 7.610890791241026
|
||||
valid loss clipped: 7.610890791241026
|
||||
train loss: 0.047086027235787566
|
||||
valid loss scaled: 0.04816656697145074
|
||||
valid loss: 7.973350046367094
|
||||
valid loss clipped: 7.973350046367094
|
||||
train loss: 0.04701215419012016
|
365
roberta_large_regressor_head/train.logs2
Normal file
365
roberta_large_regressor_head/train.logs2
Normal file
@ -0,0 +1,365 @@
|
||||
epoch 1
|
||||
train loss: 0.04859849993016349
|
||||
valid loss scaled: 0.049774259863821814
|
||||
valid loss: 8.239480139682152
|
||||
valid loss clipped: 8.239480139682152
|
||||
train loss: 0.048566044032315264
|
||||
valid loss scaled: 0.05388942821210208
|
||||
valid loss: 8.920693268651688
|
||||
valid loss clipped: 8.920693268651688
|
||||
train loss: 0.048645840325942186
|
||||
valid loss scaled: 0.05395363235295337
|
||||
valid loss: 8.931319168623325
|
||||
valid loss clipped: 8.931319168623325
|
||||
train loss: 0.04814379534040526
|
||||
valid loss scaled: 0.0505393534476585
|
||||
valid loss: 8.366129820820378
|
||||
valid loss clipped: 8.366112817712773
|
||||
train loss: 0.04762666338418628
|
||||
valid loss scaled: 0.0559599088881008
|
||||
valid loss: 9.263433639650358
|
||||
valid loss clipped: 9.263433639650358
|
||||
train loss: 0.047656307189366594
|
||||
valid loss scaled: 0.05023614907590337
|
||||
valid loss: 8.315944918367261
|
||||
valid loss clipped: 8.315944918367261
|
||||
train loss: 0.04783966918808628
|
||||
valid loss scaled: 0.05002744041130589
|
||||
valid loss: 8.281391949540064
|
||||
valid loss clipped: 8.281391949540064
|
||||
train loss: 0.0470857543381118
|
||||
valid loss scaled: 0.05007423239511511
|
||||
valid loss: 8.289137880244128
|
||||
valid loss clipped: 8.289137880244128
|
||||
train loss: 0.04748013890490943
|
||||
valid loss scaled: 0.048425306807186
|
||||
valid loss: 8.01617955284889
|
||||
valid loss clipped: 8.01617955284889
|
||||
train loss: 0.04756126793491012
|
||||
valid loss scaled: 0.05838172127632288
|
||||
valid loss: 9.664334319586205
|
||||
valid loss clipped: 9.664334319586205
|
||||
train loss: 0.04653691036105432
|
||||
valid loss scaled: 0.04918750869152843
|
||||
valid loss: 8.14234734821747
|
||||
valid loss clipped: 8.14225860100907
|
||||
train loss: 0.047854342213726093
|
||||
valid loss scaled: 0.04825916678037553
|
||||
valid loss: 7.988678617267059
|
||||
valid loss clipped: 7.988678617267059
|
||||
train loss: 0.04773736108071371
|
||||
valid loss scaled: 0.04693122838927694
|
||||
valid loss: 7.768854512974581
|
||||
valid loss clipped: 7.768764906393828
|
||||
train loss: 0.04762865603151333
|
||||
valid loss scaled: 0.05211583033031525
|
||||
valid loss: 8.627096447506986
|
||||
valid loss clipped: 8.627096447506986
|
||||
train loss: 0.04704306549118133
|
||||
valid loss scaled: 0.048404824273360625
|
||||
valid loss: 8.012792317972876
|
||||
valid loss clipped: 8.012792317972876
|
||||
train loss: 0.047052753271308353
|
||||
valid loss scaled: 0.047940905819710194
|
||||
valid loss: 7.935991222098737
|
||||
valid loss clipped: 7.935991222098737
|
||||
train loss: 0.047227875012458086
|
||||
valid loss scaled: 0.04819175795767938
|
||||
valid loss: 7.977517309701107
|
||||
valid loss clipped: 7.977427346370881
|
||||
train loss: 0.04710633027268819
|
||||
valid loss scaled: 0.06383742791518349
|
||||
valid loss: 10.567450933423165
|
||||
valid loss clipped: 10.567450933423165
|
||||
train loss: 0.04722847116132579
|
||||
valid loss scaled: 0.050155074146505475
|
||||
valid loss: 8.302519880598078
|
||||
valid loss clipped: 8.302519880598078
|
||||
train loss: 0.04704662696396088
|
||||
valid loss scaled: 0.049731758847849555
|
||||
valid loss: 8.232442667326458
|
||||
valid loss clipped: 8.231381949146131
|
||||
train loss: 0.04640111281869779
|
||||
valid loss scaled: 0.04439513521075078
|
||||
valid loss: 7.349034015457206
|
||||
valid loss clipped: 7.349034015457206
|
||||
train loss: 0.04608356887790085
|
||||
valid loss scaled: 0.04921432936453068
|
||||
valid loss: 8.146788631177284
|
||||
valid loss clipped: 8.146788631177284
|
||||
train loss: 0.04667254772723141
|
||||
valid loss scaled: 0.04490475325651902
|
||||
valid loss: 7.433397045419523
|
||||
valid loss clipped: 7.433397045419523
|
||||
train loss: 0.04732138091412676
|
||||
valid loss scaled: 0.050358595447874285
|
||||
valid loss: 8.336210587795785
|
||||
valid loss clipped: 8.336210587795785
|
||||
train loss: 0.046673220858461255
|
||||
valid loss scaled: 0.05013948923902952
|
||||
valid loss: 8.29994005436984
|
||||
valid loss clipped: 8.29994005436984
|
||||
train loss: 0.04553470963703993
|
||||
valid loss scaled: 0.04682243262330682
|
||||
valid loss: 7.750845583038412
|
||||
valid loss clipped: 7.750845583038412
|
||||
train loss: 0.046360710414935315
|
||||
valid loss scaled: 0.045567268129456086
|
||||
valid loss: 7.543068199989672
|
||||
valid loss clipped: 7.543068199989672
|
||||
train loss: 0.04552500833500083
|
||||
valid loss scaled: 0.04640085875333069
|
||||
valid loss: 7.681058866001286
|
||||
valid loss clipped: 7.681058866001286
|
||||
train loss: 0.0452476728574542
|
||||
valid loss scaled: 0.046615049866360565
|
||||
valid loss: 7.716513190971443
|
||||
valid loss clipped: 7.716513190971443
|
||||
train loss: 0.045669575255858894
|
||||
valid loss scaled: 0.04750326795304539
|
||||
valid loss: 7.8635460404556
|
||||
valid loss clipped: 7.8635460404556
|
||||
train loss: 0.04647987461932483
|
||||
valid loss scaled: 0.04860825732568691
|
||||
valid loss: 8.046460836562115
|
||||
valid loss clipped: 8.046460836562115
|
||||
train loss: 0.046793957499830896
|
||||
valid loss scaled: 0.04575770256354883
|
||||
valid loss: 7.574592460104707
|
||||
valid loss clipped: 7.574592460104707
|
||||
train loss: 0.04585571441358016
|
||||
valid loss scaled: 0.04501184464469181
|
||||
valid loss: 7.451126169626989
|
||||
valid loss clipped: 7.451126169626989
|
||||
train loss: 0.04610028065682122
|
||||
valid loss scaled: 0.046051780212043245
|
||||
valid loss: 7.623273458385178
|
||||
valid loss clipped: 7.623083864668157
|
||||
train loss: 0.046177845466100334
|
||||
valid loss scaled: 0.06025220972281953
|
||||
valid loss: 9.973968222380467
|
||||
valid loss clipped: 9.973968222380467
|
||||
train loss: 0.044866837962671394
|
||||
valid loss scaled: 0.04543754842407298
|
||||
valid loss: 7.5215926794190535
|
||||
valid loss clipped: 7.5215926794190535
|
||||
train loss: 0.044414652746250796
|
||||
valid loss scaled: 0.044867695763598454
|
||||
valid loss: 7.4272631776866636
|
||||
valid loss clipped: 7.427128339884669
|
||||
train loss: 0.04534780122814873
|
||||
valid loss scaled: 0.045269020539499126
|
||||
valid loss: 7.493695395829885
|
||||
valid loss clipped: 7.493620279412681
|
||||
train loss: 0.04439006437582844
|
||||
valid loss scaled: 0.04507342677914177
|
||||
valid loss: 7.461320119658475
|
||||
valid loss clipped: 7.461320119658475
|
||||
epoch 1 done, EVALUATION ON FULL DEV:
|
||||
valid loss scaled: 0.04480323708981578
|
||||
valid loss: 7.416593515309324
|
||||
valid loss clipped: 7.416387303935181
|
||||
evaluation done
|
||||
epoch 2
|
||||
train loss: 0.04156288899572168
|
||||
valid loss scaled: 0.04604067304159161
|
||||
valid loss: 7.621432637256123
|
||||
valid loss clipped: 7.621432637256123
|
||||
train loss: 0.04100707919178113
|
||||
valid loss scaled: 0.052927161282336356
|
||||
valid loss: 8.761404249024789
|
||||
valid loss clipped: 8.761404249024789
|
||||
train loss: 0.04102466219659839
|
||||
valid loss scaled: 0.0449000373685904
|
||||
valid loss: 7.432613525687441
|
||||
valid loss clipped: 7.432613525687441
|
||||
train loss: 0.04056007982480108
|
||||
valid loss scaled: 0.04401167129402493
|
||||
valid loss: 7.285561051000317
|
||||
valid loss clipped: 7.285561051000317
|
||||
train loss: 0.04043435892664155
|
||||
valid loss scaled: 0.0482961730545656
|
||||
valid loss: 7.9948042408749025
|
||||
valid loss clipped: 7.9948042408749025
|
||||
train loss: 0.040083140595004084
|
||||
valid loss scaled: 0.05474324524206743
|
||||
valid loss: 9.062036144331996
|
||||
valid loss clipped: 9.061846768141493
|
||||
train loss: 0.040820699932255504
|
||||
valid loss scaled: 0.0489683078886907
|
||||
valid loss: 8.106066491790516
|
||||
valid loss clipped: 8.106061770281157
|
||||
train loss: 0.04011373971574097
|
||||
valid loss scaled: 0.04683979676309513
|
||||
valid loss: 7.75372067780672
|
||||
valid loss clipped: 7.75372067780672
|
||||
train loss: 0.040472179534386435
|
||||
valid loss scaled: 0.04310859235241129
|
||||
valid loss: 7.13606575431109
|
||||
valid loss clipped: 7.13606575431109
|
||||
train loss: 0.040441015502005524
|
||||
valid loss scaled: 0.05050364392711698
|
||||
valid loss: 8.360220506489952
|
||||
valid loss clipped: 8.360220506489952
|
||||
train loss: 0.03942231891194775
|
||||
valid loss scaled: 0.0488088141837757
|
||||
valid loss: 8.079661712550005
|
||||
valid loss clipped: 8.079636855330948
|
||||
train loss: 0.04015028625372404
|
||||
valid loss scaled: 0.04516411367366714
|
||||
valid loss: 7.476330067311157
|
||||
valid loss clipped: 7.476182879344117
|
||||
train loss: 0.04084930318361188
|
||||
valid loss scaled: 0.04687766826105318
|
||||
valid loss: 7.7599889055078215
|
||||
valid loss clipped: 7.759971606690415
|
||||
train loss: 0.04066393936852439
|
||||
valid loss scaled: 0.04189752550714794
|
||||
valid loss: 6.935592127831488
|
||||
valid loss clipped: 6.935592127831488
|
||||
train loss: 0.03997029890541135
|
||||
valid loss scaled: 0.044517524088250444
|
||||
valid loss: 7.369297241211859
|
||||
valid loss clipped: 7.369000766181205
|
||||
train loss: 0.03993799675444034
|
||||
valid loss scaled: 0.04544862752523738
|
||||
valid loss: 7.523431609784254
|
||||
valid loss clipped: 7.523394388992367
|
||||
train loss: 0.040077832674748126
|
||||
valid loss scaled: 0.04967513279215436
|
||||
valid loss: 8.223067645108816
|
||||
valid loss clipped: 8.223067645108816
|
||||
train loss: 0.03990507605068726
|
||||
valid loss scaled: 0.05784423713065072
|
||||
valid loss: 9.575362301216071
|
||||
valid loss clipped: 9.575185345956458
|
||||
train loss: 0.040589384846144605
|
||||
valid loss scaled: 0.0496796067080015
|
||||
valid loss: 8.223812362251003
|
||||
valid loss clipped: 8.223812362251003
|
||||
train loss: 0.04017873370107866
|
||||
valid loss scaled: 0.046199073425154856
|
||||
valid loss: 7.647655631709408
|
||||
valid loss clipped: 7.647218458815767
|
||||
train loss: 0.04044891785946012
|
||||
valid loss scaled: 0.04150905843239381
|
||||
valid loss: 6.871281648789468
|
||||
valid loss clipped: 6.871281648789468
|
||||
train loss: 0.03947431928291259
|
||||
valid loss scaled: 0.05190902248718164
|
||||
valid loss: 8.592867224099864
|
||||
valid loss clipped: 8.592867224099864
|
||||
train loss: 0.03990533801732083
|
||||
valid loss scaled: 0.04276042828750439
|
||||
valid loss: 7.078432226253377
|
||||
valid loss clipped: 7.078432226253377
|
||||
train loss: 0.040009430050001356
|
||||
valid loss scaled: 0.04785424872916424
|
||||
valid loss: 7.92164959531741
|
||||
valid loss clipped: 7.92164959531741
|
||||
train loss: 0.03973275025215787
|
||||
valid loss scaled: 0.0463738630632847
|
||||
valid loss: 7.676589189399943
|
||||
valid loss clipped: 7.67643024694617
|
||||
train loss: 0.038768686683078965
|
||||
valid loss scaled: 0.04392041103001954
|
||||
valid loss: 7.270451050508478
|
||||
valid loss clipped: 7.270451050508478
|
||||
train loss: 0.039348882183647814
|
||||
valid loss scaled: 0.04542625701933914
|
||||
valid loss: 7.519721720555645
|
||||
valid loss clipped: 7.519588671593333
|
||||
train loss: 0.03904904227063566
|
||||
valid loss scaled: 0.04735228840997471
|
||||
valid loss: 7.838551573747202
|
||||
valid loss clipped: 7.8383747004264395
|
||||
train loss: 0.03885343865845939
|
||||
valid loss scaled: 0.04507136710901525
|
||||
valid loss: 7.460976738161493
|
||||
valid loss clipped: 7.460948397077222
|
||||
train loss: 0.039375428725538786
|
||||
valid loss scaled: 0.049669896283577444
|
||||
valid loss: 8.222201176726562
|
||||
valid loss clipped: 8.222057108632997
|
||||
train loss: 0.03995204566972607
|
||||
valid loss scaled: 0.048435733190966335
|
||||
valid loss: 8.017900004946629
|
||||
valid loss clipped: 8.017900004946629
|
||||
train loss: 0.0398491290780488
|
||||
valid loss scaled: 0.04720521208137839
|
||||
valid loss: 7.814206857210747
|
||||
valid loss clipped: 7.814180725425085
|
||||
train loss: 0.039705720615770324
|
||||
valid loss scaled: 0.04476360371587965
|
||||
valid loss: 7.410029695071655
|
||||
valid loss clipped: 7.410029695071655
|
||||
train loss: 0.03922532949685088
|
||||
valid loss scaled: 0.043110160856005976
|
||||
valid loss: 7.136322356669672
|
||||
valid loss clipped: 7.136322356669672
|
||||
train loss: 0.03944114218638488
|
||||
valid loss scaled: 0.05162181950303874
|
||||
valid loss: 8.54532049172949
|
||||
valid loss clipped: 8.545318087656034
|
||||
train loss: 0.038248045063862056
|
||||
valid loss scaled: 0.045729077826710006
|
||||
valid loss: 7.5698496787836405
|
||||
valid loss clipped: 7.5698496787836405
|
||||
train loss: 0.038367600972177264
|
||||
valid loss scaled: 0.045730969331967096
|
||||
valid loss: 7.570170291346379
|
||||
valid loss clipped: 7.570170291346379
|
||||
train loss: 0.03882607528483637
|
||||
valid loss scaled: 0.04907676118577929
|
||||
valid loss: 8.124020217147429
|
||||
valid loss clipped: 8.124020217147429
|
||||
train loss: 0.03804710767993723
|
||||
valid loss scaled: 0.043570846150311024
|
||||
valid loss: 7.2125900868664665
|
||||
valid loss clipped: 7.2125900868664665
|
||||
epoch 2 done, EVALUATION ON FULL DEV:
|
||||
valid loss scaled: 0.04305710652707711
|
||||
valid loss: 7.127543370821652
|
||||
valid loss clipped: 7.127414565514862
|
||||
evaluation done
|
||||
epoch 3
|
||||
train loss: 0.03567875976729889
|
||||
valid loss scaled: 0.0435919204035305
|
||||
valid loss: 7.216075509096017
|
||||
valid loss clipped: 7.216075509096017
|
||||
train loss: 0.03612129883597582
|
||||
valid loss scaled: 0.05422117357473846
|
||||
valid loss: 8.975609535535733
|
||||
valid loss clipped: 8.975609535535733
|
||||
train loss: 0.03610084946473125
|
||||
valid loss scaled: 0.04305326329624392
|
||||
valid loss: 7.126904973342827
|
||||
valid loss clipped: 7.126904973342827
|
||||
train loss: 0.0354239810019526
|
||||
valid loss scaled: 0.041925094025886615
|
||||
valid loss: 6.940151276296263
|
||||
valid loss clipped: 6.940151276296263
|
||||
train loss: 0.0352127675835109
|
||||
valid loss scaled: 0.0480450296753359
|
||||
valid loss: 7.9532327819071025
|
||||
valid loss clipped: 7.9532327819071025
|
||||
train loss: 0.03504781431121313
|
||||
valid loss scaled: 0.05108940815500588
|
||||
valid loss: 8.457186233906421
|
||||
valid loss clipped: 8.45717667639358
|
||||
train loss: 0.034990980345574955
|
||||
valid loss scaled: 0.05037606369447504
|
||||
valid loss: 8.339104845822392
|
||||
valid loss clipped: 8.339104845822392
|
||||
train loss: 0.034820230333430666
|
||||
valid loss scaled: 0.04548966136491771
|
||||
valid loss: 7.530221839492527
|
||||
valid loss clipped: 7.530221839492527
|
||||
train loss: 0.034814962549952325
|
||||
valid loss scaled: 0.044887413795706835
|
||||
valid loss: 7.430525821074159
|
||||
valid loss clipped: 7.430525821074159
|
||||
train loss: 0.035407486527384235
|
||||
valid loss scaled: 0.04845567947436388
|
||||
valid loss: 8.021209744958515
|
||||
valid loss clipped: 8.021177381420596
|
181
roberta_large_regressor_head/train.py
Normal file
181
roberta_large_regressor_head/train.py
Normal file
@ -0,0 +1,181 @@
|
||||
import os
|
||||
import torch
|
||||
import random
|
||||
import copy
|
||||
from fairseq.models.roberta import RobertaModel, RobertaHubInterface
|
||||
from fairseq import hub_utils
|
||||
from fairseq.data.data_utils import collate_tokens
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
|
||||
|
||||
EVAL_OFTEN = True
|
||||
EVAL_EVERY = 10000
|
||||
BATCH_SIZE = 1
|
||||
model_type = 'large' # base or large
|
||||
|
||||
|
||||
|
||||
roberta = torch.hub.load('pytorch/fairseq', f'roberta.{model_type}')
|
||||
roberta.cuda()
|
||||
device='cuda'
|
||||
|
||||
|
||||
# LOAD DATA
|
||||
train_in = [l.rstrip('\n') for l in open('../train/in.tsv',newline='\n').readlines()] # shuffled
|
||||
dev_in = [l.rstrip('\n') for l in open('../dev-0/in.tsv',newline='\n').readlines()] # shuffled
|
||||
|
||||
train_year = [float(l.rstrip('\n')) for l in open('../train/expected.tsv',newline='\n').readlines()]
|
||||
dev_year = [float(l.rstrip('\n')) for l in open('../dev-0/expected.tsv',newline='\n').readlines()]
|
||||
|
||||
dev_in_not_shuffled = copy.deepcopy(dev_in) # not shuffled
|
||||
test_in = [l.rstrip('\n') for l in open('../test-A/in.tsv',newline='\n').readlines()] # not shuffled
|
||||
|
||||
# SHUFFLE DATA
|
||||
c = list(zip(train_in,train_year))
|
||||
random.shuffle(c)
|
||||
train_in, train_year = zip(*c)
|
||||
c = list(zip(dev_in,dev_year))
|
||||
random.shuffle(c)
|
||||
dev_in, dev_year = zip(*c)
|
||||
|
||||
# SCALE DATA
|
||||
scaler = MinMaxScaler()
|
||||
train_year_scaled = scaler.fit_transform(np.array(train_year).reshape(-1,1))
|
||||
dev_year_scaled = scaler.transform(np.array(dev_year).reshape(-1,1))
|
||||
|
||||
|
||||
class RegressorHead(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(RegressorHead, self).__init__()
|
||||
in_dim = 768 if model_type == 'base' else 1024
|
||||
self.linear = torch.nn.Linear(in_dim, 1)
|
||||
self.m = torch.nn.LeakyReLU(0.1)
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
x = self.m(x)
|
||||
x = - self.m(-x + 1 ) +1
|
||||
return x
|
||||
|
||||
def get_features_and_year(dataset_in,dataset_y):
|
||||
for i in tqdm(range(0,len(dataset_in), BATCH_SIZE)):
|
||||
batch_of_text = dataset_in[i:i+BATCH_SIZE]
|
||||
|
||||
batch = collate_tokens([roberta.encode(p)[:512] for p in batch_of_text], pad_idx=1)
|
||||
features = roberta.extract_features(batch).mean(1)
|
||||
years = torch.FloatTensor(dataset_y[i:i+BATCH_SIZE]).to(device)
|
||||
|
||||
yield features, years
|
||||
|
||||
def eval_dev(short=False):
|
||||
criterion_eval = torch.nn.MSELoss(reduction='sum')
|
||||
roberta.eval()
|
||||
regressor_head.eval()
|
||||
|
||||
loss = 0.0
|
||||
loss_clipped = 0.0
|
||||
loss_scaled = 0.0
|
||||
|
||||
if short:
|
||||
dataset_in = dev_in[:1000]
|
||||
dataset_years = dev_year_scaled[:1000]
|
||||
else:
|
||||
dataset_in = dev_in
|
||||
dataset_years = dev_year_scaled
|
||||
|
||||
predictions_sum = 0
|
||||
for batch, year in tqdm(get_features_and_year(dataset_in, dataset_years)):
|
||||
|
||||
predictions_sum += year.shape[0]
|
||||
x = regressor_head(batch.to(device))
|
||||
x_clipped = torch.clamp(x,0.0,1.0)
|
||||
|
||||
original_x = torch.FloatTensor(scaler.inverse_transform(x.detach().cpu().numpy().reshape(1,-1)))
|
||||
original_x_clipped = torch.FloatTensor(scaler.inverse_transform(x_clipped.detach().cpu().numpy().reshape(1,-1)))
|
||||
original_year = torch.FloatTensor(scaler.inverse_transform(year.detach().cpu().numpy().reshape(1,-1)))
|
||||
|
||||
loss_scaled += criterion_eval(x, year).item()
|
||||
loss += criterion_eval(original_x, original_year).item()
|
||||
loss_clipped += criterion_eval(original_x_clipped, original_year).item()
|
||||
|
||||
print('valid loss scaled: ' + str(np.sqrt(loss_scaled/predictions_sum)))
|
||||
print('valid loss: ' + str(np.sqrt(loss/predictions_sum)))
|
||||
print('valid loss clipped: ' + str(np.sqrt(loss_clipped/predictions_sum)))
|
||||
|
||||
|
||||
|
||||
def train_one_epoch():
|
||||
roberta.train()
|
||||
regressor_head.train()
|
||||
loss_value=0.0
|
||||
iteration = 0
|
||||
for batch, year in get_features_and_year(train_in,train_year_scaled):
|
||||
iteration +=1
|
||||
roberta.zero_grad()
|
||||
regressor_head.zero_grad()
|
||||
|
||||
predictions = regressor_head(batch.to(device))
|
||||
|
||||
loss = criterion(predictions, year)
|
||||
loss_value += loss.item()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
roberta.zero_grad()
|
||||
regressor_head.zero_grad()
|
||||
|
||||
|
||||
if EVAL_OFTEN and (iteration > 1) and (iteration % EVAL_EVERY == 1):
|
||||
print('train loss: ' + str(np.sqrt(loss_value / (EVAL_EVERY*BATCH_SIZE))))
|
||||
eval_dev(True)
|
||||
roberta.train()
|
||||
regressor_head.train()
|
||||
loss_value = 0.0
|
||||
|
||||
|
||||
def predict(dataset='dev'):
|
||||
if dataset=='dev':
|
||||
f_out_path = '../dev-0/out.tsv'
|
||||
dataset_in_not_shuffled = dev_in_not_shuffled
|
||||
elif dataset=='test':
|
||||
f_out_path = '../test-A/out.tsv'
|
||||
dataset_in_not_shuffled = test_in
|
||||
roberta.eval()
|
||||
regressor_head.eval()
|
||||
f_out = open(f_out_path,'w')
|
||||
for batch, year in tqdm(get_features_and_year(dataset_in_not_shuffled, dev_year_scaled)):
|
||||
x = regressor_head(batch)
|
||||
x_clipped = torch.clamp(x,0.0,1.0)
|
||||
original_x_clipped = scaler.inverse_transform(x_clipped.detach().cpu().numpy().reshape(1,-1))
|
||||
for y in original_x_clipped[0]:
|
||||
f_out.write(str(y) + '\n')
|
||||
f_out.close()
|
||||
|
||||
|
||||
regressor_head = RegressorHead().to(device)
|
||||
|
||||
optimizer = torch.optim.Adam(list(roberta.parameters()) + list(regressor_head.parameters()), lr=1e-6)
|
||||
criterion = torch.nn.MSELoss(reduction='sum').to(device)
|
||||
|
||||
|
||||
roberta.load_state_dict(torch.load('checkpoints/roberta_to_regressor0.pt'))
|
||||
regressor_head.load_state_dict(torch.load('checkpoints/regressor_head0.pt'))
|
||||
for i in range(1,100):
|
||||
print('epoch ' + str(i))
|
||||
train_one_epoch()
|
||||
|
||||
print(f'epoch {i} done, EVALUATION ON FULL DEV:')
|
||||
eval_dev()
|
||||
print('evaluation done')
|
||||
predict('dev')
|
||||
predict('test')
|
||||
|
||||
torch.save(roberta.state_dict(),'checkpoints/roberta_to_regressor' + str(i) + '.pt')
|
||||
torch.save(regressor_head.state_dict(),'checkpoints/regressor_head' + str(i) + '.pt')
|
||||
|
||||
|
||||
roberta.load_state_dict(torch.load('checkpoints/roberta_to_regressor1.pt'))
|
||||
regressor_head.load_state_dict(torch.load('checkpoints/regressor_head1.pt'))
|
||||
predict('dev')
|
||||
predict('test')
|
73196
test-A/out.tsv
Normal file → Executable file
73196
test-A/out.tsv
Normal file → Executable file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user