add heads

This commit is contained in:
Jakub Pokrywka 2021-12-29 13:40:11 +01:00
parent c51a389af7
commit c67fa2f8af
2 changed files with 23 additions and 0 deletions

View File

@ -0,0 +1,10 @@
import torch
class YearClassificationHead(torch.nn.Module):
def __init__(self, in_dim, MIN_YEAR, MAX_YEAR):
super(YearClassificationHead, self).__init__()
self.linear = torch.nn.Linear(in_dim, MAX_YEAR - MIN_YEAR + 1)
def forward(self, x):
x = x.mean(1)
x = self.linear(x)
return x

View File

@ -0,0 +1,13 @@
import torch
class RegressorHead(torch.nn.Module):
def __init__(self, in_dim):
super(RegressorHead, self).__init__()
self.linear = torch.nn.Linear(in_dim, 1)
self.m = torch.nn.LeakyReLU(0.1)
def forward(self, x):
x = x.mean(1)
x = self.linear(x)
x = self.m(x)
x = - self.m(-x + 1 ) +1
return x