from typing import Optional, Tuple import torch __all__ = ["Conformer"] def _lengths_to_padding_mask(lengths: torch.Tensor) -> torch.Tensor: batch_size = lengths.shape[0] max_length = int(torch.max(lengths).item()) padding_mask = torch.arange(max_length, device=lengths.device, dtype=lengths.dtype).expand( batch_size, max_length ) >= lengths.unsqueeze(1) return padding_mask class _ConvolutionModule(torch.nn.Module): r"""Conformer convolution module. Args: input_dim (int): input dimension. num_channels (int): number of depthwise convolution layer input channels. depthwise_kernel_size (int): kernel size of depthwise convolution layer. dropout (float, optional): dropout probability. (Default: 0.0) bias (bool, optional): indicates whether to add bias term to each convolution layer. (Default: ``False``) use_group_norm (bool, optional): use GroupNorm rather than BatchNorm. (Default: ``False``) """ def __init__( self, input_dim: int, num_channels: int, depthwise_kernel_size: int, dropout: float = 0.0, bias: bool = False, use_group_norm: bool = False, ) -> None: super().__init__() if (depthwise_kernel_size - 1) % 2 != 0: raise ValueError("depthwise_kernel_size must be odd to achieve 'SAME' padding.") self.layer_norm = torch.nn.LayerNorm(input_dim) self.sequential = torch.nn.Sequential( torch.nn.Conv1d( input_dim, 2 * num_channels, 1, stride=1, padding=0, bias=bias, ), torch.nn.GLU(dim=1), torch.nn.Conv1d( num_channels, num_channels, depthwise_kernel_size, stride=1, padding=(depthwise_kernel_size - 1) // 2, groups=num_channels, bias=bias, ), torch.nn.GroupNorm(num_groups=1, num_channels=num_channels) if use_group_norm else torch.nn.BatchNorm1d(num_channels), torch.nn.SiLU(), torch.nn.Conv1d( num_channels, input_dim, kernel_size=1, stride=1, padding=0, bias=bias, ), torch.nn.Dropout(dropout), ) def forward(self, input: torch.Tensor) -> torch.Tensor: r""" Args: input (torch.Tensor): with shape `(B, T, D)`. Returns: torch.Tensor: output, with shape `(B, T, D)`. """ x = self.layer_norm(input) x = x.transpose(1, 2) x = self.sequential(x) return x.transpose(1, 2) class _FeedForwardModule(torch.nn.Module): r"""Positionwise feed forward layer. Args: input_dim (int): input dimension. hidden_dim (int): hidden dimension. dropout (float, optional): dropout probability. (Default: 0.0) """ def __init__(self, input_dim: int, hidden_dim: int, dropout: float = 0.0) -> None: super().__init__() self.sequential = torch.nn.Sequential( torch.nn.LayerNorm(input_dim), torch.nn.Linear(input_dim, hidden_dim, bias=True), torch.nn.SiLU(), torch.nn.Dropout(dropout), torch.nn.Linear(hidden_dim, input_dim, bias=True), torch.nn.Dropout(dropout), ) def forward(self, input: torch.Tensor) -> torch.Tensor: r""" Args: input (torch.Tensor): with shape `(*, D)`. Returns: torch.Tensor: output, with shape `(*, D)`. """ return self.sequential(input) class ConformerLayer(torch.nn.Module): r"""Conformer layer that constitutes Conformer. Args: input_dim (int): input dimension. ffn_dim (int): hidden layer dimension of feedforward network. num_attention_heads (int): number of attention heads. depthwise_conv_kernel_size (int): kernel size of depthwise convolution layer. dropout (float, optional): dropout probability. (Default: 0.0) use_group_norm (bool, optional): use ``GroupNorm`` rather than ``BatchNorm1d`` in the convolution module. (Default: ``False``) convolution_first (bool, optional): apply the convolution module ahead of the attention module. (Default: ``False``) """ def __init__( self, input_dim: int, ffn_dim: int, num_attention_heads: int, depthwise_conv_kernel_size: int, dropout: float = 0.0, use_group_norm: bool = False, convolution_first: bool = False, ) -> None: super().__init__() self.ffn1 = _FeedForwardModule(input_dim, ffn_dim, dropout=dropout) self.self_attn_layer_norm = torch.nn.LayerNorm(input_dim) self.self_attn = torch.nn.MultiheadAttention(input_dim, num_attention_heads, dropout=dropout) self.self_attn_dropout = torch.nn.Dropout(dropout) self.conv_module = _ConvolutionModule( input_dim=input_dim, num_channels=input_dim, depthwise_kernel_size=depthwise_conv_kernel_size, dropout=dropout, bias=True, use_group_norm=use_group_norm, ) self.ffn2 = _FeedForwardModule(input_dim, ffn_dim, dropout=dropout) self.final_layer_norm = torch.nn.LayerNorm(input_dim) self.convolution_first = convolution_first def _apply_convolution(self, input: torch.Tensor) -> torch.Tensor: residual = input input = input.transpose(0, 1) input = self.conv_module(input) input = input.transpose(0, 1) input = residual + input return input def forward(self, input: torch.Tensor, key_padding_mask: Optional[torch.Tensor]) -> torch.Tensor: r""" Args: input (torch.Tensor): input, with shape `(T, B, D)`. key_padding_mask (torch.Tensor or None): key padding mask to use in self attention layer. Returns: torch.Tensor: output, with shape `(T, B, D)`. """ residual = input x = self.ffn1(input) x = x * 0.5 + residual if self.convolution_first: x = self._apply_convolution(x) residual = x x = self.self_attn_layer_norm(x) x, _ = self.self_attn( query=x, key=x, value=x, key_padding_mask=key_padding_mask, need_weights=False, ) x = self.self_attn_dropout(x) x = x + residual if not self.convolution_first: x = self._apply_convolution(x) residual = x x = self.ffn2(x) x = x * 0.5 + residual x = self.final_layer_norm(x) return x class Conformer(torch.nn.Module): r"""Conformer architecture introduced in *Conformer: Convolution-augmented Transformer for Speech Recognition* :cite:`gulati2020conformer`. Args: input_dim (int): input dimension. num_heads (int): number of attention heads in each Conformer layer. ffn_dim (int): hidden layer dimension of feedforward networks. num_layers (int): number of Conformer layers to instantiate. depthwise_conv_kernel_size (int): kernel size of each Conformer layer's depthwise convolution layer. dropout (float, optional): dropout probability. (Default: 0.0) use_group_norm (bool, optional): use ``GroupNorm`` rather than ``BatchNorm1d`` in the convolution module. (Default: ``False``) convolution_first (bool, optional): apply the convolution module ahead of the attention module. (Default: ``False``) Examples: >>> conformer = Conformer( >>> input_dim=80, >>> num_heads=4, >>> ffn_dim=128, >>> num_layers=4, >>> depthwise_conv_kernel_size=31, >>> ) >>> lengths = torch.randint(1, 400, (10,)) # (batch,) >>> input = torch.rand(10, int(lengths.max()), input_dim) # (batch, num_frames, input_dim) >>> output = conformer(input, lengths) """ def __init__( self, input_dim: int, num_heads: int, ffn_dim: int, num_layers: int, depthwise_conv_kernel_size: int, dropout: float = 0.0, use_group_norm: bool = False, convolution_first: bool = False, ): super().__init__() self.conformer_layers = torch.nn.ModuleList( [ ConformerLayer( input_dim, ffn_dim, num_heads, depthwise_conv_kernel_size, dropout=dropout, use_group_norm=use_group_norm, convolution_first=convolution_first, ) for _ in range(num_layers) ] ) def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: input (torch.Tensor): with shape `(B, T, input_dim)`. lengths (torch.Tensor): with shape `(B,)` and i-th element representing number of valid frames for i-th batch element in ``input``. Returns: (torch.Tensor, torch.Tensor) torch.Tensor output frames, with shape `(B, T, input_dim)` torch.Tensor output lengths, with shape `(B,)` and i-th element representing number of valid frames for i-th batch element in output frames. """ encoder_padding_mask = _lengths_to_padding_mask(lengths) x = input.transpose(0, 1) for layer in self.conformer_layers: x = layer(x, encoder_padding_mask) return x.transpose(0, 1), lengths