aitech-eks-pub/wyk/pytorch-regression/my_neural_network.py

16 lines
564 B
Python
Raw Normal View History

2021-05-05 13:35:25 +02:00
import torch.nn as nn
import torch
class MyNeuralNetwork(nn.Module):
def __init__(self, vlen):
super(MyNeuralNetwork, self).__init__()
self.w1 = nn.Linear(vlen, 1)
self.w2 = nn.Linear(vlen, 1)
self.u1 = torch.nn.Parameter(torch.rand(1, dtype=torch.float, requires_grad=True))
self.u2 = torch.nn.Parameter(torch.rand(1, dtype=torch.float, requires_grad=True))
def forward(self, x):
return self.u1 * torch.nn.functional.tanh(self.w1(x).squeeze()) + self.u2 * torch.nn.functional.tanh(self.w2(x).squeeze())