16 lines
564 B
Python
16 lines
564 B
Python
import torch.nn as nn
|
|
import torch
|
|
|
|
|
|
class MyNeuralNetwork(nn.Module):
|
|
def __init__(self, vlen):
|
|
super(MyNeuralNetwork, self).__init__()
|
|
self.w1 = nn.Linear(vlen, 1)
|
|
self.w2 = nn.Linear(vlen, 1)
|
|
|
|
self.u1 = torch.nn.Parameter(torch.rand(1, dtype=torch.float, requires_grad=True))
|
|
self.u2 = torch.nn.Parameter(torch.rand(1, dtype=torch.float, requires_grad=True))
|
|
|
|
def forward(self, x):
|
|
return self.u1 * torch.nn.functional.tanh(self.w1(x).squeeze()) + self.u2 * torch.nn.functional.tanh(self.w2(x).squeeze())
|