85 lines
2.7 KiB
Python
85 lines
2.7 KiB
Python
|
import torch
|
||
|
|
||
|
__all__ = ["DeepSpeech"]
|
||
|
|
||
|
|
||
|
class FullyConnected(torch.nn.Module):
|
||
|
"""
|
||
|
Args:
|
||
|
n_feature: Number of input features
|
||
|
n_hidden: Internal hidden unit size.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, n_feature: int, n_hidden: int, dropout: float, relu_max_clip: int = 20) -> None:
|
||
|
super(FullyConnected, self).__init__()
|
||
|
self.fc = torch.nn.Linear(n_feature, n_hidden, bias=True)
|
||
|
self.relu_max_clip = relu_max_clip
|
||
|
self.dropout = dropout
|
||
|
|
||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
|
x = self.fc(x)
|
||
|
x = torch.nn.functional.relu(x)
|
||
|
x = torch.nn.functional.hardtanh(x, 0, self.relu_max_clip)
|
||
|
if self.dropout:
|
||
|
x = torch.nn.functional.dropout(x, self.dropout, self.training)
|
||
|
return x
|
||
|
|
||
|
|
||
|
class DeepSpeech(torch.nn.Module):
|
||
|
"""DeepSpeech architecture introduced in
|
||
|
*Deep Speech: Scaling up end-to-end speech recognition* :cite:`hannun2014deep`.
|
||
|
|
||
|
Args:
|
||
|
n_feature: Number of input features
|
||
|
n_hidden: Internal hidden unit size.
|
||
|
n_class: Number of output classes
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
n_feature: int,
|
||
|
n_hidden: int = 2048,
|
||
|
n_class: int = 40,
|
||
|
dropout: float = 0.0,
|
||
|
) -> None:
|
||
|
super(DeepSpeech, self).__init__()
|
||
|
self.n_hidden = n_hidden
|
||
|
self.fc1 = FullyConnected(n_feature, n_hidden, dropout)
|
||
|
self.fc2 = FullyConnected(n_hidden, n_hidden, dropout)
|
||
|
self.fc3 = FullyConnected(n_hidden, n_hidden, dropout)
|
||
|
self.bi_rnn = torch.nn.RNN(n_hidden, n_hidden, num_layers=1, nonlinearity="relu", bidirectional=True)
|
||
|
self.fc4 = FullyConnected(n_hidden, n_hidden, dropout)
|
||
|
self.out = torch.nn.Linear(n_hidden, n_class)
|
||
|
|
||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
|
"""
|
||
|
Args:
|
||
|
x (torch.Tensor): Tensor of dimension (batch, channel, time, feature).
|
||
|
Returns:
|
||
|
Tensor: Predictor tensor of dimension (batch, time, class).
|
||
|
"""
|
||
|
# N x C x T x F
|
||
|
x = self.fc1(x)
|
||
|
# N x C x T x H
|
||
|
x = self.fc2(x)
|
||
|
# N x C x T x H
|
||
|
x = self.fc3(x)
|
||
|
# N x C x T x H
|
||
|
x = x.squeeze(1)
|
||
|
# N x T x H
|
||
|
x = x.transpose(0, 1)
|
||
|
# T x N x H
|
||
|
x, _ = self.bi_rnn(x)
|
||
|
# The fifth (non-recurrent) layer takes both the forward and backward units as inputs
|
||
|
x = x[:, :, : self.n_hidden] + x[:, :, self.n_hidden :]
|
||
|
# T x N x H
|
||
|
x = self.fc4(x)
|
||
|
# T x N x H
|
||
|
x = self.out(x)
|
||
|
# T x N x n_class
|
||
|
x = x.permute(1, 0, 2)
|
||
|
# N x T x n_class
|
||
|
x = torch.nn.functional.log_softmax(x, dim=2)
|
||
|
# N x T x n_class
|
||
|
return x
|