294 lines
9.8 KiB
Python
294 lines
9.8 KiB
Python
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
|
|
|
|
__all__ = ["Conformer"]
|
|
|
|
|
|
def _lengths_to_padding_mask(lengths: torch.Tensor) -> torch.Tensor:
|
|
batch_size = lengths.shape[0]
|
|
max_length = int(torch.max(lengths).item())
|
|
padding_mask = torch.arange(max_length, device=lengths.device, dtype=lengths.dtype).expand(
|
|
batch_size, max_length
|
|
) >= lengths.unsqueeze(1)
|
|
return padding_mask
|
|
|
|
|
|
class _ConvolutionModule(torch.nn.Module):
|
|
r"""Conformer convolution module.
|
|
|
|
Args:
|
|
input_dim (int): input dimension.
|
|
num_channels (int): number of depthwise convolution layer input channels.
|
|
depthwise_kernel_size (int): kernel size of depthwise convolution layer.
|
|
dropout (float, optional): dropout probability. (Default: 0.0)
|
|
bias (bool, optional): indicates whether to add bias term to each convolution layer. (Default: ``False``)
|
|
use_group_norm (bool, optional): use GroupNorm rather than BatchNorm. (Default: ``False``)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_dim: int,
|
|
num_channels: int,
|
|
depthwise_kernel_size: int,
|
|
dropout: float = 0.0,
|
|
bias: bool = False,
|
|
use_group_norm: bool = False,
|
|
) -> None:
|
|
super().__init__()
|
|
if (depthwise_kernel_size - 1) % 2 != 0:
|
|
raise ValueError("depthwise_kernel_size must be odd to achieve 'SAME' padding.")
|
|
self.layer_norm = torch.nn.LayerNorm(input_dim)
|
|
self.sequential = torch.nn.Sequential(
|
|
torch.nn.Conv1d(
|
|
input_dim,
|
|
2 * num_channels,
|
|
1,
|
|
stride=1,
|
|
padding=0,
|
|
bias=bias,
|
|
),
|
|
torch.nn.GLU(dim=1),
|
|
torch.nn.Conv1d(
|
|
num_channels,
|
|
num_channels,
|
|
depthwise_kernel_size,
|
|
stride=1,
|
|
padding=(depthwise_kernel_size - 1) // 2,
|
|
groups=num_channels,
|
|
bias=bias,
|
|
),
|
|
torch.nn.GroupNorm(num_groups=1, num_channels=num_channels)
|
|
if use_group_norm
|
|
else torch.nn.BatchNorm1d(num_channels),
|
|
torch.nn.SiLU(),
|
|
torch.nn.Conv1d(
|
|
num_channels,
|
|
input_dim,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
bias=bias,
|
|
),
|
|
torch.nn.Dropout(dropout),
|
|
)
|
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
r"""
|
|
Args:
|
|
input (torch.Tensor): with shape `(B, T, D)`.
|
|
|
|
Returns:
|
|
torch.Tensor: output, with shape `(B, T, D)`.
|
|
"""
|
|
x = self.layer_norm(input)
|
|
x = x.transpose(1, 2)
|
|
x = self.sequential(x)
|
|
return x.transpose(1, 2)
|
|
|
|
|
|
class _FeedForwardModule(torch.nn.Module):
|
|
r"""Positionwise feed forward layer.
|
|
|
|
Args:
|
|
input_dim (int): input dimension.
|
|
hidden_dim (int): hidden dimension.
|
|
dropout (float, optional): dropout probability. (Default: 0.0)
|
|
"""
|
|
|
|
def __init__(self, input_dim: int, hidden_dim: int, dropout: float = 0.0) -> None:
|
|
super().__init__()
|
|
self.sequential = torch.nn.Sequential(
|
|
torch.nn.LayerNorm(input_dim),
|
|
torch.nn.Linear(input_dim, hidden_dim, bias=True),
|
|
torch.nn.SiLU(),
|
|
torch.nn.Dropout(dropout),
|
|
torch.nn.Linear(hidden_dim, input_dim, bias=True),
|
|
torch.nn.Dropout(dropout),
|
|
)
|
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
r"""
|
|
Args:
|
|
input (torch.Tensor): with shape `(*, D)`.
|
|
|
|
Returns:
|
|
torch.Tensor: output, with shape `(*, D)`.
|
|
"""
|
|
return self.sequential(input)
|
|
|
|
|
|
class ConformerLayer(torch.nn.Module):
|
|
r"""Conformer layer that constitutes Conformer.
|
|
|
|
Args:
|
|
input_dim (int): input dimension.
|
|
ffn_dim (int): hidden layer dimension of feedforward network.
|
|
num_attention_heads (int): number of attention heads.
|
|
depthwise_conv_kernel_size (int): kernel size of depthwise convolution layer.
|
|
dropout (float, optional): dropout probability. (Default: 0.0)
|
|
use_group_norm (bool, optional): use ``GroupNorm`` rather than ``BatchNorm1d``
|
|
in the convolution module. (Default: ``False``)
|
|
convolution_first (bool, optional): apply the convolution module ahead of
|
|
the attention module. (Default: ``False``)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_dim: int,
|
|
ffn_dim: int,
|
|
num_attention_heads: int,
|
|
depthwise_conv_kernel_size: int,
|
|
dropout: float = 0.0,
|
|
use_group_norm: bool = False,
|
|
convolution_first: bool = False,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.ffn1 = _FeedForwardModule(input_dim, ffn_dim, dropout=dropout)
|
|
|
|
self.self_attn_layer_norm = torch.nn.LayerNorm(input_dim)
|
|
self.self_attn = torch.nn.MultiheadAttention(input_dim, num_attention_heads, dropout=dropout)
|
|
self.self_attn_dropout = torch.nn.Dropout(dropout)
|
|
|
|
self.conv_module = _ConvolutionModule(
|
|
input_dim=input_dim,
|
|
num_channels=input_dim,
|
|
depthwise_kernel_size=depthwise_conv_kernel_size,
|
|
dropout=dropout,
|
|
bias=True,
|
|
use_group_norm=use_group_norm,
|
|
)
|
|
|
|
self.ffn2 = _FeedForwardModule(input_dim, ffn_dim, dropout=dropout)
|
|
self.final_layer_norm = torch.nn.LayerNorm(input_dim)
|
|
self.convolution_first = convolution_first
|
|
|
|
def _apply_convolution(self, input: torch.Tensor) -> torch.Tensor:
|
|
residual = input
|
|
input = input.transpose(0, 1)
|
|
input = self.conv_module(input)
|
|
input = input.transpose(0, 1)
|
|
input = residual + input
|
|
return input
|
|
|
|
def forward(self, input: torch.Tensor, key_padding_mask: Optional[torch.Tensor]) -> torch.Tensor:
|
|
r"""
|
|
Args:
|
|
input (torch.Tensor): input, with shape `(T, B, D)`.
|
|
key_padding_mask (torch.Tensor or None): key padding mask to use in self attention layer.
|
|
|
|
Returns:
|
|
torch.Tensor: output, with shape `(T, B, D)`.
|
|
"""
|
|
residual = input
|
|
x = self.ffn1(input)
|
|
x = x * 0.5 + residual
|
|
|
|
if self.convolution_first:
|
|
x = self._apply_convolution(x)
|
|
|
|
residual = x
|
|
x = self.self_attn_layer_norm(x)
|
|
x, _ = self.self_attn(
|
|
query=x,
|
|
key=x,
|
|
value=x,
|
|
key_padding_mask=key_padding_mask,
|
|
need_weights=False,
|
|
)
|
|
x = self.self_attn_dropout(x)
|
|
x = x + residual
|
|
|
|
if not self.convolution_first:
|
|
x = self._apply_convolution(x)
|
|
|
|
residual = x
|
|
x = self.ffn2(x)
|
|
x = x * 0.5 + residual
|
|
|
|
x = self.final_layer_norm(x)
|
|
return x
|
|
|
|
|
|
class Conformer(torch.nn.Module):
|
|
r"""Conformer architecture introduced in
|
|
*Conformer: Convolution-augmented Transformer for Speech Recognition*
|
|
:cite:`gulati2020conformer`.
|
|
|
|
Args:
|
|
input_dim (int): input dimension.
|
|
num_heads (int): number of attention heads in each Conformer layer.
|
|
ffn_dim (int): hidden layer dimension of feedforward networks.
|
|
num_layers (int): number of Conformer layers to instantiate.
|
|
depthwise_conv_kernel_size (int): kernel size of each Conformer layer's depthwise convolution layer.
|
|
dropout (float, optional): dropout probability. (Default: 0.0)
|
|
use_group_norm (bool, optional): use ``GroupNorm`` rather than ``BatchNorm1d``
|
|
in the convolution module. (Default: ``False``)
|
|
convolution_first (bool, optional): apply the convolution module ahead of
|
|
the attention module. (Default: ``False``)
|
|
|
|
Examples:
|
|
>>> conformer = Conformer(
|
|
>>> input_dim=80,
|
|
>>> num_heads=4,
|
|
>>> ffn_dim=128,
|
|
>>> num_layers=4,
|
|
>>> depthwise_conv_kernel_size=31,
|
|
>>> )
|
|
>>> lengths = torch.randint(1, 400, (10,)) # (batch,)
|
|
>>> input = torch.rand(10, int(lengths.max()), input_dim) # (batch, num_frames, input_dim)
|
|
>>> output = conformer(input, lengths)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_dim: int,
|
|
num_heads: int,
|
|
ffn_dim: int,
|
|
num_layers: int,
|
|
depthwise_conv_kernel_size: int,
|
|
dropout: float = 0.0,
|
|
use_group_norm: bool = False,
|
|
convolution_first: bool = False,
|
|
):
|
|
super().__init__()
|
|
|
|
self.conformer_layers = torch.nn.ModuleList(
|
|
[
|
|
ConformerLayer(
|
|
input_dim,
|
|
ffn_dim,
|
|
num_heads,
|
|
depthwise_conv_kernel_size,
|
|
dropout=dropout,
|
|
use_group_norm=use_group_norm,
|
|
convolution_first=convolution_first,
|
|
)
|
|
for _ in range(num_layers)
|
|
]
|
|
)
|
|
|
|
def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
r"""
|
|
Args:
|
|
input (torch.Tensor): with shape `(B, T, input_dim)`.
|
|
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
number of valid frames for i-th batch element in ``input``.
|
|
|
|
Returns:
|
|
(torch.Tensor, torch.Tensor)
|
|
torch.Tensor
|
|
output frames, with shape `(B, T, input_dim)`
|
|
torch.Tensor
|
|
output lengths, with shape `(B,)` and i-th element representing
|
|
number of valid frames for i-th batch element in output frames.
|
|
"""
|
|
encoder_padding_mask = _lengths_to_padding_mask(lengths)
|
|
|
|
x = input.transpose(0, 1)
|
|
for layer in self.conformer_layers:
|
|
x = layer(x, encoder_padding_mask)
|
|
return x.transpose(0, 1), lengths
|