roberta base with regression layer on top lr=1e-8 4 epochs
This commit is contained in:
parent
8a5bc51f44
commit
10f548680a
12488
dev-0/out.tsv
12488
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
@ -65,7 +65,7 @@ class RegressorHead(torch.nn.Module):
|
|||||||
|
|
||||||
regressor_head = RegressorHead().to(device)
|
regressor_head = RegressorHead().to(device)
|
||||||
|
|
||||||
optimizer = torch.optim.Adam(list(roberta.parameters()) + list(regressor_head.parameters()), lr=1e-6)
|
optimizer = torch.optim.Adam(list(roberta.parameters()) + list(regressor_head.parameters()), lr=1e-8)
|
||||||
criterion = torch.nn.MSELoss(reduction='sum').to(device)
|
criterion = torch.nn.MSELoss(reduction='sum').to(device)
|
||||||
|
|
||||||
BATCH_SIZE = 1
|
BATCH_SIZE = 1
|
||||||
@ -118,7 +118,8 @@ def eval_short():
|
|||||||
loss = 0.0
|
loss = 0.0
|
||||||
loss_clipped = 0.0
|
loss_clipped = 0.0
|
||||||
loss_scaled = 0.0
|
loss_scaled = 0.0
|
||||||
for batch, year in tqdm(get_train_batch(dev_in[:1000],dev_year_scaled[:1000])):
|
eval_num = 10000
|
||||||
|
for batch, year in tqdm(get_train_batch(dev_in[:eval_num],dev_year_scaled[:eval_num])):
|
||||||
|
|
||||||
x = regressor_head(batch.to(device)).squeeze()
|
x = regressor_head(batch.to(device)).squeeze()
|
||||||
x_clipped = torch.clamp(x,0.0,1.0)
|
x_clipped = torch.clamp(x,0.0,1.0)
|
||||||
@ -130,8 +131,8 @@ def eval_short():
|
|||||||
loss_scaled += criterion_eval(x, year).item()
|
loss_scaled += criterion_eval(x, year).item()
|
||||||
loss += criterion_eval(original_x, original_year).item()
|
loss += criterion_eval(original_x, original_year).item()
|
||||||
loss_clipped += criterion_eval(original_x_clipped, original_year).item()
|
loss_clipped += criterion_eval(original_x_clipped, original_year).item()
|
||||||
print('valid loss scaled: ' + str(np.sqrt(loss_scaled/1000)))
|
print('valid loss scaled: ' + str(np.sqrt(loss_scaled/eval_num)))
|
||||||
print('valid loss: ' + str(np.sqrt(loss/1000)))
|
print('valid loss: ' + str(np.sqrt(loss/eval_num)))
|
||||||
print('valid loss clipped: ' + str(np.sqrt(loss_clipped/len(dev_year))))
|
print('valid loss clipped: ' + str(np.sqrt(loss_clipped/len(dev_year))))
|
||||||
|
|
||||||
|
|
||||||
|
10722
test-A/out.tsv
10722
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user