ireland-news-headlines-year.../hf_roberta_base_classification/regressor_head.py
Jakub Pokrywka c67fa2f8af add heads
2021-12-29 13:40:11 +01:00

14 lines
664 B
Python

import torch
class RegressorHead(torch.nn.Module):
def __init__(self, in_dim):
super(RegressorHead, self).__init__()
self.linear = torch.nn.Linear(in_dim, 1)
self.m = torch.nn.LeakyReLU(0.1)
def forward(self, x):
x = x.mean(1)
x = self.linear(x)
x = self.m(x)
x = - self.m(-x + 1 ) +1
return x